Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,17 +294,27 @@ def _ibv_get_device_list() -> list[str]:
return devices


def _get_rdma_devices() -> list[str]:
def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
"""
use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
Allocate RDMA devices to GPUs in a round-robin or block-sharing fashion.
If there are more RDMA devices than GPUs, only use the first 'k' devices
such that k <= gpu_count and gpu_count % k == 0.
"""
devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
if devices_str:
return devices_str.split(",")
# if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
hca = os.getenv("NCCL_IB_HCA", None)
return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list()
if not devices:
raise RuntimeError("No RDMA devices found")

# 找到一个合适的 k(RDMA 设备子集数量),满足 k <= gpu_count 且 gpu_count % k == 0
usable_devices = []
for k in range(len(devices), 0, -1):
if k <= gpu_count and gpu_count % k == 0:
usable_devices = devices[:k]
break
else:
# 如果找不到整除的 k,退而求其次:使用 1 个设备(所有 GPU 共享)
usable_devices = [devices[0]]

k = len(usable_devices)
return usable_devices[local_rank // (gpu_count // k)]

def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
"""
Expand Down