Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from datetime import timedelta
from typing import Dict, List, Optional

from sqlalchemy import select
from sqlalchemy import and_, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload, load_only
from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only

from dstack._internal import settings
from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT
Expand Down Expand Up @@ -139,25 +139,8 @@ async def _process_next_running_job():


async def _process_running_job(session: AsyncSession, job_model: JobModel):
# Refetch to load related attributes.
res = await session.execute(
select(JobModel)
.where(JobModel.id == job_model.id)
.options(joinedload(JobModel.instance).joinedload(InstanceModel.project))
.options(joinedload(JobModel.probes).load_only(ProbeModel.success_streak))
.execution_options(populate_existing=True)
)
job_model = res.unique().scalar_one()
res = await session.execute(
select(RunModel)
.where(RunModel.id == job_model.run_id)
.options(joinedload(RunModel.project))
.options(joinedload(RunModel.user))
.options(joinedload(RunModel.repo))
.options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name))
.options(joinedload(RunModel.jobs))
)
run_model = res.unique().scalar_one()
job_model = await _refetch_job_model(session, job_model)
run_model = await _fetch_run_model(session, job_model.run_id)
repo_model = run_model.repo
project = run_model.project
run = run_model_to_run(run_model, include_sensitive=True)
Expand Down Expand Up @@ -421,6 +404,53 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
await session.commit()


async def _refetch_job_model(session: AsyncSession, job_model: JobModel) -> JobModel:
res = await session.execute(
select(JobModel)
.where(JobModel.id == job_model.id)
.options(joinedload(JobModel.instance).joinedload(InstanceModel.project))
.options(joinedload(JobModel.probes).load_only(ProbeModel.success_streak))
.execution_options(populate_existing=True)
)
return res.unique().scalar_one()


async def _fetch_run_model(session: AsyncSession, run_id: uuid.UUID) -> RunModel:
# Select only latest submissions for every job.
latest_submissions_sq = (
select(
JobModel.run_id.label("run_id"),
JobModel.replica_num.label("replica_num"),
JobModel.job_num.label("job_num"),
func.max(JobModel.submission_num).label("max_submission_num"),
)
.where(JobModel.run_id == run_id)
.group_by(JobModel.run_id, JobModel.replica_num, JobModel.job_num)
.subquery()
)
job_alias = aliased(JobModel)
res = await session.execute(
select(RunModel)
.where(RunModel.id == run_id)
.join(job_alias, job_alias.run_id == RunModel.id)
.join(
latest_submissions_sq,
onclause=and_(
job_alias.run_id == latest_submissions_sq.c.run_id,
job_alias.replica_num == latest_submissions_sq.c.replica_num,
job_alias.job_num == latest_submissions_sq.c.job_num,
job_alias.submission_num == latest_submissions_sq.c.max_submission_num,
),
)
.options(joinedload(RunModel.project))
.options(joinedload(RunModel.user))
.options(joinedload(RunModel.repo))
.options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name))
.options(contains_eager(RunModel.jobs, alias=job_alias))
)
return res.unique().scalar_one()


