diff --git a/src/dstack/_internal/server/background/__init__.py b/src/dstack/_internal/server/background/__init__.py index 85af7d3315..8577cce6f1 100644 --- a/src/dstack/_internal/server/background/__init__.py +++ b/src/dstack/_internal/server/background/__init__.py @@ -42,7 +42,14 @@ def get_scheduler() -> AsyncIOScheduler: def start_background_tasks() -> AsyncIOScheduler: - # We try to process as many resources as possible without exhausting DB connections. + # Background processing is implemented via in-memory locks on SQLite + # and SELECT FOR UPDATE on Postgres. Locks may be held for a long time. + # This is currently the main bottleneck for scaling dstack processing + # as processing more resources requires more DB connections. + # TODO: Make background processing efficient by committing locks to DB + # and processing outside of DB transactions. + # + # Now we just try to process as many resources as possible without exhausting DB connections. # # Quick tasks can process multiple resources per transaction. # Potentially long tasks process one resource per transaction diff --git a/src/dstack/_internal/server/background/tasks/process_fleets.py b/src/dstack/_internal/server/background/tasks/process_fleets.py index 733029abf8..d369c7d242 100644 --- a/src/dstack/_internal/server/background/tasks/process_fleets.py +++ b/src/dstack/_internal/server/background/tasks/process_fleets.py @@ -5,7 +5,7 @@ from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload, load_only, selectinload +from sqlalchemy.orm import joinedload, load_only, selectinload, with_loader_criteria from dstack._internal.core.models.fleets import FleetSpec, FleetStatus from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason @@ -60,6 +60,9 @@ async def process_fleets(): .options( load_only(FleetModel.id, FleetModel.name), selectinload(FleetModel.instances).load_only(InstanceModel.id), + with_loader_criteria( + InstanceModel, InstanceModel.deleted == False, include_aliases=True + ), ) .order_by(FleetModel.last_processed_at.asc()) .limit(BATCH_SIZE) @@ -72,6 +75,7 @@ async def process_fleets(): .where( InstanceModel.id.not_in(instance_lockset), InstanceModel.fleet_id.in_(fleet_ids), + InstanceModel.deleted == False, ) .options(load_only(InstanceModel.id, InstanceModel.fleet_id)) .order_by(InstanceModel.id) @@ -113,8 +117,11 @@ async def _process_fleets(session: AsyncSession, fleet_models: List[FleetModel]) .where(FleetModel.id.in_(fleet_ids)) .options( joinedload(FleetModel.instances).joinedload(InstanceModel.jobs).load_only(JobModel.id), - joinedload(FleetModel.project), + with_loader_criteria( + InstanceModel, InstanceModel.deleted == False, include_aliases=True + ), ) + .options(joinedload(FleetModel.project)) .options(joinedload(FleetModel.runs).load_only(RunModel.status)) .execution_options(populate_existing=True) ) diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index 2241c4c6a4..9a14bdc30d 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -11,7 +11,7 @@ from pydantic import ValidationError from sqlalchemy import and_, delete, func, not_, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import joinedload, with_loader_criteria from dstack._internal import settings from dstack._internal.core.backends.base.compute import ( @@ -79,7 +79,6 @@ fleet_model_to_fleet, get_create_instance_offers, is_cloud_cluster, - is_fleet_master_instance, ) from dstack._internal.server.services.instances import ( get_instance_configuration, @@ -218,7 +217,12 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel): .where(InstanceModel.id == instance.id) .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) - .options(joinedload(InstanceModel.fleet).joinedload(FleetModel.instances)) + .options( + joinedload(InstanceModel.fleet).joinedload(FleetModel.instances), + with_loader_criteria( + InstanceModel, InstanceModel.deleted == False, include_aliases=True + ), + ) .execution_options(populate_existing=True) ) instance = res.unique().scalar_one() @@ -228,7 +232,12 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel): .where(InstanceModel.id == instance.id) .options(joinedload(InstanceModel.project)) .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) - .options(joinedload(InstanceModel.fleet).joinedload(FleetModel.instances)) + .options( + joinedload(InstanceModel.fleet).joinedload(FleetModel.instances), + with_loader_criteria( + InstanceModel, InstanceModel.deleted == False, include_aliases=True + ), + ) .execution_options(populate_existing=True) ) instance = res.unique().scalar_one() @@ -543,8 +552,11 @@ def _deploy_instance( async def _create_instance(session: AsyncSession, instance: InstanceModel) -> None: - if _need_to_wait_fleet_provisioning(instance): - logger.debug("Waiting for the first instance in the fleet to be provisioned") + master_instance = await _get_fleet_master_instance(session, instance) + if _need_to_wait_fleet_provisioning(instance, master_instance): + logger.debug( + "%s: waiting for the first instance in the fleet to be provisioned", fmt(instance) + ) return try: @@ -576,6 +588,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No placement_group_model = get_placement_group_model_for_instance( placement_group_models=placement_group_models, instance_model=instance, + master_instance_model=master_instance, ) offers = await get_create_instance_offers( project=instance.project, @@ -594,11 +607,15 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No continue compute = backend.compute() assert isinstance(compute, ComputeWithCreateInstanceSupport) - instance_offer = _get_instance_offer_for_instance(instance_offer, instance) + instance_offer = _get_instance_offer_for_instance( + instance_offer=instance_offer, + instance=instance, + master_instance=master_instance, + ) if ( instance.fleet and is_cloud_cluster(instance.fleet) - and is_fleet_master_instance(instance) + and instance.id == master_instance.id and instance_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT and isinstance(compute, ComputeWithPlacementGroupSupport) and ( @@ -667,7 +684,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No "instance_status": InstanceStatus.PROVISIONING.value, }, ) - if instance.fleet_id and is_fleet_master_instance(instance): + if instance.fleet_id and instance.id == master_instance.id: # Clean up placement groups that did not end up being used. # Flush to update still uncommitted placement groups. await session.flush() @@ -685,7 +702,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No InstanceTerminationReason.NO_OFFERS, "All offers failed" if offers else "No offers found", ) - if instance.fleet and is_fleet_master_instance(instance) and is_cloud_cluster(instance.fleet): + if instance.fleet and instance.id == master_instance.id and is_cloud_cluster(instance.fleet): # Do not attempt to deploy other instances, as they won't determine the correct cluster # backend, region, and placement group without a successfully deployed master instance for sibling_instance in instance.fleet.instances: @@ -694,6 +711,20 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No _mark_terminated(sibling_instance, InstanceTerminationReason.MASTER_FAILED) +async def _get_fleet_master_instance( + session: AsyncSession, instance: InstanceModel +) -> InstanceModel: + # The "master" fleet instance is relevant for cloud clusters only: + # it can be any fixed instance that is chosen to be provisioned first. + res = await session.execute( + select(InstanceModel) + .where(InstanceModel.fleet_id == instance.fleet_id) + .order_by(InstanceModel.instance_num, InstanceModel.created_at) + .limit(1) + ) + return res.scalar_one() + + def _mark_terminated( instance: InstanceModel, termination_reason: InstanceTerminationReason, @@ -1182,15 +1213,17 @@ def _get_termination_deadline(instance: InstanceModel) -> datetime.datetime: return instance.first_termination_retry_at + TERMINATION_RETRY_MAX_DURATION -def _need_to_wait_fleet_provisioning(instance: InstanceModel) -> bool: +def _need_to_wait_fleet_provisioning( + instance: InstanceModel, master_instance: InstanceModel +) -> bool: # Cluster cloud instances should wait for the first fleet instance to be provisioned # so that they are provisioned in the same backend/region if instance.fleet is None: return False if ( - is_fleet_master_instance(instance) - or instance.fleet.instances[0].job_provisioning_data is not None - or instance.fleet.instances[0].status == InstanceStatus.TERMINATED + instance.id == master_instance.id + or master_instance.job_provisioning_data is not None + or master_instance.status == InstanceStatus.TERMINATED ): return False return is_cloud_cluster(instance.fleet) @@ -1199,13 +1232,13 @@ def _need_to_wait_fleet_provisioning(instance: InstanceModel) -> bool: def _get_instance_offer_for_instance( instance_offer: InstanceOfferWithAvailability, instance: InstanceModel, + master_instance: InstanceModel, ) -> InstanceOfferWithAvailability: if instance.fleet is None: return instance_offer fleet = fleet_model_to_fleet(instance.fleet) - master_instance = instance.fleet.instances[0] - master_job_provisioning_data = get_instance_provisioning_data(master_instance) if fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER: + master_job_provisioning_data = get_instance_provisioning_data(master_instance) return get_instance_offer_with_restricted_az( instance_offer=instance_offer, master_job_provisioning_data=master_job_provisioning_data, diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index b4397b95e0..ad42e7ed40 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -4,7 +4,7 @@ from sqlalchemy import and_, func, or_, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only +from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only, with_loader_criteria import dstack._internal.server.services.services.autoscalers as autoscalers from dstack._internal.core.errors import ServerError @@ -111,7 +111,15 @@ async def _process_next_run(): ), ), ) - .options(joinedload(RunModel.jobs).load_only(JobModel.id)) + .options( + joinedload(RunModel.jobs).load_only(JobModel.id), + # No need to lock finished jobs + with_loader_criteria( + JobModel, + JobModel.status.not_in(JobStatus.finished_statuses()), + include_aliases=True, + ), + ) .options(load_only(RunModel.id)) .order_by(RunModel.last_processed_at.asc()) .limit(1) @@ -126,12 +134,20 @@ async def _process_next_run(): JobModel.run_id == run_model.id, JobModel.id.not_in(job_lockset), ) + .options( + load_only(JobModel.id), + with_loader_criteria( + JobModel, + JobModel.status.not_in(JobStatus.finished_statuses()), + include_aliases=True, + ), + ) .order_by(JobModel.id) # take locks in order .with_for_update(skip_locked=True, key_share=True) ) job_models = res.scalars().all() if len(run_model.jobs) != len(job_models): - # Some jobs are locked + # Some jobs are locked or there was a non-repeatable read return job_ids = [j.id for j in run_model.jobs] run_lockset.add(run_model.id) diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index d1d86c41aa..e132f83a49 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -7,7 +7,14 @@ from sqlalchemy import func, or_, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import contains_eager, joinedload, load_only, noload, selectinload +from sqlalchemy.orm import ( + contains_eager, + joinedload, + load_only, + noload, + selectinload, + with_loader_criteria, +) from dstack._internal.core.backends.base.backend import Backend from dstack._internal.core.backends.base.compute import ( @@ -213,7 +220,12 @@ async def _process_submitted_job( select(JobModel) .where(JobModel.id == job_model.id) .options(joinedload(JobModel.instance)) - .options(joinedload(JobModel.fleet).joinedload(FleetModel.instances)) + .options( + joinedload(JobModel.fleet).joinedload(FleetModel.instances), + with_loader_criteria( + InstanceModel, InstanceModel.deleted == False, include_aliases=True + ), + ) ) job_model = res.unique().scalar_one() res = await session.execute( @@ -221,7 +233,12 @@ async def _process_submitted_job( .where(RunModel.id == job_model.run_id) .options(joinedload(RunModel.project).joinedload(ProjectModel.backends)) .options(joinedload(RunModel.user).load_only(UserModel.name)) - .options(joinedload(RunModel.fleet).joinedload(FleetModel.instances)) + .options( + joinedload(RunModel.fleet).joinedload(FleetModel.instances), + with_loader_criteria( + InstanceModel, InstanceModel.deleted == False, include_aliases=True + ), + ) ) run_model = res.unique().scalar_one() logger.debug("%s: provisioning has started", fmt(job_model)) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index e347829fa4..95ae519d07 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -728,10 +728,6 @@ def is_cloud_cluster(fleet_model: FleetModel) -> bool: ) -def is_fleet_master_instance(instance: InstanceModel) -> bool: - return instance.fleet is not None and instance.id == instance.fleet.instances[0].id - - def get_fleet_requirements(fleet_spec: FleetSpec) -> Requirements: profile = fleet_spec.merged_profile requirements = Requirements( diff --git a/src/dstack/_internal/server/services/placement.py b/src/dstack/_internal/server/services/placement.py index f0c63f891c..d0c045cdc9 100644 --- a/src/dstack/_internal/server/services/placement.py +++ b/src/dstack/_internal/server/services/placement.py @@ -98,9 +98,10 @@ async def schedule_fleet_placement_groups_deletion( def get_placement_group_model_for_instance( placement_group_models: list[PlacementGroupModel], instance_model: InstanceModel, + master_instance_model: InstanceModel, ) -> Optional[PlacementGroupModel]: placement_group_model = None - if not _is_fleet_master_instance(instance_model): + if instance_model.id != master_instance_model.id: if placement_group_models: placement_group_model = placement_group_models[0] if len(placement_group_models) > 1: @@ -231,7 +232,3 @@ async def create_placement_group( ) placement_group_model.provisioning_data = pgpd.json() return placement_group_model - - -def _is_fleet_master_instance(instance: InstanceModel) -> bool: - return instance.fleet is not None and instance.id == instance.fleet.instances[0].id