Skip to content

Commit a6a4a1d

Browse files
committed
Replica group spot_policy and reservation
Allow to define `spot_policy` and `reservation` per replica group in service configurations. ```yaml type: service image: my-image port: 80 replicas: - name: baseline reservation: my-reservation count: 1 - name: overflow spot_policy: auto count: 0..3 scaling: metric: rps target: 1 ```
1 parent 60bbb7e commit a6a4a1d

7 files changed

Lines changed: 190 additions & 4 deletions

File tree

src/dstack/_internal/core/backends/base/compute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def run_job(
318318
user=run.user,
319319
ssh_keys=[SSHKey(public=project_ssh_public_key.strip())],
320320
volumes=volumes,
321-
reservation=run.run_spec.configuration.reservation,
321+
reservation=job.job_spec.requirements.reservation,
322322
tags=run.run_spec.merged_profile.tags,
323323
)
324324
instance_offer = instance_offer.copy()

src/dstack/_internal/core/compatibility/runs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType:
122122
replica_group_excludes["nvcc"] = True
123123
if all(g.privileged is None for g in replicas):
124124
replica_group_excludes["privileged"] = True
125+
if all(g.spot_policy is None for g in replicas):
126+
replica_group_excludes["spot_policy"] = True
127+
if all(g.reservation is None for g in replicas):
128+
replica_group_excludes["reservation"] = True
125129
if replica_group_excludes:
126130
configuration_excludes["replicas"] = {"__all__": replica_group_excludes}
127131

src/dstack/_internal/core/models/configurations.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from dstack._internal.core.models.profiles import (
2626
ProfileParams,
2727
ProfileParamsConfig,
28+
SpotPolicy,
2829
parse_duration,
2930
parse_off_duration,
3031
)
@@ -836,6 +837,24 @@ class ReplicaGroup(CoreModel):
836837
ResourcesSpec,
837838
Field(description="The resources requirements for replicas in this group"),
838839
] = ResourcesSpec()
840+
spot_policy: Annotated[
841+
Optional[SpotPolicy],
842+
Field(
843+
description=(
844+
"The policy for provisioning spot or on-demand instances for replicas in this group:"
845+
f" {list_enum_values_for_annotation(SpotPolicy)}"
846+
)
847+
),
848+
] = None
849+
reservation: Annotated[
850+
Optional[str],
851+
Field(
852+
description=(
853+
"The existing reservation to use for replicas in this group."
854+
" Supports AWS Capacity Reservations, AWS Capacity Blocks, and GCP reservations"
855+
)
856+
),
857+
] = None
839858

