diff --git a/messages/en/auth.json b/messages/en/auth.json index 460d311d0..4feeabdd7 100644 --- a/messages/en/auth.json +++ b/messages/en/auth.json @@ -1,7 +1,10 @@ { "form": { "title": "Login Panel", - "description": "Access the unified admin console with your API Key" + "description": "Access the unified admin console with your API Key", + "apiKeyLabel": "API Key", + "showPassword": "Show password", + "hidePassword": "Hide password" }, "login": { "title": "Login", @@ -20,6 +23,9 @@ "placeholders": { "apiKeyExample": "e.g. sk-xxxxxxxx" }, + "brand": { + "tagline": "Unified API management console" + }, "actions": { "enterConsole": "Enter Console", "viewUsageDoc": "View Usage Documentation" diff --git a/messages/en/settings/providers/batchEdit.json b/messages/en/settings/providers/batchEdit.json index 7abbd6045..c1bc48a12 100644 --- a/messages/en/settings/providers/batchEdit.json +++ b/messages/en/settings/providers/batchEdit.json @@ -5,6 +5,10 @@ "invertSelection": "Invert", "selectedCount": "{count} selected", "editSelected": "Edit Selected", + "selectByType": "Select by Type", + "selectByTypeItem": "{type} ({count})", + "selectByGroup": "Select by Group", + "selectByGroupItem": "{group} ({count})", "actions": { "edit": "Edit", "delete": "Delete", @@ -20,12 +24,33 @@ "next": "Next", "noFieldEnabled": "Please enable at least one field to update" }, + "sections": { + "basic": "Basic Settings", + "routing": "Group & Routing", + "anthropic": "Anthropic Settings" + }, "fields": { - "isEnabled": "Status", + "isEnabled": { + "label": "Status", + "noChange": "No Change", + "enable": "Enable", + "disable": "Disable" + }, "priority": "Priority", "weight": "Weight", "costMultiplier": "Cost Multiplier", - "groupTag": "Group Tag" + "groupTag": { + "label": "Group Tag", + "clear": "Clear" + }, + "modelRedirects": "Model Redirects", + "allowedModels": "Allowed Models", + "thinkingBudget": "Thinking Budget", + "adaptiveThinking": "Adaptive Thinking" + }, + "affectedProviders": { + "title": "Affected Providers", + "more": "+{count} more" }, "confirm": { "title": "Confirm Operation", @@ -34,10 +59,47 @@ "goBack": "Go Back", "processing": "Processing..." }, + "preview": { + "title": "Preview Changes", + "description": "Review changes before applying to {count} providers", + "providerHeader": "{name}", + "fieldChanged": "{field}: {before} -> {after}", + "fieldSkipped": "{field}: Skipped ({reason})", + "excludeProvider": "Exclude", + "summary": "{providerCount} providers, {fieldCount} changes, {skipCount} skipped", + "noChanges": "No changes to apply", + "apply": "Apply Changes", + "back": "Back to Edit", + "loading": "Generating preview..." + }, + "batchNotes": { + "codexOnly": "Codex only", + "claudeOnly": "Claude only", + "geminiOnly": "Gemini only" + }, + "selectionHint": "Select multiple providers for batch operations", + "undo": { + "button": "Undo", + "success": "Operation undone successfully", + "expired": "Undo expired", + "batchDeleteSuccess": "Deleted {count} providers", + "batchDeleteUndone": "Restored {count} providers", + "singleDeleteSuccess": "Provider deleted", + "singleDeleteUndone": "Provider restored", + "singleEditSuccess": "Provider updated", + "singleEditUndone": "Changes reverted", + "failed": "Undo failed" + }, "toast": { "updated": "Updated {count} providers", "deleted": "Deleted {count} providers", "circuitReset": "Reset {count} circuit breakers", - "failed": "Operation failed: {error}" + "failed": "Operation failed: {error}", + "undo": "Undo", + "undoSuccess": "Reverted {count} providers", + "undoFailed": "Undo failed: {error}", + "undoExpired": "Undo window expired", + "previewFailed": "Preview failed: {error}", + "unknownError": "Unknown error" } } diff --git a/messages/ja/auth.json b/messages/ja/auth.json index 113aa9193..68658e5ce 100644 --- a/messages/ja/auth.json +++ b/messages/ja/auth.json @@ -1,7 +1,10 @@ { "form": { "title": "ログインパネル", - "description": "API キーを使用して統一管理コンソールにアクセスします" + "description": "API キーを使用して統一管理コンソールにアクセスします", + "apiKeyLabel": "API Key", + "showPassword": "パスワードを表示", + "hidePassword": "パスワードを非表示" }, "login": { "title": "ログイン", @@ -20,6 +23,9 @@ "placeholders": { "apiKeyExample": "例: sk-xxxxxxxx" }, + "brand": { + "tagline": "統合API管理コンソール" + }, "actions": { "enterConsole": "コンソールに入る", "viewUsageDoc": "使用方法を見る" diff --git a/messages/ja/settings/providers/batchEdit.json b/messages/ja/settings/providers/batchEdit.json index 68f98a0a2..8feb4f198 100644 --- a/messages/ja/settings/providers/batchEdit.json +++ b/messages/ja/settings/providers/batchEdit.json @@ -5,6 +5,10 @@ "invertSelection": "反転", "selectedCount": "{count} 件選択中", "editSelected": "選択項目を編集", + "selectByType": "タイプで選択", + "selectByTypeItem": "{type} ({count})", + "selectByGroup": "グループで選択", + "selectByGroupItem": "{group} ({count})", "actions": { "edit": "編集", "delete": "削除", @@ -20,12 +24,33 @@ "next": "次へ", "noFieldEnabled": "更新するフィールドを少なくとも1つ有効にしてください" }, + "sections": { + "basic": "基本設定", + "routing": "グループとルーティング", + "anthropic": "Anthropic 設定" + }, "fields": { - "isEnabled": "ステータス", + "isEnabled": { + "label": "ステータス", + "noChange": "変更なし", + "enable": "有効", + "disable": "無効" + }, "priority": "優先度", "weight": "重み", "costMultiplier": "価格倍率", - "groupTag": "グループタグ" + "groupTag": { + "label": "グループタグ", + "clear": "クリア" + }, + "modelRedirects": "モデルリダイレクト", + "allowedModels": "許可モデル", + "thinkingBudget": "思考バジェット", + "adaptiveThinking": "アダプティブ思考" + }, + "affectedProviders": { + "title": "影響を受けるプロバイダー", + "more": "+{count} 件" }, "confirm": { "title": "操作の確認", @@ -34,10 +59,47 @@ "goBack": "戻る", "processing": "処理中..." }, + "preview": { + "title": "変更のプレビュー", + "description": "{count} 件のプロバイダーに適用する前に変更内容を確認してください", + "providerHeader": "{name}", + "fieldChanged": "{field}: {before} -> {after}", + "fieldSkipped": "{field}: スキップ ({reason})", + "excludeProvider": "除外", + "summary": "{providerCount} 件のプロバイダー, {fieldCount} 件の変更, {skipCount} 件スキップ", + "noChanges": "適用する変更はありません", + "apply": "変更を適用", + "back": "編集に戻る", + "loading": "プレビューを生成中..." + }, + "batchNotes": { + "codexOnly": "Codex のみ", + "claudeOnly": "Claude のみ", + "geminiOnly": "Gemini のみ" + }, + "selectionHint": "複数のプロバイダーを選択して一括操作を実行", + "undo": { + "button": "元に戻す", + "success": "操作が正常に元に戻されました", + "expired": "元に戻す期限が切れました", + "batchDeleteSuccess": "{count} 件のプロバイダーを削除しました", + "batchDeleteUndone": "{count} 件のプロバイダーを復元しました", + "singleDeleteSuccess": "プロバイダーを削除しました", + "singleDeleteUndone": "プロバイダーを復元しました", + "singleEditSuccess": "プロバイダーを更新しました", + "singleEditUndone": "変更を元に戻しました", + "failed": "元に戻すことに失敗しました" + }, "toast": { "updated": "{count} 件のプロバイダーを更新しました", "deleted": "{count} 件のプロバイダーを削除しました", "circuitReset": "{count} 件のサーキットブレーカーをリセットしました", - "failed": "操作に失敗しました: {error}" + "failed": "操作に失敗しました: {error}", + "undo": "元に戻す", + "undoSuccess": "{count} 件のプロバイダーを復元しました", + "undoFailed": "元に戻す操作に失敗しました: {error}", + "undoExpired": "元に戻す期限が切れました", + "previewFailed": "プレビューに失敗しました: {error}", + "unknownError": "不明なエラー" } } diff --git a/messages/ru/auth.json b/messages/ru/auth.json index 4e6f42542..de91560a7 100644 --- a/messages/ru/auth.json +++ b/messages/ru/auth.json @@ -1,7 +1,10 @@ { "form": { "title": "Панель входа", - "description": "Введите ваш API ключ для доступа к данным" + "description": "Введите ваш API ключ для доступа к данным", + "apiKeyLabel": "API Key", + "showPassword": "Показать пароль", + "hidePassword": "Скрыть пароль" }, "login": { "title": "Вход", @@ -20,6 +23,9 @@ "placeholders": { "apiKeyExample": "например sk-xxxxxxxx" }, + "brand": { + "tagline": "Единая консоль управления API" + }, "actions": { "enterConsole": "Перейти в консоль", "viewUsageDoc": "Просмотреть документацию" diff --git a/messages/ru/settings/providers/batchEdit.json b/messages/ru/settings/providers/batchEdit.json index 3d5c6c4f3..9a620bf1b 100644 --- a/messages/ru/settings/providers/batchEdit.json +++ b/messages/ru/settings/providers/batchEdit.json @@ -5,6 +5,10 @@ "invertSelection": "Инвертировать", "selectedCount": "Выбрано: {count}", "editSelected": "Редактировать выбранные", + "selectByType": "Выбрать по типу", + "selectByTypeItem": "{type} ({count})", + "selectByGroup": "Выбрать по группе", + "selectByGroupItem": "{group} ({count})", "actions": { "edit": "Редактировать", "delete": "Удалить", @@ -20,12 +24,33 @@ "next": "Далее", "noFieldEnabled": "Пожалуйста, включите хотя бы одно поле для обновления" }, + "sections": { + "basic": "Основные настройки", + "routing": "Группы и маршрутизация", + "anthropic": "Настройки Anthropic" + }, "fields": { - "isEnabled": "Статус", + "isEnabled": { + "label": "Статус", + "noChange": "Без изменений", + "enable": "Включить", + "disable": "Отключить" + }, "priority": "Приоритет", "weight": "Вес", "costMultiplier": "Множитель стоимости", - "groupTag": "Тег группы" + "groupTag": { + "label": "Тег группы", + "clear": "Очистить" + }, + "modelRedirects": "Перенаправление моделей", + "allowedModels": "Разрешённые модели", + "thinkingBudget": "Бюджет мышления", + "adaptiveThinking": "Адаптивное мышление" + }, + "affectedProviders": { + "title": "Затронутые поставщики", + "more": "+{count} ещё" }, "confirm": { "title": "Подтвердите операцию", @@ -34,10 +59,47 @@ "goBack": "Назад", "processing": "Обработка..." }, + "preview": { + "title": "Предпросмотр изменений", + "description": "Проверьте изменения перед применением к {count} поставщикам", + "providerHeader": "{name}", + "fieldChanged": "{field}: {before} -> {after}", + "fieldSkipped": "{field}: Пропущено ({reason})", + "excludeProvider": "Исключить", + "summary": "{providerCount} поставщиков, {fieldCount} изменений, {skipCount} пропущено", + "noChanges": "Нет изменений для применения", + "apply": "Применить изменения", + "back": "Вернуться к редактированию", + "loading": "Генерация предпросмотра..." + }, + "batchNotes": { + "codexOnly": "Только Codex", + "claudeOnly": "Только Claude", + "geminiOnly": "Только Gemini" + }, + "selectionHint": "Выберите нескольких поставщиков для массовых операций", + "undo": { + "button": "Отменить", + "success": "Операция успешно отменена", + "expired": "Время отмены истекло", + "batchDeleteSuccess": "Удалено поставщиков: {count}", + "batchDeleteUndone": "Восстановлено поставщиков: {count}", + "singleDeleteSuccess": "Поставщик удалён", + "singleDeleteUndone": "Поставщик восстановлен", + "singleEditSuccess": "Поставщик обновлён", + "singleEditUndone": "Изменения отменены", + "failed": "Ошибка отмены" + }, "toast": { "updated": "Обновлено поставщиков: {count}", "deleted": "Удалено поставщиков: {count}", "circuitReset": "Сброшено прерывателей: {count}", - "failed": "Операция не удалась: {error}" + "failed": "Операция не удалась: {error}", + "undo": "Отменить", + "undoSuccess": "Восстановлено поставщиков: {count}", + "undoFailed": "Отмена не удалась: {error}", + "undoExpired": "Время отмены истекло", + "previewFailed": "Предпросмотр не удался: {error}", + "unknownError": "Неизвестная ошибка" } } diff --git a/messages/zh-CN/auth.json b/messages/zh-CN/auth.json index 9ffb12e4f..9cb3f1934 100644 --- a/messages/zh-CN/auth.json +++ b/messages/zh-CN/auth.json @@ -27,6 +27,9 @@ "placeholders": { "apiKeyExample": "例如 sk-xxxxxxxx" }, + "brand": { + "tagline": "统一 API 管理控制台" + }, "actions": { "enterConsole": "进入控制台", "viewUsageDoc": "查看使用文档" @@ -41,6 +44,9 @@ }, "form": { "title": "登录面板", - "description": "使用您的 API Key 进入统一控制台" + "description": "使用您的 API Key 进入统一控制台", + "apiKeyLabel": "API Key", + "showPassword": "显示密码", + "hidePassword": "隐藏密码" } } diff --git a/messages/zh-CN/settings/providers/batchEdit.json b/messages/zh-CN/settings/providers/batchEdit.json index 87e6d842b..49d938805 100644 --- a/messages/zh-CN/settings/providers/batchEdit.json +++ b/messages/zh-CN/settings/providers/batchEdit.json @@ -5,6 +5,10 @@ "invertSelection": "反选", "selectedCount": "已选 {count} 项", "editSelected": "编辑选中项", + "selectByType": "按类型选择", + "selectByTypeItem": "{type} ({count})", + "selectByGroup": "按分组选择", + "selectByGroupItem": "{group} ({count})", "actions": { "edit": "编辑", "delete": "删除", @@ -20,12 +24,33 @@ "next": "下一步", "noFieldEnabled": "请至少启用一个要更新的字段" }, + "sections": { + "basic": "基本设置", + "routing": "分组与路由", + "anthropic": "Anthropic 设置" + }, "fields": { - "isEnabled": "状态", + "isEnabled": { + "label": "状态", + "noChange": "不修改", + "enable": "启用", + "disable": "禁用" + }, "priority": "优先级", "weight": "权重", "costMultiplier": "价格倍率", - "groupTag": "分组标签" + "groupTag": { + "label": "分组标签", + "clear": "清除" + }, + "modelRedirects": "模型重定向", + "allowedModels": "允许的模型", + "thinkingBudget": "思维预算", + "adaptiveThinking": "自适应思维" + }, + "affectedProviders": { + "title": "受影响的供应商", + "more": "+{count} 更多" }, "confirm": { "title": "确认操作", @@ -34,10 +59,47 @@ "goBack": "返回", "processing": "处理中..." }, + "preview": { + "title": "预览变更", + "description": "将变更应用到 {count} 个供应商前请先确认", + "providerHeader": "{name}", + "fieldChanged": "{field}: {before} -> {after}", + "fieldSkipped": "{field}: 已跳过 ({reason})", + "excludeProvider": "排除", + "summary": "{providerCount} 个供应商, {fieldCount} 项变更, {skipCount} 项跳过", + "noChanges": "没有可应用的变更", + "apply": "应用变更", + "back": "返回编辑", + "loading": "正在生成预览..." + }, + "batchNotes": { + "codexOnly": "仅 Codex", + "claudeOnly": "仅 Claude", + "geminiOnly": "仅 Gemini" + }, + "selectionHint": "选择多个服务商后可进行批量操作", + "undo": { + "button": "撤销", + "success": "操作已成功撤销", + "expired": "撤销窗口已过期", + "batchDeleteSuccess": "已删除 {count} 个供应商", + "batchDeleteUndone": "已恢复 {count} 个供应商", + "singleDeleteSuccess": "供应商已删除", + "singleDeleteUndone": "供应商已恢复", + "singleEditSuccess": "供应商已更新", + "singleEditUndone": "更改已回退", + "failed": "撤销失败" + }, "toast": { "updated": "已更新 {count} 个供应商", "deleted": "已删除 {count} 个供应商", "circuitReset": "已重置 {count} 个熔断器", - "failed": "操作失败: {error}" + "failed": "操作失败: {error}", + "undo": "撤销", + "undoSuccess": "已还原 {count} 个供应商", + "undoFailed": "撤销失败: {error}", + "undoExpired": "撤销窗口已过期", + "previewFailed": "预览失败: {error}", + "unknownError": "未知错误" } } diff --git a/messages/zh-TW/auth.json b/messages/zh-TW/auth.json index 58da807c1..439ca9dca 100644 --- a/messages/zh-TW/auth.json +++ b/messages/zh-TW/auth.json @@ -1,7 +1,10 @@ { "form": { "title": "登錄面板", - "description": "使用您的 API Key 進入統一控制台" + "description": "使用您的 API Key 進入統一控制台", + "apiKeyLabel": "API Key", + "showPassword": "顯示密碼", + "hidePassword": "隱藏密碼" }, "login": { "title": "登錄", @@ -20,6 +23,9 @@ "placeholders": { "apiKeyExample": "例如 sk-xxxxxxxx" }, + "brand": { + "tagline": "統一 API 管理控制台" + }, "actions": { "enterConsole": "進入控制台", "viewUsageDoc": "查看使用文檔" diff --git a/messages/zh-TW/settings/providers/batchEdit.json b/messages/zh-TW/settings/providers/batchEdit.json index 30ac0472a..b8541e6e1 100644 --- a/messages/zh-TW/settings/providers/batchEdit.json +++ b/messages/zh-TW/settings/providers/batchEdit.json @@ -5,6 +5,10 @@ "invertSelection": "反選", "selectedCount": "已選 {count} 項", "editSelected": "編輯選中項", + "selectByType": "按類型選擇", + "selectByTypeItem": "{type} ({count})", + "selectByGroup": "按分組選擇", + "selectByGroupItem": "{group} ({count})", "actions": { "edit": "編輯", "delete": "刪除", @@ -20,12 +24,33 @@ "next": "下一步", "noFieldEnabled": "請至少啟用一個要更新的欄位" }, + "sections": { + "basic": "基本設定", + "routing": "分組與路由", + "anthropic": "Anthropic 設定" + }, "fields": { - "isEnabled": "狀態", + "isEnabled": { + "label": "狀態", + "noChange": "不修改", + "enable": "啟用", + "disable": "停用" + }, "priority": "優先級", "weight": "權重", "costMultiplier": "價格倍率", - "groupTag": "分組標籤" + "groupTag": { + "label": "分組標籤", + "clear": "清除" + }, + "modelRedirects": "模型重新導向", + "allowedModels": "允許的模型", + "thinkingBudget": "思維預算", + "adaptiveThinking": "自適應思維" + }, + "affectedProviders": { + "title": "受影響的供應商", + "more": "+{count} 更多" }, "confirm": { "title": "確認操作", @@ -34,10 +59,47 @@ "goBack": "返回", "processing": "處理中..." }, + "preview": { + "title": "預覽變更", + "description": "將變更應用到 {count} 個供應商前請先確認", + "providerHeader": "{name}", + "fieldChanged": "{field}: {before} -> {after}", + "fieldSkipped": "{field}: 已跳過 ({reason})", + "excludeProvider": "排除", + "summary": "{providerCount} 個供應商, {fieldCount} 項變更, {skipCount} 項跳過", + "noChanges": "沒有可應用的變更", + "apply": "應用變更", + "back": "返回編輯", + "loading": "正在產生預覽..." + }, + "batchNotes": { + "codexOnly": "僅 Codex", + "claudeOnly": "僅 Claude", + "geminiOnly": "僅 Gemini" + }, + "selectionHint": "選擇多個供應商以進行批次操作", + "undo": { + "button": "復原", + "success": "操作已成功復原", + "expired": "復原時限已過期", + "batchDeleteSuccess": "已刪除 {count} 個供應商", + "batchDeleteUndone": "已還原 {count} 個供應商", + "singleDeleteSuccess": "供應商已刪除", + "singleDeleteUndone": "供應商已恢復", + "singleEditSuccess": "供應商已更新", + "singleEditUndone": "變更已還原", + "failed": "復原失敗" + }, "toast": { "updated": "已更新 {count} 個供應商", "deleted": "已刪除 {count} 個供應商", "circuitReset": "已重置 {count} 個熔斷器", - "failed": "操作失敗: {error}" + "failed": "操作失敗: {error}", + "undo": "復原", + "undoSuccess": "已還原 {count} 個供應商", + "undoFailed": "復原失敗: {error}", + "undoExpired": "復原時限已過期", + "previewFailed": "預覽失敗: {error}", + "unknownError": "未知錯誤" } } diff --git a/src/actions/providers.ts b/src/actions/providers.ts index 89cb72f06..181093158 100644 --- a/src/actions/providers.ts +++ b/src/actions/providers.ts @@ -1,6 +1,7 @@ "use server"; import { eq } from "drizzle-orm"; +import { z } from "zod"; import { GeminiAuth } from "@/app/v1/_lib/gemini/auth"; import { isClientAbortError } from "@/app/v1/_lib/proxy/errors"; import { buildProxyUrl } from "@/app/v1/_lib/url"; @@ -16,6 +17,13 @@ import { } from "@/lib/circuit-breaker"; import { PROVIDER_GROUP, PROVIDER_TIMEOUT_DEFAULTS } from "@/lib/constants/provider.constants"; import { logger } from "@/lib/logger"; +import { PROVIDER_BATCH_PATCH_ERROR_CODES } from "@/lib/provider-batch-patch-error-codes"; +import { + buildProviderBatchApplyUpdates, + hasProviderBatchPatchChanges, + normalizeProviderBatchPatchDraft, + PROVIDER_PATCH_ERROR_CODES, +} from "@/lib/provider-patch-contract"; import { executeProviderTest, type ProviderTestConfig, @@ -32,11 +40,15 @@ import { deleteProviderCircuitConfig, saveProviderCircuitConfig, } from "@/lib/redis/circuit-breaker-config"; +import { RedisKVStore } from "@/lib/redis/redis-kv-store"; import type { Context1mPreference } from "@/lib/special-attributes"; import { maskKey } from "@/lib/utils/validation"; +import { extractZodErrorCode, formatZodError } from "@/lib/utils/zod-i18n"; import { validateProviderUrlForConnectivity } from "@/lib/validation/provider-url"; import { CreateProviderSchema, UpdateProviderSchema } from "@/lib/validation/schemas"; +import { restoreProvidersBatch } from "@/repository"; import { + type BatchProviderUpdates, createProvider, deleteProvider, findAllProviders, @@ -46,6 +58,7 @@ import { resetProviderTotalCostResetAt, updateProvider, updateProviderPrioritiesBatch, + updateProvidersBatch, } from "@/repository/provider"; import { backfillProviderEndpointsFromProviders, @@ -63,7 +76,12 @@ import type { CodexReasoningEffortPreference, CodexReasoningSummaryPreference, CodexTextVerbosityPreference, + Provider, + ProviderBatchApplyUpdates, + ProviderBatchPatch, + ProviderBatchPatchField, ProviderDisplay, + ProviderPatchOperation, ProviderStatisticsMap, ProviderType, } from "@/types/provider"; @@ -664,7 +682,7 @@ export async function editProvider( rpd?: number | null; cc?: number | null; } -): Promise { +): Promise> { try { const session = await getSession(); if (!session || session.user.role !== "admin") { @@ -710,6 +728,30 @@ export async function editProvider( ...(faviconUrl !== undefined && { favicon_url: faviconUrl }), }; + const currentProvider = await findProviderById(providerId); + if (!currentProvider) { + return { ok: false, error: "供应商不存在" }; + } + + const preimageFields: Record = {}; + for (const [field, nextValue] of Object.entries(payload)) { + if (field === "key") { + continue; + } + + const providerKey = SINGLE_EDIT_PREIMAGE_FIELD_TO_PROVIDER_KEY[field]; + if (!providerKey) { + continue; + } + + const currentValue = currentProvider[providerKey]; + if (!hasProviderFieldChangedForUndo(currentValue, nextValue)) { + continue; + } + + preimageFields[providerKey] = currentValue; + } + const provider = await updateProvider(providerId, payload); if (!provider) { @@ -743,7 +785,26 @@ export async function editProvider( // 广播缓存更新(跨实例即时生效) await broadcastProviderCacheInvalidation({ operation: "edit", providerId }); - return { ok: true }; + const undoToken = createProviderPatchUndoToken(); + const operationId = createProviderPatchOperationId(); + + await providerPatchUndoStore.set(undoToken, { + undoToken, + operationId, + providerIds: [providerId], + preimage: { + [providerId]: preimageFields, + }, + patch: EMPTY_PROVIDER_BATCH_PATCH, + }); + + return { + ok: true, + data: { + undoToken, + operationId, + }, + }; } catch (error) { logger.error("更新服务商失败:", error); const message = error instanceof Error ? error.message : "更新服务商失败"; @@ -752,7 +813,9 @@ export async function editProvider( } // 删除服务商 -export async function removeProvider(providerId: number): Promise { +export async function removeProvider( + providerId: number +): Promise> { try { const session = await getSession(); if (!session || session.user.role !== "admin") { @@ -762,6 +825,15 @@ export async function removeProvider(providerId: number): Promise const provider = await findProviderById(providerId); await deleteProvider(providerId); + const undoToken = createProviderPatchUndoToken(); + const operationId = createProviderPatchOperationId(); + + await providerDeleteUndoStore.set(undoToken, { + undoToken, + operationId, + providerIds: [providerId], + }); + // 清除内存缓存(无论 Redis 是否成功都要执行) clearConfigCache(providerId); await clearProviderState(providerId); @@ -793,7 +865,13 @@ export async function removeProvider(providerId: number): Promise // 广播缓存更新(跨实例即时生效) await broadcastProviderCacheInvalidation({ operation: "remove", providerId }); - return { ok: true }; + return { + ok: true, + data: { + undoToken, + operationId, + }, + }; } catch (error) { logger.error("删除服务商失败:", error); const message = error instanceof Error ? error.message : "删除服务商失败"; @@ -1023,6 +1101,925 @@ export async function resetProviderTotalUsage(providerId: number): Promise; +} + +interface ProviderPatchUndoSnapshot { + undoToken: string; + operationId: string; + providerIds: number[]; + preimage: Record>; + patch: ProviderBatchPatch; +} + +interface ProviderDeleteUndoSnapshot { + undoToken: string; + operationId: string; + providerIds: number[]; +} + +const providerBatchPatchPreviewStore = new RedisKVStore({ + prefix: "cch:prov:preview:", + defaultTtlSeconds: PROVIDER_BATCH_PREVIEW_TTL_SECONDS, +}); +const providerPatchUndoStore = new RedisKVStore({ + prefix: "cch:prov:undo-patch:", + defaultTtlSeconds: PROVIDER_PATCH_UNDO_TTL_SECONDS, +}); +const providerDeleteUndoStore = new RedisKVStore({ + prefix: "cch:prov:undo-del:", + defaultTtlSeconds: PROVIDER_DELETE_UNDO_TTL_SECONDS, +}); +type ProviderPatchActionError = Extract; + +const SINGLE_EDIT_PREIMAGE_FIELD_TO_PROVIDER_KEY: Record = { + name: "name", + url: "url", + is_enabled: "isEnabled", + weight: "weight", + priority: "priority", + cost_multiplier: "costMultiplier", + group_tag: "groupTag", + group_priorities: "groupPriorities", + provider_type: "providerType", + preserve_client_ip: "preserveClientIp", + model_redirects: "modelRedirects", + allowed_models: "allowedModels", + limit_5h_usd: "limit5hUsd", + limit_daily_usd: "limitDailyUsd", + daily_reset_mode: "dailyResetMode", + daily_reset_time: "dailyResetTime", + limit_weekly_usd: "limitWeeklyUsd", + limit_monthly_usd: "limitMonthlyUsd", + limit_total_usd: "limitTotalUsd", + limit_concurrent_sessions: "limitConcurrentSessions", + cache_ttl_preference: "cacheTtlPreference", + swap_cache_ttl_billing: "swapCacheTtlBilling", + context_1m_preference: "context1mPreference", + codex_reasoning_effort_preference: "codexReasoningEffortPreference", + codex_reasoning_summary_preference: "codexReasoningSummaryPreference", + codex_text_verbosity_preference: "codexTextVerbosityPreference", + codex_parallel_tool_calls_preference: "codexParallelToolCallsPreference", + anthropic_max_tokens_preference: "anthropicMaxTokensPreference", + anthropic_thinking_budget_preference: "anthropicThinkingBudgetPreference", + anthropic_adaptive_thinking: "anthropicAdaptiveThinking", + gemini_google_search_preference: "geminiGoogleSearchPreference", + max_retry_attempts: "maxRetryAttempts", + circuit_breaker_failure_threshold: "circuitBreakerFailureThreshold", + circuit_breaker_open_duration: "circuitBreakerOpenDuration", + circuit_breaker_half_open_success_threshold: "circuitBreakerHalfOpenSuccessThreshold", + proxy_url: "proxyUrl", + proxy_fallback_to_direct: "proxyFallbackToDirect", + first_byte_timeout_streaming_ms: "firstByteTimeoutStreamingMs", + streaming_idle_timeout_ms: "streamingIdleTimeoutMs", + request_timeout_non_streaming_ms: "requestTimeoutNonStreamingMs", + website_url: "websiteUrl", + favicon_url: "faviconUrl", + mcp_passthrough_type: "mcpPassthroughType", + mcp_passthrough_url: "mcpPassthroughUrl", + tpm: "tpm", + rpm: "rpm", + rpd: "rpd", + cc: "cc", +}; + +const EMPTY_PROVIDER_BATCH_PATCH: ProviderBatchPatch = (() => { + const normalized = normalizeProviderBatchPatchDraft({}); + if (!normalized.ok) { + throw new Error("Failed to initialize empty provider batch patch"); + } + return normalized.data; +})(); + +function hasProviderFieldChangedForUndo(before: unknown, after: unknown): boolean { + if (Object.is(before, after)) { + return false; + } + + if ( + before !== null && + after !== null && + typeof before === "object" && + typeof after === "object" + ) { + try { + return JSON.stringify(before) !== JSON.stringify(after); + } catch { + return true; + } + } + + return true; +} + +function dedupeProviderIds(providerIds: number[]): number[] { + return [...new Set(providerIds)].sort((a, b) => a - b); +} + +function getChangedPatchFields(patch: ProviderBatchPatch): ProviderBatchPatchField[] { + return (Object.keys(patch) as ProviderBatchPatchField[]).filter( + (field) => patch[field].mode !== "no_change" + ); +} + +function isSameProviderIdList(left: number[], right: number[]): boolean { + if (left.length !== right.length) { + return false; + } + + for (let i = 0; i < left.length; i++) { + if (left[i] !== right[i]) { + return false; + } + } + + return true; +} + +function createProviderBatchPreviewToken(): string { + return `provider_patch_preview_${crypto.randomUUID()}`; +} + +function createProviderPatchUndoToken(): string { + return `provider_patch_undo_${crypto.randomUUID()}`; +} + +function createProviderPatchOperationId(): string { + return `provider_patch_apply_${crypto.randomUUID()}`; +} + +function buildActionValidationError(error: z.ZodError): ProviderPatchActionError { + return { + ok: false, + error: formatZodError(error), + errorCode: extractZodErrorCode(error) || PROVIDER_BATCH_PATCH_ERROR_CODES.INVALID_INPUT, + }; +} + +function buildNoChangesError(): ProviderPatchActionError { + return { + ok: false, + error: "没有可应用的变更", + errorCode: PROVIDER_BATCH_PATCH_ERROR_CODES.NOTHING_TO_APPLY, + }; +} + +function mapApplyUpdatesToRepositoryFormat( + applyUpdates: ProviderBatchApplyUpdates +): BatchProviderUpdates { + const result: BatchProviderUpdates = {}; + if (applyUpdates.is_enabled !== undefined) { + result.isEnabled = applyUpdates.is_enabled; + } + if (applyUpdates.priority !== undefined) { + result.priority = applyUpdates.priority; + } + if (applyUpdates.weight !== undefined) { + result.weight = applyUpdates.weight; + } + if (applyUpdates.cost_multiplier !== undefined) { + result.costMultiplier = applyUpdates.cost_multiplier.toString(); + } + if (applyUpdates.group_tag !== undefined) { + result.groupTag = applyUpdates.group_tag; + } + if (applyUpdates.model_redirects !== undefined) { + result.modelRedirects = applyUpdates.model_redirects; + } + if (applyUpdates.allowed_models !== undefined) { + result.allowedModels = applyUpdates.allowed_models; + } + if (applyUpdates.anthropic_thinking_budget_preference !== undefined) { + result.anthropicThinkingBudgetPreference = applyUpdates.anthropic_thinking_budget_preference; + } + if (applyUpdates.anthropic_adaptive_thinking !== undefined) { + result.anthropicAdaptiveThinking = applyUpdates.anthropic_adaptive_thinking; + } + if (applyUpdates.preserve_client_ip !== undefined) { + result.preserveClientIp = applyUpdates.preserve_client_ip; + } + if (applyUpdates.group_priorities !== undefined) { + result.groupPriorities = applyUpdates.group_priorities; + } + if (applyUpdates.cache_ttl_preference !== undefined) { + result.cacheTtlPreference = applyUpdates.cache_ttl_preference; + } + if (applyUpdates.swap_cache_ttl_billing !== undefined) { + result.swapCacheTtlBilling = applyUpdates.swap_cache_ttl_billing; + } + if (applyUpdates.context_1m_preference !== undefined) { + result.context1mPreference = applyUpdates.context_1m_preference; + } + if (applyUpdates.codex_reasoning_effort_preference !== undefined) { + result.codexReasoningEffortPreference = applyUpdates.codex_reasoning_effort_preference; + } + if (applyUpdates.codex_reasoning_summary_preference !== undefined) { + result.codexReasoningSummaryPreference = applyUpdates.codex_reasoning_summary_preference; + } + if (applyUpdates.codex_text_verbosity_preference !== undefined) { + result.codexTextVerbosityPreference = applyUpdates.codex_text_verbosity_preference; + } + if (applyUpdates.codex_parallel_tool_calls_preference !== undefined) { + result.codexParallelToolCallsPreference = applyUpdates.codex_parallel_tool_calls_preference; + } + if (applyUpdates.anthropic_max_tokens_preference !== undefined) { + result.anthropicMaxTokensPreference = applyUpdates.anthropic_max_tokens_preference; + } + if (applyUpdates.gemini_google_search_preference !== undefined) { + result.geminiGoogleSearchPreference = applyUpdates.gemini_google_search_preference; + } + if (applyUpdates.limit_5h_usd !== undefined) { + result.limit5hUsd = + applyUpdates.limit_5h_usd != null ? applyUpdates.limit_5h_usd.toString() : null; + } + if (applyUpdates.limit_daily_usd !== undefined) { + result.limitDailyUsd = + applyUpdates.limit_daily_usd != null ? applyUpdates.limit_daily_usd.toString() : null; + } + if (applyUpdates.daily_reset_mode !== undefined) { + result.dailyResetMode = applyUpdates.daily_reset_mode; + } + if (applyUpdates.daily_reset_time !== undefined) { + result.dailyResetTime = applyUpdates.daily_reset_time; + } + if (applyUpdates.limit_weekly_usd !== undefined) { + result.limitWeeklyUsd = + applyUpdates.limit_weekly_usd != null ? applyUpdates.limit_weekly_usd.toString() : null; + } + if (applyUpdates.limit_monthly_usd !== undefined) { + result.limitMonthlyUsd = + applyUpdates.limit_monthly_usd != null ? applyUpdates.limit_monthly_usd.toString() : null; + } + if (applyUpdates.limit_total_usd !== undefined) { + result.limitTotalUsd = + applyUpdates.limit_total_usd != null ? applyUpdates.limit_total_usd.toString() : null; + } + if (applyUpdates.limit_concurrent_sessions !== undefined) { + result.limitConcurrentSessions = applyUpdates.limit_concurrent_sessions; + } + if (applyUpdates.circuit_breaker_failure_threshold !== undefined) { + result.circuitBreakerFailureThreshold = applyUpdates.circuit_breaker_failure_threshold; + } + if (applyUpdates.circuit_breaker_open_duration !== undefined) { + result.circuitBreakerOpenDuration = applyUpdates.circuit_breaker_open_duration; + } + if (applyUpdates.circuit_breaker_half_open_success_threshold !== undefined) { + result.circuitBreakerHalfOpenSuccessThreshold = + applyUpdates.circuit_breaker_half_open_success_threshold; + } + if (applyUpdates.max_retry_attempts !== undefined) { + result.maxRetryAttempts = applyUpdates.max_retry_attempts; + } + if (applyUpdates.proxy_url !== undefined) { + result.proxyUrl = applyUpdates.proxy_url; + } + if (applyUpdates.proxy_fallback_to_direct !== undefined) { + result.proxyFallbackToDirect = applyUpdates.proxy_fallback_to_direct; + } + if (applyUpdates.first_byte_timeout_streaming_ms !== undefined) { + result.firstByteTimeoutStreamingMs = applyUpdates.first_byte_timeout_streaming_ms; + } + if (applyUpdates.streaming_idle_timeout_ms !== undefined) { + result.streamingIdleTimeoutMs = applyUpdates.streaming_idle_timeout_ms; + } + if (applyUpdates.request_timeout_non_streaming_ms !== undefined) { + result.requestTimeoutNonStreamingMs = applyUpdates.request_timeout_non_streaming_ms; + } + if (applyUpdates.mcp_passthrough_type !== undefined) { + result.mcpPassthroughType = applyUpdates.mcp_passthrough_type; + } + if (applyUpdates.mcp_passthrough_url !== undefined) { + result.mcpPassthroughUrl = applyUpdates.mcp_passthrough_url; + } + return result; +} + +const PATCH_FIELD_TO_PROVIDER_KEY: Record = { + is_enabled: "isEnabled", + priority: "priority", + weight: "weight", + cost_multiplier: "costMultiplier", + group_tag: "groupTag", + model_redirects: "modelRedirects", + allowed_models: "allowedModels", + anthropic_thinking_budget_preference: "anthropicThinkingBudgetPreference", + anthropic_adaptive_thinking: "anthropicAdaptiveThinking", + preserve_client_ip: "preserveClientIp", + group_priorities: "groupPriorities", + cache_ttl_preference: "cacheTtlPreference", + swap_cache_ttl_billing: "swapCacheTtlBilling", + context_1m_preference: "context1mPreference", + codex_reasoning_effort_preference: "codexReasoningEffortPreference", + codex_reasoning_summary_preference: "codexReasoningSummaryPreference", + codex_text_verbosity_preference: "codexTextVerbosityPreference", + codex_parallel_tool_calls_preference: "codexParallelToolCallsPreference", + anthropic_max_tokens_preference: "anthropicMaxTokensPreference", + gemini_google_search_preference: "geminiGoogleSearchPreference", + limit_5h_usd: "limit5hUsd", + limit_daily_usd: "limitDailyUsd", + daily_reset_mode: "dailyResetMode", + daily_reset_time: "dailyResetTime", + limit_weekly_usd: "limitWeeklyUsd", + limit_monthly_usd: "limitMonthlyUsd", + limit_total_usd: "limitTotalUsd", + limit_concurrent_sessions: "limitConcurrentSessions", + circuit_breaker_failure_threshold: "circuitBreakerFailureThreshold", + circuit_breaker_open_duration: "circuitBreakerOpenDuration", + circuit_breaker_half_open_success_threshold: "circuitBreakerHalfOpenSuccessThreshold", + max_retry_attempts: "maxRetryAttempts", + proxy_url: "proxyUrl", + proxy_fallback_to_direct: "proxyFallbackToDirect", + first_byte_timeout_streaming_ms: "firstByteTimeoutStreamingMs", + streaming_idle_timeout_ms: "streamingIdleTimeoutMs", + request_timeout_non_streaming_ms: "requestTimeoutNonStreamingMs", + mcp_passthrough_type: "mcpPassthroughType", + mcp_passthrough_url: "mcpPassthroughUrl", +}; + +const PATCH_FIELD_CLEAR_VALUE: Partial> = { + anthropic_thinking_budget_preference: "inherit", + cache_ttl_preference: "inherit", + context_1m_preference: "inherit", + codex_reasoning_effort_preference: "inherit", + codex_reasoning_summary_preference: "inherit", + codex_text_verbosity_preference: "inherit", + codex_parallel_tool_calls_preference: "inherit", + anthropic_max_tokens_preference: "inherit", + gemini_google_search_preference: "inherit", + mcp_passthrough_type: "none", +}; + +const CLAUDE_ONLY_FIELDS: ReadonlySet = new Set([ + "anthropic_thinking_budget_preference", + "anthropic_adaptive_thinking", + "anthropic_max_tokens_preference", + "context_1m_preference", +]); + +const CODEX_ONLY_FIELDS: ReadonlySet = new Set([ + "codex_reasoning_effort_preference", + "codex_reasoning_summary_preference", + "codex_text_verbosity_preference", + "codex_parallel_tool_calls_preference", +]); + +const GEMINI_ONLY_FIELDS: ReadonlySet = new Set([ + "gemini_google_search_preference", +]); + +function isClaudeProviderType(providerType: ProviderType): boolean { + return providerType === "claude" || providerType === "claude-auth"; +} + +function isCodexProviderType(providerType: ProviderType): boolean { + return providerType === "codex"; +} + +function isGeminiProviderType(providerType: ProviderType): boolean { + return providerType === "gemini" || providerType === "gemini-cli"; +} + +const CLAUDE_ONLY_REPO_KEYS: ReadonlySet = new Set([ + "anthropicThinkingBudgetPreference", + "anthropicAdaptiveThinking", + "anthropicMaxTokensPreference", + "context1mPreference", +]); + +const CODEX_ONLY_REPO_KEYS: ReadonlySet = new Set([ + "codexReasoningEffortPreference", + "codexReasoningSummaryPreference", + "codexTextVerbosityPreference", + "codexParallelToolCallsPreference", +]); + +const GEMINI_ONLY_REPO_KEYS: ReadonlySet = new Set([ + "geminiGoogleSearchPreference", +]); + +function filterRepositoryUpdatesByProviderType( + updates: BatchProviderUpdates, + providerType: string +): BatchProviderUpdates { + const filtered = { ...updates }; + if (!isClaudeProviderType(providerType as ProviderType)) { + for (const key of CLAUDE_ONLY_REPO_KEYS) delete filtered[key]; + } + if (!isCodexProviderType(providerType as ProviderType)) { + for (const key of CODEX_ONLY_REPO_KEYS) delete filtered[key]; + } + if (!isGeminiProviderType(providerType as ProviderType)) { + for (const key of GEMINI_ONLY_REPO_KEYS) delete filtered[key]; + } + return filtered; +} + +function computePreviewAfterValue( + field: ProviderBatchPatchField, + operation: ProviderPatchOperation +): unknown { + if (operation.mode === "set") { + if ( + field === "allowed_models" && + Array.isArray(operation.value) && + operation.value.length === 0 + ) { + return null; + } + return operation.value; + } + if (operation.mode === "clear") { + return PATCH_FIELD_CLEAR_VALUE[field] ?? null; + } + return undefined; +} + +function generatePreviewRows( + providers: Provider[], + patch: ProviderBatchPatch, + changedFields: ProviderBatchPatchField[] +): ProviderBatchPreviewRow[] { + const rows: ProviderBatchPreviewRow[] = []; + + for (const provider of providers) { + for (const field of changedFields) { + const operation = patch[field] as ProviderPatchOperation; + const providerKey = PATCH_FIELD_TO_PROVIDER_KEY[field]; + const before = provider[providerKey]; + const after = computePreviewAfterValue(field, operation); + + const isClaudeOnly = CLAUDE_ONLY_FIELDS.has(field); + const isCodexOnly = CODEX_ONLY_FIELDS.has(field); + const isGeminiOnly = GEMINI_ONLY_FIELDS.has(field); + + let isCompatible = true; + let skipReason = ""; + if (isClaudeOnly && !isClaudeProviderType(provider.providerType)) { + isCompatible = false; + skipReason = `Field "${field}" is only applicable to claude/claude-auth providers`; + } else if (isCodexOnly && !isCodexProviderType(provider.providerType)) { + isCompatible = false; + skipReason = `Field "${field}" is only applicable to codex providers`; + } else if (isGeminiOnly && !isGeminiProviderType(provider.providerType)) { + isCompatible = false; + skipReason = `Field "${field}" is only applicable to gemini/gemini-cli providers`; + } + + if (isCompatible) { + rows.push({ + providerId: provider.id, + providerName: provider.name, + field, + status: "changed", + before, + after, + }); + } else { + rows.push({ + providerId: provider.id, + providerName: provider.name, + field, + status: "skipped", + before, + after, + skipReason, + }); + } + } + } + + return rows; +} + +export async function previewProviderBatchPatch( + input: unknown +): Promise> { + try { + const session = await getSession(); + if (!session || session.user.role !== "admin") { + return { ok: false, error: "无权限执行此操作" }; + } + + const parsed = PreviewProviderBatchPatchSchema.safeParse(input); + if (!parsed.success) { + return buildActionValidationError(parsed.error); + } + + const normalizedPatch = normalizeProviderBatchPatchDraft(parsed.data.patch); + if (!normalizedPatch.ok) { + return { + ok: false, + error: normalizedPatch.error.message, + errorCode: PROVIDER_PATCH_ERROR_CODES.INVALID_PATCH_SHAPE, + }; + } + + if (!hasProviderBatchPatchChanges(normalizedPatch.data)) { + return buildNoChangesError(); + } + + const providerIds = dedupeProviderIds(parsed.data.providerIds); + const changedFields = getChangedPatchFields(normalizedPatch.data); + const nowMs = Date.now(); + + const allProviders = await findAllProvidersFresh(); + const providerIdSet = new Set(providerIds); + const matchedProviders = allProviders.filter((p) => providerIdSet.has(p.id)); + const rows = generatePreviewRows(matchedProviders, normalizedPatch.data, changedFields); + const skipCount = rows.filter((r) => r.status === "skipped").length; + + const previewToken = createProviderBatchPreviewToken(); + const previewRevision = `${nowMs}:${providerIds.join(",")}:${changedFields.join(",")}`; + const previewExpiresAt = nowMs + PROVIDER_BATCH_PREVIEW_TTL_SECONDS * 1000; + + await providerBatchPatchPreviewStore.set(previewToken, { + previewToken, + previewRevision, + providerIds, + patch: normalizedPatch.data, + patchSerialized: JSON.stringify(normalizedPatch.data), + changedFields, + rows, + applied: false, + appliedResultByIdempotencyKey: {}, + }); + + return { + ok: true, + data: { + previewToken, + previewRevision, + previewExpiresAt: new Date(previewExpiresAt).toISOString(), + providerIds, + changedFields, + rows, + summary: { + providerCount: providerIds.length, + fieldCount: changedFields.length, + skipCount, + }, + }, + }; + } catch (error) { + logger.error("预览批量补丁失败:", error); + const message = error instanceof Error ? error.message : "预览批量补丁失败"; + return { ok: false, error: message }; + } +} + +export async function applyProviderBatchPatch( + input: unknown +): Promise> { + try { + const session = await getSession(); + if (!session || session.user.role !== "admin") { + return { ok: false, error: "无权限执行此操作" }; + } + + const parsed = ApplyProviderBatchPatchSchema.safeParse(input); + if (!parsed.success) { + return buildActionValidationError(parsed.error); + } + + const nowMs = Date.now(); + + const snapshot = await providerBatchPatchPreviewStore.get(parsed.data.previewToken); + if (!snapshot) { + return { + ok: false, + error: "预览已过期,请重新预览", + errorCode: PROVIDER_BATCH_PATCH_ERROR_CODES.PREVIEW_EXPIRED, + }; + } + + const normalizedPatch = normalizeProviderBatchPatchDraft(parsed.data.patch); + if (!normalizedPatch.ok) { + return { + ok: false, + error: normalizedPatch.error.message, + errorCode: PROVIDER_PATCH_ERROR_CODES.INVALID_PATCH_SHAPE, + }; + } + + if (!hasProviderBatchPatchChanges(normalizedPatch.data)) { + return buildNoChangesError(); + } + + const providerIds = dedupeProviderIds(parsed.data.providerIds); + const patchSerialized = JSON.stringify(normalizedPatch.data); + const isStale = + parsed.data.previewRevision !== snapshot.previewRevision || + !isSameProviderIdList(providerIds, snapshot.providerIds) || + patchSerialized !== snapshot.patchSerialized; + + if (parsed.data.idempotencyKey) { + const existingResult = snapshot.appliedResultByIdempotencyKey[parsed.data.idempotencyKey]; + if (existingResult) { + return { ok: true, data: existingResult }; + } + } + + if (isStale || snapshot.applied) { + return { + ok: false, + error: "预览内容已失效,请重新预览", + errorCode: PROVIDER_BATCH_PATCH_ERROR_CODES.PREVIEW_STALE, + }; + } + + const excludeSet = new Set(parsed.data.excludeProviderIds ?? []); + const effectiveProviderIds = providerIds.filter((id) => !excludeSet.has(id)); + if (effectiveProviderIds.length === 0) { + return { + ok: false, + error: "排除后无可应用的供应商", + errorCode: PROVIDER_BATCH_PATCH_ERROR_CODES.NOTHING_TO_APPLY, + }; + } + + const updatesResult = buildProviderBatchApplyUpdates(normalizedPatch.data); + if (!updatesResult.ok) { + return { + ok: false, + error: updatesResult.error.message, + errorCode: PROVIDER_PATCH_ERROR_CODES.INVALID_PATCH_SHAPE, + }; + } + + const allProviders = await findAllProvidersFresh(); + const effectiveIdSet = new Set(effectiveProviderIds); + const matchedProviders = allProviders.filter((p) => effectiveIdSet.has(p.id)); + const changedFields = getChangedPatchFields(normalizedPatch.data); + const preimage: Record> = {}; + for (const provider of matchedProviders) { + const fieldValues: Record = {}; + for (const field of changedFields) { + const providerKey = PATCH_FIELD_TO_PROVIDER_KEY[field]; + fieldValues[providerKey] = provider[providerKey]; + } + preimage[provider.id] = fieldValues; + } + + const repositoryUpdates = mapApplyUpdatesToRepositoryFormat(updatesResult.data); + + const hasTypeSpecificFields = changedFields.some( + (f) => CLAUDE_ONLY_FIELDS.has(f) || CODEX_ONLY_FIELDS.has(f) || GEMINI_ONLY_FIELDS.has(f) + ); + + let dbUpdatedCount: number; + if (!hasTypeSpecificFields) { + dbUpdatedCount = await updateProvidersBatch(effectiveProviderIds, repositoryUpdates); + } else { + const providersByType = new Map(); + for (const provider of matchedProviders) { + const type = provider.providerType; + if (!providersByType.has(type)) providersByType.set(type, []); + providersByType.get(type)!.push(provider.id); + } + + dbUpdatedCount = 0; + for (const [type, ids] of providersByType) { + const filtered = filterRepositoryUpdatesByProviderType(repositoryUpdates, type); + if (Object.keys(filtered).length > 0) { + dbUpdatedCount += await updateProvidersBatch(ids, filtered); + } + } + } + + await publishProviderCacheInvalidation(); + + const appliedAt = new Date(nowMs).toISOString(); + const undoToken = createProviderPatchUndoToken(); + const undoExpiresAtMs = nowMs + PROVIDER_PATCH_UNDO_TTL_SECONDS * 1000; + + const applyResult: ApplyProviderBatchPatchResult = { + operationId: createProviderPatchOperationId(), + appliedAt, + updatedCount: dbUpdatedCount, + undoToken, + undoExpiresAt: new Date(undoExpiresAtMs).toISOString(), + }; + + snapshot.applied = true; + if (parsed.data.idempotencyKey) { + snapshot.appliedResultByIdempotencyKey[parsed.data.idempotencyKey] = applyResult; + } + await providerBatchPatchPreviewStore.set(parsed.data.previewToken, snapshot); + + await providerPatchUndoStore.set(undoToken, { + undoToken, + operationId: applyResult.operationId, + providerIds: effectiveProviderIds, + preimage, + patch: normalizedPatch.data, + }); + + return { ok: true, data: applyResult }; + } catch (error) { + logger.error("应用批量补丁失败:", error); + const message = error instanceof Error ? error.message : "应用批量补丁失败"; + return { ok: false, error: message }; + } +} + +export async function undoProviderPatch( + input: unknown +): Promise> { + try { + const session = await getSession(); + if (!session || session.user.role !== "admin") { + return { ok: false, error: "无权限执行此操作" }; + } + + const parsed = UndoProviderPatchSchema.safeParse(input); + if (!parsed.success) { + return buildActionValidationError(parsed.error); + } + + const nowMs = Date.now(); + + const snapshot = await providerPatchUndoStore.get(parsed.data.undoToken); + if (!snapshot) { + return { + ok: false, + error: "撤销窗口已过期", + errorCode: PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_EXPIRED, + }; + } + + if (snapshot.operationId !== parsed.data.operationId) { + return { + ok: false, + error: "撤销参数与操作不匹配", + errorCode: PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_CONFLICT, + }; + } + + // Delete after validation passes so operationId mismatch doesn't destroy the token + await providerPatchUndoStore.delete(parsed.data.undoToken); + + // Group providers by identical preimage values to minimise DB round-trips + const preimageGroups = new Map(); + + for (const providerId of snapshot.providerIds) { + const providerPreimage = snapshot.preimage[providerId]; + if (!providerPreimage || Object.keys(providerPreimage).length === 0) { + continue; + } + + const updatesObj: Record = {}; + for (const [key, value] of Object.entries(providerPreimage)) { + if (key === "costMultiplier" && typeof value === "number") { + updatesObj[key] = value.toString(); + } else { + updatesObj[key] = value; + } + } + const updates = updatesObj as BatchProviderUpdates; + + const groupKey = JSON.stringify(updates); + const existing = preimageGroups.get(groupKey); + if (existing) { + existing.ids.push(providerId); + } else { + preimageGroups.set(groupKey, { ids: [providerId], updates }); + } + } + + let revertedCount = 0; + for (const { ids, updates } of preimageGroups.values()) { + const count = await updateProvidersBatch(ids, updates); + revertedCount += count; + } + + if (preimageGroups.size > 0) { + await publishProviderCacheInvalidation(); + } + + return { + ok: true, + data: { + operationId: snapshot.operationId, + revertedAt: new Date(nowMs).toISOString(), + revertedCount, + }, + }; + } catch (error) { + logger.error("撤销批量补丁失败:", error); + const message = error instanceof Error ? error.message : "撤销批量补丁失败"; + return { ok: false, error: message }; + } +} export interface BatchUpdateProvidersParams { providerIds: number[]; @@ -1032,6 +2029,10 @@ export interface BatchUpdateProvidersParams { weight?: number; cost_multiplier?: number; group_tag?: string | null; + model_redirects?: Record | null; + allowed_models?: string[] | null; + anthropic_thinking_budget_preference?: AnthropicThinkingBudgetPreference | null; + anthropic_adaptive_thinking?: AnthropicAdaptiveThinkingConfig | null; }; } @@ -1069,6 +2070,22 @@ export async function batchUpdateProviders( repositoryUpdates.costMultiplier = updates.cost_multiplier.toString(); } if (updates.group_tag !== undefined) repositoryUpdates.groupTag = updates.group_tag; + if (updates.model_redirects !== undefined) { + repositoryUpdates.modelRedirects = updates.model_redirects; + } + if (updates.allowed_models !== undefined) { + repositoryUpdates.allowedModels = + Array.isArray(updates.allowed_models) && updates.allowed_models.length === 0 + ? null + : updates.allowed_models; + } + if (updates.anthropic_thinking_budget_preference !== undefined) { + repositoryUpdates.anthropicThinkingBudgetPreference = + updates.anthropic_thinking_budget_preference; + } + if (updates.anthropic_adaptive_thinking !== undefined) { + repositoryUpdates.anthropicAdaptiveThinking = updates.anthropic_adaptive_thinking; + } const updatedCount = await updateProvidersBatch(providerIds, repositoryUpdates); @@ -1097,7 +2114,7 @@ export interface BatchDeleteProvidersParams { export async function batchDeleteProviders( params: BatchDeleteProvidersParams -): Promise> { +): Promise> { try { const session = await getSession(); if (!session || session.user.role !== "admin") { @@ -1114,26 +2131,45 @@ export async function batchDeleteProviders( return { ok: false, error: `单次批量操作最多支持 ${BATCH_OPERATION_MAX_SIZE} 个供应商` }; } + const snapshotProviderIds = dedupeProviderIds(providerIds); + const { deleteProvidersBatch } = await import("@/repository/provider"); - const deletedCount = await deleteProvidersBatch(providerIds); + const deletedCount = await deleteProvidersBatch(snapshotProviderIds); - for (const id of providerIds) { + const undoToken = createProviderPatchUndoToken(); + const operationId = createProviderPatchOperationId(); + + await providerDeleteUndoStore.set(undoToken, { + undoToken, + operationId, + providerIds: snapshotProviderIds, + }); + + for (const id of snapshotProviderIds) { clearProviderState(id); clearConfigCache(id); } await broadcastProviderCacheInvalidation({ operation: "remove", - providerId: providerIds[0], + providerId: snapshotProviderIds[0], }); logger.info("batchDeleteProviders:completed", { - requestedCount: providerIds.length, + requestedCount: snapshotProviderIds.length, deletedCount, + operationId, }); - return { ok: true, data: { deletedCount } }; + return { + ok: true, + data: { + deletedCount, + undoToken, + operationId, + }, + }; } catch (error) { logger.error("批量删除供应商失败:", error); const message = error instanceof Error ? error.message : "批量删除供应商失败"; @@ -1141,6 +2177,66 @@ export async function batchDeleteProviders( } } +export async function undoProviderDelete( + input: unknown +): Promise> { + try { + const session = await getSession(); + if (!session || session.user.role !== "admin") { + return { ok: false, error: "无权限执行此操作" }; + } + + const parsed = UndoProviderDeleteSchema.safeParse(input); + if (!parsed.success) { + return buildActionValidationError(parsed.error); + } + + const nowMs = Date.now(); + + const snapshot = await providerDeleteUndoStore.get(parsed.data.undoToken); + if (!snapshot) { + return { + ok: false, + error: "撤销窗口已过期", + errorCode: PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_EXPIRED, + }; + } + + if (snapshot.operationId !== parsed.data.operationId) { + return { + ok: false, + error: "撤销参数与操作不匹配", + errorCode: PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_CONFLICT, + }; + } + + // Delete after validation passes so operationId mismatch doesn't destroy the token + await providerDeleteUndoStore.delete(parsed.data.undoToken); + + const restoredCount = await restoreProvidersBatch(snapshot.providerIds); + + for (const id of snapshot.providerIds) { + clearProviderState(id); + clearConfigCache(id); + } + + await publishProviderCacheInvalidation(); + + return { + ok: true, + data: { + operationId: snapshot.operationId, + restoredAt: new Date(nowMs).toISOString(), + restoredCount, + }, + }; + } catch (error) { + logger.error("撤销批量删除失败:", error); + const message = error instanceof Error ? error.message : "撤销批量删除失败"; + return { ok: false, error: message }; + } +} + export interface BatchResetCircuitParams { providerIds: number[]; } diff --git a/src/app/[locale]/login/loading.tsx b/src/app/[locale]/login/loading.tsx index cc0c65a01..d1dc82c19 100644 --- a/src/app/[locale]/login/loading.tsx +++ b/src/app/[locale]/login/loading.tsx @@ -3,13 +3,32 @@ import { Skeleton } from "@/components/ui/skeleton"; export default function LoginLoading() { return ( -
-
- - - - - +
+ {/* Brand Panel Skeleton - Desktop Only */} +
+
+ + + +
+
+ + {/* Form Panel Skeleton */} +
+ {/* Mobile Brand Skeleton */} +
+ + + +
+ +
+ + + + + +
); diff --git a/src/app/[locale]/login/page.tsx b/src/app/[locale]/login/page.tsx index 170948455..48421f025 100644 --- a/src/app/[locale]/login/page.tsx +++ b/src/app/[locale]/login/page.tsx @@ -1,16 +1,19 @@ "use client"; -import { AlertTriangle, Book, Key, Loader2 } from "lucide-react"; +import { motion } from "framer-motion"; +import { AlertTriangle, Book, ExternalLink, Eye, EyeOff, Key, Loader2 } from "lucide-react"; import { useSearchParams } from "next/navigation"; import { useTranslations } from "next-intl"; -import { Suspense, useEffect, useState } from "react"; +import { Suspense, useEffect, useRef, useState } from "react"; import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { LanguageSwitcher } from "@/components/ui/language-switcher"; +import { ThemeSwitcher } from "@/components/ui/theme-switcher"; import { Link, useRouter } from "@/i18n/routing"; +import { resolveLoginRedirectTarget } from "./redirect-safety"; export default function LoginPage() { return ( @@ -20,18 +23,92 @@ export default function LoginPage() { ); } +type LoginStatus = "idle" | "submitting" | "success" | "error"; +type LoginType = "admin" | "dashboard_user" | "readonly_user"; + +interface LoginVersionInfo { + current: string; + hasUpdate: boolean; +} + +const DEFAULT_SITE_TITLE = "Claude Code Hub"; + +function parseLoginType(value: unknown): LoginType | null { + if (value === "admin" || value === "dashboard_user" || value === "readonly_user") { + return value; + } + + return null; +} + +function getLoginTypeFallbackPath(loginType: LoginType): string { + return loginType === "readonly_user" ? "/my-usage" : "/dashboard"; +} + +function formatVersionLabel(version: string): string { + const trimmed = version.trim(); + if (!trimmed) return ""; + return /^v/i.test(trimmed) ? `v${trimmed.slice(1)}` : `v${trimmed}`; +} + +const floatAnimation = { + y: [0, -20, 0], + transition: { + duration: 6, + repeat: Number.POSITIVE_INFINITY, + ease: "easeInOut" as const, + }, +}; + +const floatAnimationSlow = { + y: [0, -15, 0], + transition: { + duration: 8, + repeat: Number.POSITIVE_INFINITY, + ease: "easeInOut" as const, + }, +}; + +const brandPanelVariants = { + hidden: { opacity: 0, x: -40 }, + visible: { + opacity: 1, + x: 0, + transition: { type: "spring" as const, stiffness: 300, damping: 30 }, + }, +}; + +const stagger = { + hidden: { opacity: 0, y: 20 }, + visible: (delay: number) => ({ + opacity: 1, + y: 0, + transition: { duration: 0.4, delay, ease: "easeOut" as const }, + }), +}; + function LoginPageContent() { const t = useTranslations("auth"); + const tCustoms = useTranslations("customs"); const router = useRouter(); const searchParams = useSearchParams(); - const from = searchParams.get("from") || "/dashboard"; + const from = searchParams.get("from") || ""; + const apiKeyInputRef = useRef(null); const [apiKey, setApiKey] = useState(""); - const [loading, setLoading] = useState(false); + const [status, setStatus] = useState("idle"); const [error, setError] = useState(""); const [showHttpWarning, setShowHttpWarning] = useState(false); + const [showPassword, setShowPassword] = useState(false); + const [versionInfo, setVersionInfo] = useState(null); + const [siteTitle, setSiteTitle] = useState(DEFAULT_SITE_TITLE); + + useEffect(() => { + if (status === "error" && apiKeyInputRef.current) { + apiKeyInputRef.current.focus(); + } + }, [status]); - // 检测是否为 HTTP(非 localhost) useEffect(() => { if (typeof window !== "undefined") { const isHttp = window.location.protocol === "http:"; @@ -41,10 +118,60 @@ function LoginPageContent() { } }, []); + useEffect(() => { + let active = true; + + void fetch("/api/version") + .then((response) => response.json() as Promise<{ current?: unknown; hasUpdate?: unknown }>) + .then((data) => { + if (!active || typeof data.current !== "string") { + return; + } + + setVersionInfo({ + current: data.current, + hasUpdate: Boolean(data.hasUpdate), + }); + }) + .catch(() => {}); + + return () => { + active = false; + }; + }, []); + + useEffect(() => { + let active = true; + + void fetch("/api/system-settings") + .then((response) => { + if (!response.ok) { + return null; + } + + return response.json() as Promise<{ siteTitle?: unknown }>; + }) + .then((data) => { + if (!active || !data || typeof data.siteTitle !== "string") { + return; + } + + const nextSiteTitle = data.siteTitle.trim(); + if (nextSiteTitle) { + setSiteTitle(nextSiteTitle); + } + }) + .catch(() => {}); + + return () => { + active = false; + }; + }, []); + const handleSubmit = async (e: React.FormEvent) => { e.preventDefault(); setError(""); - setLoading(true); + setStatus("submitting"); try { const response = await fetch("/api/auth/login", { @@ -57,121 +184,248 @@ function LoginPageContent() { if (!response.ok) { setError(data.error || t("errors.loginFailed")); + setStatus("error"); return; } - // 登录成功,按服务端返回的目标跳转,回退到原页面 - const redirectTarget = data.redirectTo || from; + setStatus("success"); + const loginType = parseLoginType(data.loginType); + const fallbackPath = loginType ? getLoginTypeFallbackPath(loginType) : from; + const redirectTarget = resolveLoginRedirectTarget(data.redirectTo, fallbackPath); router.push(redirectTarget); router.refresh(); } catch { setError(t("errors.networkError")); - } finally { - setLoading(false); + setStatus("error"); } }; + const isLoading = status === "submitting" || status === "success"; + return ( -
- {/* Language Switcher - Fixed Top Right */} -
+
+ {/* Fullscreen Loading Overlay */} + {isLoading && ( +
+ +

+ {t("login.loggingIn")} +

+
+ )} + + {/* Top Right Controls */} +
+ + + {t("actions.viewUsageDoc")} + + + + +
-
-
-
+ {/* Background Orbs */} +
+ +
-
- - -
-
- -
-
- {t("form.title")} - {t("form.description")} -
+ {/* Main Layout */} +
+ {/* Brand Panel - Desktop Only */} + + {/* Brand Panel Gradient Background */} +
+ + {/* Brand Panel Animated Orb */} + + +
+
+
- - - {showHttpWarning ? ( - - - {t("security.cookieWarningTitle")} - -

{t("security.cookieWarningDescription")}

-
-

{t("security.solutionTitle")}

-
    -
  1. {t("security.useHttps")}
  2. -
  3. {t("security.disableSecureCookies")}
  4. -
+

{siteTitle}

+

{t("brand.tagline")}

+
+
+ + + {/* Form Panel */} +
+ {/* Mobile Brand Header */} +
+
+ +
+
+

{siteTitle}

+

{t("brand.tagline")}

+
+
+ +
+ + + +
+
- - - ) : null} -
-
-
- -
- - setApiKey(e.target.value)} - className="pl-9" - required - disabled={loading} - /> +
+ + {t("form.title")} + + {t("form.description")}
-
- - {error ? ( - - {error} - - ) : null} -
- -
- -

- {t("security.privacyNote")} -

-
- - - {/* 文档页入口 */} -
- - - {t("actions.viewUsageDoc")} - -
- - + + + {showHttpWarning ? ( + + + {t("security.cookieWarningTitle")} + +

{t("security.cookieWarningDescription")}

+
+

{t("security.solutionTitle")}

+
    +
  1. {t("security.useHttps")}
  2. +
  3. {t("security.disableSecureCookies")}
  4. +
+
+
+
+ ) : null} +
+ +
+ +
+ + setApiKey(e.target.value)} + className="pl-9 pr-10" + required + disabled={isLoading} + /> + +
+
+ + {error ? ( + + {error} + + ) : null} +
+ + + +

+ {t("security.privacyNote")} +

+
+
+
+ + +
+
+
+ + {/* Page Footer */} +
+

+ {siteTitle} +

+ + {versionInfo?.current ? ( +
+ {formatVersionLabel(versionInfo.current)} + {versionInfo.hasUpdate ? ( + {tCustoms("version.updateAvailable")} + ) : null} +
+ ) : null}
); diff --git a/src/app/[locale]/login/redirect-safety.ts b/src/app/[locale]/login/redirect-safety.ts new file mode 100644 index 000000000..641ea8a6a --- /dev/null +++ b/src/app/[locale]/login/redirect-safety.ts @@ -0,0 +1,37 @@ +const DEFAULT_REDIRECT_PATH = "/dashboard"; +const PROTOCOL_LIKE_PATTERN = /^[a-zA-Z][a-zA-Z\d+.-]*:/; + +export function sanitizeRedirectPath(from: string): string { + const candidate = from.trim(); + + if (!candidate) { + return DEFAULT_REDIRECT_PATH; + } + + if (!candidate.startsWith("/")) { + return DEFAULT_REDIRECT_PATH; + } + + if (candidate.startsWith("//")) { + return DEFAULT_REDIRECT_PATH; + } + + if (PROTOCOL_LIKE_PATTERN.test(candidate)) { + return DEFAULT_REDIRECT_PATH; + } + + const withoutLeadingSlash = candidate.slice(1); + if (PROTOCOL_LIKE_PATTERN.test(withoutLeadingSlash)) { + return DEFAULT_REDIRECT_PATH; + } + + return candidate; +} + +export function resolveLoginRedirectTarget(redirectTo: unknown, from: string): string { + if (typeof redirectTo === "string" && redirectTo.trim().length > 0) { + return sanitizeRedirectPath(redirectTo); + } + + return sanitizeRedirectPath(from); +} diff --git a/src/app/[locale]/settings/providers/_components/adaptive-thinking-editor.tsx b/src/app/[locale]/settings/providers/_components/adaptive-thinking-editor.tsx new file mode 100644 index 000000000..f7537553c --- /dev/null +++ b/src/app/[locale]/settings/providers/_components/adaptive-thinking-editor.tsx @@ -0,0 +1,178 @@ +"use client"; + +import { Info } from "lucide-react"; +import { useTranslations } from "next-intl"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Switch } from "@/components/ui/switch"; +import { TagInput } from "@/components/ui/tag-input"; +import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import type { + AnthropicAdaptiveThinkingConfig, + AnthropicAdaptiveThinkingEffort, + AnthropicAdaptiveThinkingModelMatchMode, +} from "@/types/provider"; +import { SmartInputWrapper, ToggleRow } from "./forms/provider-form/components/section-card"; + +interface AdaptiveThinkingEditorProps { + enabled: boolean; + config: AnthropicAdaptiveThinkingConfig; + onEnabledChange: (enabled: boolean) => void; + onConfigChange: (config: AnthropicAdaptiveThinkingConfig) => void; + disabled?: boolean; +} + +export function AdaptiveThinkingEditor({ + enabled, + config, + onEnabledChange, + onConfigChange, + disabled = false, +}: AdaptiveThinkingEditorProps) { + const t = useTranslations("settings.providers.form"); + + const handleEffortChange = (effort: AnthropicAdaptiveThinkingEffort) => { + onConfigChange({ + ...config, + effort, + }); + }; + + const handleModeChange = (modelMatchMode: AnthropicAdaptiveThinkingModelMatchMode) => { + onConfigChange({ + ...config, + modelMatchMode, + }); + }; + + const handleModelsChange = (models: string[]) => { + onConfigChange({ + ...config, + models, + }); + }; + + return ( +
+ + + + + {enabled && ( +
+ + + +
+ + +
+
+ +

+ {t("sections.routing.anthropicOverrides.adaptiveThinking.effort.help")} +

+
+
+
+ + + + +
+ + +
+
+ +

+ {t("sections.routing.anthropicOverrides.adaptiveThinking.modelMatchMode.help")} +

+
+
+
+ + {config.modelMatchMode === "specific" && ( + + + +
+ + +
+
+ +

+ {t("sections.routing.anthropicOverrides.adaptiveThinking.models.help")} +

+
+
+
+ )} +
+ )} +
+ ); +} diff --git a/src/app/[locale]/settings/providers/_components/add-provider-dialog.tsx b/src/app/[locale]/settings/providers/_components/add-provider-dialog.tsx index e8d944292..917d30a05 100644 --- a/src/app/[locale]/settings/providers/_components/add-provider-dialog.tsx +++ b/src/app/[locale]/settings/providers/_components/add-provider-dialog.tsx @@ -1,11 +1,11 @@ "use client"; +import { VisuallyHidden } from "@radix-ui/react-visually-hidden"; import { ServerCog } from "lucide-react"; import { useTranslations } from "next-intl"; import { useState } from "react"; import { FormErrorBoundary } from "@/components/form-error-boundary"; import { Button } from "@/components/ui/button"; import { Dialog, DialogContent, DialogTitle, DialogTrigger } from "@/components/ui/dialog"; -import { VisuallyHidden } from "@radix-ui/react-visually-hidden"; import { ProviderForm } from "./forms/provider-form"; interface AddProviderDialogProps { diff --git a/src/app/[locale]/settings/providers/_components/batch-edit/build-patch-draft.ts b/src/app/[locale]/settings/providers/_components/batch-edit/build-patch-draft.ts new file mode 100644 index 000000000..d9b2e2b7e --- /dev/null +++ b/src/app/[locale]/settings/providers/_components/batch-edit/build-patch-draft.ts @@ -0,0 +1,290 @@ +import type { ProviderBatchPatchDraft } from "@/types/provider"; +import type { ProviderFormState } from "../forms/provider-form/provider-form-types"; + +/** + * Builds a ProviderBatchPatchDraft from the current form state, + * including only fields that the user has actually modified (dirty fields). + * + * Unit conversions: + * - circuitBreaker.openDurationMinutes (minutes) -> circuit_breaker_open_duration (ms) + * - network.*Seconds (seconds) -> *_ms (ms) + */ +export function buildPatchDraftFromFormState( + state: ProviderFormState, + dirtyFields: Set +): ProviderBatchPatchDraft { + const draft: ProviderBatchPatchDraft = {}; + + // Batch-specific: isEnabled + if (dirtyFields.has("batch.isEnabled")) { + if (state.batch.isEnabled !== "no_change") { + draft.is_enabled = { set: state.batch.isEnabled === "true" }; + } + } + + // Routing fields + if (dirtyFields.has("routing.priority")) { + draft.priority = { set: state.routing.priority }; + } + if (dirtyFields.has("routing.weight")) { + draft.weight = { set: state.routing.weight }; + } + if (dirtyFields.has("routing.costMultiplier")) { + draft.cost_multiplier = { set: state.routing.costMultiplier }; + } + if (dirtyFields.has("routing.groupTag")) { + const joined = state.routing.groupTag.join(", "); + if (joined === "") { + draft.group_tag = { clear: true }; + } else { + draft.group_tag = { set: joined }; + } + } + if (dirtyFields.has("routing.preserveClientIp")) { + draft.preserve_client_ip = { set: state.routing.preserveClientIp }; + } + if (dirtyFields.has("routing.modelRedirects")) { + const entries = Object.keys(state.routing.modelRedirects); + if (entries.length === 0) { + draft.model_redirects = { clear: true }; + } else { + draft.model_redirects = { set: state.routing.modelRedirects }; + } + } + if (dirtyFields.has("routing.allowedModels")) { + if (state.routing.allowedModels.length === 0) { + draft.allowed_models = { clear: true }; + } else { + draft.allowed_models = { set: state.routing.allowedModels }; + } + } + if (dirtyFields.has("routing.groupPriorities")) { + const entries = Object.keys(state.routing.groupPriorities); + if (entries.length === 0) { + draft.group_priorities = { clear: true }; + } else { + draft.group_priorities = { set: state.routing.groupPriorities }; + } + } + if (dirtyFields.has("routing.cacheTtlPreference")) { + if (state.routing.cacheTtlPreference === "inherit") { + draft.cache_ttl_preference = { clear: true }; + } else { + draft.cache_ttl_preference = { set: state.routing.cacheTtlPreference }; + } + } + if (dirtyFields.has("routing.swapCacheTtlBilling")) { + draft.swap_cache_ttl_billing = { set: state.routing.swapCacheTtlBilling }; + } + if (dirtyFields.has("routing.context1mPreference")) { + if (state.routing.context1mPreference === "inherit") { + draft.context_1m_preference = { clear: true }; + } else { + draft.context_1m_preference = { set: state.routing.context1mPreference }; + } + } + + // Codex preferences + if (dirtyFields.has("routing.codexReasoningEffortPreference")) { + if (state.routing.codexReasoningEffortPreference === "inherit") { + draft.codex_reasoning_effort_preference = { clear: true }; + } else { + draft.codex_reasoning_effort_preference = { + set: state.routing.codexReasoningEffortPreference, + }; + } + } + if (dirtyFields.has("routing.codexReasoningSummaryPreference")) { + if (state.routing.codexReasoningSummaryPreference === "inherit") { + draft.codex_reasoning_summary_preference = { clear: true }; + } else { + draft.codex_reasoning_summary_preference = { + set: state.routing.codexReasoningSummaryPreference, + }; + } + } + if (dirtyFields.has("routing.codexTextVerbosityPreference")) { + if (state.routing.codexTextVerbosityPreference === "inherit") { + draft.codex_text_verbosity_preference = { clear: true }; + } else { + draft.codex_text_verbosity_preference = { set: state.routing.codexTextVerbosityPreference }; + } + } + if (dirtyFields.has("routing.codexParallelToolCallsPreference")) { + if (state.routing.codexParallelToolCallsPreference === "inherit") { + draft.codex_parallel_tool_calls_preference = { clear: true }; + } else { + draft.codex_parallel_tool_calls_preference = { + set: state.routing.codexParallelToolCallsPreference, + }; + } + } + + // Anthropic preferences + if (dirtyFields.has("routing.anthropicMaxTokensPreference")) { + if (state.routing.anthropicMaxTokensPreference === "inherit") { + draft.anthropic_max_tokens_preference = { clear: true }; + } else { + draft.anthropic_max_tokens_preference = { set: state.routing.anthropicMaxTokensPreference }; + } + } + if (dirtyFields.has("routing.anthropicThinkingBudgetPreference")) { + if (state.routing.anthropicThinkingBudgetPreference === "inherit") { + draft.anthropic_thinking_budget_preference = { clear: true }; + } else { + draft.anthropic_thinking_budget_preference = { + set: state.routing.anthropicThinkingBudgetPreference, + }; + } + } + if (dirtyFields.has("routing.anthropicAdaptiveThinking")) { + if (state.routing.anthropicAdaptiveThinking === null) { + draft.anthropic_adaptive_thinking = { clear: true }; + } else { + draft.anthropic_adaptive_thinking = { set: state.routing.anthropicAdaptiveThinking }; + } + } + + // Gemini preferences + if (dirtyFields.has("routing.geminiGoogleSearchPreference")) { + if (state.routing.geminiGoogleSearchPreference === "inherit") { + draft.gemini_google_search_preference = { clear: true }; + } else { + draft.gemini_google_search_preference = { set: state.routing.geminiGoogleSearchPreference }; + } + } + + // Rate limit fields + if (dirtyFields.has("rateLimit.limit5hUsd")) { + if (state.rateLimit.limit5hUsd === null) { + draft.limit_5h_usd = { clear: true }; + } else { + draft.limit_5h_usd = { set: state.rateLimit.limit5hUsd }; + } + } + if (dirtyFields.has("rateLimit.limitDailyUsd")) { + if (state.rateLimit.limitDailyUsd === null) { + draft.limit_daily_usd = { clear: true }; + } else { + draft.limit_daily_usd = { set: state.rateLimit.limitDailyUsd }; + } + } + if (dirtyFields.has("rateLimit.dailyResetMode")) { + draft.daily_reset_mode = { set: state.rateLimit.dailyResetMode }; + } + if (dirtyFields.has("rateLimit.dailyResetTime")) { + draft.daily_reset_time = { set: state.rateLimit.dailyResetTime }; + } + if (dirtyFields.has("rateLimit.limitWeeklyUsd")) { + if (state.rateLimit.limitWeeklyUsd === null) { + draft.limit_weekly_usd = { clear: true }; + } else { + draft.limit_weekly_usd = { set: state.rateLimit.limitWeeklyUsd }; + } + } + if (dirtyFields.has("rateLimit.limitMonthlyUsd")) { + if (state.rateLimit.limitMonthlyUsd === null) { + draft.limit_monthly_usd = { clear: true }; + } else { + draft.limit_monthly_usd = { set: state.rateLimit.limitMonthlyUsd }; + } + } + if (dirtyFields.has("rateLimit.limitTotalUsd")) { + if (state.rateLimit.limitTotalUsd === null) { + draft.limit_total_usd = { clear: true }; + } else { + draft.limit_total_usd = { set: state.rateLimit.limitTotalUsd }; + } + } + if (dirtyFields.has("rateLimit.limitConcurrentSessions")) { + if (state.rateLimit.limitConcurrentSessions === null) { + draft.limit_concurrent_sessions = { set: 0 }; + } else { + draft.limit_concurrent_sessions = { set: state.rateLimit.limitConcurrentSessions }; + } + } + + // Circuit breaker fields (minutes -> ms conversion for open duration) + if (dirtyFields.has("circuitBreaker.failureThreshold")) { + if (state.circuitBreaker.failureThreshold === undefined) { + draft.circuit_breaker_failure_threshold = { set: 0 }; + } else { + draft.circuit_breaker_failure_threshold = { set: state.circuitBreaker.failureThreshold }; + } + } + if (dirtyFields.has("circuitBreaker.openDurationMinutes")) { + if (state.circuitBreaker.openDurationMinutes === undefined) { + draft.circuit_breaker_open_duration = { set: 0 }; + } else { + // Convert minutes to milliseconds + draft.circuit_breaker_open_duration = { + set: state.circuitBreaker.openDurationMinutes * 60000, + }; + } + } + if (dirtyFields.has("circuitBreaker.halfOpenSuccessThreshold")) { + if (state.circuitBreaker.halfOpenSuccessThreshold === undefined) { + draft.circuit_breaker_half_open_success_threshold = { set: 0 }; + } else { + draft.circuit_breaker_half_open_success_threshold = { + set: state.circuitBreaker.halfOpenSuccessThreshold, + }; + } + } + if (dirtyFields.has("circuitBreaker.maxRetryAttempts")) { + if (state.circuitBreaker.maxRetryAttempts === null) { + draft.max_retry_attempts = { clear: true }; + } else { + draft.max_retry_attempts = { set: state.circuitBreaker.maxRetryAttempts }; + } + } + + // Network fields (seconds -> ms conversion) + if (dirtyFields.has("network.proxyUrl")) { + if (state.network.proxyUrl === "") { + draft.proxy_url = { clear: true }; + } else { + draft.proxy_url = { set: state.network.proxyUrl }; + } + } + if (dirtyFields.has("network.proxyFallbackToDirect")) { + draft.proxy_fallback_to_direct = { set: state.network.proxyFallbackToDirect }; + } + if (dirtyFields.has("network.firstByteTimeoutStreamingSeconds")) { + if (state.network.firstByteTimeoutStreamingSeconds !== undefined) { + draft.first_byte_timeout_streaming_ms = { + set: state.network.firstByteTimeoutStreamingSeconds * 1000, + }; + } + } + if (dirtyFields.has("network.streamingIdleTimeoutSeconds")) { + if (state.network.streamingIdleTimeoutSeconds !== undefined) { + draft.streaming_idle_timeout_ms = { set: state.network.streamingIdleTimeoutSeconds * 1000 }; + } + } + if (dirtyFields.has("network.requestTimeoutNonStreamingSeconds")) { + if (state.network.requestTimeoutNonStreamingSeconds !== undefined) { + draft.request_timeout_non_streaming_ms = { + set: state.network.requestTimeoutNonStreamingSeconds * 1000, + }; + } + } + + // MCP fields + if (dirtyFields.has("mcp.mcpPassthroughType")) { + if (state.mcp.mcpPassthroughType === "none") { + draft.mcp_passthrough_type = { set: "none" }; + } else { + draft.mcp_passthrough_type = { set: state.mcp.mcpPassthroughType }; + } + } + if (dirtyFields.has("mcp.mcpPassthroughUrl")) { + if (state.mcp.mcpPassthroughUrl === "") { + draft.mcp_passthrough_url = { clear: true }; + } else { + draft.mcp_passthrough_url = { set: state.mcp.mcpPassthroughUrl }; + } + } + + return draft; +} diff --git a/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-dialog.tsx b/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-dialog.tsx index 7dc7d2d5e..f55ac2ae7 100644 --- a/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-dialog.tsx +++ b/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-dialog.tsx @@ -6,10 +6,13 @@ import { useTranslations } from "next-intl"; import { useCallback, useMemo, useState } from "react"; import { toast } from "sonner"; import { - type BatchUpdateProvidersParams, + applyProviderBatchPatch, batchDeleteProviders, batchResetProviderCircuits, - batchUpdateProviders, + type PreviewProviderBatchPatchResult, + previewProviderBatchPatch, + undoProviderDelete, + undoProviderPatch, } from "@/actions/providers"; import { AlertDialog, @@ -30,184 +33,345 @@ import { DialogHeader, DialogTitle, } from "@/components/ui/dialog"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { Separator } from "@/components/ui/separator"; -import { Switch } from "@/components/ui/switch"; +import { PROVIDER_BATCH_PATCH_ERROR_CODES } from "@/lib/provider-batch-patch-error-codes"; +import type { ProviderDisplay } from "@/types/provider"; +import { FormTabNav } from "../forms/provider-form/components/form-tab-nav"; +import { + ProviderFormProvider, + useProviderForm, +} from "../forms/provider-form/provider-form-context"; +import { BasicInfoSection } from "../forms/provider-form/sections/basic-info-section"; +import { LimitsSection } from "../forms/provider-form/sections/limits-section"; +import { NetworkSection } from "../forms/provider-form/sections/network-section"; +import { RoutingSection } from "../forms/provider-form/sections/routing-section"; +import { TestingSection } from "../forms/provider-form/sections/testing-section"; +import { buildPatchDraftFromFormState } from "./build-patch-draft"; import type { BatchActionMode } from "./provider-batch-actions"; +import { ProviderBatchPreviewStep } from "./provider-batch-preview-step"; + +// --------------------------------------------------------------------------- +// Props +// --------------------------------------------------------------------------- export interface ProviderBatchDialogProps { open: boolean; mode: BatchActionMode; onOpenChange: (open: boolean) => void; selectedProviderIds: Set; + providers: ProviderDisplay[]; onSuccess?: () => void; } -interface EditFieldState { - isEnabledEnabled: boolean; - isEnabled: boolean; - priorityEnabled: boolean; - priority: string; - weightEnabled: boolean; - weight: string; - costMultiplierEnabled: boolean; - costMultiplier: string; - groupTagEnabled: boolean; - groupTag: string; -} - -const INITIAL_EDIT_STATE: EditFieldState = { - isEnabledEnabled: false, - isEnabled: true, - priorityEnabled: false, - priority: "", - weightEnabled: false, - weight: "", - costMultiplierEnabled: false, - costMultiplier: "", - groupTagEnabled: false, - groupTag: "", -}; +// --------------------------------------------------------------------------- +// Component +// --------------------------------------------------------------------------- export function ProviderBatchDialog({ open, mode, onOpenChange, selectedProviderIds, + providers, onSuccess, }: ProviderBatchDialogProps) { - const t = useTranslations("settings.providers.batchEdit"); - const queryClient = useQueryClient(); + // For edit mode: delegate to form-based dialog + if (mode === "edit") { + return ( + + ); + } - const [editState, setEditState] = useState(INITIAL_EDIT_STATE); - const [confirmOpen, setConfirmOpen] = useState(false); - const [isSubmitting, setIsSubmitting] = useState(false); + // For delete/resetCircuit: use AlertDialog + return ( + + ); +} + +// --------------------------------------------------------------------------- +// BatchEditDialog: Uses ProviderFormProvider mode="batch" +// --------------------------------------------------------------------------- +function BatchEditDialog({ + open, + onOpenChange, + selectedProviderIds, + providers, + onSuccess, +}: Omit) { const selectedCount = selectedProviderIds.size; - const hasEnabledFields = useMemo(() => { - if (mode !== "edit") return true; - return ( - editState.isEnabledEnabled || - editState.priorityEnabled || - editState.weightEnabled || - editState.costMultiplierEnabled || - editState.groupTagEnabled - ); - }, [mode, editState]); + const affectedProviders = useMemo(() => { + return providers.filter((p) => selectedProviderIds.has(p.id)); + }, [providers, selectedProviderIds]); - const resetState = useCallback(() => { - setEditState(INITIAL_EDIT_STATE); - setConfirmOpen(false); - setIsSubmitting(false); - }, []); + return ( + + + + + + + + ); +} + +// Inner component that can use useProviderForm() +type DialogStep = "edit" | "preview"; + +function BatchEditDialogContent({ + selectedProviderIds, + selectedCount, + onOpenChange, + onSuccess, +}: { + selectedProviderIds: Set; + selectedCount: number; + onOpenChange: (open: boolean) => void; + onSuccess?: () => void; +}) { + const t = useTranslations("settings.providers.batchEdit"); + const queryClient = useQueryClient(); + const { state, dispatch, dirtyFields } = useProviderForm(); - const handleOpenChange = useCallback( - (newOpen: boolean) => { - if (!newOpen) { - resetState(); + const [step, setStep] = useState("edit"); + const [isSubmitting, setIsSubmitting] = useState(false); + const [isLoadingPreview, setIsLoadingPreview] = useState(false); + const [previewResult, setPreviewResult] = useState(null); + const [excludedProviderIds, setExcludedProviderIds] = useState>(new Set()); + + const hasChanges = dirtyFields.size > 0; + + const handleExcludeToggle = useCallback((providerId: number) => { + setExcludedProviderIds((prev) => { + const next = new Set(prev); + if (next.has(providerId)) { + next.delete(providerId); + } else { + next.add(providerId); } - onOpenChange(newOpen); - }, - [onOpenChange, resetState] - ); + return next; + }); + }, []); - const handleNext = useCallback(() => { - if (!hasEnabledFields) { - toast.error(t("dialog.noFieldEnabled")); - return; - } - setConfirmOpen(true); - }, [hasEnabledFields, t]); + const handleNext = useCallback(async () => { + if (!hasChanges) return; - const handleConfirm = useCallback(async () => { - if (isSubmitting) return; - setIsSubmitting(true); + setIsLoadingPreview(true); + setStep("preview"); try { const providerIds = Array.from(selectedProviderIds); + const patch = buildPatchDraftFromFormState(state, dirtyFields); + const result = await previewProviderBatchPatch({ providerIds, patch }); + + if (result.ok) { + setPreviewResult(result.data); + } else { + toast.error(t("toast.previewFailed", { error: result.error })); + setStep("edit"); + } + } catch (error) { + const message = error instanceof Error ? error.message : t("toast.unknownError"); + toast.error(t("toast.previewFailed", { error: message })); + setStep("edit"); + } finally { + setIsLoadingPreview(false); + } + }, [hasChanges, selectedProviderIds, state, dirtyFields, t]); - if (mode === "edit") { - const updates: BatchUpdateProvidersParams["updates"] = {}; + const handleBackToEdit = useCallback(() => { + setStep("edit"); + setPreviewResult(null); + setExcludedProviderIds(new Set()); + }, []); - if (editState.isEnabledEnabled) { - updates.is_enabled = editState.isEnabled; - } - if (editState.priorityEnabled && editState.priority.trim()) { - const val = Number.parseInt(editState.priority, 10); - if (!Number.isNaN(val) && val >= 0) { - updates.priority = val; - } - } - if (editState.weightEnabled && editState.weight.trim()) { - const val = Number.parseInt(editState.weight, 10); - if (!Number.isNaN(val) && val >= 0) { - updates.weight = val; - } - } - if (editState.costMultiplierEnabled && editState.costMultiplier.trim()) { - const val = Number.parseFloat(editState.costMultiplier); - if (!Number.isNaN(val) && val >= 0) { - updates.cost_multiplier = val; - } - } - if (editState.groupTagEnabled) { - updates.group_tag = editState.groupTag.trim() || null; - } + const handleApply = useCallback(async () => { + if (isSubmitting || !previewResult) return; + setIsSubmitting(true); - const result = await batchUpdateProviders({ providerIds, updates }); - if (result.ok) { - toast.success(t("toast.updated", { count: result.data?.updatedCount ?? 0 })); - } else { - toast.error(t("toast.failed", { error: result.error })); - setIsSubmitting(false); - return; - } - } else if (mode === "delete") { - const result = await batchDeleteProviders({ providerIds }); - if (result.ok) { - toast.success(t("toast.deleted", { count: result.data?.deletedCount ?? 0 })); - } else { - toast.error(t("toast.failed", { error: result.error })); - setIsSubmitting(false); - return; - } - } else if (mode === "resetCircuit") { - const result = await batchResetProviderCircuits({ providerIds }); - if (result.ok) { - toast.success(t("toast.circuitReset", { count: result.data?.resetCount ?? 0 })); - } else { - toast.error(t("toast.failed", { error: result.error })); - setIsSubmitting(false); - return; - } + try { + const providerIds = Array.from(selectedProviderIds); + const patch = buildPatchDraftFromFormState(state, dirtyFields); + const result = await applyProviderBatchPatch({ + previewToken: previewResult.previewToken, + previewRevision: previewResult.previewRevision, + providerIds, + patch, + excludeProviderIds: Array.from(excludedProviderIds), + }); + + if (result.ok) { + await queryClient.invalidateQueries({ queryKey: ["providers"] }); + onOpenChange(false); + onSuccess?.(); + + const undoToken = result.data.undoToken; + const operationId = result.data.operationId; + toast.success(t("toast.updated", { count: result.data.updatedCount }), { + duration: 10000, + action: { + label: t("toast.undo"), + onClick: async () => { + try { + const undoResult = await undoProviderPatch({ undoToken, operationId }); + if (undoResult.ok) { + toast.success(t("toast.undoSuccess", { count: undoResult.data.revertedCount })); + queryClient.invalidateQueries({ queryKey: ["providers"] }); + } else { + toast.error(t("toast.undoFailed", { error: undoResult.error })); + } + } catch (err) { + const msg = err instanceof Error ? err.message : t("toast.unknownError"); + toast.error(t("toast.undoFailed", { error: msg })); + } + }, + }, + }); + } else { + toast.error(t("toast.failed", { error: result.error })); } - - await queryClient.invalidateQueries({ queryKey: ["providers"] }); - handleOpenChange(false); - onSuccess?.(); } catch (error) { - const message = error instanceof Error ? error.message : "Unknown error"; + const message = error instanceof Error ? error.message : t("toast.unknownError"); toast.error(t("toast.failed", { error: message })); } finally { setIsSubmitting(false); } }, [ isSubmitting, + previewResult, selectedProviderIds, - mode, - editState, + state, + dirtyFields, + excludedProviderIds, queryClient, - handleOpenChange, + onOpenChange, onSuccess, t, ]); + return ( + <> + + {step === "preview" ? t("preview.title") : t("dialog.editTitle")} + + {step === "preview" + ? t("preview.description", { count: selectedCount }) + : t("dialog.editDesc", { count: selectedCount })} + + + + {step === "edit" && ( +
+ dispatch({ type: "SET_ACTIVE_TAB", payload: tab })} + layout="horizontal" + /> +
+ {state.ui.activeTab === "basic" && } + {state.ui.activeTab === "routing" && } + {state.ui.activeTab === "limits" && } + {state.ui.activeTab === "network" && } + {state.ui.activeTab === "testing" && } +
+
+ )} + + {step === "preview" && ( +
+ +
+ )} + + + {step === "preview" ? ( + <> + + + + ) : ( + <> + + + + )} + + + ); +} + +// --------------------------------------------------------------------------- +// BatchConfirmDialog: Delete / Reset Circuit (unchanged) +// --------------------------------------------------------------------------- + +function BatchConfirmDialog({ + open, + mode, + onOpenChange, + selectedProviderIds, + providers: _providers, + onSuccess, +}: ProviderBatchDialogProps) { + const t = useTranslations("settings.providers.batchEdit"); + const queryClient = useQueryClient(); + const [isSubmitting, setIsSubmitting] = useState(false); + + const selectedCount = selectedProviderIds.size; + const dialogTitle = useMemo(() => { switch (mode) { - case "edit": - return t("dialog.editTitle"); case "delete": return t("dialog.deleteTitle"); case "resetCircuit": @@ -219,8 +383,6 @@ export function ProviderBatchDialog({ const dialogDescription = useMemo(() => { switch (mode) { - case "edit": - return t("dialog.editDesc", { count: selectedCount }); case "delete": return t("dialog.deleteDesc", { count: selectedCount }); case "resetCircuit": @@ -230,151 +392,93 @@ export function ProviderBatchDialog({ } }, [mode, selectedCount, t]); - return ( - <> - - - - {dialogTitle} - {dialogDescription} - - - {mode === "edit" && ( -
- setEditState((s) => ({ ...s, isEnabledEnabled: v }))} - > - setEditState((s) => ({ ...s, isEnabled: v }))} - /> - - - - - setEditState((s) => ({ ...s, priorityEnabled: v }))} - > - setEditState((s) => ({ ...s, priority: e.target.value }))} - placeholder="0" - className="w-24" - /> - - - setEditState((s) => ({ ...s, weightEnabled: v }))} - > - setEditState((s) => ({ ...s, weight: e.target.value }))} - placeholder="1" - className="w-24" - /> - - - setEditState((s) => ({ ...s, costMultiplierEnabled: v }))} - > - setEditState((s) => ({ ...s, costMultiplier: e.target.value }))} - placeholder="1.0" - className="w-24" - /> - - - - - setEditState((s) => ({ ...s, groupTagEnabled: v }))} - > - setEditState((s) => ({ ...s, groupTag: e.target.value }))} - placeholder="tag1, tag2" - className="w-40" - /> - -
- )} - - {(mode === "delete" || mode === "resetCircuit") && ( -
{dialogDescription}
- )} - - - - - -
-
- - - - - {t("confirm.title")} - {dialogDescription} - - - {t("confirm.goBack")} - - {isSubmitting ? ( - <> - - {t("confirm.processing")} - - ) : ( - t("confirm.confirm") - )} - - - - - - ); -} + const handleConfirm = useCallback(async () => { + if (isSubmitting) return; + setIsSubmitting(true); -interface FieldToggleProps { - label: string; - enabled: boolean; - onEnabledChange: (enabled: boolean) => void; - children: React.ReactNode; -} + try { + const providerIds = Array.from(selectedProviderIds); + + if (mode === "delete") { + const result = await batchDeleteProviders({ providerIds }); + if (result.ok) { + const deletedCount = result.data.deletedCount; + const undoToken = result.data.undoToken; + const operationId = result.data.operationId; + + toast.success(t("undo.batchDeleteSuccess", { count: deletedCount }), { + duration: 10000, + action: { + label: t("undo.button"), + onClick: async () => { + try { + const undoResult = await undoProviderDelete({ undoToken, operationId }); + if (undoResult.ok) { + toast.success( + t("undo.batchDeleteUndone", { count: undoResult.data.restoredCount }) + ); + await queryClient.invalidateQueries({ queryKey: ["providers"] }); + } else if ( + undoResult.errorCode === PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_EXPIRED + ) { + toast.error(t("undo.expired")); + } else { + toast.error(t("undo.failed")); + } + } catch { + toast.error(t("undo.failed")); + } + }, + }, + }); + } else { + toast.error(t("toast.failed", { error: result.error })); + setIsSubmitting(false); + return; + } + } else if (mode === "resetCircuit") { + const result = await batchResetProviderCircuits({ providerIds }); + if (result.ok) { + toast.success(t("toast.circuitReset", { count: result.data?.resetCount ?? 0 })); + } else { + toast.error(t("toast.failed", { error: result.error })); + setIsSubmitting(false); + return; + } + } + + await queryClient.invalidateQueries({ queryKey: ["providers"] }); + onOpenChange(false); + onSuccess?.(); + } catch (error) { + const message = error instanceof Error ? error.message : t("toast.unknownError"); + toast.error(t("toast.failed", { error: message })); + } finally { + setIsSubmitting(false); + } + }, [isSubmitting, selectedProviderIds, mode, queryClient, onOpenChange, onSuccess, t]); -function FieldToggle({ label, enabled, onEnabledChange, children }: FieldToggleProps) { return ( -
-
- - -
-
{children}
-
+ + + + {dialogTitle} + {dialogDescription} + + + {t("confirm.goBack")} + + {isSubmitting ? ( + <> + + {t("confirm.processing")} + + ) : ( + t("confirm.confirm") + )} + + + + ); } diff --git a/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-preview-step.tsx b/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-preview-step.tsx new file mode 100644 index 000000000..a1f88394e --- /dev/null +++ b/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-preview-step.tsx @@ -0,0 +1,179 @@ +"use client"; + +import { Loader2 } from "lucide-react"; +import { useTranslations } from "next-intl"; +import { useCallback, useMemo } from "react"; + +// --------------------------------------------------------------------------- +// Field label lookup (uses existing translations with readable fallback) +// --------------------------------------------------------------------------- + +const FIELD_LABEL_KEYS: Record = { + is_enabled: "fields.isEnabled.label", + priority: "fields.priority", + weight: "fields.weight", + cost_multiplier: "fields.costMultiplier", + group_tag: "fields.groupTag.label", + model_redirects: "fields.modelRedirects", + allowed_models: "fields.allowedModels", + anthropic_thinking_budget_preference: "fields.thinkingBudget", + anthropic_adaptive_thinking: "fields.adaptiveThinking", +}; + +import type { ProviderBatchPreviewRow } from "@/actions/providers"; +import { Checkbox } from "@/components/ui/checkbox"; + +// --------------------------------------------------------------------------- +// Props +// --------------------------------------------------------------------------- + +export interface ProviderBatchPreviewStepProps { + rows: ProviderBatchPreviewRow[]; + summary: { providerCount: number; fieldCount: number; skipCount: number }; + excludedProviderIds: Set; + onExcludeToggle: (providerId: number) => void; + isLoading?: boolean; +} + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +interface ProviderGroup { + providerId: number; + providerName: string; + rows: ProviderBatchPreviewRow[]; +} + +// --------------------------------------------------------------------------- +// Component +// --------------------------------------------------------------------------- + +export function ProviderBatchPreviewStep({ + rows, + summary, + excludedProviderIds, + onExcludeToggle, + isLoading, +}: ProviderBatchPreviewStepProps) { + const t = useTranslations("settings.providers.batchEdit"); + + const grouped = useMemo(() => { + const map = new Map(); + for (const row of rows) { + let group = map.get(row.providerId); + if (!group) { + group = { providerId: row.providerId, providerName: row.providerName, rows: [] }; + map.set(row.providerId, group); + } + group.rows.push(row); + } + return Array.from(map.values()); + }, [rows]); + + const getFieldLabel = useCallback( + (field: string): string => { + const key = FIELD_LABEL_KEYS[field]; + if (key) return t(key); + return field.replace(/_/g, " "); + }, + [t] + ); + + if (isLoading) { + return ( +
+ + {t("preview.loading")} +
+ ); + } + + if (rows.length === 0) { + return ( +
+ {t("preview.noChanges")} +
+ ); + } + + return ( +
+ {/* Summary */} +

+ {t("preview.summary", { + providerCount: summary.providerCount, + fieldCount: summary.fieldCount, + skipCount: summary.skipCount, + })} +

+ + {/* Provider groups */} +
+ {grouped.map((group) => { + const excluded = excludedProviderIds.has(group.providerId); + return ( +
+ {/* Provider header with exclusion checkbox */} +
+ onExcludeToggle(group.providerId)} + aria-label={t("preview.excludeProvider")} + data-testid={`exclude-checkbox-${group.providerId}`} + /> + + {t("preview.providerHeader", { name: group.providerName })} + +
+ + {/* Field rows */} +
+ {group.rows.map((row) => ( +
+ {row.status === "changed" + ? t("preview.fieldChanged", { + field: getFieldLabel(row.field), + before: formatValue(row.before), + after: formatValue(row.after), + }) + : t("preview.fieldSkipped", { + field: getFieldLabel(row.field), + reason: row.skipReason ?? "", + })} +
+ ))} +
+
+ ); + })} +
+
+ ); +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function formatValue(value: unknown): string { + if (value === null || value === undefined) return "null"; + if (typeof value === "boolean") return String(value); + if (typeof value === "number") return String(value); + if (typeof value === "string") return value; + return JSON.stringify(value); +} diff --git a/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-toolbar.tsx b/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-toolbar.tsx index 6225069fe..40ee6c928 100644 --- a/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-toolbar.tsx +++ b/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-toolbar.tsx @@ -1,10 +1,18 @@ "use client"; -import { Pencil, X } from "lucide-react"; +import { ChevronDown, Pencil, X } from "lucide-react"; import { useTranslations } from "next-intl"; +import { useMemo } from "react"; import { Button } from "@/components/ui/button"; import { Checkbox } from "@/components/ui/checkbox"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; import { cn } from "@/lib/utils"; +import type { ProviderDisplay, ProviderType } from "@/types/provider"; export interface ProviderBatchToolbarProps { isMultiSelectMode: boolean; @@ -16,6 +24,9 @@ export interface ProviderBatchToolbarProps { onSelectAll: (checked: boolean) => void; onInvertSelection: () => void; onOpenBatchEdit: () => void; + providers: ProviderDisplay[]; + onSelectByType: (type: ProviderType) => void; + onSelectByGroup: (group: string) => void; } export function ProviderBatchToolbar({ @@ -28,20 +39,58 @@ export function ProviderBatchToolbar({ onSelectAll, onInvertSelection, onOpenBatchEdit, + providers, + onSelectByType, + onSelectByGroup, }: ProviderBatchToolbarProps) { const t = useTranslations("settings.providers.batchEdit"); + const uniqueTypes = useMemo(() => { + const typeMap = new Map(); + for (const p of providers) { + typeMap.set(p.providerType, (typeMap.get(p.providerType) ?? 0) + 1); + } + return Array.from(typeMap.entries()) + .map(([type, count]) => ({ type, count })) + .sort((a, b) => a.type.localeCompare(b.type)); + }, [providers]); + + const uniqueGroups = useMemo(() => { + const groupMap = new Map(); + for (const p of providers) { + if (p.groupTag) { + const tags = p.groupTag + .split(",") + .map((tag) => tag.trim()) + .filter(Boolean); + for (const tag of tags) { + groupMap.set(tag, (groupMap.get(tag) ?? 0) + 1); + } + } + } + return Array.from(groupMap.entries()) + .map(([group, count]) => ({ group, count })) + .sort((a, b) => a.group.localeCompare(b.group)); + }, [providers]); + if (!isMultiSelectMode) { return ( - +
+ + {totalCount > 0 && ( + + {t("selectionHint")} + + )} +
); } @@ -65,6 +114,46 @@ export function ProviderBatchToolbar({ {t("invertSelection")} + {uniqueTypes.length > 1 && ( + + + + + + {uniqueTypes.map(({ type, count }) => ( + onSelectByType(type)}> + {t("selectByTypeItem", { type, count })} + + ))} + + + )} + + {uniqueGroups.length > 0 && ( + + + + + + {uniqueGroups.map(({ group, count }) => ( + onSelectByGroup(group)} + > + {t("selectByGroupItem", { group, count })} + + ))} + + + )} + + ); + })} +
+ + ); + } + return ( <> {/* Desktop: Vertical Sidebar */} diff --git a/src/app/[locale]/settings/providers/_components/forms/provider-form/index.tsx b/src/app/[locale]/settings/providers/_components/forms/provider-form/index.tsx index 2942d5be3..907ba3ff4 100644 --- a/src/app/[locale]/settings/providers/_components/forms/provider-form/index.tsx +++ b/src/app/[locale]/settings/providers/_components/forms/provider-form/index.tsx @@ -5,7 +5,13 @@ import { useTranslations } from "next-intl"; import { useCallback, useEffect, useMemo, useRef, useState, useTransition } from "react"; import { toast } from "sonner"; import { getProviderEndpoints, getProviderVendors } from "@/actions/provider-endpoints"; -import { addProvider, editProvider, removeProvider } from "@/actions/providers"; +import { + addProvider, + editProvider, + removeProvider, + undoProviderDelete, + undoProviderPatch, +} from "@/actions/providers"; import { getDistinctProviderGroupsAction } from "@/actions/request-filters"; import { AlertDialog, @@ -19,6 +25,7 @@ import { AlertDialogTitle as AlertTitle, } from "@/components/ui/alert-dialog"; import { Button } from "@/components/ui/button"; +import { PROVIDER_BATCH_PATCH_ERROR_CODES } from "@/lib/provider-batch-patch-error-codes"; import { isValidUrl } from "@/lib/utils/validation"; import type { ProviderDisplay, @@ -89,6 +96,7 @@ function ProviderFormContent({ resolvedUrl?: string | null; }) { const t = useTranslations("settings.providers.form"); + const tBatchEdit = useTranslations("settings.providers.batchEdit"); const { state, dispatch, mode, provider, hideUrl } = useProviderForm(); const [isPending, startTransition] = useTransition(); const isEdit = mode === "edit"; @@ -363,7 +371,36 @@ function ProviderFormContent({ toast.error(res.error || t("errors.updateFailed")); return; } - toast.success(t("success.updated")); + + const undoToken = res.data.undoToken; + const operationId = res.data.operationId; + + toast.success(tBatchEdit("undo.singleEditSuccess"), { + duration: 10000, + action: { + label: tBatchEdit("undo.button"), + onClick: async () => { + try { + const undoResult = await undoProviderPatch({ undoToken, operationId }); + if (undoResult.ok) { + toast.success(tBatchEdit("undo.singleEditUndone")); + await queryClient.invalidateQueries({ queryKey: ["providers"] }); + await queryClient.invalidateQueries({ queryKey: ["providers-health"] }); + await queryClient.invalidateQueries({ queryKey: ["providers-statistics"] }); + await queryClient.invalidateQueries({ queryKey: ["provider-vendors"] }); + } else if ( + undoResult.errorCode === PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_EXPIRED + ) { + toast.error(tBatchEdit("undo.expired")); + } else { + toast.error(tBatchEdit("undo.failed")); + } + } catch { + toast.error(tBatchEdit("undo.failed")); + } + }, + }, + }); void queryClient.invalidateQueries({ queryKey: ["providers"] }); void queryClient.invalidateQueries({ queryKey: ["providers-health"] }); @@ -426,7 +463,39 @@ function ProviderFormContent({ toast.error(res.error || t("errors.deleteFailed")); return; } - toast.success(t("success.deleted")); + + const undoToken = res.data.undoToken; + const operationId = res.data.operationId; + + toast.success(tBatchEdit("undo.singleDeleteSuccess"), { + duration: 10000, + action: { + label: tBatchEdit("undo.button"), + onClick: async () => { + try { + const undoResult = await undoProviderDelete({ undoToken, operationId }); + if (undoResult.ok) { + toast.success(tBatchEdit("undo.singleDeleteUndone")); + await queryClient.invalidateQueries({ queryKey: ["providers"] }); + await queryClient.invalidateQueries({ queryKey: ["providers-health"] }); + await queryClient.invalidateQueries({ queryKey: ["providers-statistics"] }); + await queryClient.invalidateQueries({ queryKey: ["provider-vendors"] }); + } else if (undoResult.errorCode === PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_EXPIRED) { + toast.error(tBatchEdit("undo.expired")); + } else { + toast.error(tBatchEdit("undo.failed")); + } + } catch { + toast.error(tBatchEdit("undo.failed")); + } + }, + }, + }); + + void queryClient.invalidateQueries({ queryKey: ["providers"] }); + void queryClient.invalidateQueries({ queryKey: ["providers-health"] }); + void queryClient.invalidateQueries({ queryKey: ["providers-statistics"] }); + void queryClient.invalidateQueries({ queryKey: ["provider-vendors"] }); onSuccess?.(); } catch (e) { console.error("Delete error:", e); diff --git a/src/app/[locale]/settings/providers/_components/forms/provider-form/provider-form-context.tsx b/src/app/[locale]/settings/providers/_components/forms/provider-form/provider-form-context.tsx index facc525c9..a8f79bc67 100644 --- a/src/app/[locale]/settings/providers/_components/forms/provider-form/provider-form-context.tsx +++ b/src/app/[locale]/settings/providers/_components/forms/provider-form/provider-form-context.tsx @@ -1,6 +1,15 @@ "use client"; -import { createContext, type ReactNode, useContext, useReducer } from "react"; +import { + createContext, + type Dispatch, + type ReactNode, + useCallback, + useContext, + useMemo, + useReducer, + useRef, +} from "react"; import type { ProviderDisplay, ProviderType } from "@/types/provider"; import type { FormMode, @@ -9,6 +18,52 @@ import type { ProviderFormState, } from "./provider-form-types"; +// Maps action types to dirty field paths for batch mode tracking +const ACTION_TO_FIELD_PATH: Partial> = { + SET_BATCH_IS_ENABLED: "batch.isEnabled", + SET_PRIORITY: "routing.priority", + SET_WEIGHT: "routing.weight", + SET_COST_MULTIPLIER: "routing.costMultiplier", + SET_GROUP_TAG: "routing.groupTag", + SET_PRESERVE_CLIENT_IP: "routing.preserveClientIp", + SET_MODEL_REDIRECTS: "routing.modelRedirects", + SET_ALLOWED_MODELS: "routing.allowedModels", + SET_GROUP_PRIORITIES: "routing.groupPriorities", + SET_CACHE_TTL_PREFERENCE: "routing.cacheTtlPreference", + SET_SWAP_CACHE_TTL_BILLING: "routing.swapCacheTtlBilling", + SET_CONTEXT_1M_PREFERENCE: "routing.context1mPreference", + SET_CODEX_REASONING_EFFORT: "routing.codexReasoningEffortPreference", + SET_CODEX_REASONING_SUMMARY: "routing.codexReasoningSummaryPreference", + SET_CODEX_TEXT_VERBOSITY: "routing.codexTextVerbosityPreference", + SET_CODEX_PARALLEL_TOOL_CALLS: "routing.codexParallelToolCallsPreference", + SET_ANTHROPIC_MAX_TOKENS: "routing.anthropicMaxTokensPreference", + SET_ANTHROPIC_THINKING_BUDGET: "routing.anthropicThinkingBudgetPreference", + SET_ADAPTIVE_THINKING_ENABLED: "routing.anthropicAdaptiveThinking", + SET_ADAPTIVE_THINKING_EFFORT: "routing.anthropicAdaptiveThinking", + SET_ADAPTIVE_THINKING_MODEL_MATCH_MODE: "routing.anthropicAdaptiveThinking", + SET_ADAPTIVE_THINKING_MODELS: "routing.anthropicAdaptiveThinking", + SET_GEMINI_GOOGLE_SEARCH: "routing.geminiGoogleSearchPreference", + SET_LIMIT_5H_USD: "rateLimit.limit5hUsd", + SET_LIMIT_DAILY_USD: "rateLimit.limitDailyUsd", + SET_DAILY_RESET_MODE: "rateLimit.dailyResetMode", + SET_DAILY_RESET_TIME: "rateLimit.dailyResetTime", + SET_LIMIT_WEEKLY_USD: "rateLimit.limitWeeklyUsd", + SET_LIMIT_MONTHLY_USD: "rateLimit.limitMonthlyUsd", + SET_LIMIT_TOTAL_USD: "rateLimit.limitTotalUsd", + SET_LIMIT_CONCURRENT_SESSIONS: "rateLimit.limitConcurrentSessions", + SET_FAILURE_THRESHOLD: "circuitBreaker.failureThreshold", + SET_OPEN_DURATION_MINUTES: "circuitBreaker.openDurationMinutes", + SET_HALF_OPEN_SUCCESS_THRESHOLD: "circuitBreaker.halfOpenSuccessThreshold", + SET_MAX_RETRY_ATTEMPTS: "circuitBreaker.maxRetryAttempts", + SET_PROXY_URL: "network.proxyUrl", + SET_PROXY_FALLBACK_TO_DIRECT: "network.proxyFallbackToDirect", + SET_FIRST_BYTE_TIMEOUT_STREAMING: "network.firstByteTimeoutStreamingSeconds", + SET_STREAMING_IDLE_TIMEOUT: "network.streamingIdleTimeoutSeconds", + SET_REQUEST_TIMEOUT_NON_STREAMING: "network.requestTimeoutNonStreamingSeconds", + SET_MCP_PASSTHROUGH_TYPE: "mcp.mcpPassthroughType", + SET_MCP_PASSTHROUGH_URL: "mcp.mcpPassthroughUrl", +}; + // Initial state factory export function createInitialState( mode: FormMode, @@ -22,9 +77,72 @@ export function createInitialState( } ): ProviderFormState { const isEdit = mode === "edit"; + const isBatch = mode === "batch"; const raw = isEdit ? provider : cloneProvider; const sourceProvider = raw ? structuredClone(raw) : undefined; + // Batch mode: all fields start at neutral defaults (no provider source) + if (isBatch) { + return { + basic: { name: "", url: "", key: "", websiteUrl: "" }, + routing: { + providerType: "claude", + groupTag: [], + preserveClientIp: false, + modelRedirects: {}, + allowedModels: [], + priority: 0, + groupPriorities: {}, + weight: 1, + costMultiplier: 1.0, + cacheTtlPreference: "inherit", + swapCacheTtlBilling: false, + context1mPreference: "inherit", + codexReasoningEffortPreference: "inherit", + codexReasoningSummaryPreference: "inherit", + codexTextVerbosityPreference: "inherit", + codexParallelToolCallsPreference: "inherit", + anthropicMaxTokensPreference: "inherit", + anthropicThinkingBudgetPreference: "inherit", + anthropicAdaptiveThinking: null, + geminiGoogleSearchPreference: "inherit", + }, + rateLimit: { + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + limitConcurrentSessions: null, + }, + circuitBreaker: { + failureThreshold: undefined, + openDurationMinutes: undefined, + halfOpenSuccessThreshold: undefined, + maxRetryAttempts: null, + }, + network: { + proxyUrl: "", + proxyFallbackToDirect: false, + firstByteTimeoutStreamingSeconds: undefined, + streamingIdleTimeoutSeconds: undefined, + requestTimeoutNonStreamingSeconds: undefined, + }, + mcp: { + mcpPassthroughType: "none", + mcpPassthroughUrl: "", + }, + batch: { isEnabled: "no_change" }, + ui: { + activeTab: "basic", + isPending: false, + showFailureThresholdConfirm: false, + }, + }; + } + return { basic: { name: isEdit @@ -105,6 +223,7 @@ export function createInitialState( mcpPassthroughType: sourceProvider?.mcpPassthroughType ?? "none", mcpPassthroughUrl: sourceProvider?.mcpPassthroughUrl ?? "", }, + batch: { isEnabled: "no_change" }, ui: { activeTab: "basic", isPending: false, @@ -317,6 +436,10 @@ export function providerFormReducer( case "SET_MCP_PASSTHROUGH_URL": return { ...state, mcp: { ...state.mcp, mcpPassthroughUrl: action.payload } }; + // Batch + case "SET_BATCH_IS_ENABLED": + return { ...state, batch: { ...state.batch, isEnabled: action.payload } }; + // UI case "SET_ACTIVE_TAB": return { ...state, ui: { ...state.ui, activeTab: action.payload } }; @@ -357,6 +480,7 @@ export function ProviderFormProvider({ hideWebsiteUrl = false, preset, groupSuggestions, + batchProviders, }: { children: ReactNode; mode: FormMode; @@ -372,27 +496,58 @@ export function ProviderFormProvider({ providerType?: ProviderType; }; groupSuggestions: string[]; + batchProviders?: ProviderDisplay[]; }) { - const [state, dispatch] = useReducer( + const [state, rawDispatch] = useReducer( providerFormReducer, createInitialState(mode, provider, cloneProvider, preset) ); + const dirtyFieldsRef = useRef(new Set()); + const isBatch = mode === "batch"; + + // Wrap dispatch for batch mode to auto-track dirty fields + const dispatch: Dispatch = useCallback( + (action: ProviderFormAction) => { + if (isBatch) { + const fieldPath = ACTION_TO_FIELD_PATH[action.type]; + if (fieldPath) { + dirtyFieldsRef.current.add(fieldPath); + } + } + rawDispatch(action); + }, + [isBatch] + ); + + const contextValue = useMemo( + () => ({ + state, + dispatch, + mode, + provider, + enableMultiProviderTypes, + hideUrl, + hideWebsiteUrl, + groupSuggestions, + batchProviders, + dirtyFields: dirtyFieldsRef.current, + }), + [ + state, + dispatch, + mode, + provider, + enableMultiProviderTypes, + hideUrl, + hideWebsiteUrl, + groupSuggestions, + batchProviders, + ] + ); + return ( - - {children} - + {children} ); } diff --git a/src/app/[locale]/settings/providers/_components/forms/provider-form/provider-form-types.ts b/src/app/[locale]/settings/providers/_components/forms/provider-form/provider-form-types.ts index 60355dd9e..4bec44463 100644 --- a/src/app/[locale]/settings/providers/_components/forms/provider-form/provider-form-types.ts +++ b/src/app/[locale]/settings/providers/_components/forms/provider-form/provider-form-types.ts @@ -16,7 +16,7 @@ import type { } from "@/types/provider"; // Form mode -export type FormMode = "create" | "edit"; +export type FormMode = "create" | "edit" | "batch"; // Tab identifiers export type TabId = "basic" | "routing" | "limits" | "network" | "testing"; @@ -93,6 +93,10 @@ export interface McpState { mcpPassthroughUrl: string; } +export interface BatchState { + isEnabled: "no_change" | "true" | "false"; +} + export interface UIState { activeTab: TabId; isPending: boolean; @@ -107,6 +111,7 @@ export interface ProviderFormState { circuitBreaker: CircuitBreakerState; network: NetworkState; mcp: McpState; + batch: BatchState; ui: UIState; } @@ -173,7 +178,9 @@ export type ProviderFormAction = | { type: "SET_SHOW_FAILURE_THRESHOLD_CONFIRM"; payload: boolean } // Bulk actions | { type: "RESET_FORM" } - | { type: "LOAD_PROVIDER"; payload: ProviderDisplay }; + | { type: "LOAD_PROVIDER"; payload: ProviderDisplay } + // Batch actions + | { type: "SET_BATCH_IS_ENABLED"; payload: "no_change" | "true" | "false" }; // Form props export interface ProviderFormProps { @@ -204,4 +211,6 @@ export interface ProviderFormContextValue { hideUrl: boolean; hideWebsiteUrl: boolean; groupSuggestions: string[]; + batchProviders?: ProviderDisplay[]; + dirtyFields: Set; } diff --git a/src/app/[locale]/settings/providers/_components/forms/provider-form/sections/basic-info-section.tsx b/src/app/[locale]/settings/providers/_components/forms/provider-form/sections/basic-info-section.tsx index eb7258fd8..48c048042 100644 --- a/src/app/[locale]/settings/providers/_components/forms/provider-form/sections/basic-info-section.tsx +++ b/src/app/[locale]/settings/providers/_components/forms/provider-form/sections/basic-info-section.tsx @@ -7,6 +7,13 @@ import { useEffect, useMemo, useRef, useState } from "react"; import { ProviderEndpointsSection } from "@/app/[locale]/settings/providers/_components/provider-endpoints-table"; import { InlineWarning } from "@/components/ui/inline-warning"; import { Input } from "@/components/ui/input"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; import { detectApiKeyWarnings } from "@/lib/utils/validation/api-key-warnings"; import type { ProviderType } from "@/types/provider"; import { UrlPreview } from "../../url-preview"; @@ -14,6 +21,8 @@ import { QuickPasteDialog } from "../components/quick-paste-dialog"; import { SectionCard, SmartInputWrapper } from "../components/section-card"; import { useProviderForm } from "../provider-form-context"; +const MAX_DISPLAYED_PROVIDERS = 5; + interface BasicInfoSectionProps { autoUrlPending?: boolean; endpointPool?: { @@ -25,21 +34,95 @@ interface BasicInfoSectionProps { export function BasicInfoSection({ autoUrlPending, endpointPool }: BasicInfoSectionProps) { const t = useTranslations("settings.providers.form"); + const tBatch = useTranslations("settings.providers.batchEdit"); const tProviders = useTranslations("settings.providers"); - const { state, dispatch, mode, provider, hideUrl, hideWebsiteUrl } = useProviderForm(); + const { state, dispatch, mode, provider, hideUrl, hideWebsiteUrl, batchProviders } = + useProviderForm(); const isEdit = mode === "edit"; + const isBatch = mode === "batch"; const nameInputRef = useRef(null); const [showKey, setShowKey] = useState(false); const apiKeyWarnings = useMemo(() => detectApiKeyWarnings(state.basic.key), [state.basic.key]); - // Auto-focus name input + // Auto-focus name input (skip in batch mode) useEffect(() => { + if (isBatch) return; const timer = setTimeout(() => { nameInputRef.current?.focus(); }, 100); return () => clearTimeout(timer); - }, []); + }, [isBatch]); + + // Batch mode: only isEnabled tri-state + provider summary + if (isBatch) { + const providers = batchProviders ?? []; + const displayed = providers.slice(0, MAX_DISPLAYED_PROVIDERS); + const remaining = providers.length - displayed.length; + + return ( + + +
+ + + + + {providers.length > 0 && ( +
+

+ {tBatch("affectedProviders.title")} ({providers.length}) +

+
+ {displayed.map((p) => ( +

+ {p.name} ({p.maskedKey}) +

+ ))} + {remaining > 0 && ( +

+ {tBatch("affectedProviders.more", { count: remaining })} +

+ )} +
+
+ )} +
+
+
+ ); + } return ( - {/* Proxy Test */} -
-
- -
-
{t("sections.proxy.test.label")}
-

{t("sections.proxy.test.desc")}

+ {/* Proxy Test - hidden in batch mode */} + {!isBatch && ( +
+
+ +
+
{t("sections.proxy.test.label")}
+

+ {t("sections.proxy.test.desc")} +

+
+
- -
+ )} )}
diff --git a/src/app/[locale]/settings/providers/_components/forms/provider-form/sections/routing-section.tsx b/src/app/[locale]/settings/providers/_components/forms/provider-form/sections/routing-section.tsx index d9949900e..59c0f7c4a 100644 --- a/src/app/[locale]/settings/providers/_components/forms/provider-form/sections/routing-section.tsx +++ b/src/app/[locale]/settings/providers/_components/forms/provider-form/sections/routing-section.tsx @@ -18,8 +18,6 @@ import { TagInput } from "@/components/ui/tag-input"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; import { getProviderTypeConfig } from "@/lib/provider-type-utils"; import type { - AnthropicAdaptiveThinkingEffort, - AnthropicAdaptiveThinkingModelMatchMode, CodexParallelToolCallsPreference, CodexReasoningEffortPreference, CodexReasoningSummaryPreference, @@ -27,8 +25,10 @@ import type { GeminiGoogleSearchPreference, ProviderType, } from "@/types/provider"; +import { AdaptiveThinkingEditor } from "../../../adaptive-thinking-editor"; import { ModelMultiSelect } from "../../../model-multi-select"; import { ModelRedirectEditor } from "../../../model-redirect-editor"; +import { ThinkingBudgetEditor } from "../../../thinking-budget-editor"; import { FieldGroup, SectionCard, SmartInputWrapper, ToggleRow } from "../components/section-card"; import { useProviderForm } from "../provider-form-context"; @@ -36,10 +36,13 @@ const GROUP_TAG_MAX_TOTAL_LENGTH = 50; export function RoutingSection() { const t = useTranslations("settings.providers.form"); + const tBatch = useTranslations("settings.providers.batchEdit"); const tUI = useTranslations("ui.tagInput"); const { state, dispatch, mode, provider, enableMultiProviderTypes, groupSuggestions } = useProviderForm(); const isEdit = mode === "edit"; + const isBatch = mode === "batch"; + const { providerType } = state.routing; const renderProviderTypeLabel = (type: ProviderType) => { switch (type) { @@ -76,78 +79,81 @@ export function RoutingSection() { transition={{ duration: 0.2 }} className="space-y-6" > - {/* Provider Type & Group */} - -
- - - {!enableMultiProviderTypes && state.routing.providerType === "openai-compatible" && ( -

- {t("sections.routing.providerTypeDisabledNote")} -

- )} -
+ {/* Provider Type & Group - hidden in batch mode */} + {!isBatch && ( + +
+ + + {!enableMultiProviderTypes && + state.routing.providerType === "openai-compatible" && ( +

+ {t("sections.routing.providerTypeDisabledNote")} +

+ )} +
- - { - const messages: Record = { - empty: tUI("emptyTag"), - duplicate: tUI("duplicateTag"), - too_long: tUI("tooLong", { max: GROUP_TAG_MAX_TOTAL_LENGTH }), - invalid_format: tUI("invalidFormat"), - max_tags: tUI("maxTags"), - }; - toast.error(messages[reason] || reason); - }} - /> - -
-
+ + { + const messages: Record = { + empty: tUI("emptyTag"), + duplicate: tUI("duplicateTag"), + too_long: tUI("tooLong", { max: GROUP_TAG_MAX_TOTAL_LENGTH }), + invalid_format: tUI("invalidFormat"), + max_tags: tUI("maxTags"), + }; + toast.error(messages[reason] || reason); + }} + /> + +
+
+ )} {/* Model Configuration */} - {/* 1M Context Window - Claude type only */} - {state.routing.providerType === "claude" && ( + {/* 1M Context Window - Claude type only (or batch mode) */} + {(providerType === "claude" || providerType === "claude-auth" || isBatch) && ( - {/* Codex Overrides - Codex type only */} - {state.routing.providerType === "codex" && ( + {/* Codex Overrides - Codex type only (or batch mode) */} + {(providerType === "codex" || isBatch) && ( {tBatch("batchNotes.codexOnly")} + ) : undefined + } >
@@ -548,13 +559,17 @@ export function RoutingSection() { )} - {/* Anthropic Overrides - Claude type only */} - {(state.routing.providerType === "claude" || - state.routing.providerType === "claude-auth") && ( + {/* Anthropic Overrides - Claude type only (or batch mode) */} + {(providerType === "claude" || providerType === "claude-auth" || isBatch) && ( {tBatch("batchNotes.claudeOnly")} + ) : undefined + } >
@@ -615,243 +630,61 @@ export function RoutingSection() { - - -
- - {state.routing.anthropicThinkingBudgetPreference !== "inherit" && ( - <> - { - const val = e.target.value; - if (val === "") { - dispatch({ - type: "SET_ANTHROPIC_THINKING_BUDGET", - payload: "inherit", - }); - } else { - dispatch({ - type: "SET_ANTHROPIC_THINKING_BUDGET", - payload: val, - }); - } - }} - placeholder={t( - "sections.routing.anthropicOverrides.thinkingBudget.placeholder" - )} - disabled={state.ui.isPending} - min="1024" - max="32000" - className="flex-1" - /> - - - )} - -
-
- -

- {t("sections.routing.anthropicOverrides.thinkingBudget.help")} -

-
-
-
- - - - dispatch({ type: "SET_ADAPTIVE_THINKING_ENABLED", payload: checked }) + + dispatch({ + type: "SET_ANTHROPIC_THINKING_BUDGET", + payload: val, + }) } disabled={state.ui.isPending} /> - - - {state.routing.anthropicAdaptiveThinking && ( -
- - - -
- - -
-
- -

- {t("sections.routing.anthropicOverrides.adaptiveThinking.effort.help")} -

-
-
-
- - - - -
- - -
-
- -

- {t( - "sections.routing.anthropicOverrides.adaptiveThinking.modelMatchMode.help" - )} -

-
-
-
+ - {state.routing.anthropicAdaptiveThinking.modelMatchMode === "specific" && ( - - - -
- - dispatch({ - type: "SET_ADAPTIVE_THINKING_MODELS", - payload: models, - }) - } - placeholder={t( - "sections.routing.anthropicOverrides.adaptiveThinking.models.placeholder" - )} - disabled={state.ui.isPending} - /> - -
-
- -

- {t("sections.routing.anthropicOverrides.adaptiveThinking.models.help")} -

-
-
-
- )} -
- )} + + dispatch({ type: "SET_ADAPTIVE_THINKING_ENABLED", payload: enabled }) + } + onConfigChange={(newConfig) => { + dispatch({ + type: "SET_ADAPTIVE_THINKING_EFFORT", + payload: newConfig.effort, + }); + dispatch({ + type: "SET_ADAPTIVE_THINKING_MODEL_MATCH_MODE", + payload: newConfig.modelMatchMode, + }); + dispatch({ + type: "SET_ADAPTIVE_THINKING_MODELS", + payload: newConfig.models, + }); + }} + disabled={state.ui.isPending} + />
)} - {/* Gemini Overrides - Gemini type only */} - {(state.routing.providerType === "gemini" || - state.routing.providerType === "gemini-cli") && ( + {/* Gemini Overrides - Gemini type only (or batch mode) */} + {(providerType === "gemini" || providerType === "gemini-cli" || isBatch) && ( {tBatch("batchNotes.geminiOnly")} + ) : undefined + } > + + + + + {t(`${prefix}.options.inherit`)} + {t(`${prefix}.options.custom`)} + + + {mode !== "inherit" && ( + <> + + + + )} + +
+ + +

{t(`${prefix}.help`)}

+
+ + ); +} diff --git a/src/app/[locale]/settings/providers/_components/vendor-keys-compact-list.tsx b/src/app/[locale]/settings/providers/_components/vendor-keys-compact-list.tsx index acc70cfa4..1a679d623 100644 --- a/src/app/[locale]/settings/providers/_components/vendor-keys-compact-list.tsx +++ b/src/app/[locale]/settings/providers/_components/vendor-keys-compact-list.tsx @@ -1,13 +1,18 @@ "use client"; +import { VisuallyHidden } from "@radix-ui/react-visually-hidden"; import { useMutation, useQueryClient } from "@tanstack/react-query"; import { CheckCircle, Copy, Edit2, Loader2, Plus, Trash2 } from "lucide-react"; import { useTranslations } from "next-intl"; import { useEffect, useState } from "react"; import { toast } from "sonner"; -import { VisuallyHidden } from "@radix-ui/react-visually-hidden"; import { getProviderEndpoints } from "@/actions/provider-endpoints"; -import { editProvider, getUnmaskedProviderKey, removeProvider } from "@/actions/providers"; +import { + editProvider, + getUnmaskedProviderKey, + removeProvider, + undoProviderDelete, +} from "@/actions/providers"; import { FormErrorBoundary } from "@/components/form-error-boundary"; import { AlertDialog, @@ -39,6 +44,7 @@ import { TableRow, } from "@/components/ui/table"; import { PROVIDER_LIMITS } from "@/lib/constants/provider.constants"; +import { PROVIDER_BATCH_PATCH_ERROR_CODES } from "@/lib/provider-batch-patch-error-codes"; import { getProviderTypeConfig, getProviderTypeTranslationKey } from "@/lib/provider-type-utils"; import { copyToClipboard, isClipboardSupported } from "@/lib/utils/clipboard"; import { type CurrencyCode, formatCurrency } from "@/lib/utils/currency"; @@ -214,6 +220,7 @@ function VendorKeyRow(props: { }) { const t = useTranslations("settings.providers"); const tList = useTranslations("settings.providers.list"); + const tBatchEdit = useTranslations("settings.providers.batchEdit"); const tInline = useTranslations("settings.providers.inlineEdit"); const tTypes = useTranslations("settings.providers.types"); @@ -305,15 +312,41 @@ function VendorKeyRow(props: { mutationFn: async () => { const res = await removeProvider(props.provider.id); if (!res.ok) throw new Error(res.error); + return res.data; }, - onSuccess: () => { + onSuccess: (data) => { queryClient.invalidateQueries({ queryKey: ["providers"] }); queryClient.invalidateQueries({ queryKey: ["providers-health"] }); queryClient.invalidateQueries({ queryKey: ["providers-statistics"] }); queryClient.invalidateQueries({ queryKey: ["provider-vendors"] }); setDeleteDialogOpen(false); - toast.success(tList("deleteSuccess"), { - description: tList("deleteSuccessDesc", { name: props.provider.name }), + + toast.success(tBatchEdit("undo.singleDeleteSuccess"), { + duration: 10000, + action: { + label: tBatchEdit("undo.button"), + onClick: async () => { + try { + const undoResult = await undoProviderDelete({ + undoToken: data.undoToken, + operationId: data.operationId, + }); + if (undoResult.ok) { + toast.success(tBatchEdit("undo.singleDeleteUndone")); + await queryClient.invalidateQueries({ queryKey: ["providers"] }); + await queryClient.invalidateQueries({ queryKey: ["providers-health"] }); + await queryClient.invalidateQueries({ queryKey: ["providers-statistics"] }); + await queryClient.invalidateQueries({ queryKey: ["provider-vendors"] }); + } else if (undoResult.errorCode === PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_EXPIRED) { + toast.error(tBatchEdit("undo.expired")); + } else { + toast.error(tBatchEdit("undo.failed")); + } + } catch { + toast.error(tBatchEdit("undo.failed")); + } + }, + }, }); }, onError: () => { diff --git a/src/app/[locale]/usage-doc/_components/usage-doc-auth-context.tsx b/src/app/[locale]/usage-doc/_components/usage-doc-auth-context.tsx new file mode 100644 index 000000000..dcbb71022 --- /dev/null +++ b/src/app/[locale]/usage-doc/_components/usage-doc-auth-context.tsx @@ -0,0 +1,28 @@ +"use client"; + +import { createContext, type ReactNode, useContext } from "react"; + +interface UsageDocAuthContextValue { + isLoggedIn: boolean; +} + +const UsageDocAuthContext = createContext({ + isLoggedIn: false, +}); + +// Security: HttpOnly cookies are invisible to document.cookie; session state must come from server. +export function UsageDocAuthProvider({ + isLoggedIn, + children, +}: { + isLoggedIn: boolean; + children: ReactNode; +}) { + return ( + {children} + ); +} + +export function useUsageDocAuth(): UsageDocAuthContextValue { + return useContext(UsageDocAuthContext); +} diff --git a/src/app/[locale]/usage-doc/layout.tsx b/src/app/[locale]/usage-doc/layout.tsx index 20572674e..06b1b1044 100644 --- a/src/app/[locale]/usage-doc/layout.tsx +++ b/src/app/[locale]/usage-doc/layout.tsx @@ -5,6 +5,7 @@ import { cache } from "react"; import { Link } from "@/i18n/routing"; import { getSession } from "@/lib/auth"; import { DashboardHeader } from "../dashboard/_components/dashboard-header"; +import { UsageDocAuthProvider } from "./_components/usage-doc-auth-context"; type UsageDocParams = { locale: string }; @@ -63,10 +64,8 @@ export default async function UsageDocLayout({ )} - {/* 文档内容主体 */}
- {/* 文档容器 */} - {children} + {children}
); diff --git a/src/app/[locale]/usage-doc/page.tsx b/src/app/[locale]/usage-doc/page.tsx index ee6a2f6d0..ba25ee6a1 100644 --- a/src/app/[locale]/usage-doc/page.tsx +++ b/src/app/[locale]/usage-doc/page.tsx @@ -8,6 +8,7 @@ import { Sheet, SheetContent, SheetHeader, SheetTitle, SheetTrigger } from "@/co import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; import { QuickLinks } from "./_components/quick-links"; import { type TocItem, TocNav } from "./_components/toc-nav"; +import { useUsageDocAuth } from "./_components/usage-doc-auth-context"; const headingClasses = { h2: "scroll-m-20 text-2xl font-semibold leading-snug text-foreground", @@ -1774,19 +1775,17 @@ curl -I ${resolvedOrigin}`} */ export default function UsageDocPage() { const t = useTranslations("usage"); + const { isLoggedIn } = useUsageDocAuth(); const [activeId, setActiveId] = useState(""); const [tocItems, setTocItems] = useState([]); const [tocReady, setTocReady] = useState(false); const [serviceOrigin, setServiceOrigin] = useState( () => (typeof window !== "undefined" && window.location.origin) || "" ); - const [isLoggedIn, setIsLoggedIn] = useState(false); const [sheetOpen, setSheetOpen] = useState(false); useEffect(() => { setServiceOrigin(window.location.origin); - // 检查是否已登录(通过检查 auth-token cookie) - setIsLoggedIn(document.cookie.includes("auth-token=")); }, []); // 生成目录并监听滚动 diff --git a/src/app/api/auth/login/route.ts b/src/app/api/auth/login/route.ts index 00bdd886e..5d8f4da18 100644 --- a/src/app/api/auth/login/route.ts +++ b/src/app/api/auth/login/route.ts @@ -1,12 +1,31 @@ import { type NextRequest, NextResponse } from "next/server"; import { getTranslations } from "next-intl/server"; import { defaultLocale, type Locale, locales } from "@/i18n/config"; -import { getLoginRedirectTarget, setAuthCookie, validateKey } from "@/lib/auth"; +import { + type AuthSession, + getLoginRedirectTarget, + getSessionTokenMode, + setAuthCookie, + toKeyFingerprint, + validateKey, +} from "@/lib/auth"; +import { getEnvConfig } from "@/lib/config/env.schema"; import { logger } from "@/lib/logger"; +import { withAuthResponseHeaders } from "@/lib/security/auth-response-headers"; +import { createCsrfOriginGuard } from "@/lib/security/csrf-origin-guard"; +import { LoginAbusePolicy } from "@/lib/security/login-abuse-policy"; // 需要数据库连接 export const runtime = "nodejs"; +const csrfGuard = createCsrfOriginGuard({ + allowedOrigins: [], + allowSameOrigin: true, + enforceInDevelopment: process.env.VITEST === "true", +}); + +const loginPolicy = new LoginAbusePolicy(); + /** * Get locale from request (cookie or Accept-Language header) */ @@ -52,40 +71,239 @@ async function getAuthErrorTranslations(locale: Locale) { } } +async function getAuthSecurityTranslations(locale: Locale) { + try { + return await getTranslations({ locale, namespace: "auth.security" }); + } catch (error) { + logger.warn("Login route: failed to load auth.security translations", { + locale, + error: error instanceof Error ? error.message : String(error), + }); + + try { + return await getTranslations({ locale: defaultLocale, namespace: "auth.security" }); + } catch (fallbackError) { + logger.error("Login route: failed to load default auth.security translations", { + locale: defaultLocale, + error: fallbackError instanceof Error ? fallbackError.message : String(fallbackError), + }); + return null; + } + } +} + +function hasSecureCookieHttpMismatch(request: NextRequest): boolean { + const env = getEnvConfig(); + const forwardedProto = request.headers.get("x-forwarded-proto")?.split(",")[0]?.trim(); + return env.ENABLE_SECURE_COOKIES && forwardedProto === "http"; +} + +function shouldIncludeFailureTaxonomy(request: NextRequest): boolean { + return request.headers.has("x-forwarded-proto"); +} + +function getClientIp(request: NextRequest): string { + // 1. Next.js platform-provided IP (trusted in Vercel / managed deployments) + const platformIp = (request as unknown as { ip?: string }).ip; + if (platformIp) { + return platformIp; + } + + // 2. x-real-ip is typically set by the closest trusted reverse proxy + const realIp = request.headers.get("x-real-ip")?.trim(); + if (realIp) { + return realIp; + } + + // 3. x-forwarded-for: take the rightmost (last) entry, which is the IP + // appended by the closest trusted proxy. The leftmost entry is + // client-controlled and can be spoofed. + const forwarded = request.headers.get("x-forwarded-for"); + if (forwarded) { + const ips = forwarded + .split(",") + .map((s) => s.trim()) + .filter(Boolean); + if (ips.length > 0) { + return ips[ips.length - 1]; + } + } + + return "unknown"; +} + +let sessionStoreInstance: + | import("@/lib/auth-session-store/redis-session-store").RedisSessionStore + | null = null; + +async function getLoginSessionStore() { + if (!sessionStoreInstance) { + const { RedisSessionStore } = await import("@/lib/auth-session-store/redis-session-store"); + sessionStoreInstance = new RedisSessionStore(); + } + return sessionStoreInstance; +} + +async function createOpaqueSession(key: string, session: AuthSession) { + const store = await getLoginSessionStore(); + return store.create({ + keyFingerprint: await toKeyFingerprint(key), + userId: session.user.id, + userRole: session.user.role, + }); +} + export async function POST(request: NextRequest) { + const csrfResult = csrfGuard.check(request); + if (!csrfResult.allowed) { + return withAuthResponseHeaders( + NextResponse.json({ errorCode: "CSRF_REJECTED" }, { status: 403 }) + ); + } + const locale = getLocaleFromRequest(request); const t = await getAuthErrorTranslations(locale); + const clientIp = getClientIp(request); + + const decision = loginPolicy.check(clientIp); + if (!decision.allowed) { + const response = withAuthResponseHeaders( + NextResponse.json( + { + error: t?.("loginFailed") ?? t?.("serverError") ?? "Too many attempts", + errorCode: "RATE_LIMITED", + }, + { status: 429 } + ) + ); + + if (decision.retryAfterSeconds != null) { + response.headers.set("Retry-After", String(decision.retryAfterSeconds)); + } + + return response; + } try { const { key } = await request.json(); - if (!key) { - return NextResponse.json({ error: t?.("apiKeyRequired") }, { status: 400 }); + if (!key || typeof key !== "string") { + loginPolicy.recordFailure(clientIp); + + if (!shouldIncludeFailureTaxonomy(request)) { + return withAuthResponseHeaders( + NextResponse.json( + { error: t?.("apiKeyRequired") ?? "API key is required" }, + { status: 400 } + ) + ); + } + + return withAuthResponseHeaders( + NextResponse.json( + { error: t?.("apiKeyRequired") ?? "API key is required", errorCode: "KEY_REQUIRED" }, + { status: 400 } + ) + ); } const session = await validateKey(key, { allowReadOnlyAccess: true }); if (!session) { - return NextResponse.json({ error: t?.("apiKeyInvalidOrExpired") }, { status: 401 }); + loginPolicy.recordFailure(clientIp); + + if (!shouldIncludeFailureTaxonomy(request)) { + return withAuthResponseHeaders( + NextResponse.json( + { error: t?.("apiKeyInvalidOrExpired") ?? "Authentication failed" }, + { status: 401 } + ) + ); + } + + const responseBody: { + error: string; + errorCode: "KEY_INVALID"; + httpMismatchGuidance?: string; + } = { + error: t?.("apiKeyInvalidOrExpired") ?? "Authentication failed", + errorCode: "KEY_INVALID", + }; + + if (hasSecureCookieHttpMismatch(request)) { + const securityT = await getAuthSecurityTranslations(locale); + responseBody.httpMismatchGuidance = + securityT?.("cookieWarningDescription") ?? + t?.("apiKeyInvalidOrExpired") ?? + t?.("serverError"); + } + + return withAuthResponseHeaders(NextResponse.json(responseBody, { status: 401 })); } - // 设置认证 cookie - await setAuthCookie(key); + const mode = getSessionTokenMode(); + if (mode === "legacy") { + await setAuthCookie(key); + } else if (mode === "dual") { + await setAuthCookie(key); + try { + await createOpaqueSession(key, session); + } catch (error) { + logger.warn("Failed to create opaque session in dual mode", { + error: error instanceof Error ? error.message : String(error), + }); + } + } else { + try { + const opaqueSession = await createOpaqueSession(key, session); + await setAuthCookie(opaqueSession.sessionId); + } catch (error) { + logger.error("Failed to create opaque session in opaque mode", { + error: error instanceof Error ? error.message : String(error), + }); + const serverError = t?.("serverError") ?? "Internal server error"; + return withAuthResponseHeaders( + NextResponse.json( + { error: serverError, errorCode: "SESSION_CREATE_FAILED" }, + { status: 503 } + ) + ); + } + } + + loginPolicy.recordSuccess(clientIp); const redirectTo = getLoginRedirectTarget(session); + const loginType = + session.user.role === "admin" + ? "admin" + : session.key.canLoginWebUi + ? "dashboard_user" + : "readonly_user"; - return NextResponse.json({ - ok: true, - user: { - id: session.user.id, - name: session.user.name, - description: session.user.description, - role: session.user.role, - }, - redirectTo, - }); + return withAuthResponseHeaders( + NextResponse.json({ + ok: true, + user: { + id: session.user.id, + name: session.user.name, + description: session.user.description, + role: session.user.role, + }, + redirectTo, + loginType, + }) + ); } catch (error) { logger.error("Login error:", error); - return NextResponse.json({ error: t?.("serverError") }, { status: 500 }); + const serverError = t?.("serverError") ?? "Internal server error"; + + if (!shouldIncludeFailureTaxonomy(request)) { + return withAuthResponseHeaders(NextResponse.json({ error: serverError }, { status: 500 })); + } + + return withAuthResponseHeaders( + NextResponse.json({ error: serverError, errorCode: "SERVER_ERROR" }, { status: 500 }) + ); } } diff --git a/src/app/api/auth/logout/route.ts b/src/app/api/auth/logout/route.ts index 3a435fc13..3233994e6 100644 --- a/src/app/api/auth/logout/route.ts +++ b/src/app/api/auth/logout/route.ts @@ -1,7 +1,78 @@ -import { NextResponse } from "next/server"; -import { clearAuthCookie } from "@/lib/auth"; +import { type NextRequest, NextResponse } from "next/server"; +import { + clearAuthCookie, + getAuthCookie, + getSessionTokenMode, + type SessionTokenMode, +} from "@/lib/auth"; +import { logger } from "@/lib/logger"; +import { withAuthResponseHeaders } from "@/lib/security/auth-response-headers"; +import { createCsrfOriginGuard } from "@/lib/security/csrf-origin-guard"; + +const csrfGuard = createCsrfOriginGuard({ + allowedOrigins: [], + allowSameOrigin: true, + enforceInDevelopment: process.env.VITEST === "true", +}); + +let sessionStoreInstance: + | import("@/lib/auth-session-store/redis-session-store").RedisSessionStore + | null = null; + +async function getLogoutSessionStore() { + if (!sessionStoreInstance) { + const { RedisSessionStore } = await import("@/lib/auth-session-store/redis-session-store"); + sessionStoreInstance = new RedisSessionStore(); + } + return sessionStoreInstance; +} + +function resolveSessionTokenMode(): SessionTokenMode { + try { + return getSessionTokenMode(); + } catch (err) { + logger.warn("[AuthLogout] Failed to resolve session token mode, defaulting to legacy", { + error: err instanceof Error ? err.message : String(err), + }); + return "legacy"; + } +} + +async function resolveAuthCookieToken(): Promise { + try { + return await getAuthCookie(); + } catch (err) { + logger.warn("[AuthLogout] Failed to read auth cookie", { + error: err instanceof Error ? err.message : String(err), + }); + return undefined; + } +} + +export async function POST(request: NextRequest) { + const csrfResult = csrfGuard.check(request); + if (!csrfResult.allowed) { + return withAuthResponseHeaders( + NextResponse.json({ errorCode: "CSRF_REJECTED" }, { status: 403 }) + ); + } + + const mode = resolveSessionTokenMode(); + + if (mode !== "legacy") { + try { + const sessionId = await resolveAuthCookieToken(); + if (sessionId) { + const store = await getLogoutSessionStore(); + await store.revoke(sessionId); + } + } catch (error) { + logger.warn("[AuthLogout] Failed to revoke opaque session during logout", { + error: error instanceof Error ? error.message : String(error), + }); + } + } -export async function POST() { await clearAuthCookie(); - return NextResponse.json({ ok: true }); + return withAuthResponseHeaders(NextResponse.json({ ok: true })); } diff --git a/src/app/v1/_lib/cors.ts b/src/app/v1/_lib/cors.ts index 6fc3909d5..5756f376e 100644 --- a/src/app/v1/_lib/cors.ts +++ b/src/app/v1/_lib/cors.ts @@ -15,12 +15,21 @@ const DEFAULT_CORS_HEADERS: Record = { /** * 动态构建 CORS 响应头 */ -function buildCorsHeaders(options: { origin?: string | null; requestHeaders?: string | null }) { +function buildCorsHeaders(options: { + origin?: string | null; + requestHeaders?: string | null; + allowCredentials?: boolean; +}) { const headers = new Headers(DEFAULT_CORS_HEADERS); - if (options.origin) { + // Only reflect specific origin when credentials are explicitly opted-in. + // The proxy API uses Bearer tokens; reflecting arbitrary origins with + // credentials enabled would let any malicious site make credentialed + // cross-origin requests. + if (options.allowCredentials && options.origin) { headers.set("Access-Control-Allow-Origin", options.origin); headers.append("Vary", "Origin"); + headers.set("Access-Control-Allow-Credentials", "true"); } if (options.requestHeaders) { @@ -28,10 +37,6 @@ function buildCorsHeaders(options: { origin?: string | null; requestHeaders?: st headers.append("Vary", "Access-Control-Request-Headers"); } - if (headers.get("Access-Control-Allow-Origin") !== "*") { - headers.set("Access-Control-Allow-Credentials", "true"); - } - return headers; } @@ -75,7 +80,7 @@ function mergeVaryHeader(existing: string | null, newValue: string): string { */ export function applyCors( res: Response, - ctx: { origin?: string | null; requestHeaders?: string | null } + ctx: { origin?: string | null; requestHeaders?: string | null; allowCredentials?: boolean } ): Response { const corsHeaders = buildCorsHeaders(ctx); @@ -138,6 +143,7 @@ export function applyCors( export function buildPreflightResponse(options: { origin?: string | null; requestHeaders?: string | null; + allowCredentials?: boolean; }): Response { return new Response(null, { status: 204, headers: buildCorsHeaders(options) }); } diff --git a/src/app/v1/_lib/proxy/auth-guard.ts b/src/app/v1/_lib/proxy/auth-guard.ts index c652116c7..ef8e3a008 100644 --- a/src/app/v1/_lib/proxy/auth-guard.ts +++ b/src/app/v1/_lib/proxy/auth-guard.ts @@ -1,12 +1,67 @@ import { logger } from "@/lib/logger"; +import { LoginAbusePolicy } from "@/lib/security/login-abuse-policy"; import { validateApiKeyAndGetUser } from "@/repository/key"; import { markUserExpired } from "@/repository/user"; import { GEMINI_PROTOCOL } from "../gemini/protocol"; import { ProxyResponses } from "./responses"; import type { AuthState, ProxySession } from "./session"; +/** + * Pre-auth rate limiter: throttles repeated authentication failures per IP + * to prevent brute-force API key enumeration on /v1/* endpoints. + * + * Uses the same LoginAbusePolicy as the login route but with separate + * thresholds appropriate for programmatic API access. + */ +const proxyAuthPolicy = new LoginAbusePolicy({ + maxAttemptsPerIp: 20, + maxAttemptsPerKey: 20, + windowSeconds: 300, + lockoutSeconds: 600, +}); + +function extractClientIp(session: ProxySession): string { + // Prefer x-real-ip (set by trusted reverse proxy), then rightmost + // x-forwarded-for entry, avoiding the client-spoofable leftmost value. + const realIp = session.headers.get("x-real-ip")?.trim(); + if (realIp) return realIp; + + const forwarded = session.headers.get("x-forwarded-for"); + if (forwarded) { + const ips = forwarded + .split(",") + .map((s) => s.trim()) + .filter(Boolean); + if (ips.length > 0) return ips[ips.length - 1]; + } + + return "unknown"; +} + export class ProxyAuthenticator { static async ensure(session: ProxySession): Promise { + // Pre-auth rate limit: block IPs with too many recent auth failures + const clientIp = extractClientIp(session); + const rateLimitDecision = proxyAuthPolicy.check(clientIp); + if (!rateLimitDecision.allowed) { + const retryAfter = rateLimitDecision.retryAfterSeconds; + const response = ProxyResponses.buildError( + 429, + "Too many authentication failures. Please retry later.", + "rate_limit_error" + ); + if (retryAfter != null) { + const headers = new Headers(response.headers); + headers.set("Retry-After", String(retryAfter)); + return new Response(response.body, { + status: response.status, + statusText: response.statusText, + headers, + }); + } + return response; + } + const authHeader = session.headers.get("authorization") ?? undefined; const apiKeyHeader = session.headers.get("x-api-key") ?? undefined; // Gemini CLI 认证:支持 x-goog-api-key 头部和 key 查询参数 @@ -22,9 +77,13 @@ export class ProxyAuthenticator { session.setAuthState(authState); if (authState.success) { + proxyAuthPolicy.recordSuccess(clientIp); return null; } + // Record failure for rate limiting + proxyAuthPolicy.recordFailure(clientIp); + // 返回详细的错误信息,帮助用户快速定位问题 return authState.errorResponse ?? ProxyResponses.buildError(401, "认证失败"); } diff --git a/src/lib/api/action-adapter-openapi.ts b/src/lib/api/action-adapter-openapi.ts index 80f7950ac..338ec7047 100644 --- a/src/lib/api/action-adapter-openapi.ts +++ b/src/lib/api/action-adapter-openapi.ts @@ -12,7 +12,7 @@ import { createRoute, z } from "@hono/zod-openapi"; import type { Context } from "hono"; import { getCookie } from "hono/cookie"; import type { ActionResult } from "@/actions/types"; -import { runWithAuthSession, validateKey } from "@/lib/auth"; +import { AUTH_COOKIE_NAME, runWithAuthSession, validateAuthToken } from "@/lib/auth"; import { logger } from "@/lib/logger"; function getBearerTokenFromAuthHeader(raw: string | undefined): string | undefined { @@ -300,20 +300,21 @@ export function createActionRoute( const fullPath = `${module}.${actionName}`; try { - let authSession: Awaited> | null = null; + let authSession: Awaited> | null = null; // 0. 认证检查 (如果需要) if (requiresAuth) { const authToken = - getCookie(c, "auth-token") ?? getBearerTokenFromAuthHeader(c.req.header("authorization")); + getCookie(c, AUTH_COOKIE_NAME) ?? + getBearerTokenFromAuthHeader(c.req.header("authorization")); if (!authToken) { - logger.warn(`[ActionAPI] ${fullPath} 认证失败: 缺少 auth-token`); + logger.warn(`[ActionAPI] ${fullPath} 认证失败: 缺少 ${AUTH_COOKIE_NAME}`); return c.json({ ok: false, error: "未认证" }, 401); } - const session = await validateKey(authToken, { allowReadOnlyAccess }); + const session = await validateAuthToken(authToken, { allowReadOnlyAccess }); if (!session) { - logger.warn(`[ActionAPI] ${fullPath} 认证失败: 无效的 auth-token`); + logger.warn(`[ActionAPI] ${fullPath} 认证失败: 无效的 ${AUTH_COOKIE_NAME}`); return c.json({ ok: false, error: "认证无效或已过期" }, 401); } authSession = session; diff --git a/src/lib/auth-session-store/index.ts b/src/lib/auth-session-store/index.ts new file mode 100644 index 000000000..f6f75cc1a --- /dev/null +++ b/src/lib/auth-session-store/index.ts @@ -0,0 +1,20 @@ +export interface SessionData { + sessionId: string; + keyFingerprint: string; + userId: number; + userRole: string; + createdAt: number; + expiresAt: number; +} + +export interface SessionStore { + create( + data: Omit, + ttlSeconds?: number + ): Promise; + read(sessionId: string): Promise; + revoke(sessionId: string): Promise; + rotate(oldSessionId: string): Promise; +} + +export const DEFAULT_SESSION_TTL = 604800; diff --git a/src/lib/auth-session-store/redis-session-store.ts b/src/lib/auth-session-store/redis-session-store.ts new file mode 100644 index 000000000..904358f06 --- /dev/null +++ b/src/lib/auth-session-store/redis-session-store.ts @@ -0,0 +1,225 @@ +import "server-only"; + +import type Redis from "ioredis"; +import { logger } from "@/lib/logger"; +import { getRedisClient } from "@/lib/redis"; +import { DEFAULT_SESSION_TTL, type SessionData, type SessionStore } from "./index"; + +const SESSION_KEY_PREFIX = "cch:session:"; +const MIN_TTL_SECONDS = 1; + +type RedisSessionClient = Pick; + +export interface RedisSessionStoreOptions { + defaultTtlSeconds?: number; + redisClient?: RedisSessionClient | null; +} + +function toLogError(error: unknown): string { + return error instanceof Error ? error.message : String(error); +} + +function normalizeTtlSeconds(value: number | undefined): number { + if (!Number.isFinite(value) || typeof value !== "number" || value <= 0) { + return DEFAULT_SESSION_TTL; + } + + return Math.max(MIN_TTL_SECONDS, Math.floor(value)); +} + +function buildSessionKey(sessionId: string): string { + return `${SESSION_KEY_PREFIX}${sessionId}`; +} + +function parseSessionData(raw: string): SessionData | null { + try { + const parsed: unknown = JSON.parse(raw); + if (!parsed || typeof parsed !== "object" || Array.isArray(parsed)) { + return null; + } + + const obj = parsed as Record; + if (typeof obj.sessionId !== "string") return null; + if (typeof obj.keyFingerprint !== "string") return null; + if (typeof obj.userRole !== "string") return null; + if (typeof obj.userId !== "number" || !Number.isInteger(obj.userId)) return null; + if (!Number.isFinite(obj.createdAt) || typeof obj.createdAt !== "number") return null; + if (!Number.isFinite(obj.expiresAt) || typeof obj.expiresAt !== "number") return null; + + return { + sessionId: obj.sessionId, + keyFingerprint: obj.keyFingerprint, + userId: obj.userId as number, + userRole: obj.userRole, + createdAt: obj.createdAt, + expiresAt: obj.expiresAt, + }; + } catch { + return null; + } +} + +function resolveRotateTtlSeconds(expiresAt: number): number | null { + if (!Number.isFinite(expiresAt) || typeof expiresAt !== "number") { + return DEFAULT_SESSION_TTL; + } + + const remainingMs = expiresAt - Date.now(); + if (remainingMs <= 0) { + return null; + } + return Math.max(MIN_TTL_SECONDS, Math.ceil(remainingMs / 1000)); +} + +export class RedisSessionStore implements SessionStore { + private readonly defaultTtlSeconds: number; + private readonly redisClient?: RedisSessionClient | null; + + constructor(options: RedisSessionStoreOptions = {}) { + this.defaultTtlSeconds = normalizeTtlSeconds(options.defaultTtlSeconds); + this.redisClient = options.redisClient; + } + + private resolveRedisClient(): RedisSessionClient | null { + if (this.redisClient !== undefined) { + return this.redisClient; + } + + return getRedisClient({ allowWhenRateLimitDisabled: true }) as RedisSessionClient | null; + } + + private getReadyRedis(): RedisSessionClient | null { + const redis = this.resolveRedisClient(); + if (!redis || redis.status !== "ready") { + return null; + } + + return redis; + } + + async create( + data: Omit, + ttlSeconds = this.defaultTtlSeconds + ): Promise { + const ttl = normalizeTtlSeconds(ttlSeconds); + const createdAt = Date.now(); + const sessionData: SessionData = { + sessionId: `sid_${globalThis.crypto.randomUUID()}`, + keyFingerprint: data.keyFingerprint, + userId: data.userId, + userRole: data.userRole, + createdAt, + expiresAt: createdAt + ttl * 1000, + }; + + const redis = this.getReadyRedis(); + if (!redis) { + throw new Error("Redis not ready: session not persisted"); + } + + try { + await redis.setex(buildSessionKey(sessionData.sessionId), ttl, JSON.stringify(sessionData)); + } catch (error) { + logger.error("[AuthSessionStore] Failed to create session", { + error: toLogError(error), + sessionId: sessionData.sessionId, + }); + throw error; + } + + return sessionData; + } + + async read(sessionId: string): Promise { + const redis = this.getReadyRedis(); + if (!redis) { + return null; + } + + try { + const value = await redis.get(buildSessionKey(sessionId)); + if (!value) { + return null; + } + + const parsed = parseSessionData(value); + if (!parsed) { + logger.warn("[AuthSessionStore] Invalid session payload", { sessionId }); + return null; + } + + return parsed; + } catch (error) { + logger.error("[AuthSessionStore] Failed to read session", { + error: toLogError(error), + sessionId, + }); + return null; + } + } + + async revoke(sessionId: string): Promise { + const redis = this.getReadyRedis(); + if (!redis) { + logger.warn("[AuthSessionStore] Redis not ready during revoke", { sessionId }); + return false; + } + + try { + const deleted = await redis.del(buildSessionKey(sessionId)); + return deleted > 0; + } catch (error) { + logger.error("[AuthSessionStore] Failed to revoke session", { + error: toLogError(error), + sessionId, + }); + return false; + } + } + + async rotate(oldSessionId: string): Promise { + const oldSession = await this.read(oldSessionId); + if (!oldSession) { + return null; + } + + const ttlSeconds = resolveRotateTtlSeconds(oldSession.expiresAt); + if (ttlSeconds === null) { + logger.warn("[AuthSessionStore] Cannot rotate expired session", { + sessionId: oldSessionId, + expiresAt: oldSession.expiresAt, + }); + return null; + } + let nextSession: SessionData; + try { + nextSession = await this.create( + { + keyFingerprint: oldSession.keyFingerprint, + userId: oldSession.userId, + userRole: oldSession.userRole, + }, + ttlSeconds + ); + } catch (error) { + logger.error("[AuthSessionStore] Failed to create rotated session", { + error: toLogError(error), + oldSessionId, + }); + return null; + } + + const revoked = await this.revoke(oldSessionId); + if (!revoked) { + logger.warn( + "[AuthSessionStore] Failed to revoke old session during rotate; old session will expire naturally", + { + oldSessionId, + newSessionId: nextSession.sessionId, + } + ); + } + + return nextSession; + } +} diff --git a/src/lib/auth.ts b/src/lib/auth.ts index 62a2cac0f..4f6749282 100644 --- a/src/lib/auth.ts +++ b/src/lib/auth.ts @@ -1,10 +1,23 @@ import { cookies, headers } from "next/headers"; +import type { NextResponse } from "next/server"; import { config } from "@/lib/config/config"; import { getEnvConfig } from "@/lib/config/env.schema"; -import { validateApiKeyAndGetUser } from "@/repository/key"; +import { logger } from "@/lib/logger"; +import { constantTimeEqual } from "@/lib/security/constant-time-compare"; +import { findKeyList, validateApiKeyAndGetUser } from "@/repository/key"; import type { Key } from "@/types/key"; import type { User } from "@/types/user"; +/** + * Apply no-store / cache-busting headers to auth responses that mutate session state. + * Prevents browsers and intermediary caches from storing sensitive auth responses. + */ +export function withNoStoreHeaders(response: T): T { + response.headers.set("Cache-Control", "no-store, no-cache, must-revalidate"); + response.headers.set("Pragma", "no-cache"); + return response; +} + export type ScopedAuthContext = { session: AuthSession; /** @@ -25,7 +38,7 @@ declare global { var __cchAuthSessionStorage: AuthSessionStorage | undefined; } -const AUTH_COOKIE_NAME = "auth-token"; +export const AUTH_COOKIE_NAME = "auth-token"; const AUTH_COOKIE_MAX_AGE = 60 * 60 * 24 * 7; // 7 days export interface AuthSession { @@ -33,6 +46,95 @@ export interface AuthSession { key: Key; } +export type SessionTokenMode = "legacy" | "dual" | "opaque"; +export type SessionTokenKind = "legacy" | "opaque"; + +export function getSessionTokenMode(): SessionTokenMode { + return getEnvConfig().SESSION_TOKEN_MODE; +} + +// Session contract: opaque token is a random string, not the API key +export interface OpaqueSessionContract { + sessionId: string; // random opaque token + keyFingerprint: string; // hash of the API key (for audit, not auth) + createdAt: number; // unix timestamp + expiresAt: number; // unix timestamp + userId: number; + userRole: string; +} + +export interface SessionTokenMigrationFlags { + dualReadWindowEnabled: boolean; + hardCutoverEnabled: boolean; + emergencyRollbackEnabled: boolean; +} + +export const SESSION_TOKEN_SEMANTICS = { + expiry: "hard_expiry_at_expires_at", + rotation: "rotate_before_expiry_and_revoke_previous_session_id", + revocation: "server_side_revocation_invalidates_session_immediately", + compatibility: { + legacy: "accept_legacy_only", + dual: "accept_legacy_and_opaque", + opaque: "accept_opaque_only", + }, +} as const; + +export function getSessionTokenMigrationFlags( + mode: SessionTokenMode = getSessionTokenMode() +): SessionTokenMigrationFlags { + return { + dualReadWindowEnabled: mode === "dual", + hardCutoverEnabled: mode === "opaque", + emergencyRollbackEnabled: mode === "legacy", + }; +} + +export function isSessionTokenKindAccepted( + mode: SessionTokenMode, + kind: SessionTokenKind +): boolean { + if (mode === "dual") return true; + if (mode === "legacy") return kind === "legacy"; + return kind === "opaque"; +} + +export function isOpaqueSessionContract(value: unknown): value is OpaqueSessionContract { + if (!value || typeof value !== "object") return false; + + const candidate = value as Record; + return ( + typeof candidate.sessionId === "string" && + candidate.sessionId.length > 0 && + typeof candidate.keyFingerprint === "string" && + candidate.keyFingerprint.length > 0 && + typeof candidate.createdAt === "number" && + Number.isFinite(candidate.createdAt) && + typeof candidate.expiresAt === "number" && + Number.isFinite(candidate.expiresAt) && + candidate.expiresAt > candidate.createdAt && + typeof candidate.userId === "number" && + Number.isInteger(candidate.userId) && + typeof candidate.userRole === "string" && + candidate.userRole.length > 0 + ); +} + +const OPAQUE_SESSION_ID_PREFIX = "sid_"; + +export function detectSessionTokenKind(token: string): SessionTokenKind { + const trimmed = token.trim(); + if (!trimmed) return "legacy"; + return trimmed.startsWith(OPAQUE_SESSION_ID_PREFIX) ? "opaque" : "legacy"; +} + +export function isSessionTokenAccepted( + token: string, + mode: SessionTokenMode = getSessionTokenMode() +): boolean { + return isSessionTokenKindAccepted(mode, detectSessionTokenKind(token)); +} + export function runWithAuthSession( session: AuthSession, fn: () => T, @@ -65,7 +167,7 @@ export async function validateKey( const allowReadOnlyAccess = options?.allowReadOnlyAccess ?? false; const adminToken = config.auth.adminToken; - if (adminToken && keyString === adminToken) { + if (adminToken && constantTimeEqual(keyString, adminToken)) { const now = new Date(); const adminUser: User = { id: -1, @@ -158,6 +260,40 @@ export async function clearAuthCookie() { cookieStore.delete(AUTH_COOKIE_NAME); } +export async function validateAuthToken( + token: string, + options?: { allowReadOnlyAccess?: boolean } +): Promise { + const mode = getSessionTokenMode(); + + if (mode !== "legacy") { + try { + const sessionStore = await getSessionStore(); + const sessionData = await sessionStore.read(token); + if (sessionData) { + if (sessionData.expiresAt <= Date.now()) { + logger.warn("Opaque session expired (application-level check)", { + sessionId: sessionData.sessionId, + expiresAt: sessionData.expiresAt, + }); + return null; + } + return convertToAuthSession(sessionData, options); + } + } catch (error) { + logger.warn("Opaque session read failed", { + error: error instanceof Error ? error.message : String(error), + }); + } + } + + if (mode === "legacy" || mode === "dual") { + return validateKey(token, options); + } + + return null; +} + export async function getSession(options?: { /** * 允许仅访问只读页面(如 my-usage),跳过 canLoginWebUi 校验 @@ -181,7 +317,79 @@ export async function getSession(options?: { return null; } - return validateKey(keyString, options); + return validateAuthToken(keyString, options); +} + +type SessionStoreReader = { + read(sessionId: string): Promise; +}; + +let sessionStorePromise: Promise | null = null; + +async function getSessionStore(): Promise { + if (!sessionStorePromise) { + sessionStorePromise = import("@/lib/auth-session-store/redis-session-store") + .then(({ RedisSessionStore }) => new RedisSessionStore()) + .catch((error) => { + sessionStorePromise = null; + throw error; + }); + } + + return sessionStorePromise; +} + +export async function toKeyFingerprint(keyString: string): Promise { + const digest = await crypto.subtle.digest("SHA-256", new TextEncoder().encode(keyString)); + const hex = Array.from(new Uint8Array(digest), (byte) => byte.toString(16).padStart(2, "0")).join( + "" + ); + return `sha256:${hex}`; +} + +function normalizeKeyFingerprint(fingerprint: string): string { + return fingerprint.startsWith("sha256:") ? fingerprint : `sha256:${fingerprint}`; +} + +async function convertToAuthSession( + sessionData: OpaqueSessionContract, + options?: { allowReadOnlyAccess?: boolean } +): Promise { + const expectedFingerprint = normalizeKeyFingerprint(sessionData.keyFingerprint); + + // Admin token uses virtual user (id=-1) which has no DB keys; + // verify fingerprint against the configured admin token directly. + if (sessionData.userId === -1) { + const adminToken = config.auth.adminToken; + if (!adminToken) return null; + const adminFingerprint = await toKeyFingerprint(adminToken); + return constantTimeEqual(adminFingerprint, expectedFingerprint) + ? validateKey(adminToken, options) + : null; + } + + const keyList = await findKeyList(sessionData.userId); + + for (const key of keyList) { + const keyFingerprint = await toKeyFingerprint(key.key); + if (constantTimeEqual(keyFingerprint, expectedFingerprint)) { + return validateKey(key.key, options); + } + } + + return null; +} + +export async function getSessionWithDualRead(options?: { + allowReadOnlyAccess?: boolean; +}): Promise { + return getSession(options); +} + +export async function validateSession(options?: { + allowReadOnlyAccess?: boolean; +}): Promise { + return getSessionWithDualRead(options); } function parseBearerToken(raw: string | null | undefined): string | undefined { diff --git a/src/lib/config/env.schema.ts b/src/lib/config/env.schema.ts index b7dacd738..dcdf167ef 100644 --- a/src/lib/config/env.schema.ts +++ b/src/lib/config/env.schema.ts @@ -93,6 +93,7 @@ export const EnvSchema = z.object({ REDIS_TLS_REJECT_UNAUTHORIZED: z.string().default("true").transform(booleanTransform), ENABLE_RATE_LIMIT: z.string().default("true").transform(booleanTransform), ENABLE_SECURE_COOKIES: z.string().default("true").transform(booleanTransform), + SESSION_TOKEN_MODE: z.enum(["legacy", "dual", "opaque"]).default("opaque"), SESSION_TTL: z.coerce.number().default(300), // 会话消息存储控制 // - false (默认):存储请求/响应体但对 message 内容脱敏 [REDACTED] diff --git a/src/lib/provider-batch-patch-error-codes.ts b/src/lib/provider-batch-patch-error-codes.ts new file mode 100644 index 000000000..597b12306 --- /dev/null +++ b/src/lib/provider-batch-patch-error-codes.ts @@ -0,0 +1,11 @@ +export const PROVIDER_BATCH_PATCH_ERROR_CODES = { + INVALID_INPUT: "INVALID_INPUT", + NOTHING_TO_APPLY: "NOTHING_TO_APPLY", + PREVIEW_EXPIRED: "PREVIEW_EXPIRED", + PREVIEW_STALE: "PREVIEW_STALE", + UNDO_EXPIRED: "UNDO_EXPIRED", + UNDO_CONFLICT: "UNDO_CONFLICT", +} as const; + +export type ProviderBatchPatchErrorCode = + (typeof PROVIDER_BATCH_PATCH_ERROR_CODES)[keyof typeof PROVIDER_BATCH_PATCH_ERROR_CODES]; diff --git a/src/lib/provider-patch-contract.ts b/src/lib/provider-patch-contract.ts new file mode 100644 index 000000000..659176713 --- /dev/null +++ b/src/lib/provider-patch-contract.ts @@ -0,0 +1,974 @@ +import type { + ProviderBatchApplyUpdates, + ProviderBatchPatch, + ProviderBatchPatchDraft, + ProviderBatchPatchField, + ProviderPatchDraftInput, + ProviderPatchOperation, +} from "@/types/provider"; + +export const PROVIDER_PATCH_ERROR_CODES = { + INVALID_PATCH_SHAPE: "INVALID_PATCH_SHAPE", +} as const; + +export type ProviderPatchErrorCode = + (typeof PROVIDER_PATCH_ERROR_CODES)[keyof typeof PROVIDER_PATCH_ERROR_CODES]; + +interface ProviderPatchError { + code: ProviderPatchErrorCode; + field: ProviderBatchPatchField | "__root__"; + message: string; +} + +type ProviderPatchResult = { ok: true; data: T } | { ok: false; error: ProviderPatchError }; + +const PATCH_INPUT_KEYS = new Set(["set", "clear", "no_change"]); +const PATCH_FIELDS: ProviderBatchPatchField[] = [ + "is_enabled", + "priority", + "weight", + "cost_multiplier", + "group_tag", + "model_redirects", + "allowed_models", + "anthropic_thinking_budget_preference", + "anthropic_adaptive_thinking", + // Routing + "preserve_client_ip", + "group_priorities", + "cache_ttl_preference", + "swap_cache_ttl_billing", + "context_1m_preference", + "codex_reasoning_effort_preference", + "codex_reasoning_summary_preference", + "codex_text_verbosity_preference", + "codex_parallel_tool_calls_preference", + "anthropic_max_tokens_preference", + "gemini_google_search_preference", + // Rate Limit + "limit_5h_usd", + "limit_daily_usd", + "daily_reset_mode", + "daily_reset_time", + "limit_weekly_usd", + "limit_monthly_usd", + "limit_total_usd", + "limit_concurrent_sessions", + // Circuit Breaker + "circuit_breaker_failure_threshold", + "circuit_breaker_open_duration", + "circuit_breaker_half_open_success_threshold", + "max_retry_attempts", + // Network + "proxy_url", + "proxy_fallback_to_direct", + "first_byte_timeout_streaming_ms", + "streaming_idle_timeout_ms", + "request_timeout_non_streaming_ms", + // MCP + "mcp_passthrough_type", + "mcp_passthrough_url", +]; +const PATCH_FIELD_SET = new Set(PATCH_FIELDS); + +const CLEARABLE_FIELDS: Record = { + is_enabled: false, + priority: false, + weight: false, + cost_multiplier: false, + group_tag: true, + model_redirects: true, + allowed_models: true, + anthropic_thinking_budget_preference: true, + anthropic_adaptive_thinking: true, + // Routing + preserve_client_ip: false, + group_priorities: true, + cache_ttl_preference: true, + swap_cache_ttl_billing: false, + context_1m_preference: true, + codex_reasoning_effort_preference: true, + codex_reasoning_summary_preference: true, + codex_text_verbosity_preference: true, + codex_parallel_tool_calls_preference: true, + anthropic_max_tokens_preference: true, + gemini_google_search_preference: true, + // Rate Limit + limit_5h_usd: true, + limit_daily_usd: true, + daily_reset_mode: false, + daily_reset_time: false, + limit_weekly_usd: true, + limit_monthly_usd: true, + limit_total_usd: true, + limit_concurrent_sessions: false, + // Circuit Breaker + circuit_breaker_failure_threshold: false, + circuit_breaker_open_duration: false, + circuit_breaker_half_open_success_threshold: false, + max_retry_attempts: true, + // Network + proxy_url: true, + proxy_fallback_to_direct: false, + first_byte_timeout_streaming_ms: false, + streaming_idle_timeout_ms: false, + request_timeout_non_streaming_ms: false, + // MCP + mcp_passthrough_type: false, + mcp_passthrough_url: true, +}; + +function isStringRecord(value: unknown): value is Record { + if (!isRecord(value) || Array.isArray(value)) { + return false; + } + + return Object.entries(value).every( + ([key, entry]) => typeof key === "string" && typeof entry === "string" + ); +} + +function isNumberRecord(value: unknown): value is Record { + if (!isRecord(value) || Array.isArray(value)) { + return false; + } + + return Object.values(value).every((v) => typeof v === "number" && Number.isFinite(v)); +} + +function isAdaptiveThinkingConfig( + value: unknown +): value is NonNullable { + if (!isRecord(value)) { + return false; + } + + const effortValues = new Set(["low", "medium", "high", "max"]); + const modeValues = new Set(["specific", "all"]); + + if (typeof value.effort !== "string" || !effortValues.has(value.effort)) { + return false; + } + + if (typeof value.modelMatchMode !== "string" || !modeValues.has(value.modelMatchMode)) { + return false; + } + + if (!Array.isArray(value.models) || !value.models.every((model) => typeof model === "string")) { + return false; + } + + if (value.modelMatchMode === "specific" && value.models.length === 0) { + return false; + } + + return true; +} + +function isThinkingBudgetPreference(value: unknown): boolean { + if (value === "inherit") { + return true; + } + + if (typeof value !== "string") { + return false; + } + + if (!/^\d+$/.test(value)) { + return false; + } + + const parsed = Number.parseInt(value, 10); + return parsed >= 1024 && parsed <= 32000; +} + +function isMaxTokensPreference(value: unknown): boolean { + if (value === "inherit") { + return true; + } + + if (typeof value !== "string") { + return false; + } + + if (!/^\d+$/.test(value)) { + return false; + } + + const parsed = Number.parseInt(value, 10); + return parsed > 0; +} + +function isValidSetValue(field: ProviderBatchPatchField, value: unknown): boolean { + switch (field) { + case "is_enabled": + case "preserve_client_ip": + case "swap_cache_ttl_billing": + case "proxy_fallback_to_direct": + return typeof value === "boolean"; + case "priority": + case "weight": + case "cost_multiplier": + case "limit_5h_usd": + case "limit_daily_usd": + case "limit_weekly_usd": + case "limit_monthly_usd": + case "limit_total_usd": + case "limit_concurrent_sessions": + case "circuit_breaker_failure_threshold": + case "circuit_breaker_open_duration": + case "circuit_breaker_half_open_success_threshold": + case "max_retry_attempts": + case "first_byte_timeout_streaming_ms": + case "streaming_idle_timeout_ms": + case "request_timeout_non_streaming_ms": + return typeof value === "number" && Number.isFinite(value); + case "group_tag": + case "daily_reset_time": + case "proxy_url": + case "mcp_passthrough_url": + return typeof value === "string"; + case "group_priorities": + return isNumberRecord(value); + case "cache_ttl_preference": + return value === "inherit" || value === "5m" || value === "1h"; + case "context_1m_preference": + return value === "inherit" || value === "force_enable" || value === "disabled"; + case "daily_reset_mode": + return value === "fixed" || value === "rolling"; + case "codex_reasoning_effort_preference": + return ( + value === "inherit" || + value === "none" || + value === "minimal" || + value === "low" || + value === "medium" || + value === "high" || + value === "xhigh" + ); + case "codex_reasoning_summary_preference": + return value === "inherit" || value === "auto" || value === "detailed"; + case "codex_text_verbosity_preference": + return value === "inherit" || value === "low" || value === "medium" || value === "high"; + case "codex_parallel_tool_calls_preference": + return value === "inherit" || value === "true" || value === "false"; + case "anthropic_thinking_budget_preference": + return isThinkingBudgetPreference(value); + case "anthropic_max_tokens_preference": + return isMaxTokensPreference(value); + case "gemini_google_search_preference": + return value === "inherit" || value === "enabled" || value === "disabled"; + case "mcp_passthrough_type": + return value === "none" || value === "minimax" || value === "glm" || value === "custom"; + case "model_redirects": + return isStringRecord(value); + case "allowed_models": + return Array.isArray(value) && value.every((model) => typeof model === "string"); + case "anthropic_adaptive_thinking": + return isAdaptiveThinkingConfig(value); + default: + return false; + } +} + +function createNoChangePatch(): ProviderPatchOperation { + return { mode: "no_change" }; +} + +function createInvalidPatchShapeError( + field: ProviderBatchPatchField, + message: string +): ProviderPatchResult { + return { + ok: false, + error: { + code: PROVIDER_PATCH_ERROR_CODES.INVALID_PATCH_SHAPE, + field, + message, + }, + }; +} + +function createInvalidRootPatchShapeError(message: string): ProviderPatchResult { + return { + ok: false, + error: { + code: PROVIDER_PATCH_ERROR_CODES.INVALID_PATCH_SHAPE, + field: "__root__", + message, + }, + }; +} + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} + +function normalizePatchField( + field: ProviderBatchPatchField, + input: ProviderPatchDraftInput +): ProviderPatchResult> { + if (input === undefined) { + return { ok: true, data: createNoChangePatch() }; + } + + if (!isRecord(input)) { + return createInvalidPatchShapeError(field, "Patch input must be an object"); + } + + const unknownKeys = Object.keys(input).filter((key) => !PATCH_INPUT_KEYS.has(key)); + if (unknownKeys.length > 0) { + return createInvalidPatchShapeError( + field, + `Patch input contains unknown keys: ${unknownKeys.join(",")}` + ); + } + + const hasSet = Object.hasOwn(input, "set"); + const hasClear = input.clear === true; + const hasNoChange = input.no_change === true; + const modeCount = [hasSet, hasClear, hasNoChange].filter(Boolean).length; + + if (modeCount !== 1) { + return createInvalidPatchShapeError(field, "Patch input must choose exactly one mode"); + } + + if (hasSet) { + if (input.set === undefined) { + return createInvalidPatchShapeError(field, "set mode requires a defined value"); + } + + if (!isValidSetValue(field, input.set)) { + return createInvalidPatchShapeError(field, "set mode value is invalid for this field"); + } + + return { ok: true, data: { mode: "set", value: input.set as T } }; + } + + if (hasNoChange) { + return { ok: true, data: createNoChangePatch() }; + } + + if (!CLEARABLE_FIELDS[field]) { + return createInvalidPatchShapeError(field, "clear mode is not supported for this field"); + } + + return { ok: true, data: { mode: "clear" } }; +} + +export function normalizeProviderBatchPatchDraft( + draft: unknown +): ProviderPatchResult { + if (!isRecord(draft) || Array.isArray(draft)) { + return createInvalidRootPatchShapeError("Patch draft must be an object"); + } + + const unknownFields = Object.keys(draft).filter( + (key) => !PATCH_FIELD_SET.has(key as ProviderBatchPatchField) + ); + if (unknownFields.length > 0) { + return createInvalidRootPatchShapeError( + `Patch draft contains unknown fields: ${unknownFields.join(",")}` + ); + } + + const typedDraft = draft as ProviderBatchPatchDraft; + + const isEnabled = normalizePatchField("is_enabled", typedDraft.is_enabled); + if (!isEnabled.ok) return isEnabled; + + const priority = normalizePatchField("priority", typedDraft.priority); + if (!priority.ok) return priority; + + const weight = normalizePatchField("weight", typedDraft.weight); + if (!weight.ok) return weight; + + const costMultiplier = normalizePatchField("cost_multiplier", typedDraft.cost_multiplier); + if (!costMultiplier.ok) return costMultiplier; + + const groupTag = normalizePatchField("group_tag", typedDraft.group_tag); + if (!groupTag.ok) return groupTag; + + const modelRedirects = normalizePatchField("model_redirects", typedDraft.model_redirects); + if (!modelRedirects.ok) return modelRedirects; + + const allowedModels = normalizePatchField("allowed_models", typedDraft.allowed_models); + if (!allowedModels.ok) return allowedModels; + + const thinkingBudget = normalizePatchField( + "anthropic_thinking_budget_preference", + typedDraft.anthropic_thinking_budget_preference + ); + if (!thinkingBudget.ok) return thinkingBudget; + + const adaptiveThinking = normalizePatchField( + "anthropic_adaptive_thinking", + typedDraft.anthropic_adaptive_thinking + ); + if (!adaptiveThinking.ok) return adaptiveThinking; + + // Routing + const preserveClientIp = normalizePatchField("preserve_client_ip", typedDraft.preserve_client_ip); + if (!preserveClientIp.ok) return preserveClientIp; + + const groupPriorities = normalizePatchField("group_priorities", typedDraft.group_priorities); + if (!groupPriorities.ok) return groupPriorities; + + const cacheTtlPref = normalizePatchField("cache_ttl_preference", typedDraft.cache_ttl_preference); + if (!cacheTtlPref.ok) return cacheTtlPref; + + const swapCacheTtlBilling = normalizePatchField( + "swap_cache_ttl_billing", + typedDraft.swap_cache_ttl_billing + ); + if (!swapCacheTtlBilling.ok) return swapCacheTtlBilling; + + const context1mPref = normalizePatchField( + "context_1m_preference", + typedDraft.context_1m_preference + ); + if (!context1mPref.ok) return context1mPref; + + const codexReasoningEffort = normalizePatchField( + "codex_reasoning_effort_preference", + typedDraft.codex_reasoning_effort_preference + ); + if (!codexReasoningEffort.ok) return codexReasoningEffort; + + const codexReasoningSummary = normalizePatchField( + "codex_reasoning_summary_preference", + typedDraft.codex_reasoning_summary_preference + ); + if (!codexReasoningSummary.ok) return codexReasoningSummary; + + const codexTextVerbosity = normalizePatchField( + "codex_text_verbosity_preference", + typedDraft.codex_text_verbosity_preference + ); + if (!codexTextVerbosity.ok) return codexTextVerbosity; + + const codexParallelToolCalls = normalizePatchField( + "codex_parallel_tool_calls_preference", + typedDraft.codex_parallel_tool_calls_preference + ); + if (!codexParallelToolCalls.ok) return codexParallelToolCalls; + + const anthropicMaxTokens = normalizePatchField( + "anthropic_max_tokens_preference", + typedDraft.anthropic_max_tokens_preference + ); + if (!anthropicMaxTokens.ok) return anthropicMaxTokens; + + const geminiGoogleSearch = normalizePatchField( + "gemini_google_search_preference", + typedDraft.gemini_google_search_preference + ); + if (!geminiGoogleSearch.ok) return geminiGoogleSearch; + + // Rate Limit + const limit5hUsd = normalizePatchField("limit_5h_usd", typedDraft.limit_5h_usd); + if (!limit5hUsd.ok) return limit5hUsd; + + const limitDailyUsd = normalizePatchField("limit_daily_usd", typedDraft.limit_daily_usd); + if (!limitDailyUsd.ok) return limitDailyUsd; + + const dailyResetMode = normalizePatchField("daily_reset_mode", typedDraft.daily_reset_mode); + if (!dailyResetMode.ok) return dailyResetMode; + + const dailyResetTime = normalizePatchField("daily_reset_time", typedDraft.daily_reset_time); + if (!dailyResetTime.ok) return dailyResetTime; + + const limitWeeklyUsd = normalizePatchField("limit_weekly_usd", typedDraft.limit_weekly_usd); + if (!limitWeeklyUsd.ok) return limitWeeklyUsd; + + const limitMonthlyUsd = normalizePatchField("limit_monthly_usd", typedDraft.limit_monthly_usd); + if (!limitMonthlyUsd.ok) return limitMonthlyUsd; + + const limitTotalUsd = normalizePatchField("limit_total_usd", typedDraft.limit_total_usd); + if (!limitTotalUsd.ok) return limitTotalUsd; + + const limitConcurrentSessions = normalizePatchField( + "limit_concurrent_sessions", + typedDraft.limit_concurrent_sessions + ); + if (!limitConcurrentSessions.ok) return limitConcurrentSessions; + + // Circuit Breaker + const cbFailureThreshold = normalizePatchField( + "circuit_breaker_failure_threshold", + typedDraft.circuit_breaker_failure_threshold + ); + if (!cbFailureThreshold.ok) return cbFailureThreshold; + + const cbOpenDuration = normalizePatchField( + "circuit_breaker_open_duration", + typedDraft.circuit_breaker_open_duration + ); + if (!cbOpenDuration.ok) return cbOpenDuration; + + const cbHalfOpenSuccess = normalizePatchField( + "circuit_breaker_half_open_success_threshold", + typedDraft.circuit_breaker_half_open_success_threshold + ); + if (!cbHalfOpenSuccess.ok) return cbHalfOpenSuccess; + + const maxRetryAttempts = normalizePatchField("max_retry_attempts", typedDraft.max_retry_attempts); + if (!maxRetryAttempts.ok) return maxRetryAttempts; + + // Network + const proxyUrl = normalizePatchField("proxy_url", typedDraft.proxy_url); + if (!proxyUrl.ok) return proxyUrl; + + const proxyFallbackToDirect = normalizePatchField( + "proxy_fallback_to_direct", + typedDraft.proxy_fallback_to_direct + ); + if (!proxyFallbackToDirect.ok) return proxyFallbackToDirect; + + const firstByteTimeout = normalizePatchField( + "first_byte_timeout_streaming_ms", + typedDraft.first_byte_timeout_streaming_ms + ); + if (!firstByteTimeout.ok) return firstByteTimeout; + + const streamingIdleTimeout = normalizePatchField( + "streaming_idle_timeout_ms", + typedDraft.streaming_idle_timeout_ms + ); + if (!streamingIdleTimeout.ok) return streamingIdleTimeout; + + const requestTimeoutNonStreaming = normalizePatchField( + "request_timeout_non_streaming_ms", + typedDraft.request_timeout_non_streaming_ms + ); + if (!requestTimeoutNonStreaming.ok) return requestTimeoutNonStreaming; + + // MCP + const mcpPassthroughType = normalizePatchField( + "mcp_passthrough_type", + typedDraft.mcp_passthrough_type + ); + if (!mcpPassthroughType.ok) return mcpPassthroughType; + + const mcpPassthroughUrl = normalizePatchField( + "mcp_passthrough_url", + typedDraft.mcp_passthrough_url + ); + if (!mcpPassthroughUrl.ok) return mcpPassthroughUrl; + + return { + ok: true, + data: { + is_enabled: isEnabled.data, + priority: priority.data, + weight: weight.data, + cost_multiplier: costMultiplier.data, + group_tag: groupTag.data, + model_redirects: modelRedirects.data, + allowed_models: allowedModels.data, + anthropic_thinking_budget_preference: thinkingBudget.data, + anthropic_adaptive_thinking: adaptiveThinking.data, + // Routing + preserve_client_ip: preserveClientIp.data, + group_priorities: groupPriorities.data, + cache_ttl_preference: cacheTtlPref.data, + swap_cache_ttl_billing: swapCacheTtlBilling.data, + context_1m_preference: context1mPref.data, + codex_reasoning_effort_preference: codexReasoningEffort.data, + codex_reasoning_summary_preference: codexReasoningSummary.data, + codex_text_verbosity_preference: codexTextVerbosity.data, + codex_parallel_tool_calls_preference: codexParallelToolCalls.data, + anthropic_max_tokens_preference: anthropicMaxTokens.data, + gemini_google_search_preference: geminiGoogleSearch.data, + // Rate Limit + limit_5h_usd: limit5hUsd.data, + limit_daily_usd: limitDailyUsd.data, + daily_reset_mode: dailyResetMode.data, + daily_reset_time: dailyResetTime.data, + limit_weekly_usd: limitWeeklyUsd.data, + limit_monthly_usd: limitMonthlyUsd.data, + limit_total_usd: limitTotalUsd.data, + limit_concurrent_sessions: limitConcurrentSessions.data, + // Circuit Breaker + circuit_breaker_failure_threshold: cbFailureThreshold.data, + circuit_breaker_open_duration: cbOpenDuration.data, + circuit_breaker_half_open_success_threshold: cbHalfOpenSuccess.data, + max_retry_attempts: maxRetryAttempts.data, + // Network + proxy_url: proxyUrl.data, + proxy_fallback_to_direct: proxyFallbackToDirect.data, + first_byte_timeout_streaming_ms: firstByteTimeout.data, + streaming_idle_timeout_ms: streamingIdleTimeout.data, + request_timeout_non_streaming_ms: requestTimeoutNonStreaming.data, + // MCP + mcp_passthrough_type: mcpPassthroughType.data, + mcp_passthrough_url: mcpPassthroughUrl.data, + }, + }; +} + +function applyPatchField( + updates: ProviderBatchApplyUpdates, + field: ProviderBatchPatchField, + patch: ProviderPatchOperation +): ProviderPatchResult { + if (patch.mode === "no_change") { + return { ok: true, data: undefined }; + } + + if (patch.mode === "set") { + switch (field) { + case "is_enabled": + updates.is_enabled = patch.value as ProviderBatchApplyUpdates["is_enabled"]; + return { ok: true, data: undefined }; + case "priority": + updates.priority = patch.value as ProviderBatchApplyUpdates["priority"]; + return { ok: true, data: undefined }; + case "weight": + updates.weight = patch.value as ProviderBatchApplyUpdates["weight"]; + return { ok: true, data: undefined }; + case "cost_multiplier": + updates.cost_multiplier = patch.value as ProviderBatchApplyUpdates["cost_multiplier"]; + return { ok: true, data: undefined }; + case "group_tag": + updates.group_tag = patch.value as ProviderBatchApplyUpdates["group_tag"]; + return { ok: true, data: undefined }; + case "model_redirects": + updates.model_redirects = patch.value as ProviderBatchApplyUpdates["model_redirects"]; + return { ok: true, data: undefined }; + case "allowed_models": + updates.allowed_models = + (patch.value as string[]).length > 0 + ? (patch.value as ProviderBatchApplyUpdates["allowed_models"]) + : null; + return { ok: true, data: undefined }; + case "anthropic_thinking_budget_preference": + updates.anthropic_thinking_budget_preference = + patch.value as ProviderBatchApplyUpdates["anthropic_thinking_budget_preference"]; + return { ok: true, data: undefined }; + case "anthropic_adaptive_thinking": + updates.anthropic_adaptive_thinking = + patch.value as ProviderBatchApplyUpdates["anthropic_adaptive_thinking"]; + return { ok: true, data: undefined }; + // Routing + case "preserve_client_ip": + updates.preserve_client_ip = patch.value as ProviderBatchApplyUpdates["preserve_client_ip"]; + return { ok: true, data: undefined }; + case "group_priorities": + updates.group_priorities = patch.value as ProviderBatchApplyUpdates["group_priorities"]; + return { ok: true, data: undefined }; + case "cache_ttl_preference": + updates.cache_ttl_preference = + patch.value as ProviderBatchApplyUpdates["cache_ttl_preference"]; + return { ok: true, data: undefined }; + case "swap_cache_ttl_billing": + updates.swap_cache_ttl_billing = + patch.value as ProviderBatchApplyUpdates["swap_cache_ttl_billing"]; + return { ok: true, data: undefined }; + case "context_1m_preference": + updates.context_1m_preference = + patch.value as ProviderBatchApplyUpdates["context_1m_preference"]; + return { ok: true, data: undefined }; + case "codex_reasoning_effort_preference": + updates.codex_reasoning_effort_preference = + patch.value as ProviderBatchApplyUpdates["codex_reasoning_effort_preference"]; + return { ok: true, data: undefined }; + case "codex_reasoning_summary_preference": + updates.codex_reasoning_summary_preference = + patch.value as ProviderBatchApplyUpdates["codex_reasoning_summary_preference"]; + return { ok: true, data: undefined }; + case "codex_text_verbosity_preference": + updates.codex_text_verbosity_preference = + patch.value as ProviderBatchApplyUpdates["codex_text_verbosity_preference"]; + return { ok: true, data: undefined }; + case "codex_parallel_tool_calls_preference": + updates.codex_parallel_tool_calls_preference = + patch.value as ProviderBatchApplyUpdates["codex_parallel_tool_calls_preference"]; + return { ok: true, data: undefined }; + case "anthropic_max_tokens_preference": + updates.anthropic_max_tokens_preference = + patch.value as ProviderBatchApplyUpdates["anthropic_max_tokens_preference"]; + return { ok: true, data: undefined }; + case "gemini_google_search_preference": + updates.gemini_google_search_preference = + patch.value as ProviderBatchApplyUpdates["gemini_google_search_preference"]; + return { ok: true, data: undefined }; + // Rate Limit + case "limit_5h_usd": + updates.limit_5h_usd = patch.value as ProviderBatchApplyUpdates["limit_5h_usd"]; + return { ok: true, data: undefined }; + case "limit_daily_usd": + updates.limit_daily_usd = patch.value as ProviderBatchApplyUpdates["limit_daily_usd"]; + return { ok: true, data: undefined }; + case "daily_reset_mode": + updates.daily_reset_mode = patch.value as ProviderBatchApplyUpdates["daily_reset_mode"]; + return { ok: true, data: undefined }; + case "daily_reset_time": + updates.daily_reset_time = patch.value as ProviderBatchApplyUpdates["daily_reset_time"]; + return { ok: true, data: undefined }; + case "limit_weekly_usd": + updates.limit_weekly_usd = patch.value as ProviderBatchApplyUpdates["limit_weekly_usd"]; + return { ok: true, data: undefined }; + case "limit_monthly_usd": + updates.limit_monthly_usd = patch.value as ProviderBatchApplyUpdates["limit_monthly_usd"]; + return { ok: true, data: undefined }; + case "limit_total_usd": + updates.limit_total_usd = patch.value as ProviderBatchApplyUpdates["limit_total_usd"]; + return { ok: true, data: undefined }; + case "limit_concurrent_sessions": + updates.limit_concurrent_sessions = + patch.value as ProviderBatchApplyUpdates["limit_concurrent_sessions"]; + return { ok: true, data: undefined }; + // Circuit Breaker + case "circuit_breaker_failure_threshold": + updates.circuit_breaker_failure_threshold = + patch.value as ProviderBatchApplyUpdates["circuit_breaker_failure_threshold"]; + return { ok: true, data: undefined }; + case "circuit_breaker_open_duration": + updates.circuit_breaker_open_duration = + patch.value as ProviderBatchApplyUpdates["circuit_breaker_open_duration"]; + return { ok: true, data: undefined }; + case "circuit_breaker_half_open_success_threshold": + updates.circuit_breaker_half_open_success_threshold = + patch.value as ProviderBatchApplyUpdates["circuit_breaker_half_open_success_threshold"]; + return { ok: true, data: undefined }; + case "max_retry_attempts": + updates.max_retry_attempts = patch.value as ProviderBatchApplyUpdates["max_retry_attempts"]; + return { ok: true, data: undefined }; + // Network + case "proxy_url": + updates.proxy_url = patch.value as ProviderBatchApplyUpdates["proxy_url"]; + return { ok: true, data: undefined }; + case "proxy_fallback_to_direct": + updates.proxy_fallback_to_direct = + patch.value as ProviderBatchApplyUpdates["proxy_fallback_to_direct"]; + return { ok: true, data: undefined }; + case "first_byte_timeout_streaming_ms": + updates.first_byte_timeout_streaming_ms = + patch.value as ProviderBatchApplyUpdates["first_byte_timeout_streaming_ms"]; + return { ok: true, data: undefined }; + case "streaming_idle_timeout_ms": + updates.streaming_idle_timeout_ms = + patch.value as ProviderBatchApplyUpdates["streaming_idle_timeout_ms"]; + return { ok: true, data: undefined }; + case "request_timeout_non_streaming_ms": + updates.request_timeout_non_streaming_ms = + patch.value as ProviderBatchApplyUpdates["request_timeout_non_streaming_ms"]; + return { ok: true, data: undefined }; + // MCP + case "mcp_passthrough_type": + updates.mcp_passthrough_type = + patch.value as ProviderBatchApplyUpdates["mcp_passthrough_type"]; + return { ok: true, data: undefined }; + case "mcp_passthrough_url": + updates.mcp_passthrough_url = + patch.value as ProviderBatchApplyUpdates["mcp_passthrough_url"]; + return { ok: true, data: undefined }; + default: + return createInvalidPatchShapeError(field, "Unsupported patch field"); + } + } + + // clear mode + switch (field) { + case "group_tag": + updates.group_tag = null; + return { ok: true, data: undefined }; + case "model_redirects": + updates.model_redirects = null; + return { ok: true, data: undefined }; + case "allowed_models": + updates.allowed_models = null; + return { ok: true, data: undefined }; + case "anthropic_thinking_budget_preference": + updates.anthropic_thinking_budget_preference = "inherit"; + return { ok: true, data: undefined }; + case "anthropic_adaptive_thinking": + updates.anthropic_adaptive_thinking = null; + return { ok: true, data: undefined }; + // Routing - preference fields clear to "inherit" + case "cache_ttl_preference": + updates.cache_ttl_preference = "inherit"; + return { ok: true, data: undefined }; + case "context_1m_preference": + updates.context_1m_preference = "inherit"; + return { ok: true, data: undefined }; + case "codex_reasoning_effort_preference": + updates.codex_reasoning_effort_preference = "inherit"; + return { ok: true, data: undefined }; + case "codex_reasoning_summary_preference": + updates.codex_reasoning_summary_preference = "inherit"; + return { ok: true, data: undefined }; + case "codex_text_verbosity_preference": + updates.codex_text_verbosity_preference = "inherit"; + return { ok: true, data: undefined }; + case "codex_parallel_tool_calls_preference": + updates.codex_parallel_tool_calls_preference = "inherit"; + return { ok: true, data: undefined }; + case "anthropic_max_tokens_preference": + updates.anthropic_max_tokens_preference = "inherit"; + return { ok: true, data: undefined }; + case "gemini_google_search_preference": + updates.gemini_google_search_preference = "inherit"; + return { ok: true, data: undefined }; + // Routing - nullable fields clear to null + case "group_priorities": + updates.group_priorities = null; + return { ok: true, data: undefined }; + // Rate Limit - nullable number fields clear to null + case "limit_5h_usd": + updates.limit_5h_usd = null; + return { ok: true, data: undefined }; + case "limit_daily_usd": + updates.limit_daily_usd = null; + return { ok: true, data: undefined }; + case "limit_weekly_usd": + updates.limit_weekly_usd = null; + return { ok: true, data: undefined }; + case "limit_monthly_usd": + updates.limit_monthly_usd = null; + return { ok: true, data: undefined }; + case "limit_total_usd": + updates.limit_total_usd = null; + return { ok: true, data: undefined }; + // Circuit Breaker + case "max_retry_attempts": + updates.max_retry_attempts = null; + return { ok: true, data: undefined }; + // Network + case "proxy_url": + updates.proxy_url = null; + return { ok: true, data: undefined }; + // MCP + case "mcp_passthrough_url": + updates.mcp_passthrough_url = null; + return { ok: true, data: undefined }; + default: + return createInvalidPatchShapeError(field, "clear mode is not supported for this field"); + } +} + +export function buildProviderBatchApplyUpdates( + patch: ProviderBatchPatch +): ProviderPatchResult { + const updates: ProviderBatchApplyUpdates = {}; + + const operations: Array<[ProviderBatchPatchField, ProviderPatchOperation]> = [ + ["is_enabled", patch.is_enabled], + ["priority", patch.priority], + ["weight", patch.weight], + ["cost_multiplier", patch.cost_multiplier], + ["group_tag", patch.group_tag], + ["model_redirects", patch.model_redirects], + ["allowed_models", patch.allowed_models], + ["anthropic_thinking_budget_preference", patch.anthropic_thinking_budget_preference], + ["anthropic_adaptive_thinking", patch.anthropic_adaptive_thinking], + // Routing + ["preserve_client_ip", patch.preserve_client_ip], + ["group_priorities", patch.group_priorities], + ["cache_ttl_preference", patch.cache_ttl_preference], + ["swap_cache_ttl_billing", patch.swap_cache_ttl_billing], + ["context_1m_preference", patch.context_1m_preference], + ["codex_reasoning_effort_preference", patch.codex_reasoning_effort_preference], + ["codex_reasoning_summary_preference", patch.codex_reasoning_summary_preference], + ["codex_text_verbosity_preference", patch.codex_text_verbosity_preference], + ["codex_parallel_tool_calls_preference", patch.codex_parallel_tool_calls_preference], + ["anthropic_max_tokens_preference", patch.anthropic_max_tokens_preference], + ["gemini_google_search_preference", patch.gemini_google_search_preference], + // Rate Limit + ["limit_5h_usd", patch.limit_5h_usd], + ["limit_daily_usd", patch.limit_daily_usd], + ["daily_reset_mode", patch.daily_reset_mode], + ["daily_reset_time", patch.daily_reset_time], + ["limit_weekly_usd", patch.limit_weekly_usd], + ["limit_monthly_usd", patch.limit_monthly_usd], + ["limit_total_usd", patch.limit_total_usd], + ["limit_concurrent_sessions", patch.limit_concurrent_sessions], + // Circuit Breaker + ["circuit_breaker_failure_threshold", patch.circuit_breaker_failure_threshold], + ["circuit_breaker_open_duration", patch.circuit_breaker_open_duration], + [ + "circuit_breaker_half_open_success_threshold", + patch.circuit_breaker_half_open_success_threshold, + ], + ["max_retry_attempts", patch.max_retry_attempts], + // Network + ["proxy_url", patch.proxy_url], + ["proxy_fallback_to_direct", patch.proxy_fallback_to_direct], + ["first_byte_timeout_streaming_ms", patch.first_byte_timeout_streaming_ms], + ["streaming_idle_timeout_ms", patch.streaming_idle_timeout_ms], + ["request_timeout_non_streaming_ms", patch.request_timeout_non_streaming_ms], + // MCP + ["mcp_passthrough_type", patch.mcp_passthrough_type], + ["mcp_passthrough_url", patch.mcp_passthrough_url], + ]; + + for (const [field, operation] of operations) { + const applyResult = applyPatchField(updates, field, operation); + if (!applyResult.ok) { + return applyResult; + } + } + + return { ok: true, data: updates }; +} + +export function hasProviderBatchPatchChanges(patch: ProviderBatchPatch): boolean { + return ( + patch.is_enabled.mode !== "no_change" || + patch.priority.mode !== "no_change" || + patch.weight.mode !== "no_change" || + patch.cost_multiplier.mode !== "no_change" || + patch.group_tag.mode !== "no_change" || + patch.model_redirects.mode !== "no_change" || + patch.allowed_models.mode !== "no_change" || + patch.anthropic_thinking_budget_preference.mode !== "no_change" || + patch.anthropic_adaptive_thinking.mode !== "no_change" || + // Routing + patch.preserve_client_ip.mode !== "no_change" || + patch.group_priorities.mode !== "no_change" || + patch.cache_ttl_preference.mode !== "no_change" || + patch.swap_cache_ttl_billing.mode !== "no_change" || + patch.context_1m_preference.mode !== "no_change" || + patch.codex_reasoning_effort_preference.mode !== "no_change" || + patch.codex_reasoning_summary_preference.mode !== "no_change" || + patch.codex_text_verbosity_preference.mode !== "no_change" || + patch.codex_parallel_tool_calls_preference.mode !== "no_change" || + patch.anthropic_max_tokens_preference.mode !== "no_change" || + patch.gemini_google_search_preference.mode !== "no_change" || + // Rate Limit + patch.limit_5h_usd.mode !== "no_change" || + patch.limit_daily_usd.mode !== "no_change" || + patch.daily_reset_mode.mode !== "no_change" || + patch.daily_reset_time.mode !== "no_change" || + patch.limit_weekly_usd.mode !== "no_change" || + patch.limit_monthly_usd.mode !== "no_change" || + patch.limit_total_usd.mode !== "no_change" || + patch.limit_concurrent_sessions.mode !== "no_change" || + // Circuit Breaker + patch.circuit_breaker_failure_threshold.mode !== "no_change" || + patch.circuit_breaker_open_duration.mode !== "no_change" || + patch.circuit_breaker_half_open_success_threshold.mode !== "no_change" || + patch.max_retry_attempts.mode !== "no_change" || + // Network + patch.proxy_url.mode !== "no_change" || + patch.proxy_fallback_to_direct.mode !== "no_change" || + patch.first_byte_timeout_streaming_ms.mode !== "no_change" || + patch.streaming_idle_timeout_ms.mode !== "no_change" || + patch.request_timeout_non_streaming_ms.mode !== "no_change" || + // MCP + patch.mcp_passthrough_type.mode !== "no_change" || + patch.mcp_passthrough_url.mode !== "no_change" + ); +} + +export function prepareProviderBatchApplyUpdates( + draft: unknown +): ProviderPatchResult { + const normalized = normalizeProviderBatchPatchDraft(draft); + if (!normalized.ok) { + return normalized; + } + + return buildProviderBatchApplyUpdates(normalized.data); +} diff --git a/src/lib/providers/undo-store.ts b/src/lib/providers/undo-store.ts new file mode 100644 index 000000000..db4261013 --- /dev/null +++ b/src/lib/providers/undo-store.ts @@ -0,0 +1,81 @@ +import "server-only"; + +import { logger } from "@/lib/logger"; +import { PROVIDER_BATCH_PATCH_ERROR_CODES } from "@/lib/provider-batch-patch-error-codes"; +import { RedisKVStore } from "@/lib/redis/redis-kv-store"; + +const UNDO_SNAPSHOT_TTL_SECONDS = 30; + +export interface UndoSnapshot { + operationId: string; + operationType: "batch_edit" | "single_edit" | "single_delete"; + preimage: unknown; + providerIds: number[]; + createdAt: string; +} + +export interface StoreUndoResult { + undoAvailable: boolean; + undoToken?: string; + expiresAt?: string; +} + +export type ConsumeUndoResult = + | { + ok: true; + snapshot: UndoSnapshot; + } + | { + ok: false; + code: "UNDO_EXPIRED" | "UNDO_CONFLICT"; + }; + +const store = new RedisKVStore({ + prefix: "cch:prov:undo:", + defaultTtlSeconds: UNDO_SNAPSHOT_TTL_SECONDS, +}); + +export async function storeUndoSnapshot(snapshot: UndoSnapshot): Promise { + try { + const undoToken = crypto.randomUUID(); + const expiresAtMs = Date.now() + UNDO_SNAPSHOT_TTL_SECONDS * 1000; + + const stored = await store.set(undoToken, snapshot); + if (!stored) { + logger.warn("[undo-store] Failed to persist undo snapshot; undo unavailable", { + operationId: snapshot.operationId, + }); + return { undoAvailable: false }; + } + + return { + undoAvailable: true, + undoToken, + expiresAt: new Date(expiresAtMs).toISOString(), + }; + } catch { + return { undoAvailable: false }; + } +} + +export async function consumeUndoToken(token: string): Promise { + try { + const snapshot = await store.getAndDelete(token); + if (!snapshot) { + return { + ok: false, + code: PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_EXPIRED, + }; + } + + return { + ok: true, + snapshot, + }; + } catch { + return { + ok: false, + code: PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_EXPIRED, + }; + } +} diff --git a/src/lib/redis/redis-kv-store.ts b/src/lib/redis/redis-kv-store.ts new file mode 100644 index 000000000..bd3787d80 --- /dev/null +++ b/src/lib/redis/redis-kv-store.ts @@ -0,0 +1,142 @@ +import "server-only"; + +import type Redis from "ioredis"; +import { logger } from "@/lib/logger"; +import { getRedisClient } from "./client"; + +type RedisKVClient = Pick & { + // Redis EVAL for Lua scripts (atomic getAndDelete) + eval(...args: [script: string, numkeys: number, ...keys: string[]]): Promise; +}; + +export interface RedisKVStoreOptions { + prefix: string; + defaultTtlSeconds: number; + redisClient?: RedisKVClient | null; +} + +function toLogError(error: unknown): string { + return error instanceof Error ? error.message : String(error); +} + +// Atomic GET + DEL via Lua script -- prevents TOCTOU race where two concurrent +// callers both GET the same single-use token before either DELetes it. +const LUA_GET_AND_DEL = ` +local val = redis.call('GET', KEYS[1]) +if val then redis.call('DEL', KEYS[1]) end +return val`; + +export class RedisKVStore { + private readonly prefix: string; + private readonly defaultTtlSeconds: number; + private readonly injectedClient?: RedisKVClient | null; + + constructor(options: RedisKVStoreOptions) { + this.prefix = options.prefix; + this.defaultTtlSeconds = options.defaultTtlSeconds; + this.injectedClient = options.redisClient; + } + + private resolveRedisClient(): RedisKVClient | null { + if (this.injectedClient !== undefined) { + return this.injectedClient; + } + return getRedisClient({ allowWhenRateLimitDisabled: true }) as RedisKVClient | null; + } + + private getReadyRedis(): RedisKVClient | null { + const redis = this.resolveRedisClient(); + if (!redis || redis.status !== "ready") { + return null; + } + return redis; + } + + private buildKey(key: string): string { + return `${this.prefix}${key}`; + } + + async set(key: string, value: T, ttlSeconds?: number): Promise { + const redis = this.getReadyRedis(); + if (!redis) { + return false; + } + + const ttl = ttlSeconds ?? this.defaultTtlSeconds; + try { + await redis.setex(this.buildKey(key), ttl, JSON.stringify(value)); + return true; + } catch (error) { + logger.error("[RedisKVStore] Failed to set", { + error: toLogError(error), + prefix: this.prefix, + key, + }); + return false; + } + } + + async get(key: string): Promise { + const redis = this.getReadyRedis(); + if (!redis) { + return null; + } + + try { + const raw = await redis.get(this.buildKey(key)); + if (!raw) { + return null; + } + return JSON.parse(raw) as T; + } catch (error) { + logger.error("[RedisKVStore] Failed to get", { + error: toLogError(error), + prefix: this.prefix, + key, + }); + return null; + } + } + + async getAndDelete(key: string): Promise { + const redis = this.getReadyRedis(); + if (!redis) { + return null; + } + + const fullKey = this.buildKey(key); + try { + const raw = (await redis.eval(LUA_GET_AND_DEL, 1, fullKey)) as string | null; + if (!raw) { + return null; + } + return JSON.parse(raw) as T; + } catch (error) { + logger.error("[RedisKVStore] Failed to getAndDelete", { + error: toLogError(error), + prefix: this.prefix, + key, + }); + return null; + } + } + + async delete(key: string): Promise { + const redis = this.getReadyRedis(); + if (!redis) { + return false; + } + + try { + const deleted = await redis.del(this.buildKey(key)); + return deleted > 0; + } catch (error) { + logger.error("[RedisKVStore] Failed to delete", { + error: toLogError(error), + prefix: this.prefix, + key, + }); + return false; + } + } +} diff --git a/src/lib/security/auth-response-headers.ts b/src/lib/security/auth-response-headers.ts new file mode 100644 index 000000000..a9a7ef615 --- /dev/null +++ b/src/lib/security/auth-response-headers.ts @@ -0,0 +1,22 @@ +import type { NextResponse } from "next/server"; +import { withNoStoreHeaders } from "@/lib/auth"; +import { getEnvConfig } from "@/lib/config/env.schema"; +import { buildSecurityHeaders } from "@/lib/security/security-headers"; + +export function applySecurityHeaders(response: NextResponse): NextResponse { + const env = getEnvConfig(); + const headers = buildSecurityHeaders({ + enableHsts: env.ENABLE_SECURE_COOKIES, + cspMode: "report-only", + }); + + for (const [key, value] of Object.entries(headers)) { + response.headers.set(key, value); + } + + return response; +} + +export function withAuthResponseHeaders(response: NextResponse): NextResponse { + return applySecurityHeaders(withNoStoreHeaders(response)); +} diff --git a/src/lib/security/constant-time-compare.ts b/src/lib/security/constant-time-compare.ts new file mode 100644 index 000000000..d35b9aa61 --- /dev/null +++ b/src/lib/security/constant-time-compare.ts @@ -0,0 +1,27 @@ +import { timingSafeEqual } from "node:crypto"; + +/** + * Constant-time string comparison to prevent timing attacks. + * + * Uses crypto.timingSafeEqual internally. When lengths differ, a dummy + * comparison is still performed so the total CPU time does not leak + * length information. + */ +export function constantTimeEqual(a: string, b: string): boolean { + const bufA = Buffer.from(a, "utf-8"); + const bufB = Buffer.from(b, "utf-8"); + + if (bufA.length !== bufB.length) { + // Pad both to the same length so the dummy comparison time does not + // leak which side is shorter (attacker may control either one). + const padLen = Math.max(bufA.length, bufB.length); + const padA = Buffer.alloc(padLen); + const padB = Buffer.alloc(padLen); + bufA.copy(padA); + bufB.copy(padB); + timingSafeEqual(padA, padB); + return false; + } + + return timingSafeEqual(bufA, bufB); +} diff --git a/src/lib/security/csrf-origin-guard.ts b/src/lib/security/csrf-origin-guard.ts new file mode 100644 index 000000000..90e1f9fe2 --- /dev/null +++ b/src/lib/security/csrf-origin-guard.ts @@ -0,0 +1,66 @@ +export interface CsrfGuardConfig { + allowedOrigins: string[]; + allowSameOrigin: boolean; + enforceInDevelopment: boolean; +} + +export interface CsrfGuardResult { + allowed: boolean; + reason?: string; +} + +export interface CsrfGuardRequest { + headers: { + get(name: string): string | null; + }; +} + +function normalizeOrigin(origin: string): string { + return origin.trim().toLowerCase(); +} + +function isDevelopmentRuntime(): boolean { + if (typeof process === "undefined") return false; + return process.env.NODE_ENV === "development"; +} + +export function createCsrfOriginGuard(config: CsrfGuardConfig) { + const allowSameOrigin = config.allowSameOrigin ?? true; + const enforceInDevelopment = config.enforceInDevelopment ?? false; + const allowedOrigins = new Set( + (config.allowedOrigins ?? []).map(normalizeOrigin).filter((origin) => origin.length > 0) + ); + + return { + check(request: CsrfGuardRequest): CsrfGuardResult { + if (isDevelopmentRuntime() && !enforceInDevelopment) { + return { allowed: true, reason: "csrf_guard_bypassed_in_development" }; + } + + const fetchSite = request.headers.get("sec-fetch-site")?.trim().toLowerCase() ?? null; + if (fetchSite === "same-origin" && allowSameOrigin) { + return { allowed: true }; + } + + const originValue = request.headers.get("origin"); + const origin = originValue ? normalizeOrigin(originValue) : null; + + if (!origin) { + if (fetchSite === "cross-site") { + return { + allowed: false, + reason: "Cross-site request blocked: missing Origin header", + }; + } + + return { allowed: true }; + } + + if (allowedOrigins.has(origin)) { + return { allowed: true }; + } + + return { allowed: false, reason: `Origin ${origin} not in allowlist` }; + }, + }; +} diff --git a/src/lib/security/login-abuse-policy.ts b/src/lib/security/login-abuse-policy.ts new file mode 100644 index 000000000..b0ea9bcc8 --- /dev/null +++ b/src/lib/security/login-abuse-policy.ts @@ -0,0 +1,249 @@ +export interface LoginAbuseConfig { + maxAttemptsPerIp: number; + maxAttemptsPerKey: number; + windowSeconds: number; + lockoutSeconds: number; +} + +export interface LoginAbuseDecision { + allowed: boolean; + retryAfterSeconds?: number; + reason?: string; +} + +export const DEFAULT_LOGIN_ABUSE_CONFIG: LoginAbuseConfig = { + maxAttemptsPerIp: 10, + maxAttemptsPerKey: 10, + windowSeconds: 300, + lockoutSeconds: 900, +}; + +type AttemptRecord = { + count: number; + firstAttempt: number; + lockedUntil?: number; +}; + +const MAX_TRACKED_ENTRIES = 10_000; +const SWEEP_INTERVAL_MS = 60_000; + +export class LoginAbusePolicy { + private attempts = new Map(); + private config: LoginAbuseConfig; + private lastSweepAt = 0; + + constructor(config?: Partial) { + this.config = { + ...DEFAULT_LOGIN_ABUSE_CONFIG, + ...config, + }; + } + + private sweepStaleEntries(now: number): void { + if (now - this.lastSweepAt < SWEEP_INTERVAL_MS) { + return; + } + this.lastSweepAt = now; + + for (const [key, record] of this.attempts) { + if (record.lockedUntil != null) { + if (record.lockedUntil <= now) { + this.attempts.delete(key); + } + } else if (this.isWindowExpired(record, now)) { + this.attempts.delete(key); + } + } + + if (this.attempts.size > MAX_TRACKED_ENTRIES) { + const excess = this.attempts.size - MAX_TRACKED_ENTRIES; + const iterator = this.attempts.keys(); + for (let i = 0; i < excess; i++) { + const next = iterator.next(); + if (next.done) break; + this.attempts.delete(next.value); + } + } + } + + check(ip: string, key?: string): LoginAbuseDecision { + const now = Date.now(); + this.sweepStaleEntries(now); + + const ipDecision = this.checkScope({ + scopeKey: this.toIpScope(ip), + threshold: this.config.maxAttemptsPerIp, + reason: "ip_rate_limited", + now, + }); + + if (!ipDecision.allowed || !key) { + return ipDecision; + } + + return this.checkScope({ + scopeKey: this.toKeyScope(key), + threshold: this.config.maxAttemptsPerKey, + reason: "key_rate_limited", + now, + }); + } + + recordFailure(ip: string, key?: string): void { + const now = Date.now(); + + this.recordFailureForScope({ + scopeKey: this.toIpScope(ip), + threshold: this.config.maxAttemptsPerIp, + now, + }); + + if (!key) { + return; + } + + this.recordFailureForScope({ + scopeKey: this.toKeyScope(key), + threshold: this.config.maxAttemptsPerKey, + now, + }); + } + + recordSuccess(ip: string, key?: string): void { + this.reset(ip, key); + } + + reset(ip: string, key?: string): void { + this.attempts.delete(this.toIpScope(ip)); + + if (!key) { + return; + } + + this.attempts.delete(this.toKeyScope(key)); + } + + private checkScope(params: { + scopeKey: string; + threshold: number; + reason: string; + now: number; + }): LoginAbuseDecision { + const { scopeKey, threshold, reason, now } = params; + const record = this.attempts.get(scopeKey); + + if (!record) { + return { allowed: true }; + } + + if (record.lockedUntil != null) { + if (record.lockedUntil > now) { + return { + allowed: false, + retryAfterSeconds: this.calculateRetryAfterSeconds(record.lockedUntil, now), + reason, + }; + } + + this.attempts.delete(scopeKey); + return { allowed: true }; + } + + if (this.isWindowExpired(record, now)) { + this.attempts.delete(scopeKey); + return { allowed: true }; + } + + if (record.count >= threshold) { + const lockedUntil = now + this.config.lockoutSeconds * 1000; + // LRU bump: delete + re-insert so locked entries survive eviction + this.attempts.delete(scopeKey); + this.attempts.set(scopeKey, { ...record, lockedUntil }); + return { + allowed: false, + retryAfterSeconds: this.calculateRetryAfterSeconds(lockedUntil, now), + reason, + }; + } + + // LRU bump: delete + re-insert moves entry to end of Map iteration order, + // so the eviction loop in sweepStaleEntries removes least-recently-used first + this.attempts.delete(scopeKey); + this.attempts.set(scopeKey, record); + + return { allowed: true }; + } + + private recordFailureForScope(params: { + scopeKey: string; + threshold: number; + now: number; + }): void { + const { scopeKey, threshold, now } = params; + const record = this.attempts.get(scopeKey); + + if (!record) { + this.attempts.set(scopeKey, this.createFirstRecord(now, threshold)); + return; + } + + if (record.lockedUntil != null) { + if (record.lockedUntil > now) { + return; + } + + this.attempts.delete(scopeKey); + this.attempts.set(scopeKey, this.createFirstRecord(now, threshold)); + return; + } + + if (this.isWindowExpired(record, now)) { + this.attempts.delete(scopeKey); + this.attempts.set(scopeKey, this.createFirstRecord(now, threshold)); + return; + } + + const nextCount = record.count + 1; + const nextRecord: AttemptRecord = { + count: nextCount, + firstAttempt: record.firstAttempt, + }; + + if (nextCount >= threshold) { + nextRecord.lockedUntil = now + this.config.lockoutSeconds * 1000; + } + + // LRU bump: delete + re-insert moves entry to end of iteration order + this.attempts.delete(scopeKey); + this.attempts.set(scopeKey, nextRecord); + } + + private isWindowExpired(record: AttemptRecord, now: number): boolean { + return now - record.firstAttempt >= this.config.windowSeconds * 1000; + } + + private calculateRetryAfterSeconds(lockedUntil: number, now: number): number { + return Math.max(0, Math.ceil((lockedUntil - now) / 1000)); + } + + private createFirstRecord(now: number, threshold: number): AttemptRecord { + const firstRecord: AttemptRecord = { + count: 1, + firstAttempt: now, + }; + + if (threshold <= 1) { + firstRecord.lockedUntil = now + this.config.lockoutSeconds * 1000; + } + + return firstRecord; + } + + private toIpScope(ip: string): string { + return `ip:${ip}`; + } + + private toKeyScope(key: string): string { + return `key:${key}`; + } +} diff --git a/src/lib/security/security-headers.ts b/src/lib/security/security-headers.ts new file mode 100644 index 000000000..93c3ec44b --- /dev/null +++ b/src/lib/security/security-headers.ts @@ -0,0 +1,63 @@ +export interface SecurityHeadersConfig { + enableHsts: boolean; + cspMode: "report-only" | "enforce" | "disabled"; + cspReportUri?: string; + hstsMaxAge: number; + frameOptions: "DENY" | "SAMEORIGIN"; +} + +export const DEFAULT_SECURITY_HEADERS_CONFIG: SecurityHeadersConfig = { + enableHsts: false, + cspMode: "report-only", + hstsMaxAge: 31536000, + frameOptions: "DENY", +}; + +function isValidCspReportUri(uri: string): boolean { + const trimmed = uri.trim(); + if (!trimmed || trimmed.includes(";") || trimmed.includes(",") || /\s/.test(trimmed)) { + return false; + } + try { + new URL(trimmed); + return true; + } catch { + return false; + } +} + +const DEFAULT_CSP_VALUE = + "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' " + + "'unsafe-inline'; img-src 'self' data: blob:; connect-src 'self'; font-src 'self' data:; " + + "frame-ancestors 'none'"; + +export function buildSecurityHeaders( + config?: Partial +): Record { + const merged = { ...DEFAULT_SECURITY_HEADERS_CONFIG, ...config }; + const headers: Record = {}; + + headers["X-Content-Type-Options"] = "nosniff"; + headers["X-Frame-Options"] = merged.frameOptions; + headers["Referrer-Policy"] = "strict-origin-when-cross-origin"; + headers["X-DNS-Prefetch-Control"] = "off"; + + if (merged.enableHsts) { + headers["Strict-Transport-Security"] = `max-age=${merged.hstsMaxAge}; includeSubDomains`; + } + + if (merged.cspMode !== "disabled") { + const headerName = + merged.cspMode === "report-only" + ? "Content-Security-Policy-Report-Only" + : "Content-Security-Policy"; + + if (merged.cspReportUri && isValidCspReportUri(merged.cspReportUri)) { + headers[headerName] = `${DEFAULT_CSP_VALUE}; report-uri ${merged.cspReportUri}`; + } else { + headers[headerName] = DEFAULT_CSP_VALUE; + } + } + + return headers; +} diff --git a/src/proxy.ts b/src/proxy.ts index 9157a1ceb..05cae00ac 100644 --- a/src/proxy.ts +++ b/src/proxy.ts @@ -2,7 +2,7 @@ import { type NextRequest, NextResponse } from "next/server"; import createMiddleware from "next-intl/middleware"; import type { Locale } from "@/i18n/config"; import { routing } from "@/i18n/routing"; -import { validateKey } from "@/lib/auth"; +import { AUTH_COOKIE_NAME } from "@/lib/auth"; import { isDevelopment } from "@/lib/config/env.schema"; import { logger } from "@/lib/logger"; @@ -10,16 +10,12 @@ import { logger } from "@/lib/logger"; // Note: These paths will be automatically prefixed with locale by next-intl middleware const PUBLIC_PATH_PATTERNS = ["/login", "/usage-doc", "/api/auth/login", "/api/auth/logout"]; -// Paths that allow read-only access (for canLoginWebUi=false keys) -// These paths bypass the canLoginWebUi check in validateKey -const READ_ONLY_PATH_PATTERNS = ["/my-usage"]; - const API_PROXY_PATH = "/v1"; // Create next-intl middleware for locale detection and routing const intlMiddleware = createMiddleware(routing); -async function proxyHandler(request: NextRequest) { +function proxyHandler(request: NextRequest) { const method = request.method; const pathname = request.nextUrl.pathname; @@ -61,13 +57,12 @@ async function proxyHandler(request: NextRequest) { return localeResponse; } - // Check if current path allows read-only access (for canLoginWebUi=false keys) - const isReadOnlyPath = READ_ONLY_PATH_PATTERNS.some( - (pattern) => pathWithoutLocale === pattern || pathWithoutLocale.startsWith(`${pattern}/`) - ); - - // Check authentication for protected routes - const authToken = request.cookies.get("auth-token"); + // Check authentication for protected routes (cookie existence only). + // Full session validation (Redis lookup, key permissions, expiry) is handled + // by downstream layouts (dashboard/layout.tsx, etc.) which run in Node.js + // runtime with guaranteed Redis/DB access. This avoids a death loop where + // the proxy deletes the cookie on transient validation failures. + const authToken = request.cookies.get(AUTH_COOKIE_NAME); if (!authToken) { // Not authenticated, redirect to login page @@ -79,21 +74,7 @@ async function proxyHandler(request: NextRequest) { return NextResponse.redirect(url); } - // Validate key permissions (canLoginWebUi, isEnabled, expiresAt, etc.) - const session = await validateKey(authToken.value, { allowReadOnlyAccess: isReadOnlyPath }); - if (!session) { - // Invalid key or insufficient permissions, clear cookie and redirect to login - const url = request.nextUrl.clone(); - // Preserve locale in redirect - const locale = isLocaleInPath ? potentialLocale : routing.defaultLocale; - url.pathname = `/${locale}/login`; - url.searchParams.set("from", pathWithoutLocale || "/dashboard"); - const response = NextResponse.redirect(url); - response.cookies.delete("auth-token"); - return response; - } - - // Authentication passed, return locale response + // Cookie exists - pass through to layout for full validation return localeResponse; } diff --git a/src/repository/index.ts b/src/repository/index.ts index e03e1b7e0..a4f28f4fa 100644 --- a/src/repository/index.ts +++ b/src/repository/index.ts @@ -42,6 +42,8 @@ export { findProviderById, findProviderList, getDistinctProviderGroups, + restoreProvider, + restoreProvidersBatch, updateProvider, } from "./provider"; export type { ProviderEndpointProbeTarget } from "./provider-endpoints"; diff --git a/src/repository/provider.ts b/src/repository/provider.ts index 4d6b24fad..5c24866e2 100644 --- a/src/repository/provider.ts +++ b/src/repository/provider.ts @@ -7,7 +7,12 @@ import { getCachedProviders } from "@/lib/cache/provider-cache"; import { resetEndpointCircuit } from "@/lib/endpoint-circuit-breaker"; import { logger } from "@/lib/logger"; import { resolveSystemTimezone } from "@/lib/utils/timezone"; -import type { CreateProviderData, Provider, UpdateProviderData } from "@/types/provider"; +import type { + AnthropicAdaptiveThinkingConfig, + CreateProviderData, + Provider, + UpdateProviderData, +} from "@/types/provider"; import { toProvider } from "./_shared/transformers"; import { ensureProviderEndpointExistsForUrl, @@ -16,6 +21,150 @@ import { tryDeleteProviderVendorIfEmpty, } from "./provider-endpoints"; +type ProviderTransaction = Parameters[0]>[0]; + +const PROVIDER_RESTORE_MAX_AGE_MS = 60_000; +const ENDPOINT_RESTORE_TIME_TOLERANCE_MS = 1_000; + +interface ProviderRestoreCandidate { + id: number; + providerVendorId: number | null; + providerType: Provider["providerType"]; + url: string; + deletedAt: Date | null; +} + +async function restoreSoftDeletedEndpointForProvider( + tx: ProviderTransaction, + provider: ProviderRestoreCandidate, + now: Date +): Promise { + if (provider.providerVendorId == null || !provider.url || !provider.deletedAt) { + return; + } + + const trimmedUrl = provider.url.trim(); + if (!trimmedUrl) { + return; + } + + const [activeReference] = await tx + .select({ id: providers.id }) + .from(providers) + .where( + and( + eq(providers.providerVendorId, provider.providerVendorId), + eq(providers.providerType, provider.providerType), + eq(providers.url, trimmedUrl), + eq(providers.isEnabled, true), + isNull(providers.deletedAt), + ne(providers.id, provider.id) + ) + ) + .limit(1); + + if (activeReference) { + return; + } + + const [activeEndpoint] = await tx + .select({ id: providerEndpoints.id }) + .from(providerEndpoints) + .where( + and( + eq(providerEndpoints.vendorId, provider.providerVendorId), + eq(providerEndpoints.providerType, provider.providerType), + eq(providerEndpoints.url, trimmedUrl), + isNull(providerEndpoints.deletedAt) + ) + ) + .limit(1); + + if (activeEndpoint) { + return; + } + + const lowerBound = new Date(provider.deletedAt.getTime() - ENDPOINT_RESTORE_TIME_TOLERANCE_MS); + const upperBound = new Date(provider.deletedAt.getTime() + ENDPOINT_RESTORE_TIME_TOLERANCE_MS); + + const [endpointToRestore] = await tx + .select({ id: providerEndpoints.id }) + .from(providerEndpoints) + .where( + and( + eq(providerEndpoints.vendorId, provider.providerVendorId), + eq(providerEndpoints.providerType, provider.providerType), + eq(providerEndpoints.url, trimmedUrl), + isNotNull(providerEndpoints.deletedAt), + sql`${providerEndpoints.deletedAt} >= ${lowerBound}`, + sql`${providerEndpoints.deletedAt} <= ${upperBound}` + ) + ) + .orderBy(desc(providerEndpoints.deletedAt), desc(providerEndpoints.id)) + .limit(1); + + if (!endpointToRestore) { + return; + } + + await tx + .update(providerEndpoints) + .set({ + deletedAt: null, + isEnabled: true, + updatedAt: now, + }) + .where( + and(eq(providerEndpoints.id, endpointToRestore.id), isNotNull(providerEndpoints.deletedAt)) + ); +} + +async function restoreProviderInTransaction( + tx: ProviderTransaction, + providerId: number, + now: Date +): Promise { + const [candidate] = await tx + .select({ + id: providers.id, + providerVendorId: providers.providerVendorId, + providerType: providers.providerType, + url: providers.url, + deletedAt: providers.deletedAt, + }) + .from(providers) + .where(and(eq(providers.id, providerId), isNotNull(providers.deletedAt))) + .limit(1); + + if (!candidate?.deletedAt) { + return false; + } + + if (now.getTime() - candidate.deletedAt.getTime() > PROVIDER_RESTORE_MAX_AGE_MS) { + return false; + } + + const restored = await tx + .update(providers) + .set({ deletedAt: null, updatedAt: now }) + .where( + and( + eq(providers.id, providerId), + isNotNull(providers.deletedAt), + eq(providers.deletedAt, candidate.deletedAt) + ) + ) + .returning({ id: providers.id }); + + if (restored.length === 0) { + return false; + } + + await restoreSoftDeletedEndpointForProvider(tx, candidate, now); + + return true; +} + export async function createProvider(providerData: CreateProviderData): Promise { const dbData = { name: providerData.name, @@ -803,12 +952,64 @@ export async function deleteProvider(id: number): Promise { return deleted; } +/** + * 恢复单个软删除供应商及其关联端点。 + * + * 安全策略:仅允许恢复 60 秒内删除的供应商。 + */ +export async function restoreProvider(id: number): Promise { + const now = new Date(); + + const restored = await db.transaction(async (tx) => restoreProviderInTransaction(tx, id, now)); + + return restored; +} + export interface BatchProviderUpdates { isEnabled?: boolean; priority?: number; weight?: number; costMultiplier?: string; groupTag?: string | null; + modelRedirects?: Record | null; + allowedModels?: string[] | null; + anthropicThinkingBudgetPreference?: string | null; + anthropicAdaptiveThinking?: AnthropicAdaptiveThinkingConfig | null; + // Routing + preserveClientIp?: boolean; + groupPriorities?: Record | null; + cacheTtlPreference?: string | null; + swapCacheTtlBilling?: boolean; + context1mPreference?: string | null; + codexReasoningEffortPreference?: string | null; + codexReasoningSummaryPreference?: string | null; + codexTextVerbosityPreference?: string | null; + codexParallelToolCallsPreference?: string | null; + anthropicMaxTokensPreference?: string | null; + geminiGoogleSearchPreference?: string | null; + // Rate Limit + limit5hUsd?: string | null; + limitDailyUsd?: string | null; + dailyResetMode?: string; + dailyResetTime?: string; + limitWeeklyUsd?: string | null; + limitMonthlyUsd?: string | null; + limitTotalUsd?: string | null; + limitConcurrentSessions?: number; + // Circuit Breaker + circuitBreakerFailureThreshold?: number; + circuitBreakerOpenDuration?: number; + circuitBreakerHalfOpenSuccessThreshold?: number; + maxRetryAttempts?: number | null; + // Network + proxyUrl?: string | null; + proxyFallbackToDirect?: boolean; + firstByteTimeoutStreamingMs?: number; + streamingIdleTimeoutMs?: number; + requestTimeoutNonStreamingMs?: number; + // MCP + mcpPassthroughType?: string; + mcpPassthroughUrl?: string | null; } export async function updateProvidersBatch( @@ -838,6 +1039,114 @@ export async function updateProvidersBatch( if (updates.groupTag !== undefined) { setClauses.groupTag = updates.groupTag; } + if (updates.modelRedirects !== undefined) { + setClauses.modelRedirects = updates.modelRedirects; + } + if (updates.allowedModels !== undefined) { + setClauses.allowedModels = updates.allowedModels; + } + if (updates.anthropicThinkingBudgetPreference !== undefined) { + setClauses.anthropicThinkingBudgetPreference = updates.anthropicThinkingBudgetPreference; + } + if (updates.anthropicAdaptiveThinking !== undefined) { + setClauses.anthropicAdaptiveThinking = updates.anthropicAdaptiveThinking; + } + // Routing + if (updates.preserveClientIp !== undefined) { + setClauses.preserveClientIp = updates.preserveClientIp; + } + if (updates.groupPriorities !== undefined) { + setClauses.groupPriorities = updates.groupPriorities; + } + if (updates.cacheTtlPreference !== undefined) { + setClauses.cacheTtlPreference = updates.cacheTtlPreference; + } + if (updates.swapCacheTtlBilling !== undefined) { + setClauses.swapCacheTtlBilling = updates.swapCacheTtlBilling; + } + if (updates.context1mPreference !== undefined) { + setClauses.context1mPreference = updates.context1mPreference; + } + if (updates.codexReasoningEffortPreference !== undefined) { + setClauses.codexReasoningEffortPreference = updates.codexReasoningEffortPreference; + } + if (updates.codexReasoningSummaryPreference !== undefined) { + setClauses.codexReasoningSummaryPreference = updates.codexReasoningSummaryPreference; + } + if (updates.codexTextVerbosityPreference !== undefined) { + setClauses.codexTextVerbosityPreference = updates.codexTextVerbosityPreference; + } + if (updates.codexParallelToolCallsPreference !== undefined) { + setClauses.codexParallelToolCallsPreference = updates.codexParallelToolCallsPreference; + } + if (updates.anthropicMaxTokensPreference !== undefined) { + setClauses.anthropicMaxTokensPreference = updates.anthropicMaxTokensPreference; + } + if (updates.geminiGoogleSearchPreference !== undefined) { + setClauses.geminiGoogleSearchPreference = updates.geminiGoogleSearchPreference; + } + // Rate Limit + if (updates.limit5hUsd !== undefined) { + setClauses.limit5hUsd = updates.limit5hUsd; + } + if (updates.limitDailyUsd !== undefined) { + setClauses.limitDailyUsd = updates.limitDailyUsd; + } + if (updates.dailyResetMode !== undefined) { + setClauses.dailyResetMode = updates.dailyResetMode; + } + if (updates.dailyResetTime !== undefined) { + setClauses.dailyResetTime = updates.dailyResetTime; + } + if (updates.limitWeeklyUsd !== undefined) { + setClauses.limitWeeklyUsd = updates.limitWeeklyUsd; + } + if (updates.limitMonthlyUsd !== undefined) { + setClauses.limitMonthlyUsd = updates.limitMonthlyUsd; + } + if (updates.limitTotalUsd !== undefined) { + setClauses.limitTotalUsd = updates.limitTotalUsd; + } + if (updates.limitConcurrentSessions !== undefined) { + setClauses.limitConcurrentSessions = updates.limitConcurrentSessions; + } + // Circuit Breaker + if (updates.circuitBreakerFailureThreshold !== undefined) { + setClauses.circuitBreakerFailureThreshold = updates.circuitBreakerFailureThreshold; + } + if (updates.circuitBreakerOpenDuration !== undefined) { + setClauses.circuitBreakerOpenDuration = updates.circuitBreakerOpenDuration; + } + if (updates.circuitBreakerHalfOpenSuccessThreshold !== undefined) { + setClauses.circuitBreakerHalfOpenSuccessThreshold = + updates.circuitBreakerHalfOpenSuccessThreshold; + } + if (updates.maxRetryAttempts !== undefined) { + setClauses.maxRetryAttempts = updates.maxRetryAttempts; + } + // Network + if (updates.proxyUrl !== undefined) { + setClauses.proxyUrl = updates.proxyUrl; + } + if (updates.proxyFallbackToDirect !== undefined) { + setClauses.proxyFallbackToDirect = updates.proxyFallbackToDirect; + } + if (updates.firstByteTimeoutStreamingMs !== undefined) { + setClauses.firstByteTimeoutStreamingMs = updates.firstByteTimeoutStreamingMs; + } + if (updates.streamingIdleTimeoutMs !== undefined) { + setClauses.streamingIdleTimeoutMs = updates.streamingIdleTimeoutMs; + } + if (updates.requestTimeoutNonStreamingMs !== undefined) { + setClauses.requestTimeoutNonStreamingMs = updates.requestTimeoutNonStreamingMs; + } + // MCP + if (updates.mcpPassthroughType !== undefined) { + setClauses.mcpPassthroughType = updates.mcpPassthroughType; + } + if (updates.mcpPassthroughUrl !== undefined) { + setClauses.mcpPassthroughUrl = updates.mcpPassthroughUrl; + } if (Object.keys(setClauses).length === 1) { return 0; @@ -1038,6 +1347,39 @@ export async function deleteProvidersBatch(ids: number[]): Promise { return deletedCount; } +/** + * 批量恢复软删除供应商及其关联端点(事务内逐个恢复)。 + * + * 安全策略:仅允许恢复 60 秒内删除的供应商。 + */ +export async function restoreProvidersBatch(ids: number[]): Promise { + if (ids.length === 0) { + return 0; + } + + const uniqueIds = [...new Set(ids)]; + const now = new Date(); + + const restoredCount = await db.transaction(async (tx) => { + let restored = 0; + + for (const id of uniqueIds) { + if (await restoreProviderInTransaction(tx, id, now)) { + restored += 1; + } + } + + return restored; + }); + + logger.debug("restoreProvidersBatch:completed", { + requestedIds: uniqueIds.length, + restoredCount, + }); + + return restoredCount; +} + /** * 手动重置供应商"总消费"统计起点 * diff --git a/src/types/provider.ts b/src/types/provider.ts index aed85a685..94480e6d0 100644 --- a/src/types/provider.ts +++ b/src/types/provider.ts @@ -45,6 +45,208 @@ export interface AnthropicAdaptiveThinkingConfig { models: string[]; } +export type ProviderPatchOperation = + | { mode: "no_change" } + | { mode: "set"; value: T } + | { mode: "clear" }; + +export type ProviderPatchDraftInput = + | { set: T; clear?: never; no_change?: never } + | { clear: true; set?: never; no_change?: never } + | { no_change: true; set?: never; clear?: never } + | undefined; + +export type ProviderBatchPatchField = + // Basic / existing + | "is_enabled" + | "priority" + | "weight" + | "cost_multiplier" + | "group_tag" + | "model_redirects" + | "allowed_models" + | "anthropic_thinking_budget_preference" + | "anthropic_adaptive_thinking" + // Routing + | "preserve_client_ip" + | "group_priorities" + | "cache_ttl_preference" + | "swap_cache_ttl_billing" + | "context_1m_preference" + | "codex_reasoning_effort_preference" + | "codex_reasoning_summary_preference" + | "codex_text_verbosity_preference" + | "codex_parallel_tool_calls_preference" + | "anthropic_max_tokens_preference" + | "gemini_google_search_preference" + // Rate Limit + | "limit_5h_usd" + | "limit_daily_usd" + | "daily_reset_mode" + | "daily_reset_time" + | "limit_weekly_usd" + | "limit_monthly_usd" + | "limit_total_usd" + | "limit_concurrent_sessions" + // Circuit Breaker + | "circuit_breaker_failure_threshold" + | "circuit_breaker_open_duration" + | "circuit_breaker_half_open_success_threshold" + | "max_retry_attempts" + // Network + | "proxy_url" + | "proxy_fallback_to_direct" + | "first_byte_timeout_streaming_ms" + | "streaming_idle_timeout_ms" + | "request_timeout_non_streaming_ms" + // MCP + | "mcp_passthrough_type" + | "mcp_passthrough_url"; + +export interface ProviderBatchPatchDraft { + // Basic / existing + is_enabled?: ProviderPatchDraftInput; + priority?: ProviderPatchDraftInput; + weight?: ProviderPatchDraftInput; + cost_multiplier?: ProviderPatchDraftInput; + group_tag?: ProviderPatchDraftInput; + model_redirects?: ProviderPatchDraftInput>; + allowed_models?: ProviderPatchDraftInput; + anthropic_thinking_budget_preference?: ProviderPatchDraftInput; + anthropic_adaptive_thinking?: ProviderPatchDraftInput; + // Routing + preserve_client_ip?: ProviderPatchDraftInput; + group_priorities?: ProviderPatchDraftInput>; + cache_ttl_preference?: ProviderPatchDraftInput; + swap_cache_ttl_billing?: ProviderPatchDraftInput; + context_1m_preference?: ProviderPatchDraftInput; + codex_reasoning_effort_preference?: ProviderPatchDraftInput; + codex_reasoning_summary_preference?: ProviderPatchDraftInput; + codex_text_verbosity_preference?: ProviderPatchDraftInput; + codex_parallel_tool_calls_preference?: ProviderPatchDraftInput; + anthropic_max_tokens_preference?: ProviderPatchDraftInput; + gemini_google_search_preference?: ProviderPatchDraftInput; + // Rate Limit + limit_5h_usd?: ProviderPatchDraftInput; + limit_daily_usd?: ProviderPatchDraftInput; + daily_reset_mode?: ProviderPatchDraftInput<"fixed" | "rolling">; + daily_reset_time?: ProviderPatchDraftInput; + limit_weekly_usd?: ProviderPatchDraftInput; + limit_monthly_usd?: ProviderPatchDraftInput; + limit_total_usd?: ProviderPatchDraftInput; + limit_concurrent_sessions?: ProviderPatchDraftInput; + // Circuit Breaker + circuit_breaker_failure_threshold?: ProviderPatchDraftInput; + circuit_breaker_open_duration?: ProviderPatchDraftInput; + circuit_breaker_half_open_success_threshold?: ProviderPatchDraftInput; + max_retry_attempts?: ProviderPatchDraftInput; + // Network + proxy_url?: ProviderPatchDraftInput; + proxy_fallback_to_direct?: ProviderPatchDraftInput; + first_byte_timeout_streaming_ms?: ProviderPatchDraftInput; + streaming_idle_timeout_ms?: ProviderPatchDraftInput; + request_timeout_non_streaming_ms?: ProviderPatchDraftInput; + // MCP + mcp_passthrough_type?: ProviderPatchDraftInput; + mcp_passthrough_url?: ProviderPatchDraftInput; +} + +export interface ProviderBatchPatch { + // Basic / existing + is_enabled: ProviderPatchOperation; + priority: ProviderPatchOperation; + weight: ProviderPatchOperation; + cost_multiplier: ProviderPatchOperation; + group_tag: ProviderPatchOperation; + model_redirects: ProviderPatchOperation>; + allowed_models: ProviderPatchOperation; + anthropic_thinking_budget_preference: ProviderPatchOperation; + anthropic_adaptive_thinking: ProviderPatchOperation; + // Routing + preserve_client_ip: ProviderPatchOperation; + group_priorities: ProviderPatchOperation>; + cache_ttl_preference: ProviderPatchOperation; + swap_cache_ttl_billing: ProviderPatchOperation; + context_1m_preference: ProviderPatchOperation; + codex_reasoning_effort_preference: ProviderPatchOperation; + codex_reasoning_summary_preference: ProviderPatchOperation; + codex_text_verbosity_preference: ProviderPatchOperation; + codex_parallel_tool_calls_preference: ProviderPatchOperation; + anthropic_max_tokens_preference: ProviderPatchOperation; + gemini_google_search_preference: ProviderPatchOperation; + // Rate Limit + limit_5h_usd: ProviderPatchOperation; + limit_daily_usd: ProviderPatchOperation; + daily_reset_mode: ProviderPatchOperation<"fixed" | "rolling">; + daily_reset_time: ProviderPatchOperation; + limit_weekly_usd: ProviderPatchOperation; + limit_monthly_usd: ProviderPatchOperation; + limit_total_usd: ProviderPatchOperation; + limit_concurrent_sessions: ProviderPatchOperation; + // Circuit Breaker + circuit_breaker_failure_threshold: ProviderPatchOperation; + circuit_breaker_open_duration: ProviderPatchOperation; + circuit_breaker_half_open_success_threshold: ProviderPatchOperation; + max_retry_attempts: ProviderPatchOperation; + // Network + proxy_url: ProviderPatchOperation; + proxy_fallback_to_direct: ProviderPatchOperation; + first_byte_timeout_streaming_ms: ProviderPatchOperation; + streaming_idle_timeout_ms: ProviderPatchOperation; + request_timeout_non_streaming_ms: ProviderPatchOperation; + // MCP + mcp_passthrough_type: ProviderPatchOperation; + mcp_passthrough_url: ProviderPatchOperation; +} + +export interface ProviderBatchApplyUpdates { + // Basic / existing + is_enabled?: boolean; + priority?: number; + weight?: number; + cost_multiplier?: number; + group_tag?: string | null; + model_redirects?: Record | null; + allowed_models?: string[] | null; + anthropic_thinking_budget_preference?: AnthropicThinkingBudgetPreference | null; + anthropic_adaptive_thinking?: AnthropicAdaptiveThinkingConfig | null; + // Routing + preserve_client_ip?: boolean; + group_priorities?: Record | null; + cache_ttl_preference?: CacheTtlPreference | null; + swap_cache_ttl_billing?: boolean; + context_1m_preference?: Context1mPreference | null; + codex_reasoning_effort_preference?: CodexReasoningEffortPreference | null; + codex_reasoning_summary_preference?: CodexReasoningSummaryPreference | null; + codex_text_verbosity_preference?: CodexTextVerbosityPreference | null; + codex_parallel_tool_calls_preference?: CodexParallelToolCallsPreference | null; + anthropic_max_tokens_preference?: AnthropicMaxTokensPreference | null; + gemini_google_search_preference?: GeminiGoogleSearchPreference | null; + // Rate Limit + limit_5h_usd?: number | null; + limit_daily_usd?: number | null; + daily_reset_mode?: "fixed" | "rolling"; + daily_reset_time?: string; + limit_weekly_usd?: number | null; + limit_monthly_usd?: number | null; + limit_total_usd?: number | null; + limit_concurrent_sessions?: number; + // Circuit Breaker + circuit_breaker_failure_threshold?: number; + circuit_breaker_open_duration?: number; + circuit_breaker_half_open_success_threshold?: number; + max_retry_attempts?: number | null; + // Network + proxy_url?: string | null; + proxy_fallback_to_direct?: boolean; + first_byte_timeout_streaming_ms?: number; + streaming_idle_timeout_ms?: number; + request_timeout_non_streaming_ms?: number; + // MCP + mcp_passthrough_type?: McpPassthroughType; + mcp_passthrough_url?: string | null; +} + // Gemini (generateContent API) parameter overrides // - "inherit": follow client request (default) // - "enabled": force inject googleSearch tool diff --git a/tests/api/action-adapter-auth-session.unit.test.ts b/tests/api/action-adapter-auth-session.unit.test.ts index 16eace9ce..e54025d84 100644 --- a/tests/api/action-adapter-auth-session.unit.test.ts +++ b/tests/api/action-adapter-auth-session.unit.test.ts @@ -76,11 +76,12 @@ describe("Action Adapter:会话透传", () => { return { ...actual, validateKey: vi.fn(async () => mockSession), + validateAuthToken: vi.fn(async () => mockSession), }; }); const { createActionRoute } = await import("@/lib/api/action-adapter-openapi"); - const { getSession, validateKey } = await import("@/lib/auth"); + const { getSession, validateAuthToken } = await import("@/lib/auth"); const action = vi.fn(async () => { const session = await getSession(); @@ -115,7 +116,7 @@ describe("Action Adapter:会话透传", () => { }), } as any)) as Response; - expect(validateKey).toHaveBeenCalledTimes(1); + expect(validateAuthToken).toHaveBeenCalledTimes(1); expect(action).toHaveBeenCalledTimes(1); expect(response.status).toBe(200); await expect(response.json()).resolves.toEqual({ diff --git a/tests/security/auth-bruteforce-integration.test.ts b/tests/security/auth-bruteforce-integration.test.ts new file mode 100644 index 000000000..57eb09186 --- /dev/null +++ b/tests/security/auth-bruteforce-integration.test.ts @@ -0,0 +1,172 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { NextRequest } from "next/server"; + +const mockValidateKey = vi.hoisted(() => vi.fn()); +const mockSetAuthCookie = vi.hoisted(() => vi.fn()); +const mockGetLoginRedirectTarget = vi.hoisted(() => vi.fn()); +const mockGetSessionTokenMode = vi.hoisted(() => vi.fn()); +const mockGetTranslations = vi.hoisted(() => vi.fn()); +const mockLogger = vi.hoisted(() => ({ + warn: vi.fn(), + error: vi.fn(), + info: vi.fn(), + debug: vi.fn(), +})); + +vi.mock("@/lib/auth", () => ({ + validateKey: mockValidateKey, + setAuthCookie: mockSetAuthCookie, + getLoginRedirectTarget: mockGetLoginRedirectTarget, + getSessionTokenMode: mockGetSessionTokenMode, + withNoStoreHeaders: (res: T): T => { + (res as Response).headers.set("Cache-Control", "no-store, no-cache, must-revalidate"); + (res as Response).headers.set("Pragma", "no-cache"); + return res; + }, +})); + +vi.mock("next-intl/server", () => ({ + getTranslations: mockGetTranslations, +})); + +vi.mock("@/lib/logger", () => ({ + logger: mockLogger, +})); + +vi.mock("@/lib/config/env.schema", () => ({ + getEnvConfig: () => ({ ENABLE_SECURE_COOKIES: false, SESSION_TOKEN_MODE: "legacy" }), +})); + +vi.mock("@/lib/security/auth-response-headers", () => ({ + withAuthResponseHeaders: (res: T): T => { + (res as Response).headers.set("Cache-Control", "no-store, no-cache, must-revalidate"); + (res as Response).headers.set("Pragma", "no-cache"); + return res; + }, +})); + +function makeRequest(body: unknown, ip: string): NextRequest { + return new NextRequest("http://localhost/api/auth/login", { + method: "POST", + headers: { + "Content-Type": "application/json", + "x-forwarded-for": ip, + "x-forwarded-proto": "https", + }, + body: JSON.stringify(body), + }); +} + +const fakeSession = { + user: { + id: 1, + name: "Test User", + description: "desc", + role: "user" as const, + }, + key: { canLoginWebUi: true }, +}; + +async function exhaustFailures( + POST: (request: NextRequest) => Promise, + ip: string, + count = 10 +) { + for (let i = 0; i < count; i++) { + const res = await POST(makeRequest({ key: `bad-${i}` }, ip)); + expect(res.status).toBe(401); + } +} + +describe("auth login anti-bruteforce integration", () => { + let POST: (request: NextRequest) => Promise; + + beforeEach(async () => { + vi.resetModules(); + vi.clearAllMocks(); + + const mockT = vi.fn((key: string) => `translated:${key}`); + mockGetTranslations.mockResolvedValue(mockT); + mockSetAuthCookie.mockResolvedValue(undefined); + mockGetLoginRedirectTarget.mockReturnValue("/dashboard"); + mockGetSessionTokenMode.mockReturnValue("legacy"); + + const mod = await import("../../src/app/api/auth/login/route"); + POST = mod.POST; + }); + + it("normal request passes rate-limit check", async () => { + mockValidateKey.mockResolvedValue(null); + + const res = await POST(makeRequest({ key: "bad-key" }, "198.51.100.10")); + + expect(res.status).toBe(401); + expect(res.headers.get("Retry-After")).toBeNull(); + expect(mockValidateKey).toHaveBeenCalledWith("bad-key", { allowReadOnlyAccess: true }); + }); + + it("returns 429 with Retry-After after max failures", async () => { + const ip = "198.51.100.20"; + mockValidateKey.mockResolvedValue(null); + + await exhaustFailures(POST, ip); + + const blockedRes = await POST(makeRequest({ key: "blocked-now" }, ip)); + + expect(blockedRes.status).toBe(429); + expect(blockedRes.headers.get("Retry-After")).not.toBeNull(); + expect(Number.parseInt(blockedRes.headers.get("Retry-After") ?? "0", 10)).toBeGreaterThan(0); + expect(mockValidateKey).toHaveBeenCalledTimes(10); + }); + + it("successful login resets failure counter", async () => { + const ip = "198.51.100.30"; + mockValidateKey.mockImplementation(async (key: string) => { + return key === "valid-key" ? fakeSession : null; + }); + + for (let i = 0; i < 9; i++) { + const res = await POST(makeRequest({ key: `bad-before-success-${i}` }, ip)); + expect(res.status).toBe(401); + } + + const successRes = await POST(makeRequest({ key: "valid-key" }, ip)); + expect(successRes.status).toBe(200); + + const firstAfterSuccess = await POST(makeRequest({ key: "bad-after-success-1" }, ip)); + const secondAfterSuccess = await POST(makeRequest({ key: "bad-after-success-2" }, ip)); + + expect(firstAfterSuccess.status).toBe(401); + expect(secondAfterSuccess.status).toBe(401); + expect(secondAfterSuccess.headers.get("Retry-After")).toBeNull(); + expect(mockSetAuthCookie).toHaveBeenCalledWith("valid-key"); + }); + + it("429 response includes errorCode RATE_LIMITED", async () => { + const ip = "198.51.100.40"; + mockValidateKey.mockResolvedValue(null); + + await exhaustFailures(POST, ip); + + const blockedRes = await POST(makeRequest({ key: "blocked-key" }, ip)); + + expect(blockedRes.status).toBe(429); + await expect(blockedRes.json()).resolves.toMatchObject({ + errorCode: "RATE_LIMITED", + }); + }); + + it("tracks different IPs independently", async () => { + const blockedIp = "198.51.100.50"; + const freshIp = "198.51.100.51"; + mockValidateKey.mockResolvedValue(null); + + await exhaustFailures(POST, blockedIp); + + const blockedRes = await POST(makeRequest({ key: "blocked-key" }, blockedIp)); + const freshRes = await POST(makeRequest({ key: "fresh-ip-key" }, freshIp)); + + expect(blockedRes.status).toBe(429); + expect(freshRes.status).toBe(401); + }); +}); diff --git a/tests/security/auth-csrf-route-integration.test.ts b/tests/security/auth-csrf-route-integration.test.ts new file mode 100644 index 000000000..867f80a42 --- /dev/null +++ b/tests/security/auth-csrf-route-integration.test.ts @@ -0,0 +1,175 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import type { NextRequest } from "next/server"; + +const mockValidateKey = vi.hoisted(() => vi.fn()); +const mockSetAuthCookie = vi.hoisted(() => vi.fn()); +const mockGetSessionTokenMode = vi.hoisted(() => vi.fn()); +const mockGetLoginRedirectTarget = vi.hoisted(() => vi.fn()); +const mockClearAuthCookie = vi.hoisted(() => vi.fn()); +const mockGetAuthCookie = vi.hoisted(() => vi.fn()); +const mockGetTranslations = vi.hoisted(() => vi.fn()); +const mockGetEnvConfig = vi.hoisted(() => vi.fn()); +const mockLogger = vi.hoisted(() => ({ + warn: vi.fn(), + error: vi.fn(), + info: vi.fn(), + debug: vi.fn(), +})); + +vi.mock("@/lib/auth", () => ({ + validateKey: mockValidateKey, + setAuthCookie: mockSetAuthCookie, + getSessionTokenMode: mockGetSessionTokenMode, + getLoginRedirectTarget: mockGetLoginRedirectTarget, + clearAuthCookie: mockClearAuthCookie, + getAuthCookie: mockGetAuthCookie, + toKeyFingerprint: vi.fn().mockResolvedValue("sha256:mock"), + withNoStoreHeaders: (res: T): T => { + (res as any).headers.set("Cache-Control", "no-store, no-cache, must-revalidate"); + (res as any).headers.set("Pragma", "no-cache"); + return res; + }, +})); + +vi.mock("next-intl/server", () => ({ + getTranslations: mockGetTranslations, +})); + +vi.mock("@/lib/config/env.schema", () => ({ + getEnvConfig: mockGetEnvConfig, +})); + +vi.mock("@/lib/logger", () => ({ + logger: mockLogger, +})); + +vi.mock("@/lib/security/auth-response-headers", () => ({ + withAuthResponseHeaders: (res: T): T => { + (res as any).headers.set("Cache-Control", "no-store, no-cache, must-revalidate"); + (res as any).headers.set("Pragma", "no-cache"); + return res; + }, +})); + +type LoginPostHandler = (request: NextRequest) => Promise; +type LogoutPostHandler = (request: NextRequest) => Promise; + +function makeLoginRequest(headers: Record = {}, key = "valid-key"): NextRequest { + const requestHeaders = new Headers({ + "content-type": "application/json", + ...headers, + }); + + return { + headers: requestHeaders, + cookies: { + get: () => undefined, + }, + json: async () => ({ key }), + } as unknown as NextRequest; +} + +function makeLogoutRequest(headers: Record = {}): NextRequest { + return { + headers: new Headers(headers), + } as unknown as NextRequest; +} + +describe("auth route csrf guard integration", () => { + const originalNodeEnv = process.env.NODE_ENV; + let loginPost: LoginPostHandler; + let logoutPost: LogoutPostHandler; + + afterEach(() => { + process.env.NODE_ENV = originalNodeEnv; + }); + + beforeEach(async () => { + vi.resetModules(); + vi.clearAllMocks(); + process.env.NODE_ENV = "test"; + + mockGetTranslations.mockResolvedValue( + vi.fn((messageKey: string) => `translated:${messageKey}`) + ); + mockGetEnvConfig.mockReturnValue({ ENABLE_SECURE_COOKIES: false }); + mockValidateKey.mockResolvedValue({ + user: { + id: 1, + name: "Test User", + description: "desc", + role: "user", + }, + key: { + canLoginWebUi: true, + }, + }); + mockSetAuthCookie.mockResolvedValue(undefined); + mockGetLoginRedirectTarget.mockReturnValue("/dashboard"); + mockClearAuthCookie.mockResolvedValue(undefined); + mockGetAuthCookie.mockResolvedValue(undefined); + mockGetSessionTokenMode.mockReturnValue("legacy"); + + const loginRoute = await import("@/app/api/auth/login/route"); + loginPost = loginRoute.POST; + + const logoutRoute = await import("@/app/api/auth/logout/route"); + logoutPost = logoutRoute.POST; + }); + + it("allows same-origin login request to pass through", async () => { + const res = await loginPost(makeLoginRequest({ "sec-fetch-site": "same-origin" })); + + expect(res.status).toBe(200); + expect(mockValidateKey).toHaveBeenCalledWith("valid-key", { allowReadOnlyAccess: true }); + }); + + it("blocks cross-origin login request with csrf rejected error", async () => { + const request = makeLoginRequest({ + "sec-fetch-site": "cross-site", + origin: "https://evil.example.com", + }); + + const res = await loginPost(request); + + expect(res.status).toBe(403); + expect(await res.json()).toEqual({ errorCode: "CSRF_REJECTED" }); + expect(mockValidateKey).not.toHaveBeenCalled(); + }); + + it("allows login request without origin header for non-browser clients", async () => { + const res = await loginPost(makeLoginRequest()); + + expect(res.status).toBe(200); + expect(mockValidateKey).toHaveBeenCalledTimes(1); + }); + + it("allows same-origin logout request to pass through", async () => { + const res = await logoutPost(makeLogoutRequest({ "sec-fetch-site": "same-origin" })); + + expect(res.status).toBe(200); + expect(await res.json()).toEqual({ ok: true }); + expect(mockClearAuthCookie).toHaveBeenCalledTimes(1); + }); + + it("blocks cross-origin logout request with csrf rejected error", async () => { + const request = makeLogoutRequest({ + "sec-fetch-site": "cross-site", + origin: "https://evil.example.com", + }); + + const res = await logoutPost(request); + + expect(res.status).toBe(403); + expect(await res.json()).toEqual({ errorCode: "CSRF_REJECTED" }); + expect(mockClearAuthCookie).not.toHaveBeenCalled(); + }); + + it("allows logout request without origin header for non-browser clients", async () => { + const res = await logoutPost(makeLogoutRequest()); + + expect(res.status).toBe(200); + expect(await res.json()).toEqual({ ok: true }); + expect(mockClearAuthCookie).toHaveBeenCalledTimes(1); + }); +}); diff --git a/tests/security/auth-dual-read.test.ts b/tests/security/auth-dual-read.test.ts new file mode 100644 index 000000000..a843a7885 --- /dev/null +++ b/tests/security/auth-dual-read.test.ts @@ -0,0 +1,264 @@ +import crypto from "node:crypto"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import type { Key } from "@/types/key"; +import type { User } from "@/types/user"; + +const mockCookies = vi.hoisted(() => vi.fn()); +const mockHeaders = vi.hoisted(() => vi.fn()); +const mockGetEnvConfig = vi.hoisted(() => vi.fn()); +const mockValidateApiKeyAndGetUser = vi.hoisted(() => vi.fn()); +const mockFindKeyList = vi.hoisted(() => vi.fn()); +const mockReadSession = vi.hoisted(() => vi.fn()); +const mockCookieStore = vi.hoisted(() => ({ + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), +})); +const mockHeadersStore = vi.hoisted(() => ({ + get: vi.fn(), +})); +const loggerMock = vi.hoisted(() => ({ + warn: vi.fn(), + error: vi.fn(), + info: vi.fn(), + debug: vi.fn(), + trace: vi.fn(), +})); + +vi.mock("next/headers", () => ({ + cookies: mockCookies, + headers: mockHeaders, +})); + +vi.mock("@/lib/config/env.schema", () => ({ + getEnvConfig: mockGetEnvConfig, +})); + +vi.mock("@/repository/key", () => ({ + validateApiKeyAndGetUser: mockValidateApiKeyAndGetUser, + findKeyList: mockFindKeyList, +})); + +vi.mock("@/lib/auth-session-store/redis-session-store", () => ({ + RedisSessionStore: class { + read = mockReadSession; + create = vi.fn(); + revoke = vi.fn(); + rotate = vi.fn(); + }, +})); + +vi.mock("@/lib/logger", () => ({ + logger: loggerMock, +})); + +vi.mock("@/lib/config/config", () => ({ + config: { auth: { adminToken: "" } }, +})); + +function setSessionMode(mode: "legacy" | "dual" | "opaque") { + mockGetEnvConfig.mockReturnValue({ + SESSION_TOKEN_MODE: mode, + ENABLE_SECURE_COOKIES: false, + }); +} + +function setAuthToken(token?: string) { + mockCookieStore.get.mockReturnValue(token ? { value: token } : undefined); +} + +function toFingerprint(keyString: string): string { + return `sha256:${crypto.createHash("sha256").update(keyString, "utf8").digest("hex")}`; +} + +function buildUser(id: number): User { + const now = new Date("2026-02-18T10:00:00.000Z"); + return { + id, + name: `user-${id}`, + description: "test user", + role: "user", + rpm: 100, + dailyQuota: 100, + providerGroup: null, + tags: [], + createdAt: now, + updatedAt: now, + limit5hUsd: 0, + limitWeeklyUsd: 0, + limitMonthlyUsd: 0, + limitTotalUsd: null, + limitConcurrentSessions: 0, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + isEnabled: true, + expiresAt: null, + allowedClients: [], + allowedModels: [], + }; +} + +function buildKey(id: number, userId: number, keyString: string, canLoginWebUi = true): Key { + const now = new Date("2026-02-18T10:00:00.000Z"); + return { + id, + userId, + name: `key-${id}`, + key: keyString, + isEnabled: true, + canLoginWebUi, + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + limitConcurrentSessions: 0, + providerGroup: null, + cacheTtlPreference: null, + createdAt: now, + updatedAt: now, + }; +} + +function buildAuthResult(keyString: string, userId = 1) { + return { + user: buildUser(userId), + key: buildKey(userId, userId, keyString), + }; +} + +describe("auth dual-read session resolver", () => { + beforeEach(() => { + vi.resetModules(); + vi.clearAllMocks(); + + mockCookies.mockResolvedValue(mockCookieStore); + mockHeaders.mockResolvedValue(mockHeadersStore); + mockHeadersStore.get.mockReturnValue(null); + mockCookieStore.get.mockReturnValue(undefined); + + setSessionMode("legacy"); + mockReadSession.mockResolvedValue(null); + mockFindKeyList.mockResolvedValue([]); + mockValidateApiKeyAndGetUser.mockResolvedValue(null); + }); + + it("legacy mode keeps legacy key validation path unchanged", async () => { + setSessionMode("legacy"); + setAuthToken("sk-legacy"); + const authResult = buildAuthResult("sk-legacy", 11); + mockValidateApiKeyAndGetUser.mockResolvedValue(authResult); + + const { getSessionWithDualRead } = await import("@/lib/auth"); + const session = await getSessionWithDualRead(); + + expect(session).toEqual(authResult); + expect(mockReadSession).not.toHaveBeenCalled(); + expect(mockValidateApiKeyAndGetUser).toHaveBeenCalledTimes(1); + expect(mockValidateApiKeyAndGetUser).toHaveBeenCalledWith("sk-legacy"); + }); + + it("dual mode tries opaque read first and then falls back to legacy cookie", async () => { + setSessionMode("dual"); + setAuthToken("sk-dual"); + const authResult = buildAuthResult("sk-dual", 12); + mockReadSession.mockResolvedValue(null); + mockValidateApiKeyAndGetUser.mockResolvedValue(authResult); + + const { getSessionWithDualRead } = await import("@/lib/auth"); + const session = await getSessionWithDualRead(); + + expect(session).toEqual(authResult); + expect(mockReadSession).toHaveBeenCalledTimes(1); + expect(mockReadSession).toHaveBeenCalledWith("sk-dual"); + expect(mockValidateApiKeyAndGetUser).toHaveBeenCalledWith("sk-dual"); + expect(mockReadSession.mock.invocationCallOrder[0]).toBeLessThan( + mockValidateApiKeyAndGetUser.mock.invocationCallOrder[0] + ); + }); + + it("opaque mode only reads opaque session and never falls back to legacy", async () => { + setSessionMode("opaque"); + setAuthToken("sk-legacy-in-opaque"); + mockReadSession.mockResolvedValue(null); + mockValidateApiKeyAndGetUser.mockResolvedValue(buildAuthResult("sk-legacy-in-opaque", 13)); + + const { getSessionWithDualRead } = await import("@/lib/auth"); + const session = await getSessionWithDualRead(); + + expect(session).toBeNull(); + expect(mockReadSession).toHaveBeenCalledTimes(1); + expect(mockReadSession).toHaveBeenCalledWith("sk-legacy-in-opaque"); + expect(mockValidateApiKeyAndGetUser).not.toHaveBeenCalled(); + }); + + it("returns a valid auth session when opaque session is found", async () => { + setSessionMode("dual"); + setAuthToken("sid_opaque_found"); + + const keyString = "sk-opaque-source"; + const authResult = buildAuthResult(keyString, 21); + mockReadSession.mockResolvedValue({ + sessionId: "sid_opaque_found", + keyFingerprint: toFingerprint(keyString), + userId: 21, + userRole: "user", + createdAt: Date.now(), + expiresAt: Date.now() + 3_600_000, + }); + mockFindKeyList.mockResolvedValue([ + buildKey(1, 21, "sk-not-match"), + buildKey(2, 21, keyString), + ]); + mockValidateApiKeyAndGetUser.mockResolvedValue(authResult); + + const { getSessionWithDualRead } = await import("@/lib/auth"); + const session = await getSessionWithDualRead({ allowReadOnlyAccess: true }); + + expect(session).toEqual(authResult); + expect(mockReadSession).toHaveBeenCalledWith("sid_opaque_found"); + expect(mockFindKeyList).toHaveBeenCalledWith(21); + expect(mockValidateApiKeyAndGetUser).toHaveBeenCalledTimes(1); + expect(mockValidateApiKeyAndGetUser).toHaveBeenCalledWith(keyString); + }); + + it("validateSession falls back to legacy path when opaque session is missing in dual mode", async () => { + setSessionMode("dual"); + setAuthToken("sk-dual-fallback"); + const authResult = buildAuthResult("sk-dual-fallback", 22); + mockReadSession.mockResolvedValue(null); + mockValidateApiKeyAndGetUser.mockResolvedValue(authResult); + + const { validateSession } = await import("@/lib/auth"); + const session = await validateSession(); + + expect(session).toEqual(authResult); + expect(mockReadSession).toHaveBeenCalledTimes(1); + expect(mockReadSession).toHaveBeenCalledWith("sk-dual-fallback"); + expect(mockValidateApiKeyAndGetUser).toHaveBeenCalledTimes(1); + expect(mockValidateApiKeyAndGetUser).toHaveBeenCalledWith("sk-dual-fallback"); + }); + + it("dual mode gracefully falls back to legacy when opaque session store read fails", async () => { + setSessionMode("dual"); + setAuthToken("sk-store-error"); + const authResult = buildAuthResult("sk-store-error", 23); + mockReadSession.mockRejectedValue(new Error("redis unavailable")); + mockValidateApiKeyAndGetUser.mockResolvedValue(authResult); + + const { getSessionWithDualRead } = await import("@/lib/auth"); + const session = await getSessionWithDualRead(); + + expect(session).toEqual(authResult); + expect(mockReadSession).toHaveBeenCalledTimes(1); + expect(mockValidateApiKeyAndGetUser).toHaveBeenCalledTimes(1); + expect(loggerMock.warn).toHaveBeenCalledWith( + "Opaque session read failed", + expect.objectContaining({ + error: expect.stringContaining("redis unavailable"), + }) + ); + }); +}); diff --git a/tests/security/constant-time-compare.test.ts b/tests/security/constant-time-compare.test.ts new file mode 100644 index 000000000..7177b2b4c --- /dev/null +++ b/tests/security/constant-time-compare.test.ts @@ -0,0 +1,43 @@ +import { describe, expect, it } from "vitest"; +import { constantTimeEqual } from "@/lib/security/constant-time-compare"; + +describe("constantTimeEqual", () => { + it("returns true for equal strings", () => { + expect(constantTimeEqual("hello", "hello")).toBe(true); + }); + + it("returns false for different strings of same length", () => { + expect(constantTimeEqual("hello", "world")).toBe(false); + }); + + it("returns false for strings of different lengths", () => { + expect(constantTimeEqual("short", "a-much-longer-string")).toBe(false); + }); + + it("returns true for empty strings", () => { + expect(constantTimeEqual("", "")).toBe(true); + }); + + it("returns false when one string is empty and the other is not", () => { + expect(constantTimeEqual("", "nonempty")).toBe(false); + expect(constantTimeEqual("nonempty", "")).toBe(false); + }); + + it("handles unicode correctly", () => { + expect(constantTimeEqual("\u00e9", "\u00e9")).toBe(true); + expect(constantTimeEqual("\u00e9", "e")).toBe(false); + }); + + it("handles long token-like strings", () => { + const tokenA = "sk-ant-api03-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"; + const tokenB = "sk-ant-api03-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"; + const tokenC = "sk-ant-api03-BBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"; + expect(constantTimeEqual(tokenA, tokenB)).toBe(true); + expect(constantTimeEqual(tokenA, tokenC)).toBe(false); + }); + + it("is reflexive", () => { + const s = "test-token-value"; + expect(constantTimeEqual(s, s)).toBe(true); + }); +}); diff --git a/tests/security/csrf-origin-guard.test.ts b/tests/security/csrf-origin-guard.test.ts new file mode 100644 index 000000000..3382caf95 --- /dev/null +++ b/tests/security/csrf-origin-guard.test.ts @@ -0,0 +1,133 @@ +import { afterEach, describe, expect, it } from "vitest"; +import { createCsrfOriginGuard } from "@/lib/security/csrf-origin-guard"; + +function createRequest(headers: Record) { + return { + headers: new Headers(headers), + }; +} + +describe("createCsrfOriginGuard", () => { + const originalNodeEnv = process.env.NODE_ENV; + + afterEach(() => { + process.env.NODE_ENV = originalNodeEnv; + }); + + it("allows same-origin request when allowSameOrigin is enabled", () => { + const guard = createCsrfOriginGuard({ + allowedOrigins: [], + allowSameOrigin: true, + enforceInDevelopment: true, + }); + + const result = guard.check( + createRequest({ + "sec-fetch-site": "same-origin", + }) + ); + + expect(result).toEqual({ allowed: true }); + }); + + it("allows request when Origin is in allowlist", () => { + const origin = "https://example.com"; + const guard = createCsrfOriginGuard({ + allowedOrigins: [origin], + allowSameOrigin: false, + enforceInDevelopment: true, + }); + + const result = guard.check( + createRequest({ + "sec-fetch-site": "cross-site", + origin, + }) + ); + + expect(result).toEqual({ allowed: true }); + }); + + it("blocks request when Origin is not in allowlist", () => { + const guard = createCsrfOriginGuard({ + allowedOrigins: ["https://allowed.example.com"], + allowSameOrigin: false, + enforceInDevelopment: true, + }); + + const result = guard.check( + createRequest({ + origin: "https://evil.example.com", + }) + ); + + expect(result.allowed).toBe(false); + expect(result.reason).toBe("Origin https://evil.example.com not in allowlist"); + }); + + it("allows request without Origin header", () => { + const guard = createCsrfOriginGuard({ + allowedOrigins: [], + allowSameOrigin: true, + enforceInDevelopment: true, + }); + + const result = guard.check(createRequest({})); + + expect(result).toEqual({ allowed: true }); + }); + + it("blocks cross-site request when Origin header is missing", () => { + const guard = createCsrfOriginGuard({ + allowedOrigins: ["https://example.com"], + allowSameOrigin: true, + enforceInDevelopment: true, + }); + + const result = guard.check( + createRequest({ + "sec-fetch-site": "cross-site", + }) + ); + + expect(result.allowed).toBe(false); + expect(result.reason).toBe("Cross-site request blocked: missing Origin header"); + }); + + it("matches allowedOrigins case-insensitively", () => { + const guard = createCsrfOriginGuard({ + allowedOrigins: ["https://Example.COM"], + allowSameOrigin: false, + enforceInDevelopment: true, + }); + + const result = guard.check( + createRequest({ + "sec-fetch-site": "cross-site", + origin: "https://example.com", + }) + ); + + expect(result).toEqual({ allowed: true }); + }); + + it("bypasses guard in development when enforceInDevelopment is disabled", () => { + process.env.NODE_ENV = "development"; + + const guard = createCsrfOriginGuard({ + allowedOrigins: ["https://allowed.example.com"], + allowSameOrigin: false, + enforceInDevelopment: false, + }); + + const result = guard.check( + createRequest({ + "sec-fetch-site": "cross-site", + origin: "https://evil.example.com", + }) + ); + + expect(result.allowed).toBe(true); + expect(result.reason).toBe("csrf_guard_bypassed_in_development"); + }); +}); diff --git a/tests/security/full-security-regression.test.ts b/tests/security/full-security-regression.test.ts new file mode 100644 index 000000000..26d0c0dd7 --- /dev/null +++ b/tests/security/full-security-regression.test.ts @@ -0,0 +1,283 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { createCsrfOriginGuard } from "../../src/lib/security/csrf-origin-guard"; +import { LoginAbusePolicy } from "../../src/lib/security/login-abuse-policy"; +import { + buildSecurityHeaders, + DEFAULT_SECURITY_HEADERS_CONFIG, +} from "../../src/lib/security/security-headers"; + +const mockCookieSet = vi.hoisted(() => vi.fn()); +const mockCookies = vi.hoisted(() => vi.fn()); +const mockGetRedisClient = vi.hoisted(() => vi.fn()); + +vi.mock("next/headers", () => ({ + cookies: mockCookies, + headers: vi.fn().mockResolvedValue(new Headers()), +})); + +vi.mock("@/lib/config/config", () => ({ + config: { + auth: { + adminToken: "test-admin-token", + }, + }, +})); + +vi.mock("@/repository/key", () => ({ + findKeyList: vi.fn(), + validateApiKeyAndGetUser: vi.fn(), +})); + +vi.mock("@/lib/redis", () => ({ + getRedisClient: mockGetRedisClient, +})); + +const ORIGINAL_SESSION_TOKEN_MODE = process.env.SESSION_TOKEN_MODE; +const ORIGINAL_ENABLE_SECURE_COOKIES = process.env.ENABLE_SECURE_COOKIES; + +function restoreAuthEnv() { + if (ORIGINAL_SESSION_TOKEN_MODE === undefined) { + delete process.env.SESSION_TOKEN_MODE; + } else { + process.env.SESSION_TOKEN_MODE = ORIGINAL_SESSION_TOKEN_MODE; + } + + if (ORIGINAL_ENABLE_SECURE_COOKIES === undefined) { + delete process.env.ENABLE_SECURE_COOKIES; + } else { + process.env.ENABLE_SECURE_COOKIES = ORIGINAL_ENABLE_SECURE_COOKIES; + } +} + +function setupCookieStoreMock() { + mockCookieSet.mockClear(); + mockCookies.mockResolvedValue({ + set: mockCookieSet, + get: vi.fn(), + delete: vi.fn(), + }); +} + +class FakeRedisClient { + status: "ready" = "ready"; + private readonly values = new Map(); + + async setex(key: string, _ttl: number, value: string): Promise<"OK"> { + this.values.set(key, value); + return "OK"; + } + + async get(key: string): Promise { + return this.values.get(key) ?? null; + } + + async del(key: string): Promise { + return this.values.delete(key) ? 1 : 0; + } +} + +describe("Full Security Regression Suite", () => { + beforeEach(() => { + setupCookieStoreMock(); + }); + + afterEach(() => { + restoreAuthEnv(); + vi.useRealTimers(); + vi.clearAllMocks(); + vi.resetModules(); + }); + + describe("Session Contract", () => { + it("SESSION_TOKEN_MODE defaults to opaque", async () => { + delete process.env.SESSION_TOKEN_MODE; + + vi.resetModules(); + const { getSessionTokenMode } = await import("../../src/lib/auth"); + + expect(getSessionTokenMode()).toBe("opaque"); + }); + + it("OpaqueSessionContract has required fields", async () => { + vi.resetModules(); + const { isOpaqueSessionContract } = await import("../../src/lib/auth"); + + const contract = { + sessionId: "sid_opaque_session_123", + keyFingerprint: "sha256:abc123", + createdAt: 1_700_000_000, + expiresAt: 1_700_000_300, + userId: 42, + userRole: "admin", + }; + + expect(isOpaqueSessionContract(contract)).toBe(true); + + const missingUserRole = { ...contract } as Partial; + delete missingUserRole.userRole; + expect(isOpaqueSessionContract(missingUserRole)).toBe(false); + }); + }); + + describe("Session Store", () => { + it("create returns valid session data", async () => { + const redis = new FakeRedisClient(); + mockGetRedisClient.mockReturnValue(redis); + const { RedisSessionStore } = await import( + "../../src/lib/auth-session-store/redis-session-store" + ); + + const store = new RedisSessionStore(); + + const created = await store.create({ + keyFingerprint: "sha256:fp-1", + userId: 101, + userRole: "user", + }); + + expect(created.sessionId).toMatch(/^sid_[0-9a-f-]{36}$/i); + expect(created.keyFingerprint).toBe("sha256:fp-1"); + expect(created.userId).toBe(101); + expect(created.userRole).toBe("user"); + expect(created.expiresAt).toBeGreaterThan(created.createdAt); + await expect(store.read(created.sessionId)).resolves.toEqual(created); + }); + + it("read returns null for non-existent session", async () => { + const redis = new FakeRedisClient(); + mockGetRedisClient.mockReturnValue(redis); + const { RedisSessionStore } = await import( + "../../src/lib/auth-session-store/redis-session-store" + ); + + const store = new RedisSessionStore(); + + await expect(store.read("missing-session")).resolves.toBeNull(); + }); + }); + + describe("Cookie Hardening", () => { + it("auth cookie is HttpOnly", async () => { + process.env.ENABLE_SECURE_COOKIES = "true"; + + vi.resetModules(); + const { AUTH_COOKIE_NAME, setAuthCookie } = await import("../../src/lib/auth"); + + await setAuthCookie("test-key"); + + expect(mockCookieSet).toHaveBeenCalledTimes(1); + const [name, value, options] = mockCookieSet.mock.calls[0]; + expect(name).toBe(AUTH_COOKIE_NAME); + expect(value).toBe("test-key"); + expect(options.httpOnly).toBe(true); + }); + + it("auth cookie secure flag matches env", async () => { + const cases = [ + { envValue: "true", expected: true }, + { envValue: "false", expected: false }, + ] as const; + + for (const testCase of cases) { + mockCookieSet.mockClear(); + process.env.ENABLE_SECURE_COOKIES = testCase.envValue; + + vi.resetModules(); + const { setAuthCookie } = await import("../../src/lib/auth"); + await setAuthCookie("env-test"); + + const [, , options] = mockCookieSet.mock.calls[0]; + expect(options.secure).toBe(testCase.expected); + } + }); + }); + + describe("Anti-Bruteforce", () => { + it("blocks after threshold", () => { + vi.useFakeTimers(); + vi.setSystemTime(new Date("2026-02-18T10:00:00.000Z")); + + const policy = new LoginAbusePolicy({ maxAttemptsPerIp: 2, lockoutSeconds: 60 }); + const ip = "198.51.100.10"; + + policy.recordFailure(ip); + policy.recordFailure(ip); + + const decision = policy.check(ip); + expect(decision.allowed).toBe(false); + expect(decision.reason).toBe("ip_rate_limited"); + expect(decision.retryAfterSeconds).toBeGreaterThan(0); + }); + + it("resets on success", () => { + vi.useFakeTimers(); + vi.setSystemTime(new Date("2026-02-18T10:00:00.000Z")); + + const policy = new LoginAbusePolicy({ maxAttemptsPerIp: 2, lockoutSeconds: 60 }); + const ip = "198.51.100.11"; + + policy.recordFailure(ip); + policy.recordFailure(ip); + expect(policy.check(ip).allowed).toBe(false); + + policy.recordSuccess(ip); + expect(policy.check(ip)).toEqual({ allowed: true }); + }); + }); + + describe("CSRF Guard", () => { + it("allows same-origin", () => { + const guard = createCsrfOriginGuard({ + allowedOrigins: ["https://safe.example.com"], + allowSameOrigin: true, + enforceInDevelopment: true, + }); + + const result = guard.check({ + headers: new Headers({ + "sec-fetch-site": "same-origin", + }), + }); + + expect(result).toEqual({ allowed: true }); + }); + + it("blocks cross-origin", () => { + const guard = createCsrfOriginGuard({ + allowedOrigins: ["https://safe.example.com"], + allowSameOrigin: true, + enforceInDevelopment: true, + }); + + const result = guard.check({ + headers: new Headers({ + "sec-fetch-site": "cross-site", + origin: "https://evil.example.com", + }), + }); + + expect(result.allowed).toBe(false); + expect(result.reason).toBe("Origin https://evil.example.com not in allowlist"); + }); + }); + + describe("Security Headers", () => { + it("includes all required headers", () => { + const headers = buildSecurityHeaders(); + + expect(headers["X-Content-Type-Options"]).toBe("nosniff"); + expect(headers["X-Frame-Options"]).toBe(DEFAULT_SECURITY_HEADERS_CONFIG.frameOptions); + expect(headers["Referrer-Policy"]).toBe("strict-origin-when-cross-origin"); + expect(headers["X-DNS-Prefetch-Control"]).toBe("off"); + expect(headers["Content-Security-Policy-Report-Only"]).toContain("default-src 'self'"); + }); + + it("CSP report-only by default", () => { + expect(DEFAULT_SECURITY_HEADERS_CONFIG.cspMode).toBe("report-only"); + + const headers = buildSecurityHeaders(); + expect(headers["Content-Security-Policy-Report-Only"]).toContain("default-src 'self'"); + expect(headers["Content-Security-Policy"]).toBeUndefined(); + }); + }); +}); diff --git a/tests/security/login-abuse-policy.test.ts b/tests/security/login-abuse-policy.test.ts new file mode 100644 index 000000000..90cbf62c5 --- /dev/null +++ b/tests/security/login-abuse-policy.test.ts @@ -0,0 +1,234 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { LoginAbusePolicy } from "@/lib/security/login-abuse-policy"; + +describe("LoginAbusePolicy", () => { + const nowMs = 1_700_000_000_000; + + beforeEach(() => { + vi.useFakeTimers(); + vi.setSystemTime(new Date(nowMs)); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + it("allows requests under threshold", () => { + const policy = new LoginAbusePolicy({ maxAttemptsPerIp: 3 }); + const ip = "192.168.0.1"; + + expect(policy.check(ip)).toEqual({ allowed: true }); + policy.recordFailure(ip); + expect(policy.check(ip)).toEqual({ allowed: true }); + policy.recordFailure(ip); + expect(policy.check(ip)).toEqual({ allowed: true }); + }); + + it("blocks after maxAttemptsPerIp failures", () => { + const policy = new LoginAbusePolicy({ maxAttemptsPerIp: 3, lockoutSeconds: 60 }); + const ip = "192.168.0.2"; + + policy.recordFailure(ip); + policy.recordFailure(ip); + policy.recordFailure(ip); + + const decision = policy.check(ip); + expect(decision.allowed).toBe(false); + expect(decision.reason).toBe("ip_rate_limited"); + }); + + it("returns retryAfterSeconds when blocked", () => { + const policy = new LoginAbusePolicy({ maxAttemptsPerIp: 1, lockoutSeconds: 90 }); + const ip = "192.168.0.3"; + + policy.recordFailure(ip); + + const decision = policy.check(ip); + expect(decision.allowed).toBe(false); + expect(decision.retryAfterSeconds).toBe(90); + }); + + it("lockout remains active even after window expires", () => { + const policy = new LoginAbusePolicy({ + maxAttemptsPerIp: 1, + windowSeconds: 5, + lockoutSeconds: 20, + }); + const ip = "192.168.0.33"; + + policy.recordFailure(ip); + vi.advanceTimersByTime(6_000); + + const decision = policy.check(ip); + expect(decision.allowed).toBe(false); + expect(decision.reason).toBe("ip_rate_limited"); + expect(decision.retryAfterSeconds).toBe(14); + }); + + it("recordSuccess resets the counter", () => { + const policy = new LoginAbusePolicy({ maxAttemptsPerIp: 2, lockoutSeconds: 60 }); + const ip = "192.168.0.4"; + + policy.recordFailure(ip); + policy.recordFailure(ip); + expect(policy.check(ip).allowed).toBe(false); + + policy.recordSuccess(ip); + + expect(policy.check(ip)).toEqual({ allowed: true }); + }); + + it("expired window resets automatically", () => { + const policy = new LoginAbusePolicy({ + maxAttemptsPerIp: 2, + windowSeconds: 10, + lockoutSeconds: 60, + }); + const ip = "192.168.0.5"; + + policy.recordFailure(ip); + vi.advanceTimersByTime(11_000); + + policy.recordFailure(ip); + expect(policy.check(ip)).toEqual({ allowed: true }); + }); + + it("custom config overrides defaults", () => { + const policy = new LoginAbusePolicy({ + maxAttemptsPerIp: 1, + maxAttemptsPerKey: 2, + windowSeconds: 30, + lockoutSeconds: 120, + }); + const ip = "192.168.0.6"; + + policy.recordFailure(ip); + + const decision = policy.check(ip); + expect(decision.allowed).toBe(false); + expect(decision.retryAfterSeconds).toBe(120); + }); + + it("tracks different IPs independently", () => { + const policy = new LoginAbusePolicy({ maxAttemptsPerIp: 1, lockoutSeconds: 60 }); + const blockedIp = "10.0.0.1"; + const allowedIp = "10.0.0.2"; + + policy.recordFailure(blockedIp); + + expect(policy.check(blockedIp).allowed).toBe(false); + expect(policy.check(allowedIp)).toEqual({ allowed: true }); + }); + + it("supports key-based throttling with separate threshold", () => { + const policy = new LoginAbusePolicy({ + maxAttemptsPerIp: 10, + maxAttemptsPerKey: 2, + lockoutSeconds: 60, + }); + + policy.recordFailure("10.0.0.10", "user@example.com"); + policy.recordFailure("10.0.0.11", "user@example.com"); + + const blockedByKey = policy.check("10.0.0.12", "user@example.com"); + expect(blockedByKey.allowed).toBe(false); + expect(blockedByKey.reason).toBe("key_rate_limited"); + + expect(policy.check("10.0.0.10", "other@example.com")).toEqual({ allowed: true }); + }); + + it("sweeps stale entries to prevent unbounded memory growth", () => { + const policy = new LoginAbusePolicy({ + maxAttemptsPerIp: 2, + windowSeconds: 5, + lockoutSeconds: 10, + }); + + for (let i = 0; i < 100; i++) { + policy.recordFailure(`10.0.${Math.floor(i / 256)}.${i % 256}`); + } + + vi.advanceTimersByTime(61_000); + + policy.check("10.0.99.99"); + + for (let i = 0; i < 100; i++) { + const ip = `10.0.${Math.floor(i / 256)}.${i % 256}`; + expect(policy.check(ip)).toEqual({ allowed: true }); + } + }); + + it("uses LRU eviction: recently accessed entries survive over stale ones", () => { + const policy = new LoginAbusePolicy({ + maxAttemptsPerIp: 5, + windowSeconds: 600, + lockoutSeconds: 900, + }); + + // Fill 10_050 entries via recordFailure (does NOT trigger sweep). + const totalEntries = 10_050; + for (let i = 0; i < totalEntries; i++) { + const ip = `${Math.floor(i / 65536) % 256}.${Math.floor(i / 256) % 256}.${i % 256}.1`; + policy.recordFailure(ip); + } + + // "Touch" an early IP via recordFailure - LRU bump moves it to the end. + // Position 10 (i=10) is inside the eviction range [0..49], so without + // the LRU bump this entry WOULD be evicted. + const touchedIp = "0.0.10.1"; + policy.recordFailure(touchedIp); + + // Pick an un-bumped IP also inside the eviction range to verify it IS evicted. + const evictedIp = "0.0.5.1"; + + // Trigger a sweep by calling check (lastSweepAt=0, so sweep interval met). + // Sweep finds size 10_050 > 10_000, evicts 50 from the start. + // The touchedIp was bumped to end, so it survives eviction. + vi.advanceTimersByTime(61_000); + policy.check("99.99.99.99"); + + // Negative assertion: un-bumped early entry was evicted (starts fresh). + expect(policy.check(evictedIp)).toEqual({ allowed: true }); + + // touchedIp had 1 (initial) + 1 (bump) = 2 failures. + // Record 3 more to hit threshold of 5. + policy.recordFailure(touchedIp); + policy.recordFailure(touchedIp); + policy.recordFailure(touchedIp); + + const decision = policy.check(touchedIp); + expect(decision.allowed).toBe(false); + expect(decision.reason).toBe("ip_rate_limited"); + }); + + it("LRU bump in recordFailureForScope preserves active entries", () => { + const policy = new LoginAbusePolicy({ + maxAttemptsPerIp: 10, + windowSeconds: 600, + lockoutSeconds: 900, + }); + + // Fill with stale entries + for (let i = 0; i < 10_050; i++) { + const ip = `${Math.floor(i / 65536) % 256}.${Math.floor(i / 256) % 256}.${i % 256}.2`; + policy.recordFailure(ip); + } + + // Record additional failures on an early entry (LRU bump via recordFailure) + const activeIp = "0.0.10.2"; + policy.recordFailure(activeIp); + + // Trigger sweep + vi.advanceTimersByTime(61_000); + policy.check("99.99.99.99"); + + // The actively-failed IP should still be tracked + // Record enough total failures to trigger lockout (it had 1 initial + 1 bump = 2) + for (let j = 0; j < 8; j++) { + policy.recordFailure(activeIp); + } + const decision = policy.check(activeIp); + expect(decision.allowed).toBe(false); + expect(decision.reason).toBe("ip_rate_limited"); + }); +}); diff --git a/tests/security/proxy-auth-rate-limit.test.ts b/tests/security/proxy-auth-rate-limit.test.ts new file mode 100644 index 000000000..debd6870a --- /dev/null +++ b/tests/security/proxy-auth-rate-limit.test.ts @@ -0,0 +1,160 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; + +/** + * Tests for the proxy auth pre-auth rate limiter. + * + * The rate limiter is a module-level LoginAbusePolicy instance inside + * auth-guard.ts. Since it relies on ProxySession (which depends on Hono + * Context), we test the underlying LoginAbusePolicy behaviour that the + * guard delegates to, plus the IP extraction helper logic. + */ + +// We test the LoginAbusePolicy directly with proxy-specific config +import { LoginAbusePolicy } from "@/lib/security/login-abuse-policy"; + +describe("Proxy pre-auth rate limiter (LoginAbusePolicy with proxy config)", () => { + const nowMs = 1_700_000_000_000; + + beforeEach(() => { + vi.useFakeTimers(); + vi.setSystemTime(new Date(nowMs)); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + it("allows requests below the proxy threshold (20)", () => { + const policy = new LoginAbusePolicy({ + maxAttemptsPerIp: 20, + maxAttemptsPerKey: 20, + windowSeconds: 300, + lockoutSeconds: 600, + }); + const ip = "10.0.0.1"; + + for (let i = 0; i < 19; i++) { + policy.recordFailure(ip); + } + + expect(policy.check(ip)).toEqual({ allowed: true }); + }); + + it("blocks after 20 consecutive failures", () => { + const policy = new LoginAbusePolicy({ + maxAttemptsPerIp: 20, + maxAttemptsPerKey: 20, + windowSeconds: 300, + lockoutSeconds: 600, + }); + const ip = "10.0.0.2"; + + for (let i = 0; i < 20; i++) { + policy.recordFailure(ip); + } + + const decision = policy.check(ip); + expect(decision.allowed).toBe(false); + expect(decision.retryAfterSeconds).toBe(600); + }); + + it("resets failure count after success", () => { + const policy = new LoginAbusePolicy({ + maxAttemptsPerIp: 20, + maxAttemptsPerKey: 20, + windowSeconds: 300, + lockoutSeconds: 600, + }); + const ip = "10.0.0.3"; + + for (let i = 0; i < 15; i++) { + policy.recordFailure(ip); + } + + policy.recordSuccess(ip); + + // After success, counter is reset — 5 more failures should be allowed + for (let i = 0; i < 5; i++) { + policy.recordFailure(ip); + } + expect(policy.check(ip)).toEqual({ allowed: true }); + }); + + it("unlocks after lockout period expires", () => { + const policy = new LoginAbusePolicy({ + maxAttemptsPerIp: 20, + maxAttemptsPerKey: 20, + windowSeconds: 300, + lockoutSeconds: 600, + }); + const ip = "10.0.0.4"; + + for (let i = 0; i < 20; i++) { + policy.recordFailure(ip); + } + + expect(policy.check(ip).allowed).toBe(false); + + // Advance past lockout + vi.advanceTimersByTime(601_000); + expect(policy.check(ip).allowed).toBe(true); + }); + + it("tracks different IPs independently", () => { + const policy = new LoginAbusePolicy({ + maxAttemptsPerIp: 3, + maxAttemptsPerKey: 3, + windowSeconds: 300, + lockoutSeconds: 600, + }); + + const ipA = "10.0.0.10"; + const ipB = "10.0.0.11"; + + for (let i = 0; i < 3; i++) { + policy.recordFailure(ipA); + } + + expect(policy.check(ipA).allowed).toBe(false); + expect(policy.check(ipB).allowed).toBe(true); + }); +}); + +describe("extractClientIp logic (rightmost x-forwarded-for)", () => { + it("takes rightmost IP from x-forwarded-for", () => { + // Simulates: client spoofs leftmost, proxy appends real IP + const forwarded = "spoofed-ip, real-client-ip"; + const ips = forwarded + .split(",") + .map((s) => s.trim()) + .filter(Boolean); + expect(ips[ips.length - 1]).toBe("real-client-ip"); + }); + + it("handles single IP in x-forwarded-for", () => { + const forwarded = "192.168.1.1"; + const ips = forwarded + .split(",") + .map((s) => s.trim()) + .filter(Boolean); + expect(ips[ips.length - 1]).toBe("192.168.1.1"); + }); + + it("prefers x-real-ip over x-forwarded-for", () => { + // The implementation checks x-real-ip first + const realIp = "10.0.0.1"; + const forwarded = "spoofed, 10.0.0.2"; + + // x-real-ip is present and non-empty → use it + const result = realIp.trim() || undefined; + expect(result).toBe("10.0.0.1"); + }); + + it("returns 'unknown' when no headers present", () => { + const realIp: string | null = null; + const forwarded: string | null = null; + + const result = realIp?.trim() || forwarded || "unknown"; + expect(result).toBe("unknown"); + }); +}); diff --git a/tests/security/security-headers-integration.test.ts b/tests/security/security-headers-integration.test.ts new file mode 100644 index 000000000..ae97746b9 --- /dev/null +++ b/tests/security/security-headers-integration.test.ts @@ -0,0 +1,196 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { NextRequest } from "next/server"; +import { applyCors } from "../../src/app/v1/_lib/cors"; + +const mockValidateKey = vi.hoisted(() => vi.fn()); +const mockSetAuthCookie = vi.hoisted(() => vi.fn()); +const mockGetSessionTokenMode = vi.hoisted(() => vi.fn()); +const mockGetLoginRedirectTarget = vi.hoisted(() => vi.fn()); +const mockClearAuthCookie = vi.hoisted(() => vi.fn()); +const mockGetAuthCookie = vi.hoisted(() => vi.fn()); +const mockGetTranslations = vi.hoisted(() => vi.fn()); +const mockGetEnvConfig = vi.hoisted(() => vi.fn()); +const mockLogger = vi.hoisted(() => ({ + warn: vi.fn(), + error: vi.fn(), + info: vi.fn(), + debug: vi.fn(), +})); + +vi.mock("@/lib/auth", () => ({ + validateKey: mockValidateKey, + setAuthCookie: mockSetAuthCookie, + getSessionTokenMode: mockGetSessionTokenMode, + getLoginRedirectTarget: mockGetLoginRedirectTarget, + clearAuthCookie: mockClearAuthCookie, + getAuthCookie: mockGetAuthCookie, + withNoStoreHeaders: (response: T): T => { + (response as Response).headers.set("Cache-Control", "no-store, no-cache, must-revalidate"); + (response as Response).headers.set("Pragma", "no-cache"); + return response; + }, +})); + +vi.mock("next-intl/server", () => ({ + getTranslations: mockGetTranslations, +})); + +vi.mock("@/lib/config/env.schema", () => ({ + getEnvConfig: mockGetEnvConfig, +})); + +vi.mock("@/lib/logger", () => ({ + logger: mockLogger, +})); + +type LoginPostHandler = (request: NextRequest) => Promise; +type LogoutPostHandler = (request: NextRequest) => Promise; + +function makeLoginRequest(body: unknown): NextRequest { + return new NextRequest("http://localhost/api/auth/login", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(body), + }); +} + +function makeLogoutRequest(): NextRequest { + return new NextRequest("http://localhost/api/auth/logout", { + method: "POST", + }); +} + +function expectSharedSecurityHeaders(response: Response) { + expect(response.headers.get("X-Frame-Options")).toBe("DENY"); + expect(response.headers.get("Referrer-Policy")).toBe("strict-origin-when-cross-origin"); + expect(response.headers.get("X-DNS-Prefetch-Control")).toBe("off"); +} + +const fakeSession = { + user: { + id: 1, + name: "Test User", + description: "desc", + role: "user" as const, + }, + key: { + canLoginWebUi: true, + }, +}; + +describe("security headers auth route integration", () => { + let loginPost: LoginPostHandler; + let logoutPost: LogoutPostHandler; + + beforeEach(async () => { + vi.resetModules(); + vi.clearAllMocks(); + + const t = vi.fn((messageKey: string) => `translated:${messageKey}`); + mockGetTranslations.mockResolvedValue(t); + mockValidateKey.mockResolvedValue(fakeSession); + mockSetAuthCookie.mockResolvedValue(undefined); + mockGetSessionTokenMode.mockReturnValue("legacy"); + mockGetLoginRedirectTarget.mockReturnValue("/dashboard"); + mockClearAuthCookie.mockResolvedValue(undefined); + mockGetAuthCookie.mockResolvedValue(undefined); + mockGetEnvConfig.mockReturnValue({ ENABLE_SECURE_COOKIES: false }); + + const loginRoute = await import("../../src/app/api/auth/login/route"); + loginPost = loginRoute.POST; + + const logoutRoute = await import("../../src/app/api/auth/logout/route"); + logoutPost = logoutRoute.POST; + }); + + it("login success response includes security headers", async () => { + const res = await loginPost(makeLoginRequest({ key: "valid-key" })); + + expect(res.status).toBe(200); + expectSharedSecurityHeaders(res); + expect(res.headers.get("X-Content-Type-Options")).toBe("nosniff"); + }); + + it("login error response includes security headers", async () => { + const res = await loginPost(makeLoginRequest({})); + + expect(res.status).toBe(400); + expectSharedSecurityHeaders(res); + expect(res.headers.get("X-Content-Type-Options")).toBe("nosniff"); + }); + + it("logout response includes security headers", async () => { + const res = await logoutPost(makeLogoutRequest()); + + expect(res.status).toBe(200); + expectSharedSecurityHeaders(res); + expect(res.headers.get("X-Content-Type-Options")).toBe("nosniff"); + }); + + it("CSP is applied in report-only mode by default", async () => { + const res = await loginPost(makeLoginRequest({ key: "valid-key" })); + + expect(res.headers.get("Content-Security-Policy-Report-Only")).toContain("default-src 'self'"); + expect(res.headers.get("Content-Security-Policy")).toBeNull(); + }); + + it("HSTS is present when ENABLE_SECURE_COOKIES=true", async () => { + mockGetEnvConfig.mockReturnValue({ ENABLE_SECURE_COOKIES: true }); + + const res = await loginPost(makeLoginRequest({ key: "valid-key" })); + + expect(res.headers.get("Strict-Transport-Security")).toBe( + "max-age=31536000; includeSubDomains" + ); + }); + + it("HSTS is absent when ENABLE_SECURE_COOKIES=false", async () => { + mockGetEnvConfig.mockReturnValue({ ENABLE_SECURE_COOKIES: false }); + + const res = await logoutPost(makeLogoutRequest()); + + expect(res.headers.get("Strict-Transport-Security")).toBeNull(); + }); + + it("X-Content-Type-Options is always nosniff", async () => { + mockGetEnvConfig.mockReturnValue({ ENABLE_SECURE_COOKIES: true }); + const secureRes = await loginPost(makeLoginRequest({ key: "valid-key" })); + + mockGetEnvConfig.mockReturnValue({ ENABLE_SECURE_COOKIES: false }); + const errorRes = await loginPost(makeLoginRequest({})); + const logoutRes = await logoutPost(makeLogoutRequest()); + + expect(secureRes.headers.get("X-Content-Type-Options")).toBe("nosniff"); + expect(errorRes.headers.get("X-Content-Type-Options")).toBe("nosniff"); + expect(logoutRes.headers.get("X-Content-Type-Options")).toBe("nosniff"); + }); + + it("security headers remain compatible with existing CORS headers", async () => { + const res = await loginPost(makeLoginRequest({ key: "valid-key" })); + const corsRes = applyCors(res, { + origin: "https://client.example.com", + requestHeaders: "content-type,x-api-key", + }); + + // Without allowCredentials, origin is NOT reflected — stays as wildcard + expect(corsRes.headers.get("Access-Control-Allow-Origin")).toBe("*"); + expect(corsRes.headers.get("Access-Control-Allow-Credentials")).toBeNull(); + expect(corsRes.headers.get("Access-Control-Allow-Headers")).toBe("content-type,x-api-key"); + expect(corsRes.headers.get("Content-Security-Policy-Report-Only")).toContain( + "default-src 'self'" + ); + expect(corsRes.headers.get("X-Content-Type-Options")).toBe("nosniff"); + }); + + it("CORS reflects origin only when allowCredentials is explicitly set", async () => { + const res = await loginPost(makeLoginRequest({ key: "valid-key" })); + const corsRes = applyCors(res, { + origin: "https://trusted.example.com", + requestHeaders: "content-type", + allowCredentials: true, + }); + + expect(corsRes.headers.get("Access-Control-Allow-Origin")).toBe("https://trusted.example.com"); + expect(corsRes.headers.get("Access-Control-Allow-Credentials")).toBe("true"); + }); +}); diff --git a/tests/security/security-headers.test.ts b/tests/security/security-headers.test.ts new file mode 100644 index 000000000..7647a7294 --- /dev/null +++ b/tests/security/security-headers.test.ts @@ -0,0 +1,111 @@ +import { describe, expect, test } from "vitest"; +import { + buildSecurityHeaders, + DEFAULT_SECURITY_HEADERS_CONFIG, +} from "../../src/lib/security/security-headers"; + +describe("buildSecurityHeaders", () => { + test("默认配置应生成预期安全头", () => { + const headers = buildSecurityHeaders(); + + expect(headers["X-Content-Type-Options"]).toBe("nosniff"); + expect(headers["X-Frame-Options"]).toBe(DEFAULT_SECURITY_HEADERS_CONFIG.frameOptions); + expect(headers["Referrer-Policy"]).toBe("strict-origin-when-cross-origin"); + expect(headers["X-DNS-Prefetch-Control"]).toBe("off"); + expect(headers["Strict-Transport-Security"]).toBeUndefined(); + expect(headers["Content-Security-Policy"]).toBeUndefined(); + expect(headers["Content-Security-Policy-Report-Only"]).toContain("default-src 'self'"); + }); + + test("enableHsts=true 时应包含 HSTS 头", () => { + const headers = buildSecurityHeaders({ enableHsts: true }); + + expect(headers["Strict-Transport-Security"]).toBe( + `max-age=${DEFAULT_SECURITY_HEADERS_CONFIG.hstsMaxAge}; includeSubDomains` + ); + }); + + test("enableHsts=false 时不应包含 HSTS 头", () => { + const headers = buildSecurityHeaders({ enableHsts: false }); + + expect(headers["Strict-Transport-Security"]).toBeUndefined(); + }); + + test("CSP report-only 模式应使用 Report-Only 头", () => { + const headers = buildSecurityHeaders({ cspMode: "report-only" }); + + expect(headers["Content-Security-Policy-Report-Only"]).toContain("default-src 'self'"); + expect(headers["Content-Security-Policy"]).toBeUndefined(); + }); + + test("CSP enforce 模式应使用强制策略头", () => { + const headers = buildSecurityHeaders({ cspMode: "enforce" }); + + expect(headers["Content-Security-Policy"]).toContain("default-src 'self'"); + expect(headers["Content-Security-Policy-Report-Only"]).toBeUndefined(); + }); + + test("CSP disabled 模式不应输出任何 CSP 头", () => { + const headers = buildSecurityHeaders({ cspMode: "disabled" }); + + expect(headers["Content-Security-Policy"]).toBeUndefined(); + expect(headers["Content-Security-Policy-Report-Only"]).toBeUndefined(); + }); + + test("X-Content-Type-Options 始终为 nosniff", () => { + const defaultHeaders = buildSecurityHeaders(); + const disabledCspHeaders = buildSecurityHeaders({ cspMode: "disabled" }); + const enforceCspHeaders = buildSecurityHeaders({ cspMode: "enforce", enableHsts: true }); + + expect(defaultHeaders["X-Content-Type-Options"]).toBe("nosniff"); + expect(disabledCspHeaders["X-Content-Type-Options"]).toBe("nosniff"); + expect(enforceCspHeaders["X-Content-Type-Options"]).toBe("nosniff"); + }); + + test("X-Frame-Options 应与配置一致", () => { + const denyHeaders = buildSecurityHeaders({ frameOptions: "DENY" }); + const sameOriginHeaders = buildSecurityHeaders({ frameOptions: "SAMEORIGIN" }); + + expect(denyHeaders["X-Frame-Options"]).toBe("DENY"); + expect(sameOriginHeaders["X-Frame-Options"]).toBe("SAMEORIGIN"); + }); + + test("cspReportUri with valid URL appends report-uri directive", () => { + const headers = buildSecurityHeaders({ + cspMode: "report-only", + cspReportUri: "https://csp.example.com/report", + }); + + expect(headers["Content-Security-Policy-Report-Only"]).toContain( + "; report-uri https://csp.example.com/report" + ); + }); + + test("cspReportUri with semicolons is rejected to prevent directive injection", () => { + const headers = buildSecurityHeaders({ + cspMode: "enforce", + cspReportUri: "https://evil.com; script-src 'unsafe-eval'", + }); + + expect(headers["Content-Security-Policy"]).not.toContain("report-uri"); + expect(headers["Content-Security-Policy"]).not.toContain("evil.com"); + }); + + test("cspReportUri with non-URL value is rejected", () => { + const headers = buildSecurityHeaders({ + cspMode: "enforce", + cspReportUri: "not a url", + }); + + expect(headers["Content-Security-Policy"]).not.toContain("report-uri"); + }); + + test("cspReportUri with empty string is rejected", () => { + const headers = buildSecurityHeaders({ + cspMode: "enforce", + cspReportUri: "", + }); + + expect(headers["Content-Security-Policy"]).not.toContain("report-uri"); + }); +}); diff --git a/tests/security/session-contract.test.ts b/tests/security/session-contract.test.ts new file mode 100644 index 000000000..f94929736 --- /dev/null +++ b/tests/security/session-contract.test.ts @@ -0,0 +1,112 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; + +const ORIGINAL_SESSION_TOKEN_MODE = process.env.SESSION_TOKEN_MODE; + +function restoreSessionTokenModeEnv() { + if (ORIGINAL_SESSION_TOKEN_MODE === undefined) { + delete process.env.SESSION_TOKEN_MODE; + return; + } + process.env.SESSION_TOKEN_MODE = ORIGINAL_SESSION_TOKEN_MODE; +} + +describe("session token contract and migration flags", () => { + afterEach(() => { + restoreSessionTokenModeEnv(); + vi.resetModules(); + }); + + it("SESSION_TOKEN_MODE defaults to opaque", async () => { + delete process.env.SESSION_TOKEN_MODE; + + vi.resetModules(); + const { getSessionTokenMode } = await import("@/lib/auth"); + + expect(getSessionTokenMode()).toBe("opaque"); + }); + + it("getSessionTokenMode returns configured mode values", async () => { + const modes = ["legacy", "dual", "opaque"] as const; + + for (const mode of modes) { + process.env.SESSION_TOKEN_MODE = mode; + + vi.resetModules(); + const { getSessionTokenMode } = await import("@/lib/auth"); + + expect(getSessionTokenMode()).toBe(mode); + } + }); + + it("validates OpaqueSessionContract runtime shape strictly", async () => { + vi.resetModules(); + const { isOpaqueSessionContract } = await import("@/lib/auth"); + + const validContract = { + sessionId: "sid_opaque_session_123", + keyFingerprint: "sha256:abc123", + createdAt: 1_700_000_000, + expiresAt: 1_700_000_300, + userId: 42, + userRole: "admin", + }; + + expect(isOpaqueSessionContract(validContract)).toBe(true); + expect( + isOpaqueSessionContract({ + ...validContract, + keyFingerprint: "", + }) + ).toBe(false); + expect( + isOpaqueSessionContract({ + ...validContract, + expiresAt: validContract.createdAt, + }) + ).toBe(false); + expect( + isOpaqueSessionContract({ + ...validContract, + userId: 3.14, + }) + ).toBe(false); + }); + + it("accepts both legacy cookie and opaque session in dual mode", async () => { + process.env.SESSION_TOKEN_MODE = "dual"; + + vi.resetModules(); + const { getSessionTokenMode, getSessionTokenMigrationFlags, isSessionTokenAccepted } = + await import("@/lib/auth"); + + const mode = getSessionTokenMode(); + expect(mode).toBe("dual"); + expect(getSessionTokenMigrationFlags(mode)).toEqual({ + dualReadWindowEnabled: true, + hardCutoverEnabled: false, + emergencyRollbackEnabled: false, + }); + + expect(isSessionTokenAccepted("sk-legacy-cookie", mode)).toBe(true); + expect(isSessionTokenAccepted("sid_opaque_session_cookie", mode)).toBe(true); + }); + + it("accepts only legacy cookie in legacy mode", async () => { + process.env.SESSION_TOKEN_MODE = "legacy"; + + vi.resetModules(); + const { getSessionTokenMode, getSessionTokenMigrationFlags, isSessionTokenAccepted } = + await import("@/lib/auth"); + + const mode = getSessionTokenMode(); + expect(mode).toBe("legacy"); + expect(getSessionTokenMigrationFlags(mode)).toEqual({ + dualReadWindowEnabled: false, + hardCutoverEnabled: false, + emergencyRollbackEnabled: true, + }); + + expect(isSessionTokenAccepted("sk-legacy-cookie", mode)).toBe(true); + expect(isSessionTokenAccepted("sid_opaque_session_cookie", mode)).toBe(false); + }); +}); diff --git a/tests/security/session-cookie-hardening.test.ts b/tests/security/session-cookie-hardening.test.ts new file mode 100644 index 000000000..45dd85149 --- /dev/null +++ b/tests/security/session-cookie-hardening.test.ts @@ -0,0 +1,205 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { NextRequest, NextResponse } from "next/server"; + +const mockValidateKey = vi.hoisted(() => vi.fn()); +const mockSetAuthCookie = vi.hoisted(() => vi.fn()); +const mockGetSessionTokenMode = vi.hoisted(() => vi.fn()); +const mockGetLoginRedirectTarget = vi.hoisted(() => vi.fn()); +const mockGetTranslations = vi.hoisted(() => vi.fn()); +const mockLogger = vi.hoisted(() => ({ + warn: vi.fn(), + error: vi.fn(), + info: vi.fn(), + debug: vi.fn(), +})); +const mockCookieSet = vi.hoisted(() => vi.fn()); +const mockCookies = vi.hoisted(() => vi.fn()); +const mockGetEnvConfig = vi.hoisted(() => vi.fn()); +const mockClearAuthCookie = vi.hoisted(() => vi.fn()); + +const realWithNoStoreHeaders = vi.hoisted(() => { + return >(response: T): T => { + response.headers.set("Cache-Control", "no-store, no-cache, must-revalidate"); + response.headers.set("Pragma", "no-cache"); + return response; + }; +}); + +vi.mock("@/lib/auth", () => ({ + validateKey: mockValidateKey, + setAuthCookie: mockSetAuthCookie, + getSessionTokenMode: mockGetSessionTokenMode, + clearAuthCookie: mockClearAuthCookie, + getLoginRedirectTarget: mockGetLoginRedirectTarget, + toKeyFingerprint: vi.fn().mockResolvedValue("sha256:mock"), + withNoStoreHeaders: realWithNoStoreHeaders, +})); + +vi.mock("next-intl/server", () => ({ + getTranslations: mockGetTranslations, +})); + +vi.mock("@/lib/logger", () => ({ + logger: mockLogger, +})); + +vi.mock("@/lib/config/env.schema", () => ({ + getEnvConfig: mockGetEnvConfig, +})); + +vi.mock("@/lib/security/auth-response-headers", () => ({ + withAuthResponseHeaders: realWithNoStoreHeaders, +})); + +vi.mock("@/lib/config/config", () => ({ config: { auth: { adminToken: "test" } } })); +vi.mock("@/repository/key", () => ({ validateApiKeyAndGetUser: vi.fn() })); + +vi.mock("next/headers", () => ({ + cookies: mockCookies, + headers: vi.fn().mockResolvedValue(new Headers()), +})); + +const EXPECTED_CACHE_CONTROL = "no-store, no-cache, must-revalidate"; +const EXPECTED_PRAGMA = "no-cache"; + +function makeLoginRequest(body: unknown): NextRequest { + return new NextRequest("http://localhost/api/auth/login", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(body), + }); +} + +function makeLogoutRequest(): NextRequest { + return new NextRequest("http://localhost/api/auth/logout", { + method: "POST", + }); +} + +const fakeSession = { + user: { id: 1, name: "Test User", description: "desc", role: "user" as const }, + key: { canLoginWebUi: true }, +}; + +describe("session cookie hardening", () => { + describe("withNoStoreHeaders utility", () => { + it("sets Cache-Control header", () => { + const res = NextResponse.json({ ok: true }); + const result = realWithNoStoreHeaders(res); + expect(result.headers.get("Cache-Control")).toBe(EXPECTED_CACHE_CONTROL); + }); + + it("sets Pragma header", () => { + const res = NextResponse.json({ ok: true }); + const result = realWithNoStoreHeaders(res); + expect(result.headers.get("Pragma")).toBe(EXPECTED_PRAGMA); + }); + + it("returns the same response object", () => { + const res = NextResponse.json({ ok: true }); + const result = realWithNoStoreHeaders(res); + expect(result).toBe(res); + }); + }); + + describe("login route no-store headers", () => { + let POST: (request: NextRequest) => Promise; + + beforeEach(async () => { + vi.clearAllMocks(); + const mockT = vi.fn((key: string) => `translated:${key}`); + mockGetTranslations.mockResolvedValue(mockT); + mockSetAuthCookie.mockResolvedValue(undefined); + mockGetSessionTokenMode.mockReturnValue("legacy"); + mockGetEnvConfig.mockReturnValue({ ENABLE_SECURE_COOKIES: false }); + + const mod = await import("@/app/api/auth/login/route"); + POST = mod.POST; + }); + + it("success response includes Cache-Control: no-store", async () => { + mockValidateKey.mockResolvedValue(fakeSession); + mockGetLoginRedirectTarget.mockReturnValue("/dashboard"); + + const res = await POST(makeLoginRequest({ key: "valid" })); + + expect(res.status).toBe(200); + expect(res.headers.get("Cache-Control")).toBe(EXPECTED_CACHE_CONTROL); + }); + + it("success response includes Pragma: no-cache", async () => { + mockValidateKey.mockResolvedValue(fakeSession); + mockGetLoginRedirectTarget.mockReturnValue("/dashboard"); + + const res = await POST(makeLoginRequest({ key: "valid" })); + + expect(res.headers.get("Pragma")).toBe(EXPECTED_PRAGMA); + }); + + it("400 error response includes Cache-Control: no-store", async () => { + const res = await POST(makeLoginRequest({})); + + expect(res.status).toBe(400); + expect(res.headers.get("Cache-Control")).toBe(EXPECTED_CACHE_CONTROL); + }); + + it("400 error response includes Pragma: no-cache", async () => { + const res = await POST(makeLoginRequest({})); + + expect(res.headers.get("Pragma")).toBe(EXPECTED_PRAGMA); + }); + + it("401 error response includes Cache-Control: no-store", async () => { + mockValidateKey.mockResolvedValue(null); + + const res = await POST(makeLoginRequest({ key: "bad" })); + + expect(res.status).toBe(401); + expect(res.headers.get("Cache-Control")).toBe(EXPECTED_CACHE_CONTROL); + }); + + it("401 error response includes Pragma: no-cache", async () => { + mockValidateKey.mockResolvedValue(null); + + const res = await POST(makeLoginRequest({ key: "bad" })); + + expect(res.headers.get("Pragma")).toBe(EXPECTED_PRAGMA); + }); + + it("500 error response includes no-store headers", async () => { + mockValidateKey.mockRejectedValue(new Error("db down")); + + const res = await POST(makeLoginRequest({ key: "any" })); + + expect(res.status).toBe(500); + expect(res.headers.get("Cache-Control")).toBe(EXPECTED_CACHE_CONTROL); + expect(res.headers.get("Pragma")).toBe(EXPECTED_PRAGMA); + }); + }); + + describe("logout route no-store headers", () => { + let POST: (request: NextRequest) => Promise; + + beforeEach(async () => { + vi.clearAllMocks(); + mockClearAuthCookie.mockResolvedValue(undefined); + mockGetEnvConfig.mockReturnValue({ ENABLE_SECURE_COOKIES: false }); + + const mod = await import("@/app/api/auth/logout/route"); + POST = mod.POST; + }); + + it("response includes Cache-Control: no-store", async () => { + const res = await POST(makeLogoutRequest()); + + expect(res.status).toBe(200); + expect(res.headers.get("Cache-Control")).toBe(EXPECTED_CACHE_CONTROL); + }); + + it("response includes Pragma: no-cache", async () => { + const res = await POST(makeLogoutRequest()); + + expect(res.headers.get("Pragma")).toBe(EXPECTED_PRAGMA); + }); + }); +}); diff --git a/tests/security/session-fixation-rotation.test.ts b/tests/security/session-fixation-rotation.test.ts new file mode 100644 index 000000000..a43ceec68 --- /dev/null +++ b/tests/security/session-fixation-rotation.test.ts @@ -0,0 +1,178 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { NextRequest } from "next/server"; +import type { NextResponse } from "next/server"; + +const { + mockClearAuthCookie, + mockGetAuthCookie, + mockGetSessionTokenMode, + mockRevoke, + mockRotate, + mockRedisSessionStoreCtor, + mockLogger, +} = vi.hoisted(() => { + const mockRevoke = vi.fn(); + const mockRotate = vi.fn(); + + return { + mockClearAuthCookie: vi.fn(), + mockGetAuthCookie: vi.fn(), + mockGetSessionTokenMode: vi.fn(), + mockRevoke, + mockRotate, + mockRedisSessionStoreCtor: vi.fn().mockImplementation(function RedisSessionStoreMock() { + return { + revoke: mockRevoke, + rotate: mockRotate, + }; + }), + mockLogger: { + warn: vi.fn(), + error: vi.fn(), + info: vi.fn(), + debug: vi.fn(), + trace: vi.fn(), + }, + }; +}); + +const realWithNoStoreHeaders = vi.hoisted(() => { + return >(response: T): T => { + response.headers.set("Cache-Control", "no-store, no-cache, must-revalidate"); + response.headers.set("Pragma", "no-cache"); + return response; + }; +}); + +vi.mock("@/lib/auth", () => ({ + clearAuthCookie: mockClearAuthCookie, + getAuthCookie: mockGetAuthCookie, + getSessionTokenMode: mockGetSessionTokenMode, + withNoStoreHeaders: realWithNoStoreHeaders, +})); + +vi.mock("@/lib/auth-session-store/redis-session-store", () => ({ + RedisSessionStore: mockRedisSessionStoreCtor, +})); + +vi.mock("@/lib/logger", () => ({ + logger: mockLogger, +})); + +vi.mock("@/lib/config/env.schema", () => ({ + getEnvConfig: vi.fn().mockReturnValue({ ENABLE_SECURE_COOKIES: false }), +})); + +vi.mock("@/lib/security/auth-response-headers", () => ({ + withAuthResponseHeaders: realWithNoStoreHeaders, +})); + +function makeLogoutRequest(): NextRequest { + return new NextRequest("http://localhost/api/auth/logout", { + method: "POST", + headers: { + "sec-fetch-site": "same-origin", + }, + }); +} + +async function loadLogoutPost(): Promise<(request: NextRequest) => Promise> { + const mod = await import("@/app/api/auth/logout/route"); + return mod.POST; +} + +async function simulatePostLoginSessionRotation( + oldSessionId: string, + rotate: (sessionId: string) => Promise<{ sessionId: string } | null> +): Promise { + const rotated = await rotate(oldSessionId); + return rotated?.sessionId ?? null; +} + +describe("session fixation rotation and logout revocation", () => { + beforeEach(() => { + vi.resetModules(); + vi.clearAllMocks(); + mockRedisSessionStoreCtor.mockImplementation(function RedisSessionStoreMock() { + return { + revoke: mockRevoke, + rotate: mockRotate, + }; + }); + mockClearAuthCookie.mockResolvedValue(undefined); + mockGetAuthCookie.mockResolvedValue(undefined); + mockGetSessionTokenMode.mockReturnValue("legacy"); + mockRevoke.mockResolvedValue(true); + mockRotate.mockResolvedValue(null); + }); + + it("legacy mode logout only clears cookie without session store revocation", async () => { + mockGetSessionTokenMode.mockReturnValue("legacy"); + const POST = await loadLogoutPost(); + + const response = await POST(makeLogoutRequest()); + + expect(response.status).toBe(200); + expect(mockRedisSessionStoreCtor).not.toHaveBeenCalled(); + expect(mockRevoke).not.toHaveBeenCalled(); + expect(mockClearAuthCookie).toHaveBeenCalledTimes(1); + }); + + it("dual mode logout revokes session and clears cookie", async () => { + mockGetSessionTokenMode.mockReturnValue("dual"); + mockGetAuthCookie.mockResolvedValue("sid_dual_session"); + const POST = await loadLogoutPost(); + + const response = await POST(makeLogoutRequest()); + + expect(response.status).toBe(200); + expect(mockRedisSessionStoreCtor).toHaveBeenCalledTimes(1); + expect(mockRevoke).toHaveBeenCalledWith("sid_dual_session"); + expect(mockClearAuthCookie).toHaveBeenCalledTimes(1); + }); + + it("opaque mode logout revokes session and clears cookie", async () => { + mockGetSessionTokenMode.mockReturnValue("opaque"); + mockGetAuthCookie.mockResolvedValue("sid_opaque_session"); + const POST = await loadLogoutPost(); + + const response = await POST(makeLogoutRequest()); + + expect(response.status).toBe(200); + expect(mockRedisSessionStoreCtor).toHaveBeenCalledTimes(1); + expect(mockRevoke).toHaveBeenCalledWith("sid_opaque_session"); + expect(mockClearAuthCookie).toHaveBeenCalledTimes(1); + }); + + it("logout still clears cookie when session revocation fails", async () => { + mockGetSessionTokenMode.mockReturnValue("opaque"); + mockGetAuthCookie.mockResolvedValue("sid_revocation_failure"); + mockRevoke.mockRejectedValue(new Error("redis down")); + const POST = await loadLogoutPost(); + + const response = await POST(makeLogoutRequest()); + + expect(response.status).toBe(200); + expect(mockRevoke).toHaveBeenCalledWith("sid_revocation_failure"); + expect(mockClearAuthCookie).toHaveBeenCalledTimes(1); + expect(mockLogger.warn).toHaveBeenCalledTimes(1); + }); + + it("post-login rotation returns a different session id", async () => { + const oldSessionId = "sid_existing_session"; + mockRotate.mockResolvedValue({ + sessionId: "sid_rotated_session", + keyFingerprint: "fp-login", + userId: 7, + userRole: "user", + createdAt: 1_700_000_000_000, + expiresAt: 1_700_000_300_000, + }); + + const rotatedSessionId = await simulatePostLoginSessionRotation(oldSessionId, mockRotate); + + expect(mockRotate).toHaveBeenCalledWith(oldSessionId); + expect(rotatedSessionId).toBe("sid_rotated_session"); + expect(rotatedSessionId).not.toBe(oldSessionId); + }); +}); diff --git a/tests/security/session-login-integration.test.ts b/tests/security/session-login-integration.test.ts new file mode 100644 index 000000000..4c825e248 --- /dev/null +++ b/tests/security/session-login-integration.test.ts @@ -0,0 +1,237 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { NextRequest } from "next/server"; + +const mockValidateKey = vi.hoisted(() => vi.fn()); +const mockSetAuthCookie = vi.hoisted(() => vi.fn()); +const mockGetSessionTokenMode = vi.hoisted(() => vi.fn()); +const mockGetLoginRedirectTarget = vi.hoisted(() => vi.fn()); +const mockToKeyFingerprint = vi.hoisted(() => vi.fn()); +const mockGetTranslations = vi.hoisted(() => vi.fn()); +const mockCreateSession = vi.hoisted(() => vi.fn()); +const mockGetEnvConfig = vi.hoisted(() => vi.fn()); +const mockLogger = vi.hoisted(() => ({ + warn: vi.fn(), + error: vi.fn(), + info: vi.fn(), + debug: vi.fn(), +})); + +const realWithNoStoreHeaders = vi.hoisted(() => { + return (response: any) => { + response.headers.set("Cache-Control", "no-store, no-cache, must-revalidate"); + response.headers.set("Pragma", "no-cache"); + return response; + }; +}); + +vi.mock("@/lib/auth", () => ({ + validateKey: mockValidateKey, + setAuthCookie: mockSetAuthCookie, + getSessionTokenMode: mockGetSessionTokenMode, + getLoginRedirectTarget: mockGetLoginRedirectTarget, + toKeyFingerprint: mockToKeyFingerprint, + withNoStoreHeaders: realWithNoStoreHeaders, +})); + +vi.mock("@/lib/auth-session-store/redis-session-store", () => ({ + RedisSessionStore: class { + create = mockCreateSession; + }, +})); + +vi.mock("next-intl/server", () => ({ + getTranslations: mockGetTranslations, +})); + +vi.mock("@/lib/logger", () => ({ + logger: mockLogger, +})); + +vi.mock("@/lib/config/env.schema", () => ({ + getEnvConfig: mockGetEnvConfig, +})); + +vi.mock("@/lib/security/auth-response-headers", () => ({ + withAuthResponseHeaders: realWithNoStoreHeaders, +})); + +function makeRequest(body: unknown): NextRequest { + return new NextRequest("http://localhost/api/auth/login", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(body), + }); +} + +const dashboardSession = { + user: { + id: 1, + name: "Test User", + description: "desc", + role: "user" as const, + }, + key: { canLoginWebUi: true }, +}; + +const readonlySession = { + user: { + id: 2, + name: "Readonly User", + description: "readonly", + role: "user" as const, + }, + key: { canLoginWebUi: false }, +}; + +describe("POST /api/auth/login session token mode integration", () => { + let POST: (request: NextRequest) => Promise; + + beforeEach(async () => { + vi.clearAllMocks(); + const mockT = vi.fn((key: string) => `translated:${key}`); + mockGetTranslations.mockResolvedValue(mockT); + + mockValidateKey.mockResolvedValue(dashboardSession); + mockSetAuthCookie.mockResolvedValue(undefined); + mockGetSessionTokenMode.mockReturnValue("legacy"); + mockGetLoginRedirectTarget.mockReturnValue("/dashboard"); + mockToKeyFingerprint.mockResolvedValue( + "sha256:a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + ); + mockGetEnvConfig.mockReturnValue({ ENABLE_SECURE_COOKIES: false }); + mockCreateSession.mockResolvedValue({ + sessionId: "sid_opaque_session_123", + keyFingerprint: "sha256:abcdef", + userId: 1, + userRole: "user", + createdAt: 100, + expiresAt: 200, + }); + + const mod = await import("../../src/app/api/auth/login/route"); + POST = mod.POST; + }); + + it("legacy mode keeps raw key cookie and does not create opaque session", async () => { + mockGetSessionTokenMode.mockReturnValue("legacy"); + + const res = await POST(makeRequest({ key: "legacy-key" })); + const json = await res.json(); + + expect(res.status).toBe(200); + expect(mockSetAuthCookie).toHaveBeenCalledTimes(1); + expect(mockSetAuthCookie).toHaveBeenCalledWith("legacy-key"); + expect(mockCreateSession).not.toHaveBeenCalled(); + expect(json.redirectTo).toBe("/dashboard"); + expect(json.loginType).toBe("dashboard_user"); + }); + + it("dual mode sets legacy cookie and creates opaque session in store", async () => { + mockGetSessionTokenMode.mockReturnValue("dual"); + + const res = await POST(makeRequest({ key: "dual-key" })); + const json = await res.json(); + + expect(res.status).toBe(200); + expect(mockSetAuthCookie).toHaveBeenCalledTimes(1); + expect(mockSetAuthCookie).toHaveBeenCalledWith("dual-key"); + expect(mockCreateSession).toHaveBeenCalledTimes(1); + expect(mockCreateSession).toHaveBeenCalledWith( + expect.objectContaining({ + userId: 1, + userRole: "user", + keyFingerprint: expect.stringMatching(/^sha256:[a-f0-9]{64}$/), + }) + ); + expect(json.redirectTo).toBe("/dashboard"); + expect(json.loginType).toBe("dashboard_user"); + }); + + it("opaque mode writes sessionId cookie instead of raw key", async () => { + mockGetSessionTokenMode.mockReturnValue("opaque"); + mockCreateSession.mockResolvedValue({ + sessionId: "sid_opaque_session_cookie", + keyFingerprint: "sha256:abcdef", + userId: 1, + userRole: "user", + createdAt: 100, + expiresAt: 200, + }); + + const res = await POST(makeRequest({ key: "opaque-key" })); + const json = await res.json(); + + expect(res.status).toBe(200); + expect(mockCreateSession).toHaveBeenCalledTimes(1); + expect(mockSetAuthCookie).toHaveBeenCalledTimes(1); + expect(mockSetAuthCookie).toHaveBeenCalledWith("sid_opaque_session_cookie"); + expect(mockSetAuthCookie).not.toHaveBeenCalledWith("opaque-key"); + expect(json.redirectTo).toBe("/dashboard"); + expect(json.loginType).toBe("dashboard_user"); + }); + + it("dual mode remains successful when opaque session creation fails", async () => { + mockGetSessionTokenMode.mockReturnValue("dual"); + mockCreateSession.mockRejectedValue(new Error("redis unavailable")); + + const res = await POST(makeRequest({ key: "dual-fallback-key" })); + const json = await res.json(); + + expect(res.status).toBe(200); + expect(json.ok).toBe(true); + expect(mockSetAuthCookie).toHaveBeenCalledTimes(1); + expect(mockSetAuthCookie).toHaveBeenCalledWith("dual-fallback-key"); + expect(mockCreateSession).toHaveBeenCalledTimes(1); + expect(mockLogger.warn).toHaveBeenCalledWith( + "Failed to create opaque session in dual mode", + expect.objectContaining({ + error: expect.stringContaining("redis unavailable"), + }) + ); + }); + + it("all modes preserve readonly redirect semantics", async () => { + mockValidateKey.mockResolvedValue(readonlySession); + mockGetLoginRedirectTarget.mockReturnValue("/my-usage"); + + const modes = ["legacy", "dual", "opaque"] as const; + + for (const mode of modes) { + vi.clearAllMocks(); + mockGetSessionTokenMode.mockReturnValue(mode); + mockValidateKey.mockResolvedValue(readonlySession); + mockGetLoginRedirectTarget.mockReturnValue("/my-usage"); + mockSetAuthCookie.mockResolvedValue(undefined); + mockCreateSession.mockResolvedValue({ + sessionId: `sid_${mode}_session`, + keyFingerprint: "sha256:abcdef", + userId: 2, + userRole: "user", + createdAt: 100, + expiresAt: 200, + }); + + const res = await POST(makeRequest({ key: `${mode}-readonly-key` })); + const json = await res.json(); + + expect(res.status).toBe(200); + expect(json.redirectTo).toBe("/my-usage"); + expect(json.loginType).toBe("readonly_user"); + + if (mode === "legacy") { + expect(mockCreateSession).not.toHaveBeenCalled(); + expect(mockSetAuthCookie).toHaveBeenCalledWith("legacy-readonly-key"); + } + + if (mode === "dual") { + expect(mockCreateSession).toHaveBeenCalledTimes(1); + expect(mockSetAuthCookie).toHaveBeenCalledWith("dual-readonly-key"); + } + + if (mode === "opaque") { + expect(mockCreateSession).toHaveBeenCalledTimes(1); + expect(mockSetAuthCookie).toHaveBeenCalledWith("sid_opaque_session"); + } + } + }); +}); diff --git a/tests/security/session-store.test.ts b/tests/security/session-store.test.ts new file mode 100644 index 000000000..bba336877 --- /dev/null +++ b/tests/security/session-store.test.ts @@ -0,0 +1,262 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; + +const { getRedisClientMock, loggerMock } = vi.hoisted(() => ({ + getRedisClientMock: vi.fn(), + loggerMock: { + error: vi.fn(), + warn: vi.fn(), + info: vi.fn(), + debug: vi.fn(), + trace: vi.fn(), + }, +})); + +vi.mock("@/lib/redis", () => ({ + getRedisClient: getRedisClientMock, +})); + +vi.mock("@/lib/logger", () => ({ + logger: loggerMock, +})); + +class FakeRedis { + status: "ready" | "end" = "ready"; + readonly store = new Map(); + readonly ttlByKey = new Map(); + + throwOnGet = false; + throwOnSetex = false; + throwOnDel = false; + + readonly get = vi.fn(async (key: string) => { + if (this.throwOnGet) throw new Error("redis get failed"); + return this.store.get(key) ?? null; + }); + + readonly setex = vi.fn(async (key: string, ttlSeconds: number, value: string) => { + if (this.throwOnSetex) throw new Error("redis setex failed"); + this.store.set(key, value); + this.ttlByKey.set(key, ttlSeconds); + return "OK"; + }); + + readonly del = vi.fn(async (key: string) => { + if (this.throwOnDel) throw new Error("redis del failed"); + const existed = this.store.delete(key); + this.ttlByKey.delete(key); + return existed ? 1 : 0; + }); +} + +describe("RedisSessionStore", () => { + let redis: FakeRedis; + + beforeEach(() => { + vi.useFakeTimers(); + vi.setSystemTime(new Date("2026-02-18T10:00:00.000Z")); + vi.clearAllMocks(); + + redis = new FakeRedis(); + getRedisClientMock.mockReturnValue(redis); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + it("create() returns session data with generated sessionId", async () => { + const { DEFAULT_SESSION_TTL } = await import("@/lib/auth-session-store"); + const { RedisSessionStore } = await import("@/lib/auth-session-store/redis-session-store"); + + const store = new RedisSessionStore(); + const created = await store.create({ keyFingerprint: "fp-1", userId: 101, userRole: "user" }); + + expect(created.sessionId).toMatch(/^sid_[0-9a-f-]{36}$/i); + expect(created.keyFingerprint).toBe("fp-1"); + expect(created.userId).toBe(101); + expect(created.userRole).toBe("user"); + expect(created.createdAt).toBe(new Date("2026-02-18T10:00:00.000Z").getTime()); + expect(created.expiresAt).toBe(created.createdAt + DEFAULT_SESSION_TTL * 1000); + }); + + it("read() returns data for existing session", async () => { + const { RedisSessionStore } = await import("@/lib/auth-session-store/redis-session-store"); + + const session = { + sessionId: "6b5097ff-a11e-4425-aad0-f57f7d2206fc", + keyFingerprint: "fp-existing", + userId: 7, + userRole: "admin", + createdAt: 1_700_000_000_000, + expiresAt: 1_700_000_360_000, + }; + redis.store.set(`cch:session:${session.sessionId}`, JSON.stringify(session)); + + const store = new RedisSessionStore(); + const found = await store.read(session.sessionId); + + expect(found).toEqual(session); + }); + + it("read() returns null for non-existent session", async () => { + const { RedisSessionStore } = await import("@/lib/auth-session-store/redis-session-store"); + + const store = new RedisSessionStore(); + const found = await store.read("missing-session"); + + expect(found).toBeNull(); + }); + + it("read() returns null when Redis read fails", async () => { + const { RedisSessionStore } = await import("@/lib/auth-session-store/redis-session-store"); + + redis.throwOnGet = true; + const store = new RedisSessionStore(); + const found = await store.read("any-session"); + + expect(found).toBeNull(); + expect(loggerMock.error).toHaveBeenCalled(); + }); + + it("revoke() deletes session", async () => { + const { RedisSessionStore } = await import("@/lib/auth-session-store/redis-session-store"); + + const sessionId = "f327f4f4-c95f-40ab-a017-af714df7a3f8"; + redis.store.set(`cch:session:${sessionId}`, JSON.stringify({ sessionId })); + + const store = new RedisSessionStore(); + const revoked = await store.revoke(sessionId); + + expect(revoked).toBe(true); + expect(redis.store.has(`cch:session:${sessionId}`)).toBe(false); + }); + + it("rotate() creates new session and revokes old session", async () => { + const { RedisSessionStore } = await import("@/lib/auth-session-store/redis-session-store"); + + const oldSession = { + sessionId: "e7f7bf87-c3b9-4525-ac0c-c2cf7cd5006b", + keyFingerprint: "fp-rotate", + userId: 18, + userRole: "user", + createdAt: Date.now() - 10_000, + expiresAt: Date.now() + 120_000, + }; + redis.store.set(`cch:session:${oldSession.sessionId}`, JSON.stringify(oldSession)); + + const store = new RedisSessionStore(); + const rotated = await store.rotate(oldSession.sessionId); + + expect(rotated).not.toBeNull(); + expect(rotated?.sessionId).not.toBe(oldSession.sessionId); + expect(rotated?.keyFingerprint).toBe(oldSession.keyFingerprint); + expect(rotated?.userId).toBe(oldSession.userId); + expect(rotated?.userRole).toBe(oldSession.userRole); + expect(redis.store.has(`cch:session:${oldSession.sessionId}`)).toBe(false); + expect(rotated ? redis.store.has(`cch:session:${rotated.sessionId}`) : false).toBe(true); + }); + + it("create() applies TTL and stores expiresAt deterministically", async () => { + const { RedisSessionStore } = await import("@/lib/auth-session-store/redis-session-store"); + + const store = new RedisSessionStore(); + const created = await store.create( + { keyFingerprint: "fp-ttl", userId: 9, userRole: "user" }, + 120 + ); + + const key = `cch:session:${created.sessionId}`; + expect(redis.ttlByKey.get(key)).toBe(120); + expect(created.expiresAt - created.createdAt).toBe(120_000); + }); + + it("create() throws when Redis setex fails", async () => { + const { RedisSessionStore } = await import("@/lib/auth-session-store/redis-session-store"); + + redis.throwOnSetex = true; + const store = new RedisSessionStore(); + + await expect( + store.create({ keyFingerprint: "fp-fail", userId: 3, userRole: "user" }) + ).rejects.toThrow("redis setex failed"); + expect(loggerMock.error).toHaveBeenCalled(); + }); + + it("create() throws when Redis is not ready", async () => { + const { RedisSessionStore } = await import("@/lib/auth-session-store/redis-session-store"); + + redis.status = "end"; + const store = new RedisSessionStore(); + + await expect( + store.create({ keyFingerprint: "fp-noredis", userId: 4, userRole: "user" }) + ).rejects.toThrow("Redis not ready"); + }); + + it("rotate() returns null when Redis setex fails during create", async () => { + const { RedisSessionStore } = await import("@/lib/auth-session-store/redis-session-store"); + + const oldSession = { + sessionId: "2a036ab4-902a-4f31-a782-ec18344e17b9", + keyFingerprint: "fp-failure", + userId: 3, + userRole: "user", + createdAt: Date.now(), + expiresAt: Date.now() + 60_000, + }; + redis.store.set(`cch:session:${oldSession.sessionId}`, JSON.stringify(oldSession)); + redis.throwOnSetex = true; + + const store = new RedisSessionStore(); + const rotated = await store.rotate(oldSession.sessionId); + + expect(rotated).toBeNull(); + expect(redis.store.has(`cch:session:${oldSession.sessionId}`)).toBe(true); + expect(loggerMock.error).toHaveBeenCalled(); + }); + + it("rotate() keeps new session when old session revocation fails", async () => { + const { RedisSessionStore } = await import("@/lib/auth-session-store/redis-session-store"); + + const oldSession = { + sessionId: "aaa-old-session", + keyFingerprint: "fp-revoke-fail", + userId: 5, + userRole: "user", + createdAt: Date.now() - 10_000, + expiresAt: Date.now() + 120_000, + }; + redis.store.set(`cch:session:${oldSession.sessionId}`, JSON.stringify(oldSession)); + redis.throwOnDel = true; + + const store = new RedisSessionStore(); + const rotated = await store.rotate(oldSession.sessionId); + + expect(rotated).not.toBeNull(); + expect(rotated?.keyFingerprint).toBe(oldSession.keyFingerprint); + expect(loggerMock.warn).toHaveBeenCalled(); + }); + + it("rotate() returns null for already-expired session", async () => { + const { RedisSessionStore } = await import("@/lib/auth-session-store/redis-session-store"); + + const expiredSession = { + sessionId: "bbb-expired-session", + keyFingerprint: "fp-expired", + userId: 6, + userRole: "user", + createdAt: Date.now() - 120_000, + expiresAt: Date.now() - 1_000, + }; + redis.store.set(`cch:session:${expiredSession.sessionId}`, JSON.stringify(expiredSession)); + + const store = new RedisSessionStore(); + const rotated = await store.rotate(expiredSession.sessionId); + + expect(rotated).toBeNull(); + expect(loggerMock.warn).toHaveBeenCalledWith( + "[AuthSessionStore] Cannot rotate expired session", + expect.objectContaining({ sessionId: expiredSession.sessionId }) + ); + }); +}); diff --git a/tests/unit/actions/provider-undo-delete.test.ts b/tests/unit/actions/provider-undo-delete.test.ts new file mode 100644 index 000000000..6ab0f21d5 --- /dev/null +++ b/tests/unit/actions/provider-undo-delete.test.ts @@ -0,0 +1,253 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { PROVIDER_BATCH_PATCH_ERROR_CODES } from "../../../src/lib/provider-batch-patch-error-codes"; + +const getSessionMock = vi.fn(); +const deleteProvidersBatchMock = vi.fn(); +const restoreProvidersBatchMock = vi.fn(); +const publishCacheInvalidationMock = vi.fn(); +const clearProviderStateMock = vi.fn(); +const clearConfigCacheMock = vi.fn(); + +vi.mock("@/lib/auth", () => ({ + getSession: getSessionMock, +})); + +vi.mock("@/repository/provider", () => ({ + deleteProvidersBatch: deleteProvidersBatchMock, + findAllProvidersFresh: vi.fn(), + updateProvidersBatch: vi.fn(), +})); + +vi.mock("@/repository", () => ({ + restoreProvidersBatch: restoreProvidersBatchMock, +})); + +vi.mock("@/lib/cache/provider-cache", () => ({ + publishProviderCacheInvalidation: publishCacheInvalidationMock, +})); + +vi.mock("@/lib/circuit-breaker", () => ({ + clearProviderState: clearProviderStateMock, + clearConfigCache: clearConfigCacheMock, + resetCircuit: vi.fn(), + getAllHealthStatusAsync: vi.fn(), +})); + +vi.mock("@/lib/logger", () => ({ + logger: { + trace: vi.fn(), + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + }, +})); + +describe("Provider Delete Undo Actions", () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.resetModules(); + getSessionMock.mockResolvedValue({ user: { id: 1, role: "admin" } }); + deleteProvidersBatchMock.mockResolvedValue(2); + restoreProvidersBatchMock.mockResolvedValue(2); + publishCacheInvalidationMock.mockResolvedValue(undefined); + clearProviderStateMock.mockReturnValue(undefined); + clearConfigCacheMock.mockReturnValue(undefined); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + it("batchDeleteProviders should return undoToken and operationId", async () => { + const { batchDeleteProviders } = await import("../../../src/actions/providers"); + const result = await batchDeleteProviders({ providerIds: [3, 1, 3] }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(deleteProvidersBatchMock).toHaveBeenCalledWith([1, 3]); + expect(result.data.deletedCount).toBe(2); + expect(result.data.undoToken).toMatch(/^provider_patch_undo_/); + expect(result.data.operationId).toMatch(/^provider_patch_apply_/); + }); + + it("batchDeleteProviders should return repository errors", async () => { + deleteProvidersBatchMock.mockRejectedValueOnce(new Error("delete failed")); + + const { batchDeleteProviders } = await import("../../../src/actions/providers"); + const result = await batchDeleteProviders({ providerIds: [7] }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error).toBe("delete failed"); + }); + + it("batchDeleteProviders should reject non-admin session", async () => { + getSessionMock.mockResolvedValueOnce({ user: { id: 3, role: "user" } }); + + const { batchDeleteProviders } = await import("../../../src/actions/providers"); + const result = await batchDeleteProviders({ providerIds: [1] }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error).toBe("无权限执行此操作"); + expect(deleteProvidersBatchMock).not.toHaveBeenCalled(); + }); + + it("batchDeleteProviders should reject empty provider list", async () => { + const { batchDeleteProviders } = await import("../../../src/actions/providers"); + const result = await batchDeleteProviders({ providerIds: [] }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error).toBe("请选择要删除的供应商"); + expect(deleteProvidersBatchMock).not.toHaveBeenCalled(); + }); + + it("batchDeleteProviders should reject provider lists over max size", async () => { + const { batchDeleteProviders } = await import("../../../src/actions/providers"); + const result = await batchDeleteProviders({ + providerIds: Array.from({ length: 501 }, (_, index) => index + 1), + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error).toContain("单次批量操作最多支持"); + expect(deleteProvidersBatchMock).not.toHaveBeenCalled(); + }); + + it("undoProviderDelete should restore providers by snapshot", async () => { + const { batchDeleteProviders, undoProviderDelete } = await import( + "../../../src/actions/providers" + ); + + const deleted = await batchDeleteProviders({ providerIds: [2, 4] }); + if (!deleted.ok) throw new Error(`Delete should succeed: ${deleted.error}`); + + restoreProvidersBatchMock.mockClear(); + publishCacheInvalidationMock.mockClear(); + clearProviderStateMock.mockClear(); + clearConfigCacheMock.mockClear(); + + const undone = await undoProviderDelete({ + undoToken: deleted.data.undoToken, + operationId: deleted.data.operationId, + }); + + expect(undone.ok).toBe(true); + if (!undone.ok) return; + + expect(restoreProvidersBatchMock).toHaveBeenCalledWith([2, 4]); + expect(undone.data.operationId).toBe(deleted.data.operationId); + expect(undone.data.restoredCount).toBe(2); + expect(clearProviderStateMock).toHaveBeenCalledTimes(2); + expect(clearConfigCacheMock).toHaveBeenCalledTimes(2); + expect(publishCacheInvalidationMock).toHaveBeenCalledTimes(1); + }); + + it("undoProviderDelete should expire after 61 seconds", async () => { + vi.useFakeTimers(); + vi.setSystemTime(new Date("2026-02-19T00:00:00.000Z")); + + const { batchDeleteProviders, undoProviderDelete } = await import( + "../../../src/actions/providers" + ); + + const deleted = await batchDeleteProviders({ providerIds: [9] }); + if (!deleted.ok) throw new Error(`Delete should succeed: ${deleted.error}`); + + restoreProvidersBatchMock.mockClear(); + vi.advanceTimersByTime(61_000); + + const undone = await undoProviderDelete({ + undoToken: deleted.data.undoToken, + operationId: deleted.data.operationId, + }); + + expect(undone.ok).toBe(false); + if (undone.ok) return; + + expect(undone.errorCode).toBe(PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_EXPIRED); + expect(restoreProvidersBatchMock).not.toHaveBeenCalled(); + }); + + it("undoProviderDelete should reject mismatched operation id", async () => { + const { batchDeleteProviders, undoProviderDelete } = await import( + "../../../src/actions/providers" + ); + + const deleted = await batchDeleteProviders({ providerIds: [10, 11] }); + if (!deleted.ok) throw new Error(`Delete should succeed: ${deleted.error}`); + + restoreProvidersBatchMock.mockClear(); + + const undone = await undoProviderDelete({ + undoToken: deleted.data.undoToken, + operationId: `${deleted.data.operationId}-mismatch`, + }); + + expect(undone.ok).toBe(false); + if (undone.ok) return; + + expect(undone.errorCode).toBe(PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_CONFLICT); + expect(restoreProvidersBatchMock).not.toHaveBeenCalled(); + }); + + it("undoProviderDelete should reject invalid payload", async () => { + const { undoProviderDelete } = await import("../../../src/actions/providers"); + + const undone = await undoProviderDelete({ + undoToken: "", + operationId: "provider_patch_apply_x", + }); + + expect(undone.ok).toBe(false); + if (undone.ok) return; + + expect(undone.errorCode).toBeDefined(); + expect(restoreProvidersBatchMock).not.toHaveBeenCalled(); + }); + + it("undoProviderDelete should reject non-admin session", async () => { + getSessionMock.mockResolvedValueOnce({ user: { id: 2, role: "user" } }); + + const { undoProviderDelete } = await import("../../../src/actions/providers"); + + const undone = await undoProviderDelete({ + undoToken: "provider_patch_undo_x", + operationId: "provider_patch_apply_x", + }); + + expect(undone.ok).toBe(false); + if (undone.ok) return; + + expect(undone.error).toBe("无权限执行此操作"); + expect(restoreProvidersBatchMock).not.toHaveBeenCalled(); + }); + + it("undoProviderDelete should return repository errors when restore fails", async () => { + const { batchDeleteProviders, undoProviderDelete } = await import( + "../../../src/actions/providers" + ); + + const deleted = await batchDeleteProviders({ providerIds: [12] }); + if (!deleted.ok) throw new Error(`Delete should succeed: ${deleted.error}`); + + restoreProvidersBatchMock.mockRejectedValueOnce(new Error("restore failed")); + + const undone = await undoProviderDelete({ + undoToken: deleted.data.undoToken, + operationId: deleted.data.operationId, + }); + + expect(undone.ok).toBe(false); + if (undone.ok) return; + + expect(undone.error).toBe("restore failed"); + }); +}); diff --git a/tests/unit/actions/provider-undo-edit.test.ts b/tests/unit/actions/provider-undo-edit.test.ts new file mode 100644 index 000000000..4a0466346 --- /dev/null +++ b/tests/unit/actions/provider-undo-edit.test.ts @@ -0,0 +1,396 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { PROVIDER_BATCH_PATCH_ERROR_CODES } from "../../../src/lib/provider-batch-patch-error-codes"; + +const getSessionMock = vi.fn(); +const findProviderByIdMock = vi.fn(); +const updateProviderMock = vi.fn(); +const updateProvidersBatchMock = vi.fn(); +const publishCacheInvalidationMock = vi.fn(); +const clearProviderStateMock = vi.fn(); +const clearConfigCacheMock = vi.fn(); +const saveProviderCircuitConfigMock = vi.fn(); +const deleteProviderCircuitConfigMock = vi.fn(); + +vi.mock("@/lib/auth", () => ({ + getSession: getSessionMock, +})); + +vi.mock("@/repository/provider", () => ({ + findProviderById: findProviderByIdMock, + findAllProvidersFresh: vi.fn(), + updateProvider: updateProviderMock, + updateProvidersBatch: updateProvidersBatchMock, + deleteProvidersBatch: vi.fn(), +})); + +vi.mock("@/repository", () => ({ + restoreProvidersBatch: vi.fn(), +})); + +vi.mock("@/lib/cache/provider-cache", () => ({ + publishProviderCacheInvalidation: publishCacheInvalidationMock, +})); + +vi.mock("@/lib/circuit-breaker", () => ({ + clearProviderState: clearProviderStateMock, + clearConfigCache: clearConfigCacheMock, + resetCircuit: vi.fn(), + getAllHealthStatusAsync: vi.fn(), +})); + +vi.mock("@/lib/redis/circuit-breaker-config", () => ({ + saveProviderCircuitConfig: saveProviderCircuitConfigMock, + deleteProviderCircuitConfig: deleteProviderCircuitConfigMock, +})); + +vi.mock("@/lib/logger", () => ({ + logger: { + trace: vi.fn(), + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + }, +})); + +function makeProvider(id: number, overrides: Record = {}) { + return { + id, + name: `Provider-${id}`, + url: "https://api.example.com/v1", + key: "sk-test", + providerVendorId: null, + isEnabled: true, + weight: 100, + priority: 1, + groupPriorities: null, + costMultiplier: 1.0, + groupTag: null, + providerType: "claude", + preserveClientIp: false, + modelRedirects: null, + allowedModels: null, + mcpPassthroughType: "none", + mcpPassthroughUrl: null, + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + totalCostResetAt: null, + limitConcurrentSessions: null, + maxRetryAttempts: null, + circuitBreakerFailureThreshold: 5, + circuitBreakerOpenDuration: 1800000, + circuitBreakerHalfOpenSuccessThreshold: 2, + proxyUrl: null, + proxyFallbackToDirect: false, + firstByteTimeoutStreamingMs: 30000, + streamingIdleTimeoutMs: 10000, + requestTimeoutNonStreamingMs: 600000, + websiteUrl: null, + faviconUrl: null, + cacheTtlPreference: null, + swapCacheTtlBilling: false, + context1mPreference: null, + codexReasoningEffortPreference: null, + codexReasoningSummaryPreference: null, + codexTextVerbosityPreference: null, + codexParallelToolCallsPreference: null, + anthropicMaxTokensPreference: null, + anthropicThinkingBudgetPreference: null, + anthropicAdaptiveThinking: null, + geminiGoogleSearchPreference: null, + tpm: null, + rpm: null, + rpd: null, + cc: null, + createdAt: new Date("2025-01-01"), + updatedAt: new Date("2025-01-01"), + deletedAt: null, + ...overrides, + }; +} + +describe("Provider Single Edit Undo Actions", () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.resetModules(); + getSessionMock.mockResolvedValue({ user: { id: 1, role: "admin" } }); + findProviderByIdMock.mockResolvedValue(makeProvider(1, { name: "Before Name", key: "sk-old" })); + updateProviderMock.mockResolvedValue(makeProvider(1, { name: "After Name", key: "sk-new" })); + updateProvidersBatchMock.mockResolvedValue(1); + publishCacheInvalidationMock.mockResolvedValue(undefined); + clearProviderStateMock.mockReturnValue(undefined); + clearConfigCacheMock.mockReturnValue(undefined); + saveProviderCircuitConfigMock.mockResolvedValue(undefined); + deleteProviderCircuitConfigMock.mockResolvedValue(undefined); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + it("editProvider should return undoToken and operationId", async () => { + const { editProvider } = await import("../../../src/actions/providers"); + + const result = await editProvider(1, { name: "After Name" }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.undoToken).toMatch(/^provider_patch_undo_/); + expect(result.data.operationId).toMatch(/^provider_patch_apply_/); + expect(findProviderByIdMock).toHaveBeenCalledWith(1); + expect(updateProviderMock).toHaveBeenCalledWith( + 1, + expect.objectContaining({ + name: "After Name", + }) + ); + }); + + it("editProvider should reject when provider is missing before update", async () => { + findProviderByIdMock.mockResolvedValueOnce(null); + + const { editProvider } = await import("../../../src/actions/providers"); + const result = await editProvider(999, { name: "After Name" }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error).toBe("供应商不存在"); + expect(updateProviderMock).not.toHaveBeenCalled(); + }); + + it("editProvider should reject when repository update returns null", async () => { + updateProviderMock.mockResolvedValueOnce(null); + + const { editProvider } = await import("../../../src/actions/providers"); + const result = await editProvider(1, { name: "After Name" }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error).toBe("供应商不存在"); + }); + + it("editProvider should continue when circuit config sync fails", async () => { + updateProviderMock.mockResolvedValueOnce( + makeProvider(1, { + circuitBreakerFailureThreshold: 8, + circuitBreakerOpenDuration: 1800000, + circuitBreakerHalfOpenSuccessThreshold: 2, + }) + ); + saveProviderCircuitConfigMock.mockRejectedValueOnce(new Error("redis down")); + + const { editProvider } = await import("../../../src/actions/providers"); + const result = await editProvider(1, { + name: "After Name", + circuit_breaker_failure_threshold: 8, + }); + + expect(result.ok).toBe(true); + expect(saveProviderCircuitConfigMock).toHaveBeenCalledWith( + 1, + expect.objectContaining({ + failureThreshold: 8, + }) + ); + expect(clearConfigCacheMock).not.toHaveBeenCalled(); + }); + + it("undoProviderPatch should revert a single edit", async () => { + const { editProvider, undoProviderPatch } = await import("../../../src/actions/providers"); + + const edited = await editProvider(1, { name: "After Name" }); + if (!edited.ok) throw new Error(`Edit should succeed: ${edited.error}`); + + updateProvidersBatchMock.mockClear(); + publishCacheInvalidationMock.mockClear(); + + const undone = await undoProviderPatch({ + undoToken: edited.data.undoToken, + operationId: edited.data.operationId, + }); + + expect(undone.ok).toBe(true); + if (!undone.ok) return; + + expect(updateProvidersBatchMock).toHaveBeenCalledWith( + [1], + expect.objectContaining({ + name: "Before Name", + }) + ); + expect(undone.data.revertedCount).toBe(1); + expect(publishCacheInvalidationMock).toHaveBeenCalledTimes(1); + }); + + it("undoProviderPatch should not include key field in preimage", async () => { + findProviderByIdMock.mockResolvedValueOnce(makeProvider(1, { key: "sk-before" })); + updateProviderMock.mockResolvedValueOnce(makeProvider(1, { key: "sk-after" })); + + const { editProvider, undoProviderPatch } = await import("../../../src/actions/providers"); + + const edited = await editProvider(1, { key: "sk-after" }); + if (!edited.ok) throw new Error(`Edit should succeed: ${edited.error}`); + + updateProvidersBatchMock.mockClear(); + + const undone = await undoProviderPatch({ + undoToken: edited.data.undoToken, + operationId: edited.data.operationId, + }); + + expect(undone.ok).toBe(true); + if (!undone.ok) return; + + expect(undone.data.revertedCount).toBe(0); + expect(updateProvidersBatchMock).not.toHaveBeenCalled(); + }); + + it("undoProviderPatch should skip unchanged values in single-edit preimage", async () => { + findProviderByIdMock.mockResolvedValueOnce(makeProvider(1, { name: "Stable Name" })); + updateProviderMock.mockResolvedValueOnce(makeProvider(1, { name: "Stable Name" })); + + const { editProvider, undoProviderPatch } = await import("../../../src/actions/providers"); + + const edited = await editProvider(1, { name: "Stable Name" }); + if (!edited.ok) throw new Error(`Edit should succeed: ${edited.error}`); + + updateProvidersBatchMock.mockClear(); + publishCacheInvalidationMock.mockClear(); + + const undone = await undoProviderPatch({ + undoToken: edited.data.undoToken, + operationId: edited.data.operationId, + }); + + expect(undone.ok).toBe(true); + if (!undone.ok) return; + + expect(undone.data.revertedCount).toBe(0); + expect(updateProvidersBatchMock).not.toHaveBeenCalled(); + expect(publishCacheInvalidationMock).not.toHaveBeenCalled(); + }); + + it("undoProviderPatch should stringify numeric costMultiplier on revert", async () => { + findProviderByIdMock.mockResolvedValueOnce(makeProvider(1, { costMultiplier: 1.25 })); + updateProviderMock.mockResolvedValueOnce(makeProvider(1, { costMultiplier: 2.5 })); + + const { editProvider, undoProviderPatch } = await import("../../../src/actions/providers"); + + const edited = await editProvider(1, { cost_multiplier: 2.5 }); + if (!edited.ok) throw new Error(`Edit should succeed: ${edited.error}`); + + updateProvidersBatchMock.mockClear(); + + const undone = await undoProviderPatch({ + undoToken: edited.data.undoToken, + operationId: edited.data.operationId, + }); + + expect(undone.ok).toBe(true); + if (!undone.ok) return; + + expect(updateProvidersBatchMock).toHaveBeenCalledWith( + [1], + expect.objectContaining({ costMultiplier: "1.25" }) + ); + }); + + it("undoProviderPatch should expire after patch undo TTL", async () => { + vi.useFakeTimers(); + vi.setSystemTime(new Date("2026-02-19T00:00:00.000Z")); + + const { editProvider, undoProviderPatch } = await import("../../../src/actions/providers"); + + const edited = await editProvider(1, { name: "After Name" }); + if (!edited.ok) throw new Error(`Edit should succeed: ${edited.error}`); + + vi.advanceTimersByTime(10_001); + + const undone = await undoProviderPatch({ + undoToken: edited.data.undoToken, + operationId: edited.data.operationId, + }); + + expect(undone.ok).toBe(false); + if (undone.ok) return; + + expect(undone.errorCode).toBe(PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_EXPIRED); + }); + + it("undoProviderPatch should reject mismatched operation id", async () => { + const { editProvider, undoProviderPatch } = await import("../../../src/actions/providers"); + + const edited = await editProvider(1, { name: "After Name" }); + if (!edited.ok) throw new Error(`Edit should succeed: ${edited.error}`); + + const undone = await undoProviderPatch({ + undoToken: edited.data.undoToken, + operationId: `${edited.data.operationId}-mismatch`, + }); + + expect(undone.ok).toBe(false); + if (undone.ok) return; + + expect(undone.errorCode).toBe(PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_CONFLICT); + expect(updateProvidersBatchMock).not.toHaveBeenCalled(); + }); + + it("undoProviderPatch should reject invalid payload", async () => { + const { undoProviderPatch } = await import("../../../src/actions/providers"); + + const undone = await undoProviderPatch({ + undoToken: "", + operationId: "provider_patch_apply_x", + }); + + expect(undone.ok).toBe(false); + if (undone.ok) return; + + expect(undone.errorCode).toBeDefined(); + expect(updateProvidersBatchMock).not.toHaveBeenCalled(); + }); + + it("undoProviderPatch should reject non-admin session", async () => { + getSessionMock.mockResolvedValueOnce({ user: { id: 2, role: "user" } }); + + const { undoProviderPatch } = await import("../../../src/actions/providers"); + + const undone = await undoProviderPatch({ + undoToken: "provider_patch_undo_x", + operationId: "provider_patch_apply_x", + }); + + expect(undone.ok).toBe(false); + if (undone.ok) return; + + expect(undone.error).toBe("无权限执行此操作"); + expect(updateProvidersBatchMock).not.toHaveBeenCalled(); + }); + + it("undoProviderPatch should return repository errors when revert update fails", async () => { + const { editProvider, undoProviderPatch } = await import("../../../src/actions/providers"); + + const edited = await editProvider(1, { name: "After Name" }); + if (!edited.ok) throw new Error(`Edit should succeed: ${edited.error}`); + + updateProvidersBatchMock.mockRejectedValueOnce(new Error("undo write failed")); + + const undone = await undoProviderPatch({ + undoToken: edited.data.undoToken, + operationId: edited.data.operationId, + }); + + expect(undone.ok).toBe(false); + if (undone.ok) return; + + expect(undone.error).toBe("undo write failed"); + }); +}); diff --git a/tests/unit/actions/providers-apply-engine.test.ts b/tests/unit/actions/providers-apply-engine.test.ts new file mode 100644 index 000000000..559f250c9 --- /dev/null +++ b/tests/unit/actions/providers-apply-engine.test.ts @@ -0,0 +1,425 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { PROVIDER_BATCH_PATCH_ERROR_CODES } from "@/lib/provider-batch-patch-error-codes"; + +const getSessionMock = vi.fn(); +const findAllProvidersFreshMock = vi.fn(); +const updateProvidersBatchMock = vi.fn(); +const publishCacheInvalidationMock = vi.fn(); + +vi.mock("@/lib/auth", () => ({ + getSession: getSessionMock, +})); + +vi.mock("@/repository/provider", () => ({ + findAllProvidersFresh: findAllProvidersFreshMock, + updateProvidersBatch: updateProvidersBatchMock, + deleteProvidersBatch: vi.fn(), +})); + +vi.mock("@/lib/cache/provider-cache", () => ({ + publishProviderCacheInvalidation: publishCacheInvalidationMock, +})); + +vi.mock("@/lib/circuit-breaker", () => ({ + clearProviderState: vi.fn(), + clearConfigCache: vi.fn(), + resetCircuit: vi.fn(), + getAllHealthStatusAsync: vi.fn(), +})); + +vi.mock("@/lib/logger", () => ({ + logger: { + trace: vi.fn(), + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + }, +})); + +function makeProvider(id: number, overrides: Record = {}) { + return { + id, + name: `Provider-${id}`, + url: "https://api.example.com/v1", + key: "sk-test", + providerVendorId: null, + isEnabled: true, + weight: 100, + priority: 1, + groupPriorities: null, + costMultiplier: 1.0, + groupTag: null, + providerType: "claude", + preserveClientIp: false, + modelRedirects: null, + allowedModels: null, + mcpPassthroughType: "none", + mcpPassthroughUrl: null, + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + totalCostResetAt: null, + limitConcurrentSessions: null, + maxRetryAttempts: null, + circuitBreakerFailureThreshold: 5, + circuitBreakerOpenDuration: 1800000, + circuitBreakerHalfOpenSuccessThreshold: 2, + proxyUrl: null, + proxyFallbackToDirect: false, + firstByteTimeoutStreamingMs: 30000, + streamingIdleTimeoutMs: 10000, + requestTimeoutNonStreamingMs: 600000, + websiteUrl: null, + faviconUrl: null, + cacheTtlPreference: null, + swapCacheTtlBilling: false, + context1mPreference: null, + codexReasoningEffortPreference: null, + codexReasoningSummaryPreference: null, + codexTextVerbosityPreference: null, + codexParallelToolCallsPreference: null, + anthropicMaxTokensPreference: null, + anthropicThinkingBudgetPreference: null, + anthropicAdaptiveThinking: null, + geminiGoogleSearchPreference: null, + tpm: null, + rpm: null, + rpd: null, + cc: null, + createdAt: new Date("2025-01-01"), + updatedAt: new Date("2025-01-01"), + deletedAt: null, + ...overrides, + }; +} + +describe("Apply Provider Batch Patch Engine", () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.resetModules(); + getSessionMock.mockResolvedValue({ user: { id: 1, role: "admin" } }); + findAllProvidersFreshMock.mockResolvedValue([]); + updateProvidersBatchMock.mockResolvedValue(0); + publishCacheInvalidationMock.mockResolvedValue(undefined); + }); + + /** Helper: create preview then apply with optional overrides */ + async function setupPreviewAndApply( + providerIds: number[], + patch: Record, + applyOverrides: Record = {} + ) { + const { previewProviderBatchPatch, applyProviderBatchPatch } = await import( + "@/actions/providers" + ); + + const preview = await previewProviderBatchPatch({ providerIds, patch }); + if (!preview.ok) throw new Error(`Preview failed: ${preview.error}`); + + const applyInput = { + previewToken: preview.data.previewToken, + previewRevision: preview.data.previewRevision, + providerIds, + patch, + ...applyOverrides, + }; + + const apply = await applyProviderBatchPatch(applyInput); + return { preview, apply, applyProviderBatchPatch }; + } + + it("should call updateProvidersBatch with correct IDs and updates", async () => { + const providers = [makeProvider(1, { groupTag: "old" }), makeProvider(2, { groupTag: "old" })]; + findAllProvidersFreshMock.mockResolvedValue(providers); + updateProvidersBatchMock.mockResolvedValue(2); + + const { apply } = await setupPreviewAndApply([1, 2], { group_tag: { set: "new-group" } }); + + expect(apply.ok).toBe(true); + expect(updateProvidersBatchMock).toHaveBeenCalledOnce(); + expect(updateProvidersBatchMock).toHaveBeenCalledWith( + [1, 2], + expect.objectContaining({ groupTag: "new-group" }) + ); + }); + + it("should publish cache invalidation after successful write", async () => { + findAllProvidersFreshMock.mockResolvedValue([makeProvider(1)]); + updateProvidersBatchMock.mockResolvedValue(1); + + const { apply } = await setupPreviewAndApply([1], { is_enabled: { set: false } }); + + expect(apply.ok).toBe(true); + expect(publishCacheInvalidationMock).toHaveBeenCalledOnce(); + }); + + it("should fetch providers for preimage during apply", async () => { + const providers = [ + makeProvider(1, { groupTag: "alpha", priority: 5 }), + makeProvider(2, { groupTag: "beta", priority: 10 }), + ]; + findAllProvidersFreshMock.mockResolvedValue(providers); + updateProvidersBatchMock.mockResolvedValue(2); + + const { apply } = await setupPreviewAndApply([1, 2], { group_tag: { set: "gamma" } }); + + expect(apply.ok).toBe(true); + // preview calls findAllProvidersFresh once, apply calls it once more + expect(findAllProvidersFreshMock).toHaveBeenCalledTimes(2); + }); + + it("should only apply to non-excluded providers with excludeProviderIds", async () => { + const providers = [ + makeProvider(1, { groupTag: "a" }), + makeProvider(2, { groupTag: "b" }), + makeProvider(3, { groupTag: "c" }), + ]; + findAllProvidersFreshMock.mockResolvedValue(providers); + updateProvidersBatchMock.mockResolvedValue(2); + + const { apply } = await setupPreviewAndApply( + [1, 2, 3], + { group_tag: { set: "unified" } }, + { excludeProviderIds: [2] } + ); + + expect(apply.ok).toBe(true); + expect(updateProvidersBatchMock).toHaveBeenCalledWith( + [1, 3], + expect.objectContaining({ groupTag: "unified" }) + ); + }); + + it("should return NOTHING_TO_APPLY when all providers are excluded", async () => { + findAllProvidersFreshMock.mockResolvedValue([makeProvider(1), makeProvider(2)]); + + const { apply } = await setupPreviewAndApply( + [1, 2], + { group_tag: { set: "x" } }, + { excludeProviderIds: [1, 2] } + ); + + expect(apply.ok).toBe(false); + if (apply.ok) return; + expect(apply.errorCode).toBe(PROVIDER_BATCH_PATCH_ERROR_CODES.NOTHING_TO_APPLY); + expect(updateProvidersBatchMock).not.toHaveBeenCalled(); + }); + + it("should set updatedCount from updateProvidersBatch return value", async () => { + findAllProvidersFreshMock.mockResolvedValue([ + makeProvider(1), + makeProvider(2), + makeProvider(3), + ]); + updateProvidersBatchMock.mockResolvedValue(3); + + const { apply } = await setupPreviewAndApply([1, 2, 3], { weight: { set: 50 } }); + + expect(apply.ok).toBe(true); + if (!apply.ok) return; + expect(apply.data.updatedCount).toBe(3); + }); + + it("should reflect exclusions in updatedCount", async () => { + findAllProvidersFreshMock.mockResolvedValue([ + makeProvider(1), + makeProvider(2), + makeProvider(3), + ]); + updateProvidersBatchMock.mockResolvedValue(2); + + const { apply } = await setupPreviewAndApply( + [1, 2, 3], + { weight: { set: 50 } }, + { excludeProviderIds: [3] } + ); + + expect(apply.ok).toBe(true); + if (!apply.ok) return; + expect(apply.data.updatedCount).toBe(2); + }); + + it("should return PREVIEW_EXPIRED for unknown preview token", async () => { + const { applyProviderBatchPatch } = await import("@/actions/providers"); + + const result = await applyProviderBatchPatch({ + previewToken: "provider_patch_preview_nonexistent", + previewRevision: "rev", + providerIds: [1], + patch: { group_tag: { set: "x" } }, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + expect(result.errorCode).toBe(PROVIDER_BATCH_PATCH_ERROR_CODES.PREVIEW_EXPIRED); + }); + + it("should return PREVIEW_STALE for mismatched patch", async () => { + findAllProvidersFreshMock.mockResolvedValue([makeProvider(1)]); + + const { previewProviderBatchPatch, applyProviderBatchPatch } = await import( + "@/actions/providers" + ); + + const preview = await previewProviderBatchPatch({ + providerIds: [1], + patch: { group_tag: { set: "original" } }, + }); + if (!preview.ok) throw new Error("Preview should succeed"); + + const result = await applyProviderBatchPatch({ + previewToken: preview.data.previewToken, + previewRevision: preview.data.previewRevision, + providerIds: [1], + patch: { group_tag: { set: "different" } }, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + expect(result.errorCode).toBe(PROVIDER_BATCH_PATCH_ERROR_CODES.PREVIEW_STALE); + }); + + it("should return cached result for same idempotencyKey without re-writing to DB", async () => { + findAllProvidersFreshMock.mockResolvedValue([makeProvider(1), makeProvider(2)]); + updateProvidersBatchMock.mockResolvedValue(2); + + const { previewProviderBatchPatch, applyProviderBatchPatch } = await import( + "@/actions/providers" + ); + + const preview = await previewProviderBatchPatch({ + providerIds: [1, 2], + patch: { group_tag: { set: "idem" } }, + }); + if (!preview.ok) throw new Error("Preview should succeed"); + + const applyInput = { + previewToken: preview.data.previewToken, + previewRevision: preview.data.previewRevision, + providerIds: [1, 2], + patch: { group_tag: { set: "idem" } }, + idempotencyKey: "idem-key-1", + }; + + const first = await applyProviderBatchPatch(applyInput); + const second = await applyProviderBatchPatch(applyInput); + + expect(first.ok).toBe(true); + expect(second.ok).toBe(true); + if (!first.ok || !second.ok) return; + + expect(second.data.operationId).toBe(first.data.operationId); + expect(updateProvidersBatchMock).toHaveBeenCalledOnce(); + }); + + it("should prevent double-apply by marking snapshot as applied", async () => { + findAllProvidersFreshMock.mockResolvedValue([makeProvider(1)]); + updateProvidersBatchMock.mockResolvedValue(1); + + const { previewProviderBatchPatch, applyProviderBatchPatch } = await import( + "@/actions/providers" + ); + + const preview = await previewProviderBatchPatch({ + providerIds: [1], + patch: { group_tag: { set: "x" } }, + }); + if (!preview.ok) throw new Error("Preview should succeed"); + + const applyInput = { + previewToken: preview.data.previewToken, + previewRevision: preview.data.previewRevision, + providerIds: [1], + patch: { group_tag: { set: "x" } }, + }; + + const first = await applyProviderBatchPatch(applyInput); + const second = await applyProviderBatchPatch(applyInput); + + expect(first.ok).toBe(true); + expect(second.ok).toBe(false); + if (second.ok) return; + expect(second.errorCode).toBe(PROVIDER_BATCH_PATCH_ERROR_CODES.PREVIEW_STALE); + }); + + it("should map cost_multiplier to string for repository", async () => { + findAllProvidersFreshMock.mockResolvedValue([makeProvider(1, { costMultiplier: 1.0 })]); + updateProvidersBatchMock.mockResolvedValue(1); + + const { apply } = await setupPreviewAndApply([1], { cost_multiplier: { set: 2.5 } }); + + expect(apply.ok).toBe(true); + expect(updateProvidersBatchMock).toHaveBeenCalledWith( + [1], + expect.objectContaining({ costMultiplier: "2.5" }) + ); + }); + + it("should map multiple fields correctly to repository format", async () => { + findAllProvidersFreshMock.mockResolvedValue([ + makeProvider(1, { groupTag: "old", weight: 100, priority: 1 }), + ]); + updateProvidersBatchMock.mockResolvedValue(1); + + const { apply } = await setupPreviewAndApply([1], { + group_tag: { set: "new" }, + weight: { set: 80 }, + priority: { set: 5 }, + }); + + expect(apply.ok).toBe(true); + expect(updateProvidersBatchMock).toHaveBeenCalledWith( + [1], + expect.objectContaining({ + groupTag: "new", + weight: 80, + priority: 5, + }) + ); + }); + + it("should map clear mode to null for clearable fields", async () => { + findAllProvidersFreshMock.mockResolvedValue([ + makeProvider(1, { groupTag: "has-tag", modelRedirects: { a: "b" } }), + ]); + updateProvidersBatchMock.mockResolvedValue(1); + + const { apply } = await setupPreviewAndApply([1], { + group_tag: { clear: true }, + model_redirects: { clear: true }, + }); + + expect(apply.ok).toBe(true); + expect(updateProvidersBatchMock).toHaveBeenCalledWith( + [1], + expect.objectContaining({ + groupTag: null, + modelRedirects: null, + }) + ); + }); + + it("should map anthropic_thinking_budget_preference clear to inherit", async () => { + findAllProvidersFreshMock.mockResolvedValue([ + makeProvider(1, { anthropicThinkingBudgetPreference: "8192" }), + ]); + updateProvidersBatchMock.mockResolvedValue(1); + + const { apply } = await setupPreviewAndApply([1], { + anthropic_thinking_budget_preference: { clear: true }, + }); + + expect(apply.ok).toBe(true); + expect(updateProvidersBatchMock).toHaveBeenCalledWith( + [1], + expect.objectContaining({ + anthropicThinkingBudgetPreference: "inherit", + }) + ); + }); +}); diff --git a/tests/unit/actions/providers-batch-field-mapping.test.ts b/tests/unit/actions/providers-batch-field-mapping.test.ts new file mode 100644 index 000000000..d304ef4e2 --- /dev/null +++ b/tests/unit/actions/providers-batch-field-mapping.test.ts @@ -0,0 +1,256 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +const getSessionMock = vi.fn(); + +const updateProvidersBatchMock = vi.fn(); + +const publishProviderCacheInvalidationMock = vi.fn(); + +vi.mock("@/lib/auth", () => ({ + getSession: getSessionMock, +})); + +vi.mock("@/repository/provider", () => ({ + updateProvidersBatch: updateProvidersBatchMock, +})); + +vi.mock("@/lib/cache/provider-cache", () => ({ + publishProviderCacheInvalidation: publishProviderCacheInvalidationMock, +})); + +vi.mock("@/lib/logger", () => ({ + logger: { + trace: vi.fn(), + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + }, +})); + +describe("batchUpdateProviders - advanced field mapping", () => { + beforeEach(() => { + vi.clearAllMocks(); + + getSessionMock.mockResolvedValue({ user: { id: 1, role: "admin" } }); + updateProvidersBatchMock.mockResolvedValue(2); + publishProviderCacheInvalidationMock.mockResolvedValue(undefined); + }); + + it("should still map basic fields correctly (backward compat)", async () => { + const { batchUpdateProviders } = await import("@/actions/providers"); + const result = await batchUpdateProviders({ + providerIds: [1, 2], + updates: { + is_enabled: true, + priority: 3, + weight: 5, + cost_multiplier: 1.2, + group_tag: "legacy", + }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.updatedCount).toBe(2); + expect(updateProvidersBatchMock).toHaveBeenCalledWith([1, 2], { + isEnabled: true, + priority: 3, + weight: 5, + costMultiplier: "1.2", + groupTag: "legacy", + }); + }); + + it("should map model_redirects to repository modelRedirects", async () => { + const redirects = { "claude-3-opus": "claude-3.5-sonnet", "gpt-4": "gpt-4o" }; + + const { batchUpdateProviders } = await import("@/actions/providers"); + const result = await batchUpdateProviders({ + providerIds: [10, 20], + updates: { model_redirects: redirects }, + }); + + expect(result.ok).toBe(true); + expect(updateProvidersBatchMock).toHaveBeenCalledWith([10, 20], { + modelRedirects: redirects, + }); + }); + + it("should map model_redirects=null to repository modelRedirects=null", async () => { + const { batchUpdateProviders } = await import("@/actions/providers"); + const result = await batchUpdateProviders({ + providerIds: [5], + updates: { model_redirects: null }, + }); + + expect(result.ok).toBe(true); + expect(updateProvidersBatchMock).toHaveBeenCalledWith([5], { + modelRedirects: null, + }); + }); + + it("should map allowed_models with values correctly", async () => { + const { batchUpdateProviders } = await import("@/actions/providers"); + const result = await batchUpdateProviders({ + providerIds: [1, 2], + updates: { allowed_models: ["model-a", "model-b"] }, + }); + + expect(result.ok).toBe(true); + expect(updateProvidersBatchMock).toHaveBeenCalledWith([1, 2], { + allowedModels: ["model-a", "model-b"], + }); + }); + + it("should normalize allowed_models=[] to null (allow-all)", async () => { + const { batchUpdateProviders } = await import("@/actions/providers"); + const result = await batchUpdateProviders({ + providerIds: [1], + updates: { allowed_models: [] }, + }); + + expect(result.ok).toBe(true); + expect(updateProvidersBatchMock).toHaveBeenCalledWith([1], { + allowedModels: null, + }); + }); + + it("should map allowed_models=null to repository allowedModels=null", async () => { + const { batchUpdateProviders } = await import("@/actions/providers"); + const result = await batchUpdateProviders({ + providerIds: [3], + updates: { allowed_models: null }, + }); + + expect(result.ok).toBe(true); + expect(updateProvidersBatchMock).toHaveBeenCalledWith([3], { + allowedModels: null, + }); + }); + + it("should map anthropic_thinking_budget_preference correctly", async () => { + const { batchUpdateProviders } = await import("@/actions/providers"); + const result = await batchUpdateProviders({ + providerIds: [7, 8], + updates: { anthropic_thinking_budget_preference: "10000" }, + }); + + expect(result.ok).toBe(true); + expect(updateProvidersBatchMock).toHaveBeenCalledWith([7, 8], { + anthropicThinkingBudgetPreference: "10000", + }); + }); + + it("should map anthropic_thinking_budget_preference=inherit correctly", async () => { + const { batchUpdateProviders } = await import("@/actions/providers"); + const result = await batchUpdateProviders({ + providerIds: [1], + updates: { anthropic_thinking_budget_preference: "inherit" }, + }); + + expect(result.ok).toBe(true); + expect(updateProvidersBatchMock).toHaveBeenCalledWith([1], { + anthropicThinkingBudgetPreference: "inherit", + }); + }); + + it("should map anthropic_thinking_budget_preference=null correctly", async () => { + const { batchUpdateProviders } = await import("@/actions/providers"); + const result = await batchUpdateProviders({ + providerIds: [1], + updates: { anthropic_thinking_budget_preference: null }, + }); + + expect(result.ok).toBe(true); + expect(updateProvidersBatchMock).toHaveBeenCalledWith([1], { + anthropicThinkingBudgetPreference: null, + }); + }); + + it("should map anthropic_adaptive_thinking config correctly", async () => { + const config = { + effort: "high" as const, + modelMatchMode: "all" as const, + models: [], + }; + + const { batchUpdateProviders } = await import("@/actions/providers"); + const result = await batchUpdateProviders({ + providerIds: [4, 5], + updates: { anthropic_adaptive_thinking: config }, + }); + + expect(result.ok).toBe(true); + expect(updateProvidersBatchMock).toHaveBeenCalledWith([4, 5], { + anthropicAdaptiveThinking: config, + }); + }); + + it("should map anthropic_adaptive_thinking=null correctly", async () => { + const { batchUpdateProviders } = await import("@/actions/providers"); + const result = await batchUpdateProviders({ + providerIds: [6], + updates: { anthropic_adaptive_thinking: null }, + }); + + expect(result.ok).toBe(true); + expect(updateProvidersBatchMock).toHaveBeenCalledWith([6], { + anthropicAdaptiveThinking: null, + }); + }); + + it("should handle mix of old and new fields together", async () => { + const adaptiveConfig = { + effort: "medium" as const, + modelMatchMode: "specific" as const, + models: ["claude-3-opus", "claude-3.5-sonnet"], + }; + + const { batchUpdateProviders } = await import("@/actions/providers"); + const result = await batchUpdateProviders({ + providerIds: [1, 2, 3], + updates: { + is_enabled: true, + priority: 10, + weight: 3, + cost_multiplier: 0.8, + group_tag: "mixed-batch", + model_redirects: { "old-model": "new-model" }, + allowed_models: ["claude-3-opus"], + anthropic_thinking_budget_preference: "5000", + anthropic_adaptive_thinking: adaptiveConfig, + }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.updatedCount).toBe(2); + expect(updateProvidersBatchMock).toHaveBeenCalledWith([1, 2, 3], { + isEnabled: true, + priority: 10, + weight: 3, + costMultiplier: "0.8", + groupTag: "mixed-batch", + modelRedirects: { "old-model": "new-model" }, + allowedModels: ["claude-3-opus"], + anthropicThinkingBudgetPreference: "5000", + anthropicAdaptiveThinking: adaptiveConfig, + }); + }); + + it("should detect new fields as valid updates (not reject as empty)", async () => { + const { batchUpdateProviders } = await import("@/actions/providers"); + + // Only new fields, no old fields -- must still be treated as having updates + const result = await batchUpdateProviders({ + providerIds: [1], + updates: { anthropic_thinking_budget_preference: "inherit" }, + }); + + expect(result.ok).toBe(true); + expect(updateProvidersBatchMock).toHaveBeenCalledTimes(1); + }); +}); diff --git a/tests/unit/actions/providers-patch-actions-contract.test.ts b/tests/unit/actions/providers-patch-actions-contract.test.ts new file mode 100644 index 000000000..a760b3513 --- /dev/null +++ b/tests/unit/actions/providers-patch-actions-contract.test.ts @@ -0,0 +1,305 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { PROVIDER_BATCH_PATCH_ERROR_CODES } from "@/lib/provider-batch-patch-error-codes"; + +const getSessionMock = vi.fn(); +const findAllProvidersFreshMock = vi.fn(); +const updateProvidersBatchMock = vi.fn(); + +vi.mock("@/lib/auth", () => ({ + getSession: getSessionMock, +})); + +vi.mock("@/repository/provider", () => ({ + findAllProvidersFresh: findAllProvidersFreshMock, + updateProvidersBatch: updateProvidersBatchMock, + deleteProvidersBatch: vi.fn(), +})); + +vi.mock("@/lib/cache/provider-cache", () => ({ + publishProviderCacheInvalidation: vi.fn(), +})); + +vi.mock("@/lib/circuit-breaker", () => ({ + clearProviderState: vi.fn(), + clearConfigCache: vi.fn(), + resetCircuit: vi.fn(), +})); + +vi.mock("@/lib/logger", () => ({ + logger: { + trace: vi.fn(), + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + }, +})); + +function makeProvider(id: number, overrides: Record = {}) { + return { + id, + name: `Provider-${id}`, + url: "https://api.example.com/v1", + key: "sk-test", + providerVendorId: null, + isEnabled: true, + weight: 100, + priority: 1, + groupPriorities: null, + costMultiplier: 1.0, + groupTag: null, + providerType: "claude", + preserveClientIp: false, + modelRedirects: null, + allowedModels: null, + mcpPassthroughType: "none", + mcpPassthroughUrl: null, + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + totalCostResetAt: null, + limitConcurrentSessions: null, + maxRetryAttempts: null, + circuitBreakerFailureThreshold: 5, + circuitBreakerOpenDuration: 1800000, + circuitBreakerHalfOpenSuccessThreshold: 2, + proxyUrl: null, + proxyFallbackToDirect: false, + firstByteTimeoutStreamingMs: 30000, + streamingIdleTimeoutMs: 10000, + requestTimeoutNonStreamingMs: 600000, + websiteUrl: null, + faviconUrl: null, + cacheTtlPreference: null, + swapCacheTtlBilling: false, + context1mPreference: null, + codexReasoningEffortPreference: null, + codexReasoningSummaryPreference: null, + codexTextVerbosityPreference: null, + codexParallelToolCallsPreference: null, + anthropicMaxTokensPreference: null, + anthropicThinkingBudgetPreference: null, + anthropicAdaptiveThinking: null, + geminiGoogleSearchPreference: null, + tpm: null, + rpm: null, + rpd: null, + cc: null, + createdAt: new Date("2025-01-01"), + updatedAt: new Date("2025-01-01"), + deletedAt: null, + ...overrides, + }; +} + +describe("Provider Batch Patch Action Contracts", () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.resetModules(); + getSessionMock.mockResolvedValue({ user: { id: 1, role: "admin" } }); + findAllProvidersFreshMock.mockResolvedValue([]); + updateProvidersBatchMock.mockResolvedValue(0); + }); + + it("previewProviderBatchPatch should require admin role", async () => { + getSessionMock.mockResolvedValueOnce({ user: { id: 2, role: "user" } }); + + const { previewProviderBatchPatch } = await import("@/actions/providers"); + const result = await previewProviderBatchPatch({ + providerIds: [1, 2], + patch: { group_tag: { set: "ops" } }, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error).toBe("无权限执行此操作"); + }); + + it("previewProviderBatchPatch should return structured preview payload", async () => { + const { previewProviderBatchPatch } = await import("@/actions/providers"); + const result = await previewProviderBatchPatch({ + providerIds: [3, 1, 3, 2], + patch: { + group_tag: { set: "blue" }, + allowed_models: { clear: true }, + }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.providerIds).toEqual([1, 2, 3]); + expect(result.data.summary.providerCount).toBe(3); + expect(result.data.summary.fieldCount).toBe(2); + expect(result.data.changedFields).toEqual(["group_tag", "allowed_models"]); + expect(result.data.previewToken).toMatch(/^provider_patch_preview_/); + expect(result.data.previewRevision.length).toBeGreaterThan(0); + expect(result.data.previewExpiresAt.length).toBeGreaterThan(0); + }); + + it("previewProviderBatchPatch should return NOTHING_TO_APPLY when patch has no changes", async () => { + const { previewProviderBatchPatch } = await import("@/actions/providers"); + const result = await previewProviderBatchPatch({ + providerIds: [1], + patch: { group_tag: { no_change: true } }, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.errorCode).toBe(PROVIDER_BATCH_PATCH_ERROR_CODES.NOTHING_TO_APPLY); + }); + + it("applyProviderBatchPatch should reject unknown preview token", async () => { + const { applyProviderBatchPatch } = await import("@/actions/providers"); + const result = await applyProviderBatchPatch({ + previewToken: "provider_patch_preview_missing", + previewRevision: "rev", + providerIds: [1], + patch: { group_tag: { set: "x" } }, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.errorCode).toBe(PROVIDER_BATCH_PATCH_ERROR_CODES.PREVIEW_EXPIRED); + }); + + it("applyProviderBatchPatch should reject stale revision", async () => { + const { previewProviderBatchPatch, applyProviderBatchPatch } = await import( + "@/actions/providers" + ); + const preview = await previewProviderBatchPatch({ + providerIds: [1], + patch: { group_tag: { set: "x" } }, + }); + if (!preview.ok) throw new Error("Preview should be ok in test setup"); + + const apply = await applyProviderBatchPatch({ + previewToken: preview.data.previewToken, + previewRevision: `${preview.data.previewRevision}-stale`, + providerIds: [1], + patch: { group_tag: { set: "x" } }, + }); + + expect(apply.ok).toBe(false); + if (apply.ok) return; + + expect(apply.errorCode).toBe(PROVIDER_BATCH_PATCH_ERROR_CODES.PREVIEW_STALE); + }); + + it("applyProviderBatchPatch should return idempotent result for same idempotency key", async () => { + const { previewProviderBatchPatch, applyProviderBatchPatch } = await import( + "@/actions/providers" + ); + const preview = await previewProviderBatchPatch({ + providerIds: [1, 2], + patch: { group_tag: { set: "x" } }, + }); + if (!preview.ok) throw new Error("Preview should be ok in test setup"); + + const firstApply = await applyProviderBatchPatch({ + previewToken: preview.data.previewToken, + previewRevision: preview.data.previewRevision, + providerIds: [1, 2], + patch: { group_tag: { set: "x" } }, + idempotencyKey: "idempotency-key-1", + }); + const secondApply = await applyProviderBatchPatch({ + previewToken: preview.data.previewToken, + previewRevision: preview.data.previewRevision, + providerIds: [1, 2], + patch: { group_tag: { set: "x" } }, + idempotencyKey: "idempotency-key-1", + }); + + expect(firstApply.ok).toBe(true); + expect(secondApply.ok).toBe(true); + if (!firstApply.ok || !secondApply.ok) return; + + expect(secondApply.data.operationId).toBe(firstApply.data.operationId); + expect(secondApply.data.undoToken).toBe(firstApply.data.undoToken); + }); + + it("undoProviderPatch should reject mismatched operation id", async () => { + const { previewProviderBatchPatch, applyProviderBatchPatch, undoProviderPatch } = await import( + "@/actions/providers" + ); + + const preview = await previewProviderBatchPatch({ + providerIds: [10], + patch: { group_tag: { set: "undo-test" } }, + }); + if (!preview.ok) throw new Error("Preview should be ok in test setup"); + + const apply = await applyProviderBatchPatch({ + previewToken: preview.data.previewToken, + previewRevision: preview.data.previewRevision, + providerIds: [10], + patch: { group_tag: { set: "undo-test" } }, + idempotencyKey: "undo-case", + }); + if (!apply.ok) throw new Error("Apply should be ok in test setup"); + + const undo = await undoProviderPatch({ + undoToken: apply.data.undoToken, + operationId: `${apply.data.operationId}-invalid`, + }); + + expect(undo.ok).toBe(false); + if (undo.ok) return; + + expect(undo.errorCode).toBe(PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_CONFLICT); + }); + + it("undoProviderPatch should consume token on success", async () => { + findAllProvidersFreshMock.mockResolvedValue([ + makeProvider(12, { groupTag: "before-12" }), + makeProvider(13, { groupTag: "before-13" }), + ]); + updateProvidersBatchMock.mockResolvedValue(1); + + const { previewProviderBatchPatch, applyProviderBatchPatch, undoProviderPatch } = await import( + "@/actions/providers" + ); + + const preview = await previewProviderBatchPatch({ + providerIds: [12, 13], + patch: { group_tag: { set: "rollback" } }, + }); + if (!preview.ok) throw new Error("Preview should be ok in test setup"); + + const apply = await applyProviderBatchPatch({ + previewToken: preview.data.previewToken, + previewRevision: preview.data.previewRevision, + providerIds: [12, 13], + patch: { group_tag: { set: "rollback" } }, + idempotencyKey: "undo-consume", + }); + if (!apply.ok) throw new Error("Apply should be ok in test setup"); + + const firstUndo = await undoProviderPatch({ + undoToken: apply.data.undoToken, + operationId: apply.data.operationId, + }); + const secondUndo = await undoProviderPatch({ + undoToken: apply.data.undoToken, + operationId: apply.data.operationId, + }); + + expect(firstUndo.ok).toBe(true); + if (firstUndo.ok) { + expect(firstUndo.data.revertedCount).toBe(2); + } + + expect(secondUndo.ok).toBe(false); + if (secondUndo.ok) return; + + expect(secondUndo.errorCode).toBe(PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_EXPIRED); + }); +}); diff --git a/tests/unit/actions/providers-patch-contract.test.ts b/tests/unit/actions/providers-patch-contract.test.ts new file mode 100644 index 000000000..6e93a065a --- /dev/null +++ b/tests/unit/actions/providers-patch-contract.test.ts @@ -0,0 +1,922 @@ +import { describe, expect, it } from "vitest"; +import { + buildProviderBatchApplyUpdates, + hasProviderBatchPatchChanges, + normalizeProviderBatchPatchDraft, + prepareProviderBatchApplyUpdates, + PROVIDER_PATCH_ERROR_CODES, +} from "@/lib/provider-patch-contract"; + +describe("provider patch contract", () => { + it("normalizes undefined fields as no_change and omits them from apply payload", () => { + const normalized = normalizeProviderBatchPatchDraft({}); + + expect(normalized.ok).toBe(true); + if (!normalized.ok) return; + + expect(normalized.data.group_tag.mode).toBe("no_change"); + expect(hasProviderBatchPatchChanges(normalized.data)).toBe(false); + + const applyPayload = buildProviderBatchApplyUpdates(normalized.data); + expect(applyPayload.ok).toBe(true); + if (!applyPayload.ok) return; + + expect(applyPayload.data).toEqual({}); + }); + + it("serializes set and clear with distinct payload shapes", () => { + const setResult = prepareProviderBatchApplyUpdates({ + group_tag: { set: "primary" }, + allowed_models: { set: ["claude-3-7-sonnet"] }, + }); + const clearResult = prepareProviderBatchApplyUpdates({ + group_tag: { clear: true }, + allowed_models: { clear: true }, + }); + + expect(setResult.ok).toBe(true); + if (!setResult.ok) return; + + expect(clearResult.ok).toBe(true); + if (!clearResult.ok) return; + + expect(setResult.data.group_tag).toBe("primary"); + expect(clearResult.data.group_tag).toBeNull(); + expect(setResult.data.allowed_models).toEqual(["claude-3-7-sonnet"]); + expect(clearResult.data.allowed_models).toBeNull(); + }); + + it("maps empty allowed_models set payload to null", () => { + const result = prepareProviderBatchApplyUpdates({ + allowed_models: { set: [] }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.allowed_models).toBeNull(); + }); + + it("maps thinking budget clear to inherit", () => { + const result = prepareProviderBatchApplyUpdates({ + anthropic_thinking_budget_preference: { clear: true }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.anthropic_thinking_budget_preference).toBe("inherit"); + }); + + it("rejects conflicting set and clear modes", () => { + const result = normalizeProviderBatchPatchDraft({ + group_tag: { + set: "ops", + clear: true, + } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.code).toBe(PROVIDER_PATCH_ERROR_CODES.INVALID_PATCH_SHAPE); + expect(result.error.field).toBe("group_tag"); + }); + + it("rejects clear on non-clearable fields", () => { + const result = normalizeProviderBatchPatchDraft({ + priority: { + clear: true, + } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.code).toBe(PROVIDER_PATCH_ERROR_CODES.INVALID_PATCH_SHAPE); + expect(result.error.field).toBe("priority"); + }); + + it("rejects invalid set runtime shape", () => { + const result = normalizeProviderBatchPatchDraft({ + weight: { + set: null, + } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.code).toBe(PROVIDER_PATCH_ERROR_CODES.INVALID_PATCH_SHAPE); + expect(result.error.field).toBe("weight"); + }); + + it("rejects model_redirects arrays", () => { + const result = normalizeProviderBatchPatchDraft({ + model_redirects: { + set: ["not-a-record"], + } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.code).toBe(PROVIDER_PATCH_ERROR_CODES.INVALID_PATCH_SHAPE); + expect(result.error.field).toBe("model_redirects"); + }); + + it("rejects invalid thinking budget string values", () => { + const result = normalizeProviderBatchPatchDraft({ + anthropic_thinking_budget_preference: { + set: "abc", + } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.code).toBe(PROVIDER_PATCH_ERROR_CODES.INVALID_PATCH_SHAPE); + expect(result.error.field).toBe("anthropic_thinking_budget_preference"); + }); + + it("rejects adaptive thinking specific mode with empty models", () => { + const result = normalizeProviderBatchPatchDraft({ + anthropic_adaptive_thinking: { + set: { + effort: "high", + modelMatchMode: "specific", + models: [], + }, + }, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.code).toBe(PROVIDER_PATCH_ERROR_CODES.INVALID_PATCH_SHAPE); + expect(result.error.field).toBe("anthropic_adaptive_thinking"); + }); + + it("supports explicit no_change mode", () => { + const result = normalizeProviderBatchPatchDraft({ + model_redirects: { no_change: true }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.model_redirects.mode).toBe("no_change"); + }); + + it("rejects unknown top-level fields", () => { + const result = normalizeProviderBatchPatchDraft({ + unknown_field: { set: 1 }, + } as never); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.code).toBe(PROVIDER_PATCH_ERROR_CODES.INVALID_PATCH_SHAPE); + expect(result.error.field).toBe("__root__"); + }); + + it("rejects non-object draft payloads", () => { + const result = normalizeProviderBatchPatchDraft(null as never); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.code).toBe(PROVIDER_PATCH_ERROR_CODES.INVALID_PATCH_SHAPE); + expect(result.error.field).toBe("__root__"); + }); + + describe("routing fields", () => { + it("accepts boolean set for preserve_client_ip and swap_cache_ttl_billing", () => { + const result = prepareProviderBatchApplyUpdates({ + preserve_client_ip: { set: true }, + swap_cache_ttl_billing: { set: false }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.preserve_client_ip).toBe(true); + expect(result.data.swap_cache_ttl_billing).toBe(false); + }); + + it("accepts group_priorities as Record", () => { + const result = prepareProviderBatchApplyUpdates({ + group_priorities: { set: { us: 10, eu: 5 } }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.group_priorities).toEqual({ us: 10, eu: 5 }); + }); + + it("rejects group_priorities with non-number values", () => { + const result = normalizeProviderBatchPatchDraft({ + group_priorities: { set: { us: "high" } } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe("group_priorities"); + }); + + it("rejects group_priorities when array", () => { + const result = normalizeProviderBatchPatchDraft({ + group_priorities: { set: [1, 2, 3] } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe("group_priorities"); + }); + + it("clears group_priorities to null", () => { + const result = prepareProviderBatchApplyUpdates({ + group_priorities: { clear: true }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.group_priorities).toBeNull(); + }); + + it.each([ + ["cache_ttl_preference", "inherit"], + ["cache_ttl_preference", "5m"], + ["cache_ttl_preference", "1h"], + ] as const)("accepts valid %s value: %s", (field, value) => { + const result = prepareProviderBatchApplyUpdates({ + [field]: { set: value }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data[field]).toBe(value); + }); + + it("rejects invalid cache_ttl_preference value", () => { + const result = normalizeProviderBatchPatchDraft({ + cache_ttl_preference: { set: "30m" } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe("cache_ttl_preference"); + }); + + it.each([ + ["context_1m_preference", "inherit"], + ["context_1m_preference", "force_enable"], + ["context_1m_preference", "disabled"], + ] as const)("accepts valid %s value: %s", (field, value) => { + const result = prepareProviderBatchApplyUpdates({ + [field]: { set: value }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data[field]).toBe(value); + }); + + it.each([ + ["codex_reasoning_effort_preference", "inherit"], + ["codex_reasoning_effort_preference", "none"], + ["codex_reasoning_effort_preference", "minimal"], + ["codex_reasoning_effort_preference", "low"], + ["codex_reasoning_effort_preference", "medium"], + ["codex_reasoning_effort_preference", "high"], + ["codex_reasoning_effort_preference", "xhigh"], + ] as const)("accepts valid %s value: %s", (field, value) => { + const result = prepareProviderBatchApplyUpdates({ + [field]: { set: value }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data[field]).toBe(value); + }); + + it("rejects invalid codex_reasoning_effort_preference value", () => { + const result = normalizeProviderBatchPatchDraft({ + codex_reasoning_effort_preference: { set: "ultra" } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe("codex_reasoning_effort_preference"); + }); + + it.each([ + ["codex_reasoning_summary_preference", "inherit"], + ["codex_reasoning_summary_preference", "auto"], + ["codex_reasoning_summary_preference", "detailed"], + ] as const)("accepts valid %s value: %s", (field, value) => { + const result = prepareProviderBatchApplyUpdates({ + [field]: { set: value }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data[field]).toBe(value); + }); + + it.each([ + ["codex_text_verbosity_preference", "inherit"], + ["codex_text_verbosity_preference", "low"], + ["codex_text_verbosity_preference", "medium"], + ["codex_text_verbosity_preference", "high"], + ] as const)("accepts valid %s value: %s", (field, value) => { + const result = prepareProviderBatchApplyUpdates({ + [field]: { set: value }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data[field]).toBe(value); + }); + + it.each([ + ["codex_parallel_tool_calls_preference", "inherit"], + ["codex_parallel_tool_calls_preference", "true"], + ["codex_parallel_tool_calls_preference", "false"], + ] as const)("accepts valid %s value: %s", (field, value) => { + const result = prepareProviderBatchApplyUpdates({ + [field]: { set: value }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data[field]).toBe(value); + }); + + it.each([ + ["gemini_google_search_preference", "inherit"], + ["gemini_google_search_preference", "enabled"], + ["gemini_google_search_preference", "disabled"], + ] as const)("accepts valid %s value: %s", (field, value) => { + const result = prepareProviderBatchApplyUpdates({ + [field]: { set: value }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data[field]).toBe(value); + }); + + it("rejects invalid gemini_google_search_preference value", () => { + const result = normalizeProviderBatchPatchDraft({ + gemini_google_search_preference: { set: "auto" } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe("gemini_google_search_preference"); + }); + }); + + describe("anthropic_max_tokens_preference", () => { + it("accepts inherit", () => { + const result = prepareProviderBatchApplyUpdates({ + anthropic_max_tokens_preference: { set: "inherit" }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.anthropic_max_tokens_preference).toBe("inherit"); + }); + + it("accepts positive numeric string", () => { + const result = prepareProviderBatchApplyUpdates({ + anthropic_max_tokens_preference: { set: "8192" }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.anthropic_max_tokens_preference).toBe("8192"); + }); + + it("accepts small positive numeric string (no range restriction)", () => { + const result = prepareProviderBatchApplyUpdates({ + anthropic_max_tokens_preference: { set: "1" }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.anthropic_max_tokens_preference).toBe("1"); + }); + + it("rejects non-numeric string", () => { + const result = normalizeProviderBatchPatchDraft({ + anthropic_max_tokens_preference: { set: "abc" } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe("anthropic_max_tokens_preference"); + }); + + it("rejects zero", () => { + const result = normalizeProviderBatchPatchDraft({ + anthropic_max_tokens_preference: { set: "0" } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe("anthropic_max_tokens_preference"); + }); + + it("clears to inherit", () => { + const result = prepareProviderBatchApplyUpdates({ + anthropic_max_tokens_preference: { clear: true }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.anthropic_max_tokens_preference).toBe("inherit"); + }); + }); + + describe("rate limit fields", () => { + it.each([ + "limit_5h_usd", + "limit_daily_usd", + "limit_weekly_usd", + "limit_monthly_usd", + "limit_total_usd", + ] as const)("accepts number set and clears to null for %s", (field) => { + const setResult = prepareProviderBatchApplyUpdates({ + [field]: { set: 100.5 }, + }); + + expect(setResult.ok).toBe(true); + if (!setResult.ok) return; + + expect(setResult.data[field]).toBe(100.5); + + const clearResult = prepareProviderBatchApplyUpdates({ + [field]: { clear: true }, + }); + + expect(clearResult.ok).toBe(true); + if (!clearResult.ok) return; + + expect(clearResult.data[field]).toBeNull(); + }); + + it("rejects non-number for limit_5h_usd", () => { + const result = normalizeProviderBatchPatchDraft({ + limit_5h_usd: { set: "100" } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe("limit_5h_usd"); + }); + + it("rejects NaN for number fields", () => { + const result = normalizeProviderBatchPatchDraft({ + limit_daily_usd: { set: Number.NaN } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe("limit_daily_usd"); + }); + + it("rejects Infinity for number fields", () => { + const result = normalizeProviderBatchPatchDraft({ + limit_weekly_usd: { set: Number.POSITIVE_INFINITY } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe("limit_weekly_usd"); + }); + + it("accepts limit_concurrent_sessions as number (non-clearable)", () => { + const result = prepareProviderBatchApplyUpdates({ + limit_concurrent_sessions: { set: 5 }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.limit_concurrent_sessions).toBe(5); + }); + + it("rejects clear on limit_concurrent_sessions", () => { + const result = normalizeProviderBatchPatchDraft({ + limit_concurrent_sessions: { clear: true } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe("limit_concurrent_sessions"); + }); + + it.each(["fixed", "rolling"] as const)("accepts daily_reset_mode value: %s", (value) => { + const result = prepareProviderBatchApplyUpdates({ + daily_reset_mode: { set: value }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.daily_reset_mode).toBe(value); + }); + + it("rejects invalid daily_reset_mode value", () => { + const result = normalizeProviderBatchPatchDraft({ + daily_reset_mode: { set: "hourly" } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe("daily_reset_mode"); + }); + + it("rejects clear on daily_reset_mode", () => { + const result = normalizeProviderBatchPatchDraft({ + daily_reset_mode: { clear: true } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe("daily_reset_mode"); + }); + + it("accepts daily_reset_time as string (non-clearable)", () => { + const result = prepareProviderBatchApplyUpdates({ + daily_reset_time: { set: "00:00" }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.daily_reset_time).toBe("00:00"); + }); + + it("rejects clear on daily_reset_time", () => { + const result = normalizeProviderBatchPatchDraft({ + daily_reset_time: { clear: true } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe("daily_reset_time"); + }); + }); + + describe("circuit breaker fields", () => { + it.each([ + "circuit_breaker_failure_threshold", + "circuit_breaker_open_duration", + "circuit_breaker_half_open_success_threshold", + ] as const)("accepts number set for %s (non-clearable)", (field) => { + const result = prepareProviderBatchApplyUpdates({ + [field]: { set: 10 }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data[field]).toBe(10); + }); + + it.each([ + "circuit_breaker_failure_threshold", + "circuit_breaker_open_duration", + "circuit_breaker_half_open_success_threshold", + ] as const)("rejects clear on %s", (field) => { + const result = normalizeProviderBatchPatchDraft({ + [field]: { clear: true } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe(field); + }); + + it("accepts max_retry_attempts and clears to null", () => { + const setResult = prepareProviderBatchApplyUpdates({ + max_retry_attempts: { set: 3 }, + }); + + expect(setResult.ok).toBe(true); + if (!setResult.ok) return; + + expect(setResult.data.max_retry_attempts).toBe(3); + + const clearResult = prepareProviderBatchApplyUpdates({ + max_retry_attempts: { clear: true }, + }); + + expect(clearResult.ok).toBe(true); + if (!clearResult.ok) return; + + expect(clearResult.data.max_retry_attempts).toBeNull(); + }); + }); + + describe("network fields", () => { + it("accepts proxy_url as string and clears to null", () => { + const setResult = prepareProviderBatchApplyUpdates({ + proxy_url: { set: "socks5://proxy.example.com:1080" }, + }); + + expect(setResult.ok).toBe(true); + if (!setResult.ok) return; + + expect(setResult.data.proxy_url).toBe("socks5://proxy.example.com:1080"); + + const clearResult = prepareProviderBatchApplyUpdates({ + proxy_url: { clear: true }, + }); + + expect(clearResult.ok).toBe(true); + if (!clearResult.ok) return; + + expect(clearResult.data.proxy_url).toBeNull(); + }); + + it("accepts boolean set for proxy_fallback_to_direct (non-clearable)", () => { + const result = prepareProviderBatchApplyUpdates({ + proxy_fallback_to_direct: { set: true }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.proxy_fallback_to_direct).toBe(true); + }); + + it("rejects clear on proxy_fallback_to_direct", () => { + const result = normalizeProviderBatchPatchDraft({ + proxy_fallback_to_direct: { clear: true } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe("proxy_fallback_to_direct"); + }); + + it.each([ + "first_byte_timeout_streaming_ms", + "streaming_idle_timeout_ms", + "request_timeout_non_streaming_ms", + ] as const)("accepts number set for %s (non-clearable)", (field) => { + const result = prepareProviderBatchApplyUpdates({ + [field]: { set: 30000 }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data[field]).toBe(30000); + }); + + it.each([ + "first_byte_timeout_streaming_ms", + "streaming_idle_timeout_ms", + "request_timeout_non_streaming_ms", + ] as const)("rejects clear on %s", (field) => { + const result = normalizeProviderBatchPatchDraft({ + [field]: { clear: true } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe(field); + }); + }); + + describe("MCP fields", () => { + it.each([ + "none", + "minimax", + "glm", + "custom", + ] as const)("accepts mcp_passthrough_type value: %s", (value) => { + const result = prepareProviderBatchApplyUpdates({ + mcp_passthrough_type: { set: value }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.mcp_passthrough_type).toBe(value); + }); + + it("rejects invalid mcp_passthrough_type value", () => { + const result = normalizeProviderBatchPatchDraft({ + mcp_passthrough_type: { set: "openai" } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe("mcp_passthrough_type"); + }); + + it("rejects clear on mcp_passthrough_type", () => { + const result = normalizeProviderBatchPatchDraft({ + mcp_passthrough_type: { clear: true } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.field).toBe("mcp_passthrough_type"); + }); + + it("accepts mcp_passthrough_url as string and clears to null", () => { + const setResult = prepareProviderBatchApplyUpdates({ + mcp_passthrough_url: { set: "https://api.minimaxi.com" }, + }); + + expect(setResult.ok).toBe(true); + if (!setResult.ok) return; + + expect(setResult.data.mcp_passthrough_url).toBe("https://api.minimaxi.com"); + + const clearResult = prepareProviderBatchApplyUpdates({ + mcp_passthrough_url: { clear: true }, + }); + + expect(clearResult.ok).toBe(true); + if (!clearResult.ok) return; + + expect(clearResult.data.mcp_passthrough_url).toBeNull(); + }); + }); + + describe("preference fields clear to inherit", () => { + it.each([ + "cache_ttl_preference", + "context_1m_preference", + "codex_reasoning_effort_preference", + "codex_reasoning_summary_preference", + "codex_text_verbosity_preference", + "codex_parallel_tool_calls_preference", + "anthropic_max_tokens_preference", + "gemini_google_search_preference", + ] as const)("clears %s to inherit", (field) => { + const result = prepareProviderBatchApplyUpdates({ + [field]: { clear: true }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data[field]).toBe("inherit"); + }); + }); + + describe("non-clearable field rejection", () => { + it.each([ + "preserve_client_ip", + "swap_cache_ttl_billing", + "daily_reset_mode", + "daily_reset_time", + "limit_concurrent_sessions", + "circuit_breaker_failure_threshold", + "circuit_breaker_open_duration", + "circuit_breaker_half_open_success_threshold", + "proxy_fallback_to_direct", + "first_byte_timeout_streaming_ms", + "streaming_idle_timeout_ms", + "request_timeout_non_streaming_ms", + "mcp_passthrough_type", + ] as const)("rejects clear on non-clearable field: %s", (field) => { + const result = normalizeProviderBatchPatchDraft({ + [field]: { clear: true } as never, + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + + expect(result.error.code).toBe(PROVIDER_PATCH_ERROR_CODES.INVALID_PATCH_SHAPE); + expect(result.error.field).toBe(field); + }); + }); + + describe("hasProviderBatchPatchChanges for new fields", () => { + it("detects change on a single new field", () => { + const normalized = normalizeProviderBatchPatchDraft({ + preserve_client_ip: { set: true }, + }); + + expect(normalized.ok).toBe(true); + if (!normalized.ok) return; + + expect(hasProviderBatchPatchChanges(normalized.data)).toBe(true); + }); + + it("detects change on mcp_passthrough_url (last field)", () => { + const normalized = normalizeProviderBatchPatchDraft({ + mcp_passthrough_url: { set: "https://example.com" }, + }); + + expect(normalized.ok).toBe(true); + if (!normalized.ok) return; + + expect(hasProviderBatchPatchChanges(normalized.data)).toBe(true); + }); + + it("reports no change when all new fields are no_change", () => { + const normalized = normalizeProviderBatchPatchDraft({ + preserve_client_ip: { no_change: true }, + limit_5h_usd: { no_change: true }, + proxy_url: { no_change: true }, + }); + + expect(normalized.ok).toBe(true); + if (!normalized.ok) return; + + expect(hasProviderBatchPatchChanges(normalized.data)).toBe(false); + }); + }); + + describe("combined set across all categories", () => { + it("handles a batch patch touching all field categories at once", () => { + const result = prepareProviderBatchApplyUpdates({ + // existing + is_enabled: { set: true }, + group_tag: { set: "batch-test" }, + // routing + preserve_client_ip: { set: false }, + cache_ttl_preference: { set: "1h" }, + codex_reasoning_effort_preference: { set: "high" }, + anthropic_max_tokens_preference: { set: "16384" }, + // rate limit + limit_5h_usd: { set: 50 }, + daily_reset_mode: { set: "rolling" }, + daily_reset_time: { set: "08:00" }, + // circuit breaker + circuit_breaker_failure_threshold: { set: 5 }, + max_retry_attempts: { set: 2 }, + // network + proxy_url: { set: "https://proxy.local" }, + proxy_fallback_to_direct: { set: true }, + first_byte_timeout_streaming_ms: { set: 15000 }, + // mcp + mcp_passthrough_type: { set: "minimax" }, + mcp_passthrough_url: { set: "https://api.minimaxi.com" }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.is_enabled).toBe(true); + expect(result.data.group_tag).toBe("batch-test"); + expect(result.data.preserve_client_ip).toBe(false); + expect(result.data.cache_ttl_preference).toBe("1h"); + expect(result.data.codex_reasoning_effort_preference).toBe("high"); + expect(result.data.anthropic_max_tokens_preference).toBe("16384"); + expect(result.data.limit_5h_usd).toBe(50); + expect(result.data.daily_reset_mode).toBe("rolling"); + expect(result.data.daily_reset_time).toBe("08:00"); + expect(result.data.circuit_breaker_failure_threshold).toBe(5); + expect(result.data.max_retry_attempts).toBe(2); + expect(result.data.proxy_url).toBe("https://proxy.local"); + expect(result.data.proxy_fallback_to_direct).toBe(true); + expect(result.data.first_byte_timeout_streaming_ms).toBe(15000); + expect(result.data.mcp_passthrough_type).toBe("minimax"); + expect(result.data.mcp_passthrough_url).toBe("https://api.minimaxi.com"); + }); + }); +}); diff --git a/tests/unit/actions/providers-preview-engine.test.ts b/tests/unit/actions/providers-preview-engine.test.ts new file mode 100644 index 000000000..744365abe --- /dev/null +++ b/tests/unit/actions/providers-preview-engine.test.ts @@ -0,0 +1,563 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import type { Provider } from "@/types/provider"; + +const getSessionMock = vi.fn(); +const findAllProvidersFreshMock = vi.fn(); + +vi.mock("@/lib/auth", () => ({ + getSession: getSessionMock, +})); + +vi.mock("@/repository/provider", () => ({ + findAllProvidersFresh: findAllProvidersFreshMock, + updateProvidersBatch: vi.fn(), + deleteProvidersBatch: vi.fn(), +})); + +vi.mock("@/lib/cache/provider-cache", () => ({ + publishProviderCacheInvalidation: vi.fn(), +})); + +vi.mock("@/lib/circuit-breaker", () => ({ + clearProviderState: vi.fn(), + clearConfigCache: vi.fn(), + resetCircuit: vi.fn(), +})); + +vi.mock("@/lib/logger", () => ({ + logger: { + trace: vi.fn(), + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + }, +})); + +function buildTestProvider(overrides: Partial = {}): Provider { + return { + id: 1, + name: "Test Provider", + url: "https://api.example.com", + key: "test-key", + providerVendorId: null, + isEnabled: true, + weight: 10, + priority: 1, + groupPriorities: null, + costMultiplier: 1.0, + groupTag: null, + providerType: "claude", + preserveClientIp: false, + modelRedirects: null, + allowedModels: null, + mcpPassthroughType: "none", + mcpPassthroughUrl: null, + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + totalCostResetAt: null, + limitConcurrentSessions: 10, + maxRetryAttempts: null, + circuitBreakerFailureThreshold: 5, + circuitBreakerOpenDuration: 1800000, + circuitBreakerHalfOpenSuccessThreshold: 2, + proxyUrl: null, + proxyFallbackToDirect: false, + firstByteTimeoutStreamingMs: 30000, + streamingIdleTimeoutMs: 10000, + requestTimeoutNonStreamingMs: 600000, + websiteUrl: null, + faviconUrl: null, + cacheTtlPreference: null, + swapCacheTtlBilling: false, + context1mPreference: null, + codexReasoningEffortPreference: null, + codexReasoningSummaryPreference: null, + codexTextVerbosityPreference: null, + codexParallelToolCallsPreference: null, + anthropicMaxTokensPreference: null, + anthropicThinkingBudgetPreference: null, + anthropicAdaptiveThinking: null, + geminiGoogleSearchPreference: null, + tpm: null, + rpm: null, + rpd: null, + cc: null, + createdAt: new Date("2026-01-01"), + updatedAt: new Date("2026-01-01"), + ...overrides, + }; +} + +describe("Provider Batch Preview Engine - Row Generation", () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.resetModules(); + getSessionMock.mockResolvedValue({ user: { id: 1, role: "admin" } }); + }); + + it("generates correct before/after row for single provider single field change", async () => { + const provider = buildTestProvider({ + id: 5, + name: "Claude One", + groupTag: "old-group", + }); + findAllProvidersFreshMock.mockResolvedValue([provider]); + + const { previewProviderBatchPatch } = await import("@/actions/providers"); + const result = await previewProviderBatchPatch({ + providerIds: [5], + patch: { group_tag: { set: "new-group" } }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.rows).toHaveLength(1); + expect(result.data.rows[0]).toEqual({ + providerId: 5, + providerName: "Claude One", + field: "group_tag", + status: "changed", + before: "old-group", + after: "new-group", + }); + }); + + it("generates rows for each provider-field combination", async () => { + const providerA = buildTestProvider({ + id: 1, + name: "Provider A", + priority: 5, + weight: 10, + }); + const providerB = buildTestProvider({ + id: 2, + name: "Provider B", + priority: 3, + weight: 20, + }); + findAllProvidersFreshMock.mockResolvedValue([providerA, providerB]); + + const { previewProviderBatchPatch } = await import("@/actions/providers"); + const result = await previewProviderBatchPatch({ + providerIds: [1, 2], + patch: { + priority: { set: 10 }, + weight: { set: 50 }, + }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.rows).toHaveLength(4); + + expect(result.data.rows).toContainEqual({ + providerId: 1, + providerName: "Provider A", + field: "priority", + status: "changed", + before: 5, + after: 10, + }); + expect(result.data.rows).toContainEqual({ + providerId: 1, + providerName: "Provider A", + field: "weight", + status: "changed", + before: 10, + after: 50, + }); + expect(result.data.rows).toContainEqual({ + providerId: 2, + providerName: "Provider B", + field: "priority", + status: "changed", + before: 3, + after: 10, + }); + expect(result.data.rows).toContainEqual({ + providerId: 2, + providerName: "Provider B", + field: "weight", + status: "changed", + before: 20, + after: 50, + }); + }); + + it("marks anthropic fields as skipped for non-claude providers", async () => { + const provider = buildTestProvider({ + id: 10, + name: "OpenAI Compat", + providerType: "openai-compatible", + anthropicThinkingBudgetPreference: null, + anthropicAdaptiveThinking: null, + }); + findAllProvidersFreshMock.mockResolvedValue([provider]); + + const { previewProviderBatchPatch } = await import("@/actions/providers"); + const result = await previewProviderBatchPatch({ + providerIds: [10], + patch: { + anthropic_thinking_budget_preference: { set: "8192" }, + anthropic_adaptive_thinking: { + set: { effort: "high", modelMatchMode: "all", models: [] }, + }, + }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.rows).toHaveLength(2); + + const budgetRow = result.data.rows.find( + (r: { field: string }) => r.field === "anthropic_thinking_budget_preference" + ); + expect(budgetRow).toEqual({ + providerId: 10, + providerName: "OpenAI Compat", + field: "anthropic_thinking_budget_preference", + status: "skipped", + before: null, + after: "8192", + skipReason: expect.any(String), + }); + + const adaptiveRow = result.data.rows.find( + (r: { field: string }) => r.field === "anthropic_adaptive_thinking" + ); + expect(adaptiveRow).toEqual({ + providerId: 10, + providerName: "OpenAI Compat", + field: "anthropic_adaptive_thinking", + status: "skipped", + before: null, + after: { effort: "high", modelMatchMode: "all", models: [] }, + skipReason: expect.any(String), + }); + }); + + it("marks anthropic fields as changed for claude providers", async () => { + const provider = buildTestProvider({ + id: 20, + name: "Claude Main", + providerType: "claude", + anthropicThinkingBudgetPreference: "inherit", + }); + findAllProvidersFreshMock.mockResolvedValue([provider]); + + const { previewProviderBatchPatch } = await import("@/actions/providers"); + const result = await previewProviderBatchPatch({ + providerIds: [20], + patch: { anthropic_thinking_budget_preference: { set: "16000" } }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.rows).toHaveLength(1); + expect(result.data.rows[0]).toEqual({ + providerId: 20, + providerName: "Claude Main", + field: "anthropic_thinking_budget_preference", + status: "changed", + before: "inherit", + after: "16000", + }); + }); + + it("marks anthropic fields as changed for claude-auth providers", async () => { + const provider = buildTestProvider({ + id: 21, + name: "Claude Auth", + providerType: "claude-auth", + anthropicAdaptiveThinking: null, + }); + findAllProvidersFreshMock.mockResolvedValue([provider]); + + const { previewProviderBatchPatch } = await import("@/actions/providers"); + const result = await previewProviderBatchPatch({ + providerIds: [21], + patch: { + anthropic_adaptive_thinking: { + set: { effort: "medium", modelMatchMode: "all", models: [] }, + }, + }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.rows).toHaveLength(1); + expect(result.data.rows[0].status).toBe("changed"); + expect(result.data.rows[0].providerId).toBe(21); + }); + + it("computes correct after values for clear mode", async () => { + const provider = buildTestProvider({ + id: 30, + name: "Clear Test", + providerType: "claude", + groupTag: "old-tag", + modelRedirects: { "model-a": "model-b" }, + allowedModels: ["claude-3"], + anthropicThinkingBudgetPreference: "8192", + anthropicAdaptiveThinking: { + effort: "high", + modelMatchMode: "all", + models: [], + }, + }); + findAllProvidersFreshMock.mockResolvedValue([provider]); + + const { previewProviderBatchPatch } = await import("@/actions/providers"); + const result = await previewProviderBatchPatch({ + providerIds: [30], + patch: { + group_tag: { clear: true }, + model_redirects: { clear: true }, + allowed_models: { clear: true }, + anthropic_thinking_budget_preference: { clear: true }, + anthropic_adaptive_thinking: { clear: true }, + }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.rows).toHaveLength(5); + + const groupTagRow = result.data.rows.find((r: { field: string }) => r.field === "group_tag"); + expect(groupTagRow?.before).toBe("old-tag"); + expect(groupTagRow?.after).toBeNull(); + + const modelRedirectsRow = result.data.rows.find( + (r: { field: string }) => r.field === "model_redirects" + ); + expect(modelRedirectsRow?.before).toEqual({ "model-a": "model-b" }); + expect(modelRedirectsRow?.after).toBeNull(); + + const allowedModelsRow = result.data.rows.find( + (r: { field: string }) => r.field === "allowed_models" + ); + expect(allowedModelsRow?.before).toEqual(["claude-3"]); + expect(allowedModelsRow?.after).toBeNull(); + + // anthropic_thinking_budget_preference clears to "inherit" + const budgetRow = result.data.rows.find( + (r: { field: string }) => r.field === "anthropic_thinking_budget_preference" + ); + expect(budgetRow?.before).toBe("8192"); + expect(budgetRow?.after).toBe("inherit"); + + const adaptiveRow = result.data.rows.find( + (r: { field: string }) => r.field === "anthropic_adaptive_thinking" + ); + expect(adaptiveRow?.before).toEqual({ + effort: "high", + modelMatchMode: "all", + models: [], + }); + expect(adaptiveRow?.after).toBeNull(); + }); + + it("normalizes empty allowed_models array to null in after value", async () => { + const provider = buildTestProvider({ + id: 40, + name: "Models Test", + allowedModels: ["claude-3"], + }); + findAllProvidersFreshMock.mockResolvedValue([provider]); + + const { previewProviderBatchPatch } = await import("@/actions/providers"); + const result = await previewProviderBatchPatch({ + providerIds: [40], + patch: { allowed_models: { set: [] } }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.rows).toHaveLength(1); + expect(result.data.rows[0].before).toEqual(["claude-3"]); + expect(result.data.rows[0].after).toBeNull(); + }); + + it("includes correct skipCount in summary", async () => { + const claudeProvider = buildTestProvider({ + id: 50, + name: "Claude", + providerType: "claude", + }); + const openaiProvider = buildTestProvider({ + id: 51, + name: "OpenAI", + providerType: "openai-compatible", + }); + const geminiProvider = buildTestProvider({ + id: 52, + name: "Gemini", + providerType: "gemini", + }); + findAllProvidersFreshMock.mockResolvedValue([claudeProvider, openaiProvider, geminiProvider]); + + const { previewProviderBatchPatch } = await import("@/actions/providers"); + const result = await previewProviderBatchPatch({ + providerIds: [50, 51, 52], + patch: { + anthropic_thinking_budget_preference: { set: "8192" }, + group_tag: { set: "new-tag" }, + }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + // 3 providers x 2 fields = 6 rows + expect(result.data.rows).toHaveLength(6); + // 2 non-claude providers x 1 anthropic field = 2 skipped + expect(result.data.summary.skipCount).toBe(2); + expect(result.data.summary.providerCount).toBe(3); + expect(result.data.summary.fieldCount).toBe(2); + }); + + it("returns rows in the preview result for snapshot storage", async () => { + const provider = buildTestProvider({ + id: 60, + name: "Snapshot Test", + isEnabled: true, + }); + findAllProvidersFreshMock.mockResolvedValue([provider]); + + const { previewProviderBatchPatch } = await import("@/actions/providers"); + const result = await previewProviderBatchPatch({ + providerIds: [60], + patch: { is_enabled: { set: false } }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.rows).toBeDefined(); + expect(Array.isArray(result.data.rows)).toBe(true); + expect(result.data.rows).toHaveLength(1); + expect(result.data.rows[0]).toEqual({ + providerId: 60, + providerName: "Snapshot Test", + field: "is_enabled", + status: "changed", + before: true, + after: false, + }); + }); + + it("only generates rows for providers matching requested IDs", async () => { + const providerA = buildTestProvider({ id: 100, name: "Match" }); + const providerB = buildTestProvider({ id: 200, name: "No Match" }); + findAllProvidersFreshMock.mockResolvedValue([providerA, providerB]); + + const { previewProviderBatchPatch } = await import("@/actions/providers"); + const result = await previewProviderBatchPatch({ + providerIds: [100], + patch: { priority: { set: 99 } }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.rows).toHaveLength(1); + expect(result.data.rows[0].providerId).toBe(100); + }); + + it("skips anthropic fields for all non-claude provider types", async () => { + const codexProvider = buildTestProvider({ + id: 70, + name: "Codex", + providerType: "codex", + }); + const geminiCliProvider = buildTestProvider({ + id: 71, + name: "Gemini CLI", + providerType: "gemini-cli", + }); + findAllProvidersFreshMock.mockResolvedValue([codexProvider, geminiCliProvider]); + + const { previewProviderBatchPatch } = await import("@/actions/providers"); + const result = await previewProviderBatchPatch({ + providerIds: [70, 71], + patch: { + anthropic_adaptive_thinking: { + set: { effort: "low", modelMatchMode: "all", models: [] }, + }, + }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + expect(result.data.rows).toHaveLength(2); + expect(result.data.rows.every((r: { status: string }) => r.status === "skipped")).toBe(true); + expect(result.data.summary.skipCount).toBe(2); + }); + + it("handles mixed changed and skipped rows across providers", async () => { + const claudeProvider = buildTestProvider({ + id: 80, + name: "Claude", + providerType: "claude", + groupTag: "alpha", + anthropicThinkingBudgetPreference: null, + }); + const openaiProvider = buildTestProvider({ + id: 81, + name: "OpenAI", + providerType: "openai-compatible", + groupTag: "beta", + anthropicThinkingBudgetPreference: null, + }); + findAllProvidersFreshMock.mockResolvedValue([claudeProvider, openaiProvider]); + + const { previewProviderBatchPatch } = await import("@/actions/providers"); + const result = await previewProviderBatchPatch({ + providerIds: [80, 81], + patch: { + group_tag: { set: "gamma" }, + anthropic_thinking_budget_preference: { set: "4096" }, + }, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + + // 2 providers x 2 fields = 4 rows + expect(result.data.rows).toHaveLength(4); + + // group_tag: both changed (universal field) + const groupTagRows = result.data.rows.filter((r: { field: string }) => r.field === "group_tag"); + expect(groupTagRows).toHaveLength(2); + expect(groupTagRows.every((r: { status: string }) => r.status === "changed")).toBe(true); + + // anthropic_thinking_budget_preference: claude changed, openai skipped + const budgetRows = result.data.rows.filter( + (r: { field: string }) => r.field === "anthropic_thinking_budget_preference" + ); + expect(budgetRows).toHaveLength(2); + + const claudeBudget = budgetRows.find((r: { providerId: number }) => r.providerId === 80); + expect(claudeBudget?.status).toBe("changed"); + + const openaiBudget = budgetRows.find((r: { providerId: number }) => r.providerId === 81); + expect(openaiBudget?.status).toBe("skipped"); + expect(openaiBudget?.skipReason).toBeTruthy(); + + expect(result.data.summary.skipCount).toBe(1); + }); +}); diff --git a/tests/unit/actions/providers-undo-engine.test.ts b/tests/unit/actions/providers-undo-engine.test.ts new file mode 100644 index 000000000..b7f094da8 --- /dev/null +++ b/tests/unit/actions/providers-undo-engine.test.ts @@ -0,0 +1,391 @@ +// @vitest-environment node +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { PROVIDER_BATCH_PATCH_ERROR_CODES } from "@/lib/provider-batch-patch-error-codes"; + +const getSessionMock = vi.fn(); +const findAllProvidersFreshMock = vi.fn(); +const updateProvidersBatchMock = vi.fn(); +const publishCacheInvalidationMock = vi.fn(); + +vi.mock("@/lib/auth", () => ({ + getSession: getSessionMock, +})); + +vi.mock("@/repository/provider", () => ({ + findAllProvidersFresh: findAllProvidersFreshMock, + updateProvidersBatch: updateProvidersBatchMock, + deleteProvidersBatch: vi.fn(), +})); + +vi.mock("@/lib/cache/provider-cache", () => ({ + publishProviderCacheInvalidation: publishCacheInvalidationMock, +})); + +vi.mock("@/lib/circuit-breaker", () => ({ + clearProviderState: vi.fn(), + clearConfigCache: vi.fn(), + resetCircuit: vi.fn(), + getAllHealthStatusAsync: vi.fn(), +})); + +vi.mock("@/lib/logger", () => ({ + logger: { + trace: vi.fn(), + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + }, +})); + +function makeProvider(id: number, overrides: Record = {}) { + return { + id, + name: `Provider-${id}`, + url: "https://api.example.com/v1", + key: "sk-test", + providerVendorId: null, + isEnabled: true, + weight: 100, + priority: 1, + groupPriorities: null, + costMultiplier: 1.0, + groupTag: null, + providerType: "claude", + preserveClientIp: false, + modelRedirects: null, + allowedModels: null, + mcpPassthroughType: "none", + mcpPassthroughUrl: null, + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + totalCostResetAt: null, + limitConcurrentSessions: null, + maxRetryAttempts: null, + circuitBreakerFailureThreshold: 5, + circuitBreakerOpenDuration: 1800000, + circuitBreakerHalfOpenSuccessThreshold: 2, + proxyUrl: null, + proxyFallbackToDirect: false, + firstByteTimeoutStreamingMs: 30000, + streamingIdleTimeoutMs: 10000, + requestTimeoutNonStreamingMs: 600000, + websiteUrl: null, + faviconUrl: null, + cacheTtlPreference: null, + swapCacheTtlBilling: false, + context1mPreference: null, + codexReasoningEffortPreference: null, + codexReasoningSummaryPreference: null, + codexTextVerbosityPreference: null, + codexParallelToolCallsPreference: null, + anthropicMaxTokensPreference: null, + anthropicThinkingBudgetPreference: null, + anthropicAdaptiveThinking: null, + geminiGoogleSearchPreference: null, + tpm: null, + rpm: null, + rpd: null, + cc: null, + createdAt: new Date("2025-01-01"), + updatedAt: new Date("2025-01-01"), + deletedAt: null, + ...overrides, + }; +} + +describe("Undo Provider Batch Patch Engine", () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.resetModules(); + getSessionMock.mockResolvedValue({ user: { id: 1, role: "admin" } }); + findAllProvidersFreshMock.mockResolvedValue([]); + updateProvidersBatchMock.mockResolvedValue(0); + publishCacheInvalidationMock.mockResolvedValue(undefined); + }); + + /** Helper: preview -> apply -> return undo token + operationId + undoProviderPatch */ + async function setupPreviewApplyAndGetUndo( + providers: ReturnType[], + providerIds: number[], + patch: Record, + applyOverrides: Record = {} + ) { + findAllProvidersFreshMock.mockResolvedValue(providers); + updateProvidersBatchMock.mockResolvedValue(providers.length); + + const { previewProviderBatchPatch, applyProviderBatchPatch, undoProviderPatch } = await import( + "@/actions/providers" + ); + + const preview = await previewProviderBatchPatch({ providerIds, patch }); + if (!preview.ok) throw new Error(`Preview failed: ${preview.error}`); + + const apply = await applyProviderBatchPatch({ + previewToken: preview.data.previewToken, + previewRevision: preview.data.previewRevision, + providerIds, + patch, + ...applyOverrides, + }); + if (!apply.ok) throw new Error(`Apply failed: ${apply.error}`); + + // Reset mocks after apply so undo assertions are clean + updateProvidersBatchMock.mockClear(); + publishCacheInvalidationMock.mockClear(); + + return { + undoToken: apply.data.undoToken, + operationId: apply.data.operationId, + undoProviderPatch, + }; + } + + it("should revert each provider's fields to preimage values", async () => { + const providers = [ + makeProvider(1, { groupTag: "alpha" }), + makeProvider(2, { groupTag: "beta" }), + ]; + + const { undoToken, operationId, undoProviderPatch } = await setupPreviewApplyAndGetUndo( + providers, + [1, 2], + { group_tag: { set: "gamma" } } + ); + + updateProvidersBatchMock.mockResolvedValue(1); + + const result = await undoProviderPatch({ undoToken, operationId }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + // Provider 1 had groupTag "alpha", provider 2 had "beta" -- different preimages + expect(updateProvidersBatchMock).toHaveBeenCalledWith( + [1], + expect.objectContaining({ groupTag: "alpha" }) + ); + expect(updateProvidersBatchMock).toHaveBeenCalledWith( + [2], + expect.objectContaining({ groupTag: "beta" }) + ); + }); + + it("should call updateProvidersBatch per unique preimage group", async () => { + const providers = [ + makeProvider(1, { groupTag: "same" }), + makeProvider(2, { groupTag: "same" }), + makeProvider(3, { groupTag: "different" }), + ]; + + const { undoToken, operationId, undoProviderPatch } = await setupPreviewApplyAndGetUndo( + providers, + [1, 2, 3], + { group_tag: { set: "new-value" } } + ); + + updateProvidersBatchMock.mockResolvedValue(1); + + await undoProviderPatch({ undoToken, operationId }); + + // 2 groups: [1,2] with "same" and [3] with "different" + expect(updateProvidersBatchMock).toHaveBeenCalledTimes(2); + // One call should batch providers 1 and 2 together + const calls = updateProvidersBatchMock.mock.calls as Array<[number[], Record]>; + const groupedCall = calls.find((c) => c[0].length === 2); + expect(groupedCall).toBeDefined(); + expect(groupedCall![0]).toEqual(expect.arrayContaining([1, 2])); + }); + + it("should publish cache invalidation after undo", async () => { + const providers = [makeProvider(1, { groupTag: "old" })]; + + const { undoToken, operationId, undoProviderPatch } = await setupPreviewApplyAndGetUndo( + providers, + [1], + { group_tag: { set: "new" } } + ); + + updateProvidersBatchMock.mockResolvedValue(1); + + const result = await undoProviderPatch({ undoToken, operationId }); + + expect(result.ok).toBe(true); + expect(publishCacheInvalidationMock).toHaveBeenCalledOnce(); + }); + + it("should return correct revertedCount from actual DB writes", async () => { + const providers = [ + makeProvider(1, { groupTag: "a" }), + makeProvider(2, { groupTag: "b" }), + makeProvider(3, { groupTag: "c" }), + ]; + + const { undoToken, operationId, undoProviderPatch } = await setupPreviewApplyAndGetUndo( + providers, + [1, 2, 3], + { group_tag: { set: "unified" } } + ); + + // Each per-group call returns 1 + updateProvidersBatchMock.mockResolvedValue(1); + + const result = await undoProviderPatch({ undoToken, operationId }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + // 3 different preimages -> 3 calls, each returning 1 + expect(result.data.revertedCount).toBe(3); + }); + + it("should return UNDO_EXPIRED for missing token", async () => { + const { undoProviderPatch } = await import("@/actions/providers"); + + const result = await undoProviderPatch({ + undoToken: "nonexistent_token", + operationId: "op_123", + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + expect(result.errorCode).toBe(PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_EXPIRED); + }); + + it("should return UNDO_CONFLICT for mismatched operationId", async () => { + const providers = [makeProvider(1, { groupTag: "old" })]; + + const { undoToken, undoProviderPatch } = await setupPreviewApplyAndGetUndo(providers, [1], { + group_tag: { set: "new" }, + }); + + const result = await undoProviderPatch({ + undoToken, + operationId: "wrong_operation_id", + }); + + expect(result.ok).toBe(false); + if (result.ok) return; + expect(result.errorCode).toBe(PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_CONFLICT); + expect(updateProvidersBatchMock).not.toHaveBeenCalled(); + }); + + it("should consume undo token after successful undo", async () => { + const providers = [makeProvider(1, { groupTag: "old" })]; + + const { undoToken, operationId, undoProviderPatch } = await setupPreviewApplyAndGetUndo( + providers, + [1], + { group_tag: { set: "new" } } + ); + + updateProvidersBatchMock.mockResolvedValue(1); + + const first = await undoProviderPatch({ undoToken, operationId }); + expect(first.ok).toBe(true); + + // Second undo with same token should fail -- token was consumed + const second = await undoProviderPatch({ undoToken, operationId }); + expect(second.ok).toBe(false); + if (second.ok) return; + expect(second.errorCode).toBe(PROVIDER_BATCH_PATCH_ERROR_CODES.UNDO_EXPIRED); + }); + + it("should handle costMultiplier number-to-string conversion", async () => { + const providers = [makeProvider(1, { costMultiplier: 1.5 })]; + + const { undoToken, operationId, undoProviderPatch } = await setupPreviewApplyAndGetUndo( + providers, + [1], + { cost_multiplier: { set: 2.5 } } + ); + + updateProvidersBatchMock.mockResolvedValue(1); + + const result = await undoProviderPatch({ undoToken, operationId }); + + expect(result.ok).toBe(true); + // The preimage stored costMultiplier as number 1.5; undo must convert to string "1.5" + expect(updateProvidersBatchMock).toHaveBeenCalledWith( + [1], + expect.objectContaining({ costMultiplier: "1.5" }) + ); + }); + + it("should handle providers with different preimage values individually", async () => { + const providers = [ + makeProvider(1, { priority: 5, weight: 80 }), + makeProvider(2, { priority: 10, weight: 60 }), + ]; + + const { undoToken, operationId, undoProviderPatch } = await setupPreviewApplyAndGetUndo( + providers, + [1, 2], + { priority: { set: 1 }, weight: { set: 100 } } + ); + + updateProvidersBatchMock.mockResolvedValue(1); + + const result = await undoProviderPatch({ undoToken, operationId }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + // Each provider should be reverted with its own original values + expect(updateProvidersBatchMock).toHaveBeenCalledWith( + [1], + expect.objectContaining({ priority: 5, weight: 80 }) + ); + expect(updateProvidersBatchMock).toHaveBeenCalledWith( + [2], + expect.objectContaining({ priority: 10, weight: 60 }) + ); + expect(result.data.revertedCount).toBe(2); + }); + + it("should handle providerIds without preimage entries gracefully", async () => { + // Only provider 1 exists in DB; provider 999 has no preimage + const providers = [makeProvider(1, { groupTag: "old" })]; + findAllProvidersFreshMock.mockResolvedValue(providers); + updateProvidersBatchMock.mockResolvedValue(1); + + const { previewProviderBatchPatch, applyProviderBatchPatch, undoProviderPatch } = await import( + "@/actions/providers" + ); + + const preview = await previewProviderBatchPatch({ + providerIds: [1, 999], + patch: { group_tag: { set: "new" } }, + }); + if (!preview.ok) throw new Error(`Preview failed: ${preview.error}`); + + const apply = await applyProviderBatchPatch({ + previewToken: preview.data.previewToken, + previewRevision: preview.data.previewRevision, + providerIds: [1, 999], + patch: { group_tag: { set: "new" } }, + }); + if (!apply.ok) throw new Error(`Apply failed: ${apply.error}`); + + updateProvidersBatchMock.mockClear(); + publishCacheInvalidationMock.mockClear(); + updateProvidersBatchMock.mockResolvedValue(1); + + const result = await undoProviderPatch({ + undoToken: apply.data.undoToken, + operationId: apply.data.operationId, + }); + + expect(result.ok).toBe(true); + if (!result.ok) return; + // Only provider 1 has preimage, provider 999 is skipped + expect(updateProvidersBatchMock).toHaveBeenCalledTimes(1); + expect(updateProvidersBatchMock).toHaveBeenCalledWith( + [1], + expect.objectContaining({ groupTag: "old" }) + ); + expect(result.data.revertedCount).toBe(1); + }); +}); diff --git a/tests/unit/actions/providers-undo-store.test.ts b/tests/unit/actions/providers-undo-store.test.ts new file mode 100644 index 000000000..dbc495f6d --- /dev/null +++ b/tests/unit/actions/providers-undo-store.test.ts @@ -0,0 +1,180 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; + +const setexMock = vi.fn(); +const getMock = vi.fn(); +const delMock = vi.fn(); +const evalMock = vi.fn(); + +vi.mock("@/lib/redis/client", () => ({ + getRedisClient: () => ({ + status: "ready", + setex: setexMock, + get: getMock, + del: delMock, + eval: evalMock, + }), +})); + +vi.mock("@/lib/logger", () => ({ + logger: { + error: vi.fn(), + warn: vi.fn(), + info: vi.fn(), + debug: vi.fn(), + }, +})); + +vi.mock("server-only", () => ({})); + +function buildSnapshot(overrides: Partial> = {}) { + return { + operationId: "op-1", + operationType: "batch_edit" as const, + preimage: { before: "state" }, + providerIds: [1, 2], + createdAt: new Date().toISOString(), + ...overrides, + }; +} + +describe("providers undo store", () => { + beforeEach(() => { + vi.useFakeTimers(); + vi.setSystemTime(new Date("2026-02-18T00:00:00.000Z")); + vi.resetModules(); + vi.clearAllMocks(); + setexMock.mockResolvedValue("OK"); + delMock.mockResolvedValue(1); + }); + + afterEach(() => { + vi.restoreAllMocks(); + vi.useRealTimers(); + }); + + it("stores snapshot and consumes token within TTL", async () => { + const token = "11111111-1111-1111-1111-111111111111"; + vi.spyOn(crypto, "randomUUID").mockReturnValue(token); + + const snapshot = buildSnapshot(); + evalMock.mockResolvedValue(JSON.stringify(snapshot)); + + const { storeUndoSnapshot, consumeUndoToken } = await import("@/lib/providers/undo-store"); + + const storeResult = await storeUndoSnapshot(snapshot); + + expect(storeResult).toEqual({ + undoAvailable: true, + undoToken: token, + expiresAt: "2026-02-18T00:00:30.000Z", + }); + expect(setexMock).toHaveBeenCalledWith(`cch:prov:undo:${token}`, 30, JSON.stringify(snapshot)); + + const consumeResult = await consumeUndoToken(token); + expect(consumeResult).toEqual({ + ok: true, + snapshot, + }); + expect(evalMock).toHaveBeenCalledWith(expect.any(String), 1, `cch:prov:undo:${token}`); + }); + + it("returns UNDO_EXPIRED when Redis returns null (TTL passed)", async () => { + const token = "22222222-2222-2222-2222-222222222222"; + evalMock.mockResolvedValue(null); + + const { consumeUndoToken } = await import("@/lib/providers/undo-store"); + + const consumeResult = await consumeUndoToken(token); + expect(consumeResult).toEqual({ + ok: false, + code: "UNDO_EXPIRED", + }); + }); + + it("consumes a token only once (getAndDelete)", async () => { + const token = "33333333-3333-3333-3333-333333333333"; + vi.spyOn(crypto, "randomUUID").mockReturnValue(token); + + const snapshot = buildSnapshot({ operationId: "op-3" }); + + const { storeUndoSnapshot, consumeUndoToken } = await import("@/lib/providers/undo-store"); + + await storeUndoSnapshot(snapshot); + + evalMock.mockResolvedValueOnce(JSON.stringify(snapshot)).mockResolvedValueOnce(null); + + const first = await consumeUndoToken(token); + const second = await consumeUndoToken(token); + + expect(first).toEqual({ ok: true, snapshot }); + expect(second).toEqual({ ok: false, code: "UNDO_EXPIRED" }); + }); + + it("returns UNDO_EXPIRED for unknown token", async () => { + evalMock.mockResolvedValue(null); + + const { consumeUndoToken } = await import("@/lib/providers/undo-store"); + const result = await consumeUndoToken("undo-token-missing"); + + expect(result).toEqual({ + ok: false, + code: "UNDO_EXPIRED", + }); + }); + + it("stores multiple snapshots with independent tokens", async () => { + const tokenA = "44444444-4444-4444-4444-444444444444"; + const tokenB = "55555555-5555-5555-5555-555555555555"; + vi.spyOn(crypto, "randomUUID").mockReturnValueOnce(tokenA).mockReturnValueOnce(tokenB); + + const { storeUndoSnapshot, consumeUndoToken } = await import("@/lib/providers/undo-store"); + + const snapshotA = buildSnapshot({ operationId: "op-4", providerIds: [11] }); + const snapshotB = buildSnapshot({ + operationId: "op-5", + operationType: "single_edit", + providerIds: [22, 23], + }); + + const storeA = await storeUndoSnapshot(snapshotA); + const storeB = await storeUndoSnapshot(snapshotB); + + expect(storeA.undoToken).toBe(tokenA); + expect(storeB.undoToken).toBe(tokenB); + + evalMock + .mockResolvedValueOnce(JSON.stringify(snapshotA)) + .mockResolvedValueOnce(JSON.stringify(snapshotB)); + + await expect(consumeUndoToken(tokenA)).resolves.toEqual({ + ok: true, + snapshot: snapshotA, + }); + await expect(consumeUndoToken(tokenB)).resolves.toEqual({ + ok: true, + snapshot: snapshotB, + }); + }); + + it("fails open when storage backend throws", async () => { + vi.spyOn(crypto, "randomUUID").mockImplementation(() => { + throw new Error("uuid failed"); + }); + + const { storeUndoSnapshot } = await import("@/lib/providers/undo-store"); + const result = await storeUndoSnapshot(buildSnapshot({ operationId: "op-6" })); + + expect(result).toEqual({ undoAvailable: false }); + }); + + it("returns undoAvailable false when Redis set fails", async () => { + const token = "66666666-6666-6666-6666-666666666666"; + vi.spyOn(crypto, "randomUUID").mockReturnValue(token); + setexMock.mockRejectedValue(new Error("Redis write error")); + + const { storeUndoSnapshot } = await import("@/lib/providers/undo-store"); + const result = await storeUndoSnapshot(buildSnapshot({ operationId: "op-7" })); + + expect(result).toEqual({ undoAvailable: false }); + }); +}); diff --git a/tests/unit/actions/providers.test.ts b/tests/unit/actions/providers.test.ts index 30bc0b4c6..b219ff4c0 100644 --- a/tests/unit/actions/providers.test.ts +++ b/tests/unit/actions/providers.test.ts @@ -2,6 +2,7 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; const getSessionMock = vi.fn(); +const findProviderByIdMock = vi.fn(); const findAllProvidersFreshMock = vi.fn(); const getProviderStatisticsMock = vi.fn(); const createProviderMock = vi.fn(); @@ -26,7 +27,7 @@ vi.mock("@/repository/provider", () => ({ deleteProvider: deleteProviderMock, findAllProviders: vi.fn(async () => []), findAllProvidersFresh: findAllProvidersFreshMock, - findProviderById: vi.fn(async () => null), + findProviderById: findProviderByIdMock, getProviderStatistics: getProviderStatisticsMock, resetProviderTotalCostResetAt: vi.fn(async () => {}), updateProvider: updateProviderMock, @@ -142,6 +143,11 @@ describe("Provider Actions - Async Optimization", () => { getProviderStatisticsMock.mockResolvedValue([]); + findProviderByIdMock.mockImplementation(async (id: number) => { + const providers = await findAllProvidersFreshMock(); + return providers.find((p: { id: number }) => p.id === id) ?? null; + }); + createProviderMock.mockResolvedValue({ id: 123, circuitBreakerFailureThreshold: 5, diff --git a/tests/unit/api/auth-login-failure-taxonomy.test.ts b/tests/unit/api/auth-login-failure-taxonomy.test.ts new file mode 100644 index 000000000..b3f5bbd2e --- /dev/null +++ b/tests/unit/api/auth-login-failure-taxonomy.test.ts @@ -0,0 +1,163 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { NextRequest } from "next/server"; + +const mockValidateKey = vi.hoisted(() => vi.fn()); +const mockSetAuthCookie = vi.hoisted(() => vi.fn()); +const mockGetSessionTokenMode = vi.hoisted(() => vi.fn()); +const mockGetLoginRedirectTarget = vi.hoisted(() => vi.fn()); +const mockGetTranslations = vi.hoisted(() => vi.fn()); +const mockGetEnvConfig = vi.hoisted(() => vi.fn()); +const mockLogger = vi.hoisted(() => ({ + warn: vi.fn(), + error: vi.fn(), + info: vi.fn(), + debug: vi.fn(), +})); + +vi.mock("@/lib/auth", () => ({ + validateKey: mockValidateKey, + setAuthCookie: mockSetAuthCookie, + getSessionTokenMode: mockGetSessionTokenMode, + getLoginRedirectTarget: mockGetLoginRedirectTarget, + withNoStoreHeaders: (res: T): T => { + (res as any).headers.set("Cache-Control", "no-store, no-cache, must-revalidate"); + (res as any).headers.set("Pragma", "no-cache"); + return res; + }, +})); + +vi.mock("next-intl/server", () => ({ + getTranslations: mockGetTranslations, +})); + +vi.mock("@/lib/logger", () => ({ + logger: mockLogger, +})); + +vi.mock("@/lib/config/env.schema", () => ({ + getEnvConfig: mockGetEnvConfig, +})); + +vi.mock("@/lib/security/auth-response-headers", () => ({ + withAuthResponseHeaders: (res: T): T => { + (res as any).headers.set("Cache-Control", "no-store, no-cache, must-revalidate"); + (res as any).headers.set("Pragma", "no-cache"); + return res; + }, +})); + +function makeRequest( + body: unknown, + opts?: { locale?: string; acceptLanguage?: string; xForwardedProto?: string } +): NextRequest { + const headers: Record = { "Content-Type": "application/json" }; + + if (opts?.acceptLanguage) { + headers["accept-language"] = opts.acceptLanguage; + } + + headers["x-forwarded-proto"] = opts?.xForwardedProto ?? "https"; + + const req = new NextRequest("http://localhost/api/auth/login", { + method: "POST", + headers, + body: JSON.stringify(body), + }); + + if (opts?.locale) { + req.cookies.set("NEXT_LOCALE", opts.locale); + } + + return req; +} + +describe("POST /api/auth/login failure taxonomy", () => { + let POST: (request: NextRequest) => Promise; + + beforeEach(async () => { + const mockT = vi.fn((key: string) => `translated:${key}`); + mockGetTranslations.mockResolvedValue(mockT); + mockSetAuthCookie.mockResolvedValue(undefined); + mockGetSessionTokenMode.mockReturnValue("legacy"); + mockGetEnvConfig.mockReturnValue({ ENABLE_SECURE_COOKIES: false }); + + const mod = await import("../../../src/app/api/auth/login/route"); + POST = mod.POST; + }); + + it("returns KEY_REQUIRED taxonomy for missing key", async () => { + const res = await POST(makeRequest({})); + + expect(res.status).toBe(400); + const json = await res.json(); + expect(json).toEqual({ + error: "translated:apiKeyRequired", + errorCode: "KEY_REQUIRED", + }); + expect(mockValidateKey).not.toHaveBeenCalled(); + }); + + it("returns KEY_INVALID taxonomy for invalid key", async () => { + mockValidateKey.mockResolvedValue(null); + + const res = await POST(makeRequest({ key: "bad-key" })); + + expect(res.status).toBe(401); + const json = await res.json(); + expect(json).toEqual({ + error: "translated:apiKeyInvalidOrExpired", + errorCode: "KEY_INVALID", + }); + }); + + it("returns SERVER_ERROR taxonomy when validation throws", async () => { + mockValidateKey.mockRejectedValue(new Error("DB connection failed")); + + const res = await POST(makeRequest({ key: "some-key" })); + + expect(res.status).toBe(500); + const json = await res.json(); + expect(json).toEqual({ + error: "translated:serverError", + errorCode: "SERVER_ERROR", + }); + expect(mockLogger.error).toHaveBeenCalled(); + }); + + it("adds httpMismatchGuidance on invalid key when secure cookies require HTTPS", async () => { + mockGetEnvConfig.mockReturnValue({ ENABLE_SECURE_COOKIES: true }); + mockValidateKey.mockResolvedValue(null); + + const res = await POST(makeRequest({ key: "bad-key" }, { xForwardedProto: "http" })); + + expect(res.status).toBe(401); + const json = await res.json(); + expect(json.error).toBe("translated:apiKeyInvalidOrExpired"); + expect(json.errorCode).toBe("KEY_INVALID"); + expect(typeof json.httpMismatchGuidance).toBe("string"); + expect(json.httpMismatchGuidance.length).toBeGreaterThan(0); + }); + + it("does not add httpMismatchGuidance when no HTTPS mismatch", async () => { + mockValidateKey.mockResolvedValue(null); + + const noSecureCookieRes = await POST( + makeRequest({ key: "bad-key" }, { xForwardedProto: "http" }) + ); + + expect(noSecureCookieRes.status).toBe(401); + expect(await noSecureCookieRes.json()).toEqual({ + error: "translated:apiKeyInvalidOrExpired", + errorCode: "KEY_INVALID", + }); + + mockGetEnvConfig.mockReturnValue({ ENABLE_SECURE_COOKIES: true }); + const httpsRes = await POST(makeRequest({ key: "bad-key" }, { xForwardedProto: "https" })); + + expect(httpsRes.status).toBe(401); + expect(await httpsRes.json()).toEqual({ + error: "translated:apiKeyInvalidOrExpired", + errorCode: "KEY_INVALID", + }); + }); +}); diff --git a/tests/unit/api/auth-login-route.test.ts b/tests/unit/api/auth-login-route.test.ts new file mode 100644 index 000000000..e37ae27c2 --- /dev/null +++ b/tests/unit/api/auth-login-route.test.ts @@ -0,0 +1,316 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { NextRequest } from "next/server"; + +const mockValidateKey = vi.hoisted(() => vi.fn()); +const mockSetAuthCookie = vi.hoisted(() => vi.fn()); +const mockGetSessionTokenMode = vi.hoisted(() => vi.fn()); +const mockGetLoginRedirectTarget = vi.hoisted(() => vi.fn()); +const mockGetTranslations = vi.hoisted(() => vi.fn()); +const mockLogger = vi.hoisted(() => ({ + warn: vi.fn(), + error: vi.fn(), + info: vi.fn(), + debug: vi.fn(), +})); + +vi.mock("@/lib/auth", () => ({ + validateKey: mockValidateKey, + setAuthCookie: mockSetAuthCookie, + getSessionTokenMode: mockGetSessionTokenMode, + getLoginRedirectTarget: mockGetLoginRedirectTarget, + toKeyFingerprint: vi.fn().mockResolvedValue("sha256:fake"), + withNoStoreHeaders: (res: T): T => { + (res as any).headers.set("Cache-Control", "no-store, no-cache, must-revalidate"); + (res as any).headers.set("Pragma", "no-cache"); + return res; + }, +})); + +vi.mock("next-intl/server", () => ({ + getTranslations: mockGetTranslations, +})); + +vi.mock("@/lib/logger", () => ({ + logger: mockLogger, +})); + +vi.mock("@/lib/config/env.schema", () => ({ + getEnvConfig: vi.fn().mockReturnValue({ ENABLE_SECURE_COOKIES: false }), +})); + +vi.mock("@/lib/security/auth-response-headers", () => ({ + withAuthResponseHeaders: (res: T): T => { + (res as any).headers.set("Cache-Control", "no-store, no-cache, must-revalidate"); + (res as any).headers.set("Pragma", "no-cache"); + return res; + }, +})); + +function makeRequest( + body: unknown, + opts?: { locale?: string; acceptLanguage?: string } +): NextRequest { + const headers: Record = { "Content-Type": "application/json" }; + + if (opts?.acceptLanguage) { + headers["accept-language"] = opts.acceptLanguage; + } + + const req = new NextRequest("http://localhost/api/auth/login", { + method: "POST", + headers, + body: JSON.stringify(body), + }); + + if (opts?.locale) { + req.cookies.set("NEXT_LOCALE", opts.locale); + } + + return req; +} + +const fakeSession = { + user: { + id: 1, + name: "Test User", + description: "desc", + role: "user" as const, + }, + key: { canLoginWebUi: true }, +}; + +const adminSession = { + user: { + id: -1, + name: "Admin Token", + description: "Environment admin session", + role: "admin" as const, + }, + key: { canLoginWebUi: true }, +}; + +const readonlySession = { + user: { + id: 2, + name: "Readonly User", + description: "readonly", + role: "user" as const, + }, + key: { canLoginWebUi: false }, +}; + +describe("POST /api/auth/login", () => { + let POST: (request: NextRequest) => Promise; + + beforeEach(async () => { + vi.resetModules(); + const mockT = vi.fn((key: string) => `translated:${key}`); + mockGetTranslations.mockResolvedValue(mockT); + mockSetAuthCookie.mockResolvedValue(undefined); + mockGetSessionTokenMode.mockReturnValue("legacy"); + + const mod = await import("@/app/api/auth/login/route"); + POST = mod.POST; + }); + + it("returns 400 when key is missing from body", async () => { + const res = await POST(makeRequest({})); + + expect(res.status).toBe(400); + const json = await res.json(); + expect(json).toEqual({ error: "translated:apiKeyRequired" }); + expect(mockValidateKey).not.toHaveBeenCalled(); + }); + + it("returns 400 when key is empty string", async () => { + const res = await POST(makeRequest({ key: "" })); + + expect(res.status).toBe(400); + const json = await res.json(); + expect(json).toEqual({ error: "translated:apiKeyRequired" }); + }); + + it("returns 401 when validateKey returns null", async () => { + mockValidateKey.mockResolvedValue(null); + + const res = await POST(makeRequest({ key: "bad-key" })); + + expect(res.status).toBe(401); + const json = await res.json(); + expect(json).toEqual({ error: "translated:apiKeyInvalidOrExpired" }); + expect(mockValidateKey).toHaveBeenCalledWith("bad-key", { + allowReadOnlyAccess: true, + }); + }); + + it("returns 200 with correct body shape on valid key", async () => { + mockValidateKey.mockResolvedValue(fakeSession); + mockGetLoginRedirectTarget.mockReturnValue("/dashboard"); + + const res = await POST(makeRequest({ key: "valid-key" })); + + expect(res.status).toBe(200); + const json = await res.json(); + expect(json).toEqual({ + ok: true, + user: { + id: 1, + name: "Test User", + description: "desc", + role: "user", + }, + redirectTo: "/dashboard", + loginType: "dashboard_user", + }); + }); + + it("calls setAuthCookie exactly once on success", async () => { + mockValidateKey.mockResolvedValue(fakeSession); + mockGetLoginRedirectTarget.mockReturnValue("/dashboard"); + + await POST(makeRequest({ key: "valid-key" })); + + expect(mockSetAuthCookie).toHaveBeenCalledTimes(1); + expect(mockSetAuthCookie).toHaveBeenCalledWith("valid-key"); + }); + + it("returns redirectTo from getLoginRedirectTarget", async () => { + mockValidateKey.mockResolvedValue(fakeSession); + mockGetLoginRedirectTarget.mockReturnValue("/my-usage"); + + const res = await POST(makeRequest({ key: "readonly-key" })); + const json = await res.json(); + + expect(json.redirectTo).toBe("/my-usage"); + expect(mockGetLoginRedirectTarget).toHaveBeenCalledWith(fakeSession); + }); + + it("returns loginType admin for admin session", async () => { + mockValidateKey.mockResolvedValue(adminSession); + mockGetLoginRedirectTarget.mockReturnValue("/dashboard"); + + const res = await POST(makeRequest({ key: "admin-key" })); + const json = await res.json(); + + expect(json.loginType).toBe("admin"); + expect(json.redirectTo).toBe("/dashboard"); + }); + + it("returns loginType dashboard_user for canLoginWebUi user session", async () => { + mockValidateKey.mockResolvedValue(fakeSession); + mockGetLoginRedirectTarget.mockReturnValue("/dashboard"); + + const res = await POST(makeRequest({ key: "dashboard-key" })); + const json = await res.json(); + + expect(json.loginType).toBe("dashboard_user"); + expect(json.redirectTo).toBe("/dashboard"); + }); + + it("returns loginType readonly_user for readonly session", async () => { + mockValidateKey.mockResolvedValue(readonlySession); + mockGetLoginRedirectTarget.mockReturnValue("/my-usage"); + + const res = await POST(makeRequest({ key: "readonly-key" })); + const json = await res.json(); + + expect(json.loginType).toBe("readonly_user"); + expect(json.redirectTo).toBe("/my-usage"); + }); + + it("returns 500 when validateKey throws", async () => { + mockValidateKey.mockRejectedValue(new Error("DB connection failed")); + + const res = await POST(makeRequest({ key: "some-key" })); + + expect(res.status).toBe(500); + const json = await res.json(); + expect(json).toEqual({ error: "translated:serverError" }); + expect(mockLogger.error).toHaveBeenCalled(); + }); + + it("returns 500 when request.json() throws (malformed body)", async () => { + const req = new NextRequest("http://localhost/api/auth/login", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: "not-valid-json{{{", + }); + + const res = await POST(req); + + expect(res.status).toBe(500); + const json = await res.json(); + expect(json).toEqual({ error: "translated:serverError" }); + }); + + it("uses NEXT_LOCALE cookie for translations", async () => { + mockValidateKey.mockResolvedValue(null); + + await POST(makeRequest({ key: "x" }, { locale: "ja" })); + + expect(mockGetTranslations).toHaveBeenCalledWith({ + locale: "ja", + namespace: "auth.errors", + }); + }); + + it("detects locale from accept-language header", async () => { + mockValidateKey.mockResolvedValue(null); + + await POST(makeRequest({ key: "x" }, { acceptLanguage: "ru;q=1.0" })); + + expect(mockGetTranslations).toHaveBeenCalledWith({ + locale: "ru", + namespace: "auth.errors", + }); + }); + + it("falls back to defaultLocale when getTranslations fails for requested locale", async () => { + const mockT = vi.fn((key: string) => `fallback:${key}`); + mockGetTranslations + .mockRejectedValueOnce(new Error("locale not found")) + .mockResolvedValueOnce(mockT); + mockValidateKey.mockResolvedValue(null); + + const res = await POST(makeRequest({ key: "x" }, { locale: "ja" })); + + expect(mockGetTranslations).toHaveBeenCalledTimes(2); + expect(mockGetTranslations).toHaveBeenNthCalledWith(1, { + locale: "ja", + namespace: "auth.errors", + }); + expect(mockGetTranslations).toHaveBeenNthCalledWith(2, { + locale: "zh-CN", + namespace: "auth.errors", + }); + + const json = await res.json(); + expect(json.error).toBe("fallback:apiKeyInvalidOrExpired"); + }); + + it("returns null translation when both locale and fallback fail", async () => { + mockGetTranslations + .mockRejectedValueOnce(new Error("fail")) + .mockRejectedValueOnce(new Error("fallback fail")); + mockValidateKey.mockResolvedValue(null); + + const res = await POST(makeRequest({ key: "x" })); + + expect(res.status).toBe(401); + const json = await res.json(); + expect(json).toEqual({ error: "Authentication failed" }); + expect(mockLogger.warn).toHaveBeenCalled(); + expect(mockLogger.error).toHaveBeenCalled(); + }); + + it("falls back to defaultLocale when no locale cookie or accept-language", async () => { + mockValidateKey.mockResolvedValue(null); + + await POST(makeRequest({ key: "x" })); + + expect(mockGetTranslations).toHaveBeenCalledWith({ + locale: "zh-CN", + namespace: "auth.errors", + }); + }); +}); diff --git a/tests/unit/auth/auth-cookie-constant-sync.test.ts b/tests/unit/auth/auth-cookie-constant-sync.test.ts new file mode 100644 index 000000000..ed672e8cc --- /dev/null +++ b/tests/unit/auth/auth-cookie-constant-sync.test.ts @@ -0,0 +1,23 @@ +import { readFileSync } from "node:fs"; +import { join } from "node:path"; +import { describe, expect, it } from "vitest"; +import { AUTH_COOKIE_NAME } from "@/lib/auth"; + +const readSource = (relativePath: string) => + readFileSync(join(process.cwd(), relativePath), "utf8"); + +describe("auth cookie constant sync", () => { + it("keeps AUTH_COOKIE_NAME stable", () => { + expect(AUTH_COOKIE_NAME).toBe("auth-token"); + }); + + it("removes hardcoded auth-token cookie literals from core auth layers", () => { + const proxySource = readSource("src/proxy.ts"); + const actionAdapterSource = readSource("src/lib/api/action-adapter-openapi.ts"); + + expect(proxySource).not.toMatch(/["']auth-token["']/); + expect(actionAdapterSource).not.toMatch(/["']auth-token["']/); + expect(proxySource).toContain("AUTH_COOKIE_NAME"); + expect(actionAdapterSource).toContain("AUTH_COOKIE_NAME"); + }); +}); diff --git a/tests/unit/auth/login-redirect-safety.test.ts b/tests/unit/auth/login-redirect-safety.test.ts new file mode 100644 index 000000000..2496f441f --- /dev/null +++ b/tests/unit/auth/login-redirect-safety.test.ts @@ -0,0 +1,77 @@ +import { describe, expect, it } from "vitest"; +import { + resolveLoginRedirectTarget, + sanitizeRedirectPath, +} from "@/app/[locale]/login/redirect-safety"; +import { getLoginRedirectTarget } from "@/lib/auth"; + +describe("sanitizeRedirectPath", () => { + it("keeps safe relative path /settings", () => { + expect(sanitizeRedirectPath("/settings")).toBe("/settings"); + }); + + it("keeps safe nested path /dashboard/users", () => { + expect(sanitizeRedirectPath("/dashboard/users")).toBe("/dashboard/users"); + }); + + it("rejects absolute external URL", () => { + expect(sanitizeRedirectPath("https://evil.example/phish")).toBe("/dashboard"); + }); + + it("rejects protocol-relative URL", () => { + expect(sanitizeRedirectPath("//evil.example")).toBe("/dashboard"); + }); + + it("rejects empty string", () => { + expect(sanitizeRedirectPath("")).toBe("/dashboard"); + }); + + it("keeps relative path with query string", () => { + expect(sanitizeRedirectPath("/settings?tab=general")).toBe("/settings?tab=general"); + }); + + it("rejects protocol-like path payload", () => { + expect(sanitizeRedirectPath("/https://evil.example/path")).toBe("/dashboard"); + }); +}); + +describe("resolveLoginRedirectTarget", () => { + it("always prioritizes server redirectTo over from", () => { + expect(resolveLoginRedirectTarget("/my-usage", "/settings")).toBe("/my-usage"); + expect(resolveLoginRedirectTarget("/my-usage", "https://evil.example/phish")).toBe("/my-usage"); + }); + + it("uses sanitized from when server redirectTo is empty", () => { + expect(resolveLoginRedirectTarget(undefined, "/settings")).toBe("/settings"); + expect(resolveLoginRedirectTarget("", "https://evil.example/phish")).toBe("/dashboard"); + }); +}); + +describe("getLoginRedirectTarget invariants", () => { + it("routes admin user to /dashboard", () => { + expect( + getLoginRedirectTarget({ + user: { role: "admin" } as any, + key: { canLoginWebUi: false } as any, + }) + ).toBe("/dashboard"); + }); + + it("routes canLoginWebUi user to /dashboard", () => { + expect( + getLoginRedirectTarget({ + user: { role: "user" } as any, + key: { canLoginWebUi: true } as any, + }) + ).toBe("/dashboard"); + }); + + it("routes readonly user to /my-usage", () => { + expect( + getLoginRedirectTarget({ + user: { role: "user" } as any, + key: { canLoginWebUi: false } as any, + }) + ).toBe("/my-usage"); + }); +}); diff --git a/tests/unit/auth/opaque-admin-session.test.ts b/tests/unit/auth/opaque-admin-session.test.ts new file mode 100644 index 000000000..fc7e8ad5a --- /dev/null +++ b/tests/unit/auth/opaque-admin-session.test.ts @@ -0,0 +1,137 @@ +import crypto from "node:crypto"; +import { beforeEach, describe, expect, it, vi } from "vitest"; + +// Hoisted mocks +const mockCookies = vi.hoisted(() => vi.fn()); +const mockHeaders = vi.hoisted(() => vi.fn()); +const mockGetEnvConfig = vi.hoisted(() => vi.fn()); +const mockValidateApiKeyAndGetUser = vi.hoisted(() => vi.fn()); +const mockFindKeyList = vi.hoisted(() => vi.fn()); +const mockReadSession = vi.hoisted(() => vi.fn()); +const mockCookieStore = vi.hoisted(() => ({ + get: vi.fn(), + set: vi.fn(), + delete: vi.fn(), +})); +const mockHeadersStore = vi.hoisted(() => ({ + get: vi.fn(), +})); +const mockConfig = vi.hoisted(() => ({ + auth: { adminToken: "test-admin-token-secret" }, +})); + +vi.mock("next/headers", () => ({ + cookies: mockCookies, + headers: mockHeaders, +})); + +vi.mock("@/lib/config/env.schema", () => ({ + getEnvConfig: mockGetEnvConfig, +})); + +vi.mock("@/repository/key", () => ({ + validateApiKeyAndGetUser: mockValidateApiKeyAndGetUser, + findKeyList: mockFindKeyList, +})); + +vi.mock("@/lib/auth-session-store/redis-session-store", () => ({ + RedisSessionStore: class { + read = mockReadSession; + create = vi.fn(); + revoke = vi.fn(); + rotate = vi.fn(); + }, +})); + +vi.mock("@/lib/logger", () => ({ + logger: { warn: vi.fn(), error: vi.fn(), info: vi.fn(), debug: vi.fn() }, +})); + +vi.mock("@/lib/config/config", () => ({ + config: mockConfig, +})); + +function toFingerprint(keyString: string): string { + return `sha256:${crypto.createHash("sha256").update(keyString, "utf8").digest("hex")}`; +} + +describe("opaque session with admin token (userId=-1)", () => { + beforeEach(() => { + vi.resetModules(); + vi.clearAllMocks(); + + mockCookies.mockResolvedValue(mockCookieStore); + mockHeaders.mockResolvedValue(mockHeadersStore); + mockHeadersStore.get.mockReturnValue(null); + mockCookieStore.get.mockReturnValue(undefined); + + mockGetEnvConfig.mockReturnValue({ + SESSION_TOKEN_MODE: "opaque", + ENABLE_SECURE_COOKIES: false, + }); + mockReadSession.mockResolvedValue(null); + mockFindKeyList.mockResolvedValue([]); + mockValidateApiKeyAndGetUser.mockResolvedValue(null); + mockConfig.auth.adminToken = "test-admin-token-secret"; + }); + + it("resolves admin session from opaque token with userId=-1", async () => { + const adminToken = "test-admin-token-secret"; + mockCookieStore.get.mockReturnValue({ value: "sid_admin_test" }); + mockReadSession.mockResolvedValue({ + sessionId: "sid_admin_test", + keyFingerprint: toFingerprint(adminToken), + userId: -1, + userRole: "admin", + createdAt: Date.now() - 1000, + expiresAt: Date.now() + 86400_000, + }); + + const { getSession } = await import("@/lib/auth"); + const session = await getSession(); + + expect(session).not.toBeNull(); + expect(session!.user.id).toBe(-1); + expect(session!.user.role).toBe("admin"); + expect(session!.key.name).toBe("ADMIN_TOKEN"); + // Must NOT call findKeyList -- virtual admin user has no DB keys + expect(mockFindKeyList).not.toHaveBeenCalled(); + }); + + it("returns null when admin token is not configured but session has userId=-1", async () => { + mockConfig.auth.adminToken = ""; + mockCookieStore.get.mockReturnValue({ value: "sid_admin_test" }); + mockReadSession.mockResolvedValue({ + sessionId: "sid_admin_test", + keyFingerprint: toFingerprint("test-admin-token-secret"), + userId: -1, + userRole: "admin", + createdAt: Date.now() - 1000, + expiresAt: Date.now() + 86400_000, + }); + + const { getSession } = await import("@/lib/auth"); + const session = await getSession(); + + expect(session).toBeNull(); + expect(mockFindKeyList).not.toHaveBeenCalled(); + }); + + it("returns null when fingerprint does not match admin token", async () => { + mockCookieStore.get.mockReturnValue({ value: "sid_admin_test" }); + mockReadSession.mockResolvedValue({ + sessionId: "sid_admin_test", + keyFingerprint: toFingerprint("wrong-token"), + userId: -1, + userRole: "admin", + createdAt: Date.now() - 1000, + expiresAt: Date.now() + 86400_000, + }); + + const { getSession } = await import("@/lib/auth"); + const session = await getSession(); + + expect(session).toBeNull(); + expect(mockFindKeyList).not.toHaveBeenCalled(); + }); +}); diff --git a/tests/unit/auth/set-auth-cookie-options.test.ts b/tests/unit/auth/set-auth-cookie-options.test.ts new file mode 100644 index 000000000..0e31c813c --- /dev/null +++ b/tests/unit/auth/set-auth-cookie-options.test.ts @@ -0,0 +1,110 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +const mockCookieSet = vi.hoisted(() => vi.fn()); +const mockCookies = vi.hoisted(() => vi.fn()); +const mockGetEnvConfig = vi.hoisted(() => vi.fn()); +const mockIsDevelopment = vi.hoisted(() => vi.fn(() => false)); + +vi.mock("next/headers", () => ({ + cookies: mockCookies, + headers: vi.fn().mockResolvedValue(new Headers()), +})); + +vi.mock("@/lib/config/env.schema", () => ({ + getEnvConfig: mockGetEnvConfig, + isDevelopment: mockIsDevelopment, +})); + +vi.mock("@/lib/config/config", () => ({ config: { auth: { adminToken: "test" } } })); +vi.mock("@/repository/key", () => ({ validateApiKeyAndGetUser: vi.fn() })); + +import { setAuthCookie } from "@/lib/auth"; + +describe("setAuthCookie options", () => { + beforeEach(() => { + mockCookieSet.mockClear(); + mockCookies.mockResolvedValue({ set: mockCookieSet, get: vi.fn(), delete: vi.fn() }); + }); + + describe("when ENABLE_SECURE_COOKIES is true", () => { + beforeEach(() => { + mockGetEnvConfig.mockReturnValue({ ENABLE_SECURE_COOKIES: true }); + }); + + it("sets secure=true", async () => { + await setAuthCookie("test-key-123"); + + expect(mockCookieSet).toHaveBeenCalledTimes(1); + const [, , options] = mockCookieSet.mock.calls[0]; + expect(options.secure).toBe(true); + }); + }); + + describe("when ENABLE_SECURE_COOKIES is false", () => { + beforeEach(() => { + mockGetEnvConfig.mockReturnValue({ ENABLE_SECURE_COOKIES: false }); + }); + + it("sets secure=false", async () => { + await setAuthCookie("test-key-456"); + + expect(mockCookieSet).toHaveBeenCalledTimes(1); + const [, , options] = mockCookieSet.mock.calls[0]; + expect(options.secure).toBe(false); + }); + }); + + describe("invariant cookie options", () => { + beforeEach(() => { + mockGetEnvConfig.mockReturnValue({ ENABLE_SECURE_COOKIES: true }); + }); + + it("always sets httpOnly to true", async () => { + await setAuthCookie("any-key"); + + const [, , options] = mockCookieSet.mock.calls[0]; + expect(options.httpOnly).toBe(true); + }); + + it("always sets sameSite to lax", async () => { + await setAuthCookie("any-key"); + + const [, , options] = mockCookieSet.mock.calls[0]; + expect(options.sameSite).toBe("lax"); + }); + + it("always sets maxAge to 7 days (604800 seconds)", async () => { + await setAuthCookie("any-key"); + + const [, , options] = mockCookieSet.mock.calls[0]; + expect(options.maxAge).toBe(604800); + }); + + it("always sets path to /", async () => { + await setAuthCookie("any-key"); + + const [, , options] = mockCookieSet.mock.calls[0]; + expect(options.path).toBe("/"); + }); + }); + + describe("cookie name and value", () => { + beforeEach(() => { + mockGetEnvConfig.mockReturnValue({ ENABLE_SECURE_COOKIES: true }); + }); + + it("sets cookie name to auth-token", async () => { + await setAuthCookie("my-secret-key"); + + const [name] = mockCookieSet.mock.calls[0]; + expect(name).toBe("auth-token"); + }); + + it("sets cookie value to the provided keyString", async () => { + await setAuthCookie("my-secret-key"); + + const [, value] = mockCookieSet.mock.calls[0]; + expect(value).toBe("my-secret-key"); + }); + }); +}); diff --git a/tests/unit/i18n/auth-login-keys.test.ts b/tests/unit/i18n/auth-login-keys.test.ts new file mode 100644 index 000000000..146f1018d --- /dev/null +++ b/tests/unit/i18n/auth-login-keys.test.ts @@ -0,0 +1,67 @@ +import { describe, expect, it } from "vitest"; + +import enAuth from "../../../messages/en/auth.json"; +import jaAuth from "../../../messages/ja/auth.json"; +import ruAuth from "../../../messages/ru/auth.json"; +import zhCNAuth from "../../../messages/zh-CN/auth.json"; +import zhTWAuth from "../../../messages/zh-TW/auth.json"; + +/** + * Recursively extract all dot-separated key paths from a nested object. + * e.g. { a: { b: 1, c: 2 } } -> ["a.b", "a.c"] + */ +function extractKeys(obj: Record, prefix = ""): string[] { + const keys: string[] = []; + for (const key of Object.keys(obj)) { + const fullKey = prefix ? `${prefix}.${key}` : key; + const value = obj[key]; + if (value !== null && typeof value === "object" && !Array.isArray(value)) { + keys.push(...extractKeys(value as Record, fullKey)); + } else { + keys.push(fullKey); + } + } + return keys.sort(); +} + +const locales: Record> = { + en: enAuth, + "zh-CN": zhCNAuth, + "zh-TW": zhTWAuth, + ja: jaAuth, + ru: ruAuth, +}; + +const baselineKeys = extractKeys(locales.en); + +describe("auth.json locale key parity", () => { + it("English baseline has expected top-level sections", () => { + const topLevel = Object.keys(enAuth).sort(); + expect(topLevel).toEqual( + ["actions", "brand", "errors", "form", "login", "logout", "placeholders", "security"].sort() + ); + }); + + for (const [locale, data] of Object.entries(locales)) { + if (locale === "en") continue; + + it(`${locale} has all keys present in English baseline`, () => { + const localeKeys = extractKeys(data); + const missing = baselineKeys.filter((k) => !localeKeys.includes(k)); + expect(missing, `${locale} is missing keys: ${missing.join(", ")}`).toEqual([]); + }); + + it(`${locale} has no extra keys beyond English baseline`, () => { + const localeKeys = extractKeys(data); + const extra = localeKeys.filter((k) => !baselineKeys.includes(k)); + expect(extra, `${locale} has extra keys: ${extra.join(", ")}`).toEqual([]); + }); + } + + it("all 5 locales have identical key sets", () => { + for (const [locale, data] of Object.entries(locales)) { + const localeKeys = extractKeys(data); + expect(localeKeys, `${locale} key mismatch`).toEqual(baselineKeys); + } + }); +}); diff --git a/tests/unit/lib/redis/redis-kv-store.test.ts b/tests/unit/lib/redis/redis-kv-store.test.ts new file mode 100644 index 000000000..ce5debf0b --- /dev/null +++ b/tests/unit/lib/redis/redis-kv-store.test.ts @@ -0,0 +1,259 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; + +const setexMock = vi.fn(); +const getMock = vi.fn(); +const delMock = vi.fn(); +const evalMock = vi.fn(); + +function createMockRedis(status = "ready") { + return { + status, + setex: setexMock, + get: getMock, + del: delMock, + eval: evalMock, + }; +} + +vi.mock("@/lib/logger", () => ({ + logger: { + error: vi.fn(), + warn: vi.fn(), + info: vi.fn(), + debug: vi.fn(), + }, +})); + +vi.mock("@/lib/redis/client", () => ({ + getRedisClient: vi.fn(), +})); + +vi.mock("server-only", () => ({})); + +describe("RedisKVStore", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + async function createStore(options?: { status?: string }) { + const { RedisKVStore } = await import("@/lib/redis/redis-kv-store"); + const redis = createMockRedis(options?.status); + return { + store: new RedisKVStore({ + prefix: "test:", + defaultTtlSeconds: 60, + redisClient: redis, + }), + redis, + }; + } + + describe("set", () => { + it("stores value with SETEX and default TTL", async () => { + const { store } = await createStore<{ name: string }>(); + setexMock.mockResolvedValue("OK"); + + const result = await store.set("key1", { name: "alice" }); + + expect(result).toBe(true); + expect(setexMock).toHaveBeenCalledWith("test:key1", 60, JSON.stringify({ name: "alice" })); + }); + + it("uses custom TTL when provided", async () => { + const { store } = await createStore(); + setexMock.mockResolvedValue("OK"); + + await store.set("key2", "value", 30); + + expect(setexMock).toHaveBeenCalledWith("test:key2", 30, JSON.stringify("value")); + }); + + it("returns false when Redis is not ready", async () => { + const { store } = await createStore({ status: "connecting" }); + + const result = await store.set("key3", "value"); + + expect(result).toBe(false); + expect(setexMock).not.toHaveBeenCalled(); + }); + + it("returns false when SETEX throws", async () => { + const { store } = await createStore(); + setexMock.mockRejectedValue(new Error("Redis write error")); + + const result = await store.set("key4", "value"); + + expect(result).toBe(false); + }); + }); + + describe("get", () => { + it("retrieves and deserializes stored value", async () => { + const { store } = await createStore<{ count: number }>(); + getMock.mockResolvedValue(JSON.stringify({ count: 42 })); + + const result = await store.get("key1"); + + expect(result).toEqual({ count: 42 }); + expect(getMock).toHaveBeenCalledWith("test:key1"); + }); + + it("returns null for missing key", async () => { + const { store } = await createStore(); + getMock.mockResolvedValue(null); + + const result = await store.get("missing"); + + expect(result).toBeNull(); + }); + + it("returns null when Redis is not ready", async () => { + const { store } = await createStore({ status: "connecting" }); + + const result = await store.get("key1"); + + expect(result).toBeNull(); + expect(getMock).not.toHaveBeenCalled(); + }); + + it("returns null when GET throws", async () => { + const { store } = await createStore(); + getMock.mockRejectedValue(new Error("Redis read error")); + + const result = await store.get("key1"); + + expect(result).toBeNull(); + }); + + it("returns null when stored value is malformed JSON", async () => { + const { store } = await createStore<{ count: number }>(); + getMock.mockResolvedValue("not-valid-json"); + + const result = await store.get("corrupted"); + + expect(result).toBeNull(); + }); + }); + + describe("getAndDelete", () => { + it("atomically retrieves and deletes key via Lua script", async () => { + const { store } = await createStore<{ id: string }>(); + evalMock.mockResolvedValue(JSON.stringify({ id: "abc" })); + + const result = await store.getAndDelete("key1"); + + expect(result).toEqual({ id: "abc" }); + expect(evalMock).toHaveBeenCalledWith(expect.any(String), 1, "test:key1"); + }); + + it("returns null for missing key", async () => { + const { store } = await createStore(); + evalMock.mockResolvedValue(null); + + const result = await store.getAndDelete("missing"); + + expect(result).toBeNull(); + }); + + it("returns null when Redis is not ready", async () => { + const { store } = await createStore({ status: "end" }); + + const result = await store.getAndDelete("key1"); + + expect(result).toBeNull(); + }); + + it("returns null when eval throws", async () => { + const { store } = await createStore(); + evalMock.mockRejectedValue(new Error("Redis eval error")); + + const result = await store.getAndDelete("key1"); + + expect(result).toBeNull(); + }); + + it("returns null when stored value is malformed JSON", async () => { + const { store } = await createStore<{ count: number }>(); + evalMock.mockResolvedValue("{invalid json..."); + + const result = await store.getAndDelete("corrupted-key"); + + expect(result).toBeNull(); + }); + }); + + describe("delete", () => { + it("deletes key and returns true when key existed", async () => { + const { store } = await createStore(); + delMock.mockResolvedValue(1); + + const result = await store.delete("key1"); + + expect(result).toBe(true); + expect(delMock).toHaveBeenCalledWith("test:key1"); + }); + + it("returns false when key did not exist", async () => { + const { store } = await createStore(); + delMock.mockResolvedValue(0); + + const result = await store.delete("missing"); + + expect(result).toBe(false); + }); + + it("returns false when Redis is not ready", async () => { + const { store } = await createStore({ status: "connecting" }); + + const result = await store.delete("key1"); + + expect(result).toBe(false); + }); + + it("returns false when DEL throws", async () => { + const { store } = await createStore(); + delMock.mockRejectedValue(new Error("Redis delete error")); + + const result = await store.delete("key1"); + + expect(result).toBe(false); + }); + }); + + describe("key prefixing", () => { + it("prepends prefix to all operations", async () => { + const { store } = await createStore(); + setexMock.mockResolvedValue("OK"); + getMock.mockResolvedValue(null); + delMock.mockResolvedValue(0); + + await store.set("mykey", "val"); + await store.get("mykey"); + await store.delete("mykey"); + + expect(setexMock).toHaveBeenCalledWith("test:mykey", expect.any(Number), expect.any(String)); + expect(getMock).toHaveBeenCalledWith("test:mykey"); + expect(delMock).toHaveBeenCalledWith("test:mykey"); + }); + }); + + describe("injected client", () => { + it("returns null for all ops when injected client is null", async () => { + const { RedisKVStore } = await import("@/lib/redis/redis-kv-store"); + const store = new RedisKVStore({ + prefix: "test:", + defaultTtlSeconds: 60, + redisClient: null, + }); + + expect(await store.set("k", "v")).toBe(false); + expect(await store.get("k")).toBeNull(); + expect(await store.getAndDelete("k")).toBeNull(); + expect(await store.delete("k")).toBe(false); + }); + }); +}); diff --git a/tests/unit/login/login-footer-system-name.test.tsx b/tests/unit/login/login-footer-system-name.test.tsx new file mode 100644 index 000000000..a20473278 --- /dev/null +++ b/tests/unit/login/login-footer-system-name.test.tsx @@ -0,0 +1,151 @@ +import { act } from "react"; +import { createRoot } from "react-dom/client"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import LoginPage from "../../../src/app/[locale]/login/page"; + +const mockPush = vi.hoisted(() => vi.fn()); +const mockRefresh = vi.hoisted(() => vi.fn()); +const mockUseRouter = vi.hoisted(() => vi.fn(() => ({ push: mockPush, refresh: mockRefresh }))); +const mockUseSearchParams = vi.hoisted(() => vi.fn(() => ({ get: vi.fn(() => null) }))); +const mockUseTranslations = vi.hoisted(() => vi.fn(() => (key: string) => `t:${key}`)); +const mockUseLocale = vi.hoisted(() => vi.fn(() => "en")); +const mockUsePathname = vi.hoisted(() => vi.fn(() => "/login")); + +vi.mock("next/navigation", () => ({ + useSearchParams: mockUseSearchParams, + useRouter: mockUseRouter, + usePathname: mockUsePathname, +})); + +vi.mock("next-intl", () => ({ + useTranslations: mockUseTranslations, + useLocale: mockUseLocale, +})); + +vi.mock("@/i18n/routing", () => ({ + Link: ({ children, ...props }: { children: React.ReactNode }) => {children}, + useRouter: mockUseRouter, + usePathname: mockUsePathname, +})); + +vi.mock("next-themes", () => ({ + useTheme: vi.fn(() => ({ theme: "system", setTheme: vi.fn() })), +})); + +const globalFetch = global.fetch; +const DEFAULT_SITE_TITLE = "Claude Code Hub"; + +function getRequestPath(input: string | URL | Request): string { + if (typeof input === "string") { + return input; + } + + if (input instanceof URL) { + return input.pathname; + } + + return input.url; +} + +function mockJsonResponse(payload: unknown, ok = true): Response { + return { + ok, + json: async () => payload, + } as Response; +} + +describe("LoginPage footer system name", () => { + let container: HTMLDivElement; + let root: ReturnType; + + beforeEach(() => { + container = document.createElement("div"); + document.body.appendChild(container); + root = createRoot(container); + vi.clearAllMocks(); + global.fetch = vi.fn(); + }); + + afterEach(() => { + act(() => { + root.unmount(); + }); + container.remove(); + global.fetch = globalFetch; + }); + + const render = async () => { + await act(async () => { + root.render(); + }); + }; + + const flushMicrotasks = async () => { + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); + }; + + const getSiteTitleFooter = () => + container.querySelector('[data-testid="login-site-title-footer"]'); + + it("renders configured site title when API returns it", async () => { + (global.fetch as ReturnType).mockImplementation( + (input: string | URL | Request) => { + const path = getRequestPath(input); + + if (path === "/api/system-settings") { + return Promise.resolve(mockJsonResponse({ siteTitle: "My Custom Hub" })); + } + + return Promise.resolve(mockJsonResponse({ current: "1.0.0", hasUpdate: false })); + } + ); + + await render(); + await flushMicrotasks(); + + expect(getSiteTitleFooter()).not.toBeNull(); + expect(getSiteTitleFooter()?.textContent).toBe("My Custom Hub"); + }); + + it("falls back to default title when API fails", async () => { + (global.fetch as ReturnType).mockImplementation( + (input: string | URL | Request) => { + const path = getRequestPath(input); + + if (path === "/api/system-settings") { + return Promise.resolve(mockJsonResponse({ error: "Unauthorized" }, false)); + } + + return Promise.resolve(mockJsonResponse({ current: "1.0.0", hasUpdate: false })); + } + ); + + await render(); + await flushMicrotasks(); + + expect(getSiteTitleFooter()).not.toBeNull(); + expect(getSiteTitleFooter()?.textContent).toBe(DEFAULT_SITE_TITLE); + }); + + it("shows default title while loading", async () => { + (global.fetch as ReturnType).mockImplementation( + (input: string | URL | Request) => { + const path = getRequestPath(input); + + if (path === "/api/system-settings") { + return new Promise(() => {}); + } + + return Promise.resolve(mockJsonResponse({ current: "1.0.0", hasUpdate: false })); + } + ); + + await render(); + + expect(getSiteTitleFooter()).not.toBeNull(); + expect(getSiteTitleFooter()?.textContent).toBe(DEFAULT_SITE_TITLE); + }); +}); diff --git a/tests/unit/login/login-footer-version.test.tsx b/tests/unit/login/login-footer-version.test.tsx new file mode 100644 index 000000000..349a57a32 --- /dev/null +++ b/tests/unit/login/login-footer-version.test.tsx @@ -0,0 +1,101 @@ +import { act } from "react"; +import { createRoot } from "react-dom/client"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import LoginPage from "@/app/[locale]/login/page"; + +const mockPush = vi.hoisted(() => vi.fn()); +const mockRefresh = vi.hoisted(() => vi.fn()); +const mockUseRouter = vi.hoisted(() => vi.fn(() => ({ push: mockPush, refresh: mockRefresh }))); +const mockUseSearchParams = vi.hoisted(() => vi.fn(() => ({ get: vi.fn(() => null) }))); +const mockUseTranslations = vi.hoisted(() => vi.fn(() => (key: string) => `t:${key}`)); +const mockUseLocale = vi.hoisted(() => vi.fn(() => "en")); +const mockUsePathname = vi.hoisted(() => vi.fn(() => "/login")); + +vi.mock("next/navigation", () => ({ + useSearchParams: mockUseSearchParams, + useRouter: mockUseRouter, + usePathname: mockUsePathname, +})); + +vi.mock("next-intl", () => ({ + useTranslations: mockUseTranslations, + useLocale: mockUseLocale, +})); + +vi.mock("@/i18n/routing", () => ({ + Link: ({ children, ...props }: any) => {children}, + useRouter: mockUseRouter, + usePathname: mockUsePathname, +})); + +vi.mock("next-themes", () => ({ + useTheme: vi.fn(() => ({ theme: "system", setTheme: vi.fn() })), +})); + +const globalFetch = global.fetch; + +describe("LoginPage Footer Version", () => { + let container: HTMLDivElement; + let root: ReturnType; + + beforeEach(() => { + container = document.createElement("div"); + document.body.appendChild(container); + root = createRoot(container); + vi.clearAllMocks(); + global.fetch = vi.fn(); + }); + + afterEach(() => { + act(() => { + root.unmount(); + }); + document.body.removeChild(container); + global.fetch = globalFetch; + }); + + const render = async () => { + await act(async () => { + root.render(); + }); + + await act(async () => { + await Promise.resolve(); + }); + }; + + it("shows version and update hint when hasUpdate=true", async () => { + (global.fetch as any).mockResolvedValue({ + ok: true, + json: async () => ({ current: "0.5.0", latest: "0.6.0", hasUpdate: true }), + }); + + await render(); + + expect((global.fetch as any).mock.calls[0]?.[0]).toBe("/api/version"); + const footer = container.querySelector('[data-testid="login-footer-version"]'); + expect(footer?.textContent).toContain("v0.5.0"); + expect(footer?.textContent).toContain("t:version.updateAvailable"); + }); + + it("shows version without update hint when hasUpdate=false", async () => { + (global.fetch as any).mockResolvedValue({ + ok: true, + json: async () => ({ current: "0.5.0", latest: "0.5.0", hasUpdate: false }), + }); + + await render(); + + const footer = container.querySelector('[data-testid="login-footer-version"]'); + expect(footer?.textContent).toContain("v0.5.0"); + expect(footer?.textContent).not.toContain("t:version.updateAvailable"); + }); + + it("gracefully handles version fetch error without rendering version", async () => { + (global.fetch as any).mockRejectedValue(new Error("network fail")); + + await render(); + + expect(container.querySelector('[data-testid="login-footer-version"]')).toBeNull(); + }); +}); diff --git a/tests/unit/login/login-loading-state.test.tsx b/tests/unit/login/login-loading-state.test.tsx new file mode 100644 index 000000000..00d7314e5 --- /dev/null +++ b/tests/unit/login/login-loading-state.test.tsx @@ -0,0 +1,191 @@ +import { describe, expect, it, vi, beforeEach, afterEach } from "vitest"; +import { createRoot } from "react-dom/client"; +import { act } from "react"; +import LoginPage from "@/app/[locale]/login/page"; + +const mockPush = vi.hoisted(() => vi.fn()); +const mockRefresh = vi.hoisted(() => vi.fn()); +const mockUseRouter = vi.hoisted(() => vi.fn(() => ({ push: mockPush, refresh: mockRefresh }))); +const mockUseSearchParams = vi.hoisted(() => vi.fn(() => ({ get: vi.fn(() => null) }))); +const mockUseTranslations = vi.hoisted(() => vi.fn(() => (key: string) => `t:${key}`)); +const mockUseLocale = vi.hoisted(() => vi.fn(() => "en")); +const mockUsePathname = vi.hoisted(() => vi.fn(() => "/login")); + +vi.mock("next/navigation", () => ({ + useSearchParams: mockUseSearchParams, + useRouter: mockUseRouter, + usePathname: mockUsePathname, +})); + +vi.mock("next-intl", () => ({ + useTranslations: mockUseTranslations, + useLocale: mockUseLocale, +})); + +vi.mock("@/i18n/routing", () => ({ + Link: ({ children, ...props }: any) => {children}, + useRouter: mockUseRouter, + usePathname: mockUsePathname, +})); + +vi.mock("next-themes", () => ({ + useTheme: vi.fn(() => ({ theme: "system", setTheme: vi.fn() })), +})); + +const globalFetch = global.fetch; + +describe("LoginPage Loading State", () => { + let container: HTMLDivElement; + let root: ReturnType; + + beforeEach(() => { + container = document.createElement("div"); + document.body.appendChild(container); + root = createRoot(container); + vi.clearAllMocks(); + global.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: async () => ({}), + }); + }); + + afterEach(() => { + act(() => { + root.unmount(); + }); + document.body.removeChild(container); + global.fetch = globalFetch; + }); + + const render = async () => { + await act(async () => { + root.render(); + }); + }; + + const setInputValue = (input: HTMLInputElement, value: string) => { + const nativeInputValueSetter = Object.getOwnPropertyDescriptor( + window.HTMLInputElement.prototype, + "value" + )?.set; + if (nativeInputValueSetter) { + nativeInputValueSetter.call(input, value); + } else { + input.value = value; + } + input.dispatchEvent(new Event("input", { bubbles: true })); + }; + + const getSubmitButton = () => + container.querySelector('button[type="submit"]') as HTMLButtonElement; + const getApiKeyInput = () => container.querySelector("input#apiKey") as HTMLInputElement; + const getOverlay = () => container.querySelector('[data-testid="loading-overlay"]'); + + it("starts in idle state with no overlay", async () => { + await render(); + + expect(getOverlay()).toBeNull(); + expect(getSubmitButton().disabled).toBe(true); + expect(getApiKeyInput().disabled).toBe(false); + }); + + it("shows fullscreen overlay during submission", async () => { + let resolveFetch: (value: any) => void; + const fetchPromise = new Promise((resolve) => { + resolveFetch = resolve; + }); + + (global.fetch as any).mockReturnValue(fetchPromise); + + await render(); + + const input = getApiKeyInput(); + await act(async () => { + setInputValue(input, "test-api-key"); + }); + + const button = getSubmitButton(); + await act(async () => { + button.click(); + }); + + const overlay = getOverlay(); + expect(overlay).not.toBeNull(); + expect(overlay?.textContent).toContain("t:login.loggingIn"); + expect(getSubmitButton().disabled).toBe(true); + expect(getApiKeyInput().disabled).toBe(true); + + await act(async () => { + resolveFetch!({ + ok: true, + json: async () => ({ redirectTo: "/dashboard" }), + }); + }); + }); + + it("keeps overlay on success until redirect", async () => { + (global.fetch as any).mockResolvedValue({ + ok: true, + json: async () => ({ redirectTo: "/dashboard" }), + }); + + await render(); + + const input = getApiKeyInput(); + await act(async () => { + setInputValue(input, "test-api-key"); + }); + + await act(async () => { + getSubmitButton().click(); + }); + + const overlay = getOverlay(); + expect(overlay).not.toBeNull(); + + expect(mockPush).toHaveBeenCalledWith("/dashboard"); + expect(mockRefresh).toHaveBeenCalled(); + }); + + it("removes overlay and shows error on failure", async () => { + (global.fetch as any).mockResolvedValue({ + ok: false, + json: async () => ({ error: "Invalid key" }), + }); + + await render(); + + const input = getApiKeyInput(); + await act(async () => { + setInputValue(input, "test-api-key"); + }); + + await act(async () => { + getSubmitButton().click(); + }); + + expect(getOverlay()).toBeNull(); + expect(container.textContent).toContain("Invalid key"); + expect(getSubmitButton().disabled).toBe(false); + expect(getApiKeyInput().disabled).toBe(false); + }); + + it("removes overlay and shows error on network exception", async () => { + (global.fetch as any).mockRejectedValue(new Error("Network error")); + + await render(); + + const input = getApiKeyInput(); + await act(async () => { + setInputValue(input, "test-api-key"); + }); + + await act(async () => { + getSubmitButton().click(); + }); + + expect(getOverlay()).toBeNull(); + expect(container.textContent).toContain("t:errors.networkError"); + expect(getSubmitButton().disabled).toBe(false); + }); +}); diff --git a/tests/unit/login/login-overlay-a11y.test.tsx b/tests/unit/login/login-overlay-a11y.test.tsx new file mode 100644 index 000000000..8e9311a4d --- /dev/null +++ b/tests/unit/login/login-overlay-a11y.test.tsx @@ -0,0 +1,147 @@ +import { describe, expect, it, vi, beforeEach, afterEach } from "vitest"; +import { createRoot } from "react-dom/client"; +import { act } from "react"; +import LoginPage from "@/app/[locale]/login/page"; + +const mockPush = vi.hoisted(() => vi.fn()); +const mockRefresh = vi.hoisted(() => vi.fn()); +const mockUseRouter = vi.hoisted(() => vi.fn(() => ({ push: mockPush, refresh: mockRefresh }))); +const mockUseSearchParams = vi.hoisted(() => vi.fn(() => ({ get: vi.fn(() => null) }))); +const mockUseTranslations = vi.hoisted(() => vi.fn(() => (key: string) => `t:${key}`)); +const mockUseLocale = vi.hoisted(() => vi.fn(() => "en")); +const mockUsePathname = vi.hoisted(() => vi.fn(() => "/login")); + +vi.mock("next/navigation", () => ({ + useSearchParams: mockUseSearchParams, + useRouter: mockUseRouter, + usePathname: mockUsePathname, +})); + +vi.mock("next-intl", () => ({ + useTranslations: mockUseTranslations, + useLocale: mockUseLocale, +})); + +vi.mock("@/i18n/routing", () => ({ + Link: ({ children, ...props }: any) => {children}, + useRouter: mockUseRouter, + usePathname: mockUsePathname, +})); + +vi.mock("next-themes", () => ({ + useTheme: vi.fn(() => ({ theme: "system", setTheme: vi.fn() })), +})); + +const globalFetch = global.fetch; + +describe("LoginPage Accessibility", () => { + let container: HTMLDivElement; + let root: ReturnType; + + beforeEach(() => { + container = document.createElement("div"); + document.body.appendChild(container); + root = createRoot(container); + vi.clearAllMocks(); + global.fetch = vi.fn(); + }); + + afterEach(() => { + act(() => { + root.unmount(); + }); + document.body.removeChild(container); + global.fetch = globalFetch; + }); + + const render = async () => { + await act(async () => { + root.render(); + }); + }; + + const setInputValue = (input: HTMLInputElement, value: string) => { + const nativeInputValueSetter = Object.getOwnPropertyDescriptor( + window.HTMLInputElement.prototype, + "value" + )?.set; + if (nativeInputValueSetter) { + nativeInputValueSetter.call(input, value); + } else { + input.value = value; + } + input.dispatchEvent(new Event("input", { bubbles: true })); + }; + + const getSubmitButton = () => + container.querySelector('button[type="submit"]') as HTMLButtonElement; + const getApiKeyInput = () => container.querySelector("input#apiKey") as HTMLInputElement; + const getOverlay = () => container.querySelector('[data-testid="loading-overlay"]'); + + it("loading overlay has correct ARIA attributes", async () => { + let resolveFetch: (value: any) => void; + const fetchPromise = new Promise((resolve) => { + resolveFetch = resolve; + }); + (global.fetch as any).mockReturnValue(fetchPromise); + + await render(); + + const input = getApiKeyInput(); + await act(async () => { + setInputValue(input, "test-api-key"); + }); + + const button = getSubmitButton(); + await act(async () => { + button.click(); + }); + + const overlay = getOverlay(); + expect(overlay).not.toBeNull(); + + expect(overlay?.getAttribute("role")).toBe("dialog"); + expect(overlay?.getAttribute("aria-modal")).toBe("true"); + expect(overlay?.getAttribute("aria-label")).toBe("t:login.loggingIn"); + + const statusText = overlay?.querySelector('p[role="status"]'); + expect(statusText).not.toBeNull(); + expect(statusText?.getAttribute("aria-live")).toBe("polite"); + + const spinner = overlay?.querySelector(".animate-spin"); + expect(spinner?.classList.contains("motion-reduce:animate-none")).toBe(true); + + await act(async () => { + resolveFetch!({ + ok: true, + json: async () => ({ redirectTo: "/dashboard" }), + }); + }); + }); + + it("error state manages focus and announces alert", async () => { + (global.fetch as any).mockResolvedValue({ + ok: false, + json: async () => ({ error: "Invalid key" }), + }); + + await render(); + + const input = getApiKeyInput(); + const focusSpy = vi.spyOn(input, "focus"); + + await act(async () => { + setInputValue(input, "test-api-key"); + }); + + await act(async () => { + getSubmitButton().click(); + }); + + const alert = container.querySelector('[role="alert"]'); + expect(alert).not.toBeNull(); + expect(alert?.textContent).toContain("Invalid key"); + + expect(focusSpy).toHaveBeenCalled(); + }); +}); diff --git a/tests/unit/login/login-regression-matrix.test.tsx b/tests/unit/login/login-regression-matrix.test.tsx new file mode 100644 index 000000000..e46569cd3 --- /dev/null +++ b/tests/unit/login/login-regression-matrix.test.tsx @@ -0,0 +1,230 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { NextRequest } from "next/server"; + +const mockValidateKey = vi.hoisted(() => vi.fn()); +const mockSetAuthCookie = vi.hoisted(() => vi.fn()); +const mockGetSessionTokenMode = vi.hoisted(() => vi.fn()); +const mockGetLoginRedirectTarget = vi.hoisted(() => vi.fn()); +const mockGetTranslations = vi.hoisted(() => vi.fn()); +const mockGetEnvConfig = vi.hoisted(() => vi.fn()); +const mockLogger = vi.hoisted(() => ({ + warn: vi.fn(), + error: vi.fn(), + info: vi.fn(), + debug: vi.fn(), +})); + +vi.mock("@/lib/auth", () => ({ + validateKey: mockValidateKey, + setAuthCookie: mockSetAuthCookie, + getSessionTokenMode: mockGetSessionTokenMode, + getLoginRedirectTarget: mockGetLoginRedirectTarget, + toKeyFingerprint: vi.fn().mockResolvedValue("sha256:mock"), + withNoStoreHeaders: (res: any) => { + (res as any).headers.set("Cache-Control", "no-store, no-cache, must-revalidate"); + (res as any).headers.set("Pragma", "no-cache"); + return res; + }, +})); + +vi.mock("next-intl/server", () => ({ + getTranslations: mockGetTranslations, +})); + +vi.mock("@/lib/config/env.schema", () => ({ + getEnvConfig: mockGetEnvConfig, +})); + +vi.mock("@/lib/logger", () => ({ + logger: mockLogger, +})); + +vi.mock("@/lib/security/auth-response-headers", () => ({ + withAuthResponseHeaders: (res: any) => { + (res as any).headers.set("Cache-Control", "no-store, no-cache, must-revalidate"); + (res as any).headers.set("Pragma", "no-cache"); + return res; + }, +})); + +function makeRequest(body: unknown, xForwardedProto = "https"): NextRequest { + return new NextRequest("http://localhost/api/auth/login", { + method: "POST", + headers: { + "Content-Type": "application/json", + "x-forwarded-proto": xForwardedProto, + }, + body: JSON.stringify(body), + }); +} + +const adminSession = { + user: { + id: -1, + name: "Admin Token", + description: "Environment admin session", + role: "admin" as const, + }, + key: { canLoginWebUi: true }, +}; + +const dashboardUserSession = { + user: { + id: 1, + name: "Dashboard User", + description: "dashboard", + role: "user" as const, + }, + key: { canLoginWebUi: true }, +}; + +const readonlyUserSession = { + user: { + id: 2, + name: "Readonly User", + description: "readonly", + role: "user" as const, + }, + key: { canLoginWebUi: false }, +}; + +describe("Login Regression Matrix", () => { + let POST: (request: NextRequest) => Promise; + + beforeEach(async () => { + vi.clearAllMocks(); + + const mockT = vi.fn((key: string) => `translated:${key}`); + mockGetTranslations.mockResolvedValue(mockT); + mockGetEnvConfig.mockReturnValue({ ENABLE_SECURE_COOKIES: false }); + mockSetAuthCookie.mockResolvedValue(undefined); + mockGetSessionTokenMode.mockReturnValue("legacy"); + + const mod = await import("@/app/api/auth/login/route"); + POST = mod.POST; + }); + + describe("Success Paths", () => { + it("admin user: redirectTo=/dashboard, loginType=admin", async () => { + mockValidateKey.mockResolvedValue(adminSession); + mockGetLoginRedirectTarget.mockReturnValue("/dashboard"); + + const res = await POST(makeRequest({ key: "admin-key" })); + + expect(res.status).toBe(200); + expect(await res.json()).toEqual({ + ok: true, + user: { + id: -1, + name: "Admin Token", + description: "Environment admin session", + role: "admin", + }, + redirectTo: "/dashboard", + loginType: "admin", + }); + expect(mockSetAuthCookie).toHaveBeenCalledWith("admin-key"); + expect(mockGetLoginRedirectTarget).toHaveBeenCalledWith(adminSession); + }); + + it("dashboard user: redirectTo=/dashboard, loginType=dashboard_user", async () => { + mockValidateKey.mockResolvedValue(dashboardUserSession); + mockGetLoginRedirectTarget.mockReturnValue("/dashboard"); + + const res = await POST(makeRequest({ key: "dashboard-user-key" })); + + expect(res.status).toBe(200); + expect(await res.json()).toEqual({ + ok: true, + user: { + id: 1, + name: "Dashboard User", + description: "dashboard", + role: "user", + }, + redirectTo: "/dashboard", + loginType: "dashboard_user", + }); + expect(mockSetAuthCookie).toHaveBeenCalledWith("dashboard-user-key"); + expect(mockGetLoginRedirectTarget).toHaveBeenCalledWith(dashboardUserSession); + }); + + it("readonly user: redirectTo=/my-usage, loginType=readonly_user", async () => { + mockValidateKey.mockResolvedValue(readonlyUserSession); + mockGetLoginRedirectTarget.mockReturnValue("/my-usage"); + + const res = await POST(makeRequest({ key: "readonly-user-key" })); + + expect(res.status).toBe(200); + expect(await res.json()).toEqual({ + ok: true, + user: { + id: 2, + name: "Readonly User", + description: "readonly", + role: "user", + }, + redirectTo: "/my-usage", + loginType: "readonly_user", + }); + expect(mockSetAuthCookie).toHaveBeenCalledWith("readonly-user-key"); + expect(mockGetLoginRedirectTarget).toHaveBeenCalledWith(readonlyUserSession); + }); + }); + + describe("Failure Paths", () => { + it("missing key: 400 + KEY_REQUIRED", async () => { + const res = await POST(makeRequest({})); + + expect(res.status).toBe(400); + expect(await res.json()).toEqual({ + error: "translated:apiKeyRequired", + errorCode: "KEY_REQUIRED", + }); + expect(mockValidateKey).not.toHaveBeenCalled(); + expect(mockSetAuthCookie).not.toHaveBeenCalled(); + }); + + it("invalid key: 401 + KEY_INVALID", async () => { + mockValidateKey.mockResolvedValue(null); + + const res = await POST(makeRequest({ key: "invalid-key" })); + + expect(res.status).toBe(401); + expect(await res.json()).toEqual({ + error: "translated:apiKeyInvalidOrExpired", + errorCode: "KEY_INVALID", + }); + expect(mockSetAuthCookie).not.toHaveBeenCalled(); + }); + + it("HTTP mismatch: 401 + httpMismatchGuidance", async () => { + mockGetEnvConfig.mockReturnValue({ ENABLE_SECURE_COOKIES: true }); + mockValidateKey.mockResolvedValue(null); + + const res = await POST(makeRequest({ key: "mismatch-key" }, "http")); + + expect(res.status).toBe(401); + expect(await res.json()).toEqual({ + error: "translated:apiKeyInvalidOrExpired", + errorCode: "KEY_INVALID", + httpMismatchGuidance: "translated:cookieWarningDescription", + }); + expect(mockSetAuthCookie).not.toHaveBeenCalled(); + }); + + it("server error: 500 + SERVER_ERROR", async () => { + mockValidateKey.mockRejectedValue(new Error("DB connection failed")); + + const res = await POST(makeRequest({ key: "trigger-server-error" })); + + expect(res.status).toBe(500); + expect(await res.json()).toEqual({ + error: "translated:serverError", + errorCode: "SERVER_ERROR", + }); + expect(mockSetAuthCookie).not.toHaveBeenCalled(); + expect(mockLogger.error).toHaveBeenCalled(); + }); + }); +}); diff --git a/tests/unit/login/login-ui-redesign.test.tsx b/tests/unit/login/login-ui-redesign.test.tsx new file mode 100644 index 000000000..d374a2aaf --- /dev/null +++ b/tests/unit/login/login-ui-redesign.test.tsx @@ -0,0 +1,147 @@ +import { describe, expect, it, vi, beforeEach, afterEach } from "vitest"; +import { createRoot } from "react-dom/client"; +import { act } from "react"; +import LoginPage from "@/app/[locale]/login/page"; + +const mockPush = vi.hoisted(() => vi.fn()); +const mockRefresh = vi.hoisted(() => vi.fn()); +const mockUseRouter = vi.hoisted(() => vi.fn(() => ({ push: mockPush, refresh: mockRefresh }))); +const mockUseSearchParams = vi.hoisted(() => vi.fn(() => ({ get: vi.fn(() => null) }))); +const mockUseTranslations = vi.hoisted(() => vi.fn(() => (key: string) => `t:${key}`)); +const mockUseLocale = vi.hoisted(() => vi.fn(() => "en")); +const mockUsePathname = vi.hoisted(() => vi.fn(() => "/login")); + +vi.mock("next/navigation", () => ({ + useSearchParams: mockUseSearchParams, + useRouter: mockUseRouter, + usePathname: mockUsePathname, +})); + +vi.mock("next-intl", () => ({ + useTranslations: mockUseTranslations, + useLocale: mockUseLocale, +})); + +vi.mock("@/i18n/routing", () => ({ + Link: ({ children, ...props }: any) => {children}, + useRouter: mockUseRouter, + usePathname: mockUsePathname, +})); + +vi.mock("next-themes", () => ({ + useTheme: vi.fn(() => ({ theme: "system", setTheme: vi.fn() })), +})); + +describe("LoginPage UI Redesign", () => { + let container: HTMLDivElement; + let root: ReturnType; + + beforeEach(() => { + container = document.createElement("div"); + document.body.appendChild(container); + root = createRoot(container); + vi.clearAllMocks(); + global.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: async () => ({}), + }); + }); + + afterEach(() => { + act(() => { + root.unmount(); + }); + document.body.removeChild(container); + }); + + const render = async () => { + await act(async () => { + root.render(); + }); + }; + + it("password toggle changes input type between password and text", async () => { + await render(); + + const input = container.querySelector("input#apiKey") as HTMLInputElement; + expect(input).not.toBeNull(); + expect(input.type).toBe("password"); + + const toggleButton = container.querySelector( + 'button[aria-label="t:form.showPassword"]' + ) as HTMLButtonElement; + expect(toggleButton).not.toBeNull(); + + await act(async () => { + toggleButton.click(); + }); + + expect(input.type).toBe("text"); + + const hideButton = container.querySelector( + 'button[aria-label="t:form.hidePassword"]' + ) as HTMLButtonElement; + expect(hideButton).not.toBeNull(); + + await act(async () => { + hideButton.click(); + }); + + expect(input.type).toBe("password"); + }); + + it("ThemeSwitcher renders in the top-right control area", async () => { + await render(); + + const topRightArea = container.querySelector(".fixed.top-4.right-4"); + expect(topRightArea).not.toBeNull(); + + const buttons = topRightArea?.querySelectorAll("button"); + expect(buttons?.length).toBeGreaterThanOrEqual(2); + }); + + it("brand panel has data-testid login-brand-panel", async () => { + await render(); + + const brandPanel = container.querySelector('[data-testid="login-brand-panel"]'); + expect(brandPanel).not.toBeNull(); + }); + + it("brand panel is hidden on mobile (has hidden class without lg:flex)", async () => { + await render(); + + const brandPanel = container.querySelector('[data-testid="login-brand-panel"]'); + expect(brandPanel).not.toBeNull(); + expect(brandPanel?.className).toContain("hidden"); + expect(brandPanel?.className).toContain("lg:flex"); + }); + + it("mobile brand header is visible on mobile (has lg:hidden class)", async () => { + await render(); + + const formPanel = container.querySelector(".lg\\:w-\\[55\\%\\]"); + expect(formPanel).not.toBeNull(); + + const mobileHeader = formPanel?.querySelector(".lg\\:hidden"); + expect(mobileHeader).not.toBeNull(); + }); + + it("card header icon is hidden on desktop (has lg:hidden class)", async () => { + await render(); + + const card = container.querySelector('[data-slot="card"]'); + expect(card).not.toBeNull(); + + const headerIcon = card?.querySelector(".lg\\:hidden"); + expect(headerIcon).not.toBeNull(); + }); + + it("input has padding for both key icon and toggle button", async () => { + await render(); + + const input = container.querySelector("input#apiKey") as HTMLInputElement; + expect(input).not.toBeNull(); + expect(input.className).toContain("pl-9"); + expect(input.className).toContain("pr-10"); + }); +}); diff --git a/tests/unit/login/login-visual-regression.test.tsx b/tests/unit/login/login-visual-regression.test.tsx new file mode 100644 index 000000000..f9b3abff5 --- /dev/null +++ b/tests/unit/login/login-visual-regression.test.tsx @@ -0,0 +1,98 @@ +import { describe, expect, it, vi, beforeEach, afterEach } from "vitest"; +import { createRoot } from "react-dom/client"; +import { act } from "react"; +import LoginPage from "@/app/[locale]/login/page"; + +// Mocks +const mockPush = vi.hoisted(() => vi.fn()); +const mockRefresh = vi.hoisted(() => vi.fn()); +const mockUseRouter = vi.hoisted(() => vi.fn(() => ({ push: mockPush, refresh: mockRefresh }))); +const mockUseSearchParams = vi.hoisted(() => vi.fn(() => ({ get: vi.fn(() => null) }))); +const mockUseTranslations = vi.hoisted(() => vi.fn(() => (key: string) => `t:${key}`)); +const mockUseLocale = vi.hoisted(() => vi.fn(() => "en")); +const mockUsePathname = vi.hoisted(() => vi.fn(() => "/login")); + +vi.mock("next/navigation", () => ({ + useSearchParams: mockUseSearchParams, + useRouter: mockUseRouter, + usePathname: mockUsePathname, +})); + +vi.mock("next-intl", () => ({ + useTranslations: mockUseTranslations, + useLocale: mockUseLocale, +})); + +vi.mock("@/i18n/routing", () => ({ + Link: ({ children, ...props }: any) => {children}, + useRouter: mockUseRouter, + usePathname: mockUsePathname, +})); + +vi.mock("next-themes", () => ({ + useTheme: vi.fn(() => ({ theme: "system", setTheme: vi.fn() })), +})); + +describe("LoginPage Visual Regression", () => { + let container: HTMLDivElement; + let root: ReturnType; + + beforeEach(() => { + container = document.createElement("div"); + document.body.appendChild(container); + root = createRoot(container); + vi.clearAllMocks(); + global.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: async () => ({}), + }); + }); + + afterEach(() => { + act(() => { + root.unmount(); + }); + document.body.removeChild(container); + }); + + const render = async () => { + await act(async () => { + root.render(); + }); + }; + + it("renders key structural elements", async () => { + await render(); + + const mainContainer = container.querySelector("div.min-h-screen"); + expect(mainContainer).not.toBeNull(); + const className = mainContainer?.className || ""; + expect(className).toContain("bg-gradient-to"); + + const langSwitcher = container.querySelector(".fixed.top-4.right-4"); + expect(langSwitcher).not.toBeNull(); + + const card = container.querySelector('[data-slot="card"]'); + expect(card).not.toBeNull(); + + const form = container.querySelector("form"); + expect(form).not.toBeNull(); + + const input = container.querySelector("input#apiKey"); + expect(input).not.toBeNull(); + + const button = container.querySelector('button[type="submit"]'); + expect(button).not.toBeNull(); + }); + + it("has mobile responsive classes", async () => { + await render(); + + const wrapper = container.querySelector(".max-w-lg"); + expect(wrapper).not.toBeNull(); + + const card = wrapper?.querySelector('[data-slot="card"]'); + expect(card).not.toBeNull(); + expect(card?.className).toContain("w-full"); + }); +}); diff --git a/tests/unit/proxy/proxy-auth-cookie-passthrough.test.ts b/tests/unit/proxy/proxy-auth-cookie-passthrough.test.ts new file mode 100644 index 000000000..6c3b5475b --- /dev/null +++ b/tests/unit/proxy/proxy-auth-cookie-passthrough.test.ts @@ -0,0 +1,83 @@ +import { describe, expect, it, vi } from "vitest"; + +// Hoist mocks before imports -- mock transitive dependencies to avoid +// next-intl pulling in next/navigation (not resolvable in vitest) +const mockIntlMiddleware = vi.hoisted(() => vi.fn()); +vi.mock("next-intl/middleware", () => ({ + default: () => mockIntlMiddleware, +})); + +vi.mock("@/i18n/routing", () => ({ + routing: { + locales: ["zh-CN", "en"], + defaultLocale: "zh-CN", + }, +})); + +vi.mock("@/lib/config/env.schema", () => ({ + isDevelopment: () => false, +})); + +vi.mock("@/lib/logger", () => ({ + logger: { info: vi.fn(), warn: vi.fn(), error: vi.fn() }, +})); + +function makeRequest(pathname: string, cookies: Record = {}) { + const url = new URL(`http://localhost:13500${pathname}`); + return { + method: "GET", + nextUrl: { pathname, clone: () => url }, + cookies: { + get: (name: string) => (name in cookies ? { name, value: cookies[name] } : undefined), + }, + headers: new Headers(), + } as unknown as import("next/server").NextRequest; +} + +describe("proxy auth cookie passthrough", () => { + it("redirects to login when no auth cookie is present", async () => { + const localeResponse = new Response(null, { status: 200 }); + mockIntlMiddleware.mockReturnValue(localeResponse); + + const { default: proxyHandler } = await import("@/proxy"); + const response = proxyHandler(makeRequest("/zh-CN/dashboard")); + + expect(response.status).toBeGreaterThanOrEqual(300); + expect(response.status).toBeLessThan(400); + const location = response.headers.get("location"); + expect(location).toContain("/login"); + expect(location).toContain("from="); + }); + + it("passes through when auth cookie exists without deleting it", async () => { + const localeResponse = new Response(null, { + status: 200, + headers: { "x-test": "locale-response" }, + }); + mockIntlMiddleware.mockReturnValue(localeResponse); + + const { default: proxyHandler } = await import("@/proxy"); + const response = proxyHandler( + makeRequest("/zh-CN/dashboard", { "auth-token": "sid_test-session-id" }) + ); + + // Should return the locale response, not a redirect + expect(response.headers.get("x-test")).toBe("locale-response"); + // Should NOT have a Set-Cookie header that deletes the auth cookie + const setCookie = response.headers.get("set-cookie"); + expect(setCookie).toBeNull(); + }); + + it("allows public paths without any cookie", async () => { + const localeResponse = new Response(null, { + status: 200, + headers: { "x-test": "public-ok" }, + }); + mockIntlMiddleware.mockReturnValue(localeResponse); + + const { default: proxyHandler } = await import("@/proxy"); + const response = proxyHandler(makeRequest("/zh-CN/login")); + + expect(response.headers.get("x-test")).toBe("public-ok"); + }); +}); diff --git a/tests/unit/repository/provider-batch-update-advanced-fields.test.ts b/tests/unit/repository/provider-batch-update-advanced-fields.test.ts new file mode 100644 index 000000000..21e3e1def --- /dev/null +++ b/tests/unit/repository/provider-batch-update-advanced-fields.test.ts @@ -0,0 +1,196 @@ +import { describe, expect, test, vi } from "vitest"; + +type BatchUpdateRow = { + id: number; + providerVendorId: number | null; + providerType: string; + url: string; +}; + +function createDbMock(updatedRows: BatchUpdateRow[]) { + const updateSetPayloads: Array> = []; + + const updateReturningMock = vi.fn(async () => updatedRows); + const updateWhereMock = vi.fn(() => ({ returning: updateReturningMock })); + const updateSetMock = vi.fn((payload: Record) => { + updateSetPayloads.push(payload); + return { where: updateWhereMock }; + }); + const updateMock = vi.fn(() => ({ set: updateSetMock })); + + const insertReturningMock = vi.fn(async () => []); + const insertOnConflictDoNothingMock = vi.fn(() => ({ returning: insertReturningMock })); + const insertValuesMock = vi.fn(() => ({ onConflictDoNothing: insertOnConflictDoNothingMock })); + const insertMock = vi.fn(() => ({ values: insertValuesMock })); + + return { + db: { + update: updateMock, + insert: insertMock, + }, + mocks: { + updateMock, + updateSetPayloads, + insertMock, + }, + }; +} + +async function arrange(updatedRows: BatchUpdateRow[] = []) { + vi.resetModules(); + + const dbMock = createDbMock(updatedRows); + + vi.doMock("@/drizzle/db", () => ({ db: dbMock.db })); + vi.doMock("@/lib/logger", () => ({ + logger: { + trace: vi.fn(), + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + }, + })); + + const { updateProvidersBatch } = await import("@/repository/provider"); + + return { + updateProvidersBatch, + ...dbMock.mocks, + }; +} + +describe("provider repository - updateProvidersBatch advanced fields", () => { + const updatedRows: BatchUpdateRow[] = [ + { + id: 11, + providerVendorId: 100, + providerType: "claude", + url: "https://api-one.example.com/v1/messages", + }, + { + id: 22, + providerVendorId: 100, + providerType: "claude", + url: "https://api-two.example.com/v1/messages", + }, + ]; + + test("updates modelRedirects for multiple providers", async () => { + const { updateProvidersBatch, updateSetPayloads, updateMock, insertMock } = + await arrange(updatedRows); + const modelRedirects = { + "claude-sonnet-4-5-20250929": "glm-4.6", + }; + + const result = await updateProvidersBatch([11, 22], { modelRedirects }); + + expect(result).toBe(2); + expect(updateMock).toHaveBeenCalledTimes(1); + expect(updateSetPayloads[0]).toEqual( + expect.objectContaining({ + updatedAt: expect.any(Date), + modelRedirects, + }) + ); + expect(insertMock).not.toHaveBeenCalled(); + }); + + test("updates allowedModels for multiple providers", async () => { + const { updateProvidersBatch, updateSetPayloads } = await arrange(updatedRows); + const allowedModels = ["claude-sonnet-4-5-20250929", "claude-opus-4-1-20250805"]; + + const result = await updateProvidersBatch([11, 22], { allowedModels }); + + expect(result).toBe(2); + expect(updateSetPayloads[0]).toEqual( + expect.objectContaining({ + updatedAt: expect.any(Date), + allowedModels, + }) + ); + }); + + test("updates anthropicThinkingBudgetPreference for multiple providers", async () => { + const { updateProvidersBatch, updateSetPayloads } = await arrange(updatedRows); + + const result = await updateProvidersBatch([11, 22], { + anthropicThinkingBudgetPreference: "4096", + }); + + expect(result).toBe(2); + expect(updateSetPayloads[0]).toEqual( + expect.objectContaining({ + updatedAt: expect.any(Date), + anthropicThinkingBudgetPreference: "4096", + }) + ); + }); + + test("updates anthropicAdaptiveThinking for multiple providers", async () => { + const { updateProvidersBatch, updateSetPayloads } = await arrange(updatedRows); + const anthropicAdaptiveThinking = { + effort: "high", + modelMatchMode: "specific", + models: ["claude-sonnet-4-5-20250929"], + }; + + const result = await updateProvidersBatch([11, 22], { + anthropicAdaptiveThinking, + }); + + expect(result).toBe(2); + expect(updateSetPayloads[0]).toEqual( + expect.objectContaining({ + updatedAt: expect.any(Date), + anthropicAdaptiveThinking, + }) + ); + }); + + test("does not include undefined advanced fields in set payload", async () => { + const { updateProvidersBatch, updateSetPayloads } = await arrange(updatedRows); + + const result = await updateProvidersBatch([11, 22], { + priority: 3, + modelRedirects: undefined, + allowedModels: undefined, + anthropicThinkingBudgetPreference: undefined, + anthropicAdaptiveThinking: undefined, + }); + + expect(result).toBe(2); + expect(updateSetPayloads[0]).toEqual( + expect.objectContaining({ + updatedAt: expect.any(Date), + priority: 3, + }) + ); + expect(updateSetPayloads[0]).not.toHaveProperty("modelRedirects"); + expect(updateSetPayloads[0]).not.toHaveProperty("allowedModels"); + expect(updateSetPayloads[0]).not.toHaveProperty("anthropicThinkingBudgetPreference"); + expect(updateSetPayloads[0]).not.toHaveProperty("anthropicAdaptiveThinking"); + }); + + test("writes null advanced values to clear fields", async () => { + const { updateProvidersBatch, updateSetPayloads } = await arrange(updatedRows); + + const result = await updateProvidersBatch([11, 22], { + modelRedirects: null, + allowedModels: null, + anthropicThinkingBudgetPreference: null, + anthropicAdaptiveThinking: null, + }); + + expect(result).toBe(2); + expect(updateSetPayloads[0]).toEqual( + expect.objectContaining({ + updatedAt: expect.any(Date), + modelRedirects: null, + allowedModels: null, + anthropicThinkingBudgetPreference: null, + anthropicAdaptiveThinking: null, + }) + ); + }); +}); diff --git a/tests/unit/repository/provider-restore.test.ts b/tests/unit/repository/provider-restore.test.ts new file mode 100644 index 000000000..d672b0dd4 --- /dev/null +++ b/tests/unit/repository/provider-restore.test.ts @@ -0,0 +1,300 @@ +import { describe, expect, test, vi } from "vitest"; + +type SelectRow = Record; + +function createRestoreDbHarness(options: { + selectQueue: SelectRow[][]; + updateReturningQueue?: SelectRow[][]; +}) { + const selectQueue = [...options.selectQueue]; + const updateReturningQueue = [...(options.updateReturningQueue ?? [])]; + + const selectLimitMock = vi.fn(async () => selectQueue.shift() ?? []); + const selectOrderByMock = vi.fn(() => ({ limit: selectLimitMock })); + const selectWhereMock = vi.fn(() => ({ limit: selectLimitMock, orderBy: selectOrderByMock })); + const selectFromMock = vi.fn(() => ({ where: selectWhereMock })); + const selectMock = vi.fn(() => ({ from: selectFromMock })); + + const updateReturningMock = vi.fn(async () => updateReturningQueue.shift() ?? []); + const updateWhereMock = vi.fn(() => ({ returning: updateReturningMock })); + const updateSetMock = vi.fn(() => ({ where: updateWhereMock })); + const updateMock = vi.fn(() => ({ set: updateSetMock })); + + const tx = { + select: selectMock, + update: updateMock, + }; + + const transactionMock = vi.fn(async (runInTx: (trx: typeof tx) => Promise) => { + return runInTx(tx); + }); + + return { + db: { + transaction: transactionMock, + select: selectMock, + update: updateMock, + }, + mocks: { + transactionMock, + selectLimitMock, + updateMock, + updateSetMock, + }, + }; +} + +async function setupProviderRepository(options: { + selectQueue: SelectRow[][]; + updateReturningQueue?: SelectRow[][]; +}) { + vi.resetModules(); + + const harness = createRestoreDbHarness(options); + + vi.doMock("@/drizzle/db", () => ({ + db: harness.db, + })); + + vi.doMock("@/repository/provider-endpoints", () => ({ + ensureProviderEndpointExistsForUrl: vi.fn(), + getOrCreateProviderVendorIdFromUrls: vi.fn(), + syncProviderEndpointOnProviderEdit: vi.fn(), + tryDeleteProviderVendorIfEmpty: vi.fn(), + })); + + const repository = await import("../../../src/repository/provider"); + + return { + ...repository, + harness, + }; +} + +describe("provider repository restore", () => { + test("restoreProvider restores recent soft-deleted provider and clears deletedAt", async () => { + const deletedAt = new Date(Date.now() - 15_000); + const { restoreProvider, harness } = await setupProviderRepository({ + selectQueue: [ + [ + { + id: 1, + providerVendorId: null, + providerType: "claude", + url: "https://api.example.com/v1", + deletedAt, + }, + ], + ], + updateReturningQueue: [[{ id: 1 }]], + }); + + const restored = await restoreProvider(1); + + expect(restored).toBe(true); + expect(harness.mocks.transactionMock).toHaveBeenCalledTimes(1); + expect(harness.mocks.updateMock).toHaveBeenCalledTimes(1); + expect(harness.mocks.updateSetMock).toHaveBeenCalledWith( + expect.objectContaining({ + deletedAt: null, + updatedAt: expect.any(Date), + }) + ); + }); + + test("restoreProvider returns false when provider row is already restored concurrently", async () => { + const deletedAt = new Date(Date.now() - 5_000); + const { restoreProvider, harness } = await setupProviderRepository({ + selectQueue: [ + [ + { + id: 31, + providerVendorId: null, + providerType: "claude", + url: "https://api.example.com/v1", + deletedAt, + }, + ], + ], + updateReturningQueue: [[]], + }); + + const restored = await restoreProvider(31); + + expect(restored).toBe(false); + expect(harness.mocks.updateMock).toHaveBeenCalledTimes(1); + expect(harness.mocks.selectLimitMock).toHaveBeenCalledTimes(1); + }); + + test("restoreProvider rejects provider deleted more than 60 seconds ago", async () => { + const deletedAt = new Date(Date.now() - 61_000); + const { restoreProvider, harness } = await setupProviderRepository({ + selectQueue: [ + [ + { + id: 2, + providerVendorId: null, + providerType: "claude", + url: "https://api.example.com/v1", + deletedAt, + }, + ], + ], + updateReturningQueue: [[{ id: 2 }]], + }); + + const restored = await restoreProvider(2); + + expect(restored).toBe(false); + expect(harness.mocks.updateMock).not.toHaveBeenCalled(); + }); + + test("restoreProvidersBatch restores multiple providers in a single transaction", async () => { + const recent = new Date(Date.now() - 10_000); + const { restoreProvidersBatch, harness } = await setupProviderRepository({ + selectQueue: [ + [ + { + id: 11, + providerVendorId: null, + providerType: "claude", + url: "https://api.example.com/v1", + deletedAt: recent, + }, + ], + [ + { + id: 12, + providerVendorId: null, + providerType: "claude", + url: "https://api.example.com/v1", + deletedAt: recent, + }, + ], + [], + ], + updateReturningQueue: [[{ id: 11 }], [{ id: 12 }]], + }); + + const restoredCount = await restoreProvidersBatch([11, 12, 11, 13]); + + expect(restoredCount).toBe(2); + expect(harness.mocks.transactionMock).toHaveBeenCalledTimes(1); + expect(harness.mocks.selectLimitMock).toHaveBeenCalledTimes(3); + expect(harness.mocks.updateMock).toHaveBeenCalledTimes(2); + }); + + test("restoreProvidersBatch should short-circuit for empty id list", async () => { + const { restoreProvidersBatch, harness } = await setupProviderRepository({ + selectQueue: [], + updateReturningQueue: [], + }); + + const restoredCount = await restoreProvidersBatch([]); + + expect(restoredCount).toBe(0); + expect(harness.mocks.transactionMock).not.toHaveBeenCalled(); + }); + + test("restoreProvider skips endpoint restoration when provider url is blank", async () => { + const deletedAt = new Date(Date.now() - 8_000); + const { restoreProvider, harness } = await setupProviderRepository({ + selectQueue: [ + [ + { + id: 55, + providerVendorId: 5, + providerType: "claude", + url: " ", + deletedAt, + }, + ], + ], + updateReturningQueue: [[{ id: 55 }]], + }); + + const restored = await restoreProvider(55); + + expect(restored).toBe(true); + expect(harness.mocks.selectLimitMock).toHaveBeenCalledTimes(1); + expect(harness.mocks.updateMock).toHaveBeenCalledTimes(1); + }); + + test("restoreProvider skips endpoint restoration when active provider reference exists", async () => { + const deletedAt = new Date(Date.now() - 8_000); + const { restoreProvider, harness } = await setupProviderRepository({ + selectQueue: [ + [ + { + id: 66, + providerVendorId: 8, + providerType: "claude", + url: "https://api.example.com/v1/messages", + deletedAt, + }, + ], + [{ id: 999 }], + ], + updateReturningQueue: [[{ id: 66 }]], + }); + + const restored = await restoreProvider(66); + + expect(restored).toBe(true); + expect(harness.mocks.selectLimitMock).toHaveBeenCalledTimes(2); + expect(harness.mocks.updateMock).toHaveBeenCalledTimes(1); + }); + + test("restoreProvider skips endpoint restoration when no deleted endpoint can be matched", async () => { + const deletedAt = new Date(Date.now() - 8_000); + const { restoreProvider, harness } = await setupProviderRepository({ + selectQueue: [ + [ + { + id: 67, + providerVendorId: 8, + providerType: "claude", + url: "https://api.example.com/v1/messages", + deletedAt, + }, + ], + [], + [], + [], + ], + updateReturningQueue: [[{ id: 67 }]], + }); + + const restored = await restoreProvider(67); + + expect(restored).toBe(true); + expect(harness.mocks.selectLimitMock).toHaveBeenCalledTimes(4); + expect(harness.mocks.updateMock).toHaveBeenCalledTimes(1); + }); + + test("restoreProvider skips endpoint restoration when active endpoint already exists", async () => { + const deletedAt = new Date(Date.now() - 10_000); + const { restoreProvider, harness } = await setupProviderRepository({ + selectQueue: [ + [ + { + id: 77, + providerVendorId: 9, + providerType: "claude", + url: "https://api.example.com/v1/messages", + deletedAt, + }, + ], + [], + [{ id: 9001 }], + ], + updateReturningQueue: [[{ id: 77 }]], + }); + + const restored = await restoreProvider(77); + + expect(restored).toBe(true); + expect(harness.mocks.selectLimitMock).toHaveBeenCalledTimes(3); + expect(harness.mocks.updateMock).toHaveBeenCalledTimes(1); + }); +}); diff --git a/tests/unit/settings/providers/adaptive-thinking-editor.test.tsx b/tests/unit/settings/providers/adaptive-thinking-editor.test.tsx new file mode 100644 index 000000000..dba66edf0 --- /dev/null +++ b/tests/unit/settings/providers/adaptive-thinking-editor.test.tsx @@ -0,0 +1,336 @@ +/** + * @vitest-environment happy-dom + */ + +import { act } from "react"; +import { createRoot } from "react-dom/client"; +import { describe, expect, it, vi } from "vitest"; +import { AdaptiveThinkingEditor } from "@/app/[locale]/settings/providers/_components/adaptive-thinking-editor"; +import type { AnthropicAdaptiveThinkingConfig } from "@/types/provider"; + +// Mock next-intl +vi.mock("next-intl", () => ({ + useTranslations: () => (key: string) => key, +})); + +// Mock UI components +vi.mock("@/components/ui/select", () => ({ + Select: ({ + children, + value, + onValueChange, + disabled, + }: { + children: React.ReactNode; + value: string; + onValueChange: (val: string) => void; + disabled?: boolean; + }) => ( +
+ +
+ ), + SelectTrigger: ({ children }: { children: React.ReactNode }) =>
{children}
, + SelectValue: () => null, + SelectContent: ({ children }: { children: React.ReactNode }) => <>{children}, + SelectItem: ({ value, children }: { value: string; children: React.ReactNode }) => ( + + ), +})); + +vi.mock("@/components/ui/switch", () => ({ + Switch: ({ + checked, + onCheckedChange, + disabled, + }: { + checked: boolean; + onCheckedChange: (checked: boolean) => void; + disabled?: boolean; + }) => ( + + ), +})); + +vi.mock("@/components/ui/tag-input", () => ({ + TagInput: ({ + value, + onChange, + disabled, + placeholder, + }: { + value: string[]; + onChange: (tags: string[]) => void; + disabled?: boolean; + placeholder?: string; + }) => ( + onChange(e.target.value.split(",").filter(Boolean))} + disabled={disabled} + placeholder={placeholder} + /> + ), +})); + +vi.mock("@/components/ui/tooltip", () => ({ + Tooltip: ({ children }: { children: React.ReactNode }) =>
{children}
, + TooltipTrigger: ({ children }: { children: React.ReactNode }) =>
{children}
, + TooltipContent: ({ children }: { children: React.ReactNode }) =>
{children}
, +})); + +vi.mock("./forms/provider-form/components/section-card", () => ({ + SmartInputWrapper: ({ label, children }: { label: string; children: React.ReactNode }) => ( +
+ + {children} +
+ ), + ToggleRow: ({ label, children }: { label: string; children: React.ReactNode }) => ( +
+ + {children} +
+ ), +})); + +// Mock lucide-react +vi.mock("lucide-react", () => ({ + Info: () =>
, +})); + +function render(node: React.ReactNode) { + const container = document.createElement("div"); + document.body.appendChild(container); + const root = createRoot(container); + + act(() => { + root.render(node); + }); + + return { + container, + unmount: () => { + act(() => root.unmount()); + container.remove(); + }, + }; +} + +describe("AdaptiveThinkingEditor", () => { + const defaultConfig: AnthropicAdaptiveThinkingConfig = { + effort: "medium", + modelMatchMode: "all", + models: [], + }; + + const mockOnEnabledChange = vi.fn(); + const mockOnConfigChange = vi.fn(); + + it("renders correctly in disabled state (switch off)", () => { + const { container, unmount } = render( + + ); + + const switchBtn = container.querySelector('[data-testid="switch"]'); + expect(switchBtn).toBeTruthy(); + expect(switchBtn?.textContent).toBe("Off"); + expect(container.querySelector('[data-testid="select-trigger"]')).toBeNull(); + + unmount(); + }); + + it("calls onEnabledChange when switch is clicked", () => { + const { container, unmount } = render( + + ); + + const switchBtn = container.querySelector('[data-testid="switch"]') as HTMLButtonElement; + act(() => { + switchBtn.click(); + }); + + expect(mockOnEnabledChange).toHaveBeenCalledWith(true); + + unmount(); + }); + + it("renders configuration options when enabled", () => { + const { container, unmount } = render( + + ); + + const switchBtn = container.querySelector('[data-testid="switch"]'); + expect(switchBtn?.textContent).toBe("On"); + + // Should have 2 selects: effort and mode (since mode is 'all') + const selects = container.querySelectorAll('[data-testid="select-trigger"]'); + expect(selects.length).toBe(2); + + unmount(); + }); + + it("calls onConfigChange when effort is changed", () => { + const { container, unmount } = render( + + ); + + const selects = container.querySelectorAll("select"); + // First select is effort + const effortSelect = selects[0]; + + act(() => { + effortSelect.value = "high"; + effortSelect.dispatchEvent(new Event("change", { bubbles: true })); + }); + + expect(mockOnConfigChange).toHaveBeenCalledWith({ + ...defaultConfig, + effort: "high", + }); + + unmount(); + }); + + it("calls onConfigChange when model match mode is changed", () => { + const { container, unmount } = render( + + ); + + const selects = container.querySelectorAll("select"); + // Second select is model match mode + const modeSelect = selects[1]; + + act(() => { + modeSelect.value = "specific"; + modeSelect.dispatchEvent(new Event("change", { bubbles: true })); + }); + + expect(mockOnConfigChange).toHaveBeenCalledWith({ + ...defaultConfig, + modelMatchMode: "specific", + }); + + unmount(); + }); + + it("renders model input when mode is specific", () => { + const specificConfig: AnthropicAdaptiveThinkingConfig = { + ...defaultConfig, + modelMatchMode: "specific", + }; + + const { container, unmount } = render( + + ); + + expect(container.querySelector('[data-testid="tag-input"]')).toBeTruthy(); + + unmount(); + }); + + it("calls onConfigChange when models are changed", () => { + const specificConfig: AnthropicAdaptiveThinkingConfig = { + ...defaultConfig, + modelMatchMode: "specific", + }; + + const { container, unmount } = render( + + ); + + const input = container.querySelector('[data-testid="tag-input"]') as HTMLInputElement; + + act(() => { + // Simulate typing a tag + // For standard HTML inputs, simply setting value and dispatching event works + // The Object.getOwnPropertyDescriptor trick is needed for React controlled inputs + // but here we are using a mocked input which might just need the event + const nativeInputValueSetter = Object.getOwnPropertyDescriptor( + window.HTMLInputElement.prototype, + "value" + )?.set; + nativeInputValueSetter?.call(input, "claude-3-5-sonnet"); + input.dispatchEvent(new Event("change", { bubbles: true })); + }); + + expect(mockOnConfigChange).toHaveBeenCalledWith({ + ...specificConfig, + models: ["claude-3-5-sonnet"], + }); + + unmount(); + }); + + it("passes disabled prop to children", () => { + const { container, unmount } = render( + + ); + + const switchBtn = container.querySelector('[data-testid="switch"]') as HTMLButtonElement; + expect(switchBtn.disabled).toBe(true); + + const selects = container.querySelectorAll("select"); + selects.forEach((select) => { + expect(select.disabled).toBe(true); + }); + + unmount(); + }); +}); diff --git a/tests/unit/settings/providers/build-patch-draft.test.ts b/tests/unit/settings/providers/build-patch-draft.test.ts new file mode 100644 index 000000000..c1421e6c3 --- /dev/null +++ b/tests/unit/settings/providers/build-patch-draft.test.ts @@ -0,0 +1,647 @@ +import { describe, expect, it } from "vitest"; +import { buildPatchDraftFromFormState } from "@/app/[locale]/settings/providers/_components/batch-edit/build-patch-draft"; +import type { ProviderFormState } from "@/app/[locale]/settings/providers/_components/forms/provider-form/provider-form-types"; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function createBatchState(): ProviderFormState { + return { + basic: { name: "", url: "", key: "", websiteUrl: "" }, + routing: { + providerType: "claude", + groupTag: [], + preserveClientIp: false, + modelRedirects: {}, + allowedModels: [], + priority: 0, + groupPriorities: {}, + weight: 1, + costMultiplier: 1.0, + cacheTtlPreference: "inherit", + swapCacheTtlBilling: false, + context1mPreference: "inherit", + codexReasoningEffortPreference: "inherit", + codexReasoningSummaryPreference: "inherit", + codexTextVerbosityPreference: "inherit", + codexParallelToolCallsPreference: "inherit", + anthropicMaxTokensPreference: "inherit", + anthropicThinkingBudgetPreference: "inherit", + anthropicAdaptiveThinking: null, + geminiGoogleSearchPreference: "inherit", + }, + rateLimit: { + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + limitConcurrentSessions: null, + }, + circuitBreaker: { + failureThreshold: undefined, + openDurationMinutes: undefined, + halfOpenSuccessThreshold: undefined, + maxRetryAttempts: null, + }, + network: { + proxyUrl: "", + proxyFallbackToDirect: false, + firstByteTimeoutStreamingSeconds: undefined, + streamingIdleTimeoutSeconds: undefined, + requestTimeoutNonStreamingSeconds: undefined, + }, + mcp: { + mcpPassthroughType: "none", + mcpPassthroughUrl: "", + }, + batch: { isEnabled: "no_change" }, + ui: { + activeTab: "basic", + isPending: false, + showFailureThresholdConfirm: false, + }, + }; +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("buildPatchDraftFromFormState", () => { + it("returns empty draft when no fields are dirty", () => { + const state = createBatchState(); + const dirty = new Set(); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft).toEqual({}); + }); + + it("includes isEnabled=true when dirty and set to true", () => { + const state = createBatchState(); + state.batch.isEnabled = "true"; + const dirty = new Set(["batch.isEnabled"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.is_enabled).toEqual({ set: true }); + }); + + it("includes isEnabled=false when dirty and set to false", () => { + const state = createBatchState(); + state.batch.isEnabled = "false"; + const dirty = new Set(["batch.isEnabled"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.is_enabled).toEqual({ set: false }); + }); + + it("skips isEnabled when dirty but value is no_change", () => { + const state = createBatchState(); + state.batch.isEnabled = "no_change"; + const dirty = new Set(["batch.isEnabled"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.is_enabled).toBeUndefined(); + }); + + it("sets priority when dirty", () => { + const state = createBatchState(); + state.routing.priority = 10; + const dirty = new Set(["routing.priority"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.priority).toEqual({ set: 10 }); + }); + + it("sets weight when dirty", () => { + const state = createBatchState(); + state.routing.weight = 5; + const dirty = new Set(["routing.weight"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.weight).toEqual({ set: 5 }); + }); + + it("sets costMultiplier when dirty", () => { + const state = createBatchState(); + state.routing.costMultiplier = 2.5; + const dirty = new Set(["routing.costMultiplier"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.cost_multiplier).toEqual({ set: 2.5 }); + }); + + it("clears groupTag when dirty and empty array", () => { + const state = createBatchState(); + state.routing.groupTag = []; + const dirty = new Set(["routing.groupTag"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.group_tag).toEqual({ clear: true }); + }); + + it("sets groupTag with joined value when dirty and non-empty", () => { + const state = createBatchState(); + state.routing.groupTag = ["tagA", "tagB"]; + const dirty = new Set(["routing.groupTag"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.group_tag).toEqual({ set: "tagA, tagB" }); + }); + + it("clears modelRedirects when dirty and empty object", () => { + const state = createBatchState(); + const dirty = new Set(["routing.modelRedirects"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.model_redirects).toEqual({ clear: true }); + }); + + it("sets modelRedirects when dirty and has entries", () => { + const state = createBatchState(); + state.routing.modelRedirects = { "model-a": "model-b" }; + const dirty = new Set(["routing.modelRedirects"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.model_redirects).toEqual({ set: { "model-a": "model-b" } }); + }); + + it("clears allowedModels when dirty and empty array", () => { + const state = createBatchState(); + const dirty = new Set(["routing.allowedModels"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.allowed_models).toEqual({ clear: true }); + }); + + it("sets allowedModels when dirty and non-empty", () => { + const state = createBatchState(); + state.routing.allowedModels = ["claude-opus-4-6"]; + const dirty = new Set(["routing.allowedModels"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.allowed_models).toEqual({ set: ["claude-opus-4-6"] }); + }); + + // --- inherit/clear pattern fields --- + + it("clears cacheTtlPreference when dirty and inherit", () => { + const state = createBatchState(); + const dirty = new Set(["routing.cacheTtlPreference"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.cache_ttl_preference).toEqual({ clear: true }); + }); + + it("sets cacheTtlPreference when dirty and not inherit", () => { + const state = createBatchState(); + state.routing.cacheTtlPreference = "5m"; + const dirty = new Set(["routing.cacheTtlPreference"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.cache_ttl_preference).toEqual({ set: "5m" }); + }); + + it("sets preserveClientIp when dirty", () => { + const state = createBatchState(); + state.routing.preserveClientIp = true; + const dirty = new Set(["routing.preserveClientIp"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.preserve_client_ip).toEqual({ set: true }); + }); + + it("sets swapCacheTtlBilling when dirty", () => { + const state = createBatchState(); + state.routing.swapCacheTtlBilling = true; + const dirty = new Set(["routing.swapCacheTtlBilling"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.swap_cache_ttl_billing).toEqual({ set: true }); + }); + + it("clears context1mPreference when dirty and inherit", () => { + const state = createBatchState(); + const dirty = new Set(["routing.context1mPreference"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.context_1m_preference).toEqual({ clear: true }); + }); + + it("sets context1mPreference when dirty and not inherit", () => { + const state = createBatchState(); + state.routing.context1mPreference = "force_enable"; + const dirty = new Set(["routing.context1mPreference"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.context_1m_preference).toEqual({ set: "force_enable" }); + }); + + it("clears codexReasoningEffortPreference when dirty and inherit", () => { + const state = createBatchState(); + const dirty = new Set(["routing.codexReasoningEffortPreference"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.codex_reasoning_effort_preference).toEqual({ clear: true }); + }); + + it("sets codexReasoningEffortPreference when dirty and not inherit", () => { + const state = createBatchState(); + state.routing.codexReasoningEffortPreference = "high"; + const dirty = new Set(["routing.codexReasoningEffortPreference"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.codex_reasoning_effort_preference).toEqual({ set: "high" }); + }); + + it("clears anthropicThinkingBudgetPreference when dirty and inherit", () => { + const state = createBatchState(); + const dirty = new Set(["routing.anthropicThinkingBudgetPreference"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.anthropic_thinking_budget_preference).toEqual({ clear: true }); + }); + + it("sets anthropicThinkingBudgetPreference when dirty and not inherit", () => { + const state = createBatchState(); + state.routing.anthropicThinkingBudgetPreference = "32000"; + const dirty = new Set(["routing.anthropicThinkingBudgetPreference"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.anthropic_thinking_budget_preference).toEqual({ set: "32000" }); + }); + + it("clears anthropicAdaptiveThinking when dirty and null", () => { + const state = createBatchState(); + const dirty = new Set(["routing.anthropicAdaptiveThinking"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.anthropic_adaptive_thinking).toEqual({ clear: true }); + }); + + it("sets anthropicAdaptiveThinking when dirty and configured", () => { + const state = createBatchState(); + state.routing.anthropicAdaptiveThinking = { + effort: "high", + modelMatchMode: "specific", + models: ["claude-opus-4-6"], + }; + const dirty = new Set(["routing.anthropicAdaptiveThinking"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.anthropic_adaptive_thinking).toEqual({ + set: { + effort: "high", + modelMatchMode: "specific", + models: ["claude-opus-4-6"], + }, + }); + }); + + it("clears geminiGoogleSearchPreference when dirty and inherit", () => { + const state = createBatchState(); + const dirty = new Set(["routing.geminiGoogleSearchPreference"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.gemini_google_search_preference).toEqual({ clear: true }); + }); + + it("sets geminiGoogleSearchPreference when dirty and not inherit", () => { + const state = createBatchState(); + state.routing.geminiGoogleSearchPreference = "enabled"; + const dirty = new Set(["routing.geminiGoogleSearchPreference"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.gemini_google_search_preference).toEqual({ set: "enabled" }); + }); + + // --- Rate limit fields --- + + it("clears limit5hUsd when dirty and null", () => { + const state = createBatchState(); + const dirty = new Set(["rateLimit.limit5hUsd"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.limit_5h_usd).toEqual({ clear: true }); + }); + + it("sets limit5hUsd when dirty and has value", () => { + const state = createBatchState(); + state.rateLimit.limit5hUsd = 50; + const dirty = new Set(["rateLimit.limit5hUsd"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.limit_5h_usd).toEqual({ set: 50 }); + }); + + it("sets dailyResetMode when dirty", () => { + const state = createBatchState(); + state.rateLimit.dailyResetMode = "rolling"; + const dirty = new Set(["rateLimit.dailyResetMode"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.daily_reset_mode).toEqual({ set: "rolling" }); + }); + + it("sets dailyResetTime when dirty", () => { + const state = createBatchState(); + state.rateLimit.dailyResetTime = "12:00"; + const dirty = new Set(["rateLimit.dailyResetTime"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.daily_reset_time).toEqual({ set: "12:00" }); + }); + + it("clears maxRetryAttempts when dirty and null", () => { + const state = createBatchState(); + const dirty = new Set(["circuitBreaker.maxRetryAttempts"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.max_retry_attempts).toEqual({ clear: true }); + }); + + it("sets maxRetryAttempts when dirty and has value", () => { + const state = createBatchState(); + state.circuitBreaker.maxRetryAttempts = 3; + const dirty = new Set(["circuitBreaker.maxRetryAttempts"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.max_retry_attempts).toEqual({ set: 3 }); + }); + + // --- Unit conversion: circuit breaker minutes -> ms --- + + it("converts openDurationMinutes to ms", () => { + const state = createBatchState(); + state.circuitBreaker.openDurationMinutes = 5; + const dirty = new Set(["circuitBreaker.openDurationMinutes"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.circuit_breaker_open_duration).toEqual({ set: 300000 }); + }); + + it("sets openDuration to 0 when dirty and undefined", () => { + const state = createBatchState(); + const dirty = new Set(["circuitBreaker.openDurationMinutes"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.circuit_breaker_open_duration).toEqual({ set: 0 }); + }); + + it("sets failureThreshold to 0 when dirty and undefined", () => { + const state = createBatchState(); + const dirty = new Set(["circuitBreaker.failureThreshold"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.circuit_breaker_failure_threshold).toEqual({ set: 0 }); + }); + + it("sets failureThreshold when dirty and has value", () => { + const state = createBatchState(); + state.circuitBreaker.failureThreshold = 10; + const dirty = new Set(["circuitBreaker.failureThreshold"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.circuit_breaker_failure_threshold).toEqual({ set: 10 }); + }); + + // --- Unit conversion: network seconds -> ms --- + + it("converts firstByteTimeoutStreamingSeconds to ms", () => { + const state = createBatchState(); + state.network.firstByteTimeoutStreamingSeconds = 30; + const dirty = new Set(["network.firstByteTimeoutStreamingSeconds"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.first_byte_timeout_streaming_ms).toEqual({ set: 30000 }); + }); + + it("skips firstByteTimeoutStreamingMs when dirty and undefined", () => { + const state = createBatchState(); + const dirty = new Set(["network.firstByteTimeoutStreamingSeconds"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.first_byte_timeout_streaming_ms).toBeUndefined(); + }); + + it("converts streamingIdleTimeoutSeconds to ms", () => { + const state = createBatchState(); + state.network.streamingIdleTimeoutSeconds = 120; + const dirty = new Set(["network.streamingIdleTimeoutSeconds"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.streaming_idle_timeout_ms).toEqual({ set: 120000 }); + }); + + it("converts requestTimeoutNonStreamingSeconds to ms", () => { + const state = createBatchState(); + state.network.requestTimeoutNonStreamingSeconds = 60; + const dirty = new Set(["network.requestTimeoutNonStreamingSeconds"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.request_timeout_non_streaming_ms).toEqual({ set: 60000 }); + }); + + // --- Network fields --- + + it("clears proxyUrl when dirty and empty string", () => { + const state = createBatchState(); + const dirty = new Set(["network.proxyUrl"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.proxy_url).toEqual({ clear: true }); + }); + + it("sets proxyUrl when dirty and has value", () => { + const state = createBatchState(); + state.network.proxyUrl = "socks5://proxy.example.com:1080"; + const dirty = new Set(["network.proxyUrl"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.proxy_url).toEqual({ set: "socks5://proxy.example.com:1080" }); + }); + + it("sets proxyFallbackToDirect when dirty", () => { + const state = createBatchState(); + state.network.proxyFallbackToDirect = true; + const dirty = new Set(["network.proxyFallbackToDirect"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.proxy_fallback_to_direct).toEqual({ set: true }); + }); + + // --- MCP fields --- + + it("sets mcpPassthroughType when dirty", () => { + const state = createBatchState(); + state.mcp.mcpPassthroughType = "minimax"; + const dirty = new Set(["mcp.mcpPassthroughType"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.mcp_passthrough_type).toEqual({ set: "minimax" }); + }); + + it("sets mcpPassthroughType to none when dirty", () => { + const state = createBatchState(); + const dirty = new Set(["mcp.mcpPassthroughType"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.mcp_passthrough_type).toEqual({ set: "none" }); + }); + + it("clears mcpPassthroughUrl when dirty and empty", () => { + const state = createBatchState(); + const dirty = new Set(["mcp.mcpPassthroughUrl"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.mcp_passthrough_url).toEqual({ clear: true }); + }); + + it("sets mcpPassthroughUrl when dirty and has value", () => { + const state = createBatchState(); + state.mcp.mcpPassthroughUrl = "https://mcp.example.com"; + const dirty = new Set(["mcp.mcpPassthroughUrl"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.mcp_passthrough_url).toEqual({ set: "https://mcp.example.com" }); + }); + + // --- Multi-field scenario --- + + it("only includes dirty fields in draft, ignoring non-dirty", () => { + const state = createBatchState(); + state.routing.priority = 10; + state.routing.weight = 5; + state.routing.costMultiplier = 2.0; + + // Only mark priority as dirty + const dirty = new Set(["routing.priority"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.priority).toEqual({ set: 10 }); + expect(draft.weight).toBeUndefined(); + expect(draft.cost_multiplier).toBeUndefined(); + }); + + it("handles multiple dirty fields correctly", () => { + const state = createBatchState(); + state.batch.isEnabled = "true"; + state.routing.priority = 5; + state.routing.weight = 3; + state.rateLimit.limit5hUsd = 100; + state.network.proxyUrl = "http://proxy:8080"; + + const dirty = new Set([ + "batch.isEnabled", + "routing.priority", + "routing.weight", + "rateLimit.limit5hUsd", + "network.proxyUrl", + ]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.is_enabled).toEqual({ set: true }); + expect(draft.priority).toEqual({ set: 5 }); + expect(draft.weight).toEqual({ set: 3 }); + expect(draft.limit_5h_usd).toEqual({ set: 100 }); + expect(draft.proxy_url).toEqual({ set: "http://proxy:8080" }); + // Non-dirty fields should be absent + expect(draft.cost_multiplier).toBeUndefined(); + expect(draft.group_tag).toBeUndefined(); + }); + + // --- groupPriorities --- + + it("clears groupPriorities when dirty and empty object", () => { + const state = createBatchState(); + const dirty = new Set(["routing.groupPriorities"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.group_priorities).toEqual({ clear: true }); + }); + + it("sets groupPriorities when dirty and has entries", () => { + const state = createBatchState(); + state.routing.groupPriorities = { groupA: 1, groupB: 2 }; + const dirty = new Set(["routing.groupPriorities"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.group_priorities).toEqual({ set: { groupA: 1, groupB: 2 } }); + }); + + // --- limitConcurrentSessions null -> 0 edge case --- + + it("sets limitConcurrentSessions to 0 when dirty and null", () => { + const state = createBatchState(); + const dirty = new Set(["rateLimit.limitConcurrentSessions"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.limit_concurrent_sessions).toEqual({ set: 0 }); + }); + + it("sets limitConcurrentSessions when dirty and has value", () => { + const state = createBatchState(); + state.rateLimit.limitConcurrentSessions = 20; + const dirty = new Set(["rateLimit.limitConcurrentSessions"]); + + const draft = buildPatchDraftFromFormState(state, dirty); + + expect(draft.limit_concurrent_sessions).toEqual({ set: 20 }); + }); +}); diff --git a/tests/unit/settings/providers/form-tab-nav.test.tsx b/tests/unit/settings/providers/form-tab-nav.test.tsx new file mode 100644 index 000000000..8dfeec4af --- /dev/null +++ b/tests/unit/settings/providers/form-tab-nav.test.tsx @@ -0,0 +1,213 @@ +/** + * @vitest-environment happy-dom + */ + +import { act } from "react"; +import { createRoot } from "react-dom/client"; +import { describe, expect, it, vi } from "vitest"; + +// Mock next-intl +vi.mock("next-intl", () => ({ + useTranslations: () => (key: string) => key, +})); + +// Mock framer-motion -- render motion.div as a plain div +vi.mock("framer-motion", () => ({ + motion: { + div: ({ children, layoutId, ...rest }: any) => ( +
+ {children} +
+ ), + }, +})); + +// Mock lucide-react icons used by FormTabNav +vi.mock("lucide-react", () => { + const stub = ({ className }: any) => ; + return { + FileText: stub, + Route: stub, + Gauge: stub, + Network: stub, + FlaskConical: stub, + }; +}); + +import { FormTabNav } from "@/app/[locale]/settings/providers/_components/forms/provider-form/components/form-tab-nav"; + +// --------------------------------------------------------------------------- +// Render helper (matches project convention) +// --------------------------------------------------------------------------- + +function render(node: React.ReactNode) { + const container = document.createElement("div"); + document.body.appendChild(container); + const root = createRoot(container); + + act(() => { + root.render(node); + }); + + return { + container, + unmount: () => { + act(() => root.unmount()); + container.remove(); + }, + }; +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("FormTabNav", () => { + const defaultProps = { + activeTab: "basic" as const, + onTabChange: vi.fn(), + }; + + // -- Default (vertical) layout ------------------------------------------- + + describe("default vertical layout", () => { + it("renders all 5 tabs across 3 responsive breakpoints (15 total)", () => { + const { container, unmount } = render(); + + // Desktop (5) + Tablet (5) + Mobile (5) = 15 + const buttons = container.querySelectorAll("button"); + expect(buttons.length).toBe(15); + + unmount(); + }); + + it("renders vertical sidebar nav with hidden lg:flex classes", () => { + const { container, unmount } = render(); + + const nav = container.querySelector("nav"); + expect(nav).toBeTruthy(); + expect(nav!.className).toContain("lg:flex"); + expect(nav!.className).toContain("flex-col"); + + unmount(); + }); + }); + + // -- Horizontal layout --------------------------------------------------- + + describe('layout="horizontal"', () => { + it("renders a horizontal nav bar", () => { + const { container, unmount } = render(); + + const nav = container.querySelector("nav"); + expect(nav).toBeTruthy(); + // Horizontal mode uses sticky top-0 nav with border-b + expect(nav!.className).toContain("sticky"); + expect(nav!.className).toContain("border-b"); + + unmount(); + }); + + it("has overflow-x-auto for horizontal scrolling", () => { + const { container, unmount } = render(); + + const scrollContainer = container.querySelector("nav > div"); + expect(scrollContainer).toBeTruthy(); + expect(scrollContainer!.className).toContain("overflow-x-auto"); + + unmount(); + }); + + it("highlights the active tab with text-primary", () => { + const { container, unmount } = render( + + ); + + const buttons = container.querySelectorAll("button"); + // "routing" is the second tab (index 1) + const routingBtn = buttons[1]; + expect(routingBtn.className).toContain("text-primary"); + + // Other tabs should have text-muted-foreground + const basicBtn = buttons[0]; + expect(basicBtn.className).toContain("text-muted-foreground"); + + unmount(); + }); + + it("renders motion indicator for active tab with horizontal layoutId", () => { + const { container, unmount } = render( + + ); + + const indicator = container.querySelector('[data-layout-id="activeTabIndicatorHorizontal"]'); + expect(indicator).toBeTruthy(); + + unmount(); + }); + + it("calls onTabChange when a tab is clicked", () => { + const onTabChange = vi.fn(); + const { container, unmount } = render( + + ); + + const buttons = container.querySelectorAll("button"); + // Click the "network" tab (index 3) + act(() => { + buttons[3].click(); + }); + + expect(onTabChange).toHaveBeenCalledWith("network"); + + unmount(); + }); + + it("disables all tabs when disabled prop is true", () => { + const onTabChange = vi.fn(); + const { container, unmount } = render( + + ); + + const buttons = container.querySelectorAll("button"); + for (const btn of buttons) { + expect(btn.disabled).toBe(true); + expect(btn.className).toContain("opacity-50"); + expect(btn.className).toContain("cursor-not-allowed"); + } + + // Click should not fire because button is disabled + act(() => { + buttons[2].click(); + }); + expect(onTabChange).not.toHaveBeenCalled(); + + unmount(); + }); + + it("shows status dot for tabs with warning or configured status", () => { + const { container, unmount } = render( + + ); + + const buttons = container.querySelectorAll("button"); + // routing (index 1) should have a yellow dot + const routingDot = buttons[1].querySelector(".bg-yellow-500"); + expect(routingDot).toBeTruthy(); + + // limits (index 2) should have a primary dot + const limitsDot = buttons[2].querySelector(".bg-primary"); + expect(limitsDot).toBeTruthy(); + + // basic (index 0) should have no status dot + const basicDot = buttons[0].querySelector(".rounded-full"); + expect(basicDot).toBeNull(); + + unmount(); + }); + }); +}); diff --git a/tests/unit/settings/providers/provider-batch-dialog-step1.test.tsx b/tests/unit/settings/providers/provider-batch-dialog-step1.test.tsx new file mode 100644 index 000000000..62dd1483d --- /dev/null +++ b/tests/unit/settings/providers/provider-batch-dialog-step1.test.tsx @@ -0,0 +1,482 @@ +/** + * @vitest-environment happy-dom + */ + +import { act } from "react"; +import { createRoot } from "react-dom/client"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { ProviderBatchDialog } from "@/app/[locale]/settings/providers/_components/batch-edit/provider-batch-dialog"; +import type { ProviderDisplay } from "@/types/provider"; + +// --------------------------------------------------------------------------- +// Mutable mock state for useProviderForm +// --------------------------------------------------------------------------- + +let mockDirtyFields = new Set(); +const mockDispatch = vi.fn(); +let mockActiveTab = "basic"; +const mockState = { + ui: { activeTab: mockActiveTab, isPending: false, showFailureThresholdConfirm: false }, + basic: { name: "", url: "", key: "", websiteUrl: "" }, + routing: { + providerType: "claude" as const, + groupTag: [], + preserveClientIp: false, + modelRedirects: {}, + allowedModels: [], + priority: 0, + groupPriorities: {}, + weight: 1, + costMultiplier: 1, + cacheTtlPreference: "inherit" as const, + swapCacheTtlBilling: false, + context1mPreference: "inherit" as const, + codexReasoningEffortPreference: "inherit", + codexReasoningSummaryPreference: "inherit", + codexTextVerbosityPreference: "inherit", + codexParallelToolCallsPreference: "inherit", + anthropicMaxTokensPreference: "inherit", + anthropicThinkingBudgetPreference: "inherit", + anthropicAdaptiveThinking: null, + geminiGoogleSearchPreference: "inherit", + }, + rateLimit: { + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed" as const, + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + limitConcurrentSessions: null, + }, + circuitBreaker: { + failureThreshold: undefined, + openDurationMinutes: undefined, + halfOpenSuccessThreshold: undefined, + maxRetryAttempts: null, + }, + network: { + proxyUrl: "", + proxyFallbackToDirect: false, + firstByteTimeoutStreamingSeconds: undefined, + streamingIdleTimeoutSeconds: undefined, + requestTimeoutNonStreamingSeconds: undefined, + }, + mcp: { mcpPassthroughType: "none" as const, mcpPassthroughUrl: "" }, + batch: { isEnabled: "no_change" as const }, +}; + +// --------------------------------------------------------------------------- +// Mocks +// --------------------------------------------------------------------------- + +vi.mock("next-intl", () => ({ + useTranslations: () => { + const t = (key: string, params?: Record) => { + if (params) { + let result = key; + for (const [k, v] of Object.entries(params)) { + result = result.replace(`{${k}}`, String(v)); + } + return result; + } + return key; + }; + return t; + }, +})); + +vi.mock("@tanstack/react-query", () => ({ + useQueryClient: () => ({ + invalidateQueries: vi.fn().mockResolvedValue(undefined), + }), +})); + +vi.mock("sonner", () => ({ + toast: { + success: vi.fn(), + error: vi.fn(), + }, +})); + +vi.mock("@/actions/providers", () => ({ + previewProviderBatchPatch: vi.fn().mockResolvedValue({ + ok: true, + data: { + previewToken: "tok-1", + previewRevision: "rev-1", + rows: [], + summary: { providerCount: 0, fieldCount: 0, skipCount: 0 }, + }, + }), + applyProviderBatchPatch: vi.fn().mockResolvedValue({ ok: true, data: { updatedCount: 2 } }), + undoProviderPatch: vi.fn().mockResolvedValue({ ok: true, data: { revertedCount: 2 } }), + batchDeleteProviders: vi.fn().mockResolvedValue({ ok: true, data: { deletedCount: 2 } }), + batchResetProviderCircuits: vi.fn().mockResolvedValue({ ok: true, data: { resetCount: 2 } }), +})); + +// Mock ProviderFormProvider + useProviderForm +vi.mock( + "@/app/[locale]/settings/providers/_components/forms/provider-form/provider-form-context", + () => ({ + ProviderFormProvider: ({ children }: { children: React.ReactNode }) => <>{children}, + useProviderForm: () => ({ + state: mockState, + dispatch: mockDispatch, + dirtyFields: mockDirtyFields, + mode: "batch", + }), + }) +); + +// Mock all form section components as stubs +vi.mock( + "@/app/[locale]/settings/providers/_components/forms/provider-form/sections/basic-info-section", + () => ({ + BasicInfoSection: () =>
BasicInfoSection
, + }) +); +vi.mock( + "@/app/[locale]/settings/providers/_components/forms/provider-form/sections/routing-section", + () => ({ + RoutingSection: () =>
RoutingSection
, + }) +); +vi.mock( + "@/app/[locale]/settings/providers/_components/forms/provider-form/sections/limits-section", + () => ({ + LimitsSection: () =>
LimitsSection
, + }) +); +vi.mock( + "@/app/[locale]/settings/providers/_components/forms/provider-form/sections/network-section", + () => ({ + NetworkSection: () =>
NetworkSection
, + }) +); +vi.mock( + "@/app/[locale]/settings/providers/_components/forms/provider-form/sections/testing-section", + () => ({ + TestingSection: () =>
TestingSection
, + }) +); + +// Mock FormTabNav +vi.mock( + "@/app/[locale]/settings/providers/_components/forms/provider-form/components/form-tab-nav", + () => ({ + FormTabNav: ({ activeTab }: { activeTab: string }) => ( +
+ FormTabNav +
+ ), + }) +); + +// Mock ProviderBatchPreviewStep +vi.mock( + "@/app/[locale]/settings/providers/_components/batch-edit/provider-batch-preview-step", + () => ({ + ProviderBatchPreviewStep: () =>
PreviewStep
, + }) +); + +// Mock buildPatchDraftFromFormState +vi.mock("@/app/[locale]/settings/providers/_components/batch-edit/build-patch-draft", () => ({ + buildPatchDraftFromFormState: vi.fn().mockReturnValue({}), +})); + +// UI component mocks +vi.mock("@/components/ui/dialog", () => ({ + Dialog: ({ open, children }: { open: boolean; children: React.ReactNode }) => + open ?
{children}
: null, + DialogContent: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + DialogDescription: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + DialogFooter: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + DialogHeader: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + DialogTitle: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), +})); + +vi.mock("@/components/ui/alert-dialog", () => ({ + AlertDialog: ({ open, children }: { open: boolean; children: React.ReactNode }) => + open ?
{children}
: null, + AlertDialogAction: ({ children, ...props }: any) => , + AlertDialogCancel: ({ children, ...props }: any) => , + AlertDialogContent: ({ children }: { children: React.ReactNode }) =>
{children}
, + AlertDialogDescription: ({ children }: { children: React.ReactNode }) =>
{children}
, + AlertDialogFooter: ({ children }: { children: React.ReactNode }) =>
{children}
, + AlertDialogHeader: ({ children }: { children: React.ReactNode }) =>
{children}
, + AlertDialogTitle: ({ children }: { children: React.ReactNode }) =>
{children}
, +})); + +vi.mock("@/components/ui/button", () => ({ + Button: ({ children, ...props }: any) => , +})); + +vi.mock("lucide-react", () => ({ + Loader2: () =>
, +})); + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function createMockProvider(id: number, name: string, maskedKey: string): ProviderDisplay { + return { + id, + name, + url: "https://api.example.com", + maskedKey, + isEnabled: true, + weight: 1, + priority: 0, + groupPriorities: null, + costMultiplier: 1, + groupTag: null, + providerType: "claude", + providerVendorId: null, + preserveClientIp: false, + modelRedirects: null, + allowedModels: null, + mcpPassthroughType: "none", + mcpPassthroughUrl: null, + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + limitConcurrentSessions: 10, + maxRetryAttempts: null, + circuitBreakerFailureThreshold: 5, + circuitBreakerOpenDuration: 30000, + circuitBreakerHalfOpenSuccessThreshold: 2, + proxyUrl: null, + proxyFallbackToDirect: false, + firstByteTimeoutStreamingMs: 30000, + streamingIdleTimeoutMs: 120000, + requestTimeoutNonStreamingMs: 120000, + websiteUrl: null, + faviconUrl: null, + cacheTtlPreference: null, + swapCacheTtlBilling: false, + context1mPreference: null, + codexReasoningEffortPreference: null, + codexReasoningSummaryPreference: null, + codexTextVerbosityPreference: null, + codexParallelToolCallsPreference: null, + anthropicMaxTokensPreference: null, + anthropicThinkingBudgetPreference: null, + anthropicAdaptiveThinking: null, + geminiGoogleSearchPreference: null, + tpm: null, + rpm: null, + rpd: null, + cc: null, + createdAt: "2024-01-01T00:00:00Z", + updatedAt: "2024-01-01T00:00:00Z", + }; +} + +function render(node: React.ReactNode) { + const container = document.createElement("div"); + document.body.appendChild(container); + let root: ReturnType; + act(() => { + root = createRoot(container); + root.render(node); + }); + return { + container, + unmount: () => { + act(() => root.unmount()); + container.remove(); + }, + }; +} + +// --------------------------------------------------------------------------- +// Fixtures +// --------------------------------------------------------------------------- + +const twoProviders = [ + createMockProvider(1, "Provider1", "aaaa****1111"), + createMockProvider(2, "Provider2", "bbbb****2222"), +]; + +const eightProviders = Array.from({ length: 8 }, (_, i) => + createMockProvider(i + 1, `Provider${i + 1}`, `key${i + 1}****tail${i + 1}`) +); + +function defaultProps(overrides: Record = {}) { + return { + open: true, + mode: "edit" as const, + onOpenChange: vi.fn(), + selectedProviderIds: new Set([1, 2]), + providers: twoProviders, + onSuccess: vi.fn(), + ...overrides, + }; +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("ProviderBatchDialog - Edit Mode Structure", () => { + beforeEach(() => { + vi.clearAllMocks(); + mockDirtyFields = new Set(); + mockActiveTab = "basic"; + mockState.ui.activeTab = "basic"; + }); + + it("renders edit mode with FormTabNav and basic section", () => { + const { container, unmount } = render(); + + expect(container.querySelector('[data-testid="dialog"]')).toBeTruthy(); + expect(container.querySelector('[data-testid="form-tab-nav"]')).toBeTruthy(); + expect(container.querySelector('[data-testid="basic-info-section"]')).toBeTruthy(); + + unmount(); + }); + + it("renders dialog title and description in edit step", () => { + const { container, unmount } = render(); + + const titleEl = container.querySelector('[data-testid="dialog-title"]'); + expect(titleEl?.textContent).toContain("dialog.editTitle"); + + const descEl = container.querySelector('[data-testid="dialog-description"]'); + expect(descEl?.textContent).toContain("dialog.editDesc"); + + unmount(); + }); + + it("next button is disabled when no dirty fields", () => { + const { container, unmount } = render(); + + const footer = container.querySelector('[data-testid="dialog-footer"]'); + const buttons = footer?.querySelectorAll("button") ?? []; + // Second button in footer is "Next" (first is "Cancel") + const nextButton = buttons[1] as HTMLButtonElement; + + expect(nextButton).toBeTruthy(); + expect(nextButton.disabled).toBe(true); + + unmount(); + }); + + it("next button is enabled when dirty fields exist", () => { + mockDirtyFields = new Set(["routing.priority"]); + + const { container, unmount } = render(); + + const footer = container.querySelector('[data-testid="dialog-footer"]'); + const buttons = footer?.querySelectorAll("button") ?? []; + const nextButton = buttons[1] as HTMLButtonElement; + + expect(nextButton).toBeTruthy(); + expect(nextButton.disabled).toBe(false); + + unmount(); + }); + + it("cancel button calls onOpenChange(false)", () => { + const onOpenChange = vi.fn(); + const { container, unmount } = render( + + ); + + const footer = container.querySelector('[data-testid="dialog-footer"]'); + const buttons = footer?.querySelectorAll("button") ?? []; + const cancelButton = buttons[0] as HTMLButtonElement; + + act(() => { + cancelButton.click(); + }); + + expect(onOpenChange).toHaveBeenCalledWith(false); + + unmount(); + }); + + it("next button calls preview when dirty fields exist", async () => { + mockDirtyFields = new Set(["routing.priority"]); + const { previewProviderBatchPatch } = await import("@/actions/providers"); + + const { container, unmount } = render(); + + const footer = container.querySelector('[data-testid="dialog-footer"]'); + const nextButton = (footer?.querySelectorAll("button") ?? [])[1] as HTMLButtonElement; + + await act(async () => { + nextButton.click(); + }); + await act(async () => { + await new Promise((r) => setTimeout(r, 10)); + }); + + expect(previewProviderBatchPatch).toHaveBeenCalledTimes(1); + + unmount(); + }); +}); + +describe("ProviderBatchDialog - Delete Mode", () => { + it("renders AlertDialog for delete mode", () => { + const { container, unmount } = render( + + ); + + expect(container.querySelector('[data-testid="alert-dialog"]')).toBeTruthy(); + expect(container.querySelector('[data-testid="dialog"]')).toBeFalsy(); + + const text = container.textContent ?? ""; + expect(text).toContain("dialog.deleteTitle"); + + unmount(); + }); +}); + +describe("ProviderBatchDialog - Reset Circuit Mode", () => { + it("renders AlertDialog for resetCircuit mode", () => { + const { container, unmount } = render( + + ); + + expect(container.querySelector('[data-testid="alert-dialog"]')).toBeTruthy(); + expect(container.querySelector('[data-testid="dialog"]')).toBeFalsy(); + + const text = container.textContent ?? ""; + expect(text).toContain("dialog.resetCircuitTitle"); + + unmount(); + }); +}); + +describe("ProviderBatchDialog - Closed State", () => { + it("renders nothing when open is false", () => { + const { container, unmount } = render( + + ); + + expect(container.querySelector('[data-testid="dialog"]')).toBeFalsy(); + expect(container.querySelector('[data-testid="alert-dialog"]')).toBeFalsy(); + + unmount(); + }); +}); diff --git a/tests/unit/settings/providers/provider-batch-preview-step.test.tsx b/tests/unit/settings/providers/provider-batch-preview-step.test.tsx new file mode 100644 index 000000000..5bdd09c5c --- /dev/null +++ b/tests/unit/settings/providers/provider-batch-preview-step.test.tsx @@ -0,0 +1,296 @@ +/** + * @vitest-environment happy-dom + */ + +import { act } from "react"; +import { createRoot } from "react-dom/client"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import type { ProviderBatchPreviewRow } from "@/actions/providers"; +import { ProviderBatchPreviewStep } from "@/app/[locale]/settings/providers/_components/batch-edit/provider-batch-preview-step"; + +// --------------------------------------------------------------------------- +// Mocks +// --------------------------------------------------------------------------- + +vi.mock("next-intl", () => ({ + useTranslations: () => { + const t = (key: string, params?: Record) => { + if (params) { + let result = key; + for (const [k, v] of Object.entries(params)) { + result = result.replace(`{${k}}`, String(v)); + } + return result; + } + return key; + }; + return t; + }, +})); + +vi.mock("@/components/ui/checkbox", () => ({ + Checkbox: ({ checked, onCheckedChange, ...props }: any) => ( + onCheckedChange?.(!checked)} + {...props} + /> + ), +})); + +vi.mock("lucide-react", () => ({ + Loader2: () =>
, +})); + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function render(ui: React.ReactElement) { + const container = document.createElement("div"); + document.body.appendChild(container); + let root: ReturnType; + act(() => { + root = createRoot(container); + root.render(ui); + }); + return { + container, + unmount: () => { + act(() => root.unmount()); + container.remove(); + }, + }; +} + +function makeRow(overrides: Partial = {}): ProviderBatchPreviewRow { + return { + providerId: 1, + providerName: "TestProvider", + field: "priority", + status: "changed", + before: 0, + after: 10, + ...overrides, + }; +} + +const defaultSummary = { providerCount: 2, fieldCount: 3, skipCount: 1 }; + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("ProviderBatchPreviewStep", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("renders changed rows with before/after values", () => { + const rows: ProviderBatchPreviewRow[] = [ + makeRow({ providerId: 1, providerName: "Alpha", field: "priority", before: 0, after: 5 }), + makeRow({ providerId: 1, providerName: "Alpha", field: "weight", before: 1, after: 10 }), + ]; + + const { container, unmount } = render( + {}} + /> + ); + + const changedRow1 = container.querySelector('[data-testid="preview-row-1-priority"]'); + expect(changedRow1).toBeTruthy(); + expect(changedRow1?.getAttribute("data-status")).toBe("changed"); + // Mock t() returns key with params substituted where {param} appears in key + // "preview.fieldChanged" does not contain {field} etc, so text is key with params inserted + expect(changedRow1?.textContent).toContain("preview.fieldChanged"); + + const changedRow2 = container.querySelector('[data-testid="preview-row-1-weight"]'); + expect(changedRow2).toBeTruthy(); + expect(changedRow2?.getAttribute("data-status")).toBe("changed"); + + unmount(); + }); + + it("renders skipped rows with skip reason", () => { + const rows: ProviderBatchPreviewRow[] = [ + makeRow({ + providerId: 2, + providerName: "Beta", + field: "anthropic_thinking_budget_preference", + status: "skipped", + before: null, + after: null, + skipReason: "not_applicable", + }), + ]; + + const { container, unmount } = render( + {}} + /> + ); + + const skippedRow = container.querySelector( + '[data-testid="preview-row-2-anthropic_thinking_budget_preference"]' + ); + expect(skippedRow).toBeTruthy(); + expect(skippedRow?.getAttribute("data-status")).toBe("skipped"); + expect(skippedRow?.textContent).toContain("preview.fieldSkipped"); + + unmount(); + }); + + it("groups rows by provider", () => { + const rows: ProviderBatchPreviewRow[] = [ + makeRow({ providerId: 1, providerName: "Alpha", field: "priority" }), + makeRow({ providerId: 2, providerName: "Beta", field: "weight" }), + makeRow({ providerId: 1, providerName: "Alpha", field: "is_enabled" }), + ]; + + const { container, unmount } = render( + {}} + /> + ); + + const provider1 = container.querySelector('[data-testid="preview-provider-1"]'); + const provider2 = container.querySelector('[data-testid="preview-provider-2"]'); + expect(provider1).toBeTruthy(); + expect(provider2).toBeTruthy(); + + // Provider 1 should have 2 rows + const p1Rows = provider1?.querySelectorAll("[data-status]"); + expect(p1Rows?.length).toBe(2); + + // Provider 2 should have 1 row + const p2Rows = provider2?.querySelectorAll("[data-status]"); + expect(p2Rows?.length).toBe(1); + + unmount(); + }); + + it("shows summary counts", () => { + const rows: ProviderBatchPreviewRow[] = [ + makeRow({ providerId: 1, providerName: "Alpha", field: "priority" }), + ]; + + const { container, unmount } = render( + {}} + /> + ); + + const summary = container.querySelector('[data-testid="preview-summary"]'); + expect(summary).toBeTruthy(); + // The mock t() substitutes {providerCount} -> 5, {fieldCount} -> 8, {skipCount} -> 2 + // into the key "preview.summary" which becomes "preview.summary" with params replaced + const text = summary?.textContent ?? ""; + expect(text).toContain("preview.summary"); + + unmount(); + }); + + it("exclusion checkbox toggles provider", () => { + const onToggle = vi.fn(); + const rows: ProviderBatchPreviewRow[] = [ + makeRow({ providerId: 3, providerName: "Gamma", field: "priority" }), + ]; + + const { container, unmount } = render( + + ); + + const checkbox = container.querySelector( + '[data-testid="exclude-checkbox-3"]' + ) as HTMLInputElement; + expect(checkbox).toBeTruthy(); + expect(checkbox.checked).toBe(true); // not excluded = checked + + act(() => { + checkbox.click(); + }); + + expect(onToggle).toHaveBeenCalledWith(3); + + unmount(); + }); + + it("loading state shows spinner", () => { + const { container, unmount } = render( + {}} + isLoading={true} + /> + ); + + const loading = container.querySelector('[data-testid="preview-loading"]'); + expect(loading).toBeTruthy(); + + // Should not show the empty state + const empty = container.querySelector('[data-testid="preview-empty"]'); + expect(empty).toBeNull(); + + unmount(); + }); + + it("shows empty state when no rows and not loading", () => { + const { container, unmount } = render( + {}} + /> + ); + + const empty = container.querySelector('[data-testid="preview-empty"]'); + expect(empty).toBeTruthy(); + + unmount(); + }); + + it("excluded provider checkbox shows unchecked", () => { + const rows: ProviderBatchPreviewRow[] = [ + makeRow({ providerId: 7, providerName: "Excluded", field: "weight" }), + ]; + + const { container, unmount } = render( + {}} + /> + ); + + const checkbox = container.querySelector( + '[data-testid="exclude-checkbox-7"]' + ) as HTMLInputElement; + expect(checkbox).toBeTruthy(); + expect(checkbox.checked).toBe(false); // excluded = unchecked + + unmount(); + }); +}); diff --git a/tests/unit/settings/providers/provider-batch-toolbar-selection.test.tsx b/tests/unit/settings/providers/provider-batch-toolbar-selection.test.tsx new file mode 100644 index 000000000..d6ba27294 --- /dev/null +++ b/tests/unit/settings/providers/provider-batch-toolbar-selection.test.tsx @@ -0,0 +1,246 @@ +/** + * @vitest-environment happy-dom + */ + +import { act } from "react"; +import { createRoot } from "react-dom/client"; +import { describe, expect, it, vi } from "vitest"; +import type { ProviderDisplay, ProviderType } from "@/types/provider"; + +// Mock next-intl +vi.mock("next-intl", () => ({ + useTranslations: () => (key: string) => key, +})); + +// Mock UI components +vi.mock("@/components/ui/button", () => ({ + Button: ({ children, ...props }: any) => , +})); + +vi.mock("@/components/ui/checkbox", () => ({ + Checkbox: ({ checked, onCheckedChange, ...props }: any) => ( + onCheckedChange?.(e.target.checked)} + {...props} + /> + ), +})); + +vi.mock("@/components/ui/dropdown-menu", () => ({ + DropdownMenu: ({ children }: any) =>
{children}
, + DropdownMenuTrigger: ({ children }: any) => ( +
{children}
+ ), + DropdownMenuContent: ({ children }: any) => ( +
{children}
+ ), + DropdownMenuItem: ({ children, onClick, ...props }: any) => ( +
+ {children} +
+ ), +})); + +// Mock lucide-react +vi.mock("lucide-react", () => ({ + ChevronDown: () => , + Pencil: () => , + X: () => , +})); + +function createProvider( + id: number, + providerType: ProviderType, + groupTag: string | null = null +): ProviderDisplay { + return { id, providerType, groupTag } as ProviderDisplay; +} + +function render(node: React.ReactNode) { + const container = document.createElement("div"); + document.body.appendChild(container); + const root = createRoot(container); + + act(() => { + root.render(node); + }); + + return { + container, + unmount: () => { + act(() => root.unmount()); + container.remove(); + }, + }; +} + +// Import after mocks +import { ProviderBatchToolbar } from "@/app/[locale]/settings/providers/_components/batch-edit/provider-batch-toolbar"; + +const defaultProps = { + isMultiSelectMode: false, + allSelected: false, + selectedCount: 0, + totalCount: 5, + onEnterMode: vi.fn(), + onExitMode: vi.fn(), + onSelectAll: vi.fn(), + onInvertSelection: vi.fn(), + onOpenBatchEdit: vi.fn(), + providers: [] as ProviderDisplay[], + onSelectByType: vi.fn(), + onSelectByGroup: vi.fn(), +}; + +describe("ProviderBatchToolbar - Selection enhancements", () => { + it("does NOT render type/group dropdowns when NOT in multi-select mode", () => { + const providers = [createProvider(1, "claude"), createProvider(2, "openai-compatible")]; + + const { container, unmount } = render( + + ); + + const dropdowns = container.querySelectorAll('[data-testid="dropdown-menu"]'); + expect(dropdowns.length).toBe(0); + + unmount(); + }); + + it("renders Select by Type dropdown in multi-select mode when providers have multiple types", () => { + const providers = [ + createProvider(1, "claude"), + createProvider(2, "claude"), + createProvider(3, "openai-compatible"), + ]; + + const { container, unmount } = render( + + ); + + const buttons = container.querySelectorAll("button"); + const typeButton = Array.from(buttons).find((b) => b.textContent?.includes("selectByType")); + expect(typeButton).toBeTruthy(); + + const items = container.querySelectorAll('[data-testid="dropdown-menu-item"]'); + const typeItems = Array.from(items).filter( + (item) => + item.getAttribute("data-value") === "claude" || + item.getAttribute("data-value") === "openai-compatible" + ); + expect(typeItems.length).toBe(2); + + unmount(); + }); + + it("calls onSelectByType with correct type when clicking a type option", () => { + const onSelectByType = vi.fn(); + const providers = [createProvider(1, "claude"), createProvider(2, "openai-compatible")]; + + const { container, unmount } = render( + + ); + + const claudeItem = container.querySelector('[data-value="claude"]'); + expect(claudeItem).toBeTruthy(); + + act(() => { + claudeItem!.dispatchEvent(new MouseEvent("click", { bubbles: true })); + }); + + expect(onSelectByType).toHaveBeenCalledWith("claude"); + + unmount(); + }); + + it("renders Select by Group dropdown when providers have groups", () => { + const providers = [ + createProvider(1, "claude", "production"), + createProvider(2, "claude", "staging"), + createProvider(3, "claude", "production"), + ]; + + const { container, unmount } = render( + + ); + + const buttons = container.querySelectorAll("button"); + const groupButton = Array.from(buttons).find((b) => b.textContent?.includes("selectByGroup")); + expect(groupButton).toBeTruthy(); + + const items = container.querySelectorAll('[data-testid="dropdown-menu-item"]'); + const groupItems = Array.from(items).filter( + (item) => + item.getAttribute("data-value") === "production" || + item.getAttribute("data-value") === "staging" + ); + expect(groupItems.length).toBe(2); + + unmount(); + }); + + it("calls onSelectByGroup with correct group when clicking a group option", () => { + const onSelectByGroup = vi.fn(); + const providers = [ + createProvider(1, "claude", "production"), + createProvider(2, "claude", "staging"), + ]; + + const { container, unmount } = render( + + ); + + const productionItem = container.querySelector('[data-value="production"]'); + expect(productionItem).toBeTruthy(); + + act(() => { + productionItem!.dispatchEvent(new MouseEvent("click", { bubbles: true })); + }); + + expect(onSelectByGroup).toHaveBeenCalledWith("production"); + + unmount(); + }); + + it("does NOT render type dropdown when all filtered providers have same type", () => { + const providers = [createProvider(1, "claude"), createProvider(2, "claude")]; + + const { container, unmount } = render( + + ); + + const buttons = container.querySelectorAll("button"); + const typeButton = Array.from(buttons).find((b) => b.textContent?.includes("selectByType")); + expect(typeButton).toBeFalsy(); + + unmount(); + }); + + it("does NOT render group dropdown when no groups exist", () => { + const providers = [ + createProvider(1, "claude", null), + createProvider(2, "openai-compatible", null), + ]; + + const { container, unmount } = render( + + ); + + const buttons = container.querySelectorAll("button"); + const groupButton = Array.from(buttons).find((b) => b.textContent?.includes("selectByGroup")); + expect(groupButton).toBeFalsy(); + + unmount(); + }); +}); diff --git a/tests/unit/settings/providers/provider-batch-toolbar.test.tsx b/tests/unit/settings/providers/provider-batch-toolbar.test.tsx new file mode 100644 index 000000000..c0967f6bd --- /dev/null +++ b/tests/unit/settings/providers/provider-batch-toolbar.test.tsx @@ -0,0 +1,215 @@ +/** + * @vitest-environment happy-dom + */ + +import { act } from "react"; +import { createRoot } from "react-dom/client"; +import { describe, expect, it, vi } from "vitest"; +import type { ProviderDisplay, ProviderType } from "@/types/provider"; + +vi.mock("next-intl", () => ({ + useTranslations: () => (key: string, params?: Record) => { + if (params) { + let result = key; + for (const [k, v] of Object.entries(params)) { + result = result.replace(`{${k}}`, String(v)); + } + return result; + } + return key; + }, +})); + +vi.mock("@/components/ui/button", () => ({ + Button: ({ children, ...props }: any) => , +})); + +vi.mock("@/components/ui/checkbox", () => ({ + Checkbox: ({ checked, onCheckedChange, ...props }: any) => ( + onCheckedChange?.(e.target.checked)} + {...props} + /> + ), +})); + +vi.mock("@/components/ui/dropdown-menu", () => ({ + DropdownMenu: ({ children }: any) =>
{children}
, + DropdownMenuTrigger: ({ children }: any) =>
{children}
, + DropdownMenuContent: ({ children }: any) =>
{children}
, + DropdownMenuItem: ({ children, onClick }: any) => ( +
+ {children} +
+ ), +})); + +vi.mock("lucide-react", () => ({ + ChevronDown: () => , + Pencil: () => , + X: () => , +})); + +import { + ProviderBatchToolbar, + type ProviderBatchToolbarProps, +} from "@/app/[locale]/settings/providers/_components/batch-edit/provider-batch-toolbar"; + +function createProvider( + id: number, + providerType: ProviderType, + groupTag: string | null = null +): ProviderDisplay { + return { id, providerType, groupTag } as ProviderDisplay; +} + +function render(node: React.ReactNode) { + const container = document.createElement("div"); + document.body.appendChild(container); + const root = createRoot(container); + + act(() => { + root.render(node); + }); + + return { + container, + unmount: () => { + act(() => root.unmount()); + container.remove(); + }, + }; +} + +function defaultProps( + overrides: Partial = {} +): ProviderBatchToolbarProps { + return { + isMultiSelectMode: false, + allSelected: false, + selectedCount: 0, + totalCount: 3, + onEnterMode: vi.fn(), + onExitMode: vi.fn(), + onSelectAll: vi.fn(), + onInvertSelection: vi.fn(), + onOpenBatchEdit: vi.fn(), + providers: [ + createProvider(1, "claude"), + createProvider(2, "openai"), + createProvider(3, "claude"), + ], + onSelectByType: vi.fn(), + onSelectByGroup: vi.fn(), + ...overrides, + }; +} + +describe("ProviderBatchToolbar - discoverability hint", () => { + describe("not in multi-select mode", () => { + it("shows enter-mode button and hint text when totalCount > 1", () => { + const props = defaultProps({ totalCount: 3 }); + const { container, unmount } = render(); + + const buttons = container.querySelectorAll("button"); + const enterBtn = Array.from(buttons).find((b) => b.textContent?.includes("enterMode")); + expect(enterBtn).toBeTruthy(); + + const hint = container.querySelector("span.text-xs"); + expect(hint).toBeTruthy(); + expect(hint!.textContent).toBe("selectionHint"); + + unmount(); + }); + + it("shows hint when totalCount is exactly 1 (totalCount > 0 condition)", () => { + const props = defaultProps({ + totalCount: 1, + providers: [createProvider(1, "claude")], + }); + const { container, unmount } = render(); + + const hint = container.querySelector("span.text-xs"); + expect(hint).toBeTruthy(); + + unmount(); + }); + + it("does NOT show hint when totalCount is 0", () => { + const props = defaultProps({ totalCount: 0, providers: [] }); + const { container, unmount } = render(); + + const hint = container.querySelector("span.text-xs"); + expect(hint).toBeNull(); + + unmount(); + }); + + it("hint uses i18n key selectionHint", () => { + const props = defaultProps({ totalCount: 5 }); + const { container, unmount } = render(); + + const hint = container.querySelector("span.text-xs"); + expect(hint).toBeTruthy(); + expect(hint!.textContent).toBe("selectionHint"); + + unmount(); + }); + + it("enter-mode button is disabled when totalCount is 0", () => { + const props = defaultProps({ totalCount: 0, providers: [] }); + const { container, unmount } = render(); + + const buttons = container.querySelectorAll("button"); + const enterBtn = Array.from(buttons).find((b) => b.textContent?.includes("enterMode")); + expect(enterBtn).toBeTruthy(); + expect(enterBtn!.disabled).toBe(true); + + unmount(); + }); + }); + + describe("in multi-select mode", () => { + it("does NOT show hint text", () => { + const props = defaultProps({ isMultiSelectMode: true, selectedCount: 1 }); + const { container, unmount } = render(); + + const allSpans = container.querySelectorAll("span"); + const hintSpan = Array.from(allSpans).find((s) => s.textContent === "selectionHint"); + expect(hintSpan).toBeFalsy(); + + unmount(); + }); + + it("renders select-all checkbox and selected count", () => { + const props = defaultProps({ isMultiSelectMode: true, selectedCount: 2 }); + const { container, unmount } = render(); + + const checkbox = container.querySelector('input[type="checkbox"]'); + expect(checkbox).toBeTruthy(); + + const countText = Array.from(container.querySelectorAll("span")).find((s) => + s.textContent?.includes("selectedCount") + ); + expect(countText).toBeTruthy(); + + unmount(); + }); + + it("renders invert, edit, and exit buttons", () => { + const props = defaultProps({ isMultiSelectMode: true, selectedCount: 1 }); + const { container, unmount } = render(); + + const buttons = container.querySelectorAll("button"); + const texts = Array.from(buttons).map((b) => b.textContent); + + expect(texts.some((t) => t?.includes("invertSelection"))).toBe(true); + expect(texts.some((t) => t?.includes("editSelected"))).toBe(true); + expect(texts.some((t) => t?.includes("exitMode"))).toBe(true); + + unmount(); + }); + }); +}); diff --git a/tests/unit/settings/providers/provider-form-batch-context.test.ts b/tests/unit/settings/providers/provider-form-batch-context.test.ts new file mode 100644 index 000000000..de0bd281d --- /dev/null +++ b/tests/unit/settings/providers/provider-form-batch-context.test.ts @@ -0,0 +1,190 @@ +import { describe, expect, it } from "vitest"; +import { + createInitialState, + providerFormReducer, +} from "@/app/[locale]/settings/providers/_components/forms/provider-form/provider-form-context"; + +// --------------------------------------------------------------------------- +// createInitialState("batch") +// --------------------------------------------------------------------------- + +describe("createInitialState - batch mode", () => { + it("returns batch state with isEnabled set to no_change", () => { + const state = createInitialState("batch"); + + expect(state.batch.isEnabled).toBe("no_change"); + }); + + it("returns neutral routing defaults (no provider source)", () => { + const state = createInitialState("batch"); + + expect(state.routing.priority).toBe(0); + expect(state.routing.weight).toBe(1); + expect(state.routing.costMultiplier).toBe(1.0); + expect(state.routing.groupTag).toEqual([]); + expect(state.routing.preserveClientIp).toBe(false); + expect(state.routing.modelRedirects).toEqual({}); + expect(state.routing.allowedModels).toEqual([]); + expect(state.routing.cacheTtlPreference).toBe("inherit"); + expect(state.routing.swapCacheTtlBilling).toBe(false); + expect(state.routing.anthropicAdaptiveThinking).toBeNull(); + }); + + it("returns neutral rate limit defaults", () => { + const state = createInitialState("batch"); + + expect(state.rateLimit.limit5hUsd).toBeNull(); + expect(state.rateLimit.limitDailyUsd).toBeNull(); + expect(state.rateLimit.dailyResetMode).toBe("fixed"); + expect(state.rateLimit.dailyResetTime).toBe("00:00"); + expect(state.rateLimit.limitWeeklyUsd).toBeNull(); + expect(state.rateLimit.limitMonthlyUsd).toBeNull(); + expect(state.rateLimit.limitTotalUsd).toBeNull(); + expect(state.rateLimit.limitConcurrentSessions).toBeNull(); + }); + + it("returns neutral circuit breaker defaults", () => { + const state = createInitialState("batch"); + + expect(state.circuitBreaker.failureThreshold).toBeUndefined(); + expect(state.circuitBreaker.openDurationMinutes).toBeUndefined(); + expect(state.circuitBreaker.halfOpenSuccessThreshold).toBeUndefined(); + expect(state.circuitBreaker.maxRetryAttempts).toBeNull(); + }); + + it("returns neutral network defaults", () => { + const state = createInitialState("batch"); + + expect(state.network.proxyUrl).toBe(""); + expect(state.network.proxyFallbackToDirect).toBe(false); + expect(state.network.firstByteTimeoutStreamingSeconds).toBeUndefined(); + expect(state.network.streamingIdleTimeoutSeconds).toBeUndefined(); + expect(state.network.requestTimeoutNonStreamingSeconds).toBeUndefined(); + }); + + it("returns neutral MCP defaults", () => { + const state = createInitialState("batch"); + + expect(state.mcp.mcpPassthroughType).toBe("none"); + expect(state.mcp.mcpPassthroughUrl).toBe(""); + }); + + it("ignores provider and cloneProvider arguments in batch mode", () => { + const fakeProvider = { + id: 99, + name: "Ignored", + url: "https://ignored.example.com", + maskedKey: "xxxx****xxxx", + isEnabled: false, + weight: 50, + priority: 99, + groupPriorities: null, + costMultiplier: 3.0, + groupTag: "prod", + providerType: "claude" as const, + providerVendorId: null, + preserveClientIp: true, + modelRedirects: null, + allowedModels: null, + mcpPassthroughType: "none" as const, + mcpPassthroughUrl: null, + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed" as const, + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + limitConcurrentSessions: 10, + maxRetryAttempts: null, + circuitBreakerFailureThreshold: 5, + circuitBreakerOpenDuration: 30000, + circuitBreakerHalfOpenSuccessThreshold: 2, + proxyUrl: null, + proxyFallbackToDirect: false, + firstByteTimeoutStreamingMs: 30000, + streamingIdleTimeoutMs: 120000, + requestTimeoutNonStreamingMs: 120000, + websiteUrl: null, + faviconUrl: null, + cacheTtlPreference: null, + swapCacheTtlBilling: false, + context1mPreference: null, + codexReasoningEffortPreference: null, + codexReasoningSummaryPreference: null, + codexTextVerbosityPreference: null, + codexParallelToolCallsPreference: null, + anthropicMaxTokensPreference: null, + anthropicThinkingBudgetPreference: null, + anthropicAdaptiveThinking: null, + geminiGoogleSearchPreference: null, + tpm: null, + rpm: null, + rpd: null, + cc: null, + createdAt: "2024-01-01T00:00:00Z", + updatedAt: "2024-01-01T00:00:00Z", + }; + + const state = createInitialState("batch", fakeProvider, fakeProvider); + + // Should still be batch defaults, not the provider values + expect(state.routing.priority).toBe(0); + expect(state.routing.weight).toBe(1); + expect(state.routing.costMultiplier).toBe(1.0); + expect(state.batch.isEnabled).toBe("no_change"); + }); +}); + +// --------------------------------------------------------------------------- +// providerFormReducer - SET_BATCH_IS_ENABLED +// --------------------------------------------------------------------------- + +describe("providerFormReducer - SET_BATCH_IS_ENABLED", () => { + const baseState = createInitialState("batch"); + + it("sets isEnabled to true", () => { + const next = providerFormReducer(baseState, { + type: "SET_BATCH_IS_ENABLED", + payload: "true", + }); + + expect(next.batch.isEnabled).toBe("true"); + }); + + it("sets isEnabled to false", () => { + const next = providerFormReducer(baseState, { + type: "SET_BATCH_IS_ENABLED", + payload: "false", + }); + + expect(next.batch.isEnabled).toBe("false"); + }); + + it("sets isEnabled back to no_change", () => { + const modified = providerFormReducer(baseState, { + type: "SET_BATCH_IS_ENABLED", + payload: "true", + }); + const reverted = providerFormReducer(modified, { + type: "SET_BATCH_IS_ENABLED", + payload: "no_change", + }); + + expect(reverted.batch.isEnabled).toBe("no_change"); + }); + + it("does not mutate other state sections", () => { + const next = providerFormReducer(baseState, { + type: "SET_BATCH_IS_ENABLED", + payload: "true", + }); + + expect(next.routing).toEqual(baseState.routing); + expect(next.rateLimit).toEqual(baseState.rateLimit); + expect(next.circuitBreaker).toEqual(baseState.circuitBreaker); + expect(next.network).toEqual(baseState.network); + expect(next.mcp).toEqual(baseState.mcp); + expect(next.ui).toEqual(baseState.ui); + }); +}); diff --git a/tests/unit/settings/providers/provider-undo-toast.test.tsx b/tests/unit/settings/providers/provider-undo-toast.test.tsx new file mode 100644 index 000000000..239562e0c --- /dev/null +++ b/tests/unit/settings/providers/provider-undo-toast.test.tsx @@ -0,0 +1,595 @@ +/** + * @vitest-environment happy-dom + */ + +import { act } from "react"; +import { createRoot } from "react-dom/client"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { ProviderBatchDialog } from "@/app/[locale]/settings/providers/_components/batch-edit/provider-batch-dialog"; +import type { ProviderDisplay } from "@/types/provider"; + +// --------------------------------------------------------------------------- +// Mutable mock state for useProviderForm +// --------------------------------------------------------------------------- + +let mockDirtyFields = new Set(); +const mockDispatch = vi.fn(); +const mockState = { + ui: { activeTab: "basic" as const, isPending: false, showFailureThresholdConfirm: false }, + basic: { name: "", url: "", key: "", websiteUrl: "" }, + routing: { + providerType: "claude" as const, + groupTag: [], + preserveClientIp: false, + modelRedirects: {}, + allowedModels: [], + priority: 5, + groupPriorities: {}, + weight: 1, + costMultiplier: 1, + cacheTtlPreference: "inherit" as const, + swapCacheTtlBilling: false, + context1mPreference: "inherit" as const, + codexReasoningEffortPreference: "inherit", + codexReasoningSummaryPreference: "inherit", + codexTextVerbosityPreference: "inherit", + codexParallelToolCallsPreference: "inherit", + anthropicMaxTokensPreference: "inherit", + anthropicThinkingBudgetPreference: "inherit", + anthropicAdaptiveThinking: null, + geminiGoogleSearchPreference: "inherit", + }, + rateLimit: { + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed" as const, + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + limitConcurrentSessions: null, + }, + circuitBreaker: { + failureThreshold: undefined, + openDurationMinutes: undefined, + halfOpenSuccessThreshold: undefined, + maxRetryAttempts: null, + }, + network: { + proxyUrl: "", + proxyFallbackToDirect: false, + firstByteTimeoutStreamingSeconds: undefined, + streamingIdleTimeoutSeconds: undefined, + requestTimeoutNonStreamingSeconds: undefined, + }, + mcp: { mcpPassthroughType: "none" as const, mcpPassthroughUrl: "" }, + batch: { isEnabled: "no_change" as const }, +}; + +// --------------------------------------------------------------------------- +// Mocks +// --------------------------------------------------------------------------- + +vi.mock("next-intl", () => ({ + useTranslations: () => { + const t = (key: string, params?: Record) => { + if (params) { + let result = key; + for (const [k, v] of Object.entries(params)) { + result = result.replace(`{${k}}`, String(v)); + } + return result; + } + return key; + }; + return t; + }, +})); + +const mockInvalidateQueries = vi.fn().mockResolvedValue(undefined); +vi.mock("@tanstack/react-query", () => ({ + useQueryClient: () => ({ + invalidateQueries: mockInvalidateQueries, + }), +})); + +const mockToastSuccess = vi.fn(); +const mockToastError = vi.fn(); +vi.mock("sonner", () => ({ + toast: { + success: (...args: unknown[]) => mockToastSuccess(...args), + error: (...args: unknown[]) => mockToastError(...args), + }, +})); + +const mockPreview = vi.fn(); +const mockApply = vi.fn(); +const mockUndo = vi.fn(); +vi.mock("@/actions/providers", () => ({ + previewProviderBatchPatch: (...args: unknown[]) => mockPreview(...args), + applyProviderBatchPatch: (...args: unknown[]) => mockApply(...args), + undoProviderPatch: (...args: unknown[]) => mockUndo(...args), + batchDeleteProviders: vi.fn().mockResolvedValue({ ok: true, data: { deletedCount: 1 } }), + batchResetProviderCircuits: vi.fn().mockResolvedValue({ ok: true, data: { resetCount: 1 } }), +})); + +// Mock ProviderFormProvider + useProviderForm +vi.mock( + "@/app/[locale]/settings/providers/_components/forms/provider-form/provider-form-context", + () => ({ + ProviderFormProvider: ({ children }: { children: React.ReactNode }) => <>{children}, + useProviderForm: () => ({ + state: mockState, + dispatch: mockDispatch, + dirtyFields: mockDirtyFields, + mode: "batch", + }), + }) +); + +// Mock all form section components as stubs +vi.mock( + "@/app/[locale]/settings/providers/_components/forms/provider-form/sections/basic-info-section", + () => ({ + BasicInfoSection: () =>
, + }) +); +vi.mock( + "@/app/[locale]/settings/providers/_components/forms/provider-form/sections/routing-section", + () => ({ + RoutingSection: () =>
, + }) +); +vi.mock( + "@/app/[locale]/settings/providers/_components/forms/provider-form/sections/limits-section", + () => ({ + LimitsSection: () =>
, + }) +); +vi.mock( + "@/app/[locale]/settings/providers/_components/forms/provider-form/sections/network-section", + () => ({ + NetworkSection: () =>
, + }) +); +vi.mock( + "@/app/[locale]/settings/providers/_components/forms/provider-form/sections/testing-section", + () => ({ + TestingSection: () =>
, + }) +); + +// Mock FormTabNav +vi.mock( + "@/app/[locale]/settings/providers/_components/forms/provider-form/components/form-tab-nav", + () => ({ + FormTabNav: () =>
, + }) +); + +// Mock ProviderBatchPreviewStep +vi.mock( + "@/app/[locale]/settings/providers/_components/batch-edit/provider-batch-preview-step", + () => ({ + ProviderBatchPreviewStep: () =>
, + }) +); + +// Mock buildPatchDraftFromFormState +vi.mock("@/app/[locale]/settings/providers/_components/batch-edit/build-patch-draft", () => ({ + buildPatchDraftFromFormState: vi.fn().mockReturnValue({ priority: { set: 5 } }), +})); + +// UI component mocks +vi.mock("@/components/ui/dialog", () => ({ + Dialog: ({ open, children }: { open: boolean; children: React.ReactNode }) => + open ?
{children}
: null, + DialogContent: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + DialogDescription: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + DialogFooter: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + DialogHeader: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + DialogTitle: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), +})); + +vi.mock("@/components/ui/alert-dialog", () => ({ + AlertDialog: ({ open, children }: { open: boolean; children: React.ReactNode }) => + open ?
{children}
: null, + AlertDialogAction: ({ children, ...props }: any) => , + AlertDialogCancel: ({ children, ...props }: any) => , + AlertDialogContent: ({ children }: { children: React.ReactNode }) =>
{children}
, + AlertDialogDescription: ({ children }: { children: React.ReactNode }) =>
{children}
, + AlertDialogFooter: ({ children }: { children: React.ReactNode }) =>
{children}
, + AlertDialogHeader: ({ children }: { children: React.ReactNode }) =>
{children}
, + AlertDialogTitle: ({ children }: { children: React.ReactNode }) =>
{children}
, +})); + +vi.mock("@/components/ui/button", () => ({ + Button: ({ children, ...props }: any) => , +})); + +vi.mock("lucide-react", () => ({ + Loader2: () =>
, +})); + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function render(ui: React.ReactElement) { + const container = document.createElement("div"); + document.body.appendChild(container); + let root: ReturnType; + act(() => { + root = createRoot(container); + root.render(ui); + }); + return { + container, + unmount: () => { + act(() => root.unmount()); + container.remove(); + }, + }; +} + +function createMockProvider(id: number, name: string): ProviderDisplay { + return { + id, + name, + url: "https://api.example.com", + maskedKey: "xxxx****1234", + isEnabled: true, + weight: 1, + priority: 0, + groupPriorities: null, + costMultiplier: 1, + groupTag: null, + providerType: "claude", + providerVendorId: null, + preserveClientIp: false, + modelRedirects: null, + allowedModels: null, + mcpPassthroughType: "none", + mcpPassthroughUrl: null, + limit5hUsd: null, + limitDailyUsd: null, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + limitWeeklyUsd: null, + limitMonthlyUsd: null, + limitTotalUsd: null, + limitConcurrentSessions: 10, + maxRetryAttempts: null, + circuitBreakerFailureThreshold: 5, + circuitBreakerOpenDuration: 30000, + circuitBreakerHalfOpenSuccessThreshold: 2, + proxyUrl: null, + proxyFallbackToDirect: false, + firstByteTimeoutStreamingMs: 30000, + streamingIdleTimeoutMs: 120000, + requestTimeoutNonStreamingMs: 120000, + websiteUrl: null, + faviconUrl: null, + cacheTtlPreference: null, + swapCacheTtlBilling: false, + context1mPreference: null, + codexReasoningEffortPreference: null, + codexReasoningSummaryPreference: null, + codexTextVerbosityPreference: null, + codexParallelToolCallsPreference: null, + anthropicMaxTokensPreference: null, + anthropicThinkingBudgetPreference: null, + anthropicAdaptiveThinking: null, + geminiGoogleSearchPreference: null, + tpm: null, + rpm: null, + rpd: null, + cc: null, + createdAt: "2024-01-01T00:00:00Z", + updatedAt: "2024-01-01T00:00:00Z", + }; +} + +function defaultProps(overrides: Partial> = {}) { + return { + open: true, + mode: "edit" as const, + onOpenChange: vi.fn(), + selectedProviderIds: new Set([1, 2]), + providers: [createMockProvider(1, "Provider1"), createMockProvider(2, "Provider2")], + onSuccess: vi.fn(), + ...overrides, + }; +} + +/** + * Drives the dialog from "edit" step through "preview" step to "apply": + * 1. Click "Next" (second button in edit-step footer) + * 2. Wait for preview to resolve + * 3. Click "Apply" (second button in preview-step footer) + * 4. Wait for apply to resolve + */ +async function driveToApply(container: HTMLElement) { + // Click Next (second button in footer) + const footer = container.querySelector('[data-testid="dialog-footer"]'); + const buttons = footer?.querySelectorAll("button") ?? []; + const nextButton = buttons[1] as HTMLButtonElement; + + await act(async () => { + nextButton.click(); + }); + await act(async () => { + await new Promise((r) => setTimeout(r, 10)); + }); + + // Click Apply (second button in preview-step footer) + const applyFooter = container.querySelector('[data-testid="dialog-footer"]'); + const applyButtons = applyFooter?.querySelectorAll("button") ?? []; + const applyButton = applyButtons[1] as HTMLButtonElement; + + await act(async () => { + applyButton.click(); + }); + await act(async () => { + await new Promise((r) => setTimeout(r, 10)); + }); +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("Provider Undo Toast", () => { + beforeEach(() => { + vi.clearAllMocks(); + // Make hasChanges true so the "Next" button is enabled + mockDirtyFields = new Set(["routing.priority"]); + }); + + it("shows undo toast after successful apply", async () => { + mockPreview.mockResolvedValue({ + ok: true, + data: { + previewToken: "tok-1", + previewRevision: "rev-1", + previewExpiresAt: new Date(Date.now() + 60000).toISOString(), + providerIds: [1, 2], + changedFields: ["priority"], + rows: [ + { + providerId: 1, + providerName: "Provider1", + field: "priority", + status: "changed", + before: 0, + after: 5, + }, + { + providerId: 2, + providerName: "Provider2", + field: "priority", + status: "changed", + before: 0, + after: 5, + }, + ], + summary: { providerCount: 2, fieldCount: 2, skipCount: 0 }, + }, + }); + + mockApply.mockResolvedValue({ + ok: true, + data: { + operationId: "op-1", + appliedAt: new Date().toISOString(), + updatedCount: 2, + undoToken: "undo-tok-1", + undoExpiresAt: new Date(Date.now() + 10000).toISOString(), + }, + }); + + const props = defaultProps(); + const { container, unmount } = render(); + + await driveToApply(container); + + expect(mockPreview).toHaveBeenCalledTimes(1); + expect(mockApply).toHaveBeenCalledTimes(1); + + // Verify toast.success was called with undo action + expect(mockToastSuccess).toHaveBeenCalledWith( + "toast.updated", + expect.objectContaining({ + duration: 10000, + action: expect.objectContaining({ + label: expect.any(String), + onClick: expect.any(Function), + }), + }) + ); + + unmount(); + }); + + it("undo action calls undoProviderPatch on success", async () => { + mockPreview.mockResolvedValue({ + ok: true, + data: { + previewToken: "tok-2", + previewRevision: "rev-2", + previewExpiresAt: new Date(Date.now() + 60000).toISOString(), + providerIds: [1], + changedFields: ["priority"], + rows: [ + { + providerId: 1, + providerName: "Provider1", + field: "priority", + status: "changed", + before: 0, + after: 5, + }, + ], + summary: { providerCount: 1, fieldCount: 1, skipCount: 0 }, + }, + }); + + mockApply.mockResolvedValue({ + ok: true, + data: { + operationId: "op-2", + appliedAt: new Date().toISOString(), + updatedCount: 1, + undoToken: "undo-tok-2", + undoExpiresAt: new Date(Date.now() + 10000).toISOString(), + }, + }); + + mockUndo.mockResolvedValue({ + ok: true, + data: { + operationId: "op-2", + revertedAt: new Date().toISOString(), + revertedCount: 1, + }, + }); + + const props = defaultProps({ selectedProviderIds: new Set([1]) }); + const { container, unmount } = render(); + + await driveToApply(container); + + // Extract the undo onClick from the toast call + const toastCall = mockToastSuccess.mock.calls[0]; + const toastOptions = toastCall[1] as { action: { onClick: () => Promise } }; + + // Call the undo action + await act(async () => { + await toastOptions.action.onClick(); + }); + + expect(mockUndo).toHaveBeenCalledWith({ + undoToken: "undo-tok-2", + operationId: "op-2", + }); + + // Should show success toast for undo + expect(mockToastSuccess).toHaveBeenCalledTimes(2); + expect(mockToastSuccess.mock.calls[1][0]).toBe("toast.undoSuccess"); + + unmount(); + }); + + it("undo failure shows error toast", async () => { + mockPreview.mockResolvedValue({ + ok: true, + data: { + previewToken: "tok-3", + previewRevision: "rev-3", + previewExpiresAt: new Date(Date.now() + 60000).toISOString(), + providerIds: [1], + changedFields: ["priority"], + rows: [ + { + providerId: 1, + providerName: "Provider1", + field: "priority", + status: "changed", + before: 0, + after: 5, + }, + ], + summary: { providerCount: 1, fieldCount: 1, skipCount: 0 }, + }, + }); + + mockApply.mockResolvedValue({ + ok: true, + data: { + operationId: "op-3", + appliedAt: new Date().toISOString(), + updatedCount: 1, + undoToken: "undo-tok-3", + undoExpiresAt: new Date(Date.now() + 10000).toISOString(), + }, + }); + + mockUndo.mockResolvedValue({ + ok: false, + error: "Undo window expired", + errorCode: "UNDO_EXPIRED", + }); + + const props = defaultProps({ selectedProviderIds: new Set([1]) }); + const { container, unmount } = render(); + + await driveToApply(container); + + // Extract undo onClick + const toastCall = mockToastSuccess.mock.calls[0]; + const toastOptions = toastCall[1] as { action: { onClick: () => Promise } }; + + // Call undo - should fail + await act(async () => { + await toastOptions.action.onClick(); + }); + + expect(mockUndo).toHaveBeenCalledTimes(1); + // After undo failure, error toast is shown via toast.error + expect(mockToastError).toHaveBeenCalled(); + + unmount(); + }); + + it("apply shows error toast on failure", async () => { + mockPreview.mockResolvedValue({ + ok: true, + data: { + previewToken: "tok-4", + previewRevision: "rev-4", + previewExpiresAt: new Date(Date.now() + 60000).toISOString(), + providerIds: [1], + changedFields: ["priority"], + rows: [ + { + providerId: 1, + providerName: "Provider1", + field: "priority", + status: "changed", + before: 0, + after: 5, + }, + ], + summary: { providerCount: 1, fieldCount: 1, skipCount: 0 }, + }, + }); + + mockApply.mockResolvedValue({ + ok: false, + error: "Preview expired", + errorCode: "PREVIEW_EXPIRED", + }); + + const props = defaultProps({ selectedProviderIds: new Set([1]) }); + const { container, unmount } = render(); + + await driveToApply(container); + + expect(mockApply).toHaveBeenCalledTimes(1); + // After apply failure, error toast is shown via toast.error + expect(mockToastError).toHaveBeenCalled(); + expect(mockToastSuccess).not.toHaveBeenCalled(); + + unmount(); + }); +}); diff --git a/tests/unit/settings/providers/thinking-budget-editor.test.tsx b/tests/unit/settings/providers/thinking-budget-editor.test.tsx new file mode 100644 index 000000000..3965822b4 --- /dev/null +++ b/tests/unit/settings/providers/thinking-budget-editor.test.tsx @@ -0,0 +1,233 @@ +/** + * @vitest-environment happy-dom + */ + +import { act } from "react"; +import { createRoot } from "react-dom/client"; +import { describe, expect, it, vi, beforeEach } from "vitest"; +import { ThinkingBudgetEditor } from "@/app/[locale]/settings/providers/_components/thinking-budget-editor"; + +// Mock next-intl +vi.mock("next-intl", () => ({ + useTranslations: () => (key: string) => key, +})); + +// Mock Select as native onValueChange(e.target.value)} + disabled={disabled} + > + {children} + +
+ ), + SelectTrigger: ({ children }: { children: React.ReactNode }) =>
{children}
, + SelectValue: () => null, + SelectContent: ({ children }: { children: React.ReactNode }) => <>{children}, + SelectItem: ({ value, children }: { value: string; children: React.ReactNode }) => ( + + ), +})); + +// Mock Tooltip as passthrough +vi.mock("@/components/ui/tooltip", () => ({ + Tooltip: ({ children }: { children: React.ReactNode }) =>
{children}
, + TooltipTrigger: ({ children }: { children: React.ReactNode }) =>
{children}
, + TooltipContent: ({ children }: { children: React.ReactNode }) =>
{children}
, +})); + +// Mock lucide-react +vi.mock("lucide-react", () => ({ + Info: () =>
, +})); + +function render(node: React.ReactNode) { + const container = document.createElement("div"); + document.body.appendChild(container); + const root = createRoot(container); + + act(() => { + root.render(node); + }); + + return { + container, + unmount: () => { + act(() => root.unmount()); + container.remove(); + }, + }; +} + +describe("ThinkingBudgetEditor", () => { + const defaultProps = { + value: "inherit", + onChange: vi.fn(), + disabled: false, + }; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("renders with inherit value - no numeric input or max button", () => { + const { container, unmount } = render(); + + const select = container.querySelector('[data-testid="select-trigger"]') as HTMLSelectElement; + expect(select).toBeTruthy(); + expect(select.value).toBe("inherit"); + + // No number input when inherit + expect(container.querySelector('input[type="number"]')).toBeNull(); + // No max-out button when inherit + expect(container.querySelector("button")).toBeNull(); + + unmount(); + }); + + it("renders with numeric value - shows custom select, input, and max button", () => { + const { container, unmount } = render(); + + const select = container.querySelector('[data-testid="select-trigger"]') as HTMLSelectElement; + expect(select.value).toBe("custom"); + + const input = container.querySelector('input[type="number"]') as HTMLInputElement; + expect(input).toBeTruthy(); + expect(input.value).toBe("15000"); + + const maxButton = container.querySelector("button"); + expect(maxButton).toBeTruthy(); + expect(maxButton?.textContent).toContain("maxOutButton"); + + unmount(); + }); + + it("switches from inherit to custom - calls onChange with 10240", () => { + const onChange = vi.fn(); + const { container, unmount } = render( + + ); + + const select = container.querySelector('[data-testid="select-trigger"]') as HTMLSelectElement; + + act(() => { + select.value = "custom"; + select.dispatchEvent(new Event("change", { bubbles: true })); + }); + + expect(onChange).toHaveBeenCalledWith("10240"); + + unmount(); + }); + + it("switches from custom to inherit - calls onChange with inherit", () => { + const onChange = vi.fn(); + const { container, unmount } = render( + + ); + + const select = container.querySelector('[data-testid="select-trigger"]') as HTMLSelectElement; + + act(() => { + select.value = "inherit"; + select.dispatchEvent(new Event("change", { bubbles: true })); + }); + + expect(onChange).toHaveBeenCalledWith("inherit"); + + unmount(); + }); + + it("clicking max-out button calls onChange with 32000", () => { + const onChange = vi.fn(); + const { container, unmount } = render( + + ); + + const maxButton = container.querySelector("button") as HTMLButtonElement; + + act(() => { + maxButton.click(); + }); + + expect(onChange).toHaveBeenCalledWith("32000"); + + unmount(); + }); + + it("typing a number calls onChange with that value", () => { + const onChange = vi.fn(); + const { container, unmount } = render( + + ); + + const input = container.querySelector('input[type="number"]') as HTMLInputElement; + + act(() => { + const nativeInputValueSetter = Object.getOwnPropertyDescriptor( + window.HTMLInputElement.prototype, + "value" + )?.set; + nativeInputValueSetter?.call(input, "12345"); + input.dispatchEvent(new Event("change", { bubbles: true })); + }); + + expect(onChange).toHaveBeenCalledWith("12345"); + + unmount(); + }); + + it("clearing input calls onChange with inherit", () => { + const onChange = vi.fn(); + const { container, unmount } = render( + + ); + + const input = container.querySelector('input[type="number"]') as HTMLInputElement; + + act(() => { + const nativeInputValueSetter = Object.getOwnPropertyDescriptor( + window.HTMLInputElement.prototype, + "value" + )?.set; + nativeInputValueSetter?.call(input, ""); + input.dispatchEvent(new Event("change", { bubbles: true })); + }); + + expect(onChange).toHaveBeenCalledWith(""); + + unmount(); + }); + + it("disabled prop disables all controls", () => { + const { container, unmount } = render( + + ); + + const select = container.querySelector('[data-testid="select-trigger"]') as HTMLSelectElement; + expect(select.disabled).toBe(true); + + const input = container.querySelector('input[type="number"]') as HTMLInputElement; + expect(input.disabled).toBe(true); + + const maxButton = container.querySelector("button") as HTMLButtonElement; + expect(maxButton.disabled).toBe(true); + + unmount(); + }); +}); diff --git a/tests/unit/usage-doc/usage-doc-auth-state.test.tsx b/tests/unit/usage-doc/usage-doc-auth-state.test.tsx new file mode 100644 index 000000000..e22fa99c0 --- /dev/null +++ b/tests/unit/usage-doc/usage-doc-auth-state.test.tsx @@ -0,0 +1,126 @@ +/** + * @vitest-environment happy-dom + */ + +import fs from "node:fs"; +import path from "node:path"; +import type { ReactNode } from "react"; +import { act } from "react"; +import { createRoot } from "react-dom/client"; +import { NextIntlClientProvider } from "next-intl"; +import { describe, expect, test, vi } from "vitest"; +import { UsageDocAuthProvider } from "@/app/[locale]/usage-doc/_components/usage-doc-auth-context"; +import { QuickLinks } from "@/app/[locale]/usage-doc/_components/quick-links"; + +vi.mock("@/i18n/routing", () => ({ + Link: ({ + href, + children, + ...rest + }: { + href: string; + children: ReactNode; + className?: string; + }) => ( + + {children} + + ), +})); + +function loadUsageMessages() { + return JSON.parse( + fs.readFileSync(path.join(process.cwd(), "messages", "en", "usage.json"), "utf8") + ); +} + +function renderWithAuth(node: ReactNode) { + const container = document.createElement("div"); + document.body.appendChild(container); + const root = createRoot(container); + const usageMessages = loadUsageMessages(); + + act(() => { + root.render( + + {node} + + ); + }); + + return { + container, + unmount: () => { + act(() => root.unmount()); + container.remove(); + }, + }; +} + +describe("usage-doc auth state - HttpOnly cookie alignment", () => { + test("logged-in: QuickLinks renders dashboard link when isLoggedIn=true", () => { + Object.defineProperty(window, "scrollTo", { value: vi.fn(), writable: true }); + + const { container, unmount } = renderWithAuth( + + + + ); + + const dashboardLink = container.querySelector('a[href="/dashboard"]'); + expect(dashboardLink).not.toBeNull(); + + unmount(); + }); + + test("logged-out: QuickLinks does NOT render dashboard link when isLoggedIn=false", () => { + Object.defineProperty(window, "scrollTo", { value: vi.fn(), writable: true }); + + const { container, unmount } = renderWithAuth( + + + + ); + + const dashboardLink = container.querySelector('a[href="/dashboard"]'); + expect(dashboardLink).toBeNull(); + + unmount(); + }); + + test("default context value is isLoggedIn=false (no provider ancestor)", () => { + Object.defineProperty(window, "scrollTo", { value: vi.fn(), writable: true }); + + const { container, unmount } = renderWithAuth(); + + const dashboardLink = container.querySelector('a[href="/dashboard"]'); + expect(dashboardLink).toBeNull(); + + unmount(); + }); + + test("page.tsx no longer reads document.cookie for auth state", async () => { + const srcContent = fs.readFileSync( + path.join(process.cwd(), "src", "app", "[locale]", "usage-doc", "page.tsx"), + "utf8" + ); + expect(srcContent).not.toContain("document.cookie"); + }); + + test("page.tsx uses useUsageDocAuth hook for session state", async () => { + const srcContent = fs.readFileSync( + path.join(process.cwd(), "src", "app", "[locale]", "usage-doc", "page.tsx"), + "utf8" + ); + expect(srcContent).toContain("useUsageDocAuth"); + }); + + test("layout.tsx wraps children with UsageDocAuthProvider", async () => { + const srcContent = fs.readFileSync( + path.join(process.cwd(), "src", "app", "[locale]", "usage-doc", "layout.tsx"), + "utf8" + ); + expect(srcContent).toContain("UsageDocAuthProvider"); + expect(srcContent).toContain("isLoggedIn={!!session}"); + }); +}); diff --git a/tests/unit/usage-doc/usage-doc-page.test.tsx b/tests/unit/usage-doc/usage-doc-page.test.tsx index 284637e53..9801bda2c 100644 --- a/tests/unit/usage-doc/usage-doc-page.test.tsx +++ b/tests/unit/usage-doc/usage-doc-page.test.tsx @@ -10,6 +10,7 @@ import { createRoot } from "react-dom/client"; import { NextIntlClientProvider } from "next-intl"; import { describe, expect, test, vi } from "vitest"; import UsageDocPage from "@/app/[locale]/usage-doc/page"; +import { UsageDocAuthProvider } from "@/app/[locale]/usage-doc/_components/usage-doc-auth-context"; vi.mock("@/i18n/routing", () => ({ Link: ({ @@ -56,18 +57,18 @@ async function renderWithIntl(locale: string, node: ReactNode) { } describe("UsageDocPage - 目录/快速链接交互", () => { - test("应渲染 skip links,且登录态显示返回仪表盘链接", async () => { + test("should render skip links and show dashboard link when logged in", async () => { Object.defineProperty(window, "scrollTo", { value: vi.fn(), writable: true, }); - Object.defineProperty(document, "cookie", { - configurable: true, - get: () => "auth-token=test-token", - }); - - const { unmount } = await renderWithIntl("en", ); + const { unmount } = await renderWithIntl( + "en", + + + + ); expect(document.querySelector('a[href="#main-content"]')).not.toBeNull(); expect(document.querySelector('a[href="#toc-navigation"]')).not.toBeNull(); @@ -76,8 +77,6 @@ describe("UsageDocPage - 目录/快速链接交互", () => { expect(dashboardLink).not.toBeNull(); await unmount(); - - Reflect.deleteProperty(document, "cookie"); }); test("ru 语言不应显示中文占位符与代码块注释", async () => { diff --git a/vitest.config.ts b/vitest.config.ts index ec86290f7..0bee6eaad 100644 --- a/vitest.config.ts +++ b/vitest.config.ts @@ -12,6 +12,7 @@ export default defineConfig({ environment: "happy-dom", include: [ "tests/unit/**/*.{test,spec}.tsx", + "tests/security/**/*.{test,spec}.{ts,tsx}", "tests/api/**/*.{test,spec}.tsx", "src/**/*.{test,spec}.tsx", ], @@ -89,6 +90,7 @@ export default defineConfig({ // ==================== 文件匹配 ==================== include: [ "tests/unit/**/*.{test,spec}.ts", // 单元测试 + "tests/security/**/*.{test,spec}.ts", "tests/api/**/*.{test,spec}.ts", // API 测试 "src/**/*.{test,spec}.ts", // 支持源码中的测试 ],