@@ -39,6 +39,7 @@ def init_process_group(
3939 rank : int ,
4040 world_size : int ,
4141 timeout : timedelta ,
42+ ** kwargs ,
4243 ):
4344 raise NotImplementedError
4445
@@ -100,22 +101,21 @@ def new_group(
100101
101102
102103class TorchBackend (Distributed ):
103- def __init__ (self , backend_type : str ):
104- self .backend_type = backend_type
105-
106104 def init_process_group (
107105 self ,
108106 host : str ,
109107 port : int ,
110108 rank : int ,
111109 world_size : int ,
112110 timeout : timedelta ,
111+ ** kwargs ,
113112 ):
113+ backend = kwargs .get ("backend" , "nccl" )
114114 store = torch .distributed .TCPStore (
115115 host , port , world_size , timeout = timeout , is_master = (rank == 0 )
116116 )
117117 torch .distributed .init_process_group (
118- backend = self . backend_type ,
118+ backend = backend ,
119119 world_size = world_size ,
120120 rank = rank ,
121121 timeout = timeout ,
@@ -159,7 +159,7 @@ def new_group(self, ranks: list[int], **kwargs) -> DistributedProcessGroup | Non
159159
160160
161161# specific device instance
162- _BACKEND_INSTANCE : Distributed = TorchBackend (backend_type = "nccl" )
162+ _BACKEND_INSTANCE : Distributed = TorchBackend ()
163163
164164_pickler = pickle .Pickler
165165_unpickler = pickle .Unpickler
@@ -223,33 +223,34 @@ def _common_all_gather_object(
223223 object_list [i ] = _tensor_to_object (tensor , tensor_size )
224224
225225
226+ def use_backend (backend : str | None ):
227+ global _BACKEND_INSTANCE
228+
229+ if not backend :
230+ return
231+
232+ mapping = {
233+ "vllm_nccl" : ".nccl.DistributedNccl" ,
234+ "vllm_hccl" : ".hccl.DistributedHccl" ,
235+ }
236+ if backend not in mapping :
237+ raise ValueError (f"Unsupported custom backend: { backend } " )
238+
239+ module_path , class_name = mapping [backend ].rsplit ("." , 1 )
240+ module = importlib .import_module (module_path , "checkpoint_engine.distributed" )
241+ backend_class = getattr (module , class_name )
242+ _BACKEND_INSTANCE = backend_class ()
243+
244+
226245def init_process_group (
227246 host : str ,
228247 port : int ,
229248 rank : int ,
230249 world_size : int ,
231- custom_dist : bool ,
232- backend : str ,
233250 timeout : timedelta = timedelta (seconds = 300 ),
251+ ** kwargs ,
234252):
235- global _BACKEND_INSTANCE
236-
237- if not custom_dist :
238- _BACKEND_INSTANCE = TorchBackend (backend_type = backend )
239- else :
240- mapping = {
241- "nccl" : ".nccl.DistributedNccl" ,
242- "hccl" : ".hccl.DistributedHccl" ,
243- }
244- if backend not in mapping :
245- raise ValueError (f"Unsupported custom backend: { backend } " )
246-
247- module_path , class_name = mapping [backend ].rsplit ("." , 1 )
248- module = importlib .import_module (module_path , "checkpoint_engine.distributed" )
249- backend_class = getattr (module , class_name )
250- _BACKEND_INSTANCE = backend_class ()
251-
252- _BACKEND_INSTANCE .init_process_group (host , port , rank , world_size , timeout )
253+ _BACKEND_INSTANCE .init_process_group (host , port , rank , world_size , timeout , ** kwargs )
253254
254255
255256def destroy_process_group (group : DistributedProcessGroup | None = None ):
0 commit comments