44import re
55from abc import abstractmethod
66from pathlib import PurePosixPath
7- from typing import TYPE_CHECKING , Any , List , Optional , Type , Union
7+ from typing import Any , List , Optional , Type
88
99from lisa .base_tools import Sed , Uname , Wget
1010from lisa .executable import Tool
1616 Redhat ,
1717 Ubuntu ,
1818)
19- from lisa .tools .amdsmi import AmdSmi
2019
2120# Import tools directly from their modules to avoid circular import.
2221# lisa.tools.__init__.py imports from this file (gpu_drivers.py), so we cannot
2322# import from lisa.tools package directly. Instead, import from individual modules.
2423from lisa .tools .df import Df
2524from lisa .tools .echo import Echo
25+ from lisa .tools .gpu_smi import GpuSmi
2626from lisa .tools .mkdir import Mkdir
2727from lisa .tools .modprobe import Modprobe
28- from lisa .tools .nvidiasmi import NvidiaSmi
2928from lisa .util import LisaException , MissingPackagesException , SkippedException
3029
31- if TYPE_CHECKING :
32- from lisa .node import Node
33-
3430
3531class GpuDriver (Tool ):
3632 """
@@ -47,7 +43,7 @@ class GpuDriver(Tool):
4743 """
4844
4945 _driver_class : Type ["GpuDriverInstaller" ]
50- _smi_class : Type [Union [ AmdSmi , NvidiaSmi ] ]
46+ _smi_class : Type [GpuSmi ]
5147
5248 @property
5349 def command (self ) -> str :
@@ -56,9 +52,8 @@ def command(self) -> str:
5652 @classmethod
5753 def create (
5854 cls ,
59- node : "Node" ,
55+ node : Any ,
6056 * args : Any ,
61- driver_class : Type ["GpuDriverInstaller" ],
6257 ** kwargs : Any ,
6358 ) -> "GpuDriver" :
6459 """
@@ -72,6 +67,11 @@ def create(
7267 Returns:
7368 GpuDriver instance configured for the specified driver
7469 """
70+ driver_class : Type ["GpuDriverInstaller" ] = kwargs .pop ("driver_class" , None )
71+ assert driver_class is not None , (
72+ "driver_class parameter is required when creating GpuDriver. "
73+ "Use node.tools.create(GpuDriver, driver_class=AmdGpuDriver) or similar."
74+ )
7575
7676 instance = cls (node )
7777 instance ._driver_class = driver_class
@@ -84,7 +84,7 @@ def get_gpu_count(self) -> int:
8484 """
8585 Get GPU count using the appropriate monitoring tool.
8686 """
87- smi_tool : Union [ AmdSmi , NvidiaSmi ] = self .node .tools [self ._smi_class ]
87+ smi_tool : GpuSmi = self .node .tools [self ._smi_class ]
8888 return smi_tool .get_gpu_count ()
8989
9090 def install_driver (self ) -> None :
@@ -102,7 +102,7 @@ class GpuDriverInstaller(Tool):
102102
103103 @classmethod
104104 @abstractmethod
105- def smi (cls ) -> Type [Union [ AmdSmi , NvidiaSmi ] ]:
105+ def smi (cls ) -> Type [GpuSmi ]:
106106 """Return the monitoring tool class for this driver"""
107107 raise NotImplementedError
108108
@@ -185,8 +185,10 @@ class NvidiaGridDriver(GpuDriverInstaller):
185185 }
186186
187187 @classmethod
188- def smi (cls ) -> Type [NvidiaSmi ]:
188+ def smi (cls ) -> Type [GpuSmi ]:
189189 """Return the monitoring tool class for NVIDIA GRID driver"""
190+ from lisa .tools .gpu_smi import NvidiaSmi
191+
190192 return NvidiaSmi
191193
192194 @property
@@ -303,8 +305,10 @@ class NvidiaCudaDriver(GpuDriverInstaller):
303305 DEFAULT_CUDA_VERSION = "10.1.243-1"
304306
305307 @classmethod
306- def smi (cls ) -> Type [NvidiaSmi ]:
308+ def smi (cls ) -> Type [GpuSmi ]:
307309 """Return the monitoring tool class for NVIDIA CUDA driver"""
310+ from lisa .tools .gpu_smi import NvidiaSmi
311+
308312 return NvidiaSmi
309313
310314 @property
@@ -552,8 +556,10 @@ class AmdGpuDriver(GpuDriverInstaller):
552556 ROCM_BUILD = "70001"
553557
554558 @classmethod
555- def smi (cls ) -> Type [AmdSmi ]:
559+ def smi (cls ) -> Type [GpuSmi ]:
556560 """Return the monitoring tool class for AMD GPU driver"""
561+ from lisa .tools .gpu_smi import AmdSmi
562+
557563 return AmdSmi
558564
559565 @property
0 commit comments