diff --git a/src/index.ts b/src/index.ts index 02c9f26..3ae5fa0 100644 --- a/src/index.ts +++ b/src/index.ts @@ -24,6 +24,8 @@ export { getGridBounds } from './utils/grid-utils'; export { LLMProvider, LLMResponse, + LocalLLMProvider, + LocalVisionLLMProvider, OpenAIProvider, AnthropicProvider, GLMProvider, diff --git a/src/llm-provider.ts b/src/llm-provider.ts index 9816c24..5acf2e3 100644 --- a/src/llm-provider.ts +++ b/src/llm-provider.ts @@ -68,6 +68,211 @@ export abstract class LLMProvider { } } +/** + * Local OpenAI-compatible Provider (Ollama / LM Studio / llama.cpp server, etc.) + * + * This is the TypeScript equivalent of Python's LocalLLMProvider concept, but instead of + * embedding a full HF runtime inside Node, it calls a local HTTP server that exposes an + * OpenAI-compatible Chat Completions API. + * + * Examples of compatible local servers: + * - Ollama (OpenAI-compatible endpoint) + * - LM Studio (OpenAI-compatible endpoint) + * - llama.cpp server (OpenAI-compatible endpoint) + */ +export class LocalLLMProvider extends LLMProvider { + private _modelName: string; + private _baseUrl: string; + private _apiKey?: string; + private _defaultHeaders: Record; + private _timeoutMs: number; + + constructor( + options: { + model?: string; + baseUrl?: string; + apiKey?: string; + timeoutMs?: number; + headers?: Record; + } = {} + ) { + super(); + this._modelName = options.model ?? process.env.SENTIENCE_LOCAL_LLM_MODEL ?? 'local-model'; + // Common defaults: + // - Ollama OpenAI-compatible: http://localhost:11434/v1 + // - LM Studio: http://localhost:1234/v1 + this._baseUrl = + options.baseUrl ?? process.env.SENTIENCE_LOCAL_LLM_BASE_URL ?? 'http://localhost:11434/v1'; + this._apiKey = options.apiKey ?? process.env.SENTIENCE_LOCAL_LLM_API_KEY; + this._timeoutMs = options.timeoutMs ?? 60_000; + this._defaultHeaders = { + 'Content-Type': 'application/json', + ...(options.headers ?? {}), + }; + if (this._apiKey) { + this._defaultHeaders.Authorization = `Bearer ${this._apiKey}`; + } + } + + supportsJsonMode(): boolean { + // Many local OpenAI-compatible servers don't reliably implement response_format=json_object. + return false; + } + + get modelName(): string { + return this._modelName; + } + + async generate( + systemPrompt: string, + userPrompt: string, + options: Record = {} + ): Promise { + const fetchFn = (globalThis as any).fetch as typeof fetch | undefined; + if (!fetchFn) { + throw new Error( + 'Global fetch is not available. Use Node 18+ or polyfill fetch before using LocalLLMProvider.' + ); + } + + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), options.timeoutMs ?? this._timeoutMs); + + const payload: any = { + model: this._modelName, + messages: [ + ...(systemPrompt ? [{ role: 'system', content: systemPrompt }] : []), + { role: 'user', content: userPrompt }, + ], + temperature: options.temperature ?? 0.0, + }; + + if (options.max_tokens !== undefined) payload.max_tokens = options.max_tokens; + if (options.top_p !== undefined) payload.top_p = options.top_p; + + // Allow pass-through of server-specific fields, but avoid overriding core fields accidentally + const { timeoutMs: _ignoredTimeout, ...rest } = options; + Object.assign(payload, rest); + + try { + const res = await fetchFn(`${this._baseUrl}/chat/completions`, { + method: 'POST', + headers: this._defaultHeaders, + body: JSON.stringify(payload), + signal: controller.signal, + }); + + const text = await res.text(); + if (!res.ok) { + throw new Error(`Local LLM HTTP ${res.status}: ${text.slice(0, 500)}`); + } + + const data = JSON.parse(text); + const choice = data?.choices?.[0]; + const content = choice?.message?.content ?? ''; + const usage = data?.usage; + + return { + content, + promptTokens: usage?.prompt_tokens, + completionTokens: usage?.completion_tokens, + totalTokens: usage?.total_tokens, + modelName: data?.model ?? this._modelName, + }; + } finally { + clearTimeout(timeoutId); + } + } +} + +/** + * Local OpenAI-compatible Vision Provider. + * + * This is the TypeScript analogue of Python's LocalVisionLLMProvider, but it assumes your + * local server supports the OpenAI vision message format (image_url with data: URI). + * + * If your local stack uses a different schema (e.g., Ollama images array), you can implement + * a custom provider by extending LLMProvider. + */ +export class LocalVisionLLMProvider extends LocalLLMProvider { + supportsVision(): boolean { + return true; + } + + async generateWithImage( + systemPrompt: string, + userPrompt: string, + imageBase64: string, + options: Record = {} + ): Promise { + const fetchFn = (globalThis as any).fetch as typeof fetch | undefined; + if (!fetchFn) { + throw new Error( + 'Global fetch is not available. Use Node 18+ or polyfill fetch before using LocalVisionLLMProvider.' + ); + } + + const controller = new AbortController(); + const timeoutId = setTimeout( + () => controller.abort(), + options.timeoutMs ?? (this as any)._timeoutMs ?? 60_000 + ); + + // Rebuild payload (we cannot reuse LocalLLMProvider.generate because message shape differs) + const modelName = (this as any)._modelName ?? 'local-model'; + const baseUrl = (this as any)._baseUrl ?? 'http://localhost:11434/v1'; + const headers = (this as any)._defaultHeaders ?? { 'Content-Type': 'application/json' }; + + const payload: any = { + model: modelName, + messages: [ + ...(systemPrompt ? [{ role: 'system', content: systemPrompt }] : []), + { + role: 'user', + content: [ + { type: 'text', text: userPrompt }, + { type: 'image_url', image_url: { url: `data:image/png;base64,${imageBase64}` } }, + ], + }, + ], + temperature: options.temperature ?? 0.0, + }; + + if (options.max_tokens !== undefined) payload.max_tokens = options.max_tokens; + const { timeoutMs: _ignoredTimeout, ...rest } = options; + Object.assign(payload, rest); + + try { + const res = await fetchFn(`${baseUrl}/chat/completions`, { + method: 'POST', + headers, + body: JSON.stringify(payload), + signal: controller.signal, + }); + + const text = await res.text(); + if (!res.ok) { + throw new Error(`Local Vision LLM HTTP ${res.status}: ${text.slice(0, 500)}`); + } + + const data = JSON.parse(text); + const choice = data?.choices?.[0]; + const content = choice?.message?.content ?? ''; + const usage = data?.usage; + + return { + content, + promptTokens: usage?.prompt_tokens, + completionTokens: usage?.completion_tokens, + totalTokens: usage?.total_tokens, + modelName: data?.model ?? modelName, + }; + } finally { + clearTimeout(timeoutId); + } + } +} + /** * OpenAI Provider (GPT-4, GPT-4o, etc.) * Requires: npm install openai diff --git a/tests/local-llm-provider.test.ts b/tests/local-llm-provider.test.ts new file mode 100644 index 0000000..a7d582c --- /dev/null +++ b/tests/local-llm-provider.test.ts @@ -0,0 +1,75 @@ +import { LocalLLMProvider, LocalVisionLLMProvider } from '../src/llm-provider'; + +describe('LocalLLMProvider (OpenAI-compatible)', () => { + const originalFetch = (globalThis as any).fetch; + + afterEach(() => { + (globalThis as any).fetch = originalFetch; + }); + + it('should call /chat/completions and parse response', async () => { + (globalThis as any).fetch = jest.fn(async () => { + return { + ok: true, + status: 200, + text: async () => + JSON.stringify({ + model: 'local-model', + choices: [{ message: { content: 'hello' } }], + usage: { prompt_tokens: 1, completion_tokens: 2, total_tokens: 3 }, + }), + }; + }); + + const llm = new LocalLLMProvider({ + baseUrl: 'http://localhost:11434/v1', + model: 'local-model', + }); + const resp = await llm.generate('sys', 'user', { temperature: 0.0 }); + + expect(resp.content).toBe('hello'); + expect(resp.modelName).toBe('local-model'); + expect(resp.totalTokens).toBe(3); + expect((globalThis as any).fetch).toHaveBeenCalledTimes(1); + expect(((globalThis as any).fetch as any).mock.calls[0][0]).toBe( + 'http://localhost:11434/v1/chat/completions' + ); + }); +}); + +describe('LocalVisionLLMProvider (OpenAI-compatible)', () => { + const originalFetch = (globalThis as any).fetch; + + afterEach(() => { + (globalThis as any).fetch = originalFetch; + }); + + it('should send image_url message content', async () => { + let capturedBody: any = null; + (globalThis as any).fetch = jest.fn(async (_url: string, init: any) => { + capturedBody = JSON.parse(init.body); + return { + ok: true, + status: 200, + text: async () => + JSON.stringify({ + model: 'local-vision', + choices: [{ message: { content: 'YES' } }], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + }), + }; + }); + + const llm = new LocalVisionLLMProvider({ + baseUrl: 'http://localhost:1234/v1', + model: 'local-vision', + }); + + const resp = await llm.generateWithImage('sys', 'is there a button?', 'AAAA', {}); + expect(resp.content).toBe('YES'); + expect(capturedBody.messages[1].content[1].type).toBe('image_url'); + expect(capturedBody.messages[1].content[1].image_url.url).toContain( + 'data:image/png;base64,AAAA' + ); + }); +});