Skip to content

Commit c631421

Browse files
committed
refactor: use a shared TCPStore in ParameterServer and create ProcessGroup using PrefixStore
1 parent 15446dd commit c631421

File tree

2 files changed

+23
-37
lines changed

2 files changed

+23
-37
lines changed

checkpoint_engine/ps.py

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ def __init__(
175175
auto_pg: bool = True,
176176
gpu_count: int | None = None,
177177
mem_fraction: float | None = None,
178+
master_addr: str | None = None,
179+
master_port: int | None = None,
178180
):
179181
"""
180182
Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
@@ -228,6 +230,17 @@ def __init__(
228230
self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
229231
self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
230232

233+
master_addr = master_addr or os.getenv("MASTER_ADDR")
234+
assert master_addr, "master_addr is required"
235+
self._store = torch.distributed.TCPStore(
236+
master_addr,
237+
_get_master_port(master_port),
238+
self._world_size,
239+
timeout=timedelta(minutes=10),
240+
is_master=self._rank == 0,
241+
)
242+
self._store_counter = 0
243+
231244
def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]:
232245
if checkpoint_name == self._current_shared_memory_pool_user:
233246
assert self._memory_pool[self.shared_memory_pool_name], (
@@ -487,8 +500,6 @@ def gather_metas(self, checkpoint_name: str):
487500
def init_process_group(
488501
self,
489502
*,
490-
master_addr: str | None = None,
491-
master_port: int | None = None,
492503
timeout: timedelta = timedelta(minutes=10),
493504
):
494505
"""
@@ -498,27 +509,18 @@ def init_process_group(
498509
master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
499510
timeout: The timeout of the process group.
500511
"""
501-
master_addr = master_addr or os.getenv("MASTER_ADDR")
502-
assert master_addr, "master_addr is required"
503-
store = dist.TCPStore(
504-
master_addr,
505-
_get_master_port(master_port),
506-
self._world_size,
507-
timeout=timeout,
508-
is_master=self._rank == 0,
509-
)
512+
self._store_counter += 1
513+
sub_store = torch.distributed.PrefixStore(f"prefix-{self._store_counter}", self._store)
510514
dist.init_process_group(
511515
backend=self.device_manager.backend,
512516
world_size=self._world_size,
513517
rank=self._rank,
514518
timeout=timeout,
515-
store=store,
519+
store=sub_store,
516520
)
517521
logger.info(f"[rank{self._rank}] init process group successfully.")
518522

519-
def store_based_barrier(
520-
self, store: dist.TCPStore, timeout: timedelta = timedelta(minutes=5)
521-
) -> None:
523+
def store_based_barrier(self, timeout: timedelta = timedelta(minutes=5)) -> None:
522524
"""
523525
Perform a store-based barrier synchronization across all ranks.
524526
@@ -531,7 +533,7 @@ def store_based_barrier(
531533
"""
532534
dist.distributed_c10d._store_based_barrier(
533535
rank=self._rank,
534-
store=store,
536+
store=self._store,
535537
group_name="parameter_server_barrier",
536538
rendezvous_count=self._world_size,
537539
timeout=timeout,
@@ -544,8 +546,6 @@ def update(
544546
*,
545547
timeout: timedelta = timedelta(minutes=10),
546548
ranks: list[int] | None = None,
547-
master_addr: str | None = None,
548-
master_port: int | None = None,
549549
) -> None:
550550
"""
551551
Update the checkpoint to inference engine. This function should be called after gather_metas.
@@ -566,28 +566,12 @@ def update(
566566
assert req_func is not None, "req_func is required"
567567
ranks_group = None
568568
try:
569-
master_addr = os.getenv("MASTER_ADDR") or master_addr
570-
assert master_addr, "master_addr is required"
571-
if self._auto_pg:
572-
if not dist.is_initialized():
573-
self.init_process_group(
574-
timeout=timeout, master_addr=master_addr, master_port=master_port
575-
)
576-
manager_store = dist.distributed_c10d._get_default_store()
577-
else:
578-
# HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
579-
# If master_port is provided, use master_port+1 for barrier store
580-
manager_store = dist.TCPStore(
581-
master_addr,
582-
_get_master_port(master_port) + 1,
583-
self._world_size,
584-
timeout=timeout,
585-
is_master=self._rank == 0,
586-
)
569+
if self._auto_pg and not dist.is_initialized():
570+
self.init_process_group(timeout=timeout)
587571
# if ranks is None or [], it will use fully broadcast to update to all ranks
588572
ranks_group = dist.new_group(ranks) if ranks else None
589573
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
590-
self.store_based_barrier(manager_store)
574+
self.store_based_barrier()
591575
except Exception as e:
592576
logger.exception(
593577
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"

tests/test_reuse_pin_memory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def generate_dummy_checkpoint() -> dict[str, torch.Tensor]:
2323
def test_register_pin_memory():
2424
os.environ["RANK"] = "0"
2525
os.environ["WORLD_SIZE"] = "1"
26+
os.environ["MASTER_ADDR"] = "localhost"
27+
os.environ["MASTER_PORT"] = "25400"
2628
ps = ParameterServer()
2729
checkpoint1 = generate_dummy_checkpoint()
2830
checkpoint_shared1 = generate_dummy_checkpoint()

0 commit comments

Comments
 (0)