@@ -169,9 +169,67 @@ def materialize(self) -> str:
169169 parameters .update (self .executor .additional_parameters )
170170
171171 sbatch_flags = []
172- assert not self .executor .heterogeneous , "heterogeneous is not supported for ray clusters"
173- for k in sorted (parameters ):
174- sbatch_flags .append (_as_sbatch_flag (k , parameters [k ]))
172+ if self .executor .heterogeneous :
173+ # Validate resource_group exists
174+ assert self .executor .resource_group , "heterogeneous requires resource_group to be set"
175+ assert len (self .executor .resource_group ) > 0 , "resource_group must not be empty"
176+
177+ # Validate het-group-0 has at least 1 node for Ray head
178+ head_group = self .executor .resource_group [0 ]
179+ assert head_group .nodes >= 1 , "het-group-0 must have at least 1 node for Ray head"
180+
181+ # Determine the final het group index (for hetjob separator placement)
182+ final_group_index = len (self .executor .resource_group ) - 1
183+ if self .executor .het_group_indices :
184+ final_group_index = self .executor .het_group_indices .index (
185+ max (self .executor .het_group_indices )
186+ )
187+
188+ # Generate SBATCH blocks for each het group
189+ for i , resource_req in enumerate (self .executor .resource_group ):
190+ # Skip duplicate het groups (when het_group_index is shared)
191+ if resource_req .het_group_index is not None :
192+ if (
193+ i > 0
194+ and resource_req .het_group_index
195+ == self .executor .resource_group [i - 1 ].het_group_index
196+ ):
197+ continue
198+
199+ # Build het-specific parameters
200+ het_parameters = parameters .copy ()
201+ het_parameters .update (
202+ {
203+ "nodes" : resource_req .nodes ,
204+ "ntasks_per_node" : resource_req .ntasks_per_node ,
205+ }
206+ )
207+
208+ # Update job name to include het group index
209+ het_parameters ["job_name" ] = f"{ job_details .job_name } -{ i } "
210+
211+ # Only update GPU parameters if they're explicitly set in resource_req
212+ if resource_req .gpus_per_node is not None :
213+ het_parameters ["gpus_per_node" ] = resource_req .gpus_per_node
214+ if resource_req .gpus_per_task is not None :
215+ het_parameters ["gpus_per_task" ] = resource_req .gpus_per_task
216+
217+ # Update output/error paths to include het group index
218+ het_parameters ["output" ] = parameters ["output" ].replace ("%t" , str (i ))
219+ if "error" in het_parameters :
220+ het_parameters ["error" ] = parameters ["error" ].replace ("%t" , str (i ))
221+
222+ # Generate SBATCH flags for this het group
223+ for k in sorted (het_parameters ):
224+ sbatch_flags .append (_as_sbatch_flag (k , het_parameters [k ]))
225+
226+ # Add hetjob separator (except after last group)
227+ if i != final_group_index :
228+ sbatch_flags .append ("#SBATCH hetjob" )
229+ else :
230+ # Non-heterogeneous: use existing logic
231+ for k in sorted (parameters ):
232+ sbatch_flags .append (_as_sbatch_flag (k , parameters [k ]))
175233
176234 if self .executor .dependencies :
177235 slurm_deps = self .executor .parse_deps ()
@@ -236,6 +294,8 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
236294 "command_workdir" : self .workdir ,
237295 "gres_specification" : get_gres_specification (),
238296 "ray_log_prefix" : ray_log_prefix ,
297+ "heterogeneous" : self .executor .heterogeneous ,
298+ "resource_group" : self .executor .resource_group if self .executor .heterogeneous else [],
239299 }
240300
241301 if self .command_groups :
@@ -271,12 +331,24 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
271331 os .path .join (logs_dir , f"{ ray_log_prefix } overlap-{ idx } .err" ),
272332 ]
273333
334+ # Determine het group for this command (if heterogeneous)
335+ het_group_flag = []
336+ if self .executor .heterogeneous and self .executor .run_as_group :
337+ if len (self .executor .resource_group ) == len (self .command_groups ):
338+ # Use resource_group mapping
339+ req = self .executor .resource_group [idx ]
340+ het_group_num = (
341+ req .het_group_index if req .het_group_index is not None else idx
342+ )
343+ het_group_flag = [f"--het-group={ het_group_num } " ]
344+
274345 srun_cmd = " " .join (
275346 list (
276347 map (
277348 lambda arg : arg if isinstance (arg , noquote ) else shlex .quote (arg ),
278349 [
279350 "srun" ,
351+ * het_group_flag ,
280352 "--output" ,
281353 noquote (stdout_path ),
282354 * stderr_flags ,
0 commit comments