Skip to content

Commit a01f925

Browse files
clumsyazzhipa
andauthored
feat: add option to validate k8s spec (#1152) (#1153)
Co-authored-by: Alexander Zhipa <azzhipa@amazon.com>
1 parent 5957532 commit a01f925

File tree

2 files changed

+195
-20
lines changed

2 files changed

+195
-20
lines changed

torchx/schedulers/kubernetes_scheduler.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def app_to_resource(
369369
queue: str,
370370
service_account: Optional[str],
371371
priority_class: Optional[str] = None,
372-
) -> Dict[str, object]:
372+
) -> Dict[str, Any]:
373373
"""
374374
app_to_resource creates a volcano job kubernetes resource definition from
375375
the provided AppDef. The resource definition can be used to launch the
@@ -444,7 +444,7 @@ def app_to_resource(
444444
if priority_class is not None:
445445
job_spec["priorityClassName"] = priority_class
446446

447-
resource: Dict[str, object] = {
447+
resource: Dict[str, Any] = {
448448
"apiVersion": "batch.volcano.sh/v1alpha1",
449449
"kind": "Job",
450450
"metadata": {"name": f"{unique_app_id}"},
@@ -456,7 +456,7 @@ def app_to_resource(
456456
@dataclass
457457
class KubernetesJob:
458458
images_to_push: Dict[str, Tuple[str, str]]
459-
resource: Dict[str, object]
459+
resource: Dict[str, Any]
460460

461461
def __str__(self) -> str:
462462
return yaml.dump(sanitize_for_serialization(self.resource))
@@ -471,6 +471,7 @@ class KubernetesOpts(TypedDict, total=False):
471471
image_repo: Optional[str]
472472
service_account: Optional[str]
473473
priority_class: Optional[str]
474+
validate_spec: Optional[bool]
474475

475476

476477
class KubernetesScheduler(
@@ -659,6 +660,36 @@ def _submit_dryrun(
659660
), "priority_class must be a str"
660661

661662
resource = app_to_resource(app, queue, service_account, priority_class)
663+
664+
if cfg.get("validate_spec"):
665+
try:
666+
self._custom_objects_api().create_namespaced_custom_object(
667+
group="batch.volcano.sh",
668+
version="v1alpha1",
669+
namespace=cfg.get("namespace") or "default",
670+
plural="jobs",
671+
body=resource,
672+
dry_run="All",
673+
)
674+
except Exception as e:
675+
from kubernetes.client.rest import ApiException
676+
677+
if isinstance(e, ApiException):
678+
raise ValueError(f"Invalid job spec: {e.reason}") from e
679+
raise
680+
681+
job_name = resource["metadata"]["name"]
682+
for task in resource["spec"]["tasks"]:
683+
task_name = task["name"]
684+
replicas = task.get("replicas", 1)
685+
max_index = replicas - 1
686+
pod_name = f"{job_name}-{task_name}-{max_index}"
687+
if len(pod_name) > 63:
688+
raise ValueError(
689+
f"Pod name '{pod_name}' ({len(pod_name)} chars) exceeds 63 character limit. "
690+
f"Shorten app.name or role names"
691+
)
692+
662693
req = KubernetesJob(
663694
resource=resource,
664695
images_to_push=images_to_push,
@@ -703,6 +734,12 @@ def _run_opts(self) -> runopts:
703734
type_=str,
704735
help="The name of the PriorityClass to set on the job specs",
705736
)
737+
opts.add(
738+
"validate_spec",
739+
type_=bool,
740+
help="Validate job spec using Kubernetes API dry-run before submission",
741+
default=True,
742+
)
706743
return opts
707744

708745
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:

torchx/schedulers/test/kubernetes_scheduler_test.py

Lines changed: 155 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import sys
1212
import unittest
1313
from datetime import datetime
14-
from typing import Any, Dict
14+
from typing import Any, cast, Dict
1515
from unittest.mock import MagicMock, patch
1616

1717
import torchx
@@ -111,10 +111,7 @@ def test_app_to_resource_resolved_macros(self) -> None:
111111
make_unique_ctx.return_value = unique_app_name
112112
resource = app_to_resource(app, "test_queue", service_account=None)
113113
actual_cmd = (
114-
# pyre-ignore [16]
115-
resource["spec"]["tasks"][0]["template"]
116-
.spec.containers[0]
117-
.command
114+
resource["spec"]["tasks"][0]["template"].spec.containers[0].command
118115
)
119116
expected_cmd = [
120117
"main",
@@ -135,7 +132,6 @@ def test_retry_policy_not_set(self) -> None:
135132
{"event": "PodEvicted", "action": "RestartJob"},
136133
{"event": "PodFailed", "action": "RestartJob"},
137134
],
138-
# pyre-ignore [16]
139135
resource["spec"]["tasks"][0]["policies"],
140136
)
141137
for role in app.roles:
@@ -251,7 +247,11 @@ def test_role_to_pod(self) -> None:
251247
want,
252248
)
253249

254-
def test_submit_dryrun(self) -> None:
250+
@patch(
251+
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
252+
)
253+
def test_submit_dryrun(self, mock_api: MagicMock) -> None:
254+
mock_api.return_value.create_namespaced_custom_object.return_value = {}
255255
scheduler = create_scheduler("test")
256256
app = _test_app()
257257
cfg = KubernetesOpts({"queue": "testqueue"})
@@ -262,6 +262,9 @@ def test_submit_dryrun(self) -> None:
262262
info = scheduler.submit_dryrun(app, cfg)
263263

264264
resource = str(info.request)
265+
mock_api.return_value.create_namespaced_custom_object.assert_called_once()
266+
call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1]
267+
self.assertEqual(call_kwargs["dry_run"], "All")
265268

266269
print(resource)
267270

@@ -505,7 +508,11 @@ def test_instance_type(self) -> None:
505508
},
506509
)
507510

508-
def test_rank0_env(self) -> None:
511+
@patch(
512+
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
513+
)
514+
def test_rank0_env(self, mock_api: MagicMock) -> None:
515+
mock_api.return_value.create_namespaced_custom_object.return_value = {}
509516
from kubernetes.client.models import V1EnvVar
510517

511518
scheduler = create_scheduler("test")
@@ -517,7 +524,7 @@ def test_rank0_env(self) -> None:
517524
make_unique_ctx.return_value = "app-name-42"
518525
info = scheduler.submit_dryrun(app, cfg)
519526

520-
tasks = info.request.resource["spec"]["tasks"] # pyre-ignore[16]
527+
tasks = info.request.resource["spec"]["tasks"]
521528
container0 = tasks[0]["template"].spec.containers[0]
522529
self.assertIn("TORCHX_RANK0_HOST", container0.command)
523530
self.assertIn(
@@ -528,8 +535,16 @@ def test_rank0_env(self) -> None:
528535
)
529536
container1 = tasks[1]["template"].spec.containers[0]
530537
self.assertIn("VC_TRAINERFOO_0_HOSTS", container1.command)
538+
mock_api.return_value.create_namespaced_custom_object.assert_called_once()
539+
call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1]
540+
self.assertEqual(call_kwargs["dry_run"], "All")
541+
self.assertEqual(call_kwargs["namespace"], "default")
531542

532-
def test_submit_dryrun_patch(self) -> None:
543+
@patch(
544+
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
545+
)
546+
def test_submit_dryrun_patch(self, mock_api: MagicMock) -> None:
547+
mock_api.return_value.create_namespaced_custom_object.return_value = {}
533548
scheduler = create_scheduler("test")
534549
app = _test_app()
535550
app.roles[0].image = "sha256:testhash"
@@ -555,8 +570,15 @@ def test_submit_dryrun_patch(self) -> None:
555570
),
556571
},
557572
)
573+
mock_api.return_value.create_namespaced_custom_object.assert_called_once()
574+
call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1]
575+
self.assertEqual(call_kwargs["dry_run"], "All")
558576

559-
def test_submit_dryrun_service_account(self) -> None:
577+
@patch(
578+
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
579+
)
580+
def test_submit_dryrun_service_account(self, mock_api: MagicMock) -> None:
581+
mock_api.return_value.create_namespaced_custom_object.return_value = {}
560582
scheduler = create_scheduler("test")
561583
self.assertIn("service_account", scheduler.run_opts()._opts)
562584
app = _test_app()
@@ -573,7 +595,17 @@ def test_submit_dryrun_service_account(self) -> None:
573595
info = scheduler.submit_dryrun(app, cfg)
574596
self.assertIn("service_account_name': None", str(info.request.resource))
575597

576-
def test_submit_dryrun_priority_class(self) -> None:
598+
self.assertEqual(
599+
mock_api.return_value.create_namespaced_custom_object.call_count, 2
600+
)
601+
call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1]
602+
self.assertEqual(call_kwargs["dry_run"], "All")
603+
604+
@patch(
605+
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
606+
)
607+
def test_submit_dryrun_priority_class(self, mock_api: MagicMock) -> None:
608+
mock_api.return_value.create_namespaced_custom_object.return_value = {}
577609
scheduler = create_scheduler("test")
578610
self.assertIn("priority_class", scheduler.run_opts()._opts)
579611
app = _test_app()
@@ -591,6 +623,12 @@ def test_submit_dryrun_priority_class(self) -> None:
591623
info = scheduler.submit_dryrun(app, cfg)
592624
self.assertNotIn("'priorityClassName'", str(info.request.resource))
593625

626+
self.assertEqual(
627+
mock_api.return_value.create_namespaced_custom_object.call_count, 2
628+
)
629+
call_kwargs = mock_api.return_value.create_namespaced_custom_object.call_args[1]
630+
self.assertEqual(call_kwargs["dry_run"], "All")
631+
594632
@patch("kubernetes.client.CustomObjectsApi.create_namespaced_custom_object")
595633
def test_submit(self, create_namespaced_custom_object: MagicMock) -> None:
596634
create_namespaced_custom_object.return_value = {
@@ -624,7 +662,7 @@ def test_submit_job_name_conflict(
624662

625663
api_exc = ApiException(status=409, reason="Conflict")
626664
api_exc.body = '{"details":{"name": "test_job"}}'
627-
create_namespaced_custom_object.side_effect = api_exc
665+
create_namespaced_custom_object.side_effect = [{}, api_exc]
628666

629667
scheduler = create_scheduler("test")
630668
app = _test_app()
@@ -638,6 +676,14 @@ def test_submit_job_name_conflict(
638676
with self.assertRaises(ValueError):
639677
scheduler.schedule(info)
640678

679+
self.assertEqual(create_namespaced_custom_object.call_count, 2)
680+
# First call is spec validation
681+
first_call_kwargs = create_namespaced_custom_object.call_args_list[0][1]
682+
self.assertEqual(first_call_kwargs["dry_run"], "All")
683+
# Second call is actual schedule
684+
second_call_kwargs = create_namespaced_custom_object.call_args_list[1][1]
685+
self.assertNotIn("dry_run", second_call_kwargs)
686+
641687
@patch("kubernetes.client.CustomObjectsApi.get_namespaced_custom_object_status")
642688
def test_describe(self, get_namespaced_custom_object_status: MagicMock) -> None:
643689
get_namespaced_custom_object_status.return_value = {
@@ -752,6 +798,7 @@ def test_runopts(self) -> None:
752798
"image_repo",
753799
"service_account",
754800
"priority_class",
801+
"validate_spec",
755802
},
756803
)
757804

@@ -949,12 +996,103 @@ def test_min_replicas(self) -> None:
949996
app.roles[0].min_replicas = 2
950997

951998
resource = app_to_resource(app, "test_queue", service_account=None)
952-
min_available = [
953-
task["minAvailable"]
954-
for task in resource["spec"]["tasks"] # pyre-ignore[16]
955-
]
999+
min_available = [task["minAvailable"] for task in resource["spec"]["tasks"]]
9561000
self.assertEqual(min_available, [1, 1, 0])
9571001

1002+
@patch(
1003+
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
1004+
)
1005+
def test_validate_spec_invalid_name(self, mock_api: MagicMock) -> None:
1006+
from kubernetes.client.rest import ApiException
1007+
1008+
scheduler = create_scheduler("test")
1009+
app = _test_app()
1010+
app.name = "Invalid_Name"
1011+
1012+
mock_api_instance = MagicMock()
1013+
mock_api_instance.create_namespaced_custom_object.side_effect = ApiException(
1014+
status=422,
1015+
reason="Invalid",
1016+
)
1017+
mock_api.return_value = mock_api_instance
1018+
1019+
cfg = cast(KubernetesOpts, {"queue": "testqueue", "validate_spec": True})
1020+
1021+
with self.assertRaises(ValueError) as ctx:
1022+
scheduler.submit_dryrun(app, cfg)
1023+
1024+
self.assertIn("Invalid job spec", str(ctx.exception))
1025+
mock_api_instance.create_namespaced_custom_object.assert_called_once()
1026+
call_kwargs = mock_api_instance.create_namespaced_custom_object.call_args[1]
1027+
self.assertEqual(call_kwargs["dry_run"], "All")
1028+
1029+
def test_validate_spec_disabled(self) -> None:
1030+
scheduler = create_scheduler("test")
1031+
app = _test_app()
1032+
1033+
cfg = KubernetesOpts({"queue": "testqueue", "validate_spec": False})
1034+
1035+
with patch(
1036+
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
1037+
) as mock_api:
1038+
mock_api_instance = MagicMock()
1039+
mock_api_instance.create_namespaced_custom_object.return_value = {}
1040+
mock_api.return_value = mock_api_instance
1041+
1042+
info = scheduler.submit_dryrun(app, cfg)
1043+
1044+
self.assertIsNotNone(info)
1045+
mock_api_instance.create_namespaced_custom_object.assert_not_called()
1046+
1047+
@patch(
1048+
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
1049+
)
1050+
def test_validate_spec_invalid_task_name(self, mock_api: MagicMock) -> None:
1051+
from kubernetes.client.rest import ApiException
1052+
1053+
scheduler = create_scheduler("test")
1054+
app = _test_app()
1055+
app.roles[0].name = "Invalid-Task-Name"
1056+
1057+
mock_api_instance = MagicMock()
1058+
mock_api_instance.create_namespaced_custom_object.side_effect = ApiException(
1059+
status=422,
1060+
reason="Invalid",
1061+
)
1062+
mock_api.return_value = mock_api_instance
1063+
1064+
cfg = cast(KubernetesOpts, {"queue": "testqueue", "validate_spec": True})
1065+
1066+
with self.assertRaises(ValueError) as ctx:
1067+
scheduler.submit_dryrun(app, cfg)
1068+
1069+
self.assertIn("Invalid job spec", str(ctx.exception))
1070+
1071+
@patch(
1072+
"torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
1073+
)
1074+
def test_validate_spec_long_pod_name(self, mock_api: MagicMock) -> None:
1075+
scheduler = create_scheduler("test")
1076+
app = _test_app()
1077+
app.name = "x" * 50
1078+
app.roles[0].name = "y" * 20
1079+
1080+
mock_api_instance = MagicMock()
1081+
mock_api_instance.create_namespaced_custom_object.return_value = {}
1082+
mock_api.return_value = mock_api_instance
1083+
1084+
cfg = cast(KubernetesOpts, {"queue": "testqueue", "validate_spec": True})
1085+
1086+
with patch(
1087+
"torchx.schedulers.kubernetes_scheduler.make_unique"
1088+
) as make_unique_ctx:
1089+
make_unique_ctx.return_value = "x" * 50
1090+
with self.assertRaises(ValueError) as ctx:
1091+
scheduler.submit_dryrun(app, cfg)
1092+
1093+
self.assertIn("Pod name", str(ctx.exception))
1094+
self.assertIn("exceeds 63 character limit", str(ctx.exception))
1095+
9581096

9591097
class KubernetesSchedulerNoImportTest(unittest.TestCase):
9601098
"""

0 commit comments

Comments
 (0)