Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 24 additions & 31 deletions src/dstack/_internal/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Awaitable, Callable, List, Optional
from typing import Annotated, Awaitable, Callable, List, Optional

import sentry_sdk
from fastapi import FastAPI, Request, Response, status
from fastapi import Depends, FastAPI, Request, Response, status
from fastapi.datastructures import URL
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from packaging.version import Version
from prometheus_client import Counter, Histogram
from sentry_sdk.types import SamplingContext

from dstack._internal import settings as core_settings
from dstack._internal.cli.utils.common import console
from dstack._internal.core.errors import ForbiddenError, ServerClientError
from dstack._internal.core.services.configs import update_default_project
Expand Down Expand Up @@ -68,7 +70,6 @@
get_client_version,
get_server_client_error_details,
)
from dstack._internal.settings import DSTACK_VERSION
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.ssh import check_required_ssh_version

Expand All @@ -91,6 +92,9 @@ def create_app() -> FastAPI:
app = FastAPI(
docs_url="/api/docs",
lifespan=lifespan,
dependencies=[
Depends(_check_client_version),
],
)
app.state.proxy_dependency_injector = ServerProxyDependencyInjector()
return app
Expand All @@ -102,7 +106,7 @@ async def lifespan(app: FastAPI):
if settings.SENTRY_DSN is not None:
sentry_sdk.init(
dsn=settings.SENTRY_DSN,
release=DSTACK_VERSION,
release=core_settings.DSTACK_VERSION,
environment=settings.SERVER_ENVIRONMENT,
enable_tracing=True,
traces_sampler=_sentry_traces_sampler,
Expand Down Expand Up @@ -164,7 +168,9 @@ async def lifespan(app: FastAPI):
else:
logger.info("Background processing is disabled")
PROBES_SCHEDULER.start()
dstack_version = DSTACK_VERSION if DSTACK_VERSION else "(no version)"
dstack_version = (
core_settings.DSTACK_VERSION if core_settings.DSTACK_VERSION else "(no version)"
)
job_network_mode_log = (
logger.info
if settings.JOB_NETWORK_MODE != settings.DEFAULT_JOB_NETWORK_MODE
Expand Down Expand Up @@ -336,32 +342,6 @@ def _extract_endpoint_label(request: Request, response: Response) -> str:
).inc()
return response

@app.middleware("http")
async def check_client_version(request: Request, call_next):
if (
not request.url.path.startswith("/api/")
or request.url.path in _NO_API_VERSION_CHECK_ROUTES
):
return await call_next(request)
try:
client_version = get_client_version(request)
except ValueError as e:
return CustomORJSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": [error_detail(str(e))]},
)
client_release: Optional[tuple[int, ...]] = None
if client_version is not None:
client_release = client_version.release
request.state.client_release = client_release
response = check_client_server_compatibility(
client_version=client_version,
server_version=DSTACK_VERSION,
)
if response is not None:
return response
return await call_next(request)

@app.get("/healthcheck")
async def healthcheck():
return CustomORJSONResponse(content={"status": "running"})
Expand Down Expand Up @@ -396,6 +376,19 @@ async def index():
return RedirectResponse("/api/docs")


def _check_client_version(
request: Request, client_version: Annotated[Optional[Version], Depends(get_client_version)]
) -> None:
if (
request.url.path.startswith("/api/")
and request.url.path not in _NO_API_VERSION_CHECK_ROUTES
):
check_client_server_compatibility(
client_version=client_version,
server_version=core_settings.DSTACK_VERSION,
)


def _is_proxy_request(request: Request) -> bool:
if request.url.path.startswith("/proxy"):
return True
Expand Down
Empty file.
20 changes: 20 additions & 0 deletions src/dstack/_internal/server/compatibility/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Optional

from packaging.version import Version

from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceOfferWithAvailability,
)


