1010import torch
1111import torch .distributed
1212from torch .distributed import ReduceOp
13+ from vllm .distributed .device_communicators .pynccl import PyNcclCommunicator
14+ from vllm .distributed .device_communicators .pynccl_wrapper import (
15+ Function ,
16+ NCCLLibrary ,
17+ buffer_type ,
18+ ncclComm_t ,
19+ ncclResult_t ,
20+ )
1321from vllm .distributed .utils import StatelessProcessGroup
1422from vllm .utils import current_stream
1523
@@ -68,6 +76,83 @@ def _common_all_gather_object(comm, device, world_size, object_list, object):
6876 object_list [i ] = _tensor_to_object (tensor , tensor_size )
6977
7078
79+ class ncclConfig_t (ctypes .Structure ):
80+ _fields_ = [
81+ ("size" , ctypes .c_size_t ),
82+ ("magic" , ctypes .c_uint ),
83+ ("version" , ctypes .c_uint ),
84+ ("blocking" , ctypes .c_int ),
85+ ("cgaClusterSize" , ctypes .c_int ),
86+ ("minCTAs" , ctypes .c_int ),
87+ ("maxCTAs" , ctypes .c_int ),
88+ ("netName" , ctypes .c_char_p ),
89+ ("splitShare" , ctypes .c_int ),
90+ ("trafficClass" , ctypes .c_int ),
91+ ("commName" , ctypes .c_char_p ),
92+ ("collnetEnable" , ctypes .c_int ),
93+ ("CTAPolicy" , ctypes .c_int ),
94+ ("shrinkShare" , ctypes .c_int ),
95+ ("nvlsCTAs" , ctypes .c_int ),
96+ ("nChannelsPerNetPeer" , ctypes .c_int ),
97+ ("nvlinkCentricSched" , ctypes .c_int ),
98+ ("graphUsageMode" , ctypes .c_int ),
99+ ("numRmdCtx" , ctypes .c_int ),
100+ ]
101+
102+ nccl_orig_exported_functions = NCCLLibrary .exported_functions
103+ nccl_extended_functions = [
104+ # ncclResult_t ncclCommSplit(
105+ # ncclComm_t comm, int color, int key, ncclComm_t *newcomm, ncclConfig_t *config
106+ # )
107+ Function (
108+ "ncclCommSplit" ,
109+ ncclResult_t ,
110+ [
111+ ncclComm_t ,
112+ ctypes .c_int ,
113+ ctypes .c_int ,
114+ ctypes .POINTER (ncclComm_t ),
115+ ctypes .POINTER (ncclConfig_t ),
116+ ],
117+ ),
118+ ]
119+
120+
121+ def nccl_comm_split (
122+ self ,
123+ comm : ncclComm_t ,
124+ color : int ,
125+ key : int ,
126+ ) -> ncclComm_t :
127+ newcomm = ncclComm_t ()
128+
129+ self .NCCL_CHECK (
130+ self ._funcs ["ncclCommSplit" ](comm , color , key , ctypes .byref (newcomm ), None )
131+ )
132+ return newcomm
133+
134+
135+ # extend NCCLLibrary
136+ NCCLLibrary .exported_functions = nccl_orig_exported_functions + nccl_extended_functions
137+ NCCLLibrary .ncclCommSplit = nccl_comm_split
138+
139+
140+ class PyNcclCommunicatorEx (PyNcclCommunicator ):
141+ def destroy_comm (self , comm = None ):
142+ if comm :
143+ self .nccl .ncclCommDestroy (comm )
144+ else :
145+ self .nccl .ncclCommDestroy (self .comm )
146+
147+ def create_newcomm (self , ranks ):
148+ if self .rank in ranks :
149+ color = 0
150+ else :
151+ color = - 1 # NCCL_SPLIT_NOCOLOR
152+ newcomm = self .nccl .ncclCommSplit (self .comm , color , self .rank )
153+ return newcomm
154+
155+
71156class DistributedNccl :
72157 def __init__ (self ):
73158 self .pg = None
@@ -93,83 +178,88 @@ def init_process_group(
93178 host , port , rank , world_size , store_timeout = int (timeout .total_seconds ())
94179 )
95180
96- from vllm .distributed .device_communicators .pynccl import PyNcclCommunicator
97-
98- self .pynccl = PyNcclCommunicator (group = self .pg , device = self ._device )
181+ self .pynccl = PyNcclCommunicatorEx (group = self .pg , device = self ._device )
182+ self ._comm = self .pynccl .comm
99183
100184 def destroy_process_group (self , group = None ):
101- # PyNcclCommunicator does not provide destroy method
102185 if group in self .sub_groups :
186+ newcomm = ctypes .c_void_p (group )
187+ self .pynccl .destroy_comm (newcomm )
103188 del self .sub_groups [group ]
104189 return
105190
191+ self .pynccl .destroy_comm ()
192+
193+ self .pynccl = None
194+ self .pg = None
195+
106196 def is_initialized (self ) -> bool :
107197 return self .pynccl is not None
108198
109199 def all_gather_object (self , object_list : list [Any ], obj : Any , group = None ):
110200 if group :
111201 assert group in self .sub_groups , "invalid sub_group"
112- pynccl = group .pynccl
113- else :
114- pynccl = self .pynccl
202+ newcomm = ctypes .c_void_p (group )
203+ self .pynccl .comm = newcomm
115204
116- _common_all_gather_object (pynccl , self ._device , self ._world_size , object_list , object )
205+ _common_all_gather_object (self . pynccl , self ._device , self ._world_size , object_list , obj )
117206 current_stream ().synchronize ()
118207
208+ if group :
209+ self .pynccl .comm = self ._comm
210+
119211 def all_reduce (self , tensor : torch .Tensor , op = ReduceOp .SUM , group = None ):
120212 if group :
121213 assert group in self .sub_groups , "invalid sub_group"
122- pynccl = group .pynccl
123- else :
124- pynccl = self .pynccl
214+ newcomm = ctypes .c_void_p (group )
215+ self .pynccl .comm = newcomm
125216
126- out_tensor = pynccl .all_reduce (in_tensor = tensor , op = op )
217+ out_tensor = self . pynccl .all_reduce (in_tensor = tensor , op = op )
127218 current_stream ().synchronize ()
128219 tensor .copy_ (out_tensor )
129220
221+ if group :
222+ self .pynccl .comm = self ._comm
223+
130224 def broadcast (self , tensor : torch .Tensor , src = None , group = None ):
131225 if group :
132226 assert group in self .sub_groups , "invalid sub_group"
133- pynccl = group .pynccl
134- else :
135- pynccl = self .pynccl
227+ assert src in self .sub_groups [group ], "src rank not in group"
228+ newcomm = ctypes .c_void_p (group )
229+ self .pynccl .comm = newcomm
230+ # convert src rank id in default world to newcomm
231+ #src = self.sub_groups[group].index(src)
136232
137- pynccl .broadcast (tensor , src )
233+ self . pynccl .broadcast (tensor , src )
138234 current_stream ().synchronize ()
139235
236+ if group :
237+ self .pynccl .comm = self ._comm
238+
140239 def barrier (self , group = None ):
141240 if group :
142241 assert group in self .sub_groups , "invalid sub_group"
143- pynccl = group .pynccl
144- else :
145- pynccl = self .pynccl
242+ newcomm = ctypes .c_void_p (group )
243+ self .pynccl .comm = newcomm
146244
147245 data = torch .zeros (1 , device = self ._rank )
148- pynccl .all_reduce (data )
246+ self . pynccl .all_reduce (data )
149247 current_stream ().synchronize ()
150248
249+ if group :
250+ self .pynccl .comm = self ._comm
251+
151252 def new_group (self , ranks ):
152253 # ranks is None or []
153254 if not ranks :
154- return self
155-
156- host = self ._host
157- port = self ._port
158- rank = self ._rank
159-
160- if rank not in ranks :
161- return
162-
163- new_rank = ranks .index (rank )
164- new_world_size = len (ranks )
165-
166- new_dist = DistributedNccl ()
167- new_dist .init_process_group (
168- host , port + 10 , new_rank , new_world_size
169- ) # todo host maybe incorrect
170- self .sub_groups .append (new_dist )
171-
172- return new_dist
255+ ranks = list (range (self ._world_size ))
256+
257+ newcomm = self .pynccl .create_newcomm (ranks )
258+ value = 0
259+ if newcomm :
260+ value = newcomm .value
261+ self .sub_groups [value ] = ranks
262+ return value
173263
174264
175265try :
0 commit comments