840859
commands: Annotated[
841860
CommandsList,
@@ -1144,7 +1163,7 @@ def validate_top_level_properties_with_replica_groups(cls, values):
11441163
@root_validator()
11451164
def validate_no_mixed_service_and_group_container_fields(cls, values):
11461165
"""
1147-
When replicas is a list (image, docker, privileged) may be set
1166+
When replicas is a list, certain fields may be set
11481167
at the service level OR in replica groups, never both. Mixing is
11491168
rejected — including partial mixing, where only some groups set a
11501169
field the service also sets — because it leaves precedence ambiguous.
@@ -1179,6 +1198,16 @@ def validate_no_mixed_service_and_group_container_fields(cls, values):
11791198
values.get("nvcc") is True,
11801199
lambda g: g.nvcc is not None,
11811200
),
1201+
(
1202+
"spot_policy",
1203+
values.get("spot_policy") is not None,
1204+
lambda g: g.spot_policy is not None,
1205+
),
1206+
(
1207+
"reservation",
1208+
values.get("reservation") is not None,
1209+
lambda g: g.reservation is not None,
1210+
),
11821211
]
11831212

11841213
for field, service_set, group_set in checks:

src/dstack/_internal/server/services/jobs/configurators/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ def _default_max_duration(self) -> Optional[int]:
127127
def _spot_policy(self) -> SpotPolicy:
128128
pass
129129

130+
def _reservation(self) -> Optional[str]:
131+
return self.run_spec.merged_profile.reservation
132+
130133
@abstractmethod
131134
def _ports(self) -> List[PortMapping]:
132135
pass
@@ -334,7 +337,7 @@ def _requirements(self, jobs_per_replica: int) -> Requirements:
334337
resources=resources,
335338
max_price=self.run_spec.merged_profile.max_price,
336339
spot=None if spot_policy == SpotPolicy.AUTO else (spot_policy == SpotPolicy.SPOT),
337-
reservation=self.run_spec.merged_profile.reservation,
340+
reservation=self._reservation(),
338341
multinode=jobs_per_replica > 1,
339342
backend_options=self.run_spec.merged_profile.backend_options,
340343
)

src/dstack/_internal/server/services/jobs/configurators/service.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,16 @@ def _default_max_duration(self) -> Optional[int]:
113113
return None
114114

115115
def _spot_policy(self) -> SpotPolicy:
116+
group = self._current_replica_group()
117+
if group is not None and group.spot_policy is not None:
118+
return group.spot_policy
116119
return self.run_spec.merged_profile.spot_policy or SpotPolicy.ONDEMAND
117120

121+
def _reservation(self) -> Optional[str]:
122+
group = self._current_replica_group()
123+
if group is not None and group.reservation is not None:
124+
return group.reservation
125+
return super()._reservation()
126+
118127
def _ports(self) -> List[PortMapping]:
119128
return []

src/tests/_internal/core/models/test_configurations.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,46 @@ def test_replica_group_router_forbids_service_level_router(self):
142142
):
143143
parse_run_configuration(conf)
144144

145+
def test_spot_policy_set_at_both_service_and_group_rejected(self):
146+
with pytest.raises(
147+
ConfigurationError,
148+
match="`spot_policy` is set at both",
149+
):
150+
parse_run_configuration(
151+
{
152+
"type": "service",
153+
"port": 8000,
154+
"spot_policy": "spot",
155+
"replicas": [
156+
{
157+
"count": 1,
158+
"commands": ["x"],
159+
"spot_policy": "on-demand",
160+
},
161+
],
162+
}
163+
)
164+
165+
def test_reservation_set_at_both_service_and_group_rejected(self):
166+
with pytest.raises(
167+
ConfigurationError,
168+
match="`reservation` is set at both",
169+
):
170+
parse_run_configuration(
171+
{
172+
"type": "service",
173+
"port": 8000,
174+
"image": "x",
175+
"reservation": "svc-res",
176+
"replicas": [
177+
{
178+
"count": 1,
179+
"reservation": "grp-res",
180+
},
181+
],
182+
}
183+
)
184+
145185
@pytest.mark.parametrize("shell", [None, "sh", "bash", "/usr/bin/zsh"])
146186
def test_shell_valid(self, shell: Optional[str]):
147187
conf = {

src/tests/_internal/server/services/jobs/configurators/test_service.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ReplicaGroup,
1111
ServiceConfiguration,
1212
)
13+
from dstack._internal.core.models.profiles import SpotPolicy
1314
from dstack._internal.core.models.resources import Range
1415
from dstack._internal.core.models.services import OpenAIChatModel
1516
from dstack._internal.server.services.docker import ImageConfig
@@ -118,7 +119,7 @@ def _make_run_spec(replicas, **service_kwargs):
118119
@pytest.mark.usefixtures("image_config_mock")
119120
class TestPerGroupOverrides:
120121
"""Verifies that ServiceJobConfigurator picks up per-replica-group
121-
image-source fields (image, docker, python, nvcc, privileged)."""
122+
image-source fields (image, docker, python, nvcc, privileged, etc)."""
122123

123124
async def test_image_name_uses_group_image(self):
124125
run_spec = _make_run_spec(
@@ -331,3 +332,103 @@ async def test_user_does_not_lookup_for_group_docker(self, monkeypatch: pytest.M
331332
configurator = ServiceJobConfigurator(run_spec, replica_group_name="a")
332333
await configurator._user()
333334
mock_get_image_config.assert_not_called()
335+
336+
async def test_spot_policy_uses_group_value(self):
337+
run_spec = _make_run_spec(
338+
replicas=[
339+
ReplicaGroup(
340+
name="a",
341+
count=Range(min=1, max=1),
342+
commands=["x"],
343+
spot_policy=SpotPolicy.SPOT,
344+
)
345+
],
346+
)
347+
configurator = ServiceJobConfigurator(run_spec, replica_group_name="a")
348+
assert configurator._spot_policy() == SpotPolicy.SPOT
349+
350+
async def test_spot_policy_defaults_to_ondemand_when_group_unset(self):
351+
run_spec = _make_run_spec(
352+
replicas=[
353+
ReplicaGroup(
354+
name="a",
355+
count=Range(min=1, max=1),
356+
commands=["x"],
357+
)
358+
],
359+
)
360+
configurator = ServiceJobConfigurator(run_spec, replica_group_name="a")
361+
assert configurator._spot_policy() == SpotPolicy.ONDEMAND
362+
363+
async def test_different_groups_different_spot_policies(self):
364+
run_spec = _make_run_spec(
365+
replicas=[
366+
ReplicaGroup(
367+
name="spot",
368+
count=Range(min=1, max=1),
369+
commands=["x"],
370+
spot_policy=SpotPolicy.SPOT,
371+
),
372+
ReplicaGroup(
373+
name="od",
374+
count=Range(min=1, max=1),
375+
commands=["y"],
376+
spot_policy=SpotPolicy.ONDEMAND,
377+
),
378+
],
379+
)
380+
assert (
381+
ServiceJobConfigurator(run_spec, replica_group_name="spot")._spot_policy()
382+
== SpotPolicy.SPOT
383+
)
384+
assert (
385+
ServiceJobConfigurator(run_spec, replica_group_name="od")._spot_policy()
386+
== SpotPolicy.ONDEMAND
387+
)
388+
389+
async def test_reservation_uses_group_value(self):
390+
run_spec = _make_run_spec(
391+
replicas=[
392+
ReplicaGroup(
393+
name="a",
394+
count=Range(min=1, max=1),
395+
commands=["x"],
396+
reservation="my-reservation",
397+
)
398+
],
399+
)
400+
configurator = ServiceJobConfigurator(run_spec, replica_group_name="a")
401+
assert configurator._reservation() == "my-reservation"
402+
403+
async def test_reservation_defaults_to_none_when_group_unset(self):
404+
run_spec = _make_run_spec(
405+
replicas=[
406+
ReplicaGroup(
407+
name="a",
408+
count=Range(min=1, max=1),
409+
commands=["x"],
410+
)
411+
],
412+
)
413+
configurator = ServiceJobConfigurator(run_spec, replica_group_name="a")
414+
assert configurator._reservation() is None
415+
416+
async def test_different_groups_different_reservations(self):
417+
run_spec = _make_run_spec(
418+
replicas=[
419+
ReplicaGroup(
420+
name="a",
421+
count=Range(min=1, max=1),
422+
commands=["x"],
423+
reservation="res-a",
424+
),
425+
ReplicaGroup(
426+
name="b",
427+
count=Range(min=1, max=1),
428+
commands=["y"],
429+
reservation="res-b",
430+
),
431+
],
432+
)
433+
assert ServiceJobConfigurator(run_spec, replica_group_name="a")._reservation() == "res-a"
434+
assert ServiceJobConfigurator(run_spec, replica_group_name="b")._reservation() == "res-b"

0 commit comments

Comments
 (0)