diff --git a/apis/cloudflare/src/env.ts b/apis/cloudflare/src/env.ts index f9f3d65a..5fdb164f 100644 --- a/apis/cloudflare/src/env.ts +++ b/apis/cloudflare/src/env.ts @@ -4,6 +4,7 @@ declare global { BRAINTRUST_APP_URL: string; WHITELISTED_ORIGINS?: string; METRICS_LICENSE_KEY?: string; + NATIVE_INFERENCE_SECRET_KEY?: string; } } diff --git a/apis/cloudflare/src/proxy.ts b/apis/cloudflare/src/proxy.ts index 56257c32..5eb614da 100644 --- a/apis/cloudflare/src/proxy.ts +++ b/apis/cloudflare/src/proxy.ts @@ -199,6 +199,7 @@ export async function handleProxyV1( spanLogger, spanId, spanExport, + nativeInferenceSecretKey: env.NATIVE_INFERENCE_SECRET_KEY, }; const url = new URL(request.url); diff --git a/packages/proxy/edge/index.ts b/packages/proxy/edge/index.ts index f7e954a7..f861a9fc 100644 --- a/packages/proxy/edge/index.ts +++ b/packages/proxy/edge/index.ts @@ -4,7 +4,7 @@ import { proxyV1, SpanLogger, LogHistogramFn } from "@lib/proxy"; import { isEmpty } from "@lib/util"; import { MeterProvider } from "@opentelemetry/sdk-metrics"; -import { APISecret, getModelEndpointTypes } from "@schema"; +import { APISecret, APISecretSchema, getModelEndpointTypes } from "@schema"; import { verifyTempCredentials, isTempCredential } from "utils"; import { decryptMessage, @@ -38,6 +38,7 @@ export interface ProxyOpts { spanLogger?: SpanLogger; spanId?: string; spanExport?: string; + nativeInferenceSecretKey?: string; } const defaultWhitelist: (string | RegExp)[] = [ @@ -113,6 +114,10 @@ export async function digestMessage(message: string) { return btoa(String.fromCharCode(...new Uint8Array(hash))); } +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} + export function makeFetchApiSecrets({ ctx, opts, @@ -192,11 +197,57 @@ export function makeFetchApiSecrets({ model, org_name, mode: "full", + ...(opts.nativeInferenceSecretKey + ? { can_execute_native_inference: true } + : {}), }), }, ); if (response.ok) { - secrets = await response.json(); + const responseJson: unknown = await response.json(); + if ( + isRecord(responseJson) && + responseJson.encrypted === true && + typeof responseJson.iv === "string" && + typeof responseJson.data === "string" + ) { + if (!opts.nativeInferenceSecretKey) { + throw new Error( + "Received encrypted response but NATIVE_INFERENCE_SECRET_KEY is not configured", + ); + } + const keys = opts.nativeInferenceSecretKey + .split(",") + .map((k) => k.trim()) + .filter((k) => k.length > 0); + let decrypted: string | null | undefined = null; + for (const key of keys) { + const encryptionKey = await digestMessage(key); + try { + decrypted = await decryptMessage( + encryptionKey, + responseJson.iv, + responseJson.data, + ); + if (decrypted) break; + } catch {} + } + if (!decrypted) { + throw new Error( + "Failed to decrypt native inference response (tried all keys)", + ); + } + const parsed: unknown = JSON.parse(decrypted); + if (!Array.isArray(parsed)) { + throw new Error("Decrypted response is not an array"); + } + secrets = parsed.map((s: unknown) => APISecretSchema.parse(s)); + } else { + if (!Array.isArray(responseJson)) { + throw new Error("Response is not an array"); + } + secrets = responseJson.map((s: unknown) => APISecretSchema.parse(s)); + } } else { lookupFailed = true; console.warn("Failed to lookup api key", await response.text()); diff --git a/packages/proxy/schema/index.ts b/packages/proxy/schema/index.ts index 06afb293..43dbc16a 100644 --- a/packages/proxy/schema/index.ts +++ b/packages/proxy/schema/index.ts @@ -632,6 +632,7 @@ export const EndpointProviderToBaseURL: { [name in ModelEndpointType]: string | null; } = { openai: "https://api.openai.com/v1", + braintrust: null, anthropic: "https://api.anthropic.com/v1", perplexity: "https://api.perplexity.ai", replicate: "https://openai-proxy.replicate.com/v1", diff --git a/packages/proxy/schema/model_list.json b/packages/proxy/schema/model_list.json index 3571ca51..957c7e60 100644 --- a/packages/proxy/schema/model_list.json +++ b/packages/proxy/schema/model_list.json @@ -3907,9 +3907,7 @@ "displayName": "Gemini 3 Pro Preview", "reasoning": true, "reasoning_budget": true, - "locations": [ - "global" - ], + "locations": ["global"], "max_input_tokens": 1048576, "max_output_tokens": 65535 }, @@ -3923,9 +3921,7 @@ "displayName": "Gemini 3 Flash Preview", "reasoning": true, "reasoning_budget": true, - "locations": [ - "global" - ], + "locations": ["global"], "max_input_tokens": 1048576, "max_output_tokens": 65535 }, diff --git a/packages/proxy/schema/models.ts b/packages/proxy/schema/models.ts index c9dfaaa0..49e473bb 100644 --- a/packages/proxy/schema/models.ts +++ b/packages/proxy/schema/models.ts @@ -16,6 +16,7 @@ export type ModelFormat = (typeof ModelFormats)[number]; export const ModelEndpointType = [ "openai", + "braintrust", "anthropic", "google", "mistral", diff --git a/packages/proxy/schema/secrets.ts b/packages/proxy/schema/secrets.ts index dda8a2f4..9c7dafb6 100644 --- a/packages/proxy/schema/secrets.ts +++ b/packages/proxy/schema/secrets.ts @@ -8,6 +8,8 @@ export const BaseMetadataSchema = z excludeDefaultModels: z.boolean().nullish(), additionalHeaders: z.record(z.string(), z.string()).nullish(), supportsStreaming: z.boolean().default(true), + custom_model: z.string().nullish(), + custom_system_prompt: z.string().nullish(), }) .strict(); @@ -89,6 +91,10 @@ export const OpenAIMetadataSchema = BaseMetadataSchema.merge( z.null(), ]), organization_id: z.string().nullish(), + // Custom endpoint path to override the default (e.g., "" to use api_base as full URL) + endpoint_path: z.string().nullish(), + // Auth format for the authorization header (default: "bearer") + auth_format: z.enum(["bearer", "api_key"]).nullish(), }), ).strict(); @@ -114,6 +120,7 @@ export const APISecretSchema = z.union([ type: z.enum([ "perplexity", "anthropic", + "braintrust", "google", "replicate", "together", diff --git a/packages/proxy/src/proxy.ts b/packages/proxy/src/proxy.ts index 12a644e4..65dbfbb2 100644 --- a/packages/proxy/src/proxy.ts +++ b/packages/proxy/src/proxy.ts @@ -268,7 +268,7 @@ export async function proxyV1({ } // Caching is enabled by default, but let the user disable it - let useCacheMode = parseEnumHeader( + const useCacheMode = parseEnumHeader( CACHE_HEADER, CACHE_MODES, proxyHeaders[CACHE_HEADER], @@ -328,6 +328,7 @@ export async function proxyV1({ let bodyData = null; if ( url === "/auto" || + url === "/embeddings" || url === "/chat/completions" || url === "/responses" || url === "/completions" || @@ -347,6 +348,7 @@ export async function proxyV1({ if ( method === "POST" && (url === "/auto" || + url === "/embeddings" || url === "/chat/completions" || url === "/completions" || url === "/responses" || @@ -458,7 +460,7 @@ export async function proxyV1({ // The data key is used as the encryption key, so unless you have the actual incoming data, you can't decrypt the cache. const encryptionKey = await digest(`${dataKey}:${authToken}`); - let startTime = getCurrentUnixTimestamp(); + const startTime = getCurrentUnixTimestamp(); let spanType: SpanType | undefined = undefined; const isStreaming = !!bodyData?.stream; @@ -501,7 +503,7 @@ export async function proxyV1({ stream = new ReadableStream({ start(controller) { if ("body" in cachedData && cachedData.body) { - let splits = cachedData.body.split("\n"); + const splits = cachedData.body.split("\n"); for (let i = 0; i < splits.length; i++) { controller.enqueue( new TextEncoder().encode( @@ -553,7 +555,7 @@ export async function proxyV1({ let responseFailed = false; - let overridenHeaders: string[] = []; + const overridenHeaders: string[] = []; const setOverriddenHeader = (name: string, value: string) => { overridenHeaders.push(name); setHeader(name, value); @@ -1032,7 +1034,9 @@ export async function proxyV1({ { const data = dataRaw as CreateEmbeddingResponse; spanLogger?.log({ - output: { embedding_length: data.data[0].embedding.length }, + output: { + embedding_length: data.data?.[0].embedding.length, + }, metrics: { tokens: data.usage?.total_tokens, prompt_tokens: data.usage?.prompt_tokens, @@ -1149,7 +1153,7 @@ const RATE_LIMITING_ERROR_CODES = [ OVERLOADED_ERROR_CODE, ]; -let loopIndex = 0; +const loopIndex = 0; async function fetchModelLoop( logHistogram: LogHistogramFn | undefined, method: "GET" | "POST", @@ -1176,6 +1180,54 @@ async function fetchModelLoop( // TODO: Make this smarter. For now, just pick a random one. const secrets = await getApiSecrets(model); + + const customModelOverride = + secrets.length > 0 && secrets[0].metadata?.custom_model + ? String(secrets[0].metadata.custom_model) + : null; + const customSystemPromptContent = + secrets.length > 0 && secrets[0].metadata?.custom_system_prompt + ? String(secrets[0].metadata.custom_system_prompt) + : null; + + if (customModelOverride) { + model = customModelOverride; + if (bodyData && typeof bodyData === "object" && "model" in bodyData) { + bodyData = { ...bodyData, model: customModelOverride }; + + // Inject custom system prompt for chat completions (if provided) + if (customSystemPromptContent && Array.isArray(bodyData.messages)) { + const customSystemPrompt = { + role: "system", + content: customSystemPromptContent, + }; + const existingMessages = bodyData.messages; + const hasSystemMessage = + existingMessages.length > 0 && existingMessages[0]?.role === "system"; + if (hasSystemMessage) { + // Prepend custom context to existing system message + const existingSystem = existingMessages[0]; + bodyData = { + ...bodyData, + messages: [ + { + ...existingSystem, + content: `${customSystemPromptContent}\n\n${existingSystem.content ?? ""}`, + }, + ...existingMessages.slice(1), + ], + }; + } else { + // Add system message at the beginning + bodyData = { + ...bodyData, + messages: [customSystemPrompt, ...existingMessages], + }; + } + } + } + } + const initialIdx = getRandomInt(secrets.length); let proxyResponse: ModelResponse | null = null; let secretName: string | null | undefined = null; @@ -1990,7 +2042,13 @@ async function fetchOpenAI( `${baseURL}/serving-endpoints/${bodyData.model}/invocations`, ); } else { - fullURL = new URL(baseURL + url); + const endpointPath = + secret.metadata && + "endpoint_path" in secret.metadata && + typeof secret.metadata.endpoint_path === "string" + ? secret.metadata.endpoint_path + : url; + fullURL = new URL(baseURL + endpointPath); } } @@ -2012,7 +2070,17 @@ async function fetchOpenAI( } headers["host"] = fullURL.host; - headers["authorization"] = "Bearer " + bearerToken; + // Use custom auth format if specified (e.g., "api_key" for Baseten) + const authFormat = + secret.metadata && + "auth_format" in secret.metadata && + secret.metadata.auth_format === "api_key" + ? "api_key" + : "bearer"; + headers["authorization"] = + authFormat === "api_key" + ? `Api-Key ${bearerToken}` + : `Bearer ${bearerToken}`; if (secret.type === "azure" && secret.metadata?.api_version) { fullURL.searchParams.set("api-version", secret.metadata.api_version); @@ -2684,7 +2752,7 @@ async function fetchAnthropicChatCompletions({ if (proxyResponse.ok) { if (params.stream) { let idx = 0; - let usage: Partial = {}; + const usage: Partial = {}; stream = stream.pipeThrough( createEventStreamTransformer((data) => { const ret = anthropicEventToOpenAIEvent( @@ -2820,7 +2888,7 @@ async function openAIToolsToGoogleTools(params: { break; } } - let out = { + const out = { tools: params.tools ? [ { diff --git a/packages/proxy/types/openai.ts b/packages/proxy/types/openai.ts index 0c010b02..906c2507 100644 --- a/packages/proxy/types/openai.ts +++ b/packages/proxy/types/openai.ts @@ -92,7 +92,7 @@ export const completionUsageSchema = z.object({ reasoning_tokens: z.number().optional(), rejected_prediction_tokens: z.number().optional(), }) - .optional(), + .nullish(), prompt_tokens_details: z .object({ audio_tokens: z.number().optional(), @@ -104,7 +104,7 @@ export const completionUsageSchema = z.object({ "Extension to support Anthropic `cache_creation_input_tokens`", ), }) - .optional(), + .nullish(), }); export type OpenAICompletionUsage = z.infer;