def patch_offers_list(
offers: list[InstanceOfferWithAvailability], client_version: Optional[Version]
) -> None:
if client_version is None:
return
# CLIs prior to 0.20.4 incorrectly display the `no_balance` availability in the run/fleet plan
if client_version < Version("0.20.4"):
for offer in offers:
if offer.availability == InstanceAvailability.NO_BALANCE:
offer.availability = InstanceAvailability.NOT_AVAILABLE
22 changes: 22 additions & 0 deletions src/dstack/_internal/server/compatibility/gpus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Optional

from packaging.version import Version

from dstack._internal.core.models.instances import InstanceAvailability
from dstack._internal.server.schemas.gpus import ListGpusResponse


def patch_list_gpus_response(
response: ListGpusResponse, client_version: Optional[Version]
) -> None:
if client_version is None:
return
# CLIs prior to 0.20.4 incorrectly display the `no_balance` availability in `dstack offer --group-by gpu`
if client_version < Version("0.20.4"):
for gpu in response.gpus:
if InstanceAvailability.NO_BALANCE in gpu.availability:
gpu.availability = [
a for a in gpu.availability if a != InstanceAvailability.NO_BALANCE
]
if InstanceAvailability.NOT_AVAILABLE not in gpu.availability:
gpu.availability.append(InstanceAvailability.NOT_AVAILABLE)
7 changes: 6 additions & 1 deletion src/dstack/_internal/server/routers/fleets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import List, Tuple
from typing import List, Optional, Tuple

from fastapi import APIRouter, Depends
from packaging.version import Version
from sqlalchemy.ext.asyncio import AsyncSession

