Skip to content

Commit 6e9f6a6

Browse files
committed
fix
Signed-off-by: Hemil Desai <[email protected]>
1 parent 41d58e0 commit 6e9f6a6

File tree

1 file changed

+364
-0
lines changed

1 file changed

+364
-0
lines changed

test/run/ray/test_slurm_ray_request.py

Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,3 +964,367 @@ def test_heterogeneous_artifact(
964964
expected_script = f.read()
965965

966966
assert generated_script.strip() == expected_script.strip()
967+
968+
def test_heterogeneous_with_het_group_indices(self):
969+
"""Test het job with explicit het_group_indices for final_group_index calculation."""
970+
from unittest.mock import Mock
971+
972+
executor = SlurmExecutor(
973+
account="test_account",
974+
partition="gpu",
975+
heterogeneous=True,
976+
)
977+
executor.run_as_group = True
978+
# Create resource groups with explicit het_group_indices
979+
executor.resource_group = [
980+
SlurmExecutor.ResourceRequest(
981+
packager=Mock(),
982+
nodes=2,
983+
ntasks_per_node=8,
984+
gpus_per_node=8,
985+
container_image="gpu_image",
986+
container_mounts=["/data:/data"],
987+
het_group_index=0,
988+
),
989+
SlurmExecutor.ResourceRequest(
990+
packager=Mock(),
991+
nodes=1,
992+
ntasks_per_node=1,
993+
gpus_per_node=0,
994+
container_image="cpu_image",
995+
container_mounts=["/data:/data"],
996+
het_group_index=2, # Non-sequential index
997+
),
998+
]
999+
executor.tunnel = Mock(spec=SSHTunnel)
1000+
executor.tunnel.job_dir = "/tmp/test_jobs"
1001+
1002+
request = SlurmRayRequest(
1003+
name="test-ray-het-cluster",
1004+
cluster_dir="/tmp/test_jobs/test-ray-het-cluster",
1005+
template_name="ray.sub.j2",
1006+
executor=executor,
1007+
launch_cmd=["sbatch", "--parsable"],
1008+
)
1009+
1010+
script = request.materialize()
1011+
1012+
# Should have het job structure
1013+
assert "#SBATCH hetjob" in script
1014+
# Should have both het group hostnames
1015+
assert "het_group_host_0" in script
1016+
assert "het_group_host_1" in script
1017+
1018+
def test_heterogeneous_duplicate_het_group_index_skipped(self):
1019+
"""Test that duplicate het_group_index ResourceRequests are skipped."""
1020+
from unittest.mock import Mock
1021+
1022+
executor = SlurmExecutor(
1023+
account="test_account",
1024+
partition="gpu",
1025+
heterogeneous=True,
1026+
)
1027+
executor.run_as_group = True
1028+
# Create resource groups where two share the same het_group_index
1029+
executor.resource_group = [
1030+
SlurmExecutor.ResourceRequest(
1031+
packager=Mock(),
1032+
nodes=2,
1033+
ntasks_per_node=8,
1034+
gpus_per_node=8,
1035+
container_image="gpu_image",
1036+
container_mounts=["/data:/data"],
1037+
het_group_index=0,
1038+
),
1039+
SlurmExecutor.ResourceRequest(
1040+
packager=Mock(),
1041+
nodes=2,
1042+
ntasks_per_node=8,
1043+
gpus_per_node=8,
1044+
container_image="gpu_image2",
1045+
container_mounts=["/data:/data"],
1046+
het_group_index=0, # Same as previous - should be skipped
1047+
),
1048+
SlurmExecutor.ResourceRequest(
1049+
packager=Mock(),
1050+
nodes=1,
1051+
ntasks_per_node=1,
1052+
gpus_per_node=0,
1053+
container_image="cpu_image",
1054+
container_mounts=["/data:/data"],
1055+
het_group_index=1,
1056+
),
1057+
]
1058+
executor.tunnel = Mock(spec=SSHTunnel)
1059+
executor.tunnel.job_dir = "/tmp/test_jobs"
1060+
1061+
request = SlurmRayRequest(
1062+
name="test-ray-het-cluster",
1063+
cluster_dir="/tmp/test_jobs/test-ray-het-cluster",
1064+
template_name="ray.sub.j2",
1065+
executor=executor,
1066+
launch_cmd=["sbatch", "--parsable"],
1067+
)
1068+
1069+
script = request.materialize()
1070+
1071+
# Should only have one #SBATCH hetjob separator (2 het groups, not 3)
1072+
assert script.count("#SBATCH hetjob") == 1
1073+
# Should have het group hostnames for each unique het group
1074+
assert "het_group_host_0" in script
1075+
assert "het_group_host_1" in script
1076+
1077+
def test_heterogeneous_with_gpus_per_task(self):
1078+
"""Test het job with gpus_per_task set in ResourceRequest."""
1079+
from unittest.mock import Mock
1080+
1081+
executor = SlurmExecutor(
1082+
account="test_account",
1083+
partition="gpu",
1084+
heterogeneous=True,
1085+
)
1086+
executor.run_as_group = True
1087+
executor.resource_group = [
1088+
SlurmExecutor.ResourceRequest(
1089+
packager=Mock(),
1090+
nodes=1,
1091+
ntasks_per_node=8,
1092+
gpus_per_node=8,
1093+
gpus_per_task=1, # Explicit gpus_per_task
1094+
container_image="gpu_image",
1095+
container_mounts=["/data:/data"],
1096+
het_group_index=0,
1097+
),
1098+
SlurmExecutor.ResourceRequest(
1099+
packager=Mock(),
1100+
nodes=1,
1101+
ntasks_per_node=1,
1102+
gpus_per_node=0,
1103+
container_image="cpu_image",
1104+
container_mounts=["/data:/data"],
1105+
het_group_index=1,
1106+
),
1107+
]
1108+
executor.tunnel = Mock(spec=SSHTunnel)
1109+
executor.tunnel.job_dir = "/tmp/test_jobs"
1110+
1111+
request = SlurmRayRequest(
1112+
name="test-ray-het-cluster",
1113+
cluster_dir="/tmp/test_jobs/test-ray-het-cluster",
1114+
template_name="ray.sub.j2",
1115+
executor=executor,
1116+
launch_cmd=["sbatch", "--parsable"],
1117+
)
1118+
1119+
script = request.materialize()
1120+
1121+
# First het group should have gpus-per-task
1122+
lines = script.split("\n")
1123+
het_job_idx = None
1124+
for i, line in enumerate(lines):
1125+
if "#SBATCH hetjob" in line:
1126+
het_job_idx = i
1127+
break
1128+
1129+
before_hetjob = "\n".join(lines[:het_job_idx])
1130+
assert "#SBATCH --gpus-per-task=1" in before_hetjob
1131+
1132+
def test_heterogeneous_with_separate_stderr(self):
1133+
"""Test het job with stderr_to_stdout=False generates error paths."""
1134+
from unittest.mock import Mock
1135+
1136+
executor = SlurmExecutor(
1137+
account="test_account",
1138+
partition="gpu",
1139+
heterogeneous=True,
1140+
)
1141+
executor.stderr_to_stdout = False # Separate stderr
1142+
executor.run_as_group = True
1143+
executor.resource_group = [
1144+
SlurmExecutor.ResourceRequest(
1145+
packager=Mock(),
1146+
nodes=2,
1147+
ntasks_per_node=8,
1148+
gpus_per_node=8,
1149+
container_image="gpu_image",
1150+
container_mounts=["/data:/data"],
1151+
het_group_index=0,
1152+
),
1153+
SlurmExecutor.ResourceRequest(
1154+
packager=Mock(),
1155+
nodes=1,
1156+
ntasks_per_node=1,
1157+
gpus_per_node=0,
1158+
container_image="cpu_image",
1159+
container_mounts=["/data:/data"],
1160+
het_group_index=1,
1161+
),
1162+
]
1163+
executor.tunnel = Mock(spec=SSHTunnel)
1164+
executor.tunnel.job_dir = "/tmp/test_jobs"
1165+
1166+
request = SlurmRayRequest(
1167+
name="test-ray-het-cluster",
1168+
cluster_dir="/tmp/test_jobs/test-ray-het-cluster",
1169+
template_name="ray.sub.j2",
1170+
executor=executor,
1171+
launch_cmd=["sbatch", "--parsable"],
1172+
)
1173+
1174+
script = request.materialize()
1175+
1176+
# Should have separate error output paths for each het group
1177+
assert "#SBATCH --error=" in script
1178+
1179+
def test_heterogeneous_command_groups_without_het_group_index(self):
1180+
"""Test het command groups fallback to idx when het_group_index is None."""
1181+
from unittest.mock import Mock
1182+
1183+
executor = SlurmExecutor(
1184+
account="test_account",
1185+
heterogeneous=True,
1186+
)
1187+
executor.run_as_group = True
1188+
# Resource groups WITHOUT het_group_index set
1189+
executor.resource_group = [
1190+
SlurmExecutor.ResourceRequest(
1191+
packager=Mock(),
1192+
nodes=1,
1193+
ntasks_per_node=8,
1194+
gpus_per_node=8,
1195+
container_image="image1",
1196+
container_mounts=["/data:/data"],
1197+
# het_group_index not set - should fall back to idx
1198+
),
1199+
SlurmExecutor.ResourceRequest(
1200+
packager=Mock(),
1201+
nodes=1,
1202+
ntasks_per_node=1,
1203+
gpus_per_node=0,
1204+
container_image="image2",
1205+
container_mounts=["/data:/data"],
1206+
# het_group_index not set - should fall back to idx
1207+
),
1208+
]
1209+
executor.tunnel = Mock(spec=SSHTunnel)
1210+
executor.tunnel.job_dir = "/tmp/test_jobs"
1211+
1212+
request = SlurmRayRequest(
1213+
name="test-ray-het-cluster",
1214+
cluster_dir="/tmp/test_jobs/test-ray-het-cluster",
1215+
template_name="ray.sub.j2",
1216+
executor=executor,
1217+
command_groups=[["cmd0"], ["cmd1"]],
1218+
launch_cmd=["sbatch", "--parsable"],
1219+
)
1220+
1221+
script = request.materialize()
1222+
1223+
# Should have het-group flags using idx fallback
1224+
assert "--het-group=1" in script # command_groups[1] uses het-group=1 (idx fallback)
1225+
1226+
def test_heterogeneous_without_run_as_group(self):
1227+
"""Test het job without run_as_group does not add het-group flags to commands."""
1228+
from unittest.mock import Mock
1229+
1230+
executor = SlurmExecutor(
1231+
account="test_account",
1232+
heterogeneous=True,
1233+
)
1234+
# run_as_group NOT set
1235+
executor.resource_group = [
1236+
SlurmExecutor.ResourceRequest(
1237+
packager=Mock(),
1238+
nodes=1,
1239+
ntasks_per_node=8,
1240+
gpus_per_node=8,
1241+
container_image="image1",
1242+
container_mounts=["/data:/data"],
1243+
het_group_index=0,
1244+
),
1245+
SlurmExecutor.ResourceRequest(
1246+
packager=Mock(),
1247+
nodes=1,
1248+
ntasks_per_node=1,
1249+
gpus_per_node=0,
1250+
container_image="image2",
1251+
container_mounts=["/data:/data"],
1252+
het_group_index=1,
1253+
),
1254+
]
1255+
executor.tunnel = Mock(spec=SSHTunnel)
1256+
executor.tunnel.job_dir = "/tmp/test_jobs"
1257+
1258+
request = SlurmRayRequest(
1259+
name="test-ray-het-cluster",
1260+
cluster_dir="/tmp/test_jobs/test-ray-het-cluster",
1261+
template_name="ray.sub.j2",
1262+
executor=executor,
1263+
command_groups=[["cmd0"], ["cmd1"]],
1264+
launch_cmd=["sbatch", "--parsable"],
1265+
)
1266+
1267+
script = request.materialize()
1268+
1269+
# SBATCH het job structure should still exist
1270+
assert "#SBATCH hetjob" in script
1271+
# But command groups should NOT have --het-group flags (run_as_group not set)
1272+
# The overlap srun commands should not have --het-group=1
1273+
# Find the overlap srun command
1274+
lines = script.split("\n")
1275+
overlap_srun_lines = [line for line in lines if "overlap" in line and "srun" in line]
1276+
for line in overlap_srun_lines:
1277+
# These should NOT have --het-group since run_as_group is not set
1278+
if "cmd1" in line:
1279+
assert "--het-group=1" not in line
1280+
1281+
def test_heterogeneous_mismatched_command_groups_length(self):
1282+
"""Test het job when command_groups length doesn't match resource_group length."""
1283+
from unittest.mock import Mock
1284+
1285+
executor = SlurmExecutor(
1286+
account="test_account",
1287+
heterogeneous=True,
1288+
)
1289+
executor.run_as_group = True
1290+
executor.resource_group = [
1291+
SlurmExecutor.ResourceRequest(
1292+
packager=Mock(),
1293+
nodes=1,
1294+
ntasks_per_node=8,
1295+
gpus_per_node=8,
1296+
container_image="image1",
1297+
container_mounts=["/data:/data"],
1298+
het_group_index=0,
1299+
),
1300+
SlurmExecutor.ResourceRequest(
1301+
packager=Mock(),
1302+
nodes=1,
1303+
ntasks_per_node=1,
1304+
gpus_per_node=0,
1305+
container_image="image2",
1306+
container_mounts=["/data:/data"],
1307+
het_group_index=1,
1308+
),
1309+
]
1310+
executor.tunnel = Mock(spec=SSHTunnel)
1311+
executor.tunnel.job_dir = "/tmp/test_jobs"
1312+
1313+
# 3 command groups but only 2 resource groups - mismatched
1314+
request = SlurmRayRequest(
1315+
name="test-ray-het-cluster",
1316+
cluster_dir="/tmp/test_jobs/test-ray-het-cluster",
1317+
template_name="ray.sub.j2",
1318+
executor=executor,
1319+
command_groups=[["cmd0"], ["cmd1"], ["cmd2"]],
1320+
launch_cmd=["sbatch", "--parsable"],
1321+
)
1322+
1323+
script = request.materialize()
1324+
1325+
# Should still generate script but WITHOUT het-group flags
1326+
# (because lengths don't match)
1327+
assert "#SBATCH hetjob" in script
1328+
# Overlap commands should NOT have --het-group flags
1329+
assert "--het-group=1" not in script
1330+
assert "--het-group=2" not in script

0 commit comments

Comments
 (0)