diff --git a/src/app/v1/_lib/proxy/forwarder.ts b/src/app/v1/_lib/proxy/forwarder.ts index c33082343..d8a21078e 100644 --- a/src/app/v1/_lib/proxy/forwarder.ts +++ b/src/app/v1/_lib/proxy/forwarder.ts @@ -25,6 +25,7 @@ import { import { getGlobalAgentPool, getProxyAgentForProvider } from "@/lib/proxy-agent"; import { SessionManager } from "@/lib/session-manager"; import { CONTEXT_1M_BETA_HEADER, shouldApplyContext1m } from "@/lib/special-attributes"; +import { detectUpstreamErrorFromSseOrJsonText } from "@/lib/utils/upstream-error-detection"; import { isVendorTypeCircuitOpen, recordVendorTypeAllEndpointsTimeout, @@ -84,6 +85,62 @@ const MAX_PROVIDER_SWITCHES = 20; // 保险栓:最多切换 20 次供应商( type CacheTtlOption = CacheTtlPreference | null | undefined; +// 非流式响应体检查的上限(字节):避免上游在 2xx 场景返回超大内容导致内存占用失控。 +// 说明: +// - 该检查仅用于“空响应/假 200”启发式判定,不用于业务逻辑解析; +// - 超过上限时,仍认为“非空”,但会跳过 JSON 内容结构检查(避免截断导致误判)。 +const NON_STREAM_BODY_INSPECTION_MAX_BYTES = 1024 * 1024; // 1 MiB + +async function readResponseTextUpTo( + response: Response, + maxBytes: number +): Promise<{ text: string; truncated: boolean }> { + const reader = response.body?.getReader(); + if (!reader) { + return { text: "", truncated: false }; + } + + const decoder = new TextDecoder(); + const chunks: string[] = []; + let bytesRead = 0; + let truncated = false; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + if (!value || value.byteLength === 0) continue; + + const remaining = maxBytes - bytesRead; + if (remaining <= 0) { + truncated = true; + break; + } + + if (value.byteLength > remaining) { + chunks.push(decoder.decode(value.subarray(0, remaining), { stream: true })); + bytesRead += remaining; + truncated = true; + break; + } + + chunks.push(decoder.decode(value, { stream: true })); + bytesRead += value.byteLength; + } + + const flushed = decoder.decode(); + if (flushed) chunks.push(flushed); + + if (truncated) { + try { + await reader.cancel(); + } catch { + // ignore + } + } + + return { text: chunks.join(""), truncated }; +} + function resolveCacheTtlPreference( keyPref: CacheTtlOption, providerPref: CacheTtlOption @@ -619,7 +676,11 @@ export class ProxyForwarder { // ========== 空响应检测(仅非流式)========== const contentType = response.headers.get("content-type") || ""; - const isSSE = contentType.includes("text/event-stream"); + const normalizedContentType = contentType.toLowerCase(); + const isSSE = normalizedContentType.includes("text/event-stream"); + const isHtml = + normalizedContentType.includes("text/html") || + normalizedContentType.includes("application/xhtml+xml"); // ========== 流式响应:延迟成功判定(避免“假 200”)========== // 背景:上游可能返回 HTTP 200,但 SSE 内容为错误 JSON(如 {"error": "..."})。 @@ -655,29 +716,62 @@ export class ProxyForwarder { return response; } - if (!isSSE) { - // 非流式响应:检测空响应 - const contentLength = response.headers.get("content-length"); + // 非流式响应:检测空响应 + const contentLength = response.headers.get("content-length"); - // 检测 Content-Length: 0 的情况 - if (contentLength === "0") { - throw new EmptyResponseError(currentProvider.id, currentProvider.name, "empty_body"); + // 检测 Content-Length: 0 的情况 + if (contentLength === "0") { + throw new EmptyResponseError(currentProvider.id, currentProvider.name, "empty_body"); + } + + // 200 + text/html(或 xhtml)通常是上游网关/WAF/Cloudflare 的错误页,但被包装成了 HTTP 200。 + // 这种“假 200”会导致: + // - 熔断/故障转移统计被误记为成功; + // - session 智能绑定被更新到不可用 provider(影响后续重试)。 + // 因此这里在进入成功分支前做一次强信号检测:仅当 body 看起来是完整 HTML 文档时才视为错误。 + let inspectedText: string | undefined; + let inspectedTruncated = false; + if (isHtml || !contentLength) { + const clonedResponse = response.clone(); + const inspected = await readResponseTextUpTo( + clonedResponse, + NON_STREAM_BODY_INSPECTION_MAX_BYTES + ); + inspectedText = inspected.text; + inspectedTruncated = inspected.truncated; + } + + if (isHtml && inspectedText !== undefined) { + const detected = detectUpstreamErrorFromSseOrJsonText(inspectedText); + if (detected.isError && detected.code === "FAKE_200_HTML_BODY") { + throw new ProxyError(detected.code, 502, { + body: detected.detail ?? "", + providerId: currentProvider.id, + providerName: currentProvider.name, + }); } + } - // 对于没有 Content-Length 的情况,需要 clone 并检查响应体 - // 注意:这会增加一定的性能开销,但对于非流式响应是可接受的 - if (!contentLength) { - const clonedResponse = response.clone(); - const responseText = await clonedResponse.text(); - - if (!responseText || responseText.trim() === "") { - throw new EmptyResponseError( - currentProvider.id, - currentProvider.name, - "empty_body" - ); - } + // 对于没有 Content-Length 的情况,需要 clone 并检查响应体 + // 注意:这会增加一定的性能开销,但对于非流式响应是可接受的 + if (!contentLength) { + const responseText = inspectedText ?? ""; + + if (!responseText || responseText.trim() === "") { + throw new EmptyResponseError(currentProvider.id, currentProvider.name, "empty_body"); + } + if (inspectedTruncated) { + logger.debug( + "ProxyForwarder: Response body too large for non-stream content check, skipping JSON parse", + { + providerId: currentProvider.id, + providerName: currentProvider.name, + contentType, + maxBytes: NON_STREAM_BODY_INSPECTION_MAX_BYTES, + } + ); + } else { // 尝试解析 JSON 并检查是否有输出内容 try { const responseJson = JSON.parse(responseText) as Record; @@ -722,7 +816,12 @@ export class ProxyForwarder { // 注意:不抛出错误,因为某些请求(如 count_tokens)可能合法地返回 0 output tokens } } - } catch (_parseError) { + } catch (_parseOrContentError) { + // EmptyResponseError 会触发重试/故障转移,不能在这里被当作 JSON 解析错误吞掉。 + if (isEmptyResponseError(_parseOrContentError)) { + throw _parseOrContentError; + } + // JSON 解析失败但响应体不为空,不视为空响应错误 logger.debug("ProxyForwarder: Non-JSON response body, skipping content check", { providerId: currentProvider.id, diff --git a/src/lib/utils/upstream-error-detection.test.ts b/src/lib/utils/upstream-error-detection.test.ts index d1facd969..88b5b7516 100644 --- a/src/lib/utils/upstream-error-detection.test.ts +++ b/src/lib/utils/upstream-error-detection.test.ts @@ -16,6 +16,31 @@ describe("detectUpstreamErrorFromSseOrJsonText", () => { }); }); + test("明显的 HTML 文档视为错误(覆盖 200+text/html 的“假 200”)", () => { + const html = [ + "", + '', + "New API", + "Something went wrong", + "", + ].join("\n"); + const res = detectUpstreamErrorFromSseOrJsonText(html); + expect(res).toEqual({ + isError: true, + code: "FAKE_200_HTML_BODY", + detail: expect.any(String), + }); + }); + + test("纯 JSON:content 内包含 文本不应误判为 HTML 错误", () => { + const body = JSON.stringify({ + type: "message", + content: [{ type: "text", text: "not an error" }], + }); + const res = detectUpstreamErrorFromSseOrJsonText(body); + expect(res.isError).toBe(false); + }); + test("纯 JSON:error 字段非空视为错误", () => { const res = detectUpstreamErrorFromSseOrJsonText('{"error":"当前无可用凭证"}'); expect(res.isError).toBe(true); diff --git a/src/lib/utils/upstream-error-detection.ts b/src/lib/utils/upstream-error-detection.ts index 066f1bc8f..56734b971 100644 --- a/src/lib/utils/upstream-error-detection.ts +++ b/src/lib/utils/upstream-error-detection.ts @@ -18,6 +18,7 @@ import { parseSSEData } from "@/lib/utils/sse"; * * 设计目标(偏保守) * - 仅基于结构化字段做启发式判断:`error` 与 `message`; + * - 对明显的 HTML 文档(doctype/html 标签)做强信号判定,覆盖部分网关/WAF/Cloudflare 返回的“假 200”; * - 不扫描模型生成的正文内容(例如 content/choices),避免把用户/模型自然语言里的 "error" 误判为上游错误; * - message 关键字检测仅对“小体积 JSON”启用,降低误判与性能开销。 * - 返回的 `code` 是语言无关的错误码(便于写入 DB/监控/告警); @@ -53,6 +54,7 @@ const DEFAULT_MESSAGE_KEYWORD = /error/i; const FAKE_200_CODES = { EMPTY_BODY: "FAKE_200_EMPTY_BODY", + HTML_BODY: "FAKE_200_HTML_BODY", JSON_ERROR_NON_EMPTY: "FAKE_200_JSON_ERROR_NON_EMPTY", JSON_ERROR_MESSAGE_NON_EMPTY: "FAKE_200_JSON_ERROR_MESSAGE_NON_EMPTY", JSON_MESSAGE_KEYWORD_MATCH: "FAKE_200_JSON_MESSAGE_KEYWORD_MATCH", @@ -63,6 +65,16 @@ const FAKE_200_CODES = { const MAY_HAVE_JSON_ERROR_KEY = /"error"\s*:/; const MAY_HAVE_JSON_MESSAGE_KEY = /"message"\s*:/; +const HTML_DOC_SNIFF_MAX_CHARS = 1024; +const HTML_DOCTYPE_RE = /^]/i; +const HTML_HTML_TAG_RE = /]/i; + +function isLikelyHtmlDocument(trimmedText: string): boolean { + if (!trimmedText.startsWith("<")) return false; + const head = trimmedText.slice(0, HTML_DOC_SNIFF_MAX_CHARS); + return HTML_DOCTYPE_RE.test(head) || HTML_HTML_TAG_RE.test(head); +} + function isPlainRecord(value: unknown): value is Record { return !!value && typeof value === "object" && !Array.isArray(value); } @@ -194,6 +206,20 @@ export function detectUpstreamErrorFromSseOrJsonText( return { isError: true, code: FAKE_200_CODES.EMPTY_BODY }; } + // 情况 0:明显的 HTML 文档(通常是网关/WAF/Cloudflare 返回的错误页) + // + // 说明: + // - 此处不依赖 Content-Type:部分上游会缺失/错误设置该字段; + // - 仅匹配 doctype/html 标签等“强信号”,避免把普通 `<...>` 文本误判为 HTML 页面。 + if (isLikelyHtmlDocument(trimmed)) { + return { + isError: true, + code: FAKE_200_CODES.HTML_BODY, + // 避免对超大 HTML 做无谓处理:仅截取前段用于脱敏/截断与排查 + detail: truncateForDetail(trimmed.slice(0, 4096)), + }; + } + // 情况 1:纯 JSON(对象) // 上游可能 Content-Type 设置为 SSE,但实际上返回 JSON;此处只处理对象格式({...}), // 不处理数组([...])以避免误判(数组场景的语义差异较大,后续若确认需要再扩展)。 diff --git a/tests/unit/lib/provider-endpoints/probe.test.ts b/tests/unit/lib/provider-endpoints/probe.test.ts index c77b04845..be25071bb 100644 --- a/tests/unit/lib/provider-endpoints/probe.test.ts +++ b/tests/unit/lib/provider-endpoints/probe.test.ts @@ -51,6 +51,8 @@ describe("provider-endpoints: probe", () => { })); vi.doMock("@/lib/endpoint-circuit-breaker", () => ({ recordEndpointFailure: vi.fn(async () => {}), + getEndpointCircuitStateSync: vi.fn(() => "closed"), + resetEndpointCircuit: vi.fn(async () => {}), })); const fetchMock = vi.fn(async (_url: string, init?: RequestInit) => { @@ -91,6 +93,8 @@ describe("provider-endpoints: probe", () => { })); vi.doMock("@/lib/endpoint-circuit-breaker", () => ({ recordEndpointFailure: vi.fn(async () => {}), + getEndpointCircuitStateSync: vi.fn(() => "closed"), + resetEndpointCircuit: vi.fn(async () => {}), })); const fetchMock = vi.fn(async (_url: string, init?: RequestInit) => { @@ -253,6 +257,8 @@ describe("provider-endpoints: probe", () => { })); vi.doMock("@/lib/endpoint-circuit-breaker", () => ({ recordEndpointFailure: recordFailureMock, + getEndpointCircuitStateSync: vi.fn(() => "closed"), + resetEndpointCircuit: vi.fn(async () => {}), })); vi.stubGlobal( @@ -299,6 +305,8 @@ describe("provider-endpoints: probe", () => { })); vi.doMock("@/lib/endpoint-circuit-breaker", () => ({ recordEndpointFailure: recordFailureMock, + getEndpointCircuitStateSync: vi.fn(() => "closed"), + resetEndpointCircuit: vi.fn(async () => {}), })); vi.stubGlobal( @@ -369,6 +377,8 @@ describe("provider-endpoints: probe", () => { })); vi.doMock("@/lib/endpoint-circuit-breaker", () => ({ recordEndpointFailure: recordFailureMock, + getEndpointCircuitStateSync: vi.fn(() => "closed"), + resetEndpointCircuit: vi.fn(async () => {}), })); vi.stubGlobal( @@ -409,6 +419,8 @@ describe("provider-endpoints: probe", () => { })); vi.doMock("@/lib/endpoint-circuit-breaker", () => ({ recordEndpointFailure: recordFailureMock, + getEndpointCircuitStateSync: vi.fn(() => "closed"), + resetEndpointCircuit: vi.fn(async () => {}), })); vi.stubGlobal( diff --git a/tests/unit/proxy/proxy-forwarder-fake-200-html.test.ts b/tests/unit/proxy/proxy-forwarder-fake-200-html.test.ts new file mode 100644 index 000000000..7c1fc4215 --- /dev/null +++ b/tests/unit/proxy/proxy-forwarder-fake-200-html.test.ts @@ -0,0 +1,274 @@ +import { beforeEach, describe, expect, test, vi } from "vitest"; + +const mocks = vi.hoisted(() => { + return { + pickRandomProviderWithExclusion: vi.fn(), + recordSuccess: vi.fn(), + recordFailure: vi.fn(async () => {}), + getCircuitState: vi.fn(() => "closed"), + getProviderHealthInfo: vi.fn(async () => ({ + health: { failureCount: 0 }, + config: { failureThreshold: 3 }, + })), + updateMessageRequestDetails: vi.fn(async () => {}), + isHttp2Enabled: vi.fn(async () => false), + }; +}); + +vi.mock("@/lib/logger", () => ({ + logger: { + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + trace: vi.fn(), + error: vi.fn(), + fatal: vi.fn(), + }, +})); + +vi.mock("@/lib/config", async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + isHttp2Enabled: mocks.isHttp2Enabled, + }; +}); + +vi.mock("@/lib/circuit-breaker", () => ({ + getCircuitState: mocks.getCircuitState, + getProviderHealthInfo: mocks.getProviderHealthInfo, + recordFailure: mocks.recordFailure, + recordSuccess: mocks.recordSuccess, +})); + +vi.mock("@/repository/message", () => ({ + updateMessageRequestDetails: mocks.updateMessageRequestDetails, +})); + +vi.mock("@/lib/endpoint-circuit-breaker", () => ({ + recordEndpointSuccess: vi.fn(async () => {}), + recordEndpointFailure: vi.fn(async () => {}), +})); + +vi.mock("@/app/v1/_lib/proxy/provider-selector", () => ({ + ProxyProviderResolver: { + pickRandomProviderWithExclusion: mocks.pickRandomProviderWithExclusion, + }, +})); + +import { ProxyForwarder } from "@/app/v1/_lib/proxy/forwarder"; +import { ProxySession } from "@/app/v1/_lib/proxy/session"; +import type { Provider } from "@/types/provider"; + +function createProvider(overrides: Partial = {}): Provider { + return { + id: 1, + name: "p1", + url: "https://provider.example.com", + key: "k", + providerVendorId: null, + isEnabled: true, + weight: 1, + priority: 0, + groupPriorities: null, + costMultiplier: 1, + groupTag: null, + providerType: "openai-compatible", + preserveClientIp: false, + modelRedirects: null, + allowedModels: null, + mcpPassthroughType: "none", + mcpPassthroughUrl: null, + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + totalCostResetAt: null, + limitConcurrentSessions: 0, + maxRetryAttempts: 1, + circuitBreakerFailureThreshold: 5, + circuitBreakerOpenDuration: 1_800_000, + circuitBreakerHalfOpenSuccessThreshold: 2, + proxyUrl: null, + proxyFallbackToDirect: false, + firstByteTimeoutStreamingMs: 30_000, + streamingIdleTimeoutMs: 10_000, + requestTimeoutNonStreamingMs: 1_000, + websiteUrl: null, + faviconUrl: null, + cacheTtlPreference: null, + context1mPreference: null, + codexReasoningEffortPreference: null, + codexReasoningSummaryPreference: null, + codexTextVerbosityPreference: null, + codexParallelToolCallsPreference: null, + anthropicMaxTokensPreference: null, + anthropicThinkingBudgetPreference: null, + geminiGoogleSearchPreference: null, + tpm: 0, + rpm: 0, + rpd: 0, + cc: 0, + createdAt: new Date(), + updatedAt: new Date(), + deletedAt: null, + ...overrides, + }; +} + +function createSession(): ProxySession { + const headers = new Headers(); + const session = Object.create(ProxySession.prototype); + + Object.assign(session, { + startTime: Date.now(), + method: "POST", + requestUrl: new URL("https://example.com/v1/messages"), + headers, + originalHeaders: new Headers(headers), + headerLog: JSON.stringify(Object.fromEntries(headers.entries())), + request: { + model: "claude-test", + log: "(test)", + message: { + model: "claude-test", + messages: [{ role: "user", content: "hi" }], + }, + }, + userAgent: null, + context: null, + clientAbortSignal: null, + userName: "test-user", + authState: { success: true, user: null, key: null, apiKey: null }, + provider: null, + messageContext: null, + sessionId: null, + requestSequence: 1, + originalFormat: "claude", + providerType: null, + originalModelName: null, + originalUrlPathname: null, + providerChain: [], + cacheTtlResolved: null, + context1mApplied: false, + specialSettings: [], + cachedPriceData: undefined, + cachedBillingModelSource: undefined, + isHeaderModified: () => false, + }); + + return session as any; +} + +describe("ProxyForwarder - fake 200 HTML body", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + test("200 + text/html 的 HTML 页面应视为失败并切换供应商", async () => { + const provider1 = createProvider({ id: 1, name: "p1", key: "k1", maxRetryAttempts: 1 }); + const provider2 = createProvider({ id: 2, name: "p2", key: "k2", maxRetryAttempts: 1 }); + + const session = createSession(); + session.setProvider(provider1); + + mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2); + + const doForward = vi.spyOn(ProxyForwarder as any, "doForward"); + + const htmlBody = [ + "", + "New API", + "blocked", + ].join("\n"); + const okJson = JSON.stringify({ type: "message", content: [{ type: "text", text: "ok" }] }); + + doForward.mockResolvedValueOnce( + new Response(htmlBody, { + status: 200, + headers: { + "content-type": "text/html; charset=utf-8", + "content-length": String(htmlBody.length), + }, + }) + ); + + doForward.mockResolvedValueOnce( + new Response(okJson, { + status: 200, + headers: { + "content-type": "application/json; charset=utf-8", + "content-length": String(okJson.length), + }, + }) + ); + + const response = await ProxyForwarder.send(session); + expect(await response.text()).toContain("ok"); + + expect(doForward).toHaveBeenCalledTimes(2); + expect(doForward.mock.calls[0][1].id).toBe(1); + expect(doForward.mock.calls[1][1].id).toBe(2); + + expect(mocks.pickRandomProviderWithExclusion).toHaveBeenCalledWith(session, [1]); + expect(mocks.recordFailure).toHaveBeenCalledWith( + 1, + expect.objectContaining({ message: "FAKE_200_HTML_BODY" }) + ); + expect(mocks.recordSuccess).toHaveBeenCalledWith(2); + expect(mocks.recordSuccess).not.toHaveBeenCalledWith(1); + }); + + test("缺少 content 字段(missing_content)不应被 JSON 解析 catch 吞掉,应触发切换供应商", async () => { + const provider1 = createProvider({ id: 1, name: "p1", key: "k1", maxRetryAttempts: 1 }); + const provider2 = createProvider({ id: 2, name: "p2", key: "k2", maxRetryAttempts: 1 }); + + const session = createSession(); + session.setProvider(provider1); + + mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2); + + const doForward = vi.spyOn(ProxyForwarder as any, "doForward"); + + const missingContentJson = JSON.stringify({ type: "message", content: [] }); + const okJson = JSON.stringify({ type: "message", content: [{ type: "text", text: "ok" }] }); + + doForward.mockResolvedValueOnce( + new Response(missingContentJson, { + status: 200, + headers: { + "content-type": "application/json; charset=utf-8", + // 故意不提供 content-length:覆盖 forwarder 的 clone + JSON 内容结构检查分支 + }, + }) + ); + + doForward.mockResolvedValueOnce( + new Response(okJson, { + status: 200, + headers: { + "content-type": "application/json; charset=utf-8", + "content-length": String(okJson.length), + }, + }) + ); + + const response = await ProxyForwarder.send(session); + expect(await response.text()).toContain("ok"); + + expect(doForward).toHaveBeenCalledTimes(2); + expect(doForward.mock.calls[0][1].id).toBe(1); + expect(doForward.mock.calls[1][1].id).toBe(2); + + expect(mocks.pickRandomProviderWithExclusion).toHaveBeenCalledWith(session, [1]); + expect(mocks.recordFailure).toHaveBeenCalledWith( + 1, + expect.objectContaining({ reason: "missing_content" }) + ); + expect(mocks.recordSuccess).toHaveBeenCalledWith(2); + expect(mocks.recordSuccess).not.toHaveBeenCalledWith(1); + }); +});