Skip to content

Commit 056479b

Browse files
committed
Create class for gpu_smi, move nvidia smi
1 parent cdcfae7 commit 056479b

File tree

4 files changed

+73
-61
lines changed

4 files changed

+73
-61
lines changed

lisa/tools/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
YumConfigManager,
1616
)
1717

18-
from .amdsmi import AmdSmi
1918
from .aria import Aria
2019
from .b4 import B4
2120
from .blkid import Blkid
@@ -56,6 +55,7 @@
5655
NvidiaCudaDriver,
5756
NvidiaGridDriver,
5857
)
58+
from .gpu_smi import AmdSmi, GpuSmi, NvidiaSmi
5959
from .grub_config import GrubConfig
6060
from .hibernation_setup import HibernationSetup
6161
from .hostname import Hostname
@@ -99,7 +99,6 @@
9999
from .ntp import Ntp
100100
from .ntpstat import Ntpstat
101101
from .ntttcp import Ntttcp
102-
from .nvidiasmi import NvidiaSmi
103102
from .nvmecli import Nvmecli
104103
from .openssl import OpenSSL
105104
from .parted import Parted
@@ -185,6 +184,7 @@
185184
"Git",
186185
"GpuDriver",
187186
"GpuDriverInstaller",
187+
"GpuSmi",
188188
"GrubConfig",
189189
"Ip",
190190
"IpInfo",

lisa/tools/gpu_drivers.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import re
55
from abc import abstractmethod
66
from pathlib import PurePosixPath
7-
from typing import TYPE_CHECKING, Any, List, Optional, Type, Union
7+
from typing import Any, List, Optional, Type
88

99
from lisa.base_tools import Sed, Uname, Wget
1010
from lisa.executable import Tool
@@ -16,21 +16,17 @@
1616
Redhat,
1717
Ubuntu,
1818
)
19-
from lisa.tools.amdsmi import AmdSmi
2019

2120
# Import tools directly from their modules to avoid circular import.
2221
# lisa.tools.__init__.py imports from this file (gpu_drivers.py), so we cannot
2322
# import from lisa.tools package directly. Instead, import from individual modules.
2423
from lisa.tools.df import Df
2524
from lisa.tools.echo import Echo
25+
from lisa.tools.gpu_smi import GpuSmi
2626
from lisa.tools.mkdir import Mkdir
2727
from lisa.tools.modprobe import Modprobe
28-
from lisa.tools.nvidiasmi import NvidiaSmi
2928
from lisa.util import LisaException, MissingPackagesException, SkippedException
3029

31-
if TYPE_CHECKING:
32-
from lisa.node import Node
33-
3430

3531
class GpuDriver(Tool):
3632
"""
@@ -47,7 +43,7 @@ class GpuDriver(Tool):
4743
"""
4844

4945
_driver_class: Type["GpuDriverInstaller"]
50-
_smi_class: Type[Union[AmdSmi, NvidiaSmi]]
46+
_smi_class: Type[GpuSmi]
5147

5248
@property
5349
def command(self) -> str:
@@ -56,9 +52,8 @@ def command(self) -> str:
5652
@classmethod
5753
def create(
5854
cls,
59-
node: "Node",
55+
node: Any,
6056
*args: Any,
61-
driver_class: Type["GpuDriverInstaller"],
6257
**kwargs: Any,
6358
) -> "GpuDriver":
6459
"""
@@ -72,6 +67,11 @@ def create(
7267
Returns:
7368
GpuDriver instance configured for the specified driver
7469
"""
70+
driver_class: Type["GpuDriverInstaller"] = kwargs.pop("driver_class", None)
71+
assert driver_class is not None, (
72+
"driver_class parameter is required when creating GpuDriver. "
73+
"Use node.tools.create(GpuDriver, driver_class=AmdGpuDriver) or similar."
74+
)
7575

7676
instance = cls(node)
7777
instance._driver_class = driver_class
@@ -84,7 +84,7 @@ def get_gpu_count(self) -> int:
8484
"""
8585
Get GPU count using the appropriate monitoring tool.
8686
"""
87-
smi_tool: Union[AmdSmi, NvidiaSmi] = self.node.tools[self._smi_class]
87+
smi_tool: GpuSmi = self.node.tools[self._smi_class]
8888
return smi_tool.get_gpu_count()
8989

9090
def install_driver(self) -> None:
@@ -102,7 +102,7 @@ class GpuDriverInstaller(Tool):
102102

