Skip to content

Commit 2ccf1c9

Browse files
authored
feat: add het-job support for ray slurm (#407)
1 parent a610e89 commit 2ccf1c9

File tree

6 files changed

+1372
-16
lines changed

6 files changed

+1372
-16
lines changed

nemo_run/run/ray/slurm.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,67 @@ def materialize(self) -> str:
169169
parameters.update(self.executor.additional_parameters)
170170

171171
sbatch_flags = []
172-
assert not self.executor.heterogeneous, "heterogeneous is not supported for ray clusters"
173-
for k in sorted(parameters):
174-
sbatch_flags.append(_as_sbatch_flag(k, parameters[k]))
172+
if self.executor.heterogeneous:
173+
# Validate resource_group exists
174+
assert self.executor.resource_group, "heterogeneous requires resource_group to be set"
175+
assert len(self.executor.resource_group) > 0, "resource_group must not be empty"
176+
177+
# Validate het-group-0 has at least 1 node for Ray head
178+
head_group = self.executor.resource_group[0]
179+
assert head_group.nodes >= 1, "het-group-0 must have at least 1 node for Ray head"
180+
181+
# Determine the final het group index (for hetjob separator placement)
182+
final_group_index = len(self.executor.resource_group) - 1
183+
if self.executor.het_group_indices:
184+
final_group_index = self.executor.het_group_indices.index(
185+
max(self.executor.het_group_indices)
186+
)
187+
188+
# Generate SBATCH blocks for each het group
189+
for i, resource_req in enumerate(self.executor.resource_group):
190+
# Skip duplicate het groups (when het_group_index is shared)
191+
if resource_req.het_group_index is not None:
192+
if (
193+
i > 0
194+
and resource_req.het_group_index
195+
== self.executor.resource_group[i - 1].het_group_index
196+
):
197+
continue
198+
199+
# Build het-specific parameters
200+
het_parameters = parameters.copy()
201+
het_parameters.update(
202+
{
203+
"nodes": resource_req.nodes,
204+
"ntasks_per_node": resource_req.ntasks_per_node,
205+
}
206+
)
207+
208+
# Update job name to include het group index
209+
het_parameters["job_name"] = f"{job_details.job_name}-{i}"
210+
211+
# Only update GPU parameters if they're explicitly set in resource_req
212+
if resource_req.gpus_per_node is not None:
213+
het_parameters["gpus_per_node"] = resource_req.gpus_per_node
214+
if resource_req.gpus_per_task is not None:
215+
het_parameters["gpus_per_task"] = resource_req.gpus_per_task
216+
217+
# Update output/error paths to include het group index
218+
het_parameters["output"] = parameters["output"].replace("%t", str(i))
219+
if "error" in het_parameters:
220+
het_parameters["error"] = parameters["error"].replace("%t", str(i))
221+
222+
# Generate SBATCH flags for this het group
223+
for k in sorted(het_parameters):
224+
sbatch_flags.append(_as_sbatch_flag(k, het_parameters[k]))
225+
226+
# Add hetjob separator (except after last group)
227+
if i != final_group_index:
228+
sbatch_flags.append("#SBATCH hetjob")
229+
else:
230+
# Non-heterogeneous: use existing logic
231+
for k in sorted(parameters):
232+
sbatch_flags.append(_as_sbatch_flag(k, parameters[k]))
175233

176234
if self.executor.dependencies:
177235
slurm_deps = self.executor.parse_deps()
@@ -236,6 +294,8 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
236294
"command_workdir": self.workdir,
237295
"gres_specification": get_gres_specification(),
238296
"ray_log_prefix": ray_log_prefix,
297+
"heterogeneous": self.executor.heterogeneous,
298+
"resource_group": self.executor.resource_group if self.executor.heterogeneous else [],
239299
}
240300

241301
if self.command_groups:
@@ -271,12 +331,24 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
271331
os.path.join(logs_dir, f"{ray_log_prefix}overlap-{idx}.err"),
272332
]
273333

