Skip to content

Commit 5266ac1

Browse files
author
yexin
committed
implement PyNcclCommunicatorEx
1 parent 77f7b57 commit 5266ac1

File tree

1 file changed

+129
-39
lines changed

1 file changed

+129
-39
lines changed

checkpoint_engine/distributed.py

Lines changed: 129 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@
1010
import torch
1111
import torch.distributed
1212
from torch.distributed import ReduceOp
13+
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
14+
from vllm.distributed.device_communicators.pynccl_wrapper import (
15+
Function,
16+
NCCLLibrary,
17+
buffer_type,
18+
ncclComm_t,
19+
ncclResult_t,
20+
)
1321
from vllm.distributed.utils import StatelessProcessGroup
1422
from vllm.utils import current_stream
1523

@@ -68,6 +76,83 @@ def _common_all_gather_object(comm, device, world_size, object_list, object):
6876
object_list[i] = _tensor_to_object(tensor, tensor_size)
6977

7078

79+
class ncclConfig_t(ctypes.Structure):
80+
_fields_ = [
81+
("size", ctypes.c_size_t),
82+
("magic", ctypes.c_uint),
83+
("version", ctypes.c_uint),
84+
("blocking", ctypes.c_int),
85+
("cgaClusterSize", ctypes.c_int),
86+
("minCTAs", ctypes.c_int),
87+
("maxCTAs", ctypes.c_int),
88+
("netName", ctypes.c_char_p),
89+
("splitShare", ctypes.c_int),
90+
("trafficClass", ctypes.c_int),
91+
("commName", ctypes.c_char_p),
92+
("collnetEnable", ctypes.c_int),
93+
("CTAPolicy", ctypes.c_int),
94+
("shrinkShare", ctypes.c_int),
95+
("nvlsCTAs", ctypes.c_int),
96+
("nChannelsPerNetPeer", ctypes.c_int),
97+
("nvlinkCentricSched", ctypes.c_int),
98+
("graphUsageMode", ctypes.c_int),
99+
("numRmdCtx", ctypes.c_int),
100+
]
101+
102+
nccl_orig_exported_functions = NCCLLibrary.exported_functions
103+
nccl_extended_functions = [
104+
# ncclResult_t ncclCommSplit(
105+
# ncclComm_t comm, int color, int key, ncclComm_t *newcomm, ncclConfig_t *config
106+
# )
107+
Function(
108+
"ncclCommSplit",
109+
ncclResult_t,
110+
[
111+
ncclComm_t,
112+
ctypes.c_int,
113+
ctypes.c_int,
114+
ctypes.POINTER(ncclComm_t),
115+
ctypes.POINTER(ncclConfig_t),
116+
],
117+
),
118+
]
119+
120+
121+
def nccl_comm_split(
122+
self,
123+
comm: ncclComm_t,
124+
color: int,
125+
key: int,
126+
) -> ncclComm_t:
127+
newcomm = ncclComm_t()
128+
129+
self.NCCL_CHECK(
130+
self._funcs["ncclCommSplit"](comm, color, key, ctypes.byref(newcomm), None)
131+
)
132+
return newcomm
133+
134+
135+
# extend NCCLLibrary
136+
NCCLLibrary.exported_functions = nccl_orig_exported_functions + nccl_extended_functions
137+
NCCLLibrary.ncclCommSplit = nccl_comm_split
138+
139+
140+
class PyNcclCommunicatorEx(PyNcclCommunicator):
141+
def destroy_comm(self, comm=None):
142+
if comm:
143+
self.nccl.ncclCommDestroy(comm)
144+
else:
145+
self.nccl.ncclCommDestroy(self.comm)
146+
147+
def create_newcomm(self, ranks):
148+
if self.rank in ranks:
149+
color = 0
150+
else:
151+
color = -1 # NCCL_SPLIT_NOCOLOR
152+
newcomm = self.nccl.ncclCommSplit(self.comm, color, self.rank)
153+
return newcomm
154+
155+
71156
class DistributedNccl:
72157
def __init__(self):
73158
self.pg = None
@@ -93,83 +178,88 @@ def init_process_group(
93178
host, port, rank, world_size, store_timeout=int(timeout.total_seconds())
94179
)
95180

