From 70adccc302e7666e95aad2484b14c258f3d73db0 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Wed, 14 Jan 2026 17:20:43 +0100 Subject: [PATCH] Do not return `NO_BALANCE` to older clients Since only newer CLIs can correctly display `InstanceAvailability.NO_BALANCE`, replace `NO_BALANCE` with `NOT_AVAILABLE` in server responses for older clients for the following API methods: - `/api/project/{project_name}/fleets/get_plan` - `/api/project/{project_name}/runs/get_plan` - `/api/project/{project_name}/gpus/list` Additionally, refactor the code to make it easy to retrieve the client version using FastAPI dependencies. ```python client_version: Annotated[Optional[Version], Depends(get_client_version)] ``` --- src/dstack/_internal/server/app.py | 55 ++++++------- .../server/compatibility/__init__.py | 0 .../_internal/server/compatibility/common.py | 20 +++++ .../_internal/server/compatibility/gpus.py | 22 +++++ src/dstack/_internal/server/routers/fleets.py | 7 +- src/dstack/_internal/server/routers/gpus.py | 14 +++- src/dstack/_internal/server/routers/runs.py | 17 ++-- src/dstack/_internal/server/utils/routers.py | 37 ++++----- .../_internal/server/routers/test_fleets.py | 63 +++++++++++++++ .../_internal/server/routers/test_gpus.py | 47 ++++++++++- .../_internal/server/routers/test_runs.py | 73 +++++++++++++++++ src/tests/_internal/server/test_app.py | 80 +++++++++++++++++++ .../_internal/server/utils/test_routers.py | 68 ++++++---------- 13 files changed, 399 insertions(+), 104 deletions(-) create mode 100644 src/dstack/_internal/server/compatibility/__init__.py create mode 100644 src/dstack/_internal/server/compatibility/common.py create mode 100644 src/dstack/_internal/server/compatibility/gpus.py diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 488a5a9e0e..b41152c149 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -5,16 +5,18 @@ from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager from pathlib import Path -from typing import Awaitable, Callable, List, Optional +from typing import Annotated, Awaitable, Callable, List, Optional import sentry_sdk -from fastapi import FastAPI, Request, Response, status +from fastapi import Depends, FastAPI, Request, Response, status from fastapi.datastructures import URL from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.staticfiles import StaticFiles +from packaging.version import Version from prometheus_client import Counter, Histogram from sentry_sdk.types import SamplingContext +from dstack._internal import settings as core_settings from dstack._internal.cli.utils.common import console from dstack._internal.core.errors import ForbiddenError, ServerClientError from dstack._internal.core.services.configs import update_default_project @@ -68,7 +70,6 @@ get_client_version, get_server_client_error_details, ) -from dstack._internal.settings import DSTACK_VERSION from dstack._internal.utils.logging import get_logger from dstack._internal.utils.ssh import check_required_ssh_version @@ -91,6 +92,9 @@ def create_app() -> FastAPI: app = FastAPI( docs_url="/api/docs", lifespan=lifespan, + dependencies=[ + Depends(_check_client_version), + ], ) app.state.proxy_dependency_injector = ServerProxyDependencyInjector() return app @@ -102,7 +106,7 @@ async def lifespan(app: FastAPI): if settings.SENTRY_DSN is not None: sentry_sdk.init( dsn=settings.SENTRY_DSN, - release=DSTACK_VERSION, + release=core_settings.DSTACK_VERSION, environment=settings.SERVER_ENVIRONMENT, enable_tracing=True, traces_sampler=_sentry_traces_sampler, @@ -164,7 +168,9 @@ async def lifespan(app: FastAPI): else: logger.info("Background processing is disabled") PROBES_SCHEDULER.start() - dstack_version = DSTACK_VERSION if DSTACK_VERSION else "(no version)" + dstack_version = ( + core_settings.DSTACK_VERSION if core_settings.DSTACK_VERSION else "(no version)" + ) job_network_mode_log = ( logger.info if settings.JOB_NETWORK_MODE != settings.DEFAULT_JOB_NETWORK_MODE @@ -336,32 +342,6 @@ def _extract_endpoint_label(request: Request, response: Response) -> str: ).inc() return response - @app.middleware("http") - async def check_client_version(request: Request, call_next): - if ( - not request.url.path.startswith("/api/") - or request.url.path in _NO_API_VERSION_CHECK_ROUTES - ): - return await call_next(request) - try: - client_version = get_client_version(request) - except ValueError as e: - return CustomORJSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": [error_detail(str(e))]}, - ) - client_release: Optional[tuple[int, ...]] = None - if client_version is not None: - client_release = client_version.release - request.state.client_release = client_release - response = check_client_server_compatibility( - client_version=client_version, - server_version=DSTACK_VERSION, - ) - if response is not None: - return response - return await call_next(request) - @app.get("/healthcheck") async def healthcheck(): return CustomORJSONResponse(content={"status": "running"}) @@ -396,6 +376,19 @@ async def index(): return RedirectResponse("/api/docs") +def _check_client_version( + request: Request, client_version: Annotated[Optional[Version], Depends(get_client_version)] +) -> None: + if ( + request.url.path.startswith("/api/") + and request.url.path not in _NO_API_VERSION_CHECK_ROUTES + ): + check_client_server_compatibility( + client_version=client_version, + server_version=core_settings.DSTACK_VERSION, + ) + + def _is_proxy_request(request: Request) -> bool: if request.url.path.startswith("/proxy"): return True diff --git a/src/dstack/_internal/server/compatibility/__init__.py b/src/dstack/_internal/server/compatibility/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/dstack/_internal/server/compatibility/common.py b/src/dstack/_internal/server/compatibility/common.py new file mode 100644 index 0000000000..227b45fdaf --- /dev/null +++ b/src/dstack/_internal/server/compatibility/common.py @@ -0,0 +1,20 @@ +from typing import Optional + +from packaging.version import Version + +from dstack._internal.core.models.instances import ( + InstanceAvailability, + InstanceOfferWithAvailability, +) + + +def patch_offers_list( + offers: list[InstanceOfferWithAvailability], client_version: Optional[Version] +) -> None: + if client_version is None: + return + # CLIs prior to 0.20.4 incorrectly display the `no_balance` availability in the run/fleet plan + if client_version < Version("0.20.4"): + for offer in offers: + if offer.availability == InstanceAvailability.NO_BALANCE: + offer.availability = InstanceAvailability.NOT_AVAILABLE diff --git a/src/dstack/_internal/server/compatibility/gpus.py b/src/dstack/_internal/server/compatibility/gpus.py new file mode 100644 index 0000000000..8548e58bf9 --- /dev/null +++ b/src/dstack/_internal/server/compatibility/gpus.py @@ -0,0 +1,22 @@ +from typing import Optional + +from packaging.version import Version + +from dstack._internal.core.models.instances import InstanceAvailability +from dstack._internal.server.schemas.gpus import ListGpusResponse + + +def patch_list_gpus_response( + response: ListGpusResponse, client_version: Optional[Version] +) -> None: + if client_version is None: + return + # CLIs prior to 0.20.4 incorrectly display the `no_balance` availability in `dstack offer --group-by gpu` + if client_version < Version("0.20.4"): + for gpu in response.gpus: + if InstanceAvailability.NO_BALANCE in gpu.availability: + gpu.availability = [ + a for a in gpu.availability if a != InstanceAvailability.NO_BALANCE + ] + if InstanceAvailability.NOT_AVAILABLE not in gpu.availability: + gpu.availability.append(InstanceAvailability.NOT_AVAILABLE) diff --git a/src/dstack/_internal/server/routers/fleets.py b/src/dstack/_internal/server/routers/fleets.py index 7e7126f4bf..d423134675 100644 --- a/src/dstack/_internal/server/routers/fleets.py +++ b/src/dstack/_internal/server/routers/fleets.py @@ -1,11 +1,13 @@ -from typing import List, Tuple +from typing import List, Optional, Tuple from fastapi import APIRouter, Depends +from packaging.version import Version from sqlalchemy.ext.asyncio import AsyncSession import dstack._internal.server.services.fleets as fleets_services from dstack._internal.core.errors import ResourceNotExistsError from dstack._internal.core.models.fleets import Fleet, FleetPlan +from dstack._internal.server.compatibility.common import patch_offers_list from dstack._internal.server.db import get_session from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.schemas.fleets import ( @@ -21,6 +23,7 @@ from dstack._internal.server.utils.routers import ( CustomORJSONResponse, get_base_api_additional_responses, + get_client_version, ) root_router = APIRouter( @@ -101,6 +104,7 @@ async def get_plan( body: GetFleetPlanRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), + client_version: Optional[Version] = Depends(get_client_version), ): """ Returns a fleet plan for the given fleet configuration. @@ -112,6 +116,7 @@ async def get_plan( user=user, spec=body.spec, ) + patch_offers_list(plan.offers, client_version) return CustomORJSONResponse(plan) diff --git a/src/dstack/_internal/server/routers/gpus.py b/src/dstack/_internal/server/routers/gpus.py index 45f0e8bf1f..3a701fb1e8 100644 --- a/src/dstack/_internal/server/routers/gpus.py +++ b/src/dstack/_internal/server/routers/gpus.py @@ -1,12 +1,17 @@ -from typing import Tuple +from typing import Annotated, Optional, Tuple from fastapi import APIRouter, Depends +from packaging.version import Version +from dstack._internal.server.compatibility.gpus import patch_list_gpus_response from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.schemas.gpus import ListGpusRequest, ListGpusResponse from dstack._internal.server.security.permissions import ProjectMember from dstack._internal.server.services.gpus import list_gpus_grouped -from dstack._internal.server.utils.routers import get_base_api_additional_responses +from dstack._internal.server.utils.routers import ( + get_base_api_additional_responses, + get_client_version, +) project_router = APIRouter( prefix="/api/project/{project_name}/gpus", @@ -18,7 +23,10 @@ @project_router.post("/list", response_model=ListGpusResponse, response_model_exclude_none=True) async def list_gpus( body: ListGpusRequest, + client_version: Annotated[Optional[Version], Depends(get_client_version)], user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), ) -> ListGpusResponse: _, project = user_project - return await list_gpus_grouped(project=project, run_spec=body.run_spec, group_by=body.group_by) + resp = await list_gpus_grouped(project=project, run_spec=body.run_spec, group_by=body.group_by) + patch_list_gpus_response(resp, client_version) + return resp diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index a4a09b3fb8..27d378d8ba 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -1,10 +1,12 @@ -from typing import Annotated, List, Optional, Tuple, cast +from typing import Annotated, List, Optional, Tuple -from fastapi import APIRouter, Depends, Request +from fastapi import APIRouter, Depends +from packaging.version import Version from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.core.errors import ResourceNotExistsError from dstack._internal.core.models.runs import Run, RunPlan +from dstack._internal.server.compatibility.common import patch_offers_list from dstack._internal.server.db import get_session from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.schemas.runs import ( @@ -21,6 +23,7 @@ from dstack._internal.server.utils.routers import ( CustomORJSONResponse, get_base_api_additional_responses, + get_client_version, ) root_router = APIRouter( @@ -35,9 +38,10 @@ ) -def use_legacy_repo_dir(request: Request) -> bool: - client_release = cast(Optional[tuple[int, ...]], request.state.client_release) - return client_release is not None and client_release < (0, 19, 27) +def use_legacy_repo_dir( + client_version: Annotated[Optional[Version], Depends(get_client_version)], +) -> bool: + return client_version is not None and client_version < Version("0.19.27") @root_router.post( @@ -110,6 +114,7 @@ async def get_plan( body: GetRunPlanRequest, session: Annotated[AsyncSession, Depends(get_session)], user_project: Annotated[tuple[UserModel, ProjectModel], Depends(ProjectMember())], + client_version: Annotated[Optional[Version], Depends(get_client_version)], legacy_repo_dir: Annotated[bool, Depends(use_legacy_repo_dir)], ): """ @@ -127,6 +132,8 @@ async def get_plan( max_offers=body.max_offers, legacy_repo_dir=legacy_repo_dir, ) + for job_plan in run_plan.job_plans: + patch_offers_list(job_plan.offers, client_version) return CustomORJSONResponse(run_plan) diff --git a/src/dstack/_internal/server/utils/routers.py b/src/dstack/_internal/server/utils/routers.py index a625ccd9a2..5aff751868 100644 --- a/src/dstack/_internal/server/utils/routers.py +++ b/src/dstack/_internal/server/utils/routers.py @@ -124,19 +124,28 @@ def get_request_size(request: Request) -> int: def get_client_version(request: Request) -> Optional[packaging.version.Version]: + """ + FastAPI dependency that returns the dstack client version or None if the version is latest/dev. + """ + version = request.headers.get("x-api-version") if version is None: return None - return parse_version(version) + try: + return parse_version(version) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=[error_detail(str(e))], + ) def check_client_server_compatibility( client_version: Optional[packaging.version.Version], server_version: Optional[str], -) -> Optional[CustomORJSONResponse]: +) -> None: """ - Returns `JSONResponse` with error if client/server versions are incompatible. - Returns `None` otherwise. + Raise HTTP exception if the client is incompatible with the server. """ if client_version is None or server_version is None: return None @@ -149,21 +158,9 @@ def check_client_server_compatibility( client_version.major > parsed_server_version.major or client_version.minor > parsed_server_version.minor ): - return error_incompatible_versions( - str(client_version), server_version, ask_cli_update=False + msg = f"The client/CLI version ({client_version}) is incompatible with the server version ({server_version})." + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=get_server_client_error_details(ServerClientError(msg=msg)), ) return None - - -def error_incompatible_versions( - client_version: Optional[str], - server_version: str, - ask_cli_update: bool, -) -> CustomORJSONResponse: - msg = f"The client/CLI version ({client_version}) is incompatible with the server version ({server_version})." - if ask_cli_update: - msg += f" Update the dstack CLI: `pip install dstack=={server_version}`." - return CustomORJSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": get_server_client_error_details(ServerClientError(msg=msg))}, - ) diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index 12e439111e..afa68b788d 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -1,5 +1,6 @@ import json from datetime import datetime, timezone +from typing import Optional from unittest.mock import Mock, patch from uuid import UUID, uuid4 @@ -1167,6 +1168,68 @@ async def test_returns_create_plan_for_existing_fleet( "action": "create", } + @pytest.mark.parametrize( + ("client_version", "expected_availability"), + [ + ("0.20.3", InstanceAvailability.NOT_AVAILABLE), + ("0.20.4", InstanceAvailability.NO_BALANCE), + (None, InstanceAvailability.NO_BALANCE), + ], + ) + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_replaces_no_balance_with_not_available_for_old_clients( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + client_version: Optional[str], + expected_availability: InstanceAvailability, + ): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + offers = [ + InstanceOfferWithAvailability( + backend=BackendType.AWS, + instance=InstanceType( + name="instance-1", + resources=Resources(cpus=1, memory_mib=512, spot=False, gpus=[]), + ), + region="us", + price=1.0, + availability=InstanceAvailability.AVAILABLE, + ), + InstanceOfferWithAvailability( + backend=BackendType.AWS, + instance=InstanceType( + name="instance-2", + resources=Resources(cpus=2, memory_mib=1024, spot=False, gpus=[]), + ), + region="us", + price=2.0, + availability=InstanceAvailability.NO_BALANCE, + ), + ] + headers = get_auth_headers(user.token) + if client_version is not None: + headers["X-API-Version"] = client_version + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + backend_mock = Mock() + m.return_value = [backend_mock] + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value.get_offers.return_value = offers + response = await client.post( + f"/api/project/{project.name}/fleets/get_plan", + headers=headers, + json={"spec": get_fleet_spec().dict()}, + ) + + assert response.status_code == 200 + offers = response.json()["offers"] + assert len(offers) == 2 + assert offers[0]["availability"] == InstanceAvailability.AVAILABLE.value + assert offers[1]["availability"] == expected_availability.value + def _fleet_model_to_json_dict(fleet: FleetModel) -> dict: return json.loads(fleet_model_to_fleet(fleet).json()) diff --git a/src/tests/_internal/server/routers/test_gpus.py b/src/tests/_internal/server/routers/test_gpus.py index d07a92bb2f..32c862231a 100644 --- a/src/tests/_internal/server/routers/test_gpus.py +++ b/src/tests/_internal/server/routers/test_gpus.py @@ -96,15 +96,19 @@ async def call_gpus_api( user_token: str, run_spec: RunSpec, group_by: Optional[List[str]] = None, + client_version: Optional[str] = None, ): """Helper to call the GPUs API with standard parameters.""" json_data = {"run_spec": run_spec.dict()} if group_by is not None: json_data["group_by"] = group_by + headers = get_auth_headers(user_token) + if client_version is not None: + headers["X-API-Version"] = client_version return await client.post( f"/api/project/{project_name}/gpus/list", - headers=get_auth_headers(user_token), + headers=headers, json=json_data, ) @@ -511,3 +515,44 @@ async def test_exact_aggregation_values( assert rtx_runpod_euwest1["region"] == "eu-west-1" assert rtx_runpod_euwest1["price"]["min"] == 0.65 assert rtx_runpod_euwest1["price"]["max"] == 0.65 + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize( + ("client_version", "expected_availability"), + [ + ("0.20.3", InstanceAvailability.NOT_AVAILABLE), + ("0.20.4", InstanceAvailability.NO_BALANCE), + (None, InstanceAvailability.NO_BALANCE), + ], + ) + async def test_replaces_no_balance_with_not_available_for_old_clients( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + client_version: Optional[str], + expected_availability: InstanceAvailability, + ): + user, project, repo, run_spec = await gpu_test_setup(session) + + available_offer = create_gpu_offer( + BackendType.AWS, "T4", 16384, 0.50, availability=InstanceAvailability.AVAILABLE + ) + no_balance_offer = create_gpu_offer( + BackendType.AWS, "L4", 24 * 1024, 1.0, availability=InstanceAvailability.NO_BALANCE + ) + offers_by_backend = {BackendType.AWS: [available_offer, no_balance_offer]} + mocked_backends = create_mock_backends_with_offers(offers_by_backend) + + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = mocked_backends + response = await call_gpus_api( + client, project.name, user.token, run_spec, client_version=client_version + ) + + assert response.status_code == 200 + response_data = response.json() + assert len(response_data["gpus"]) == 2 + assert response_data["gpus"][0]["availability"] == [InstanceAvailability.AVAILABLE.value] + assert response_data["gpus"][1]["availability"] == [expected_availability.value] diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 4f3ab2ed2d..627fa8a167 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -1280,6 +1280,79 @@ async def test_returns_run_plan_instance_volumes( assert response.status_code == 200, response.json() assert response.json() == run_plan_dict + @pytest.mark.parametrize( + ("client_version", "expected_availability"), + [ + ("0.20.3", InstanceAvailability.NOT_AVAILABLE), + ("0.20.4", InstanceAvailability.NO_BALANCE), + (None, InstanceAvailability.NO_BALANCE), + ], + ) + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_replaces_no_balance_with_not_available_for_old_clients( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + client_version: Optional[str], + expected_availability: InstanceAvailability, + ) -> None: + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + fleet_spec = get_fleet_spec() + fleet_spec.configuration.nodes = FleetNodesSpec(min=0, target=0, max=None) + await create_fleet(session=session, project=project, spec=fleet_spec) + repo = await create_repo(session=session, project_id=project.id) + offers = [ + InstanceOfferWithAvailability( + backend=BackendType.AWS, + instance=InstanceType( + name="instance-1", + resources=Resources(cpus=1, memory_mib=512, spot=False, gpus=[]), + ), + region="us", + price=1.0, + availability=InstanceAvailability.AVAILABLE, + ), + InstanceOfferWithAvailability( + backend=BackendType.AWS, + instance=InstanceType( + name="instance-2", + resources=Resources(cpus=2, memory_mib=1024, spot=False, gpus=[]), + ), + region="us", + price=2.0, + availability=InstanceAvailability.NO_BALANCE, + ), + ] + run_plan_dict = get_dev_env_run_plan_dict( + project_name=project.name, + username=user.name, + repo_id=repo.name, + offers=offers, + total_offers=1, + max_price=1.0, + ) + body = {"run_spec": run_plan_dict["run_spec"]} + headers = get_auth_headers(user.token) + if client_version is not None: + headers["X-API-Version"] = client_version + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value.get_offers.return_value = offers + m.return_value = [backend_mock] + response = await client.post( + f"/api/project/{project.name}/runs/get_plan", + headers=headers, + json=body, + ) + offers = response.json()["job_plans"][0]["offers"] + assert len(offers) == 2 + assert offers[0]["availability"] == InstanceAvailability.AVAILABLE.value + assert offers[1]["availability"] == expected_availability.value + @pytest.mark.asyncio @pytest.mark.parametrize( ("old_conf", "new_conf", "action"), diff --git a/src/tests/_internal/server/test_app.py b/src/tests/_internal/server/test_app.py index 8f11660d35..4fafb04e31 100644 --- a/src/tests/_internal/server/test_app.py +++ b/src/tests/_internal/server/test_app.py @@ -1,9 +1,14 @@ +from typing import Optional +from unittest.mock import patch + import pytest from fastapi.testclient import TestClient from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession +from dstack._internal import settings from dstack._internal.server.main import app +from dstack._internal.server.testing.common import create_user, get_auth_headers client = TestClient(app) @@ -16,3 +21,78 @@ async def test_returns_html(self, test_db, session: AsyncSession, client: AsyncC response = await client.get("/") assert response.status_code == 200 assert response.content.startswith(b'<') + + +class TestCheckXApiVersion: + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize( + ("client_version", "server_version", "is_compatible"), + [ + ("12.12.12", None, True), + ("0.12.4", "0.12.4", True), + (None, "0.1.12", True), + ("0.13.0", "0.12.4", False), + # For test performance, only a few cases are covered here. + # More cases are covered in `TestCheckClientServerCompatibility`. + ], + ) + @pytest.mark.parametrize("endpoint", ["/api/users/list", "/api/projects/list"]) + async def test_check_client_compatibility( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + endpoint: str, + client_version: Optional[str], + server_version: Optional[str], + is_compatible: bool, + ): + user = await create_user(session=session) + headers = get_auth_headers(user.token) + if client_version is not None: + headers["X-API-Version"] = client_version + + with patch.object(settings, "DSTACK_VERSION", server_version): + response = await client.post(endpoint, headers=headers, json={}) + + if is_compatible: + assert response.status_code == 200, response.text + else: + assert response.status_code == 400 + assert response.json() == { + "detail": [ + { + "code": "error", + "msg": f"The client/CLI version ({client_version}) is incompatible with the server version ({server_version}).", + } + ] + } + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize("endpoint", ["/api/users/list", "/api/projects/list"]) + @pytest.mark.parametrize("invalid_value", ["", "1..0", "version1"]) + async def test_invalid_x_api_version_header( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + endpoint: str, + invalid_value: str, + ): + user = await create_user(session=session) + headers = get_auth_headers(user.token) + headers["X-API-Version"] = invalid_value + + response = await client.post(endpoint, headers=headers, json={}) + + assert response.status_code == 400 + assert response.json() == { + "detail": [ + { + "code": None, + "msg": f"Invalid version: {invalid_value}", + } + ] + } diff --git a/src/tests/_internal/server/utils/test_routers.py b/src/tests/_internal/server/utils/test_routers.py index d3ea11213c..0aeb4be8b8 100644 --- a/src/tests/_internal/server/utils/test_routers.py +++ b/src/tests/_internal/server/utils/test_routers.py @@ -2,69 +2,51 @@ import packaging.version import pytest +from fastapi import HTTPException from dstack._internal.server.utils.routers import check_client_server_compatibility class TestCheckClientServerCompatibility: - @pytest.mark.parametrize("client_version", [packaging.version.parse("12.12.12"), None]) - def test_returns_none_if_server_version_is_none( - self, client_version: Optional[packaging.version.Version] - ): - assert ( - check_client_server_compatibility( - client_version=client_version, - server_version=None, - ) - is None - ) - @pytest.mark.parametrize( - "client_version,server_version", + ("client_version", "server_version"), [ + ("0.12.5", "0.12.4"), + ("0.12.5rc1", "0.12.4"), + ("0.12.4rc1", "0.12.4"), ("0.12.4", "0.12.4"), ("0.12.4", "0.12.5"), ("0.12.4", "0.13.0"), ("0.12.4", "1.12.0"), ("0.12.4", "0.12.5rc1"), ("1.0.5", "1.0.6"), + ("12.12.12", None), + (None, "0.1.12"), + (None, None), ], ) - def test_returns_none_if_compatible(self, client_version: str, server_version: str): - assert ( - check_client_server_compatibility( - client_version=packaging.version.parse(client_version), - server_version=server_version, - ) - is None - ) + def test_compatible( + self, client_version: Optional[str], server_version: Optional[str] + ) -> None: + parsed_client_version = None + if client_version is not None: + parsed_client_version = packaging.version.parse(client_version) - @pytest.mark.parametrize( - "client_version,server_version", - [ - ("0.13.0", "0.12.4"), - ("1.12.0", "0.12.0"), - ], - ) - def test_returns_error_if_client_version_larger( - self, client_version: str, server_version: str - ): - res = check_client_server_compatibility( - client_version=packaging.version.parse(client_version), + check_client_server_compatibility( + client_version=parsed_client_version, server_version=server_version, ) - assert res is not None @pytest.mark.parametrize( - "server_version", + ("client_version", "server_version"), [ - None, - "0.1.12", + ("0.13.0", "0.12.4"), + ("1.12.0", "0.12.0"), ], ) - def test_returns_none_if_client_version_is_latest(self, server_version: Optional[str]): - res = check_client_server_compatibility( - client_version=None, - server_version=server_version, - ) - assert res is None + def test_incompatible(self, client_version: str, server_version: str) -> None: + with pytest.raises(HTTPException): + check_client_server_compatibility( + client_version=packaging.version.parse(client_version), + server_version=server_version, + )