@@ -964,3 +964,367 @@ def test_heterogeneous_artifact(
964964 expected_script = f .read ()
965965
966966 assert generated_script .strip () == expected_script .strip ()
967+
968+ def test_heterogeneous_with_het_group_indices (self ):
969+ """Test het job with explicit het_group_indices for final_group_index calculation."""
970+ from unittest .mock import Mock
971+
972+ executor = SlurmExecutor (
973+ account = "test_account" ,
974+ partition = "gpu" ,
975+ heterogeneous = True ,
976+ )
977+ executor .run_as_group = True
978+ # Create resource groups with explicit het_group_indices
979+ executor .resource_group = [
980+ SlurmExecutor .ResourceRequest (
981+ packager = Mock (),
982+ nodes = 2 ,
983+ ntasks_per_node = 8 ,
984+ gpus_per_node = 8 ,
985+ container_image = "gpu_image" ,
986+ container_mounts = ["/data:/data" ],
987+ het_group_index = 0 ,
988+ ),
989+ SlurmExecutor .ResourceRequest (
990+ packager = Mock (),
991+ nodes = 1 ,
992+ ntasks_per_node = 1 ,
993+ gpus_per_node = 0 ,
994+ container_image = "cpu_image" ,
995+ container_mounts = ["/data:/data" ],
996+ het_group_index = 2 , # Non-sequential index
997+ ),
998+ ]
999+ executor .tunnel = Mock (spec = SSHTunnel )
1000+ executor .tunnel .job_dir = "/tmp/test_jobs"
1001+
1002+ request = SlurmRayRequest (
1003+ name = "test-ray-het-cluster" ,
1004+ cluster_dir = "/tmp/test_jobs/test-ray-het-cluster" ,
1005+ template_name = "ray.sub.j2" ,
1006+ executor = executor ,
1007+ launch_cmd = ["sbatch" , "--parsable" ],
1008+ )
1009+
1010+ script = request .materialize ()
1011+
1012+ # Should have het job structure
1013+ assert "#SBATCH hetjob" in script
1014+ # Should have both het group hostnames
1015+ assert "het_group_host_0" in script
1016+ assert "het_group_host_1" in script
1017+
1018+ def test_heterogeneous_duplicate_het_group_index_skipped (self ):
1019+ """Test that duplicate het_group_index ResourceRequests are skipped."""
1020+ from unittest .mock import Mock
1021+
1022+ executor = SlurmExecutor (
1023+ account = "test_account" ,
1024+ partition = "gpu" ,
1025+ heterogeneous = True ,
1026+ )
1027+ executor .run_as_group = True
1028+ # Create resource groups where two share the same het_group_index
1029+ executor .resource_group = [
1030+ SlurmExecutor .ResourceRequest (
1031+ packager = Mock (),
1032+ nodes = 2 ,
1033+ ntasks_per_node = 8 ,
1034+ gpus_per_node = 8 ,
1035+ container_image = "gpu_image" ,
1036+ container_mounts = ["/data:/data" ],
1037+ het_group_index = 0 ,
1038+ ),
1039+ SlurmExecutor .ResourceRequest (
1040+ packager = Mock (),
1041+ nodes = 2 ,
1042+ ntasks_per_node = 8 ,
1043+ gpus_per_node = 8 ,
1044+ container_image = "gpu_image2" ,
1045+ container_mounts = ["/data:/data" ],
1046+ het_group_index = 0 , # Same as previous - should be skipped
1047+ ),
1048+ SlurmExecutor .ResourceRequest (
1049+ packager = Mock (),
1050+ nodes = 1 ,
1051+ ntasks_per_node = 1 ,
1052+ gpus_per_node = 0 ,
1053+ container_image = "cpu_image" ,
1054+ container_mounts = ["/data:/data" ],
1055+ het_group_index = 1 ,
1056+ ),
1057+ ]
1058+ executor .tunnel = Mock (spec = SSHTunnel )
1059+ executor .tunnel .job_dir = "/tmp/test_jobs"
1060+
1061+ request = SlurmRayRequest (
1062+ name = "test-ray-het-cluster" ,
1063+ cluster_dir = "/tmp/test_jobs/test-ray-het-cluster" ,
1064+ template_name = "ray.sub.j2" ,
1065+ executor = executor ,
1066+ launch_cmd = ["sbatch" , "--parsable" ],
1067+ )
1068+
1069+ script = request .materialize ()
1070+
1071+ # Should only have one #SBATCH hetjob separator (2 het groups, not 3)
1072+ assert script .count ("#SBATCH hetjob" ) == 1
1073+ # Should have het group hostnames for each unique het group
1074+ assert "het_group_host_0" in script
1075+ assert "het_group_host_1" in script
1076+
1077+ def test_heterogeneous_with_gpus_per_task (self ):
1078+ """Test het job with gpus_per_task set in ResourceRequest."""
1079+ from unittest .mock import Mock
1080+
1081+ executor = SlurmExecutor (
1082+ account = "test_account" ,
1083+ partition = "gpu" ,
1084+ heterogeneous = True ,
1085+ )
1086+ executor .run_as_group = True
1087+ executor .resource_group = [
1088+ SlurmExecutor .ResourceRequest (
1089+ packager = Mock (),
1090+ nodes = 1 ,
1091+ ntasks_per_node = 8 ,
1092+ gpus_per_node = 8 ,
1093+ gpus_per_task = 1 , # Explicit gpus_per_task
1094+ container_image = "gpu_image" ,
1095+ container_mounts = ["/data:/data" ],
1096+ het_group_index = 0 ,
1097+ ),
1098+ SlurmExecutor .ResourceRequest (
1099+ packager = Mock (),
1100+ nodes = 1 ,
1101+ ntasks_per_node = 1 ,
1102+ gpus_per_node = 0 ,
1103+ container_image = "cpu_image" ,
1104+ container_mounts = ["/data:/data" ],
1105+ het_group_index = 1 ,
1106+ ),
1107+ ]
1108+ executor .tunnel = Mock (spec = SSHTunnel )
1109+ executor .tunnel .job_dir = "/tmp/test_jobs"
1110+
1111+ request = SlurmRayRequest (
1112+ name = "test-ray-het-cluster" ,
1113+ cluster_dir = "/tmp/test_jobs/test-ray-het-cluster" ,
1114+ template_name = "ray.sub.j2" ,
1115+ executor = executor ,
1116+ launch_cmd = ["sbatch" , "--parsable" ],
1117+ )
1118+
1119+ script = request .materialize ()
1120+
1121+ # First het group should have gpus-per-task
1122+ lines = script .split ("\n " )
1123+ het_job_idx = None
1124+ for i , line in enumerate (lines ):
1125+ if "#SBATCH hetjob" in line :
1126+ het_job_idx = i
1127+ break
1128+
1129+ before_hetjob = "\n " .join (lines [:het_job_idx ])
1130+ assert "#SBATCH --gpus-per-task=1" in before_hetjob
1131+
1132+ def test_heterogeneous_with_separate_stderr (self ):
1133+ """Test het job with stderr_to_stdout=False generates error paths."""
1134+ from unittest .mock import Mock
1135+
1136+ executor = SlurmExecutor (
1137+ account = "test_account" ,
1138+ partition = "gpu" ,
1139+ heterogeneous = True ,
1140+ )
1141+ executor .stderr_to_stdout = False # Separate stderr
1142+ executor .run_as_group = True
1143+ executor .resource_group = [
1144+ SlurmExecutor .ResourceRequest (
1145+ packager = Mock (),
1146+ nodes = 2 ,
1147+ ntasks_per_node = 8 ,
1148+ gpus_per_node = 8 ,
1149+ container_image = "gpu_image" ,
1150+ container_mounts = ["/data:/data" ],
1151+ het_group_index = 0 ,
1152+ ),
1153+ SlurmExecutor .ResourceRequest (
1154+ packager = Mock (),
1155+ nodes = 1 ,
1156+ ntasks_per_node = 1 ,
1157+ gpus_per_node = 0 ,
1158+ container_image = "cpu_image" ,
1159+ container_mounts = ["/data:/data" ],
1160+ het_group_index = 1 ,
1161+ ),
1162+ ]
1163+ executor .tunnel = Mock (spec = SSHTunnel )
1164+ executor .tunnel .job_dir = "/tmp/test_jobs"
1165+
1166+ request = SlurmRayRequest (
1167+ name = "test-ray-het-cluster" ,
1168+ cluster_dir = "/tmp/test_jobs/test-ray-het-cluster" ,
1169+ template_name = "ray.sub.j2" ,
1170+ executor = executor ,
1171+ launch_cmd = ["sbatch" , "--parsable" ],
1172+ )
1173+
1174+ script = request .materialize ()
1175+
1176+ # Should have separate error output paths for each het group
1177+ assert "#SBATCH --error=" in script
1178+
1179+ def test_heterogeneous_command_groups_without_het_group_index (self ):
1180+ """Test het command groups fallback to idx when het_group_index is None."""
1181+ from unittest .mock import Mock
1182+
1183+ executor = SlurmExecutor (
1184+ account = "test_account" ,
1185+ heterogeneous = True ,
1186+ )
1187+ executor .run_as_group = True
1188+ # Resource groups WITHOUT het_group_index set
1189+ executor .resource_group = [
1190+ SlurmExecutor .ResourceRequest (
1191+ packager = Mock (),
1192+ nodes = 1 ,
1193+ ntasks_per_node = 8 ,
1194+ gpus_per_node = 8 ,
1195+ container_image = "image1" ,
1196+ container_mounts = ["/data:/data" ],
1197+ # het_group_index not set - should fall back to idx
1198+ ),
1199+ SlurmExecutor .ResourceRequest (
1200+ packager = Mock (),
1201+ nodes = 1 ,
1202+ ntasks_per_node = 1 ,
1203+ gpus_per_node = 0 ,
1204+ container_image = "image2" ,
1205+ container_mounts = ["/data:/data" ],
1206+ # het_group_index not set - should fall back to idx
1207+ ),
1208+ ]
1209+ executor .tunnel = Mock (spec = SSHTunnel )
1210+ executor .tunnel .job_dir = "/tmp/test_jobs"
1211+
1212+ request = SlurmRayRequest (
1213+ name = "test-ray-het-cluster" ,
1214+ cluster_dir = "/tmp/test_jobs/test-ray-het-cluster" ,
1215+ template_name = "ray.sub.j2" ,
1216+ executor = executor ,
1217+ command_groups = [["cmd0" ], ["cmd1" ]],
1218+ launch_cmd = ["sbatch" , "--parsable" ],
1219+ )
1220+
1221+ script = request .materialize ()
1222+
1223+ # Should have het-group flags using idx fallback
1224+ assert "--het-group=1" in script # command_groups[1] uses het-group=1 (idx fallback)
1225+
1226+ def test_heterogeneous_without_run_as_group (self ):
1227+ """Test het job without run_as_group does not add het-group flags to commands."""
1228+ from unittest .mock import Mock
1229+
1230+ executor = SlurmExecutor (
1231+ account = "test_account" ,
1232+ heterogeneous = True ,
1233+ )
1234+ # run_as_group NOT set
1235+ executor .resource_group = [
1236+ SlurmExecutor .ResourceRequest (
1237+ packager = Mock (),
1238+ nodes = 1 ,
1239+ ntasks_per_node = 8 ,
1240+ gpus_per_node = 8 ,
1241+ container_image = "image1" ,
1242+ container_mounts = ["/data:/data" ],
1243+ het_group_index = 0 ,
1244+ ),
1245+ SlurmExecutor .ResourceRequest (
1246+ packager = Mock (),
1247+ nodes = 1 ,
1248+ ntasks_per_node = 1 ,
1249+ gpus_per_node = 0 ,
1250+ container_image = "image2" ,
1251+ container_mounts = ["/data:/data" ],
1252+ het_group_index = 1 ,
1253+ ),
1254+ ]
1255+ executor .tunnel = Mock (spec = SSHTunnel )
1256+ executor .tunnel .job_dir = "/tmp/test_jobs"
1257+
1258+ request = SlurmRayRequest (
1259+ name = "test-ray-het-cluster" ,
1260+ cluster_dir = "/tmp/test_jobs/test-ray-het-cluster" ,
1261+ template_name = "ray.sub.j2" ,
1262+ executor = executor ,
1263+ command_groups = [["cmd0" ], ["cmd1" ]],
1264+ launch_cmd = ["sbatch" , "--parsable" ],
1265+ )
1266+
1267+ script = request .materialize ()
1268+
1269+ # SBATCH het job structure should still exist
1270+ assert "#SBATCH hetjob" in script
1271+ # But command groups should NOT have --het-group flags (run_as_group not set)
1272+ # The overlap srun commands should not have --het-group=1
1273+ # Find the overlap srun command
1274+ lines = script .split ("\n " )
1275+ overlap_srun_lines = [line for line in lines if "overlap" in line and "srun" in line ]
1276+ for line in overlap_srun_lines :
1277+ # These should NOT have --het-group since run_as_group is not set
1278+ if "cmd1" in line :
1279+ assert "--het-group=1" not in line
1280+
1281+ def test_heterogeneous_mismatched_command_groups_length (self ):
1282+ """Test het job when command_groups length doesn't match resource_group length."""
1283+ from unittest .mock import Mock
1284+
1285+ executor = SlurmExecutor (
1286+ account = "test_account" ,
1287+ heterogeneous = True ,
1288+ )
1289+ executor .run_as_group = True
1290+ executor .resource_group = [
1291+ SlurmExecutor .ResourceRequest (
1292+ packager = Mock (),
1293+ nodes = 1 ,
1294+ ntasks_per_node = 8 ,
1295+ gpus_per_node = 8 ,
1296+ container_image = "image1" ,
1297+ container_mounts = ["/data:/data" ],
1298+ het_group_index = 0 ,
1299+ ),
1300+ SlurmExecutor .ResourceRequest (
1301+ packager = Mock (),
1302+ nodes = 1 ,
1303+ ntasks_per_node = 1 ,
1304+ gpus_per_node = 0 ,
1305+ container_image = "image2" ,
1306+ container_mounts = ["/data:/data" ],
1307+ het_group_index = 1 ,
1308+ ),
1309+ ]
1310+ executor .tunnel = Mock (spec = SSHTunnel )
1311+ executor .tunnel .job_dir = "/tmp/test_jobs"
1312+
1313+ # 3 command groups but only 2 resource groups - mismatched
1314+ request = SlurmRayRequest (
1315+ name = "test-ray-het-cluster" ,
1316+ cluster_dir = "/tmp/test_jobs/test-ray-het-cluster" ,
1317+ template_name = "ray.sub.j2" ,
1318+ executor = executor ,
1319+ command_groups = [["cmd0" ], ["cmd1" ], ["cmd2" ]],
1320+ launch_cmd = ["sbatch" , "--parsable" ],
1321+ )
1322+
1323+ script = request .materialize ()
1324+
1325+ # Should still generate script but WITHOUT het-group flags
1326+ # (because lengths don't match)
1327+ assert "#SBATCH hetjob" in script
1328+ # Overlap commands should NOT have --het-group flags
1329+ assert "--het-group=1" not in script
1330+ assert "--het-group=2" not in script
0 commit comments