import dstack._internal.server.services.fleets as fleets_services
from dstack._internal.core.errors import ResourceNotExistsError
from dstack._internal.core.models.fleets import Fleet, FleetPlan
from dstack._internal.server.compatibility.common import patch_offers_list
from dstack._internal.server.db import get_session
from dstack._internal.server.models import ProjectModel, UserModel
from dstack._internal.server.schemas.fleets import (
Expand All @@ -21,6 +23,7 @@
from dstack._internal.server.utils.routers import (
CustomORJSONResponse,
get_base_api_additional_responses,
get_client_version,
)

root_router = APIRouter(
Expand Down Expand Up @@ -101,6 +104,7 @@ async def get_plan(
body: GetFleetPlanRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
client_version: Optional[Version] = Depends(get_client_version),
):
"""
Returns a fleet plan for the given fleet configuration.
Expand All @@ -112,6 +116,7 @@ async def get_plan(
user=user,
spec=body.spec,
)
patch_offers_list(plan.offers, client_version)
return CustomORJSONResponse(plan)


Expand Down
14 changes: 11 additions & 3 deletions src/dstack/_internal/server/routers/gpus.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from typing import Tuple
from typing import Annotated, Optional, Tuple

from fastapi import APIRouter, Depends
from packaging.version import Version

from dstack._internal.server.compatibility.gpus import patch_list_gpus_response
from dstack._internal.server.models import ProjectModel, UserModel
from dstack._internal.server.schemas.gpus import ListGpusRequest, ListGpusResponse
from dstack._internal.server.security.permissions import ProjectMember
from dstack._internal.server.services.gpus import list_gpus_grouped
from dstack._internal.server.utils.routers import get_base_api_additional_responses
from dstack._internal.server.utils.routers import (
get_base_api_additional_responses,
get_client_version,
)

project_router = APIRouter(
prefix="/api/project/{project_name}/gpus",
Expand All @@ -18,7 +23,10 @@
@project_router.post("/list", response_model=ListGpusResponse, response_model_exclude_none=True)
async def list_gpus(
body: ListGpusRequest,
client_version: Annotated[Optional[Version], Depends(get_client_version)],
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
) -> ListGpusResponse:
_, project = user_project
return await list_gpus_grouped(project=project, run_spec=body.run_spec, group_by=body.group_by)
resp = await list_gpus_grouped(project=project, run_spec=body.run_spec, group_by=body.group_by)
patch_list_gpus_response(resp, client_version)
return resp
17 changes: 12 additions & 5 deletions src/dstack/_internal/server/routers/runs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Annotated, List, Optional, Tuple, cast
from typing import Annotated, List, Optional, Tuple

from fastapi import APIRouter, Depends, Request
from fastapi import APIRouter, Depends
from packaging.version import Version
from sqlalchemy.ext.asyncio import AsyncSession

from dstack._internal.core.errors import ResourceNotExistsError
from dstack._internal.core.models.runs import Run, RunPlan
from dstack._internal.server.compatibility.common import patch_offers_list
from dstack._internal.server.db import get_session
from dstack._internal.server.models import ProjectModel, UserModel
from dstack._internal.server.schemas.runs import (
Expand All @@ -21,6 +23,7 @@
from dstack._internal.server.utils.routers import (
CustomORJSONResponse,
get_base_api_additional_responses,
get_client_version,
)

root_router = APIRouter(
Expand All @@ -35,9 +38,10 @@
)


def use_legacy_repo_dir(request: Request) -> bool:
client_release = cast(Optional[tuple[int, ...]], request.state.client_release)
return client_release is not None and client_release < (0, 19, 27)
def use_legacy_repo_dir(
client_version: Annotated[Optional[Version], Depends(get_client_version)],
) -> bool:
return client_version is not None and client_version < Version("0.19.27")


@root_router.post(
Expand Down Expand Up @@ -110,6 +114,7 @@ async def get_plan(
body: GetRunPlanRequest,
session: Annotated[AsyncSession, Depends(get_session)],
user_project: Annotated[tuple[UserModel, ProjectModel], Depends(ProjectMember())],
client_version: Annotated[Optional[Version], Depends(get_client_version)],
legacy_repo_dir: Annotated[bool, Depends(use_legacy_repo_dir)],
):
"""
Expand All @@ -127,6 +132,8 @@ async def get_plan(
max_offers=body.max_offers,
legacy_repo_dir=legacy_repo_dir,
)
for job_plan in run_plan.job_plans:
patch_offers_list(job_plan.offers, client_version)
return CustomORJSONResponse(run_plan)


Expand Down
37 changes: 17 additions & 20 deletions src/dstack/_internal/server/utils/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,19 +124,28 @@ def get_request_size(request: Request) -> int:


def get_client_version(request: Request) -> Optional[packaging.version.Version]:
"""
FastAPI dependency that returns the dstack client version or None if the version is latest/dev.
"""

version = request.headers.get("x-api-version")
if version is None:
return None
return parse_version(version)
try:
return parse_version(version)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=[error_detail(str(e))],
)


def check_client_server_compatibility(
client_version: Optional[packaging.version.Version],
server_version: Optional[str],
) -> Optional[CustomORJSONResponse]:
) -> None:
"""
Returns `JSONResponse` with error if client/server versions are incompatible.
Returns `None` otherwise.
Raise HTTP exception if the client is incompatible with the server.
"""
if client_version is None or server_version is None:
return None
Expand All @@ -149,21 +158,9 @@ def check_client_server_compatibility(
client_version.major > parsed_server_version.major
or client_version.minor > parsed_server_version.minor
):
return error_incompatible_versions(
str(client_version), server_version, ask_cli_update=False
msg = f"The client/CLI version ({client_version}) is incompatible with the server version ({server_version})."
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=get_server_client_error_details(ServerClientError(msg=msg)),
)
return None


def error_incompatible_versions(
client_version: Optional[str],
server_version: str,
ask_cli_update: bool,
) -> CustomORJSONResponse:
msg = f"The client/CLI version ({client_version}) is incompatible with the server version ({server_version})."
if ask_cli_update:
msg += f" Update the dstack CLI: `pip install dstack=={server_version}`."
return CustomORJSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": get_server_client_error_details(ServerClientError(msg=msg))},
)
Loading