Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 63 additions & 85 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,20 +677,6 @@ def _get_master_port(master_port: int | None = None) -> int:
return master_port


def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, int]:
"""
map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
which are generated in self.init_process_group_for_ranks
"""
bcast_rank_map: dict[int, int] = {}
if not ranks:
bcast_rank_map = {r: r for r in range(world_size)}
else:
for i, r in enumerate(ranks):
bcast_rank_map[r] = i
return bcast_rank_map


class P2PStore:
def __init__(self, device_manager: DeviceManager):
from mooncake.engine import TransferEngine
Expand Down Expand Up @@ -1045,12 +1031,36 @@ def init_process_group(
)
logger.info(f"[rank{self._rank}] init process group successfully.")

def store_based_barrier(
self, store: dist.TCPStore, timeout: timedelta = timedelta(minutes=5)
) -> None:
"""
Perform a store-based barrier synchronization across all ranks.

This barrier uses a TCP store directly rather than a process group,
allowing all ranks to synchronize regardless of which process group
they belong to.

Args:
store: The TCPStore instance to use for synchronization.
"""
dist.distributed_c10d._store_based_barrier(
rank=self._rank,
store=store,
group_name="parameter_server_barrier",
rendezvous_count=self._world_size,
timeout=timeout,
)

def update(
self,
checkpoint_name: str,
req_func: Callable[[list[tuple[str, str]]], None],
*,
timeout: timedelta = timedelta(minutes=10),
ranks: list[int] | None = None,
master_addr: str | None = None,
master_port: int | None = None,
) -> None:
"""
Update the checkpoint to inference engine. This function should be called after gather_metas.
Expand All @@ -1062,34 +1072,45 @@ def update(
which is the fastest way to update weights, especially in colocated architecture.
If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
which is useful in disaggregated architecture.
master_addr: The master address for process group initialization. If not set, will use env MASTER_ADDR.
master_port: The master port for process group initialization. If not set, will use _get_master_port to get the port, which will use MASTER_PORT+1.
timeout: The timeout of the barrier operation.
"""
assert req_func is not None, "req_func is required"
ranks_group = None
try:
# if both ranks is None or [], it will use fully broadcast to update to all ranks
if not ranks:
if self._auto_pg and not dist.is_initialized():
self.init_process_group()
self._update_per_bucket(checkpoint_name, req_func)
master_addr = os.getenv("MASTER_ADDR") or master_addr
assert master_addr, "master_addr is required"
if self._auto_pg:
if not dist.is_initialized():
self.init_process_group(
timeout=timeout, master_addr=master_addr, master_port=master_port
)
manager_store = dist.distributed_c10d._get_default_store()
else:
if self._auto_pg:
if dist.is_initialized():
dist.destroy_process_group()
# HACK: wait 2s to ensure destroy is finished
time.sleep(2)
self.init_process_group_for_ranks(ranks)
if self._rank not in ranks:
return
self._update_per_bucket(checkpoint_name, req_func, ranks)

# HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
# If master_port is provided, use master_port+1 for barrier store
manager_store = dist.TCPStore(
master_addr,
_get_master_port(master_port) + 1,
self._world_size,
timeout=timeout,
is_master=self._rank == 0,
)
# if ranks is None or [], it will use fully broadcast to update to all ranks
ranks_group = dist.new_group(ranks if ranks else None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will cause compatible problem. If user does not use auto pg and init process group only in ranks by using the same logic in like init_process_group_for_ranks, this will break.
But whether we should be compatible with this situation may need to discuss

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would assume that if the user initialize the PG by themselve, the ranks param should also correspond to the PG? In which case it should be OK?

Copy link
Collaborator

@blahgeek blahgeek Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm maybe I was wrong.

Is there any document about, in case of not _auto_pg, which ranks should form a PG & which ranks should call update & the meaning of ranks?

self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
self.store_based_barrier(manager_store)
except Exception as e:
logger.exception(
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
)
raise
finally:
if self._auto_pg and (not ranks or self._rank in ranks):
if ranks_group:
dist.destroy_process_group(ranks_group)
if self._auto_pg and dist.is_initialized():
dist.destroy_process_group()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? I think the GPU mem from NCCL may be released after dist.destroy_process_group(ranks_group) and dist.destroy_process_group() may not be necessary. Please test and check whether my view is correct.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. If only dist.destroy_process_group(ranks_group) is called, 1306MB will remain, while 980MB for both are called

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But, we can only call dist.destroy_process_group(). When no arguments are given, it will destroy all process groups, including ranks_group

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the case where auto_pg == False, I think we'd better not to leave ranks_group not destroyed


self.device_manager.device_module.empty_cache()
logger.info(
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
Expand All @@ -1107,7 +1128,9 @@ def zmq_handle(device_uuid: str) -> str:
self._zmq_addr_counter += 1
return socket, socket_paths

def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, bool]:
def _detect_bucket_size(
self, ranks_group: dist.ProcessGroup, *, disable_h2d_buffer: bool = False
) -> tuple[int, bool]:
GiB = 1 << 30 # noqa: N806
# auto detect bucket size
tensor = torch.tensor(
Expand All @@ -1123,7 +1146,7 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
dtype=torch.int64,
device=self.device_manager.device_type,
)
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=ranks_group)
tensor = tensor.cpu()
free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item()
max_tensor_bytes = 0
Expand Down Expand Up @@ -1186,51 +1209,6 @@ def _copy_to_buffer(
self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
self.device_manager.device_module.synchronize()

def init_process_group_for_ranks(
self,
ranks: list[int],
*,
master_port: int | None = None,
timeout: timedelta = timedelta(minutes=10),
):
"""
Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.

Args:
ranks: The ranks to initialize the process group. ranks should be a subset of all ranks.
master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
timeout: The timeout of the process group.
"""
assert not dist.is_initialized()
assert ranks, "ranks should be set"
if self._rank not in ranks:
return
assert self._all_hosts, "all_hosts should be set"
assert len(self._all_hosts) == self._world_size // self._gpu_count, (
f"world_size {self._world_size} should be equal to all_hosts {len(self._all_hosts)}"
)
rank = ranks.index(self._rank)
master_addr = self._all_hosts[ranks[0] // self._gpu_count]
master_port = _get_master_port(master_port)
logger.info(
f"[rank{self._rank}] start to init process group as virtual_rank {rank}, "
f"master_addr {master_addr}, master_port {master_port}, world_size {len(ranks)}, "
)
# only initialize process group and store for ranks, other nodes are not initialized
# and will not participate in this update. Since they have registered memory addresses
# to p2p_store at the beginning, update ranks can directly get the memory addresses
# from other nodes and put the weights into the buffer.
store = dist.TCPStore(
master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
)
dist.init_process_group(
backend=self.device_manager.backend,
world_size=len(ranks),
rank=rank,
timeout=timeout,
store=store,
)

def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr
metas_list = self._current_global_parameter_metas[owner_rank].memory_buffer_metas_list
Expand Down Expand Up @@ -1260,10 +1238,12 @@ def _update_per_bucket(
self,
checkpoint_name: str,
req_func: Callable[[list[tuple[str, str]]], None],
ranks_group: dist.ProcessGroup,
ranks: list[int] | None = None,
):
assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
assert dist.is_initialized(), "process group is not initialized"

# if both ranks is None or [], it will use fully broadcast to update to all ranks
if not ranks:
logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
Expand All @@ -1281,9 +1261,9 @@ def _update_per_bucket(
if not need_update:
return
# first execute a barrier to avoid subsequent device oom
dist.barrier()
dist.barrier(group=ranks_group)

bucket_size, disable_h2d_buffer = self._detect_bucket_size()
bucket_size, disable_h2d_buffer = self._detect_bucket_size(ranks_group)
buckets = _gen_h2d_buckets(
self._current_global_parameter_metas,
bucket_size,
Expand Down Expand Up @@ -1330,7 +1310,6 @@ def _update_per_bucket(

gidx = 0
ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64)
bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
try:
for i in range(max_len):
if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
Expand Down Expand Up @@ -1360,16 +1339,15 @@ def _update_per_bucket(
self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
else:
buffer_b.data.copy_(h2d_buffer[: bucket.size])
brank = bcast_rank_map[receiver_rank]
dist.broadcast(buffer_b, src=brank)
dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group)
resp = socket.recv()
if resp != b"":
msg = resp.decode("utf-8")
logger.error(
f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
)
ret_code.fill_(1)
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group)
self.device_manager.device_module.synchronize()
if ret_code.item() != 0:
# quit early if any rank failed
Expand All @@ -1383,7 +1361,7 @@ def _update_per_bucket(
socket.recv()
finally:
req_thread.join()
dist.barrier()
dist.barrier(group=ranks_group)
socket.close()
if ranks and h2d_buffer is not None:
self._p2p_store.unregister_named_tensors([h2d_buffer_name])
Expand Down
10 changes: 8 additions & 2 deletions tests/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def error_run(weights: list[tuple[str, torch.Tensor]]):
try:
trigger_error(socket_paths)
except RuntimeError as e:
assert str(e) == "Failed to update weights due to remote errors"
assert str(e) == "Some workers failed to update weights"


def checker_proc(rank: int, device_uuid: str, named_tensors: dict[str, torch.Tensor], queue: Queue):
Expand Down Expand Up @@ -177,7 +177,13 @@ def run(
],
),
("test_with_remote_error", [[]]),
# ("long_test_no_error", [list(random.sample(range(get_world_size()), k=num_ranks)) for num_ranks in range(get_world_size() + 1)]),
(
"test_no_error",
[
list(random.sample(range(get_world_size()), k=num_ranks))
for num_ranks in range(get_world_size() + 1)
],
),
],
)
def test_update(test_name: str, rank_list: list[list[int]] | None):
Expand Down