96-
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
97-
98-
self.pynccl = PyNcclCommunicator(group=self.pg, device=self._device)
181+
self.pynccl = PyNcclCommunicatorEx(group=self.pg, device=self._device)
182+
self._comm = self.pynccl.comm
99183

100184
def destroy_process_group(self, group=None):
101-
# PyNcclCommunicator does not provide destroy method
102185
if group in self.sub_groups:
186+
newcomm = ctypes.c_void_p(group)
187+
self.pynccl.destroy_comm(newcomm)
103188
del self.sub_groups[group]
104189
return
105190

191+
self.pynccl.destroy_comm()
192+
193+
self.pynccl = None
194+
self.pg = None
195+
106196
def is_initialized(self) -> bool:
107197
return self.pynccl is not None
108198

109199
def all_gather_object(self, object_list: list[Any], obj: Any, group=None):
110200
if group:
111201
assert group in self.sub_groups, "invalid sub_group"
112-
pynccl = group.pynccl
113-
else:
114-
pynccl = self.pynccl
202+
newcomm = ctypes.c_void_p(group)
203+
self.pynccl.comm = newcomm
115204

116-
_common_all_gather_object(pynccl, self._device, self._world_size, object_list, object)
205+
_common_all_gather_object(self.pynccl, self._device, self._world_size, object_list, obj)
117206
current_stream().synchronize()
118207

208+
if group:
209+
self.pynccl.comm = self._comm
210+
119211
def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, group=None):
120212
if group:
121213
assert group in self.sub_groups, "invalid sub_group"
122-
pynccl = group.pynccl
123-
else:
124-
pynccl = self.pynccl
214+
newcomm = ctypes.c_void_p(group)
215+
self.pynccl.comm = newcomm
125216

126-
out_tensor = pynccl.all_reduce(in_tensor=tensor, op=op)
217+
out_tensor = self.pynccl.all_reduce(in_tensor=tensor, op=op)
127218
current_stream().synchronize()
128219
tensor.copy_(out_tensor)
129220

221+
if group:
222+
self.pynccl.comm = self._comm
223+
130224
def broadcast(self, tensor: torch.Tensor, src=None, group=None):
131225
if group:
132226
assert group in self.sub_groups, "invalid sub_group"
133-
pynccl = group.pynccl
134-
else:
135-
pynccl = self.pynccl
227+
assert src in self.sub_groups[group], "src rank not in group"
228+
newcomm = ctypes.c_void_p(group)
229+
self.pynccl.comm = newcomm
230+
# convert src rank id in default world to newcomm
231+
#src = self.sub_groups[group].index(src)
136232

137-
pynccl.broadcast(tensor, src)
233+
self.pynccl.broadcast(tensor, src)
138234
current_stream().synchronize()
139235

236+
if group:
237+
self.pynccl.comm = self._comm
238+
140239
def barrier(self, group=None):
141240
if group:
142241
assert group in self.sub_groups, "invalid sub_group"
143-
pynccl = group.pynccl
144-
else:
145-
pynccl = self.pynccl
242+
newcomm = ctypes.c_void_p(group)
243+
self.pynccl.comm = newcomm
146244

147245
data = torch.zeros(1, device=self._rank)
148-
pynccl.all_reduce(data)
246+
self.pynccl.all_reduce(data)
149247
current_stream().synchronize()
150248

249+
if group:
250+
self.pynccl.comm = self._comm
251+
151252
def new_group(self, ranks):
152253
# ranks is None or []
153254
if not ranks:
154-
return self
155-
156-
host = self._host
157-
port = self._port
158-
rank = self._rank
159-
160-
if rank not in ranks:
161-
return
162-
163-
new_rank = ranks.index(rank)
164-
new_world_size = len(ranks)
165-
166-
new_dist = DistributedNccl()
167-
new_dist.init_process_group(
168-
host, port + 10, new_rank, new_world_size
169-
) # todo host maybe incorrect
170-
self.sub_groups.append(new_dist)
171-
172-
return new_dist
255+
ranks = list(range(self._world_size))
256+
257+
newcomm = self.pynccl.create_newcomm(ranks)
258+
value = 0
259+
if newcomm:
260+
value = newcomm.value
261+
self.sub_groups[value] = ranks
262+
return value
173263

174264

175265
try:

0 commit comments

Comments
 (0)