334+
# Determine het group for this command (if heterogeneous)
335+
het_group_flag = []
336+
if self.executor.heterogeneous and self.executor.run_as_group:
337+
if len(self.executor.resource_group) == len(self.command_groups):
338+
# Use resource_group mapping
339+
req = self.executor.resource_group[idx]
340+
het_group_num = (
341+
req.het_group_index if req.het_group_index is not None else idx
342+
)
343+
het_group_flag = [f"--het-group={het_group_num}"]
344+
274345
srun_cmd = " ".join(
275346
list(
276347
map(
277348
lambda arg: arg if isinstance(arg, noquote) else shlex.quote(arg),
278349
[
279350
"srun",
351+
*het_group_flag,
280352
"--output",
281353
noquote(stdout_path),
282354
*stderr_flags,

nemo_run/run/ray/templates/ray.sub.j2

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,13 @@ head_node=${nodes_array[0]}
170170
head_node_ip=${ip_addresses_array[0]}
171171

172172
ip_head=$head_node_ip:$PORT
173+
{%- if heterogeneous %}
174+
175+
# Extract het group hostnames for heterogeneous jobs
176+
{% for i in range(resource_group|length) %}
177+
het_group_host_{{i}}=$(scontrol show hostnames $SLURM_JOB_NODELIST_HET_GROUP_{{i}} | head -n1)
178+
{%- endfor %}
179+
{%- endif %}
173180

174181
{%- if setup_lines %}
175182
{{setup_lines}}
@@ -279,12 +286,12 @@ touch $LOG_DIR/ENDED
279286
exit 1
280287
EOF
281288
)
282-
srun {{ common_srun_args }} --container-name=ray-head --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}head.log bash -x -c "$head_cmd" &
289+
srun {% if heterogeneous %}--het-group=0 {% endif %}{{ common_srun_args }} --container-name=ray-head --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}head.log bash -x -c "$head_cmd" &
283290
SRUN_PIDS["ray-head"]=$!
284291

285292
# Wait for the head node container to start and for Ray to be ready
286293
elapsed_time=0
287-
while ! (srun --overlap --nodes=1 --ntasks=1 -w $head_node test -f $LOG_DIR/STARTED_RAY_HEAD && srun --overlap --container-name=ray-head --nodes=1 --ntasks=1 -w $head_node ray status --address $ip_head 2>/dev/null); do
294+
while ! (srun {% if heterogeneous %}--het-group=0 {% endif %}--overlap --nodes=1 --ntasks=1 -w $head_node test -f $LOG_DIR/STARTED_RAY_HEAD && srun {% if heterogeneous %}--het-group=0 {% endif %}--overlap --container-name=ray-head --nodes=1 --ntasks=1 -w $head_node ray status --address $ip_head 2>/dev/null); do
288295
if [[ $elapsed_time -ge $RAY_HEAD_START_TIMEOUT ]]; then
289296
echo "[ERROR][$(date)] Ray head node failed to start within $RAY_HEAD_START_TIMEOUT seconds. Exiting..."
290297
touch $LOG_DIR/ENDED
@@ -368,7 +375,7 @@ EOF
368375
if [[ $i -eq 0 ]]; then
369376
OVERLAP_HEAD_AND_WORKER_ARG="--overlap"
370377
fi
371-
srun {{ common_srun_args }} ${OVERLAP_HEAD_AND_WORKER_ARG:-} --container-name=ray-worker-$i --exact --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$node_i" -o $LOG_DIR/{{ ray_log_prefix }}worker-$i.log bash -x -c "$worker_cmd" &
378+
srun {% if heterogeneous %}--het-group=0 {% endif %}{{ common_srun_args }} ${OVERLAP_HEAD_AND_WORKER_ARG:-} --container-name=ray-worker-$i --exact --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$node_i" -o $LOG_DIR/{{ ray_log_prefix }}worker-$i.log bash -x -c "$worker_cmd" &
372379
SRUN_PIDS["ray-worker-$i"]=$!
373380
sleep 3
374381
done
@@ -377,7 +384,7 @@ done
377384
# Before we launch a job on this cluster we need to make sure that the bringup is complete
378385
# We do so by querying the number of worker_units in the ray cluster and asserting = NUM_ACTORS
379386
extract_worker_units() {
380-
status_output=$(srun --overlap --container-name=ray-head --nodes=1 --ntasks=1 -w "$head_node" ray status --address $ip_head)
387+
status_output=$(srun {% if heterogeneous %}--het-group=0 {% endif %}--overlap --container-name=ray-head --nodes=1 --ntasks=1 -w "$head_node" ray status --address $ip_head)
381388
if echo "$status_output" | grep -q "worker_units"; then
382389
worker_units=$(echo "$status_output" | grep "worker_units" | awk -F'[/. ]' '{print $4}')
383390
echo $worker_units
@@ -447,18 +454,18 @@ COMMAND="${COMMAND:-{{ command | default('', true) }}}"
447454
COMMAND_WORKDIR={{ command_workdir | default('$CONTAINER_CWD') }}
448455
449456
if [[ -n "$COMMAND" ]]; then
450-
srun --no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}job.log bash -c "$COMMAND"
457+
srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}job.log bash -c "$COMMAND"
451458
else
452459
echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:"
453460
cat <<EOF >$CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh
454461
# No args launches on the head node
455462
WORKER_NUM=\${1:-}
456463
if [[ -z "\$WORKER_NUM" ]]; then
457464
# Empty means we are on the head node
458-
srun --no-container-mount-home --gpus=0 -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" --jobid $SLURM_JOB_ID --pty bash
465+
srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home --gpus=0 -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" --jobid $SLURM_JOB_ID --pty bash
459466
else
460467
nodes_array=($nodes)
461-
srun --no-container-mount-home {%- if gres_specification %}{{gres_specification}}{% endif %} -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID --pty bash
468+
srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home {%- if gres_specification %}{{gres_specification}}{% endif %} -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID --pty bash
462469
fi
463470
EOF
464471
chmod +x $CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh

nemo_run/run/torchx_backend/schedulers/slurm.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,17 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t
113113
srun_cmd = [role.entrypoint] + role.args
114114
srun_cmds.append([" ".join(srun_cmd)])
115115

116+
# For heterogeneous jobs, ensure run_as_group is set for command group mapping
117+
if executor.heterogeneous and executor.resource_group:
118+
executor.run_as_group = True
119+
# Validate that command groups align with resource groups
120+
if len(srun_cmds) != len(executor.resource_group):
121+
log.warning(
122+
f"Heterogeneous job has {len(executor.resource_group)} resource groups "
123+
f"but {len(srun_cmds)} roles. Command groups should match resource groups "
124+
f"for proper het-group mapping."
125+
)
126+
116127
command = [app.roles[0].entrypoint] + app.roles[0].args
117128
# Allow selecting Ray template via environment variable
118129
ray_template_name = os.environ.get("NEMO_RUN_SLURM_RAY_TEMPLATE", "ray.sub.j2")

0 commit comments

Comments
 (0)