Skip to content

Commit 1eaaa28

Browse files
committed
feat: pod overlay for kubernetes scheduler (#1067,#1068)
1 parent 1d26b39 commit 1eaaa28

File tree

2 files changed

+207
-2
lines changed

2 files changed

+207
-2
lines changed

torchx/schedulers/kubernetes_scheduler.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,50 @@
2727
See the
2828
`Volcano Quickstart <https://github.com/volcano-sh/volcano>`_
2929
for more information.
30+
31+
Pod Overlay
32+
===========
33+
34+
You can overlay arbitrary Kubernetes PodSpec fields on generated pods using the ``pod``
35+
scheduler argument.
36+
37+
The overlay can be provided as a dict or YAML file path:
38+
39+
.. code:: bash
40+
41+
# Inline dict
42+
torchx run --scheduler kubernetes \\
43+
--scheduler_args 'pod={"spec":{"nodeSelector":{"gpu":"true"}}}' \\
44+
my_component.py
45+
46+
# From YAML file
47+
torchx run --scheduler kubernetes \\
48+
--scheduler_args pod=pod_overlay.yaml \\
49+
my_component.py
50+
51+
Example ``pod_overlay.yaml``:
52+
53+
.. code:: yaml
54+
55+
spec:
56+
nodeSelector:
57+
node.kubernetes.io/instance-type: p4d.24xlarge
58+
tolerations:
59+
- key: nvidia.com/gpu
60+
operator: Exists
61+
effect: NoSchedule
62+
affinity:
63+
podAntiAffinity:
64+
requiredDuringSchedulingIgnoredDuringExecution:
65+
- labelSelector:
66+
matchExpressions:
67+
- key: app
68+
operator: In
69+
values: [trainer]
70+
topologyKey: kubernetes.io/hostname
71+
72+
The overlay is deep-merged with the generated pod spec, preserving existing fields
73+
and adding or overriding specified ones.
3074
"""
3175

3276
import json
@@ -45,6 +89,7 @@
4589
Tuple,
4690
TYPE_CHECKING,
4791
TypedDict,
92+
Union,
4893
)
4994

5095
import torchx
@@ -97,6 +142,42 @@
97142
RESERVED_MILLICPU = 100
98143
RESERVED_MEMMB = 1024
99144

145+
146+
def _load_pod_overlay(pod: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
147+
"""Load pod overlay from dict or YAML file path."""
148+
if isinstance(pod, str):
149+
try:
150+
with open(pod) as f:
151+
return yaml.safe_load(f) or {}
152+
except Exception as e:
153+
raise ValueError(f"Failed to load pod overlay from file {pod}: {e}") from e
154+
elif isinstance(pod, dict):
155+
return pod
156+
else:
157+
raise ValueError(f"pod must be a dict or file path string, got {type(pod)}")
158+
159+
160+
def _apply_pod_overlay(pod: "V1Pod", overlay: Dict[str, Any]) -> None:
161+
"""Apply overlay dict to V1Pod object, merging nested fields."""
162+
from kubernetes import client
163+
164+
api = client.ApiClient()
165+
pod_dict = api.sanitize_for_serialization(pod)
166+
167+
def deep_merge(base: Dict[str, Any], overlay: Dict[str, Any]) -> None:
168+
for key, value in overlay.items():
169+
if isinstance(value, dict) and key in base and isinstance(base[key], dict):
170+
deep_merge(base[key], value)
171+
else:
172+
base[key] = value
173+
174+
deep_merge(pod_dict, overlay)
175+
176+
merged_pod = api._ApiClient__deserialize(pod_dict, "V1Pod")
177+
pod.spec = merged_pod.spec
178+
pod.metadata = merged_pod.metadata
179+
180+
100181
RETRY_POLICIES: Mapping[str, Iterable[Mapping[str, str]]] = {
101182
RetryPolicy.REPLICA: [],
102183
RetryPolicy.APPLICATION: [
@@ -369,6 +450,7 @@ def app_to_resource(
369450
queue: str,
370451
service_account: Optional[str],
371452
priority_class: Optional[str] = None,
453+
pod_overlay: Optional[Dict[str, Any]] = None,
372454
) -> Dict[str, object]:
373455
"""
374456
app_to_resource creates a volcano job kubernetes resource definition from
@@ -402,6 +484,8 @@ def app_to_resource(
402484
replica_role.env["TORCHX_IMAGE"] = replica_role.image
403485

404486
pod = role_to_pod(name, replica_role, service_account)
487+
if pod_overlay:
488+
_apply_pod_overlay(pod, pod_overlay)
405489
pod.metadata.labels.update(
406490
pod_labels(
407491
app=app,
@@ -471,6 +555,7 @@ class KubernetesOpts(TypedDict, total=False):
471555
image_repo: Optional[str]
472556
service_account: Optional[str]
473557
priority_class: Optional[str]
558+
pod: Union[str, Dict[str, Any]]
474559

475560

476561
class KubernetesScheduler(
@@ -636,7 +721,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[KubernetesJob]) -> str:
636721
else:
637722
raise
638723

639-
return f'{namespace}:{resp["metadata"]["name"]}'
724+
return f"{namespace}:{resp['metadata']['name']}"
640725

641726
def _submit_dryrun(
642727
self, app: AppDef, cfg: KubernetesOpts
@@ -658,7 +743,12 @@ def _submit_dryrun(
658743
priority_class, str
659744
), "priority_class must be a str"
660745

661-
resource = app_to_resource(app, queue, service_account, priority_class)
746+
pod = cfg.get("pod")
747+
pod_overlay = _load_pod_overlay(pod) if pod else None
748+
749+
resource = app_to_resource(
750+
app, queue, service_account, priority_class, pod_overlay
751+
)
662752
req = KubernetesJob(
663753
resource=resource,
664754
images_to_push=images_to_push,
@@ -703,6 +793,11 @@ def _run_opts(self) -> runopts:
703793
type_=str,
704794
help="The name of the PriorityClass to set on the job specs",
705795
)
796+
opts.add(
797+
"pod",
798+
type_=Union[str, dict],
799+
help="Pod overlay as dict or YAML file path to merge with generated pod specs",
800+
)
706801
return opts
707802

708803
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:

torchx/schedulers/test/kubernetes_scheduler_test.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,7 @@ def test_runopts(self) -> None:
726726
"image_repo",
727727
"service_account",
728728
"priority_class",
729+
"pod",
729730
},
730731
)
731732

@@ -929,6 +930,115 @@ def test_min_replicas(self) -> None:
929930
]
930931
self.assertEqual(min_available, [1, 1, 0])
931932

933+
def test_load_pod_overlay_dict(self) -> None:
934+
from torchx.schedulers.kubernetes_scheduler import _load_pod_overlay
935+
936+
overlay = {"spec": {"nodeSelector": {"gpu": "true"}}}
937+
result = _load_pod_overlay(overlay)
938+
self.assertEqual(result, overlay)
939+
940+
def test_load_pod_overlay_file(self) -> None:
941+
import tempfile
942+
943+
from torchx.schedulers.kubernetes_scheduler import _load_pod_overlay
944+
945+
overlay = {"spec": {"nodeSelector": {"gpu": "true"}}}
946+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
947+
import yaml
948+
949+
yaml.dump(overlay, f)
950+
result = _load_pod_overlay(f.name)
951+
self.assertEqual(result, overlay)
952+
953+
def test_load_pod_overlay_invalid(self) -> None:
954+
from torchx.schedulers.kubernetes_scheduler import _load_pod_overlay
955+
956+
with self.assertRaises(ValueError):
957+
_load_pod_overlay(123)
958+
959+
def test_apply_pod_overlay(self) -> None:
960+
from kubernetes.client.models import V1Container, V1ObjectMeta, V1Pod, V1PodSpec
961+
from torchx.schedulers.kubernetes_scheduler import _apply_pod_overlay
962+
963+
pod = V1Pod(
964+
spec=V1PodSpec(
965+
containers=[V1Container(name="test", image="test:latest")],
966+
node_selector={"existing": "label"},
967+
),
968+
metadata=V1ObjectMeta(name="test-pod"),
969+
)
970+
971+
overlay = {
972+
"spec": {
973+
"nodeSelector": {"gpu": "true"},
974+
"tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}],
975+
}
976+
}
977+
978+
_apply_pod_overlay(pod, overlay)
979+
980+
self.assertEqual(pod.spec.node_selector, {"existing": "label", "gpu": "true"})
981+
self.assertEqual(len(pod.spec.tolerations), 1)
982+
self.assertEqual(pod.spec.tolerations[0].key, "nvidia.com/gpu")
983+
984+
def test_apply_pod_overlay_new_fields(self) -> None:
985+
from kubernetes.client.models import V1Container, V1ObjectMeta, V1Pod, V1PodSpec
986+
from torchx.schedulers.kubernetes_scheduler import _apply_pod_overlay
987+
988+
# Pod without nodeSelector or tolerations
989+
pod = V1Pod(
990+
spec=V1PodSpec(containers=[V1Container(name="test", image="test:latest")]),
991+
metadata=V1ObjectMeta(name="test-pod"),
992+
)
993+
994+
# Overlay adds fields not present in original
995+
overlay = {
996+
"spec": {
997+
"nodeSelector": {"gpu": "true"},
998+
"tolerations": [{"key": "nvidia.com/gpu", "operator": "Exists"}],
999+
"affinity": {
1000+
"nodeAffinity": {
1001+
"requiredDuringSchedulingIgnoredDuringExecution": {
1002+
"nodeSelectorTerms": [
1003+
{
1004+
"matchExpressions": [
1005+
{
1006+
"key": "gpu",
1007+
"operator": "In",
1008+
"values": ["true"],
1009+
}
1010+
]
1011+
}
1012+
]
1013+
}
1014+
}
1015+
},
1016+
}
1017+
}
1018+
1019+
_apply_pod_overlay(pod, overlay)
1020+
1021+
self.assertEqual(pod.spec.node_selector, {"gpu": "true"})
1022+
self.assertEqual(len(pod.spec.tolerations), 1)
1023+
self.assertIsNotNone(pod.spec.affinity)
1024+
self.assertIsNotNone(pod.spec.affinity.node_affinity)
1025+
1026+
def test_submit_dryrun_with_pod_overlay(self) -> None:
1027+
scheduler = create_scheduler("test")
1028+
app = _test_app()
1029+
cfg = KubernetesOpts(
1030+
{"queue": "testqueue", "pod": {"spec": {"nodeSelector": {"gpu": "true"}}}}
1031+
)
1032+
1033+
info = scheduler.submit_dryrun(app, cfg)
1034+
resource = info.request.resource
1035+
1036+
# Check that overlay was applied to all pods
1037+
for task in resource["spec"]["tasks"]:
1038+
pod = task["template"]
1039+
self.assertIn("gpu", pod.spec.node_selector)
1040+
self.assertEqual(pod.spec.node_selector["gpu"], "true")
1041+
9321042

9331043
class KubernetesSchedulerNoImportTest(unittest.TestCase):
9341044
"""

0 commit comments

Comments
 (0)