async def _wait_for_instance_provisioning_data(session: AsyncSession, job_model: JobModel):
"""
This function will be called until instance IP address appears
Expand Down
108 changes: 78 additions & 30 deletions src/dstack/_internal/server/background/tasks/process_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import datetime
from typing import List, Optional, Set, Tuple

from sqlalchemy import and_, or_, select
from sqlalchemy import and_, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload, load_only, selectinload
from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only

import dstack._internal.server.services.services.autoscalers as autoscalers
from dstack._internal.core.errors import ServerError
Expand Down Expand Up @@ -33,6 +33,7 @@
get_job_specs_from_run_spec,
group_jobs_by_replica_latest,
is_master_job,
job_model_to_job_submission,
switch_job_status,
)
from dstack._internal.server.services.locking import get_locker
Expand Down Expand Up @@ -144,22 +145,7 @@ async def _process_next_run():


async def _process_run(session: AsyncSession, run_model: RunModel):
# Refetch to load related attributes.
res = await session.execute(
select(RunModel)
.where(RunModel.id == run_model.id)
.execution_options(populate_existing=True)
.options(joinedload(RunModel.project).load_only(ProjectModel.id, ProjectModel.name))
.options(joinedload(RunModel.user).load_only(UserModel.name))
.options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name))
.options(
selectinload(RunModel.jobs)
.joinedload(JobModel.instance)
.load_only(InstanceModel.fleet_id)
)
.execution_options(populate_existing=True)
)
run_model = res.unique().scalar_one()
run_model = await _refetch_run_model(session, run_model)
logger.debug("%s: processing run", fmt(run_model))
try:
if run_model.status == RunStatus.PENDING:
Expand All @@ -181,6 +167,46 @@ async def _process_run(session: AsyncSession, run_model: RunModel):
await session.commit()


async def _refetch_run_model(session: AsyncSession, run_model: RunModel) -> RunModel:
# Select only latest submissions for every job.
latest_submissions_sq = (
select(
JobModel.run_id.label("run_id"),
JobModel.replica_num.label("replica_num"),
JobModel.job_num.label("job_num"),
func.max(JobModel.submission_num).label("max_submission_num"),
)
.where(JobModel.run_id == run_model.id)
.group_by(JobModel.run_id, JobModel.replica_num, JobModel.job_num)
.subquery()
)
job_alias = aliased(JobModel)
res = await session.execute(
select(RunModel)
.where(RunModel.id == run_model.id)
.outerjoin(latest_submissions_sq, latest_submissions_sq.c.run_id == RunModel.id)
.outerjoin(
job_alias,
onclause=and_(
job_alias.run_id == latest_submissions_sq.c.run_id,
job_alias.replica_num == latest_submissions_sq.c.replica_num,
job_alias.job_num == latest_submissions_sq.c.job_num,
job_alias.submission_num == latest_submissions_sq.c.max_submission_num,
),
)
.options(joinedload(RunModel.project).load_only(ProjectModel.id, ProjectModel.name))
.options(joinedload(RunModel.user).load_only(UserModel.name))
.options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name))
.options(
contains_eager(RunModel.jobs, alias=job_alias)
.joinedload(JobModel.instance)
.load_only(InstanceModel.fleet_id)
)
.execution_options(populate_existing=True)
)
return res.unique().scalar_one()


async def _process_pending_run(session: AsyncSession, run_model: RunModel):
"""Jobs are not created yet"""
run = run_model_to_run(run_model)
Expand Down Expand Up @@ -294,7 +320,7 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
and job_model.termination_reason
not in {JobTerminationReason.DONE_BY_RUNNER, JobTerminationReason.SCALED_DOWN}
):
current_duration = _should_retry_job(run, job, job_model)
current_duration = await _should_retry_job(session, run, job, job_model)
if current_duration is None:
replica_statuses.add(RunStatus.FAILED)
run_termination_reasons.add(RunTerminationReason.JOB_FAILED)
Expand Down Expand Up @@ -552,19 +578,44 @@ def _has_out_of_date_replicas(run: RunModel) -> bool:
return False


def _should_retry_job(run: Run, job: Job, job_model: JobModel) -> Optional[datetime.timedelta]:
async def _should_retry_job(
session: AsyncSession,
run: Run,
job: Job,
job_model: JobModel,
) -> Optional[datetime.timedelta]:
"""
Checks if the job should be retried.
Returns the current duration of retrying if retry is enabled.
Retrying duration is calculated as the time since `last_processed_at`
of the latest provisioned submission.
"""
if job.job_spec.retry is None:
return None

last_provisioned_submission = None
for job_submission in reversed(job.job_submissions):
if job_submission.job_provisioning_data is not None:
last_provisioned_submission = job_submission
break
if len(job.job_submissions) > 0:
last_submission = job.job_submissions[-1]
if last_submission.job_provisioning_data is not None:
last_provisioned_submission = last_submission
else:
# The caller passes at most one latest submission in job.job_submissions, so check the db.
res = await session.execute(
select(JobModel)
.where(
JobModel.run_id == job_model.run_id,
JobModel.replica_num == job_model.replica_num,
JobModel.job_num == job_model.job_num,
JobModel.job_provisioning_data.is_not(None),
)
.order_by(JobModel.last_processed_at.desc())
.limit(1)
)
last_provisioned_submission_model = res.scalar()
if last_provisioned_submission_model is not None:
last_provisioned_submission = job_model_to_job_submission(
last_provisioned_submission_model
)

if (
job_model.termination_reason is not None
Expand All @@ -574,13 +625,10 @@ def _should_retry_job(run: Run, job: Job, job_model: JobModel) -> Optional[datet
):
return common.get_current_datetime() - run.submitted_at

if last_provisioned_submission is None:
return None

if (
last_provisioned_submission.termination_reason is not None
and JobTerminationReason(last_provisioned_submission.termination_reason).to_retry_event()
in job.job_spec.retry.on_events
job_model.termination_reason is not None
and job_model.termination_reason.to_retry_event() in job.job_spec.retry.on_events
and last_provisioned_submission is not None
):
return common.get_current_datetime() - last_provisioned_submission.last_processed_at

Expand Down
45 changes: 42 additions & 3 deletions src/tests/_internal/server/background/tasks/test_process_runs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
from collections.abc import Iterable
from typing import Union, cast
from typing import Optional, Union, cast
from unittest.mock import patch

import pytest
Expand All @@ -15,7 +15,7 @@
TaskConfiguration,
)
from dstack._internal.core.models.instances import InstanceStatus
from dstack._internal.core.models.profiles import Profile, ProfileRetry, Schedule
from dstack._internal.core.models.profiles import Profile, ProfileRetry, RetryEvent, Schedule
from dstack._internal.core.models.resources import Range
from dstack._internal.core.models.runs import (
JobSpec,
Expand Down Expand Up @@ -48,6 +48,7 @@ async def make_run(
deployment_num: int = 0,
image: str = "ubuntu:latest",
probes: Iterable[ProbeConfig] = (),
retry: Optional[ProfileRetry] = None,
) -> RunModel:
project = await create_project(session=session)
user = await create_user(session=session)
Expand All @@ -58,7 +59,7 @@ async def make_run(
run_name = "test-run"
profile = Profile(
name="test-profile",
retry=True,
retry=retry or True,
)
run_spec = get_run_spec(
repo_id=repo.name,
Expand Down Expand Up @@ -230,6 +231,44 @@ async def test_retry_running_to_failed(self, test_db, session: AsyncSession):
assert run.status == RunStatus.TERMINATING
assert run.termination_reason == RunTerminationReason.JOB_FAILED

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_calculates_retry_duration_since_last_successful_submission(
self, test_db, session: AsyncSession
):
run = await make_run(
session,
status=RunStatus.RUNNING,
replicas=1,
retry=ProfileRetry(duration=300, on_events=[RetryEvent.NO_CAPACITY]),
)
now = run.submitted_at + datetime.timedelta(minutes=10)
# Retry logic should look at this job and calculate retry duration since its last_processed_at.
await create_job(
session=session,
run=run,
status=JobStatus.FAILED,
termination_reason=JobTerminationReason.EXECUTOR_ERROR,
last_processed_at=now - datetime.timedelta(minutes=4),
replica_num=0,
job_provisioning_data=get_job_provisioning_data(),
)
await create_job(
session=session,
run=run,
status=JobStatus.FAILED,
termination_reason=JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY,
replica_num=0,
submission_num=1,
last_processed_at=now - datetime.timedelta(minutes=2),
job_provisioning_data=None,
)
with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock:
datetime_mock.return_value = now
await process_runs.process_runs()
await session.refresh(run)
assert run.status == RunStatus.PENDING

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_pending_to_submitted(self, test_db, session: AsyncSession):
Expand Down