Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/dstack/_internal/server/background/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,14 @@ def get_scheduler() -> AsyncIOScheduler:


def start_background_tasks() -> AsyncIOScheduler:
# We try to process as many resources as possible without exhausting DB connections.
# Background processing is implemented via in-memory locks on SQLite
# and SELECT FOR UPDATE on Postgres. Locks may be held for a long time.
# This is currently the main bottleneck for scaling dstack processing
# as processing more resources requires more DB connections.
# TODO: Make background processing efficient by committing locks to DB
# and processing outside of DB transactions.
#
# Now we just try to process as many resources as possible without exhausting DB connections.
#
# Quick tasks can process multiple resources per transaction.
# Potentially long tasks process one resource per transaction
Expand Down
11 changes: 9 additions & 2 deletions src/dstack/_internal/server/background/tasks/process_fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload, load_only, selectinload
from sqlalchemy.orm import joinedload, load_only, selectinload, with_loader_criteria

from dstack._internal.core.models.fleets import FleetSpec, FleetStatus
from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason
Expand Down Expand Up @@ -60,6 +60,9 @@ async def process_fleets():
.options(
load_only(FleetModel.id, FleetModel.name),
selectinload(FleetModel.instances).load_only(InstanceModel.id),
with_loader_criteria(
InstanceModel, InstanceModel.deleted == False, include_aliases=True
),
)
.order_by(FleetModel.last_processed_at.asc())
.limit(BATCH_SIZE)
Expand All @@ -72,6 +75,7 @@ async def process_fleets():
.where(
InstanceModel.id.not_in(instance_lockset),
InstanceModel.fleet_id.in_(fleet_ids),
InstanceModel.deleted == False,
)
.options(load_only(InstanceModel.id, InstanceModel.fleet_id))
.order_by(InstanceModel.id)
Expand Down Expand Up @@ -113,8 +117,11 @@ async def _process_fleets(session: AsyncSession, fleet_models: List[FleetModel])
.where(FleetModel.id.in_(fleet_ids))
.options(
joinedload(FleetModel.instances).joinedload(InstanceModel.jobs).load_only(JobModel.id),
joinedload(FleetModel.project),
with_loader_criteria(
InstanceModel, InstanceModel.deleted == False, include_aliases=True
),
)
.options(joinedload(FleetModel.project))
.options(joinedload(FleetModel.runs).load_only(RunModel.status))
.execution_options(populate_existing=True)
)
Expand Down
65 changes: 49 additions & 16 deletions src/dstack/_internal/server/background/tasks/process_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pydantic import ValidationError
from sqlalchemy import and_, delete, func, not_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import joinedload, with_loader_criteria

