1111import sys
1212import unittest
1313from datetime import datetime
14- from typing import Any , Dict
14+ from typing import Any , cast , Dict
1515from unittest .mock import MagicMock , patch
1616
1717import torchx
@@ -111,10 +111,7 @@ def test_app_to_resource_resolved_macros(self) -> None:
111111 make_unique_ctx .return_value = unique_app_name
112112 resource = app_to_resource (app , "test_queue" , service_account = None )
113113 actual_cmd = (
114- # pyre-ignore [16]
115- resource ["spec" ]["tasks" ][0 ]["template" ]
116- .spec .containers [0 ]
117- .command
114+ resource ["spec" ]["tasks" ][0 ]["template" ].spec .containers [0 ].command
118115 )
119116 expected_cmd = [
120117 "main" ,
@@ -135,7 +132,6 @@ def test_retry_policy_not_set(self) -> None:
135132 {"event" : "PodEvicted" , "action" : "RestartJob" },
136133 {"event" : "PodFailed" , "action" : "RestartJob" },
137134 ],
138- # pyre-ignore [16]
139135 resource ["spec" ]["tasks" ][0 ]["policies" ],
140136 )
141137 for role in app .roles :
@@ -251,7 +247,11 @@ def test_role_to_pod(self) -> None:
251247 want ,
252248 )
253249
254- def test_submit_dryrun (self ) -> None :
250+ @patch (
251+ "torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
252+ )
253+ def test_submit_dryrun (self , mock_api : MagicMock ) -> None :
254+ mock_api .return_value .create_namespaced_custom_object .return_value = {}
255255 scheduler = create_scheduler ("test" )
256256 app = _test_app ()
257257 cfg = KubernetesOpts ({"queue" : "testqueue" })
@@ -262,6 +262,9 @@ def test_submit_dryrun(self) -> None:
262262 info = scheduler .submit_dryrun (app , cfg )
263263
264264 resource = str (info .request )
265+ mock_api .return_value .create_namespaced_custom_object .assert_called_once ()
266+ call_kwargs = mock_api .return_value .create_namespaced_custom_object .call_args [1 ]
267+ self .assertEqual (call_kwargs ["dry_run" ], "All" )
265268
266269 print (resource )
267270
@@ -505,7 +508,11 @@ def test_instance_type(self) -> None:
505508 },
506509 )
507510
508- def test_rank0_env (self ) -> None :
511+ @patch (
512+ "torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
513+ )
514+ def test_rank0_env (self , mock_api : MagicMock ) -> None :
515+ mock_api .return_value .create_namespaced_custom_object .return_value = {}
509516 from kubernetes .client .models import V1EnvVar
510517
511518 scheduler = create_scheduler ("test" )
@@ -517,7 +524,7 @@ def test_rank0_env(self) -> None:
517524 make_unique_ctx .return_value = "app-name-42"
518525 info = scheduler .submit_dryrun (app , cfg )
519526
520- tasks = info .request .resource ["spec" ]["tasks" ] # pyre-ignore[16]
527+ tasks = info .request .resource ["spec" ]["tasks" ]
521528 container0 = tasks [0 ]["template" ].spec .containers [0 ]
522529 self .assertIn ("TORCHX_RANK0_HOST" , container0 .command )
523530 self .assertIn (
@@ -528,8 +535,16 @@ def test_rank0_env(self) -> None:
528535 )
529536 container1 = tasks [1 ]["template" ].spec .containers [0 ]
530537 self .assertIn ("VC_TRAINERFOO_0_HOSTS" , container1 .command )
538+ mock_api .return_value .create_namespaced_custom_object .assert_called_once ()
539+ call_kwargs = mock_api .return_value .create_namespaced_custom_object .call_args [1 ]
540+ self .assertEqual (call_kwargs ["dry_run" ], "All" )
541+ self .assertEqual (call_kwargs ["namespace" ], "default" )
531542
532- def test_submit_dryrun_patch (self ) -> None :
543+ @patch (
544+ "torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
545+ )
546+ def test_submit_dryrun_patch (self , mock_api : MagicMock ) -> None :
547+ mock_api .return_value .create_namespaced_custom_object .return_value = {}
533548 scheduler = create_scheduler ("test" )
534549 app = _test_app ()
535550 app .roles [0 ].image = "sha256:testhash"
@@ -555,8 +570,15 @@ def test_submit_dryrun_patch(self) -> None:
555570 ),
556571 },
557572 )
573+ mock_api .return_value .create_namespaced_custom_object .assert_called_once ()
574+ call_kwargs = mock_api .return_value .create_namespaced_custom_object .call_args [1 ]
575+ self .assertEqual (call_kwargs ["dry_run" ], "All" )
558576
559- def test_submit_dryrun_service_account (self ) -> None :
577+ @patch (
578+ "torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
579+ )
580+ def test_submit_dryrun_service_account (self , mock_api : MagicMock ) -> None :
581+ mock_api .return_value .create_namespaced_custom_object .return_value = {}
560582 scheduler = create_scheduler ("test" )
561583 self .assertIn ("service_account" , scheduler .run_opts ()._opts )
562584 app = _test_app ()
@@ -573,7 +595,17 @@ def test_submit_dryrun_service_account(self) -> None:
573595 info = scheduler .submit_dryrun (app , cfg )
574596 self .assertIn ("service_account_name': None" , str (info .request .resource ))
575597
576- def test_submit_dryrun_priority_class (self ) -> None :
598+ self .assertEqual (
599+ mock_api .return_value .create_namespaced_custom_object .call_count , 2
600+ )
601+ call_kwargs = mock_api .return_value .create_namespaced_custom_object .call_args [1 ]
602+ self .assertEqual (call_kwargs ["dry_run" ], "All" )
603+
604+ @patch (
605+ "torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
606+ )
607+ def test_submit_dryrun_priority_class (self , mock_api : MagicMock ) -> None :
608+ mock_api .return_value .create_namespaced_custom_object .return_value = {}
577609 scheduler = create_scheduler ("test" )
578610 self .assertIn ("priority_class" , scheduler .run_opts ()._opts )
579611 app = _test_app ()
@@ -591,6 +623,12 @@ def test_submit_dryrun_priority_class(self) -> None:
591623 info = scheduler .submit_dryrun (app , cfg )
592624 self .assertNotIn ("'priorityClassName'" , str (info .request .resource ))
593625
626+ self .assertEqual (
627+ mock_api .return_value .create_namespaced_custom_object .call_count , 2
628+ )
629+ call_kwargs = mock_api .return_value .create_namespaced_custom_object .call_args [1 ]
630+ self .assertEqual (call_kwargs ["dry_run" ], "All" )
631+
594632 @patch ("kubernetes.client.CustomObjectsApi.create_namespaced_custom_object" )
595633 def test_submit (self , create_namespaced_custom_object : MagicMock ) -> None :
596634 create_namespaced_custom_object .return_value = {
@@ -624,7 +662,7 @@ def test_submit_job_name_conflict(
624662
625663 api_exc = ApiException (status = 409 , reason = "Conflict" )
626664 api_exc .body = '{"details":{"name": "test_job"}}'
627- create_namespaced_custom_object .side_effect = api_exc
665+ create_namespaced_custom_object .side_effect = [{}, api_exc ]
628666
629667 scheduler = create_scheduler ("test" )
630668 app = _test_app ()
@@ -638,6 +676,14 @@ def test_submit_job_name_conflict(
638676 with self .assertRaises (ValueError ):
639677 scheduler .schedule (info )
640678
679+ self .assertEqual (create_namespaced_custom_object .call_count , 2 )
680+ # First call is spec validation
681+ first_call_kwargs = create_namespaced_custom_object .call_args_list [0 ][1 ]
682+ self .assertEqual (first_call_kwargs ["dry_run" ], "All" )
683+ # Second call is actual schedule
684+ second_call_kwargs = create_namespaced_custom_object .call_args_list [1 ][1 ]
685+ self .assertNotIn ("dry_run" , second_call_kwargs )
686+
641687 @patch ("kubernetes.client.CustomObjectsApi.get_namespaced_custom_object_status" )
642688 def test_describe (self , get_namespaced_custom_object_status : MagicMock ) -> None :
643689 get_namespaced_custom_object_status .return_value = {
@@ -752,6 +798,7 @@ def test_runopts(self) -> None:
752798 "image_repo" ,
753799 "service_account" ,
754800 "priority_class" ,
801+ "validate_spec" ,
755802 },
756803 )
757804
@@ -949,12 +996,103 @@ def test_min_replicas(self) -> None:
949996 app .roles [0 ].min_replicas = 2
950997
951998 resource = app_to_resource (app , "test_queue" , service_account = None )
952- min_available = [
953- task ["minAvailable" ]
954- for task in resource ["spec" ]["tasks" ] # pyre-ignore[16]
955- ]
999+ min_available = [task ["minAvailable" ] for task in resource ["spec" ]["tasks" ]]
9561000 self .assertEqual (min_available , [1 , 1 , 0 ])
9571001
1002+ @patch (
1003+ "torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
1004+ )
1005+ def test_validate_spec_invalid_name (self , mock_api : MagicMock ) -> None :
1006+ from kubernetes .client .rest import ApiException
1007+
1008+ scheduler = create_scheduler ("test" )
1009+ app = _test_app ()
1010+ app .name = "Invalid_Name"
1011+
1012+ mock_api_instance = MagicMock ()
1013+ mock_api_instance .create_namespaced_custom_object .side_effect = ApiException (
1014+ status = 422 ,
1015+ reason = "Invalid" ,
1016+ )
1017+ mock_api .return_value = mock_api_instance
1018+
1019+ cfg = cast (KubernetesOpts , {"queue" : "testqueue" , "validate_spec" : True })
1020+
1021+ with self .assertRaises (ValueError ) as ctx :
1022+ scheduler .submit_dryrun (app , cfg )
1023+
1024+ self .assertIn ("Invalid job spec" , str (ctx .exception ))
1025+ mock_api_instance .create_namespaced_custom_object .assert_called_once ()
1026+ call_kwargs = mock_api_instance .create_namespaced_custom_object .call_args [1 ]
1027+ self .assertEqual (call_kwargs ["dry_run" ], "All" )
1028+
1029+ def test_validate_spec_disabled (self ) -> None :
1030+ scheduler = create_scheduler ("test" )
1031+ app = _test_app ()
1032+
1033+ cfg = KubernetesOpts ({"queue" : "testqueue" , "validate_spec" : False })
1034+
1035+ with patch (
1036+ "torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
1037+ ) as mock_api :
1038+ mock_api_instance = MagicMock ()
1039+ mock_api_instance .create_namespaced_custom_object .return_value = {}
1040+ mock_api .return_value = mock_api_instance
1041+
1042+ info = scheduler .submit_dryrun (app , cfg )
1043+
1044+ self .assertIsNotNone (info )
1045+ mock_api_instance .create_namespaced_custom_object .assert_not_called ()
1046+
1047+ @patch (
1048+ "torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
1049+ )
1050+ def test_validate_spec_invalid_task_name (self , mock_api : MagicMock ) -> None :
1051+ from kubernetes .client .rest import ApiException
1052+
1053+ scheduler = create_scheduler ("test" )
1054+ app = _test_app ()
1055+ app .roles [0 ].name = "Invalid-Task-Name"
1056+
1057+ mock_api_instance = MagicMock ()
1058+ mock_api_instance .create_namespaced_custom_object .side_effect = ApiException (
1059+ status = 422 ,
1060+ reason = "Invalid" ,
1061+ )
1062+ mock_api .return_value = mock_api_instance
1063+
1064+ cfg = cast (KubernetesOpts , {"queue" : "testqueue" , "validate_spec" : True })
1065+
1066+ with self .assertRaises (ValueError ) as ctx :
1067+ scheduler .submit_dryrun (app , cfg )
1068+
1069+ self .assertIn ("Invalid job spec" , str (ctx .exception ))
1070+
1071+ @patch (
1072+ "torchx.schedulers.kubernetes_scheduler.KubernetesScheduler._custom_objects_api"
1073+ )
1074+ def test_validate_spec_long_pod_name (self , mock_api : MagicMock ) -> None :
1075+ scheduler = create_scheduler ("test" )
1076+ app = _test_app ()
1077+ app .name = "x" * 50
1078+ app .roles [0 ].name = "y" * 20
1079+
1080+ mock_api_instance = MagicMock ()
1081+ mock_api_instance .create_namespaced_custom_object .return_value = {}
1082+ mock_api .return_value = mock_api_instance
1083+
1084+ cfg = cast (KubernetesOpts , {"queue" : "testqueue" , "validate_spec" : True })
1085+
1086+ with patch (
1087+ "torchx.schedulers.kubernetes_scheduler.make_unique"
1088+ ) as make_unique_ctx :
1089+ make_unique_ctx .return_value = "x" * 50
1090+ with self .assertRaises (ValueError ) as ctx :
1091+ scheduler .submit_dryrun (app , cfg )
1092+
1093+ self .assertIn ("Pod name" , str (ctx .exception ))
1094+ self .assertIn ("exceeds 63 character limit" , str (ctx .exception ))
1095+
9581096
9591097class KubernetesSchedulerNoImportTest (unittest .TestCase ):
9601098 """
0 commit comments