mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
align dependencies across inference features
This commit is contained in:
@@ -1,554 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use inference_engine::model::Which;
|
||||
use inference_engine::text_generation::TextGeneration;
|
||||
use inference_engine::token_output_stream::TokenOutputStream;
|
||||
use std::collections::HashMap;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Helper function to create a simple tokenizer for testing
|
||||
fn create_test_tokenizer() -> Result<Tokenizer> {
|
||||
// Create a simple tokenizer from the pretrained model
|
||||
// This uses the tokenizer from the Hugging Face hub
|
||||
let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap();
|
||||
Ok(tokenizer)
|
||||
}
|
||||
|
||||
// Test the Which enum's to_model_id method
|
||||
#[test]
|
||||
fn test_which_model_id() {
|
||||
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
|
||||
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
|
||||
}
|
||||
|
||||
// Test the Which enum's is_instruct_model method
|
||||
#[test]
|
||||
fn test_which_is_instruct() {
|
||||
assert!(!Which::Base2B.is_instruct_model());
|
||||
assert!(Which::Instruct7B.is_instruct_model());
|
||||
}
|
||||
|
||||
// Test the Which enum's is_v3_model method
|
||||
#[test]
|
||||
fn test_which_is_v3() {
|
||||
assert!(!Which::Base2B.is_v3_model());
|
||||
assert!(Which::BaseV3_1B.is_v3_model());
|
||||
}
|
||||
|
||||
// Test the TokenOutputStream functionality
|
||||
#[test]
|
||||
fn test_token_output_stream() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
// Test encoding and decoding
|
||||
let text = "Hello, world!";
|
||||
let encoded = token_stream.tokenizer().encode(text, true).unwrap();
|
||||
let token_ids = encoded.get_ids();
|
||||
|
||||
// Add tokens one by one
|
||||
for &token_id in token_ids {
|
||||
token_stream.next_token(token_id)?;
|
||||
}
|
||||
|
||||
// Decode all and check
|
||||
let decoded = token_stream.decode_all()?;
|
||||
assert_eq!(decoded.trim(), text);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Test the LogitsProcessor
|
||||
#[test]
|
||||
fn test_logits_processor() -> Result<()> {
|
||||
// Create a LogitsProcessor with default settings
|
||||
let seed = 42;
|
||||
let temp = Some(0.8);
|
||||
let top_p = Some(0.9);
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
|
||||
// Create a simple logits tensor
|
||||
// In a real test, we would create a tensor with known values and verify
|
||||
// that sampling produces expected results
|
||||
|
||||
// For now, we'll just verify that the LogitsProcessor can be created
|
||||
assert!(true);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Test the TextGeneration constructor
|
||||
#[test]
|
||||
fn test_text_generation_constructor() -> Result<()> {
|
||||
// We can't easily create a Model instance for testing,
|
||||
// but we can test that the constructor compiles and the types are correct
|
||||
|
||||
// In a real test with a mock Model, we would:
|
||||
// 1. Create a mock model
|
||||
// 2. Create a tokenizer
|
||||
// 3. Call TextGeneration::new
|
||||
// 4. Verify the properties of the created instance
|
||||
|
||||
// For now, we'll just verify that the code compiles
|
||||
assert!(true);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Test apply_cached_repeat_penalty method with no penalty
|
||||
#[test]
|
||||
fn test_apply_cached_repeat_penalty_no_penalty() -> Result<()> {
|
||||
// Create a simple test setup
|
||||
let device = Device::Cpu;
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 2u32, 3u32];
|
||||
|
||||
// Create a mock TextGeneration instance
|
||||
// Since we can't easily create a full TextGeneration instance without a model,
|
||||
// we'll test the logic by creating a simple struct with the necessary fields
|
||||
struct MockTextGeneration {
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
logits: Tensor,
|
||||
tokens: &[u32],
|
||||
) -> Result<(Tensor, std::time::Duration)> {
|
||||
let repeat_start = std::time::Instant::now();
|
||||
|
||||
// If no penalty, return the original logits
|
||||
if self.repeat_penalty == 1.0 {
|
||||
return Ok((logits, repeat_start.elapsed()));
|
||||
}
|
||||
|
||||
// Get the tokens to penalize (the last n tokens)
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
let penalty_tokens = &tokens[start_at..];
|
||||
|
||||
// Extract logits to a vector for modification
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
let cache_hits = std::cell::Cell::new(0);
|
||||
|
||||
// Apply penalties with caching
|
||||
for &token_id in penalty_tokens {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
// Check if we've already calculated this token's penalty
|
||||
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
|
||||
// Use cached value
|
||||
logits_vec[token_id] = *penalized_score;
|
||||
cache_hits.set(cache_hits.get() + 1);
|
||||
} else {
|
||||
// Calculate and cache new value
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
let penalized_score = sign * score / self.repeat_penalty;
|
||||
logits_vec[token_id] = penalized_score;
|
||||
self.penalty_cache.insert(token_id, penalized_score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new tensor with the modified logits
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
let result = new_logits.reshape(shape)?;
|
||||
|
||||
let elapsed = repeat_start.elapsed();
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 1.0, // No penalty
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
let (result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
let result_data = result_logits.to_vec1::<f32>()?;
|
||||
|
||||
// With no penalty, logits should be unchanged
|
||||
assert_eq!(result_data, logits_data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Test apply_cached_repeat_penalty method with penalty
|
||||
#[test]
|
||||
fn test_apply_cached_repeat_penalty_with_penalty() -> Result<()> {
|
||||
let device = Device::Cpu;
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 2u32, 3u32];
|
||||
|
||||
struct MockTextGeneration {
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
logits: Tensor,
|
||||
tokens: &[u32],
|
||||
) -> Result<(Tensor, std::time::Duration)> {
|
||||
let repeat_start = std::time::Instant::now();
|
||||
|
||||
if self.repeat_penalty == 1.0 {
|
||||
return Ok((logits, repeat_start.elapsed()));
|
||||
}
|
||||
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
let penalty_tokens = &tokens[start_at..];
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
let cache_hits = std::cell::Cell::new(0);
|
||||
|
||||
for &token_id in penalty_tokens {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
|
||||
logits_vec[token_id] = *penalized_score;
|
||||
cache_hits.set(cache_hits.get() + 1);
|
||||
} else {
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
let penalized_score = sign * score / self.repeat_penalty;
|
||||
logits_vec[token_id] = penalized_score;
|
||||
self.penalty_cache.insert(token_id, penalized_score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
let result = new_logits.reshape(shape)?;
|
||||
|
||||
let elapsed = repeat_start.elapsed();
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 2.0, // Apply penalty
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
let (result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
let result_data = result_logits.to_vec1::<f32>()?;
|
||||
|
||||
// Tokens 1, 2, 3 should be penalized (divided by 2.0)
|
||||
let expected = vec![1.0f32, 1.0, 1.5, 2.0, 5.0]; // [1.0, 2.0/2.0, 3.0/2.0, 4.0/2.0, 5.0]
|
||||
assert_eq!(result_data, expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Test apply_cached_repeat_penalty caching behavior
|
||||
#[test]
|
||||
fn test_apply_cached_repeat_penalty_caching() -> Result<()> {
|
||||
let device = Device::Cpu;
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 1u32, 1u32]; // Repeated token should use cache
|
||||
|
||||
struct MockTextGeneration {
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
logits: Tensor,
|
||||
tokens: &[u32],
|
||||
) -> Result<(Tensor, std::time::Duration)> {
|
||||
let repeat_start = std::time::Instant::now();
|
||||
|
||||
if self.repeat_penalty == 1.0 {
|
||||
return Ok((logits, repeat_start.elapsed()));
|
||||
}
|
||||
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
let penalty_tokens = &tokens[start_at..];
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
for &token_id in penalty_tokens {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
|
||||
logits_vec[token_id] = *penalized_score;
|
||||
} else {
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
let penalized_score = sign * score / self.repeat_penalty;
|
||||
logits_vec[token_id] = penalized_score;
|
||||
self.penalty_cache.insert(token_id, penalized_score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
let result = new_logits.reshape(shape)?;
|
||||
|
||||
let elapsed = repeat_start.elapsed();
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 2.0,
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
// First call should cache the penalty for token 1
|
||||
let (_result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
// Cache should contain the penalized value for token 1
|
||||
assert!(mock_gen.penalty_cache.contains_key(&1));
|
||||
assert_eq!(mock_gen.penalty_cache.get(&1), Some(&1.0)); // 2.0 / 2.0 = 1.0
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Test edge case: empty tokens array
|
||||
#[test]
|
||||
fn test_apply_cached_repeat_penalty_empty_tokens() -> Result<()> {
|
||||
let device = Device::Cpu;
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens: Vec<u32> = vec![]; // Empty tokens
|
||||
|
||||
struct MockTextGeneration {
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
logits: Tensor,
|
||||
tokens: &[u32],
|
||||
) -> Result<(Tensor, std::time::Duration)> {
|
||||
let repeat_start = std::time::Instant::now();
|
||||
|
||||
if self.repeat_penalty == 1.0 {
|
||||
return Ok((logits, repeat_start.elapsed()));
|
||||
}
|
||||
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
let penalty_tokens = &tokens[start_at..];
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
for &token_id in penalty_tokens {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
|
||||
logits_vec[token_id] = *penalized_score;
|
||||
} else {
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
let penalized_score = sign * score / self.repeat_penalty;
|
||||
logits_vec[token_id] = penalized_score;
|
||||
self.penalty_cache.insert(token_id, penalized_score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
let result = new_logits.reshape(shape)?;
|
||||
|
||||
let elapsed = repeat_start.elapsed();
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 2.0,
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
let (result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
let result_data = result_logits.to_vec1::<f32>()?;
|
||||
|
||||
// With empty tokens, logits should be unchanged
|
||||
assert_eq!(result_data, logits_data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Test edge case: out-of-bounds token IDs
|
||||
#[test]
|
||||
fn test_apply_cached_repeat_penalty_out_of_bounds() -> Result<()> {
|
||||
let device = Device::Cpu;
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 5u32, 10u32]; // Token 5 and 10 are out of bounds
|
||||
|
||||
struct MockTextGeneration {
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
logits: Tensor,
|
||||
tokens: &[u32],
|
||||
) -> Result<(Tensor, std::time::Duration)> {
|
||||
let repeat_start = std::time::Instant::now();
|
||||
|
||||
if self.repeat_penalty == 1.0 {
|
||||
return Ok((logits, repeat_start.elapsed()));
|
||||
}
|
||||
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
let penalty_tokens = &tokens[start_at..];
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
for &token_id in penalty_tokens {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
|
||||
logits_vec[token_id] = *penalized_score;
|
||||
} else {
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
let penalized_score = sign * score / self.repeat_penalty;
|
||||
logits_vec[token_id] = penalized_score;
|
||||
self.penalty_cache.insert(token_id, penalized_score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
let result = new_logits.reshape(shape)?;
|
||||
|
||||
let elapsed = repeat_start.elapsed();
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 2.0,
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
let (result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
let result_data = result_logits.to_vec1::<f32>()?;
|
||||
|
||||
// Only token 1 should be penalized, out-of-bounds tokens should be ignored
|
||||
let expected = vec![1.0f32, 1.0, 3.0]; // [1.0, 2.0/2.0, 3.0]
|
||||
assert_eq!(result_data, expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Test the actual apply_cached_repeat_penalty method from TextGeneration
|
||||
// This test creates a TextGeneration instance with minimal dependencies to test the real method
|
||||
#[test]
|
||||
fn test_actual_apply_cached_repeat_penalty_implementation() -> Result<()> {
|
||||
// Since creating a real TextGeneration instance requires a Model which needs model weights,
|
||||
// we'll create a test that demonstrates the method is now public and can be accessed.
|
||||
// The comprehensive functionality testing is already covered by the mock tests above.
|
||||
|
||||
// Test data setup
|
||||
let device = Device::Cpu;
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 2u32, 3u32];
|
||||
|
||||
// Test that we can create the necessary components
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
|
||||
// The method is now public as confirmed by making it pub fn apply_cached_repeat_penalty
|
||||
// This test verifies the method signature and that it's accessible from external code
|
||||
|
||||
// We could create a TextGeneration instance if we had a way to mock the Model,
|
||||
// but for now we confirm that the existing mock tests cover the functionality
|
||||
// and the method is properly exposed as public
|
||||
|
||||
println!("apply_cached_repeat_penalty method is now public and accessible for testing");
|
||||
assert!(true);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Integration test that demonstrates the method usage pattern
|
||||
#[test]
|
||||
fn test_apply_cached_repeat_penalty_usage_pattern() -> Result<()> {
|
||||
// This test demonstrates how the apply_cached_repeat_penalty method would be used
|
||||
// in practice, even though we can't create a full TextGeneration instance in unit tests
|
||||
|
||||
let device = Device::Cpu;
|
||||
let logits_data = vec![1.5f32, 2.5, 3.5, 4.5, 5.5];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 2u32, 1u32, 3u32]; // Repeated token 1 to test caching
|
||||
|
||||
// Test parameters that would be used with TextGeneration
|
||||
let repeat_penalty = 1.2f32;
|
||||
let repeat_last_n = 3usize;
|
||||
let mut penalty_cache: HashMap<usize, f32> = HashMap::new();
|
||||
|
||||
// Simulate the method's logic to verify it works as expected
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
if repeat_penalty != 1.0 {
|
||||
let start_at = tokens.len().saturating_sub(repeat_last_n);
|
||||
let penalty_tokens = &tokens[start_at..];
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
for &token_id in penalty_tokens {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
if let Some(_cached_score) = penalty_cache.get(&token_id) {
|
||||
// Cache hit simulation
|
||||
} else {
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
let penalized_score = sign * score / repeat_penalty;
|
||||
penalty_cache.insert(token_id, penalized_score);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _duration = start_time.elapsed();
|
||||
|
||||
// Verify that tokens were processed correctly
|
||||
assert!(penalty_cache.contains_key(&1)); // Token 1 should be cached
|
||||
assert!(penalty_cache.contains_key(&2)); // Token 2 should be cached
|
||||
assert!(penalty_cache.contains_key(&3)); // Token 3 should be cached
|
||||
|
||||
println!("Successfully demonstrated apply_cached_repeat_penalty usage pattern");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Note: Testing the actual text generation functionality would require
|
||||
// integration tests with real models, which is beyond the scope of these unit tests.
|
||||
// The tests above focus on the components that can be tested in isolation.
|
||||
}
|
@@ -1,135 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use inference_engine::token_output_stream::TokenOutputStream;
|
||||
use std::path::PathBuf;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Helper function to create a simple tokenizer for testing
|
||||
fn create_test_tokenizer() -> Result<Tokenizer> {
|
||||
// Create a simple tokenizer from the pretrained model
|
||||
// This uses the tokenizer from the Hugging Face hub
|
||||
let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap();
|
||||
Ok(tokenizer)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new_token_output_stream() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
// Check that the token stream was created successfully
|
||||
assert!(token_stream.tokenizer().get_vocab(true).len() > 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
// Add a token
|
||||
let token_id = token_stream.get_token("<eos>").unwrap();
|
||||
token_stream.next_token(token_id)?;
|
||||
|
||||
// Clear the stream
|
||||
token_stream.clear();
|
||||
|
||||
// Check that the stream is empty by trying to decode all
|
||||
let decoded = token_stream.decode_all()?;
|
||||
assert_eq!(decoded, "");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_token() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
// Get a token that should exist
|
||||
let eos_token = token_stream.get_token("<eos>");
|
||||
assert!(eos_token.is_some());
|
||||
|
||||
// Get a token that shouldn't exist
|
||||
let nonexistent_token = token_stream.get_token("<this_token_does_not_exist>");
|
||||
assert!(nonexistent_token.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_token_and_decode() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
// Get some tokens
|
||||
let hello_tokens = token_stream
|
||||
.tokenizer()
|
||||
.encode("Hello world", true)
|
||||
.unwrap();
|
||||
let token_ids = hello_tokens.get_ids();
|
||||
|
||||
// Add tokens one by one
|
||||
let mut output = String::new();
|
||||
for &token_id in token_ids {
|
||||
if let Some(text) = token_stream.next_token(token_id)? {
|
||||
output.push_str(&text);
|
||||
}
|
||||
}
|
||||
|
||||
// Get any remaining text
|
||||
if let Some(rest) = token_stream.decode_rest()? {
|
||||
output.push_str(&rest);
|
||||
}
|
||||
|
||||
// Check the output
|
||||
assert!(!output.is_empty());
|
||||
assert_eq!(output.trim(), "Hello world");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_all() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
// Get some tokens
|
||||
let hello_tokens = token_stream
|
||||
.tokenizer()
|
||||
.encode("Hello world", true)
|
||||
.unwrap();
|
||||
let token_ids = hello_tokens.get_ids();
|
||||
|
||||
// Add tokens one by one
|
||||
for &token_id in token_ids {
|
||||
token_stream.next_token(token_id)?;
|
||||
}
|
||||
|
||||
// Decode all
|
||||
let decoded = token_stream.decode_all()?;
|
||||
|
||||
// Check the output
|
||||
assert_eq!(decoded.trim(), "Hello world");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_into_inner() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
// Get the inner tokenizer
|
||||
let inner_tokenizer = token_stream.into_inner();
|
||||
|
||||
// Check that the inner tokenizer works
|
||||
let encoded = inner_tokenizer.encode("Test", true).unwrap();
|
||||
assert!(encoded.get_ids().len() > 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user