diff --git a/checkpoint_engine/distributed/__init__.py b/checkpoint_engine/distributed/__init__.py new file mode 100644 index 0000000..f4f1161 --- /dev/null +++ b/checkpoint_engine/distributed/__init__.py @@ -0,0 +1,23 @@ +from .base import ( + Distributed, + init_process_group, + destroy_process_group, + is_initialized, + all_gather_object, + all_reduce, + broadcast, + barrier, + new_group, +) + +__all__ = [ + "Distributed", + "init_process_group", + "destroy_process_group", + "is_initialized", + "all_gather_object", + "all_reduce", + "broadcast", + "barrier", + "new_group", +] diff --git a/checkpoint_engine/distributed/base.py b/checkpoint_engine/distributed/base.py new file mode 100644 index 0000000..e7b55cb --- /dev/null +++ b/checkpoint_engine/distributed/base.py @@ -0,0 +1,210 @@ +from abc import ABC, abstractmethod +import io +import pickle +from datetime import timedelta +from typing import Any, List +import importlib + +import torch +from torch.distributed import ReduceOp + + +class Distributed(ABC): + @abstractmethod + def init_process_group( + self, + host: str, + port: int, + rank: int, + world_size: int, + timeout: timedelta, + ): + raise NotImplementedError + + @abstractmethod + def destroy_process_group( + self, + group, + ): + raise NotImplementedError + + @abstractmethod + def is_initialized(self) -> bool: + raise NotImplementedError + + @abstractmethod + def all_gather_object( + self, + object_list: list[Any], + obj: Any, + group, + ): + raise NotImplementedError + + @abstractmethod + def all_reduce( + self, + tensor: torch.Tensor, + op :ReduceOp, + group, + ): + raise NotImplementedError + + @abstractmethod + def broadcast( + self, + tensor: torch.Tensor, + src: int, + group, + ): + raise NotImplementedError + + @abstractmethod + def barrier( + self, + group, + ): + raise NotImplementedError + + @abstractmethod + def new_group( + self, + ranks: list[int], + ): + raise NotImplementedError + + +# specific device instance +_BACKEND_INSTANCE = None + +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler + + +def _object_to_tensor(obj, device): + f = io.BytesIO() + _pickler(f).dump(obj) + byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) + byte_tensor = torch.ByteTensor(byte_storage).to(device) + local_size = torch.LongTensor([byte_tensor.numel()]).to(device) + return byte_tensor, local_size + + +def _tensor_to_object(tensor, tensor_size): + tensor = tensor.cpu() + buf = tensor.numpy().tobytes()[:tensor_size] + return _unpickler(io.BytesIO(buf)).load() + + +def _flatten_for_scatter_gather(tensor_list, copy=False): + if not tensor_list: + raise RuntimeError("Received an empty list.") + t = tensor_list[0] + buffer_shape = [len(tensor_list)] + list(t.shape) + + buffer = torch.empty(tuple(buffer_shape), dtype=t.dtype, device=t.device) + if copy: + for i, tensor in enumerate(tensor_list): + buffer[i].copy_(tensor) + return buffer + + +def _common_all_gather_object(comm, device, world_size, object_list, object): + input_tensor, local_size = _object_to_tensor(object, device) + object_sizes_tensor = torch.empty(world_size, dtype=torch.long, device=device) + comm.all_gather(object_sizes_tensor, local_size) + object_size_list = [object_sizes_tensor[i].unsqueeze(dim=0) for i in range(world_size)] + max_object_size = int(max(object_size_list).item()) + input_tensor.resize_(max_object_size) + coalesced_output_tensor = torch.empty( + max_object_size * world_size, dtype=torch.uint8, device=device + ) + + comm.all_gather(coalesced_output_tensor, input_tensor) + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(world_size) + ] + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + tensor_size = object_size_list[i] + object_list[i] = _tensor_to_object(tensor, tensor_size) + + +def init_process_group( + host: str, + port: int, + rank: int, + world_size: int, + device_type: str, + timeout: timedelta = timedelta(seconds=300), +): + global _BACKEND_INSTANCE + + mapping = { + "cuda": ".nccl.DistributedNccl", + "npu": ".hccl.DistributedHccl", + } + + if device_type not in mapping: + raise ValueError(f"Unsupported device type: {device_type}") + + module_path, class_name = mapping[device_type].rsplit(".", 1) + module = importlib.import_module(module_path, ".checkpoint_engine.distributed") + backend_class = getattr(module, class_name) + + _BACKEND_INSTANCE = backend_class() + _BACKEND_INSTANCE.init_process_group(host, port, rank, world_size, timeout) + + +def destroy_process_group(group=None): + if _BACKEND_INSTANCE is None: + raise RuntimeError("distribute module not initialized") + _BACKEND_INSTANCE.destroy_process_group(group) + + +def is_initialized() -> bool: + if _BACKEND_INSTANCE is None: + return False + _BACKEND_INSTANCE.is_initialized() + +def all_gather_object( + object_list: list[Any], + obj: Any, + group=None, +): + if _BACKEND_INSTANCE is None: + raise RuntimeError("distribute module not initialized") + _BACKEND_INSTANCE.all_gather_object(object_list, obj, group) + + +def all_reduce( + tensor: torch.Tensor, + op=ReduceOp.SUM, + group=None, +): + if _BACKEND_INSTANCE is None: + raise RuntimeError("distribute module not initialized") + _BACKEND_INSTANCE.all_reduce(tensor, op, group) + + +def broadcast( + tensor: torch.Tensor, + src= None, + group=None, +): + if _BACKEND_INSTANCE is None: + raise RuntimeError("distribute module not initialized") + _BACKEND_INSTANCE.all_reduce(tensor, src, group) + + +def barrier(group=None): + if _BACKEND_INSTANCE is None: + raise RuntimeError("distribute module not initialized") + _BACKEND_INSTANCE.barrier(group) + + +def new_group(ranks: list[int]): + if _BACKEND_INSTANCE is None: + raise RuntimeError("distribute module not initialized") + _BACKEND_INSTANCE.new_group(ranks) diff --git a/checkpoint_engine/distributed/hccl.py b/checkpoint_engine/distributed/hccl.py new file mode 100644 index 0000000..b469e17 --- /dev/null +++ b/checkpoint_engine/distributed/hccl.py @@ -0,0 +1,341 @@ +import ctypes +from datetime import timedelta +from typing import Any, List, Optional + +import torch +from torch.distributed import ReduceOp +from vllm.distributed.utils import StatelessProcessGroup +from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator +from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import ( + Function, + HCCLLibrary, + aclrtStream_t, + buffer_type, + hcclComm_t, + hcclDataType_t, + hcclDataTypeEnum, + hcclRedOp_t, + hcclRedOpTypeEnum, + hcclResult_t, + hcclUniqueId, +) +from vllm_ascend.utils import current_stream +from checkpoint_engine.distributed.base import Distributed, _common_all_gather_object + + +class HcclCommConfig(ctypes.Structure): + _fields_ = [ + ("size", ctypes.c_size_t), + ("magic_word", ctypes.c_uint32), + ("version", ctypes.c_uint32), + ("reserved", ctypes.c_uint64), + ("hccl_buffer_size", ctypes.c_uint32), + ("hccl_deterministic", ctypes.c_uint32), + ("hccl_comm_name", ctypes.c_char * 128), + ("hccl_udi", ctypes.c_char * 128), + ("hccl_op_expansion_mode", ctypes.c_uint32), + ("hccl_rdma_traffic_class", ctypes.c_uint32), + ("hccl_rdma_service_level", ctypes.c_uint32), + ("hcll_world_rank_id", ctypes.c_uint32), + ("hccl_job_id", ctypes.c_uint64), + ("comm_engine", ctypes.c_int32), + ("thread_num", ctypes.c_uint32), + ("notify_num_per_thread", ctypes.c_uint32), + ("acl_graph_zero_copy_enable", ctypes.c_uint8), + ] + + +orig_exported_functions = HCCLLibrary.exported_functions +extended_functions = [ + # HcclResult HcclAllGather( + # void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType, + # HcclComm comm, alcrtStream stream + # ) + Function( + "HcclAllGather", + hcclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_uint64, + hcclDataType_t, + hcclComm_t, + aclrtStream_t, + ], + ), + # HcclResult HcclCreateSubCommConfig( + # HcclComm *comm, uin32_t rankNum, uint32_t *rankIds, uint64_t subCommId, + # uint32_t subCommRankId, HcclCommConfig *config, HcclComm *subComm + # ) + Function( + "HcclCreateSubCommConfig", + hcclResult_t, + [ + ctypes.POINTER(hcclComm_t), + ctypes.c_uint32, + ctypes.POINTER(ctypes.c_uint32), + ctypes.c_uint64, + ctypes.c_uint32, + ctypes.POINTER(HcclCommConfig), + ctypes.POINTER(hcclComm_t), + ], + ), +] + + +def hccl_all_gather(self, send_buf, recv_buf, count, data_type, comm, stream): + self.HCCL_CHECK( + self._funcs["HcclAllGather"](send_buf, recv_buf, count, data_type, comm, stream) + ) + + +def hccl_create_subcomm_config( + self, comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config +): + subcomm = hcclComm_t() + self.HCCL_CHECK( + self._funcs["HcclCreateSubCommConfig"]( + ctypes.byref(comm), + ranks_size, + c_rank_ids, + subcomm_id, + subcomm_rank, + ctypes.byref(comm_config), + ctypes.byref(subcomm), + ) + ) + return subcomm + + +# extend HCCLLibrary +HCCLLibrary.exported_functions = orig_exported_functions + extended_functions +HCCLLibrary.hcclAllGather = hccl_all_gather +HCCLLibrary.hcclCreateSubCommConfig = hccl_create_subcomm_config + + +class PyHcclCommunicatorEx(PyHcclCommunicator): + def __init__(self, group, device): + super().__init__(group, device) + self.subcomm_id = 1 + + def destroy_comm(self, comm=None): + if comm: + self.hccl.hcclCommDestroy(comm) + else: + self.hccl.hcclCommDestroy(self.comm) + + def all_gather(self, out_tensor: torch.Tensor, in_tensor: torch.Tensor, stream=None): + if self.disabled: + return + assert in_tensor.device == self.device, ( + f"this hccl communicator is created to work on {self.device}, " + f"but the input tensor in on {in_tensor.device}" + ) + if stream is None: + stream = current_stream() + self.hccl.hcclAllGather( + buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + hcclDataTypeEnum.from_torch(in_tensor.dtype), + self.comm, # todo + aclrtStream_t(stream.npu_stream), + ) + return out_tensor + + def create_subcomm(self, ranks): + comm_config = HcclCommConfig( + size=312, + magic_word=0xF0F0F0F0, + version=6, + reserved=0, + hccl_buffer_size=0xFFFFFFFF, + hccl_deterministic=0xFFFFFFFF, + hccl_comm_name=b"\0", + hccl_udi=b"\0", + hccl_op_expansize_mode=0, + hccl_rdma_traffic_class=0xFFFFFFFF, + hccl_rdma_service_level=0xFFFFFFFF, + hccl_world_rank_id=0, + hccl_job_id=0, + comm_engine=-1, + thread_num=0xFFFFFFFF, + notify_num_per_thread=0xFFFFFFFF, + acl_graph_zero_copy_enable=0, + ) + uint32_array = ctypes.c_uint32 * len(ranks) + c_rank_ids = uint32_array(*ranks) + subcomm_rank = ranks.index(self.rank) + ranks_size = len(ranks) + subcomm_id = self.subcomm_id + + subcomm = self.hccl.hcclCreateSubCommConfig( + self.comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config + ) + self.subcomm_id += 1 + return subcomm + + +class DistributedHccl(Distributed): + def __init__(self): + self.pg: StatelessProcessGroup = None + self.pyhccl: PyHcclCommunicatorEx = None + self.sub_groups: dict[int, list[int]] = {} + self.comm: hcclComm_t = None + + self.host: str = None + self.port: int = None + self.rank: int = None + self.world_size: int = None + self.device: torch.device = None + + self.initialized: bool = False + + def init_process_group( + self, + host: str, + port: int, + rank: int, + world_size: int, + timeout: timedelta = timedelta(seconds=300), + ): + assert not self.initialized, "already initialized" + + self.host = host + self.port = port + self.rank = rank + self.world_size = world_size + self.device = torch.device("npu", rank) + + self.pg = StatelessProcessGroup.create( + host, port, rank, world_size, store_timeout=int(timeout.total_seconds()) + ) + self.pyhccl = PyHcclCommunicatorEx(group=self.pg, device=self.device) + self.comm = self.pyhccl.comm + self.initialized = True + + + def destroy_process_group( + self, + group=None, + ): + assert self.initialized, "not initialized" + + if group in self.sub_groups: + subcomm = ctypes.c_void_p(group) + self.pyhccl.destroy_comm(subcomm) + del self.sub_groups[group] + return + + self.pyhccl.destroy_comm() + self.initialized = False + + + def is_initialized(self) -> bool: + return self.initialized + + + def all_gather_object( + self, + object_list: list[Any], + obj: Any, + group=None + ): + assert self.initialized, "not initialized" + + if group: + assert group in self.sub_groups, "invalid sub_group" + subcomm = ctypes.c_void_p(group) + self.pyhccl.comm = subcomm + + _common_all_gather_object(self.pyhccl, self.device, self.world_size, object_list, obj) + current_stream().synchronize() + + if group: + self.pyhccl.comm = self.comm + + + def all_reduce( + self, + tensor: torch.Tensor, + op=ReduceOp.SUM, + group=None + ): + assert self.initialized, "not initialized" + + if group: + assert group in self.sub_groups, "invalid sub_group" + subcomm = ctypes.c_void_p(group) + self.pyhccl.comm = subcomm + + out_tensor = self.pyhccl.all_reduce(tensor, op) + current_stream().synchronize() + tensor.copy_(out_tensor) + + if group: + self.pyhccl.comm = self.comm + + + def broadcast( + self, + tensor: torch.Tensor, + src=None, + group=None + ): + assert self.initialized, "not initialized" + + if group: + assert group in self.sub_groups, "invalid sub_group" + assert src in self.sub_groups[group], "src rank not in group" + subcomm = ctypes.c_void_p(group) + self.pyhccl.comm = subcomm + # convert src rank id in default world to subcomm + src = self.sub_groups[group].index(src) + self.pyhccl.rank = self.sub_groups[group].index(self.rank) + + self.pyhccl.broadcast(tensor, src) + current_stream().synchronize() + + if group: + self.pyhccl.comm = self.comm + self.pyhccl.rank = self.rank + + + def barrier( + self, + group=None + ): + assert self.initialized, "not initialized" + + if group: + assert group in self.sub_groups, "invalid sub_group" + subcomm = ctypes.c_void_p(group) + self.pyhccl.comm = subcomm + + data = torch.zeros(1, device=self.device) + self.pyhccl.all_reduce(data) + current_stream().synchronize() + + if group: + self.pyhccl.comm = self.comm + + + def new_group( + self, + ranks + ): + assert self.initialized, "not initialized" + + # if ranks is None or [], using the world instead + if not ranks: + ranks = list(range(self.world_size)) + + if self.rank not in ranks: + return + + subcomm = self.pyhccl.create_subcomm(ranks) + value = 0 + if subcomm: + value = subcomm.value + self.sub_groups[value] = ranks + return value diff --git a/checkpoint_engine/distributed/nccl.py b/checkpoint_engine/distributed/nccl.py new file mode 100644 index 0000000..9c5199a --- /dev/null +++ b/checkpoint_engine/distributed/nccl.py @@ -0,0 +1,259 @@ +import ctypes +from datetime import timedelta +from typing import Any, List, Optional + +import torch +from torch.distributed import ReduceOp +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.device_communicators.pynccl_wrapper import ( + Function, + NCCLLibrary, + buffer_type, + ncclComm_t, + ncclResult_t, +) +from vllm.distributed.utils import StatelessProcessGroup +from vllm.utils import current_stream +from checkpoint_engine.distributed.base import Distributed, _common_all_gather_object + + +class ncclConfig_t(ctypes.Structure): + _fields_ = [ + ("size", ctypes.c_size_t), + ("magic", ctypes.c_uint), + ("version", ctypes.c_uint), + ("blocking", ctypes.c_int), + ("cgaClusterSize", ctypes.c_int), + ("minCTAs", ctypes.c_int), + ("maxCTAs", ctypes.c_int), + ("netName", ctypes.c_char_p), + ("splitShare", ctypes.c_int), + ("trafficClass", ctypes.c_int), + ("commName", ctypes.c_char_p), + ("collnetEnable", ctypes.c_int), + ("CTAPolicy", ctypes.c_int), + ("shrinkShare", ctypes.c_int), + ("nvlsCTAs", ctypes.c_int), + ("nChannelsPerNetPeer", ctypes.c_int), + ("nvlinkCentricSched", ctypes.c_int), + ("graphUsageMode", ctypes.c_int), + ("numRmaCtx", ctypes.c_int), + ] + + +nccl_orig_exported_functions = NCCLLibrary.exported_functions +nccl_extended_functions = [ + # ncclResult_t ncclCommSplit( + # ncclComm_t comm, int color, int key, ncclComm_t *newcomm, ncclConfig_t *config + # ) + Function( + "ncclCommSplit", + ncclResult_t, + [ + ncclComm_t, + ctypes.c_int, + ctypes.c_int, + ctypes.POINTER(ncclComm_t), + ctypes.POINTER(ncclConfig_t), + ], + ), +] + + +def nccl_comm_split( + self, + comm: ncclComm_t, + color: int, + key: int, +) -> ncclComm_t: + newcomm = ncclComm_t() + + self.NCCL_CHECK(self._funcs["ncclCommSplit"](comm, color, key, ctypes.byref(newcomm), None)) + return newcomm + + +# extend NCCLLibrary +NCCLLibrary.exported_functions = nccl_orig_exported_functions + nccl_extended_functions +NCCLLibrary.ncclCommSplit = nccl_comm_split + + +class PyNcclCommunicatorEx(PyNcclCommunicator): + def destroy_comm(self, comm=None): + if comm: + self.nccl.ncclCommDestroy(comm) + else: + self.nccl.ncclCommDestroy(self.comm) + + def create_newcomm(self, ranks): + if self.rank in ranks: + color = 0 + else: + color = -1 # NCCL_SPLIT_NOCOLOR + newcomm = self.nccl.ncclCommSplit(self.comm, color, self.rank) + return newcomm + + +class DistributedNccl(Distributed): + def __init__(self): + self.pg: StatelessProcessGroup = None + self.pynccl: PyNcclCommunicatorEx = None + self.sub_groups: dict[int, list[int]] = {} + self.comm: ncclComm_t = None + + self.host: str = None + self.port: int = None + self.rank: int = None + self.world_size: int = None + self.device: torch.device = None + + self.initialized: bool = False + + def init_process_group( + self, + host: str, + port: int, + rank: int, + world_size: int, + timeout: timedelta = timedelta(seconds=300), + ): + assert not self.initialized, "already initialized" + + self.host = host + self.port = port + self.rank = rank + self.world_size = world_size + self.device = torch.device("cuda", rank) + + self.pg = StatelessProcessGroup.create( + host, port, rank, world_size, store_timeout=int(timeout.total_seconds()) + ) + + self.pynccl = PyNcclCommunicatorEx(group=self.pg, device=self.device) + self.comm = self.pynccl.comm + self.initialized = True + + + def destroy_process_group( + self, + group=None, + ): + assert self.initialized, "not initialized" + + if group in self.sub_groups: + newcomm = ctypes.c_void_p(group) + self.pynccl.destroy_comm(newcomm) + del self.sub_groups[group] + return + + self.pynccl.destroy_comm() + + self.pynccl = None + self.pg = None + self.initialized = False + + + def is_initialized(self) -> bool: + return self.initialized + + + def all_gather_object( + self, + object_list: list[Any], + obj: Any, + group=None + ): + assert self.initialized, "not initialized" + + if group: + assert group in self.sub_groups, "invalid sub_group" + newcomm = ctypes.c_void_p(group) + self.pynccl.comm = newcomm + + _common_all_gather_object(self.pynccl, self.device, self.world_size, object_list, obj) + current_stream().synchronize() + + if group: + self.pynccl.comm = self.comm + + + def all_reduce( + self, + tensor: torch.Tensor, + op=ReduceOp.SUM, + group=None + ): + assert self.initialized, "not initialized" + + if group: + assert group in self.sub_groups, "invalid sub_group" + newcomm = ctypes.c_void_p(group) + self.pynccl.comm = newcomm + + out_tensor = self.pynccl.all_reduce(in_tensor=tensor, op=op) + current_stream().synchronize() + tensor.copy_(out_tensor) + + if group: + self.pynccl.comm = self.comm + + + def broadcast( + self, + tensor: torch.Tensor, + src=None, + group=None + ): + assert self.initialized, "not initialized" + + if group: + assert group in self.sub_groups, "invalid sub_group" + assert src in self.sub_groups[group], "src rank not in group" + newcomm = ctypes.c_void_p(group) + self.pynccl.comm = newcomm + # convert src rank id in default world to newcomm + src = self.sub_groups[group].index(src) + self.pynccl.rank = self.sub_groups[group].index(self.rank) + + self.pynccl.broadcast(tensor, src) + current_stream().synchronize() + + if group: + self.pynccl.comm = self.comm + self.pynccl.rank = self.rank + + + def barrier( + self, + group=None + ): + assert self.initialized, "not initialized" + + if group: + assert group in self.sub_groups, "invalid sub_group" + newcomm = ctypes.c_void_p(group) + self.pynccl.comm = newcomm + + data = torch.zeros(1, device=self.device) + self.pynccl.all_reduce(data) + current_stream().synchronize() + + if group: + self.pynccl.comm = self.comm + + + def new_group( + self, + ranks + ): + assert self.initialized, "not initialized" + + # ranks is None or [] + if not ranks: + ranks = list(range(self.world_size)) + + newcomm = self.pynccl.create_newcomm(ranks) + value = 0 + if newcomm: + value = newcomm.value + self.sub_groups[value] = ranks + return value diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index e5cd655..570734a 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -175,6 +175,7 @@ def __init__( auto_pg: bool = True, gpu_count: int | None = None, mem_fraction: float | None = None, + device_type: str | None = None, ): """ Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set. @@ -195,6 +196,12 @@ def __init__( self._local_rdma_devices: dict[str, set[int]] = defaultdict(set) self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set) self._mem_fraction = mem_fraction or float(os.getenv("PS_MEM_FRACTION", "0.9")) + global dist + if device_type is not None: + import checkpoint_engine.distributed as dist + self._device_type = device_type + else: + self._device_type = "torch" assert self._rank is not None and self._rank >= 0, self._rank assert self._world_size and self._world_size > 0, self._world_size @@ -490,20 +497,30 @@ def init_process_group( """ master_addr = master_addr or os.getenv("MASTER_ADDR") assert master_addr, "master_addr is required" - store = dist.TCPStore( - master_addr, - _get_master_port(master_port), - self._world_size, - timeout=timeout, - is_master=self._rank == 0, - ) - dist.init_process_group( - backend=self.device_manager.backend, - world_size=self._world_size, - rank=self._rank, - timeout=timeout, - store=store, - ) + if self._device_type == "torch": + store = torch.distributed.TCPStore( + master_addr, + _get_master_port(master_port), + self._world_size, + timeout=timeout, + is_master=self._rank == 0, + ) + torch.distributed.init_process_group( + backend=self.device_manager.backend, + world_size=self._world_size, + rank=self._rank, + timeout=timeout, + store=store, + ) + else: + dist.init_process_group( + host=master_addr, + port=_get_master_port(master_port), + rank=self._rank, + world_size=self._world_size, + device_type=self._device_type, + timeout=timeout, + ) logger.info(f"[rank{self._rank}] init process group successfully.") def store_based_barrier( @@ -519,7 +536,7 @@ def store_based_barrier( Args: store: The TCPStore instance to use for synchronization. """ - dist.distributed_c10d._store_based_barrier( + torch.distributed.distributed_c10d._store_based_barrier( rank=self._rank, store=store, group_name="parameter_server_barrier", @@ -563,21 +580,10 @@ def update( self.init_process_group( timeout=timeout, master_addr=master_addr, master_port=master_port ) - manager_store = dist.distributed_c10d._get_default_store() - else: - # HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1 - # If master_port is provided, use master_port+1 for barrier store - manager_store = dist.TCPStore( - master_addr, - _get_master_port(master_port) + 1, - self._world_size, - timeout=timeout, - is_master=self._rank == 0, - ) # if ranks is None or [], it will use fully broadcast to update to all ranks ranks_group = dist.new_group(ranks) if ranks else None self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks) - self.store_based_barrier(manager_store) + dist.barrier() except Exception as e: logger.exception( f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}" @@ -623,7 +629,7 @@ def _detect_bucket_size( dtype=torch.int64, device=self.device_manager.device_type, ) - dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=ranks_group) + dist.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN, group=ranks_group) tensor = tensor.cpu() free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item() max_tensor_bytes = 0 @@ -834,7 +840,7 @@ def _update_per_bucket( f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}" ) ret_code.fill_(1) - dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group) + dist.all_reduce(ret_code, op=torch.distributed.ReduceOp.SUM, group=ranks_group) self.device_manager.device_module.synchronize() if ret_code.item() != 0: # quit early if any rank failed diff --git a/examples/update.py b/examples/update.py index 51cb189..c24ad0d 100644 --- a/examples/update.py +++ b/examples/update.py @@ -14,7 +14,8 @@ from loguru import logger from safetensors import safe_open -from checkpoint_engine.ps import ParameterServer, request_inference_to_update +from checkpoint_engine.ps import ParameterServer +from checkpoint_engine.api import request_inference_to_update @contextmanager @@ -158,11 +159,16 @@ def join( parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0") parser.add_argument("--update-method", type=str, default="broadcast") parser.add_argument("--uds", type=str, default=None) + parser.add_argument("--device_type", type=str, default=None) args = parser.parse_args() rank = int(os.getenv("RANK")) world_size = int(os.getenv("WORLD_SIZE")) + + if args.device_type is not None: + import checkpoint_engine.distributed as dist + req_func = req_inference(args.endpoint, args.inference_parallel_size, args.uds) - ps = ParameterServer(auto_pg=True) + ps = ParameterServer(auto_pg=True, device_type=args.device_type) if args.load_metas_file: join( ps,