mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
run cargo fmt
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
@@ -11,13 +10,13 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
use std::sync::mpsc::{self, Receiver, Sender};
|
||||
use std::thread;
|
||||
use tokenizers::Tokenizer;
|
||||
use utils::hub_load_safetensors;
|
||||
use utils::token_output_stream::TokenOutputStream;
|
||||
use std::str::FromStr;
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
pub enum WhichModel {
|
||||
@@ -367,7 +366,9 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let filenames = match cfg.model {
|
||||
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => vec![repo.get("model.safetensors")?],
|
||||
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("Retrieved files in {:?}", start.elapsed());
|
||||
@@ -396,7 +397,8 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
| Some(WhichModel::InstructV2_2B)
|
||||
| Some(WhichModel::BaseV2_9B)
|
||||
| Some(WhichModel::InstructV2_9B)
|
||||
| None => { // default to V2 model
|
||||
| None => {
|
||||
// default to V2 model
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V2(model)
|
||||
|
@@ -105,7 +105,9 @@ fn discover_services(workspace_path: &str) -> Result<Vec<ServiceInfo>> {
|
||||
.into_iter()
|
||||
.filter_map(|e| e.ok())
|
||||
{
|
||||
if entry.file_name() == "Cargo.toml" && entry.path() != workspace_root.join("../../../Cargo.toml") {
|
||||
if entry.file_name() == "Cargo.toml"
|
||||
&& entry.path() != workspace_root.join("../../../Cargo.toml")
|
||||
{
|
||||
if let Ok(service_info) = parse_cargo_toml(entry.path()) {
|
||||
services.push(service_info);
|
||||
}
|
||||
|
@@ -102,7 +102,7 @@ impl Default for LlamaInferenceConfig {
|
||||
max_tokens: 512,
|
||||
|
||||
// Performance flags
|
||||
no_kv_cache: false, // keep cache ON for speed
|
||||
no_kv_cache: false, // keep cache ON for speed
|
||||
use_flash_attn: false, // great speed boost if supported
|
||||
|
||||
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
|
||||
|
@@ -1,5 +1,5 @@
|
||||
use candle_transformers::models::mimi::candle;
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
use candle_transformers::models::mimi::candle;
|
||||
|
||||
pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406];
|
||||
pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];
|
||||
|
@@ -8,8 +8,10 @@ pub mod coco_classes;
|
||||
pub mod imagenet;
|
||||
pub mod token_output_stream;
|
||||
pub mod wav;
|
||||
use candle_core::{Device, Tensor, utils::{cuda_is_available, metal_is_available}};
|
||||
|
||||
use candle_core::{
|
||||
utils::{cuda_is_available, metal_is_available},
|
||||
Device, Tensor,
|
||||
};
|
||||
|
||||
pub fn device(cpu: bool) -> Result<Device, anyhow::Error> {
|
||||
if cpu {
|
||||
@@ -126,7 +128,7 @@ pub fn hub_load_safetensors(
|
||||
repo.get(v)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
|
||||
})
|
||||
.collect::<Result<Vec<_>, std::io::Error, >>()?;
|
||||
.collect::<Result<Vec<_>, std::io::Error>>()?;
|
||||
Ok(safetensors_files)
|
||||
}
|
||||
|
||||
@@ -136,7 +138,8 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
|
||||
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
|
||||
let path = path.as_ref();
|
||||
let jsfile = std::fs::File::open(path.join(json_file))?;
|
||||
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
||||
let json: serde_json::Value =
|
||||
serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
||||
let weight_map = match json.get("weight_map") {
|
||||
None => anyhow::bail!("no weight map in {json_file:?}"),
|
||||
Some(serde_json::Value::Object(map)) => map,
|
||||
|
Reference in New Issue
Block a user