mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
cleanup, add ci
This commit is contained in:
@@ -106,7 +106,7 @@ mod tests {
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 2u32, 3u32];
|
||||
|
||||
|
||||
// Create a mock TextGeneration instance
|
||||
// Since we can't easily create a full TextGeneration instance without a model,
|
||||
// we'll test the logic by creating a simple struct with the necessary fields
|
||||
@@ -115,7 +115,7 @@ mod tests {
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
@@ -167,16 +167,17 @@ mod tests {
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 1.0, // No penalty
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
let (result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
let result_data = result_logits.to_vec1::<f32>()?;
|
||||
|
||||
|
||||
// With no penalty, logits should be unchanged
|
||||
assert_eq!(result_data, logits_data);
|
||||
Ok(())
|
||||
@@ -189,13 +190,13 @@ mod tests {
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 2u32, 3u32];
|
||||
|
||||
|
||||
struct MockTextGeneration {
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
@@ -238,16 +239,17 @@ mod tests {
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 2.0, // Apply penalty
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
let (result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
let result_data = result_logits.to_vec1::<f32>()?;
|
||||
|
||||
|
||||
// Tokens 1, 2, 3 should be penalized (divided by 2.0)
|
||||
let expected = vec![1.0f32, 1.0, 1.5, 2.0, 5.0]; // [1.0, 2.0/2.0, 3.0/2.0, 4.0/2.0, 5.0]
|
||||
assert_eq!(result_data, expected);
|
||||
@@ -261,13 +263,13 @@ mod tests {
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 1u32, 1u32]; // Repeated token should use cache
|
||||
|
||||
|
||||
struct MockTextGeneration {
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
@@ -308,20 +310,21 @@ mod tests {
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 2.0,
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
|
||||
// First call should cache the penalty for token 1
|
||||
let (_result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
let (_result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
// Cache should contain the penalized value for token 1
|
||||
assert!(mock_gen.penalty_cache.contains_key(&1));
|
||||
assert_eq!(mock_gen.penalty_cache.get(&1), Some(&1.0)); // 2.0 / 2.0 = 1.0
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -332,13 +335,13 @@ mod tests {
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens: Vec<u32> = vec![]; // Empty tokens
|
||||
|
||||
|
||||
struct MockTextGeneration {
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
@@ -379,16 +382,17 @@ mod tests {
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 2.0,
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
let (result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
let result_data = result_logits.to_vec1::<f32>()?;
|
||||
|
||||
|
||||
// With empty tokens, logits should be unchanged
|
||||
assert_eq!(result_data, logits_data);
|
||||
Ok(())
|
||||
@@ -401,13 +405,13 @@ mod tests {
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 5u32, 10u32]; // Token 5 and 10 are out of bounds
|
||||
|
||||
|
||||
struct MockTextGeneration {
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
@@ -448,16 +452,17 @@ mod tests {
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 2.0,
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
let (result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
let result_data = result_logits.to_vec1::<f32>()?;
|
||||
|
||||
|
||||
// Only token 1 should be penalized, out-of-bounds tokens should be ignored
|
||||
let expected = vec![1.0f32, 1.0, 3.0]; // [1.0, 2.0/2.0, 3.0]
|
||||
assert_eq!(result_data, expected);
|
||||
@@ -471,52 +476,52 @@ mod tests {
|
||||
// Since creating a real TextGeneration instance requires a Model which needs model weights,
|
||||
// we'll create a test that demonstrates the method is now public and can be accessed.
|
||||
// The comprehensive functionality testing is already covered by the mock tests above.
|
||||
|
||||
|
||||
// Test data setup
|
||||
let device = Device::Cpu;
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 2u32, 3u32];
|
||||
|
||||
|
||||
// Test that we can create the necessary components
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
|
||||
|
||||
// The method is now public as confirmed by making it pub fn apply_cached_repeat_penalty
|
||||
// This test verifies the method signature and that it's accessible from external code
|
||||
|
||||
|
||||
// We could create a TextGeneration instance if we had a way to mock the Model,
|
||||
// but for now we confirm that the existing mock tests cover the functionality
|
||||
// and the method is properly exposed as public
|
||||
|
||||
|
||||
println!("apply_cached_repeat_penalty method is now public and accessible for testing");
|
||||
assert!(true);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
// Integration test that demonstrates the method usage pattern
|
||||
#[test]
|
||||
#[test]
|
||||
fn test_apply_cached_repeat_penalty_usage_pattern() -> Result<()> {
|
||||
// This test demonstrates how the apply_cached_repeat_penalty method would be used
|
||||
// in practice, even though we can't create a full TextGeneration instance in unit tests
|
||||
|
||||
|
||||
let device = Device::Cpu;
|
||||
let logits_data = vec![1.5f32, 2.5, 3.5, 4.5, 5.5];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 2u32, 1u32, 3u32]; // Repeated token 1 to test caching
|
||||
|
||||
|
||||
// Test parameters that would be used with TextGeneration
|
||||
let repeat_penalty = 1.2f32;
|
||||
let repeat_last_n = 3usize;
|
||||
let mut penalty_cache: HashMap<usize, f32> = HashMap::new();
|
||||
|
||||
|
||||
// Simulate the method's logic to verify it works as expected
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
if repeat_penalty != 1.0 {
|
||||
let start_at = tokens.len().saturating_sub(repeat_last_n);
|
||||
let penalty_tokens = &tokens[start_at..];
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
|
||||
for &token_id in penalty_tokens {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
@@ -531,14 +536,14 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let _duration = start_time.elapsed();
|
||||
|
||||
|
||||
// Verify that tokens were processed correctly
|
||||
assert!(penalty_cache.contains_key(&1)); // Token 1 should be cached
|
||||
assert!(penalty_cache.contains_key(&2)); // Token 2 should be cached
|
||||
assert!(penalty_cache.contains_key(&2)); // Token 2 should be cached
|
||||
assert!(penalty_cache.contains_key(&3)); // Token 3 should be cached
|
||||
|
||||
|
||||
println!("Successfully demonstrated apply_cached_repeat_penalty usage pattern");
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user