Skip to content

Commit bbc83db

Browse files
authored
A more reasonable way to obtain RDMA devices (#36)
* feat: NCCLIBHCAParser class added, supporting exact match, exclude, and port specifications for RDMA devices. https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8
1 parent 2ef05a4 commit bbc83db

File tree

4 files changed

+302
-8
lines changed

4 files changed

+302
-8
lines changed

.github/workflows/cpu-tests.yml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
name: CPU Tests
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
types: [opened, synchronize, reopened]
8+
9+
10+
permissions:
11+
contents: read
12+
13+
jobs:
14+
build:
15+
runs-on: ubuntu-latest
16+
steps:
17+
- name: Checkout code
18+
uses: actions/checkout@v4
19+
- name: Set up Python
20+
uses: actions/setup-python@v3
21+
with:
22+
python-version: "3.10"
23+
- name: Install dependencies
24+
run: |
25+
python -m pip install --upgrade pip
26+
pip install pytest
27+
pip install .[p2p]
28+
- name: Do CPU tests with pytest
29+
run: |
30+
pytest -v -m "not gpu" tests/

checkpoint_engine/ps.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -303,14 +303,7 @@ def _get_rdma_devices() -> list[str]:
303303
return devices_str.split(",")
304304
# if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
305305
hca = os.getenv("NCCL_IB_HCA", None)
306-
if hca:
307-
hca_list = hca.split(",")
308-
if len(hca_list) > 1:
309-
# if NCCL_IB_HCA has multiple values, just return
310-
return hca_list
311-
else:
312-
hca = hca_list[0]
313-
return [device for device in sorted(_ibv_get_device_list()) if hca is None or hca in device]
306+
return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list()
314307

315308

316309
def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
@@ -328,6 +321,75 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) ->
328321
return devices[local_rank // (gpu_count // len(devices))]
329322

330323

324+
def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
325+
"""
326+
The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8.
327+
The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662.
328+
329+
The list is comma-separated; port numbers are NOT supported yet.
330+
An optional prefix '^' indicates the list is an exclude list.
331+
A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix.
332+
Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported.
333+
334+
Examples:
335+
- `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`.
336+
- `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`.
337+
- `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`.
338+
- `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`.
339+
"""
340+
max_hcas = 32
341+
if not value or value.strip() == "":
342+
return available_devices[:max_hcas]
343+
344+
value = value.strip()
345+
result = []
346+
is_exclude = value.startswith("^")
347+
if is_exclude:
348+
value = value.removeprefix("^")
349+
is_exact_match = value.startswith("=")
350+
if is_exact_match:
351+
value = value.removeprefix("=")
352+
353+
device_specs = [spec.strip() for spec in value.split(",") if spec.strip()]
354+
355+
result = _resolve_device_specs(device_specs, is_exact_match, available_devices)
356+
if is_exclude:
357+
result = [dev for dev in available_devices if dev not in result]
358+
if len(result) > max_hcas:
359+
result = result[:max_hcas]
360+
361+
logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}")
362+
363+
return result
364+
365+
366+
def _resolve_device_specs(
367+
device_specs: list[str], is_exact_match: bool, available_devices: list[str]
368+
) -> list[str]:
369+
devices = set()
370+
for spec in device_specs:
371+
parts = spec.split(":", 1)
372+
device_name = parts[0].strip()
373+
# HACK: mooncake transfer engine does not support port specification yet, so we ignore it
374+
# port = parts[1].strip() if len(parts) > 1 else None
375+
base_devices = (
376+
[device_name]
377+
if device_name in available_devices
378+
else []
379+
if is_exact_match
380+
else [dev for dev in available_devices if dev.startswith(device_name)]
381+
)
382+
383+
if not base_devices:
384+
logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.")
385+
continue
386+
387+
for base_dev in base_devices:
388+
devices.add(base_dev)
389+
390+
return sorted(devices)
391+
392+
331393
def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
332394
class TPMeta(BaseModel):
333395
concat_dim: int

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,8 @@ inline-quotes = "double"
158158

159159
[tool.ruff.lint.flake8-tidy-imports]
160160
ban-relative-imports = "all"
161+
162+
[tool.pytest.ini_options]
163+
markers = [
164+
"gpu: marks tests as GPU test (deselect with '-m \"not gpu\"')",
165+
]