from dstack._internal import settings
from dstack._internal.core.backends.base.compute import (
Expand Down Expand Up @@ -79,7 +79,6 @@
fleet_model_to_fleet,
get_create_instance_offers,
is_cloud_cluster,
is_fleet_master_instance,
)
from dstack._internal.server.services.instances import (
get_instance_configuration,
Expand Down Expand Up @@ -218,7 +217,12 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
.where(InstanceModel.id == instance.id)
.options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends))
.options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status))
.options(joinedload(InstanceModel.fleet).joinedload(FleetModel.instances))
.options(
joinedload(InstanceModel.fleet).joinedload(FleetModel.instances),
with_loader_criteria(
InstanceModel, InstanceModel.deleted == False, include_aliases=True
),
)
.execution_options(populate_existing=True)
)
instance = res.unique().scalar_one()
Expand All @@ -228,7 +232,12 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
.where(InstanceModel.id == instance.id)
.options(joinedload(InstanceModel.project))
.options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status))
.options(joinedload(InstanceModel.fleet).joinedload(FleetModel.instances))
.options(
joinedload(InstanceModel.fleet).joinedload(FleetModel.instances),
with_loader_criteria(
InstanceModel, InstanceModel.deleted == False, include_aliases=True
),
)
.execution_options(populate_existing=True)
)
instance = res.unique().scalar_one()
Expand Down Expand Up @@ -543,8 +552,11 @@ def _deploy_instance(


async def _create_instance(session: AsyncSession, instance: InstanceModel) -> None:
if _need_to_wait_fleet_provisioning(instance):
logger.debug("Waiting for the first instance in the fleet to be provisioned")
master_instance = await _get_fleet_master_instance(session, instance)
if _need_to_wait_fleet_provisioning(instance, master_instance):
logger.debug(
"%s: waiting for the first instance in the fleet to be provisioned", fmt(instance)
)
return

try:
Expand Down Expand Up @@ -576,6 +588,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
placement_group_model = get_placement_group_model_for_instance(
placement_group_models=placement_group_models,
instance_model=instance,
master_instance_model=master_instance,
)
offers = await get_create_instance_offers(
project=instance.project,
Expand All @@ -594,11 +607,15 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
continue
compute = backend.compute()
assert isinstance(compute, ComputeWithCreateInstanceSupport)
instance_offer = _get_instance_offer_for_instance(instance_offer, instance)
instance_offer = _get_instance_offer_for_instance(
instance_offer=instance_offer,
instance=instance,
master_instance=master_instance,
)
if (
instance.fleet
and is_cloud_cluster(instance.fleet)
and is_fleet_master_instance(instance)
and instance.id == master_instance.id
and instance_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT
and isinstance(compute, ComputeWithPlacementGroupSupport)
and (
Expand Down Expand Up @@ -667,7 +684,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
"instance_status": InstanceStatus.PROVISIONING.value,
},
)
if instance.fleet_id and is_fleet_master_instance(instance):
if instance.fleet_id and instance.id == master_instance.id:
# Clean up placement groups that did not end up being used.
# Flush to update still uncommitted placement groups.
await session.flush()
Expand All @@ -685,7 +702,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
InstanceTerminationReason.NO_OFFERS,
"All offers failed" if offers else "No offers found",
)
if instance.fleet and is_fleet_master_instance(instance) and is_cloud_cluster(instance.fleet):
if instance.fleet and instance.id == master_instance.id and is_cloud_cluster(instance.fleet):
# Do not attempt to deploy other instances, as they won't determine the correct cluster
# backend, region, and placement group without a successfully deployed master instance
for sibling_instance in instance.fleet.instances:
Expand All @@ -694,6 +711,20 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No
_mark_terminated(sibling_instance, InstanceTerminationReason.MASTER_FAILED)


async def _get_fleet_master_instance(
session: AsyncSession, instance: InstanceModel
) -> InstanceModel:
# The "master" fleet instance is relevant for cloud clusters only:
# it can be any fixed instance that is chosen to be provisioned first.
res = await session.execute(
select(InstanceModel)
.where(InstanceModel.fleet_id == instance.fleet_id)
.order_by(InstanceModel.instance_num, InstanceModel.created_at)
.limit(1)
)
return res.scalar_one()


def _mark_terminated(
instance: InstanceModel,
termination_reason: InstanceTerminationReason,
Expand Down Expand Up @@ -1182,15 +1213,17 @@ def _get_termination_deadline(instance: InstanceModel) -> datetime.datetime:
return instance.first_termination_retry_at + TERMINATION_RETRY_MAX_DURATION


def _need_to_wait_fleet_provisioning(instance: InstanceModel) -> bool:
def _need_to_wait_fleet_provisioning(
instance: InstanceModel, master_instance: InstanceModel
) -> bool:
# Cluster cloud instances should wait for the first fleet instance to be provisioned
# so that they are provisioned in the same backend/region
if instance.fleet is None:
return False
if (
is_fleet_master_instance(instance)
or instance.fleet.instances[0].job_provisioning_data is not None
or instance.fleet.instances[0].status == InstanceStatus.TERMINATED
instance.id == master_instance.id
or master_instance.job_provisioning_data is not None
or master_instance.status == InstanceStatus.TERMINATED
):
return False
return is_cloud_cluster(instance.fleet)
Expand All @@ -1199,13 +1232,13 @@ def _need_to_wait_fleet_provisioning(instance: InstanceModel) -> bool:
def _get_instance_offer_for_instance(
instance_offer: InstanceOfferWithAvailability,
instance: InstanceModel,
master_instance: InstanceModel,
) -> InstanceOfferWithAvailability:
if instance.fleet is None:
return instance_offer
fleet = fleet_model_to_fleet(instance.fleet)
master_instance = instance.fleet.instances[0]
master_job_provisioning_data = get_instance_provisioning_data(master_instance)
if fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER:
master_job_provisioning_data = get_instance_provisioning_data(master_instance)
return get_instance_offer_with_restricted_az(
instance_offer=instance_offer,
master_job_provisioning_data=master_job_provisioning_data,
Expand Down
22 changes: 19 additions & 3 deletions src/dstack/_internal/server/background/tasks/process_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from sqlalchemy import and_, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only
from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only, with_loader_criteria

import dstack._internal.server.services.services.autoscalers as autoscalers
from dstack._internal.core.errors import ServerError
Expand Down Expand Up @@ -111,7 +111,15 @@ async def _process_next_run():
),
),
)
.options(joinedload(RunModel.jobs).load_only(JobModel.id))
.options(
joinedload(RunModel.jobs).load_only(JobModel.id),
# No need to lock finished jobs
with_loader_criteria(
JobModel,
JobModel.status.not_in(JobStatus.finished_statuses()),
include_aliases=True,
),
)
.options(load_only(RunModel.id))
.order_by(RunModel.last_processed_at.asc())
.limit(1)
Expand All @@ -126,12 +134,20 @@ async def _process_next_run():
JobModel.run_id == run_model.id,
JobModel.id.not_in(job_lockset),
)
.options(
load_only(JobModel.id),
with_loader_criteria(
JobModel,
JobModel.status.not_in(JobStatus.finished_statuses()),
include_aliases=True,
),
)
.order_by(JobModel.id) # take locks in order
.with_for_update(skip_locked=True, key_share=True)
)
job_models = res.scalars().all()
if len(run_model.jobs) != len(job_models):
# Some jobs are locked
# Some jobs are locked or there was a non-repeatable read
return
job_ids = [j.id for j in run_model.jobs]
run_lockset.add(run_model.id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@

from sqlalchemy import func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import contains_eager, joinedload, load_only, noload, selectinload
from sqlalchemy.orm import (
contains_eager,
joinedload,
load_only,
noload,
selectinload,
with_loader_criteria,
)

from dstack._internal.core.backends.base.backend import Backend
from dstack._internal.core.backends.base.compute import (
Expand Down Expand Up @@ -213,15 +220,25 @@ async def _process_submitted_job(
select(JobModel)
.where(JobModel.id == job_model.id)
.options(joinedload(JobModel.instance))
.options(joinedload(JobModel.fleet).joinedload(FleetModel.instances))
.options(
joinedload(JobModel.fleet).joinedload(FleetModel.instances),
with_loader_criteria(
InstanceModel, InstanceModel.deleted == False, include_aliases=True
),
)
)
job_model = res.unique().scalar_one()
res = await session.execute(
select(RunModel)
.where(RunModel.id == job_model.run_id)
.options(joinedload(RunModel.project).joinedload(ProjectModel.backends))
.options(joinedload(RunModel.user).load_only(UserModel.name))
.options(joinedload(RunModel.fleet).joinedload(FleetModel.instances))
.options(
joinedload(RunModel.fleet).joinedload(FleetModel.instances),
with_loader_criteria(
InstanceModel, InstanceModel.deleted == False, include_aliases=True
),
)
)
run_model = res.unique().scalar_one()
logger.debug("%s: provisioning has started", fmt(job_model))
Expand Down
4 changes: 0 additions & 4 deletions src/dstack/_internal/server/services/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,10 +728,6 @@ def is_cloud_cluster(fleet_model: FleetModel) -> bool:
)


def is_fleet_master_instance(instance: InstanceModel) -> bool:
return instance.fleet is not None and instance.id == instance.fleet.instances[0].id


def get_fleet_requirements(fleet_spec: FleetSpec) -> Requirements:
profile = fleet_spec.merged_profile
requirements = Requirements(
Expand Down
7 changes: 2 additions & 5 deletions src/dstack/_internal/server/services/placement.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,10 @@ async def schedule_fleet_placement_groups_deletion(
def get_placement_group_model_for_instance(
placement_group_models: list[PlacementGroupModel],
instance_model: InstanceModel,
master_instance_model: InstanceModel,
) -> Optional[PlacementGroupModel]:
placement_group_model = None
if not _is_fleet_master_instance(instance_model):
if instance_model.id != master_instance_model.id:
if placement_group_models:
placement_group_model = placement_group_models[0]
if len(placement_group_models) > 1:
Expand Down Expand Up @@ -231,7 +232,3 @@ async def create_placement_group(
)
placement_group_model.provisioning_data = pgpd.json()
return placement_group_model


def _is_fleet_master_instance(instance: InstanceModel) -> bool:
return instance.fleet is not None and instance.id == instance.fleet.instances[0].id