align dependencies across inference features

This commit is contained in:
geoffsee
2025-08-31 10:49:04 -04:00
parent f5d2a85f2e
commit e6c417bd83
10 changed files with 17 additions and 3009 deletions

View File

@@ -1,9 +1,6 @@
// Expose modules for testing and library usage
pub mod model;
pub mod openai_types;
pub mod text_generation;
pub mod token_output_stream;
pub mod utilities_lib;
// pub mod cli;
pub mod inference;
pub mod server;
@@ -12,8 +9,6 @@ pub mod server;
pub use inference::ModelInference;
pub use model::{Model, Which};
pub use server::{create_router, AppState};
pub use text_generation::TextGeneration;
pub use token_output_stream::TokenOutputStream;
use std::env;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

View File

@@ -1,6 +1,7 @@
use either::Either;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use serde_json::json;
use utoipa::ToSchema;
/// Inner content structure for messages that can be either a string or key-value pairs

File diff suppressed because it is too large Load Diff

View File

@@ -1,87 +0,0 @@
use candle_core::Result;
/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a
/// streaming way rather than having to wait for the full decoding.
pub struct TokenOutputStream {
tokenizer: tokenizers::Tokenizer,
tokens: Vec<u32>,
prev_index: usize,
current_index: usize,
}
impl TokenOutputStream {
pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
Self {
tokenizer,
tokens: Vec::new(),
prev_index: 0,
current_index: 0,
}
}
pub fn into_inner(self) -> tokenizers::Tokenizer {
self.tokenizer
}
fn decode(&self, tokens: &[u32]) -> Result<String> {
match self.tokenizer.decode(tokens, true) {
Ok(str) => Ok(str),
Err(err) => candle_core::bail!("cannot decode: {err}"),
}
}
// https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68
pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
let prev_text = if self.tokens.is_empty() {
String::new()
} else {
let tokens = &self.tokens[self.prev_index..self.current_index];
self.decode(tokens)?
};
self.tokens.push(token);
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() {
// Modified to include all tokens, not just alphanumeric ones
let text = text.split_at(prev_text.len());
self.prev_index = self.current_index;
self.current_index = self.tokens.len();
Ok(Some(text.1.to_string()))
} else {
Ok(None)
}
}
pub fn decode_rest(&self) -> Result<Option<String>> {
let prev_text = if self.tokens.is_empty() {
String::new()
} else {
let tokens = &self.tokens[self.prev_index..self.current_index];
self.decode(tokens)?
};
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() {
let text = text.split_at(prev_text.len());
Ok(Some(text.1.to_string()))
} else {
Ok(None)
}
}
pub fn decode_all(&self) -> Result<String> {
self.decode(&self.tokens)
}
pub fn get_token(&self, token_s: &str) -> Option<u32> {
self.tokenizer.get_vocab(true).get(token_s).copied()
}
pub fn tokenizer(&self) -> &tokenizers::Tokenizer {
&self.tokenizer
}
pub fn clear(&mut self) {
self.tokens.clear();
self.prev_index = 0;
self.current_index = 0;
}
}

View File

@@ -1,168 +0,0 @@
use candle_core::utils::{cuda_is_available, metal_is_available};
use candle_core::{Device, Result, Tensor};
pub fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
println!(
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
);
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
}
Ok(Device::Cpu)
}
}
pub fn load_image<P: AsRef<std::path::Path>>(
p: P,
resize_longest: Option<usize>,
) -> Result<(Tensor, usize, usize)> {
let img = image::ImageReader::open(p)?
.decode()
.map_err(candle_core::Error::wrap)?;
let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
let img = match resize_longest {
None => img,
Some(resize_longest) => {
let (height, width) = (img.height(), img.width());
let resize_longest = resize_longest as u32;
let (height, width) = if height < width {
let h = (resize_longest * height) / width;
(h, resize_longest)
} else {
let w = (resize_longest * width) / height;
(resize_longest, w)
};
img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
}
};
let (height, width) = (img.height() as usize, img.width() as usize);
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
Ok((data, initial_h, initial_w))
}
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
p: P,
width: usize,
height: usize,
) -> Result<Tensor> {
let img = image::ImageReader::open(p)?
.decode()
.map_err(candle_core::Error::wrap)?
.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let data = img.into_raw();
Tensor::from_vec(data, (width, height, 3), &Device::Cpu)?.permute((2, 0, 1))
}
/// Saves an image to disk using the image crate, this expects an input with shape
/// (c, height, width).
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
let p = p.as_ref();
let (channel, height, width) = img.dims3()?;
if channel != 3 {
candle_core::bail!("save_image expects an input of shape (3, height, width)")
}
let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => candle_core::bail!("error saving image {p:?}"),
};
image.save(p).map_err(candle_core::Error::wrap)?;
Ok(())
}
pub fn save_image_resize<P: AsRef<std::path::Path>>(
img: &Tensor,
p: P,
h: usize,
w: usize,
) -> Result<()> {
let p = p.as_ref();
let (channel, height, width) = img.dims3()?;
if channel != 3 {
candle_core::bail!("save_image expects an input of shape (3, height, width)")
}
let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => candle_core::bail!("error saving image {p:?}"),
};
let image = image::DynamicImage::from(image);
let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
image.save(p).map_err(candle_core::Error::wrap)?;
Ok(())
}
/// Loads the safetensors files for a model from the hub based on a json index file.
pub fn hub_load_safetensors(
repo: &hf_hub::api::sync::ApiRepo,
json_file: &str,
) -> Result<Vec<std::path::PathBuf>> {
let json_file = repo.get(json_file).map_err(candle_core::Error::wrap)?;
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value =
serde_json::from_reader(&json_file).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle_core::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => candle_core::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| repo.get(v).map_err(candle_core::Error::wrap))
.collect::<Result<Vec<_>>>()?;
Ok(safetensors_files)
}
pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
path: P,
json_file: &str,
) -> Result<Vec<std::path::PathBuf>> {
let path = path.as_ref();
let jsfile = std::fs::File::open(path.join(json_file))?;
let json: serde_json::Value =
serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle_core::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => candle_core::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file);
}
}
let safetensors_files: Vec<_> = safetensors_files
.into_iter()
.map(|v| path.join(v))
.collect();
Ok(safetensors_files)
}