Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 75 additions & 3 deletions nemo_run/run/ray/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,67 @@ def materialize(self) -> str:
parameters.update(self.executor.additional_parameters)

sbatch_flags = []
assert not self.executor.heterogeneous, "heterogeneous is not supported for ray clusters"
for k in sorted(parameters):
sbatch_flags.append(_as_sbatch_flag(k, parameters[k]))
if self.executor.heterogeneous:
# Validate resource_group exists
assert self.executor.resource_group, "heterogeneous requires resource_group to be set"
assert len(self.executor.resource_group) > 0, "resource_group must not be empty"

# Validate het-group-0 has at least 1 node for Ray head
head_group = self.executor.resource_group[0]
assert head_group.nodes >= 1, "het-group-0 must have at least 1 node for Ray head"

# Determine the final het group index (for hetjob separator placement)
final_group_index = len(self.executor.resource_group) - 1
if self.executor.het_group_indices:
final_group_index = self.executor.het_group_indices.index(
max(self.executor.het_group_indices)
)

# Generate SBATCH blocks for each het group
for i, resource_req in enumerate(self.executor.resource_group):
# Skip duplicate het groups (when het_group_index is shared)
if resource_req.het_group_index is not None:
if (
i > 0
and resource_req.het_group_index
== self.executor.resource_group[i - 1].het_group_index
):
continue

# Build het-specific parameters
het_parameters = parameters.copy()
het_parameters.update(
{
"nodes": resource_req.nodes,
"ntasks_per_node": resource_req.ntasks_per_node,
}
)

# Update job name to include het group index
het_parameters["job_name"] = f"{job_details.job_name}-{i}"

# Only update GPU parameters if they're explicitly set in resource_req
if resource_req.gpus_per_node is not None:
het_parameters["gpus_per_node"] = resource_req.gpus_per_node
if resource_req.gpus_per_task is not None:
het_parameters["gpus_per_task"] = resource_req.gpus_per_task

# Update output/error paths to include het group index
het_parameters["output"] = parameters["output"].replace("%t", str(i))
if "error" in het_parameters:
het_parameters["error"] = parameters["error"].replace("%t", str(i))

# Generate SBATCH flags for this het group
for k in sorted(het_parameters):
sbatch_flags.append(_as_sbatch_flag(k, het_parameters[k]))

# Add hetjob separator (except after last group)
if i != final_group_index:
sbatch_flags.append("#SBATCH hetjob")
else:
# Non-heterogeneous: use existing logic
for k in sorted(parameters):
sbatch_flags.append(_as_sbatch_flag(k, parameters[k]))

if self.executor.dependencies:
slurm_deps = self.executor.parse_deps()
Expand Down Expand Up @@ -238,6 +296,8 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
"command_workdir": self.workdir,
"gres_specification": get_gres_specification(),
"ray_log_prefix": ray_log_prefix,
"heterogeneous": self.executor.heterogeneous,
"resource_group": self.executor.resource_group if self.executor.heterogeneous else [],
}

if self.command_groups:
Expand Down Expand Up @@ -273,12 +333,24 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
os.path.join(logs_dir, f"{ray_log_prefix}overlap-{idx}.err"),
]

# Determine het group for this command (if heterogeneous)
het_group_flag = []
if self.executor.heterogeneous and self.executor.run_as_group:
if len(self.executor.resource_group) == len(self.command_groups):
# Use resource_group mapping
req = self.executor.resource_group[idx]
het_group_num = (
req.het_group_index if req.het_group_index is not None else idx
)
het_group_flag = [f"--het-group={het_group_num}"]