tests/test_rdma_parser.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import os
2+
from unittest.mock import patch
3+
4+
import pytest
5+
6+
from checkpoint_engine.ps import (
7+
_get_my_rdma_device,
8+
_get_rdma_devices,
9+
_ibv_get_device_list,
10+
_parse_NCCL_IB_HCA,
11+
)
12+
13+
14+
@pytest.fixture
15+
def mock_available_devices() -> list[str]:
16+
"""Provide mock available device list"""
17+
return ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"]
18+
19+
20+
def test_detect_ibv_list():
21+
"""Test detection of _ibv_get_device_list function"""
22+
# Skip this test if no real infiniband devices exist
23+
if not os.path.exists("/sys/class/infiniband"):
24+
pytest.skip("No infiniband devices found on system")
25+
26+
real_ibv_list = sorted(os.listdir("/sys/class/infiniband"))
27+
if real_ibv_list:
28+
devices = _ibv_get_device_list()
29+
assert isinstance(devices, list)
30+
31+
32+
def test_parse_max_hcas_limit():
33+
"""Test maximum HCA quantity limit"""
34+
# Create mock data with more than 32 devices
35+
many_devices = [f"device_{i}" for i in range(50)]
36+
result = _parse_NCCL_IB_HCA("", many_devices)
37+
assert len(result) == 32
38+
assert result == many_devices[:32]
39+
40+
41+
def test_get_rdma_devices_no_env_vars(mock_available_devices: list[str]):
42+
"""Test _get_rdma_devices with no environment variables"""
43+
with (
44+
patch.dict(os.environ, clear=True),
45+
patch("checkpoint_engine.ps._ibv_get_device_list", return_value=mock_available_devices),
46+
):
47+
devices = _get_rdma_devices()
48+
assert sorted(devices) == sorted(mock_available_devices)
49+
50+
51+
@pytest.mark.parametrize(
52+
"input_value,expected",
53+
[
54+
pytest.param("", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="empty string"),
55+
pytest.param(" \t\n ", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="whitespace"),
56+
pytest.param("None", [], id="None string"),
57+
pytest.param("^", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="caret"),
58+
pytest.param("^=", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="caret-equals"),
59+
pytest.param("=^", [], id="equals-caret"),
60+
pytest.param("^^", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="double-caret"),
61+
pytest.param("=", [], id="equals"),
62+
pytest.param("==", [], id="double-equals"),
63+
],
64+
)
65+
def test_parse_basic_cases(
66+
input_value: str, expected: list[str], mock_available_devices: list[str]
67+
):
68+
"""Test basic parsing cases: empty string, whitespace, None"""
69+
result = _parse_NCCL_IB_HCA(input_value, mock_available_devices)
70+
assert result == expected
71+
72+
73+
@pytest.mark.parametrize(
74+
"input_value,expected",
75+
[
76+
# prefix
77+
("mlx5_0", ["mlx5_0"]),
78+
("mlx5", ["mlx5_0", "mlx5_1"]),
79+
# exact match
80+
("=mlx5_0", ["mlx5_0"]),
81+
("=mlx5_0,mlx5_1", ["mlx5_0", "mlx5_1"]),
82+
# ignore ports, whitespace and duplicated commas
83+
("mlx5_0:1,mlx5_1:2", ["mlx5_0", "mlx5_1"]),
84+
("mlx5_0:1,mlx5_1", ["mlx5_0", "mlx5_1"]),
85+
(" mlx5_0 , mlx5_1 ", ["mlx5_0", "mlx5_1"]),
86+
("mlx5_0,,mlx5_1", ["mlx5_0", "mlx5_1"]),
87+
# exclusion
88+
("^mlx5_0", ["mlx5_1", "mlx4_0", "mlx4_1"]),
89+
("^mlx5_0,mlx5_1", ["mlx4_0", "mlx4_1"]),
90+
("^mlx5", ["mlx4_0", "mlx4_1"]),
91+
("^=mlx5_0,mlx5_1", ["mlx4_0", "mlx4_1"]),
92+
("^=mlx4", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"]),
93+
],
94+
)
95+
def test_parse_various_patterns(
96+
input_value: str, expected: list[str], mock_available_devices: list[str]
97+
):
98+
"""Test various parsing patterns"""
99+
result = _parse_NCCL_IB_HCA(input_value, mock_available_devices)
100+
assert result == expected
101+
102+
103+
@pytest.mark.parametrize(
104+
"input_value,expected_result,expected_warning",
105+
[
106+
("=mlx5_100", [], "No RDMA device match device_name='mlx5_100' where is_exact_match=True."),
107+
("mlx5_100", [], "No RDMA device match device_name='mlx5_100' where is_exact_match=False."),
108+
(
109+
"^mlx5_100",
110+
["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"],
111+
"No RDMA device match device_name='mlx5_100' where is_exact_match=False.",
112+
),
113+
("mlx6", [], "No RDMA device match device_name='mlx6' where is_exact_match=False."),
114+
("=mlx6", [], "No RDMA device match device_name='mlx6' where is_exact_match=True."),
115+
],
116+
)
117+
def test_parse_exact_match_with_nonexistent_device(
118+
input_value: str,
119+
expected_result: list[str],
120+
expected_warning: str,
121+
mock_available_devices: list[str],
122+
):
123+
"""Test exact matching with non-existent device"""
124+
with patch("checkpoint_engine.ps.logger") as mock_logger:
125+
result = _parse_NCCL_IB_HCA(input_value, mock_available_devices)
126+
assert result == expected_result
127+
mock_logger.warning.assert_called_once_with(expected_warning)
128+
129+
130+
@pytest.mark.parametrize(
131+
"env_var_name,env_var_value,expected_devices",
132+
[
133+
("PS_P2P_STORE_RDMA_DEVICES", "mlx5_0,mlx5_1", ["mlx5_0", "mlx5_1"]),
134+
("NCCL_IB_HCA", "mlx5", ["mlx5_0", "mlx5_1"]),
135+
("NCCL_IB_HCA", "mlx5_0,mlx5_1", ["mlx5_0", "mlx5_1"]),
136+
("NCCL_IB_HCA", "^mlx5_0", ["mlx5_1", "mlx4_0", "mlx4_1"]),
137+
("NCCL_IB_HCA", "mlx6", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"]),
138+
("NCCL_IB_HCA", "", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"]),
139+
],
140+
)
141+
def test_get_rdma_devices_with_env_vars(
142+
env_var_name: str,
143+
env_var_value: str,
144+
expected_devices: list[str],
145+
mock_available_devices: list[str],
146+
):
147+
"""Test _get_rdma_devices with various environment variables"""
148+
env_dict = {env_var_name: env_var_value}
149+
with (
150+
patch.dict(os.environ, env_dict),
151+
patch("checkpoint_engine.ps._ibv_get_device_list", return_value=mock_available_devices),
152+
):
153+
devices = _get_rdma_devices()
154+
assert sorted(devices) == sorted(expected_devices)
155+
156+
157+
@pytest.mark.parametrize(
158+
"local_rank,gpu_count,expected_device",
159+
[
160+
(0, 4, "mlx5_0"),
161+
(3, 4, "mlx5_3"),
162+
(4, 8, "mlx5_2"),
163+
(7, 8, "mlx5_3"),
164+
],
165+
)
166+
def test_get_my_rdma_device_basic(local_rank: int, gpu_count: int, expected_device: str):
167+
"""Test _get_my_rdma_device with basic allocation"""
168+
# Use fewer devices to match the GPU count constraint
169+
devices = ["mlx5_0", "mlx5_1", "mlx5_2", "mlx5_3"]
170+
device = _get_my_rdma_device(local_rank, gpu_count, devices)
171+
assert device == expected_device
172+
173+
174+
@pytest.mark.parametrize(
175+
"local_rank,gpu_count,devices,error",
176+
[
177+
(
178+
0,
179+
4,
180+
["mlx5_0", "mlx5_1", "mlx5_2", "mlx5_3", "mlx5_4"],
181+
AssertionError,
182+
), # Too many devices
183+
(
184+
0,
185+
8,
186+
["mlx5_0", "mlx5_1", "mlx5_2"],
187+
AssertionError,
188+
), # GPU count not divisible by device count
189+
(0, 8, [], RuntimeError), # No devices
190+
],
191+
)
192+
def test_get_my_rdma_device_invalid_config(
193+
local_rank: int, gpu_count: int, devices: list[str], error: type
194+
):
195+
"""Test _get_my_rdma_device with invalid configuration"""
196+
with pytest.raises(error):
197+
_get_my_rdma_device(local_rank, gpu_count, devices)

0 commit comments

Comments
 (0)