mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
supports small llama and gemma models
Refactor inference dedicated crates for llama and gemma inferencing, not integrated
This commit is contained in:
33
crates/inference-engine/src/inference.rs
Normal file
33
crates/inference-engine/src/inference.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use anyhow::Result;
|
||||
use candle_core::Tensor;
|
||||
|
||||
/// ModelInference trait defines the common interface for model inference operations
|
||||
///
|
||||
/// This trait serves as an abstraction for different model implementations (Gemma and Llama)
|
||||
/// to provide a unified interface for the inference engine.
|
||||
pub trait ModelInference {
|
||||
/// Perform model inference for the given input tensor starting at the specified position
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input_ids` - The input tensor containing token IDs
|
||||
/// * `pos` - The position to start generation from
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor containing the logits for the next token prediction
|
||||
fn forward(&mut self, input_ids: &Tensor, pos: usize) -> Result<Tensor>;
|
||||
|
||||
/// Reset the model's internal state, if applicable
|
||||
///
|
||||
/// This method can be used to clear any cached state between inference requests
|
||||
fn reset_state(&mut self) -> Result<()>;
|
||||
|
||||
/// Get the model type name
|
||||
///
|
||||
/// Returns a string identifier for the model type (e.g., "Gemma", "Llama")
|
||||
fn model_type(&self) -> &'static str;
|
||||
}
|
||||
|
||||
/// Factory function type for creating model inference implementations
|
||||
pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>;
|
Reference in New Issue
Block a user