-
Notifications
You must be signed in to change notification settings - Fork 76
fix: use tcp store_based_barrier to control p2p update synchronization #51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0c8d3f2
f9b5a0f
7c6054c
5a26fbf
214fc86
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -677,20 +677,6 @@ def _get_master_port(master_port: int | None = None) -> int: | |
| return master_port | ||
|
|
||
|
|
||
| def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, int]: | ||
| """ | ||
| map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1), | ||
| which are generated in self.init_process_group_for_ranks | ||
| """ | ||
| bcast_rank_map: dict[int, int] = {} | ||
| if not ranks: | ||
| bcast_rank_map = {r: r for r in range(world_size)} | ||
| else: | ||
| for i, r in enumerate(ranks): | ||
| bcast_rank_map[r] = i | ||
| return bcast_rank_map | ||
|
|
||
|
|
||
| class P2PStore: | ||
| def __init__(self, device_manager: DeviceManager): | ||
| from mooncake.engine import TransferEngine | ||
|
|
@@ -1045,12 +1031,36 @@ def init_process_group( | |
| ) | ||
| logger.info(f"[rank{self._rank}] init process group successfully.") | ||
|
|
||
| def store_based_barrier( | ||
| self, store: dist.TCPStore, timeout: timedelta = timedelta(minutes=5) | ||
| ) -> None: | ||
| """ | ||
| Perform a store-based barrier synchronization across all ranks. | ||
|
|
||
| This barrier uses a TCP store directly rather than a process group, | ||
| allowing all ranks to synchronize regardless of which process group | ||
| they belong to. | ||
|
|
||
| Args: | ||
| store: The TCPStore instance to use for synchronization. | ||
| """ | ||
| dist.distributed_c10d._store_based_barrier( | ||
| rank=self._rank, | ||
| store=store, | ||
| group_name="parameter_server_barrier", | ||
| rendezvous_count=self._world_size, | ||
| timeout=timeout, | ||
| ) | ||
|
|
||
| def update( | ||
| self, | ||
| checkpoint_name: str, | ||
| req_func: Callable[[list[tuple[str, str]]], None], | ||
| *, | ||
| timeout: timedelta = timedelta(minutes=10), | ||
| ranks: list[int] | None = None, | ||
| master_addr: str | None = None, | ||
| master_port: int | None = None, | ||
| ) -> None: | ||
| """ | ||
| Update the checkpoint to inference engine. This function should be called after gather_metas. | ||
|
|
@@ -1062,34 +1072,45 @@ def update( | |
| which is the fastest way to update weights, especially in colocated architecture. | ||
| If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks, | ||
| which is useful in disaggregated architecture. | ||
| master_addr: The master address for process group initialization. If not set, will use env MASTER_ADDR. | ||
| master_port: The master port for process group initialization. If not set, will use _get_master_port to get the port, which will use MASTER_PORT+1. | ||
| timeout: The timeout of the barrier operation. | ||
| """ | ||
| assert req_func is not None, "req_func is required" | ||
| ranks_group = None | ||
| try: | ||
| # if both ranks is None or [], it will use fully broadcast to update to all ranks | ||
| if not ranks: | ||
| if self._auto_pg and not dist.is_initialized(): | ||
| self.init_process_group() | ||
| self._update_per_bucket(checkpoint_name, req_func) | ||
| master_addr = os.getenv("MASTER_ADDR") or master_addr | ||
| assert master_addr, "master_addr is required" | ||
| if self._auto_pg: | ||
| if not dist.is_initialized(): | ||
| self.init_process_group( | ||
| timeout=timeout, master_addr=master_addr, master_port=master_port | ||
| ) | ||
| manager_store = dist.distributed_c10d._get_default_store() | ||
| else: | ||
| if self._auto_pg: | ||
| if dist.is_initialized(): | ||
| dist.destroy_process_group() | ||
| # HACK: wait 2s to ensure destroy is finished | ||
| time.sleep(2) | ||
| self.init_process_group_for_ranks(ranks) | ||
| if self._rank not in ranks: | ||
| return | ||
| self._update_per_bucket(checkpoint_name, req_func, ranks) | ||
|
|
||
| # HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1 | ||
| # If master_port is provided, use master_port+1 for barrier store | ||
| manager_store = dist.TCPStore( | ||
| master_addr, | ||
| _get_master_port(master_port) + 1, | ||
| self._world_size, | ||
| timeout=timeout, | ||
| is_master=self._rank == 0, | ||
| ) | ||
| # if ranks is None or [], it will use fully broadcast to update to all ranks | ||
| ranks_group = dist.new_group(ranks if ranks else None) | ||
| self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks) | ||
| self.store_based_barrier(manager_store) | ||
| except Exception as e: | ||
| logger.exception( | ||
| f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}" | ||
| ) | ||
| raise | ||
| finally: | ||
| if self._auto_pg and (not ranks or self._rank in ranks): | ||
| if ranks_group: | ||
| dist.destroy_process_group(ranks_group) | ||
| if self._auto_pg and dist.is_initialized(): | ||
| dist.destroy_process_group() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this necessary? I think the GPU mem from NCCL may be released after
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. If only
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But, we can only call
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the case where |
||
|
|
||
| self.device_manager.device_module.empty_cache() | ||
| logger.info( | ||
| f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. " | ||
|
|
@@ -1107,7 +1128,9 @@ def zmq_handle(device_uuid: str) -> str: | |
| self._zmq_addr_counter += 1 | ||
| return socket, socket_paths | ||
|
|
||
| def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, bool]: | ||
| def _detect_bucket_size( | ||
| self, ranks_group: dist.ProcessGroup, *, disable_h2d_buffer: bool = False | ||
| ) -> tuple[int, bool]: | ||
| GiB = 1 << 30 # noqa: N806 | ||
| # auto detect bucket size | ||
| tensor = torch.tensor( | ||
|
|
@@ -1123,7 +1146,7 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, | |
| dtype=torch.int64, | ||
| device=self.device_manager.device_type, | ||
| ) | ||
| dist.all_reduce(tensor, op=dist.ReduceOp.MIN) | ||
| dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=ranks_group) | ||
| tensor = tensor.cpu() | ||
| free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item() | ||
| max_tensor_bytes = 0 | ||
|
|
@@ -1186,51 +1209,6 @@ def _copy_to_buffer( | |
| self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens) | ||
| self.device_manager.device_module.synchronize() | ||
|
|
||
| def init_process_group_for_ranks( | ||
| self, | ||
| ranks: list[int], | ||
| *, | ||
| master_port: int | None = None, | ||
| timeout: timedelta = timedelta(minutes=10), | ||
| ): | ||
| """ | ||
| Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group. | ||
|
|
||
| Args: | ||
| ranks: The ranks to initialize the process group. ranks should be a subset of all ranks. | ||
| master_port: The specified port of the master node. If not set, will use _get_master_port to get the port. | ||
| timeout: The timeout of the process group. | ||
| """ | ||
| assert not dist.is_initialized() | ||
| assert ranks, "ranks should be set" | ||
| if self._rank not in ranks: | ||
| return | ||
| assert self._all_hosts, "all_hosts should be set" | ||
| assert len(self._all_hosts) == self._world_size // self._gpu_count, ( | ||
| f"world_size {self._world_size} should be equal to all_hosts {len(self._all_hosts)}" | ||
| ) | ||
| rank = ranks.index(self._rank) | ||
| master_addr = self._all_hosts[ranks[0] // self._gpu_count] | ||
| master_port = _get_master_port(master_port) | ||
| logger.info( | ||
| f"[rank{self._rank}] start to init process group as virtual_rank {rank}, " | ||
| f"master_addr {master_addr}, master_port {master_port}, world_size {len(ranks)}, " | ||
| ) | ||
| # only initialize process group and store for ranks, other nodes are not initialized | ||
| # and will not participate in this update. Since they have registered memory addresses | ||
| # to p2p_store at the beginning, update ranks can directly get the memory addresses | ||
| # from other nodes and put the weights into the buffer. | ||
| store = dist.TCPStore( | ||
| master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout | ||
| ) | ||
| dist.init_process_group( | ||
| backend=self.device_manager.backend, | ||
| world_size=len(ranks), | ||
| rank=rank, | ||
| timeout=timeout, | ||
| store=store, | ||
| ) | ||
|
|
||
| def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]: | ||
| addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr | ||
| metas_list = self._current_global_parameter_metas[owner_rank].memory_buffer_metas_list | ||
|
|
@@ -1260,10 +1238,12 @@ def _update_per_bucket( | |
| self, | ||
| checkpoint_name: str, | ||
| req_func: Callable[[list[tuple[str, str]]], None], | ||
| ranks_group: dist.ProcessGroup, | ||
| ranks: list[int] | None = None, | ||
| ): | ||
| assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty" | ||
| assert dist.is_initialized(), "process group is not initialized" | ||
|
|
||
| # if both ranks is None or [], it will use fully broadcast to update to all ranks | ||
| if not ranks: | ||
| logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}") | ||
|
|
@@ -1281,9 +1261,9 @@ def _update_per_bucket( | |
| if not need_update: | ||
| return | ||
| # first execute a barrier to avoid subsequent device oom | ||
| dist.barrier() | ||
| dist.barrier(group=ranks_group) | ||
|
|
||
| bucket_size, disable_h2d_buffer = self._detect_bucket_size() | ||
| bucket_size, disable_h2d_buffer = self._detect_bucket_size(ranks_group) | ||
| buckets = _gen_h2d_buckets( | ||
| self._current_global_parameter_metas, | ||
| bucket_size, | ||
|
|
@@ -1330,7 +1310,6 @@ def _update_per_bucket( | |
|
|
||
| gidx = 0 | ||
| ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64) | ||
| bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks) | ||
| try: | ||
| for i in range(max_len): | ||
| if i < len(receiver_rank_buckets) and not disable_h2d_buffer: | ||
|
|
@@ -1360,16 +1339,15 @@ def _update_per_bucket( | |
| self._copy_to_buffer(checkpoint_name, bucket, buffer_b) | ||
| else: | ||
| buffer_b.data.copy_(h2d_buffer[: bucket.size]) | ||
| brank = bcast_rank_map[receiver_rank] | ||
| dist.broadcast(buffer_b, src=brank) | ||
| dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group) | ||
| resp = socket.recv() | ||
| if resp != b"": | ||
| msg = resp.decode("utf-8") | ||
| logger.error( | ||
| f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}" | ||
| ) | ||
| ret_code.fill_(1) | ||
| dist.all_reduce(ret_code, op=dist.ReduceOp.SUM) | ||
| dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group) | ||
| self.device_manager.device_module.synchronize() | ||
| if ret_code.item() != 0: | ||
| # quit early if any rank failed | ||
|
|
@@ -1383,7 +1361,7 @@ def _update_per_bucket( | |
| socket.recv() | ||
| finally: | ||
| req_thread.join() | ||
| dist.barrier() | ||
| dist.barrier(group=ranks_group) | ||
| socket.close() | ||
| if ranks and h2d_buffer is not None: | ||
| self._p2p_store.unregister_named_tensors([h2d_buffer_name]) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will cause compatible problem. If user does not use auto pg and init process group only in ranks by using the same logic in like
init_process_group_for_ranks, this will break.But whether we should be compatible with this situation may need to discuss
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would assume that if the user initialize the PG by themselve, the
ranksparam should also correspond to the PG? In which case it should be OK?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm maybe I was wrong.
Is there any document about, in case of
not _auto_pg, which ranks should form a PG & which ranks should callupdate& the meaning ofranks?