Skip to content

Commit a435796

Browse files
author
kip-cxj
committed
add dist.use_backend
1 parent 5ab13a1 commit a435796

File tree

6 files changed

+34
-31
lines changed

6 files changed

+34
-31
lines changed

checkpoint_engine/distributed/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
init_process_group,
99
is_initialized,
1010
new_group,
11+
use_backend,
1112
)
1213

1314

@@ -21,4 +22,5 @@
2122
"init_process_group",
2223
"is_initialized",
2324
"new_group",
25+
"use_backend",
2426
]

checkpoint_engine/distributed/base.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def init_process_group(
3939
rank: int,
4040
world_size: int,
4141
timeout: timedelta,
42+
**kwargs,
4243
):
4344
raise NotImplementedError
4445

@@ -100,22 +101,21 @@ def new_group(
100101

101102

102103
class TorchBackend(Distributed):
103-
def __init__(self, backend_type: str):
104-
self.backend_type = backend_type
105-
106104
def init_process_group(
107105
self,
108106
host: str,
109107
port: int,
110108
rank: int,
111109
world_size: int,
112110
timeout: timedelta,
111+
**kwargs,
113112
):
113+
backend = kwargs.get("backend", "nccl")
114114
store = torch.distributed.TCPStore(
115115
host, port, world_size, timeout=timeout, is_master=(rank == 0)
116116
)
117117
torch.distributed.init_process_group(
118-
backend=self.backend_type,
118+
backend=backend,
119119
world_size=world_size,
120120
rank=rank,
121121
timeout=timeout,
@@ -159,7 +159,7 @@ def new_group(self, ranks: list[int], **kwargs) -> DistributedProcessGroup | Non
159159

160160

161161
# specific device instance
162-
_BACKEND_INSTANCE: Distributed = TorchBackend(backend_type="nccl")
162+
_BACKEND_INSTANCE: Distributed = TorchBackend()
163163

164164
_pickler = pickle.Pickler
165165
_unpickler = pickle.Unpickler
@@ -223,33 +223,34 @@ def _common_all_gather_object(
223223
object_list[i] = _tensor_to_object(tensor, tensor_size)
224224

225225

226+
def use_backend(backend: str | None):
227+
global _BACKEND_INSTANCE
228+
229+
if not backend:
230+
return
231+
232+
mapping = {
233+
"vllm_nccl": ".nccl.DistributedNccl",
234+
"vllm_hccl": ".hccl.DistributedHccl",
235+
}
236+
if backend not in mapping:
237+
raise ValueError(f"Unsupported custom backend: {backend}")
238+
239+
module_path, class_name = mapping[backend].rsplit(".", 1)
240+
module = importlib.import_module(module_path, "checkpoint_engine.distributed")
241+
backend_class = getattr(module, class_name)
242+
_BACKEND_INSTANCE = backend_class()
243+
244+
226245
def init_process_group(
227246
host: str,
228247
port: int,
229248
rank: int,
230249
world_size: int,
231-
custom_dist: bool,
232-
backend: str,
233250
timeout: timedelta = timedelta(seconds=300),
251+
**kwargs,
234252
):
235-
global _BACKEND_INSTANCE
236-
237-
if not custom_dist:
238-
_BACKEND_INSTANCE = TorchBackend(backend_type=backend)
239-
else:
240-
mapping = {
241-
"nccl": ".nccl.DistributedNccl",
242-
"hccl": ".hccl.DistributedHccl",
243-
}
244-
if backend not in mapping:
245-
raise ValueError(f"Unsupported custom backend: {backend}")
246-
247-
module_path, class_name = mapping[backend].rsplit(".", 1)
248-
module = importlib.import_module(module_path, "checkpoint_engine.distributed")
249-
backend_class = getattr(module, class_name)
250-
_BACKEND_INSTANCE = backend_class()
251-
252-
_BACKEND_INSTANCE.init_process_group(host, port, rank, world_size, timeout)
253+
_BACKEND_INSTANCE.init_process_group(host, port, rank, world_size, timeout, **kwargs)
253254

254255

255256
def destroy_process_group(group: DistributedProcessGroup | None = None):

checkpoint_engine/distributed/hccl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def init_process_group(
235235
rank: int,
236236
world_size: int,
237237
timeout: timedelta = timedelta(seconds=300),
238+
**kwargs,
238239
):
239240
assert not self.initialized, "already initialized"
240241

checkpoint_engine/distributed/nccl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def init_process_group(
138138
rank: int,
139139
world_size: int,
140140
timeout: timedelta = timedelta(seconds=300),
141+
**kwargs,
141142
):
142143
assert not self.initialized, "already initialized"
143144

checkpoint_engine/ps.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ def __init__(
176176
auto_pg: bool = True,
177177
gpu_count: int | None = None,
178178
mem_fraction: float | None = None,
179-
custom_dist: bool = False,
180179
):
181180
"""
182181
Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
@@ -197,7 +196,6 @@ def __init__(
197196
self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
198197
self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
199198
self._mem_fraction = mem_fraction or float(os.getenv("PS_MEM_FRACTION", "0.9"))
200-
self._custom_dist = custom_dist
201199

202200
assert self._rank is not None and self._rank >= 0, self._rank
203201
assert self._world_size and self._world_size > 0, self._world_size
@@ -498,9 +496,8 @@ def init_process_group(
498496
port=_get_master_port(master_port),
499497
rank=self._rank,
500498
world_size=self._world_size,
501-
custom_dist=self._custom_dist,
502-
backend=self.device_manager.backend,
503499
timeout=timeout,
500+
backend=self.device_manager.backend,
504501
)
505502
logger.info(f"[rank{self._rank}] init process group successfully.")
506503

examples/update.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,14 @@ def join(
159159
parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0")
160160
parser.add_argument("--update-method", type=str, default="broadcast")
161161
parser.add_argument("--uds", type=str, default=None)
162-
parser.add_argument("--custom-dist", action="store_true")
162+
parser.add_argument("--custom-dist", type=str, default=None)
163163
args = parser.parse_args()
164164
rank = int(os.getenv("RANK"))
165165
world_size = int(os.getenv("WORLD_SIZE"))
166166

167167
req_func = req_inference(args.endpoint, args.inference_parallel_size, args.uds)
168-
ps = ParameterServer(auto_pg=True, custom_dist=args.custom_dist)
168+
dist.use_backend(args.custom_dist)
169+
ps = ParameterServer(auto_pg=True)
169170
if args.load_metas_file:
170171
join(
171172
ps,

0 commit comments

Comments
 (0)