From b7f02eb4fb2c9a1e3299a0c291cb0de8fd9606b6 Mon Sep 17 00:00:00 2001 From: geoffsee <> Date: Wed, 18 Jun 2025 14:33:07 -0400 Subject: [PATCH] fix mlx omni provider --- bun.lock | 1 + packages/client/server/server.ts | 3 + packages/server/api-router.ts | 13 +++- packages/server/package.json | 1 + .../server/providers/_ProviderRepository.ts | 10 ++- packages/server/providers/mlx-omni.ts | 74 ++++++++++++------- packages/server/services/AssetService.ts | 4 +- packages/server/services/ChatService.ts | 12 ++- packages/server/tsconfig.json | 3 +- 9 files changed, 78 insertions(+), 43 deletions(-) create mode 100644 packages/client/server/server.ts diff --git a/bun.lock b/bun.lock index 6a3c989..1f89276 100644 --- a/bun.lock +++ b/bun.lock @@ -85,6 +85,7 @@ "devDependencies": { "@anthropic-ai/sdk": "^0.32.1", "@cloudflare/workers-types": "^4.20241205.0", + "@open-gsio/client": "workspace:*", "@open-gsio/env": "workspace:*", "@testing-library/jest-dom": "^6.4.2", "@testing-library/user-event": "^14.5.2", diff --git a/packages/client/server/server.ts b/packages/client/server/server.ts new file mode 100644 index 0000000..324bfbf --- /dev/null +++ b/packages/client/server/server.ts @@ -0,0 +1,3 @@ +import { renderPage } from "vike/server"; + +export default renderPage; \ No newline at end of file diff --git a/packages/server/api-router.ts b/packages/server/api-router.ts index 60e2cde..e2049a4 100644 --- a/packages/server/api-router.ts +++ b/packages/server/api-router.ts @@ -1,5 +1,5 @@ import { Router, withParams } from "itty-router"; -import { createRequestContext } from "./RequestContext.ts"; +import { createRequestContext } from "./RequestContext"; export function createRouter() { return ( @@ -57,13 +57,18 @@ export function createRouter() { // return documentService.handleGetDocument(r) // }) - .all("/api/metrics*", async (r, e, c) => { + .get("/api/metrics*", async (r, e, c) => { const { metricsService } = createRequestContext(e, c); return metricsService.handleMetricsRequest(r); }) - // renders the app - .get("^(?!/api/).*$", async (r, e, c) => { + .post("/api/metrics*", async (r, e, c) => { + const { metricsService } = createRequestContext(e, c); + return metricsService.handleMetricsRequest(r); + }) + + // renders the app + .all("^(?!/api/)(?!/assets/).*$", async (r, e, c) => { const { assetService } = createRequestContext(e, c); diff --git a/packages/server/package.json b/packages/server/package.json index 157e38f..03d52a4 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -10,6 +10,7 @@ }, "devDependencies": { "@open-gsio/env": "workspace:*", + "@open-gsio/client": "workspace:*", "@anthropic-ai/sdk": "^0.32.1", "bun-sqlite-key-value": "^1.13.1", "@cloudflare/workers-types": "^4.20241205.0", diff --git a/packages/server/providers/_ProviderRepository.ts b/packages/server/providers/_ProviderRepository.ts index 1bf6a1b..a801b75 100644 --- a/packages/server/providers/_ProviderRepository.ts +++ b/packages/server/providers/_ProviderRepository.ts @@ -1,7 +1,12 @@ +export type GenericEnv = Record; + export class ProviderRepository { #providers: {name: string, key: string, endpoint: string}[] = []; - constructor(env: Record) { + #env: Record; + + constructor(env: GenericEnv) { + this.#env = env this.setProviders(env); } @@ -19,7 +24,8 @@ export class ProviderRepository { mlx: "http://localhost:10240/v1", } - static async getModelFamily(model, env: Env) { + static async getModelFamily(model: any, env: Env) { + console.log(env); const allModels = await env.KV_STORAGE.get("supportedModels"); const models = JSON.parse(allModels); const modelData = models.filter(m => m.id === model) diff --git a/packages/server/providers/mlx-omni.ts b/packages/server/providers/mlx-omni.ts index f3abdc8..e71f0fa 100644 --- a/packages/server/providers/mlx-omni.ts +++ b/packages/server/providers/mlx-omni.ts @@ -1,39 +1,71 @@ import { OpenAI } from "openai"; -import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider.ts"; +import { Utils } from "../lib/utils"; +import { ChatCompletionCreateParamsStreaming } from "openai/resources/chat/completions/completions"; +import { BaseChatProvider, CommonProviderParams } from "./chat-stream-provider"; export class MlxOmniChatProvider extends BaseChatProvider { getOpenAIClient(param: CommonProviderParams): OpenAI { return new OpenAI({ - baseURL: param.env.MLX_API_ENDPOINT ?? "http://localhost:10240", + baseURL: "http://localhost:10240", apiKey: param.env.MLX_API_KEY, }); } - getStreamParams(param: CommonProviderParams, safeMessages: any[]): any { - const tuningParams = { - temperature: 0.75, + getStreamParams(param: CommonProviderParams, safeMessages: any[]): ChatCompletionCreateParamsStreaming { + const baseTuningParams = { + temperature: 0.86, + top_p: 0.98, + presence_penalty: 0.1, + frequency_penalty: 0.3, + max_tokens: param.maxTokens as number, }; const getTuningParams = () => { - return tuningParams; + return baseTuningParams; }; - return { + let completionRequest: ChatCompletionCreateParamsStreaming = { model: param.model, - messages: safeMessages, stream: true, - ...getTuningParams(), + messages: safeMessages }; + + const client = this.getOpenAIClient(param); + const isLocal = client.baseURL.includes("localhost"); + + if(isLocal) { + completionRequest["messages"] = Utils.normalizeWithBlanks(safeMessages); + completionRequest["stream_options"] = { + include_usage: true + }; + } else { + completionRequest = {...completionRequest, ...getTuningParams()}; + } + + return completionRequest; } async processChunk(chunk: any, dataCallback: (data: any) => void): Promise { - if (chunk.choices && chunk.choices[0]?.finish_reason === "stop") { - dataCallback({ type: "chat", data: chunk }); - return true; + const isLocal = chunk.usage !== undefined; + + if (isLocal && chunk.usage) { + dataCallback({ + type: "chat", + data: { + choices: [ + { + delta: { content: "" }, + logprobs: null, + finish_reason: "stop", + }, + ], + }, + }); + return true; // Break the stream } dataCallback({ type: "chat", data: chunk }); - return false; + return false; // Continue the stream } } @@ -41,16 +73,7 @@ export class MlxOmniChatSdk { private static provider = new MlxOmniChatProvider(); static async handleMlxOmniStream( - ctx: { - openai: OpenAI; - systemPrompt: any; - preprocessedContext: any; - maxTokens: unknown | number | undefined; - messages: any; - disableWebhookGeneration: boolean; - model: any; - env: Env; - }, + ctx: any, dataCallback: (data: any) => any, ) { if (!ctx.messages?.length) { @@ -62,10 +85,9 @@ export class MlxOmniChatSdk { systemPrompt: ctx.systemPrompt, preprocessedContext: ctx.preprocessedContext, maxTokens: ctx.maxTokens, - messages: ctx.messages, + messages: Utils.normalizeWithBlanks(ctx.messages), model: ctx.model, - env: ctx.env, - disableWebhookGeneration: ctx.disableWebhookGeneration, + env: ctx.env }, dataCallback, ); diff --git a/packages/server/services/AssetService.ts b/packages/server/services/AssetService.ts index b8babe1..60cbc2b 100644 --- a/packages/server/services/AssetService.ts +++ b/packages/server/services/AssetService.ts @@ -1,5 +1,5 @@ import { types } from "mobx-state-tree"; -import { renderPage } from "vike/server"; + export default types .model("StaticAssetStore", {}) @@ -17,7 +17,7 @@ export default types async handleSsr( url: string, headers: Headers, - env: Vike.PageContext["env"], + env: Vike.PageContext.env, ) { console.log("handleSsr"); const pageContextInit = { diff --git a/packages/server/services/ChatService.ts b/packages/server/services/ChatService.ts index 067b6cf..b6c4db4 100644 --- a/packages/server/services/ChatService.ts +++ b/packages/server/services/ChatService.ts @@ -13,7 +13,7 @@ import {XaiChatSdk} from "../providers/xai"; import {CerebrasSdk} from "../providers/cerebras"; import {CloudflareAISdk} from "../providers/cloudflareAi"; import {OllamaChatSdk} from "../providers/ollama"; -import {MlxOmniChatSdk} from "../providers/mlx-omni"; +import {MlxOmniChatProvider, MlxOmniChatSdk} from "../providers/mlx-omni"; import {ProviderRepository} from "../providers/_ProviderRepository"; export interface StreamParams { @@ -126,7 +126,7 @@ const ChatService = types // ----- Helpers ---------------------------------------------------------- const logger = console; - const useCache = false; + const useCache = true; if(useCache) { // ----- 1. Try cached value --------------------------------------------- @@ -139,9 +139,10 @@ const ChatService = types return new Response(JSON.stringify(parsed), { status: 200 }); } logger.warn('Cache entry malformed – refreshing'); + throw new Error('Malformed cache entry'); } } catch (err) { - logger.error('Error reading/parsing supportedModels cache', err); + logger.warn('Error reading/parsing supportedModels cache', err); } } @@ -260,11 +261,8 @@ const ChatService = types }) { const {streamConfig, streamParams, controller, encoder, streamId} = params; - const useModelFamily = () => { - return ProviderRepository.getModelFamily(streamConfig.model, self.env) - } - const modelFamily = await useModelFamily(); + const modelFamily = await ProviderRepository.getModelFamily(streamConfig.model, self.env); const useModelHandler = () => { return modelHandlers[modelFamily] diff --git a/packages/server/tsconfig.json b/packages/server/tsconfig.json index 31ee229..761ecfc 100644 --- a/packages/server/tsconfig.json +++ b/packages/server/tsconfig.json @@ -11,6 +11,5 @@ "moduleResolution": "bundler", "skipLibCheck": true, "jsx": "react-jsx" - }, - "exclude": ["*.test.ts"] + } }