/* eslint-disable no-irregular-whitespace */ import { flow, getSnapshot, types } from 'mobx-state-tree'; import OpenAI from 'openai'; import ChatSdk from '../lib/chat-sdk'; import handleStreamData from '../lib/handleStreamData'; import Message from '../models/Message'; import O1Message from '../models/O1Message'; import { ProviderRepository } from '../providers/_ProviderRepository'; import { CerebrasSdk } from '../providers/cerebras'; import { ClaudeChatSdk } from '../providers/claude'; import { CloudflareAISdk } from '../providers/cloudflareAi'; import { FireworksAiChatSdk } from '../providers/fireworks'; import { GoogleChatSdk } from '../providers/google'; import { GroqChatSdk } from '../providers/groq'; import { MlxOmniChatProvider, MlxOmniChatSdk } from '../providers/mlx-omni'; import { OllamaChatSdk } from '../providers/ollama'; import { OpenAiChatSdk } from '../providers/openai'; import { XaiChatSdk } from '../providers/xai'; export interface StreamParams { env: Env; openai: OpenAI; messages: any[]; model: string; systemPrompt: string; preprocessedContext: any; maxTokens: number; } const activeStreamType = types.model({ name: types.optional(types.string, ''), maxTokens: types.optional(types.number, 0), systemPrompt: types.optional(types.string, ''), model: types.optional(types.string, ''), messages: types.optional(types.array(types.frozen()), []), }); const activeStreamsMap = types.map(activeStreamType); const ChatService = types .model('ChatService', { openAIApiKey: types.optional(types.string, ''), openAIBaseURL: types.optional(types.string, ''), activeStreams: types.optional(activeStreamsMap, {}), maxTokens: types.number, systemPrompt: types.string, }) .volatile(self => ({ openai: {} as OpenAI, env: {} as Env, })) .actions(self => { // Helper functions const createMessageInstance = (message: any) => { if (typeof message.content === 'string') { return Message.create({ role: message.role, content: message.content, }); } if (Array.isArray(message.content)) { const m = O1Message.create({ role: message.role, content: message.content.map(item => ({ type: item.type, text: item.text, })), }); return m; } throw new Error('Unsupported message format'); }; const createStreamParams = async ( streamConfig: any, dynamicContext: any, durableObject: any, ): Promise => { return { env: self.env, openai: self.openai, messages: streamConfig.messages.map(createMessageInstance), model: streamConfig.model, systemPrompt: streamConfig.systemPrompt, preprocessedContext: getSnapshot(dynamicContext), maxTokens: await durableObject.dynamicMaxTokens(streamConfig.messages, 2000), }; }; const modelHandlers = { openai: (params: StreamParams, dataHandler: (data: any) => any) => OpenAiChatSdk.handleOpenAiStream(params, dataHandler), groq: (params: StreamParams, dataHandler: (data: any) => any) => GroqChatSdk.handleGroqStream(params, dataHandler), claude: (params: StreamParams, dataHandler: (data: any) => any) => ClaudeChatSdk.handleClaudeStream(params, dataHandler), fireworks: (params: StreamParams, dataHandler: (data: any) => any) => FireworksAiChatSdk.handleFireworksStream(params, dataHandler), google: (params: StreamParams, dataHandler: (data: any) => any) => GoogleChatSdk.handleGoogleStream(params, dataHandler), xai: (params: StreamParams, dataHandler: (data: any) => any) => XaiChatSdk.handleXaiStream(params, dataHandler), cerebras: (params: StreamParams, dataHandler: (data: any) => any) => CerebrasSdk.handleCerebrasStream(params, dataHandler), cloudflareAI: (params: StreamParams, dataHandler: (data: any) => any) => CloudflareAISdk.handleCloudflareAIStream(params, dataHandler), ollama: (params: StreamParams, dataHandler: (data: any) => any) => OllamaChatSdk.handleOllamaStream(params, dataHandler), mlx: (params: StreamParams, dataHandler: (data: any) => any) => MlxOmniChatSdk.handleMlxOmniStream(params, dataHandler), }; return { getSupportedModels: flow(function* (): Generator, Response, unknown> { // ----- Helpers ---------------------------------------------------------- const logger = console; const useCache = true; if (useCache) { // ----- 1. Try cached value --------------------------------------------- try { const cached = yield self.env.KV_STORAGE.get('supportedModels'); if (cached) { const parsed = JSON.parse(cached as string); if (Array.isArray(parsed) && parsed.length > 0) { logger.info('Cache hit – returning supportedModels from KV'); return new Response(JSON.stringify(parsed), { status: 200 }); } logger.warn('Cache entry malformed – refreshing'); throw new Error('Malformed cache entry'); } } catch (err) { logger.warn('Error reading/parsing supportedModels cache', err); } } // ----- 2. Build fresh list --------------------------------------------- const providerRepo = new ProviderRepository(self.env); const providers = providerRepo.getProviders(); const providerModels = new Map(); const modelMeta = new Map(); for (const provider of providers) { if (!provider.key) continue; logger.info(`Fetching models from «${provider.endpoint}»`); const openai = new OpenAI({ apiKey: provider.key, baseURL: provider.endpoint }); // 2‑a. List models try { const listResp = yield openai.models.list(); // <‑‑ async const models = 'data' in listResp ? listResp.data : listResp; providerModels.set(provider.name, models); // 2‑b. Retrieve metadata for (const mdl of models) { try { const meta = yield openai.models.retrieve(mdl.id); // <‑‑ async modelMeta.set(mdl.id, { ...mdl, ...meta }); } catch (err) { // logger.error(`Metadata fetch failed for ${mdl.id}`, err); modelMeta.set(mdl.id, { provider: provider.name, mdl }); } } } catch (err) { logger.error(`Model list failed for provider «${provider.name}»`, err); } } // ----- 3. Merge results ------------------------------------------------- const resultMap = new Map(); for (const [provName, models] of providerModels) { for (const mdl of models) { resultMap.set(mdl.id, { id: mdl.id, provider: provName, ...(modelMeta.get(mdl.id) ?? mdl), }); } } const resultArr = Array.from(resultMap.values()); // ----- 4. Cache fresh list --------------------------------------------- try { yield self.env.KV_STORAGE.put( 'supportedModels', JSON.stringify(resultArr), { expirationTtl: 60 * 60 * 24 }, // 24 h ); logger.info('supportedModels cache refreshed'); } catch (err) { logger.error('KV put failed for supportedModels', err); } // ----- 5. Return -------------------------------------------------------- return new Response(JSON.stringify(resultArr), { status: 200 }); }), setActiveStream(streamId: string, stream: any) { const validStream = { name: stream?.name || 'Unnamed Stream', maxTokens: stream?.maxTokens || 0, systemPrompt: stream?.systemPrompt || '', model: stream?.model || '', messages: stream?.messages || [], }; self.activeStreams.set(streamId, validStream); }, removeActiveStream(streamId: string) { self.activeStreams.delete(streamId); }, setEnv(env: Env) { self.env = env; if (env.OPENAI_API_ENDPOINT && env.OPENAI_API_ENDPOINT.includes('localhost')) { self.openai = new OpenAI({ apiKey: self.env.OPENAI_API_KEY, baseURL: self.env.OPENAI_API_ENDPOINT, }); } else { self.openai = new OpenAI({ apiKey: self.openAIApiKey, baseURL: self.openAIBaseURL, }); } }, handleChatRequest: async (request: Request) => { return ChatSdk.handleChatRequest(request, { openai: self.openai, env: self.env, systemPrompt: self.systemPrompt, maxTokens: self.maxTokens, }); }, async runModelHandler(params: { streamConfig: any; streamParams: any; controller: ReadableStreamDefaultController; encoder: TextEncoder; streamId: string; }) { const { streamConfig, streamParams, controller, encoder, streamId } = params; const modelFamily = await ProviderRepository.getModelFamily(streamConfig.model, self.env); const useModelHandler = () => { return modelHandlers[modelFamily]; }; const handler = useModelHandler(); if (handler) { try { await handler(streamParams, handleStreamData(controller, encoder)); } catch (error) { const message = error.message.toLowerCase(); if ( message.includes('413 ') || message.includes('maximum') || message.includes('too long') || message.includes('too large') ) { throw new ClientError( `Error! Content length exceeds limits. Try shortening your message or editing an earlier message.`, 413, { model: streamConfig.model, maxTokens: streamParams.maxTokens, }, ); } if (message.includes('429 ')) { throw new ClientError( `Error! Rate limit exceeded. Wait a few minutes before trying again.`, 429, { model: streamConfig.model, maxTokens: streamParams.maxTokens, }, ); } if (message.includes('404')) { throw new ClientError(`Something went wrong, try again.`, 413, {}); } throw error; } } }, createSseReadableStream(params: { streamId: string; streamConfig: any; savedStreamConfig: string; durableObject: any; }) { const { streamId, streamConfig, savedStreamConfig, durableObject } = params; return new ReadableStream({ async start(controller) { const encoder = new TextEncoder(); try { const dynamicContext = Message.create(streamConfig.preprocessedContext); // Process the stream data using the appropriate handler const streamParams = await createStreamParams( streamConfig, dynamicContext, durableObject, ); await self.runModelHandler({ streamConfig, streamParams, controller, encoder, streamId, }); } catch (error) { console.error(`chatService::handleSseStream::${streamId}::Error`, error); if (error instanceof ClientError) { controller.enqueue( encoder.encode( `data: ${JSON.stringify({ type: 'error', error: error.message })}\n\n`, ), ); } else { controller.enqueue( encoder.encode( `data: ${JSON.stringify({ type: 'error', error: 'Server error', })}\n\n`, ), ); } controller.close(); } finally { try { controller.close(); } catch (_) { // Ignore errors when closing the controller, as it might already be closed } } }, }); }, handleSseStream: flow(function* ( streamId: string, ): Generator, Response, unknown> { // Check if a stream is already active for this ID if (self.activeStreams.has(streamId)) { return new Response('Stream already active', { status: 409 }); } // Retrieve the stream configuration from the durable object const objectId = self.env.SERVER_COORDINATOR.idFromName('stream-index'); const durableObject = self.env.SERVER_COORDINATOR.get(objectId); const savedStreamConfig = yield durableObject.getStreamData(streamId); if (!savedStreamConfig) { return new Response('Stream not found', { status: 404 }); } const streamConfig = JSON.parse(savedStreamConfig); const stream = self.createSseReadableStream({ streamId, streamConfig, savedStreamConfig, durableObject, }); // Use `tee()` to create two streams: one for processing and one for the response const [processingStream, responseStream] = stream.tee(); self.setActiveStream(streamId, { ...streamConfig, }); processingStream.pipeTo( new WritableStream({ close() { self.removeActiveStream(streamId); }, }), ); // Return the second stream as the response return new Response(responseStream, { headers: { 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', Connection: 'keep-alive', }, }); }), }; }); /** * ClientError * A custom construct for sending client-friendly errors via the controller in a structured and controlled manner. */ export class ClientError extends Error { public statusCode: number; public details: Record; constructor(message: string, statusCode: number, details: Record = {}) { super(message); this.name = 'ClientError'; this.statusCode = statusCode; this.details = details; Object.setPrototypeOf(this, ClientError.prototype); } /** * Formats the error for SSE-compatible data transmission. */ public formatForSSE(): string { return JSON.stringify({ type: 'error', message: this.message, details: this.details, statusCode: this.statusCode, }); } } export default ChatService;