mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00

Removed `test_request.sh`, deprecated functionality, and unused imports; introduced a new CLI tool (`cli.ts`) for testing inference engine and adjusted handling of non-streaming/streaming chat completions. - Add CPU fallback support for text generation when primary device is unsupported - Introduce `execute_with_fallback` method to handle device compatibility and shape mismatch errors - Extend unit tests to reproduce tensor shape mismatch errors specific to model configurations - Increase HTTP timeout limits in `curl_chat_stream.sh` script for reliable API testing chat completion endpoint functions with gemma3 (no streaming) Add benchmarking guide with HTML reporting, Leptos chat crate, and middleware for metrics tracking
129 lines
3.9 KiB
Rust
129 lines
3.9 KiB
Rust
use legacy_inference_engine::token_output_stream::TokenOutputStream;
|
|
use tokenizers::Tokenizer;
|
|
use std::path::PathBuf;
|
|
use anyhow::Result;
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
// Helper function to create a simple tokenizer for testing
|
|
fn create_test_tokenizer() -> Result<Tokenizer> {
|
|
// Create a simple tokenizer from the pretrained model
|
|
// This uses the tokenizer from the Hugging Face hub
|
|
let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap();
|
|
Ok(tokenizer)
|
|
}
|
|
|
|
#[test]
|
|
fn test_new_token_output_stream() -> Result<()> {
|
|
let tokenizer = create_test_tokenizer()?;
|
|
let token_stream = TokenOutputStream::new(tokenizer);
|
|
|
|
// Check that the token stream was created successfully
|
|
assert!(token_stream.tokenizer().get_vocab(true).len() > 0);
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn test_clear() -> Result<()> {
|
|
let tokenizer = create_test_tokenizer()?;
|
|
let mut token_stream = TokenOutputStream::new(tokenizer);
|
|
|
|
// Add a token
|
|
let token_id = token_stream.get_token("<eos>").unwrap();
|
|
token_stream.next_token(token_id)?;
|
|
|
|
// Clear the stream
|
|
token_stream.clear();
|
|
|
|
// Check that the stream is empty by trying to decode all
|
|
let decoded = token_stream.decode_all()?;
|
|
assert_eq!(decoded, "");
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn test_get_token() -> Result<()> {
|
|
let tokenizer = create_test_tokenizer()?;
|
|
let token_stream = TokenOutputStream::new(tokenizer);
|
|
|
|
// Get a token that should exist
|
|
let eos_token = token_stream.get_token("<eos>");
|
|
assert!(eos_token.is_some());
|
|
|
|
// Get a token that shouldn't exist
|
|
let nonexistent_token = token_stream.get_token("<this_token_does_not_exist>");
|
|
assert!(nonexistent_token.is_none());
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn test_next_token_and_decode() -> Result<()> {
|
|
let tokenizer = create_test_tokenizer()?;
|
|
let mut token_stream = TokenOutputStream::new(tokenizer);
|
|
|
|
// Get some tokens
|
|
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
|
|
let token_ids = hello_tokens.get_ids();
|
|
|
|
// Add tokens one by one
|
|
let mut output = String::new();
|
|
for &token_id in token_ids {
|
|
if let Some(text) = token_stream.next_token(token_id)? {
|
|
output.push_str(&text);
|
|
}
|
|
}
|
|
|
|
// Get any remaining text
|
|
if let Some(rest) = token_stream.decode_rest()? {
|
|
output.push_str(&rest);
|
|
}
|
|
|
|
// Check the output
|
|
assert!(!output.is_empty());
|
|
assert_eq!(output.trim(), "Hello world");
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn test_decode_all() -> Result<()> {
|
|
let tokenizer = create_test_tokenizer()?;
|
|
let mut token_stream = TokenOutputStream::new(tokenizer);
|
|
|
|
// Get some tokens
|
|
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
|
|
let token_ids = hello_tokens.get_ids();
|
|
|
|
// Add tokens one by one
|
|
for &token_id in token_ids {
|
|
token_stream.next_token(token_id)?;
|
|
}
|
|
|
|
// Decode all
|
|
let decoded = token_stream.decode_all()?;
|
|
|
|
// Check the output
|
|
assert_eq!(decoded.trim(), "Hello world");
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn test_into_inner() -> Result<()> {
|
|
let tokenizer = create_test_tokenizer()?;
|
|
let token_stream = TokenOutputStream::new(tokenizer);
|
|
|
|
// Get the inner tokenizer
|
|
let inner_tokenizer = token_stream.into_inner();
|
|
|
|
// Check that the inner tokenizer works
|
|
let encoded = inner_tokenizer.encode("Test", true).unwrap();
|
|
assert!(encoded.get_ids().len() > 0);
|
|
|
|
Ok(())
|
|
}
|
|
} |