run cargo fmt

This commit is contained in:
geoffsee
2025-09-04 13:45:25 -04:00
parent 1e02b12cda
commit c1c583faab
11 changed files with 241 additions and 170 deletions

View File

@@ -1,4 +1,3 @@
use anyhow::{Error as E, Result};
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
@@ -11,13 +10,13 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
use std::fmt;
use std::str::FromStr;
use std::sync::mpsc::{self, Receiver, Sender};
use std::thread;
use tokenizers::Tokenizer;
use utils::hub_load_safetensors;
use utils::token_output_stream::TokenOutputStream;
use std::str::FromStr;
use std::fmt;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum WhichModel {
@@ -367,7 +366,9 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
let tokenizer_filename = repo.get("tokenizer.json")?;
let config_filename = repo.get("config.json")?;
let filenames = match cfg.model {
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => vec![repo.get("model.safetensors")?],
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
vec![repo.get("model.safetensors")?]
}
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("Retrieved files in {:?}", start.elapsed());
@@ -396,7 +397,8 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
| Some(WhichModel::InstructV2_2B)
| Some(WhichModel::BaseV2_9B)
| Some(WhichModel::InstructV2_9B)
| None => { // default to V2 model
| None => {
// default to V2 model
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
Model::V2(model)

View File

@@ -105,7 +105,9 @@ fn discover_services(workspace_path: &str) -> Result<Vec<ServiceInfo>> {
.into_iter()
.filter_map(|e| e.ok())
{
if entry.file_name() == "Cargo.toml" && entry.path() != workspace_root.join("../../../Cargo.toml") {
if entry.file_name() == "Cargo.toml"
&& entry.path() != workspace_root.join("../../../Cargo.toml")
{
if let Ok(service_info) = parse_cargo_toml(entry.path()) {
services.push(service_info);
}

View File

@@ -102,7 +102,7 @@ impl Default for LlamaInferenceConfig {
max_tokens: 512,
// Performance flags
no_kv_cache: false, // keep cache ON for speed
no_kv_cache: false, // keep cache ON for speed
use_flash_attn: false, // great speed boost if supported
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.

View File

@@ -1,5 +1,5 @@
use candle_transformers::models::mimi::candle;
use candle_core::{Device, Result, Tensor};
use candle_transformers::models::mimi::candle;
pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406];
pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];

View File

@@ -8,8 +8,10 @@ pub mod coco_classes;
pub mod imagenet;
pub mod token_output_stream;
pub mod wav;
use candle_core::{Device, Tensor, utils::{cuda_is_available, metal_is_available}};
use candle_core::{
utils::{cuda_is_available, metal_is_available},
Device, Tensor,
};
pub fn device(cpu: bool) -> Result<Device, anyhow::Error> {
if cpu {
@@ -126,7 +128,7 @@ pub fn hub_load_safetensors(
repo.get(v)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
})
.collect::<Result<Vec<_>, std::io::Error, >>()?;
.collect::<Result<Vec<_>, std::io::Error>>()?;
Ok(safetensors_files)
}
@@ -136,7 +138,8 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
let path = path.as_ref();
let jsfile = std::fs::File::open(path.join(json_file))?;
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
let json: serde_json::Value =
serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => anyhow::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,