srun_cmd = " ".join(
list(
map(
lambda arg: arg if isinstance(arg, noquote) else shlex.quote(arg),
[
"srun",
*het_group_flag,
"--output",
noquote(stdout_path),
*stderr_flags,
Expand Down
21 changes: 14 additions & 7 deletions nemo_run/run/ray/templates/ray.sub.j2
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ head_node=${nodes_array[0]}
head_node_ip=${ip_addresses_array[0]}

ip_head=$head_node_ip:$PORT
{%- if heterogeneous %}

# Extract het group hostnames for heterogeneous jobs
{% for i in range(resource_group|length) %}
het_group_host_{{i}}=$(scontrol show hostnames $SLURM_JOB_NODELIST_HET_GROUP_{{i}} | head -n1)
{%- endfor %}
{%- endif %}

{%- if setup_lines %}
{{setup_lines}}
Expand Down Expand Up @@ -279,12 +286,12 @@ touch $LOG_DIR/ENDED
exit 1
EOF
)
srun {{ common_srun_args }} --container-name=ray-head --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}head.log bash -x -c "$head_cmd" &
srun {% if heterogeneous %}--het-group=0 {% endif %}{{ common_srun_args }} --container-name=ray-head --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}head.log bash -x -c "$head_cmd" &
SRUN_PIDS["ray-head"]=$!

# Wait for the head node container to start and for Ray to be ready
elapsed_time=0
while ! (srun --overlap --nodes=1 --ntasks=1 -w $head_node test -f $LOG_DIR/STARTED_RAY_HEAD && srun --overlap --container-name=ray-head --nodes=1 --ntasks=1 -w $head_node ray status --address $ip_head 2>/dev/null); do
while ! (srun {% if heterogeneous %}--het-group=0 {% endif %}--overlap --nodes=1 --ntasks=1 -w $head_node test -f $LOG_DIR/STARTED_RAY_HEAD && srun {% if heterogeneous %}--het-group=0 {% endif %}--overlap --container-name=ray-head --nodes=1 --ntasks=1 -w $head_node ray status --address $ip_head 2>/dev/null); do
if [[ $elapsed_time -ge $RAY_HEAD_START_TIMEOUT ]]; then
echo "[ERROR][$(date)] Ray head node failed to start within $RAY_HEAD_START_TIMEOUT seconds. Exiting..."
touch $LOG_DIR/ENDED
Expand Down Expand Up @@ -368,7 +375,7 @@ EOF
if [[ $i -eq 0 ]]; then
OVERLAP_HEAD_AND_WORKER_ARG="--overlap"
fi
srun {{ common_srun_args }} ${OVERLAP_HEAD_AND_WORKER_ARG:-} --container-name=ray-worker-$i --exact --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$node_i" -o $LOG_DIR/{{ ray_log_prefix }}worker-$i.log bash -x -c "$worker_cmd" &
srun {% if heterogeneous %}--het-group=0 {% endif %}{{ common_srun_args }} ${OVERLAP_HEAD_AND_WORKER_ARG:-} --container-name=ray-worker-$i --exact --nodes=1 --ntasks=1 --cpus-per-task=$CPUS_PER_WORKER -w "$node_i" -o $LOG_DIR/{{ ray_log_prefix }}worker-$i.log bash -x -c "$worker_cmd" &
SRUN_PIDS["ray-worker-$i"]=$!
sleep 3
done
Expand All @@ -377,7 +384,7 @@ done
# Before we launch a job on this cluster we need to make sure that the bringup is complete
# We do so by querying the number of worker_units in the ray cluster and asserting = NUM_ACTORS
extract_worker_units() {
status_output=$(srun --overlap --container-name=ray-head --nodes=1 --ntasks=1 -w "$head_node" ray status --address $ip_head)
status_output=$(srun {% if heterogeneous %}--het-group=0 {% endif %}--overlap --container-name=ray-head --nodes=1 --ntasks=1 -w "$head_node" ray status --address $ip_head)
if echo "$status_output" | grep -q "worker_units"; then
worker_units=$(echo "$status_output" | grep "worker_units" | awk -F'[/. ]' '{print $4}')
echo $worker_units
Expand Down Expand Up @@ -447,18 +454,18 @@ COMMAND="${COMMAND:-{{ command | default('', true) }}}"
COMMAND_WORKDIR={{ command_workdir | default('$CONTAINER_CWD') }}

if [[ -n "$COMMAND" ]]; then
srun --no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}job.log bash -c "$COMMAND"
srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}job.log bash -c "$COMMAND"
else
echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:"
cat <<EOF >$CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh
# No args launches on the head node
WORKER_NUM=\${1:-}
if [[ -z "\$WORKER_NUM" ]]; then
# Empty means we are on the head node
srun --no-container-mount-home --gpus=0 -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" --jobid $SLURM_JOB_ID --pty bash
srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home --gpus=0 -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" --jobid $SLURM_JOB_ID --pty bash
else
nodes_array=($nodes)
srun --no-container-mount-home {%- if gres_specification %}{{gres_specification}}{% endif %} -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID --pty bash
srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home {%- if gres_specification %}{{gres_specification}}{% endif %} -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID --pty bash
fi
EOF
chmod +x $CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh
Expand Down
11 changes: 11 additions & 0 deletions nemo_run/run/torchx_backend/schedulers/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,17 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t
srun_cmd = [role.entrypoint] + role.args
srun_cmds.append([" ".join(srun_cmd)])

# For heterogeneous jobs, ensure run_as_group is set for command group mapping
if executor.heterogeneous and executor.resource_group:
executor.run_as_group = True
# Validate that command groups align with resource groups
if len(srun_cmds) != len(executor.resource_group):
log.warning(
f"Heterogeneous job has {len(executor.resource_group)} resource groups "
f"but {len(srun_cmds)} roles. Command groups should match resource groups "
f"for proper het-group mapping."
)

command = [app.roles[0].entrypoint] + app.roles[0].args
# Allow selecting Ray template via environment variable
ray_template_name = os.environ.get("NEMO_RUN_SLURM_RAY_TEMPLATE", "ray.sub.j2")
Expand Down
Loading
Loading