|
2 | 2 |
|
3 | 3 | import argparse |
4 | 4 | import concurrent.futures |
| 5 | +import ctypes |
5 | 6 | import os |
6 | 7 | import pickle |
7 | 8 | import random |
@@ -269,47 +270,48 @@ def _get_ip() -> str: |
269 | 270 | return socket.gethostbyname(socket.gethostname()) |
270 | 271 |
|
271 | 272 |
|
| 273 | +def _ibv_get_device_list() -> list[str]: |
| 274 | + lib = ctypes.CDLL("libibverbs.so.1") |
| 275 | + lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices |
| 276 | + lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device ** |
| 277 | + |
| 278 | + lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)] |
| 279 | + lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device * |
| 280 | + lib.ibv_get_device_name.restype = ctypes.c_char_p # const char * |
| 281 | + |
| 282 | + num = ctypes.c_int() |
| 283 | + dev_array = lib.ibv_get_device_list(ctypes.byref(num)) |
| 284 | + if not dev_array or num.value <= 0: |
| 285 | + return [] |
| 286 | + |
| 287 | + devices = [] |
| 288 | + for i in range(num.value): |
| 289 | + dev_ptr = dev_array[i] # struct ibv_device * |
| 290 | + name = lib.ibv_get_device_name(dev_ptr) # const char * |
| 291 | + devices.append(name.decode()) |
| 292 | + lib.ibv_free_device_list(dev_array) |
| 293 | + return devices |
| 294 | + |
| 295 | + |
272 | 296 | def _get_rdma_devices() -> list[str]: |
273 | 297 | """ |
274 | | - use script like below to get RDMA devices, if NCCL_IB_HCA has multiple values, just return |
275 | | - ```bash |
276 | | - pushd /sys/class/infiniband/ > /dev/null; |
277 | | - for i in mlx5_*; do cat "$i"/ports/1/gid_attrs/types/* 2>/dev/null | grep v >/dev/null && echo "$i" ; done; |
278 | | - popd > /dev/null; |
279 | | - ``` |
| 298 | + use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return |
280 | 299 | """ |
281 | 300 | devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES") |
282 | 301 | if devices_str: |
283 | 302 | return devices_str.split(",") |
284 | 303 | # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices |
285 | 304 | hca = os.getenv("NCCL_IB_HCA", None) |
286 | 305 | if hca: |
287 | | - l = hca.split(",") # noqa: E741 |
288 | | - if len(l) > 1: |
| 306 | + hca_list = hca.split(",") |
| 307 | + if len(hca_list) > 1: |
289 | 308 | # if NCCL_IB_HCA has multiple values, just return |
290 | | - return l |
| 309 | + return hca_list |
291 | 310 | else: |
292 | | - hca = l[0] |
293 | | - basepath = "/sys/class/infiniband/" |
294 | | - port_path = "ports/1/gid_attrs/types" |
295 | | - devices = [] |
296 | | - for device in sorted(os.listdir(basepath)): |
297 | | - if hca is not None and hca not in device: |
298 | | - continue |
299 | | - path = os.path.join(basepath, device, port_path) |
300 | | - if not os.path.exists(path) or not os.path.isdir(path): |
301 | | - continue |
302 | | - for port in os.listdir(path): |
303 | | - try: |
304 | | - with open(os.path.join(path, port)) as f: |
305 | | - content = f.read() |
306 | | - if "v" in content: |
307 | | - print(f"found rdma device {device} in port {port}: {content.strip()}") |
308 | | - devices.append(device) |
309 | | - break |
310 | | - except Exception: # noqa: BLE001,S110 |
311 | | - pass |
312 | | - return devices |
| 311 | + hca = hca_list[0] |
| 312 | + return [ |
| 313 | + device for device in sorted(_ibv_get_device_list()) if hca is not None and hca in device |
| 314 | + ] |
313 | 315 |
|
314 | 316 |
|
315 | 317 | def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str: |
|
0 commit comments