From 0d1e7c9ec6929a1644e991fb3bacf215906e500f Mon Sep 17 00:00:00 2001 From: kip-cxj Date: Tue, 16 Dec 2025 11:42:53 +0800 Subject: [PATCH 01/14] 1. add collective communication for npu 2. cache uuid in inference engine --- checkpoint_engine/collective.py | 541 ++++++++++++++++++++++++++++++++ 1 file changed, 541 insertions(+) create mode 100644 checkpoint_engine/collective.py diff --git a/checkpoint_engine/collective.py b/checkpoint_engine/collective.py new file mode 100644 index 0000000..dff523c --- /dev/null +++ b/checkpoint_engine/collective.py @@ -0,0 +1,541 @@ +import base64 +import ctypes +import datetime +import io +import logging +import os +import pickle +from enum import Enum +from typing import Any, List + +import torch +import torch_npu + + +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler + + +class ReduceOp(Enum): + SUM = 0 + PRODUCT = 1 + MIN = 2 + MAX = 3 + + +logger = logging.getLogger(__name__) +libhccl = None +try: + libhccl = ctypes.CDLL("libhccl.so") +except OSError: + raise ImportError + + +class HcclRootInfo(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 4108)] + + +buffer_type = ctypes.c_void_p +npuStream_t = ctypes.c_void_p +hcclComm_t = ctypes.c_void_p + + +class HcclDataTypeEnum: + HCCL_DATA_TYPE_INT8 = 0 + HCCL_DATA_TYPE_INT16 = 1 + HCCL_DATA_TYPE_INT32 = 2 + HCCL_DATA_TYPE_FP16 = 3 + HCCL_DATA_TYPE_FP32 = 4 + HCCL_DATA_TYPE_INT64 = 5 + HCCL_DATA_TYPE_UINT8 = 7 + HCCL_DATA_TYPE_FP64 = 10 + HCCL_DATA_TYPE_BFP16 = 11 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + _DTYPE_MAP = { + torch.int8: cls.HCCL_DATA_TYPE_INT8, + torch.int16: cls.HCCL_DATA_TYPE_INT16, + torch.int32: cls.HCCL_DATA_TYPE_INT32, + torch.float16: cls.HCCL_DATA_TYPE_FP16, + torch.float32: cls.HCCL_DATA_TYPE_FP32, + torch.int64: cls.HCCL_DATA_TYPE_INT64, + torch.uint8: cls.HCCL_DATA_TYPE_UINT8, + torch.float64: cls.HCCL_DATA_TYPE_FP64, + torch.bfloat16: cls.HCCL_DATA_TYPE_BFP16, + } + hccl_dtype = _DTYPE_MAP.get(dtype) + if hccl_dtype is None: + raise ValueError(f"Unsupported dtype: {dtype}") + return hccl_dtype + + +class HcclRedOpTypeEnum: + HCCL_REDUCE_SUM = 0 + HCCL_REDUCE_PROD = 1 + HCCL_REDUCE_MAX = 2 + HCCL_REDUCE_MIN = 3 + + @classmethod + def from_base(cls, op: ReduceOp) -> int: + _OP_MAP = { + ReduceOp.SUM: cls.HCCL_REDUCE_SUM, + ReduceOp.PRODUCT: cls.HCCL_REDUCE_PROD, + ReduceOp.MAX: cls.HCCL_REDUCE_MAX, + ReduceOp.MIN: cls.HCCL_REDUCE_MIN, + } + hccl_op = _OP_MAP.get(op) + if hccl_op is None: + raise ValueError(f"Unsupported op: {op}") + return hccl_op + + +_name_map = {} + + +def is_group_exist(group_name: str = "default_group") -> bool: + return group_name in _name_map + + +def create_group( + group_size: int, + rank: int, + device_index: int, + group_name: str = "default_group", + master_addr: str | None = None, + master_port: int | None = None, + store: torch.distributed.TCPStore | None = None, +): + if group_name in _name_map: + return _name_map[group_name] + + g = HCCLGroup(group_size, rank, group_name, device_index, master_addr, master_port, store) + _name_map[group_name] = g + return g + + +def destroy_group(group_name: str = "default_group"): + assert isinstance(group_name, str) + if group_name not in _name_map: + return + + g = _name_map[group_name] + g.destroy() + del _name_map[group_name] + + +def get_handle_by_name(group_name: str): + assert group_name in _name_map, f"{group_name} not in _name_map" + return _name_map[group_name] + + +def get_default_handle(): + return get_handle_by_name("default_group") + + +def get_default_store(): + return get_handle_by_name("default_group").get_tcp_store() + + +class HcclCommConfig(ctypes.Structure): + _fields_ = [ + ("size", ctypes.c_size_t), + ("magic_word", ctypes.c_uint32), + ("version", ctypes.c_uint32), + ("reserved", ctypes.c_uint64), + ("hccl_buffer_size", ctypes.c_uint32), + ("hccl_deterministic", ctypes.c_uint32), + ("hccl_comm_name", ctypes.c_char * 128), + ("hccl_udi", ctypes.c_char * 128), + ("hccl_op_expansion_mode", ctypes.c_uint32), + ("hccl_rdma_traffic_class", ctypes.c_uint32), + ("hccl_rdma_service_level", ctypes.c_uint32), + ("hcll_world_rank_id", ctypes.c_uint32), + ("hccl_job_id", ctypes.c_uint64), + ("comm_engine", ctypes.c_int32), + ("thread_num", ctypes.c_uint32), + ("notify_num_per_thread", ctypes.c_uint32), + ("acl_graph_zero_copy_enable", ctypes.c_uint8), + ] + + +class HCCLGroup: + def __init__( + self, + group_size: int, + rank: int, + group_name: str, + device_index: int, + master_addr: str | None = None, + master_port: int | None = None, + store: torch.distributed.TCPStore | None = None, + ): + """Init an HCCL collective group.""" + + self.group_size = group_size + self.rank = rank + self.group_name = group_name + self.libhccl = libhccl + self.device = torch.device("npu", device_index) + self.store = store + torch.npu.set_device(self.device) + + self.rank_table_file = os.environ.get("RANK_TABLE_FILE", None) + + master_addr = master_addr or os.environ["MASTER_ADDR"] + master_port = master_port or int(os.environ["MASTER_PORT"]) + 100 + if self.store is None: + self.store = torch.distributed.TCPStore( + master_addr, + master_port, + group_size, + is_master=rank == 0, + timeout=datetime.timedelta(seconds=180), + ) + if rank == 0: + root_info = self._generate_hccl_root_info() + root_info_b64 = base64.b64encode(bytes(root_info)).decode("utf-8") + self.store.set(group_name, root_info_b64) + else: + root_info_b64 = self.store.get(group_name) + root_info_bytes = base64.b64decode(root_info_b64) + root_info = HcclRootInfo.from_buffer_copy(bytearray(root_info_bytes)) + + self.comm = self._create_hccl_comm(root_info) + self.stream = torch.npu.Stream() + self.subcomm_id = 1 + self.subcomms = {} + self.initialized = True + + def create_subcomm(self, ranks: list[int] | None) -> int: + assert self.initialized, "Not initialied, maybe destroyed" + + if ranks and self.rank not in ranks: + return 0 + + comm_config = HcclCommConfig( + size=312, + magic_word=0xF0F0F0F0, + version=6, + reserved=0, + hccl_buffer_size=0xFFFFFFFF, + hccl_deterministic=0xFFFFFFFF, + hccl_comm_name=b"\0", + hccl_udi=b"\0", + hccl_op_expansize_mode=0, + hccl_rdma_traffic_class=0xFFFFFFFF, + hccl_rdma_service_level=0xFFFFFFFF, + hccl_world_rank_id=0, + hccl_job_id=0, + comm_engine=-1, + thread_num=0xFFFFFFFF, + notify_num_per_thread=0xFFFFFFFF, + acl_graph_zero_copy_enable=0, + ) + subcomm = hcclComm_t() + subcomm_id = self.subcomm_id + + if ranks: + uint32_array = ctypes.c_uint32 * len(ranks) + c_rank_ids = uint32_array(*ranks) + subcomm_rank = ranks.index(self.rank) + else: + uint32_array = ctypes.c_uint32 * self.group_size + c_rank_ids = uint32_array(*list(range(self.group_size))) + subcomm_rank = self.rank + + ranks_size = len(ranks) if ranks else self.group_size + exec_result = self.libhccl.HcclCreateSubCommConfig( + ctypes.byref(self.comm), + ranks_size, + c_rank_ids, + subcomm_id, + subcomm_rank, + ctypes.byref(comm_config), + ctypes.byref(subcomm), + ) + assert exec_result == 0, ( + f"Failed to execute 'HcclCreateSubCommConfig'. Error code: {exec_result}" + ) + self.subcomms[subcomm_id] = subcomm + self.subcomm_id += 1 + return subcomm_id + + def destroy(self, subcomm_id=None): + if subcomm_id: + assert subcomm_id in self.subcomms, f"{subcomm_id} not in subcomms" + exec_result = self.libhccl.HcclCommDestroy(self.subcomms[subcomm_id]) + assert exec_result == 0, ( + f"Failed to execute 'HcclCommDestroy'. Error code: {exec_result}" + ) + del self.subcomms[subcomm_id] + return + + for _, subcomm in self.subcomms.items(): + exec_result = self.libhccl.HcclCommDestroy(subcomm) + assert exec_result == 0, ( + f"Failed to execute 'HcclCommDestroy'. Error code: {exec_result}" + ) + + exec_result = self.libhccl.HcclCommDestroy(self.comm) + assert exec_result == 0, f"Failed to execute 'HcclCommDestroy'. Error code: {exec_result}" + if self.rank == 0: + self.store.delete_key(self.group_name) + + self.store = None + self.comm = None + self.stream = None + self.subcomm_id = 1 + self.subcomms = {} + self.initialized = False + + def broadcast(self, tensor: torch.Tensor, src: int = 0, subcomm_id=None): + """Broadcast tensors to all other npus following options. + + Args: + tensor: tensor to be broadcast or received. + src: source rank on group. + + Returns: + None + """ + + assert self.initialized, "Not initialied, maybe destroyed" + + if subcomm_id: + assert subcomm_id in self.subcomms, f"{subcomm_id} not in subcomms" + comm = self.subcomms[subcomm_id] + else: + comm = self.comm + + with torch.npu.device(self.device): + exec_result = self.libhccl.HcclBroadcast( + buffer_type(tensor.data_ptr()), + tensor.numel(), + HcclDataTypeEnum.from_torch(tensor.dtype), + src, + comm, + npuStream_t(self.stream.npu_stream), + ) + self.stream.synchronize() + + assert exec_result == 0, f"Failed to execute 'HcclBroadcast'. Error code: {exec_result}." + + def all_gather(self, tensor_list: list[torch.Tensor], tensor: torch.Tensor, subcomm_id=None): + """Allgather tensors across npus into a list of tensors. + + Args: + tensor_list (List[Tensor]): allgathered tensors. + tensor (torch.Tensor): Tensor to be gathered from current process. + + Returns: + None + """ + assert self.initialized, "Not initialied, maybe destroyed" + + if subcomm_id: + assert subcomm_id in self.subcomms, f"{subcomm_id} not in subcomms" + comm = self.subcomms[subcomm_id] + else: + comm = self.comm + + output_flattened = _flatten_for_scatter_gather(tensor_list) + + with torch.npu.device(self.device): + exec_result = self.libhccl.HcclAllGather( + buffer_type(tensor.data_ptr()), + buffer_type(output_flattened.data_ptr()), + tensor.numel(), + HcclDataTypeEnum.from_torch(tensor.dtype), + comm, + npuStream_t(self.stream.npu_stream), + ) + self.stream.synchronize() + assert exec_result == 0, f"Failed to execute 'HcclAllGather'. Error code: {exec_result}." + + for i, x in enumerate(tensor_list): + x.copy_(output_flattened[i]) + + def all_gather_object(self, object_list: list[Any], object: Any, subcomm_id=None): + """Allgather python objects across npus into a list of objects. + + Args: + tensor_list (List[Any]): allgathered python objects. + tensor (Any): python object to be gathered from current process. + + Returns: + None + """ + assert self.initialized, "Not initialied, maybe destroyed" + + input_tensor, local_size = self._object_to_tensor(object, self.device) + object_sizes_tensor = torch.zeros(self.group_size, dtype=torch.long, device=self.device) + object_size_list = [object_sizes_tensor[i].unsqueeze(dim=0) for i in range(self.group_size)] + self.all_gather(object_size_list, local_size, subcomm_id) + max_object_size = int(max(object_size_list).item()) + input_tensor.resize_(max_object_size) + coalesced_output_tensor = torch.empty( + max_object_size * self.group_size, dtype=torch.uint8, device=self.device + ) + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(self.group_size) + ] + self.all_gather(output_tensors, input_tensor, subcomm_id) + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + tensor_size = object_size_list[i] + object_list[i] = self._tensor_to_object(tensor, tensor_size) + + def all_reduce(self, tensor, op=ReduceOp.SUM, subcomm_id=None): + """AllReduce tensor across the collective group following options. + + Args: + tensor: Input and output of the collective. Each tensor must reside on one NPU of the current process. + reduce_op: reduce options. + + Returns: + None + """ + assert self.initialized, "Not initialied, maybe destroyed" + + if subcomm_id: + assert subcomm_id in self.subcomms, f"{subcomm_id} not in subcomms" + comm = self.subcomms[subcomm_id] + else: + comm = self.comm + + with torch.npu.device(self.device): + exec_result = self.libhccl.HcclAllReduce( + buffer_type(tensor.data_ptr()), + buffer_type(tensor.data_ptr()), + tensor.numel(), + HcclDataTypeEnum.from_torch(tensor.dtype), + HcclRedOpTypeEnum.from_base(op), + comm, + npuStream_t(self.stream.npu_stream), + ) + self.stream.synchronize() + assert exec_result == 0, f"Failed to execute 'HcclAllReduce'. Error code: {exec_result}." + + def barrier(self, subcomm_id=None): + """Blocks until all processes reach this barrier. + + Returns: + None + """ + assert self.initialized, "Not initialied, maybe destroyed" + + tensor = torch.empty(1, dtype=torch.int8, device=self.device) + self.all_reduce(tensor, subcomm_id=subcomm_id) + + def get_tcp_store(self): + return self.store + + def _generate_hccl_root_info(self, dev=0): + root_info = HcclRootInfo() + + with torch.npu.device(f"npu:{dev}"): + exec_result = self.libhccl.HcclGetRootInfo(ctypes.byref(root_info)) + assert exec_result == 0, f"Failed to execute 'HcclGetRootInfo'. Error code: {exec_result}." + + return root_info + + def _create_hccl_comm(self, root_info): + comm = hcclComm_t() + + with torch.npu.device(self.device): + if self.rank_table_file is not None: + exec_result = self.libhccl.HcclCommInitClusterInfo( + self.rank_table_file.encode("utf-8"), + self.rank, + ctypes.byref(comm), + ) + else: + exec_result = self.libhccl.HcclCommInitRootInfo( + self.group_size, + ctypes.byref(root_info), + self.rank, + ctypes.byref(comm), + ) + assert exec_result == 0, ( + f"Failed to execute 'HcclCommInitRootInfo'. Error code: {exec_result}" + ) + + return comm + + def _object_to_tensor(self, obj, device): + f = io.BytesIO() + _pickler(f).dump(obj) + byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined] + byte_tensor = torch.ByteTensor(byte_storage).to(device) + local_size = torch.LongTensor([byte_tensor.numel()]).to(device) + return byte_tensor, local_size + + def _tensor_to_object(self, tensor, tensor_size): + tensor = tensor.cpu() + buf = tensor.numpy().tobytes()[:tensor_size] + return _unpickler(io.BytesIO(buf)).load() + + +def _flatten_for_scatter_gather(tensor_list, copy=False): + """Flatten the tensor for gather/scatter operations. + + Args: + tensor_list: the list of tensors to be scattered/gathered. + copy: whether to copy the tensors in tensor_list into the buffer. + + Returns: + The flattened tensor buffer. + """ + if not tensor_list: + raise RuntimeError("Received an empty list.") + t: torch.Tensor = tensor_list[0] + buffer_shape = [len(tensor_list)] + list(t.shape) + + buffer = torch.empty(tuple(buffer_shape), dtype=t.dtype, device=t.device) + if copy: + for i, tensor in enumerate(tensor_list): + buffer[i].copy_(tensor) + return buffer + + +def _check_inputs_compatibility_for_scatter_gather( + tensors: List[torch.Tensor], tensor_lists: List[List[torch.Tensor]] +) -> None: + """Check the compatibility between tensor input and tensor list input.""" + if not tensors or not isinstance(tensors, list): + raise RuntimeError("The first argument 'tensors' expects a list of tensors.") + if not tensor_lists or not isinstance(tensor_lists, list): + raise RuntimeError("The second argument 'tensor_lists' expects a list of tensor list.") + dtype = tensors[0].dtype + shape = list(tensors[0].shape) + for i, tensor_list in enumerate(tensor_lists): + # check all tensor in `tensors` match. + dt = tensors[i].dtype + if dt != dtype: + raise RuntimeError( + "All tensor operands to scatter/gather must " + f"have the same dtype. Got '{dt}' and '{dtype}'." + ) + s = list(tensors[i].shape) + if s != shape: + raise RuntimeError( + "All tensor operands to scatter/gather must " + f"have the same shape. Got '{s}' and '{shape}'." + ) + # check all tensors in `tensor_lists` match. + for t in tensor_lists[i]: + # check dtype + dtl = t.dtype + if dtl != dtype: + raise RuntimeError( + "All tensor operands to scatter/gather must " + f"have the same dtype. Got '{dtl}' and '{dtype}'." + ) + sl = list(t.shape) + if sl != shape: + raise RuntimeError( + "All tensor operands to scatter/gather must " + f"have the same shape. Got '{sl}' and '{shape}'." + ) From 4aa4097338a62cbe055d0a39e34078cdfdc7b422 Mon Sep 17 00:00:00 2001 From: kip-cxj Date: Tue, 30 Dec 2025 17:51:26 +0800 Subject: [PATCH 02/14] add statelesscommgroup --- checkpoint_engine/collective.py | 541 ------------------------------- checkpoint_engine/distributed.py | 490 ++++++++++++++++++++++++++++ checkpoint_engine/ps.py | 52 ++- checkpoint_engine/worker.py | 2 +- 4 files changed, 515 insertions(+), 570 deletions(-) delete mode 100644 checkpoint_engine/collective.py create mode 100644 checkpoint_engine/distributed.py diff --git a/checkpoint_engine/collective.py b/checkpoint_engine/collective.py deleted file mode 100644 index dff523c..0000000 --- a/checkpoint_engine/collective.py +++ /dev/null @@ -1,541 +0,0 @@ -import base64 -import ctypes -import datetime -import io -import logging -import os -import pickle -from enum import Enum -from typing import Any, List - -import torch -import torch_npu - - -_pickler = pickle.Pickler -_unpickler = pickle.Unpickler - - -class ReduceOp(Enum): - SUM = 0 - PRODUCT = 1 - MIN = 2 - MAX = 3 - - -logger = logging.getLogger(__name__) -libhccl = None -try: - libhccl = ctypes.CDLL("libhccl.so") -except OSError: - raise ImportError - - -class HcclRootInfo(ctypes.Structure): - _fields_ = [("internal", ctypes.c_byte * 4108)] - - -buffer_type = ctypes.c_void_p -npuStream_t = ctypes.c_void_p -hcclComm_t = ctypes.c_void_p - - -class HcclDataTypeEnum: - HCCL_DATA_TYPE_INT8 = 0 - HCCL_DATA_TYPE_INT16 = 1 - HCCL_DATA_TYPE_INT32 = 2 - HCCL_DATA_TYPE_FP16 = 3 - HCCL_DATA_TYPE_FP32 = 4 - HCCL_DATA_TYPE_INT64 = 5 - HCCL_DATA_TYPE_UINT8 = 7 - HCCL_DATA_TYPE_FP64 = 10 - HCCL_DATA_TYPE_BFP16 = 11 - - @classmethod - def from_torch(cls, dtype: torch.dtype) -> int: - _DTYPE_MAP = { - torch.int8: cls.HCCL_DATA_TYPE_INT8, - torch.int16: cls.HCCL_DATA_TYPE_INT16, - torch.int32: cls.HCCL_DATA_TYPE_INT32, - torch.float16: cls.HCCL_DATA_TYPE_FP16, - torch.float32: cls.HCCL_DATA_TYPE_FP32, - torch.int64: cls.HCCL_DATA_TYPE_INT64, - torch.uint8: cls.HCCL_DATA_TYPE_UINT8, - torch.float64: cls.HCCL_DATA_TYPE_FP64, - torch.bfloat16: cls.HCCL_DATA_TYPE_BFP16, - } - hccl_dtype = _DTYPE_MAP.get(dtype) - if hccl_dtype is None: - raise ValueError(f"Unsupported dtype: {dtype}") - return hccl_dtype - - -class HcclRedOpTypeEnum: - HCCL_REDUCE_SUM = 0 - HCCL_REDUCE_PROD = 1 - HCCL_REDUCE_MAX = 2 - HCCL_REDUCE_MIN = 3 - - @classmethod - def from_base(cls, op: ReduceOp) -> int: - _OP_MAP = { - ReduceOp.SUM: cls.HCCL_REDUCE_SUM, - ReduceOp.PRODUCT: cls.HCCL_REDUCE_PROD, - ReduceOp.MAX: cls.HCCL_REDUCE_MAX, - ReduceOp.MIN: cls.HCCL_REDUCE_MIN, - } - hccl_op = _OP_MAP.get(op) - if hccl_op is None: - raise ValueError(f"Unsupported op: {op}") - return hccl_op - - -_name_map = {} - - -def is_group_exist(group_name: str = "default_group") -> bool: - return group_name in _name_map - - -def create_group( - group_size: int, - rank: int, - device_index: int, - group_name: str = "default_group", - master_addr: str | None = None, - master_port: int | None = None, - store: torch.distributed.TCPStore | None = None, -): - if group_name in _name_map: - return _name_map[group_name] - - g = HCCLGroup(group_size, rank, group_name, device_index, master_addr, master_port, store) - _name_map[group_name] = g - return g - - -def destroy_group(group_name: str = "default_group"): - assert isinstance(group_name, str) - if group_name not in _name_map: - return - - g = _name_map[group_name] - g.destroy() - del _name_map[group_name] - - -def get_handle_by_name(group_name: str): - assert group_name in _name_map, f"{group_name} not in _name_map" - return _name_map[group_name] - - -def get_default_handle(): - return get_handle_by_name("default_group") - - -def get_default_store(): - return get_handle_by_name("default_group").get_tcp_store() - - -class HcclCommConfig(ctypes.Structure): - _fields_ = [ - ("size", ctypes.c_size_t), - ("magic_word", ctypes.c_uint32), - ("version", ctypes.c_uint32), - ("reserved", ctypes.c_uint64), - ("hccl_buffer_size", ctypes.c_uint32), - ("hccl_deterministic", ctypes.c_uint32), - ("hccl_comm_name", ctypes.c_char * 128), - ("hccl_udi", ctypes.c_char * 128), - ("hccl_op_expansion_mode", ctypes.c_uint32), - ("hccl_rdma_traffic_class", ctypes.c_uint32), - ("hccl_rdma_service_level", ctypes.c_uint32), - ("hcll_world_rank_id", ctypes.c_uint32), - ("hccl_job_id", ctypes.c_uint64), - ("comm_engine", ctypes.c_int32), - ("thread_num", ctypes.c_uint32), - ("notify_num_per_thread", ctypes.c_uint32), - ("acl_graph_zero_copy_enable", ctypes.c_uint8), - ] - - -class HCCLGroup: - def __init__( - self, - group_size: int, - rank: int, - group_name: str, - device_index: int, - master_addr: str | None = None, - master_port: int | None = None, - store: torch.distributed.TCPStore | None = None, - ): - """Init an HCCL collective group.""" - - self.group_size = group_size - self.rank = rank - self.group_name = group_name - self.libhccl = libhccl - self.device = torch.device("npu", device_index) - self.store = store - torch.npu.set_device(self.device) - - self.rank_table_file = os.environ.get("RANK_TABLE_FILE", None) - - master_addr = master_addr or os.environ["MASTER_ADDR"] - master_port = master_port or int(os.environ["MASTER_PORT"]) + 100 - if self.store is None: - self.store = torch.distributed.TCPStore( - master_addr, - master_port, - group_size, - is_master=rank == 0, - timeout=datetime.timedelta(seconds=180), - ) - if rank == 0: - root_info = self._generate_hccl_root_info() - root_info_b64 = base64.b64encode(bytes(root_info)).decode("utf-8") - self.store.set(group_name, root_info_b64) - else: - root_info_b64 = self.store.get(group_name) - root_info_bytes = base64.b64decode(root_info_b64) - root_info = HcclRootInfo.from_buffer_copy(bytearray(root_info_bytes)) - - self.comm = self._create_hccl_comm(root_info) - self.stream = torch.npu.Stream() - self.subcomm_id = 1 - self.subcomms = {} - self.initialized = True - - def create_subcomm(self, ranks: list[int] | None) -> int: - assert self.initialized, "Not initialied, maybe destroyed" - - if ranks and self.rank not in ranks: - return 0 - - comm_config = HcclCommConfig( - size=312, - magic_word=0xF0F0F0F0, - version=6, - reserved=0, - hccl_buffer_size=0xFFFFFFFF, - hccl_deterministic=0xFFFFFFFF, - hccl_comm_name=b"\0", - hccl_udi=b"\0", - hccl_op_expansize_mode=0, - hccl_rdma_traffic_class=0xFFFFFFFF, - hccl_rdma_service_level=0xFFFFFFFF, - hccl_world_rank_id=0, - hccl_job_id=0, - comm_engine=-1, - thread_num=0xFFFFFFFF, - notify_num_per_thread=0xFFFFFFFF, - acl_graph_zero_copy_enable=0, - ) - subcomm = hcclComm_t() - subcomm_id = self.subcomm_id - - if ranks: - uint32_array = ctypes.c_uint32 * len(ranks) - c_rank_ids = uint32_array(*ranks) - subcomm_rank = ranks.index(self.rank) - else: - uint32_array = ctypes.c_uint32 * self.group_size - c_rank_ids = uint32_array(*list(range(self.group_size))) - subcomm_rank = self.rank - - ranks_size = len(ranks) if ranks else self.group_size - exec_result = self.libhccl.HcclCreateSubCommConfig( - ctypes.byref(self.comm), - ranks_size, - c_rank_ids, - subcomm_id, - subcomm_rank, - ctypes.byref(comm_config), - ctypes.byref(subcomm), - ) - assert exec_result == 0, ( - f"Failed to execute 'HcclCreateSubCommConfig'. Error code: {exec_result}" - ) - self.subcomms[subcomm_id] = subcomm - self.subcomm_id += 1 - return subcomm_id - - def destroy(self, subcomm_id=None): - if subcomm_id: - assert subcomm_id in self.subcomms, f"{subcomm_id} not in subcomms" - exec_result = self.libhccl.HcclCommDestroy(self.subcomms[subcomm_id]) - assert exec_result == 0, ( - f"Failed to execute 'HcclCommDestroy'. Error code: {exec_result}" - ) - del self.subcomms[subcomm_id] - return - - for _, subcomm in self.subcomms.items(): - exec_result = self.libhccl.HcclCommDestroy(subcomm) - assert exec_result == 0, ( - f"Failed to execute 'HcclCommDestroy'. Error code: {exec_result}" - ) - - exec_result = self.libhccl.HcclCommDestroy(self.comm) - assert exec_result == 0, f"Failed to execute 'HcclCommDestroy'. Error code: {exec_result}" - if self.rank == 0: - self.store.delete_key(self.group_name) - - self.store = None - self.comm = None - self.stream = None - self.subcomm_id = 1 - self.subcomms = {} - self.initialized = False - - def broadcast(self, tensor: torch.Tensor, src: int = 0, subcomm_id=None): - """Broadcast tensors to all other npus following options. - - Args: - tensor: tensor to be broadcast or received. - src: source rank on group. - - Returns: - None - """ - - assert self.initialized, "Not initialied, maybe destroyed" - - if subcomm_id: - assert subcomm_id in self.subcomms, f"{subcomm_id} not in subcomms" - comm = self.subcomms[subcomm_id] - else: - comm = self.comm - - with torch.npu.device(self.device): - exec_result = self.libhccl.HcclBroadcast( - buffer_type(tensor.data_ptr()), - tensor.numel(), - HcclDataTypeEnum.from_torch(tensor.dtype), - src, - comm, - npuStream_t(self.stream.npu_stream), - ) - self.stream.synchronize() - - assert exec_result == 0, f"Failed to execute 'HcclBroadcast'. Error code: {exec_result}." - - def all_gather(self, tensor_list: list[torch.Tensor], tensor: torch.Tensor, subcomm_id=None): - """Allgather tensors across npus into a list of tensors. - - Args: - tensor_list (List[Tensor]): allgathered tensors. - tensor (torch.Tensor): Tensor to be gathered from current process. - - Returns: - None - """ - assert self.initialized, "Not initialied, maybe destroyed" - - if subcomm_id: - assert subcomm_id in self.subcomms, f"{subcomm_id} not in subcomms" - comm = self.subcomms[subcomm_id] - else: - comm = self.comm - - output_flattened = _flatten_for_scatter_gather(tensor_list) - - with torch.npu.device(self.device): - exec_result = self.libhccl.HcclAllGather( - buffer_type(tensor.data_ptr()), - buffer_type(output_flattened.data_ptr()), - tensor.numel(), - HcclDataTypeEnum.from_torch(tensor.dtype), - comm, - npuStream_t(self.stream.npu_stream), - ) - self.stream.synchronize() - assert exec_result == 0, f"Failed to execute 'HcclAllGather'. Error code: {exec_result}." - - for i, x in enumerate(tensor_list): - x.copy_(output_flattened[i]) - - def all_gather_object(self, object_list: list[Any], object: Any, subcomm_id=None): - """Allgather python objects across npus into a list of objects. - - Args: - tensor_list (List[Any]): allgathered python objects. - tensor (Any): python object to be gathered from current process. - - Returns: - None - """ - assert self.initialized, "Not initialied, maybe destroyed" - - input_tensor, local_size = self._object_to_tensor(object, self.device) - object_sizes_tensor = torch.zeros(self.group_size, dtype=torch.long, device=self.device) - object_size_list = [object_sizes_tensor[i].unsqueeze(dim=0) for i in range(self.group_size)] - self.all_gather(object_size_list, local_size, subcomm_id) - max_object_size = int(max(object_size_list).item()) - input_tensor.resize_(max_object_size) - coalesced_output_tensor = torch.empty( - max_object_size * self.group_size, dtype=torch.uint8, device=self.device - ) - output_tensors = [ - coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] - for i in range(self.group_size) - ] - self.all_gather(output_tensors, input_tensor, subcomm_id) - for i, tensor in enumerate(output_tensors): - tensor = tensor.type(torch.uint8) - tensor_size = object_size_list[i] - object_list[i] = self._tensor_to_object(tensor, tensor_size) - - def all_reduce(self, tensor, op=ReduceOp.SUM, subcomm_id=None): - """AllReduce tensor across the collective group following options. - - Args: - tensor: Input and output of the collective. Each tensor must reside on one NPU of the current process. - reduce_op: reduce options. - - Returns: - None - """ - assert self.initialized, "Not initialied, maybe destroyed" - - if subcomm_id: - assert subcomm_id in self.subcomms, f"{subcomm_id} not in subcomms" - comm = self.subcomms[subcomm_id] - else: - comm = self.comm - - with torch.npu.device(self.device): - exec_result = self.libhccl.HcclAllReduce( - buffer_type(tensor.data_ptr()), - buffer_type(tensor.data_ptr()), - tensor.numel(), - HcclDataTypeEnum.from_torch(tensor.dtype), - HcclRedOpTypeEnum.from_base(op), - comm, - npuStream_t(self.stream.npu_stream), - ) - self.stream.synchronize() - assert exec_result == 0, f"Failed to execute 'HcclAllReduce'. Error code: {exec_result}." - - def barrier(self, subcomm_id=None): - """Blocks until all processes reach this barrier. - - Returns: - None - """ - assert self.initialized, "Not initialied, maybe destroyed" - - tensor = torch.empty(1, dtype=torch.int8, device=self.device) - self.all_reduce(tensor, subcomm_id=subcomm_id) - - def get_tcp_store(self): - return self.store - - def _generate_hccl_root_info(self, dev=0): - root_info = HcclRootInfo() - - with torch.npu.device(f"npu:{dev}"): - exec_result = self.libhccl.HcclGetRootInfo(ctypes.byref(root_info)) - assert exec_result == 0, f"Failed to execute 'HcclGetRootInfo'. Error code: {exec_result}." - - return root_info - - def _create_hccl_comm(self, root_info): - comm = hcclComm_t() - - with torch.npu.device(self.device): - if self.rank_table_file is not None: - exec_result = self.libhccl.HcclCommInitClusterInfo( - self.rank_table_file.encode("utf-8"), - self.rank, - ctypes.byref(comm), - ) - else: - exec_result = self.libhccl.HcclCommInitRootInfo( - self.group_size, - ctypes.byref(root_info), - self.rank, - ctypes.byref(comm), - ) - assert exec_result == 0, ( - f"Failed to execute 'HcclCommInitRootInfo'. Error code: {exec_result}" - ) - - return comm - - def _object_to_tensor(self, obj, device): - f = io.BytesIO() - _pickler(f).dump(obj) - byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined] - byte_tensor = torch.ByteTensor(byte_storage).to(device) - local_size = torch.LongTensor([byte_tensor.numel()]).to(device) - return byte_tensor, local_size - - def _tensor_to_object(self, tensor, tensor_size): - tensor = tensor.cpu() - buf = tensor.numpy().tobytes()[:tensor_size] - return _unpickler(io.BytesIO(buf)).load() - - -def _flatten_for_scatter_gather(tensor_list, copy=False): - """Flatten the tensor for gather/scatter operations. - - Args: - tensor_list: the list of tensors to be scattered/gathered. - copy: whether to copy the tensors in tensor_list into the buffer. - - Returns: - The flattened tensor buffer. - """ - if not tensor_list: - raise RuntimeError("Received an empty list.") - t: torch.Tensor = tensor_list[0] - buffer_shape = [len(tensor_list)] + list(t.shape) - - buffer = torch.empty(tuple(buffer_shape), dtype=t.dtype, device=t.device) - if copy: - for i, tensor in enumerate(tensor_list): - buffer[i].copy_(tensor) - return buffer - - -def _check_inputs_compatibility_for_scatter_gather( - tensors: List[torch.Tensor], tensor_lists: List[List[torch.Tensor]] -) -> None: - """Check the compatibility between tensor input and tensor list input.""" - if not tensors or not isinstance(tensors, list): - raise RuntimeError("The first argument 'tensors' expects a list of tensors.") - if not tensor_lists or not isinstance(tensor_lists, list): - raise RuntimeError("The second argument 'tensor_lists' expects a list of tensor list.") - dtype = tensors[0].dtype - shape = list(tensors[0].shape) - for i, tensor_list in enumerate(tensor_lists): - # check all tensor in `tensors` match. - dt = tensors[i].dtype - if dt != dtype: - raise RuntimeError( - "All tensor operands to scatter/gather must " - f"have the same dtype. Got '{dt}' and '{dtype}'." - ) - s = list(tensors[i].shape) - if s != shape: - raise RuntimeError( - "All tensor operands to scatter/gather must " - f"have the same shape. Got '{s}' and '{shape}'." - ) - # check all tensors in `tensor_lists` match. - for t in tensor_lists[i]: - # check dtype - dtl = t.dtype - if dtl != dtype: - raise RuntimeError( - "All tensor operands to scatter/gather must " - f"have the same dtype. Got '{dtl}' and '{dtype}'." - ) - sl = list(t.shape) - if sl != shape: - raise RuntimeError( - "All tensor operands to scatter/gather must " - f"have the same shape. Got '{sl}' and '{shape}'." - ) diff --git a/checkpoint_engine/distributed.py b/checkpoint_engine/distributed.py new file mode 100644 index 0000000..0f1952e --- /dev/null +++ b/checkpoint_engine/distributed.py @@ -0,0 +1,490 @@ +import ctypes +import io +import logging +import os +import pickle +from enum import Enum +from typing import Any, List, Optional + +import torch +import torch.distributed +from torch.distributed import ReduceOp +from vllm.distributed.utils import StatelessProcessGroup +from vllm.utils import current_stream + + +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler + + +def _object_to_tensor(obj, device): + f = io.BytesIO() + _pickler(f).dump(obj) + byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) + byte_tensor = torch.ByteTensor(byte_storage).to(device) + local_size = torch.LongTensor([byte_tensor.numel()]).to(device) + return byte_tensor, local_size + + +def _tensor_to_object(tensor, tensor_size): + tensor = tensor.cpu() + buf = tensor.numpy().tobytes()[:tensor_size] + return _unpickler(io.BytesIO(buf)).load() + + +def _flatten_for_scatter_gather(tensor_list, copy=False): + if not tensor_list: + raise RuntimeError("Received an empty list.") + t = tensor_list[0] + buffer_shape = [len(tensor_list)] + list(t.shape) + + buffer = torch.empty(tuple(buffer_shape), dtype=t.dtype, device=t.device) + if copy: + for i, tensor in enumerate(tensor_list): + buffer[i].copy_(tensor) + return buffer + + +def _common_all_gather_object(comm, device, world_size, object_list, object): + input_tensor, local_size = _object_to_tensor(object, device) + object_sizes_tensor = torch.empty(world_size, dtype=torch.long, device=device) + comm.all_gather(object_sizes_tensor, local_size) + object_size_list = [object_sizes_tensor[i].unsqueeze(dim=0) for i in range(world_size)] + max_object_size = int(max(object_size_list).item()) + input_tensor.resize_(max_object_size) + coalesced_output_tensor = torch.empty( + max_object_size * world_size, dtype=torch.uint8, device=device + ) + + comm.all_gather(coalesced_output_tensor, input_tensor) + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(world_size) + ] + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + tensor_size = object_size_list[i] + object_list[i] = _tensor_to_object(tensor, tensor_size) + + +class DistributedNccl: + def __init__(self): + self.pg = None + self.pynccl = None + self.sub_groups = {} + + def init_process_group( + self, + host: str, + port: int, + rank: int, + world_size: int, + timeout: int = 300, + **kwargs, + ): + self._host = host + self._port = port + self._rank = rank + self._world_size = world_size + self._device = torch.device("cuda", rank) + + self.pg = StatelessProcessGroup.create(host, port, rank, world_size, store_timeout=timeout) + + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + + self.pynccl = PyNcclCommunicator(group=self.pg, device=self._device) + + def destroy_process_group(self, group=None): + # PyNcclCommunicator does not provide destroy method + if group in self.sub_groups: + del self.sub_groups[group] + return + + def is_initialized(self) -> bool: + return self.pynccl is not None + + def all_gather_object(self, object_list: list[Any], obj: Any, group=None): + if group: + assert group in self.sub_groups, "invalid sub_group" + pynccl = group.pynccl + else: + pynccl = self.pynccl + + _common_all_gather_object(pynccl, self._device, self._world_size, object_list, object) + current_stream().synchronize() + + def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, group=None): + if group: + assert group in self.sub_groups, "invalid sub_group" + pynccl = group.pynccl + else: + pynccl = self.pynccl + + out_tensor = pynccl.all_reduce(in_tensor=tensor, op=op) + current_stream().synchronize() + tensor.copy_(out_tensor) + + def broadcast(self, tensor: torch.Tensor, src=None, group=None): + if group: + assert group in self.sub_groups, "invalid sub_group" + pynccl = group.pynccl + else: + pynccl = self.pynccl + + pynccl.broadcast(tensor, src) + current_stream().synchronize() + + def barrier(self, group=None): + if group: + assert group in self.sub_groups, "invalid sub_group" + pynccl = group.pynccl + else: + pynccl = self.pynccl + + data = torch.zeros(1, device=self._rank) + pynccl.all_reduce(data) + current_stream().synchronize() + + def new_group(self, ranks): + # ranks is None or [] + if not ranks: + return self + + host = self._host + port = self._port + rank = self._rank + + if rank not in ranks: + return + + new_rank = ranks.index(rank) + new_world_size = len(ranks) + + new_dist = DistributedNccl() + new_dist.init_process_group( + host, port + 10, new_rank, new_world_size + ) # todo host maybe incorrect + self.sub_groups.append(new_dist) + + return new_dist + + +try: + from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator + from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import ( + Function, + HCCLLibrary, + aclrtStream_t, + buffer_type, + hcclComm_t, + hcclDataType_t, + hcclDataTypeEnum, + hcclRedOp_t, + hcclRedOpTypeEnum, + hcclResult_t, + hcclUniqueId, + ) + from vllm_ascend.utils import current_stream + + orig_exported_functions = HCCLLibrary.exported_functions + extended_functions = [ + # HcclResult HcclAllGather( + # void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType, + # HcclComm comm, alcrtStream stream + # ) + Function( + "HcclAllGather", + hcclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_uint64, + hcclDataType_t, + hcclComm_t, + aclrtStream_t, + ], + ), + # HcclResult HcclCreateSubCommConfig( + # HcclComm *comm, uin32_t rankNum, uint32_t *rankIds, uint64_t subCommId, + # uint32_t subCommRankId, HcclCommConfig *config, HcclComm *subComm + # ) + Function( + "HcclCreateSubCommConfig", + hcclResult_t, + [ + ctypes.POINTER(hcclComm_t), + ctypes.c_uint32, + ctypes.POINTER(ctypes.c_uint32), + ctypes.c_uint64, + ctypes.c_uint32, + ctypes.POINTER(hcclUniqueId), + ctypes.POINTER(hcclComm_t), + ], + ), + ] + + def hccl_all_gather(self, send_buf, recv_buf, count, data_type, comm, stream): + self.HCCL_CHECK( + self._funcs["HcclAllGather"](send_buf, recv_buf, count, data_type, comm, stream) + ) + + class HcclCommConfig(ctypes.Structure): + _fields_ = [ + ("size", ctypes.c_size_t), + ("magic_word", ctypes.c_uint32), + ("version", ctypes.c_uint32), + ("reserved", ctypes.c_uint64), + ("hccl_buffer_size", ctypes.c_uint32), + ("hccl_deterministic", ctypes.c_uint32), + ("hccl_comm_name", ctypes.c_char * 128), + ("hccl_udi", ctypes.c_char * 128), + ("hccl_op_expansion_mode", ctypes.c_uint32), + ("hccl_rdma_traffic_class", ctypes.c_uint32), + ("hccl_rdma_service_level", ctypes.c_uint32), + ("hcll_world_rank_id", ctypes.c_uint32), + ("hccl_job_id", ctypes.c_uint64), + ("comm_engine", ctypes.c_int32), + ("thread_num", ctypes.c_uint32), + ("notify_num_per_thread", ctypes.c_uint32), + ("acl_graph_zero_copy_enable", ctypes.c_uint8), + ] + + def hccl_create_subcomm_config( + self, comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config + ): + subcomm = hcclComm_t() + self.HCCL_CHECK( + self._funcs["HcclCreateSubCommConfig"]( + ctypes.byref(comm), + ranks_size, + c_rank_ids, + subcomm_id, + subcomm_rank, + ctypes.byref(comm_config), + ctypes.byref(subcomm), + ) + ) + return subcomm + + # extend HCCLLibrary + HCCLLibrary.exported_functions = orig_exported_functions + extended_functions + HCCLLibrary.hcclAllGather = hccl_all_gather + HCCLLibrary.hcclCreateSubCommConfig = hccl_create_subcomm_config + + class PyHcclCommunicatorEx(PyHcclCommunicator): + def __init__(self, group, device): + super().__init__(group, device) + self.subcomms = {} + self.subcomm_id = 1 + + def destroy_comm(self): + self.hccl.hcclCommDestroy(self.comm) + + def all_reduce( + self, + in_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None, + ) -> torch.Tensor: + if self.disabled: + return None + assert in_tensor.device == self.device, ( + f"this hccl communicator is created to work on {self.device}, " + f"but the input tensor is on {in_tensor.device}" + ) + out_tensor = torch.empty_like(in_tensor) + if stream is None: + stream = current_stream() + self.hccl.hcclAllReduce( + buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + hcclDataTypeEnum.from_torch(in_tensor.dtype), + hcclRedOpTypeEnum.from_torch(op), + self.comm, # todo + aclrtStream_t(stream.npu_stream), + ) + return out_tensor + + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return None + assert tensor.device == self.device, ( + f"this hccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = current_stream() + self.hccl.hcclBroadcast( + buffer_type(tensor.data_ptr()), + tensor.numel(), + hcclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, # todo + aclrtStream_t(stream.npu_stream), + ) + + def all_gather(self, out_tensor: torch.Tensor, in_tensor: torch.Tensor, stream=None): + if self.disabled: + return + assert in_tensor.device == self.device, ( + f"this hccl communicator is created to work on {self.device}, " + f"but the input tensor in on {in_tensor.device}" + ) + if stream is None: + stream = current_stream() + self.hccl.hcclAllGather( + buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + hcclDataTypeEnum.from_torch(in_tensor.dtype), + self.comm, # todo + aclrtStream_t(stream.npu_stream), + ) + return out_tensor + + def create_subcomm( + self, + ranks, + ): + comm_config = HcclCommConfig( + size=312, + magic_word=0xF0F0F0F0, + version=6, + reserved=0, + hccl_buffer_size=0xFFFFFFFF, + hccl_deterministic=0xFFFFFFFF, + hccl_comm_name=b"\0", + hccl_udi=b"\0", + hccl_op_expansize_mode=0, + hccl_rdma_traffic_class=0xFFFFFFFF, + hccl_rdma_service_level=0xFFFFFFFF, + hccl_world_rank_id=0, + hccl_job_id=0, + comm_engine=-1, + thread_num=0xFFFFFFFF, + notify_num_per_thread=0xFFFFFFFF, + acl_graph_zero_copy_enable=0, + ) + uint32_array = ctypes.c_uint32 * len(ranks) + c_rank_ids = uint32_array(*ranks) + subcomm_rank = ranks.index(self.rank) + ranks_size = len(ranks) + subcomm_id = self.subcomm_id + + subcomm = self.hccl.hcclCreateSubCommConfig( + self.comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config + ) + self.subcomms[subcomm_id] = subcomm + self.subcomm_id += 1 + return subcomm + + class DistributedHccl: + def __init__(self): + self.pg = None + self.pyhccl = None + self.sub_groups = {} + + def init_process_group( + self, + host: str, + port: int, + rank: int, + world_size: int, + timeout: int = 300, + **kwargs, + ): + self._host = host + self._port = port + self._rank = rank + self._world_size = world_size + self._device = torch.device("npu", rank) + + self.pg = StatelessProcessGroup.create( + host, port, rank, world_size, store_timeout=timeout + ) + self.pyhccl = PyHcclCommunicatorEx(group=self.pg, device=self._device) + + def destroy_process_group(self, group=None): + if group in self.sub_groups: + group.pyhccl.destroy_comm() + del self.sub_groups[group] + return + + self.pyhccl.destroy_comm() + + self.pyhccl = None + self.pg = None + + def is_initialized(self) -> bool: + return self.pyhccl is not None + + def all_gather_object(self, object_list: list[Any], obj: Any, group=None): + if group: + assert group in self.sub_groups, "invalid sub_group" + pyhccl = group.pyhccl + else: + pyhccl = self.pyhccl + _common_all_gather_object(pyhccl, self._device, self._world_size, object_list, obj) + current_stream().synchronize() + + def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, group=None): + if group: + assert group in self.sub_groups, "invalid sub_group" + pyhccl = group.pyhccl + else: + pyhccl = self.pyhccl + + out_tensor = pyhccl.all_reduce(tensor, op) + current_stream().synchronize() + tensor.copy_(out_tensor) + + def broadcast(self, tensor: torch.Tensor, src=None, group=None): + if group: + assert group in self.sub_groups, "invalid sub_group" + assert src in self.sub_groups[group], "src rank not in group" + pyhccl = group.pyhccl + # src is rank id in global world + src = self.sub_groups[group].index(src) + else: + pyhccl = self.pyhccl + + pyhccl.broadcast(tensor, src) + current_stream().synchronize() + + def barrier(self, group=None): + if group: + assert group in self.sub_groups, "invalid sub_group" + pyhccl = group.pyhccl + else: + pyhccl = self.pyhccl + + data = torch.zeros(1, device=self._rank) + pyhccl.all_reduce(data) + current_stream().synchronize() + + def new_group(self, ranks): + # ranks is None or [] + if not ranks: + return self + + host = self._host + port = self._port + rank = self._rank + + if rank not in ranks: + return + + new_rank = ranks.index(rank) + new_world_size = len(ranks) + + new_dist = DistributedHccl() + new_dist.init_process_group( + host, port + 10, new_rank, new_world_size + ) # todo host maybe incorrect + self.sub_groups[new_dist] = ranks + + return new_dist + +except ImportError as e: + pass diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index e5cd655..0fbfd8b 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -24,6 +24,7 @@ from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid from checkpoint_engine.p2p_store import P2PStore from checkpoint_engine.pin_memory import _ALIGN_SIZE, _register_checkpoint +from checkpoint_engine.distributed import DistributedNccl, DistributedHccl if TYPE_CHECKING: @@ -195,6 +196,12 @@ def __init__( self._local_rdma_devices: dict[str, set[int]] = defaultdict(set) self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set) self._mem_fraction = mem_fraction or float(os.getenv("PS_MEM_FRACTION", "0.9")) + if self.device_manager.backend == "nccl": + self.dist = DistributedNccl + elif self.device_manager.backend == "hccl": + self.dist = DistributedHccl + else: + self.dist = torch.distributed assert self._rank is not None and self._rank >= 0, self._rank assert self._world_size and self._world_size > 0, self._world_size @@ -415,9 +422,9 @@ def gather_metas(self, checkpoint_name: str): This function should be called before update and init a new value to `self._current_global_parameter_metas`, which can be exported by using `self.get_metas` function. """ - if self._auto_pg and not dist.is_initialized(): + if self._auto_pg and not self.dist.is_initialized(): self.init_process_group() - assert dist.is_initialized(), "process group is not initialized" + assert self.dist.is_initialized(), "process group is not initialized" metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore try: memory_pool = self._get_memory_pool(checkpoint_name) @@ -438,7 +445,7 @@ def gather_metas(self, checkpoint_name: str): rdma_device=self._rdma_device or "", ) - dist.all_gather_object(metas_lst, metas) + self.dist.all_gather_object(metas_lst, metas) self._current_global_parameter_metas = {} @@ -490,14 +497,14 @@ def init_process_group( """ master_addr = master_addr or os.getenv("MASTER_ADDR") assert master_addr, "master_addr is required" - store = dist.TCPStore( + store = self.dist.TCPStore( master_addr, _get_master_port(master_port), self._world_size, timeout=timeout, is_master=self._rank == 0, ) - dist.init_process_group( + self.dist.init_process_group( backend=self.device_manager.backend, world_size=self._world_size, rank=self._rank, @@ -559,25 +566,14 @@ def update( master_addr = os.getenv("MASTER_ADDR") or master_addr assert master_addr, "master_addr is required" if self._auto_pg: - if not dist.is_initialized(): + if not self.dist.is_initialized(): self.init_process_group( timeout=timeout, master_addr=master_addr, master_port=master_port ) - manager_store = dist.distributed_c10d._get_default_store() - else: - # HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1 - # If master_port is provided, use master_port+1 for barrier store - manager_store = dist.TCPStore( - master_addr, - _get_master_port(master_port) + 1, - self._world_size, - timeout=timeout, - is_master=self._rank == 0, - ) # if ranks is None or [], it will use fully broadcast to update to all ranks - ranks_group = dist.new_group(ranks) if ranks else None + ranks_group = self.dist.new_group(ranks) if ranks else None self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks) - self.store_based_barrier(manager_store) + self.dist.barrier() except Exception as e: logger.exception( f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}" @@ -585,9 +581,9 @@ def update( raise finally: if ranks_group: - dist.destroy_process_group(ranks_group) - if self._auto_pg and dist.is_initialized(): - dist.destroy_process_group() + self.dist.destroy_process_group(ranks_group) + if self._auto_pg and self.dist.is_initialized(): + self.dist.destroy_process_group() self.device_manager.device_module.empty_cache() logger.info( f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. " @@ -623,7 +619,7 @@ def _detect_bucket_size( dtype=torch.int64, device=self.device_manager.device_type, ) - dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=ranks_group) + self.dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=ranks_group) tensor = tensor.cpu() free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item() max_tensor_bytes = 0 @@ -729,7 +725,7 @@ def _update_per_bucket( ranks: list[int] | None = None, ): assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty" - assert dist.is_initialized(), "process group is not initialized" + assert self.dist.is_initialized(), "process group is not initialized" # if both ranks is None or [], it will use fully broadcast to update to all ranks if not ranks: @@ -748,7 +744,7 @@ def _update_per_bucket( if not need_update: return # first execute a barrier to avoid subsequent device oom - dist.barrier(group=ranks_group) + self.dist.barrier(group=ranks_group) bucket_size, disable_h2d_buffer = self._detect_bucket_size(ranks_group) buckets = _gen_h2d_buckets( @@ -826,7 +822,7 @@ def _update_per_bucket( self._copy_to_buffer(checkpoint_name, bucket, buffer_b) else: buffer_b.data.copy_(h2d_buffer[: bucket.size]) - dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group) + self.dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group) resp = socket.recv() if resp != b"": msg = resp.decode("utf-8") @@ -834,7 +830,7 @@ def _update_per_bucket( f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}" ) ret_code.fill_(1) - dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group) + self.dist.all_reduce(ret_code, op=self.dist.ReduceOp.SUM, group=ranks_group) self.device_manager.device_module.synchronize() if ret_code.item() != 0: # quit early if any rank failed @@ -848,7 +844,7 @@ def _update_per_bucket( socket.recv() finally: req_thread.join() - dist.barrier(group=ranks_group) + self.dist.barrier(group=ranks_group) socket.close() if ranks and h2d_buffer is not None: self._p2p_store.unregister_named_tensors([h2d_buffer_name]) diff --git a/checkpoint_engine/worker.py b/checkpoint_engine/worker.py index c69815c..4cea898 100644 --- a/checkpoint_engine/worker.py +++ b/checkpoint_engine/worker.py @@ -168,7 +168,7 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]): update_weights_from_ipc( self._zmq_ctx, - zmq_handles[self._device_uuid], + zmq_handles[device_uuid], device_id=self.device.index, run=self.model_runner.model.load_weights, post_hook=lambda: process_weights_after_loading( From 77f7b5790354a6b6748ec6cea9836205790327ee Mon Sep 17 00:00:00 2001 From: yexin <469221983@qq.com> Date: Sun, 4 Jan 2026 17:06:59 +0800 Subject: [PATCH 03/14] fix bugs --- checkpoint_engine/distributed.py | 186 ++++++++++++------------------- checkpoint_engine/ps.py | 8 +- 2 files changed, 77 insertions(+), 117 deletions(-) diff --git a/checkpoint_engine/distributed.py b/checkpoint_engine/distributed.py index 0f1952e..9069488 100644 --- a/checkpoint_engine/distributed.py +++ b/checkpoint_engine/distributed.py @@ -3,6 +3,7 @@ import logging import os import pickle +from datetime import timedelta from enum import Enum from typing import Any, List, Optional @@ -79,7 +80,7 @@ def init_process_group( port: int, rank: int, world_size: int, - timeout: int = 300, + timeout: timedelta = timedelta(seconds=300), **kwargs, ): self._host = host @@ -88,7 +89,9 @@ def init_process_group( self._world_size = world_size self._device = torch.device("cuda", rank) - self.pg = StatelessProcessGroup.create(host, port, rank, world_size, store_timeout=timeout) + self.pg = StatelessProcessGroup.create( + host, port, rank, world_size, store_timeout=int(timeout.total_seconds()) + ) from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator @@ -186,6 +189,27 @@ def new_group(self, ranks): ) from vllm_ascend.utils import current_stream + class HcclCommConfig(ctypes.Structure): + _fields_ = [ + ("size", ctypes.c_size_t), + ("magic_word", ctypes.c_uint32), + ("version", ctypes.c_uint32), + ("reserved", ctypes.c_uint64), + ("hccl_buffer_size", ctypes.c_uint32), + ("hccl_deterministic", ctypes.c_uint32), + ("hccl_comm_name", ctypes.c_char * 128), + ("hccl_udi", ctypes.c_char * 128), + ("hccl_op_expansion_mode", ctypes.c_uint32), + ("hccl_rdma_traffic_class", ctypes.c_uint32), + ("hccl_rdma_service_level", ctypes.c_uint32), + ("hcll_world_rank_id", ctypes.c_uint32), + ("hccl_job_id", ctypes.c_uint64), + ("comm_engine", ctypes.c_int32), + ("thread_num", ctypes.c_uint32), + ("notify_num_per_thread", ctypes.c_uint32), + ("acl_graph_zero_copy_enable", ctypes.c_uint8), + ] + orig_exported_functions = HCCLLibrary.exported_functions extended_functions = [ # HcclResult HcclAllGather( @@ -217,7 +241,7 @@ def new_group(self, ranks): ctypes.POINTER(ctypes.c_uint32), ctypes.c_uint64, ctypes.c_uint32, - ctypes.POINTER(hcclUniqueId), + ctypes.POINTER(HcclCommConfig), ctypes.POINTER(hcclComm_t), ], ), @@ -228,27 +252,6 @@ def hccl_all_gather(self, send_buf, recv_buf, count, data_type, comm, stream): self._funcs["HcclAllGather"](send_buf, recv_buf, count, data_type, comm, stream) ) - class HcclCommConfig(ctypes.Structure): - _fields_ = [ - ("size", ctypes.c_size_t), - ("magic_word", ctypes.c_uint32), - ("version", ctypes.c_uint32), - ("reserved", ctypes.c_uint64), - ("hccl_buffer_size", ctypes.c_uint32), - ("hccl_deterministic", ctypes.c_uint32), - ("hccl_comm_name", ctypes.c_char * 128), - ("hccl_udi", ctypes.c_char * 128), - ("hccl_op_expansion_mode", ctypes.c_uint32), - ("hccl_rdma_traffic_class", ctypes.c_uint32), - ("hccl_rdma_service_level", ctypes.c_uint32), - ("hcll_world_rank_id", ctypes.c_uint32), - ("hccl_job_id", ctypes.c_uint64), - ("comm_engine", ctypes.c_int32), - ("thread_num", ctypes.c_uint32), - ("notify_num_per_thread", ctypes.c_uint32), - ("acl_graph_zero_copy_enable", ctypes.c_uint8), - ] - def hccl_create_subcomm_config( self, comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config ): @@ -274,55 +277,13 @@ def hccl_create_subcomm_config( class PyHcclCommunicatorEx(PyHcclCommunicator): def __init__(self, group, device): super().__init__(group, device) - self.subcomms = {} self.subcomm_id = 1 - def destroy_comm(self): - self.hccl.hcclCommDestroy(self.comm) - - def all_reduce( - self, - in_tensor: torch.Tensor, - op: ReduceOp = ReduceOp.SUM, - stream=None, - ) -> torch.Tensor: - if self.disabled: - return None - assert in_tensor.device == self.device, ( - f"this hccl communicator is created to work on {self.device}, " - f"but the input tensor is on {in_tensor.device}" - ) - out_tensor = torch.empty_like(in_tensor) - if stream is None: - stream = current_stream() - self.hccl.hcclAllReduce( - buffer_type(in_tensor.data_ptr()), - buffer_type(out_tensor.data_ptr()), - in_tensor.numel(), - hcclDataTypeEnum.from_torch(in_tensor.dtype), - hcclRedOpTypeEnum.from_torch(op), - self.comm, # todo - aclrtStream_t(stream.npu_stream), - ) - return out_tensor - - def broadcast(self, tensor: torch.Tensor, src: int, stream=None): - if self.disabled: - return None - assert tensor.device == self.device, ( - f"this hccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}" - ) - if stream is None: - stream = current_stream() - self.hccl.hcclBroadcast( - buffer_type(tensor.data_ptr()), - tensor.numel(), - hcclDataTypeEnum.from_torch(tensor.dtype), - src, - self.comm, # todo - aclrtStream_t(stream.npu_stream), - ) + def destroy_comm(self, comm=None): + if comm: + self.hccl.hcclCommDestroy(comm) + else: + self.hccl.hcclCommDestroy(self.comm) def all_gather(self, out_tensor: torch.Tensor, in_tensor: torch.Tensor, stream=None): if self.disabled: @@ -343,10 +304,7 @@ def all_gather(self, out_tensor: torch.Tensor, in_tensor: torch.Tensor, stream=N ) return out_tensor - def create_subcomm( - self, - ranks, - ): + def create_subcomm(self, ranks): comm_config = HcclCommConfig( size=312, magic_word=0xF0F0F0F0, @@ -375,7 +333,6 @@ def create_subcomm( subcomm = self.hccl.hcclCreateSubCommConfig( self.comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config ) - self.subcomms[subcomm_id] = subcomm self.subcomm_id += 1 return subcomm @@ -391,7 +348,7 @@ def init_process_group( port: int, rank: int, world_size: int, - timeout: int = 300, + timeout: timedelta = timedelta(seconds=300), **kwargs, ): self._host = host @@ -401,13 +358,15 @@ def init_process_group( self._device = torch.device("npu", rank) self.pg = StatelessProcessGroup.create( - host, port, rank, world_size, store_timeout=timeout + host, port, rank, world_size, store_timeout=int(timeout.total_seconds()) ) self.pyhccl = PyHcclCommunicatorEx(group=self.pg, device=self._device) + self._comm = self.pyhccl.comm def destroy_process_group(self, group=None): if group in self.sub_groups: - group.pyhccl.destroy_comm() + subcomm = ctypes.c_void_p(group) + self.pyhccl.destroy_comm(subcomm) del self.sub_groups[group] return @@ -422,69 +381,70 @@ def is_initialized(self) -> bool: def all_gather_object(self, object_list: list[Any], obj: Any, group=None): if group: assert group in self.sub_groups, "invalid sub_group" - pyhccl = group.pyhccl - else: - pyhccl = self.pyhccl - _common_all_gather_object(pyhccl, self._device, self._world_size, object_list, obj) + subcomm = ctypes.c_void_p(group) + self.pyhccl.comm = subcomm + + _common_all_gather_object(self.pyhccl, self._device, self._world_size, object_list, obj) current_stream().synchronize() + if group: + self.pyhccl.comm = self._comm + def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, group=None): if group: assert group in self.sub_groups, "invalid sub_group" - pyhccl = group.pyhccl - else: - pyhccl = self.pyhccl + subcomm = ctypes.c_void_p(group) + self.pyhccl.comm = subcomm - out_tensor = pyhccl.all_reduce(tensor, op) + out_tensor = self.pyhccl.all_reduce(tensor, op) current_stream().synchronize() tensor.copy_(out_tensor) + if group: + self.pyhccl.comm = self._comm + def broadcast(self, tensor: torch.Tensor, src=None, group=None): if group: assert group in self.sub_groups, "invalid sub_group" assert src in self.sub_groups[group], "src rank not in group" - pyhccl = group.pyhccl - # src is rank id in global world + subcomm = ctypes.c_void_p(group) + self.pyhccl.comm = subcomm + # convert src rank id in default world to subcomm src = self.sub_groups[group].index(src) - else: - pyhccl = self.pyhccl - pyhccl.broadcast(tensor, src) + self.pyhccl.broadcast(tensor, src) current_stream().synchronize() + if group: + self.pyhccl.comm = self._comm + def barrier(self, group=None): if group: assert group in self.sub_groups, "invalid sub_group" - pyhccl = group.pyhccl - else: - pyhccl = self.pyhccl + subcomm = ctypes.c_void_p(group) + self.pyhccl.comm = subcomm data = torch.zeros(1, device=self._rank) - pyhccl.all_reduce(data) + self.pyhccl.all_reduce(data) current_stream().synchronize() + if group: + self.pyhccl.comm = self._comm + def new_group(self, ranks): - # ranks is None or [] + # if ranks is None or [], using the world instead if not ranks: - return self - - host = self._host - port = self._port - rank = self._rank + ranks = list(range(self._world_size)) - if rank not in ranks: + if self._rank not in ranks: return - new_rank = ranks.index(rank) - new_world_size = len(ranks) - - new_dist = DistributedHccl() - new_dist.init_process_group( - host, port + 10, new_rank, new_world_size - ) # todo host maybe incorrect - self.sub_groups[new_dist] = ranks - - return new_dist + subcomm = self.pyhccl.create_subcomm(ranks) + value = 0 + if subcomm: + value = subcomm.value + self.sub_groups[value] = ranks + return value except ImportError as e: pass diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 0fbfd8b..095b853 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -197,9 +197,9 @@ def __init__( self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set) self._mem_fraction = mem_fraction or float(os.getenv("PS_MEM_FRACTION", "0.9")) if self.device_manager.backend == "nccl": - self.dist = DistributedNccl + self.dist = DistributedNccl() elif self.device_manager.backend == "hccl": - self.dist = DistributedHccl + self.dist = DistributedHccl() else: self.dist = torch.distributed @@ -497,7 +497,7 @@ def init_process_group( """ master_addr = master_addr or os.getenv("MASTER_ADDR") assert master_addr, "master_addr is required" - store = self.dist.TCPStore( + store = dist.TCPStore( master_addr, _get_master_port(master_port), self._world_size, @@ -830,7 +830,7 @@ def _update_per_bucket( f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}" ) ret_code.fill_(1) - self.dist.all_reduce(ret_code, op=self.dist.ReduceOp.SUM, group=ranks_group) + self.dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group) self.device_manager.device_module.synchronize() if ret_code.item() != 0: # quit early if any rank failed From 5266ac15b266e0c1deb1a36d0ec05ab6b168845c Mon Sep 17 00:00:00 2001 From: yexin <469221983@qq.com> Date: Mon, 5 Jan 2026 21:55:31 +0800 Subject: [PATCH 04/14] implement PyNcclCommunicatorEx --- checkpoint_engine/distributed.py | 168 ++++++++++++++++++++++++------- 1 file changed, 129 insertions(+), 39 deletions(-) diff --git a/checkpoint_engine/distributed.py b/checkpoint_engine/distributed.py index 9069488..f651c25 100644 --- a/checkpoint_engine/distributed.py +++ b/checkpoint_engine/distributed.py @@ -10,6 +10,14 @@ import torch import torch.distributed from torch.distributed import ReduceOp +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.device_communicators.pynccl_wrapper import ( + Function, + NCCLLibrary, + buffer_type, + ncclComm_t, + ncclResult_t, +) from vllm.distributed.utils import StatelessProcessGroup from vllm.utils import current_stream @@ -68,6 +76,83 @@ def _common_all_gather_object(comm, device, world_size, object_list, object): object_list[i] = _tensor_to_object(tensor, tensor_size) +class ncclConfig_t(ctypes.Structure): + _fields_ = [ + ("size", ctypes.c_size_t), + ("magic", ctypes.c_uint), + ("version", ctypes.c_uint), + ("blocking", ctypes.c_int), + ("cgaClusterSize", ctypes.c_int), + ("minCTAs", ctypes.c_int), + ("maxCTAs", ctypes.c_int), + ("netName", ctypes.c_char_p), + ("splitShare", ctypes.c_int), + ("trafficClass", ctypes.c_int), + ("commName", ctypes.c_char_p), + ("collnetEnable", ctypes.c_int), + ("CTAPolicy", ctypes.c_int), + ("shrinkShare", ctypes.c_int), + ("nvlsCTAs", ctypes.c_int), + ("nChannelsPerNetPeer", ctypes.c_int), + ("nvlinkCentricSched", ctypes.c_int), + ("graphUsageMode", ctypes.c_int), + ("numRmdCtx", ctypes.c_int), + ] + +nccl_orig_exported_functions = NCCLLibrary.exported_functions +nccl_extended_functions = [ + # ncclResult_t ncclCommSplit( + # ncclComm_t comm, int color, int key, ncclComm_t *newcomm, ncclConfig_t *config + # ) + Function( + "ncclCommSplit", + ncclResult_t, + [ + ncclComm_t, + ctypes.c_int, + ctypes.c_int, + ctypes.POINTER(ncclComm_t), + ctypes.POINTER(ncclConfig_t), + ], + ), +] + + +def nccl_comm_split( + self, + comm: ncclComm_t, + color: int, + key: int, +) -> ncclComm_t: + newcomm = ncclComm_t() + + self.NCCL_CHECK( + self._funcs["ncclCommSplit"](comm, color, key, ctypes.byref(newcomm), None) + ) + return newcomm + + +# extend NCCLLibrary +NCCLLibrary.exported_functions = nccl_orig_exported_functions + nccl_extended_functions +NCCLLibrary.ncclCommSplit = nccl_comm_split + + +class PyNcclCommunicatorEx(PyNcclCommunicator): + def destroy_comm(self, comm=None): + if comm: + self.nccl.ncclCommDestroy(comm) + else: + self.nccl.ncclCommDestroy(self.comm) + + def create_newcomm(self, ranks): + if self.rank in ranks: + color = 0 + else: + color = -1 # NCCL_SPLIT_NOCOLOR + newcomm = self.nccl.ncclCommSplit(self.comm, color, self.rank) + return newcomm + + class DistributedNccl: def __init__(self): self.pg = None @@ -93,83 +178,88 @@ def init_process_group( host, port, rank, world_size, store_timeout=int(timeout.total_seconds()) ) - from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator - - self.pynccl = PyNcclCommunicator(group=self.pg, device=self._device) + self.pynccl = PyNcclCommunicatorEx(group=self.pg, device=self._device) + self._comm = self.pynccl.comm def destroy_process_group(self, group=None): - # PyNcclCommunicator does not provide destroy method if group in self.sub_groups: + newcomm = ctypes.c_void_p(group) + self.pynccl.destroy_comm(newcomm) del self.sub_groups[group] return + self.pynccl.destroy_comm() + + self.pynccl = None + self.pg = None + def is_initialized(self) -> bool: return self.pynccl is not None def all_gather_object(self, object_list: list[Any], obj: Any, group=None): if group: assert group in self.sub_groups, "invalid sub_group" - pynccl = group.pynccl - else: - pynccl = self.pynccl + newcomm = ctypes.c_void_p(group) + self.pynccl.comm = newcomm - _common_all_gather_object(pynccl, self._device, self._world_size, object_list, object) + _common_all_gather_object(self.pynccl, self._device, self._world_size, object_list, obj) current_stream().synchronize() + if group: + self.pynccl.comm = self._comm + def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, group=None): if group: assert group in self.sub_groups, "invalid sub_group" - pynccl = group.pynccl - else: - pynccl = self.pynccl + newcomm = ctypes.c_void_p(group) + self.pynccl.comm = newcomm - out_tensor = pynccl.all_reduce(in_tensor=tensor, op=op) + out_tensor = self.pynccl.all_reduce(in_tensor=tensor, op=op) current_stream().synchronize() tensor.copy_(out_tensor) + if group: + self.pynccl.comm = self._comm + def broadcast(self, tensor: torch.Tensor, src=None, group=None): if group: assert group in self.sub_groups, "invalid sub_group" - pynccl = group.pynccl - else: - pynccl = self.pynccl + assert src in self.sub_groups[group], "src rank not in group" + newcomm = ctypes.c_void_p(group) + self.pynccl.comm = newcomm + # convert src rank id in default world to newcomm + #src = self.sub_groups[group].index(src) - pynccl.broadcast(tensor, src) + self.pynccl.broadcast(tensor, src) current_stream().synchronize() + if group: + self.pynccl.comm = self._comm + def barrier(self, group=None): if group: assert group in self.sub_groups, "invalid sub_group" - pynccl = group.pynccl - else: - pynccl = self.pynccl + newcomm = ctypes.c_void_p(group) + self.pynccl.comm = newcomm data = torch.zeros(1, device=self._rank) - pynccl.all_reduce(data) + self.pynccl.all_reduce(data) current_stream().synchronize() + if group: + self.pynccl.comm = self._comm + def new_group(self, ranks): # ranks is None or [] if not ranks: - return self - - host = self._host - port = self._port - rank = self._rank - - if rank not in ranks: - return - - new_rank = ranks.index(rank) - new_world_size = len(ranks) - - new_dist = DistributedNccl() - new_dist.init_process_group( - host, port + 10, new_rank, new_world_size - ) # todo host maybe incorrect - self.sub_groups.append(new_dist) - - return new_dist + ranks = list(range(self._world_size)) + + newcomm = self.pynccl.create_newcomm(ranks) + value = 0 + if newcomm: + value = newcomm.value + self.sub_groups[value] = ranks + return value try: From 533bc5d86151577843fb519454b6a976470826d1 Mon Sep 17 00:00:00 2001 From: yexin <469221983@qq.com> Date: Tue, 6 Jan 2026 10:45:42 +0800 Subject: [PATCH 05/14] fix rebase issues --- checkpoint_engine/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/checkpoint_engine/worker.py b/checkpoint_engine/worker.py index 4cea898..c69815c 100644 --- a/checkpoint_engine/worker.py +++ b/checkpoint_engine/worker.py @@ -168,7 +168,7 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]): update_weights_from_ipc( self._zmq_ctx, - zmq_handles[device_uuid], + zmq_handles[self._device_uuid], device_id=self.device.index, run=self.model_runner.model.load_weights, post_hook=lambda: process_weights_after_loading( From c7303b9f4e4aea211ed7984d12f6824a3933abb3 Mon Sep 17 00:00:00 2001 From: yexin <469221983@qq.com> Date: Tue, 6 Jan 2026 11:14:28 +0800 Subject: [PATCH 06/14] split distributed.py into distributed_nccl.py & distributed_hccl.py --- checkpoint_engine/distributed.py | 540 -------------------------- checkpoint_engine/distributed_hccl.py | 281 ++++++++++++++ checkpoint_engine/distributed_nccl.py | 262 +++++++++++++ 3 files changed, 543 insertions(+), 540 deletions(-) delete mode 100644 checkpoint_engine/distributed.py create mode 100644 checkpoint_engine/distributed_hccl.py create mode 100644 checkpoint_engine/distributed_nccl.py diff --git a/checkpoint_engine/distributed.py b/checkpoint_engine/distributed.py deleted file mode 100644 index f651c25..0000000 --- a/checkpoint_engine/distributed.py +++ /dev/null @@ -1,540 +0,0 @@ -import ctypes -import io -import logging -import os -import pickle -from datetime import timedelta -from enum import Enum -from typing import Any, List, Optional - -import torch -import torch.distributed -from torch.distributed import ReduceOp -from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator -from vllm.distributed.device_communicators.pynccl_wrapper import ( - Function, - NCCLLibrary, - buffer_type, - ncclComm_t, - ncclResult_t, -) -from vllm.distributed.utils import StatelessProcessGroup -from vllm.utils import current_stream - - -_pickler = pickle.Pickler -_unpickler = pickle.Unpickler - - -def _object_to_tensor(obj, device): - f = io.BytesIO() - _pickler(f).dump(obj) - byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) - byte_tensor = torch.ByteTensor(byte_storage).to(device) - local_size = torch.LongTensor([byte_tensor.numel()]).to(device) - return byte_tensor, local_size - - -def _tensor_to_object(tensor, tensor_size): - tensor = tensor.cpu() - buf = tensor.numpy().tobytes()[:tensor_size] - return _unpickler(io.BytesIO(buf)).load() - - -def _flatten_for_scatter_gather(tensor_list, copy=False): - if not tensor_list: - raise RuntimeError("Received an empty list.") - t = tensor_list[0] - buffer_shape = [len(tensor_list)] + list(t.shape) - - buffer = torch.empty(tuple(buffer_shape), dtype=t.dtype, device=t.device) - if copy: - for i, tensor in enumerate(tensor_list): - buffer[i].copy_(tensor) - return buffer - - -def _common_all_gather_object(comm, device, world_size, object_list, object): - input_tensor, local_size = _object_to_tensor(object, device) - object_sizes_tensor = torch.empty(world_size, dtype=torch.long, device=device) - comm.all_gather(object_sizes_tensor, local_size) - object_size_list = [object_sizes_tensor[i].unsqueeze(dim=0) for i in range(world_size)] - max_object_size = int(max(object_size_list).item()) - input_tensor.resize_(max_object_size) - coalesced_output_tensor = torch.empty( - max_object_size * world_size, dtype=torch.uint8, device=device - ) - - comm.all_gather(coalesced_output_tensor, input_tensor) - output_tensors = [ - coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] - for i in range(world_size) - ] - for i, tensor in enumerate(output_tensors): - tensor = tensor.type(torch.uint8) - tensor_size = object_size_list[i] - object_list[i] = _tensor_to_object(tensor, tensor_size) - - -class ncclConfig_t(ctypes.Structure): - _fields_ = [ - ("size", ctypes.c_size_t), - ("magic", ctypes.c_uint), - ("version", ctypes.c_uint), - ("blocking", ctypes.c_int), - ("cgaClusterSize", ctypes.c_int), - ("minCTAs", ctypes.c_int), - ("maxCTAs", ctypes.c_int), - ("netName", ctypes.c_char_p), - ("splitShare", ctypes.c_int), - ("trafficClass", ctypes.c_int), - ("commName", ctypes.c_char_p), - ("collnetEnable", ctypes.c_int), - ("CTAPolicy", ctypes.c_int), - ("shrinkShare", ctypes.c_int), - ("nvlsCTAs", ctypes.c_int), - ("nChannelsPerNetPeer", ctypes.c_int), - ("nvlinkCentricSched", ctypes.c_int), - ("graphUsageMode", ctypes.c_int), - ("numRmdCtx", ctypes.c_int), - ] - -nccl_orig_exported_functions = NCCLLibrary.exported_functions -nccl_extended_functions = [ - # ncclResult_t ncclCommSplit( - # ncclComm_t comm, int color, int key, ncclComm_t *newcomm, ncclConfig_t *config - # ) - Function( - "ncclCommSplit", - ncclResult_t, - [ - ncclComm_t, - ctypes.c_int, - ctypes.c_int, - ctypes.POINTER(ncclComm_t), - ctypes.POINTER(ncclConfig_t), - ], - ), -] - - -def nccl_comm_split( - self, - comm: ncclComm_t, - color: int, - key: int, -) -> ncclComm_t: - newcomm = ncclComm_t() - - self.NCCL_CHECK( - self._funcs["ncclCommSplit"](comm, color, key, ctypes.byref(newcomm), None) - ) - return newcomm - - -# extend NCCLLibrary -NCCLLibrary.exported_functions = nccl_orig_exported_functions + nccl_extended_functions -NCCLLibrary.ncclCommSplit = nccl_comm_split - - -class PyNcclCommunicatorEx(PyNcclCommunicator): - def destroy_comm(self, comm=None): - if comm: - self.nccl.ncclCommDestroy(comm) - else: - self.nccl.ncclCommDestroy(self.comm) - - def create_newcomm(self, ranks): - if self.rank in ranks: - color = 0 - else: - color = -1 # NCCL_SPLIT_NOCOLOR - newcomm = self.nccl.ncclCommSplit(self.comm, color, self.rank) - return newcomm - - -class DistributedNccl: - def __init__(self): - self.pg = None - self.pynccl = None - self.sub_groups = {} - - def init_process_group( - self, - host: str, - port: int, - rank: int, - world_size: int, - timeout: timedelta = timedelta(seconds=300), - **kwargs, - ): - self._host = host - self._port = port - self._rank = rank - self._world_size = world_size - self._device = torch.device("cuda", rank) - - self.pg = StatelessProcessGroup.create( - host, port, rank, world_size, store_timeout=int(timeout.total_seconds()) - ) - - self.pynccl = PyNcclCommunicatorEx(group=self.pg, device=self._device) - self._comm = self.pynccl.comm - - def destroy_process_group(self, group=None): - if group in self.sub_groups: - newcomm = ctypes.c_void_p(group) - self.pynccl.destroy_comm(newcomm) - del self.sub_groups[group] - return - - self.pynccl.destroy_comm() - - self.pynccl = None - self.pg = None - - def is_initialized(self) -> bool: - return self.pynccl is not None - - def all_gather_object(self, object_list: list[Any], obj: Any, group=None): - if group: - assert group in self.sub_groups, "invalid sub_group" - newcomm = ctypes.c_void_p(group) - self.pynccl.comm = newcomm - - _common_all_gather_object(self.pynccl, self._device, self._world_size, object_list, obj) - current_stream().synchronize() - - if group: - self.pynccl.comm = self._comm - - def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, group=None): - if group: - assert group in self.sub_groups, "invalid sub_group" - newcomm = ctypes.c_void_p(group) - self.pynccl.comm = newcomm - - out_tensor = self.pynccl.all_reduce(in_tensor=tensor, op=op) - current_stream().synchronize() - tensor.copy_(out_tensor) - - if group: - self.pynccl.comm = self._comm - - def broadcast(self, tensor: torch.Tensor, src=None, group=None): - if group: - assert group in self.sub_groups, "invalid sub_group" - assert src in self.sub_groups[group], "src rank not in group" - newcomm = ctypes.c_void_p(group) - self.pynccl.comm = newcomm - # convert src rank id in default world to newcomm - #src = self.sub_groups[group].index(src) - - self.pynccl.broadcast(tensor, src) - current_stream().synchronize() - - if group: - self.pynccl.comm = self._comm - - def barrier(self, group=None): - if group: - assert group in self.sub_groups, "invalid sub_group" - newcomm = ctypes.c_void_p(group) - self.pynccl.comm = newcomm - - data = torch.zeros(1, device=self._rank) - self.pynccl.all_reduce(data) - current_stream().synchronize() - - if group: - self.pynccl.comm = self._comm - - def new_group(self, ranks): - # ranks is None or [] - if not ranks: - ranks = list(range(self._world_size)) - - newcomm = self.pynccl.create_newcomm(ranks) - value = 0 - if newcomm: - value = newcomm.value - self.sub_groups[value] = ranks - return value - - -try: - from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator - from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import ( - Function, - HCCLLibrary, - aclrtStream_t, - buffer_type, - hcclComm_t, - hcclDataType_t, - hcclDataTypeEnum, - hcclRedOp_t, - hcclRedOpTypeEnum, - hcclResult_t, - hcclUniqueId, - ) - from vllm_ascend.utils import current_stream - - class HcclCommConfig(ctypes.Structure): - _fields_ = [ - ("size", ctypes.c_size_t), - ("magic_word", ctypes.c_uint32), - ("version", ctypes.c_uint32), - ("reserved", ctypes.c_uint64), - ("hccl_buffer_size", ctypes.c_uint32), - ("hccl_deterministic", ctypes.c_uint32), - ("hccl_comm_name", ctypes.c_char * 128), - ("hccl_udi", ctypes.c_char * 128), - ("hccl_op_expansion_mode", ctypes.c_uint32), - ("hccl_rdma_traffic_class", ctypes.c_uint32), - ("hccl_rdma_service_level", ctypes.c_uint32), - ("hcll_world_rank_id", ctypes.c_uint32), - ("hccl_job_id", ctypes.c_uint64), - ("comm_engine", ctypes.c_int32), - ("thread_num", ctypes.c_uint32), - ("notify_num_per_thread", ctypes.c_uint32), - ("acl_graph_zero_copy_enable", ctypes.c_uint8), - ] - - orig_exported_functions = HCCLLibrary.exported_functions - extended_functions = [ - # HcclResult HcclAllGather( - # void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType, - # HcclComm comm, alcrtStream stream - # ) - Function( - "HcclAllGather", - hcclResult_t, - [ - buffer_type, - buffer_type, - ctypes.c_uint64, - hcclDataType_t, - hcclComm_t, - aclrtStream_t, - ], - ), - # HcclResult HcclCreateSubCommConfig( - # HcclComm *comm, uin32_t rankNum, uint32_t *rankIds, uint64_t subCommId, - # uint32_t subCommRankId, HcclCommConfig *config, HcclComm *subComm - # ) - Function( - "HcclCreateSubCommConfig", - hcclResult_t, - [ - ctypes.POINTER(hcclComm_t), - ctypes.c_uint32, - ctypes.POINTER(ctypes.c_uint32), - ctypes.c_uint64, - ctypes.c_uint32, - ctypes.POINTER(HcclCommConfig), - ctypes.POINTER(hcclComm_t), - ], - ), - ] - - def hccl_all_gather(self, send_buf, recv_buf, count, data_type, comm, stream): - self.HCCL_CHECK( - self._funcs["HcclAllGather"](send_buf, recv_buf, count, data_type, comm, stream) - ) - - def hccl_create_subcomm_config( - self, comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config - ): - subcomm = hcclComm_t() - self.HCCL_CHECK( - self._funcs["HcclCreateSubCommConfig"]( - ctypes.byref(comm), - ranks_size, - c_rank_ids, - subcomm_id, - subcomm_rank, - ctypes.byref(comm_config), - ctypes.byref(subcomm), - ) - ) - return subcomm - - # extend HCCLLibrary - HCCLLibrary.exported_functions = orig_exported_functions + extended_functions - HCCLLibrary.hcclAllGather = hccl_all_gather - HCCLLibrary.hcclCreateSubCommConfig = hccl_create_subcomm_config - - class PyHcclCommunicatorEx(PyHcclCommunicator): - def __init__(self, group, device): - super().__init__(group, device) - self.subcomm_id = 1 - - def destroy_comm(self, comm=None): - if comm: - self.hccl.hcclCommDestroy(comm) - else: - self.hccl.hcclCommDestroy(self.comm) - - def all_gather(self, out_tensor: torch.Tensor, in_tensor: torch.Tensor, stream=None): - if self.disabled: - return - assert in_tensor.device == self.device, ( - f"this hccl communicator is created to work on {self.device}, " - f"but the input tensor in on {in_tensor.device}" - ) - if stream is None: - stream = current_stream() - self.hccl.hcclAllGather( - buffer_type(in_tensor.data_ptr()), - buffer_type(out_tensor.data_ptr()), - in_tensor.numel(), - hcclDataTypeEnum.from_torch(in_tensor.dtype), - self.comm, # todo - aclrtStream_t(stream.npu_stream), - ) - return out_tensor - - def create_subcomm(self, ranks): - comm_config = HcclCommConfig( - size=312, - magic_word=0xF0F0F0F0, - version=6, - reserved=0, - hccl_buffer_size=0xFFFFFFFF, - hccl_deterministic=0xFFFFFFFF, - hccl_comm_name=b"\0", - hccl_udi=b"\0", - hccl_op_expansize_mode=0, - hccl_rdma_traffic_class=0xFFFFFFFF, - hccl_rdma_service_level=0xFFFFFFFF, - hccl_world_rank_id=0, - hccl_job_id=0, - comm_engine=-1, - thread_num=0xFFFFFFFF, - notify_num_per_thread=0xFFFFFFFF, - acl_graph_zero_copy_enable=0, - ) - uint32_array = ctypes.c_uint32 * len(ranks) - c_rank_ids = uint32_array(*ranks) - subcomm_rank = ranks.index(self.rank) - ranks_size = len(ranks) - subcomm_id = self.subcomm_id - - subcomm = self.hccl.hcclCreateSubCommConfig( - self.comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config - ) - self.subcomm_id += 1 - return subcomm - - class DistributedHccl: - def __init__(self): - self.pg = None - self.pyhccl = None - self.sub_groups = {} - - def init_process_group( - self, - host: str, - port: int, - rank: int, - world_size: int, - timeout: timedelta = timedelta(seconds=300), - **kwargs, - ): - self._host = host - self._port = port - self._rank = rank - self._world_size = world_size - self._device = torch.device("npu", rank) - - self.pg = StatelessProcessGroup.create( - host, port, rank, world_size, store_timeout=int(timeout.total_seconds()) - ) - self.pyhccl = PyHcclCommunicatorEx(group=self.pg, device=self._device) - self._comm = self.pyhccl.comm - - def destroy_process_group(self, group=None): - if group in self.sub_groups: - subcomm = ctypes.c_void_p(group) - self.pyhccl.destroy_comm(subcomm) - del self.sub_groups[group] - return - - self.pyhccl.destroy_comm() - - self.pyhccl = None - self.pg = None - - def is_initialized(self) -> bool: - return self.pyhccl is not None - - def all_gather_object(self, object_list: list[Any], obj: Any, group=None): - if group: - assert group in self.sub_groups, "invalid sub_group" - subcomm = ctypes.c_void_p(group) - self.pyhccl.comm = subcomm - - _common_all_gather_object(self.pyhccl, self._device, self._world_size, object_list, obj) - current_stream().synchronize() - - if group: - self.pyhccl.comm = self._comm - - def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, group=None): - if group: - assert group in self.sub_groups, "invalid sub_group" - subcomm = ctypes.c_void_p(group) - self.pyhccl.comm = subcomm - - out_tensor = self.pyhccl.all_reduce(tensor, op) - current_stream().synchronize() - tensor.copy_(out_tensor) - - if group: - self.pyhccl.comm = self._comm - - def broadcast(self, tensor: torch.Tensor, src=None, group=None): - if group: - assert group in self.sub_groups, "invalid sub_group" - assert src in self.sub_groups[group], "src rank not in group" - subcomm = ctypes.c_void_p(group) - self.pyhccl.comm = subcomm - # convert src rank id in default world to subcomm - src = self.sub_groups[group].index(src) - - self.pyhccl.broadcast(tensor, src) - current_stream().synchronize() - - if group: - self.pyhccl.comm = self._comm - - def barrier(self, group=None): - if group: - assert group in self.sub_groups, "invalid sub_group" - subcomm = ctypes.c_void_p(group) - self.pyhccl.comm = subcomm - - data = torch.zeros(1, device=self._rank) - self.pyhccl.all_reduce(data) - current_stream().synchronize() - - if group: - self.pyhccl.comm = self._comm - - def new_group(self, ranks): - # if ranks is None or [], using the world instead - if not ranks: - ranks = list(range(self._world_size)) - - if self._rank not in ranks: - return - - subcomm = self.pyhccl.create_subcomm(ranks) - value = 0 - if subcomm: - value = subcomm.value - self.sub_groups[value] = ranks - return value - -except ImportError as e: - pass diff --git a/checkpoint_engine/distributed_hccl.py b/checkpoint_engine/distributed_hccl.py new file mode 100644 index 0000000..34d6625 --- /dev/null +++ b/checkpoint_engine/distributed_hccl.py @@ -0,0 +1,281 @@ +import ctypes +from datetime import timedelta +from typing import Any, List, Optional + +import torch +from torch.distributed import ReduceOp + +from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator +from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import ( + Function, + HCCLLibrary, + aclrtStream_t, + buffer_type, + hcclComm_t, + hcclDataType_t, + hcclDataTypeEnum, + hcclRedOp_t, + hcclRedOpTypeEnum, + hcclResult_t, + hcclUniqueId, +) +from vllm_ascend.utils import current_stream + +from .distributed_nccl import _common_all_gather_object + +class HcclCommConfig(ctypes.Structure): + _fields_ = [ + ("size", ctypes.c_size_t), + ("magic_word", ctypes.c_uint32), + ("version", ctypes.c_uint32), + ("reserved", ctypes.c_uint64), + ("hccl_buffer_size", ctypes.c_uint32), + ("hccl_deterministic", ctypes.c_uint32), + ("hccl_comm_name", ctypes.c_char * 128), + ("hccl_udi", ctypes.c_char * 128), + ("hccl_op_expansion_mode", ctypes.c_uint32), + ("hccl_rdma_traffic_class", ctypes.c_uint32), + ("hccl_rdma_service_level", ctypes.c_uint32), + ("hcll_world_rank_id", ctypes.c_uint32), + ("hccl_job_id", ctypes.c_uint64), + ("comm_engine", ctypes.c_int32), + ("thread_num", ctypes.c_uint32), + ("notify_num_per_thread", ctypes.c_uint32), + ("acl_graph_zero_copy_enable", ctypes.c_uint8), + ] + +orig_exported_functions = HCCLLibrary.exported_functions +extended_functions = [ + # HcclResult HcclAllGather( + # void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType, + # HcclComm comm, alcrtStream stream + # ) + Function( + "HcclAllGather", + hcclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_uint64, + hcclDataType_t, + hcclComm_t, + aclrtStream_t, + ], + ), + # HcclResult HcclCreateSubCommConfig( + # HcclComm *comm, uin32_t rankNum, uint32_t *rankIds, uint64_t subCommId, + # uint32_t subCommRankId, HcclCommConfig *config, HcclComm *subComm + # ) + Function( + "HcclCreateSubCommConfig", + hcclResult_t, + [ + ctypes.POINTER(hcclComm_t), + ctypes.c_uint32, + ctypes.POINTER(ctypes.c_uint32), + ctypes.c_uint64, + ctypes.c_uint32, + ctypes.POINTER(HcclCommConfig), + ctypes.POINTER(hcclComm_t), + ], + ), +] + +def hccl_all_gather(self, send_buf, recv_buf, count, data_type, comm, stream): + self.HCCL_CHECK( + self._funcs["HcclAllGather"](send_buf, recv_buf, count, data_type, comm, stream) + ) + +def hccl_create_subcomm_config( + self, comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config +): + subcomm = hcclComm_t() + self.HCCL_CHECK( + self._funcs["HcclCreateSubCommConfig"]( + ctypes.byref(comm), + ranks_size, + c_rank_ids, + subcomm_id, + subcomm_rank, + ctypes.byref(comm_config), + ctypes.byref(subcomm), + ) + ) + return subcomm + +# extend HCCLLibrary +HCCLLibrary.exported_functions = orig_exported_functions + extended_functions +HCCLLibrary.hcclAllGather = hccl_all_gather +HCCLLibrary.hcclCreateSubCommConfig = hccl_create_subcomm_config + +class PyHcclCommunicatorEx(PyHcclCommunicator): + def __init__(self, group, device): + super().__init__(group, device) + self.subcomm_id = 1 + + def destroy_comm(self, comm=None): + if comm: + self.hccl.hcclCommDestroy(comm) + else: + self.hccl.hcclCommDestroy(self.comm) + + def all_gather(self, out_tensor: torch.Tensor, in_tensor: torch.Tensor, stream=None): + if self.disabled: + return + assert in_tensor.device == self.device, ( + f"this hccl communicator is created to work on {self.device}, " + f"but the input tensor in on {in_tensor.device}" + ) + if stream is None: + stream = current_stream() + self.hccl.hcclAllGather( + buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + hcclDataTypeEnum.from_torch(in_tensor.dtype), + self.comm, # todo + aclrtStream_t(stream.npu_stream), + ) + return out_tensor + + def create_subcomm(self, ranks): + comm_config = HcclCommConfig( + size=312, + magic_word=0xF0F0F0F0, + version=6, + reserved=0, + hccl_buffer_size=0xFFFFFFFF, + hccl_deterministic=0xFFFFFFFF, + hccl_comm_name=b"\0", + hccl_udi=b"\0", + hccl_op_expansize_mode=0, + hccl_rdma_traffic_class=0xFFFFFFFF, + hccl_rdma_service_level=0xFFFFFFFF, + hccl_world_rank_id=0, + hccl_job_id=0, + comm_engine=-1, + thread_num=0xFFFFFFFF, + notify_num_per_thread=0xFFFFFFFF, + acl_graph_zero_copy_enable=0, + ) + uint32_array = ctypes.c_uint32 * len(ranks) + c_rank_ids = uint32_array(*ranks) + subcomm_rank = ranks.index(self.rank) + ranks_size = len(ranks) + subcomm_id = self.subcomm_id + + subcomm = self.hccl.hcclCreateSubCommConfig( + self.comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config + ) + self.subcomm_id += 1 + return subcomm + +class DistributedHccl: + def __init__(self): + self.pg = None + self.pyhccl = None + self.sub_groups = {} + + def init_process_group( + self, + host: str, + port: int, + rank: int, + world_size: int, + timeout: timedelta = timedelta(seconds=300), + **kwargs, + ): + self._host = host + self._port = port + self._rank = rank + self._world_size = world_size + self._device = torch.device("npu", rank) + + self.pg = StatelessProcessGroup.create( + host, port, rank, world_size, store_timeout=int(timeout.total_seconds()) + ) + self.pyhccl = PyHcclCommunicatorEx(group=self.pg, device=self._device) + self._comm = self.pyhccl.comm + + def destroy_process_group(self, group=None): + if group in self.sub_groups: + subcomm = ctypes.c_void_p(group) + self.pyhccl.destroy_comm(subcomm) + del self.sub_groups[group] + return + + self.pyhccl.destroy_comm() + + self.pyhccl = None + self.pg = None + + def is_initialized(self) -> bool: + return self.pyhccl is not None + + def all_gather_object(self, object_list: list[Any], obj: Any, group=None): + if group: + assert group in self.sub_groups, "invalid sub_group" + subcomm = ctypes.c_void_p(group) + self.pyhccl.comm = subcomm + + _common_all_gather_object(self.pyhccl, self._device, self._world_size, object_list, obj) + current_stream().synchronize() + + if group: + self.pyhccl.comm = self._comm + + def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, group=None): + if group: + assert group in self.sub_groups, "invalid sub_group" + subcomm = ctypes.c_void_p(group) + self.pyhccl.comm = subcomm + + out_tensor = self.pyhccl.all_reduce(tensor, op) + current_stream().synchronize() + tensor.copy_(out_tensor) + + if group: + self.pyhccl.comm = self._comm + + def broadcast(self, tensor: torch.Tensor, src=None, group=None): + if group: + assert group in self.sub_groups, "invalid sub_group" + assert src in self.sub_groups[group], "src rank not in group" + subcomm = ctypes.c_void_p(group) + self.pyhccl.comm = subcomm + # convert src rank id in default world to subcomm + src = self.sub_groups[group].index(src) + + self.pyhccl.broadcast(tensor, src) + current_stream().synchronize() + + if group: + self.pyhccl.comm = self._comm + + def barrier(self, group=None): + if group: + assert group in self.sub_groups, "invalid sub_group" + subcomm = ctypes.c_void_p(group) + self.pyhccl.comm = subcomm + + data = torch.zeros(1, device=self._rank) + self.pyhccl.all_reduce(data) + current_stream().synchronize() + + if group: + self.pyhccl.comm = self._comm + + def new_group(self, ranks): + # if ranks is None or [], using the world instead + if not ranks: + ranks = list(range(self._world_size)) + + if self._rank not in ranks: + return + + subcomm = self.pyhccl.create_subcomm(ranks) + value = 0 + if subcomm: + value = subcomm.value + self.sub_groups[value] = ranks + return value diff --git a/checkpoint_engine/distributed_nccl.py b/checkpoint_engine/distributed_nccl.py new file mode 100644 index 0000000..ca7c19f --- /dev/null +++ b/checkpoint_engine/distributed_nccl.py @@ -0,0 +1,262 @@ +import ctypes +import io +import logging +import os +import pickle +from datetime import timedelta +from enum import Enum +from typing import Any, List, Optional + +import torch +import torch.distributed +from torch.distributed import ReduceOp +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.device_communicators.pynccl_wrapper import ( + Function, + NCCLLibrary, + buffer_type, + ncclComm_t, + ncclResult_t, +) +from vllm.distributed.utils import StatelessProcessGroup +from vllm.utils import current_stream + + +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler + + +def _object_to_tensor(obj, device): + f = io.BytesIO() + _pickler(f).dump(obj) + byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) + byte_tensor = torch.ByteTensor(byte_storage).to(device) + local_size = torch.LongTensor([byte_tensor.numel()]).to(device) + return byte_tensor, local_size + + +def _tensor_to_object(tensor, tensor_size): + tensor = tensor.cpu() + buf = tensor.numpy().tobytes()[:tensor_size] + return _unpickler(io.BytesIO(buf)).load() + + +def _flatten_for_scatter_gather(tensor_list, copy=False): + if not tensor_list: + raise RuntimeError("Received an empty list.") + t = tensor_list[0] + buffer_shape = [len(tensor_list)] + list(t.shape) + + buffer = torch.empty(tuple(buffer_shape), dtype=t.dtype, device=t.device) + if copy: + for i, tensor in enumerate(tensor_list): + buffer[i].copy_(tensor) + return buffer + + +def _common_all_gather_object(comm, device, world_size, object_list, object): + input_tensor, local_size = _object_to_tensor(object, device) + object_sizes_tensor = torch.empty(world_size, dtype=torch.long, device=device) + comm.all_gather(object_sizes_tensor, local_size) + object_size_list = [object_sizes_tensor[i].unsqueeze(dim=0) for i in range(world_size)] + max_object_size = int(max(object_size_list).item()) + input_tensor.resize_(max_object_size) + coalesced_output_tensor = torch.empty( + max_object_size * world_size, dtype=torch.uint8, device=device + ) + + comm.all_gather(coalesced_output_tensor, input_tensor) + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(world_size) + ] + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + tensor_size = object_size_list[i] + object_list[i] = _tensor_to_object(tensor, tensor_size) + + +class ncclConfig_t(ctypes.Structure): + _fields_ = [ + ("size", ctypes.c_size_t), + ("magic", ctypes.c_uint), + ("version", ctypes.c_uint), + ("blocking", ctypes.c_int), + ("cgaClusterSize", ctypes.c_int), + ("minCTAs", ctypes.c_int), + ("maxCTAs", ctypes.c_int), + ("netName", ctypes.c_char_p), + ("splitShare", ctypes.c_int), + ("trafficClass", ctypes.c_int), + ("commName", ctypes.c_char_p), + ("collnetEnable", ctypes.c_int), + ("CTAPolicy", ctypes.c_int), + ("shrinkShare", ctypes.c_int), + ("nvlsCTAs", ctypes.c_int), + ("nChannelsPerNetPeer", ctypes.c_int), + ("nvlinkCentricSched", ctypes.c_int), + ("graphUsageMode", ctypes.c_int), + ("numRmdCtx", ctypes.c_int), + ] + +nccl_orig_exported_functions = NCCLLibrary.exported_functions +nccl_extended_functions = [ + # ncclResult_t ncclCommSplit( + # ncclComm_t comm, int color, int key, ncclComm_t *newcomm, ncclConfig_t *config + # ) + Function( + "ncclCommSplit", + ncclResult_t, + [ + ncclComm_t, + ctypes.c_int, + ctypes.c_int, + ctypes.POINTER(ncclComm_t), + ctypes.POINTER(ncclConfig_t), + ], + ), +] + + +def nccl_comm_split( + self, + comm: ncclComm_t, + color: int, + key: int, +) -> ncclComm_t: + newcomm = ncclComm_t() + + self.NCCL_CHECK( + self._funcs["ncclCommSplit"](comm, color, key, ctypes.byref(newcomm), None) + ) + return newcomm + + +# extend NCCLLibrary +NCCLLibrary.exported_functions = nccl_orig_exported_functions + nccl_extended_functions +NCCLLibrary.ncclCommSplit = nccl_comm_split + + +class PyNcclCommunicatorEx(PyNcclCommunicator): + def destroy_comm(self, comm=None): + if comm: + self.nccl.ncclCommDestroy(comm) + else: + self.nccl.ncclCommDestroy(self.comm) + + def create_newcomm(self, ranks): + if self.rank in ranks: + color = 0 + else: + color = -1 # NCCL_SPLIT_NOCOLOR + newcomm = self.nccl.ncclCommSplit(self.comm, color, self.rank) + return newcomm + + +class DistributedNccl: + def __init__(self): + self.pg = None + self.pynccl = None + self.sub_groups = {} + + def init_process_group( + self, + host: str, + port: int, + rank: int, + world_size: int, + timeout: timedelta = timedelta(seconds=300), + **kwargs, + ): + self._host = host + self._port = port + self._rank = rank + self._world_size = world_size + self._device = torch.device("cuda", rank) + + self.pg = StatelessProcessGroup.create( + host, port, rank, world_size, store_timeout=int(timeout.total_seconds()) + ) + + self.pynccl = PyNcclCommunicatorEx(group=self.pg, device=self._device) + self._comm = self.pynccl.comm + + def destroy_process_group(self, group=None): + if group in self.sub_groups: + newcomm = ctypes.c_void_p(group) + self.pynccl.destroy_comm(newcomm) + del self.sub_groups[group] + return + + self.pynccl.destroy_comm() + + self.pynccl = None + self.pg = None + + def is_initialized(self) -> bool: + return self.pynccl is not None + + def all_gather_object(self, object_list: list[Any], obj: Any, group=None): + if group: + assert group in self.sub_groups, "invalid sub_group" + newcomm = ctypes.c_void_p(group) + self.pynccl.comm = newcomm + + _common_all_gather_object(self.pynccl, self._device, self._world_size, object_list, obj) + current_stream().synchronize() + + if group: + self.pynccl.comm = self._comm + + def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, group=None): + if group: + assert group in self.sub_groups, "invalid sub_group" + newcomm = ctypes.c_void_p(group) + self.pynccl.comm = newcomm + + out_tensor = self.pynccl.all_reduce(in_tensor=tensor, op=op) + current_stream().synchronize() + tensor.copy_(out_tensor) + + if group: + self.pynccl.comm = self._comm + + def broadcast(self, tensor: torch.Tensor, src=None, group=None): + if group: + assert group in self.sub_groups, "invalid sub_group" + assert src in self.sub_groups[group], "src rank not in group" + newcomm = ctypes.c_void_p(group) + self.pynccl.comm = newcomm + # convert src rank id in default world to newcomm + #src = self.sub_groups[group].index(src) + + self.pynccl.broadcast(tensor, src) + current_stream().synchronize() + + if group: + self.pynccl.comm = self._comm + + def barrier(self, group=None): + if group: + assert group in self.sub_groups, "invalid sub_group" + newcomm = ctypes.c_void_p(group) + self.pynccl.comm = newcomm + + data = torch.zeros(1, device=self._rank) + self.pynccl.all_reduce(data) + current_stream().synchronize() + + if group: + self.pynccl.comm = self._comm + + def new_group(self, ranks): + # ranks is None or [] + if not ranks: + ranks = list(range(self._world_size)) + + newcomm = self.pynccl.create_newcomm(ranks) + value = 0 + if newcomm: + value = newcomm.value + self.sub_groups[value] = ranks + return value From 981ec100fe08c4488a78c9aeceb33e35f7e5a653 Mon Sep 17 00:00:00 2001 From: yexin <469221983@qq.com> Date: Tue, 6 Jan 2026 17:54:11 +0800 Subject: [PATCH 07/14] fix ncclBroadcast illegal memory access --- checkpoint_engine/distributed_nccl.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/checkpoint_engine/distributed_nccl.py b/checkpoint_engine/distributed_nccl.py index ca7c19f..78b54c6 100644 --- a/checkpoint_engine/distributed_nccl.py +++ b/checkpoint_engine/distributed_nccl.py @@ -75,7 +75,6 @@ def _common_all_gather_object(comm, device, world_size, object_list, object): tensor_size = object_size_list[i] object_list[i] = _tensor_to_object(tensor, tensor_size) - class ncclConfig_t(ctypes.Structure): _fields_ = [ ("size", ctypes.c_size_t), @@ -96,7 +95,7 @@ class ncclConfig_t(ctypes.Structure): ("nChannelsPerNetPeer", ctypes.c_int), ("nvlinkCentricSched", ctypes.c_int), ("graphUsageMode", ctypes.c_int), - ("numRmdCtx", ctypes.c_int), + ("numRmaCtx", ctypes.c_int), ] nccl_orig_exported_functions = NCCLLibrary.exported_functions @@ -228,13 +227,15 @@ def broadcast(self, tensor: torch.Tensor, src=None, group=None): newcomm = ctypes.c_void_p(group) self.pynccl.comm = newcomm # convert src rank id in default world to newcomm - #src = self.sub_groups[group].index(src) + src = self.sub_groups[group].index(src) + self.pynccl.rank = self.sub_groups[group].index(self._rank) self.pynccl.broadcast(tensor, src) current_stream().synchronize() if group: self.pynccl.comm = self._comm + self.pynccl.rank = self._rank def barrier(self, group=None): if group: From 166f819aecd01239b4a0416a52cd5292f51839a3 Mon Sep 17 00:00:00 2001 From: yexin <469221983@qq.com> Date: Wed, 7 Jan 2026 09:22:35 +0800 Subject: [PATCH 08/14] export distributed functions --- checkpoint_engine/distributed_hccl.py | 210 ++++++++++++---------- checkpoint_engine/distributed_nccl.py | 246 ++++++++++++++------------ 2 files changed, 254 insertions(+), 202 deletions(-) diff --git a/checkpoint_engine/distributed_hccl.py b/checkpoint_engine/distributed_hccl.py index 34d6625..b7304a1 100644 --- a/checkpoint_engine/distributed_hccl.py +++ b/checkpoint_engine/distributed_hccl.py @@ -1,3 +1,4 @@ +from sre_parse import State import ctypes from datetime import timedelta from typing import Any, List, Optional @@ -5,6 +6,7 @@ import torch from torch.distributed import ReduceOp +from vllm.distributed.utils import StatelessProcessGroup from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import ( Function, @@ -171,111 +173,139 @@ def create_subcomm(self, ranks): return subcomm class DistributedHccl: - def __init__(self): - self.pg = None - self.pyhccl = None - self.sub_groups = {} - - def init_process_group( - self, - host: str, - port: int, - rank: int, - world_size: int, - timeout: timedelta = timedelta(seconds=300), - **kwargs, - ): - self._host = host - self._port = port - self._rank = rank - self._world_size = world_size - self._device = torch.device("npu", rank) - - self.pg = StatelessProcessGroup.create( - host, port, rank, world_size, store_timeout=int(timeout.total_seconds()) - ) - self.pyhccl = PyHcclCommunicatorEx(group=self.pg, device=self._device) - self._comm = self.pyhccl.comm - - def destroy_process_group(self, group=None): - if group in self.sub_groups: - subcomm = ctypes.c_void_p(group) - self.pyhccl.destroy_comm(subcomm) - del self.sub_groups[group] - return + pg: StatelessProcessGroup + pyhccl: PyHcclCommunicatorEx + sub_groups: dict[int, list[int]] + comm: hcclComm_t - self.pyhccl.destroy_comm() + host: str + port: int + rank: int + world_size: int + device: torch.device - self.pyhccl = None - self.pg = None + initialized: bool = False - def is_initialized(self) -> bool: - return self.pyhccl is not None - def all_gather_object(self, object_list: list[Any], obj: Any, group=None): - if group: - assert group in self.sub_groups, "invalid sub_group" - subcomm = ctypes.c_void_p(group) - self.pyhccl.comm = subcomm +dist = DistributedHccl() - _common_all_gather_object(self.pyhccl, self._device, self._world_size, object_list, obj) - current_stream().synchronize() +def init_process_group( + host: str, + port: int, + rank: int, + world_size: int, + timeout: timedelta = timedelta(seconds=300), + **kwargs, +): + assert not dist.initialized, "already initialized" - if group: - self.pyhccl.comm = self._comm + dist.host = host + dist.port = port + dist.rank = rank + dist.world_size = world_size + dist.device = torch.device("npu", rank) - def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, group=None): - if group: - assert group in self.sub_groups, "invalid sub_group" - subcomm = ctypes.c_void_p(group) - self.pyhccl.comm = subcomm + dist.pg = StatelessProcessGroup.create( + host, port, rank, world_size, store_timeout=int(timeout.total_seconds()) + ) + dist.pyhccl = PyHcclCommunicatorEx(group=dist.pg, device=dist.device) + dist.comm = dist.pyhccl.comm + dist.initialized = True - out_tensor = self.pyhccl.all_reduce(tensor, op) - current_stream().synchronize() - tensor.copy_(out_tensor) +def destroy_process_group(group=None): + assert dist.initialized, "not initialized" - if group: - self.pyhccl.comm = self._comm + if group in dist.sub_groups: + subcomm = ctypes.c_void_p(group) + dist.pyhccl.destroy_comm(subcomm) + del dist.sub_groups[group] + return - def broadcast(self, tensor: torch.Tensor, src=None, group=None): - if group: - assert group in self.sub_groups, "invalid sub_group" - assert src in self.sub_groups[group], "src rank not in group" - subcomm = ctypes.c_void_p(group) - self.pyhccl.comm = subcomm - # convert src rank id in default world to subcomm - src = self.sub_groups[group].index(src) + dist.pyhccl.destroy_comm() - self.pyhccl.broadcast(tensor, src) - current_stream().synchronize() + dist.pyhccl = None + dist.pg = None + dist.initialized = False - if group: - self.pyhccl.comm = self._comm +def is_initialized(dist) -> bool: + return dist.initialized - def barrier(self, group=None): - if group: - assert group in self.sub_groups, "invalid sub_group" - subcomm = ctypes.c_void_p(group) - self.pyhccl.comm = subcomm +def all_gather_object(object_list: list[Any], obj: Any, group=None): + assert dist.initialized, "not initialized" - data = torch.zeros(1, device=self._rank) - self.pyhccl.all_reduce(data) - current_stream().synchronize() + if group: + assert group in dist.sub_groups, "invalid sub_group" + subcomm = ctypes.c_void_p(group) + dist.pyhccl.comm = subcomm - if group: - self.pyhccl.comm = self._comm + _common_all_gather_object(dist.pyhccl, dist.device, dist.world_size, object_list, obj) + current_stream().synchronize() - def new_group(self, ranks): - # if ranks is None or [], using the world instead - if not ranks: - ranks = list(range(self._world_size)) + if group: + dist.pyhccl.comm = dist.comm - if self._rank not in ranks: - return +def all_reduce(tensor: torch.Tensor, op=ReduceOp.SUM, group=None): + assert dist.initialized, "not initialized" + + if group: + assert group in dist.sub_groups, "invalid sub_group" + subcomm = ctypes.c_void_p(group) + dist.pyhccl.comm = subcomm + + out_tensor = dist.pyhccl.all_reduce(tensor, op) + current_stream().synchronize() + tensor.copy_(out_tensor) + + if group: + dist.pyhccl.comm = dist.comm + +def broadcast(tensor: torch.Tensor, src=None, group=None): + assert dist.initialized, "not initialized" + + if group: + assert group in dist.sub_groups, "invalid sub_group" + assert src in dist.sub_groups[group], "src rank not in group" + subcomm = ctypes.c_void_p(group) + dist.pyhccl.comm = subcomm + # convert src rank id in default world to subcomm + src = dist.sub_groups[group].index(src) + dist.pyhccl.rank = dist.sub_groups[group].index(dist.rank) + + dist.pyhccl.broadcast(tensor, src) + current_stream().synchronize() + + if group: + dist.pyhccl.comm = dist.comm + dist.pyhccl.rank = dist.rank + +def barrier(group=None): + assert dist.initialized, "not initialized" + + if group: + assert group in dist.sub_groups, "invalid sub_group" + subcomm = ctypes.c_void_p(group) + dist.pyhccl.comm = subcomm + + data = torch.zeros(1, device=dist.rank) + dist.pyhccl.all_reduce(data) + current_stream().synchronize() + + if group: + dist.pyhccl.comm = dist.comm + +def new_group(ranks): + assert dist.initialized, "not initialized" + + # if ranks is None or [], using the world instead + if not ranks: + ranks = list(range(dist.world_size)) + + if dist.rank not in ranks: + return - subcomm = self.pyhccl.create_subcomm(ranks) - value = 0 - if subcomm: - value = subcomm.value - self.sub_groups[value] = ranks - return value + subcomm = dist.pyhccl.create_subcomm(ranks) + value = 0 + if subcomm: + value = subcomm.value + dist.sub_groups[value] = ranks + return value diff --git a/checkpoint_engine/distributed_nccl.py b/checkpoint_engine/distributed_nccl.py index 78b54c6..996bece 100644 --- a/checkpoint_engine/distributed_nccl.py +++ b/checkpoint_engine/distributed_nccl.py @@ -1,14 +1,10 @@ import ctypes import io -import logging -import os import pickle from datetime import timedelta -from enum import Enum from typing import Any, List, Optional import torch -import torch.distributed from torch.distributed import ReduceOp from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import ( @@ -153,111 +149,137 @@ def create_newcomm(self, ranks): class DistributedNccl: - def __init__(self): - self.pg = None - self.pynccl = None - self.sub_groups = {} - - def init_process_group( - self, - host: str, - port: int, - rank: int, - world_size: int, - timeout: timedelta = timedelta(seconds=300), - **kwargs, - ): - self._host = host - self._port = port - self._rank = rank - self._world_size = world_size - self._device = torch.device("cuda", rank) - - self.pg = StatelessProcessGroup.create( - host, port, rank, world_size, store_timeout=int(timeout.total_seconds()) - ) - - self.pynccl = PyNcclCommunicatorEx(group=self.pg, device=self._device) - self._comm = self.pynccl.comm - - def destroy_process_group(self, group=None): - if group in self.sub_groups: - newcomm = ctypes.c_void_p(group) - self.pynccl.destroy_comm(newcomm) - del self.sub_groups[group] - return - - self.pynccl.destroy_comm() - - self.pynccl = None - self.pg = None - - def is_initialized(self) -> bool: - return self.pynccl is not None - - def all_gather_object(self, object_list: list[Any], obj: Any, group=None): - if group: - assert group in self.sub_groups, "invalid sub_group" - newcomm = ctypes.c_void_p(group) - self.pynccl.comm = newcomm - - _common_all_gather_object(self.pynccl, self._device, self._world_size, object_list, obj) - current_stream().synchronize() - - if group: - self.pynccl.comm = self._comm - - def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, group=None): - if group: - assert group in self.sub_groups, "invalid sub_group" - newcomm = ctypes.c_void_p(group) - self.pynccl.comm = newcomm - - out_tensor = self.pynccl.all_reduce(in_tensor=tensor, op=op) - current_stream().synchronize() - tensor.copy_(out_tensor) - - if group: - self.pynccl.comm = self._comm - - def broadcast(self, tensor: torch.Tensor, src=None, group=None): - if group: - assert group in self.sub_groups, "invalid sub_group" - assert src in self.sub_groups[group], "src rank not in group" - newcomm = ctypes.c_void_p(group) - self.pynccl.comm = newcomm - # convert src rank id in default world to newcomm - src = self.sub_groups[group].index(src) - self.pynccl.rank = self.sub_groups[group].index(self._rank) - - self.pynccl.broadcast(tensor, src) - current_stream().synchronize() - - if group: - self.pynccl.comm = self._comm - self.pynccl.rank = self._rank - - def barrier(self, group=None): - if group: - assert group in self.sub_groups, "invalid sub_group" - newcomm = ctypes.c_void_p(group) - self.pynccl.comm = newcomm - - data = torch.zeros(1, device=self._rank) - self.pynccl.all_reduce(data) - current_stream().synchronize() - - if group: - self.pynccl.comm = self._comm - - def new_group(self, ranks): - # ranks is None or [] - if not ranks: - ranks = list(range(self._world_size)) - - newcomm = self.pynccl.create_newcomm(ranks) - value = 0 - if newcomm: - value = newcomm.value - self.sub_groups[value] = ranks - return value + pg: StatelessProcessGroup + pynccl: PyNcclCommunicatorEx + sub_groups: dict[int, list[int]] + comm: ncclComm_t + + host: str + port: int + rank: int + world_size: int + device: torch.device + + initialized: bool = False + + +dist = DistributedNccl() + +def init_process_group( + host: str, + port: int, + rank: int, + world_size: int, + timeout: timedelta = timedelta(seconds=300), + **kwargs, +): + assert not dist.initialized, "already initialized" + + dist.host = host + dist.port = port + dist.rank = rank + dist.world_size = world_size + dist.device = torch.device("cuda", rank) + + dist.pg = StatelessProcessGroup.create( + host, port, rank, world_size, store_timeout=int(timeout.total_seconds()) + ) + + dist.pynccl = PyNcclCommunicatorEx(group=dist.pg, device=dist.device) + dist.comm = dist.pynccl.comm + dist.initialized = True + +def destroy_process_group(group=None): + assert dist.initialized, "not initialized" + + if group in dist.sub_groups: + newcomm = ctypes.c_void_p(group) + dist.pynccl.destroy_comm(newcomm) + del dist.sub_groups[group] + return + + dist.pynccl.destroy_comm() + + dist.pynccl = None + dist.pg = None + dist.initialized = False + +def is_initialized(dist) -> bool: + return dist.initialized + +def all_gather_object(object_list: list[Any], obj: Any, group=None): + assert dist.initialized, "not initialized" + + if group: + assert group in dist.sub_groups, "invalid sub_group" + newcomm = ctypes.c_void_p(group) + dist.pynccl.comm = newcomm + + _common_all_gather_object(dist.pynccl, dist.device, dist.world_size, object_list, obj) + current_stream().synchronize() + + if group: + dist.pynccl.comm = dist.comm + +def all_reduce(tensor: torch.Tensor, op=ReduceOp.SUM, group=None): + assert dist.initialized, "not initialized" + + if group: + assert group in dist.sub_groups, "invalid sub_group" + newcomm = ctypes.c_void_p(group) + dist.pynccl.comm = newcomm + + out_tensor = dist.pynccl.all_reduce(in_tensor=tensor, op=op) + current_stream().synchronize() + tensor.copy_(out_tensor) + + if group: + dist.pynccl.comm = dist.comm + +def broadcast(tensor: torch.Tensor, src=None, group=None): + assert dist.initialized, "not initialized" + + if group: + assert group in dist.sub_groups, "invalid sub_group" + assert src in dist.sub_groups[group], "src rank not in group" + newcomm = ctypes.c_void_p(group) + dist.pynccl.comm = newcomm + # convert src rank id in default world to newcomm + src = dist.sub_groups[group].index(src) + dist.pynccl.rank = dist.sub_groups[group].index(dist.rank) + + dist.pynccl.broadcast(tensor, src) + current_stream().synchronize() + + if group: + dist.pynccl.comm = dist.comm + dist.pynccl.rank = dist.rank + +def barrier(group=None): + assert dist.initialized, "not initialized" + + if group: + assert group in dist.sub_groups, "invalid sub_group" + newcomm = ctypes.c_void_p(group) + dist.pynccl.comm = newcomm + + data = torch.zeros(1, device=dist.rank) + dist.pynccl.all_reduce(data) + current_stream().synchronize() + + if group: + dist.pynccl.comm = dist.comm + +def new_group(ranks): + assert dist.initialized, "not initialized" + + # ranks is None or [] + if not ranks: + ranks = list(range(dist.world_size)) + + newcomm = dist.pynccl.create_newcomm(ranks) + value = 0 + if newcomm: + value = newcomm.value + dist.sub_groups[value] = ranks + return value From f00d6d6b5f652c476d3b9972545adcd7355d049d Mon Sep 17 00:00:00 2001 From: yexin <469221983@qq.com> Date: Wed, 7 Jan 2026 09:35:43 +0800 Subject: [PATCH 09/14] fix bugs --- checkpoint_engine/distributed_hccl.py | 3 +-- checkpoint_engine/distributed_nccl.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/checkpoint_engine/distributed_hccl.py b/checkpoint_engine/distributed_hccl.py index b7304a1..9ee1588 100644 --- a/checkpoint_engine/distributed_hccl.py +++ b/checkpoint_engine/distributed_hccl.py @@ -1,4 +1,3 @@ -from sre_parse import State import ctypes from datetime import timedelta from typing import Any, List, Optional @@ -227,7 +226,7 @@ def destroy_process_group(group=None): dist.pg = None dist.initialized = False -def is_initialized(dist) -> bool: +def is_initialized() -> bool: return dist.initialized def all_gather_object(object_list: list[Any], obj: Any, group=None): diff --git a/checkpoint_engine/distributed_nccl.py b/checkpoint_engine/distributed_nccl.py index 996bece..e6507d3 100644 --- a/checkpoint_engine/distributed_nccl.py +++ b/checkpoint_engine/distributed_nccl.py @@ -204,7 +204,7 @@ def destroy_process_group(group=None): dist.pg = None dist.initialized = False -def is_initialized(dist) -> bool: +def is_initialized() -> bool: return dist.initialized def all_gather_object(object_list: list[Any], obj: Any, group=None): From 68ef1f67154a2190424e2c9f2384a0f064b1fdee Mon Sep 17 00:00:00 2001 From: yexin <469221983@qq.com> Date: Wed, 7 Jan 2026 11:03:36 +0800 Subject: [PATCH 10/14] fix bugs --- checkpoint_engine/distributed_hccl.py | 18 ++++++++++++++++-- checkpoint_engine/distributed_nccl.py | 18 +++++++++++++----- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/checkpoint_engine/distributed_hccl.py b/checkpoint_engine/distributed_hccl.py index 9ee1588..c4d8935 100644 --- a/checkpoint_engine/distributed_hccl.py +++ b/checkpoint_engine/distributed_hccl.py @@ -4,7 +4,6 @@ import torch from torch.distributed import ReduceOp - from vllm.distributed.utils import StatelessProcessGroup from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import ( @@ -24,6 +23,7 @@ from .distributed_nccl import _common_all_gather_object + class HcclCommConfig(ctypes.Structure): _fields_ = [ ("size", ctypes.c_size_t), @@ -45,6 +45,7 @@ class HcclCommConfig(ctypes.Structure): ("acl_graph_zero_copy_enable", ctypes.c_uint8), ] + orig_exported_functions = HCCLLibrary.exported_functions extended_functions = [ # HcclResult HcclAllGather( @@ -82,11 +83,13 @@ class HcclCommConfig(ctypes.Structure): ), ] + def hccl_all_gather(self, send_buf, recv_buf, count, data_type, comm, stream): self.HCCL_CHECK( self._funcs["HcclAllGather"](send_buf, recv_buf, count, data_type, comm, stream) ) + def hccl_create_subcomm_config( self, comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config ): @@ -104,11 +107,13 @@ def hccl_create_subcomm_config( ) return subcomm + # extend HCCLLibrary HCCLLibrary.exported_functions = orig_exported_functions + extended_functions HCCLLibrary.hcclAllGather = hccl_all_gather HCCLLibrary.hcclCreateSubCommConfig = hccl_create_subcomm_config + class PyHcclCommunicatorEx(PyHcclCommunicator): def __init__(self, group, device): super().__init__(group, device) @@ -171,10 +176,11 @@ def create_subcomm(self, ranks): self.subcomm_id += 1 return subcomm + class DistributedHccl: pg: StatelessProcessGroup pyhccl: PyHcclCommunicatorEx - sub_groups: dict[int, list[int]] + sub_groups: dict[int, list[int]] = {} comm: hcclComm_t host: str @@ -188,6 +194,7 @@ class DistributedHccl: dist = DistributedHccl() + def init_process_group( host: str, port: int, @@ -211,6 +218,7 @@ def init_process_group( dist.comm = dist.pyhccl.comm dist.initialized = True + def destroy_process_group(group=None): assert dist.initialized, "not initialized" @@ -226,9 +234,11 @@ def destroy_process_group(group=None): dist.pg = None dist.initialized = False + def is_initialized() -> bool: return dist.initialized + def all_gather_object(object_list: list[Any], obj: Any, group=None): assert dist.initialized, "not initialized" @@ -243,6 +253,7 @@ def all_gather_object(object_list: list[Any], obj: Any, group=None): if group: dist.pyhccl.comm = dist.comm + def all_reduce(tensor: torch.Tensor, op=ReduceOp.SUM, group=None): assert dist.initialized, "not initialized" @@ -258,6 +269,7 @@ def all_reduce(tensor: torch.Tensor, op=ReduceOp.SUM, group=None): if group: dist.pyhccl.comm = dist.comm + def broadcast(tensor: torch.Tensor, src=None, group=None): assert dist.initialized, "not initialized" @@ -277,6 +289,7 @@ def broadcast(tensor: torch.Tensor, src=None, group=None): dist.pyhccl.comm = dist.comm dist.pyhccl.rank = dist.rank + def barrier(group=None): assert dist.initialized, "not initialized" @@ -292,6 +305,7 @@ def barrier(group=None): if group: dist.pyhccl.comm = dist.comm + def new_group(ranks): assert dist.initialized, "not initialized" diff --git a/checkpoint_engine/distributed_nccl.py b/checkpoint_engine/distributed_nccl.py index e6507d3..5f7f919 100644 --- a/checkpoint_engine/distributed_nccl.py +++ b/checkpoint_engine/distributed_nccl.py @@ -71,6 +71,7 @@ def _common_all_gather_object(comm, device, world_size, object_list, object): tensor_size = object_size_list[i] object_list[i] = _tensor_to_object(tensor, tensor_size) + class ncclConfig_t(ctypes.Structure): _fields_ = [ ("size", ctypes.c_size_t), @@ -94,6 +95,7 @@ class ncclConfig_t(ctypes.Structure): ("numRmaCtx", ctypes.c_int), ] + nccl_orig_exported_functions = NCCLLibrary.exported_functions nccl_extended_functions = [ # ncclResult_t ncclCommSplit( @@ -121,9 +123,7 @@ def nccl_comm_split( ) -> ncclComm_t: newcomm = ncclComm_t() - self.NCCL_CHECK( - self._funcs["ncclCommSplit"](comm, color, key, ctypes.byref(newcomm), None) - ) + self.NCCL_CHECK(self._funcs["ncclCommSplit"](comm, color, key, ctypes.byref(newcomm), None)) return newcomm @@ -143,7 +143,7 @@ def create_newcomm(self, ranks): if self.rank in ranks: color = 0 else: - color = -1 # NCCL_SPLIT_NOCOLOR + color = -1 # NCCL_SPLIT_NOCOLOR newcomm = self.nccl.ncclCommSplit(self.comm, color, self.rank) return newcomm @@ -151,7 +151,7 @@ def create_newcomm(self, ranks): class DistributedNccl: pg: StatelessProcessGroup pynccl: PyNcclCommunicatorEx - sub_groups: dict[int, list[int]] + sub_groups: dict[int, list[int]] = {} comm: ncclComm_t host: str @@ -165,6 +165,7 @@ class DistributedNccl: dist = DistributedNccl() + def init_process_group( host: str, port: int, @@ -189,6 +190,7 @@ def init_process_group( dist.comm = dist.pynccl.comm dist.initialized = True + def destroy_process_group(group=None): assert dist.initialized, "not initialized" @@ -204,9 +206,11 @@ def destroy_process_group(group=None): dist.pg = None dist.initialized = False + def is_initialized() -> bool: return dist.initialized + def all_gather_object(object_list: list[Any], obj: Any, group=None): assert dist.initialized, "not initialized" @@ -221,6 +225,7 @@ def all_gather_object(object_list: list[Any], obj: Any, group=None): if group: dist.pynccl.comm = dist.comm + def all_reduce(tensor: torch.Tensor, op=ReduceOp.SUM, group=None): assert dist.initialized, "not initialized" @@ -236,6 +241,7 @@ def all_reduce(tensor: torch.Tensor, op=ReduceOp.SUM, group=None): if group: dist.pynccl.comm = dist.comm + def broadcast(tensor: torch.Tensor, src=None, group=None): assert dist.initialized, "not initialized" @@ -255,6 +261,7 @@ def broadcast(tensor: torch.Tensor, src=None, group=None): dist.pynccl.comm = dist.comm dist.pynccl.rank = dist.rank + def barrier(group=None): assert dist.initialized, "not initialized" @@ -270,6 +277,7 @@ def barrier(group=None): if group: dist.pynccl.comm = dist.comm + def new_group(ranks): assert dist.initialized, "not initialized" From e3816789847115b61bdd2a7311eca7129607f760 Mon Sep 17 00:00:00 2001 From: yexin <469221983@qq.com> Date: Wed, 7 Jan 2026 16:48:38 +0800 Subject: [PATCH 11/14] modify ps.py --- checkpoint_engine/ps.py | 85 ++++++++++++++++++++++++----------------- examples/update.py | 9 ++++- 2 files changed, 57 insertions(+), 37 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 095b853..38f18b3 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -24,7 +24,6 @@ from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid from checkpoint_engine.p2p_store import P2PStore from checkpoint_engine.pin_memory import _ALIGN_SIZE, _register_checkpoint -from checkpoint_engine.distributed import DistributedNccl, DistributedHccl if TYPE_CHECKING: @@ -176,6 +175,7 @@ def __init__( auto_pg: bool = True, gpu_count: int | None = None, mem_fraction: float | None = None, + device_type: str | None = None, ): """ Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set. @@ -196,12 +196,16 @@ def __init__( self._local_rdma_devices: dict[str, set[int]] = defaultdict(set) self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set) self._mem_fraction = mem_fraction or float(os.getenv("PS_MEM_FRACTION", "0.9")) - if self.device_manager.backend == "nccl": - self.dist = DistributedNccl() - elif self.device_manager.backend == "hccl": - self.dist = DistributedHccl() + if device_type == "npu" and self.device_manager.device_type == "npu": + import checkpoint_engine.distributed_hccl + dist = checkpoint_engine.distributed_hccl + self._device_type = "npu" + elif device_type == "cuda" and self.device_manager.device_type == "cuda": + import checkpoint_engine.distributed_nccl + dist = checkpoint_engine.distributed_nccl + self._device_type = "cuda" else: - self.dist = torch.distributed + self._device_type = "torch" assert self._rank is not None and self._rank >= 0, self._rank assert self._world_size and self._world_size > 0, self._world_size @@ -422,9 +426,9 @@ def gather_metas(self, checkpoint_name: str): This function should be called before update and init a new value to `self._current_global_parameter_metas`, which can be exported by using `self.get_metas` function. """ - if self._auto_pg and not self.dist.is_initialized(): + if self._auto_pg and not dist.is_initialized(): self.init_process_group() - assert self.dist.is_initialized(), "process group is not initialized" + assert dist.is_initialized(), "process group is not initialized" metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore try: memory_pool = self._get_memory_pool(checkpoint_name) @@ -445,7 +449,7 @@ def gather_metas(self, checkpoint_name: str): rdma_device=self._rdma_device or "", ) - self.dist.all_gather_object(metas_lst, metas) + dist.all_gather_object(metas_lst, metas) self._current_global_parameter_metas = {} @@ -497,20 +501,29 @@ def init_process_group( """ master_addr = master_addr or os.getenv("MASTER_ADDR") assert master_addr, "master_addr is required" - store = dist.TCPStore( - master_addr, - _get_master_port(master_port), - self._world_size, - timeout=timeout, - is_master=self._rank == 0, - ) - self.dist.init_process_group( - backend=self.device_manager.backend, - world_size=self._world_size, - rank=self._rank, - timeout=timeout, - store=store, - ) + if self._device_type == "torch": + store = torch.distributed.TCPStore( + master_addr, + _get_master_port(master_port), + self._world_size, + timeout=timeout, + is_master=self._rank == 0, + ) + torch.distributed.init_process_group( + backend=self.device_manager.backend, + world_size=self._world_size, + rank=self._rank, + timeout=timeout, + store=store, + ) + else: + dist.init_process_group( + host=master_addr, + port=_get_master_port(master_port), + rank=self._rank, + world_size=self._world_size, + timeout=timeout, + ) logger.info(f"[rank{self._rank}] init process group successfully.") def store_based_barrier( @@ -526,7 +539,7 @@ def store_based_barrier( Args: store: The TCPStore instance to use for synchronization. """ - dist.distributed_c10d._store_based_barrier( + torch.distributed.distributed_c10d._store_based_barrier( rank=self._rank, store=store, group_name="parameter_server_barrier", @@ -566,14 +579,14 @@ def update( master_addr = os.getenv("MASTER_ADDR") or master_addr assert master_addr, "master_addr is required" if self._auto_pg: - if not self.dist.is_initialized(): + if not dist.is_initialized(): self.init_process_group( timeout=timeout, master_addr=master_addr, master_port=master_port ) # if ranks is None or [], it will use fully broadcast to update to all ranks - ranks_group = self.dist.new_group(ranks) if ranks else None + ranks_group = dist.new_group(ranks) if ranks else None self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks) - self.dist.barrier() + dist.barrier() except Exception as e: logger.exception( f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}" @@ -581,9 +594,9 @@ def update( raise finally: if ranks_group: - self.dist.destroy_process_group(ranks_group) - if self._auto_pg and self.dist.is_initialized(): - self.dist.destroy_process_group() + dist.destroy_process_group(ranks_group) + if self._auto_pg and dist.is_initialized(): + dist.destroy_process_group() self.device_manager.device_module.empty_cache() logger.info( f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. " @@ -619,7 +632,7 @@ def _detect_bucket_size( dtype=torch.int64, device=self.device_manager.device_type, ) - self.dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=ranks_group) + dist.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN, group=ranks_group) tensor = tensor.cpu() free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item() max_tensor_bytes = 0 @@ -725,7 +738,7 @@ def _update_per_bucket( ranks: list[int] | None = None, ): assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty" - assert self.dist.is_initialized(), "process group is not initialized" + assert dist.is_initialized(), "process group is not initialized" # if both ranks is None or [], it will use fully broadcast to update to all ranks if not ranks: @@ -744,7 +757,7 @@ def _update_per_bucket( if not need_update: return # first execute a barrier to avoid subsequent device oom - self.dist.barrier(group=ranks_group) + dist.barrier(group=ranks_group) bucket_size, disable_h2d_buffer = self._detect_bucket_size(ranks_group) buckets = _gen_h2d_buckets( @@ -822,7 +835,7 @@ def _update_per_bucket( self._copy_to_buffer(checkpoint_name, bucket, buffer_b) else: buffer_b.data.copy_(h2d_buffer[: bucket.size]) - self.dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group) + dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group) resp = socket.recv() if resp != b"": msg = resp.decode("utf-8") @@ -830,7 +843,7 @@ def _update_per_bucket( f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}" ) ret_code.fill_(1) - self.dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group) + dist.all_reduce(ret_code, op=torch.distributed.ReduceOp.SUM, group=ranks_group) self.device_manager.device_module.synchronize() if ret_code.item() != 0: # quit early if any rank failed @@ -844,7 +857,7 @@ def _update_per_bucket( socket.recv() finally: req_thread.join() - self.dist.barrier(group=ranks_group) + dist.barrier(group=ranks_group) socket.close() if ranks and h2d_buffer is not None: self._p2p_store.unregister_named_tensors([h2d_buffer_name]) diff --git a/examples/update.py b/examples/update.py index 51cb189..cfbd774 100644 --- a/examples/update.py +++ b/examples/update.py @@ -158,11 +158,18 @@ def join( parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0") parser.add_argument("--update-method", type=str, default="broadcast") parser.add_argument("--uds", type=str, default=None) + parser.add_argument("--device_type", type=str, default=None) args = parser.parse_args() rank = int(os.getenv("RANK")) world_size = int(os.getenv("WORLD_SIZE")) + + if args.device_type == "npu": + import checkpoint_engine.distributed_hccl as dist + elif args.device_type == "cuda": + import checkpoint_engine.distributed_nccl as dist + req_func = req_inference(args.endpoint, args.inference_parallel_size, args.uds) - ps = ParameterServer(auto_pg=True) + ps = ParameterServer(auto_pg=True, device_type=args.device_type) if args.load_metas_file: join( ps, From f2c0ae81b09c79cfbac340bcfcfc8fff03a1f98f Mon Sep 17 00:00:00 2001 From: yexin <469221983@qq.com> Date: Wed, 7 Jan 2026 17:16:06 +0800 Subject: [PATCH 12/14] fix import error --- examples/update.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/update.py b/examples/update.py index cfbd774..95f7e47 100644 --- a/examples/update.py +++ b/examples/update.py @@ -14,7 +14,8 @@ from loguru import logger from safetensors import safe_open -from checkpoint_engine.ps import ParameterServer, request_inference_to_update +from checkpoint_engine.ps import ParameterServer +from checkpoint_engine.api import request_inference_to_update @contextmanager From 8218a3b49c858cb2a10299a3bfb49e842237613e Mon Sep 17 00:00:00 2001 From: yexin <469221983@qq.com> Date: Wed, 7 Jan 2026 17:51:06 +0800 Subject: [PATCH 13/14] add missing global statement --- checkpoint_engine/ps.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 38f18b3..45e8a66 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -196,13 +196,12 @@ def __init__( self._local_rdma_devices: dict[str, set[int]] = defaultdict(set) self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set) self._mem_fraction = mem_fraction or float(os.getenv("PS_MEM_FRACTION", "0.9")) + global dist if device_type == "npu" and self.device_manager.device_type == "npu": - import checkpoint_engine.distributed_hccl - dist = checkpoint_engine.distributed_hccl + import checkpoint_engine.distributed_hccl as dist self._device_type = "npu" elif device_type == "cuda" and self.device_manager.device_type == "cuda": - import checkpoint_engine.distributed_nccl - dist = checkpoint_engine.distributed_nccl + import checkpoint_engine.distributed_nccl as dist self._device_type = "cuda" else: self._device_type = "torch" From c3badb409e3e2a06b20412bcb6acac8f179b5387 Mon Sep 17 00:00:00 2001 From: yexin <469221983@qq.com> Date: Thu, 8 Jan 2026 21:36:46 +0800 Subject: [PATCH 14/14] use dist.device instead of dist.rank --- checkpoint_engine/distributed_hccl.py | 2 +- checkpoint_engine/distributed_nccl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/checkpoint_engine/distributed_hccl.py b/checkpoint_engine/distributed_hccl.py index c4d8935..eca56c8 100644 --- a/checkpoint_engine/distributed_hccl.py +++ b/checkpoint_engine/distributed_hccl.py @@ -298,7 +298,7 @@ def barrier(group=None): subcomm = ctypes.c_void_p(group) dist.pyhccl.comm = subcomm - data = torch.zeros(1, device=dist.rank) + data = torch.zeros(1, device=dist.device) dist.pyhccl.all_reduce(data) current_stream().synchronize() diff --git a/checkpoint_engine/distributed_nccl.py b/checkpoint_engine/distributed_nccl.py index 5f7f919..cd7f103 100644 --- a/checkpoint_engine/distributed_nccl.py +++ b/checkpoint_engine/distributed_nccl.py @@ -270,7 +270,7 @@ def barrier(group=None): newcomm = ctypes.c_void_p(group) dist.pynccl.comm = newcomm - data = torch.zeros(1, device=dist.rank) + data = torch.zeros(1, device=dist.device) dist.pynccl.all_reduce(data) current_stream().synchronize()