Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,13 +305,21 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) ->
"""
if not devices:
raise RuntimeError("no rdma devices found")
assert len(devices) <= gpu_count, (
f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
)
assert gpu_count % len(devices) == 0, (
f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
)
return devices[local_rank // (gpu_count // len(devices))]
try:
assert len(devices) <= gpu_count, (
f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
)
assert gpu_count % len(devices) == 0, (
f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
)
return devices[local_rank // (gpu_count // len(devices))]
except AssertionError:
logger.error(
"Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
"The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices."
"The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'."
)
raise


def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
Expand Down