Skip to content

Commit 661bfa3

Browse files
committed
feat: pod overlay for kubernetes scheduler (#1067,#1068)
1 parent 1d26b39 commit 661bfa3

File tree

2 files changed

+246
-1
lines changed

2 files changed

+246
-1
lines changed

torchx/schedulers/kubernetes_scheduler.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,65 @@
2727
See the
2828
`Volcano Quickstart <https://github.com/volcano-sh/volcano>`_
2929
for more information.
30+
31+
Pod Overlay
32+
===========
33+
34+
You can overlay arbitrary Kubernetes Pod fields on generated pods by setting
35+
the ``kubernetes`` metadata on your role. The structure follows the Kubernetes
36+
Pod spec with ``metadata`` and ``spec`` fields.
37+
38+
The metadata value can be:
39+
- A dict with the overlay structure
40+
- A resource URI pointing to a YAML file (e.g. ``file://``, ``s3://``, ``gs://``)
41+
42+
.. code:: python
43+
44+
from torchx.specs import AppDef, Role
45+
46+
# Option 1: Dict
47+
role = Role(
48+
name="trainer",
49+
image="my-image:latest",
50+
entrypoint="train.py",
51+
metadata={
52+
"kubernetes": {
53+
"spec": {
54+
"nodeSelector": {"gpu": "true"},
55+
"tolerations": [{
56+
"key": "nvidia.com/gpu",
57+
"operator": "Exists",
58+
"effect": "NoSchedule"
59+
}]
60+
}
61+
}
62+
}
63+
)
64+
65+
# Option 2: Resource URI
66+
role = Role(
67+
name="trainer",
68+
image="my-image:latest",
69+
entrypoint="train.py",
70+
metadata={
71+
"kubernetes": "file:///path/to/pod_overlay.yaml"
72+
}
73+
)
74+
75+
Example ``pod_overlay.yaml``:
76+
77+
.. code:: yaml
78+
79+
spec:
80+
nodeSelector:
81+
node.kubernetes.io/instance-type: p4d.24xlarge
82+
tolerations:
83+
- key: nvidia.com/gpu
84+
operator: Exists
85+
effect: NoSchedule
86+
87+
The overlay is deep-merged with the generated pod, preserving existing fields
88+
and adding or overriding specified ones.
3089
"""
3190

3291
import json
@@ -45,6 +104,7 @@
45104
Tuple,
46105
TYPE_CHECKING,
47106
TypedDict,
107+
Union,
48108
)
49109

50110
import torchx
@@ -97,6 +157,29 @@
97157
RESERVED_MILLICPU = 100
98158
RESERVED_MEMMB = 1024
99159

160+
161+
def _apply_pod_overlay(pod: "V1Pod", overlay: Dict[str, Any]) -> None:
162+
"""Apply overlay dict to V1Pod object, merging nested fields."""
163+
from kubernetes import client
164+
165+
api = client.ApiClient()
166+
pod_dict = api.sanitize_for_serialization(pod)
167+
168+
def deep_merge(base: Dict[str, Any], overlay: Dict[str, Any]) -> None:
169+
for key, value in overlay.items():
170+
if isinstance(value, dict) and key in base and isinstance(base[key], dict):
171+
deep_merge(base[key], value)
172+
else:
173+
base[key] = value
174+
175+
deep_merge(pod_dict, overlay)
176+
177+
merged_pod = api._ApiClient__deserialize(pod_dict, "V1Pod")
178+
for attr in ["api_version", "kind", "metadata", "spec", "status"]:
179+
if hasattr(merged_pod, attr):
180+
setattr(pod, attr, getattr(merged_pod, attr))
181+
182+
100183
RETRY_POLICIES: Mapping[str, Iterable[Mapping[str, str]]] = {
101184
RetryPolicy.REPLICA: [],
102185
RetryPolicy.APPLICATION: [
@@ -402,6 +485,17 @@ def app_to_resource(
402485
replica_role.env["TORCHX_IMAGE"] = replica_role.image
403486

404487
pod = role_to_pod(name, replica_role, service_account)
488+
if k8s_metadata := role.metadata.get("kubernetes"):
489+
if isinstance(k8s_metadata, str):
490+
import fsspec # pyre-ignore[21]
491+
492+
with fsspec.open(k8s_metadata, "r") as f:
493+
k8s_metadata = yaml.safe_load(f)
494+
elif not isinstance(k8s_metadata, dict):
495+
raise ValueError(
496+
f"metadata['kubernetes'] must be a dict or resource URI, got {type(k8s_metadata)}"
497+
)
498+
_apply_pod_overlay(pod, k8s_metadata)
405499
pod.metadata.labels.update(
406500
pod_labels(
407501
app=app,
@@ -636,7 +730,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[KubernetesJob]) -> str:
636730
else:
637731
raise
638732

639-
return f'{namespace}:{resp["metadata"]["name"]}'
733+
return f"{namespace}:{resp['metadata']['name']}"
640734

641735
def _submit_dryrun(
642736
self, app: AppDef, cfg: KubernetesOpts

torchx/schedulers/test/kubernetes_scheduler_test.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,157 @@ def test_min_replicas(self) -> None:
929929
]
930930
self.assertEqual(min_available, [1, 1, 0])
931931

932+
def test_apply_pod_overlay(self) -> None:
933+
from kubernetes.client.models import V1Container, V1ObjectMeta, V1Pod, V1PodSpec
934+
from torchx.schedulers.kubernetes_scheduler import _apply_pod_overlay
935+
936+
pod = V1Pod(
937+
spec=V1PodSpec(
938+
containers=[V1Container(name="test", image="test:latest")],
939+
node_selector={"existing": "label"},
940+
),
941+
metadata=V1ObjectMeta(name="test-pod"),
942+
)
943+
944+
overlay = {
945+
"spec": {
946+
"nodeSelector": {"gpu": "true"},
947+
"tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}],
948+
}
949+
}
950+
951+
_apply_pod_overlay(pod, overlay)
952+
953+
self.assertEqual(pod.spec.node_selector, {"existing": "label", "gpu": "true"})
954+
self.assertEqual(len(pod.spec.tolerations), 1)
955+
self.assertEqual(pod.spec.tolerations[0].key, "nvidia.com/gpu")
956+
957+
def test_apply_pod_overlay_new_fields(self) -> None:
958+
from kubernetes.client.models import V1Container, V1ObjectMeta, V1Pod, V1PodSpec
959+
from torchx.schedulers.kubernetes_scheduler import _apply_pod_overlay
960+
961+
# Pod without nodeSelector or tolerations
962+
pod = V1Pod(
963+
spec=V1PodSpec(containers=[V1Container(name="test", image="test:latest")]),
964+
metadata=V1ObjectMeta(name="test-pod"),
965+
)
966+
967+
# Overlay adds fields not present in original
968+
overlay = {
969+
"spec": {
970+
"nodeSelector": {"gpu": "true"},
971+
"tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}],
972+
"affinity": {
973+
"nodeAffinity": {
974+
"requiredDuringSchedulingIgnoredDuringExecution": {
975+
"nodeSelectorTerms": [
976+
{
977+
"matchExpressions": [
978+
{
979+
"key": "gpu",
980+
"operator": "In",
981+
"values": ["true"],
982+
}
983+
]
984+
}
985+
]
986+
}
987+
}
988+
},
989+
}
990+
}
991+
992+
_apply_pod_overlay(pod, overlay)
993+
994+
self.assertEqual(pod.spec.node_selector, {"gpu": "true"})
995+
self.assertEqual(len(pod.spec.tolerations), 1)
996+
self.assertIsNotNone(pod.spec.affinity)
997+
self.assertIsNotNone(pod.spec.affinity.node_affinity)
998+
999+
def test_submit_dryrun_with_pod_overlay(self) -> None:
1000+
scheduler = create_scheduler("test")
1001+
1002+
# Create app with metadata
1003+
trainer_role = specs.Role(
1004+
name="trainer",
1005+
image="pytorch/torchx:latest",
1006+
entrypoint="main",
1007+
resource=specs.Resource(cpu=1, memMB=1000, gpu=0),
1008+
metadata={"kubernetes": {"spec": {"nodeSelector": {"gpu": "true"}}}},
1009+
)
1010+
app = specs.AppDef("test", roles=[trainer_role])
1011+
cfg = KubernetesOpts({"queue": "testqueue"})
1012+
1013+
info = scheduler.submit_dryrun(app, cfg)
1014+
resource = info.request.resource
1015+
1016+
# Check that overlay was applied to all pods
1017+
tasks = resource["spec"]["tasks"] # pyre-ignore[16]
1018+
for task in tasks:
1019+
pod = task["template"]
1020+
self.assertIn("gpu", pod.spec.node_selector)
1021+
self.assertEqual(pod.spec.node_selector["gpu"], "true")
1022+
1023+
def test_submit_dryrun_with_pod_overlay_file_uri(self) -> None:
1024+
import tempfile
1025+
1026+
import yaml
1027+
1028+
scheduler = create_scheduler("test")
1029+
1030+
# Create overlay file
1031+
overlay = {"spec": {"nodeSelector": {"instance-type": "p4d.24xlarge"}}}
1032+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
1033+
yaml.dump(overlay, f)
1034+
overlay_path = f.name
1035+
1036+
try:
1037+
# Create app with file URI
1038+
trainer_role = specs.Role(
1039+
name="trainer",
1040+
image="pytorch/torchx:latest",
1041+
entrypoint="main",
1042+
resource=specs.Resource(cpu=1, memMB=1000, gpu=0),
1043+
metadata={"kubernetes": f"file://{overlay_path}"},
1044+
)
1045+
app = specs.AppDef("test", roles=[trainer_role])
1046+
cfg = KubernetesOpts({"queue": "testqueue"})
1047+
1048+
info = scheduler.submit_dryrun(app, cfg)
1049+
resource = info.request.resource
1050+
1051+
# Check that overlay was applied
1052+
tasks = resource["spec"]["tasks"] # pyre-ignore[16]
1053+
for task in tasks:
1054+
pod = task["template"]
1055+
self.assertIn("instance-type", pod.spec.node_selector)
1056+
self.assertEqual(
1057+
pod.spec.node_selector["instance-type"], "p4d.24xlarge"
1058+
)
1059+
finally:
1060+
import os
1061+
1062+
os.unlink(overlay_path)
1063+
1064+
def test_submit_dryrun_with_pod_overlay_invalid_type(self) -> None:
1065+
scheduler = create_scheduler("test")
1066+
1067+
# Create app with invalid metadata type
1068+
trainer_role = specs.Role(
1069+
name="trainer",
1070+
image="pytorch/torchx:latest",
1071+
entrypoint="main",
1072+
resource=specs.Resource(cpu=1, memMB=1000, gpu=0),
1073+
metadata={"kubernetes": 123}, # Invalid type
1074+
)
1075+
app = specs.AppDef("test", roles=[trainer_role])
1076+
cfg = KubernetesOpts({"queue": "testqueue"})
1077+
1078+
with self.assertRaises(ValueError) as ctx:
1079+
scheduler.submit_dryrun(app, cfg)
1080+
1081+
self.assertIn("must be a dict or resource URI", str(ctx.exception))
1082+
9321083

9331084
class KubernetesSchedulerNoImportTest(unittest.TestCase):
9341085
"""

0 commit comments

Comments
 (0)