103103
@classmethod
104104
@abstractmethod
105-
def smi(cls) -> Type[Union[AmdSmi, NvidiaSmi]]:
105+
def smi(cls) -> Type[GpuSmi]:
106106
"""Return the monitoring tool class for this driver"""
107107
raise NotImplementedError
108108

@@ -185,8 +185,10 @@ class NvidiaGridDriver(GpuDriverInstaller):
185185
}
186186

187187
@classmethod
188-
def smi(cls) -> Type[NvidiaSmi]:
188+
def smi(cls) -> Type[GpuSmi]:
189189
"""Return the monitoring tool class for NVIDIA GRID driver"""
190+
from lisa.tools.gpu_smi import NvidiaSmi
191+
190192
return NvidiaSmi
191193

192194
@property
@@ -303,8 +305,10 @@ class NvidiaCudaDriver(GpuDriverInstaller):
303305
DEFAULT_CUDA_VERSION = "10.1.243-1"
304306

305307
@classmethod
306-
def smi(cls) -> Type[NvidiaSmi]:
308+
def smi(cls) -> Type[GpuSmi]:
307309
"""Return the monitoring tool class for NVIDIA CUDA driver"""
310+
from lisa.tools.gpu_smi import NvidiaSmi
311+
308312
return NvidiaSmi
309313

310314
@property
@@ -552,8 +556,10 @@ class AmdGpuDriver(GpuDriverInstaller):
552556
ROCM_BUILD = "70001"
553557

554558
@classmethod
555-
def smi(cls) -> Type[AmdSmi]:
559+
def smi(cls) -> Type[GpuSmi]:
556560
"""Return the monitoring tool class for AMD GPU driver"""
561+
from lisa.tools.gpu_smi import AmdSmi
562+
557563
return AmdSmi
558564

559565
@property

lisa/tools/amdsmi.py renamed to lisa/tools/gpu_smi.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,62 @@
22
# Licensed under the MIT license.
33

44
import re
5+
from abc import abstractmethod
56
from typing import List, Type
67

78
from lisa.executable import Tool
89
from lisa.util import LisaException, find_groups_in_lines
910

1011

11-
class AmdSmi(Tool):
12+
class GpuSmi(Tool):
13+
"""
14+
Base class for GPU monitoring tools (nvidia-smi, amd-smi, etc.).
15+
"""
16+
17+
@abstractmethod
18+
def get_gpu_count(self) -> int:
19+
"""Get the number of GPUs detected by the SMI tool"""
20+
raise NotImplementedError
21+
22+
23+
class NvidiaSmi(GpuSmi):
24+
# tuple of gpu device names and their device id pattern
25+
# e.g. Tesla GPU device has device id "47505500-0001-0000-3130-444531303244"
26+
# A10-4Q device id "56475055-0002-0000-3130-444532323336"
27+
gpu_devices = (
28+
("Tesla", "47505500", 0),
29+
("A100", "44450000", 6),
30+
("H100", "44453233", 0),
31+
("A10-4Q", "56475055", 0),
32+
("A10-8Q", "3e810200", 0),
33+
("GB200", "42333130", 0),
34+
)
35+
36+
@property
37+
def command(self) -> str:
38+
return "nvidia-smi"
39+
40+
@property
41+
def can_install(self) -> bool:
42+
return False
43+
44+
def get_gpu_count(self) -> int:
45+
result = self.run("-L")
46+
if result.exit_code != 0 or (result.exit_code == 0 and result.stdout == ""):
47+
result = self.run("-L", sudo=True)
48+
if result.exit_code != 0 or (result.exit_code == 0 and result.stdout == ""):
49+
raise LisaException(
50+
f"nvidia-smi command exited with exit_code {result.exit_code}"
51+
)
52+
gpu_types = [x[0] for x in self.gpu_devices]
53+
device_count = 0
54+
for gpu_type in gpu_types:
55+
device_count += result.stdout.count(gpu_type)
56+
57+
return device_count
58+
59+
60+
class AmdSmi(GpuSmi):
1261
# Pattern to match GPU entries in amd-smi list output
1362
# Example:
1463
# GPU: 0
@@ -25,7 +74,7 @@ def command(self) -> str:
2574

2675
@property
2776
def can_install(self) -> bool:
28-
return False
77+
return True
2978

3079
@property
3180
def dependencies(self) -> List[Type[Tool]]:

lisa/tools/nvidiasmi.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

0 commit comments

Comments
 (0)