diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index ef637c7..8b080cd 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -305,13 +305,21 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> """ if not devices: raise RuntimeError("no rdma devices found") - assert len(devices) <= gpu_count, ( - f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}" - ) - assert gpu_count % len(devices) == 0, ( - f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}" - ) - return devices[local_rank // (gpu_count // len(devices))] + try: + assert len(devices) <= gpu_count, ( + f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}" + ) + assert gpu_count % len(devices) == 0, ( + f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}" + ) + return devices[local_rank // (gpu_count // len(devices))] + except AssertionError: + logger.error( + "Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices." + "The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices." + "The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'." + ) + raise def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]: