use anyhow::Result; use candle_transformers::generation::LogitsProcessor; use legacy_inference_engine::model::Which; use legacy_inference_engine::token_output_stream::TokenOutputStream; use tokenizers::Tokenizer; #[cfg(test)] mod tests { use super::*; // Helper function to create a simple tokenizer for testing fn create_test_tokenizer() -> Result { // Create a simple tokenizer from the pretrained model // This uses the tokenizer from the Hugging Face hub let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap(); Ok(tokenizer) } // Test the Which enum's to_model_id method #[test] fn test_which_model_id() { assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b"); assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it"); } // Test the Which enum's is_instruct_model method #[test] fn test_which_is_instruct() { assert!(!Which::Base2B.is_instruct_model()); assert!(Which::Instruct7B.is_instruct_model()); } // Test the Which enum's is_v3_model method #[test] fn test_which_is_v3() { assert!(!Which::Base2B.is_v3_model()); assert!(Which::BaseV3_1B.is_v3_model()); } // Test the TokenOutputStream functionality #[test] fn test_token_output_stream() -> Result<()> { let tokenizer = create_test_tokenizer()?; let mut token_stream = TokenOutputStream::new(tokenizer); // Test encoding and decoding let text = "Hello, world!"; let encoded = token_stream.tokenizer().encode(text, true).unwrap(); let token_ids = encoded.get_ids(); // Add tokens one by one for &token_id in token_ids { token_stream.next_token(token_id)?; } // Decode all and check let decoded = token_stream.decode_all()?; assert_eq!(decoded.trim(), text); Ok(()) } // Test the LogitsProcessor #[test] fn test_logits_processor() -> Result<()> { // Create a LogitsProcessor with default settings let seed = 42; let temp = Some(0.8); let top_p = Some(0.9); let logits_processor = LogitsProcessor::new(seed, temp, top_p); // Create a simple logits tensor // In a real test, we would create a tensor with known values and verify // that sampling produces expected results // For now, we'll just verify that the LogitsProcessor can be created assert!(true); Ok(()) } // Test the TextGeneration constructor #[test] fn test_text_generation_constructor() -> Result<()> { // We can't easily create a Model instance for testing, // but we can test that the constructor compiles and the types are correct // In a real test with a mock Model, we would: // 1. Create a mock model // 2. Create a tokenizer // 3. Call TextGeneration::new // 4. Verify the properties of the created instance // For now, we'll just verify that the code compiles assert!(true); Ok(()) } // Note: Testing the actual text generation functionality would require // integration tests with real models, which is beyond the scope of these unit tests. // The tests above focus on the components that can be tested in isolation. }