Skip to content

Commit 109efc0

Browse files
feat: use ibv_get_device_list to get rdma devices instead of getting from file (#19)
1 parent 3def1a2 commit 109efc0

File tree

1 file changed

+32
-30
lines changed

1 file changed

+32
-30
lines changed

checkpoint_engine/ps.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import argparse
44
import concurrent.futures
5+
import ctypes
56
import os
67
import pickle
78
import random
@@ -269,47 +270,48 @@ def _get_ip() -> str:
269270
return socket.gethostbyname(socket.gethostname())
270271

271272

273+
def _ibv_get_device_list() -> list[str]:
274+
lib = ctypes.CDLL("libibverbs.so.1")
275+
lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
276+
lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device **
277+
278+
lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
279+
lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device *
280+
lib.ibv_get_device_name.restype = ctypes.c_char_p # const char *
281+
282+
num = ctypes.c_int()
283+
dev_array = lib.ibv_get_device_list(ctypes.byref(num))
284+
if not dev_array or num.value <= 0:
285+
return []
286+
287+
devices = []
288+
for i in range(num.value):
289+
dev_ptr = dev_array[i] # struct ibv_device *
290+
name = lib.ibv_get_device_name(dev_ptr) # const char *
291+
devices.append(name.decode())
292+
lib.ibv_free_device_list(dev_array)
293+
return devices
294+
295+
272296
def _get_rdma_devices() -> list[str]:
273297
"""
274-
use script like below to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
275-
```bash
276-
pushd /sys/class/infiniband/ > /dev/null;
277-
for i in mlx5_*; do cat "$i"/ports/1/gid_attrs/types/* 2>/dev/null | grep v >/dev/null && echo "$i" ; done;
278-
popd > /dev/null;
279-
```
298+
use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
280299
"""
281300
devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
282301
if devices_str:
283302
return devices_str.split(",")
284303
# if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
285304
hca = os.getenv("NCCL_IB_HCA", None)
286305
if hca:
287-
l = hca.split(",") # noqa: E741
288-
if len(l) > 1:
306+
hca_list = hca.split(",")
307+
if len(hca_list) > 1:
289308
# if NCCL_IB_HCA has multiple values, just return
290-
return l
309+
return hca_list
291310
else:
292-
hca = l[0]
293-
basepath = "/sys/class/infiniband/"
294-
port_path = "ports/1/gid_attrs/types"
295-
devices = []
296-
for device in sorted(os.listdir(basepath)):
297-
if hca is not None and hca not in device:
298-
continue
299-
path = os.path.join(basepath, device, port_path)
300-
if not os.path.exists(path) or not os.path.isdir(path):
301-
continue
302-
for port in os.listdir(path):
303-
try:
304-
with open(os.path.join(path, port)) as f:
305-
content = f.read()
306-
if "v" in content:
307-
print(f"found rdma device {device} in port {port}: {content.strip()}")
308-
devices.append(device)
309-
break
310-
except Exception: # noqa: BLE001,S110
311-
pass
312-
return devices
311+
hca = hca_list[0]
312+
return [
313+
device for device in sorted(_ibv_get_device_list()) if hca is not None and hca in device
314+
]
313315

314316

315317
def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:

0 commit comments

Comments
 (0)