@@ -929,6 +929,157 @@ def test_min_replicas(self) -> None:
929929 ]
930930 self .assertEqual (min_available , [1 , 1 , 0 ])
931931
932+ def test_apply_pod_overlay (self ) -> None :
933+ from kubernetes .client .models import V1Container , V1ObjectMeta , V1Pod , V1PodSpec
934+ from torchx .schedulers .kubernetes_scheduler import _apply_pod_overlay
935+
936+ pod = V1Pod (
937+ spec = V1PodSpec (
938+ containers = [V1Container (name = "test" , image = "test:latest" )],
939+ node_selector = {"existing" : "label" },
940+ ),
941+ metadata = V1ObjectMeta (name = "test-pod" ),
942+ )
943+
944+ overlay = {
945+ "spec" : {
946+ "nodeSelector" : {"gpu" : "true" },
947+ "tolerations" : [{"key" : "nvidia.com/gpu" , "operator" : "Exists" }],
948+ }
949+ }
950+
951+ _apply_pod_overlay (pod , overlay )
952+
953+ self .assertEqual (pod .spec .node_selector , {"existing" : "label" , "gpu" : "true" })
954+ self .assertEqual (len (pod .spec .tolerations ), 1 )
955+ self .assertEqual (pod .spec .tolerations [0 ].key , "nvidia.com/gpu" )
956+
957+ def test_apply_pod_overlay_new_fields (self ) -> None :
958+ from kubernetes .client .models import V1Container , V1ObjectMeta , V1Pod , V1PodSpec
959+ from torchx .schedulers .kubernetes_scheduler import _apply_pod_overlay
960+
961+ # Pod without nodeSelector or tolerations
962+ pod = V1Pod (
963+ spec = V1PodSpec (containers = [V1Container (name = "test" , image = "test:latest" )]),
964+ metadata = V1ObjectMeta (name = "test-pod" ),
965+ )
966+
967+ # Overlay adds fields not present in original
968+ overlay = {
969+ "spec" : {
970+ "nodeSelector" : {"gpu" : "true" },
971+ "tolerations" : [{"key" : "nvidia.com/gpu" , "operator" : "Exists" }],
972+ "affinity" : {
973+ "nodeAffinity" : {
974+ "requiredDuringSchedulingIgnoredDuringExecution" : {
975+ "nodeSelectorTerms" : [
976+ {
977+ "matchExpressions" : [
978+ {
979+ "key" : "gpu" ,
980+ "operator" : "In" ,
981+ "values" : ["true" ],
982+ }
983+ ]
984+ }
985+ ]
986+ }
987+ }
988+ },
989+ }
990+ }
991+
992+ _apply_pod_overlay (pod , overlay )
993+
994+ self .assertEqual (pod .spec .node_selector , {"gpu" : "true" })
995+ self .assertEqual (len (pod .spec .tolerations ), 1 )
996+ self .assertIsNotNone (pod .spec .affinity )
997+ self .assertIsNotNone (pod .spec .affinity .node_affinity )
998+
999+ def test_submit_dryrun_with_pod_overlay (self ) -> None :
1000+ scheduler = create_scheduler ("test" )
1001+
1002+ # Create app with metadata
1003+ trainer_role = specs .Role (
1004+ name = "trainer" ,
1005+ image = "pytorch/torchx:latest" ,
1006+ entrypoint = "main" ,
1007+ resource = specs .Resource (cpu = 1 , memMB = 1000 , gpu = 0 ),
1008+ metadata = {"kubernetes" : {"spec" : {"nodeSelector" : {"gpu" : "true" }}}},
1009+ )
1010+ app = specs .AppDef ("test" , roles = [trainer_role ])
1011+ cfg = KubernetesOpts ({"queue" : "testqueue" })
1012+
1013+ info = scheduler .submit_dryrun (app , cfg )
1014+ resource = info .request .resource
1015+
1016+ # Check that overlay was applied to all pods
1017+ tasks = resource ["spec" ]["tasks" ] # pyre-ignore[16]
1018+ for task in tasks :
1019+ pod = task ["template" ]
1020+ self .assertIn ("gpu" , pod .spec .node_selector )
1021+ self .assertEqual (pod .spec .node_selector ["gpu" ], "true" )
1022+
1023+ def test_submit_dryrun_with_pod_overlay_file_uri (self ) -> None :
1024+ import tempfile
1025+
1026+ import yaml
1027+
1028+ scheduler = create_scheduler ("test" )
1029+
1030+ # Create overlay file
1031+ overlay = {"spec" : {"nodeSelector" : {"instance-type" : "p4d.24xlarge" }}}
1032+ with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".yaml" , delete = False ) as f :
1033+ yaml .dump (overlay , f )
1034+ overlay_path = f .name
1035+
1036+ try :
1037+ # Create app with file URI
1038+ trainer_role = specs .Role (
1039+ name = "trainer" ,
1040+ image = "pytorch/torchx:latest" ,
1041+ entrypoint = "main" ,
1042+ resource = specs .Resource (cpu = 1 , memMB = 1000 , gpu = 0 ),
1043+ metadata = {"kubernetes" : f"file://{ overlay_path } " },
1044+ )
1045+ app = specs .AppDef ("test" , roles = [trainer_role ])
1046+ cfg = KubernetesOpts ({"queue" : "testqueue" })
1047+
1048+ info = scheduler .submit_dryrun (app , cfg )
1049+ resource = info .request .resource
1050+
1051+ # Check that overlay was applied
1052+ tasks = resource ["spec" ]["tasks" ] # pyre-ignore[16]
1053+ for task in tasks :
1054+ pod = task ["template" ]
1055+ self .assertIn ("instance-type" , pod .spec .node_selector )
1056+ self .assertEqual (
1057+ pod .spec .node_selector ["instance-type" ], "p4d.24xlarge"
1058+ )
1059+ finally :
1060+ import os
1061+
1062+ os .unlink (overlay_path )
1063+
1064+ def test_submit_dryrun_with_pod_overlay_invalid_type (self ) -> None :
1065+ scheduler = create_scheduler ("test" )
1066+
1067+ # Create app with invalid metadata type
1068+ trainer_role = specs .Role (
1069+ name = "trainer" ,
1070+ image = "pytorch/torchx:latest" ,
1071+ entrypoint = "main" ,
1072+ resource = specs .Resource (cpu = 1 , memMB = 1000 , gpu = 0 ),
1073+ metadata = {"kubernetes" : 123 }, # Invalid type
1074+ )
1075+ app = specs .AppDef ("test" , roles = [trainer_role ])
1076+ cfg = KubernetesOpts ({"queue" : "testqueue" })
1077+
1078+ with self .assertRaises (ValueError ) as ctx :
1079+ scheduler .submit_dryrun (app , cfg )
1080+
1081+ self .assertIn ("must be a dict or resource URI" , str (ctx .exception ))
1082+
9321083
9331084class KubernetesSchedulerNoImportTest (unittest .TestCase ):
9341085 """
0 commit comments