From 89cafbb07603b0050d57bde86138265a692450bc Mon Sep 17 00:00:00 2001 From: Eric Gustin Date: Wed, 25 Feb 2026 21:59:27 -0800 Subject: [PATCH 1/3] Remove toolkits --- .github/actions/setup-uv-env/action.yml | 16 +- .../workflows/release-on-version-change.yml | 3 - .github/workflows/test-toolkits.yml | 132 ---- CLAUDE.md | 3 +- Makefile | 115 +--- toolkits/brightdata/.pre-commit-config.yaml | 18 - toolkits/brightdata/.ruff.toml | 44 -- toolkits/brightdata/LICENSE | 21 - toolkits/brightdata/Makefile | 55 -- .../brightdata/arcade_brightdata/__init__.py | 3 - .../brightdata/arcade_brightdata/__main__.py | 28 - .../arcade_brightdata/bright_data_client.py | 63 -- .../arcade_brightdata/tools/__init__.py | 7 - .../tools/bright_data_tools.py | 361 ---------- toolkits/brightdata/pyproject.toml | 62 -- toolkits/brightdata/tests/__init__.py | 0 toolkits/brightdata/tests/test_brightdata.py | 418 ------------ toolkits/clickhouse/Makefile | 53 -- .../clickhouse/arcade_clickhouse/__init__.py | 0 .../clickhouse/arcade_clickhouse/__main__.py | 29 - .../arcade_clickhouse/database_engine.py | 209 ------ .../arcade_clickhouse/tools/__init__.py | 0 .../arcade_clickhouse/tools/clickhouse.py | 347 ---------- toolkits/clickhouse/pyproject.toml | 66 -- toolkits/clickhouse/tests/__init__.py | 0 toolkits/clickhouse/tests/dump.sql | 369 ---------- toolkits/clickhouse/tests/test_clickhouse.py | 195 ------ toolkits/clickhouse/tests/test_setup.sh | 3 - toolkits/linkedin/.pre-commit-config.yaml | 18 - toolkits/linkedin/.ruff.toml | 46 -- toolkits/linkedin/LICENSE | 21 - toolkits/linkedin/Makefile | 55 -- toolkits/linkedin/arcade_linkedin/__init__.py | 0 toolkits/linkedin/arcade_linkedin/__main__.py | 28 - .../arcade_linkedin/tools/__init__.py | 0 .../arcade_linkedin/tools/constants.py | 1 - .../linkedin/arcade_linkedin/tools/share.py | 76 --- .../linkedin/arcade_linkedin/tools/utils.py | 68 -- toolkits/linkedin/conftest.py | 24 - toolkits/linkedin/evals/eval_linkedin.py | 48 -- toolkits/linkedin/pyproject.toml | 59 -- toolkits/linkedin/tests/__init__.py | 0 toolkits/linkedin/tests/test_share.py | 35 - toolkits/math/.pre-commit-config.yaml | 18 - toolkits/math/.ruff.toml | 47 -- toolkits/math/LICENSE | 21 - toolkits/math/Makefile | 55 -- toolkits/math/arcade_math/__init__.py | 0 toolkits/math/arcade_math/__main__.py | 29 - toolkits/math/arcade_math/tools/__init__.py | 65 -- toolkits/math/arcade_math/tools/arithmetic.py | 161 ----- toolkits/math/arcade_math/tools/exponents.py | 51 -- .../math/arcade_math/tools/miscellaneous.py | 70 -- toolkits/math/arcade_math/tools/random.py | 57 -- toolkits/math/arcade_math/tools/rational.py | 49 -- toolkits/math/arcade_math/tools/rounding.py | 75 --- toolkits/math/arcade_math/tools/statistics.py | 53 -- .../math/arcade_math/tools/trigonometry.py | 49 -- toolkits/math/evals/eval_math_tools.py | 137 ---- toolkits/math/pyproject.toml | 58 -- toolkits/math/tests/__init__.py | 0 toolkits/math/tests/test_arithmetic.py | 147 ---- toolkits/math/tests/test_exponents.py | 41 -- toolkits/math/tests/test_miscellaneous.py | 81 --- toolkits/math/tests/test_rational.py | 31 - toolkits/math/tests/test_rounding.py | 54 -- toolkits/math/tests/test_statistics.py | 18 - toolkits/math/tests/test_trigonometry.py | 45 -- toolkits/mongodb/Makefile | 53 -- toolkits/mongodb/arcade_mongodb/__init__.py | 0 toolkits/mongodb/arcade_mongodb/__main__.py | 29 - .../mongodb/arcade_mongodb/database_engine.py | 118 ---- .../mongodb/arcade_mongodb/tools/__init__.py | 0 .../mongodb/arcade_mongodb/tools/mongodb.py | 434 ------------ .../mongodb/arcade_mongodb/tools/utils.py | 281 -------- toolkits/mongodb/evals/eval_mongodb.py | 190 ------ toolkits/mongodb/pyproject.toml | 62 -- toolkits/mongodb/tests/__init__.py | 0 toolkits/mongodb/tests/conftest.py | 45 -- toolkits/mongodb/tests/dump.js | 378 ----------- .../mongodb/tests/test_json_validation.py | 221 ------ toolkits/mongodb/tests/test_mongodb.py | 292 -------- toolkits/mongodb/tests/test_setup.sh | 12 - .../mongodb/tests/test_write_validation.py | 248 ------- toolkits/postgres/Makefile | 53 -- toolkits/postgres/arcade_postgres/__init__.py | 0 toolkits/postgres/arcade_postgres/__main__.py | 29 - .../arcade_postgres/database_engine.py | 180 ----- .../arcade_postgres/tools/__init__.py | 0 .../arcade_postgres/tools/postgres.py | 300 --------- toolkits/postgres/evals/eval_postgres.py | 94 --- toolkits/postgres/pyproject.toml | 65 -- toolkits/postgres/tests/__init__.py | 0 toolkits/postgres/tests/dump.sql | 399 ----------- toolkits/postgres/tests/test_postgres.py | 188 ------ toolkits/postgres/tests/test_setup.sh | 15 - toolkits/zendesk/.pre-commit-config.yaml | 18 - toolkits/zendesk/.ruff.toml | 46 -- toolkits/zendesk/Makefile | 55 -- toolkits/zendesk/arcade_zendesk/__init__.py | 15 - toolkits/zendesk/arcade_zendesk/__main__.py | 28 - toolkits/zendesk/arcade_zendesk/enums.py | 27 - .../zendesk/arcade_zendesk/tools/__init__.py | 17 - .../arcade_zendesk/tools/search_articles.py | 219 ------ .../arcade_zendesk/tools/system_context.py | 45 -- .../zendesk/arcade_zendesk/tools/tickets.py | 367 ---------- toolkits/zendesk/arcade_zendesk/utils.py | 216 ------ .../zendesk/arcade_zendesk/who_am_i_util.py | 118 ---- toolkits/zendesk/evals/eval_articles.py | 360 ---------- toolkits/zendesk/evals/eval_tickets.py | 631 ------------------ toolkits/zendesk/pyproject.toml | 60 -- toolkits/zendesk/tests/__init__.py | 0 toolkits/zendesk/tests/conftest.py | 84 --- .../zendesk/tests/test_search_articles.py | 360 ---------- toolkits/zendesk/tests/test_tickets.py | 526 --------------- toolkits/zendesk/tests/test_utils.py | 291 -------- toolkits/zendesk/tests/test_who_am_i_util.py | 330 --------- 117 files changed, 9 insertions(+), 12001 deletions(-) delete mode 100644 .github/workflows/test-toolkits.yml delete mode 100644 toolkits/brightdata/.pre-commit-config.yaml delete mode 100644 toolkits/brightdata/.ruff.toml delete mode 100644 toolkits/brightdata/LICENSE delete mode 100644 toolkits/brightdata/Makefile delete mode 100644 toolkits/brightdata/arcade_brightdata/__init__.py delete mode 100644 toolkits/brightdata/arcade_brightdata/__main__.py delete mode 100644 toolkits/brightdata/arcade_brightdata/bright_data_client.py delete mode 100644 toolkits/brightdata/arcade_brightdata/tools/__init__.py delete mode 100644 toolkits/brightdata/arcade_brightdata/tools/bright_data_tools.py delete mode 100644 toolkits/brightdata/pyproject.toml delete mode 100644 toolkits/brightdata/tests/__init__.py delete mode 100644 toolkits/brightdata/tests/test_brightdata.py delete mode 100644 toolkits/clickhouse/Makefile delete mode 100644 toolkits/clickhouse/arcade_clickhouse/__init__.py delete mode 100644 toolkits/clickhouse/arcade_clickhouse/__main__.py delete mode 100644 toolkits/clickhouse/arcade_clickhouse/database_engine.py delete mode 100644 toolkits/clickhouse/arcade_clickhouse/tools/__init__.py delete mode 100644 toolkits/clickhouse/arcade_clickhouse/tools/clickhouse.py delete mode 100644 toolkits/clickhouse/pyproject.toml delete mode 100644 toolkits/clickhouse/tests/__init__.py delete mode 100644 toolkits/clickhouse/tests/dump.sql delete mode 100644 toolkits/clickhouse/tests/test_clickhouse.py delete mode 100755 toolkits/clickhouse/tests/test_setup.sh delete mode 100644 toolkits/linkedin/.pre-commit-config.yaml delete mode 100644 toolkits/linkedin/.ruff.toml delete mode 100644 toolkits/linkedin/LICENSE delete mode 100644 toolkits/linkedin/Makefile delete mode 100644 toolkits/linkedin/arcade_linkedin/__init__.py delete mode 100644 toolkits/linkedin/arcade_linkedin/__main__.py delete mode 100644 toolkits/linkedin/arcade_linkedin/tools/__init__.py delete mode 100644 toolkits/linkedin/arcade_linkedin/tools/constants.py delete mode 100644 toolkits/linkedin/arcade_linkedin/tools/share.py delete mode 100644 toolkits/linkedin/arcade_linkedin/tools/utils.py delete mode 100644 toolkits/linkedin/conftest.py delete mode 100644 toolkits/linkedin/evals/eval_linkedin.py delete mode 100644 toolkits/linkedin/pyproject.toml delete mode 100644 toolkits/linkedin/tests/__init__.py delete mode 100644 toolkits/linkedin/tests/test_share.py delete mode 100644 toolkits/math/.pre-commit-config.yaml delete mode 100644 toolkits/math/.ruff.toml delete mode 100644 toolkits/math/LICENSE delete mode 100644 toolkits/math/Makefile delete mode 100644 toolkits/math/arcade_math/__init__.py delete mode 100644 toolkits/math/arcade_math/__main__.py delete mode 100644 toolkits/math/arcade_math/tools/__init__.py delete mode 100644 toolkits/math/arcade_math/tools/arithmetic.py delete mode 100644 toolkits/math/arcade_math/tools/exponents.py delete mode 100644 toolkits/math/arcade_math/tools/miscellaneous.py delete mode 100644 toolkits/math/arcade_math/tools/random.py delete mode 100644 toolkits/math/arcade_math/tools/rational.py delete mode 100644 toolkits/math/arcade_math/tools/rounding.py delete mode 100644 toolkits/math/arcade_math/tools/statistics.py delete mode 100644 toolkits/math/arcade_math/tools/trigonometry.py delete mode 100644 toolkits/math/evals/eval_math_tools.py delete mode 100644 toolkits/math/pyproject.toml delete mode 100644 toolkits/math/tests/__init__.py delete mode 100644 toolkits/math/tests/test_arithmetic.py delete mode 100644 toolkits/math/tests/test_exponents.py delete mode 100644 toolkits/math/tests/test_miscellaneous.py delete mode 100644 toolkits/math/tests/test_rational.py delete mode 100644 toolkits/math/tests/test_rounding.py delete mode 100644 toolkits/math/tests/test_statistics.py delete mode 100644 toolkits/math/tests/test_trigonometry.py delete mode 100644 toolkits/mongodb/Makefile delete mode 100644 toolkits/mongodb/arcade_mongodb/__init__.py delete mode 100644 toolkits/mongodb/arcade_mongodb/__main__.py delete mode 100644 toolkits/mongodb/arcade_mongodb/database_engine.py delete mode 100644 toolkits/mongodb/arcade_mongodb/tools/__init__.py delete mode 100644 toolkits/mongodb/arcade_mongodb/tools/mongodb.py delete mode 100644 toolkits/mongodb/arcade_mongodb/tools/utils.py delete mode 100644 toolkits/mongodb/evals/eval_mongodb.py delete mode 100644 toolkits/mongodb/pyproject.toml delete mode 100644 toolkits/mongodb/tests/__init__.py delete mode 100644 toolkits/mongodb/tests/conftest.py delete mode 100644 toolkits/mongodb/tests/dump.js delete mode 100644 toolkits/mongodb/tests/test_json_validation.py delete mode 100644 toolkits/mongodb/tests/test_mongodb.py delete mode 100755 toolkits/mongodb/tests/test_setup.sh delete mode 100644 toolkits/mongodb/tests/test_write_validation.py delete mode 100644 toolkits/postgres/Makefile delete mode 100644 toolkits/postgres/arcade_postgres/__init__.py delete mode 100644 toolkits/postgres/arcade_postgres/__main__.py delete mode 100644 toolkits/postgres/arcade_postgres/database_engine.py delete mode 100644 toolkits/postgres/arcade_postgres/tools/__init__.py delete mode 100644 toolkits/postgres/arcade_postgres/tools/postgres.py delete mode 100644 toolkits/postgres/evals/eval_postgres.py delete mode 100644 toolkits/postgres/pyproject.toml delete mode 100644 toolkits/postgres/tests/__init__.py delete mode 100644 toolkits/postgres/tests/dump.sql delete mode 100644 toolkits/postgres/tests/test_postgres.py delete mode 100755 toolkits/postgres/tests/test_setup.sh delete mode 100644 toolkits/zendesk/.pre-commit-config.yaml delete mode 100644 toolkits/zendesk/.ruff.toml delete mode 100644 toolkits/zendesk/Makefile delete mode 100644 toolkits/zendesk/arcade_zendesk/__init__.py delete mode 100644 toolkits/zendesk/arcade_zendesk/__main__.py delete mode 100644 toolkits/zendesk/arcade_zendesk/enums.py delete mode 100644 toolkits/zendesk/arcade_zendesk/tools/__init__.py delete mode 100644 toolkits/zendesk/arcade_zendesk/tools/search_articles.py delete mode 100644 toolkits/zendesk/arcade_zendesk/tools/system_context.py delete mode 100644 toolkits/zendesk/arcade_zendesk/tools/tickets.py delete mode 100644 toolkits/zendesk/arcade_zendesk/utils.py delete mode 100644 toolkits/zendesk/arcade_zendesk/who_am_i_util.py delete mode 100644 toolkits/zendesk/evals/eval_articles.py delete mode 100644 toolkits/zendesk/evals/eval_tickets.py delete mode 100644 toolkits/zendesk/pyproject.toml delete mode 100644 toolkits/zendesk/tests/__init__.py delete mode 100644 toolkits/zendesk/tests/conftest.py delete mode 100644 toolkits/zendesk/tests/test_search_articles.py delete mode 100644 toolkits/zendesk/tests/test_tickets.py delete mode 100644 toolkits/zendesk/tests/test_utils.py delete mode 100644 toolkits/zendesk/tests/test_who_am_i_util.py diff --git a/.github/actions/setup-uv-env/action.yml b/.github/actions/setup-uv-env/action.yml index abccad063..ee514ffe7 100644 --- a/.github/actions/setup-uv-env/action.yml +++ b/.github/actions/setup-uv-env/action.yml @@ -6,10 +6,6 @@ inputs: required: false description: "The python version to use" default: "3.11" - is-toolkit: - required: false - description: "Whether this is a toolkit package" - default: "false" is-contrib: required: false description: "Whether this is a contrib package" @@ -20,7 +16,7 @@ inputs: default: "false" working-directory: required: false - description: "Working directory for the installation (used for toolkits)" + description: "Working directory for the installation" default: "." runs: @@ -32,14 +28,6 @@ runs: working-directory: ${{ inputs.working-directory }} python-version: ${{ inputs.python-version }} - - name: Install toolkit dependencies - if: inputs.is-toolkit == 'true' - working-directory: ${{ inputs.working-directory }} - run: | - echo "Installing dependencies for ${{ inputs.working-directory }}" - make install-local - shell: bash - - name: Install contrib dependencies if: inputs.is-contrib == 'true' working-directory: ${{ inputs.working-directory }} @@ -49,6 +37,6 @@ runs: shell: bash - name: Install libs dependencies - if: inputs.is-toolkit != 'true' + if: inputs.is-contrib != 'true' run: uv sync --extra all --extra dev shell: bash diff --git a/.github/workflows/release-on-version-change.yml b/.github/workflows/release-on-version-change.yml index 68129f3bd..986b972d1 100644 --- a/.github/workflows/release-on-version-change.yml +++ b/.github/workflows/release-on-version-change.yml @@ -83,14 +83,11 @@ jobs: uses: ./.github/actions/setup-uv-env with: python-version: "3.10" - is-toolkit: ${{ startsWith(matrix.package, 'toolkits/') }} is-contrib: ${{ startsWith(matrix.package, 'contrib/') }} is-lib: ${{ startsWith(matrix.package, 'libs/') }} working-directory: ${{ matrix.package }} - name: Run tests - # Skip tests for toolkits - tests are run on every PR commit for toolkits - if: ${{ !startsWith(matrix.package, 'toolkits/') }} working-directory: ${{ matrix.package }} run: | # Run tests if they exist diff --git a/.github/workflows/test-toolkits.yml b/.github/workflows/test-toolkits.yml deleted file mode 100644 index 0195197be..000000000 --- a/.github/workflows/test-toolkits.yml +++ /dev/null @@ -1,132 +0,0 @@ -name: Test Toolkits - -on: - push: - branches: - - main - pull_request: - types: [opened, synchronize, reopened, ready_for_review] - -env: - ARCADE_USAGE_TRACKING: "0" - -jobs: - setup: - runs-on: ubuntu-latest - outputs: - toolkits_with_gha_secrets: ${{ steps.load_toolkits.outputs.toolkits_with_gha_secrets }} - toolkits_without_gha_secrets: ${{ steps.load_toolkits.outputs.toolkits_without_gha_secrets }} - steps: - - name: Check out - uses: actions/checkout@v4 - - - name: determine toolkits with and without GHA secrets - id: load_toolkits - run: | - # Find all directories in toolkits/ that have a pyproject.toml - TOOLKITS=$(find toolkits -maxdepth 1 -type d -not -name "toolkits" -exec test -f {}/pyproject.toml \; -exec basename {} \; | jq -R -s -c 'split("\n")[:-1]') - TOOLKITS_WITH_GHA_SECRETS='["postgres", "clickhouse", "mongodb"]' - TOOLKITS_WITHOUT_GHA_SECRETS=$(echo "$TOOLKITS" | jq -c --argjson with "$TOOLKITS_WITH_GHA_SECRETS" '[.[] | select(. as $t | $with | index($t) | not)]') - echo "Found toolkits: $TOOLKITS" - echo "Found toolkits without GHA secrets: $TOOLKITS_WITHOUT_GHA_SECRETS" - echo "Found toolkits with GHA secrets: $TOOLKITS_WITH_GHA_SECRETS" - echo "toolkits_without_gha_secrets=$TOOLKITS_WITHOUT_GHA_SECRETS" >> $GITHUB_OUTPUT - echo "toolkits_with_gha_secrets=$TOOLKITS_WITH_GHA_SECRETS" >> $GITHUB_OUTPUT - - test-toolkits: - needs: setup - name: test-toolkits (${{ matrix.toolkit }}, ${{ matrix.os }}) - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest, windows-latest, macos-latest] - toolkit: ${{ fromJson(needs.setup.outputs.toolkits_without_gha_secrets) }} - fail-fast: false - steps: - - name: Check out - uses: actions/checkout@v4 - - - name: Set up the environment - uses: ./.github/actions/setup-uv-env - with: - is-toolkit: "true" - working-directory: toolkits/${{ matrix.toolkit }} - - - name: Install toolkit dependencies - working-directory: toolkits/${{ matrix.toolkit }} - shell: bash - run: uv pip install -e ".[dev]" - - - name: Check toolkit - working-directory: toolkits/${{ matrix.toolkit }} - shell: bash - run: | - uv run --active pre-commit run -a - uv run --active mypy --config-file=pyproject.toml - - - name: Test stand-alone toolkits (no secrets) - working-directory: toolkits/${{ matrix.toolkit }} - shell: bash - run: | - # Run pytest and capture exit code - uv run --active pytest -W ignore -v --cov=arcade_${{ matrix.toolkit }} --cov-report=xml || EXIT_CODE=$? - - if [ "${EXIT_CODE:-0}" -eq 5 ]; then - echo "No tests found for toolkit ${{ matrix.toolkit }}, skipping..." - exit 0 - elif [ "${EXIT_CODE:-0}" -ne 0 ]; then - exit ${EXIT_CODE} - fi - - test-toolkits-with-gha-secrets: - needs: setup - # Linux-only: these toolkits bootstrap local DBs via docker/apt in tests/test_setup.sh. - runs-on: ubuntu-latest - strategy: - matrix: - toolkit: ${{ fromJson(needs.setup.outputs.toolkits_with_gha_secrets) }} - fail-fast: true - steps: - - name: Check out - uses: actions/checkout@v4 - - - name: Set up the environment - uses: ./.github/actions/setup-uv-env - with: - is-toolkit: "true" - working-directory: toolkits/${{ matrix.toolkit }} - - - name: Install toolkit dependencies - working-directory: toolkits/${{ matrix.toolkit }} - run: uv pip install -e ".[dev]" - - - name: Check toolkit - working-directory: toolkits/${{ matrix.toolkit }} - run: | - uv run --active pre-commit run -a - uv run --active mypy --config-file=pyproject.toml - - - name: Test stand-alone toolkits (with secrets) - if: | - !github.event.pull_request.head.repo.fork - working-directory: toolkits/${{ matrix.toolkit }} - env: - TEST_POSTGRES_DATABASE_CONNECTION_STRING: ${{ secrets.TEST_POSTGRES_DATABASE_CONNECTION_STRING }} # TODO: dynamically only load the `TEST_${{ matrix.toolkit }}_DATABASE_CONNECTION_STRING secret` - TEST_CLICKHOUSE_DATABASE_CONNECTION_STRING: ${{ secrets.TEST_CLICKHOUSE_DATABASE_CONNECTION_STRING }} - TEST_MONGODB_CONNECTION_STRING: ${{ secrets.TEST_MONGODB_CONNECTION_STRING }} - run: | - # If there's a custom test_setup.sh file, run it - if [ -f tests/test_setup.sh ]; then - echo "Running custom test setup for ${{ matrix.toolkit }}..." - ./tests/test_setup.sh - fi - - # Run pytest and capture exit code - uv run --active pytest -W ignore -v --cov=arcade_${{ matrix.toolkit }} --cov-report=xml || EXIT_CODE=$? - - if [ "${EXIT_CODE:-0}" -eq 5 ]; then - echo "No tests found for toolkit ${{ matrix.toolkit }}, skipping..." - exit 0 - elif [ "${EXIT_CODE:-0}" -ne 0 ]; then - exit ${EXIT_CODE} - fi diff --git a/CLAUDE.md b/CLAUDE.md index edbd06146..d723ffbdd 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co ## What This Is -Arcade MCP is a Python tool-calling platform for building MCP (Model Context Protocol) servers. It's a monorepo containing 5 interdependent libraries, 30+ prebuilt toolkit integrations, and a CLI. +Arcade MCP is a Python tool-calling platform for building MCP (Model Context Protocol) servers. It's a monorepo containing 5 interdependent libraries and a CLI. ## Commands @@ -15,7 +15,6 @@ Arcade MCP is a Python tool-calling platform for building MCP (Model Context Pro | Run a single test | `uv run pytest libs/tests/core/test_toolkit.py::TestClass::test_method` | | Lint + type check | `make check` (pre-commit + mypy) | | Build all wheels | `make build` | -| Run toolkit tests | `make test-toolkits` | Package manager is **uv** โ€” always use `uv run` to execute Python commands, never bare `pip` or `python`. Python 3.10+. Build system is Hatchling. diff --git a/Makefile b/Makefile index 4845274c8..329a6e4bb 100644 --- a/Makefile +++ b/Makefile @@ -6,37 +6,6 @@ install: ## Install the uv environment and all packages with dependencies @uv run pre-commit install @echo "โœ… All packages and dependencies installed via uv workspace" -.PHONY: install-toolkits -install-toolkits: ## Install dependencies for all toolkits - @echo "๐Ÿš€ Installing dependencies for all toolkits" - @failed=0; \ - successful=0; \ - for dir in toolkits/*/ ; do \ - if [ -d "$$dir" ] && [ -f "$$dir/pyproject.toml" ]; then \ - echo "๐Ÿ“ฆ Installing dependencies for $$dir"; \ - if (cd $$dir && uv pip install -e ".[dev]"); then \ - successful=$$((successful + 1)); \ - else \ - echo "โŒ Failed to install dependencies for $$dir"; \ - failed=$$((failed + 1)); \ - fi; \ - else \ - echo "โš ๏ธ Skipping $$dir (no pyproject.toml found)"; \ - fi; \ - done; \ - echo ""; \ - echo "๐Ÿ“Š Installation Summary:"; \ - echo " โœ… Successful: $$successful toolkits"; \ - echo " โŒ Failed: $$failed toolkits"; \ - if [ $$failed -gt 0 ]; then \ - echo ""; \ - echo "โš ๏ธ Some toolkit installations failed. Check the output above for details."; \ - exit 1; \ - else \ - echo ""; \ - echo "๐ŸŽ‰ All toolkit dependencies installed successfully!"; \ - fi - .PHONY: check check: ## Run code quality tools. @echo "๐Ÿš€ Linting code: Running pre-commit" @@ -56,18 +25,6 @@ check-libs: ## Run code quality tools for each lib package (cd $$lib && uv run mypy . || true); \ done -.PHONY: check-toolkits -check-toolkits: ## Run code quality tools for each toolkit that has a Makefile - @echo "๐Ÿš€ Running 'make check' in each toolkit with a Makefile" - @for dir in toolkits/*/ ; do \ - if [ -f "$$dir/Makefile" ]; then \ - echo "๐Ÿ› ๏ธ Checking toolkit $$dir"; \ - (cd "$$dir" && uv run pre-commit run -a && uv run mypy --config-file=pyproject.toml); \ - else \ - echo "๐Ÿ› ๏ธ Skipping toolkit $$dir (no Makefile found)"; \ - fi; \ - done - .PHONY: test test: install ## Test the code with pytest @echo "๐Ÿš€ Testing libs: Running pytest" @@ -81,15 +38,6 @@ test-libs: ## Test each lib package individually (cd $$lib && uv run pytest -W ignore -v || true); \ done -.PHONY: test-toolkits -test-toolkits: ## Iterate over all toolkits and run pytest on each one - @echo "๐Ÿš€ Testing code in toolkits: Running pytest" - @for dir in toolkits/*/ ; do \ - toolkit_name=$$(basename "$$dir"); \ - echo "๐Ÿงช Testing $$toolkit_name toolkit"; \ - (cd $$dir && uv run pytest -W ignore -v --cov=arcade_$$toolkit_name --cov-report=xml || exit 1); \ - done - .PHONY: coverage coverage: ## Generate coverage report @echo "coverage report" @@ -107,38 +55,6 @@ build: clean-build ## Build wheel files using uv fi; \ done -.PHONY: build-toolkits -build-toolkits: ## Build wheel files for all toolkits - @echo "๐Ÿš€ Creating wheel files for all toolkits" - @failed=0; \ - successful=0; \ - for dir in toolkits/*/ ; do \ - if [ -d "$$dir" ] && [ -f "$$dir/pyproject.toml" ]; then \ - toolkit_name=$$(basename "$$dir"); \ - echo "๐Ÿ› ๏ธ Building toolkit $$toolkit_name"; \ - if (cd $$dir && uv build); then \ - successful=$$((successful + 1)); \ - else \ - echo "โŒ Failed to build toolkit $$toolkit_name"; \ - failed=$$((failed + 1)); \ - fi; \ - else \ - echo "โš ๏ธ Skipping $$dir (no pyproject.toml found)"; \ - fi; \ - done; \ - echo ""; \ - echo "๐Ÿ“Š Build Summary:"; \ - echo " โœ… Successful: $$successful toolkits"; \ - echo " โŒ Failed: $$failed toolkits"; \ - if [ $$failed -gt 0 ]; then \ - echo ""; \ - echo "โš ๏ธ Some toolkit builds failed. Check the output above for details."; \ - exit 1; \ - else \ - echo ""; \ - echo "๐ŸŽ‰ All toolkit wheels built successfully!"; \ - fi - .PHONY: clean-build clean-build: ## clean build artifacts @echo "๐Ÿ—‘๏ธ Cleaning build artifacts" @@ -161,7 +77,7 @@ build-and-publish: build publish ## Build and publish. .PHONY: docker docker: ## Build and run the Docker container - @echo "๐Ÿš€ Building lib packages and toolkit wheels..." + @echo "๐Ÿš€ Building lib packages..." @make full-dist @echo "๐Ÿš€ Building Docker image" @cd docker && make docker-build @@ -169,22 +85,19 @@ docker: ## Build and run the Docker container .PHONY: docker-base docker-base: ## Build and run the Docker container - @echo "๐Ÿš€ Building lib packages and toolkit wheels..." + @echo "๐Ÿš€ Building lib packages..." @make full-dist @echo "๐Ÿš€ Building Docker image" - @cd docker && INSTALL_TOOLKITS=false make docker-build - @cd docker && INSTALL_TOOLKITS=false make docker-run + @cd docker && make docker-build + @cd docker && make docker-run .PHONY: publish-ghcr publish-ghcr: ## Publish to the GHCR - # Publish the base image - ghcr.io/arcadeai/worker-base - @cd docker && INSTALL_TOOLKITS=false make publish-ghcr - # Publish the image with toolkits - ghcr.io/arcadeai/worker - @cd docker && INSTALL_TOOLKITS=true make publish-ghcr + @cd docker && make publish-ghcr .PHONY: full-dist full-dist: clean-dist ## Build all projects and copy wheels to ./dist - @echo "๐Ÿ› ๏ธ Building a full distribution with lib packages and toolkits" + @echo "๐Ÿ› ๏ธ Building a full distribution with lib packages" @echo "๐Ÿ› ๏ธ Building all lib packages and copying wheels to ./dist" @mkdir -p dist @@ -198,16 +111,6 @@ full-dist: clean-dist ## Build all projects and copy wheels to ./dist @uv build @rm -f dist/*.tar.gz - @echo "๐Ÿ› ๏ธ Building all toolkit packages and copying wheels to ./dist" - @for dir in toolkits/*/ ; do \ - if [ -d "$$dir" ] && [ -f "$$dir/pyproject.toml" ]; then \ - toolkit_name=$$(basename "$$dir"); \ - echo "๐Ÿ› ๏ธ Building toolkit $$toolkit_name wheel..."; \ - (cd $$dir && uv build); \ - cp $$dir/dist/*.whl dist/; \ - fi; \ - done - .PHONY: clean-dist clean-dist: ## Clean all built distributions @echo "๐Ÿ—‘๏ธ Cleaning dist directory" @@ -216,12 +119,6 @@ clean-dist: ## Clean all built distributions @for lib in libs/arcade*/ ; do \ rm -rf "$$lib"/dist; \ done - @echo "๐Ÿ—‘๏ธ Cleaning toolkits/*/dist directory" - @for toolkit_dir in toolkits/*; do \ - if [ -d "$$toolkit_dir" ]; then \ - rm -rf "$$toolkit_dir"/dist; \ - fi; \ - done .PHONY: setup setup: ## Run uv environment setup script diff --git a/toolkits/brightdata/.pre-commit-config.yaml b/toolkits/brightdata/.pre-commit-config.yaml deleted file mode 100644 index e9fa67332..000000000 --- a/toolkits/brightdata/.pre-commit-config.yaml +++ /dev/null @@ -1,18 +0,0 @@ -files: ^arcade_brightdata/.* -repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: "v4.4.0" - hooks: - - id: check-case-conflict - - id: check-merge-conflict - - id: check-toml - - id: check-yaml - - id: end-of-file-fixer - - id: trailing-whitespace - - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.7 - hooks: - - id: ruff - args: [--fix] - - id: ruff-format diff --git a/toolkits/brightdata/.ruff.toml b/toolkits/brightdata/.ruff.toml deleted file mode 100644 index 9519fe6c3..000000000 --- a/toolkits/brightdata/.ruff.toml +++ /dev/null @@ -1,44 +0,0 @@ -target-version = "py310" -line-length = 100 -fix = true - -[lint] -select = [ - # flake8-2020 - "YTT", - # flake8-bandit - "S", - # flake8-bugbear - "B", - # flake8-builtins - "A", - # flake8-comprehensions - "C4", - # flake8-debugger - "T10", - # flake8-simplify - "SIM", - # isort - "I", - # mccabe - "C90", - # pycodestyle - "E", "W", - # pyflakes - "F", - # pygrep-hooks - "PGH", - # pyupgrade - "UP", - # ruff - "RUF", - # tryceratops - "TRY", -] - -[lint.per-file-ignores] -"**/tests/*" = ["S101"] - -[format] -preview = true -skip-magic-trailing-comma = false diff --git a/toolkits/brightdata/LICENSE b/toolkits/brightdata/LICENSE deleted file mode 100644 index dfbb8b76d..000000000 --- a/toolkits/brightdata/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2025, Arcade AI - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/toolkits/brightdata/Makefile b/toolkits/brightdata/Makefile deleted file mode 100644 index 0a8969beb..000000000 --- a/toolkits/brightdata/Makefile +++ /dev/null @@ -1,55 +0,0 @@ -.PHONY: help - -help: - @echo "๐Ÿ› ๏ธ github Commands:\n" - @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' - -.PHONY: install -install: ## Install the uv environment and install all packages with dependencies - @echo "๐Ÿš€ Creating virtual environment and installing all packages using uv" - @uv sync --active --all-extras --no-sources - @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi - @echo "โœ… All packages and dependencies installed via uv" - -.PHONY: install-local -install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources - @echo "๐Ÿš€ Creating virtual environment and installing all packages using uv" - @uv sync --active --all-extras - @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi - @echo "โœ… All packages and dependencies installed via uv" - -.PHONY: build -build: clean-build ## Build wheel file using poetry - @echo "๐Ÿš€ Creating wheel file" - uv build - -.PHONY: clean-build -clean-build: ## clean build artifacts - @echo "๐Ÿ—‘๏ธ Cleaning dist directory" - rm -rf dist - -.PHONY: test -test: ## Test the code with pytest - @echo "๐Ÿš€ Testing code: Running pytest" - @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml - -.PHONY: coverage -coverage: ## Generate coverage report - @echo "coverage report" - @uv run --no-sources coverage report - @echo "Generating coverage report" - @uv run --no-sources coverage html - -.PHONY: bump-version -bump-version: ## Bump the version in the pyproject.toml file by a patch version - @echo "๐Ÿš€ Bumping version in pyproject.toml" - uv version --no-sources --bump patch - -.PHONY: check -check: ## Run code quality tools. - @if [ -f .pre-commit-config.yaml ]; then\ - echo "๐Ÿš€ Linting code: Running pre-commit";\ - uv run --no-sources pre-commit run -a;\ - fi - @echo "๐Ÿš€ Static type checking: Running mypy" - @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/brightdata/arcade_brightdata/__init__.py b/toolkits/brightdata/arcade_brightdata/__init__.py deleted file mode 100644 index b5983dd11..000000000 --- a/toolkits/brightdata/arcade_brightdata/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from arcade_brightdata.tools import scrape_as_markdown, search_engine, web_data_feed - -__all__ = ["scrape_as_markdown", "search_engine", "web_data_feed"] diff --git a/toolkits/brightdata/arcade_brightdata/__main__.py b/toolkits/brightdata/arcade_brightdata/__main__.py deleted file mode 100644 index 7c433a074..000000000 --- a/toolkits/brightdata/arcade_brightdata/__main__.py +++ /dev/null @@ -1,28 +0,0 @@ -import sys -from typing import cast - -from arcade_mcp_server import MCPApp -from arcade_mcp_server.mcp_app import TransportType - -import arcade_brightdata - -app = MCPApp( - name="BrightData", - instructions=( - "Use this server when you need to interact with Bright Data to help users " - "scrape web pages, search the web, and extract structured data from websites." - ), -) - -app.add_tools_from_module(arcade_brightdata) - - -def main() -> None: - transport = sys.argv[1] if len(sys.argv) > 1 else "stdio" - host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1" - port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000 - app.run(transport=cast(TransportType, transport), host=host, port=port) - - -if __name__ == "__main__": - main() diff --git a/toolkits/brightdata/arcade_brightdata/bright_data_client.py b/toolkits/brightdata/arcade_brightdata/bright_data_client.py deleted file mode 100644 index 94cb32a4c..000000000 --- a/toolkits/brightdata/arcade_brightdata/bright_data_client.py +++ /dev/null @@ -1,63 +0,0 @@ -import json -from typing import ClassVar -from urllib.parse import quote - -import requests - - -class BrightDataClient: - """Engine for interacting with Bright Data API with connection management.""" - - _clients: ClassVar[dict[str, "BrightDataClient"]] = {} - - def __init__(self, api_key: str, zone: str = "web_unlocker1") -> None: - """ - Initialize with API token and default zone. - Args: - api_key (str): Your Bright Data API token - zone (str): Bright Data zone name - """ - self.api_key = api_key - self.headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", - } - self.zone = zone - self.endpoint = "https://api.brightdata.com/request" - - @classmethod - def create_client(cls, api_key: str, zone: str = "web_unlocker1") -> "BrightDataClient": - """Create or get cached client instance using API key only.""" - if api_key not in cls._clients: - cls._clients[api_key] = cls(api_key, zone) - - # Update zone for this request (user controls zone per request) - client = cls._clients[api_key] - client.zone = zone - return client - - @classmethod - def clear_cache(cls) -> None: - """Clear the client cache.""" - cls._clients.clear() - - def make_request(self, payload: dict) -> str: - """ - Make a request to Bright Data API. - Args: - payload (Dict): Request payload - Returns: - str: Response text - """ - response = requests.post( - self.endpoint, headers=self.headers, data=json.dumps(payload), timeout=30 - ) - - response.raise_for_status() - result: str = response.text - return result - - @staticmethod - def encode_query(query: str) -> str: - """URL encode a search query.""" - return quote(query) diff --git a/toolkits/brightdata/arcade_brightdata/tools/__init__.py b/toolkits/brightdata/arcade_brightdata/tools/__init__.py deleted file mode 100644 index d52374092..000000000 --- a/toolkits/brightdata/arcade_brightdata/tools/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from arcade_brightdata.tools.bright_data_tools import ( - scrape_as_markdown, - search_engine, - web_data_feed, -) - -__all__ = ["scrape_as_markdown", "search_engine", "web_data_feed"] diff --git a/toolkits/brightdata/arcade_brightdata/tools/bright_data_tools.py b/toolkits/brightdata/arcade_brightdata/tools/bright_data_tools.py deleted file mode 100644 index 3d51739c8..000000000 --- a/toolkits/brightdata/arcade_brightdata/tools/bright_data_tools.py +++ /dev/null @@ -1,361 +0,0 @@ -import json -import time -from enum import Enum -from typing import Annotated, Any, cast - -import requests -from arcade_mcp_server import Context, tool -from arcade_mcp_server.exceptions import RetryableToolError -from arcade_mcp_server.metadata import ( - Behavior, - Classification, - Operation, - ServiceDomain, - ToolMetadata, -) - -from arcade_brightdata.bright_data_client import BrightDataClient - - -class DeviceType(str, Enum): - MOBILE = "mobile" - IOS = "ios" - IPHONE = "iphone" - IPAD = "ipad" - ANDROID = "android" - ANDROID_TABLET = "android_tablet" - - -class SearchEngine(str, Enum): - GOOGLE = "google" - BING = "bing" - YANDEX = "yandex" - - -class SearchType(str, Enum): - IMAGES = "images" - SHOPPING = "shopping" - NEWS = "news" - JOBS = "jobs" - - -class SourceType(str, Enum): - AMAZON_PRODUCT = "amazon_product" - AMAZON_PRODUCT_REVIEWS = "amazon_product_reviews" - LINKEDIN_PERSON_PROFILE = "linkedin_person_profile" - LINKEDIN_COMPANY_PROFILE = "linkedin_company_profile" - ZOOMINFO_COMPANY_PROFILE = "zoominfo_company_profile" - INSTAGRAM_PROFILES = "instagram_profiles" - INSTAGRAM_POSTS = "instagram_posts" - INSTAGRAM_REELS = "instagram_reels" - INSTAGRAM_COMMENTS = "instagram_comments" - FACEBOOK_POSTS = "facebook_posts" - FACEBOOK_MARKETPLACE_LISTINGS = "facebook_marketplace_listings" - FACEBOOK_COMPANY_REVIEWS = "facebook_company_reviews" - X_POSTS = "x_posts" - ZILLOW_PROPERTIES_LISTING = "zillow_properties_listing" - BOOKING_HOTEL_LISTINGS = "booking_hotel_listings" - YOUTUBE_VIDEOS = "youtube_videos" - - -@tool( - requires_secrets=["BRIGHTDATA_API_KEY", "BRIGHTDATA_ZONE"], - metadata=ToolMetadata( - classification=Classification( - service_domains=[ServiceDomain.WEB_SCRAPING], - ), - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -def scrape_as_markdown( - context: Context, - url: Annotated[str, "URL to scrape"], -) -> Annotated[str, "Scraped webpage content as Markdown"]: - """ - Scrape a webpage and return content in Markdown format using Bright Data. - - Examples: - scrape_as_markdown("https://example.com") -> "# Example Page\n\nContent..." - scrape_as_markdown("https://news.ycombinator.com") -> "# Hacker News\n..." - """ - api_key = context.get_secret("BRIGHTDATA_API_KEY") - zone = context.get_secret("BRIGHTDATA_ZONE") - client = BrightDataClient.create_client(api_key=api_key, zone=zone) - - payload = {"url": url, "zone": zone, "format": "raw", "data_format": "markdown"} - return client.make_request(payload) - - -@tool( - requires_secrets=["BRIGHTDATA_API_KEY", "BRIGHTDATA_ZONE"], - metadata=ToolMetadata( - classification=Classification( - service_domains=[ServiceDomain.WEB_SCRAPING], - ), - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -def search_engine( # noqa: C901 - context: Context, - query: Annotated[str, "Search query"], - engine: Annotated[SearchEngine, "Search engine to use"] = SearchEngine.GOOGLE, - language: Annotated[str | None, "Two-letter language code"] = None, - country_code: Annotated[str | None, "Two-letter country code"] = None, - search_type: Annotated[SearchType | None, "Type of search"] = None, - start: Annotated[int | None, "Results pagination offset"] = None, - num_results: Annotated[int, "Number of results to return. The default is 10"] = 10, - location: Annotated[str | None, "Location for search results"] = None, - device: Annotated[DeviceType | None, "Device type"] = None, - return_json: Annotated[bool, "Return JSON instead of Markdown"] = False, -) -> Annotated[str, "Search results as Markdown or JSON"]: - """ - Search using Google, Bing, or Yandex with advanced parameters using Bright Data. - - Examples: - search_engine("climate change") -> "# Search Results\n\n## Climate Change - Wikipedia\n..." - search_engine("Python tutorials", engine="bing", num_results=5) -> "# Bing Results\n..." - search_engine("cats", search_type="images", country_code="us") -> "# Image Results\n..." - """ - api_key = context.get_secret("BRIGHTDATA_API_KEY") - zone = context.get_secret("BRIGHTDATA_ZONE") - client = BrightDataClient.create_client(api_key=api_key, zone=zone) - - encoded_query = BrightDataClient.encode_query(query) - - base_urls = { - SearchEngine.GOOGLE: f"https://www.google.com/search?q={encoded_query}", - SearchEngine.BING: f"https://www.bing.com/search?q={encoded_query}", - SearchEngine.YANDEX: f"https://yandex.com/search/?text={encoded_query}", - } - - search_url = base_urls[engine] - - if engine == SearchEngine.GOOGLE: - params = [] - - if language: - params.append(f"hl={language}") - - if country_code: - params.append(f"gl={country_code}") - - if search_type: - if search_type == SearchType.JOBS: - params.append("ibp=htl;jobs") - else: - search_types = { - SearchType.IMAGES: "isch", - SearchType.SHOPPING: "shop", - SearchType.NEWS: "nws", - } - tbm_value = search_types.get(search_type, search_type) - params.append(f"tbm={tbm_value}") - - if start is not None: - params.append(f"start={start}") - - if num_results: - params.append(f"num={num_results}") - - if location: - params.append(f"uule={BrightDataClient.encode_query(location)}") - - if device: - device_value = "1" - - if device.value in ["ios", "iphone"]: - device_value = "ios" - elif device.value == "ipad": - device_value = "ios_tablet" - elif device.value == "android": - device_value = "android" - elif device.value == "android_tablet": - device_value = "android_tablet" - - params.append(f"brd_mobile={device_value}") - - if return_json: - params.append("brd_json=1") - - if params: - search_url += "&" + "&".join(params) - - payload = { - "url": search_url, - "zone": zone, - "format": "raw", - "data_format": "markdown" if not return_json else "raw", - } - - return client.make_request(payload) - - -@tool( - requires_secrets=["BRIGHTDATA_API_KEY"], - metadata=ToolMetadata( - classification=Classification( - service_domains=[ServiceDomain.WEB_SCRAPING], - ), - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=False, - open_world=True, - ), - ), -) -def web_data_feed( - context: Context, - source_type: Annotated[SourceType, "Type of data source"], - url: Annotated[str, "URL of the web resource to extract data from"], - num_of_reviews: Annotated[ - int | None, - ( - "Number of reviews to retrieve. Only applicable for " - "facebook_company_reviews. Default is None" - ), - ] = None, - timeout: Annotated[int, "Maximum time in seconds to wait for data retrieval"] = 600, - polling_interval: Annotated[int, "Time in seconds between polling attempts"] = 1, -) -> Annotated[str, "Structured data from the requested source as JSON"]: - """ - Extract structured data from various websites like LinkedIn, Amazon, Instagram, etc. - NEVER MADE UP LINKS - IF LINKS ARE NEEDED, EXECUTE search_engine FIRST. - Supported source types: - - amazon_product, amazon_product_reviews - - linkedin_person_profile, linkedin_company_profile - - zoominfo_company_profile - - instagram_profiles, instagram_posts, instagram_reels, instagram_comments - - facebook_posts, facebook_marketplace_listings, facebook_company_reviews - - x_posts - - zillow_properties_listing - - booking_hotel_listings - - youtube_videos - - Examples: - web_data_feed("amazon_product", "https://amazon.com/dp/B08N5WRWNW") - -> "{\"title\": \"Product Name\", ...}" - web_data_feed("linkedin_person_profile", "https://linkedin.com/in/johndoe") - -> "{\"name\": \"John Doe\", ...}" - web_data_feed( - "facebook_company_reviews", "https://facebook.com/company", num_of_reviews=50 - ) -> "[{\"review\": \"...\", ...}]" - """ - api_key = context.get_secret("BRIGHTDATA_API_KEY") - client = BrightDataClient.create_client(api_key=api_key) - if num_of_reviews is not None and source_type != SourceType.FACEBOOK_COMPANY_REVIEWS: - msg = ( - f"num_of_reviews parameter is only applicable for facebook_company_reviews, " - f"not for {source_type.value}" - ) - prompt = ( - "The num_of_reviews parameter should only be used with " - "facebook_company_reviews source type." - ) - raise RetryableToolError(msg, additional_prompt_content=prompt) - data = _extract_structured_data( - client=client, - source_type=source_type, - url=url, - num_of_reviews=num_of_reviews, - timeout=timeout, - polling_interval=polling_interval, - ) - return json.dumps(data, indent=2) - - -def _extract_structured_data( - client: BrightDataClient, - source_type: SourceType, - url: str, - num_of_reviews: int | None = None, - timeout: int = 600, - polling_interval: int = 1, -) -> dict[str, Any]: - """ - Extract structured data from various sources. - """ - datasets = { - SourceType.AMAZON_PRODUCT: "gd_l7q7dkf244hwjntr0", - SourceType.AMAZON_PRODUCT_REVIEWS: "gd_le8e811kzy4ggddlq", - SourceType.LINKEDIN_PERSON_PROFILE: "gd_l1viktl72bvl7bjuj0", - SourceType.LINKEDIN_COMPANY_PROFILE: "gd_l1vikfnt1wgvvqz95w", - SourceType.ZOOMINFO_COMPANY_PROFILE: "gd_m0ci4a4ivx3j5l6nx", - SourceType.INSTAGRAM_PROFILES: "gd_l1vikfch901nx3by4", - SourceType.INSTAGRAM_POSTS: "gd_lk5ns7kz21pck8jpis", - SourceType.INSTAGRAM_REELS: "gd_lyclm20il4r5helnj", - SourceType.INSTAGRAM_COMMENTS: "gd_ltppn085pokosxh13", - SourceType.FACEBOOK_POSTS: "gd_lyclm1571iy3mv57zw", - SourceType.FACEBOOK_MARKETPLACE_LISTINGS: "gd_lvt9iwuh6fbcwmx1a", - SourceType.FACEBOOK_COMPANY_REVIEWS: "gd_m0dtqpiu1mbcyc2g86", - SourceType.X_POSTS: "gd_lwxkxvnf1cynvib9co", - SourceType.ZILLOW_PROPERTIES_LISTING: "gd_lfqkr8wm13ixtbd8f5", - SourceType.BOOKING_HOTEL_LISTINGS: "gd_m5mbdl081229ln6t4a", - SourceType.YOUTUBE_VIDEOS: "gd_m5mbdl081229ln6t4a", - } - - dataset_id = datasets[source_type] - - request_data = {"url": url} - if source_type == SourceType.FACEBOOK_COMPANY_REVIEWS and num_of_reviews is not None: - request_data["num_of_reviews"] = str(num_of_reviews) - - trigger_response = requests.post( - "https://api.brightdata.com/datasets/v3/trigger", - params={"dataset_id": dataset_id, "include_errors": "true"}, - headers=client.headers, - json=[request_data], - timeout=30, - ) - - trigger_data = trigger_response.json() - if not trigger_data.get("snapshot_id"): - msg = "No snapshot ID returned from trigger request" - prompt = "Invalid input provided, use search_engine to get the relevant data first" - raise RetryableToolError(msg, additional_prompt_content=prompt) - - snapshot_id = trigger_data["snapshot_id"] - - attempts = 0 - max_attempts = timeout - - while attempts < max_attempts: - try: - snapshot_response = requests.get( - f"https://api.brightdata.com/datasets/v3/snapshot/{snapshot_id}", - params={"format": "json"}, - headers=client.headers, - timeout=30, - ) - - snapshot_data = cast(dict[str, Any], snapshot_response.json()) - - if isinstance(snapshot_data, dict) and snapshot_data.get("status") in ( - "running", - "building", - ): - attempts += 1 - time.sleep(polling_interval) - continue - else: - return snapshot_data - - except Exception: - attempts += 1 - time.sleep(polling_interval) - - msg = f"Timeout after {max_attempts} seconds waiting for {source_type.value} data" - raise TimeoutError(msg) diff --git a/toolkits/brightdata/pyproject.toml b/toolkits/brightdata/pyproject.toml deleted file mode 100644 index 73aaf356b..000000000 --- a/toolkits/brightdata/pyproject.toml +++ /dev/null @@ -1,62 +0,0 @@ -[build-system] -requires = [ "hatchling",] -build-backend = "hatchling.build" - -[project] -name = "arcade_brightdata" -version = "0.4.0" -description = "Search, Crawl and Scrape any site, at scale, without getting blocked" -requires-python = ">=3.10" -dependencies = [ - "arcade-mcp-server>=1.17.0,<2.0.0", - "requests>=2.32.5", -] -[[project.authors]] -name = "meirk-brd" -email = "meirk@brightdata.com" - -[project.scripts] -arcade-brightdata = "arcade_brightdata.__main__:main" -arcade_brightdata = "arcade_brightdata.__main__:main" - -[project.optional-dependencies] -dev = [ - "arcade-mcp[all]>=1.2.0,<2.0.0", - "pytest>=8.3.0,<8.4.0", - "pytest-cov>=4.0.0,<4.1.0", - "pytest-mock>=3.11.1,<3.12.0", - "pytest-asyncio>=0.24.0,<0.25.0", - "mypy>=1.5.1,<1.6.0", - "pre-commit>=3.4.0,<3.5.0", - "tox>=4.11.1,<4.12.0", - "ruff>=0.7.4,<0.8.0", - "types-requests>=2.32.0", -] -# Tell Arcade.dev that this package is a toolkit -[project.entry-points.arcade_toolkits] -toolkit_name = "arcade_brightdata" - -[tool.mypy] -files = [ "arcade_brightdata/**/*.py",] -python_version = "3.10" -disallow_untyped_defs = "True" -disallow_any_unimported = "True" -no_implicit_optional = "True" -check_untyped_defs = "True" -warn_return_any = "True" -warn_unused_ignores = "True" -show_error_codes = "True" -ignore_missing_imports = "True" - -[tool.uv.sources] -arcade-mcp = { path = "../../", editable = true } -arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true } - -[tool.pytest.ini_options] -testpaths = [ "tests",] - -[tool.coverage.report] -skip_empty = true - -[tool.hatch.build.targets.wheel] -packages = [ "arcade_brightdata",] diff --git a/toolkits/brightdata/tests/__init__.py b/toolkits/brightdata/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/toolkits/brightdata/tests/test_brightdata.py b/toolkits/brightdata/tests/test_brightdata.py deleted file mode 100644 index 98555af08..000000000 --- a/toolkits/brightdata/tests/test_brightdata.py +++ /dev/null @@ -1,418 +0,0 @@ -from os import environ -from unittest.mock import MagicMock as _MagicMock -from unittest.mock import Mock, patch - -import pytest -import requests -from arcade_mcp_server import Context -from arcade_mcp_server.exceptions import ToolExecutionError - -from arcade_brightdata.bright_data_client import BrightDataClient -from arcade_brightdata.tools.bright_data_tools import ( - DeviceType, - SourceType, - scrape_as_markdown, - search_engine, - web_data_feed, -) - -BRIGHTDATA_API_KEY = environ.get("TEST_BRIGHTDATA_API_KEY") or "api-key" -BRIGHTDATA_ZONE = environ.get("TEST_BRIGHTDATA_ZONE") or "unblocker" - - -@pytest.fixture -def mock_context(): - context = _MagicMock(spec=Context) - context.get_secret = _MagicMock( - side_effect=lambda key: { - "BRIGHTDATA_API_KEY": BRIGHTDATA_API_KEY, - "BRIGHTDATA_ZONE": BRIGHTDATA_ZONE, - }[key] - ) - return context - - -@pytest.fixture(autouse=True) -def cleanup_engines(): - """Clean up bright data clients after each test to prevent connection leaks.""" - yield - BrightDataClient.clear_cache() - - -class TestBrightDataClient: - def test_get_instance_creates_new_client(self): - client1 = BrightDataClient.create_client("test_key_1", "zone1") - client2 = BrightDataClient.create_client("test_key_2", "zone2") - - assert client1 != client2 - assert client1.api_key == "test_key_1" - assert client1.zone == "zone1" - assert client2.api_key == "test_key_2" - assert client2.zone == "zone2" - - def test_get_instance_returns_cached_client(self): - client1 = BrightDataClient.create_client("test_key", "zone1") - client2 = BrightDataClient.create_client("test_key", "zone1") - - assert client1 is client2 - - def test_clear_cache(self): - client1 = BrightDataClient.create_client("test_key", "zone1") - BrightDataClient.clear_cache() - client2 = BrightDataClient.create_client("test_key", "zone1") - - assert client1 is not client2 - - def test_encode_query(self): - result = BrightDataClient.encode_query("hello world test") - assert result == "hello%20world%20test" - - @patch("requests.post") - def test_make_request_success(self, mock_post): - mock_response = Mock() - mock_response.status_code = 200 - mock_response.text = "Success response" - mock_post.return_value = mock_response - - client = BrightDataClient("test_key", "test_zone") - result = client.make_request({"url": "https://example.com"}) - - assert result == "Success response" - mock_post.assert_called_once() - - @patch("requests.post") - def test_make_request_failure(self, mock_post): - mock_response = Mock() - mock_response.status_code = 400 - mock_response.text = "Bad Request" - mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( - "400 Client Error" - ) - mock_post.return_value = mock_response - - client = BrightDataClient("test_key", "test_zone") - - with pytest.raises(requests.exceptions.HTTPError): - client.make_request({"url": "https://example.com"}) - - -class TestScrapeAsMarkdown: - @patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient") - def test_scrape_as_markdown_success(self, mock_engine_class, mock_context): - mock_client = Mock() - mock_client.make_request.return_value = "# Test Page\n\nContent here" - mock_engine_class.create_client.return_value = mock_client - - result = scrape_as_markdown(mock_context, "https://example.com") - - assert result == "# Test Page\n\nContent here" - mock_engine_class.create_client.assert_called_once_with( - api_key=BRIGHTDATA_API_KEY, zone=BRIGHTDATA_ZONE - ) - mock_client.make_request.assert_called_once_with({ - "url": "https://example.com", - "zone": BRIGHTDATA_ZONE, - "format": "raw", - "data_format": "markdown", - }) - - -class TestSearchEngine: - @patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient") - def test_search_engine_google_basic(self, mock_engine_class, mock_context): - mock_client = Mock() - mock_client.make_request.return_value = "# Search Results\n\nResult 1\nResult 2" - mock_engine_class.create_client.return_value = mock_client - mock_engine_class.encode_query.return_value = "test%20query" - - result = search_engine(mock_context, "test query") - - assert result == "# Search Results\n\nResult 1\nResult 2" - mock_engine_class.create_client.assert_called_once_with( - api_key=BRIGHTDATA_API_KEY, zone=BRIGHTDATA_ZONE - ) - - @patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient") - def test_search_engine_bing(self, mock_engine_class, mock_context): - mock_client = Mock() - mock_client.make_request.return_value = "# Bing Results" - mock_engine_class.create_client.return_value = mock_client - mock_engine_class.encode_query.return_value = "test%20query" - - result = search_engine(mock_context, "test query", engine="bing") - - assert result == "# Bing Results" - expected_payload = { - "url": "https://www.bing.com/search?q=test%20query", - "zone": BRIGHTDATA_ZONE, - "format": "raw", - "data_format": "markdown", - } - mock_client.make_request.assert_called_once_with(expected_payload) - - @patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient") - def test_search_engine_google_with_parameters(self, mock_engine_class, mock_context): - mock_client = Mock() - mock_client.make_request.return_value = "# Google Results with params" - mock_engine_class.create_client.return_value = mock_client - mock_engine_class.encode_query.side_effect = lambda x: x.replace(" ", "%20") - - result = search_engine( - mock_context, - "test query", - language="en", - country_code="us", - search_type="images", - start=10, - num_results=20, - location="New York", - device=DeviceType.MOBILE, - return_json=True, - ) - - assert result == "# Google Results with params" - call_args = mock_client.make_request.call_args[0][0] - - assert "hl=en" in call_args["url"] - assert "gl=us" in call_args["url"] - assert "tbm=isch" in call_args["url"] - assert "start=10" in call_args["url"] - assert "num=20" in call_args["url"] - assert "brd_mobile=1" in call_args["url"] - assert "brd_json=1" in call_args["url"] - assert call_args["data_format"] == "raw" - - def test_search_engine_invalid_engine(self, mock_context): - with pytest.raises(ToolExecutionError): - search_engine(mock_context, "test query", engine="invalid_engine") - - @patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient") - def test_search_engine_google_jobs(self, mock_engine_class, mock_context): - mock_client = Mock() - mock_client.make_request.return_value = "# Job Results" - mock_engine_class.create_client.return_value = mock_client - mock_engine_class.encode_query.return_value = "python%20developer" - - result = search_engine(mock_context, "python developer", search_type="jobs") - - assert result == "# Job Results" - call_args = mock_client.make_request.call_args[0][0] - assert "ibp=htl;jobs" in call_args["url"] - - -class TestWebDataFeed: - @patch("arcade_brightdata.tools.bright_data_tools._extract_structured_data") - @patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient") - def test_web_data_feed_success(self, mock_engine_class, mock_extract, mock_context): - mock_client = Mock() - mock_engine_class.create_client.return_value = mock_client - mock_extract.return_value = {"title": "Test Product", "price": "$19.99"} - - result = web_data_feed(mock_context, "amazon_product", "https://amazon.com/dp/B08N5WRWNW") - - expected_json = '{\n "title": "Test Product",\n "price": "$19.99"\n}' - assert result == expected_json - - mock_engine_class.create_client.assert_called_once_with(api_key=BRIGHTDATA_API_KEY) - mock_extract.assert_called_once_with( - client=mock_client, - source_type=SourceType.AMAZON_PRODUCT, - url="https://amazon.com/dp/B08N5WRWNW", - num_of_reviews=None, - timeout=600, - polling_interval=1, - ) - - @patch("arcade_brightdata.tools.bright_data_tools._extract_structured_data") - @patch("arcade_brightdata.tools.bright_data_tools.BrightDataClient") - def test_web_data_feed_with_reviews(self, mock_engine_class, mock_extract, mock_context): - mock_client = Mock() - mock_engine_class.create_client.return_value = mock_client - mock_extract.return_value = [{"review": "Great product!", "rating": 5}] - - result = web_data_feed( - mock_context, - "facebook_company_reviews", - "https://facebook.com/company", - num_of_reviews=50, - timeout=300, - polling_interval=2, - ) - - expected_json = '[\n {\n "review": "Great product!",\n "rating": 5\n }\n]' - assert result == expected_json - - mock_extract.assert_called_once_with( - client=mock_client, - source_type=SourceType.FACEBOOK_COMPANY_REVIEWS, - url="https://facebook.com/company", - num_of_reviews=50, - timeout=300, - polling_interval=2, - ) - - -class TestExtractStructuredData: - @patch("requests.get") - @patch("requests.post") - def test_extract_structured_data_success(self, mock_post, mock_get): - from arcade_brightdata.tools.bright_data_tools import _extract_structured_data - - client = BrightDataClient("test_key", "test_zone") - - mock_trigger_response = Mock() - mock_trigger_response.json.return_value = {"snapshot_id": "snap_123"} - mock_post.return_value = mock_trigger_response - - mock_snapshot_response = Mock() - mock_snapshot_response.json.return_value = {"data": "extracted_data"} - mock_get.return_value = mock_snapshot_response - - result = _extract_structured_data( - client=client, - source_type=SourceType.AMAZON_PRODUCT, - url="https://amazon.com/dp/TEST", - timeout=10, - polling_interval=0.1, - ) - - assert result == {"data": "extracted_data"} - - mock_post.assert_called_once() - trigger_call = mock_post.call_args - assert "gd_l7q7dkf244hwjntr0" in str(trigger_call) # Amazon product dataset ID - - mock_get.assert_called_once() - snapshot_call = mock_get.call_args - assert "snap_123" in str(snapshot_call) - - @patch("requests.get") - @patch("requests.post") - def test_extract_structured_data_with_polling(self, mock_post, mock_get): - from arcade_brightdata.tools.bright_data_tools import _extract_structured_data - - client = BrightDataClient("test_key", "test_zone") - - mock_trigger_response = Mock() - mock_trigger_response.json.return_value = {"snapshot_id": "snap_123"} - mock_post.return_value = mock_trigger_response - - running_response = Mock() - running_response.json.return_value = {"status": "running"} - - complete_response = Mock() - complete_response.json.return_value = {"data": "final_data"} - - mock_get.side_effect = [running_response, complete_response] - - result = _extract_structured_data( - client=client, - source_type=SourceType.LINKEDIN_PERSON_PROFILE, - url="https://linkedin.com/in/test", - timeout=10, - polling_interval=0.1, - ) - - assert result == {"data": "final_data"} - assert mock_get.call_count == 2 - - @patch("requests.post") - def test_extract_structured_data_invalid_source_type(self, mock_post): - from arcade_brightdata.tools.bright_data_tools import _extract_structured_data - - client = BrightDataClient("test_key", "test_zone") - - # Create a mock SourceType that doesn't exist in the datasets dict - class InvalidSourceType: - value = "invalid_source" - - with pytest.raises(KeyError): - _extract_structured_data( - client=client, source_type=InvalidSourceType(), url="https://example.com" - ) - - @patch("requests.get") - @patch("requests.post") - def test_extract_structured_data_no_snapshot_id(self, mock_post, mock_get): - from arcade_brightdata.tools.bright_data_tools import _extract_structured_data - - client = BrightDataClient("test_key", "test_zone") - - # Mock trigger response without snapshot_id - mock_trigger_response = Mock() - mock_trigger_response.json.return_value = {} - mock_post.return_value = mock_trigger_response - - with pytest.raises(Exception) as exc_info: - _extract_structured_data( - client=client, - source_type=SourceType.AMAZON_PRODUCT, - url="https://amazon.com/dp/TEST", - ) - - assert "No snapshot ID returned from trigger request" in str(exc_info.value) - - @patch("requests.get") - @patch("requests.post") - @patch("time.sleep") - def test_extract_structured_data_timeout(self, mock_sleep, mock_post, mock_get): - from arcade_brightdata.tools.bright_data_tools import _extract_structured_data - - client = BrightDataClient("test_key", "test_zone") - - # Mock trigger response - mock_trigger_response = Mock() - mock_trigger_response.json.return_value = {"snapshot_id": "snap_123"} - mock_post.return_value = mock_trigger_response - - # Mock snapshot response that always returns running - mock_snapshot_response = Mock() - mock_snapshot_response.json.return_value = {"status": "running"} - mock_get.return_value = mock_snapshot_response - - with pytest.raises(TimeoutError) as exc_info: - _extract_structured_data( - client=client, - source_type=SourceType.AMAZON_PRODUCT, - url="https://amazon.com/dp/TEST", - timeout=2, - polling_interval=0.1, - ) - - assert "Timeout after 2 seconds waiting for amazon_product data" in str(exc_info.value) - - -class TestIntegration: - """Integration tests that test the full flow without mocking internal components.""" - - @patch("requests.post") - def test_scrape_as_markdown_integration(self, mock_post, mock_context): - mock_response = Mock() - mock_response.status_code = 200 - mock_response.text = "# Integration Test\n\nThis is a test page" - mock_post.return_value = mock_response - - result = scrape_as_markdown(mock_context, "https://example.com") - - assert result == "# Integration Test\n\nThis is a test page" - - # Verify the request was made correctly - call_args = mock_post.call_args - assert call_args[1]["headers"]["Authorization"] == f"Bearer {BRIGHTDATA_API_KEY}" - assert "https://api.brightdata.com/request" in str(call_args) - - @patch("requests.post") - def test_search_engine_integration(self, mock_post, mock_context): - mock_response = Mock() - mock_response.status_code = 200 - mock_response.text = "# Search Results\n\n1. First result\n2. Second result" - mock_post.return_value = mock_response - - result = search_engine(mock_context, "test query", engine="google") - - assert result == "# Search Results\n\n1. First result\n2. Second result" - - call_args = mock_post.call_args - payload = call_args[1]["data"] - assert '"url": "https://www.google.com/search?q=test%20query' in payload - assert '"data_format": "markdown"' in payload diff --git a/toolkits/clickhouse/Makefile b/toolkits/clickhouse/Makefile deleted file mode 100644 index 7102fe5ea..000000000 --- a/toolkits/clickhouse/Makefile +++ /dev/null @@ -1,53 +0,0 @@ -.PHONY: help - -help: - @echo "๐Ÿ› ๏ธ clickhouse Commands:\n" - @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' - -.PHONY: install -install: ## Install the uv environment and install all packages with dependencies - @echo "๐Ÿš€ Creating virtual environment and installing all packages using uv" - @uv sync --active --all-extras --no-sources - @uv run pre-commit install - @echo "โœ… All packages and dependencies installed via uv" - -.PHONY: install-local -install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources - @echo "๐Ÿš€ Creating virtual environment and installing all packages using uv" - @uv sync --active --all-extras - @uv run pre-commit install - @echo "โœ… All packages and dependencies installed via uv" - -.PHONY: build -build: clean-build ## Build wheel file using poetry - @echo "๐Ÿš€ Creating wheel file" - uv build - -.PHONY: clean-build -clean-build: ## clean build artifacts - @echo "๐Ÿ—‘๏ธ Cleaning dist directory" - rm -rf dist - -.PHONY: test -test: ## Test the code with pytest - @echo "๐Ÿš€ Testing code: Running pytest" - @uv run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml - -.PHONY: coverage -coverage: ## Generate coverage report - @echo "coverage report" - coverage report - @echo "Generating coverage report" - coverage html - -.PHONY: bump-version -bump-version: ## Bump the version in the pyproject.toml file by a patch version - @echo "๐Ÿš€ Bumping version in pyproject.toml" - uv version --bump patch - -.PHONY: check -check: ## Run code quality tools. - @echo "๐Ÿš€ Linting code: Running pre-commit" - @uv run pre-commit run -a - @echo "๐Ÿš€ Static type checking: Running mypy" - @uv run mypy --config-file=pyproject.toml diff --git a/toolkits/clickhouse/arcade_clickhouse/__init__.py b/toolkits/clickhouse/arcade_clickhouse/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/toolkits/clickhouse/arcade_clickhouse/__main__.py b/toolkits/clickhouse/arcade_clickhouse/__main__.py deleted file mode 100644 index 39c445e1c..000000000 --- a/toolkits/clickhouse/arcade_clickhouse/__main__.py +++ /dev/null @@ -1,29 +0,0 @@ -import sys -from typing import cast - -from arcade_mcp_server import MCPApp -from arcade_mcp_server.mcp_app import TransportType - -import arcade_clickhouse - -app = MCPApp( - name="ClickHouse", - instructions=( - "Use this server when you need to interact with ClickHouse to help users " - "query, explore, and manage their ClickHouse databases." - ), -) - -app.add_tools_from_module(arcade_clickhouse) - - -def main() -> None: - transport = sys.argv[1] if len(sys.argv) > 1 else "stdio" - host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1" - port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000 - - app.run(transport=cast(TransportType, transport), host=host, port=port) - - -if __name__ == "__main__": - main() diff --git a/toolkits/clickhouse/arcade_clickhouse/database_engine.py b/toolkits/clickhouse/arcade_clickhouse/database_engine.py deleted file mode 100644 index 75e5a2408..000000000 --- a/toolkits/clickhouse/arcade_clickhouse/database_engine.py +++ /dev/null @@ -1,209 +0,0 @@ -import contextlib -from typing import Any, ClassVar -from urllib.parse import urlparse - -import clickhouse_connect -from arcade_mcp_server.exceptions import RetryableToolError - -MAX_ROWS_RETURNED = 1000 -TEST_QUERY = "SELECT 1" - - -class DatabaseEngine: - _instance: ClassVar[None] = None - _clients: ClassVar[dict[str, Any]] = {} - - @classmethod - async def get_instance(cls, connection_string: str) -> Any: - parsed_url = urlparse(connection_string) - - # Extract connection parameters from the URL - host = parsed_url.hostname or "localhost" - port = parsed_url.port - database = parsed_url.path.lstrip("/") or "default" - username = parsed_url.username - password = parsed_url.password - - # Handle different ClickHouse protocols - # clickhouse-connect only supports HTTP and HTTPS interfaces - if parsed_url.scheme in ["clickhouse+native"]: - # Convert native protocol to HTTP for clickhouse-connect compatibility - # Convert native port 9000 to HTTP port 8123 - port = 8123 if port == 9000 else port or 8123 - interface = "http" - elif parsed_url.scheme in ["clickhouse+https"]: - # For HTTPS protocol - port = port or 8443 - interface = "https" - else: - # For HTTP or unspecified, use port 8123 by default - port = port or 8123 - interface = "http" - - key = f"{interface}://{host}:{port}/{database}" - - if key not in cls._clients: - try: - # Create ClickHouse client - client_args: dict[str, Any] = { - "host": host, - "port": port, - "database": database, - "interface": interface, - } - - if username: - client_args["username"] = username - if password: - client_args["password"] = password - - client = clickhouse_connect.get_client(**client_args) - cls._clients[key] = client - - # Test the connection - client.command(TEST_QUERY) - - except Exception as e: - # Remove failed client from cache - cls._clients.pop(key, None) - raise RetryableToolError( - f"Connection failed: {e}", - developer_message="Connection to ClickHouse failed.", - additional_prompt_content="Check the connection string and try again.", - ) from e - - return cls._clients[key] - - @classmethod - async def get_engine(cls, connection_string: str) -> Any: - client = await cls.get_instance(connection_string) - - class ConnectionContextManager: - def __init__(self, client: Any) -> None: - self.client = client - - async def __aenter__(self) -> Any: - return self.client - - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - # Connection cleanup is handled by clickhouse-connect - pass - - return ConnectionContextManager(client) - - @classmethod - async def cleanup(cls) -> None: - """Clean up all cached clients. Call this when shutting down.""" - for client in cls._clients.values(): - with contextlib.suppress(Exception): - client.close() - cls._clients.clear() - - @classmethod - def clear_cache(cls) -> None: - """Clear the client cache without disposing clients. Use with caution.""" - cls._clients.clear() - - @classmethod - def sanitize_query( # noqa: C901 - cls, - select_clause: str, - from_clause: str, - limit: int, - offset: int, - join_clause: str | None, - where_clause: str | None, - having_clause: str | None, - group_by_clause: str | None, - order_by_clause: str | None, - with_clause: str | None, - ) -> tuple[str, dict[str, Any]]: - # Remove the leading keywords from the clauses if they are present - if select_clause.strip().split(" ")[0].upper() == "SELECT": - select_clause = select_clause.strip()[6:] - - if from_clause.strip().split(" ")[0].upper() == "FROM": - from_clause = from_clause.strip()[4:] - - if join_clause and join_clause.strip().split(" ")[0].upper() == "JOIN": - join_clause = join_clause.strip()[4:] - - if where_clause and where_clause.strip().split(" ")[0].upper() == "WHERE": - where_clause = where_clause.strip()[5:] - - if group_by_clause and group_by_clause.strip().split(" ")[0].upper() == "GROUP BY": - group_by_clause = group_by_clause.strip()[8:] - - if order_by_clause and order_by_clause.strip().split(" ")[0].upper() == "ORDER BY": - order_by_clause = order_by_clause.strip()[8:] - - if having_clause and having_clause.strip().split(" ")[0].upper() == "HAVING": - having_clause = having_clause.strip()[6:] - - first_select_word = select_clause.strip().split(" ")[0].upper() - if first_select_word in [ - "INSERT", - "UPDATE", - "DELETE", - "CREATE", - "ALTER", - "DROP", - "TRUNCATE", - "REINDEX", - "VACUUM", - "ANALYZE", - "COMMENT", - "OPTIMIZE", # ClickHouse-specific - "SYSTEM", # ClickHouse-specific - ]: - raise RetryableToolError( - "Only SELECT queries are allowed.", - ) - - if select_clause.strip() == "*": - raise RetryableToolError( - "Do not use * in the select clause. Use a comma separated list of columns you wish to return.", - ) - - if limit > MAX_ROWS_RETURNED: - raise RetryableToolError( - f"Limit is too high. Maximum is {MAX_ROWS_RETURNED}.", - ) - - if offset < 0: - raise RetryableToolError( - "Offset must be greater than or equal to 0.", - developer_message="Offset must be greater than or equal to 0.", - ) - - if limit <= 0: - raise RetryableToolError( - "Limit must be greater than 0.", - developer_message="Limit must be greater than 0.", - ) - - # Build query with identifiers directly interpolated, but use parameters for values - parts = [] - if with_clause: - parts.append(f"WITH {with_clause}") - parts.append(f"SELECT {select_clause} FROM {from_clause}") # noqa: S608 - if join_clause: - parts.append(f"JOIN {join_clause}") - if where_clause: - parts.append(f"WHERE {where_clause}") - if group_by_clause: - parts.append(f"GROUP BY {group_by_clause}") - if having_clause: - parts.append(f"HAVING {having_clause}") - if order_by_clause: - parts.append(f"ORDER BY {order_by_clause}") - parts.append("LIMIT :limit OFFSET :offset") - query = " ".join(parts) - - # Only use parameters for values, not identifiers - parameters = { - "limit": limit, - "offset": offset, - } - - return query, parameters diff --git a/toolkits/clickhouse/arcade_clickhouse/tools/__init__.py b/toolkits/clickhouse/arcade_clickhouse/tools/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/toolkits/clickhouse/arcade_clickhouse/tools/clickhouse.py b/toolkits/clickhouse/arcade_clickhouse/tools/clickhouse.py deleted file mode 100644 index 4c7965596..000000000 --- a/toolkits/clickhouse/arcade_clickhouse/tools/clickhouse.py +++ /dev/null @@ -1,347 +0,0 @@ -from typing import Annotated, Any - -from arcade_mcp_server import Context, tool -from arcade_mcp_server.exceptions import RetryableToolError -from arcade_mcp_server.metadata import Behavior, Operation, ToolMetadata - -from ..database_engine import MAX_ROWS_RETURNED, DatabaseEngine - - -@tool( - requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"], - metadata=ToolMetadata( - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def discover_schemas( - context: Context, -) -> list[str]: - """Discover all the schemas in the ClickHouse database. - - Note: ClickHouse doesn't have schemas like PostgreSQL, so this returns a default schema name. - """ - # ClickHouse doesn't have schemas like PostgreSQL, but we return a default for compatibility - return ["default"] - - -@tool( - requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"], - metadata=ToolMetadata( - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def discover_databases( - context: Context, -) -> list[str]: - """Discover all the databases in the ClickHouse database.""" - async with await DatabaseEngine.get_engine( - context.get_secret("CLICKHOUSE_DATABASE_CONNECTION_STRING") - ) as client: - databases = await _get_databases(client) - return databases - - -@tool( - requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"], - metadata=ToolMetadata( - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def discover_tables( - context: Context, -) -> list[str]: - """Discover all the tables in the ClickHouse database when the list of tables is not known. - - ALWAYS use this tool before any other tool that requires a table name. - """ - async with await DatabaseEngine.get_engine( - context.get_secret("CLICKHOUSE_DATABASE_CONNECTION_STRING") - ) as client: - tables = await _get_tables(client, "default") - return tables - - -@tool( - requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"], - metadata=ToolMetadata( - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def get_table_schema( - context: Context, - schema_name: Annotated[str, "The schema to get the table schema of"], - table_name: Annotated[str, "The table to get the schema of"], -) -> list[str]: - """ - Get the schema/structure of a ClickHouse table in the ClickHouse database when the schema is not known, and the name of the table is provided. - - This tool should ALWAYS be used before executing any query. All tables in the query must be discovered first using the tool. - """ - async with await DatabaseEngine.get_engine( - context.get_secret("CLICKHOUSE_DATABASE_CONNECTION_STRING") - ) as client: - return await _get_table_schema(client, "default", table_name) - - -@tool( - requires_secrets=["CLICKHOUSE_DATABASE_CONNECTION_STRING"], - metadata=ToolMetadata( - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def execute_select_query( - context: Context, - select_clause: Annotated[ - str, - "This is the part of the SQL query that comes after the SELECT keyword wish a comma separated list of columns you wish to return. Do not include the SELECT keyword.", - ], - from_clause: Annotated[ - str, - "This is the part of the SQL query that comes after the FROM keyword. Do not include the FROM keyword.", - ], - limit: Annotated[ - int, - "The maximum number of rows to return. This is the LIMIT clause of the query. Default: 100.", - ] = 100, - offset: Annotated[ - int, "The number of rows to skip. This is the OFFSET clause of the query. Default: 0." - ] = 0, - join_clause: Annotated[ - str | None, - "This is the part of the SQL query that comes after the JOIN keyword. Do not include the JOIN keyword. If no join is needed, leave this blank.", - ] = None, - where_clause: Annotated[ - str | None, - "This is the part of the SQL query that comes after the WHERE keyword. Do not include the WHERE keyword. If no where clause is needed, leave this blank.", - ] = None, - having_clause: Annotated[ - str | None, - "This is the part of the SQL query that comes after the HAVING keyword. Do not include the HAVING keyword. If no having clause is needed, leave this blank.", - ] = None, - group_by_clause: Annotated[ - str | None, - "This is the part of the SQL query that comes after the GROUP BY keyword. Do not include the GROUP BY keyword. If no group by clause is needed, leave this blank.", - ] = None, - order_by_clause: Annotated[ - str | None, - "This is the part of the SQL query that comes after the ORDER BY keyword. Do not include the ORDER BY keyword. If no order by clause is needed, leave this blank.", - ] = None, - with_clause: Annotated[ - str | None, - "This is the part of the SQL query that comes after the WITH keyword when basing the query on a virtual table. If no WITH clause is needed, leave this blank.", - ] = None, -) -> list[str]: - """ - You have a connection to a ClickHouse database. - Execute a SELECT query and return the results against the ClickHouse database. No other queries (INSERT, UPDATE, DELETE, etc.) are allowed. - - ONLY use this tool if you have already loaded the schema of the tables you need to query. Use the tool to load the schema if not already known. - - The final query will be constructed as follows: - SELECT {select_query_part} FROM {from_clause} JOIN {join_clause} WHERE {where_clause} HAVING {having_clause} ORDER BY {order_by_clause} LIMIT {limit} OFFSET {offset} - - When running queries, follow these rules which will help avoid errors: - * Never "select *" from a table. Always select the columns you need. - * Always order your results by the most important columns first. If you aren't sure, order by the primary key. - * Always use case-insensitive queries to match strings in the query. - * Always trim strings in the query. - * Prefer LIKE queries over direct string matches or regex queries. - * Only join on columns that are indexed or the primary key. Do not join on arbitrary columns. - * ClickHouse is case-sensitive, so be careful with table and column names. - """ - async with await DatabaseEngine.get_engine( - context.get_secret("CLICKHOUSE_DATABASE_CONNECTION_STRING") - ) as client: - try: - return await _execute_query( - client, - select_clause=select_clause, - from_clause=from_clause, - limit=limit, - offset=offset, - join_clause=join_clause, - where_clause=where_clause, - having_clause=having_clause, - group_by_clause=group_by_clause, - order_by_clause=order_by_clause, - with_clause=with_clause, - ) - except Exception as e: - raise RetryableToolError( - f"Query failed: {e}", - developer_message=f"Query failed with parameters: select_clause={select_clause}, from_clause={from_clause}, limit={limit}, offset={offset}, join_clause={join_clause}, where_clause={where_clause}, having_clause={having_clause}, order_by_clause={order_by_clause}, with_clause={with_clause}.", - additional_prompt_content="Load the database schema or use the tool to discover the tables and try again.", - retry_after_ms=10, - ) from e - - -async def _get_databases(client: Any) -> list[str]: - """Get all the databases in ClickHouse""" - # ClickHouse uses SHOW DATABASES instead of information_schema - result = client.query("SHOW DATABASES") - databases = [row[0] for row in result.result_rows] - - # Filter out system databases - system_databases = { - "system", - "information_schema", - "INFORMATION_SCHEMA", - "default", - "temporary_tables", - "temporary_tables_metadata", - } - databases = [db for db in databases if db not in system_databases] - databases.sort() - - return databases - - -async def _get_tables(client: Any, database_name: str) -> list[str]: - """Get all the tables in the specified ClickHouse database""" - # ClickHouse uses SHOW TABLES FROM database_name - result = client.query(f"SHOW TABLES FROM {database_name}") - tables = [row[0] for row in result.result_rows] - tables.sort() - - return tables - - -async def _get_table_schema(client: Any, database_name: str, table_name: str) -> list[str]: - """Get the schema of a ClickHouse table""" - # ClickHouse uses DESCRIBE TABLE database_name.table_name - result = client.query(f"DESCRIBE TABLE {database_name}.{table_name}") - columns = result.result_rows - - # Get primary key information - # ClickHouse doesn't have traditional primary keys like PostgreSQL - # Instead, it has sorting keys and primary keys that are part of the table engine - try: - pk_result = client.query(f"SHOW CREATE TABLE {database_name}.{table_name}") - if pk_result.result_rows: - create_statement = pk_result.result_rows[0][0] - # Parse the CREATE statement to extract primary key information - primary_keys = _extract_primary_keys_from_create_statement(create_statement) - else: - primary_keys = set() - except Exception: - primary_keys = set() - - results = [] - for column in columns: - column_name = column[ - 0 - ] # ClickHouse DESCRIBE returns: name, type, default_type, default_expression, comment, codec_expression, ttl_expression - column_type = column[1] - - # Build column description - description = f"{column_name}: {column_type}" - - # Add primary key indicator - if column_name in primary_keys: - description += " (PRIMARY KEY)" - - # Add default value if present - if len(column) > 3 and column[3]: # default_expression - description += f" DEFAULT {column[3]}" - - # Add comment if present - if len(column) > 4 and column[4]: # comment - description += f" COMMENT '{column[4]}'" - - results.append(description) - - return results[:MAX_ROWS_RETURNED] - - -def _extract_primary_keys_from_create_statement(create_statement: str) -> set[str]: - """Extract primary key columns from ClickHouse CREATE TABLE statement""" - primary_keys = set() - - # Look for PRIMARY KEY clause - import re - - pk_match = re.search(r"PRIMARY KEY\s*\(([^)]+)\)", create_statement, re.IGNORECASE) - if pk_match: - pk_columns = pk_match.group(1).split(",") - for col in pk_columns: - primary_keys.add(col.strip().strip("`")) - - # Look for ORDER BY clause (which can also indicate primary key) - order_match = re.search(r"ORDER BY\s*\(([^)]+)\)", create_statement, re.IGNORECASE) - if order_match: - order_columns = order_match.group(1).split(",") - for col in order_columns: - primary_keys.add(col.strip().strip("`")) - - return primary_keys - - -async def _execute_query( - client: Any, - select_clause: str, - from_clause: str, - limit: int, - offset: int, - join_clause: str | None, - where_clause: str | None, - having_clause: str | None, - group_by_clause: str | None, - order_by_clause: str | None, - with_clause: str | None, -) -> list[str]: - """Execute a query and return the results.""" - query, parameters = DatabaseEngine.sanitize_query( - select_clause=select_clause, - from_clause=from_clause, - limit=limit, - offset=offset, - join_clause=join_clause, - where_clause=where_clause, - having_clause=having_clause, - group_by_clause=group_by_clause, - order_by_clause=order_by_clause, - with_clause=with_clause, - ) - print(f"Query: {query}") - print(f"Parameters: {parameters}") - - # For clickhouse-connect, we need to substitute parameters manually - # since it doesn't use SQLAlchemy-style parameter binding - formatted_query = query - for param_name, param_value in parameters.items(): - formatted_query = formatted_query.replace(f":{param_name}", str(param_value)) - - result = client.query(formatted_query) - rows = result.result_rows - results = [str(row) for row in rows] - return results[:MAX_ROWS_RETURNED] diff --git a/toolkits/clickhouse/pyproject.toml b/toolkits/clickhouse/pyproject.toml deleted file mode 100644 index 8442433a9..000000000 --- a/toolkits/clickhouse/pyproject.toml +++ /dev/null @@ -1,66 +0,0 @@ -[build-system] -requires = [ "hatchling",] -build-backend = "hatchling.build" - -[project] -name = "arcade_clickhouse" -version = "0.3.0" -description = "Tools to query and explore a ClickHouse database" -requires-python = ">=3.10" -dependencies = [ - "arcade-mcp-server>=1.17.0,<2.0.0", - "clickhouse-connect>=0.7.0", - "pydantic>=2.11.7", - "sqlalchemy>=2.0.41", - "clickhouse-sqlalchemy>=0.2.0", - "greenlet>=3.2.3", - "aiochsa>=0.1.0", - "setuptools>=80.9.0", -] -[[project.authors]] -name = "evantahler" -email = "support@arcade.dev" - -[project.optional-dependencies] -dev = [ - "arcade-mcp[all]>=1.2.0,<2.0.0", - "pytest>=8.3.0,<8.4.0", - "pytest-cov>=4.0.0,<4.1.0", - "pytest-mock>=3.11.1,<3.12.0", - "pytest-asyncio>=0.24.0,<0.25.0", - "mypy>=1.5.1,<1.6.0", - "pre-commit>=3.4.0,<3.5.0", - "tox>=4.11.1,<4.12.0", - "ruff>=0.7.4,<0.8.0", -] - -[project.scripts] -arcade-clickhouse = "arcade_clickhouse.__main__:main" -arcade_clickhouse = "arcade_clickhouse.__main__:main" - -# Use local path sources for arcade libs when working locally -[tool.uv.sources] -arcade-mcp = { path = "../../", editable = true } -arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true } - -[tool.mypy] -files = [ "arcade_clickhouse/**/*.py",] -python_version = "3.10" -disallow_untyped_defs = "True" -disallow_any_unimported = "True" -no_implicit_optional = "True" -check_untyped_defs = "True" -warn_return_any = "True" -warn_unused_ignores = "True" -show_error_codes = "True" -ignore_missing_imports = "True" - -[tool.pytest.ini_options] -testpaths = [ "tests",] -asyncio_default_fixture_loop_scope = "function" - -[tool.coverage.report] -skip_empty = true - -[tool.hatch.build.targets.wheel] -packages = [ "arcade_clickhouse",] diff --git a/toolkits/clickhouse/tests/__init__.py b/toolkits/clickhouse/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/toolkits/clickhouse/tests/dump.sql b/toolkits/clickhouse/tests/dump.sql deleted file mode 100644 index edb77bcc5..000000000 --- a/toolkits/clickhouse/tests/dump.sql +++ /dev/null @@ -1,369 +0,0 @@ --- ClickHouse test database setup --- This file contains sample data for testing the ClickHouse toolkit --- Create users table -CREATE TABLE IF NOT EXISTS default.users ( - id UInt32, - name String, - email String, - password_hash String, - created_at DateTime, - updated_at DateTime, - status String -) ENGINE = MergeTree() -ORDER BY (id, created_at); --- Create messages table -CREATE TABLE IF NOT EXISTS default.messages ( - id UInt32, - body String, - user_id UInt32, - created_at DateTime, - updated_at DateTime -) ENGINE = MergeTree() -ORDER BY (id, created_at); --- Insert sample data into users table -INSERT INTO default.users ( - id, - name, - email, - password_hash, - created_at, - updated_at, - status - ) -VALUES ( - 1, - 'Alice', - 'alice@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$tMg1Rd3IEDnp3iFKrqsF4Dsbw6/Cbf6seRB/H5bhaPg$zZj5yn4x3D3O3mDHcW2aczQNiYfAs3cw21XMEIgkF0E', - '2024-09-01 20:49:38', - '2024-09-02 03:49:39', - 'active' - ), - ( - 2, - 'Bob', - 'bob@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$CvOMK1WUd99R7kYXpiBPNYw4OQP53pYIgeMnwz92mrE$HPthId4phMoPT1TWuCRHHCr9BSQA8XoUkQuB1HZsqTY', - '2024-09-02 17:49:23', - '2024-09-02 17:49:23', - 'active' - ), - ( - 3, - 'Charlie', - 'charlie@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$paCAAD1HVZkncP/WvecuUO6zFXp2/8BISpgr5rXRxps$M5kBFc9JHHGNw9SXnPu2ggpJY0mFFCska7TXMrllndo', - '2024-09-03 10:30:15', - '2024-09-03 10:30:15', - 'active' - ), - ( - 4, - 'Diana', - 'diana@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$xyz123ABC456DEF789GHI$SampleHashForDiana123', - '2024-09-04 14:20:30', - '2024-09-04 14:20:30', - 'active' - ), - ( - 5, - 'Evan', - 'evan@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$evanHash123$EvanPasswordHash456', - '2024-09-05 09:15:45', - '2024-09-05 09:15:45', - 'active' - ), - ( - 6, - 'Fiona', - 'fiona@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$fionaHash456$FionaPasswordHash789', - '2024-09-06 16:45:12', - '2024-09-06 16:45:12', - 'active' - ), - ( - 7, - 'George', - 'george@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$georgeHash789$GeorgePasswordHash012', - '2024-09-07 11:30:25', - '2024-09-07 11:30:25', - 'active' - ), - ( - 8, - 'Helen', - 'helen@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$helenHash012$HelenPasswordHash345', - '2024-09-08 13:25:40', - '2024-09-08 13:25:40', - 'active' - ), - ( - 9, - 'Ian', - 'ian@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$ianHash345$IanPasswordHash678', - '2024-09-09 08:40:55', - '2024-09-09 08:40:55', - 'active' - ), - ( - 10, - 'Julia', - 'julia@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$juliaHash678$JuliaPasswordHash901', - '2024-09-10 15:55:18', - '2024-09-10 15:55:18', - 'active' - ); --- Insert sample data into messages table -INSERT INTO default.messages (id, body, user_id, created_at, updated_at) -VALUES ( - 1, - 'Hello everyone!', - 1, - '2025-01-10 10:00:00', - '2025-01-10 10:00:00' - ), - ( - 2, - 'How is everyone doing today?', - 1, - '2025-01-10 11:30:00', - '2025-01-10 11:30:00' - ), - ( - 3, - 'Great to see you all here!', - 1, - '2025-01-10 14:15:00', - '2025-01-10 14:15:00' - ), - ( - 4, - 'Hi Alice! Doing well, thanks for asking.', - 2, - '2025-01-10 11:35:00', - '2025-01-10 11:35:00' - ), - ( - 5, - 'Anyone up for a game later?', - 2, - '2025-01-10 16:20:00', - '2025-01-10 16:20:00' - ), - ( - 6, - 'Count me in for the game!', - 3, - '2025-01-10 16:25:00', - '2025-01-10 16:25:00' - ), - ( - 7, - 'What time works for everyone?', - 3, - '2025-01-10 16:30:00', - '2025-01-10 16:30:00' - ), - ( - 8, - 'I can play around 8 PM', - 3, - '2025-01-10 17:00:00', - '2025-01-10 17:00:00' - ), - ( - 9, - '8 PM works for me too!', - 4, - '2025-01-10 17:05:00', - '2025-01-10 17:05:00' - ), - ( - 10, - 'What game should we play?', - 4, - '2025-01-10 17:10:00', - '2025-01-10 17:10:00' - ), - ( - 11, - 'I suggest we try the new arcade game!', - 5, - '2025-01-10 17:15:00', - '2025-01-10 17:15:00' - ), - ( - 12, - 'It has great multiplayer features', - 5, - '2025-01-10 17:20:00', - '2025-01-10 17:20:00' - ), - ( - 13, - 'Perfect timing for a weekend session', - 5, - '2025-01-10 18:00:00', - '2025-01-10 18:00:00' - ), - ( - 26, - 'Just finished setting up the game server!', - 5, - '2025-01-10 20:00:00', - '2025-01-10 20:00:00' - ), - ( - 27, - 'Everyone should be able to connect now', - 5, - '2025-01-10 20:05:00', - '2025-01-10 20:05:00' - ), - ( - 28, - 'I added some custom maps too', - 5, - '2025-01-10 20:10:00', - '2025-01-10 20:10:00' - ), - ( - 29, - 'The graphics look amazing on this new version', - 5, - '2025-01-10 20:15:00', - '2025-01-10 20:15:00' - ), - ( - 30, - 'Hope you all enjoy the new features', - 5, - '2025-01-10 20:20:00', - '2025-01-10 20:20:00' - ), - ( - 31, - 'I also set up a leaderboard system', - 5, - '2025-01-10 20:25:00', - '2025-01-10 20:25:00' - ), - ( - 32, - 'We can track high scores now', - 5, - '2025-01-10 20:30:00', - '2025-01-10 20:30:00' - ), - ( - 33, - 'The game supports up to 8 players simultaneously', - 5, - '2025-01-10 20:35:00', - '2025-01-10 20:35:00' - ), - ( - 34, - 'I tested it earlier and it runs smoothly', - 5, - '2025-01-10 20:40:00', - '2025-01-10 20:40:00' - ), - ( - 35, - 'Cannot wait to see everyone online tonight!', - 5, - '2025-01-10 20:45:00', - '2025-01-10 20:45:00' - ), - ( - 14, - 'Sounds like fun! I love arcade games.', - 6, - '2025-01-10 18:05:00', - '2025-01-10 18:05:00' - ), - ( - 15, - 'Should I bring snacks?', - 6, - '2025-01-10 18:10:00', - '2025-01-10 18:10:00' - ), - ( - 16, - 'Snacks are always welcome!', - 7, - '2025-01-10 18:15:00', - '2025-01-10 18:15:00' - ), - ( - 17, - 'I can bring some drinks', - 7, - '2025-01-10 18:20:00', - '2025-01-10 18:20:00' - ), - ( - 18, - 'This is going to be awesome', - 7, - '2025-01-10 19:00:00', - '2025-01-10 19:00:00' - ), - ( - 19, - 'I agree! Cannot wait for the game night.', - 8, - '2025-01-10 19:05:00', - '2025-01-10 19:05:00' - ), - ( - 20, - 'Should we set up a Discord call?', - 8, - '2025-01-10 19:10:00', - '2025-01-10 19:10:00' - ), - ( - 21, - 'Discord would be perfect for voice chat', - 9, - '2025-01-10 19:15:00', - '2025-01-10 19:15:00' - ), - ( - 22, - 'I will create a server for us', - 9, - '2025-01-10 19:20:00', - '2025-01-10 19:20:00' - ), - ( - 23, - 'Link will be shared in a few minutes', - 9, - '2025-01-10 19:25:00', - '2025-01-10 19:25:00' - ), - ( - 24, - 'Thanks Ian! You are the best.', - 10, - '2025-01-10 19:30:00', - '2025-01-10 19:30:00' - ), - ( - 25, - 'See you all at 8 PM!', - 10, - '2025-01-10 19:35:00', - '2025-01-10 19:35:00' - ); diff --git a/toolkits/clickhouse/tests/test_clickhouse.py b/toolkits/clickhouse/tests/test_clickhouse.py deleted file mode 100644 index bad0a42d4..000000000 --- a/toolkits/clickhouse/tests/test_clickhouse.py +++ /dev/null @@ -1,195 +0,0 @@ -import os -from os import environ -from unittest.mock import MagicMock - -import pytest -import pytest_asyncio -from arcade_clickhouse.tools.clickhouse import ( - DatabaseEngine, - discover_schemas, - discover_tables, - execute_select_query, - get_table_schema, -) -from arcade_mcp_server import Context -from arcade_mcp_server.exceptions import RetryableToolError - -CLICKHOUSE_DATABASE_CONNECTION_STRING = ( - environ.get("TEST_CLICKHOUSE_DATABASE_CONNECTION_STRING") - or "clickhouse+native://localhost:9000/default" -) - - -@pytest.fixture -def mock_context(): - context = MagicMock(spec=Context) - context.get_secret = MagicMock(return_value=CLICKHOUSE_DATABASE_CONNECTION_STRING) - return context - - -# before the tests, restore the database from the dump -@pytest_asyncio.fixture(autouse=True) -async def restore_database(): - import clickhouse_connect - - # Create client for database setup - client = clickhouse_connect.get_client(host="localhost", port=8123) - - # Clear existing tables first to avoid duplicates - client.command("DROP TABLE IF EXISTS default.messages") - client.command("DROP TABLE IF EXISTS default.users") - - # Read and execute the dump file - with open(f"{os.path.dirname(__file__)}/dump.sql") as f: - queries = f.read().split(";") - for query in queries: - if query.strip(): - client.command(query) - - client.close() - - -@pytest_asyncio.fixture(autouse=True) -async def cleanup_engines(): - """Clean up database engines after each test to prevent connection leaks.""" - yield - # Clean up all cached engines after each test - await DatabaseEngine.cleanup() - - -@pytest.mark.asyncio -async def test_discover_schemas(mock_context) -> None: - assert await discover_schemas(mock_context) == ["default"] - - -@pytest.mark.asyncio -async def test_discover_tables(mock_context) -> None: - tables = await discover_tables(mock_context) - assert sorted(tables) == ["messages", "users"] - - -@pytest.mark.asyncio -async def test_get_table_schema(mock_context) -> None: - users_schema = await get_table_schema(mock_context, "default", "users") - expected_users = [ - "id: UInt32 (PRIMARY KEY)", - "name: String", - "email: String", - "password_hash: String", - "created_at: DateTime (PRIMARY KEY)", - "updated_at: DateTime", - "status: String", - ] - assert users_schema == expected_users - - messages_schema = await get_table_schema(mock_context, "default", "messages") - expected_messages = [ - "id: UInt32 (PRIMARY KEY)", - "body: String", - "user_id: UInt32", - "created_at: DateTime (PRIMARY KEY)", - "updated_at: DateTime", - ] - assert messages_schema == expected_messages - - -@pytest.mark.asyncio -async def test_execute_select_query(mock_context) -> None: - # Test specific user query with limit - result1 = await execute_select_query( - mock_context, - select_clause="id, name, email", - from_clause="users", - where_clause="id = 1", - limit=1, - ) - assert result1 == ["(1, 'Alice', 'alice@example.com')"] - - # Test query with offset - result2 = await execute_select_query( - mock_context, - select_clause="id, name, email", - from_clause="users", - order_by_clause="id", - limit=1, - offset=1, - ) - assert result2 == ["(2, 'Bob', 'bob@example.com')"] - - -@pytest.mark.asyncio -async def test_execute_select_query_with_keywords(mock_context) -> None: - result = await execute_select_query( - mock_context, - select_clause="SELECT id, name, email", - from_clause="FROM users", - limit=1, - ) - assert result == ["(1, 'Alice', 'alice@example.com')"] - - -@pytest.mark.asyncio -async def test_execute_select_query_with_join(mock_context) -> None: - result = await execute_select_query( - mock_context, - select_clause="u.id, u.name, u.email, m.id, m.body", - from_clause="users u", - join_clause="messages m ON u.id = m.user_id", - limit=1, - ) - assert result == ["(1, 'Alice', 'alice@example.com', 1, 'Hello everyone!')"] - - -@pytest.mark.asyncio -async def test_execute_select_query_with_group_by(mock_context) -> None: - result = await execute_select_query( - mock_context, - select_clause="u.name, COUNT(m.id) AS message_count", - from_clause="messages m", - join_clause="users u ON m.user_id = u.id", - group_by_clause="u.name", - order_by_clause="message_count DESC", - limit=2, - ) - assert result == [ - "('Evan', 13)", - "('Alice', 3)", - ] - - -@pytest.mark.asyncio -async def test_execute_select_query_with_no_results(mock_context) -> None: - # does not raise an error - assert ( - await execute_select_query( - mock_context, - select_clause="id, name, email", - from_clause="users", - where_clause="id = 9999999999", - ) - == [] - ) - - -@pytest.mark.asyncio -async def test_execute_select_query_with_problem(mock_context) -> None: - # 'foo' is not a valid id - with pytest.raises(RetryableToolError) as e: - await execute_select_query( - mock_context, - select_clause="*", - from_clause="users", - where_clause="id = 'foo'", - ) - assert "Do not use * in the select clause" in str(e.value) - - -@pytest.mark.asyncio -async def test_execute_select_query_rejects_non_select(mock_context) -> None: - with pytest.raises(RetryableToolError) as e: - await execute_select_query( - mock_context, - select_clause="INSERT INTO users (name, email, password_hash) VALUES ('Luigi', 'luigi@example.com', 'password')", - from_clause="users", - ) - assert "Only SELECT queries are allowed" in str(e.value) diff --git a/toolkits/clickhouse/tests/test_setup.sh b/toolkits/clickhouse/tests/test_setup.sh deleted file mode 100755 index 583aa4081..000000000 --- a/toolkits/clickhouse/tests/test_setup.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash - -docker run -d --name some-clickhouse-server --ulimit nofile=262144:262144 -p 8123:8123 -p 8443:8443 -p 9000:9000 yandex/clickhouse-server diff --git a/toolkits/linkedin/.pre-commit-config.yaml b/toolkits/linkedin/.pre-commit-config.yaml deleted file mode 100644 index bc01e2436..000000000 --- a/toolkits/linkedin/.pre-commit-config.yaml +++ /dev/null @@ -1,18 +0,0 @@ -files: ^.*/linkedin/.* -repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: "v4.4.0" - hooks: - - id: check-case-conflict - - id: check-merge-conflict - - id: check-toml - - id: check-yaml - - id: end-of-file-fixer - - id: trailing-whitespace - - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.7 - hooks: - - id: ruff - args: [--fix] - - id: ruff-format diff --git a/toolkits/linkedin/.ruff.toml b/toolkits/linkedin/.ruff.toml deleted file mode 100644 index f1aed90fc..000000000 --- a/toolkits/linkedin/.ruff.toml +++ /dev/null @@ -1,46 +0,0 @@ -target-version = "py310" -line-length = 100 -fix = true - -[lint] -select = [ - # flake8-2020 - "YTT", - # flake8-bandit - "S", - # flake8-bugbear - "B", - # flake8-builtins - "A", - # flake8-comprehensions - "C4", - # flake8-debugger - "T10", - # flake8-simplify - "SIM", - # isort - "I", - # mccabe - "C90", - # pycodestyle - "E", "W", - # pyflakes - "F", - # pygrep-hooks - "PGH", - # pyupgrade - "UP", - # ruff - "RUF", - # tryceratops - "TRY", -] - -[lint.per-file-ignores] -"*" = ["TRY003", "B904"] -"**/tests/*" = ["S101", "E501"] -"**/evals/*" = ["S101", "E501"] - -[format] -preview = true -skip-magic-trailing-comma = false diff --git a/toolkits/linkedin/LICENSE b/toolkits/linkedin/LICENSE deleted file mode 100644 index dfbb8b76d..000000000 --- a/toolkits/linkedin/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2025, Arcade AI - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/toolkits/linkedin/Makefile b/toolkits/linkedin/Makefile deleted file mode 100644 index 0a8969beb..000000000 --- a/toolkits/linkedin/Makefile +++ /dev/null @@ -1,55 +0,0 @@ -.PHONY: help - -help: - @echo "๐Ÿ› ๏ธ github Commands:\n" - @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' - -.PHONY: install -install: ## Install the uv environment and install all packages with dependencies - @echo "๐Ÿš€ Creating virtual environment and installing all packages using uv" - @uv sync --active --all-extras --no-sources - @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi - @echo "โœ… All packages and dependencies installed via uv" - -.PHONY: install-local -install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources - @echo "๐Ÿš€ Creating virtual environment and installing all packages using uv" - @uv sync --active --all-extras - @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi - @echo "โœ… All packages and dependencies installed via uv" - -.PHONY: build -build: clean-build ## Build wheel file using poetry - @echo "๐Ÿš€ Creating wheel file" - uv build - -.PHONY: clean-build -clean-build: ## clean build artifacts - @echo "๐Ÿ—‘๏ธ Cleaning dist directory" - rm -rf dist - -.PHONY: test -test: ## Test the code with pytest - @echo "๐Ÿš€ Testing code: Running pytest" - @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml - -.PHONY: coverage -coverage: ## Generate coverage report - @echo "coverage report" - @uv run --no-sources coverage report - @echo "Generating coverage report" - @uv run --no-sources coverage html - -.PHONY: bump-version -bump-version: ## Bump the version in the pyproject.toml file by a patch version - @echo "๐Ÿš€ Bumping version in pyproject.toml" - uv version --no-sources --bump patch - -.PHONY: check -check: ## Run code quality tools. - @if [ -f .pre-commit-config.yaml ]; then\ - echo "๐Ÿš€ Linting code: Running pre-commit";\ - uv run --no-sources pre-commit run -a;\ - fi - @echo "๐Ÿš€ Static type checking: Running mypy" - @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/linkedin/arcade_linkedin/__init__.py b/toolkits/linkedin/arcade_linkedin/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/toolkits/linkedin/arcade_linkedin/__main__.py b/toolkits/linkedin/arcade_linkedin/__main__.py deleted file mode 100644 index 01f392d0c..000000000 --- a/toolkits/linkedin/arcade_linkedin/__main__.py +++ /dev/null @@ -1,28 +0,0 @@ -import sys -from typing import cast - -from arcade_mcp_server import MCPApp -from arcade_mcp_server.mcp_app import TransportType - -import arcade_linkedin - -app = MCPApp( - name="LinkedIn", - instructions=( - "Use this server when you need to interact with LinkedIn to help users " - "create and share posts on their LinkedIn profile." - ), -) - -app.add_tools_from_module(arcade_linkedin) - - -def main() -> None: - transport = sys.argv[1] if len(sys.argv) > 1 else "stdio" - host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1" - port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000 - app.run(transport=cast(TransportType, transport), host=host, port=port) - - -if __name__ == "__main__": - main() diff --git a/toolkits/linkedin/arcade_linkedin/tools/__init__.py b/toolkits/linkedin/arcade_linkedin/tools/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/toolkits/linkedin/arcade_linkedin/tools/constants.py b/toolkits/linkedin/arcade_linkedin/tools/constants.py deleted file mode 100644 index 187eeaed5..000000000 --- a/toolkits/linkedin/arcade_linkedin/tools/constants.py +++ /dev/null @@ -1 +0,0 @@ -LINKEDIN_BASE_URL = "https://api.linkedin.com/v2" diff --git a/toolkits/linkedin/arcade_linkedin/tools/share.py b/toolkits/linkedin/arcade_linkedin/tools/share.py deleted file mode 100644 index 53a6dbaaf..000000000 --- a/toolkits/linkedin/arcade_linkedin/tools/share.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Annotated - -from arcade_mcp_server import Context, tool -from arcade_mcp_server.auth import LinkedIn -from arcade_mcp_server.exceptions import ToolExecutionError -from arcade_mcp_server.metadata import ( - Behavior, - Classification, - Operation, - ServiceDomain, - ToolMetadata, -) - -from arcade_linkedin.tools.utils import _handle_linkedin_api_error, _send_linkedin_request - - -@tool( - requires_auth=LinkedIn( - scopes=["w_member_social"], - ), - metadata=ToolMetadata( - classification=Classification( - service_domains=[ServiceDomain.SOCIAL_MEDIA], - ), - behavior=Behavior( - operations=[Operation.CREATE], - read_only=False, - destructive=False, - idempotent=False, - open_world=True, - ), - ), -) -async def create_text_post( - context: Context, - text: Annotated[str, "The text content of the post"], -) -> Annotated[str, "URL of the shared post"]: - """Share a new text post to LinkedIn.""" - endpoint = "/ugcPosts" - - # The LinkedIn user ID is required to create a post, even though we're using - # the user's access token. - # Arcade Engine gets the current user's info from LinkedIn and automatically - # populates context.authorization.user_info. - # LinkedIn calls the user ID "sub" in their user_info data payload. See: - # https://learn.microsoft.com/en-us/linkedin/consumer/integrations/self-serve/sign-in-with-linkedin-v2#api-request-to-retreive-member-details - user_id = context.authorization.user_info.get("sub") if context.authorization else None - - if not user_id: - raise ToolExecutionError( - "User ID not found.", - developer_message="User ID not found in `context.authorization.user_info.sub`", - ) - - author_id = f"urn:li:person:{user_id}" - payload = { - "author": author_id, - "lifecycleState": "PUBLISHED", - "specificContent": { - "com.linkedin.ugc.ShareContent": { - "shareCommentary": {"text": text}, - "shareMediaCategory": "NONE", - } - }, - "visibility": {"com.linkedin.ugc.MemberNetworkVisibility": "PUBLIC"}, - } - - response = await _send_linkedin_request(context, "POST", endpoint, json_data=payload) - - if response.status_code >= 200 and response.status_code < 300: - share_id = response.json().get("id") - return f"https://www.linkedin.com/feed/update/{share_id}/" - - _handle_linkedin_api_error(response) - - return "" diff --git a/toolkits/linkedin/arcade_linkedin/tools/utils.py b/toolkits/linkedin/arcade_linkedin/tools/utils.py deleted file mode 100644 index fb457ec22..000000000 --- a/toolkits/linkedin/arcade_linkedin/tools/utils.py +++ /dev/null @@ -1,68 +0,0 @@ -import httpx -from arcade_mcp_server import Context -from arcade_mcp_server.exceptions import ToolExecutionError - -from arcade_linkedin.tools.constants import LINKEDIN_BASE_URL - - -async def _send_linkedin_request( - context: Context, - method: str, - endpoint: str, - params: dict | None = None, - json_data: dict | None = None, -) -> httpx.Response: - """ - Send an asynchronous request to the LinkedIn API. - - Args: - context: The tool context containing the authorization token. - method: The HTTP method (GET, POST, PUT, DELETE, etc.). - endpoint: The API endpoint path (e.g., "/ugcPosts"). - params: Query parameters to include in the request. - json_data: JSON data to include in the request body. - - Returns: - The response object from the API request. - - Raises: - ToolExecutionError: If the request fails for any reason. - """ - url = f"{LINKEDIN_BASE_URL}{endpoint}" - token = ( - context.authorization.token if context.authorization and context.authorization.token else "" - ) - headers = {"Authorization": f"Bearer {token}"} - - async with httpx.AsyncClient() as client: - try: - response = await client.request( - method, url, headers=headers, params=params, json=json_data - ) - response.raise_for_status() - except httpx.RequestError as e: - raise ToolExecutionError(f"Failed to send request to LinkedIn API: {e}") - - return response - - -def _handle_linkedin_api_error(response: httpx.Response) -> None: - """ - Handle errors from the LinkedIn API by mapping common status codes to ToolExecutionErrors. - - Args: - response: The response object from the API request. - - Raises: - ToolExecutionError: If the response contains an error status code. - """ - status_code_map = { - 401: ToolExecutionError("Unauthorized: Invalid or expired token"), - 403: ToolExecutionError("Forbidden: User does not have Spotify Premium"), - 429: ToolExecutionError("Too Many Requests: Rate limit exceeded"), - } - - if response.status_code in status_code_map: - raise status_code_map[response.status_code] - elif response.status_code >= 400: - raise ToolExecutionError(f"Error: {response.status_code} - {response.text}") diff --git a/toolkits/linkedin/conftest.py b/toolkits/linkedin/conftest.py deleted file mode 100644 index 59c0562de..000000000 --- a/toolkits/linkedin/conftest.py +++ /dev/null @@ -1,24 +0,0 @@ -from unittest.mock import MagicMock - -import pytest -from arcade_mcp_server import Context - - -@pytest.fixture -def tool_context(): - """Fixture for the tool Context with mock authorization.""" - context = MagicMock(spec=Context) - authorization = MagicMock() - authorization.token = "test_token" # noqa: S105 - authorization.user_info = {"sub": "test_user"} - context.authorization = authorization - return context - - -@pytest.fixture -def mock_httpx_client(mocker): - """Fixture to mock the httpx.AsyncClient.""" - # Mock the AsyncClient context manager - mock_client = mocker.patch("httpx.AsyncClient", autospec=True) - async_mock_client = mock_client.return_value.__aenter__.return_value - return async_mock_client diff --git a/toolkits/linkedin/evals/eval_linkedin.py b/toolkits/linkedin/evals/eval_linkedin.py deleted file mode 100644 index 9156aa502..000000000 --- a/toolkits/linkedin/evals/eval_linkedin.py +++ /dev/null @@ -1,48 +0,0 @@ -from arcade_core import ToolCatalog -from arcade_evals import ( - EvalRubric, - EvalSuite, - ExpectedToolCall, - SimilarityCritic, - tool_eval, -) - -import arcade_linkedin -from arcade_linkedin.tools.share import create_text_post - -rubric = EvalRubric( - fail_threshold=0.85, - warn_threshold=0.95, -) - - -catalog = ToolCatalog() -catalog.add_module(arcade_linkedin) - - -@tool_eval() -def linkedin_eval_suite() -> EvalSuite: - suite = EvalSuite( - name="LinkedIn Tools Evaluation", - system_message="You are an AI assistant with access to LinkedIn tools. Use them to help the user with their tasks.", - catalog=catalog, - rubric=rubric, - ) - - suite.add_case( - name="Run code", - user_message="post this transcription to linkedin. there may be some things that you need to clean up since it was spoken.: 'It is with great pleasure that I announce that I am now a member of the LinkedIn community! I'd like to thank the LinkedIn team for their support and encouragement in my journey to success. hash tag Y2K'", - expected_tool_calls=[ - ExpectedToolCall( - func=create_text_post, - args={ - "text": "It is with great pleasure that I announce that I am now a member of the LinkedIn community! I'd like to thank the LinkedIn team for their support and encouragement in my journey to success. #Y2K", - }, - ) - ], - critics=[ - SimilarityCritic(critic_field="text", weight=1.0), - ], - ) - - return suite diff --git a/toolkits/linkedin/pyproject.toml b/toolkits/linkedin/pyproject.toml deleted file mode 100644 index e2c8b937b..000000000 --- a/toolkits/linkedin/pyproject.toml +++ /dev/null @@ -1,59 +0,0 @@ -[build-system] -requires = [ "hatchling",] -build-backend = "hatchling.build" - -[project] -name = "arcade_linkedin" -version = "0.3.0" -description = "Arcade.dev LLM tools for LinkedIn" -requires-python = ">=3.10" -dependencies = [ - "arcade-mcp-server>=1.17.0,<2.0.0", - "httpx>=0.27.2,<1.0.0", -] -[[project.authors]] -name = "Arcade" -email = "dev@arcade.dev" - -[project.scripts] -arcade-linkedin = "arcade_linkedin.__main__:main" -arcade_linkedin = "arcade_linkedin.__main__:main" - -[project.optional-dependencies] -dev = [ - "arcade-mcp[all]>=1.2.0,<2.0.0", - "pytest>=8.3.0,<8.4.0", - "pytest-cov>=4.0.0,<4.1.0", - "pytest-asyncio>=0.24.0,<0.25.0", - "pytest-mock>=3.11.1,<3.12.0", - "mypy>=1.5.1,<1.6.0", - "pre-commit>=3.4.0,<3.5.0", - "tox>=4.11.1,<4.12.0", - "ruff>=0.7.4,<0.8.0", -] - -# Use local path sources for arcade libs when working locally -[tool.uv.sources] -arcade-mcp = {path = "../../", editable = true} -arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true } - -[tool.mypy] -files = [ "arcade_linkedin/**/*.py",] -python_version = "3.10" -disallow_untyped_defs = "True" -disallow_any_unimported = "True" -no_implicit_optional = "True" -check_untyped_defs = "True" -warn_return_any = "True" -warn_unused_ignores = "True" -show_error_codes = "True" -ignore_missing_imports = "True" - -[tool.pytest.ini_options] -testpaths = [ "tests",] - -[tool.coverage.report] -skip_empty = true - -[tool.hatch.build.targets.wheel] -packages = [ "arcade_linkedin",] diff --git a/toolkits/linkedin/tests/__init__.py b/toolkits/linkedin/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/toolkits/linkedin/tests/test_share.py b/toolkits/linkedin/tests/test_share.py deleted file mode 100644 index 590f8c97d..000000000 --- a/toolkits/linkedin/tests/test_share.py +++ /dev/null @@ -1,35 +0,0 @@ -from unittest.mock import AsyncMock, MagicMock - -import pytest -from arcade_mcp_server.exceptions import ToolExecutionError - -from arcade_linkedin.tools.share import create_text_post - - -@pytest.mark.asyncio -async def test_create_text_post_success(tool_context, mock_httpx_client): - """Test successful creation of a LinkedIn text post.""" - # Mock response for a successful post creation - mock_response = MagicMock() - mock_response.status_code = 201 - mock_response.json.return_value = {"id": "1234567890"} - # Ensure the mock is awaited properly - mock_httpx_client.request = AsyncMock(return_value=mock_response) - - post_text = "Hello, LinkedIn!" - result = await create_text_post(tool_context, post_text) - - expected_url = "https://www.linkedin.com/feed/update/1234567890/" - assert result == expected_url - mock_httpx_client.request.assert_called_once() - - -@pytest.mark.asyncio -async def test_create_text_post_no_user_id(tool_context): - """Test error when user ID is not found in the context.""" - # Simulate missing user ID in the context - tool_context.authorization.user_info = {} - - post_text = "Hello, LinkedIn!" - with pytest.raises(ToolExecutionError, match="User ID not found"): - await create_text_post(tool_context, post_text) diff --git a/toolkits/math/.pre-commit-config.yaml b/toolkits/math/.pre-commit-config.yaml deleted file mode 100644 index 3e1fd287d..000000000 --- a/toolkits/math/.pre-commit-config.yaml +++ /dev/null @@ -1,18 +0,0 @@ -files: ^.*/math/.* -repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: "v4.4.0" - hooks: - - id: check-case-conflict - - id: check-merge-conflict - - id: check-toml - - id: check-yaml - - id: end-of-file-fixer - - id: trailing-whitespace - - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.7 - hooks: - - id: ruff - args: [--fix] - - id: ruff-format diff --git a/toolkits/math/.ruff.toml b/toolkits/math/.ruff.toml deleted file mode 100644 index 19364180c..000000000 --- a/toolkits/math/.ruff.toml +++ /dev/null @@ -1,47 +0,0 @@ -target-version = "py310" -line-length = 100 -fix = true - -[lint] -select = [ - # flake8-2020 - "YTT", - # flake8-bandit - "S", - # flake8-bugbear - "B", - # flake8-builtins - "A", - # flake8-comprehensions - "C4", - # flake8-debugger - "T10", - # flake8-simplify - "SIM", - # isort - "I", - # mccabe - "C90", - # pycodestyle - "E", "W", - # pyflakes - "F", - # pygrep-hooks - "PGH", - # pyupgrade - "UP", - # ruff - "RUF", - # tryceratops - "TRY", -] - -[lint.per-file-ignores] -"*" = ["TRY003", "B904"] -"**/tests/*" = ["S101", "E501"] -"**/evals/*" = ["S101", "E501"] - - -[format] -preview = true -skip-magic-trailing-comma = false diff --git a/toolkits/math/LICENSE b/toolkits/math/LICENSE deleted file mode 100644 index dfbb8b76d..000000000 --- a/toolkits/math/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2025, Arcade AI - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/toolkits/math/Makefile b/toolkits/math/Makefile deleted file mode 100644 index 0a8969beb..000000000 --- a/toolkits/math/Makefile +++ /dev/null @@ -1,55 +0,0 @@ -.PHONY: help - -help: - @echo "๐Ÿ› ๏ธ github Commands:\n" - @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' - -.PHONY: install -install: ## Install the uv environment and install all packages with dependencies - @echo "๐Ÿš€ Creating virtual environment and installing all packages using uv" - @uv sync --active --all-extras --no-sources - @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi - @echo "โœ… All packages and dependencies installed via uv" - -.PHONY: install-local -install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources - @echo "๐Ÿš€ Creating virtual environment and installing all packages using uv" - @uv sync --active --all-extras - @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi - @echo "โœ… All packages and dependencies installed via uv" - -.PHONY: build -build: clean-build ## Build wheel file using poetry - @echo "๐Ÿš€ Creating wheel file" - uv build - -.PHONY: clean-build -clean-build: ## clean build artifacts - @echo "๐Ÿ—‘๏ธ Cleaning dist directory" - rm -rf dist - -.PHONY: test -test: ## Test the code with pytest - @echo "๐Ÿš€ Testing code: Running pytest" - @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml - -.PHONY: coverage -coverage: ## Generate coverage report - @echo "coverage report" - @uv run --no-sources coverage report - @echo "Generating coverage report" - @uv run --no-sources coverage html - -.PHONY: bump-version -bump-version: ## Bump the version in the pyproject.toml file by a patch version - @echo "๐Ÿš€ Bumping version in pyproject.toml" - uv version --no-sources --bump patch - -.PHONY: check -check: ## Run code quality tools. - @if [ -f .pre-commit-config.yaml ]; then\ - echo "๐Ÿš€ Linting code: Running pre-commit";\ - uv run --no-sources pre-commit run -a;\ - fi - @echo "๐Ÿš€ Static type checking: Running mypy" - @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/math/arcade_math/__init__.py b/toolkits/math/arcade_math/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/toolkits/math/arcade_math/__main__.py b/toolkits/math/arcade_math/__main__.py deleted file mode 100644 index 3703f55a5..000000000 --- a/toolkits/math/arcade_math/__main__.py +++ /dev/null @@ -1,29 +0,0 @@ -import sys -from typing import cast - -from arcade_mcp_server import MCPApp -from arcade_mcp_server.mcp_app import TransportType - -import arcade_math - -app = MCPApp( - name="Math", - instructions=( - "Use this server when you need to perform mathematical calculations to help users " - "with arithmetic, trigonometry, statistics, exponents, rounding, and other math operations." - ), -) - -app.add_tools_from_module(arcade_math) - - -def main() -> None: - transport = sys.argv[1] if len(sys.argv) > 1 else "stdio" - host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1" - port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000 - - app.run(transport=cast(TransportType, transport), host=host, port=port) - - -if __name__ == "__main__": - main() diff --git a/toolkits/math/arcade_math/tools/__init__.py b/toolkits/math/arcade_math/tools/__init__.py deleted file mode 100644 index bee1b80af..000000000 --- a/toolkits/math/arcade_math/tools/__init__.py +++ /dev/null @@ -1,65 +0,0 @@ -from arcade_math.tools.arithmetic import ( - add, - divide, - mod, - multiply, - subtract, - sum_list, - sum_range, -) -from arcade_math.tools.exponents import ( - log, - power, -) -from arcade_math.tools.miscellaneous import ( - abs_val, - factorial, - sqrt, -) -from arcade_math.tools.random import ( - generate_random_float, - generate_random_int, -) -from arcade_math.tools.rational import ( - gcd, - lcm, -) -from arcade_math.tools.rounding import ( - ceil, - floor, - round_num, -) -from arcade_math.tools.statistics import ( - avg, - median, -) -from arcade_math.tools.trigonometry import ( - deg_to_rad, - rad_to_deg, -) - -__all__ = [ - "abs_val", - "add", - "avg", - "ceil", - "deg_to_rad", - "divide", - "factorial", - "floor", - "gcd", - "generate_random_float", - "generate_random_int", - "lcm", - "log", - "median", - "mod", - "multiply", - "power", - "rad_to_deg", - "round_num", - "sqrt", - "subtract", - "sum_list", - "sum_range", -] diff --git a/toolkits/math/arcade_math/tools/arithmetic.py b/toolkits/math/arcade_math/tools/arithmetic.py deleted file mode 100644 index 2ae5f8f8d..000000000 --- a/toolkits/math/arcade_math/tools/arithmetic.py +++ /dev/null @@ -1,161 +0,0 @@ -import decimal -from decimal import Decimal -from typing import Annotated - -from arcade_mcp_server import tool -from arcade_mcp_server.metadata import Behavior, ToolMetadata - -decimal.getcontext().prec = 100 - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def add( - a: Annotated[str, "The first number as a string"], - b: Annotated[str, "The second number as a string"], -) -> Annotated[str, "The sum of the two numbers as a string"]: - """ - Add two numbers together - """ - # Use Decimal for arbitrary precision - a_decimal = Decimal(a) - b_decimal = Decimal(b) - return str(a_decimal + b_decimal) - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def subtract( - a: Annotated[str, "The first number as a string"], - b: Annotated[str, "The second number as a string"], -) -> Annotated[str, "The difference of the two numbers as a string"]: - """ - Subtract two numbers - """ - # Use Decimal for arbitrary precision - a_decimal = Decimal(a) - b_decimal = Decimal(b) - return str(a_decimal - b_decimal) - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def multiply( - a: Annotated[str, "The first number as a string"], - b: Annotated[str, "The second number as a string"], -) -> Annotated[str, "The product of the two numbers as a string"]: - """ - Multiply two numbers together - """ - # Use Decimal for arbitrary precision - a_decimal = Decimal(a) - b_decimal = Decimal(b) - return str(a_decimal * b_decimal) - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def divide( - a: Annotated[str, "The first number as a string"], - b: Annotated[str, "The second number as a string"], -) -> Annotated[str, "The quotient of the two numbers as a string"]: - """ - Divide two numbers - """ - # Use Decimal for arbitrary precision - a_decimal = Decimal(a) - b_decimal = Decimal(b) - return str(a_decimal / b_decimal) - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def sum_list( - numbers: Annotated[list[str], "The list of numbers as strings"], -) -> Annotated[str, "The sum of the numbers in the list as a string"]: - """ - Sum all numbers in a list - """ - # Use Decimal for arbitrary precision - return str(sum([Decimal(n) for n in numbers])) - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def sum_range( - start: Annotated[str, "The start of the range to sum as a string"], - end: Annotated[str, "The end of the range to sum as a string"], -) -> Annotated[str, "The sum of the numbers in the list as a string"]: - """ - Sum all numbers from start through end - """ - return str(sum(list(range(int(start), int(end) + 1)))) - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def mod( - a: Annotated[str, "The dividend as a string"], - b: Annotated[str, "The divisor as a string"], -) -> Annotated[str, "The remainder after dividing a by b as a string"]: - """ - Calculate the remainder (modulus) of one number divided by another - """ - # Use Decimal for arbitrary precision - return str(Decimal(a) % Decimal(b)) diff --git a/toolkits/math/arcade_math/tools/exponents.py b/toolkits/math/arcade_math/tools/exponents.py deleted file mode 100644 index c4658e3d3..000000000 --- a/toolkits/math/arcade_math/tools/exponents.py +++ /dev/null @@ -1,51 +0,0 @@ -import decimal -import math -from decimal import Decimal -from typing import Annotated - -from arcade_mcp_server import tool -from arcade_mcp_server.metadata import Behavior, ToolMetadata - -decimal.getcontext().prec = 100 - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def log( - a: Annotated[str, "The number to take the logarithm of as a string"], - base: Annotated[str, "The logarithmic base as a string"], -) -> Annotated[str, "The logarithm of the number with the specified base as a string"]: - """ - Calculate the logarithm of a number with a given base - """ - # Use Decimal for arbitrary precision - return str(math.log(Decimal(a), Decimal(base))) - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def power( - a: Annotated[str, "The base number as a string"], - b: Annotated[str, "The exponent as a string"], -) -> Annotated[str, "The result of raising a to the power of b as a string"]: - """ - Calculate one number raised to the power of another - """ - # Use Decimal for arbitrary precision - return str(Decimal(a) ** Decimal(b)) diff --git a/toolkits/math/arcade_math/tools/miscellaneous.py b/toolkits/math/arcade_math/tools/miscellaneous.py deleted file mode 100644 index 85a4315c2..000000000 --- a/toolkits/math/arcade_math/tools/miscellaneous.py +++ /dev/null @@ -1,70 +0,0 @@ -import decimal -import math -from decimal import Decimal -from typing import Annotated - -from arcade_mcp_server import tool -from arcade_mcp_server.metadata import Behavior, ToolMetadata - -decimal.getcontext().prec = 100 - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def abs_val( - a: Annotated[str, "The number as a string"], -) -> Annotated[str, "The absolute value of the number as a string"]: - """ - Calculate the absolute value of a number - """ - # Use Decimal for arbitrary precision - return str(abs(Decimal(a))) - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def factorial( - a: Annotated[str, "The non-negative integer to compute the factorial for as a string"], -) -> Annotated[str, "The factorial of the number as a string"]: - """ - Compute the factorial of a non-negative integer - Returns "1" for "0" - """ - return str(math.factorial(int(a))) - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def sqrt( - a: Annotated[str, "The number to square root as a string"], -) -> Annotated[str, "The square root of the number as a string"]: - """ - Get the square root of a number - """ - # Use Decimal for arbitrary precision - a_decimal = Decimal(a) - return str(a_decimal.sqrt()) diff --git a/toolkits/math/arcade_math/tools/random.py b/toolkits/math/arcade_math/tools/random.py deleted file mode 100644 index 304823c8c..000000000 --- a/toolkits/math/arcade_math/tools/random.py +++ /dev/null @@ -1,57 +0,0 @@ -import random -from typing import Annotated - -from arcade_mcp_server import tool -from arcade_mcp_server.metadata import Behavior, ToolMetadata - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=False, - open_world=False, - ), - ), -) -def generate_random_int( - min_value: Annotated[str, "The minimum value of the random integer as a string"], - max_value: Annotated[str, "The maximum value of the random integer as a string"], - seed: Annotated[ - str | None, - "The seed for the random number generator as a string." - " If None, the current system time is used.", - ] = None, -) -> Annotated[str, "A random integer between min_value and max_value as a string"]: - """Generate a random integer between min_value and max_value (inclusive).""" - if seed is not None: - random.seed(int(seed)) - - return str(random.randint(int(min_value), int(max_value))) # noqa: S311 - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=False, - open_world=False, - ), - ), -) -def generate_random_float( - min_value: Annotated[str, "The minimum value of the random float as a string"], - max_value: Annotated[str, "The maximum value of the random float as a string"], - seed: Annotated[ - str | None, - "The seed for the random number generator as a string." - " If None, the current system time is used.", - ] = None, -) -> Annotated[str, "A random float between min_value and max_value as a string"]: - """Generate a random float between min_value and max_value.""" - if seed is not None: - random.seed(int(seed)) - - return str(random.uniform(float(min_value), float(max_value))) # noqa: S311 diff --git a/toolkits/math/arcade_math/tools/rational.py b/toolkits/math/arcade_math/tools/rational.py deleted file mode 100644 index bf9f60390..000000000 --- a/toolkits/math/arcade_math/tools/rational.py +++ /dev/null @@ -1,49 +0,0 @@ -import math -from typing import Annotated - -from arcade_mcp_server import tool -from arcade_mcp_server.metadata import Behavior, ToolMetadata - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def gcd( - a: Annotated[str, "First integer as a string"], - b: Annotated[str, "Second integer as a string"], -) -> Annotated[str, "The greatest common divisor of a and b as a string"]: - """ - Calculate the greatest common divisor (GCD) of two integers. - """ - return str(math.gcd(int(a), int(b))) - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def lcm( - a: Annotated[str, "First integer as a string"], - b: Annotated[str, "Second integer as a string"], -) -> Annotated[str, "The least common multiple of a and b as a string"]: - """ - Calculate the least common multiple (LCM) of two integers. - Returns "0" if either integer is 0. - """ - a_int, b_int = int(a), int(b) - if a_int == 0 or b_int == 0: - return "0" - return str(abs(a_int * b_int) // math.gcd(a_int, b_int)) diff --git a/toolkits/math/arcade_math/tools/rounding.py b/toolkits/math/arcade_math/tools/rounding.py deleted file mode 100644 index 1630f0b47..000000000 --- a/toolkits/math/arcade_math/tools/rounding.py +++ /dev/null @@ -1,75 +0,0 @@ -import decimal -import math -from decimal import Decimal -from typing import Annotated - -from arcade_mcp_server import tool -from arcade_mcp_server.metadata import Behavior, ToolMetadata - -decimal.getcontext().prec = 100 - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def ceil( - a: Annotated[str, "The number to round up as a string"], -) -> Annotated[str, "The smallest integer greater than or equal to the number as a string"]: - """ - Return the ceiling of a number - """ - # Use Decimal for arbitrary precision - return str(math.ceil(Decimal(a))) - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def floor( - a: Annotated[str, "The number to round down as a string"], -) -> Annotated[str, "The largest integer less than or equal to the number as a string"]: - """ - Return the floor of a number - """ - # Use Decimal for arbitrary precision - return str(math.floor(Decimal(a))) - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def round_num( - value: Annotated[str, "The number to round as a string"], - ndigits: Annotated[str, "The number of digits after the decimal point as a string"], -) -> Annotated[str, "The number rounded to the specified number of digits as a string"]: - """ - Round a number to a specified number of positive digits - """ - ndigits_int = int(ndigits) - if ndigits_int >= 0: - # Use Decimal for arbitrary precision - return str(round(Decimal(value), int(ndigits_int))) - # cast value from str -> float -> int here because rounding with negative - # decimals is only useful for weird math - return str(round(int(float(value)), int(ndigits_int))) diff --git a/toolkits/math/arcade_math/tools/statistics.py b/toolkits/math/arcade_math/tools/statistics.py deleted file mode 100644 index 9e996543c..000000000 --- a/toolkits/math/arcade_math/tools/statistics.py +++ /dev/null @@ -1,53 +0,0 @@ -import decimal -from decimal import Decimal -from statistics import median as stats_median -from typing import Annotated - -from arcade_mcp_server import tool -from arcade_mcp_server.metadata import Behavior, ToolMetadata - -decimal.getcontext().prec = 100 - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def avg( - numbers: Annotated[list[str], "The list of numbers as strings"], -) -> Annotated[str, "The average (mean) of the numbers in the list as a string"]: - """ - Calculate the average (mean) of a list of numbers. - Returns "0.0" if the list is empty. - """ - # Use Decimal for arbitrary precision - d_numbers = [Decimal(n) for n in numbers] - return str(sum(d_numbers) / len(d_numbers)) if d_numbers else "0.0" - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def median( - numbers: Annotated[list[str], "A list of numbers as strings"], -) -> Annotated[str, "The median value of the numbers in the list as a string"]: - """ - Calculate the median of a list of numbers. - Returns "0.0" if the list is empty. - """ - # Use Decimal for arbitrary precision - d_numbers = [Decimal(n) for n in numbers] - return str(stats_median(d_numbers)) if d_numbers else "0.0" diff --git a/toolkits/math/arcade_math/tools/trigonometry.py b/toolkits/math/arcade_math/tools/trigonometry.py deleted file mode 100644 index e31e129fc..000000000 --- a/toolkits/math/arcade_math/tools/trigonometry.py +++ /dev/null @@ -1,49 +0,0 @@ -import decimal -import math -from decimal import Decimal -from typing import Annotated - -from arcade_mcp_server import tool -from arcade_mcp_server.metadata import Behavior, ToolMetadata - -decimal.getcontext().prec = 100 - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def deg_to_rad( - degrees: Annotated[str, "Angle in degrees as a string"], -) -> Annotated[str, "Angle in radians as a string"]: - """ - Convert an angle from degrees to radians. - """ - # Use Decimal for arbitrary precision - return str(math.radians(Decimal(degrees))) - - -@tool( - metadata=ToolMetadata( - behavior=Behavior( - read_only=True, - destructive=False, - idempotent=True, - open_world=False, - ), - ), -) -def rad_to_deg( - radians: Annotated[str, "Angle in radians as a string"], -) -> Annotated[str, "Angle in degrees as a string"]: - """ - Convert an angle from radians to degrees. - """ - # Use Decimal for arbitrary precision - return str(math.degrees(Decimal(radians))) diff --git a/toolkits/math/evals/eval_math_tools.py b/toolkits/math/evals/eval_math_tools.py deleted file mode 100644 index 74c3aba68..000000000 --- a/toolkits/math/evals/eval_math_tools.py +++ /dev/null @@ -1,137 +0,0 @@ -from collections.abc import Callable -from typing import Any - -from arcade_core import ToolCatalog -from arcade_evals import ( - BinaryCritic, - EvalRubric, - EvalSuite, - ExpectedToolCall, - tool_eval, -) - -import arcade_math -from arcade_math.tools.arithmetic import ( - add, - divide, - mod, - multiply, - subtract, - sum_list, - sum_range, -) -from arcade_math.tools.exponents import ( - log, - power, -) -from arcade_math.tools.miscellaneous import ( - abs_val, - factorial, - sqrt, -) -from arcade_math.tools.rational import ( - gcd, - lcm, -) -from arcade_math.tools.rounding import ( - ceil, - floor, - round_num, -) -from arcade_math.tools.statistics import ( - avg, - median, -) -from arcade_math.tools.trigonometry import ( - deg_to_rad, - rad_to_deg, -) - -# Type alias for test case tuples: (function, prompt_template, params) -TestCase = tuple[Callable[..., Any], str, dict[str, Any]] - -# Evaluation rubric -rubric = EvalRubric( - fail_threshold=0.85, - warn_threshold=0.95, -) - - -catalog = ToolCatalog() -catalog.add_module(arcade_math) - - -@tool_eval() -def math_eval_suite() -> EvalSuite: - suite = EvalSuite( - name="Math Tools Evaluation", - system_message="You're an AI assistant with access to math tools. Use them to help the user with their math-related tasks.", - catalog=catalog, - rubric=rubric, - ) - - list_param = ["1", "2", "3", "4", "5"] - funcs_to_expression_and_params: list[TestCase] = [ - # unary - (sqrt, "What's the square root of {a}?", {"a": "25"}), - (abs_val, "What's the absolute value of {a}?", {"a": "-10"}), - (factorial, "What's the factorial of {a}?", {"a": "5"}), - (deg_to_rad, "Convert {degrees} from degrees to radians", {"degrees": "180"}), - (rad_to_deg, "Convert {radians} from radias to degrees", {"radians": "3.14"}), - (ceil, "Compute the ceiling of {a}", {"a": "3.14"}), - (floor, "Compute the floor of {a}", {"a": "3.14"}), - # binary - (add, "Add {a} and {b}", {"a": "12345", "b": "987654321"}), - (subtract, "Subtract {b} from {a}", {"a": "987654321", "b": "12345"}), - (multiply, "Multiply {a} and {b}", {"a": "12345", "b": "567890"}), - (divide, "What is {a} divided by {b}?", {"a": "1234123479", "b": "123"}), - ( - sum_range, - "What's the sum of all numbers from {start} to {end}?", - {"start": "10", "end": "345"}, - ), - (mod, "What's the remainder of dividing {a} by {b}?", {"a": "234", "b": "17"}), - (power, "Raise {a} to the power of {b}", {"a": "2", "b": "8"}), - (log, "What's the logarithm of {a} with base {base}?", {"a": "8", "base": "2"}), - ( - round_num, - "Round {value} to {ndigits} decimal places", - {"value": "12.23746234", "ndigits": "3"}, - ), - (gcd, "Find the greatest common divisor of {a} and {b}", {"a": "50", "b": "10"}), - (lcm, "FInd the least common multiple of {a} and {b}", {"a": "7", "b": "13"}), - # n-nary - ( - sum_list, - f"Calculate the sum of these numbers: {' '.join(list_param)}", - {"numbers": list_param}, - ), - ( - avg, - f"Find the average of these numbers: {' '.join(list_param)}", - {"numbers": list_param}, - ), - ( - median, - f"Find the median of these numbers: {' '.join(list_param)}", - {"numbers": list_param}, - ), - ] - - for func, expression, params in funcs_to_expression_and_params: - parametrized_expression = expression.format(**params) - num_params = len(params) - suite.add_case( - name=parametrized_expression, - user_message=parametrized_expression, - expected_tool_calls=[ - ExpectedToolCall( - func=func, - args=params, - ) - ], - rubric=rubric, - critics=[BinaryCritic(critic_field=param, weight=1.0 / num_params) for param in params], - ) - - return suite diff --git a/toolkits/math/pyproject.toml b/toolkits/math/pyproject.toml deleted file mode 100644 index 0253036ee..000000000 --- a/toolkits/math/pyproject.toml +++ /dev/null @@ -1,58 +0,0 @@ -[build-system] -requires = [ "hatchling",] -build-backend = "hatchling.build" - -[project] -name = "arcade_math" -version = "1.2.0" -description = "Arcade.dev LLM tools for doing math" -requires-python = ">=3.10" -dependencies = [ - "arcade-mcp-server>=1.17.0,<2.0.0", -] -[[project.authors]] -name = "Arcade" -email = "dev@arcade.dev" - -[project.optional-dependencies] -dev = [ - "arcade-mcp[all]>=1.2.0,<2.0.0", - "pytest>=8.3.0,<8.4.0", - "pytest-cov>=4.0.0,<4.1.0", - "pytest-asyncio>=0.24.0,<0.25.0", - "pytest-mock>=3.11.1,<3.12.0", - "mypy>=1.5.1,<1.6.0", - "pre-commit>=3.4.0,<3.5.0", - "tox>=4.11.1,<4.12.0", - "ruff>=0.7.4,<0.8.0", -] - -[project.scripts] -arcade-math = "arcade_math.__main__:main" -arcade_math = "arcade_math.__main__:main" - -# Use local path sources for arcade libs when working locally -[tool.uv.sources] -arcade-mcp = {path = "../../", editable = true} -arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true } - -[tool.mypy] -files = [ "arcade_math/**/*.py",] -python_version = "3.10" -disallow_untyped_defs = "True" -disallow_any_unimported = "True" -no_implicit_optional = "True" -check_untyped_defs = "True" -warn_return_any = "True" -warn_unused_ignores = "True" -show_error_codes = "True" -ignore_missing_imports = "True" - -[tool.pytest.ini_options] -testpaths = [ "tests",] - -[tool.coverage.report] -skip_empty = true - -[tool.hatch.build.targets.wheel] -packages = [ "arcade_math",] diff --git a/toolkits/math/tests/__init__.py b/toolkits/math/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/toolkits/math/tests/test_arithmetic.py b/toolkits/math/tests/test_arithmetic.py deleted file mode 100644 index a9f7617e3..000000000 --- a/toolkits/math/tests/test_arithmetic.py +++ /dev/null @@ -1,147 +0,0 @@ -import pytest -from arcade_mcp_server.exceptions import ToolExecutionError - -from arcade_math.tools.arithmetic import ( - add, - divide, - mod, - multiply, - subtract, - sum_list, - sum_range, -) - - -@pytest.mark.parametrize( - "a, b, expected", - [ - ("1", "2", "3"), - ("-1", "1", "0"), - ("0.5", "10.9", "11.4"), - # Big ints - ("12345678901234567890", "9876543210987654321", "22222222112222222211"), - # Big floats - ( - "12345678901234567890.120", - "9876543210987654321.987", - "22222222112222222212.107", - ), - ], -) -def test_add(a, b, expected): - assert add(a, b) == expected - - -@pytest.mark.parametrize( - "a, b, expected", - [ - ("1", "2", "-1"), - ("-1", "1", "-2"), - ("0.5", "10.9", "-10.4"), - # Big ints - ("12345678901234567890", "12323456679012345668", "22222222222222222"), - # Big floats - ( - "12345678901234567890.120", - "12343557689113355768.9079", - "2121212121212121.2121", - ), - ], -) -def test_subtract(a, b, expected): - assert subtract(a, b) == expected - - -@pytest.mark.parametrize( - "a, b, expected", - [ - ("-1", "2", "-2"), - ("-10", "0", "-0"), - ("0.5", "10.9", "5.45"), - # Big ints - ( - "12345678901234567890", - "18000000162000001474380013420000", - "222222222222222222222222222261233060226101083800000", - ), - # Big floats - ( - "12345678901234567890.120", - "12345678901234567890.120", - "152415787532388367504868162811315348393.614400", - ), - ], -) -def test_multiply(a, b, expected): - assert multiply(a, b) == expected - - -@pytest.mark.parametrize( - "a, b, expected", - [ - ("-1", "2", "-0.5"), - ("-10", "1", "-10"), - ( - "0.5", - "10.9", - "0.0458715596330275229357798165137614678899082568807339" - "4495412844036697247706422018348623853211009174312", - ), - # Big ints - ("152407406035740740602050", "12345678901234567890", "12345"), - # Big floats - ( - "152407406035740740603531.400", - "12345678901234567890.120", - "12345", - ), - ], -) -def test_divide(a, b, expected): - assert divide(a, b) == expected - - -def text_zero_division(): - with pytest.raises(ToolExecutionError): - divide("1", "0") - with pytest.raises(ToolExecutionError): - divide("1", "0.0") - with pytest.raises(ToolExecutionError): - divide("1", "0.000000") - - -def test_sum_list(): - assert sum_list(["1", "2", "3", "4", "5", "6"]) == "21" - assert sum_list([]) == "0" - assert sum_list(["-1", "-2", "-3", "-4", "-5", "-6"]) == "-21" - assert sum_list(["0.1", "0.2", "0.3", "0.3", "0.5", "0.7"]) == "2.1" - - -def test_sum_range(): - assert sum_range("8", "2") == "0" - assert sum_range("-8", "2") == "-33" - assert sum_range("8", "-2") == "0" - assert sum_range("2", "3") == "5" - assert sum_range("0", "10") == "55" - with pytest.raises(ToolExecutionError): - sum_range("2", "0.5") - with pytest.raises(ToolExecutionError): - sum_range("-1", "0.5") - with pytest.raises(ToolExecutionError): - sum_range("2.", "0.5") - with pytest.raises(ToolExecutionError): - sum_range("-1", "0.5") - - -def test_mod(): - assert mod("-1", "0.5") == "-0.0" - assert mod("-8", "2") == "-0" - assert mod("0", "10") == "0" - assert mod("2", "0.5") == "0.0" - assert mod("2", "3") == "2" - assert mod("2.", "-0.5") == "0.0" - assert mod("2.1234", "0.6") == "0.3234" - assert mod("2.1234", "1") == "0.1234" - assert mod("2.1234", "3") == "2.1234" - assert mod("8", "-2") == "0" - assert mod("8", "2") == "0" diff --git a/toolkits/math/tests/test_exponents.py b/toolkits/math/tests/test_exponents.py deleted file mode 100644 index f76325909..000000000 --- a/toolkits/math/tests/test_exponents.py +++ /dev/null @@ -1,41 +0,0 @@ -import pytest -from arcade_mcp_server.exceptions import ToolExecutionError - -from arcade_math.tools.exponents import ( - log, - power, -) - - -def test_log(): - assert log("8", "2") == "3.0" - assert log("2", "3") == "0.6309297535714574" - assert log("2", "0.5") == "-1.0" - with pytest.raises(ToolExecutionError): - log("-1", "0.5") - with pytest.raises(ToolExecutionError): - log("0", "10") - - -def test_power(): - assert power("-8", "2") == "64" - assert power("0", "10") == "0" - assert ( - power("2", "0.5") == "1.41421356237309504880168872420969807856" - "9671875376948073176679737990732478462107038850387534327641573" - ) - assert power("2", "3") == "8" - assert ( - power("2.", "-0.5") == "0.707106781186547524400844362104849039" - "2848359376884740365883398689953662392310535194251937671638207864" - ) - assert ( - power("2.1234", "0.6") == "1.571155202490495156807227174573016145" - "282682479346448636509576776014844055570115193494685328114403375" - ) - assert power("2.1234", "1") == "2.1234" - assert power("2.1234", "3") == "9.574044440904" - assert power("8", "-2") == "0.015625" - assert power("8", "2") == "64" - with pytest.raises(ToolExecutionError): - power("-1", "0.5") diff --git a/toolkits/math/tests/test_miscellaneous.py b/toolkits/math/tests/test_miscellaneous.py deleted file mode 100644 index 2addab6a2..000000000 --- a/toolkits/math/tests/test_miscellaneous.py +++ /dev/null @@ -1,81 +0,0 @@ -import pytest -from arcade_mcp_server.exceptions import ToolExecutionError - -from arcade_math.tools.miscellaneous import ( - abs_val, - factorial, - sqrt, -) - - -def test_abs_val(): - assert abs_val("2") == "2" - assert abs_val("-1") == "1" - assert abs_val("-1.12341234") == "1.12341234" - - -def test_factorial(): - assert factorial("1") == "1" - assert factorial("0") == "1" - assert factorial("-0") == "1" - assert factorial("23") == "25852016738884976640000" - assert factorial("24") == "620448401733239439360000" - assert factorial("10") == "3628800" - with pytest.raises(ToolExecutionError): - factorial("-1") - with pytest.raises(ToolExecutionError): - factorial("-10") - with pytest.raises(ToolExecutionError): - factorial("0.0000") - with pytest.raises(ToolExecutionError): - factorial("-0.0") - with pytest.raises(ToolExecutionError): - factorial("1.0") - with pytest.raises(ToolExecutionError): - factorial("-1.0") - with pytest.raises(ToolExecutionError): - factorial("23.0") - - -def test_sqrt(): - assert sqrt("1") == "1" - assert sqrt("0") == "0" - assert sqrt("-0") == "-0" - assert ( - sqrt("23") == "4.79583152331271954159743806416269391999670704190" - "4129346485309114448257235907464082492191446436918861" - ) - assert ( - sqrt("24") == "4.89897948556635619639456814941178278393189496131" - "3340256865385134501920754914630053079718866209280470" - ) - assert ( - sqrt("10") == "3.16227766016837933199889354443271853371955513932" - "5216826857504852792594438639238221344248108379300295" - ) - assert sqrt("0.0") == "0.0" - assert sqrt("0.0000") == "0.00" - assert sqrt("-0.0") == "-0.0" - assert sqrt("1.0") == "1.0" - assert ( - sqrt("3.14") == "1.772004514666935040199112509753631525073608516" - "162942966817771970290992972348902551472561151153909188" - ) - assert ( - sqrt("0.4") == "0.6324555320336758663997787088865437067439110278" - "650433653715009705585188877278476442688496216758600590" - ) - assert ( - sqrt("10.0") == "3.162277660168379331998893544432718533719555139" - "325216826857504852792594438639238221344248108379300295" - ) - with pytest.raises(ToolExecutionError): - sqrt("-1") - with pytest.raises(ToolExecutionError): - sqrt("-10") - with pytest.raises(ToolExecutionError): - sqrt("-1.0") - with pytest.raises(ToolExecutionError): - sqrt("-1.3") - with pytest.raises(ToolExecutionError): - sqrt("-10.0") diff --git a/toolkits/math/tests/test_rational.py b/toolkits/math/tests/test_rational.py deleted file mode 100644 index 88d76028d..000000000 --- a/toolkits/math/tests/test_rational.py +++ /dev/null @@ -1,31 +0,0 @@ -import pytest -from arcade_mcp_server.exceptions import ToolExecutionError - -from arcade_math.tools.rational import ( - gcd, - lcm, -) - - -def test_gcd(): - assert gcd("-15", "-5") == "5" - assert gcd("15", "0") == "15" - assert gcd("15", "-2") == "1" - assert gcd("15", "-0") == "15" - assert gcd("15", "5") == "5" - assert gcd("7", "13") == "1" - assert gcd("-13", "13") == "13" - with pytest.raises(ToolExecutionError): - gcd("15.0", "5.0") - - -def test_lcm(): - assert lcm("-15", "-5") == "15" - assert lcm("15", "0") == "0" - assert lcm("15", "-2") == "30" - assert lcm("15", "-0") == "0" - assert lcm("15", "5") == "15" - assert lcm("7", "13") == "91" - assert lcm("-13", "13") == "13" - with pytest.raises(ToolExecutionError): - lcm("15.0", "5.0") diff --git a/toolkits/math/tests/test_rounding.py b/toolkits/math/tests/test_rounding.py deleted file mode 100644 index 22e252a5a..000000000 --- a/toolkits/math/tests/test_rounding.py +++ /dev/null @@ -1,54 +0,0 @@ -from arcade_math.tools.rounding import ( - ceil, - floor, - round_num, -) - - -def test_ceil(): - assert ceil("1") == "1" - assert ceil("-1") == "-1" - assert ceil("0") == "0" - assert ceil("-0") == "0" - assert ceil("0.0") == "0" - assert ceil("0.0000") == "0" - assert ceil("-0.0") == "0" - assert ceil("1.0") == "1" - assert ceil("-1.0") == "-1" - assert ceil("3.14") == "4" - assert ceil("0.4") == "1" - assert ceil("-1.3") == "-1" - - -def test_floor(): - assert floor("1") == "1" - assert floor("-1") == "-1" - assert floor("0") == "0" - assert floor("-0") == "0" - assert floor("10") == "10" - assert floor("0.0") == "0" - assert floor("0.0000") == "0" - assert floor("-0.0") == "0" - assert floor("1.0") == "1" - assert floor("-1.0") == "-1" - assert floor("3.14") == "3" - assert floor("0.4") == "0" - assert floor("-1.3") == "-2" - - -def test_round_num(): - # TODO(mateo): ok with scientific notatin? ok with negative round digits? - assert round_num("1.2345", "-2") == "0" - assert round_num("1.2345", "-1") == "0" - assert round_num("1.2345", "0") == "1" - assert round_num("1.2345", "1") == "1.2" - assert round_num("1.2345", "2") == "1.23" - assert round_num("1.2345", "3") == "1.234" - assert round_num("1.2345", "8") == "1.23450000" - assert round_num("1.654321", "-2") == "0" - assert round_num("1.654321", "-1") == "0" - assert round_num("1.654321", "0") == "2" - assert round_num("1.654321", "1") == "1.7" - assert round_num("1.654321", "2") == "1.65" - assert round_num("1.654321", "3") == "1.654" - assert round_num("1.654321", "8") == "1.65432100" diff --git a/toolkits/math/tests/test_statistics.py b/toolkits/math/tests/test_statistics.py deleted file mode 100644 index 90a542af3..000000000 --- a/toolkits/math/tests/test_statistics.py +++ /dev/null @@ -1,18 +0,0 @@ -from arcade_math.tools.statistics import ( - avg, - median, -) - - -def test_avg(): - assert avg(["1", "2", "3", "4", "5", "6"]) == "3.5" - assert avg([]) == "0.0" - assert avg(["-1", "-2", "-3", "-4", "-5", "-6"]) == "-3.5" - assert avg(["0.1", "0.2", "0.3", "0.3", "0.5", "0.7"]) == "0.35" - - -def test_median(): - assert median(["1", "2", "3", "4", "5", "6"]) == "3.5" - assert median([]) == "0.0" - assert median(["-1", "-2", "-3", "-4", "-5", "-6"]) == "-3.5" - assert median(["0.1", "0.2", "0.3", "0.3", "0.5", "0.7"]) == "0.3" diff --git a/toolkits/math/tests/test_trigonometry.py b/toolkits/math/tests/test_trigonometry.py deleted file mode 100644 index 98e1f8fe3..000000000 --- a/toolkits/math/tests/test_trigonometry.py +++ /dev/null @@ -1,45 +0,0 @@ -from arcade_math.tools.trigonometry import ( - deg_to_rad, - rad_to_deg, -) - - -def test_deg_to_rad(): - assert deg_to_rad("1") == "0.017453292519943295" - assert deg_to_rad("-1") == "-0.017453292519943295" - assert deg_to_rad("0") == "0.0" - assert deg_to_rad("-0") == "-0.0" - assert deg_to_rad("23") == "0.4014257279586958" - assert deg_to_rad("24") == "0.4188790204786391" - assert deg_to_rad("-10") == "-0.17453292519943295" - assert deg_to_rad("10") == "0.17453292519943295" - assert deg_to_rad("180") == "3.141592653589793" - assert deg_to_rad("0.0") == "0.0" - assert deg_to_rad("0.0000") == "0.0" - assert deg_to_rad("-0.0") == "-0.0" - assert deg_to_rad("1.0") == "0.017453292519943295" - assert deg_to_rad("-1.0") == "-0.017453292519943295" - assert deg_to_rad("23.0") == "0.4014257279586958" - assert deg_to_rad("0.4") == "0.006981317007977318" - assert deg_to_rad("-10.0") == "-0.17453292519943295" - assert deg_to_rad("10.0") == "0.17453292519943295" - - -def test_rad_to_deg(): - assert rad_to_deg("1") == "57.29577951308232" - assert rad_to_deg("-1") == "-57.29577951308232" - assert rad_to_deg("0") == "0.0" - assert rad_to_deg("-0") == "-0.0" - assert rad_to_deg("23") == "1317.8029288008934" - assert rad_to_deg("24") == "1375.0987083139757" - assert rad_to_deg("-10") == "-572.9577951308232" - assert rad_to_deg("10") == "572.9577951308232" - assert rad_to_deg("0.0") == "0.0" - assert rad_to_deg("0.0000") == "0.0" - assert rad_to_deg("-0.0") == "-0.0" - assert rad_to_deg("1.0") == "57.29577951308232" - assert rad_to_deg("-1.0") == "-57.29577951308232" - assert rad_to_deg("3.14") == "179.9087476710785" - assert rad_to_deg("0.4") == "22.918311805232932" - assert rad_to_deg("-10.0") == "-572.9577951308232" - assert rad_to_deg("10.0") == "572.9577951308232" diff --git a/toolkits/mongodb/Makefile b/toolkits/mongodb/Makefile deleted file mode 100644 index 7e2c686e1..000000000 --- a/toolkits/mongodb/Makefile +++ /dev/null @@ -1,53 +0,0 @@ -.PHONY: help - -help: - @echo "๐Ÿ› ๏ธ github Commands:\n" - @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' - -.PHONY: install -install: ## Install the uv environment and install all packages with dependencies - @echo "๐Ÿš€ Creating virtual environment and installing all packages using uv" - @uv sync --active --all-extras --no-sources - @uv run pre-commit install - @echo "โœ… All packages and dependencies installed via uv" - -.PHONY: install-local -install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources - @echo "๐Ÿš€ Creating virtual environment and installing all packages using uv" - @uv sync --active --all-extras - @uv run pre-commit install - @echo "โœ… All packages and dependencies installed via uv" - -.PHONY: build -build: clean-build ## Build wheel file using poetry - @echo "๐Ÿš€ Creating wheel file" - uv build - -.PHONY: clean-build -clean-build: ## clean build artifacts - @echo "๐Ÿ—‘๏ธ Cleaning dist directory" - rm -rf dist - -.PHONY: test -test: ## Test the code with pytest - @echo "๐Ÿš€ Testing code: Running pytest" - @uv run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml - -.PHONY: coverage -coverage: ## Generate coverage report - @echo "coverage report" - coverage report - @echo "Generating coverage report" - coverage html - -.PHONY: bump-version -bump-version: ## Bump the version in the pyproject.toml file by a patch version - @echo "๐Ÿš€ Bumping version in pyproject.toml" - uv version --bump patch - -.PHONY: check -check: ## Run code quality tools. - @echo "๐Ÿš€ Linting code: Running pre-commit" - @uv run pre-commit run -a - @echo "๐Ÿš€ Static type checking: Running mypy" - @uv run mypy --config-file=pyproject.toml diff --git a/toolkits/mongodb/arcade_mongodb/__init__.py b/toolkits/mongodb/arcade_mongodb/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/toolkits/mongodb/arcade_mongodb/__main__.py b/toolkits/mongodb/arcade_mongodb/__main__.py deleted file mode 100644 index ece2973bf..000000000 --- a/toolkits/mongodb/arcade_mongodb/__main__.py +++ /dev/null @@ -1,29 +0,0 @@ -import sys -from typing import cast - -from arcade_mcp_server import MCPApp -from arcade_mcp_server.mcp_app import TransportType - -import arcade_mongodb - -app = MCPApp( - name="MongoDB", - instructions=( - "Use this server when you need to interact with MongoDB to help users " - "query, explore, and manage their MongoDB databases and collections." - ), -) - -app.add_tools_from_module(arcade_mongodb) - - -def main() -> None: - transport = sys.argv[1] if len(sys.argv) > 1 else "stdio" - host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1" - port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000 - - app.run(transport=cast(TransportType, transport), host=host, port=port) - - -if __name__ == "__main__": - main() diff --git a/toolkits/mongodb/arcade_mongodb/database_engine.py b/toolkits/mongodb/arcade_mongodb/database_engine.py deleted file mode 100644 index 7eda9b033..000000000 --- a/toolkits/mongodb/arcade_mongodb/database_engine.py +++ /dev/null @@ -1,118 +0,0 @@ -from typing import Any, ClassVar - -from arcade_mcp_server.exceptions import RetryableToolError -from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase -from pymongo.errors import ServerSelectionTimeoutError - -MAX_RECORDS_RETURNED = 1000 -TEST_QUERY = {"ping": 1} - - -class DatabaseEngine: - _instance: ClassVar[None] = None - _clients: ClassVar[dict[str, AsyncIOMotorClient]] = {} - - @classmethod - async def get_instance(cls, connection_string: str) -> AsyncIOMotorClient: - key = connection_string - if key not in cls._clients: - cls._clients[key] = AsyncIOMotorClient(connection_string) - - # try a simple query to see if the connection is valid - try: - admin_db = cls._clients[key].admin - await admin_db.command(TEST_QUERY) - return cls._clients[key] - except ServerSelectionTimeoutError: - # close and try again - cls._clients[key].close() - cls._clients[key] = AsyncIOMotorClient(connection_string) - - try: - admin_db = cls._clients[key].admin - await admin_db.command(TEST_QUERY) - return cls._clients[key] - except Exception as e: - raise RetryableToolError( - f"Connection failed: {e}", - developer_message="Connection to MongoDB failed.", - additional_prompt_content="Check the connection string and try again.", - ) from e - - @classmethod - async def get_database(cls, connection_string: str, database_name: str) -> Any: - client = await cls.get_instance(connection_string) - - class DatabaseContextManager: - def __init__(self, client: AsyncIOMotorClient, database_name: str) -> None: - self.client = client - self.database_name = database_name - self.database = client[database_name] - - async def __aenter__(self) -> AsyncIOMotorDatabase: - return self.database - - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - # Connection cleanup is handled by the client cache - pass - - return DatabaseContextManager(client, database_name) - - @classmethod - async def cleanup(cls) -> None: - """Clean up all cached clients. Call this when shutting down.""" - for client in cls._clients.values(): - client.close() - cls._clients.clear() - - @classmethod - def clear_cache(cls) -> None: - """Clear the client cache without closing clients. Use with caution.""" - cls._clients.clear() - - @classmethod - def sanitize_query_params( - cls, - database_name: str, - collection_name: str, - filter_dict: dict[str, Any] | None, - projection: dict[str, Any] | None, - sort: list[dict[str, Any]] | None, - limit: int, - skip: int, - ) -> tuple[ - str, str, dict[str, Any], dict[str, Any] | None, list[dict[str, Any]] | None, int, int - ]: - if not database_name: - raise RetryableToolError( - "Database name is required.", - developer_message="Database name cannot be empty.", - ) - - if not collection_name: - raise RetryableToolError( - "Collection name is required.", - developer_message="Collection name cannot be empty.", - ) - - if filter_dict is None: - filter_dict = {} - - if limit > MAX_RECORDS_RETURNED: - raise RetryableToolError( - f"Limit is too high. Maximum is {MAX_RECORDS_RETURNED}.", - ) - - if skip < 0: - raise RetryableToolError( - "Skip must be greater than or equal to 0.", - developer_message="Skip must be greater than or equal to 0.", - ) - - if limit <= 0: - raise RetryableToolError( - "Limit must be greater than 0.", - developer_message="Limit must be greater than 0.", - ) - - return database_name, collection_name, filter_dict, projection, sort, limit, skip diff --git a/toolkits/mongodb/arcade_mongodb/tools/__init__.py b/toolkits/mongodb/arcade_mongodb/tools/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/toolkits/mongodb/arcade_mongodb/tools/mongodb.py b/toolkits/mongodb/arcade_mongodb/tools/mongodb.py deleted file mode 100644 index 8f31841a1..000000000 --- a/toolkits/mongodb/arcade_mongodb/tools/mongodb.py +++ /dev/null @@ -1,434 +0,0 @@ -import json -from typing import Annotated, Any - -from arcade_mcp_server import Context, tool -from arcade_mcp_server.exceptions import RetryableToolError -from arcade_mcp_server.metadata import Behavior, Operation, ToolMetadata - -from ..database_engine import MAX_RECORDS_RETURNED, DatabaseEngine -from .utils import ( - _infer_schema_from_docs, - _parse_json_list_parameter, - _parse_json_parameter, - _serialize_document, -) - -# class UserStatus(str, Enum): -# """User status enumeration.""" - -# ACTIVE = "active" -# INACTIVE = "inactive" -# BANNED = "banned" - - -@tool( - requires_secrets=["MONGODB_CONNECTION_STRING"], - metadata=ToolMetadata( - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def discover_databases( - context: Context, -) -> list[str]: - """Discover all the databases in the MongoDB instance.""" - client = await DatabaseEngine.get_instance(context.get_secret("MONGODB_CONNECTION_STRING")) - databases = await client.list_database_names() - # Filter out admin and config databases by default - databases = [db for db in databases if db not in ["admin", "config", "local"]] - return databases - - -@tool( - requires_secrets=["MONGODB_CONNECTION_STRING"], - metadata=ToolMetadata( - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def discover_collections( - context: Context, - database_name: Annotated[str, "The database name to discover collections in"], -) -> list[str]: - """Discover all the collections in the MongoDB database when the list of collections is not known. - - ALWAYS use this tool before any other tool that requires a collection name. - """ - async with await DatabaseEngine.get_database( - context.get_secret("MONGODB_CONNECTION_STRING"), database_name - ) as db: - collections = await db.list_collection_names() - return list(collections) - - -@tool( - requires_secrets=["MONGODB_CONNECTION_STRING"], - metadata=ToolMetadata( - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def get_collection_schema( - context: Context, - database_name: Annotated[str, "The database name to get the collection schema of"], - collection_name: Annotated[str, "The collection to get the schema of"], - sample_size: Annotated[ - int, - f"The number of documents to sample for schema discovery (default: {MAX_RECORDS_RETURNED})", - ] = MAX_RECORDS_RETURNED, -) -> dict[str, Any]: - """ - Get the schema/structure of a MongoDB collection by sampling documents. - - Since MongoDB is schema-less, this tool samples a configurable number of documents - to infer the schema structure and data types. - - This tool should ALWAYS be used before executing any query. All collections in the query must be discovered first using the tool. - """ - async with await DatabaseEngine.get_database( - context.get_secret("MONGODB_CONNECTION_STRING"), database_name - ) as db: - collection = db[collection_name] - - # Sample documents at random to infer schema - # Use MongoDB's $sample aggregation to get random documents - sample_docs = [] - async for doc in collection.aggregate([{"$sample": {"size": sample_size}}]): - sample_docs.append(doc) - - if not sample_docs: - return {"message": "Collection is empty", "schema": {}} - - # Infer schema from sampled documents - schema = _infer_schema_from_docs(sample_docs) - - return { - "total_documents_sampled": len(sample_docs), - "sample_size_requested": sample_size, - "schema": schema, - } - - -@tool( - requires_secrets=["MONGODB_CONNECTION_STRING"], - metadata=ToolMetadata( - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def find_documents( - context: Context, - database_name: Annotated[str, "The database name to query"], - collection_name: Annotated[str, "The collection name to query"], - filter_dict: Annotated[ - str | None, - 'MongoDB filter/query as JSON string. Leave None for no filter (find all documents). Example: \'{"status": "active", "age": {"$gte": 18}}\'', - ] = None, - projection: Annotated[ - str | None, - 'Fields to include/exclude as JSON string. Use 1 to include, 0 to exclude. Example: \'{"name": 1, "email": 1, "_id": 0}\'. Leave None to include all fields.', - ] = None, - sort: Annotated[ - list[str] | None, - 'Sort criteria as list of JSON strings, each containing \'field\' and \'direction\' keys. Use 1 for ascending, -1 for descending. Example: [\'{"field": "name", "direction": 1}\', \'{"field": "created_at", "direction": -1}\']', - ] = None, - limit: Annotated[ - int, - f"The maximum number of documents to return. Default: {MAX_RECORDS_RETURNED}.", - ] = MAX_RECORDS_RETURNED, - skip: Annotated[int, "The number of documents to skip. Default: 0."] = 0, -) -> list[str]: - """ - Find documents in a MongoDB collection. - - ONLY use this tool if you have already loaded the schema of the collection you need to query. - Use the tool to load the schema if not already known. - - Returns a list of JSON strings, where each string represents a document from the collection (tools cannot return complex types). - - When running queries, follow these rules which will help avoid errors: - * Always specify projection to limit fields returned if you don't need all data. - * Always sort your results by the most important fields first. If you aren't sure, sort by '_id'. - * Use appropriate MongoDB query operators for complex filtering ($gte, $lte, $in, $regex, etc.). - * Be mindful of case sensitivity when querying string fields. - * Use indexes when possible (typically on _id and commonly queried fields). - """ - # Initialize variables to avoid UnboundLocalError in exception handler - parsed_filter = None - parsed_projection = None - parsed_sort = None - - try: - # Parse JSON string inputs - parsed_filter = _parse_json_parameter(filter_dict, "filter_dict") - parsed_projection = _parse_json_parameter(projection, "projection") - parsed_sort = _parse_json_list_parameter(sort, "sort") - - ( - database_name, - collection_name, - parsed_filter, - parsed_projection, - parsed_sort, - limit, - skip, - ) = DatabaseEngine.sanitize_query_params( - database_name=database_name, - collection_name=collection_name, - filter_dict=parsed_filter, - projection=parsed_projection, - sort=parsed_sort, - limit=limit, - skip=skip, - ) - - async with await DatabaseEngine.get_database( - context.get_secret("MONGODB_CONNECTION_STRING"), database_name - ) as db: - collection = db[collection_name] - - # Build the query - cursor = collection.find(parsed_filter, parsed_projection) - - if parsed_sort: - # Convert list of dicts to list of tuples for MongoDB sort - sort_tuples = [(str(item["field"]), int(item["direction"])) for item in parsed_sort] - cursor = cursor.sort(sort_tuples) - - cursor = cursor.skip(skip).limit(limit) - - # Execute query and collect results - documents = [] - async for doc in cursor: - # Convert ObjectId and other non-serializable types to strings - doc = _serialize_document(doc) - documents.append(json.dumps(doc)) - - return documents - - except RetryableToolError: - # Re-raise RetryableToolError as-is to preserve JSON validation messages - raise - except Exception as e: - raise RetryableToolError( - f"Query failed: {e}", - developer_message=f"Query failed with parameters: database_name={database_name}, collection_name={collection_name}, filter_dict={parsed_filter}, projection={parsed_projection}, sort={parsed_sort}, limit={limit}, skip={skip}.", - additional_prompt_content="Load the collection schema or use the tool to discover the collections and try again.", - retry_after_ms=10, - ) from e - - -@tool( - requires_secrets=["MONGODB_CONNECTION_STRING"], - metadata=ToolMetadata( - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def count_documents( - context: Context, - database_name: Annotated[str, "The database name to query"], - collection_name: Annotated[str, "The collection name to query"], - filter_dict: Annotated[ - str | None, - 'MongoDB filter/query as JSON string. Leave None for no filter (count all documents). Example: \'{"status": "active"}\'', - ] = None, -) -> int: - """Count documents in a MongoDB collection matching the given filter.""" - parsed_filter = None - - try: - # Parse JSON string input - parsed_filter = _parse_json_parameter(filter_dict, "filter_dict") or {} - - async with await DatabaseEngine.get_database( - context.get_secret("MONGODB_CONNECTION_STRING"), database_name - ) as db: - collection = db[collection_name] - - count = await collection.count_documents(parsed_filter) - return int(count) - - except RetryableToolError: - # Re-raise RetryableToolError as-is to preserve JSON validation messages - raise - except Exception as e: - raise RetryableToolError( - f"Count query failed: {e}", - developer_message=f"Count query failed with parameters: database_name={database_name}, collection_name={collection_name}, filter_dict={parsed_filter}.", - additional_prompt_content="Check the collection name and filter criteria and try again.", - retry_after_ms=10, - ) from e - - -@tool( - requires_secrets=["MONGODB_CONNECTION_STRING"], - metadata=ToolMetadata( - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def aggregate_documents( - context: Context, - database_name: Annotated[str, "The database name to query"], - collection_name: Annotated[str, "The collection name to query"], - pipeline: Annotated[ - list[str], - 'MongoDB aggregation pipeline as a list of JSON strings, each representing a stage. Example: [\'{"$match": {"status": "active"}}\', \'{"$group": {"_id": "$category", "count": {"$sum": 1}}}\']', - ], - limit: Annotated[ - int, - f"The maximum number of results to return from the aggregation. Default: {MAX_RECORDS_RETURNED}.", - ] = MAX_RECORDS_RETURNED, -) -> list[str]: - """ - Execute a MongoDB aggregation pipeline on a collection. - - ONLY use this tool if you have already loaded the schema of the collection you need to query. - Use the tool to load the schema if not already known. - - Returns a list of JSON strings, where each string represents a result document from the aggregation (tools cannot return complex types). - - Aggregation pipelines allow for complex data processing including: - * $match - filter documents - * $group - group documents and perform calculations - * $project - reshape documents - * $sort - sort documents - * $limit - limit results - * $lookup - join with other collections - * And many more stages - """ - parsed_pipeline = None - - try: - # Parse JSON string inputs - parsed_pipeline = _parse_json_list_parameter(pipeline, "pipeline") - - if parsed_pipeline is None: - raise RetryableToolError( # noqa: TRY301 - "Pipeline cannot be empty", - developer_message="The pipeline parameter is required and cannot be None", - ) - - async with await DatabaseEngine.get_database( - context.get_secret("MONGODB_CONNECTION_STRING"), database_name - ) as db: - collection = db[collection_name] - - # Add limit to pipeline if not already present - pipeline_with_limit = parsed_pipeline.copy() - has_limit = any("$limit" in stage for stage in pipeline_with_limit) - if not has_limit: - pipeline_with_limit.append({"$limit": limit}) - - # Execute aggregation - cursor = collection.aggregate(pipeline_with_limit) - - documents = [] - async for doc in cursor: - # Convert ObjectId and other non-serializable types to strings - doc = _serialize_document(doc) - documents.append(json.dumps(doc)) - - return documents - - except RetryableToolError: - # Re-raise RetryableToolError as-is to preserve JSON validation messages - raise - except Exception as e: - raise RetryableToolError( - f"Aggregation query failed: {e}", - developer_message=f"Aggregation query failed with parameters: database_name={database_name}, collection_name={collection_name}, pipeline={parsed_pipeline}, limit={limit}.", - additional_prompt_content="Check the aggregation pipeline syntax and collection schema, then try again.", - retry_after_ms=10, - ) from e - - -# @tool(requires_secrets=["MONGODB_CONNECTION_STRING"]) -# async def update_user_status( -# context: ToolContext, -# database_name: Annotated[str, "The database name containing the users collection"], -# collection_name: Annotated[str, "The collection name containing user documents"], -# user_id: Annotated[str, "The _id of the user to update"], -# status: Annotated[UserStatus, "The new status for the user"], -# ) -> dict[str, Any]: -# """ -# [CUSTOM TOOL] -# Update the status of a user in the MongoDB collection. - -# This tool updates a user document by setting the status field to the specified value. -# The status must be one of: active, inactive, or banned. - -# Returns information about the update operation including the number of documents modified. -# """ - -# try: -# async with await DatabaseEngine.get_database( -# context.get_secret("MONGODB_CONNECTION_STRING"), database_name -# ) as db: -# collection = db[collection_name] - -# # cast the user_id to int if it looks like an integer -# if isinstance(user_id, str) and user_id.isdigit(): -# user_id = int(user_id) - -# result = await collection.update_one( -# {"_id": user_id}, {"$set": {"status": status.value}} -# ) - -# print(result) - -# if result.matched_count == 0: -# return { -# "success": False, -# "message": f"No user found with _id: {user_id}", -# "matched_count": 0, -# "modified_count": 0, -# } - -# return { -# "success": True, -# "message": f"User status updated to '{status.value}'", -# "user_id": user_id, -# "new_status": status.value, -# "matched_count": result.matched_count, -# "modified_count": result.modified_count, -# } - -# except Exception as e: -# raise RetryableToolError( -# f"Failed to update user status: {e}", -# developer_message=f"Update operation failed with parameters: database_name={database_name}, collection_name={collection_name}, user_id={user_id}, status={status}.", -# additional_prompt_content="Check the database name, collection name, and user ID, then try again.", -# retry_after_ms=10, -# ) from e diff --git a/toolkits/mongodb/arcade_mongodb/tools/utils.py b/toolkits/mongodb/arcade_mongodb/tools/utils.py deleted file mode 100644 index d6f074b09..000000000 --- a/toolkits/mongodb/arcade_mongodb/tools/utils.py +++ /dev/null @@ -1,281 +0,0 @@ -import json -from datetime import datetime -from typing import Any - -from arcade_mcp_server.exceptions import RetryableToolError -from bson import ObjectId - - -def _validate_no_write_operations(obj: Any, parameter_name: str, path: str = "") -> None: - """ - Recursively validate that an object doesn't contain MongoDB write operations. - - Args: - obj: The object to validate - parameter_name: Name of the parameter for error messages - path: Current path in the object (for nested validation) - - Raises: - RetryableToolError: If write operations are detected - """ - # MongoDB write/update operators that should be blocked - WRITE_OPERATORS = { - # Update operators - "$set", - "$unset", - "$inc", - "$mul", - "$rename", - "$min", - "$max", - "$currentDate", - "$addToSet", - "$pop", - "$pull", - "$push", - "$pullAll", - "$each", - "$slice", - "$sort", - "$position", - "$bit", - "$isolated", - # Array update operators - "$", - "$[]", - "$[]", - # Pipeline update operators - "$addFields", - "$replaceRoot", - "$replaceWith", - # Aggregation stages that can modify (in case they're misused) - "$out", - "$merge", - # Other potentially dangerous operators - "$where", # Can execute JavaScript - } - - if isinstance(obj, dict): - for key, value in obj.items(): - current_path = f"{path}.{key}" if path else key - - # Special check for $where operator which can execute JavaScript (check this first) - if key == "$where": - raise RetryableToolError( - f"JavaScript execution operator '$where' not allowed in {parameter_name}", - developer_message=f"Found '$where' operator at path '{current_path}' in parameter '{parameter_name}'. JavaScript execution is not allowed for security reasons.", - additional_prompt_content=f"The {parameter_name} parameter cannot use the $where operator. Use other query operators instead.", - ) - - # Check if this key is a write operator - if key in WRITE_OPERATORS: - raise RetryableToolError( - f"Write operation '{key}' not allowed in {parameter_name}", - developer_message=f"Found write operation '{key}' at path '{current_path}' in parameter '{parameter_name}'. Only read operations are allowed.", - additional_prompt_content=f"The {parameter_name} parameter cannot contain write operations like '{key}'. Use only query/read operations such as $match, $gte, $lte, $in, $regex, etc.", - ) - - # Recursively validate nested objects - _validate_no_write_operations(value, parameter_name, current_path) - - elif isinstance(obj, list): - for i, item in enumerate(obj): - current_path = f"{path}[{i}]" if path else f"[{i}]" - _validate_no_write_operations(item, parameter_name, current_path) - - -def _parse_json_parameter( - json_string: str | None, parameter_name: str, validate_read_only: bool = True -) -> Any | None: - """ - Parse a JSON string parameter with proper error handling and optional write operation validation. - - Args: - json_string: The JSON string to parse (can be None) - parameter_name: Name of the parameter for error messages - validate_read_only: Whether to validate that no write operations are present - - Returns: - Parsed JSON object or None if json_string is None - - Raises: - RetryableToolError: If JSON parsing fails or write operations are detected - """ - if json_string is None: - return None - - try: - parsed_obj = json.loads(json_string) - - # Validate that no write operations are present - if validate_read_only and parsed_obj is not None: - _validate_no_write_operations(parsed_obj, parameter_name) - - except json.JSONDecodeError as e: - raise RetryableToolError( - f"Invalid JSON in {parameter_name}: {e}", - developer_message=f"Failed to parse JSON string for parameter '{parameter_name}': {json_string}. Error: {e}", - additional_prompt_content=f"Please provide valid JSON for the {parameter_name} parameter. Check for proper escaping of quotes and valid JSON syntax.", - ) from e - else: - return parsed_obj - - -def _validate_aggregation_pipeline(pipeline: list[Any], parameter_name: str) -> None: - """ - Validate that an aggregation pipeline only contains read operations. - - Args: - pipeline: The aggregation pipeline to validate - parameter_name: Name of the parameter for error messages - - Raises: - RetryableToolError: If write operations are detected in the pipeline - """ - # MongoDB aggregation stages that can modify data - WRITE_STAGES = { - "$out", - "$merge", # These stages write to collections - } - - # Aggregation stages that are potentially dangerous - DANGEROUS_STAGES = { - "$where", # Can execute JavaScript - } - - for i, stage in enumerate(pipeline): - if isinstance(stage, dict): - for stage_name in stage: - if stage_name in WRITE_STAGES: - raise RetryableToolError( - f"Write stage '{stage_name}' not allowed in {parameter_name}", - developer_message=f"Found write stage '{stage_name}' at pipeline index {i} in parameter '{parameter_name}'. Only read operations are allowed.", - additional_prompt_content=f"The {parameter_name} parameter cannot contain write stages like '{stage_name}'. Use only read stages such as $match, $group, $project, $sort, $limit, etc.", - ) - - if stage_name in DANGEROUS_STAGES: - raise RetryableToolError( - f"Dangerous stage '{stage_name}' not allowed in {parameter_name}", - developer_message=f"Found dangerous stage '{stage_name}' at pipeline index {i} in parameter '{parameter_name}'. JavaScript execution is not allowed for security reasons.", - additional_prompt_content=f"The {parameter_name} parameter cannot use the {stage_name} stage. Use other aggregation stages instead.", - ) - - # Also validate the stage content for write operations - _validate_no_write_operations( - stage[stage_name], f"{parameter_name}[{i}].{stage_name}" - ) - - -def _parse_json_list_parameter( - json_strings: list[str] | None, parameter_name: str, validate_read_only: bool = True -) -> list[Any] | None: - """ - Parse a list of JSON strings with proper error handling and optional write operation validation. - - Args: - json_strings: List of JSON strings to parse (can be None) - parameter_name: Name of the parameter for error messages - validate_read_only: Whether to validate that no write operations are present - - Returns: - List of parsed JSON objects or None if json_strings is None - - Raises: - RetryableToolError: If JSON parsing fails for any string or write operations are detected - """ - if json_strings is None: - return None - - try: - parsed_list = [json.loads(json_str) for json_str in json_strings] - - # Validate that no write operations are present - if validate_read_only and parsed_list is not None: - # Special handling for pipeline parameters - if parameter_name == "pipeline": - _validate_aggregation_pipeline(parsed_list, parameter_name) - else: - # For non-pipeline lists, validate each item - for i, item in enumerate(parsed_list): - _validate_no_write_operations(item, f"{parameter_name}[{i}]") - - except json.JSONDecodeError as e: - raise RetryableToolError( - f"Invalid JSON in {parameter_name}: {e}", - developer_message=f"Failed to parse JSON string list for parameter '{parameter_name}': {json_strings}. Error: {e}", - additional_prompt_content=f"Please provide valid JSON strings for the {parameter_name} parameter. Each string must be valid JSON with proper escaping of quotes.", - ) from e - else: - return parsed_list - - -def _infer_schema_from_docs(docs: list[dict[str, Any]]) -> dict[str, Any]: - """Infer schema structure from a list of documents.""" - schema: dict[str, Any] = {} - - for doc in docs: - _update_schema_with_doc(schema, doc) - - # Convert sets to lists for serialization - for key in schema: - if isinstance(schema[key]["types"], set): - schema[key]["types"] = list(schema[key]["types"]) - - return schema - - -def _update_schema_with_doc(schema: dict[str, Any], doc: dict[str, Any], prefix: str = "") -> None: - """Recursively update schema with document structure.""" - for key, value in doc.items(): - full_key = f"{prefix}.{key}" if prefix else key - - if full_key not in schema: - schema[full_key] = { - "types": set(), - "sample_values": [], - "null_count": 0, - "total_count": 0, - } - - schema[full_key]["total_count"] += 1 - - if value is None: - schema[full_key]["null_count"] += 1 - schema[full_key]["types"].add("null") - else: - value_type = type(value).__name__ - schema[full_key]["types"].add(value_type) - - # Store sample values (limit to 3 unique samples) - if ( - len(schema[full_key]["sample_values"]) < 3 - and value not in schema[full_key]["sample_values"] - ): - schema[full_key]["sample_values"].append(value) - - # Handle nested objects - if isinstance(value, dict): - _update_schema_with_doc(schema, value, full_key) - elif isinstance(value, list) and value and isinstance(value[0], dict): - # Handle arrays of objects by sampling the first few - for i, item in enumerate(value[:3]): # Sample first 3 array items - if isinstance(item, dict): - _update_schema_with_doc(schema, item, f"{full_key}[{i}]") - - -def _serialize_document(doc: dict[str, Any]) -> dict[str, Any]: - """Convert MongoDB document to JSON-serializable format.""" - - if isinstance(doc, dict): - result = {} - for key, value in doc.items(): - result[key] = _serialize_document(value) - return result - elif isinstance(doc, list): - return [_serialize_document(item) for item in doc] - elif isinstance(doc, ObjectId): - return str(doc) - elif isinstance(doc, datetime): - return doc.isoformat() - else: - return doc diff --git a/toolkits/mongodb/evals/eval_mongodb.py b/toolkits/mongodb/evals/eval_mongodb.py deleted file mode 100644 index 34aa20497..000000000 --- a/toolkits/mongodb/evals/eval_mongodb.py +++ /dev/null @@ -1,190 +0,0 @@ -# RUN ME WITH `uv run arcade evals evals --host api.arcade.dev` - -import arcade_mongodb -from arcade_core import ToolCatalog -from arcade_evals import ( - BinaryCritic, - EvalRubric, - EvalSuite, - ExpectedToolCall, - SimilarityCritic, - tool_eval, -) -from arcade_mongodb.tools.mongodb import ( - aggregate_documents, - count_documents, - discover_collections, - discover_databases, - find_documents, - get_collection_schema, -) - -# Evaluation rubric -rubric = EvalRubric( - fail_threshold=0.85, - warn_threshold=0.95, -) - - -catalog = ToolCatalog() -catalog.add_module(arcade_mongodb) - - -@tool_eval() -def mongodb_eval_suite() -> EvalSuite: - suite = EvalSuite( - name="MongoDB Tools Evaluation", - system_message=( - "You are an AI assistant with access to MongoDB tools. " - "Use them to help the user with their tasks." - ), - catalog=catalog, - rubric=rubric, - ) - - suite.add_case( - name="Discover databases", - user_message="What databases are available in my MongoDB instance?", - expected_tool_calls=[ - ExpectedToolCall(func=discover_databases, args={}), - ], - rubric=rubric, - ) - - suite.add_case( - name="Discover collections", - user_message="What collections are in the 'admin' database?", - expected_tool_calls=[ - ExpectedToolCall(func=discover_collections, args={"database_name": "admin"}), - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="database_name", weight=1.0), - ], - ) - - suite.add_case( - name="Get collection schema (single tool call)", - user_message="Get the schema of the 'system.users' collection in the 'admin' database.", - expected_tool_calls=[ - ExpectedToolCall( - func=get_collection_schema, - args={"database_name": "admin", "collection_name": "system.users"}, - ), - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="database_name", weight=0.5), - BinaryCritic(critic_field="collection_name", weight=0.5), - ], - ) - - suite.add_case( - name="Find documents (direct call)", - user_message="Find documents in the 'startup_log' collection of the 'local' database, limited to 5 results.", - additional_messages=[ - { - "role": "user", - "content": "You can call find_documents directly without discovering collections first for this test.", - } - ], - expected_tool_calls=[ - ExpectedToolCall( - func=find_documents, - args={ - "database_name": "local", - "collection_name": "startup_log", - "limit": 5, - }, - ), - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="database_name", weight=0.33), - BinaryCritic(critic_field="collection_name", weight=0.33), - BinaryCritic(critic_field="limit", weight=0.34), - ], - ) - - suite.add_case( - name="Count documents", - user_message="Count all documents in the 'startup_log' collection of the 'local' database.", - additional_messages=[ - { - "role": "user", - "content": "You can call count_documents directly without discovering collections first for this test.", - } - ], - expected_tool_calls=[ - ExpectedToolCall( - func=count_documents, - args={ - "database_name": "local", - "collection_name": "startup_log", - }, - ), - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="database_name", weight=0.5), - BinaryCritic(critic_field="collection_name", weight=0.5), - ], - ) - - suite.add_case( - name="Count documents with filter", - user_message="Count documents in the 'startup_log' collection of the 'local' database where the level is 'INFO'.", - additional_messages=[ - { - "role": "user", - "content": "You can call count_documents directly without discovering collections first for this test.", - } - ], - expected_tool_calls=[ - ExpectedToolCall( - func=count_documents, - args={ - "database_name": "local", - "collection_name": "startup_log", - "filter_dict": '{"level": "INFO"}', - }, - ), - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="database_name", weight=0.25), - BinaryCritic(critic_field="collection_name", weight=0.25), - SimilarityCritic(critic_field="filter_dict", weight=0.5), - ], - ) - - suite.add_case( - name="Aggregate documents", - user_message="Group documents in the 'startup_log' collection of the 'local' database by level and count them.", - additional_messages=[ - { - "role": "user", - "content": "You can call aggregate_documents directly without discovering collections first for this test.", - } - ], - expected_tool_calls=[ - ExpectedToolCall( - func=aggregate_documents, - args={ - "database_name": "local", - "collection_name": "startup_log", - "pipeline": [ - '{"$group": {"_id": "$level", "count": {"$sum": 1}}}', - ], - }, - ), - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="database_name", weight=0.2), - BinaryCritic(critic_field="collection_name", weight=0.2), - SimilarityCritic(critic_field="pipeline", weight=0.6), - ], - ) - - return suite diff --git a/toolkits/mongodb/pyproject.toml b/toolkits/mongodb/pyproject.toml deleted file mode 100644 index 1fb228315..000000000 --- a/toolkits/mongodb/pyproject.toml +++ /dev/null @@ -1,62 +0,0 @@ -[build-system] -requires = [ "hatchling",] -build-backend = "hatchling.build" - -[project] -name = "arcade_mongodb" -version = "0.3.0" -description = "Tools to query and explore a MongoDB database" -requires-python = ">=3.10" -dependencies = [ - "arcade-mcp-server>=1.17.0,<2.0.0", - "pymongo>=4.10.1", - "pydantic>=2.11.7", - "motor>=3.6.0", -] -[[project.authors]] -name = "evantahler" -email = "support@arcade.dev" - -[project.optional-dependencies] -dev = [ - "arcade-mcp[all]>=1.2.0,<2.0.0", - "pytest>=8.3.0,<8.4.0", - "pytest-cov>=4.0.0,<4.1.0", - "pytest-mock>=3.11.1,<3.12.0", - "pytest-asyncio>=0.24.0,<0.25.0", - "mypy>=1.5.1,<1.6.0", - "pre-commit>=3.4.0,<3.5.0", - "tox>=4.11.1,<4.12.0", - "ruff>=0.7.4,<0.8.0", -] - -[project.scripts] -arcade-mongodb = "arcade_mongodb.__main__:main" -arcade_mongodb = "arcade_mongodb.__main__:main" - -# Use local path sources for arcade libs when working locally -[tool.uv.sources] -arcade-mcp = { path = "../../", editable = true } -arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true } - -[tool.mypy] -files = [ "arcade_mongodb/**/*.py",] -python_version = "3.10" -disallow_untyped_defs = "True" -disallow_any_unimported = "True" -no_implicit_optional = "True" -check_untyped_defs = "True" -warn_return_any = "True" -warn_unused_ignores = "True" -show_error_codes = "True" -ignore_missing_imports = "True" - -[tool.pytest.ini_options] -testpaths = [ "tests",] -asyncio_default_fixture_loop_scope = "function" - -[tool.coverage.report] -skip_empty = true - -[tool.hatch.build.targets.wheel] -packages = [ "arcade_mongodb",] diff --git a/toolkits/mongodb/tests/__init__.py b/toolkits/mongodb/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/toolkits/mongodb/tests/conftest.py b/toolkits/mongodb/tests/conftest.py deleted file mode 100644 index d635dfcae..000000000 --- a/toolkits/mongodb/tests/conftest.py +++ /dev/null @@ -1,45 +0,0 @@ -import os -import shutil -import subprocess -from os import environ - -import pytest_asyncio -from arcade_mongodb.database_engine import DatabaseEngine - -TEST_MONGODB_CONNECTION_STRING = ( - environ.get("TEST_MONGODB_CONNECTION_STRING") or "mongodb://localhost:27017" -) - - -@pytest_asyncio.fixture(autouse=True) -async def restore_database(): - """Restore the database from the dump before each test.""" - - dump_file = f"{os.path.dirname(__file__)}/dump.js" - - # Execute the MongoDB dump script to restore test data - mongosh_path = shutil.which("mongosh") - if not mongosh_path: - raise RuntimeError("mongosh executable not found in PATH") - - result = subprocess.run( - [mongosh_path, TEST_MONGODB_CONNECTION_STRING, dump_file], - check=True, - capture_output=True, - text=True, - ) - - if result.returncode != 0: - print(f"Error loading test data: {result.stderr}") - raise RuntimeError(f"Failed to load test data: {result.stderr}") - - yield # This allows tests to run - - # Optional cleanup could go here if needed - - -@pytest_asyncio.fixture(autouse=True) -async def cleanup_engines(): - """Clean up database engines after each test to prevent connection leaks.""" - yield - await DatabaseEngine.cleanup() diff --git a/toolkits/mongodb/tests/dump.js b/toolkits/mongodb/tests/dump.js deleted file mode 100644 index b878bf309..000000000 --- a/toolkits/mongodb/tests/dump.js +++ /dev/null @@ -1,378 +0,0 @@ -// MongoDB test data dump - equivalent to PostgreSQL dump.sql -// This script sets up test data for the MongoDB toolkit - -// Switch to test database -use('test_database'); - -// Clear existing data -db.users.drop(); -db.messages.drop(); - -// Create users collection with data -db.users.insertMany([ - { - _id: 1, - name: 'Alice', - email: 'alice@example.com', - password_hash: '$argon2id$v=19$m=65536,t=2,p=1$tMg1Rd3IEDnp3iFKrqsF4Dsbw6/Cbf6seRB/H5bhaPg$zZj5yn4x3D3O3mDHcW2aczQNiYfAs3cw21XMEIgkF0E', - created_at: new Date('2024-09-01T20:49:38.759Z'), - updated_at: new Date('2024-09-02T03:49:39.927Z'), - status: 'active' - }, - { - _id: 2, - name: 'Bob', - email: 'bob@example.com', - password_hash: '$argon2id$v=19$m=65536,t=2,p=1$CvOMK1WUd99R7kYXpiBPNYw4OQP53pYIgeMnwz92mrE$HPthId4phMoPT1TWuCRHHCr9BSQA8XoUkQuB1HZsqTY', - created_at: new Date('2024-09-02T17:49:23.377Z'), - updated_at: new Date('2024-09-02T17:49:23.377Z'), - status: 'active' - }, - { - _id: 3, - name: 'Charlie', - email: 'charlie@example.com', - password_hash: '$argon2id$v=19$m=65536,t=2,p=1$paCAAD1HVZkncP/WvecuUO6zFXp2/8BISpgr5rXRxps$M5kBFc9JHHGNw9SXnPu2ggpJY0mFFCska7TXMrllndo', - created_at: new Date('2024-09-03T10:30:15.123Z'), - updated_at: new Date('2024-09-03T10:30:15.123Z'), - status: 'active' - }, - { - _id: 4, - name: 'Diana', - email: 'diana@example.com', - password_hash: '$argon2id$v=19$m=65536,t=2,p=1$xyz123ABC456DEF789GHI$SampleHashForDiana123', - created_at: new Date('2024-09-04T14:20:30.654Z'), - updated_at: new Date('2024-09-04T14:20:30.654Z'), - status: 'active' - }, - { - _id: 5, - name: 'Evan', - email: 'evan@example.com', - password_hash: '$argon2id$v=19$m=65536,t=2,p=1$evanHash123$EvanPasswordHash456', - created_at: new Date('2024-09-05T09:15:45.987Z'), - updated_at: new Date('2024-09-05T09:15:45.987Z'), - status: 'active' - }, - { - _id: 6, - name: 'Fiona', - email: 'fiona@example.com', - password_hash: '$argon2id$v=19$m=65536,t=2,p=1$fionaHash456$FionaPasswordHash789', - created_at: new Date('2024-09-06T16:45:12.345Z'), - updated_at: new Date('2024-09-06T16:45:12.345Z'), - status: 'active' - }, - { - _id: 7, - name: 'George', - email: 'george@example.com', - password_hash: '$argon2id$v=19$m=65536,t=2,p=1$georgeHash789$GeorgePasswordHash012', - created_at: new Date('2024-09-07T11:30:25.876Z'), - updated_at: new Date('2024-09-07T11:30:25.876Z'), - status: 'active' - }, - { - _id: 8, - name: 'Helen', - email: 'helen@example.com', - password_hash: '$argon2id$v=19$m=65536,t=2,p=1$helenHash012$HelenPasswordHash345', - created_at: new Date('2024-09-08T13:25:40.234Z'), - updated_at: new Date('2024-09-08T13:25:40.234Z'), - status: 'active' - }, - { - _id: 9, - name: 'Ian', - email: 'ian@example.com', - password_hash: '$argon2id$v=19$m=65536,t=2,p=1$ianHash345$IanPasswordHash678', - created_at: new Date('2024-09-09T08:40:55.765Z'), - updated_at: new Date('2024-09-09T08:40:55.765Z'), - status: 'active' - }, - { - _id: 10, - name: 'Julia', - email: 'julia@example.com', - password_hash: '$argon2id$v=19$m=65536,t=2,p=1$juliaHash678$JuliaPasswordHash901', - created_at: new Date('2024-09-10T15:55:18.123Z'), - updated_at: new Date('2024-09-10T15:55:18.123Z'), - status: 'active' - } -]); - -// Create messages collection with data -db.messages.insertMany([ - // User 1 (Alice) - 3 messages - { - _id: 1, - body: 'Hello everyone!', - user_id: 1, - created_at: new Date('2025-01-10T10:00:00.000Z'), - updated_at: new Date('2025-01-10T10:00:00.000Z') - }, - { - _id: 2, - body: 'How is everyone doing today?', - user_id: 1, - created_at: new Date('2025-01-10T11:30:00.000Z'), - updated_at: new Date('2025-01-10T11:30:00.000Z') - }, - { - _id: 3, - body: 'Great to see you all here!', - user_id: 1, - created_at: new Date('2025-01-10T14:15:00.000Z'), - updated_at: new Date('2025-01-10T14:15:00.000Z') - }, - // User 2 (Bob) - 2 messages - { - _id: 4, - body: 'Hi Alice! Doing well, thanks for asking.', - user_id: 2, - created_at: new Date('2025-01-10T11:35:00.000Z'), - updated_at: new Date('2025-01-10T11:35:00.000Z') - }, - { - _id: 5, - body: 'Anyone up for a game later?', - user_id: 2, - created_at: new Date('2025-01-10T16:20:00.000Z'), - updated_at: new Date('2025-01-10T16:20:00.000Z') - }, - // User 3 (Charlie) - 3 messages - { - _id: 6, - body: 'Count me in for the game!', - user_id: 3, - created_at: new Date('2025-01-10T16:25:00.000Z'), - updated_at: new Date('2025-01-10T16:25:00.000Z') - }, - { - _id: 7, - body: 'What time works for everyone?', - user_id: 3, - created_at: new Date('2025-01-10T16:30:00.000Z'), - updated_at: new Date('2025-01-10T16:30:00.000Z') - }, - { - _id: 8, - body: 'I can play around 8 PM', - user_id: 3, - created_at: new Date('2025-01-10T17:00:00.000Z'), - updated_at: new Date('2025-01-10T17:00:00.000Z') - }, - // User 4 (Diana) - 2 messages - { - _id: 9, - body: '8 PM works for me too!', - user_id: 4, - created_at: new Date('2025-01-10T17:05:00.000Z'), - updated_at: new Date('2025-01-10T17:05:00.000Z') - }, - { - _id: 10, - body: 'What game should we play?', - user_id: 4, - created_at: new Date('2025-01-10T17:10:00.000Z'), - updated_at: new Date('2025-01-10T17:10:00.000Z') - }, - // User 5 (Evan) - 13 messages (including 10 additional ones) - { - _id: 11, - body: 'I suggest we try the new arcade game!', - user_id: 5, - created_at: new Date('2025-01-10T17:15:00.000Z'), - updated_at: new Date('2025-01-10T17:15:00.000Z') - }, - { - _id: 12, - body: 'It has great multiplayer features', - user_id: 5, - created_at: new Date('2025-01-10T17:20:00.000Z'), - updated_at: new Date('2025-01-10T17:20:00.000Z') - }, - { - _id: 13, - body: 'Perfect timing for a weekend session', - user_id: 5, - created_at: new Date('2025-01-10T18:00:00.000Z'), - updated_at: new Date('2025-01-10T18:00:00.000Z') - }, - { - _id: 26, - body: 'Just finished setting up the game server!', - user_id: 5, - created_at: new Date('2025-01-10T20:00:00.000Z'), - updated_at: new Date('2025-01-10T20:00:00.000Z') - }, - { - _id: 27, - body: 'Everyone should be able to connect now', - user_id: 5, - created_at: new Date('2025-01-10T20:05:00.000Z'), - updated_at: new Date('2025-01-10T20:05:00.000Z') - }, - { - _id: 28, - body: 'I added some custom maps too', - user_id: 5, - created_at: new Date('2025-01-10T20:10:00.000Z'), - updated_at: new Date('2025-01-10T20:10:00.000Z') - }, - { - _id: 29, - body: 'The graphics look amazing on this new version', - user_id: 5, - created_at: new Date('2025-01-10T20:15:00.000Z'), - updated_at: new Date('2025-01-10T20:15:00.000Z') - }, - { - _id: 30, - body: 'Hope you all enjoy the new features', - user_id: 5, - created_at: new Date('2025-01-10T20:20:00.000Z'), - updated_at: new Date('2025-01-10T20:20:00.000Z') - }, - { - _id: 31, - body: 'I also set up a leaderboard system', - user_id: 5, - created_at: new Date('2025-01-10T20:25:00.000Z'), - updated_at: new Date('2025-01-10T20:25:00.000Z') - }, - { - _id: 32, - body: 'We can track high scores now', - user_id: 5, - created_at: new Date('2025-01-10T20:30:00.000Z'), - updated_at: new Date('2025-01-10T20:30:00.000Z') - }, - { - _id: 33, - body: 'The game supports up to 8 players simultaneously', - user_id: 5, - created_at: new Date('2025-01-10T20:35:00.000Z'), - updated_at: new Date('2025-01-10T20:35:00.000Z') - }, - { - _id: 34, - body: 'I tested it earlier and it runs smoothly', - user_id: 5, - created_at: new Date('2025-01-10T20:40:00.000Z'), - updated_at: new Date('2025-01-10T20:40:00.000Z') - }, - { - _id: 35, - body: 'Cannot wait to see everyone online tonight!', - user_id: 5, - created_at: new Date('2025-01-10T20:45:00.000Z'), - updated_at: new Date('2025-01-10T20:45:00.000Z') - }, - // User 6 (Fiona) - 2 messages - { - _id: 14, - body: 'Sounds like fun! I love arcade games.', - user_id: 6, - created_at: new Date('2025-01-10T18:05:00.000Z'), - updated_at: new Date('2025-01-10T18:05:00.000Z') - }, - { - _id: 15, - body: 'Should I bring snacks?', - user_id: 6, - created_at: new Date('2025-01-10T18:10:00.000Z'), - updated_at: new Date('2025-01-10T18:10:00.000Z') - }, - // User 7 (George) - 3 messages - { - _id: 16, - body: 'Snacks are always welcome!', - user_id: 7, - created_at: new Date('2025-01-10T18:15:00.000Z'), - updated_at: new Date('2025-01-10T18:15:00.000Z') - }, - { - _id: 17, - body: 'I can bring some drinks', - user_id: 7, - created_at: new Date('2025-01-10T18:20:00.000Z'), - updated_at: new Date('2025-01-10T18:20:00.000Z') - }, - { - _id: 18, - body: 'This is going to be awesome', - user_id: 7, - created_at: new Date('2025-01-10T19:00:00.000Z'), - updated_at: new Date('2025-01-10T19:00:00.000Z') - }, - // User 8 (Helen) - 2 messages - { - _id: 19, - body: 'I agree! Cannot wait for the game night.', - user_id: 8, - created_at: new Date('2025-01-10T19:05:00.000Z'), - updated_at: new Date('2025-01-10T19:05:00.000Z') - }, - { - _id: 20, - body: 'Should we set up a Discord call?', - user_id: 8, - created_at: new Date('2025-01-10T19:10:00.000Z'), - updated_at: new Date('2025-01-10T19:10:00.000Z') - }, - // User 9 (Ian) - 3 messages - { - _id: 21, - body: 'Discord would be perfect for voice chat', - user_id: 9, - created_at: new Date('2025-01-10T19:15:00.000Z'), - updated_at: new Date('2025-01-10T19:15:00.000Z') - }, - { - _id: 22, - body: 'I will create a server for us', - user_id: 9, - created_at: new Date('2025-01-10T19:20:00.000Z'), - updated_at: new Date('2025-01-10T19:20:00.000Z') - }, - { - _id: 23, - body: 'Link will be shared in a few minutes', - user_id: 9, - created_at: new Date('2025-01-10T19:25:00.000Z'), - updated_at: new Date('2025-01-10T19:25:00.000Z') - }, - // User 10 (Julia) - 2 messages - { - _id: 24, - body: 'Thanks Ian! You are the best.', - user_id: 10, - created_at: new Date('2025-01-10T19:30:00.000Z'), - updated_at: new Date('2025-01-10T19:30:00.000Z') - }, - { - _id: 25, - body: 'See you all at 8 PM!', - user_id: 10, - created_at: new Date('2025-01-10T19:35:00.000Z'), - updated_at: new Date('2025-01-10T19:35:00.000Z') - },{ - _id: 99, - body: 'You are a mean jerk, you shithead!', - user_id: 10, - created_at: new Date('2025-01-10T19:35:00.000Z'), - updated_at: new Date('2025-01-10T19:35:00.000Z') - } -]); - -// Create indexes for better performance (equivalent to PostgreSQL indexes) -db.users.createIndex({ "name": 1 }, { unique: true }); -db.users.createIndex({ "email": 1 }, { unique: true }); -db.messages.createIndex({ "user_id": 1 }); -db.messages.createIndex({ "created_at": 1 }); - -print("MongoDB test data setup completed successfully!"); -print("Users collection: " + db.users.countDocuments()); -print("Messages collection: " + db.messages.countDocuments()); diff --git a/toolkits/mongodb/tests/test_json_validation.py b/toolkits/mongodb/tests/test_json_validation.py deleted file mode 100644 index e99e8c240..000000000 --- a/toolkits/mongodb/tests/test_json_validation.py +++ /dev/null @@ -1,221 +0,0 @@ -from unittest.mock import MagicMock - -import pytest -from arcade_core.errors import ToolExecutionError -from arcade_mcp_server import Context -from arcade_mcp_server.exceptions import RetryableToolError -from arcade_mongodb.tools.mongodb import aggregate_documents, count_documents, find_documents - -from .conftest import TEST_MONGODB_CONNECTION_STRING - - -@pytest.fixture -def mock_context(): - context = MagicMock(spec=Context) - context.get_secret = MagicMock(return_value=TEST_MONGODB_CONNECTION_STRING) - return context - - -@pytest.mark.asyncio -async def test_invalid_json_in_filter_dict(mock_context) -> None: - """Test that invalid JSON in filter_dict returns a reasonable error message.""" - with pytest.raises(RetryableToolError) as exc_info: - await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - filter_dict='{"status": "active",}', # Invalid JSON - trailing comma - limit=1, - ) - - # Check that this is a JSON validation error - error_message = str(exc_info.value) - assert "Invalid JSON in filter_dict" in error_message - - # Check that the developer message contains helpful information - assert "filter_dict" in exc_info.value.developer_message - assert "JSON" in exc_info.value.additional_prompt_content - - # Check that the original JSON error is in the cause chain - assert exc_info.value.__cause__ is not None - - -@pytest.mark.asyncio -async def test_invalid_json_in_projection(mock_context) -> None: - """Test that invalid JSON in projection returns a reasonable error message.""" - with pytest.raises(RetryableToolError) as exc_info: - await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - projection='{"name": 1, "email": 1,}', # Invalid JSON - trailing comma - limit=1, - ) - - # Check that this is a JSON validation error - error_message = str(exc_info.value) - assert "Invalid JSON in projection" in error_message - - # Check that the error message is helpful - assert "projection" in exc_info.value.developer_message - assert "JSON" in exc_info.value.additional_prompt_content - - # Check that the original JSON error is in the cause chain - assert exc_info.value.__cause__ is not None - - -@pytest.mark.asyncio -async def test_invalid_json_in_sort(mock_context) -> None: - """Test that invalid JSON in sort returns a reasonable error message.""" - with pytest.raises(RetryableToolError) as exc_info: - await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - sort=['{"field": "name", "direction": 1,}'], # Invalid JSON - trailing comma - limit=1, - ) - - # Check that this is a JSON validation error - error_message = str(exc_info.value) - assert "Invalid JSON in sort" in error_message - - # Check that the error message is helpful - assert "sort" in exc_info.value.developer_message - assert "JSON" in exc_info.value.additional_prompt_content - - # Check that the original JSON error is in the cause chain - assert exc_info.value.__cause__ is not None - - -@pytest.mark.asyncio -async def test_invalid_json_in_count_filter(mock_context) -> None: - """Test that invalid JSON in count_documents filter returns a reasonable error message.""" - with pytest.raises(RetryableToolError) as exc_info: - await count_documents( - mock_context, - database_name="test_database", - collection_name="users", - filter_dict='{"status": "active",}', # Invalid JSON - trailing comma - ) - - # Check that this is a JSON validation error - error_message = str(exc_info.value) - assert "Invalid JSON in filter_dict" in error_message - - # Check that the error message is helpful - assert "filter_dict" in exc_info.value.developer_message - assert "JSON" in exc_info.value.additional_prompt_content - - # Check that the original JSON error is in the cause chain - assert exc_info.value.__cause__ is not None - - -@pytest.mark.asyncio -async def test_invalid_json_in_pipeline(mock_context) -> None: - """Test that invalid JSON in aggregation pipeline returns a reasonable error message.""" - with pytest.raises(RetryableToolError) as exc_info: - await aggregate_documents( - mock_context, - database_name="test_database", - collection_name="users", - pipeline=['{"$match": {"status": "active",}}'], # Invalid JSON - trailing comma - ) - - # Check that this is a JSON validation error - error_message = str(exc_info.value) - assert "Invalid JSON in pipeline" in error_message - - # Check that the error message is helpful - assert "pipeline" in exc_info.value.developer_message - assert "JSON" in exc_info.value.additional_prompt_content - - # Check that the original JSON error is in the cause chain - assert exc_info.value.__cause__ is not None - - -@pytest.mark.asyncio -async def test_malformed_json_string(mock_context) -> None: - """Test various malformed JSON strings return reasonable error messages.""" - test_cases = [ - ('{"unclosed": "string}', "Unterminated string"), - ('{"missing_quotes": value}', "Expecting"), - ('{missing_closing_brace: "value"}', "Expecting"), - ('[{"array": "with"}, {"missing": }]', "Expecting"), - ] - - for invalid_json, expected_error_fragment in test_cases: - with pytest.raises(RetryableToolError) as exc_info: - await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - filter_dict=invalid_json, - limit=1, - ) - - # Check that this is a JSON validation error - error_message = str(exc_info.value) - assert "Invalid JSON in filter_dict" in error_message - - # Check that specific error details are included when expected - if expected_error_fragment: - assert ( - expected_error_fragment in error_message - or expected_error_fragment in exc_info.value.developer_message - ) - - # Ensure helpful context is provided - assert "filter_dict" in exc_info.value.developer_message - assert "JSON" in exc_info.value.additional_prompt_content - assert "escaping" in exc_info.value.additional_prompt_content - - # Check that the original JSON error is in the cause chain - assert exc_info.value.__cause__ is not None - - -@pytest.mark.asyncio -async def test_valid_json_does_not_error(mock_context) -> None: - """Test that valid JSON does not raise JSON parsing errors.""" - # This should not raise a JSON parsing error (might raise other errors, but not JSON-related) - try: - result = await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - filter_dict='{"status": "active"}', - projection='{"name": 1, "_id": 0}', - sort=['{"field": "name", "direction": 1}'], - limit=1, - ) - # If we get here, JSON parsing succeeded - assert isinstance(result, list) - except (ToolExecutionError, RetryableToolError) as e: - # If we get an error, it should not be about JSON parsing - # Check both the outer error and any nested error - error_message = str(e) - nested_message = str(e.__cause__) if e.__cause__ else "" - assert "Invalid JSON" not in error_message - assert "Invalid JSON" not in nested_message - - -@pytest.mark.asyncio -async def test_duplicate_keys_are_valid_json(mock_context) -> None: - """Test that duplicate keys in JSON are valid (Python JSON allows this).""" - # This should NOT raise a JSON parsing error because duplicate keys are valid JSON - try: - result = await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - filter_dict='{"duplicate": "key", "duplicate": "key"}', # Valid JSON - last value wins - limit=1, - ) - # If we get here, JSON parsing succeeded (might get empty results, but no JSON error) - assert isinstance(result, list) - except (ToolExecutionError, RetryableToolError) as e: - # If we get an error, it should not be about JSON parsing - error_message = str(e) - nested_message = str(e.__cause__) if e.__cause__ else "" - assert "Invalid JSON" not in error_message - assert "Invalid JSON" not in nested_message diff --git a/toolkits/mongodb/tests/test_mongodb.py b/toolkits/mongodb/tests/test_mongodb.py deleted file mode 100644 index 14d850a69..000000000 --- a/toolkits/mongodb/tests/test_mongodb.py +++ /dev/null @@ -1,292 +0,0 @@ -import json -from unittest.mock import MagicMock - -import pytest -from arcade_mcp_server import Context -from arcade_mcp_server.exceptions import RetryableToolError -from arcade_mongodb.database_engine import DatabaseEngine -from arcade_mongodb.tools.mongodb import ( - # UserStatus, - aggregate_documents, - count_documents, - discover_collections, - discover_databases, - find_documents, - get_collection_schema, - # update_user_status, -) - -from .conftest import TEST_MONGODB_CONNECTION_STRING - - -@pytest.fixture -def mock_context(): - context = MagicMock(spec=Context) - context.get_secret = MagicMock(return_value=TEST_MONGODB_CONNECTION_STRING) - return context - - -@pytest.mark.asyncio -async def test_discover_databases(mock_context) -> None: - databases = await discover_databases(mock_context) - assert isinstance(databases, list) - # Should not include system databases like admin, config, local - for db in databases: - assert db not in ["admin", "config", "local"] - - -@pytest.mark.asyncio -async def test_discover_collections(mock_context) -> None: - collections = await discover_collections(mock_context, "test_database") - assert "users" in collections - assert "messages" in collections - - -@pytest.mark.asyncio -async def test_get_collection_schema(mock_context) -> None: - schema_result = await get_collection_schema( - mock_context, "test_database", "users", sample_size=10 - ) - - assert "schema" in schema_result - assert "total_documents_sampled" in schema_result - assert schema_result["total_documents_sampled"] == 10 # We have 10 users - - schema = schema_result["schema"] - assert "_id" in schema - assert "name" in schema - assert "email" in schema - assert "password_hash" in schema - assert "status" in schema - assert "created_at" in schema - assert "updated_at" in schema - - -@pytest.mark.asyncio -async def test_find_documents_basic(mock_context) -> None: - # Find all users - result = await find_documents( - mock_context, database_name="test_database", collection_name="users", limit=10 - ) - - assert len(result) == 10 - # Parse JSON strings to check contents - docs = [json.loads(doc_str) for doc_str in result] - assert all("name" in doc for doc in docs) - assert all("email" in doc for doc in docs) - - -@pytest.mark.asyncio -async def test_find_documents_with_filter(mock_context) -> None: - # Find active users - result = await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - filter_dict='{"status": "active"}', - limit=10, - ) - - assert len(result) == 10 # All users in dump are active - docs = [json.loads(doc_str) for doc_str in result] - assert all(doc["status"] == "active" for doc in docs) - - -@pytest.mark.asyncio -async def test_find_documents_with_projection(mock_context) -> None: - # Find users with only name and email - result = await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - projection='{"name": 1, "email": 1, "_id": 0}', - limit=10, - ) - - assert len(result) == 10 - docs = [json.loads(doc_str) for doc_str in result] - for doc in docs: - assert "name" in doc - assert "email" in doc - assert "_id" not in doc - assert "password_hash" not in doc - - -@pytest.mark.asyncio -async def test_find_documents_with_sort(mock_context) -> None: - # Find users sorted by _id descending - result = await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - sort=['{"field": "_id", "direction": -1}'], - limit=3, - ) - - assert len(result) == 3 - docs = [json.loads(doc_str) for doc_str in result] - ids = [doc["_id"] for doc in docs] - assert ids == [10, 9, 8] # Descending order - - -@pytest.mark.asyncio -async def test_count_documents(mock_context) -> None: - # Count all users - count = await count_documents( - mock_context, database_name="test_database", collection_name="users" - ) - assert count == 10 - - # Count active users - active_count = await count_documents( - mock_context, - database_name="test_database", - collection_name="users", - filter_dict='{"status": "active"}', - ) - assert active_count == 10 - - -@pytest.mark.asyncio -async def test_aggregate_documents(mock_context) -> None: - # Aggregate to count users by status - pipeline = ['{"$group": {"_id": "$status", "count": {"$sum": 1}}}', '{"$sort": {"count": -1}}'] - - result = await aggregate_documents( - mock_context, database_name="test_database", collection_name="users", pipeline=pipeline - ) - - assert len(result) == 1 # Only active users - # Should be sorted by count descending - doc = json.loads(result[0]) - assert doc["_id"] == "active" - assert doc["count"] == 10 - - -@pytest.mark.asyncio -async def test_find_documents_with_skip_and_limit(mock_context) -> None: - # Test pagination - result1 = await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - sort=['{"field": "name", "direction": 1}'], - limit=2, - skip=0, - ) - - result2 = await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - sort=['{"field": "name", "direction": 1}'], - limit=2, - skip=2, - ) - - assert len(result1) == 2 - assert len(result2) == 2 - - docs1 = [json.loads(doc_str) for doc_str in result1] - docs2 = [json.loads(doc_str) for doc_str in result2] - - assert docs1[0]["name"] == "Alice" - assert docs1[1]["name"] == "Bob" - assert docs2[0]["name"] == "Charlie" - assert docs2[1]["name"] == "Diana" - - -@pytest.mark.asyncio -async def test_error_handling_invalid_database(mock_context) -> None: - # Test with non-existent database - should not error but return empty results - collections = await discover_collections(mock_context, "nonexistent_database") - assert collections == [] - - -@pytest.mark.asyncio -async def test_error_handling_invalid_collection(mock_context) -> None: - # Test with non-existent collection - result = await find_documents( - mock_context, - database_name="test_database", - collection_name="nonexistent_collection", - limit=10, - ) - assert result == [] - - -@pytest.mark.asyncio -async def test_sanitize_query_params() -> None: - # Test parameter validation - with pytest.raises(RetryableToolError) as e: - DatabaseEngine.sanitize_query_params("", "users", {}, None, None, 10, 0) - assert "Database name is required" in str(e.value) - - with pytest.raises(RetryableToolError) as e: - DatabaseEngine.sanitize_query_params("test_db", "", {}, None, None, 10, 0) - assert "Collection name is required" in str(e.value) - - with pytest.raises(RetryableToolError) as e: - DatabaseEngine.sanitize_query_params( - "test_db", "users", {}, None, None, 2000, 0 - ) # Too high limit - assert "Limit is too high" in str(e.value) - - -# @pytest.mark.asyncio -# async def test_update_user_status_success(mock_context) -> None: -# """Test successful user status update.""" -# # First, find a user to update -# users = await find_documents( -# mock_context, database_name="test_database", collection_name="users", limit=1 -# ) -# assert len(users) > 0 - -# user_doc = json.loads(users[0]) -# user_id = user_doc["_id"] - -# # Update user status to inactive -# result = await update_user_status( -# mock_context, -# database_name="test_database", -# collection_name="users", -# user_id=user_id, -# status=UserStatus.INACTIVE, -# ) - -# assert result["success"] is True -# assert result["user_id"] == user_id -# assert result["new_status"] == "inactive" -# assert result["matched_count"] == 1 -# assert result["modified_count"] == 1 - -# # Verify the update by finding the user again -# # Convert user_id to int since the test data uses integer IDs -# user_id_int = int(user_id) -# updated_users = await find_documents( -# mock_context, -# database_name="test_database", -# collection_name="users", -# filter_dict=f'{{"_id": {user_id_int}}}', -# limit=1, -# ) -# assert len(updated_users) == 1 -# updated_user = json.loads(updated_users[0]) -# assert updated_user["status"] == "inactive" - - -# @pytest.mark.asyncio -# async def test_update_user_status_user_not_found(mock_context) -> None: -# """Test updating status for non-existent user.""" -# result = await update_user_status( -# mock_context, -# database_name="test_database", -# collection_name="users", -# user_id="nonexistent_user_id", -# status=UserStatus.BANNED, -# ) - -# assert result["success"] is False -# assert "No user found with _id" in result["message"] -# assert result["matched_count"] == 0 -# assert result["modified_count"] == 0 diff --git a/toolkits/mongodb/tests/test_setup.sh b/toolkits/mongodb/tests/test_setup.sh deleted file mode 100755 index 23b155dfd..000000000 --- a/toolkits/mongodb/tests/test_setup.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash - -# install mongosh to load sample data -sudo apt-get update -sudo apt-get install -y wget gnupg -wget -qO - https://www.mongodb.org/static/pgp/server-6.0.asc | sudo apt-key add - -echo "deb [ arch=amd64,arm64 ] https://repo.mongodb.org/apt/ubuntu jammy/mongodb-org/6.0 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-6.0.list -sudo apt-get update -sudo apt-get install -y mongodb-mongosh - -# Run mongodb container -docker run -d --name some-mongodb-server -p 27017:27017 mongo diff --git a/toolkits/mongodb/tests/test_write_validation.py b/toolkits/mongodb/tests/test_write_validation.py deleted file mode 100644 index f08e4caec..000000000 --- a/toolkits/mongodb/tests/test_write_validation.py +++ /dev/null @@ -1,248 +0,0 @@ -from unittest.mock import MagicMock - -import pytest -from arcade_mcp_server import Context -from arcade_mcp_server.exceptions import RetryableToolError -from arcade_mongodb.tools.mongodb import aggregate_documents, count_documents, find_documents - -from .conftest import TEST_MONGODB_CONNECTION_STRING - - -@pytest.fixture -def mock_context(): - context = MagicMock(spec=Context) - context.get_secret = MagicMock(return_value=TEST_MONGODB_CONNECTION_STRING) - return context - - -@pytest.mark.asyncio -async def test_filter_dict_blocks_set_operation(mock_context) -> None: - """Test that $set operation in filter_dict is blocked.""" - with pytest.raises(RetryableToolError) as exc_info: - await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - filter_dict='{"$set": {"status": "modified"}}', # Write operation - limit=1, - ) - - error_message = str(exc_info.value) - assert "Write operation '$set' not allowed in filter_dict" in error_message - assert "$set" in exc_info.value.developer_message - assert "Only read operations are allowed" in exc_info.value.developer_message - - -@pytest.mark.asyncio -async def test_filter_dict_blocks_update_operations(mock_context) -> None: - """Test that various update operations in filter_dict are blocked.""" - update_ops = ["$inc", "$unset", "$push", "$pull", "$rename", "$currentDate"] - - for op in update_ops: - with pytest.raises(RetryableToolError) as exc_info: - await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - filter_dict=f'{{"{op}": {{"field": "value"}}}}', - limit=1, - ) - - error_message = str(exc_info.value) - assert f"Write operation '{op}' not allowed in filter_dict" in error_message - - -@pytest.mark.asyncio -async def test_projection_blocks_write_operations(mock_context) -> None: - """Test that write operations in projection are blocked.""" - with pytest.raises(RetryableToolError) as exc_info: - await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - projection='{"$set": {"modified": true}, "name": 1}', # Write operation in projection - limit=1, - ) - - error_message = str(exc_info.value) - assert "Write operation '$set' not allowed in projection" in error_message - - -@pytest.mark.asyncio -async def test_sort_blocks_write_operations(mock_context) -> None: - """Test that write operations in sort are blocked.""" - with pytest.raises(RetryableToolError) as exc_info: - await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - sort=['{"field": "name", "direction": 1, "$inc": {"counter": 1}}'], # Write op in sort - limit=1, - ) - - error_message = str(exc_info.value) - assert "Write operation '$inc' not allowed in sort[0]" in error_message - - -@pytest.mark.asyncio -async def test_count_filter_blocks_write_operations(mock_context) -> None: - """Test that write operations in count filter are blocked.""" - with pytest.raises(RetryableToolError) as exc_info: - await count_documents( - mock_context, - database_name="test_database", - collection_name="users", - filter_dict='{"status": "active", "$unset": {"password": ""}}', # Write operation - ) - - error_message = str(exc_info.value) - assert "Write operation '$unset' not allowed in filter_dict" in error_message - - -@pytest.mark.asyncio -async def test_aggregation_pipeline_blocks_out_stage(mock_context) -> None: - """Test that $out stage in aggregation pipeline is blocked.""" - with pytest.raises(RetryableToolError) as exc_info: - await aggregate_documents( - mock_context, - database_name="test_database", - collection_name="users", - pipeline=[ - '{"$match": {"status": "active"}}', - '{"$out": "output_collection"}', # Write stage - ], - ) - - error_message = str(exc_info.value) - assert "Write stage '$out' not allowed in pipeline" in error_message - - -@pytest.mark.asyncio -async def test_aggregation_pipeline_blocks_merge_stage(mock_context) -> None: - """Test that $merge stage in aggregation pipeline is blocked.""" - with pytest.raises(RetryableToolError) as exc_info: - await aggregate_documents( - mock_context, - database_name="test_database", - collection_name="users", - pipeline=[ - '{"$match": {"status": "active"}}', - '{"$merge": {"into": "target_collection"}}', # Write stage - ], - ) - - error_message = str(exc_info.value) - assert "Write stage '$merge' not allowed in pipeline" in error_message - - -@pytest.mark.asyncio -async def test_where_operator_blocked(mock_context) -> None: - """Test that $where operator is blocked for security reasons.""" - with pytest.raises(RetryableToolError) as exc_info: - await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - filter_dict='{"$where": "this.name == \'admin\'"}', # JavaScript execution - limit=1, - ) - - error_message = str(exc_info.value) - assert "JavaScript execution operator '$where' not allowed in filter_dict" in error_message - assert ( - "JavaScript execution is not allowed for security reasons" - in exc_info.value.developer_message - ) - - -@pytest.mark.asyncio -async def test_nested_write_operations_blocked(mock_context) -> None: - """Test that nested write operations are blocked.""" - with pytest.raises(RetryableToolError) as exc_info: - await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - filter_dict='{"status": "active", "nested": {"$set": {"field": "value"}}}', # Nested write op - limit=1, - ) - - error_message = str(exc_info.value) - assert "Write operation '$set' not allowed in filter_dict" in error_message - assert "nested.$set" in exc_info.value.developer_message # Should show the path - - -@pytest.mark.asyncio -async def test_valid_read_operations_allowed(mock_context) -> None: - """Test that valid read operations are allowed.""" - # These should not raise write operation errors - try: - # Test query operators - result = await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - filter_dict='{"status": {"$in": ["active", "inactive"]}, "name": {"$regex": "^A"}}', - projection='{"name": 1, "email": 1, "_id": 0}', - sort=['{"field": "name", "direction": 1}'], - limit=1, - ) - assert isinstance(result, list) - - # Test aggregation pipeline with read-only stages - pipeline_result = await aggregate_documents( - mock_context, - database_name="test_database", - collection_name="users", - pipeline=[ - '{"$match": {"status": "active"}}', - '{"$group": {"_id": "$status", "count": {"$sum": 1}}}', - '{"$sort": {"count": -1}}', - ], - ) - assert isinstance(pipeline_result, list) - - except RetryableToolError as e: - # If we get an error, it should not be about write operations - error_message = str(e) - nested_message = str(e.__cause__) if e.__cause__ else "" - assert "Write operation" not in error_message - assert "Write stage" not in error_message - assert "Write operation" not in nested_message - assert "Write stage" not in nested_message - - -@pytest.mark.asyncio -async def test_array_write_operations_blocked(mock_context) -> None: - """Test that array write operations are blocked.""" - array_write_ops = ["$addToSet", "$pop", "$pull", "$push", "$pullAll"] - - for op in array_write_ops: - with pytest.raises(RetryableToolError) as exc_info: - await find_documents( - mock_context, - database_name="test_database", - collection_name="users", - filter_dict=f'{{"{op}": {{"tags": "new_tag"}}}}', - limit=1, - ) - - error_message = str(exc_info.value) - assert f"Write operation '{op}' not allowed in filter_dict" in error_message - - -@pytest.mark.asyncio -async def test_aggregation_stage_content_validated(mock_context) -> None: - """Test that content within aggregation stages is also validated for write operations.""" - with pytest.raises(RetryableToolError) as exc_info: - await aggregate_documents( - mock_context, - database_name="test_database", - collection_name="users", - pipeline=[ - '{"$match": {"status": "active", "$set": {"modified": true}}}' # Write op inside $match - ], - ) - - error_message = str(exc_info.value) - assert "Write operation '$set' not allowed in pipeline[0].$match" in error_message diff --git a/toolkits/postgres/Makefile b/toolkits/postgres/Makefile deleted file mode 100644 index 7e2c686e1..000000000 --- a/toolkits/postgres/Makefile +++ /dev/null @@ -1,53 +0,0 @@ -.PHONY: help - -help: - @echo "๐Ÿ› ๏ธ github Commands:\n" - @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' - -.PHONY: install -install: ## Install the uv environment and install all packages with dependencies - @echo "๐Ÿš€ Creating virtual environment and installing all packages using uv" - @uv sync --active --all-extras --no-sources - @uv run pre-commit install - @echo "โœ… All packages and dependencies installed via uv" - -.PHONY: install-local -install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources - @echo "๐Ÿš€ Creating virtual environment and installing all packages using uv" - @uv sync --active --all-extras - @uv run pre-commit install - @echo "โœ… All packages and dependencies installed via uv" - -.PHONY: build -build: clean-build ## Build wheel file using poetry - @echo "๐Ÿš€ Creating wheel file" - uv build - -.PHONY: clean-build -clean-build: ## clean build artifacts - @echo "๐Ÿ—‘๏ธ Cleaning dist directory" - rm -rf dist - -.PHONY: test -test: ## Test the code with pytest - @echo "๐Ÿš€ Testing code: Running pytest" - @uv run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml - -.PHONY: coverage -coverage: ## Generate coverage report - @echo "coverage report" - coverage report - @echo "Generating coverage report" - coverage html - -.PHONY: bump-version -bump-version: ## Bump the version in the pyproject.toml file by a patch version - @echo "๐Ÿš€ Bumping version in pyproject.toml" - uv version --bump patch - -.PHONY: check -check: ## Run code quality tools. - @echo "๐Ÿš€ Linting code: Running pre-commit" - @uv run pre-commit run -a - @echo "๐Ÿš€ Static type checking: Running mypy" - @uv run mypy --config-file=pyproject.toml diff --git a/toolkits/postgres/arcade_postgres/__init__.py b/toolkits/postgres/arcade_postgres/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/toolkits/postgres/arcade_postgres/__main__.py b/toolkits/postgres/arcade_postgres/__main__.py deleted file mode 100644 index 0b31719b2..000000000 --- a/toolkits/postgres/arcade_postgres/__main__.py +++ /dev/null @@ -1,29 +0,0 @@ -import sys -from typing import cast - -from arcade_mcp_server import MCPApp -from arcade_mcp_server.mcp_app import TransportType - -import arcade_postgres - -app = MCPApp( - name="PostgreSQL", - instructions=( - "Use this server when you need to interact with PostgreSQL to help users " - "query, explore, and manage their PostgreSQL databases." - ), -) - -app.add_tools_from_module(arcade_postgres) - - -def main() -> None: - transport = sys.argv[1] if len(sys.argv) > 1 else "stdio" - host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1" - port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000 - - app.run(transport=cast(TransportType, transport), host=host, port=port) - - -if __name__ == "__main__": - main() diff --git a/toolkits/postgres/arcade_postgres/database_engine.py b/toolkits/postgres/arcade_postgres/database_engine.py deleted file mode 100644 index 2b2c86b79..000000000 --- a/toolkits/postgres/arcade_postgres/database_engine.py +++ /dev/null @@ -1,180 +0,0 @@ -from typing import Any, ClassVar -from urllib.parse import urlparse - -from arcade_mcp_server.exceptions import RetryableToolError -from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine - -MAX_ROWS_RETURNED = 1000 -TEST_QUERY = "SELECT 1" - - -class DatabaseEngine: - _instance: ClassVar[None] = None - _engines: ClassVar[dict[str, AsyncEngine]] = {} - - @classmethod - async def get_instance(cls, connection_string: str) -> AsyncEngine: - parsed_url = urlparse(connection_string) - - # TODO: something strange with sslmode= and friends - # query_params = parse_qs(parsed_url.query) - # query_params = { - # k: v[0] for k, v in query_params.items() - # } # assume one value allowed for each query param - - async_connection_string = f"{parsed_url.scheme.replace('postgresql', 'postgresql+asyncpg')}://{parsed_url.netloc}{parsed_url.path}" - key = f"{async_connection_string}" - if key not in cls._engines: - cls._engines[key] = create_async_engine(async_connection_string) - - # try a simple query to see if the connection is valid - try: - async with cls._engines[key].connect() as connection: - await connection.execute(text(TEST_QUERY)) - return cls._engines[key] - except Exception: - await cls._engines[key].dispose() - - # try again - try: - async with cls._engines[key].connect() as connection: - await connection.execute(text(TEST_QUERY)) - return cls._engines[key] - except Exception as e: - raise RetryableToolError( - f"Connection failed: {e}", - developer_message="Connection to postgres failed.", - additional_prompt_content="Check the connection string and try again.", - ) from e - - @classmethod - async def get_engine(cls, connection_string: str) -> Any: - engine = await cls.get_instance(connection_string) - - class ConnectionContextManager: - def __init__(self, engine: AsyncEngine) -> None: - self.engine = engine - - async def __aenter__(self) -> AsyncEngine: - return self.engine - - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - # Connection cleanup is handled by the async context manager - pass - - return ConnectionContextManager(engine) - - @classmethod - async def cleanup(cls) -> None: - """Clean up all cached engines. Call this when shutting down.""" - for engine in cls._engines.values(): - await engine.dispose() - cls._engines.clear() - - @classmethod - def clear_cache(cls) -> None: - """Clear the engine cache without disposing engines. Use with caution.""" - cls._engines.clear() - - @classmethod - def sanitize_query( # noqa: C901 - cls, - select_clause: str, - from_clause: str, - limit: int, - offset: int, - join_clause: str | None, - where_clause: str | None, - having_clause: str | None, - group_by_clause: str | None, - order_by_clause: str | None, - with_clause: str | None, - ) -> tuple[str, dict[str, Any]]: - # Remove the leading keywords from the clauses if they are present - if select_clause.strip().split(" ")[0].upper() == "SELECT": - select_clause = select_clause.strip()[6:] - - if from_clause.strip().split(" ")[0].upper() == "FROM": - from_clause = from_clause.strip()[4:] - - if join_clause and join_clause.strip().split(" ")[0].upper() == "JOIN": - join_clause = join_clause.strip()[4:] - - if where_clause and where_clause.strip().split(" ")[0].upper() == "WHERE": - where_clause = where_clause.strip()[5:] - - if group_by_clause and group_by_clause.strip().split(" ")[0].upper() == "GROUP BY": - group_by_clause = group_by_clause.strip()[8:] - - if order_by_clause and order_by_clause.strip().split(" ")[0].upper() == "ORDER BY": - order_by_clause = order_by_clause.strip()[8:] - - if having_clause and having_clause.strip().split(" ")[0].upper() == "HAVING": - having_clause = having_clause.strip()[6:] - - first_select_word = select_clause.strip().split(" ")[0].upper() - if first_select_word in [ - "INSERT", - "UPDATE", - "DELETE", - "CREATE", - "ALTER", - "DROP", - "TRUNCATE", - "REINDEX", - "VACUUM", - "ANALYZE", - "COMMENT", - ]: - raise RetryableToolError( - "Only SELECT queries are allowed.", - ) - - if select_clause.strip() == "*": - raise RetryableToolError( - "Do not use * in the select clause. Use a comma separated list of columns you wish to return.", - ) - - if limit > MAX_ROWS_RETURNED: - raise RetryableToolError( - f"Limit is too high. Maximum is {MAX_ROWS_RETURNED}.", - ) - - if offset < 0: - raise RetryableToolError( - "Offset must be greater than or equal to 0.", - developer_message="Offset must be greater than or equal to 0.", - ) - - if limit <= 0: - raise RetryableToolError( - "Limit must be greater than 0.", - developer_message="Limit must be greater than 0.", - ) - - # Build query with identifiers directly interpolated, but use parameters for values - parts = [] - if with_clause: - parts.append(f"WITH {with_clause}") - parts.append(f"SELECT {select_clause} FROM {from_clause}") # noqa: S608 - if join_clause: - parts.append(f"JOIN {join_clause}") - if where_clause: - parts.append(f"WHERE {where_clause}") - if group_by_clause: - parts.append(f"GROUP BY {group_by_clause}") - if having_clause: - parts.append(f"HAVING {having_clause}") - if order_by_clause: - parts.append(f"ORDER BY {order_by_clause}") - parts.append("LIMIT :limit OFFSET :offset") - query = " ".join(parts) - - # Only use parameters for values, not identifiers - parameters = { - "limit": limit, - "offset": offset, - } - - return query, parameters diff --git a/toolkits/postgres/arcade_postgres/tools/__init__.py b/toolkits/postgres/arcade_postgres/tools/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/toolkits/postgres/arcade_postgres/tools/postgres.py b/toolkits/postgres/arcade_postgres/tools/postgres.py deleted file mode 100644 index 5bd79a2eb..000000000 --- a/toolkits/postgres/arcade_postgres/tools/postgres.py +++ /dev/null @@ -1,300 +0,0 @@ -from typing import Annotated, Any - -from arcade_mcp_server import Context, tool -from arcade_mcp_server.exceptions import RetryableToolError -from arcade_mcp_server.metadata import Behavior, Operation, ToolMetadata -from sqlalchemy import inspect, text -from sqlalchemy.ext.asyncio import AsyncEngine - -from ..database_engine import MAX_ROWS_RETURNED, DatabaseEngine - - -@tool( - requires_secrets=["POSTGRES_DATABASE_CONNECTION_STRING"], - metadata=ToolMetadata( - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def discover_schemas( - context: Context, -) -> list[str]: - """Discover all the schemas in the postgres database.""" - async with await DatabaseEngine.get_engine( - context.get_secret("POSTGRES_DATABASE_CONNECTION_STRING") - ) as engine: - schemas = await _get_schemas(engine) - return schemas - - -@tool( - requires_secrets=["POSTGRES_DATABASE_CONNECTION_STRING"], - metadata=ToolMetadata( - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def discover_tables( - context: Context, - schema_name: Annotated[ - str, "The database schema to discover tables in (default value: 'public')" - ] = "public", -) -> list[str]: - """Discover all the tables in the postgres database when the list of tables is not known. - - ALWAYS use this tool before any other tool that requires a table name. - """ - async with await DatabaseEngine.get_engine( - context.get_secret("POSTGRES_DATABASE_CONNECTION_STRING") - ) as engine: - tables = await _get_tables(engine, schema_name) - return tables - - -@tool( - requires_secrets=["POSTGRES_DATABASE_CONNECTION_STRING"], - metadata=ToolMetadata( - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def get_table_schema( - context: Context, - schema_name: Annotated[str, "The database schema to get the table schema of"], - table_name: Annotated[str, "The table to get the schema of"], -) -> list[str]: - """ - Get the schema/structure of a postgres table in the postgres database when the schema is not known, and the name of the table is provided. - - This tool should ALWAYS be used before executing any query. All tables in the query must be discovered first using the tool. - """ - async with await DatabaseEngine.get_engine( - context.get_secret("POSTGRES_DATABASE_CONNECTION_STRING") - ) as engine: - return await _get_table_schema(engine, schema_name, table_name) - - -@tool( - requires_secrets=["POSTGRES_DATABASE_CONNECTION_STRING"], - metadata=ToolMetadata( - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def execute_select_query( - context: Context, - select_clause: Annotated[ - str, - "This is the part of the SQL query that comes after the SELECT keyword wish a comma separated list of columns you wish to return. Do not include the SELECT keyword.", - ], - from_clause: Annotated[ - str, - "This is the part of the SQL query that comes after the FROM keyword. Do not include the FROM keyword.", - ], - limit: Annotated[ - int, - "The maximum number of rows to return. This is the LIMIT clause of the query. Default: 100.", - ] = 100, - offset: Annotated[ - int, "The number of rows to skip. This is the OFFSET clause of the query. Default: 0." - ] = 0, - join_clause: Annotated[ - str | None, - "This is the part of the SQL query that comes after the JOIN keyword. Do not include the JOIN keyword. If no join is needed, leave this blank.", - ] = None, - where_clause: Annotated[ - str | None, - "This is the part of the SQL query that comes after the WHERE keyword. Do not include the WHERE keyword. If no where clause is needed, leave this blank.", - ] = None, - having_clause: Annotated[ - str | None, - "This is the part of the SQL query that comes after the HAVING keyword. Do not include the HAVING keyword. If no having clause is needed, leave this blank.", - ] = None, - group_by_clause: Annotated[ - str | None, - "This is the part of the SQL query that comes after the GROUP BY keyword. Do not include the GROUP BY keyword. If no group by clause is needed, leave this blank.", - ] = None, - order_by_clause: Annotated[ - str | None, - "This is the part of the SQL query that comes after the ORDER BY keyword. Do not include the ORDER BY keyword. If no order by clause is needed, leave this blank.", - ] = None, - with_clause: Annotated[ - str | None, - "This is the part of the SQL query that comes after the WITH keyword when basing the query on a virtual table. If no WITH clause is needed, leave this blank.", - ] = None, -) -> list[str]: - """ - You have a connection to a postgres database. - Execute a SELECT query and return the results against the postgres database. No other queries (INSERT, UPDATE, DELETE, etc.) are allowed. - - ONLY use this tool if you have already loaded the schema of the tables you need to query. Use the tool to load the schema if not already known. - - The final query will be constructed as follows: - SELECT {select_query_part} FROM {from_clause} JOIN {join_clause} WHERE {where_clause} HAVING {having_clause} ORDER BY {order_by_clause} LIMIT {limit} OFFSET {offset} - - When running queries, follow these rules which will help avoid errors: - * Never "select *" from a table. Always select the columns you need. - * Always order your results by the most important columns first. If you aren't sure, order by the primary key. - * Always use case-insensitive queries to match strings in the query. - * Always trim strings in the query. - * Prefer LIKE queries over direct string matches or regex queries. - * Only join on columns that are indexed or the primary key. Do not join on arbitrary columns. - """ - async with await DatabaseEngine.get_engine( - context.get_secret("POSTGRES_DATABASE_CONNECTION_STRING") - ) as engine: - try: - return await _execute_query( - engine, - select_clause=select_clause, - from_clause=from_clause, - limit=limit, - offset=offset, - join_clause=join_clause, - where_clause=where_clause, - having_clause=having_clause, - group_by_clause=group_by_clause, - order_by_clause=order_by_clause, - with_clause=with_clause, - ) - except Exception as e: - raise RetryableToolError( - f"Query failed: {e}", - developer_message=f"Query failed with parameters: select_clause={select_clause}, from_clause={from_clause}, limit={limit}, offset={offset}, join_clause={join_clause}, where_clause={where_clause}, having_clause={having_clause}, order_by_clause={order_by_clause}, with_clause={with_clause}.", - additional_prompt_content="Load the database schema or use the tool to discover the tables and try again.", - retry_after_ms=10, - ) from e - - -async def _get_schemas(engine: AsyncEngine) -> list[str]: - """Get all the schemas in the database""" - async with engine.connect() as conn: - - def get_schema_names(sync_conn: Any) -> list[str]: - return list(inspect(sync_conn).get_schema_names()) - - schemas: list[str] = await conn.run_sync(get_schema_names) - schemas = [schema for schema in schemas if schema != "information_schema"] - - return schemas - - -async def _get_tables(engine: AsyncEngine, schema_name: str) -> list[str]: - """Get all the tables in the database""" - async with engine.connect() as conn: - - def get_schema_names(sync_conn: Any) -> list[str]: - return list(inspect(sync_conn).get_schema_names()) - - schemas: list[str] = await conn.run_sync(get_schema_names) - tables = [] - for schema in schemas: - if schema == schema_name: - - def get_table_names(sync_conn: Any, s: str = schema) -> list[str]: - return list(inspect(sync_conn).get_table_names(schema=s)) - - these_tables = await conn.run_sync(get_table_names) - tables.extend(these_tables) - - tables.sort() - return tables - - -async def _get_table_schema(engine: AsyncEngine, schema_name: str, table_name: str) -> list[str]: - """Get the schema of a table""" - async with engine.connect() as connection: - - def get_columns(sync_conn: Any, t: str = table_name, s: str = schema_name) -> list[Any]: - return list(inspect(sync_conn).get_columns(t, s)) - - columns_table = await connection.run_sync(get_columns) - - # Get primary key information - pk_constraint = await connection.run_sync( - lambda sync_conn: inspect(sync_conn).get_pk_constraint(table_name, schema_name) - ) - primary_keys = set(pk_constraint.get("constrained_columns", [])) - - # Get index information - indexes = await connection.run_sync( - lambda sync_conn: inspect(sync_conn).get_indexes(table_name, schema_name) - ) - indexed_columns = set() - for index in indexes: - indexed_columns.update(index.get("column_names", [])) - - results = [] - for column in columns_table: - column_name = column["name"] - column_type = column["type"].python_type.__name__ - - # Build column description - description = f"{column_name}: {column_type}" - - # Add primary key indicator - if column_name in primary_keys: - description += " (PRIMARY KEY)" - - # Add index indicator - if column_name in indexed_columns: - description += " (INDEXED)" - - results.append(description) - - return results[:MAX_ROWS_RETURNED] - - -async def _execute_query( - engine: AsyncEngine, - select_clause: str, - from_clause: str, - limit: int, - offset: int, - join_clause: str | None, - where_clause: str | None, - having_clause: str | None, - group_by_clause: str | None, - order_by_clause: str | None, - with_clause: str | None, -) -> list[str]: - """Execute a query and return the results.""" - async with engine.connect() as connection: - query, parameters = DatabaseEngine.sanitize_query( - select_clause=select_clause, - from_clause=from_clause, - limit=limit, - offset=offset, - join_clause=join_clause, - where_clause=where_clause, - having_clause=having_clause, - group_by_clause=group_by_clause, - order_by_clause=order_by_clause, - with_clause=with_clause, - ) - print(f"Query: {query}") - print(f"Parameters: {parameters}") - result = await connection.execute(text(query), parameters) - rows = result.fetchall() - results = [str(row) for row in rows] - return results[:MAX_ROWS_RETURNED] diff --git a/toolkits/postgres/evals/eval_postgres.py b/toolkits/postgres/evals/eval_postgres.py deleted file mode 100644 index f54f77fe3..000000000 --- a/toolkits/postgres/evals/eval_postgres.py +++ /dev/null @@ -1,94 +0,0 @@ -import arcade_postgres -from arcade_core import ToolCatalog -from arcade_evals import ( - BinaryCritic, - EvalRubric, - EvalSuite, - ExpectedToolCall, - SimilarityCritic, - tool_eval, -) -from arcade_postgres.tools.postgres import ( - discover_tables, - execute_query, - get_table_schema, -) - -# Evaluation rubric -rubric = EvalRubric( - fail_threshold=0.85, - warn_threshold=0.95, -) - - -catalog = ToolCatalog() -catalog.add_module(arcade_postgres) - - -@tool_eval() -def sql_eval_suite() -> EvalSuite: - suite = EvalSuite( - name="sql Tools Evaluation", - system_message=( - "You are an AI assistant with access to sql tools. " - "Use them to help the user with their tasks." - ), - catalog=catalog, - rubric=rubric, - ) - - suite.add_case( - name="Get user by id (schema known)", - user_message="Tell me the name and email of user #1 in my database. The table 'users' has the following schema: id: int, name: str, email: str, password_hash: str, created_at: datetime, updated_at: datetime", - expected_tool_calls=[ - ExpectedToolCall( - func=execute_query, args={"query": "SELECT name, email FROM users WHERE id = 1"} - ) - ], - rubric=rubric, - critics=[SimilarityCritic(critic_field="query", weight=1.0)], - ) - - suite.add_case( - name="Discover tables", - user_message="What tables are in my database?", - expected_tool_calls=[ - ExpectedToolCall(func=discover_tables, args={}), - ], - rubric=rubric, - ) - - suite.add_case( - name="Get table schema (schema provided)", - user_message="What columns are in the table 'public.users' in my database?", - expected_tool_calls=[ - ExpectedToolCall( - func=get_table_schema, args={"schema_name": "public", "table_name": "users"} - ), - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="schema_name", weight=0.5), - BinaryCritic(critic_field="table_name", weight=0.5), - ], - ) - - suite.add_case( - name="Get table schema (schema not provided)", - user_message="What columns are in the table 'users' in my database?", - additional_messages=[ - {"role": "user", "content": "When not provided, the schema is 'public'."} - ], - expected_tool_calls=[ - ExpectedToolCall( - func=get_table_schema, args={"schema_name": "public", "table_name": "users"} - ), - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="schema_name", weight=0.5), - BinaryCritic(critic_field="table_name", weight=0.5), - ], - ) - - return suite diff --git a/toolkits/postgres/pyproject.toml b/toolkits/postgres/pyproject.toml deleted file mode 100644 index fe969c7ac..000000000 --- a/toolkits/postgres/pyproject.toml +++ /dev/null @@ -1,65 +0,0 @@ -[build-system] -requires = [ "hatchling",] -build-backend = "hatchling.build" - -[project] -name = "arcade_postgres" -version = "0.5.0" -description = "Tools to query and explore a postgres database" -requires-python = ">=3.10" -dependencies = [ - "arcade-mcp-server>=1.17.0,<2.0.0", - "psycopg2-binary>=2.9.10", - "pydantic>=2.11.7", - "sqlalchemy>=2.0.41", - "psycopg2-binary>=2.9.10", - "asyncpg>=0.30.0", - "greenlet>=3.2.3", -] -[[project.authors]] -name = "evantahler" -email = "support@arcade.dev" - -[project.optional-dependencies] -dev = [ - "arcade-mcp[all]>=1.2.0,<2.0.0", - "pytest>=8.3.0,<8.4.0", - "pytest-cov>=4.0.0,<4.1.0", - "pytest-mock>=3.11.1,<3.12.0", - "pytest-asyncio>=0.24.0,<0.25.0", - "mypy>=1.5.1,<1.6.0", - "pre-commit>=3.4.0,<3.5.0", - "tox>=4.11.1,<4.12.0", - "ruff>=0.7.4,<0.8.0", -] - -[project.scripts] -arcade-postgres = "arcade_postgres.__main__:main" -arcade_postgres = "arcade_postgres.__main__:main" - -# Use local path sources for arcade libs when working locally -[tool.uv.sources] -arcade-mcp = { path = "../../", editable = true } -arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true } - -[tool.mypy] -files = [ "arcade_postgres/**/*.py",] -python_version = "3.10" -disallow_untyped_defs = "True" -disallow_any_unimported = "True" -no_implicit_optional = "True" -check_untyped_defs = "True" -warn_return_any = "True" -warn_unused_ignores = "True" -show_error_codes = "True" -ignore_missing_imports = "True" - -[tool.pytest.ini_options] -testpaths = [ "tests",] -asyncio_default_fixture_loop_scope = "function" - -[tool.coverage.report] -skip_empty = true - -[tool.hatch.build.targets.wheel] -packages = [ "arcade_postgres",] diff --git a/toolkits/postgres/tests/__init__.py b/toolkits/postgres/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/toolkits/postgres/tests/dump.sql b/toolkits/postgres/tests/dump.sql deleted file mode 100644 index a94b7ee16..000000000 --- a/toolkits/postgres/tests/dump.sql +++ /dev/null @@ -1,399 +0,0 @@ -DROP TABLE IF EXISTS "public"."messages"; --- This script only contains the table creation statements and does not fully represent the table in the database. Do not use it as a backup. --- Sequence and defined type -CREATE SEQUENCE IF NOT EXISTS messages_id_seq; --- Table Definition -CREATE TABLE "public"."messages" ( - "id" int4 NOT NULL DEFAULT nextval('messages_id_seq'::regclass), - "body" text NOT NULL, - "user_id" int4 NOT NULL, - "created_at" timestamp NOT NULL DEFAULT now(), - "updated_at" timestamp NOT NULL DEFAULT now(), - PRIMARY KEY ("id") -); -DROP TABLE IF EXISTS "public"."users"; --- This script only contains the table creation statements and does not fully represent the table in the database. Do not use it as a backup. --- Sequence and defined type -CREATE SEQUENCE IF NOT EXISTS users_id_seq; --- Table Definition -CREATE TABLE "public"."users" ( - "id" int4 NOT NULL DEFAULT nextval('users_id_seq'::regclass), - "name" varchar(256) NOT NULL, - "email" text NOT NULL, - "password_hash" text NOT NULL, - "created_at" timestamp NOT NULL DEFAULT now(), - "updated_at" timestamp NOT NULL DEFAULT now(), - "status" varchar, - PRIMARY KEY ("id") -); -INSERT INTO "public"."messages" ( - "id", - "body", - "user_id", - "created_at", - "updated_at" - ) -VALUES -- User 1 (Alice) - 3 messages - ( - 1, - 'Hello everyone!', - 1, - '2025-01-10 10:00:00.000000', - '2025-01-10 10:00:00.000000' - ), - ( - 2, - 'How is everyone doing today?', - 1, - '2025-01-10 11:30:00.000000', - '2025-01-10 11:30:00.000000' - ), - ( - 3, - 'Great to see you all here!', - 1, - '2025-01-10 14:15:00.000000', - '2025-01-10 14:15:00.000000' - ), - -- User 2 (Bob) - 2 messages - ( - 4, - 'Hi Alice! Doing well, thanks for asking.', - 2, - '2025-01-10 11:35:00.000000', - '2025-01-10 11:35:00.000000' - ), - ( - 5, - 'Anyone up for a game later?', - 2, - '2025-01-10 16:20:00.000000', - '2025-01-10 16:20:00.000000' - ), - -- User 3 (Charlie) - 3 messages - ( - 6, - 'Count me in for the game!', - 3, - '2025-01-10 16:25:00.000000', - '2025-01-10 16:25:00.000000' - ), - ( - 7, - 'What time works for everyone?', - 3, - '2025-01-10 16:30:00.000000', - '2025-01-10 16:30:00.000000' - ), - ( - 8, - 'I can play around 8 PM', - 3, - '2025-01-10 17:00:00.000000', - '2025-01-10 17:00:00.000000' - ), - -- User 4 (Diana) - 2 messages - ( - 9, - '8 PM works for me too!', - 4, - '2025-01-10 17:05:00.000000', - '2025-01-10 17:05:00.000000' - ), - ( - 10, - 'What game should we play?', - 4, - '2025-01-10 17:10:00.000000', - '2025-01-10 17:10:00.000000' - ), - -- User 5 (Evan) - 3 messages - ( - 11, - 'I suggest we try the new arcade game!', - 5, - '2025-01-10 17:15:00.000000', - '2025-01-10 17:15:00.000000' - ), - ( - 12, - 'It has great multiplayer features', - 5, - '2025-01-10 17:20:00.000000', - '2025-01-10 17:20:00.000000' - ), - ( - 13, - 'Perfect timing for a weekend session', - 5, - '2025-01-10 18:00:00.000000', - '2025-01-10 18:00:00.000000' - ), - -- User 6 (Fiona) - 2 messages - ( - 14, - 'Sounds like fun! I love arcade games.', - 6, - '2025-01-10 18:05:00.000000', - '2025-01-10 18:05:00.000000' - ), - ( - 15, - 'Should I bring snacks?', - 6, - '2025-01-10 18:10:00.000000', - '2025-01-10 18:10:00.000000' - ), - -- User 7 (George) - 3 messages - ( - 16, - 'Snacks are always welcome!', - 7, - '2025-01-10 18:15:00.000000', - '2025-01-10 18:15:00.000000' - ), - ( - 17, - 'I can bring some drinks', - 7, - '2025-01-10 18:20:00.000000', - '2025-01-10 18:20:00.000000' - ), - ( - 18, - 'This is going to be awesome', - 7, - '2025-01-10 19:00:00.000000', - '2025-01-10 19:00:00.000000' - ), - -- User 8 (Helen) - 2 messages - ( - 19, - 'I agree! Cannot wait for the game night.', - 8, - '2025-01-10 19:05:00.000000', - '2025-01-10 19:05:00.000000' - ), - ( - 20, - 'Should we set up a Discord call?', - 8, - '2025-01-10 19:10:00.000000', - '2025-01-10 19:10:00.000000' - ), - -- User 9 (Ian) - 3 messages - ( - 21, - 'Discord would be perfect for voice chat', - 9, - '2025-01-10 19:15:00.000000', - '2025-01-10 19:15:00.000000' - ), - ( - 22, - 'I will create a server for us', - 9, - '2025-01-10 19:20:00.000000', - '2025-01-10 19:20:00.000000' - ), - ( - 23, - 'Link will be shared in a few minutes', - 9, - '2025-01-10 19:25:00.000000', - '2025-01-10 19:25:00.000000' - ), - -- User 10 (Julia) - 2 messages - ( - 24, - 'Thanks Ian! You are the best.', - 10, - '2025-01-10 19:30:00.000000', - '2025-01-10 19:30:00.000000' - ), - ( - 25, - 'See you all at 8 PM!', - 10, - '2025-01-10 19:35:00.000000', - '2025-01-10 19:35:00.000000' - ), - -- Additional messages for Evan (user_id 5) - 10 more messages - ( - 26, - 'Just finished setting up the game server!', - 5, - '2025-01-10 20:00:00.000000', - '2025-01-10 20:00:00.000000' - ), - ( - 27, - 'Everyone should be able to connect now', - 5, - '2025-01-10 20:05:00.000000', - '2025-01-10 20:05:00.000000' - ), - ( - 28, - 'I added some custom maps too', - 5, - '2025-01-10 20:10:00.000000', - '2025-01-10 20:10:00.000000' - ), - ( - 29, - 'The graphics look amazing on this new version', - 5, - '2025-01-10 20:15:00.000000', - '2025-01-10 20:15:00.000000' - ), - ( - 30, - 'Hope you all enjoy the new features', - 5, - '2025-01-10 20:20:00.000000', - '2025-01-10 20:20:00.000000' - ), - ( - 31, - 'I also set up a leaderboard system', - 5, - '2025-01-10 20:25:00.000000', - '2025-01-10 20:25:00.000000' - ), - ( - 32, - 'We can track high scores now', - 5, - '2025-01-10 20:30:00.000000', - '2025-01-10 20:30:00.000000' - ), - ( - 33, - 'The game supports up to 8 players simultaneously', - 5, - '2025-01-10 20:35:00.000000', - '2025-01-10 20:35:00.000000' - ), - ( - 34, - 'I tested it earlier and it runs smoothly', - 5, - '2025-01-10 20:40:00.000000', - '2025-01-10 20:40:00.000000' - ), - ( - 35, - 'Cannot wait to see everyone online tonight!', - 5, - '2025-01-10 20:45:00.000000', - '2025-01-10 20:45:00.000000' - ); -INSERT INTO "public"."users" ( - "id", - "name", - "email", - "password_hash", - "created_at", - "updated_at", - "status" - ) -VALUES ( - 1, - 'Alice', - 'alice@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$tMg1Rd3IEDnp3iFKrqsF4Dsbw6/Cbf6seRB/H5bhaPg$zZj5yn4x3D3O3mDHcW2aczQNiYfAs3cw21XMEIgkF0E', - '2024-09-01 20:49:38.759432', - '2024-09-02 03:49:39.927', - 'active' - ), - ( - 2, - 'Bob', - 'bob@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$CvOMK1WUd99R7kYXpiBPNYw4OQP53pYIgeMnwz92mrE$HPthId4phMoPT1TWuCRHHCr9BSQA8XoUkQuB1HZsqTY', - '2024-09-02 17:49:23.377425', - '2024-09-02 17:49:23.377425', - 'active' - ), - ( - 3, - 'Charlie', - 'charlie@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$paCAAD1HVZkncP/WvecuUO6zFXp2/8BISpgr5rXRxps$M5kBFc9JHHGNw9SXnPu2ggpJY0mFFCska7TXMrllndo', - '2024-09-03 10:30:15.123456', - '2024-09-03 10:30:15.123456', - 'active' - ), - ( - 4, - 'Diana', - 'diana@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$xyz123ABC456DEF789GHI$SampleHashForDiana123', - '2024-09-04 14:20:30.654321', - '2024-09-04 14:20:30.654321', - 'active' - ), - ( - 5, - 'Evan', - 'evan@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$evanHash123$EvanPasswordHash456', - '2024-09-05 09:15:45.987654', - '2024-09-05 09:15:45.987654', - 'active' - ), - ( - 6, - 'Fiona', - 'fiona@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$fionaHash456$FionaPasswordHash789', - '2024-09-06 16:45:12.345678', - '2024-09-06 16:45:12.345678', - 'active' - ), - ( - 7, - 'George', - 'george@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$georgeHash789$GeorgePasswordHash012', - '2024-09-07 11:30:25.876543', - '2024-09-07 11:30:25.876543', - 'active' - ), - ( - 8, - 'Helen', - 'helen@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$helenHash012$HelenPasswordHash345', - '2024-09-08 13:25:40.234567', - '2024-09-08 13:25:40.234567', - 'active' - ), - ( - 9, - 'Ian', - 'ian@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$ianHash345$IanPasswordHash678', - '2024-09-09 08:40:55.765432', - '2024-09-09 08:40:55.765432', - 'active' - ), - ( - 10, - 'Julia', - 'julia@example.com', - '$argon2id$v=19$m=65536,t=2,p=1$juliaHash678$JuliaPasswordHash901', - '2024-09-10 15:55:18.123456', - '2024-09-10 15:55:18.123456', - 'active' - ); -ALTER TABLE "public"."messages" -ADD FOREIGN KEY ("user_id") REFERENCES "public"."users"("id"); --- set pk to 11 -ALTER SEQUENCE users_id_seq RESTART WITH 11; --- Indices -CREATE UNIQUE INDEX name_idx ON public.users USING btree (name); -CREATE UNIQUE INDEX email_idx ON public.users USING btree (email); -DROP INDEX IF EXISTS users_email_unique; -CREATE UNIQUE INDEX users_email_unique ON public.users USING btree (email); diff --git a/toolkits/postgres/tests/test_postgres.py b/toolkits/postgres/tests/test_postgres.py deleted file mode 100644 index 99c35a802..000000000 --- a/toolkits/postgres/tests/test_postgres.py +++ /dev/null @@ -1,188 +0,0 @@ -import os -from os import environ -from unittest.mock import MagicMock - -import pytest -import pytest_asyncio -from arcade_mcp_server import Context -from arcade_mcp_server.exceptions import RetryableToolError -from arcade_postgres.tools.postgres import ( - DatabaseEngine, - discover_schemas, - discover_tables, - execute_select_query, - get_table_schema, -) -from sqlalchemy import text -from sqlalchemy.ext.asyncio import create_async_engine - -POSTGRES_DATABASE_CONNECTION_STRING = ( - environ.get("TEST_POSTGRES_DATABASE_CONNECTION_STRING") - or "postgresql://postgres@localhost:5432/postgres" -) - - -@pytest.fixture -def mock_context(): - context = MagicMock(spec=Context) - context.get_secret = MagicMock(return_value=POSTGRES_DATABASE_CONNECTION_STRING) - return context - - -# before the tests, restore the database from the dump -@pytest_asyncio.fixture(autouse=True) -async def restore_database(): - with open(f"{os.path.dirname(__file__)}/dump.sql") as f: - engine = create_async_engine( - POSTGRES_DATABASE_CONNECTION_STRING.replace("postgresql", "postgresql+asyncpg").split( - "?" - )[0] - ) - async with engine.connect() as c: - queries = f.read().split(";") - await c.execute(text("BEGIN")) - for query in queries: - if query.strip(): - await c.execute(text(query)) - await c.commit() - await engine.dispose() - - -@pytest_asyncio.fixture(autouse=True) -async def cleanup_engines(): - """Clean up database engines after each test to prevent connection leaks.""" - yield - # Clean up all cached engines after each test - await DatabaseEngine.cleanup() - - -@pytest.mark.asyncio -async def test_discover_schemas(mock_context) -> None: - assert await discover_schemas(mock_context) == ["public"] - - -@pytest.mark.asyncio -async def test_discover_tables(mock_context) -> None: - assert await discover_tables(mock_context) == ["messages", "users"] - - -@pytest.mark.asyncio -async def test_get_table_schema(mock_context) -> None: - assert await get_table_schema(mock_context, "public", "users") == [ - "id: int (PRIMARY KEY)", - "name: str (INDEXED)", - "email: str (INDEXED)", - "password_hash: str", - "created_at: datetime", - "updated_at: datetime", - "status: str", - ] - - assert await get_table_schema(mock_context, "public", "messages") == [ - "id: int (PRIMARY KEY)", - "body: str", - "user_id: int", - "created_at: datetime", - "updated_at: datetime", - ] - - -@pytest.mark.asyncio -async def test_execute_select_query(mock_context) -> None: - assert await execute_select_query( - mock_context, - select_clause="id, name, email", - from_clause="users", - where_clause="id = 1", - ) == [ - "(1, 'Alice', 'alice@example.com')", - ] - assert await execute_select_query( - mock_context, - select_clause="id, name, email", - from_clause="users", - order_by_clause="id", - limit=1, - offset=1, - ) == [ - "(2, 'Bob', 'bob@example.com')", - ] - - -@pytest.mark.asyncio -async def test_execute_select_query_with_keywords(mock_context) -> None: - assert await execute_select_query( - mock_context, - select_clause="SELECT id, name, email", - from_clause="FROM users", - limit=1, - ) == [ - "(1, 'Alice', 'alice@example.com')", - ] - - -@pytest.mark.asyncio -async def test_execute_select_query_with_join(mock_context) -> None: - assert await execute_select_query( - mock_context, - select_clause="u.id, u.name, u.email, m.id, m.body", - from_clause="users u", - join_clause="messages m ON u.id = m.user_id", - limit=1, - ) == [ - "(1, 'Alice', 'alice@example.com', 1, 'Hello everyone!')", - ] - - -@pytest.mark.asyncio -async def test_execute_select_query_with_group_by(mock_context) -> None: - assert await execute_select_query( - mock_context, - select_clause="u.name, COUNT(m.id) AS message_count", - from_clause="messages m", - join_clause="users u ON m.user_id = u.id", - group_by_clause="u.name", - order_by_clause="message_count DESC", - limit=2, - ) == [ - "('Evan', 13)", - "('Alice', 3)", - ] - - -@pytest.mark.asyncio -async def test_execute_select_query_with_no_results(mock_context) -> None: - # does not raise an error - assert ( - await execute_select_query( - mock_context, - select_clause="id, name, email", - from_clause="users", - where_clause="id = 9999999999", - ) - == [] - ) - - -@pytest.mark.asyncio -async def test_execute_select_query_with_problem(mock_context) -> None: - # 'foo' is not a valid id - with pytest.raises(RetryableToolError) as e: - await execute_select_query( - mock_context, - select_clause="*", - from_clause="users", - where_clause="id = 'foo'", - ) - assert "Do not use * in the select clause" in str(e.value) - - -@pytest.mark.asyncio -async def test_execute_select_query_rejects_non_select(mock_context) -> None: - with pytest.raises(RetryableToolError) as e: - await execute_select_query( - mock_context, - select_clause="INSERT INTO users (name, email, password_hash) VALUES ('Luigi', 'luigi@example.com', 'password')", - from_clause="users", - ) - assert "Only SELECT queries are allowed" in str(e.value) diff --git a/toolkits/postgres/tests/test_setup.sh b/toolkits/postgres/tests/test_setup.sh deleted file mode 100755 index bac017577..000000000 --- a/toolkits/postgres/tests/test_setup.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -# Run PostgreSQL container -docker run -d --name some-postgres-server -e POSTGRES_HOST_AUTH_METHOD=trust -p 5432:5432 postgres:latest - -# Wait for PostgreSQL to be ready -echo "Waiting for PostgreSQL to be ready..." -for i in {1..30}; do - if docker exec some-postgres-server pg_isready -U postgres > /dev/null 2>&1; then - echo "PostgreSQL is ready!" - break - fi - echo "Waiting... ($i/30)" - sleep 1 -done diff --git a/toolkits/zendesk/.pre-commit-config.yaml b/toolkits/zendesk/.pre-commit-config.yaml deleted file mode 100644 index 6a345ba48..000000000 --- a/toolkits/zendesk/.pre-commit-config.yaml +++ /dev/null @@ -1,18 +0,0 @@ -files: ^.*/zendesk/.* -repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: "v4.4.0" - hooks: - - id: check-case-conflict - - id: check-merge-conflict - - id: check-toml - - id: check-yaml - - id: end-of-file-fixer - - id: trailing-whitespace - - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.7 - hooks: - - id: ruff - args: [--fix] - - id: ruff-format diff --git a/toolkits/zendesk/.ruff.toml b/toolkits/zendesk/.ruff.toml deleted file mode 100644 index 2315c4aa8..000000000 --- a/toolkits/zendesk/.ruff.toml +++ /dev/null @@ -1,46 +0,0 @@ -target-version = "py310" -line-length = 100 -fix = true - -[lint] -select = [ - # flake8-2020 - "YTT", - # flake8-bandit - "S", - # flake8-bugbear - "B", - # flake8-builtins - "A", - # flake8-comprehensions - "C4", - # flake8-debugger - "T10", - # flake8-simplify - "SIM", - # isort - "I", - # mccabe - "C90", - # pycodestyle - "E", "W", - # pyflakes - "F", - # pygrep-hooks - "PGH", - # pyupgrade - "UP", - # ruff - "RUF", - # tryceratops - "TRY", -] - -ignore = ["C901"] - -[lint.per-file-ignores] -"**/tests/*" = ["S101"] - -[format] -preview = true -skip-magic-trailing-comma = false diff --git a/toolkits/zendesk/Makefile b/toolkits/zendesk/Makefile deleted file mode 100644 index 0a8969beb..000000000 --- a/toolkits/zendesk/Makefile +++ /dev/null @@ -1,55 +0,0 @@ -.PHONY: help - -help: - @echo "๐Ÿ› ๏ธ github Commands:\n" - @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' - -.PHONY: install -install: ## Install the uv environment and install all packages with dependencies - @echo "๐Ÿš€ Creating virtual environment and installing all packages using uv" - @uv sync --active --all-extras --no-sources - @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi - @echo "โœ… All packages and dependencies installed via uv" - -.PHONY: install-local -install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources - @echo "๐Ÿš€ Creating virtual environment and installing all packages using uv" - @uv sync --active --all-extras - @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi - @echo "โœ… All packages and dependencies installed via uv" - -.PHONY: build -build: clean-build ## Build wheel file using poetry - @echo "๐Ÿš€ Creating wheel file" - uv build - -.PHONY: clean-build -clean-build: ## clean build artifacts - @echo "๐Ÿ—‘๏ธ Cleaning dist directory" - rm -rf dist - -.PHONY: test -test: ## Test the code with pytest - @echo "๐Ÿš€ Testing code: Running pytest" - @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml - -.PHONY: coverage -coverage: ## Generate coverage report - @echo "coverage report" - @uv run --no-sources coverage report - @echo "Generating coverage report" - @uv run --no-sources coverage html - -.PHONY: bump-version -bump-version: ## Bump the version in the pyproject.toml file by a patch version - @echo "๐Ÿš€ Bumping version in pyproject.toml" - uv version --no-sources --bump patch - -.PHONY: check -check: ## Run code quality tools. - @if [ -f .pre-commit-config.yaml ]; then\ - echo "๐Ÿš€ Linting code: Running pre-commit";\ - uv run --no-sources pre-commit run -a;\ - fi - @echo "๐Ÿš€ Static type checking: Running mypy" - @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/zendesk/arcade_zendesk/__init__.py b/toolkits/zendesk/arcade_zendesk/__init__.py deleted file mode 100644 index 16ac06623..000000000 --- a/toolkits/zendesk/arcade_zendesk/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from arcade_zendesk.tools import ( - add_ticket_comment, - get_ticket_comments, - list_tickets, - mark_ticket_solved, - search_articles, -) - -__all__ = [ - "add_ticket_comment", - "get_ticket_comments", - "list_tickets", - "mark_ticket_solved", - "search_articles", -] diff --git a/toolkits/zendesk/arcade_zendesk/__main__.py b/toolkits/zendesk/arcade_zendesk/__main__.py deleted file mode 100644 index 20a432882..000000000 --- a/toolkits/zendesk/arcade_zendesk/__main__.py +++ /dev/null @@ -1,28 +0,0 @@ -import sys -from typing import cast - -from arcade_mcp_server import MCPApp -from arcade_mcp_server.mcp_app import TransportType - -import arcade_zendesk - -app = MCPApp( - name="Zendesk", - instructions=( - "Use this server when you need to interact with Zendesk to help users " - "manage support tickets and search knowledge base articles." - ), -) - -app.add_tools_from_module(arcade_zendesk) - - -def main() -> None: - transport = sys.argv[1] if len(sys.argv) > 1 else "stdio" - host = sys.argv[2] if len(sys.argv) > 2 else "127.0.0.1" - port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000 - app.run(transport=cast(TransportType, transport), host=host, port=port) - - -if __name__ == "__main__": - main() diff --git a/toolkits/zendesk/arcade_zendesk/enums.py b/toolkits/zendesk/arcade_zendesk/enums.py deleted file mode 100644 index f135bf445..000000000 --- a/toolkits/zendesk/arcade_zendesk/enums.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Enums for the Zendesk toolkit.""" - -from enum import Enum - - -class ArticleSortBy(Enum): - """Sort fields for article search results.""" - - CREATED_AT = "created_at" - RELEVANCE = "relevance" - - -class SortOrder(Enum): - """Sort order direction.""" - - ASC = "asc" - DESC = "desc" - - -class TicketStatus(Enum): - """Valid ticket statuses.""" - - NEW = "new" - OPEN = "open" - PENDING = "pending" - SOLVED = "solved" - CLOSED = "closed" diff --git a/toolkits/zendesk/arcade_zendesk/tools/__init__.py b/toolkits/zendesk/arcade_zendesk/tools/__init__.py deleted file mode 100644 index d2e90cbf8..000000000 --- a/toolkits/zendesk/arcade_zendesk/tools/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from arcade_zendesk.tools.search_articles import search_articles -from arcade_zendesk.tools.system_context import who_am_i -from arcade_zendesk.tools.tickets import ( - add_ticket_comment, - get_ticket_comments, - list_tickets, - mark_ticket_solved, -) - -__all__ = [ - "add_ticket_comment", - "get_ticket_comments", - "list_tickets", - "mark_ticket_solved", - "search_articles", - "who_am_i", -] diff --git a/toolkits/zendesk/arcade_zendesk/tools/search_articles.py b/toolkits/zendesk/arcade_zendesk/tools/search_articles.py deleted file mode 100644 index 6d9bd5395..000000000 --- a/toolkits/zendesk/arcade_zendesk/tools/search_articles.py +++ /dev/null @@ -1,219 +0,0 @@ -import logging -from typing import Annotated, Any - -import httpx -from arcade_mcp_server import Context, tool -from arcade_mcp_server.auth import OAuth2 -from arcade_mcp_server.exceptions import RetryableToolError -from arcade_mcp_server.metadata import ( - Behavior, - Classification, - Operation, - ServiceDomain, - ToolMetadata, -) - -from arcade_zendesk.enums import ArticleSortBy, SortOrder -from arcade_zendesk.utils import ( - fetch_paginated_results, - get_zendesk_subdomain, - process_search_results, - validate_date_format, -) - -logger = logging.getLogger(__name__) - - -@tool( - requires_auth=OAuth2(id="zendesk", scopes=["read"]), - requires_secrets=["ZENDESK_SUBDOMAIN"], - metadata=ToolMetadata( - classification=Classification( - service_domains=[ServiceDomain.CUSTOMER_SUPPORT], - ), - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def search_articles( - context: Context, - query: Annotated[ - str | None, - "Search text to match against articles. Supports quoted expressions for exact matching", - ] = None, - label_names: Annotated[ - list[str] | None, - "List of label names to filter by (case-insensitive). Article must have at least " - "one matching label. Available on Professional/Enterprise plans only", - ] = None, - created_after: Annotated[ - str | None, - "Filter articles created after this date (format: YYYY-MM-DD)", - ] = None, - created_before: Annotated[ - str | None, - "Filter articles created before this date (format: YYYY-MM-DD)", - ] = None, - created_at: Annotated[ - str | None, - "Filter articles created on this exact date (format: YYYY-MM-DD)", - ] = None, - sort_by: Annotated[ - ArticleSortBy | None, - "Field to sort articles by. Defaults to relevance according to the search query", - ] = None, - sort_order: Annotated[ - SortOrder | None, - "Sort order direction. Defaults to descending", - ] = None, - limit: Annotated[ - int, - "Number of articles to return. Defaults to 30", - ] = 30, - offset: Annotated[ - int, - "Number of articles to skip before returning results. Defaults to 0", - ] = 0, - include_body: Annotated[ - bool, - "Include article body content in results. Bodies will be cleaned of HTML and truncated", - ] = True, - max_article_length: Annotated[ - int | None, - "Maximum length for article body content in characters. " - "Set to None for no limit. Defaults to 500", - ] = 500, -) -> Annotated[ - dict[str, Any], - "Article search results with pagination metadata. Includes 'next_offset' when more " - "results are available. Simply use this value as the 'offset' parameter in your next " - "call to fetch the next batch", -]: - """ - Search for Help Center articles in your Zendesk knowledge base. - - This tool searches specifically for published knowledge base articles that provide - solutions and guidance to users. At least one search parameter (query or label_names) - must be provided. - - PAGINATION: - - The response includes 'next_offset' when more results are available - - To fetch the next batch, simply pass the 'next_offset' value as the 'offset' parameter - - If 'next_offset' is not present, you've reached the end of available results - - The tool automatically handles fetching from the correct page based on your offset - - IMPORTANT: ALL FILTERS CAN BE COMBINED IN A SINGLE CALL - You can combine multiple filters (query, labels, dates) in one search request. - Do NOT make separate tool calls - combine all relevant filters together. - """ - - # Validate date parameters - date_params = { - "created_after": created_after, - "created_before": created_before, - "created_at": created_at, - } - - for param_name, param_value in date_params.items(): - if param_value and not validate_date_format(param_value): - raise RetryableToolError( - message=( - f"Invalid date format for {param_name}: '{param_value}'. " - "Please use YYYY-MM-DD format." - ), - developer_message=( - f"Date validation failed for parameter '{param_name}' " - f"with value '{param_value}'" - ), - retry_after_ms=500, - additional_prompt_content="Use format YYYY-MM-DD.", - ) - - # Validate limit and offset parameters - if limit < 1: - raise RetryableToolError( - message="limit must be at least 1.", - developer_message=f"Invalid limit value: {limit}", - retry_after_ms=100, - additional_prompt_content="Provide a positive limit value", - ) - - if offset < 0: - raise RetryableToolError( - message="offset cannot be negative.", - developer_message=f"Invalid offset value: {offset}", - retry_after_ms=100, - additional_prompt_content="Provide a non-negative offset value", - ) - - # Validate that at least one search parameter is provided - if not any([query, label_names]): - raise RetryableToolError( - message="At least one search parameter must be provided.", - developer_message="No search parameters were provided", - retry_after_ms=100, - additional_prompt_content=( - "Provide at least one of: query text or a list of label names" - ), - ) - - auth_token = context.get_auth_token_or_empty() - subdomain = get_zendesk_subdomain(context) - - url = f"https://{subdomain}.zendesk.com/api/v2/help_center/articles/search" - - # Base parameters for the search - base_params: dict[str, Any] = { - "per_page": 100, # Max allowed per page - } - - if query: - base_params["query"] = query - - if label_names: - base_params["label_names"] = ",".join(label_names) - - if created_after: - base_params["created_after"] = created_after - - if created_before: - base_params["created_before"] = created_before - - if created_at: - base_params["created_at"] = created_at - - if sort_by: - base_params["sort_by"] = sort_by.value - - if sort_order: - base_params["sort_order"] = sort_order.value - - async with httpx.AsyncClient() as client: - headers = { - "Authorization": f"Bearer {auth_token}", - "Content-Type": "application/json", - "Accept": "application/json", - } - - data = await fetch_paginated_results( - client=client, - url=url, - headers=headers, - params=base_params, - offset=offset, - limit=limit, - ) - - if "results" in data: - data["results"] = process_search_results( - data["results"], include_body=include_body, max_body_length=max_article_length - ) - - logger.info(f"Article search returned {data.get('count', 0)} results") - - return data diff --git a/toolkits/zendesk/arcade_zendesk/tools/system_context.py b/toolkits/zendesk/arcade_zendesk/tools/system_context.py deleted file mode 100644 index a74f18d1f..000000000 --- a/toolkits/zendesk/arcade_zendesk/tools/system_context.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Annotated, Any - -from arcade_mcp_server import Context, tool -from arcade_mcp_server.auth import OAuth2 -from arcade_mcp_server.metadata import ( - Behavior, - Classification, - Operation, - ServiceDomain, - ToolMetadata, -) - -from arcade_zendesk.who_am_i_util import build_who_am_i_response - - -@tool( - requires_auth=OAuth2(id="zendesk", scopes=["read"]), - requires_secrets=["ZENDESK_SUBDOMAIN"], - metadata=ToolMetadata( - classification=Classification( - service_domains=[ServiceDomain.CUSTOMER_SUPPORT], - ), - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def who_am_i( - context: Context, -) -> Annotated[ - dict[str, Any], - "Get comprehensive user profile and Zendesk account information.", -]: - """ - Get comprehensive user profile and Zendesk account information. - - This tool provides detailed information about the authenticated user including - their name, email, role, organization details, and Zendesk account context. - """ - user_info = await build_who_am_i_response(context) - return dict(user_info) diff --git a/toolkits/zendesk/arcade_zendesk/tools/tickets.py b/toolkits/zendesk/arcade_zendesk/tools/tickets.py deleted file mode 100644 index d5ca67749..000000000 --- a/toolkits/zendesk/arcade_zendesk/tools/tickets.py +++ /dev/null @@ -1,367 +0,0 @@ -from typing import Annotated, Any - -import httpx -from arcade_mcp_server import Context, tool -from arcade_mcp_server.auth import OAuth2 -from arcade_mcp_server.exceptions import RetryableToolError -from arcade_mcp_server.metadata import ( - Behavior, - Classification, - Operation, - ServiceDomain, - ToolMetadata, -) - -from arcade_zendesk.enums import SortOrder, TicketStatus -from arcade_zendesk.utils import fetch_paginated_results, get_zendesk_subdomain - - -def _handle_ticket_not_found(response: httpx.Response, ticket_id: int) -> None: - """Handle 404 responses for ticket operations.""" - if response.status_code == 404: - raise RetryableToolError( - message=f"Ticket #{ticket_id} not found.", - developer_message=f"Ticket with ID {ticket_id} does not exist", - retry_after_ms=500, - additional_prompt_content="Please verify the ticket ID and try again", - ) - - -@tool( - requires_auth=OAuth2(id="zendesk", scopes=["read"]), - requires_secrets=["ZENDESK_SUBDOMAIN"], - metadata=ToolMetadata( - classification=Classification( - service_domains=[ServiceDomain.CUSTOMER_SUPPORT], - ), - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def list_tickets( - context: Context, - status: Annotated[ - TicketStatus, - "The status of tickets to filter by. Defaults to 'open'", - ] = TicketStatus.OPEN, - limit: Annotated[ - int, - "Number of tickets to return. Defaults to 30", - ] = 30, - offset: Annotated[ - int, - "Number of tickets to skip before returning results. Defaults to 0", - ] = 0, - sort_order: Annotated[ - SortOrder, - "Sort order for tickets by ID. 'asc' returns oldest first, 'desc' returns newest first. " - "Defaults to 'desc'", - ] = SortOrder.DESC, -) -> Annotated[ - dict[str, Any], - "A dictionary containing tickets list (each with html_url), count, and pagination metadata. " - "Includes 'next_offset' when more results are available", -]: - """List tickets from your Zendesk account with offset-based pagination. - - By default, returns tickets sorted by ID with newest tickets first (desc). - - Each ticket in the response includes an 'html_url' field with the direct link - to view the ticket in Zendesk. - - PAGINATION: - - The response includes 'next_offset' when more results are available - - To fetch the next batch, simply pass the 'next_offset' value as the 'offset' parameter - - If 'next_offset' is not present, you've reached the end of available results - """ - - # Validate limit and offset parameters - if limit < 1: - raise RetryableToolError( - message="limit must be at least 1.", - developer_message=f"Invalid limit value: {limit}", - retry_after_ms=100, - additional_prompt_content="Provide a positive limit value", - ) - - if offset < 0: - raise RetryableToolError( - message="offset cannot be negative.", - developer_message=f"Invalid offset value: {offset}", - retry_after_ms=100, - additional_prompt_content="Provide a non-negative offset value", - ) - - # Get the authorization token - token = context.get_auth_token_or_empty() - subdomain = get_zendesk_subdomain(context) - - # Build the API URL - url = f"https://{subdomain}.zendesk.com/api/v2/tickets.json" - - # Base parameters for the request - base_params: dict[str, Any] = { - "status": status.value, - "per_page": 100, # Max allowed per page - "sort_order": sort_order.value, - } - - # Make the API request - async with httpx.AsyncClient() as client: - headers = { - "Authorization": f"Bearer {token}", - "Content-Type": "application/json", - } - - # Use the fetch_paginated_results utility - data = await fetch_paginated_results( - client=client, - url=url, - headers=headers, - params=base_params, - offset=offset, - limit=limit, - ) - - # Process tickets to add html_url and remove api url - tickets = data.get("results", []) - for ticket in tickets: - if "id" in ticket: - ticket["html_url"] = f"https://{subdomain}.zendesk.com/agent/tickets/{ticket['id']}" - # Remove API url to avoid confusion - if "url" in ticket: - del ticket["url"] - - # Build the result with consistent structure - result = { - "tickets": tickets, - "count": data.get("count", len(tickets)), - } - - # Add next_offset if present - if "next_offset" in data: - result["next_offset"] = data["next_offset"] - return result - - -@tool( - requires_auth=OAuth2(id="zendesk", scopes=["read"]), - requires_secrets=["ZENDESK_SUBDOMAIN"], - metadata=ToolMetadata( - classification=Classification( - service_domains=[ServiceDomain.CUSTOMER_SUPPORT], - ), - behavior=Behavior( - operations=[Operation.READ], - read_only=True, - destructive=False, - idempotent=True, - open_world=True, - ), - ), -) -async def get_ticket_comments( - context: Context, - ticket_id: Annotated[int, "The ID of the ticket to get comments for"], -) -> Annotated[ - dict[str, Any], "A dictionary containing the ticket comments, metadata, and ticket URL" -]: - """Get all comments for a specific Zendesk ticket, including the original description. - - The first comment is always the ticket's original description/content. - Subsequent comments show the conversation history. - - Each comment includes: - - author_id: ID of the comment author - - body: The comment text - - created_at: Timestamp when comment was created - - public: Whether the comment is public or internal - - attachments: List of file attachments (if any) with file_name, content_url, size, etc. - """ - - # Get the authorization token - token = context.get_auth_token_or_empty() - subdomain = get_zendesk_subdomain(context) - - # Zendesk API endpoint for ticket comments - url = f"https://{subdomain}.zendesk.com/api/v2/tickets/{ticket_id}/comments.json" - - # Make the API request - async with httpx.AsyncClient() as client: - headers = { - "Authorization": f"Bearer {token}", - "Content-Type": "application/json", - } - - response = await client.get(url, headers=headers) - _handle_ticket_not_found(response, ticket_id) - response.raise_for_status() - - data = response.json() - comments = data.get("comments", []) - - return { - "ticket_id": ticket_id, - "comments": comments, - "count": len(comments), - } - - -@tool( - requires_auth=OAuth2(id="zendesk", scopes=["tickets:write"]), - requires_secrets=["ZENDESK_SUBDOMAIN"], - metadata=ToolMetadata( - classification=Classification( - service_domains=[ServiceDomain.CUSTOMER_SUPPORT], - ), - behavior=Behavior( - operations=[Operation.CREATE], - read_only=False, - destructive=False, - idempotent=False, - open_world=True, - ), - ), -) -async def add_ticket_comment( - context: Context, - ticket_id: Annotated[int, "The ID of the ticket to comment on"], - comment_body: Annotated[str, "The text of the comment"], - public: Annotated[ - bool, "Whether the comment is public (visible to requester) or internal. Defaults to True" - ] = True, -) -> Annotated[ - dict[str, Any], "A dictionary containing the result of the comment operation and ticket URL" -]: - """Add a comment to an existing Zendesk ticket. - - The returned ticket object includes an 'html_url' field with the direct link - to view the ticket in Zendesk. - """ - - # Get the authorization token - token = context.get_auth_token_or_empty() - subdomain = get_zendesk_subdomain(context) - - # Zendesk API endpoint for updating ticket - url = f"https://{subdomain}.zendesk.com/api/v2/tickets/{ticket_id}.json" - - # Prepare the request body - request_body = {"ticket": {"comment": {"body": comment_body, "public": public}}} - - # Make the API request - async with httpx.AsyncClient() as client: - headers = { - "Authorization": f"Bearer {token}", - "Content-Type": "application/json", - } - - response = await client.put(url, headers=headers, json=request_body) - _handle_ticket_not_found(response, ticket_id) - response.raise_for_status() - - data = response.json() - ticket = data.get("ticket", {}) - - # Add web interface URL if not present - if "id" in ticket and "html_url" not in ticket: - ticket["html_url"] = f"https://{subdomain}.zendesk.com/agent/tickets/{ticket['id']}" - # Remove API url to avoid confusion - if "url" in ticket: - del ticket["url"] - - return { - "success": True, - "ticket_id": ticket_id, - "comment_type": "public" if public else "internal", - "ticket": ticket, - } - - -@tool( - requires_auth=OAuth2(id="zendesk", scopes=["tickets:write"]), - requires_secrets=["ZENDESK_SUBDOMAIN"], - metadata=ToolMetadata( - classification=Classification( - service_domains=[ServiceDomain.CUSTOMER_SUPPORT], - ), - behavior=Behavior( - operations=[Operation.UPDATE, Operation.CREATE], - read_only=False, - destructive=False, - idempotent=False, - open_world=True, - ), - ), -) -async def mark_ticket_solved( - context: Context, - ticket_id: Annotated[int, "The ID of the ticket to mark as solved"], - comment_body: Annotated[ - str | None, - "Optional final comment to add when solving the ticket", - ] = None, - comment_public: Annotated[ - bool, "Whether the comment is visible to the requester. Defaults to False" - ] = False, -) -> Annotated[dict[str, Any], "A dictionary containing the result of the solve operation"]: - """Mark a Zendesk ticket as solved, optionally with a final comment. - - The returned ticket object includes an 'html_url' field with the direct link - to view the ticket in Zendesk. - """ - - # Get the authorization token - token = context.get_auth_token_or_empty() - subdomain = get_zendesk_subdomain(context) - - # Zendesk API endpoint for updating ticket - url = f"https://{subdomain}.zendesk.com/api/v2/tickets/{ticket_id}.json" - - # Prepare the request body - request_body: dict[str, Any] = {"ticket": {"status": "solved"}} - - # Add resolution comment if provided - if comment_body: - request_body["ticket"]["comment"] = { - "body": comment_body, - "public": comment_public, - } - - # Make the API request - async with httpx.AsyncClient() as client: - headers = { - "Authorization": f"Bearer {token}", - "Content-Type": "application/json", - } - - response = await client.put(url, headers=headers, json=request_body) - _handle_ticket_not_found(response, ticket_id) - response.raise_for_status() - - data = response.json() - ticket = data.get("ticket", {}) - - # Add web interface URL if not present - if "id" in ticket and "html_url" not in ticket: - ticket["html_url"] = f"https://{subdomain}.zendesk.com/agent/tickets/{ticket['id']}" - # Remove API url to avoid confusion - if "url" in ticket: - del ticket["url"] - - result = { - "success": True, - "ticket_id": ticket_id, - "status": "solved", - "ticket": ticket, - } - if comment_body: - result["comment_added"] = True - result["comment_type"] = "public" if comment_public else "internal" - - return result diff --git a/toolkits/zendesk/arcade_zendesk/utils.py b/toolkits/zendesk/arcade_zendesk/utils.py deleted file mode 100644 index 8d8498d55..000000000 --- a/toolkits/zendesk/arcade_zendesk/utils.py +++ /dev/null @@ -1,216 +0,0 @@ -import logging -import re -from typing import Any - -import httpx -from arcade_mcp_server import Context -from arcade_mcp_server.exceptions import ToolExecutionError -from bs4 import BeautifulSoup - -logger = logging.getLogger(__name__) - -DEFAULT_MAX_BODY_LENGTH = 500 # Default max length for article body content - - -async def fetch_paginated_results( - client: httpx.AsyncClient, - url: str, - headers: dict[str, str], - params: dict[str, Any], - offset: int, - limit: int, -) -> dict[str, Any]: - """ - Fetch paginated results using offset and limit pattern. - - This function internally manages pagination to fulfill the requested offset and limit, - fetching multiple pages as needed. - - Args: - client: The HTTP client to use - url: The API endpoint URL - headers: Request headers including authorization - params: Base query parameters (without pagination params) - offset: Number of items to skip - limit: Number of items to return - - Returns: - Dict containing: - - results: List of fetched items - - count: Number of items returned - - next_offset: Present only if more results are available - """ - # Calculate pagination parameters - # Most Zendesk APIs use 1-based page numbering - items_per_page = params.get("per_page", 100) # Use per_page from params or default to 100 - start_page = (offset // items_per_page) + 1 - start_index = offset % items_per_page - - # Collect results across multiple pages if needed - all_results = [] - current_page = start_page - items_collected = 0 - has_more = False - last_page_had_more_items = False - - while items_collected < limit: - # Set the current page - page_params = params.copy() - page_params["page"] = current_page - - response = await client.get(url, headers=headers, params=page_params, timeout=30.0) - response.raise_for_status() - page_data = response.json() - - # Extract results from current page (handle both "results" and "tickets" keys) - page_results = page_data.get("results", page_data.get("tickets", [])) - - # If this is the first page, skip to the start index - if current_page == start_page: - page_results = page_results[start_index:] - - # Take only what we need to reach the limit - items_needed = limit - items_collected - results_to_add = page_results[:items_needed] - all_results.extend(results_to_add) - items_collected += len(results_to_add) - - # Check if we left items on this page - if len(page_results) > items_needed: - last_page_had_more_items = True - - # Check if there are more pages - has_more = page_data.get("next_page") is not None - - # Stop if we've collected enough or no more pages - if items_collected >= limit or not has_more: - break - - current_page += 1 - - # Build the response - result = { - "results": all_results, - "count": len(all_results), - } - - # Add next_offset if there might be more results - # This happens when: - # 1. We got exactly the limit requested AND (there are more pages OR we left items on the page) - # 2. We didn't get the full limit but there are more pages available - if (len(all_results) == limit and (has_more or last_page_had_more_items)) or ( - len(all_results) < limit and has_more - ): - result["next_offset"] = offset + len(all_results) - - return result - - -def clean_html_text(text: str | None) -> str: - """Remove HTML tags and clean up text.""" - if not text: - return "" - - soup = BeautifulSoup(text, "html.parser") - clean_text: str = soup.get_text(separator=" ") - - clean_text = re.sub(r"\n+", "\n", clean_text) - - clean_text = re.sub(r"\s+", " ", clean_text) - - clean_text = "\n".join(line.strip() for line in clean_text.split("\n")) - - return clean_text.strip() - - -def truncate_text( - text: str | None, max_length: int, suffix: str = " ... [truncated]" -) -> str | None: - """Truncate text to a maximum length with a suffix.""" - if not text or len(text) <= max_length: - return text - - truncate_at = max_length - len(suffix) - if truncate_at <= 0: - return suffix - - return text[:truncate_at] + suffix - - -def process_article_body(body: str | None, max_length: int | None = None) -> str | None: - """Process article body by cleaning HTML and optionally truncating.""" - if not body: - return None - - cleaned_text: str = clean_html_text(body) - - if max_length and len(cleaned_text) > max_length: - result = truncate_text(cleaned_text, max_length) - return result - - return cleaned_text - - -def process_search_results( - results: list[dict[str, Any]], - include_body: bool = False, - max_body_length: int | None = DEFAULT_MAX_BODY_LENGTH, -) -> list[dict[str, Any]]: - """Process search results to clean up data and restructure with content and metadata.""" - processed_results = [] - - for result in results: - body_content = result.get("body", "") - cleaned_content = None - - if include_body and body_content: - cleaned_content = process_article_body(body_content, max_body_length) - - processed_result: dict[str, Any] = {"content": cleaned_content, "metadata": {}} - - for key, value in result.items(): - if key != "body": - processed_result["metadata"][key] = value - - processed_results.append(processed_result) - - return processed_results - - -def validate_date_format(date_string: str) -> bool: - """Validate that a date string matches YYYY-MM-DD format and is a valid date.""" - from datetime import datetime - - try: - parsed_date = datetime.strptime(date_string, "%Y-%m-%d") - # Ensure the input matches the expected format exactly - return parsed_date.strftime("%Y-%m-%d") == date_string - except ValueError: - return False - - -def get_zendesk_subdomain(context: Context) -> str: - """ - Get the Zendesk subdomain from secrets with proper error handling. - - Args: - context: The tool context containing secrets - - Returns: - The Zendesk subdomain - - Raises: - ToolExecutionError: If the subdomain secret is not configured - """ - try: - subdomain = context.get_secret("ZENDESK_SUBDOMAIN") - except ValueError: - raise ToolExecutionError( - message="Zendesk subdomain is not set.", - developer_message=( - "Zendesk subdomain is not set. Make sure to set the " - "'ZENDESK_SUBDOMAIN' secret in the Arcade Dashboard." - ), - ) from None - else: - return subdomain diff --git a/toolkits/zendesk/arcade_zendesk/who_am_i_util.py b/toolkits/zendesk/arcade_zendesk/who_am_i_util.py deleted file mode 100644 index 6f773da55..000000000 --- a/toolkits/zendesk/arcade_zendesk/who_am_i_util.py +++ /dev/null @@ -1,118 +0,0 @@ -from typing import Any, TypedDict - -import httpx -from arcade_mcp_server import Context - - -class WhoAmIResponse(TypedDict, total=False): - user_id: int - name: str - email: str - role: str - active: bool - verified: bool - locale: str - time_zone: str - organization_id: int - organization_name: str - organization_domains: list[str] - zendesk_access: bool - - -async def build_who_am_i_response(context: Context) -> WhoAmIResponse: - """Build comprehensive who am I response for Zendesk.""" - user_info = await _get_current_user(context) - organization_info = await _get_organization_info(context, user_info.get("organization_id")) - - response_data = {} - response_data.update(_extract_user_info(user_info)) - response_data.update(_extract_organization_info(organization_info)) - response_data["zendesk_access"] = True - - return response_data # type: ignore[return-value] - - -async def _get_current_user(context: Context) -> dict[str, Any]: - """Get current user information from Zendesk API.""" - subdomain = context.get_secret("ZENDESK_SUBDOMAIN") - base_url = f"https://{subdomain}.zendesk.com" - - headers = { - "Authorization": f"Bearer {context.get_auth_token_or_empty()}", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient() as client: - response = await client.get(f"{base_url}/api/v2/users/me", headers=headers) - response.raise_for_status() - return response.json().get("user", {}) # type: ignore[no-any-return] - - -async def _get_organization_info(context: Context, organization_id: int | None) -> dict[str, Any]: - """Get organization information from Zendesk API.""" - if not organization_id: - return {} - - subdomain = context.get_secret("ZENDESK_SUBDOMAIN") - base_url = f"https://{subdomain}.zendesk.com" - - headers = { - "Authorization": f"Bearer {context.get_auth_token_or_empty()}", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient() as client: - response = await client.get( - f"{base_url}/api/v2/organizations/{organization_id}", headers=headers - ) - response.raise_for_status() - return response.json().get("organization", {}) # type: ignore[no-any-return] - - -def _extract_user_info(user_info: dict[str, Any]) -> dict[str, Any]: - """Extract user information from Zendesk user response.""" - extracted = {} - - if user_info.get("id"): - extracted["user_id"] = user_info["id"] - - if user_info.get("name"): - extracted["name"] = user_info["name"] - - if user_info.get("email"): - extracted["email"] = user_info["email"] - - if user_info.get("role"): - extracted["role"] = user_info["role"] - - if "active" in user_info: - extracted["active"] = user_info["active"] - - if "verified" in user_info: - extracted["verified"] = user_info["verified"] - - if user_info.get("locale"): - extracted["locale"] = user_info["locale"] - - if user_info.get("time_zone"): - extracted["time_zone"] = user_info["time_zone"] - - if user_info.get("organization_id"): - extracted["organization_id"] = user_info["organization_id"] - - return extracted - - -def _extract_organization_info(organization_info: dict[str, Any]) -> dict[str, Any]: - """Extract organization information from Zendesk organization response.""" - extracted = {} - - if organization_info.get("name"): - extracted["organization_name"] = organization_info["name"] - - if organization_info.get("domain_names"): - domains = organization_info["domain_names"] - if domains: - extracted["organization_domains"] = domains - - return extracted diff --git a/toolkits/zendesk/evals/eval_articles.py b/toolkits/zendesk/evals/eval_articles.py deleted file mode 100644 index 26c6b2dcf..000000000 --- a/toolkits/zendesk/evals/eval_articles.py +++ /dev/null @@ -1,360 +0,0 @@ -from datetime import timedelta - -from arcade_core import ToolCatalog -from arcade_evals import ( - DatetimeCritic, - EvalRubric, - EvalSuite, - ExpectedToolCall, - tool_eval, -) -from arcade_evals.critic import BinaryCritic, SimilarityCritic - -import arcade_zendesk -from arcade_zendesk.enums import ArticleSortBy, SortOrder -from arcade_zendesk.tools.search_articles import search_articles - -# Evaluation rubric -rubric = EvalRubric( - fail_threshold=0.85, - warn_threshold=0.95, -) - - -catalog = ToolCatalog() -catalog.add_module(arcade_zendesk) - - -@tool_eval() -def zendesk_search_articles_eval_suite() -> EvalSuite: - suite = EvalSuite( - name="Zendesk Search Articles Evaluation", - system_message=( - "You are an AI assistant with access to Zendesk Search Articles tool. " - "Use it to help users search for knowledge base articles and documentation." - ), - catalog=catalog, - rubric=rubric, - ) - - # Basic search scenarios - suite.add_case( - name="Basic search with query only", - user_message="Find articles about password reset", - expected_tool_calls=[ - ExpectedToolCall( - func=search_articles, - args={ - "query": "password reset", - }, - ) - ], - rubric=rubric, - critics=[ - SimilarityCritic(critic_field="query", weight=1.0), - ], - ) - - suite.add_case( - name="Search with specific result count", - user_message="Show me 25 articles about API documentation", - expected_tool_calls=[ - ExpectedToolCall( - func=search_articles, - args={"query": "API documentation", "limit": 25}, - ) - ], - rubric=rubric, - critics=[ - SimilarityCritic(critic_field="query", weight=0.7), - BinaryCritic(critic_field="limit", weight=0.3), - ], - ) - - # Date filtering scenarios - suite.add_case( - name="Search with created after date filter", - user_message="Find articles about security updates created after January 15, 2024", - expected_tool_calls=[ - ExpectedToolCall( - func=search_articles, - args={"query": "security updates", "created_after": "2024-01-15"}, - ) - ], - rubric=rubric, - critics=[ - SimilarityCritic(critic_field="query", weight=0.6), - DatetimeCritic(critic_field="created_after", weight=0.4, tolerance=timedelta(days=1)), - ], - ) - - suite.add_case( - name="Search with date range filter", - user_message="Show me articles about new features created between January and June 2024", - expected_tool_calls=[ - ExpectedToolCall( - func=search_articles, - args={ - "query": "new features", - "created_after": "2024-01-01", - "created_before": "2024-06-30", - }, - ) - ], - rubric=rubric, - critics=[ - SimilarityCritic(critic_field="query", weight=0.4), - DatetimeCritic(critic_field="created_after", weight=0.3, tolerance=timedelta(days=1)), - DatetimeCritic(critic_field="created_before", weight=0.3, tolerance=timedelta(days=1)), - ], - ) - - # Label filtering (Professional/Enterprise) - suite.add_case( - name="Search by labels only", - user_message="Show me articles tagged with windows and setup labels", - expected_tool_calls=[ - ExpectedToolCall(func=search_articles, args={"label_names": ["windows", "setup"]}) - ], - rubric=rubric, - critics=[ - SimilarityCritic(critic_field="label_names", weight=1.0), - ], - ) - - suite.add_case( - name="Search with query and labels", - user_message="Find installation guides with labels: macos, quickstart", - expected_tool_calls=[ - ExpectedToolCall( - func=search_articles, - args={"query": "installation guide", "label_names": ["macos", "quickstart"]}, - ) - ], - rubric=rubric, - critics=[ - SimilarityCritic(critic_field="query", weight=0.5), - SimilarityCritic(critic_field="label_names", weight=0.5), - ], - ) - - # Sorting scenarios - suite.add_case( - name="Search sorted by creation date ascending", - user_message="Find onboarding articles sorted by oldest first", - expected_tool_calls=[ - ExpectedToolCall( - func=search_articles, - args={ - "query": "onboarding", - "sort_by": ArticleSortBy.CREATED_AT, - "sort_order": SortOrder.ASC, - }, - ) - ], - rubric=rubric, - critics=[ - SimilarityCritic(critic_field="query", weight=0.4), - BinaryCritic(critic_field="sort_by", weight=0.3), - BinaryCritic(critic_field="sort_order", weight=0.3), - ], - ) - - suite.add_case( - name="Search sorted by most recently created", - user_message="Show me troubleshooting guides sorted by latest creation", - expected_tool_calls=[ - ExpectedToolCall( - func=search_articles, - args={ - "query": "troubleshooting guide", - "sort_by": ArticleSortBy.CREATED_AT, - "sort_order": SortOrder.DESC, - }, - ) - ], - rubric=rubric, - critics=[ - SimilarityCritic(critic_field="query", weight=0.4), - BinaryCritic(critic_field="sort_by", weight=0.3), - BinaryCritic(critic_field="sort_order", weight=0.3), - ], - ) - - # Pagination scenarios - suite.add_case( - name="Search with higher limit", - user_message="Show me 100 installation guides", - expected_tool_calls=[ - ExpectedToolCall( - func=search_articles, - args={"query": "installation guide", "limit": 100}, - ) - ], - rubric=rubric, - critics=[ - SimilarityCritic(critic_field="query", weight=0.7), - BinaryCritic(critic_field="limit", weight=0.3), - ], - ) - - suite.add_case( - name="Search with offset pagination", - user_message="Find API documentation, skip the first 50 results and show me the next 50", - expected_tool_calls=[ - ExpectedToolCall( - func=search_articles, - args={"query": "API documentation", "offset": 50, "limit": 50}, - ) - ], - rubric=rubric, - critics=[ - SimilarityCritic(critic_field="query", weight=0.4), - BinaryCritic(critic_field="offset", weight=0.3), - BinaryCritic(critic_field="limit", weight=0.3), - ], - ) - - # Complex search scenarios - suite.add_case( - name="Complex search with multiple filters", - user_message="Find recent troubleshooting articles about login issues " - "created after March 31, 2024, sorted by newest first", - expected_tool_calls=[ - ExpectedToolCall( - func=search_articles, - args={ - "query": "login issues troubleshooting", - "created_after": "2024-03-31", - "sort_by": ArticleSortBy.CREATED_AT, - "sort_order": SortOrder.DESC, - }, - ) - ], - rubric=rubric, - critics=[ - SimilarityCritic(critic_field="query", weight=0.4), - DatetimeCritic(critic_field="created_after", weight=0.3, tolerance=timedelta(days=1)), - BinaryCritic(critic_field="sort_by", weight=0.15), - BinaryCritic(critic_field="sort_order", weight=0.15), - ], - ) - - # Content control - suite.add_case( - name="Search without article body content", - user_message="List article titles about billing without the full content", - expected_tool_calls=[ - ExpectedToolCall(func=search_articles, args={"query": "billing", "include_body": False}) - ], - rubric=rubric, - critics=[ - SimilarityCritic(critic_field="query", weight=0.7), - BinaryCritic(critic_field="include_body", weight=0.3), - ], - ) - - # Edge cases - suite.add_case( - name="Search with exact phrase matching", - user_message='Find articles with the exact phrase "password reset procedure"', - expected_tool_calls=[ - ExpectedToolCall( - func=search_articles, - args={ - "query": '"password reset procedure"', - }, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="query", weight=1.0), - ], - ) - - return suite - - -@tool_eval() -def zendesk_search_articles_pagination_eval_suite() -> EvalSuite: - """Separate suite for pagination scenarios with context.""" - suite = EvalSuite( - name="Zendesk Pagination Evaluation", - system_message=( - "You are an AI assistant with access to Zendesk Help Center tools. " - "Use them to help users search for knowledge base articles. " - "When users ask for more results, use appropriate pagination parameters." - ), - catalog=catalog, - rubric=rubric, - ) - - # Pagination with context - suite.add_case( - name="Initial search with pagination context", - user_message="I need to find all troubleshooting articles. " - "Start by showing me the first 20.", - expected_tool_calls=[ - ExpectedToolCall(func=search_articles, args={"query": "troubleshooting", "limit": 20}) - ], - rubric=rubric, - critics=[ - SimilarityCritic(critic_field="query", weight=0.6), - BinaryCritic(critic_field="limit", weight=0.4), - ], - ) - - suite.add_case( - name="Request for more results after initial search", - user_message="Show me the next 20 troubleshooting articles", - expected_tool_calls=[ - ExpectedToolCall( - func=search_articles, - args={"query": "troubleshooting", "offset": 20, "limit": 20}, - ) - ], - rubric=rubric, - critics=[ - SimilarityCritic(critic_field="query", weight=0.5), - BinaryCritic(critic_field="offset", weight=0.25), - BinaryCritic(critic_field="limit", weight=0.25), - ], - additional_messages=[ - { - "role": "user", - "content": "I need to find all troubleshooting articles. " - "Start by showing me the first 20.", - }, - { - "role": "assistant", - "content": "I'll search for troubleshooting articles and " - "show you the first 20 results.", - "tool_calls": [ - { - "id": "call_1", - "type": "function", - "function": { - "name": "search_articles", - "arguments": '{"query": "troubleshooting", "limit": 20}', - }, - } - ], - }, - { - "role": "tool", - "content": '{"results": [{"content": "Troubleshooting guide 1", ' - '"metadata": {"id": 1, "title": "How to troubleshoot login issues"}}], ' - '"count": 20, "next_offset": 20}', - "tool_call_id": "call_1", - "name": "search_articles", - }, - { - "role": "assistant", - "content": "I found 20 troubleshooting articles, and there are more available. " - "The first one is 'How to troubleshoot login issues'. " - "Would you like to see more results?", - }, - ], - ) - - return suite diff --git a/toolkits/zendesk/evals/eval_tickets.py b/toolkits/zendesk/evals/eval_tickets.py deleted file mode 100644 index 628dd5038..000000000 --- a/toolkits/zendesk/evals/eval_tickets.py +++ /dev/null @@ -1,631 +0,0 @@ -from arcade_core import ToolCatalog -from arcade_evals import ( - BinaryCritic, - EvalRubric, - EvalSuite, - ExpectedToolCall, - SimilarityCritic, - tool_eval, -) - -import arcade_zendesk -from arcade_zendesk.enums import SortOrder, TicketStatus -from arcade_zendesk.tools.tickets import ( - add_ticket_comment, - get_ticket_comments, - list_tickets, - mark_ticket_solved, -) - -# Evaluation rubric -rubric = EvalRubric( - fail_threshold=0.85, - warn_threshold=0.95, -) - -catalog = ToolCatalog() -catalog.add_module(arcade_zendesk) - - -@tool_eval() -def zendesk_tickets_read_eval_suite() -> EvalSuite: - """Evaluation suite for ticket reading operations.""" - suite = EvalSuite( - name="Zendesk Tickets Read Operations", - system_message=( - "You are an AI assistant with access to Zendesk ticket tools. " - "Use them to help users view and manage support tickets." - ), - catalog=catalog, - rubric=rubric, - ) - - # Basic ticket listing - suite.add_case( - name="List all open tickets", - user_message="Show me all open tickets", - expected_tool_calls=[ - ExpectedToolCall( - func=list_tickets, - args={}, - ) - ], - rubric=rubric, - critics=[], # No args to validate - ) - - suite.add_case( - name="List tickets with explicit status request", - user_message="Can you list the open support tickets?", - expected_tool_calls=[ - ExpectedToolCall( - func=list_tickets, - args={}, - ) - ], - rubric=rubric, - critics=[], - ) - - suite.add_case( - name="Request for ticket overview", - user_message="I need to see what tickets are currently open", - expected_tool_calls=[ - ExpectedToolCall( - func=list_tickets, - args={}, - ) - ], - rubric=rubric, - critics=[], - ) - - # Test pagination - suite.add_case( - name="List tickets with limit", - user_message="Show me the first 5 open tickets", - expected_tool_calls=[ - ExpectedToolCall( - func=list_tickets, - args={"limit": 5}, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="limit", weight=1.0), - ], - ) - - # Test status filter - suite.add_case( - name="List tickets with specific status", - user_message="Show me all pending tickets", - expected_tool_calls=[ - ExpectedToolCall( - func=list_tickets, - args={"status": TicketStatus.PENDING}, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="status", weight=1.0), - ], - ) - - # Test sort order - suite.add_case( - name="List tickets oldest first", - user_message="Show me tickets sorted from oldest to newest", - expected_tool_calls=[ - ExpectedToolCall( - func=list_tickets, - args={"sort_order": SortOrder.ASC}, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="sort_order", weight=1.0), - ], - ) - - return suite - - -@tool_eval() -def zendesk_get_ticket_comments_eval_suite() -> EvalSuite: - """Evaluation suite for getting ticket comments.""" - suite = EvalSuite( - name="Zendesk Get Ticket Comments", - system_message=( - "You are an AI assistant with access to Zendesk ticket tools. " - "Use them to help users view ticket comments and conversation history." - ), - catalog=catalog, - rubric=rubric, - ) - - # Get comments for a ticket - suite.add_case( - name="Get comments for specific ticket", - user_message="Show me the comments for ticket 123", - expected_tool_calls=[ - ExpectedToolCall( - func=get_ticket_comments, - args={"ticket_id": 123}, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="ticket_id", weight=1.0), - ], - ) - - suite.add_case( - name="View ticket conversation", - user_message="Can you show me the conversation history for ticket #456?", - expected_tool_calls=[ - ExpectedToolCall( - func=get_ticket_comments, - args={"ticket_id": 456}, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="ticket_id", weight=1.0), - ], - ) - - suite.add_case( - name="Get ticket description", - user_message="What is the original description of ticket 789?", - expected_tool_calls=[ - ExpectedToolCall( - func=get_ticket_comments, - args={"ticket_id": 789}, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="ticket_id", weight=1.0), - ], - ) - - return suite - - -@tool_eval() -def zendesk_ticket_comments_eval_suite() -> EvalSuite: - """Evaluation suite for ticket comment operations.""" - suite = EvalSuite( - name="Zendesk Ticket Comments", - system_message=( - "You are an AI assistant with access to Zendesk ticket tools. " - "Use them to help users add comments to support tickets." - ), - catalog=catalog, - rubric=rubric, - ) - - # Public comments - suite.add_case( - name="Add public comment to ticket", - user_message="Add a comment to ticket 123 saying 'We are investigating this issue'", - expected_tool_calls=[ - ExpectedToolCall( - func=add_ticket_comment, - args={ - "ticket_id": 123, - "comment_body": "We are investigating this issue", - "public": True, - }, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="ticket_id", weight=0.3), - SimilarityCritic(critic_field="comment_body", weight=0.5), - BinaryCritic(critic_field="public", weight=0.2), - ], - ) - - suite.add_case( - name="Add public comment without specifying visibility", - user_message="Please comment on ticket #456: " - "The issue has been escalated to our engineering team", - expected_tool_calls=[ - ExpectedToolCall( - func=add_ticket_comment, - args={ - "ticket_id": 456, - "comment_body": "The issue has been escalated to our engineering team", - "public": True, - }, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="ticket_id", weight=0.3), - SimilarityCritic(critic_field="comment_body", weight=0.5), - BinaryCritic(critic_field="public", weight=0.2), - ], - ) - - # Internal comments - suite.add_case( - name="Add internal comment to ticket", - user_message="Add an internal note to ticket 789: Customer is VIP, prioritize this issue", - expected_tool_calls=[ - ExpectedToolCall( - func=add_ticket_comment, - args={ - "ticket_id": 789, - "comment_body": "Customer is VIP, prioritize this issue", - "public": False, - }, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="ticket_id", weight=0.3), - SimilarityCritic(critic_field="comment_body", weight=0.5), - BinaryCritic(critic_field="public", weight=0.2), - ], - ) - - suite.add_case( - name="Add private comment to ticket", - user_message="Add a private comment to ticket 321 for agents only: " - "Check with backend team about API limits", - expected_tool_calls=[ - ExpectedToolCall( - func=add_ticket_comment, - args={ - "ticket_id": 321, - "comment_body": "Check with backend team about API limits", - "public": False, - }, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="ticket_id", weight=0.3), - SimilarityCritic(critic_field="comment_body", weight=0.5), - BinaryCritic(critic_field="public", weight=0.2), - ], - ) - - # Complex comment scenarios - suite.add_case( - name="Add detailed public update", - user_message="Update ticket 555 with: 'We've identified the root cause. " - "A fix will be deployed within 24 hours. We apologize for the inconvenience.'", - expected_tool_calls=[ - ExpectedToolCall( - func=add_ticket_comment, - args={ - "ticket_id": 555, - "comment_body": "We've identified the root cause. " - "A fix will be deployed within 24 hours. We apologize for the inconvenience.", - "public": True, - }, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="ticket_id", weight=0.3), - SimilarityCritic(critic_field="comment_body", weight=0.6), - BinaryCritic(critic_field="public", weight=0.1), - ], - ) - - return suite - - -@tool_eval() -def zendesk_ticket_resolution_eval_suite() -> EvalSuite: - """Evaluation suite for ticket resolution operations.""" - suite = EvalSuite( - name="Zendesk Ticket Resolution", - system_message=( - "You are an AI assistant with access to Zendesk ticket tools. " - "Use them to help users resolve support tickets." - "Consider that closing a ticket is the same as marking it as solved." - ), - catalog=catalog, - rubric=rubric, - ) - - # Simple resolution - suite.add_case( - name="Mark ticket as solved without comment", - user_message="Mark ticket 100 as solved", - expected_tool_calls=[ - ExpectedToolCall( - func=mark_ticket_solved, - args={ - "ticket_id": 100, - }, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="ticket_id", weight=1.0), - ], - ) - - suite.add_case( - name="Close ticket", - user_message="Please close ticket #200", - expected_tool_calls=[ - ExpectedToolCall( - func=mark_ticket_solved, - args={ - "ticket_id": 200, - }, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="ticket_id", weight=1.0), - ], - ) - - # Resolution with public comment - suite.add_case( - name="Solve ticket with public resolution comment", - user_message="Resolve ticket 300 with comment: " - "'Issue resolved by updating your account settings'", - expected_tool_calls=[ - ExpectedToolCall( - func=mark_ticket_solved, - args={ - "ticket_id": 300, - "comment_body": "Issue resolved by updating your account settings", - "comment_public": True, - }, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="ticket_id", weight=0.3), - SimilarityCritic(critic_field="comment_body", weight=0.5), - BinaryCritic(critic_field="comment_public", weight=0.2), - ], - ) - - suite.add_case( - name="Close ticket with customer-facing message", - user_message="Close ticket 400 and tell the customer: " - "Your refund has been processed successfully", - expected_tool_calls=[ - ExpectedToolCall( - func=mark_ticket_solved, - args={ - "ticket_id": 400, - "comment_body": "Your refund has been processed successfully", - "comment_public": True, - }, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="ticket_id", weight=0.3), - SimilarityCritic(critic_field="comment_body", weight=0.5), - BinaryCritic(critic_field="comment_public", weight=0.2), - ], - ) - - # Resolution with internal comment - suite.add_case( - name="Solve ticket with internal note", - user_message="Mark ticket 500 as solved with internal note: " - "'Resolved via backend database fix'", - expected_tool_calls=[ - ExpectedToolCall( - func=mark_ticket_solved, - args={ - "ticket_id": 500, - "comment_body": "Resolved via backend database fix", - "comment_public": False, - }, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="ticket_id", weight=0.3), - SimilarityCritic(critic_field="comment_body", weight=0.5), - BinaryCritic(critic_field="comment_public", weight=0.2), - ], - ) - - # Default internal comment behavior - suite.add_case( - name="Solve ticket with comment defaults to internal", - user_message="Mark ticket 550 as solved with comment: 'Fixed by applying patch #2345'", - expected_tool_calls=[ - ExpectedToolCall( - func=mark_ticket_solved, - args={ - "ticket_id": 550, - "comment_body": "Fixed by applying patch #2345", - # comment_public should default to False if not specified - }, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="ticket_id", weight=0.4), - SimilarityCritic(critic_field="comment_body", weight=0.6), - ], - ) - - suite.add_case( - name="Close ticket with private resolution details", - user_message="Close ticket 600 with a private note for agents: " - "'Customer account had duplicate entries, merged successfully'", - expected_tool_calls=[ - ExpectedToolCall( - func=mark_ticket_solved, - args={ - "ticket_id": 600, - "comment_body": "Customer account had duplicate entries, merged successfully", - "comment_public": False, - }, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="ticket_id", weight=0.3), - SimilarityCritic(critic_field="comment_body", weight=0.5), - BinaryCritic(critic_field="comment_public", weight=0.2), - ], - ) - - return suite - - -@tool_eval() -def zendesk_ticket_workflow_eval_suite() -> EvalSuite: - """Evaluation suite for ticket workflow scenarios with context.""" - suite = EvalSuite( - name="Zendesk Ticket Workflows", - system_message=( - "You are an AI assistant with access to Zendesk ticket tools. " - "Use them to help users manage support ticket workflows." - ), - catalog=catalog, - rubric=rubric, - ) - - # Workflow: View then comment - suite.add_case( - name="Comment on specific ticket after viewing", - user_message="Add a comment to the login issue ticket saying we're working on it", - expected_tool_calls=[ - ExpectedToolCall( - func=add_ticket_comment, - args={ - "ticket_id": 1, - "comment_body": "We're currently working on resolving your login issue.", - "public": True, - }, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="ticket_id", weight=0.3), - SimilarityCritic(critic_field="comment_body", weight=0.5), - BinaryCritic(critic_field="public", weight=0.2), - ], - additional_messages=[ - { - "role": "user", - "content": "Show me all open tickets", - }, - { - "role": "assistant", - "content": "I'll list all open tickets for you.", - "tool_calls": [ - { - "id": "call_1", - "type": "function", - "function": { - "name": "list_tickets", - "arguments": "{}", - }, - } - ], - }, - { - "role": "tool", - "content": '{"tickets": [{"id": 1, "subject": "Login issue", "status": "open", ' - '"html_url": "https://example.zendesk.com/agent/tickets/1"}, ' - '{"id": 2, "subject": "Password reset request", "status": "open", ' - '"html_url": "https://example.zendesk.com/agent/tickets/2"}], "count": 2}', - "tool_call_id": "call_1", - "name": "list_tickets", - }, - { - "role": "assistant", - "content": "I found 2 open tickets:\n" - "1. Ticket #1: Login issue\n2. Ticket #2: Password reset request", - }, - ], - ) - - # Workflow: Comment then resolve - suite.add_case( - name="Resolve ticket after adding solution", - user_message="Now mark that ticket as solved", - expected_tool_calls=[ - ExpectedToolCall( - func=mark_ticket_solved, - args={ - "ticket_id": 789, - }, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="ticket_id", weight=1.0), - ], - additional_messages=[ - { - "role": "user", - "content": "Add a comment to ticket 789: " - "'Reset your password using the forgot password link on the login page'", - }, - { - "role": "assistant", - "content": "I'll add that comment to ticket 789.", - "tool_calls": [ - { - "id": "call_1", - "type": "function", - "function": { - "name": "add_ticket_comment", - "arguments": '{"ticket_id": 789, "comment_body": ' - '"Reset your password using the forgot password link on the login ' - 'page", "public": true}', - }, - } - ], - }, - { - "role": "tool", - "content": '{"success": true, "ticket_id": 789, "comment_type": "public", ' - '"ticket": {"id": 789, "html_url": "https://example.zendesk.com/agent/tickets/789"}}', - "tool_call_id": "call_1", - "name": "add_ticket_comment", - }, - { - "role": "assistant", - "content": "I've added the comment with password reset instructions " - "to ticket #789.", - }, - ], - ) - - # Workflow: Multiple updates - suite.add_case( - name="Add final comment and close ticket", - user_message="Add 'This issue has been fully resolved' and close ticket 999", - expected_tool_calls=[ - ExpectedToolCall( - func=mark_ticket_solved, - args={ - "ticket_id": 999, - "comment_body": "This issue has been fully resolved", - "comment_public": True, - }, - ) - ], - rubric=rubric, - critics=[ - BinaryCritic(critic_field="ticket_id", weight=0.3), - SimilarityCritic(critic_field="comment_body", weight=0.5), - BinaryCritic(critic_field="comment_public", weight=0.2), - ], - ) - - return suite diff --git a/toolkits/zendesk/pyproject.toml b/toolkits/zendesk/pyproject.toml deleted file mode 100644 index 52cbc0712..000000000 --- a/toolkits/zendesk/pyproject.toml +++ /dev/null @@ -1,60 +0,0 @@ -[build-system] -requires = [ "hatchling",] -build-backend = "hatchling.build" - -[project] -name = "arcade_zendesk" -version = "0.5.0" -requires-python = ">=3.10" -dependencies = [ - "arcade-mcp-server>=1.17.0,<2.0.0", - "httpx>=0.25.0,<1.0.0", - "beautifulsoup4>=4.0.0,<5" -] - -[project.scripts] -arcade-zendesk = "arcade_zendesk.__main__:main" -arcade_zendesk = "arcade_zendesk.__main__:main" - -[project.optional-dependencies] -dev = [ - "arcade-mcp[all]>=1.2.0,<2.0.0", - "pytest>=8.3.0,<8.4.0", - "pytest-cov>=4.0.0,<4.1.0", - "pytest-mock>=3.11.1,<3.12.0", - "pytest-asyncio>=0.24.0,<0.25.0", - "mypy>=1.5.1,<1.6.0", - "pre-commit>=3.4.0,<3.5.0", - "tox>=4.11.1,<4.12.0", - "ruff>=0.7.4,<0.8.0", -] - -# Use local path sources for arcade libs when working locally -[tool.uv.sources] -arcade-mcp = { path = "../../", editable = true } -arcade-mcp-server = { path = "../../libs/arcade-mcp-server/", editable = true } - - -[tool.mypy] -files = [ "arcade_zendesk/**/*.py",] -python_version = "3.10" -disallow_untyped_defs = "True" -disallow_any_unimported = "True" -no_implicit_optional = "True" -check_untyped_defs = "True" -warn_return_any = "True" -warn_unused_ignores = "True" -show_error_codes = "True" -ignore_missing_imports = "True" - -[tool.pytest.ini_options] -testpaths = [ "tests",] - -[tool.coverage.report] -skip_empty = true - -[tool.ruff.lint] -ignore = ["C901"] - -[tool.hatch.build.targets.wheel] -packages = [ "arcade_zendesk",] diff --git a/toolkits/zendesk/tests/__init__.py b/toolkits/zendesk/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/toolkits/zendesk/tests/conftest.py b/toolkits/zendesk/tests/conftest.py deleted file mode 100644 index 60d17f9ae..000000000 --- a/toolkits/zendesk/tests/conftest.py +++ /dev/null @@ -1,84 +0,0 @@ -from unittest.mock import AsyncMock, MagicMock - -import pytest -from arcade_mcp_server import Context - - -@pytest.fixture -def mock_context(): - """Standard mock context fixture used across all arcade toolkits.""" - context = MagicMock(spec=Context) - - context.get_auth_token_or_empty = MagicMock(return_value="fake-token") - context.get_secret = MagicMock() - - return context - - -@pytest.fixture -def mock_httpx_client(mocker): - """Mock httpx.AsyncClient for API calls.""" - mock_client_class = mocker.patch("httpx.AsyncClient", autospec=True) - mock_client = AsyncMock() - mock_client_class.return_value.__aenter__.return_value = mock_client - return mock_client - - -@pytest.fixture -def sample_article_response(): - """Sample article data for testing.""" - return { - "id": 123456, - "title": "How to reset your password", - "body": "

To reset your password, follow these steps:

" - "
  1. Click forgot password
  2. Enter your email
", - "url": "https://support.example.com/hc/en-us/articles/123456", - "created_at": "2024-01-15T10:00:00Z", - "updated_at": "2024-06-01T15:30:00Z", - "section_id": 789, - "category_id": 456, - "label_names": ["password", "security", "account"], - } - - -@pytest.fixture -def build_search_response(sample_article_response): - """Builder for search API responses.""" - - def builder(articles=None, next_page=None, count=None): - if articles is None: - articles = [sample_article_response] - - response = { - "results": articles, - "next_page": next_page, - "page": 1, - "per_page": len(articles), - "page_count": 1, - } - - if count is not None: - response["count"] = count - - return response - - return builder - - -@pytest.fixture -def mock_http_response(): - """Factory for creating mock HTTP responses.""" - - def create_response(json_data=None, status_code=200, raise_for_status=True): - response = MagicMock() - response.json.return_value = json_data - response.status_code = status_code - - if raise_for_status and status_code >= 400: - response.raise_for_status.side_effect = Exception(f"HTTP {status_code}") - else: - response.raise_for_status.return_value = None - - return response - - return create_response diff --git a/toolkits/zendesk/tests/test_search_articles.py b/toolkits/zendesk/tests/test_search_articles.py deleted file mode 100644 index 738831a00..000000000 --- a/toolkits/zendesk/tests/test_search_articles.py +++ /dev/null @@ -1,360 +0,0 @@ -import pytest -from arcade_mcp_server.exceptions import RetryableToolError, ToolExecutionError - -from arcade_zendesk.enums import ArticleSortBy, SortOrder -from arcade_zendesk.tools.search_articles import search_articles - - -class TestSearchArticlesValidation: - """Test input validation for search_articles.""" - - @pytest.mark.asyncio - async def test_missing_subdomain(self, mock_context): - """Test error when subdomain is not configured.""" - mock_context.get_secret.side_effect = ValueError("Secret not found") - - with pytest.raises(ToolExecutionError) as exc_info: - await search_articles(context=mock_context, query="test") - - assert "subdomain is not set" in str(exc_info.value.message) - - @pytest.mark.asyncio - async def test_missing_search_params(self, mock_context): - """Test error when no search parameters provided.""" - mock_context.get_secret.return_value = "test-subdomain" - - with pytest.raises(RetryableToolError) as exc_info: - await search_articles(context=mock_context) - - assert "At least one search parameter" in str(exc_info.value.message) - - @pytest.mark.parametrize( - "date_param,date_value", - [ - ("created_after", "2024/01/01"), - ("created_before", "01-15-2024"), - ("created_at", "2024-1-15"), - ("created_after", "2024-01-1"), - ("created_before", "20240115"), - ("created_at", "not-a-date"), - ], - ) - @pytest.mark.asyncio - async def test_invalid_date_format(self, mock_context, date_param, date_value): - """Test validation of date format parameters.""" - mock_context.get_secret.return_value = "test-subdomain" - - with pytest.raises(RetryableToolError) as exc_info: - await search_articles(context=mock_context, query="test", **{date_param: date_value}) - - assert "Invalid date format" in str(exc_info.value.message) - assert "YYYY-MM-DD" in str(exc_info.value.message) - assert date_param in str(exc_info.value.message) - - @pytest.mark.parametrize("limit", [0, -1, -10]) - @pytest.mark.asyncio - async def test_invalid_limit(self, mock_context, limit): - """Test validation of limit parameter.""" - mock_context.get_secret.return_value = "test-subdomain" - - with pytest.raises(RetryableToolError) as exc_info: - await search_articles(context=mock_context, query="test", limit=limit) - - assert "at least 1" in str(exc_info.value.message) - - @pytest.mark.parametrize("offset", [-1, -10]) - @pytest.mark.asyncio - async def test_invalid_offset(self, mock_context, offset): - """Test validation of offset parameter.""" - mock_context.get_secret.return_value = "test-subdomain" - - with pytest.raises(RetryableToolError) as exc_info: - await search_articles(context=mock_context, query="test", offset=offset) - - assert "cannot be negative" in str(exc_info.value.message) - - -class TestSearchArticlesSuccess: - """Test successful search scenarios.""" - - @pytest.mark.asyncio - async def test_basic_search( - self, mock_context, mock_httpx_client, build_search_response, mock_http_response - ): - """Test basic search with query parameter.""" - mock_context.get_secret.return_value = "test-subdomain" - - # Setup mock response - search_response = build_search_response() - mock_httpx_client.get.return_value = mock_http_response(search_response) - - result = await search_articles(context=mock_context, query="password reset") - - assert "results" in result - assert len(result["results"]) == 1 - assert result["results"][0]["metadata"]["title"] == "How to reset your password" - - mock_httpx_client.get.assert_called_once() - call_args = mock_httpx_client.get.call_args - assert ( - call_args[0][0] - == "https://test-subdomain.zendesk.com/api/v2/help_center/articles/search" - ) - assert call_args[1]["params"]["query"] == "password reset" - # Check that pagination params were set correctly - assert call_args[1]["params"]["page"] == 1 - assert call_args[1]["params"]["per_page"] == 100 - - @pytest.mark.asyncio - async def test_search_with_filters( - self, mock_context, mock_httpx_client, build_search_response, mock_http_response - ): - """Test search with multiple filter parameters.""" - mock_context.get_secret.return_value = "test-subdomain" - - search_response = build_search_response() - mock_httpx_client.get.return_value = mock_http_response(search_response) - - result = await search_articles( - context=mock_context, - query="API", - created_after="2024-01-01", - sort_by=ArticleSortBy.CREATED_AT, - sort_order=SortOrder.DESC, - limit=25, - ) - - assert "results" in result - - # Verify all parameters were passed - call_params = mock_httpx_client.get.call_args[1]["params"] - assert call_params["query"] == "API" - assert call_params["created_after"] == "2024-01-01" - assert call_params["sort_by"] == "created_at" - assert call_params["sort_order"] == "desc" - # Should fetch first page with 100 items per page - assert call_params["page"] == 1 - assert call_params["per_page"] == 100 - - @pytest.mark.asyncio - async def test_search_without_body( - self, - mock_context, - mock_httpx_client, - sample_article_response, - mock_http_response, - ): - """Test search with include_body=False.""" - mock_context.get_secret.return_value = "test-subdomain" - - search_response = {"results": [sample_article_response], "next_page": None} - mock_httpx_client.get.return_value = mock_http_response(search_response) - - result = await search_articles(context=mock_context, query="test", include_body=False) - - assert result["results"][0]["content"] is None - assert result["results"][0]["metadata"]["title"] == sample_article_response["title"] - - @pytest.mark.asyncio - async def test_search_by_labels( - self, mock_context, mock_httpx_client, build_search_response, mock_http_response - ): - """Test search by label names.""" - mock_context.get_secret.return_value = "test-subdomain" - - search_response = build_search_response() - mock_httpx_client.get.return_value = mock_http_response(search_response) - - result = await search_articles(context=mock_context, label_names=["password", "security"]) - - assert "results" in result - assert mock_httpx_client.get.call_args[1]["params"]["label_names"] == "password,security" - - -class TestSearchArticlesPagination: - """Test pagination scenarios.""" - - @pytest.mark.asyncio - async def test_single_page_default( - self, mock_context, mock_httpx_client, build_search_response, mock_http_response - ): - """Test default behavior returns single page.""" - mock_context.get_secret.return_value = "test-subdomain" - - search_response = build_search_response(count=100) - mock_httpx_client.get.return_value = mock_http_response(search_response) - - result = await search_articles(context=mock_context, query="test") - - assert len(result["results"]) == 1 - assert mock_httpx_client.get.call_count == 1 - - @pytest.mark.asyncio - async def test_fetch_with_limit_across_pages( - self, mock_context, mock_httpx_client, mock_http_response - ): - """Test fetching results across multiple pages with limit.""" - mock_context.get_secret.return_value = "test-subdomain" - - # Setup pagination responses - 100 items per page - articles_page1 = [ - {"id": i, "title": f"Article {i}", "body": f"Content {i}"} for i in range(1, 101) - ] - articles_page2 = [ - {"id": i, "title": f"Article {i}", "body": f"Content {i}"} for i in range(101, 201) - ] - - page1 = {"results": articles_page1, "next_page": "page2"} - page2 = {"results": articles_page2, "next_page": "page3"} - - mock_httpx_client.get.side_effect = [ - mock_http_response(page1), - mock_http_response(page2), - ] - - # Request 150 items starting from offset 0 - result = await search_articles(context=mock_context, query="test", limit=150) - - assert result["count"] == 150 - assert "next_offset" in result # More results available - assert result["next_offset"] == 150 - assert mock_httpx_client.get.call_count == 2 # Fetched 2 pages - - @pytest.mark.asyncio - async def test_fetch_with_offset(self, mock_context, mock_httpx_client, mock_http_response): - """Test fetching with offset parameter.""" - mock_context.get_secret.return_value = "test-subdomain" - - # Setup response - page 2 would have items 101-200 - # We want items starting from offset 150 (which is item 151, at index 50 on page 2) - articles_page2 = [ - {"id": i, "title": f"Article {i}", "body": f"Content {i}"} for i in range(101, 201) - ] - response = {"results": articles_page2, "next_page": "page3"} - - mock_httpx_client.get.return_value = mock_http_response(response) - - # Request 30 items starting from offset 150 - result = await search_articles(context=mock_context, query="test", offset=150, limit=30) - - assert result["count"] == 30 - assert "next_offset" in result - assert result["next_offset"] == 180 - - # Should request page 2 (offset 150 = page 2, starting at index 50) - call_params = mock_httpx_client.get.call_args[1]["params"] - assert call_params["page"] == 2 - - @pytest.mark.asyncio - async def test_no_next_offset_when_no_more_results( - self, mock_context, mock_httpx_client, build_search_response, mock_http_response - ): - """Test that next_offset is not included when no more results.""" - mock_context.get_secret.return_value = "test-subdomain" - - # Setup response with no next page - articles = [ - {"id": i, "title": f"Article {i}", "body": f"Content {i}"} for i in range(1, 21) - ] - response = {"results": articles, "next_page": None} - - mock_httpx_client.get.return_value = mock_http_response(response) - - result = await search_articles(context=mock_context, query="test", limit=20) - - assert result["count"] == 20 - assert "next_offset" not in result # No more results - - @pytest.mark.asyncio - async def test_partial_page_with_more_items( - self, mock_context, mock_httpx_client, mock_http_response - ): - """Test that next_offset is included when there are more items on the current page.""" - mock_context.get_secret.return_value = "test-subdomain" - - # Setup response with 50 items on a page, but we only request 30 - articles = [ - {"id": i, "title": f"Article {i}", "body": f"Content {i}"} for i in range(1, 51) - ] - response = {"results": articles, "next_page": None} - - mock_httpx_client.get.return_value = mock_http_response(response) - - # Request only 30 items when page has 50 - result = await search_articles(context=mock_context, query="test", limit=30) - - assert result["count"] == 30 - assert "next_offset" in result # More items available on current page - assert result["next_offset"] == 30 - - @pytest.mark.asyncio - async def test_request_more_than_available( - self, mock_context, mock_httpx_client, mock_http_response - ): - """Test when requesting more items than are available returns only what's available.""" - mock_context.get_secret.return_value = "test-subdomain" - - # Setup response with only 15 items total - articles = [ - {"id": i, "title": f"Article {i}", "body": f"Content {i}"} for i in range(1, 16) - ] - response = {"results": articles, "next_page": None} - - mock_httpx_client.get.return_value = mock_http_response(response) - - # Request 30 items when only 15 are available - result = await search_articles(context=mock_context, query="test", limit=30) - - assert result["count"] == 15 # Only returns what's available - assert "next_offset" not in result # No more results - - -class TestSearchArticlesContentProcessing: - """Test content processing and formatting.""" - - @pytest.mark.asyncio - async def test_html_cleaning(self, mock_context, mock_httpx_client, mock_http_response): - """Test HTML content is properly cleaned.""" - mock_context.get_secret.return_value = "test-subdomain" - - article_with_html = { - "id": 1, - "title": "Test Article", - "body": "

Header

Paragraph with bold and " - "italic.


Div content
", - "url": "https://example.com/article/1", - } - - search_response = {"results": [article_with_html], "next_page": None} - mock_httpx_client.get.return_value = mock_http_response(search_response) - - result = await search_articles(context=mock_context, query="test", include_body=True) - - content = result["results"][0]["content"] - assert content == "Header Paragraph with bold and italic . Div content" - - @pytest.mark.asyncio - async def test_max_article_length(self, mock_context, mock_httpx_client, mock_http_response): - """Test article length limiting.""" - mock_context.get_secret.return_value = "test-subdomain" - - long_article = { - "id": 1, - "title": "Long Article", - "body": "A" * 1000, # 1000 character body - } - - search_response = {"results": [long_article], "next_page": None} - mock_httpx_client.get.return_value = mock_http_response(search_response) - - # Test with default 500 char limit - result = await search_articles(context=mock_context, query="test") - assert len(result["results"][0]["content"]) < 520 # 500 + truncation suffix - - # Test with custom limit - result = await search_articles(context=mock_context, query="test", max_article_length=100) - assert len(result["results"][0]["content"]) < 120 # 100 + truncation suffix - - # Test with no limit - result = await search_articles(context=mock_context, query="test", max_article_length=None) - assert len(result["results"][0]["content"]) == 1000 diff --git a/toolkits/zendesk/tests/test_tickets.py b/toolkits/zendesk/tests/test_tickets.py deleted file mode 100644 index 1f10174df..000000000 --- a/toolkits/zendesk/tests/test_tickets.py +++ /dev/null @@ -1,526 +0,0 @@ -from unittest.mock import MagicMock - -import httpx -import pytest -from arcade_core.errors import ToolExecutionError - -from arcade_zendesk.enums import SortOrder, TicketStatus -from arcade_zendesk.tools.tickets import ( - add_ticket_comment, - get_ticket_comments, - list_tickets, - mark_ticket_solved, -) - - -class TestListTickets: - """Test list_tickets functionality.""" - - @pytest.mark.asyncio - async def test_list_tickets_success(self, mock_context, mock_httpx_client, mock_http_response): - """Test successful listing of open tickets.""" - mock_context.get_secret.return_value = "test-subdomain" - - # Mock response data - includes url field that should be removed - tickets_response = { - "tickets": [ - { - "id": 1, - "subject": "Login issue", - "status": "open", - "url": "https://test-subdomain.zendesk.com/api/v2/tickets/1.json", - }, - { - "id": 2, - "subject": "Password reset request", - "status": "open", - "url": "https://test-subdomain.zendesk.com/api/v2/tickets/2.json", - }, - ] - } - - mock_httpx_client.get.return_value = mock_http_response(tickets_response) - - result = await list_tickets(mock_context, status=TicketStatus.OPEN) - - # Verify the result is structured data - assert isinstance(result, dict) - assert "tickets" in result - assert "count" in result - assert result["count"] == 2 - - # Verify tickets have html_url but not url - for ticket in result["tickets"]: - assert "url" not in ticket - assert "html_url" in ticket - assert ticket["html_url"].startswith( - "https://test-subdomain.zendesk.com/agent/tickets/" - ) - - # Verify the API call with default parameters - mock_httpx_client.get.assert_called() - # The fetch_paginated_results makes the actual call - call_args = mock_httpx_client.get.call_args - assert "https://test-subdomain.zendesk.com/api/v2/tickets.json" in call_args[0][0] - assert call_args[1]["params"]["status"] == "open" - assert call_args[1]["params"]["per_page"] == 100 - assert call_args[1]["params"]["sort_order"] == "desc" - assert call_args[1]["params"]["page"] == 1 # First page - - @pytest.mark.asyncio - async def test_list_tickets_with_offset_limit( - self, mock_context, mock_httpx_client, mock_http_response - ): - """Test listing tickets with offset and limit.""" - mock_context.get_secret.return_value = "test-subdomain" - - # Mock response for page 2 (offset 10, limit 5) - tickets_response = { - "tickets": [ - {"id": 11, "subject": "Test 11", "status": "open"}, - {"id": 12, "subject": "Test 12", "status": "open"}, - {"id": 13, "subject": "Test 13", "status": "open"}, - {"id": 14, "subject": "Test 14", "status": "open"}, - {"id": 15, "subject": "Test 15", "status": "open"}, - ], - "next_page": "https://test.zendesk.com/api/v2/tickets.json?page=3", - } - - mock_httpx_client.get.return_value = mock_http_response(tickets_response) - - result = await list_tickets( - mock_context, status=TicketStatus.OPEN, limit=5, offset=10, sort_order=SortOrder.ASC - ) - - # Verify response structure - assert result["count"] == 5 - assert len(result["tickets"]) == 5 - assert "next_offset" in result - assert result["next_offset"] == 15 # offset + limit - - # Verify API call parameters - call_args = mock_httpx_client.get.call_args - assert ( - call_args[1]["params"]["page"] == 2 - ) # offset 10 / per_page 100 = page 2 (but adjusted for limit) - assert call_args[1]["params"]["per_page"] == 100 - assert call_args[1]["params"]["sort_order"] == "asc" - - @pytest.mark.asyncio - async def test_list_tickets_no_more_results( - self, mock_context, mock_httpx_client, mock_http_response - ): - """Test listing tickets when no more results are available.""" - mock_context.get_secret.return_value = "test-subdomain" - - # Mock response with no next_page - simulating the last page - tickets_response = { - "tickets": [{"id": 21, "subject": "Test", "status": "pending"}], - # No next_page means no more results - } - - mock_httpx_client.get.return_value = mock_http_response(tickets_response) - - result = await list_tickets( - mock_context, - status=TicketStatus.PENDING, - limit=10, - offset=0, # Start from beginning - ) - - # Verify no next_offset when no more results - assert "next_offset" not in result - assert result["count"] == 1 - - # Verify API call parameters - call_args = mock_httpx_client.get.call_args - assert call_args[1]["params"]["status"] == "pending" - assert call_args[1]["params"]["page"] == 1 - - @pytest.mark.asyncio - async def test_list_tickets_no_tickets( - self, mock_context, mock_httpx_client, mock_http_response - ): - """Test when no tickets are found.""" - mock_context.get_secret.return_value = "test-subdomain" - - mock_httpx_client.get.return_value = mock_http_response({"tickets": []}) - - result = await list_tickets(mock_context, status=TicketStatus.OPEN) - - assert result["tickets"] == [] - assert result["count"] == 0 - - @pytest.mark.asyncio - async def test_list_tickets_error(self, mock_context, mock_httpx_client): - """Test error handling for failed API call.""" - mock_context.get_secret.return_value = "test-subdomain" - - # Mock error response that raise_for_status will catch - error_response = MagicMock() - error_response.raise_for_status.side_effect = httpx.HTTPStatusError( - "Unauthorized", request=MagicMock(), response=MagicMock(status_code=401) - ) - mock_httpx_client.get.return_value = error_response - - with pytest.raises(ToolExecutionError): - await list_tickets(mock_context) - - @pytest.mark.asyncio - async def test_list_tickets_no_subdomain(self, mock_context): - """Test when subdomain is not configured.""" - mock_context.get_secret.return_value = None - - with pytest.raises(ToolExecutionError): - await list_tickets(mock_context) - - -class TestGetTicketComments: - """Test get_ticket_comments functionality.""" - - @pytest.mark.asyncio - async def test_get_ticket_comments_success( - self, mock_context, mock_httpx_client, mock_http_response - ): - """Test successfully getting ticket comments.""" - mock_context.get_secret.return_value = "test-subdomain" - - # Mock response data - comments_response = { - "comments": [ - { - "id": 1, - "body": "I cannot access my account. Please help!", - "author_id": 12345, - "created_at": "2024-01-15T10:00:00Z", - "public": True, - "attachments": [], - }, - { - "id": 2, - "body": "I'll help you reset your password.", - "author_id": 67890, - "created_at": "2024-01-15T10:30:00Z", - "public": True, - "attachments": [ - { - "file_name": "screenshot.png", - "content_url": "https://example.com/screenshot.png", - "size": 12345, - } - ], - }, - ] - } - - mock_httpx_client.get.return_value = mock_http_response(comments_response) - - result = await get_ticket_comments(mock_context, ticket_id=123) - - # Verify the result is structured data - assert isinstance(result, dict) - assert result["ticket_id"] == 123 - assert result["count"] == 2 - assert len(result["comments"]) == 2 - - # Verify attachments are included - assert result["comments"][1]["attachments"][0]["file_name"] == "screenshot.png" - - # Verify the API call - mock_httpx_client.get.assert_called_once_with( - "https://test-subdomain.zendesk.com/api/v2/tickets/123/comments.json", - headers={ - "Authorization": "Bearer fake-token", - "Content-Type": "application/json", - }, - ) - - @pytest.mark.asyncio - async def test_get_ticket_comments_no_comments( - self, mock_context, mock_httpx_client, mock_http_response - ): - """Test when no comments are found.""" - mock_context.get_secret.return_value = "test-subdomain" - - mock_httpx_client.get.return_value = mock_http_response({"comments": []}) - - result = await get_ticket_comments(mock_context, ticket_id=123) - - assert result["comments"] == [] - assert result["count"] == 0 - - @pytest.mark.asyncio - async def test_get_ticket_comments_not_found( - self, mock_context, mock_httpx_client, mock_http_response - ): - """Test when ticket is not found.""" - mock_context.get_secret.return_value = "test-subdomain" - - mock_httpx_client.get.return_value = mock_http_response({}, status_code=404) - - with pytest.raises(ToolExecutionError): - await get_ticket_comments(mock_context, ticket_id=999) - - @pytest.mark.asyncio - async def test_get_ticket_comments_error(self, mock_context, mock_httpx_client): - """Test error handling when API fails.""" - mock_context.get_secret.return_value = "test-subdomain" - - error_response = MagicMock() - error_response.status_code = 500 - error_response.raise_for_status.side_effect = httpx.HTTPStatusError( - "Server Error", request=MagicMock(), response=MagicMock(status_code=500) - ) - - mock_httpx_client.get.return_value = error_response - - with pytest.raises(ToolExecutionError): - await get_ticket_comments(mock_context, ticket_id=123) - - -class TestAddTicketComment: - """Test add_ticket_comment functionality.""" - - @pytest.mark.asyncio - async def test_add_public_comment_success( - self, mock_context, mock_httpx_client, mock_http_response - ): - """Test successfully adding a public comment.""" - mock_context.get_secret.return_value = "test-subdomain" - - ticket_response = { - "ticket": { - "id": 123, - "subject": "Test ticket", - "url": "https://test-subdomain.zendesk.com/api/v2/tickets/123.json", - } - } - - mock_httpx_client.put.return_value = mock_http_response(ticket_response) - - result = await add_ticket_comment( - mock_context, - ticket_id=123, - comment_body="This is a test comment", - public=True, - ) - - # Verify structured response - assert isinstance(result, dict) - assert result["success"] is True - assert result["ticket_id"] == 123 - assert result["comment_type"] == "public" - assert "ticket" in result - - # Verify ticket has html_url but not url - assert "url" not in result["ticket"] - assert "html_url" in result["ticket"] - - # Verify the API call - mock_httpx_client.put.assert_called_once_with( - "https://test-subdomain.zendesk.com/api/v2/tickets/123.json", - headers={ - "Authorization": "Bearer fake-token", - "Content-Type": "application/json", - }, - json={"ticket": {"comment": {"body": "This is a test comment", "public": True}}}, - ) - - @pytest.mark.asyncio - async def test_add_comment_default_public( - self, mock_context, mock_httpx_client, mock_http_response - ): - """Test that comment defaults to public when not specified.""" - mock_context.get_secret.return_value = "test-subdomain" - - mock_httpx_client.put.return_value = mock_http_response({"ticket": {"id": 123}}) - - result = await add_ticket_comment( - mock_context, - ticket_id=123, - comment_body="Test comment", - # Not specifying public parameter - ) - - assert result["comment_type"] == "public" - - # Verify the API call has public=True - call_args = mock_httpx_client.put.call_args - assert call_args[1]["json"]["ticket"]["comment"]["public"] is True - - @pytest.mark.asyncio - async def test_add_internal_comment_success( - self, mock_context, mock_httpx_client, mock_http_response - ): - """Test successfully adding an internal comment.""" - mock_context.get_secret.return_value = "test-subdomain" - - mock_httpx_client.put.return_value = mock_http_response({"ticket": {"id": 456}}) - - result = await add_ticket_comment( - mock_context, - ticket_id=456, - comment_body="Internal note for agents", - public=False, - ) - - assert result["comment_type"] == "internal" - - @pytest.mark.asyncio - async def test_add_comment_error(self, mock_context, mock_httpx_client): - """Test error handling when adding comment fails.""" - mock_context.get_secret.return_value = "test-subdomain" - - error_response = MagicMock() - error_response.raise_for_status.side_effect = httpx.HTTPStatusError( - "Not Found", request=MagicMock(), response=MagicMock(status_code=404) - ) - mock_httpx_client.put.return_value = error_response - - with pytest.raises(ToolExecutionError): - await add_ticket_comment(mock_context, ticket_id=999, comment_body="Test comment") - - -class TestMarkTicketSolved: - """Test mark_ticket_solved functionality.""" - - @pytest.mark.asyncio - async def test_mark_solved_without_comment( - self, mock_context, mock_httpx_client, mock_http_response - ): - """Test marking ticket as solved without a comment.""" - mock_context.get_secret.return_value = "test-subdomain" - - ticket_response = { - "ticket": { - "id": 789, - "status": "solved", - "url": "https://test-subdomain.zendesk.com/api/v2/tickets/789.json", - } - } - - mock_httpx_client.put.return_value = mock_http_response(ticket_response) - - result = await mark_ticket_solved(mock_context, ticket_id=789) - - # Verify structured response - assert isinstance(result, dict) - assert result["success"] is True - assert result["ticket_id"] == 789 - assert result["status"] == "solved" - assert "comment_added" not in result - - # Verify ticket has html_url - assert "html_url" in result["ticket"] - assert "url" not in result["ticket"] - - # Verify the API call - mock_httpx_client.put.assert_called_once_with( - "https://test-subdomain.zendesk.com/api/v2/tickets/789.json", - headers={ - "Authorization": "Bearer fake-token", - "Content-Type": "application/json", - }, - json={"ticket": {"status": "solved"}}, - ) - - @pytest.mark.asyncio - async def test_mark_solved_with_public_comment( - self, mock_context, mock_httpx_client, mock_http_response - ): - """Test marking ticket as solved with a public comment.""" - mock_context.get_secret.return_value = "test-subdomain" - - mock_httpx_client.put.return_value = mock_http_response({"ticket": {"id": 123}}) - - result = await mark_ticket_solved( - mock_context, - ticket_id=123, - comment_body="Issue resolved by resetting password", - comment_public=True, - ) - - assert result["comment_added"] is True - assert result["comment_type"] == "public" - - # Verify the request body includes the comment - call_args = mock_httpx_client.put.call_args - request_body = call_args[1]["json"] - assert request_body["ticket"]["status"] == "solved" - assert request_body["ticket"]["comment"]["body"] == "Issue resolved by resetting password" - assert request_body["ticket"]["comment"]["public"] is True - - @pytest.mark.asyncio - async def test_mark_solved_with_comment_default_internal( - self, mock_context, mock_httpx_client, mock_http_response - ): - """Test marking ticket as solved with comment defaults to internal.""" - mock_context.get_secret.return_value = "test-subdomain" - - mock_httpx_client.put.return_value = mock_http_response({"ticket": {"id": 555}}) - - result = await mark_ticket_solved( - mock_context, - ticket_id=555, - comment_body="Internal resolution note", - # Not specifying comment_public, should default to False - ) - - assert result["comment_added"] is True - assert result["comment_type"] == "internal" - - # Verify the comment is internal by default - call_args = mock_httpx_client.put.call_args - request_body = call_args[1]["json"] - assert request_body["ticket"]["comment"]["public"] is False - - @pytest.mark.asyncio - async def test_mark_solved_error(self, mock_context, mock_httpx_client): - """Test error handling when marking ticket as solved fails.""" - mock_context.get_secret.return_value = "test-subdomain" - - error_response = MagicMock() - error_response.raise_for_status.side_effect = httpx.HTTPStatusError( - "Forbidden", request=MagicMock(), response=MagicMock(status_code=403) - ) - mock_httpx_client.put.return_value = error_response - - with pytest.raises(ToolExecutionError): - await mark_ticket_solved(mock_context, ticket_id=999) - - -class TestAuthenticationAndSecrets: - """Test authentication and secrets handling.""" - - @pytest.mark.asyncio - async def test_no_auth_token(self, mock_context, mock_httpx_client): - """Test behavior when auth token is empty.""" - mock_context.get_auth_token_or_empty.return_value = "" - mock_context.get_secret.return_value = "test-subdomain" - - # The tools should still attempt the API call with empty token - error_response = MagicMock() - error_response.raise_for_status.side_effect = httpx.HTTPStatusError( - "Unauthorized", request=MagicMock(), response=MagicMock(status_code=401) - ) - mock_httpx_client.get.return_value = error_response - - with pytest.raises(ToolExecutionError): - await list_tickets(mock_context) - - # Should be called with empty Bearer token - call_args = mock_httpx_client.get.call_args - assert call_args[1]["headers"]["Authorization"] == "Bearer " - - @pytest.mark.asyncio - async def test_subdomain_from_secret(self, mock_context, mock_httpx_client, mock_http_response): - """Test that subdomain is correctly retrieved from secrets.""" - mock_context.get_secret.return_value = "my-company" - - mock_httpx_client.get.return_value = mock_http_response({"tickets": []}) - - await list_tickets(mock_context, status=TicketStatus.OPEN) - - # Verify the correct subdomain was used - call_args = mock_httpx_client.get.call_args - assert "https://my-company.zendesk.com" in call_args[0][0] diff --git a/toolkits/zendesk/tests/test_utils.py b/toolkits/zendesk/tests/test_utils.py deleted file mode 100644 index 995fa52c2..000000000 --- a/toolkits/zendesk/tests/test_utils.py +++ /dev/null @@ -1,291 +0,0 @@ -import pytest - -from arcade_zendesk.utils import ( - clean_html_text, - process_article_body, - process_search_results, - truncate_text, - validate_date_format, -) - - -class TestCleanHtmlText: - """Test HTML cleaning functionality.""" - - def test_clean_simple_html(self): - """Test cleaning basic HTML tags.""" - html = "

Hello World

" - assert clean_html_text(html) == "Hello World" - - def test_clean_complex_html(self): - """Test cleaning complex HTML with multiple tags.""" - html = """ -

Title

-

Paragraph with emphasis and bold.

-
    -
  • Item 1
  • -
  • Item 2
  • -
- - """ - cleaned = clean_html_text(html) - assert "Title" in cleaned - assert "Paragraph with emphasis and bold" in cleaned - assert "Item 1" in cleaned - assert "Item 2" in cleaned - assert "Footer content" in cleaned - assert "

" not in cleaned - assert "
  • " not in cleaned - - def test_clean_html_with_special_chars(self): - """Test cleaning HTML with special characters.""" - html = "

    Price: £100 & €120

    " - cleaned = clean_html_text(html) - assert "ยฃ100" in cleaned or "100" in cleaned # Depends on BeautifulSoup version - assert "&" in cleaned - assert "โ‚ฌ120" in cleaned or "120" in cleaned - - @pytest.mark.parametrize( - "input_value,expected", - [ - (None, ""), - ("", ""), - (" ", ""), - ("

    ", ""), - ("

    ", ""), - ], - ) - def test_clean_html_edge_cases(self, input_value, expected): - """Test edge cases for HTML cleaning.""" - assert clean_html_text(input_value) == expected - - def test_clean_html_preserves_line_breaks(self): - """Test that meaningful line breaks are preserved.""" - html = "

    Line 1

    Line 2

    " - cleaned = clean_html_text(html) - # Should have text from both lines - assert "Line 1" in cleaned - assert "Line 2" in cleaned - - -class TestTruncateText: - """Test text truncation functionality.""" - - def test_truncate_long_text(self): - """Test truncating text longer than max length.""" - text = "This is a very long text that needs to be truncated" - result = truncate_text(text, 20) - assert result == "This ... [truncated]" - assert len(result) == 20 - - def test_no_truncation_needed(self): - """Test text shorter than max length is not truncated.""" - text = "Short text" - assert truncate_text(text, 20) == text - - def test_truncate_with_custom_suffix(self): - """Test truncation with custom suffix.""" - text = "This is a long text for testing" - result = truncate_text(text, 15, "...") - assert result == "This is a lo..." - assert len(result) == 15 - - @pytest.mark.parametrize( - "text,max_length,expected", - [ - (None, 10, None), - ("", 10, ""), - ("Hello", 5, "Hello"), - ("Hello World", 5, " ... [truncated]"), # Suffix is longer than allowed - ], - ) - def test_truncate_edge_cases(self, text, max_length, expected): - """Test edge cases for truncation.""" - result = truncate_text(text, max_length) - if expected == " ... [truncated]": - # When suffix is longer than max_length, only suffix is returned - assert result == expected - else: - assert result == expected - - def test_truncate_at_word_boundary(self): - """Test that truncation happens cleanly.""" - text = "The quick brown fox jumps over the lazy dog" - result = truncate_text(text, 25) - assert result == "The quick ... [truncated]" - assert len(result) == 25 - - -class TestProcessArticleBody: - """Test article body processing.""" - - def test_process_html_body(self): - """Test processing HTML body content.""" - body = "

    Article Title

    Article content with formatting.

    " - result = process_article_body(body) - assert "Article Title" in result - assert "Article content with formatting" in result - assert "

    " not in result - assert "" not in result - - def test_process_body_with_truncation(self): - """Test processing body with max length.""" - body = "

    " + "Long content " * 50 + "

    " - result = process_article_body(body, max_length=100) - assert len(result) <= 100 + len(" ... [truncated]") - assert result.endswith(" ... [truncated]") - - @pytest.mark.parametrize( - "body,max_length,expected", - [ - (None, None, None), - ("", None, None), - ("

    Short

    ", 100, "Short"), - ( - "

    ", - None, - "", - ), # Empty paragraph returns empty string after cleaning - ], - ) - def test_process_body_edge_cases(self, body, max_length, expected): - """Test edge cases for body processing.""" - result = process_article_body(body, max_length) - assert result == expected - - -class TestProcessSearchResults: - """Test search results processing.""" - - def test_process_results_with_body(self): - """Test processing results with body content included.""" - results = [ - { - "id": 1, - "title": "Article 1", - "body": "

    Content 1

    ", - "url": "https://example.com/1", - }, - { - "id": 2, - "title": "Article 2", - "body": "

    Content 2

    ", - "url": "https://example.com/2", - }, - ] - - processed = process_search_results(results, include_body=True) - - assert len(processed) == 2 - assert processed[0]["content"] == "Content 1" - assert processed[0]["metadata"]["id"] == 1 - assert processed[0]["metadata"]["title"] == "Article 1" - assert "body" not in processed[0]["metadata"] - - assert processed[1]["content"] == "Content 2" - assert processed[1]["metadata"]["id"] == 2 - - def test_process_results_without_body(self): - """Test processing results without body content.""" - results = [ - { - "id": 1, - "title": "Article 1", - "body": "

    Content 1

    ", - "url": "https://example.com/1", - } - ] - - processed = process_search_results(results, include_body=False) - - assert processed[0]["content"] is None - assert processed[0]["metadata"]["id"] == 1 - assert processed[0]["metadata"]["title"] == "Article 1" - assert "body" not in processed[0]["metadata"] - - def test_process_results_with_max_body_length(self): - """Test processing results with body length limit.""" - results = [ - { - "id": 1, - "title": "Article", - "body": "

    " + "Long content " * 100 + "

    ", - } - ] - - processed = process_search_results(results, include_body=True, max_body_length=50) - - content = processed[0]["content"] - assert len(content) <= 50 + len(" ... [truncated]") - assert content.endswith(" ... [truncated]") - - def test_process_empty_results(self): - """Test processing empty results list.""" - processed = process_search_results([]) - assert processed == [] - - def test_process_results_preserves_all_metadata(self): - """Test that all non-body fields are preserved in metadata.""" - results = [ - { - "id": 1, - "title": "Article", - "body": "

    Content

    ", - "url": "https://example.com/1", - "created_at": "2024-01-01", - "custom_field": "value", - "nested": {"key": "value"}, - } - ] - - processed = process_search_results(results, include_body=True) - - metadata = processed[0]["metadata"] - assert metadata["id"] == 1 - assert metadata["title"] == "Article" - assert metadata["url"] == "https://example.com/1" - assert metadata["created_at"] == "2024-01-01" - assert metadata["custom_field"] == "value" - assert metadata["nested"] == {"key": "value"} - assert "body" not in metadata - - -class TestValidateDateFormat: - """Test date format validation.""" - - @pytest.mark.parametrize( - "date_string", - [ - "2024-01-15", - "2024-12-31", - "2000-01-01", - "1999-12-31", - "2030-06-15", - ], - ) - def test_valid_date_formats(self, date_string): - """Test valid YYYY-MM-DD date formats.""" - assert validate_date_format(date_string) is True - - @pytest.mark.parametrize( - "date_string", - [ - "2024/01/15", - "01-15-2024", - "2024-1-15", - "2024-01-1", - "24-01-15", - "2024.01.15", - "20240115", - "January 15, 2024", - "15/01/2024", - "2024", - "2024-01", - "", - "not-a-date", - # Note: These have valid format but invalid values - regex only checks format - ], - ) - def test_invalid_date_formats(self, date_string): - """Test invalid date formats.""" - assert validate_date_format(date_string) is False diff --git a/toolkits/zendesk/tests/test_who_am_i_util.py b/toolkits/zendesk/tests/test_who_am_i_util.py deleted file mode 100644 index 07113c3a9..000000000 --- a/toolkits/zendesk/tests/test_who_am_i_util.py +++ /dev/null @@ -1,330 +0,0 @@ -from unittest.mock import Mock, patch - -import httpx -import pytest - -from arcade_zendesk.who_am_i_util import ( - WhoAmIResponse, - _extract_organization_info, - _extract_user_info, - _get_current_user, - _get_organization_info, - build_who_am_i_response, -) - - -@pytest.fixture -def mock_context(): - """Create a mock ToolContext for testing.""" - context = Mock() - context.get_secret.return_value = "test-subdomain" - context.get_auth_token_or_empty.return_value = "test-token" - return context - - -@pytest.fixture -def sample_user_data(): - """Sample user data from Zendesk API.""" - return { - "id": 12345, - "name": "John Doe", - "email": "john.doe@example.com", - "role": "admin", - "active": True, - "verified": True, - "locale": "en-US", - "time_zone": "America/New_York", - "organization_id": 67890, - "created_at": "2023-01-01T00:00:00Z", - "updated_at": "2023-12-01T00:00:00Z", - } - - -@pytest.fixture -def sample_organization_data(): - """Sample organization data from Zendesk API.""" - return { - "id": 67890, - "name": "Example Corp", - "domain_names": ["example.com", "example.org"], - "created_at": "2022-01-01T00:00:00Z", - "updated_at": "2023-11-01T00:00:00Z", - "details": "Main organization", - "notes": "Primary customer", - "group_id": 123, - "shared_tickets": True, - "shared_comments": False, - } - - -class TestBuildWhoAmIResponse: - """Test the main build_who_am_i_response function.""" - - @pytest.mark.asyncio - async def test_build_complete_response( - self, mock_context, sample_user_data, sample_organization_data - ): - """Test building a complete who am I response.""" - with ( - patch("arcade_zendesk.who_am_i_util._get_current_user") as mock_get_user, - patch("arcade_zendesk.who_am_i_util._get_organization_info") as mock_get_org, - ): - mock_get_user.return_value = sample_user_data - mock_get_org.return_value = sample_organization_data - - result = await build_who_am_i_response(mock_context) - - assert isinstance(result, dict) - assert result["user_id"] == 12345 - assert result["name"] == "John Doe" - assert result["email"] == "john.doe@example.com" - assert result["role"] == "admin" - assert result["active"] is True - assert result["verified"] is True - assert result["locale"] == "en-US" - assert result["time_zone"] == "America/New_York" - assert result["organization_id"] == 67890 - assert result["organization_name"] == "Example Corp" - assert result["organization_domains"] == ["example.com", "example.org"] - assert result["zendesk_access"] is True - - mock_get_user.assert_called_once_with(mock_context) - mock_get_org.assert_called_once_with(mock_context, 67890) - - @pytest.mark.asyncio - async def test_build_response_without_organization(self, mock_context, sample_user_data): - """Test building response when user has no organization.""" - user_data_no_org = sample_user_data.copy() - del user_data_no_org["organization_id"] - - with ( - patch("arcade_zendesk.who_am_i_util._get_current_user") as mock_get_user, - patch("arcade_zendesk.who_am_i_util._get_organization_info") as mock_get_org, - ): - mock_get_user.return_value = user_data_no_org - mock_get_org.return_value = {} - - result = await build_who_am_i_response(mock_context) - - assert result["user_id"] == 12345 - assert result["name"] == "John Doe" - assert result["zendesk_access"] is True - assert "organization_name" not in result - assert "organization_domains" not in result - - mock_get_org.assert_called_once_with(mock_context, None) - - -class TestGetCurrentUser: - """Test the _get_current_user function.""" - - @pytest.mark.asyncio - async def test_get_current_user_success(self, mock_context, sample_user_data): - """Test successful user retrieval.""" - mock_response = Mock() - mock_response.json.return_value = {"user": sample_user_data} - mock_response.raise_for_status.return_value = None - - with patch("httpx.AsyncClient") as mock_client: - mock_client.return_value.__aenter__.return_value.get.return_value = mock_response - - result = await _get_current_user(mock_context) - - assert result == sample_user_data - mock_client.return_value.__aenter__.return_value.get.assert_called_once_with( - "https://test-subdomain.zendesk.com/api/v2/users/me", - headers={ - "Authorization": "Bearer test-token", - "Content-Type": "application/json", - }, - ) - - @pytest.mark.asyncio - async def test_get_current_user_http_error(self, mock_context): - """Test user retrieval with HTTP error.""" - with patch("httpx.AsyncClient") as mock_client: - mock_client.return_value.__aenter__.return_value.get.side_effect = ( - httpx.HTTPStatusError("404 Not Found", request=Mock(), response=Mock()) - ) - - with pytest.raises(httpx.HTTPStatusError): - await _get_current_user(mock_context) - - -class TestGetOrganizationInfo: - """Test the _get_organization_info function.""" - - @pytest.mark.asyncio - async def test_get_organization_info_success(self, mock_context, sample_organization_data): - """Test successful organization retrieval.""" - mock_response = Mock() - mock_response.json.return_value = {"organization": sample_organization_data} - mock_response.raise_for_status.return_value = None - - with patch("httpx.AsyncClient") as mock_client: - mock_client.return_value.__aenter__.return_value.get.return_value = mock_response - - result = await _get_organization_info(mock_context, 67890) - - assert result == sample_organization_data - mock_client.return_value.__aenter__.return_value.get.assert_called_once_with( - "https://test-subdomain.zendesk.com/api/v2/organizations/67890", - headers={ - "Authorization": "Bearer test-token", - "Content-Type": "application/json", - }, - ) - - @pytest.mark.asyncio - async def test_get_organization_info_no_id(self, mock_context): - """Test organization retrieval with no organization ID.""" - result = await _get_organization_info(mock_context, None) - assert result == {} - - @pytest.mark.asyncio - async def test_get_organization_info_http_error(self, mock_context): - """Test organization retrieval with HTTP error.""" - with patch("httpx.AsyncClient") as mock_client: - mock_client.return_value.__aenter__.return_value.get.side_effect = ( - httpx.HTTPStatusError("404 Not Found", request=Mock(), response=Mock()) - ) - - with pytest.raises(httpx.HTTPStatusError): - await _get_organization_info(mock_context, 67890) - - -class TestExtractUserInfo: - """Test the _extract_user_info function.""" - - def test_extract_complete_user_info(self, sample_user_data): - """Test extracting complete user information.""" - result = _extract_user_info(sample_user_data) - - expected = { - "user_id": 12345, - "name": "John Doe", - "email": "john.doe@example.com", - "role": "admin", - "active": True, - "verified": True, - "locale": "en-US", - "time_zone": "America/New_York", - "organization_id": 67890, - } - - assert result == expected - - def test_extract_partial_user_info(self): - """Test extracting partial user information.""" - partial_data = { - "id": 12345, - "name": "John Doe", - "email": "john.doe@example.com", - } - - result = _extract_user_info(partial_data) - - expected = { - "user_id": 12345, - "name": "John Doe", - "email": "john.doe@example.com", - } - - assert result == expected - - def test_extract_empty_user_info(self): - """Test extracting from empty user data.""" - result = _extract_user_info({}) - assert result == {} - - @pytest.mark.parametrize( - "field,value,expected_key", - [ - ("active", False, "active"), - ("verified", False, "verified"), - ("active", True, "active"), - ("verified", True, "verified"), - ], - ) - def test_extract_boolean_fields(self, field, value, expected_key): - """Test extracting boolean fields correctly.""" - user_data = {field: value} - result = _extract_user_info(user_data) - assert result[expected_key] == value - - -class TestExtractOrganizationInfo: - """Test the _extract_organization_info function.""" - - def test_extract_complete_organization_info(self, sample_organization_data): - """Test extracting complete organization information.""" - result = _extract_organization_info(sample_organization_data) - - assert result["organization_name"] == "Example Corp" - assert result["organization_domains"] == ["example.com", "example.org"] - - def test_extract_organization_info_no_domains(self): - """Test extracting organization info without domain names.""" - org_data = { - "name": "Example Corp", - "domain_names": [], - } - - result = _extract_organization_info(org_data) - - assert result["organization_name"] == "Example Corp" - assert "organization_domains" not in result - - def test_extract_organization_info_multiple_domains(self): - """Test extracting organization info with multiple domains.""" - org_data = { - "name": "Example Corp", - "domain_names": ["primary.com", "secondary.com", "tertiary.com"], - } - - result = _extract_organization_info(org_data) - - assert result["organization_name"] == "Example Corp" - assert result["organization_domains"] == ["primary.com", "secondary.com", "tertiary.com"] - - def test_extract_empty_organization_info(self): - """Test extracting from empty organization data.""" - result = _extract_organization_info({}) - assert "organization_name" not in result - assert "organization_domains" not in result - assert result == {} - - -class TestWhoAmIResponseType: - """Test the WhoAmIResponse TypedDict.""" - - def test_typed_dict_structure(self): - """Test that WhoAmIResponse accepts expected fields.""" - response: WhoAmIResponse = { - "user_id": 12345, - "name": "John Doe", - "email": "john.doe@example.com", - "role": "admin", - "active": True, - "verified": True, - "locale": "en-US", - "time_zone": "America/New_York", - "organization_id": 67890, - "organization_name": "Example Corp", - "organization_domains": ["example.com", "example.org"], - "zendesk_access": True, - } - - assert response["user_id"] == 12345 - assert response["zendesk_access"] is True - - def test_typed_dict_partial(self): - """Test that WhoAmIResponse works with partial data.""" - response: WhoAmIResponse = { - "user_id": 12345, - "name": "John Doe", - "zendesk_access": True, - } - - assert response["user_id"] == 12345 - assert response["zendesk_access"] is True From 98b02222365781f4e1b23ac6a44aa1b15abcd4d1 Mon Sep 17 00:00:00 2001 From: Eric Gustin Date: Wed, 25 Feb 2026 22:07:55 -0800 Subject: [PATCH 2/3] Simplify workflows --- .github/actions/setup-uv-env/action.yml | 24 +------------------ .../workflows/release-on-version-change.yml | 3 --- 2 files changed, 1 insertion(+), 26 deletions(-) diff --git a/.github/actions/setup-uv-env/action.yml b/.github/actions/setup-uv-env/action.yml index ee514ffe7..e00155243 100644 --- a/.github/actions/setup-uv-env/action.yml +++ b/.github/actions/setup-uv-env/action.yml @@ -6,18 +6,6 @@ inputs: required: false description: "The python version to use" default: "3.11" - is-contrib: - required: false - description: "Whether this is a contrib package" - default: "false" - is-lib: - required: false - description: "Whether this is a library package" - default: "false" - working-directory: - required: false - description: "Working directory for the installation" - default: "." runs: using: "composite" @@ -25,18 +13,8 @@ runs: - name: Install uv uses: astral-sh/setup-uv@v6 with: - working-directory: ${{ inputs.working-directory }} python-version: ${{ inputs.python-version }} - - name: Install contrib dependencies - if: inputs.is-contrib == 'true' - working-directory: ${{ inputs.working-directory }} - run: | - echo "Installing dependencies for ${{ inputs.working-directory }}" - make install - shell: bash - - - name: Install libs dependencies - if: inputs.is-contrib != 'true' + - name: Install dependencies run: uv sync --extra all --extra dev shell: bash diff --git a/.github/workflows/release-on-version-change.yml b/.github/workflows/release-on-version-change.yml index 986b972d1..58d09d706 100644 --- a/.github/workflows/release-on-version-change.yml +++ b/.github/workflows/release-on-version-change.yml @@ -83,9 +83,6 @@ jobs: uses: ./.github/actions/setup-uv-env with: python-version: "3.10" - is-contrib: ${{ startsWith(matrix.package, 'contrib/') }} - is-lib: ${{ startsWith(matrix.package, 'libs/') }} - working-directory: ${{ matrix.package }} - name: Run tests working-directory: ${{ matrix.package }} From ca9e4aa6cb54dd8d4c47dc898876d510241407e6 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Thu, 26 Feb 2026 06:51:35 +0000 Subject: [PATCH 3/3] Remove redundant docker-base target from Makefile Applied via @cursor push command --- Makefile | 8 -------- 1 file changed, 8 deletions(-) diff --git a/Makefile b/Makefile index 329a6e4bb..7b765c40e 100644 --- a/Makefile +++ b/Makefile @@ -83,14 +83,6 @@ docker: ## Build and run the Docker container @cd docker && make docker-build @cd docker && make docker-run -.PHONY: docker-base -docker-base: ## Build and run the Docker container - @echo "๐Ÿš€ Building lib packages..." - @make full-dist - @echo "๐Ÿš€ Building Docker image" - @cd docker && make docker-build - @cd docker && make docker-run - .PHONY: publish-ghcr publish-ghcr: ## Publish to the GHCR @cd docker && make publish-ghcr