From 8e747c0786713f5412b68cc009cfcac799362315 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 14 Jan 2026 16:14:55 +0500 Subject: [PATCH 1/4] Optimize process_running_jobs select --- .../background/tasks/process_running_jobs.py | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 341b47a38b..353ef92573 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -5,9 +5,9 @@ from datetime import timedelta from typing import Dict, List, Optional -from sqlalchemy import select +from sqlalchemy import and_, func, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload, load_only +from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only from dstack._internal import settings from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT @@ -148,14 +148,37 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): .execution_options(populate_existing=True) ) job_model = res.unique().scalar_one() + # Select only latest submissions for every job. + latest_submissions_sq = ( + select( + JobModel.run_id.label("run_id"), + JobModel.replica_num.label("replica_num"), + JobModel.job_num.label("job_num"), + func.max(JobModel.submission_num).label("max_submission_num"), + ) + .where(JobModel.run_id == job_model.run_id) + .group_by(JobModel.run_id, JobModel.replica_num, JobModel.job_num) + .subquery() + ) + job_alias = aliased(JobModel) res = await session.execute( select(RunModel) .where(RunModel.id == job_model.run_id) + .join(job_alias, job_alias.run_id == RunModel.id) + .join( + latest_submissions_sq, + onclause=and_( + job_alias.run_id == latest_submissions_sq.c.run_id, + job_alias.replica_num == latest_submissions_sq.c.replica_num, + job_alias.job_num == latest_submissions_sq.c.job_num, + job_alias.submission_num == latest_submissions_sq.c.max_submission_num, + ), + ) .options(joinedload(RunModel.project)) .options(joinedload(RunModel.user)) .options(joinedload(RunModel.repo)) .options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name)) - .options(joinedload(RunModel.jobs)) + .options(contains_eager(RunModel.jobs, alias=job_alias)) ) run_model = res.unique().scalar_one() repo_model = run_model.repo From e94fa835ed5488e4ff2dcdb9d165d4e99d8e4087 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 15 Jan 2026 11:16:21 +0500 Subject: [PATCH 2/4] Optimize process_runs select --- .../background/tasks/process_running_jobs.py | 91 ++++++++++--------- .../server/background/tasks/process_runs.py | 61 +++++++++---- 2 files changed, 92 insertions(+), 60 deletions(-) diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 353ef92573..f5ca6c61ae 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -139,48 +139,8 @@ async def _process_next_running_job(): async def _process_running_job(session: AsyncSession, job_model: JobModel): - # Refetch to load related attributes. - res = await session.execute( - select(JobModel) - .where(JobModel.id == job_model.id) - .options(joinedload(JobModel.instance).joinedload(InstanceModel.project)) - .options(joinedload(JobModel.probes).load_only(ProbeModel.success_streak)) - .execution_options(populate_existing=True) - ) - job_model = res.unique().scalar_one() - # Select only latest submissions for every job. - latest_submissions_sq = ( - select( - JobModel.run_id.label("run_id"), - JobModel.replica_num.label("replica_num"), - JobModel.job_num.label("job_num"), - func.max(JobModel.submission_num).label("max_submission_num"), - ) - .where(JobModel.run_id == job_model.run_id) - .group_by(JobModel.run_id, JobModel.replica_num, JobModel.job_num) - .subquery() - ) - job_alias = aliased(JobModel) - res = await session.execute( - select(RunModel) - .where(RunModel.id == job_model.run_id) - .join(job_alias, job_alias.run_id == RunModel.id) - .join( - latest_submissions_sq, - onclause=and_( - job_alias.run_id == latest_submissions_sq.c.run_id, - job_alias.replica_num == latest_submissions_sq.c.replica_num, - job_alias.job_num == latest_submissions_sq.c.job_num, - job_alias.submission_num == latest_submissions_sq.c.max_submission_num, - ), - ) - .options(joinedload(RunModel.project)) - .options(joinedload(RunModel.user)) - .options(joinedload(RunModel.repo)) - .options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name)) - .options(contains_eager(RunModel.jobs, alias=job_alias)) - ) - run_model = res.unique().scalar_one() + job_model = await _refetch_job_model(session, job_model) + run_model = await _fetch_run_model(session, job_model.run_id) repo_model = run_model.repo project = run_model.project run = run_model_to_run(run_model, include_sensitive=True) @@ -444,6 +404,53 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): await session.commit() +async def _refetch_job_model(session: AsyncSession, job_model: JobModel) -> JobModel: + res = await session.execute( + select(JobModel) + .where(JobModel.id == job_model.id) + .options(joinedload(JobModel.instance).joinedload(InstanceModel.project)) + .options(joinedload(JobModel.probes).load_only(ProbeModel.success_streak)) + .execution_options(populate_existing=True) + ) + return res.unique().scalar_one() + + +async def _fetch_run_model(session: AsyncSession, run_id: uuid.UUID) -> RunModel: + # Select only latest submissions for every job. + latest_submissions_sq = ( + select( + JobModel.run_id.label("run_id"), + JobModel.replica_num.label("replica_num"), + JobModel.job_num.label("job_num"), + func.max(JobModel.submission_num).label("max_submission_num"), + ) + .where(JobModel.run_id == run_id) + .group_by(JobModel.run_id, JobModel.replica_num, JobModel.job_num) + .subquery() + ) + job_alias = aliased(JobModel) + res = await session.execute( + select(RunModel) + .where(RunModel.id == run_id) + .join(job_alias, job_alias.run_id == RunModel.id) + .join( + latest_submissions_sq, + onclause=and_( + job_alias.run_id == latest_submissions_sq.c.run_id, + job_alias.replica_num == latest_submissions_sq.c.replica_num, + job_alias.job_num == latest_submissions_sq.c.job_num, + job_alias.submission_num == latest_submissions_sq.c.max_submission_num, + ), + ) + .options(joinedload(RunModel.project)) + .options(joinedload(RunModel.user)) + .options(joinedload(RunModel.repo)) + .options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name)) + .options(contains_eager(RunModel.jobs, alias=job_alias)) + ) + return res.unique().scalar_one() + + async def _wait_for_instance_provisioning_data(session: AsyncSession, job_model: JobModel): """ This function will be called until instance IP address appears diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index af2dcee8d8..56648ceb2a 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -2,9 +2,9 @@ import datetime from typing import List, Optional, Set, Tuple -from sqlalchemy import and_, or_, select +from sqlalchemy import and_, func, or_, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload, load_only, selectinload +from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only import dstack._internal.server.services.services.autoscalers as autoscalers from dstack._internal.core.errors import ServerError @@ -144,22 +144,7 @@ async def _process_next_run(): async def _process_run(session: AsyncSession, run_model: RunModel): - # Refetch to load related attributes. - res = await session.execute( - select(RunModel) - .where(RunModel.id == run_model.id) - .execution_options(populate_existing=True) - .options(joinedload(RunModel.project).load_only(ProjectModel.id, ProjectModel.name)) - .options(joinedload(RunModel.user).load_only(UserModel.name)) - .options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name)) - .options( - selectinload(RunModel.jobs) - .joinedload(JobModel.instance) - .load_only(InstanceModel.fleet_id) - ) - .execution_options(populate_existing=True) - ) - run_model = res.unique().scalar_one() + run_model = await _refetch_run_model(session, run_model) logger.debug("%s: processing run", fmt(run_model)) try: if run_model.status == RunStatus.PENDING: @@ -181,6 +166,46 @@ async def _process_run(session: AsyncSession, run_model: RunModel): await session.commit() +async def _refetch_run_model(session: AsyncSession, run_model: RunModel) -> RunModel: + # Select only latest submissions for every job. + latest_submissions_sq = ( + select( + JobModel.run_id.label("run_id"), + JobModel.replica_num.label("replica_num"), + JobModel.job_num.label("job_num"), + func.max(JobModel.submission_num).label("max_submission_num"), + ) + .where(JobModel.run_id == run_model.id) + .group_by(JobModel.run_id, JobModel.replica_num, JobModel.job_num) + .subquery() + ) + job_alias = aliased(JobModel) + res = await session.execute( + select(RunModel) + .where(RunModel.id == run_model.id) + .outerjoin(latest_submissions_sq, latest_submissions_sq.c.run_id == RunModel.id) + .outerjoin( + job_alias, + onclause=and_( + job_alias.run_id == latest_submissions_sq.c.run_id, + job_alias.replica_num == latest_submissions_sq.c.replica_num, + job_alias.job_num == latest_submissions_sq.c.job_num, + job_alias.submission_num == latest_submissions_sq.c.max_submission_num, + ), + ) + .options(joinedload(RunModel.project).load_only(ProjectModel.id, ProjectModel.name)) + .options(joinedload(RunModel.user).load_only(UserModel.name)) + .options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name)) + .options( + contains_eager(RunModel.jobs, alias=job_alias) + .joinedload(JobModel.instance) + .load_only(InstanceModel.fleet_id) + ) + .execution_options(populate_existing=True) + ) + return res.unique().scalar_one() + + async def _process_pending_run(session: AsyncSession, run_model: RunModel): """Jobs are not created yet""" run = run_model_to_run(run_model) From 97735749aa7db38236bc7a38f2406d92c4eb5ff7 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 15 Jan 2026 12:15:48 +0500 Subject: [PATCH 3/4] Add test_calculates_retry_duration_since_last_successful_submission --- .../background/tasks/test_process_runs.py | 44 +++++++++++++++++-- 1 file changed, 41 insertions(+), 3 deletions(-) diff --git a/src/tests/_internal/server/background/tasks/test_process_runs.py b/src/tests/_internal/server/background/tasks/test_process_runs.py index 81c1ef0026..b34b5e5cb7 100644 --- a/src/tests/_internal/server/background/tasks/test_process_runs.py +++ b/src/tests/_internal/server/background/tasks/test_process_runs.py @@ -1,6 +1,6 @@ import datetime from collections.abc import Iterable -from typing import Union, cast +from typing import Optional, Union, cast from unittest.mock import patch import pytest @@ -15,7 +15,7 @@ TaskConfiguration, ) from dstack._internal.core.models.instances import InstanceStatus -from dstack._internal.core.models.profiles import Profile, ProfileRetry, Schedule +from dstack._internal.core.models.profiles import Profile, ProfileRetry, RetryEvent, Schedule from dstack._internal.core.models.resources import Range from dstack._internal.core.models.runs import ( JobSpec, @@ -48,6 +48,7 @@ async def make_run( deployment_num: int = 0, image: str = "ubuntu:latest", probes: Iterable[ProbeConfig] = (), + retry: Optional[ProfileRetry] = None, ) -> RunModel: project = await create_project(session=session) user = await create_user(session=session) @@ -58,7 +59,7 @@ async def make_run( run_name = "test-run" profile = Profile( name="test-profile", - retry=True, + retry=retry or True, ) run_spec = get_run_spec( repo_id=repo.name, @@ -230,6 +231,43 @@ async def test_retry_running_to_failed(self, test_db, session: AsyncSession): assert run.status == RunStatus.TERMINATING assert run.termination_reason == RunTerminationReason.JOB_FAILED + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_calculates_retry_duration_since_last_successful_submission( + self, test_db, session: AsyncSession + ): + run = await make_run( + session, + status=RunStatus.RUNNING, + replicas=1, + retry=ProfileRetry(duration=300, on_events=[RetryEvent.NO_CAPACITY]), + ) + now = run.submitted_at + datetime.timedelta(minutes=10) + # Retry logic should look at this job and calculate retry duration since its last_processed_at. + await create_job( + session=session, + run=run, + status=JobStatus.FAILED, + termination_reason=JobTerminationReason.EXECUTOR_ERROR, + last_processed_at=now - datetime.timedelta(minutes=4), + replica_num=0, + ) + await create_job( + session=session, + run=run, + status=JobStatus.FAILED, + termination_reason=JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY, + replica_num=0, + submission_num=1, + last_processed_at=now - datetime.timedelta(minutes=2), + job_provisioning_data=None, + ) + with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock: + datetime_mock.return_value = now + await process_runs.process_runs() + await session.refresh(run) + assert run.status == RunStatus.PENDING + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_pending_to_submitted(self, test_db, session: AsyncSession): From 3676f002288b5dc3373a716385dd65f7017abc96 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 15 Jan 2026 12:34:54 +0500 Subject: [PATCH 4/4] Fix _should_retry_job --- .../server/background/tasks/process_runs.py | 47 ++++++++++++++----- .../background/tasks/test_process_runs.py | 1 + 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index 56648ceb2a..b4397b95e0 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -33,6 +33,7 @@ get_job_specs_from_run_spec, group_jobs_by_replica_latest, is_master_job, + job_model_to_job_submission, switch_job_status, ) from dstack._internal.server.services.locking import get_locker @@ -319,7 +320,7 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel): and job_model.termination_reason not in {JobTerminationReason.DONE_BY_RUNNER, JobTerminationReason.SCALED_DOWN} ): - current_duration = _should_retry_job(run, job, job_model) + current_duration = await _should_retry_job(session, run, job, job_model) if current_duration is None: replica_statuses.add(RunStatus.FAILED) run_termination_reasons.add(RunTerminationReason.JOB_FAILED) @@ -577,19 +578,44 @@ def _has_out_of_date_replicas(run: RunModel) -> bool: return False -def _should_retry_job(run: Run, job: Job, job_model: JobModel) -> Optional[datetime.timedelta]: +async def _should_retry_job( + session: AsyncSession, + run: Run, + job: Job, + job_model: JobModel, +) -> Optional[datetime.timedelta]: """ Checks if the job should be retried. Returns the current duration of retrying if retry is enabled. + Retrying duration is calculated as the time since `last_processed_at` + of the latest provisioned submission. """ if job.job_spec.retry is None: return None last_provisioned_submission = None - for job_submission in reversed(job.job_submissions): - if job_submission.job_provisioning_data is not None: - last_provisioned_submission = job_submission - break + if len(job.job_submissions) > 0: + last_submission = job.job_submissions[-1] + if last_submission.job_provisioning_data is not None: + last_provisioned_submission = last_submission + else: + # The caller passes at most one latest submission in job.job_submissions, so check the db. + res = await session.execute( + select(JobModel) + .where( + JobModel.run_id == job_model.run_id, + JobModel.replica_num == job_model.replica_num, + JobModel.job_num == job_model.job_num, + JobModel.job_provisioning_data.is_not(None), + ) + .order_by(JobModel.last_processed_at.desc()) + .limit(1) + ) + last_provisioned_submission_model = res.scalar() + if last_provisioned_submission_model is not None: + last_provisioned_submission = job_model_to_job_submission( + last_provisioned_submission_model + ) if ( job_model.termination_reason is not None @@ -599,13 +625,10 @@ def _should_retry_job(run: Run, job: Job, job_model: JobModel) -> Optional[datet ): return common.get_current_datetime() - run.submitted_at - if last_provisioned_submission is None: - return None - if ( - last_provisioned_submission.termination_reason is not None - and JobTerminationReason(last_provisioned_submission.termination_reason).to_retry_event() - in job.job_spec.retry.on_events + job_model.termination_reason is not None + and job_model.termination_reason.to_retry_event() in job.job_spec.retry.on_events + and last_provisioned_submission is not None ): return common.get_current_datetime() - last_provisioned_submission.last_processed_at diff --git a/src/tests/_internal/server/background/tasks/test_process_runs.py b/src/tests/_internal/server/background/tasks/test_process_runs.py index b34b5e5cb7..46aaa9b48e 100644 --- a/src/tests/_internal/server/background/tasks/test_process_runs.py +++ b/src/tests/_internal/server/background/tasks/test_process_runs.py @@ -251,6 +251,7 @@ async def test_calculates_retry_duration_since_last_successful_submission( termination_reason=JobTerminationReason.EXECUTOR_ERROR, last_processed_at=now - datetime.timedelta(minutes=4), replica_num=0, + job_provisioning_data=get_job_provisioning_data(), ) await create_job( session=session,