diff --git a/server/mcp_server_supabase/README.md b/server/mcp_server_supabase/README.md index 3c0e95ac..6bd52dfa 100644 --- a/server/mcp_server_supabase/README.md +++ b/server/mcp_server_supabase/README.md @@ -2,34 +2,31 @@ English | [简体中文](README_zh.md) -> Supabase MCP server for AIDAP workspaces. It exposes workspace, branch, database, Edge Functions, storage, and TypeScript type generation capabilities through MCP. +> MCP server for Volcengine Supabase workspaces. It exposes workspace, branch, database, Edge Functions, storage, and TypeScript type generation capabilities through MCP. | Item | Details | | ---- | ---- | | Version | v0.1.0 | -| Description | Supabase MCP server built on top of AIDAP workspaces | +| Description | MCP server built on top of Volcengine Supabase workspaces | | Category | Database / Developer Tools | -| Tags | Supabase, PostgreSQL, Edge Functions, Storage, AIDAP | +| Tags | Supabase, PostgreSQL, Edge Functions, Storage, Volcengine | ## Tools -### Workspace and Branch +### `account` | Tool | Description | | ---- | ---- | | `list_workspaces` | List all available Supabase workspaces in the current account | -| `get_workspace` | Get workspace details; branch IDs are also accepted | +| `get_workspace` | Get workspace details | | `create_workspace` | Create a new Supabase workspace | | `pause_workspace` | Pause a workspace | | `restore_workspace` | Resume a paused workspace | -| `get_workspace_url` | Get the API endpoint for a workspace or branch | -| `get_publishable_keys` | Get publishable, anon, and service role keys | -| `list_branches` | List branches under a workspace | -| `create_branch` | Create a development branch | -| `delete_branch` | Delete a development branch | -| `reset_branch` | Reset a branch to its baseline state | +### `docs` + +No tools are currently exposed. -### Database +### `database` | Tool | Description | | ---- | ---- | @@ -38,18 +35,38 @@ English | [简体中文](README_zh.md) | `list_migrations` | List records from `supabase_migrations.schema_migrations` | | `list_extensions` | List installed PostgreSQL extensions | | `apply_migration` | Run migration SQL and record it in `supabase_migrations.schema_migrations` | + +### `debugging` + +No tools are currently exposed. + +### `development` + +| Tool | Description | +| ---- | ---- | +| `get_workspace_url` | Get the API endpoint for a workspace | +| `get_publishable_keys` | Get publishable, anon, and service role keys | | `generate_typescript_types` | Generate TypeScript definitions from schema metadata | -### Edge Functions +### `functions` | Tool | Description | | ---- | ---- | -| `list_edge_functions` | List Edge Functions in a workspace or branch | +| `list_edge_functions` | List Edge Functions in a workspace | | `get_edge_function` | Get the source code and configuration of an Edge Function | | `deploy_edge_function` | Create or update an Edge Function | | `delete_edge_function` | Delete an Edge Function | -### Storage +### `branching` + +| Tool | Description | +| ---- | ---- | +| `list_branches` | List branches under a workspace | +| `create_branch` | Create a development branch | +| `delete_branch` | Delete a development branch | +| `restore_branch` | Restore branch data to a specified point in time and return the restored branch ID | + +### `storage` | Tool | Description | | ---- | ---- | @@ -60,20 +77,31 @@ English | [简体中文](README_zh.md) ## Authentication -Use Volcengine AK/SK authentication. Obtain your credentials from the [Volcengine API Access Key console](https://console.volcengine.com/iam/keymanage/). +- Local deployment: use `VOLCENGINE_ACCESS_KEY`, `VOLCENGINE_SECRET_KEY`, and optional `VOLCENGINE_SESSION_TOKEN` + +Static AK/SK can be obtained from the [Volcengine API Access Key console](https://console.volcengine.com/iam/keymanage/). ## Environment Variables | Name | Required | Default | Description | | ---- | ---- | ---- | ---- | -| `VOLCENGINE_ACCESS_KEY` | Yes | - | Volcengine access key | -| `VOLCENGINE_SECRET_KEY` | Yes | - | Volcengine secret key | -| `VOLCENGINE_REGION` | No | `cn-beijing` | Region used for the AIDAP API | -| `DEFAULT_WORKSPACE_ID` | No | - | Default target used when `workspace_id` is omitted | -| `READ_ONLY` | No | `false` | Set to `true` to block all mutating tools | +| `VOLCENGINE_ACCESS_KEY` | Yes | - | Volcengine access key for local static authentication | +| `VOLCENGINE_SECRET_KEY` | Yes | - | Volcengine secret key for local static authentication | +| `VOLCENGINE_SESSION_TOKEN` | No | - | Optional session token used with temporary local credentials | +| `VOLCENGINE_REGION` | No | `cn-beijing` | Region used for the Volcengine API | +| `WORKSPACE_REF` | No | - | Startup-level hard scope. When set, `account` tools are hidden and workspace-scoped calls are forced to this target | +| `FEATURES` | No | `account,database,debugging,development,docs,functions,branching` | Official feature groups. `storage` is disabled by default | +| `DISABLED_TOOLS` | No | - | Comma-separated denylist applied after all other policy filters | +| `READ_ONLY` | No | `false` | Startup-level read-only switch; when enabled, mutating tools are hidden | | `SUPABASE_WORKSPACE_SLUG` | No | `default` | Project slug used by Edge Functions APIs | | `SUPABASE_ENDPOINT_SCHEME` | No | `http` | Endpoint scheme used when building workspace URLs | -| `PORT` | No | `8000` | Port used when running the server directly | +| `MCP_SERVER_HOST` | No | `0.0.0.0` | Host used by `sse` and `streamable-http` transports | +| `MCP_SERVER_PORT` | No | `8000` | Preferred port variable for network transports | +| `PORT` | No | `8000` | Backward-compatible port variable | +| `MCP_MOUNT_PATH` | No | `/` | Base mount path for HTTP transports | +| `MCP_SSE_PATH` | No | `/sse` | SSE endpoint path | +| `MCP_MESSAGE_PATH` | No | `/messages/` | SSE message POST path | +| `STREAMABLE_HTTP_PATH` | No | `/mcp` | Streamable HTTP endpoint path | ## Deployment @@ -83,7 +111,22 @@ Use Volcengine AK/SK authentication. Obtain your credentials from the [Volcengin uv --directory /ABSOLUTE/PATH/TO/mcp-server/server/mcp_server_supabase run mcp-server-supabase ``` -### MCP client config with local source +### Run with an explicit transport + +```bash +uv --directory /ABSOLUTE/PATH/TO/mcp-server/server/mcp_server_supabase run mcp-server-supabase --transport stdio +uv --directory /ABSOLUTE/PATH/TO/mcp-server/server/mcp_server_supabase run mcp-server-supabase --transport sse --host 0.0.0.0 --port 8000 +uv --directory /ABSOLUTE/PATH/TO/mcp-server/server/mcp_server_supabase run mcp-server-supabase --transport streamable-http --host 0.0.0.0 --port 8000 +``` + +### Dedicated network entrypoints + +```bash +uv --directory /ABSOLUTE/PATH/TO/mcp-server/server/mcp_server_supabase run mcp-server-supabase-sse +uv --directory /ABSOLUTE/PATH/TO/mcp-server/server/mcp_server_supabase run mcp-server-supabase-streamable +``` + +### AI tool integration with local source ```json { @@ -100,14 +143,15 @@ uv --directory /ABSOLUTE/PATH/TO/mcp-server/server/mcp_server_supabase run mcp-s "VOLCENGINE_ACCESS_KEY": "", "VOLCENGINE_SECRET_KEY": "", "VOLCENGINE_REGION": "cn-beijing", - "DEFAULT_WORKSPACE_ID": "ws-xxxxxxxx" + "WORKSPACE_REF": "ws-xxxxxxxx", + "FEATURES": "database,functions" } } } } ``` -### MCP client config with `uvx` +### AI tool integration with `uvx` ```json { @@ -123,7 +167,8 @@ uv --directory /ABSOLUTE/PATH/TO/mcp-server/server/mcp_server_supabase run mcp-s "VOLCENGINE_ACCESS_KEY": "", "VOLCENGINE_SECRET_KEY": "", "VOLCENGINE_REGION": "cn-beijing", - "DEFAULT_WORKSPACE_ID": "ws-xxxxxxxx" + "WORKSPACE_REF": "ws-xxxxxxxx", + "FEATURES": "database,functions" } } } @@ -134,25 +179,55 @@ uv --directory /ABSOLUTE/PATH/TO/mcp-server/server/mcp_server_supabase run mcp-s ```bash python3 -m mcp_server_supabase.server --port 8000 +python3 -m mcp_server_supabase.server --transport sse --host 0.0.0.0 --port 8000 ``` -The package exposes both `mcp-server-supabase` and `supabase-aidap`. The examples above use `mcp-server-supabase`. +The package exposes `mcp-server-supabase`, `mcp-server-supabase-sse`, and `mcp-server-supabase-streamable`. The examples above use `mcp-server-supabase`. ## Usage Notes -- If `workspace_id` is omitted, the server falls back to `DEFAULT_WORKSPACE_ID` when configured. -- If a branch ID such as `br-xxxx` is provided, the server resolves the corresponding workspace automatically. +- `WORKSPACE_REF` applies a hard workspace scope for the server instance and removes `workspace_id` from visible tool schemas. +- When `WORKSPACE_REF` is active, `account` tools are hidden and any explicit `workspace_id` outside the scope is rejected. +- `FEATURES` accepts only the official groups: `account`, `docs`, `database`, `debugging`, `development`, `functions`, `storage`, and `branching`. +- If `FEATURES` is not set, the default enabled groups are `account`, `database`, `debugging`, `development`, `docs`, `functions`, and `branching`. `storage` stays disabled by default. +- `READ_ONLY=true` hides all mutating tools for the server instance. +- `DISABLED_TOOLS` takes tool names such as `execute_sql,deploy_edge_function` and removes them after the rest of the policy has been resolved. +- `workspace_id` and `workspace_ref` accept workspace IDs only. Branch IDs such as `br-xxxx` are rejected. - `get_publishable_keys` resolves the default branch automatically when needed. -- `reset_branch` accepts `migration_version`, but the current AIDAP API ignores that value and performs a branch reset only. +- `restore_branch` supports optional `time` and `source_branch_id` arguments and returns `backup_branch_id`. - `deploy_edge_function` currently supports `native-node20/v1`, `native-python3.9/v1`, `native-python3.10/v1`, and `native-python3.12/v1`. +- `--transport sse` serves the MCP SSE endpoint at `MCP_SSE_PATH` and the message endpoint at `MCP_MESSAGE_PATH`. +- `--transport streamable-http` serves the MCP HTTP endpoint at `STREAMABLE_HTTP_PATH`. +- For remote deployments, `streamable-http` is usually the better default; `sse` remains available for clients that still require it. + +## Policy Precedence + +### Tool filtering order at startup + +1. `features` selects the base tool set +2. `workspace_ref` removes `account` tools and scopes the server to one workspace +3. `read_only` removes all mutating tools +4. `disabled_tools` removes specific tool names last + +## Integration Modes + +### AI tools + +This server works with Cursor, Claude Desktop, Cline, Trae, and any other MCP client that supports `stdio`, `sse`, or `streamable-http`. + +- Local integrations usually use `stdio` +- Configure `command`, `args`, and `env` in the client +- Local source mode usually injects static AK/SK through `env` +- The two `mcpServers` JSON examples above follow this pattern + +### Custom AI agents -## Compatible Clients +If your agent runtime can spawn a local MCP process, you can keep using `stdio`. If your agent runs on a server, in containers, or in a multi-instance environment, `streamable-http` or `sse` is usually the better integration path. -- Cursor -- Claude Desktop -- Cline -- Trae -- Any MCP client that supports `stdio` +- `stdio`: have the agent spawn `mcp-server-supabase` as a child process +- `streamable-http`: connect to `http://:/mcp` +- `sse`: connect to `http://:/sse` and post messages to `http://:/messages/` +- Tool visibility and workspace scope are fixed when the server starts through env vars or CLI flags ## License diff --git a/server/mcp_server_supabase/README_zh.md b/server/mcp_server_supabase/README_zh.md index 8079b28a..1359837f 100644 --- a/server/mcp_server_supabase/README_zh.md +++ b/server/mcp_server_supabase/README_zh.md @@ -2,34 +2,32 @@ [English](README.md) | 简体中文 -> 面向 AIDAP workspace 的 Supabase MCP Server,通过 MCP 暴露工作区、分支、数据库、Edge Functions、Storage 和 TypeScript 类型生成能力。 +> 面向火山引擎 Supabase 的 MCP Server,通过 MCP 暴露工作区、分支、数据库、Edge Functions、Storage 和 TypeScript 类型生成能力。 | 项目 | 详情 | | ---- | ---- | | 版本 | v0.1.0 | -| 描述 | 基于 AIDAP workspace 的 Supabase MCP Server | +| 描述 | 基于火山引擎 Supabase workspace 的 MCP Server | | 分类 | 数据库 / 开发工具 | -| 标签 | Supabase, PostgreSQL, Edge Functions, Storage, AIDAP | +| 标签 | Supabase, PostgreSQL, Edge Functions, Storage, Volcengine | ## 工具列表 -### 工作区与分支 +### `account` | 工具 | 说明 | | ---- | ---- | | `list_workspaces` | 列出当前账号下可访问的 Supabase workspace | -| `get_workspace` | 查询 workspace 详情,也支持直接传 branch ID | +| `get_workspace` | 查询 workspace 详情 | | `create_workspace` | 创建新的 Supabase workspace | | `pause_workspace` | 暂停 workspace | | `restore_workspace` | 恢复已暂停的 workspace | -| `get_workspace_url` | 获取 workspace 或 branch 的 API 地址 | -| `get_publishable_keys` | 获取 publishable、anon、service_role 等密钥 | -| `list_branches` | 列出 workspace 下的分支 | -| `create_branch` | 创建开发分支 | -| `delete_branch` | 删除开发分支 | -| `reset_branch` | 将分支重置到初始状态 | -### 数据库 +### `docs` + +当前没有暴露工具。 + +### `database` | 工具 | 说明 | | ---- | ---- | @@ -38,18 +36,38 @@ | `list_migrations` | 查询 `supabase_migrations.schema_migrations` 中的迁移记录 | | `list_extensions` | 列出已安装的 PostgreSQL 扩展 | | `apply_migration` | 执行迁移 SQL,并写入 `supabase_migrations.schema_migrations` | + +### `debugging` + +当前没有暴露工具。 + +### `development` + +| 工具 | 说明 | +| ---- | ---- | +| `get_workspace_url` | 获取 workspace 的 API 地址 | +| `get_publishable_keys` | 获取 publishable、anon、service_role 等密钥 | | `generate_typescript_types` | 根据 schema 元数据生成 TypeScript 类型定义 | -### Edge Functions +### `functions` | 工具 | 说明 | | ---- | ---- | -| `list_edge_functions` | 列出 workspace 或 branch 下的 Edge Functions | +| `list_edge_functions` | 列出 workspace 下的 Edge Functions | | `get_edge_function` | 获取 Edge Function 的代码和配置 | | `deploy_edge_function` | 创建或更新 Edge Function | | `delete_edge_function` | 删除 Edge Function | -### Storage +### `branching` + +| 工具 | 说明 | +| ---- | ---- | +| `list_branches` | 列出 workspace 下的分支 | +| `create_branch` | 创建开发分支 | +| `delete_branch` | 删除开发分支 | +| `restore_branch` | 将分支数据恢复到指定时间点,并返回恢复出的新分支 ID | + +### `storage` | 工具 | 说明 | | ---- | ---- | @@ -60,20 +78,31 @@ ## 鉴权方式 -使用火山引擎 AK/SK 鉴权。可在[火山引擎 API 访问密钥控制台](https://console.volcengine.com/iam/keymanage/)获取凭证。 +- 本地部署:使用 `VOLCENGINE_ACCESS_KEY`、`VOLCENGINE_SECRET_KEY` 和可选的 `VOLCENGINE_SESSION_TOKEN` + +静态 AK/SK 可在[火山引擎 API 访问密钥控制台](https://console.volcengine.com/iam/keymanage/)获取。 ## 环境变量 | 变量名 | 必需 | 默认值 | 说明 | | ---- | ---- | ---- | ---- | -| `VOLCENGINE_ACCESS_KEY` | 是 | - | 火山引擎 Access Key | -| `VOLCENGINE_SECRET_KEY` | 是 | - | 火山引擎 Secret Key | -| `VOLCENGINE_REGION` | 否 | `cn-beijing` | AIDAP API 所在地域 | -| `DEFAULT_WORKSPACE_ID` | 否 | - | 未传 `workspace_id` 时使用的默认目标 | -| `READ_ONLY` | 否 | `false` | 设为 `true` 后会禁止所有写操作工具 | +| `VOLCENGINE_ACCESS_KEY` | 是 | - | 本地静态鉴权使用的火山引擎 Access Key | +| `VOLCENGINE_SECRET_KEY` | 是 | - | 本地静态鉴权使用的火山引擎 Secret Key | +| `VOLCENGINE_SESSION_TOKEN` | 否 | - | 临时本地凭证使用的 Session Token | +| `VOLCENGINE_REGION` | 否 | `cn-beijing` | 火山引擎 API 所在地域 | +| `WORKSPACE_REF` | 否 | - | 服务启动级 workspace scope,设置后会隐藏 `account` 组工具,并强制所有 workspace-scoped 调用只能访问这个目标 | +| `FEATURES` | 否 | `account,database,debugging,development,docs,functions,branching` | 官方 feature groups,`storage` 默认关闭 | +| `DISABLED_TOOLS` | 否 | - | 逗号分隔的工具黑名单,在其他策略之后做最终剔除 | +| `READ_ONLY` | 否 | `false` | 服务启动级只读开关;启用后会隐藏所有写工具 | | `SUPABASE_WORKSPACE_SLUG` | 否 | `default` | Edge Functions API 使用的项目 slug | | `SUPABASE_ENDPOINT_SCHEME` | 否 | `http` | 生成 workspace URL 时使用的协议 | -| `PORT` | 否 | `8000` | 直接启动服务时监听的端口 | +| `MCP_SERVER_HOST` | 否 | `0.0.0.0` | `sse` 和 `streamable-http` 使用的监听地址 | +| `MCP_SERVER_PORT` | 否 | `8000` | 网络传输优先使用的端口变量 | +| `PORT` | 否 | `8000` | 兼容保留的端口变量 | +| `MCP_MOUNT_PATH` | 否 | `/` | HTTP 传输的基础挂载路径 | +| `MCP_SSE_PATH` | 否 | `/sse` | SSE 连接路径 | +| `MCP_MESSAGE_PATH` | 否 | `/messages/` | SSE 消息投递路径 | +| `STREAMABLE_HTTP_PATH` | 否 | `/mcp` | Streamable HTTP 路径 | ## 部署 @@ -83,7 +112,22 @@ uv --directory /ABSOLUTE/PATH/TO/mcp-server/server/mcp_server_supabase run mcp-server-supabase ``` -### 使用本地源码配置 MCP Client +### 显式指定 transport 启动 + +```bash +uv --directory /ABSOLUTE/PATH/TO/mcp-server/server/mcp_server_supabase run mcp-server-supabase --transport stdio +uv --directory /ABSOLUTE/PATH/TO/mcp-server/server/mcp_server_supabase run mcp-server-supabase --transport sse --host 0.0.0.0 --port 8000 +uv --directory /ABSOLUTE/PATH/TO/mcp-server/server/mcp_server_supabase run mcp-server-supabase --transport streamable-http --host 0.0.0.0 --port 8000 +``` + +### 独立网络启动入口 + +```bash +uv --directory /ABSOLUTE/PATH/TO/mcp-server/server/mcp_server_supabase run mcp-server-supabase-sse +uv --directory /ABSOLUTE/PATH/TO/mcp-server/server/mcp_server_supabase run mcp-server-supabase-streamable +``` + +### AI 工具使用本地源码接入 ```json { @@ -100,14 +144,15 @@ uv --directory /ABSOLUTE/PATH/TO/mcp-server/server/mcp_server_supabase run mcp-s "VOLCENGINE_ACCESS_KEY": "", "VOLCENGINE_SECRET_KEY": "", "VOLCENGINE_REGION": "cn-beijing", - "DEFAULT_WORKSPACE_ID": "ws-xxxxxxxx" + "WORKSPACE_REF": "ws-xxxxxxxx", + "FEATURES": "database,functions" } } } } ``` -### 使用 `uvx` 配置 MCP Client +### AI 工具使用 `uvx` 接入 ```json { @@ -123,7 +168,8 @@ uv --directory /ABSOLUTE/PATH/TO/mcp-server/server/mcp_server_supabase run mcp-s "VOLCENGINE_ACCESS_KEY": "", "VOLCENGINE_SECRET_KEY": "", "VOLCENGINE_REGION": "cn-beijing", - "DEFAULT_WORKSPACE_ID": "ws-xxxxxxxx" + "WORKSPACE_REF": "ws-xxxxxxxx", + "FEATURES": "database,functions" } } } @@ -134,25 +180,55 @@ uv --directory /ABSOLUTE/PATH/TO/mcp-server/server/mcp_server_supabase run mcp-s ```bash python3 -m mcp_server_supabase.server --port 8000 +python3 -m mcp_server_supabase.server --transport sse --host 0.0.0.0 --port 8000 ``` -这个包同时暴露了 `mcp-server-supabase` 和 `supabase-aidap` 两个入口,示例统一使用 `mcp-server-supabase`。 +这个包同时暴露了 `mcp-server-supabase`、`mcp-server-supabase-sse` 和 `mcp-server-supabase-streamable` 三个入口,示例统一使用 `mcp-server-supabase`。 ## 使用说明 -- 如果没有显式传入 `workspace_id`,且配置了 `DEFAULT_WORKSPACE_ID`,服务会自动使用这个默认目标。 -- 如果传入的是 `br-xxxx` 这样的 branch ID,服务会自动解析所属 workspace。 +- `WORKSPACE_REF` 会把服务实例 hard-scope 到单个目标,并在 tool schema 中移除 `workspace_id`。 +- `WORKSPACE_REF` 生效时,`account` 组工具不会暴露,且显式传入其他 `workspace_id` 会被拒绝。 +- `FEATURES` 只接受官方 8 个分组:`account`、`docs`、`database`、`debugging`、`development`、`functions`、`storage`、`branching`。 +- 如果没有设置 `FEATURES`,默认启用 `account`、`database`、`debugging`、`development`、`docs`、`functions`、`branching`,`storage` 默认关闭。 +- `READ_ONLY=true` 会让整个服务实例进入只读模式,并隐藏所有写工具。 +- `DISABLED_TOOLS` 填工具名,例如 `execute_sql,deploy_edge_function`,会在其他策略计算完成后做最终剔除。 +- `workspace_id` 和 `workspace_ref` 只接受 workspace ID,`br-xxxx` 这样的 branch ID 会被直接拒绝。 - `get_publishable_keys` 在需要时会自动解析默认分支。 -- `reset_branch` 虽然接收 `migration_version` 参数,但当前 AIDAP API 会忽略这个值,只执行分支重置。 +- `restore_branch` 支持可选的 `time` 和 `source_branch_id` 参数,并返回 `backup_branch_id`。 - `deploy_edge_function` 当前支持 `native-node20/v1`、`native-python3.9/v1`、`native-python3.10/v1`、`native-python3.12/v1`。 +- `--transport sse` 会在 `MCP_SSE_PATH` 暴露 SSE 连接地址,并在 `MCP_MESSAGE_PATH` 暴露消息投递地址。 +- `--transport streamable-http` 会在 `STREAMABLE_HTTP_PATH` 暴露 MCP HTTP 地址。 +- 远程部署通常更推荐 `streamable-http`,但为了兼容仍保留 `sse`。 + +## 配置优先级 + +### 启动时的工具过滤顺序 + +1. `features` 先决定基础工具集合 +2. `workspace_ref` 再移除 `account` 工具,并把服务限制到单个 workspace +3. `read_only` 再移除所有写工具 +4. `disabled_tools` 最后按工具名做剔除 + +## 接入方式 + +### AI 工具 + +适用于 Cursor、Claude Desktop、Cline、Trae 等带 MCP 配置界面的 AI 工具,也适用于其他支持 `stdio`、`sse` 或 `streamable-http` 的 MCP Client。 + +- 本地集成通常使用 `stdio` +- 直接在客户端配置 `command`、`args` 和 `env` +- 本地源码接入通常通过 `env` 传静态 AK/SK +- 上面的两个 `mcpServers` JSON 示例就是这类接入方式 + +### 自研 AI Agent -## 可适配客户端 +如果你的 Agent Runtime 可以直接拉起本地 MCP 进程,可以继续使用 `stdio`。如果你的 Agent 部署在服务端、容器或多实例环境,更推荐用 `streamable-http` 或 `sse` 暴露远程地址再接入。 -- Cursor -- Claude Desktop -- Cline -- Trae -- 所有支持 `stdio` 的 MCP Client +- `stdio`:Agent 进程直接拉起 `mcp-server-supabase` +- `streamable-http`:连接 `http://:/mcp` +- `sse`:连接 `http://:/sse`,并向 `http://:/messages/` 投递消息 +- 工具可见性和 workspace scope 在服务启动时通过环境变量或 CLI 参数固定下来 ## License diff --git a/server/mcp_server_supabase/pyproject.toml b/server/mcp_server_supabase/pyproject.toml index b1348b7c..746a3170 100644 --- a/server/mcp_server_supabase/pyproject.toml +++ b/server/mcp_server_supabase/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "mcp-server-supabase" version = "0.1.0" -description = "MCP server for Supabase/AIDAP" +description = "MCP server for Volcengine Supabase" readme = "README.md" requires-python = ">=3.10" license = { text = "Apache-2.0" } @@ -28,7 +28,8 @@ legacy = [ [project.scripts] mcp-server-supabase = "mcp_server_supabase.server:main" -supabase-aidap = "mcp_server_supabase.server:main" +mcp-server-supabase-sse = "mcp_server_supabase.sse:main" +mcp-server-supabase-streamable = "mcp_server_supabase.streamable_http:main" [build-system] requires = ["hatchling"] diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/access_policy.py b/server/mcp_server_supabase/src/mcp_server_supabase/access_policy.py new file mode 100644 index 00000000..978faa16 --- /dev/null +++ b/server/mcp_server_supabase/src/mcp_server_supabase/access_policy.py @@ -0,0 +1,168 @@ +import json +from dataclasses import dataclass +from typing import Any + +from .tool_registry import TOOL_DEFINITIONS + +OFFICIAL_FEATURE_GROUPS = ( + "account", + "docs", + "database", + "debugging", + "development", + "functions", + "storage", + "branching", +) +DEFAULT_FEATURE_GROUPS = frozenset({ + "account", + "database", + "debugging", + "development", + "docs", + "functions", + "branching", +}) + +ALL_TOOL_NAMES = frozenset(tool.name for tool in TOOL_DEFINITIONS) +FEATURE_TOOLS = { + feature: frozenset(tool.name for tool in TOOL_DEFINITIONS if tool.feature == feature) + for feature in OFFICIAL_FEATURE_GROUPS +} +SCOPED_TOOL_NAMES = frozenset(tool.name for tool in TOOL_DEFINITIONS if tool.scoped) +MUTATING_TOOL_NAMES = frozenset(tool.name for tool in TOOL_DEFINITIONS if tool.mutating) + + +@dataclass(frozen=True) +class AccessPolicy: + workspace_ref: str | None = None + features: frozenset[str] = DEFAULT_FEATURE_GROUPS + read_only: bool = False + disabled_tools: frozenset[str] = frozenset() + + +def _normalize_name(value: Any) -> str: + if not isinstance(value, str): + raise ValueError("Expected string value") + normalized = value.strip() + if not normalized: + raise ValueError("Value cannot be empty") + return normalized + + +def _expand_names(value: Any) -> list[str]: + if value is None: + return [] + if isinstance(value, str): + text = value.strip() + if not text: + return [] + if text.startswith("["): + parsed = json.loads(text) + if not isinstance(parsed, list): + raise ValueError("Expected a JSON array") + return [_normalize_name(item) for item in parsed] + return [_normalize_name(item) for item in text.split(",") if item.strip()] + if isinstance(value, (list, tuple, set, frozenset)): + return [_normalize_name(item) for item in value] + raise ValueError("Unsupported value type") + + +def _parse_name_set(value: Any) -> frozenset[str] | None: + names = _expand_names(value) + if not names: + return None + return frozenset(names) + + +def _parse_workspace_ref(value: Any) -> str | None: + if value is None: + return None + if not isinstance(value, str): + raise ValueError("workspace_ref must be a string") + normalized = value.strip() + if not normalized: + return None + if normalized.startswith("br-"): + raise ValueError("workspace_ref must be a workspace ID; branch IDs are not supported") + return normalized + + +def _parse_read_only(value: Any) -> bool | None: + if value is None: + return None + if isinstance(value, bool): + return value + if not isinstance(value, str): + raise ValueError("read_only must be a boolean") + normalized = value.strip().lower() + if not normalized: + return None + if normalized in {"1", "true", "yes", "on"}: + return True + if normalized in {"0", "false", "no", "off"}: + return False + raise ValueError("read_only must be true or false") + + +def _validate_features(features: frozenset[str] | None) -> frozenset[str] | None: + if features is None: + return None + invalid = sorted(features - set(OFFICIAL_FEATURE_GROUPS)) + if invalid: + raise ValueError(f"Unsupported features: {', '.join(invalid)}") + return features + + +def _validate_tools(tools: frozenset[str] | None, field_name: str) -> frozenset[str] | None: + if tools is None: + return None + invalid = sorted(tools - ALL_TOOL_NAMES) + if invalid: + raise ValueError(f"Unsupported {field_name}: {', '.join(invalid)}") + return tools + + +def build_access_policy( + workspace_ref: Any = None, + features: Any = None, + read_only: Any = None, + disabled_tools: Any = None, +) -> AccessPolicy: + return AccessPolicy( + workspace_ref=_parse_workspace_ref(workspace_ref), + features=_validate_features(_parse_name_set(features)) or DEFAULT_FEATURE_GROUPS, + read_only=bool(_parse_read_only(read_only)), + disabled_tools=_validate_tools(_parse_name_set(disabled_tools), "disabled_tools") or frozenset(), + ) + + +def resolve_allowed_tools(policy: AccessPolicy) -> frozenset[str]: + allowed = frozenset().union(*(FEATURE_TOOLS[feature] for feature in policy.features)) + if policy.workspace_ref: + allowed -= FEATURE_TOOLS["account"] + if policy.read_only: + allowed -= MUTATING_TOOL_NAMES + allowed -= policy.disabled_tools + return allowed + + +def workspace_scope_schema(tool_name: str, input_schema: dict[str, Any], workspace_ref: str | None) -> dict[str, Any]: + if tool_name not in SCOPED_TOOL_NAMES: + return input_schema + result = dict(input_schema) + properties = dict(input_schema.get("properties", {})) + result["properties"] = properties + required = [name for name in result.get("required", []) if name != "workspace_id"] + if workspace_ref: + properties.pop("workspace_id", None) + if required: + result["required"] = required + elif "required" in result: + result.pop("required", None) + return result + if "workspace_id" in properties and "workspace_id" not in required: + required.append("workspace_id") + if required: + result["required"] = required + return result diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/config.py b/server/mcp_server_supabase/src/mcp_server_supabase/config.py index c9520ea3..c7f3532c 100644 --- a/server/mcp_server_supabase/src/mcp_server_supabase/config.py +++ b/server/mcp_server_supabase/src/mcp_server_supabase/config.py @@ -1,87 +1,3 @@ import os -import logging -logger = logging.getLogger(__name__) - -READ_ONLY = os.getenv("READ_ONLY", "false").lower() == "true" - -VOLCENGINE_ACCESS_KEY = os.getenv("VOLCENGINE_ACCESS_KEY") -VOLCENGINE_SECRET_KEY = os.getenv("VOLCENGINE_SECRET_KEY") VOLCENGINE_REGION = os.getenv("VOLCENGINE_REGION", "cn-beijing") - -# 验证必需的环境变量 -if not VOLCENGINE_ACCESS_KEY: - logger.warning("VOLCENGINE_ACCESS_KEY not set") -if not VOLCENGINE_SECRET_KEY: - logger.warning("VOLCENGINE_SECRET_KEY not set") - -_default_branch_cache = {} -_endpoint_cache = {} -_api_key_cache = {} -_branch_workspace_cache = {} - - -def get_branch_cache(): - return _default_branch_cache - - -def get_endpoint_cache(): - return _endpoint_cache - - -def get_api_key_cache(): - return _api_key_cache - - -def get_branch_workspace_cache(): - return _branch_workspace_cache - - -def clear_branch_cache(workspace_id: str = None): - if workspace_id: - _default_branch_cache.pop(workspace_id, None) - else: - _default_branch_cache.clear() - - -def clear_endpoint_cache(workspace_id: str = None, branch_id: str = None): - if workspace_id and branch_id: - _endpoint_cache.pop(f"{workspace_id}:{branch_id}", None) - elif workspace_id: - _endpoint_cache.pop(workspace_id, None) - keys_to_delete = [key for key in _endpoint_cache if key.startswith(f"{workspace_id}:")] - for key in keys_to_delete: - _endpoint_cache.pop(key, None) - else: - _endpoint_cache.clear() - - -def clear_api_key_cache(workspace_id: str = None, branch_id: str = None): - if workspace_id and branch_id: - keys_to_delete = [key for key in _api_key_cache if key.startswith(f"{workspace_id}:") and key.endswith(f":{branch_id}")] - for key in keys_to_delete: - _api_key_cache.pop(key, None) - elif workspace_id: - keys_to_delete = [key for key in _api_key_cache if key == workspace_id or key.startswith(f"{workspace_id}:")] - for key in keys_to_delete: - _api_key_cache.pop(key, None) - else: - _api_key_cache.clear() - - -def clear_branch_workspace_cache(workspace_id: str = None, branch_id: str = None): - if branch_id: - _branch_workspace_cache.pop(branch_id, None) - elif workspace_id: - branch_ids = [key for key, value in _branch_workspace_cache.items() if value == workspace_id] - for key in branch_ids: - _branch_workspace_cache.pop(key, None) - else: - _branch_workspace_cache.clear() - - -def clear_all_caches(workspace_id: str = None, branch_id: str = None): - clear_branch_cache(workspace_id) - clear_endpoint_cache(workspace_id, branch_id) - clear_api_key_cache(workspace_id, branch_id) - clear_branch_workspace_cache(workspace_id, branch_id) diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/credentials.py b/server/mcp_server_supabase/src/mcp_server_supabase/credentials.py new file mode 100644 index 00000000..17473e43 --- /dev/null +++ b/server/mcp_server_supabase/src/mcp_server_supabase/credentials.py @@ -0,0 +1,123 @@ +import base64 +import json +import os +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Callable + + +VEFAAS_IAM_CREDENTIAL_PATH = "/var/run/secrets/iam/credential" +AUTHORIZATION_ENV_NAMES = ("authorization", "AUTHORIZATION") +STATIC_ACCESS_KEY_ENV_NAMES = ("VOLCENGINE_ACCESS_KEY", "VOLC_ACCESSKEY") +STATIC_SECRET_KEY_ENV_NAMES = ("VOLCENGINE_SECRET_KEY", "VOLC_SECRETKEY") +STATIC_SESSION_TOKEN_ENV_NAMES = ("VOLCENGINE_SESSION_TOKEN",) + + +@dataclass(frozen=True, slots=True) +class VolcengineCredentials: + access_key: str + secret_key: str + session_token: str + + +def _get_env_value(*names: str) -> str: + for name in names: + value = os.getenv(name) + if value: + return value + return "" + + +def _normalize_iso8601(value: str) -> str: + return value.replace("Z", "+00:00") if value.endswith("Z") else value + + +def _validate_sts_time_window(payload: dict[str, Any]) -> None: + current_time = payload.get("CurrentTime") + expired_time = payload.get("ExpiredTime") + if not current_time or not expired_time: + return + current_dt = datetime.fromisoformat(_normalize_iso8601(str(current_time))) + expired_dt = datetime.fromisoformat(_normalize_iso8601(str(expired_time))) + if current_dt > expired_dt: + raise ValueError("STS token is expired") + + +def _parse_authorization_payload(raw_value: str) -> VolcengineCredentials: + token = raw_value.split(" ", 1)[1] if " " in raw_value else raw_value + decoded_bytes = base64.b64decode(token) + payload = json.loads(decoded_bytes.decode("utf-8")) + _validate_sts_time_window(payload) + access_key = str(payload.get("AccessKeyId") or "").strip() + secret_key = str(payload.get("SecretAccessKey") or "").strip() + session_token = str(payload.get("SessionToken") or "").strip() + if not access_key or not secret_key: + raise ValueError("AccessKeyId or SecretAccessKey missing in authorization payload") + return VolcengineCredentials( + access_key=access_key, + secret_key=secret_key, + session_token=session_token, + ) + + +def _get_request_authorization(context_getter: Callable[[], Any] | None) -> str: + if context_getter is None: + return "" + try: + context = context_getter() + except Exception: + return "" + request_context = getattr(context, "request_context", None) + if request_context is None: + request_context = getattr(context, "_request_context", None) + request = getattr(request_context, "request", None) + if request is None: + return "" + return str(request.headers.get("authorization") or "").strip() + + +def _get_vefaas_iam_credentials() -> VolcengineCredentials | None: + path = Path(VEFAAS_IAM_CREDENTIAL_PATH) + if not path.exists(): + return None + payload = json.loads(path.read_text()) + access_key = str(payload.get("access_key_id") or "").strip() + secret_key = str(payload.get("secret_access_key") or "").strip() + session_token = str(payload.get("session_token") or "").strip() + if not access_key or not secret_key: + return None + return VolcengineCredentials( + access_key=access_key, + secret_key=secret_key, + session_token=session_token, + ) + + +def resolve_volcengine_credentials(context_getter: Callable[[], Any] | None = None) -> VolcengineCredentials: + static_access_key = _get_env_value(*STATIC_ACCESS_KEY_ENV_NAMES) + static_secret_key = _get_env_value(*STATIC_SECRET_KEY_ENV_NAMES) + static_session_token = _get_env_value(*STATIC_SESSION_TOKEN_ENV_NAMES) + if static_access_key and static_secret_key: + return VolcengineCredentials( + access_key=static_access_key, + secret_key=static_secret_key, + session_token=static_session_token, + ) + + request_authorization = _get_request_authorization(context_getter) + if request_authorization: + return _parse_authorization_payload(request_authorization) + + env_authorization = _get_env_value(*AUTHORIZATION_ENV_NAMES) + if env_authorization: + return _parse_authorization_payload(env_authorization) + + vefaas_credentials = _get_vefaas_iam_credentials() + if vefaas_credentials is not None: + return vefaas_credentials + + raise ValueError( + "Volcengine credentials are not configured. " + "Set VOLCENGINE_ACCESS_KEY/VOLCENGINE_SECRET_KEY, provide authorization, or mount VeFaaS IAM credentials." + ) diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/platform/aidap_client.py b/server/mcp_server_supabase/src/mcp_server_supabase/platform/aidap_client.py index de038c82..5beae88c 100644 --- a/server/mcp_server_supabase/src/mcp_server_supabase/platform/aidap_client.py +++ b/server/mcp_server_supabase/src/mcp_server_supabase/platform/aidap_client.py @@ -2,17 +2,10 @@ import asyncio import os import random +from collections.abc import Callable from typing import Any, Optional -from ..config import ( - VOLCENGINE_ACCESS_KEY, - VOLCENGINE_SECRET_KEY, - VOLCENGINE_REGION, - get_branch_cache, - get_branch_workspace_cache, - get_endpoint_cache, - get_api_key_cache, - clear_all_caches, -) +from ..config import VOLCENGINE_REGION +from ..credentials import resolve_volcengine_credentials from ..utils import pick_value logger = logging.getLogger(__name__) @@ -23,10 +16,10 @@ from volcenginesdkaidap import AIDAPApi from volcenginesdkaidap.models import ( DescribeBranchesRequest, - DescribeWorkspacesRequest, DescribeWorkspaceEndpointRequest, DescribeAPIKeysRequest, - ResetBranchRequest, + BranchRestoreRequest, + RestoreSettingsForBranchRestoreInput, CreateBranchRequest, DeleteBranchRequest, BranchSettingsForCreateBranchInput, @@ -43,14 +36,26 @@ class AidapClient: - def __init__(self) -> None: + def __init__(self, context_getter: Callable[[], Any] | None = None) -> None: + self._context_getter = context_getter + + def _get_credentials(self): + return resolve_volcengine_credentials(self._context_getter) + + def _create_client(self) -> AIDAPApi: + credentials = self._get_credentials() configuration = volcenginesdkcore.Configuration() - configuration.ak = VOLCENGINE_ACCESS_KEY - configuration.sk = VOLCENGINE_SECRET_KEY + configuration.ak = credentials.access_key + configuration.sk = credentials.secret_key configuration.region = VOLCENGINE_REGION - + if credentials.session_token: + configuration.session_token = credentials.session_token api_client = volcenginesdkcore.ApiClient(configuration) - self.client = AIDAPApi(api_client) + return AIDAPApi(api_client) + + @property + def client(self) -> AIDAPApi: + return self._create_client() def _branch_error_code(self, error_text: str) -> str: if "OperationDenied_BranchNotReady" in error_text: @@ -64,21 +69,6 @@ def _branch_error_code(self, error_text: str) -> str: def _pick_value(self, source: Any, *field_names: str) -> Any: return pick_value(source, *field_names) - def _looks_like_branch_id(self, value: Optional[str]) -> bool: - return bool(value and value.strip().startswith("br-")) - - def _cache_branch_workspace(self, workspace_id: Optional[str], branch_id: Optional[str]) -> None: - if workspace_id and branch_id: - get_branch_workspace_cache()[branch_id] = workspace_id - - def _workspace_ids_from_response(self, response: Any) -> list[str]: - workspace_ids = [] - for workspace in list(getattr(response, "workspaces", []) or []): - workspace_id = self._pick_value(workspace, "workspace_id") - if workspace_id: - workspace_ids.append(workspace_id) - return workspace_ids - def _branch_payload(self, branch: Any, fallback_name: Optional[str] = None) -> dict: parent_branch = self._pick_value(branch, "parent_branch") parent_id = self._pick_value(parent_branch, "branch_id", "parent_id") @@ -94,13 +84,7 @@ def _branch_payload(self, branch: Any, fallback_name: Optional[str] = None) -> d "created_at": self._pick_value(branch, "create_time", "created_at"), "updated_at": self._pick_value(branch, "update_time", "updated_at"), } - result = {key: value for key, value in payload.items() if value is not None} - self._cache_branch_workspace(result.get("workspace_id"), result.get("branch_id")) - return result - - def _describe_supabase_workspaces_response(self): - request = DescribeWorkspacesRequest() - return self.client.describe_workspaces(request) + return {key: value for key, value in payload.items() if value is not None} async def _find_branch( self, @@ -120,27 +104,6 @@ async def _find_branch( await self._sleep_backoff(attempt, base_seconds=0.5, max_seconds=3.0) return None - async def _find_workspace_id_for_branch(self, branch_id: str) -> Optional[str]: - cached_workspace_id = get_branch_workspace_cache().get(branch_id) - if cached_workspace_id: - return cached_workspace_id - response = self._describe_supabase_workspaces_response() - for workspace_id in self._workspace_ids_from_response(response): - branch = await self._find_branch(workspace_id, branch_id=branch_id, max_attempts=1) - if branch: - self._cache_branch_workspace(workspace_id, branch_id) - return workspace_id - return None - - async def resolve_workspace_and_branch(self, workspace_or_branch_id: str) -> tuple[str, Optional[str]]: - normalized_id = workspace_or_branch_id.strip() - if not self._looks_like_branch_id(normalized_id): - return normalized_id, None - workspace_id = await self._find_workspace_id_for_branch(normalized_id) - if not workspace_id: - raise ValueError(f"Could not resolve workspace for branch {normalized_id}") - return workspace_id, normalized_id - async def get_branch(self, workspace_id: str, branch_id: str) -> Optional[dict]: return await self._find_branch(workspace_id, branch_id=branch_id, max_attempts=1) @@ -154,11 +117,7 @@ async def _sleep_backoff( jitter = random.uniform(0.0, delay * 0.2) await asyncio.sleep(delay + jitter) - async def get_default_branch_id(self, workspace_id: str, use_cache: bool = True) -> Optional[str]: - cache = get_branch_cache() - if use_cache and workspace_id in cache: - return cache[workspace_id] - + async def get_default_branch_id(self, workspace_id: str) -> Optional[str]: try: request = DescribeBranchesRequest(workspace_id=workspace_id) response = self.client.describe_branches(request) @@ -166,16 +125,10 @@ async def get_default_branch_id(self, workspace_id: str, use_cache: bool = True) if hasattr(response, 'branches') and response.branches: for branch in response.branches: if getattr(branch, 'default', False): - branch_id = branch.branch_id - cache[workspace_id] = branch_id - self._cache_branch_workspace(workspace_id, branch_id) - return branch_id + return branch.branch_id first_branch = response.branches[0] - branch_id = first_branch.branch_id - cache[workspace_id] = branch_id - self._cache_branch_workspace(workspace_id, branch_id) - return branch_id + return first_branch.branch_id return None except Exception as e: @@ -300,7 +253,6 @@ async def delete_branch(self, workspace_id: str, branch_id: str) -> dict: branch_id=branch_id, ) self.client.delete_branch(request) - clear_all_caches(workspace_id, branch_id) return {"success": True} except Exception as e: error_text = str(e) @@ -323,14 +275,7 @@ async def delete_branch(self, workspace_id: str, branch_id: str) -> dict: "retriable": True, } - async def get_endpoint(self, workspace_id: str, branch_id: Optional[str] = None, use_cache: bool = True) -> Optional[str]: - # 检查缓存 - cache_key = f"{workspace_id}:{branch_id}" if branch_id else workspace_id - endpoint_cache = get_endpoint_cache() - - if use_cache and cache_key in endpoint_cache: - return endpoint_cache[cache_key] - + async def get_endpoint(self, workspace_id: str, branch_id: Optional[str] = None) -> Optional[str]: if not branch_id: branch_id = await self.get_default_branch_id(workspace_id) if not branch_id: @@ -353,36 +298,43 @@ async def get_endpoint(self, workspace_id: str, branch_id: Optional[str] = None, for domain in domains: if 'volces.com' in domain and 'ivolces.com' not in domain: - if ENDPOINT_SCHEME == "https": - result = f"https://{domain}" - else: - result = f"http://{domain}:80" - endpoint_cache[cache_key] = result - return result + return f"https://{domain}" if ENDPOINT_SCHEME == "https" else f"http://{domain}:80" if domains: - if ENDPOINT_SCHEME == "https": - result = f"https://{domains[0]}" - else: - result = f"http://{domains[0]}:80" - endpoint_cache[cache_key] = result - return result + return f"https://{domains[0]}" if ENDPOINT_SCHEME == "https" else f"http://{domains[0]}:80" return None except Exception as e: logger.error(f"Error getting endpoint: {e}") return None - async def reset_branch(self, workspace_id: str, branch_id: str) -> dict: + async def restore_branch( + self, + workspace_id: str, + branch_id: str, + source_branch_id: Optional[str] = None, + time: Optional[str] = None, + ) -> dict: max_attempts = 8 for attempt in range(1, max_attempts + 1): try: - request = ResetBranchRequest( + request = BranchRestoreRequest( workspace_id=workspace_id, branch_id=branch_id, + restore_settings=RestoreSettingsForBranchRestoreInput( + source_branch_id=source_branch_id or branch_id, + time=time, + ), ) - self.client.reset_branch(request) - return {"success": True} + response = self.client.branch_restore(request) + return { + "success": True, + "workspace_id": workspace_id, + "branch_id": branch_id, + "source_branch_id": source_branch_id or branch_id, + "time": time, + "backup_branch_id": self._pick_value(response, "backup_branch_id", "BackupBranchID"), + } except Exception as e: error_text = str(e) code = self._branch_error_code(error_text) @@ -390,7 +342,7 @@ async def reset_branch(self, workspace_id: str, branch_id: str) -> dict: if retriable and attempt < max_attempts: await self._sleep_backoff(attempt) continue - logger.error(f"Error resetting branch: {e}") + logger.error(f"Error restoring branch: {e}") return { "success": False, "error": error_text, @@ -399,20 +351,13 @@ async def reset_branch(self, workspace_id: str, branch_id: str) -> dict: } return { "success": False, - "error": "reset_branch failed after retries", + "error": "restore_branch failed after retries", "code": "OperationDenied_BranchNotReady", "retriable": True, } async def get_api_key(self, workspace_id: str, key_type: str = "service_role", - branch_id: Optional[str] = None, use_cache: bool = True) -> Optional[str]: - # 检查缓存 - cache_key = f"{workspace_id}:{key_type}:{branch_id}" if branch_id else f"{workspace_id}:{key_type}" - api_key_cache = get_api_key_cache() - - if use_cache and cache_key in api_key_cache: - return api_key_cache[cache_key] - + branch_id: Optional[str] = None) -> Optional[str]: if not branch_id: branch_id = await self.get_default_branch_id(workspace_id) if not branch_id: @@ -434,10 +379,7 @@ async def get_api_key(self, workspace_id: str, key_type: str = "service_role", for key in response.api_keys: if hasattr(key, 'type') and key.type == target_type: - result = key.key if hasattr(key, 'key') else None - if result: - api_key_cache[cache_key] = result - return result + return key.key if hasattr(key, 'key') else None return None except Exception as e: diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/platform/supabase_client.py b/server/mcp_server_supabase/src/mcp_server_supabase/platform/supabase_client.py index f28387a5..6efe7d02 100644 --- a/server/mcp_server_supabase/src/mcp_server_supabase/platform/supabase_client.py +++ b/server/mcp_server_supabase/src/mcp_server_supabase/platform/supabase_client.py @@ -2,7 +2,7 @@ import httpx import logging import json -from typing import Optional, Dict, Any +from typing import Dict, Any, Optional logger = logging.getLogger(__name__) @@ -30,21 +30,6 @@ class SupabaseClient: def __init__(self, endpoint: str, api_key: str): self.endpoint = endpoint self.api_key = api_key - self._client: Optional[httpx.AsyncClient] = None - - async def _get_client(self) -> httpx.AsyncClient: - """Get or create HTTP client with connection pooling""" - if self._client is None or self._client.is_closed: - self._client = httpx.AsyncClient( - timeout=30.0, - limits=httpx.Limits(max_keepalive_connections=5, max_connections=10) - ) - return self._client - - async def close(self): - """Close HTTP client""" - if self._client and not self._client.is_closed: - await self._client.aclose() async def call_api( self, @@ -57,7 +42,7 @@ async def call_api( timeout: float = 30.0 ) -> Any: url = f"{self.endpoint}{path}" - logger.info(f"[DEBUG] Calling API: method={method}, url={url}, path={path}") + logger.debug("Calling API method=%s url=%s path=%s", method, url, path) default_headers = { "apikey": self.api_key, @@ -66,19 +51,22 @@ async def call_api( if headers: default_headers.update(headers) - client = await self._get_client() for attempt in range(3): try: - if content: - response = await client.request( - method, url, content=content, headers=default_headers, - params=params, timeout=timeout - ) - else: - response = await client.request( - method, url, json=json_data, headers=default_headers, - params=params, timeout=timeout - ) + async with httpx.AsyncClient( + timeout=timeout, + limits=httpx.Limits(max_keepalive_connections=5, max_connections=10), + ) as client: + if content: + response = await client.request( + method, url, content=content, headers=default_headers, + params=params, timeout=timeout + ) + else: + response = await client.request( + method, url, json=json_data, headers=default_headers, + params=params, timeout=timeout + ) response.raise_for_status() if response.status_code == 204 or not response.content: diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/runtime.py b/server/mcp_server_supabase/src/mcp_server_supabase/runtime.py index d4f103c4..ed414bab 100644 --- a/server/mcp_server_supabase/src/mcp_server_supabase/runtime.py +++ b/server/mcp_server_supabase/src/mcp_server_supabase/runtime.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional +from typing import Any, Callable from .platform import AidapClient from .tools import DatabaseTools, EdgeFunctionTools, StorageTools, WorkspaceTools @@ -8,7 +8,6 @@ @dataclass(slots=True) class SupabaseRuntime: aidap_client: AidapClient - default_workspace_id: Optional[str] edge_tools: EdgeFunctionTools storage_tools: StorageTools database_tools: DatabaseTools @@ -16,15 +15,14 @@ class SupabaseRuntime: def create_runtime( - default_workspace_id: Optional[str] = None, - aidap_client: Optional[AidapClient] = None, + aidap_client: AidapClient | None = None, + context_getter: Callable[[], Any] | None = None, ) -> SupabaseRuntime: - client = aidap_client or AidapClient() + client = aidap_client or AidapClient(context_getter=context_getter) return SupabaseRuntime( aidap_client=client, - default_workspace_id=default_workspace_id, - edge_tools=EdgeFunctionTools(client, default_workspace_id), - storage_tools=StorageTools(client, default_workspace_id), - database_tools=DatabaseTools(client, default_workspace_id), - workspace_tools=WorkspaceTools(client, default_workspace_id), + edge_tools=EdgeFunctionTools(client), + storage_tools=StorageTools(client), + database_tools=DatabaseTools(client), + workspace_tools=WorkspaceTools(client), ) diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/scoped_mcp.py b/server/mcp_server_supabase/src/mcp_server_supabase/scoped_mcp.py new file mode 100644 index 00000000..f15b165a --- /dev/null +++ b/server/mcp_server_supabase/src/mcp_server_supabase/scoped_mcp.py @@ -0,0 +1,45 @@ +from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp.exceptions import ToolError +from mcp.types import Tool as MCPTool + +from .access_policy import ( + AccessPolicy, + SCOPED_TOOL_NAMES, + resolve_allowed_tools, + workspace_scope_schema, +) + + +class ScopedFastMCP(FastMCP): + def __init__(self, *args, access_policy: AccessPolicy | None = None, **kwargs): + super().__init__(*args, **kwargs) + self._access_policy = access_policy or AccessPolicy() + self._allowed_tools = resolve_allowed_tools(self._access_policy) + + async def list_tools(self): + tools = await super().list_tools() + visible_tools = [] + for tool in tools: + if tool.name not in self._allowed_tools: + continue + scoped_schema = workspace_scope_schema(tool.name, tool.inputSchema, self._access_policy.workspace_ref) + if scoped_schema is tool.inputSchema: + visible_tools.append(tool) + continue + payload = tool.model_dump(exclude_none=False) + payload["inputSchema"] = scoped_schema + visible_tools.append(MCPTool(**payload)) + return visible_tools + + async def call_tool(self, name: str, arguments: dict[str, object]): + if name not in self._allowed_tools: + raise ToolError(f"Tool '{name}' is not available for the current connection") + + effective_arguments = dict(arguments or {}) + if self._access_policy.workspace_ref and name in SCOPED_TOOL_NAMES: + provided_workspace_id = effective_arguments.get("workspace_id") + if provided_workspace_id not in {None, "", self._access_policy.workspace_ref}: + raise ToolError("workspace_id is outside the current workspace_ref scope") + effective_arguments["workspace_id"] = self._access_policy.workspace_ref + + return await super().call_tool(name, effective_arguments) diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/server.py b/server/mcp_server_supabase/src/mcp_server_supabase/server.py index 554b2b82..38422bc5 100644 --- a/server/mcp_server_supabase/src/mcp_server_supabase/server.py +++ b/server/mcp_server_supabase/src/mcp_server_supabase/server.py @@ -1,12 +1,12 @@ import argparse import logging import os +from dataclasses import dataclass -from mcp.server.fastmcp import FastMCP - -from .config import READ_ONLY from .runtime import create_runtime from .tool_registry import register_tools +from .access_policy import AccessPolicy, build_access_policy +from .scoped_mcp import ScopedFastMCP logger = logging.getLogger(__name__) logging.basicConfig( @@ -14,35 +14,149 @@ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) -default_workspace_id = os.getenv("DEFAULT_WORKSPACE_ID") +DEFAULT_HOST = "0.0.0.0" +DEFAULT_PORT = 8000 +DEFAULT_MOUNT_PATH = "/" +DEFAULT_SSE_PATH = "/sse" +DEFAULT_MESSAGE_PATH = "/messages/" +DEFAULT_STREAMABLE_HTTP_PATH = "/mcp" + + +@dataclass(frozen=True, slots=True) +class ServerConfig: + host: str + port: int + access_policy: AccessPolicy + mount_path: str + sse_path: str + message_path: str + streamable_http_path: str + + +def _resolve_string(value: str | None, env_name: str, default: str | None = None) -> str | None: + if value is not None: + return value + if default is None: + return os.getenv(env_name) + return os.getenv(env_name, default) + +def _resolve_read_only(read_only: str | bool | None) -> str | bool | None: + if read_only is not None: + return read_only + return os.getenv("READ_ONLY") -def create_mcp( + +def build_server_config( port: int | None = None, - default_target_id: str | None = None, -) -> FastMCP: - resolved_port = port if port is not None else int(os.getenv("PORT", "8000")) - resolved_default_target_id = default_target_id if default_target_id is not None else default_workspace_id - runtime = create_runtime(resolved_default_target_id) - mcp = FastMCP("Supabase MCP Server (AIDAP)", port=resolved_port) + host: str | None = None, + workspace_ref: str | None = None, + features: str | None = None, + read_only: str | bool | None = None, + disabled_tools: str | None = None, + mount_path: str | None = None, + sse_path: str | None = None, + message_path: str | None = None, + streamable_http_path: str | None = None, +) -> ServerConfig: + resolved_port = port if port is not None else int(os.getenv("MCP_SERVER_PORT", os.getenv("PORT", str(DEFAULT_PORT)))) + resolved_host = _resolve_string(host, "MCP_SERVER_HOST", DEFAULT_HOST) or DEFAULT_HOST + return ServerConfig( + host=resolved_host, + port=resolved_port, + access_policy=build_access_policy( + workspace_ref=_resolve_string(workspace_ref, "WORKSPACE_REF"), + features=_resolve_string(features, "FEATURES"), + read_only=_resolve_read_only(read_only), + disabled_tools=_resolve_string(disabled_tools, "DISABLED_TOOLS"), + ), + mount_path=_resolve_string(mount_path, "MCP_MOUNT_PATH", DEFAULT_MOUNT_PATH) or DEFAULT_MOUNT_PATH, + sse_path=_resolve_string(sse_path, "MCP_SSE_PATH", DEFAULT_SSE_PATH) or DEFAULT_SSE_PATH, + message_path=_resolve_string(message_path, "MCP_MESSAGE_PATH", DEFAULT_MESSAGE_PATH) or DEFAULT_MESSAGE_PATH, + streamable_http_path=_resolve_string(streamable_http_path, "STREAMABLE_HTTP_PATH", DEFAULT_STREAMABLE_HTTP_PATH) or DEFAULT_STREAMABLE_HTTP_PATH, + ) + + +def create_mcp(config: ServerConfig) -> ScopedFastMCP: + mcp = ScopedFastMCP( + "Supabase MCP Server (Volcengine)", + access_policy=config.access_policy, + host=config.host, + port=config.port, + mount_path=config.mount_path, + sse_path=config.sse_path, + message_path=config.message_path, + streamable_http_path=config.streamable_http_path, + ) + runtime = create_runtime(context_getter=mcp.get_context) register_tools(mcp, runtime) return mcp -mcp = create_mcp() +def run_server( + transport: str = "stdio", + port: int | None = None, + host: str | None = None, + workspace_ref: str | None = None, + features: str | None = None, + read_only: str | bool | None = None, + disabled_tools: str | None = None, +) -> None: + config = build_server_config( + port=port, + host=host, + workspace_ref=workspace_ref, + features=features, + read_only=read_only, + disabled_tools=disabled_tools, + ) + create_mcp(config).run(transport=transport) -def main(): +def main() -> None: parser = argparse.ArgumentParser(description="Supabase MCP Server") - parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") + parser.add_argument( + "--transport", + "-t", + choices=["sse", "stdio", "streamable-http"], + default="stdio", + help="Transport protocol to use", + ) + parser.add_argument("--host", type=str, default=None, help="Host to bind for network transports") + parser.add_argument("--port", type=int, default=None, help="Port to run the server on") + parser.add_argument("--workspace-ref", type=str, default=None, help="Hard-scope the server to a single workspace") + parser.add_argument("--features", type=str, default=None, help="Comma-separated official feature groups") + parser.add_argument("--read-only", nargs="?", const="true", default=None, help="Hide all mutating tools for the server") + parser.add_argument("--disabled-tools", type=str, default=None, help="Comma-separated blacklist of tool names") args = parser.parse_args() - logger.info(f"Starting Supabase MCP Server on port {args.port}") - logger.info(f"Read-only mode: {READ_ONLY}") - if default_workspace_id: - logger.info(f"Default workspace ID: {default_workspace_id}") + config = build_server_config( + port=args.port, + host=args.host, + workspace_ref=args.workspace_ref, + features=args.features, + read_only=args.read_only, + disabled_tools=args.disabled_tools, + ) + + logger.info("Starting Supabase MCP Server with %s transport", args.transport) + logger.info("Read-only mode: %s", config.access_policy.read_only) + if config.access_policy.workspace_ref: + logger.info("Workspace scope: %s", config.access_policy.workspace_ref) + logger.info("Feature groups: %s", ",".join(sorted(config.access_policy.features))) + if config.access_policy.disabled_tools: + logger.info("Disabled tools: %s", ",".join(sorted(config.access_policy.disabled_tools))) + if args.transport != "stdio": + logger.info( + "Server binding: host=%s port=%s sse_path=%s message_path=%s streamable_http_path=%s", + config.host, + config.port, + config.sse_path, + config.message_path, + config.streamable_http_path, + ) - create_mcp(port=args.port).run() + create_mcp(config).run(transport=args.transport) if __name__ == "__main__": diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/sse.py b/server/mcp_server_supabase/src/mcp_server_supabase/sse.py new file mode 100644 index 00000000..554e9ad6 --- /dev/null +++ b/server/mcp_server_supabase/src/mcp_server_supabase/sse.py @@ -0,0 +1,9 @@ +from .server import run_server + + +def main() -> None: + run_server(transport="sse") + + +if __name__ == "__main__": + main() diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/streamable_http.py b/server/mcp_server_supabase/src/mcp_server_supabase/streamable_http.py new file mode 100644 index 00000000..2b7d8ff4 --- /dev/null +++ b/server/mcp_server_supabase/src/mcp_server_supabase/streamable_http.py @@ -0,0 +1,9 @@ +from .server import run_server + + +def main() -> None: + run_server(transport="streamable-http") + + +if __name__ == "__main__": + main() diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/tool_registry.py b/server/mcp_server_supabase/src/mcp_server_supabase/tool_registry.py index 6df7b161..1ffb95e1 100644 --- a/server/mcp_server_supabase/src/mcp_server_supabase/tool_registry.py +++ b/server/mcp_server_supabase/src/mcp_server_supabase/tool_registry.py @@ -1,29 +1,46 @@ +from dataclasses import dataclass +from typing import Awaitable, Callable + from mcp.server.fastmcp import FastMCP from .runtime import SupabaseRuntime -def register_tools(mcp: FastMCP, runtime: SupabaseRuntime) -> None: - _register_edge_tools(mcp, runtime) - _register_storage_tools(mcp, runtime) - _register_database_tools(mcp, runtime) - _register_workspace_tools(mcp, runtime) +ToolBuilder = Callable[[SupabaseRuntime], Callable[..., Awaitable[str]]] + +@dataclass(frozen=True) +class ToolDefinition: + name: str + feature: str + scoped: bool + mutating: bool + build: ToolBuilder -def _register_edge_tools(mcp: FastMCP, runtime: SupabaseRuntime) -> None: + +def _build_list_edge_functions(runtime: SupabaseRuntime): edge_tools = runtime.edge_tools - @mcp.tool() async def list_edge_functions(workspace_id: str = None) -> str: - """Lists all Edge Functions in a workspace or branch.""" + """Lists all Edge Functions in a workspace.""" return await edge_tools.list_edge_functions(workspace_id) - @mcp.tool() + return list_edge_functions + + +def _build_get_edge_function(runtime: SupabaseRuntime): + edge_tools = runtime.edge_tools + async def get_edge_function(function_name: str, workspace_id: str = None) -> str: """Retrieves the source code and configuration for an Edge Function.""" return await edge_tools.get_edge_function(function_name, workspace_id) - @mcp.tool() + return get_edge_function + + +def _build_deploy_edge_function(runtime: SupabaseRuntime): + edge_tools = runtime.edge_tools + async def deploy_edge_function( function_name: str, source_code: str, @@ -40,7 +57,7 @@ async def deploy_edge_function( verify_jwt: Whether to verify JWT tokens runtime: Runtime environment import_map: Optional import map JSON for dependencies - workspace_id: The workspace ID or branch ID + workspace_id: The workspace ID """ return await edge_tools.deploy_edge_function( function_name, @@ -51,21 +68,32 @@ async def deploy_edge_function( workspace_id, ) - @mcp.tool() + return deploy_edge_function + + +def _build_delete_edge_function(runtime: SupabaseRuntime): + edge_tools = runtime.edge_tools + async def delete_edge_function(function_name: str, workspace_id: str = None) -> str: """Deletes an Edge Function.""" return await edge_tools.delete_edge_function(function_name, workspace_id) + return delete_edge_function + -def _register_storage_tools(mcp: FastMCP, runtime: SupabaseRuntime) -> None: +def _build_list_storage_buckets(runtime: SupabaseRuntime): storage_tools = runtime.storage_tools - @mcp.tool() async def list_storage_buckets(workspace_id: str = None) -> str: - """Lists all storage buckets in a workspace or branch.""" + """Lists all storage buckets in a workspace.""" return await storage_tools.list_storage_buckets(workspace_id) - @mcp.tool() + return list_storage_buckets + + +def _build_create_storage_bucket(runtime: SupabaseRuntime): + storage_tools = runtime.storage_tools + async def create_storage_bucket( bucket_name: str, public: bool = False, @@ -82,67 +110,114 @@ async def create_storage_bucket( workspace_id, ) - @mcp.tool() + return create_storage_bucket + + +def _build_delete_storage_bucket(runtime: SupabaseRuntime): + storage_tools = runtime.storage_tools + async def delete_storage_bucket(bucket_name: str, workspace_id: str = None) -> str: """Deletes a storage bucket.""" return await storage_tools.delete_storage_bucket(bucket_name, workspace_id) - @mcp.tool() + return delete_storage_bucket + + +def _build_get_storage_config(runtime: SupabaseRuntime): + storage_tools = runtime.storage_tools + async def get_storage_config(workspace_id: str = None) -> str: - """Gets the storage configuration for a workspace or branch.""" + """Gets the storage configuration for a workspace.""" return await storage_tools.get_storage_config(workspace_id) + return get_storage_config + -def _register_database_tools(mcp: FastMCP, runtime: SupabaseRuntime) -> None: +def _build_execute_sql(runtime: SupabaseRuntime): database_tools = runtime.database_tools - @mcp.tool() async def execute_sql(query: str, workspace_id: str = None) -> str: """Executes raw SQL in the Postgres database.""" return await database_tools.execute_sql(query, workspace_id) - @mcp.tool() + return execute_sql + + +def _build_list_tables(runtime: SupabaseRuntime): + database_tools = runtime.database_tools + async def list_tables(schemas: str = "public", workspace_id: str = None) -> str: """Lists all tables in one or more schemas.""" schema_list = [schema.strip() for schema in schemas.split(",")] return await database_tools.list_tables(schema_list, workspace_id) - @mcp.tool() + return list_tables + + +def _build_list_migrations(runtime: SupabaseRuntime): + database_tools = runtime.database_tools + async def list_migrations(workspace_id: str = None) -> str: """Lists all migrations in the database.""" return await database_tools.list_migrations(workspace_id) - @mcp.tool() + return list_migrations + + +def _build_list_extensions(runtime: SupabaseRuntime): + database_tools = runtime.database_tools + async def list_extensions(workspace_id: str = None) -> str: """Lists all PostgreSQL extensions in the database.""" return await database_tools.list_extensions(workspace_id) - @mcp.tool() + return list_extensions + + +def _build_apply_migration(runtime: SupabaseRuntime): + database_tools = runtime.database_tools + async def apply_migration(name: str, query: str, workspace_id: str = None) -> str: """Applies a migration to the database.""" return await database_tools.apply_migration(name, query, workspace_id) - @mcp.tool() + return apply_migration + + +def _build_generate_typescript_types(runtime: SupabaseRuntime): + database_tools = runtime.database_tools + async def generate_typescript_types(schemas: str = "public", workspace_id: str = None) -> str: """Generates TypeScript definitions from database schema.""" schema_list = [schema.strip() for schema in schemas.split(",") if schema.strip()] return await database_tools.generate_typescript_types(schema_list, workspace_id) + return generate_typescript_types + -def _register_workspace_tools(mcp: FastMCP, runtime: SupabaseRuntime) -> None: +def _build_list_workspaces(runtime: SupabaseRuntime): workspace_tools = runtime.workspace_tools - @mcp.tool() async def list_workspaces() -> str: """Lists all available workspaces.""" return await workspace_tools.list_workspaces() - @mcp.tool() + return list_workspaces + + +def _build_get_workspace(runtime: SupabaseRuntime): + workspace_tools = runtime.workspace_tools + async def get_workspace(workspace_id: str) -> str: - """Gets details for a specific workspace or branch target.""" + """Gets details for a specific workspace.""" return await workspace_tools.get_workspace(workspace_id) - @mcp.tool() + return get_workspace + + +def _build_create_workspace(runtime: SupabaseRuntime): + workspace_tools = runtime.workspace_tools + async def create_workspace( workspace_name: str, engine_version: str = "Supabase_1_24", @@ -151,42 +226,123 @@ async def create_workspace( """Creates a new workspace.""" return await workspace_tools.create_workspace(workspace_name, engine_version, engine_type) - @mcp.tool() + return create_workspace + + +def _build_pause_workspace(runtime: SupabaseRuntime): + workspace_tools = runtime.workspace_tools + async def pause_workspace(workspace_id: str = None) -> str: """Pauses a workspace.""" return await workspace_tools.pause_workspace(workspace_id) - @mcp.tool() + return pause_workspace + + +def _build_restore_workspace(runtime: SupabaseRuntime): + workspace_tools = runtime.workspace_tools + async def restore_workspace(workspace_id: str = None) -> str: """Restores a workspace.""" return await workspace_tools.restore_workspace(workspace_id) - @mcp.tool() + return restore_workspace + + +def _build_get_workspace_url(runtime: SupabaseRuntime): + workspace_tools = runtime.workspace_tools + async def get_workspace_url(workspace_id: str = None) -> str: - """Gets API endpoint URL for a workspace or branch.""" + """Gets API endpoint URL for a workspace.""" return await workspace_tools.get_workspace_url(workspace_id) - @mcp.tool() + return get_workspace_url + + +def _build_get_publishable_keys(runtime: SupabaseRuntime): + workspace_tools = runtime.workspace_tools + async def get_publishable_keys(workspace_id: str = None, reveal: bool = False) -> str: - """Gets API keys for a workspace or branch.""" + """Gets API keys for a workspace.""" return await workspace_tools.get_publishable_keys(workspace_id, reveal) - @mcp.tool() + return get_publishable_keys + + +def _build_list_branches(runtime: SupabaseRuntime): + workspace_tools = runtime.workspace_tools + async def list_branches(workspace_id: str = None) -> str: """Lists all development branches of a workspace.""" return await workspace_tools.list_branches(workspace_id) - @mcp.tool() + return list_branches + + +def _build_create_branch(runtime: SupabaseRuntime): + workspace_tools = runtime.workspace_tools + async def create_branch(name: str = "develop", workspace_id: str = None) -> str: """Creates a development branch.""" return await workspace_tools.create_branch(name, workspace_id) - @mcp.tool() + return create_branch + + +def _build_delete_branch(runtime: SupabaseRuntime): + workspace_tools = runtime.workspace_tools + async def delete_branch(branch_id: str, workspace_id: str = None) -> str: """Deletes a development branch.""" return await workspace_tools.delete_branch(branch_id, workspace_id) - @mcp.tool() - async def reset_branch(branch_id: str, migration_version: str = None, workspace_id: str = None) -> str: - """Resets a development branch. Any untracked data or schema changes will be lost.""" - return await workspace_tools.reset_branch(branch_id, migration_version, workspace_id) + return delete_branch + + +def _build_restore_branch(runtime: SupabaseRuntime): + workspace_tools = runtime.workspace_tools + + async def restore_branch( + branch_id: str, + source_branch_id: str = None, + time: str = None, + workspace_id: str = None, + ) -> str: + """Restores branch data to a specified point in time and returns the restored branch ID.""" + return await workspace_tools.restore_branch(branch_id, source_branch_id, time, workspace_id) + + return restore_branch + + +TOOL_DEFINITIONS = ( + ToolDefinition("list_workspaces", "account", False, False, _build_list_workspaces), + ToolDefinition("get_workspace", "account", True, False, _build_get_workspace), + ToolDefinition("create_workspace", "account", False, True, _build_create_workspace), + ToolDefinition("pause_workspace", "account", True, True, _build_pause_workspace), + ToolDefinition("restore_workspace", "account", True, True, _build_restore_workspace), + ToolDefinition("execute_sql", "database", True, True, _build_execute_sql), + ToolDefinition("list_tables", "database", True, False, _build_list_tables), + ToolDefinition("list_migrations", "database", True, False, _build_list_migrations), + ToolDefinition("list_extensions", "database", True, False, _build_list_extensions), + ToolDefinition("apply_migration", "database", True, True, _build_apply_migration), + ToolDefinition("get_workspace_url", "development", True, False, _build_get_workspace_url), + ToolDefinition("get_publishable_keys", "development", True, False, _build_get_publishable_keys), + ToolDefinition("generate_typescript_types", "development", True, False, _build_generate_typescript_types), + ToolDefinition("list_edge_functions", "functions", True, False, _build_list_edge_functions), + ToolDefinition("get_edge_function", "functions", True, False, _build_get_edge_function), + ToolDefinition("deploy_edge_function", "functions", True, True, _build_deploy_edge_function), + ToolDefinition("delete_edge_function", "functions", True, True, _build_delete_edge_function), + ToolDefinition("list_storage_buckets", "storage", True, False, _build_list_storage_buckets), + ToolDefinition("create_storage_bucket", "storage", True, True, _build_create_storage_bucket), + ToolDefinition("delete_storage_bucket", "storage", True, True, _build_delete_storage_bucket), + ToolDefinition("get_storage_config", "storage", True, False, _build_get_storage_config), + ToolDefinition("list_branches", "branching", True, False, _build_list_branches), + ToolDefinition("create_branch", "branching", True, True, _build_create_branch), + ToolDefinition("delete_branch", "branching", True, True, _build_delete_branch), + ToolDefinition("restore_branch", "branching", True, True, _build_restore_branch), +) + + +def register_tools(mcp: FastMCP, runtime: SupabaseRuntime) -> None: + for tool_definition in TOOL_DEFINITIONS: + mcp.tool()(tool_definition.build(runtime)) diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/tools/base.py b/server/mcp_server_supabase/src/mcp_server_supabase/tools/base.py index 677e5fc2..3c85c718 100644 --- a/server/mcp_server_supabase/src/mcp_server_supabase/tools/base.py +++ b/server/mcp_server_supabase/src/mcp_server_supabase/tools/base.py @@ -1,50 +1,25 @@ from typing import Optional from ..platform import AidapClient, SupabaseClient -from ..utils import resolve_target, select_target_id +from ..utils import resolve_workspace_id class BaseTools: - """Base class for all tool classes""" - - def __init__(self, aidap_client: AidapClient, workspace_id: Optional[str] = None): + def __init__(self, aidap_client: AidapClient): self.aidap = aidap_client - self.default_workspace_id = workspace_id - - def _get_workspace_id(self, workspace_id: Optional[str]) -> str: - """Get workspace ID from parameter or default""" - result = select_target_id(workspace_id, self.default_workspace_id) - if not result: - raise ValueError( - "workspace_id is required: not provided as parameter and no default workspace_id configured. " - "Please provide workspace_id or set DEFAULT_WORKSPACE_ID environment variable." - ) - return result - async def _resolve_target(self, workspace_id: Optional[str]) -> tuple[str, Optional[str]]: - target = self._get_workspace_id(workspace_id) - resolved_workspace_id, branch_id = await resolve_target(self.aidap, target, None) + def _resolve_workspace_id(self, workspace_id: Optional[str]) -> str: + resolved_workspace_id = resolve_workspace_id(workspace_id) if not resolved_workspace_id: - raise ValueError( - "workspace_id is required: not provided as parameter and no default workspace_id configured. " - "Please provide workspace_id or set DEFAULT_WORKSPACE_ID environment variable." - ) - return resolved_workspace_id, branch_id - - async def _get_client(self, workspace_id: str, branch_id: Optional[str] = None) -> SupabaseClient: - """Get Supabase client for workspace""" - import logging - logger = logging.getLogger(__name__) + raise ValueError("workspace_id is required") + return resolved_workspace_id - endpoint = await self.aidap.get_endpoint(workspace_id, branch_id=branch_id) - logger.info(f"[DEBUG] Got endpoint for {workspace_id} branch={branch_id}: {endpoint}") + async def _get_client(self, workspace_id: str) -> SupabaseClient: + endpoint = await self.aidap.get_endpoint(workspace_id) if not endpoint: - target = branch_id or workspace_id - raise ValueError(f"Could not get endpoint for target {target}") + raise ValueError(f"Could not get endpoint for workspace {workspace_id}") - api_key = await self.aidap.get_api_key(workspace_id, "service_role", branch_id=branch_id) - logger.info(f"[DEBUG] Got API key for {workspace_id} branch={branch_id}: {'yes' if api_key else 'no'}") + api_key = await self.aidap.get_api_key(workspace_id, "service_role") if not api_key: - target = branch_id or workspace_id - raise ValueError(f"Could not get API key for target {target}") + raise ValueError(f"Could not get API key for workspace {workspace_id}") return SupabaseClient(endpoint, api_key) diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/tools/database_tools.py b/server/mcp_server_supabase/src/mcp_server_supabase/tools/database_tools.py index f6ba117e..501758b0 100644 --- a/server/mcp_server_supabase/src/mcp_server_supabase/tools/database_tools.py +++ b/server/mcp_server_supabase/src/mcp_server_supabase/tools/database_tools.py @@ -2,24 +2,23 @@ import logging from datetime import datetime, timezone from .base import BaseTools -from ..utils import handle_errors, read_only_check +from ..utils import handle_errors logger = logging.getLogger(__name__) class DatabaseTools(BaseTools): - """使用 REST API 方式执行 SQL""" async def _execute_sql_raw(self, query: str, workspace_id: Optional[str] = None) -> List[dict]: if not query or not query.strip(): raise ValueError("SQL query cannot be empty") - ws_id, branch_id = await self._resolve_target(workspace_id) - logger.info( + ws_id = self._resolve_workspace_id(workspace_id) + logger.debug( "Executing SQL query", - extra={"workspace_id": ws_id, "branch_id": branch_id, "query_length": len(query)} + extra={"workspace_id": ws_id, "query_length": len(query)} ) - client = await self._get_client(ws_id, branch_id) + client = await self._get_client(ws_id) result = await client.call_api("/pg/query", method="POST", json_data={"query": query}) if isinstance(result, dict) and isinstance(result.get("data"), list): @@ -30,21 +29,22 @@ async def _execute_sql_raw(self, query: str, workspace_id: Optional[str] = None) logger.debug(f"SQL query returned {len(result)} rows") return result + def _normalize_schemas(self, schemas: Optional[List[str]] = None) -> List[str]: + normalized = [schema.strip() for schema in (schemas or ["public"]) if schema and schema.strip()] + if not normalized: + raise ValueError("At least one schema is required") + for schema in normalized: + if not schema.replace('_', '').isalnum(): + raise ValueError(f"Invalid schema name: {schema}") + return normalized + @handle_errors async def execute_sql(self, query: str, workspace_id: Optional[str] = None) -> List[dict]: return await self._execute_sql_raw(query, workspace_id) @handle_errors async def list_tables(self, schemas: List[str] = None, workspace_id: Optional[str] = None) -> List[dict]: - if schemas is None: - schemas = ["public"] - - # 验证 schema 名称,防止 SQL 注入 - for schema in schemas: - if not schema.replace('_', '').isalnum(): - raise ValueError(f"Invalid schema name: {schema}") - - schema_list = "', '".join(schemas) + schema_list = "', '".join(self._normalize_schemas(schemas)) query = f""" SELECT schemaname as schema, @@ -58,18 +58,28 @@ async def list_tables(self, schemas: List[str] = None, workspace_id: Optional[st @handle_errors async def list_migrations(self, workspace_id: Optional[str] = None) -> List[dict]: - query = """ - CREATE SCHEMA IF NOT EXISTS supabase_migrations; - CREATE TABLE IF NOT EXISTS supabase_migrations.schema_migrations ( - version text PRIMARY KEY, - name text NOT NULL, - inserted_at timestamptz NOT NULL DEFAULT now() - ); - SELECT version, name - FROM supabase_migrations.schema_migrations - ORDER BY version DESC - """ - return await self._execute_sql_raw(query, workspace_id) + existence_rows = await self._execute_sql_raw( + """ + SELECT EXISTS ( + SELECT 1 + FROM information_schema.tables + WHERE table_schema = 'supabase_migrations' + AND table_name = 'schema_migrations' + ) AS exists + """, + workspace_id, + ) + exists = bool(existence_rows and existence_rows[0].get("exists")) + if not exists: + return [] + return await self._execute_sql_raw( + """ + SELECT version, name + FROM supabase_migrations.schema_migrations + ORDER BY version DESC + """, + workspace_id, + ) @handle_errors async def list_extensions(self, workspace_id: Optional[str] = None) -> List[dict]: @@ -85,7 +95,6 @@ async def list_extensions(self, workspace_id: Optional[str] = None) -> List[dict return await self._execute_sql_raw(query, workspace_id) @handle_errors - @read_only_check async def apply_migration(self, name: str, query: str, workspace_id: Optional[str] = None) -> dict: if not name or not name.strip(): raise ValueError("Migration name cannot be empty") @@ -155,13 +164,7 @@ async def generate_typescript_types( schemas: List[str] = None, workspace_id: Optional[str] = None ) -> str: - if schemas is None: - schemas = ["public"] - for schema in schemas: - if not schema.replace('_', '').isalnum(): - raise ValueError(f"Invalid schema name: {schema}") - - schema_list = "', '".join(schemas) + schema_list = "', '".join(self._normalize_schemas(schemas)) query = f""" SELECT table_schema, diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/tools/edge_function_tools.py b/server/mcp_server_supabase/src/mcp_server_supabase/tools/edge_function_tools.py index b274c15c..9b1912c5 100644 --- a/server/mcp_server_supabase/src/mcp_server_supabase/tools/edge_function_tools.py +++ b/server/mcp_server_supabase/src/mcp_server_supabase/tools/edge_function_tools.py @@ -6,13 +6,12 @@ import re from urllib.parse import quote from .base import BaseTools -from ..utils import handle_errors, read_only_check +from ..utils import handle_errors from ..models import EdgeFunction from ..platform.supabase_client import SupabaseApiError logger = logging.getLogger(__name__) -# 运行时配置 RUNTIME_CONFIG = { "native-node20/v1": { "entrypoint": "index.ts", @@ -36,10 +35,9 @@ } } -# 保留的函数名 RESERVED_SLUGS = {"deploy", "body", "health", "metrics"} MAX_SLUG_LENGTH = 127 -MAX_CODE_SIZE = 10 * 1024 * 1024 # 10MB +MAX_CODE_SIZE = 10 * 1024 * 1024 WORKSPACE_SLUG = os.getenv("SUPABASE_WORKSPACE_SLUG", "default").strip() or "default" @@ -99,7 +97,15 @@ def _normalize_function_payload(self, payload: object) -> object: return result def _validate_function_name(self, function_name: str) -> None: - return + normalized = (function_name or "").strip() + if not normalized: + raise ValueError("Function name cannot be empty") + if len(normalized) > MAX_SLUG_LENGTH: + raise ValueError(f"Function name too long: {len(normalized)} characters (max {MAX_SLUG_LENGTH})") + if normalized.lower() in RESERVED_SLUGS: + raise ValueError(f"Function name '{normalized}' is reserved") + if not re.fullmatch(r"[A-Za-z0-9][A-Za-z0-9_-]*", normalized): + raise ValueError("Function name must start with a letter or digit and contain only letters, digits, hyphens, or underscores") def _validate_runtime(self, runtime: str) -> None: """验证运行时""" @@ -131,23 +137,23 @@ def _extract_error_text(self, payload: object) -> str: @handle_errors async def list_edge_functions(self, workspace_id: Optional[str] = None) -> List[EdgeFunction]: - ws_id, branch_id = await self._resolve_target(workspace_id) - logger.info(f"Listing edge functions for workspace {ws_id}") + ws_id = self._resolve_workspace_id(workspace_id) + logger.debug("Listing edge functions for workspace %s", ws_id) - client = await self._get_client(ws_id, branch_id) + client = await self._get_client(ws_id) result = await client.call_api(f"/v1/projects/{WORKSPACE_SLUG}/functions") functions = [EdgeFunction(**func) for func in result] - logger.info(f"Found {len(functions)} edge functions") + logger.debug("Found %s edge functions", len(functions)) return functions @handle_errors async def get_edge_function(self, function_name: str, workspace_id: Optional[str] = None) -> dict: self._validate_function_name(function_name) - ws_id, branch_id = await self._resolve_target(workspace_id) - logger.info(f"Getting edge function '{function_name}' from workspace {ws_id}") + ws_id = self._resolve_workspace_id(workspace_id) + logger.debug("Getting edge function '%s' from workspace %s", function_name, ws_id) - client = await self._get_client(ws_id, branch_id) + client = await self._get_client(ws_id) encoded_name = quote(function_name, safe="") try: result = await client.call_api(f"/v1/projects/{WORKSPACE_SLUG}/functions/{encoded_name}") @@ -162,7 +168,6 @@ async def get_edge_function(self, function_name: str, workspace_id: Optional[str return EdgeFunction(**result).model_dump() @handle_errors - @read_only_check async def deploy_edge_function( self, function_name: str, @@ -172,37 +177,18 @@ async def deploy_edge_function( import_map: Optional[str] = None, workspace_id: Optional[str] = None ) -> dict: - """ - 部署边缘函数 - - Args: - function_name: 函数名称 - source_code: 源代码 - verify_jwt: 是否验证 JWT - runtime: 运行时环境 (native-node20/v1, native-python3.9/v1, etc.) - import_map: 可选的 import map JSON - workspace_id: 工作空间 ID - - Returns: - 部署结果字典 - - Raises: - ValueError: 参数验证失败 - """ - # 验证输入 self._validate_function_name(function_name) self._validate_runtime(runtime) if not source_code or not source_code.strip(): raise ValueError("Source code cannot be empty") - # HTML 反转义,防止代码中的特殊字符被转义 source_code = html.unescape(source_code) self._validate_code_size(source_code) self._validate_runtime_compatibility(runtime, source_code) - ws_id, branch_id = await self._resolve_target(workspace_id) + ws_id = self._resolve_workspace_id(workspace_id) entrypoint = self._get_entrypoint(runtime) logger.info( @@ -210,7 +196,6 @@ async def deploy_edge_function( extra={ "function_name": function_name, "workspace_id": ws_id, - "branch_id": branch_id, "runtime": runtime, "verify_jwt": verify_jwt, "entrypoint": entrypoint, @@ -218,7 +203,7 @@ async def deploy_edge_function( } ) - client = await self._get_client(ws_id, branch_id) + client = await self._get_client(ws_id) encoded_name = quote(function_name, safe="") @@ -236,7 +221,6 @@ async def deploy_edge_function( except json.JSONDecodeError as e: raise ValueError(f"Invalid import map JSON: {e}") - # AIDAP 部署 API 路径 result = await client.call_api( f"/v1/projects/{WORKSPACE_SLUG}/functions/deploy?slug={encoded_name}", method="POST", @@ -253,13 +237,12 @@ async def deploy_edge_function( return result @handle_errors - @read_only_check async def delete_edge_function(self, function_name: str, workspace_id: Optional[str] = None) -> dict: self._validate_function_name(function_name) - ws_id, branch_id = await self._resolve_target(workspace_id) + ws_id = self._resolve_workspace_id(workspace_id) logger.info(f"Deleting edge function '{function_name}' from workspace {ws_id}") - client = await self._get_client(ws_id, branch_id) + client = await self._get_client(ws_id) encoded_name = quote(function_name, safe="") await client.call_api(f"/v1/projects/{WORKSPACE_SLUG}/functions/{encoded_name}", method="DELETE") diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/tools/storage_tools.py b/server/mcp_server_supabase/src/mcp_server_supabase/tools/storage_tools.py index 18b6f10e..52b5d62e 100644 --- a/server/mcp_server_supabase/src/mcp_server_supabase/tools/storage_tools.py +++ b/server/mcp_server_supabase/src/mcp_server_supabase/tools/storage_tools.py @@ -2,7 +2,7 @@ import logging import json from .base import BaseTools -from ..utils import handle_errors, read_only_check +from ..utils import handle_errors from ..models import StorageConfig logger = logging.getLogger(__name__) @@ -35,17 +35,16 @@ def _normalize_allowed_mime_types(self, allowed_mime_types: Optional[str | list[ @handle_errors async def list_storage_buckets(self, workspace_id: Optional[str] = None) -> List[dict]: - ws_id, branch_id = await self._resolve_target(workspace_id) - logger.info(f"Listing storage buckets for workspace {ws_id}") + ws_id = self._resolve_workspace_id(workspace_id) + logger.debug("Listing storage buckets for workspace %s", ws_id) - client = await self._get_client(ws_id, branch_id) + client = await self._get_client(ws_id) result = await client.call_api("/storage/v1/bucket") - logger.info(f"Found {len(result)} storage buckets") + logger.debug("Found %s storage buckets", len(result)) return result @handle_errors - @read_only_check async def create_storage_bucket( self, bucket_name: str, @@ -57,13 +56,13 @@ async def create_storage_bucket( if not bucket_name or not bucket_name.strip(): raise ValueError("Bucket name cannot be empty") - ws_id, branch_id = await self._resolve_target(workspace_id) + ws_id = self._resolve_workspace_id(workspace_id) logger.info( f"Creating storage bucket '{bucket_name}'", - extra={"workspace_id": ws_id, "branch_id": branch_id, "public": public} + extra={"workspace_id": ws_id, "public": public} ) - client = await self._get_client(ws_id, branch_id) + client = await self._get_client(ws_id) data = { "name": bucket_name, @@ -78,12 +77,11 @@ async def create_storage_bucket( return await client.call_api("/storage/v1/bucket", method="POST", json_data=data) @handle_errors - @read_only_check async def delete_storage_bucket(self, bucket_name: str, workspace_id: Optional[str] = None) -> dict: if not bucket_name or not bucket_name.strip(): raise ValueError("Bucket name cannot be empty") - ws_id, branch_id = await self._resolve_target(workspace_id) - client = await self._get_client(ws_id, branch_id) + ws_id = self._resolve_workspace_id(workspace_id) + client = await self._get_client(ws_id) response = await client.call_api(f"/storage/v1/bucket/{bucket_name}", method="DELETE") if isinstance(response, dict) and "error" in response: raise ValueError(response["error"]) @@ -91,7 +89,7 @@ async def delete_storage_bucket(self, bucket_name: str, workspace_id: Optional[s @handle_errors async def get_storage_config(self, workspace_id: Optional[str] = None) -> StorageConfig: - ws_id, branch_id = await self._resolve_target(workspace_id) - client = await self._get_client(ws_id, branch_id) + ws_id = self._resolve_workspace_id(workspace_id) + client = await self._get_client(ws_id) result = await client.call_api("/storage/v1/config") return StorageConfig(**result) diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/tools/workspace_tools.py b/server/mcp_server_supabase/src/mcp_server_supabase/tools/workspace_tools.py index e2630ed3..9bffac99 100644 --- a/server/mcp_server_supabase/src/mcp_server_supabase/tools/workspace_tools.py +++ b/server/mcp_server_supabase/src/mcp_server_supabase/tools/workspace_tools.py @@ -3,41 +3,55 @@ import logging from typing import Any, Optional -from ..utils import compact_dict, pick_value, read_only_check, resolve_target, to_json +from .base import BaseTools +from ..utils import compact_dict, pick_value, to_json logger = logging.getLogger(__name__) -class WorkspaceTools: - def __init__(self, aidap_client, default_workspace_id: Optional[str] = None): - self.aidap_client = aidap_client - self.default_workspace_id = default_workspace_id +class WorkspaceTools(BaseTools): + _filter_supports_mode: bool | None = None - def _to_json(self, payload: dict) -> str: - return to_json(payload) + @classmethod + def _supports_workspace_filter_mode(cls) -> bool: + if cls._filter_supports_mode is None: + from volcenginesdkaidap.models import FilterForDescribeWorkspacesInput - def _compact(self, payload: dict) -> dict: - return compact_dict(payload) + cls._filter_supports_mode = "mode" in inspect.signature(FilterForDescribeWorkspacesInput).parameters + return cls._filter_supports_mode - def _pick(self, source: Any, *field_names: str) -> Any: - return pick_value(source, *field_names) + def _resolve_workspace_or_response( + self, + workspace_id: Optional[str], + detailed: bool = False, + ) -> tuple[str | None, str | None]: + try: + return self._resolve_workspace_id(workspace_id), None + except ValueError: + return None, self._workspace_required_response(detailed) - async def _resolve_target(self, target_id: Optional[str]) -> tuple[Optional[str], Optional[str]]: - return await resolve_target(self.aidap_client, target_id, self.default_workspace_id) + def _workspace_required_response(self, detailed: bool = False) -> str: + payload = { + "success": False, + "error": "workspace_id is required", + } + if detailed: + payload["error_detail"] = self._error_detail("MissingWorkspaceId", "workspace_id is required", False) + return to_json(payload) def _workspace_view(self, source: Any) -> dict: payload = { - "workspace_id": self._pick(source, "workspace_id"), - "workspace_name": self._pick(source, "workspace_name"), - "status": self._pick(source, "workspace_status", "status"), - "region": self._pick(source, "region_id", "region"), - "created_at": self._pick(source, "create_time", "created_at"), - "updated_at": self._pick(source, "update_time", "updated_at"), - "engine_type": self._pick(source, "engine_type"), - "engine_version": self._pick(source, "engine_version"), - "deletion_protection_status": self._pick(source, "deletion_protection_status"), + "workspace_id": pick_value(source, "workspace_id"), + "workspace_name": pick_value(source, "workspace_name"), + "status": pick_value(source, "workspace_status", "status"), + "region": pick_value(source, "region_id", "region"), + "created_at": pick_value(source, "create_time", "created_at"), + "updated_at": pick_value(source, "update_time", "updated_at"), + "engine_type": pick_value(source, "engine_type"), + "engine_version": pick_value(source, "engine_version"), + "deletion_protection_status": pick_value(source, "deletion_protection_status"), } - return self._compact(payload) + return compact_dict(payload) def _branch_view(self, branch: dict, workspace_payload: Optional[dict] = None) -> dict: workspace_payload = workspace_payload or {} @@ -56,26 +70,25 @@ def _branch_view(self, branch: dict, workspace_payload: Optional[dict] = None) - "deletion_protection_status": workspace_payload.get("deletion_protection_status"), "target_type": "branch", } - return self._compact(payload) + return compact_dict(payload) def _describe_workspaces_response(self): from volcenginesdkaidap.models import DescribeWorkspacesRequest, FilterForDescribeWorkspacesInput - parameters = inspect.signature(FilterForDescribeWorkspacesInput).parameters filter_kwargs = { "name": "DBEngineVersion", "value": "Supabase_1_24", } - if "mode" in parameters: + if self._supports_workspace_filter_mode(): filter_kwargs["mode"] = "Exact" filters = [FilterForDescribeWorkspacesInput(**filter_kwargs)] request = DescribeWorkspacesRequest(filters=filters) - return self.aidap_client.client.describe_workspaces(request) + return self.aidap.client.describe_workspaces(request) def _find_workspace_source(self, workspace_id: str) -> Optional[Any]: response = self._describe_workspaces_response() for workspace in list(getattr(response, "workspaces", []) or []): - if self._pick(workspace, "workspace_id") == workspace_id: + if pick_value(workspace, "workspace_id") == workspace_id: return workspace return None @@ -100,49 +113,39 @@ async def list_workspaces(self) -> str: response = self._describe_workspaces_response() raw_workspaces = list(getattr(response, "workspaces", []) or []) workspaces = [self._workspace_view(workspace) for workspace in raw_workspaces] - return self._to_json({ + return to_json({ "success": True, "workspaces": workspaces, "count": len(workspaces), }) except Exception as e: logger.error(f"Error listing workspaces: {e}") - return self._to_json({ + return to_json({ "success": False, "error": str(e), }) async def get_workspace(self, workspace_id: str) -> str: try: - ws_id, branch_id = await self._resolve_target(workspace_id) - if not ws_id: - return self._to_json({ - "success": False, - "error": "workspace_id is required", - }) + ws_id = self._resolve_workspace_id(workspace_id) workspace_source = self._find_workspace_source(ws_id) if workspace_source is None: - return self._to_json({ + return to_json({ "success": False, "error": "Workspace not found", }) workspace_info = self._workspace_view(workspace_source) - if branch_id: - branch = await self.aidap_client.get_branch(ws_id, branch_id) - if branch: - workspace_info.update(self._branch_view(branch, workspace_info)) - return self._to_json({ + return to_json({ "success": True, "workspace": workspace_info, }) except Exception as e: logger.error(f"Error getting workspace: {e}") - return self._to_json({ + return to_json({ "success": False, "error": str(e), }) - @read_only_check async def create_workspace( self, workspace_name: str, @@ -150,14 +153,14 @@ async def create_workspace( engine_type: str = "Supabase", ) -> str: if not workspace_name or not workspace_name.strip(): - return self._to_json({"success": False, "error": "workspace_name is required"}) - result = await self.aidap_client.create_workspace( + return to_json({"success": False, "error": "workspace_name is required"}) + result = await self.aidap.create_workspace( workspace_name=workspace_name.strip(), engine_type=engine_type, engine_version=engine_version, ) if not isinstance(result, dict): - return self._to_json({"success": False, "error": "Unexpected create workspace response"}) + return to_json({"success": False, "error": "Unexpected create workspace response"}) if result.get("success"): mapped = { "success": True, @@ -166,36 +169,33 @@ async def create_workspace( "engine_type": result.get("engine_type"), "engine_version": result.get("engine_version"), } - return self._to_json(self._compact(mapped)) - return self._to_json(result) + return to_json(compact_dict(mapped)) + return to_json(result) - @read_only_check async def restore_workspace(self, workspace_id: Optional[str] = None) -> str: - ws_id, _ = await self._resolve_target(workspace_id) - if not ws_id: - return self._to_json({"success": False, "error": "workspace_id is required"}) - result = await self.aidap_client.start_workspace(ws_id) - return self._to_json(result if isinstance(result, dict) else {"success": bool(result), "workspace_id": ws_id}) + ws_id, error_response = self._resolve_workspace_or_response(workspace_id) + if error_response: + return error_response + result = await self.aidap.start_workspace(ws_id) + return to_json(result if isinstance(result, dict) else {"success": bool(result), "workspace_id": ws_id}) - @read_only_check async def pause_workspace(self, workspace_id: Optional[str] = None) -> str: - ws_id, _ = await self._resolve_target(workspace_id) - if not ws_id: - return self._to_json({"success": False, "error": "workspace_id is required"}) - result = await self.aidap_client.stop_workspace(ws_id) - return self._to_json(result if isinstance(result, dict) else {"success": bool(result), "workspace_id": ws_id}) + ws_id, error_response = self._resolve_workspace_or_response(workspace_id) + if error_response: + return error_response + result = await self.aidap.stop_workspace(ws_id) + return to_json(result if isinstance(result, dict) else {"success": bool(result), "workspace_id": ws_id}) - @read_only_check async def create_branch( self, name: str = "develop", workspace_id: Optional[str] = None, ) -> str: - ws_id, _ = await self._resolve_target(workspace_id) - if not ws_id: - return self._to_json({"success": False, "error": "workspace_id is required"}) + ws_id, error_response = self._resolve_workspace_or_response(workspace_id) + if error_response: + return error_response - result = await self.aidap_client.create_branch(ws_id, name) + result = await self.aidap.create_branch(ws_id, name) if result.get("success") and result.get("branch_id"): branch_payload = self._branch_view(result, {"workspace_id": ws_id}) branch_payload["branch_name"] = branch_payload.get("branch_name") or name @@ -203,38 +203,31 @@ async def create_branch( "success": True, **branch_payload, } - endpoint = await self.aidap_client.get_endpoint(ws_id, branch_id=result["branch_id"], use_cache=False) + endpoint = await self.aidap.get_endpoint(ws_id, branch_id=result["branch_id"]) if endpoint: response_payload["workspace_url"] = endpoint response_payload["api_url"] = endpoint - return self._to_json(self._compact(response_payload)) - return self._to_json(result) + return to_json(compact_dict(response_payload)) + return to_json(result) async def list_branches(self, workspace_id: Optional[str] = None) -> str: - ws_id, _ = await self._resolve_target(workspace_id) - if not ws_id: - return self._to_json({"success": False, "error": "workspace_id is required"}) try: + ws_id = self._resolve_workspace_id(workspace_id) workspace_source = self._find_workspace_source(ws_id) workspace_payload = self._workspace_view(workspace_source) if workspace_source is not None else {"workspace_id": ws_id} - branches = await self.aidap_client.list_branches(ws_id) + branches = await self.aidap.list_branches(ws_id) normalized_branches = [self._branch_view(branch, workspace_payload) for branch in branches] - return self._to_json({"success": True, "branches": normalized_branches}) + return to_json({"success": True, "branches": normalized_branches}) except Exception as e: logger.error(f"Error listing branches: {e}") - return self._to_json({"success": False, "error": str(e)}) + return to_json({"success": False, "error": str(e)}) - @read_only_check async def delete_branch(self, branch_id: str, workspace_id: Optional[str] = None) -> str: - ws_id, _ = await self._resolve_target(workspace_id) - if not ws_id: - return self._to_json({ - "success": False, - "error": "workspace_id is required", - "error_detail": self._error_detail("MissingWorkspaceId", "workspace_id is required", False), - }) + ws_id, error_response = self._resolve_workspace_or_response(workspace_id, detailed=True) + if error_response: + return error_response if not branch_id or not branch_id.strip(): - return self._to_json({ + return to_json({ "success": False, "error": "branch_id is required", "error_detail": self._error_detail("MissingBranchId", "branch_id is required", False), @@ -242,10 +235,10 @@ async def delete_branch(self, branch_id: str, workspace_id: Optional[str] = None normalized_branch_id = branch_id.strip() try: - branches = await self.aidap_client.list_branches(ws_id) + branches = await self.aidap.list_branches(ws_id) exists = any(branch.get("branch_id") == normalized_branch_id for branch in branches) if not exists: - return self._to_json({ + return to_json({ "success": False, "error": f"Branch '{normalized_branch_id}' not found in workspace '{ws_id}'", "error_detail": self._error_detail( @@ -256,16 +249,16 @@ async def delete_branch(self, branch_id: str, workspace_id: Optional[str] = None }) except Exception as e: logger.error(f"Error checking branch before delete: {e}") - return self._to_json({ + return to_json({ "success": False, "error": str(e), "error_detail": self._error_detail("ListBranchesFailed", str(e), True), }) - result = await self.aidap_client.delete_branch(ws_id, normalized_branch_id) + result = await self.aidap.delete_branch(ws_id, normalized_branch_id) if not result.get("success"): error_text = result.get("error", "delete branch failed") - return self._to_json({ + return to_json({ "success": False, "error": error_text, "error_detail": self._error_detail( @@ -280,15 +273,15 @@ async def delete_branch(self, branch_id: str, workspace_id: Optional[str] = None for _ in range(max_confirm_attempts): await asyncio.sleep(1) try: - branches = await self.aidap_client.list_branches(ws_id) + branches = await self.aidap.list_branches(ws_id) exists = any(branch.get("branch_id") == normalized_branch_id for branch in branches) if not exists: - return self._to_json({"success": True, "branch_id": normalized_branch_id, "workspace_id": ws_id}) + return to_json({"success": True, "branch_id": normalized_branch_id, "workspace_id": ws_id}) except Exception as e: last_list_error = str(e) if last_list_error: - return self._to_json({ + return to_json({ "success": False, "error": f"Delete requested for branch '{normalized_branch_id}' but verification failed: {last_list_error}", "error_detail": self._error_detail( @@ -297,7 +290,7 @@ async def delete_branch(self, branch_id: str, workspace_id: Optional[str] = None True, ), }) - return self._to_json({ + return to_json({ "success": False, "error": f"Delete requested for branch '{normalized_branch_id}' but branch still exists", "error_detail": self._error_detail( @@ -308,16 +301,15 @@ async def delete_branch(self, branch_id: str, workspace_id: Optional[str] = None }) async def get_workspace_url(self, workspace_id: Optional[str] = None) -> str: - ws_id, branch_id = await self._resolve_target(workspace_id) - if not ws_id: - return self._to_json({"success": False, "error": "workspace_id is required"}) + ws_id, error_response = self._resolve_workspace_or_response(workspace_id) + if error_response: + return error_response - endpoint = await self.aidap_client.get_endpoint(ws_id, branch_id=branch_id) + endpoint = await self.aidap.get_endpoint(ws_id) if not endpoint: - target_id = branch_id or ws_id - return self._to_json({ + return to_json({ "success": False, - "error": f"Could not get endpoint for workspace {target_id}", + "error": f"Could not get endpoint for workspace {ws_id}", }) payload = { @@ -326,16 +318,13 @@ async def get_workspace_url(self, workspace_id: Optional[str] = None) -> str: "workspace_url": endpoint, "api_url": endpoint, } - if branch_id: - payload["branch_id"] = branch_id - payload["target_type"] = "branch" - return self._to_json(payload) + return to_json(payload) - async def _get_api_keys_payload(self, workspace_id: str, branch_id: Optional[str] = None, reveal: bool = False) -> dict: - resolved_branch_id = branch_id or await self.aidap_client.get_default_branch_id(workspace_id) + async def _get_api_keys_payload(self, workspace_id: str, reveal: bool = False) -> dict: + resolved_branch_id = await self.aidap.get_default_branch_id(workspace_id) if not resolved_branch_id: raise RuntimeError(f"Could not resolve default branch for workspace {workspace_id}") - keys = await self.aidap_client.get_api_keys(workspace_id, branch_id=resolved_branch_id) + keys = await self.aidap.get_api_keys(workspace_id, branch_id=resolved_branch_id) publishable_key = None anon_key = None service_role_key = None @@ -367,44 +356,38 @@ async def _get_api_keys_payload(self, workspace_id: str, branch_id: Optional[str return payload async def get_publishable_keys(self, workspace_id: Optional[str] = None, reveal: bool = False) -> str: - ws_id, branch_id = await self._resolve_target(workspace_id) - if not ws_id: - return self._to_json({"success": False, "error": "workspace_id is required"}) - try: - payload = await self._get_api_keys_payload(ws_id, branch_id=branch_id, reveal=reveal) - return self._to_json(payload) + ws_id = self._resolve_workspace_id(workspace_id) + payload = await self._get_api_keys_payload(ws_id, reveal=reveal) + return to_json(payload) except Exception as e: logger.error(f"Error getting publishable keys: {e}") - return self._to_json({"success": False, "error": str(e)}) + return to_json({"success": False, "error": str(e)}) - @read_only_check - async def reset_branch( + async def restore_branch( self, branch_id: str, - migration_version: Optional[str] = None, + source_branch_id: Optional[str] = None, + time: Optional[str] = None, workspace_id: Optional[str] = None, ) -> str: - ws_id, _ = await self._resolve_target(workspace_id) - if not ws_id: - return self._to_json({ - "success": False, - "error": "workspace_id is required", - }) - try: - result = await self.aidap_client.reset_branch(ws_id, branch_id) + ws_id = self._resolve_workspace_id(workspace_id) + result = await self.aidap.restore_branch( + ws_id, + branch_id, + source_branch_id=source_branch_id, + time=time, + ) if not isinstance(result, dict): result = {"success": bool(result)} if result.get("success"): result.setdefault("workspace_id", ws_id) result.setdefault("branch_id", branch_id) - if migration_version: - result["warning"] = "migration_version is ignored because current AIDAP reset_branch API does not support version-targeted reset" - return self._to_json(result) + return to_json(compact_dict(result)) except Exception as e: - logger.error(f"Error resetting branch: {e}") - return self._to_json({ + logger.error(f"Error restoring branch: {e}") + return to_json({ "success": False, "error": str(e), }) diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/utils/__init__.py b/server/mcp_server_supabase/src/mcp_server_supabase/utils/__init__.py index e43d67e1..d1a3bd03 100644 --- a/server/mcp_server_supabase/src/mcp_server_supabase/utils/__init__.py +++ b/server/mcp_server_supabase/src/mcp_server_supabase/utils/__init__.py @@ -1,14 +1,11 @@ from .common import compact_dict, pick_value, to_json -from .decorators import format_error, handle_errors, read_only_check -from .targets import resolve_target, select_target_id +from .decorators import handle_errors +from .targets import resolve_workspace_id __all__ = [ 'compact_dict', - 'format_error', 'handle_errors', 'pick_value', - 'read_only_check', - 'resolve_target', - 'select_target_id', + 'resolve_workspace_id', 'to_json', ] diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/utils/decorators.py b/server/mcp_server_supabase/src/mcp_server_supabase/utils/decorators.py index aee00dcc..cd77bc31 100644 --- a/server/mcp_server_supabase/src/mcp_server_supabase/utils/decorators.py +++ b/server/mcp_server_supabase/src/mcp_server_supabase/utils/decorators.py @@ -1,14 +1,13 @@ -import json import logging from functools import wraps -from typing import Any, Callable +from typing import Callable from .common import to_json logger = logging.getLogger(__name__) -def format_error(e: Exception) -> str: +def _format_error(e: Exception) -> str: error_msg = str(e) if str(e) else f"{type(e).__name__}" return error_msg @@ -27,17 +26,7 @@ async def wrapper(*args, **kwargs) -> str: result = result.model_dump() return to_json(result) except Exception as e: - error_msg = format_error(e) + error_msg = _format_error(e) logger.error(f"Error in {func.__name__}: {error_msg}") return to_json({"error": error_msg}) return wrapper - - -def read_only_check(func: Callable) -> Callable: - @wraps(func) - async def wrapper(*args, **kwargs) -> Any: - from ..config import READ_ONLY - if READ_ONLY: - return to_json({"error": f"Cannot execute {func.__name__} in read-only mode"}) - return await func(*args, **kwargs) - return wrapper diff --git a/server/mcp_server_supabase/src/mcp_server_supabase/utils/targets.py b/server/mcp_server_supabase/src/mcp_server_supabase/utils/targets.py index a5f1126e..150fb538 100644 --- a/server/mcp_server_supabase/src/mcp_server_supabase/utils/targets.py +++ b/server/mcp_server_supabase/src/mcp_server_supabase/utils/targets.py @@ -1,12 +1,12 @@ from typing import Optional -def select_target_id(target_id: Optional[str], default_target_id: Optional[str]) -> Optional[str]: - return target_id or default_target_id - - -async def resolve_target(aidap_client, target_id: Optional[str], default_target_id: Optional[str]) -> tuple[Optional[str], Optional[str]]: - resolved_id = select_target_id(target_id, default_target_id) - if not resolved_id: - return None, None - return await aidap_client.resolve_workspace_and_branch(resolved_id) +def resolve_workspace_id(workspace_id: Optional[str]) -> Optional[str]: + if not workspace_id: + return None + normalized_id = workspace_id.strip() + if not normalized_id: + return None + if normalized_id.startswith("br-"): + raise ValueError("workspace_id must be a workspace ID; branch IDs are not supported") + return normalized_id