Skip to content

Commit 347778f

Browse files
committed
Remove driver_name and other review comments
1 parent 315488f commit 347778f

File tree

2 files changed

+16
-37
lines changed

2 files changed

+16
-37
lines changed

lisa/microsoft/testsuites/gpu/gpusuite.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,11 +330,9 @@ def _check_driver_installed(node: Node, log: Logger) -> None:
330330

331331
lspci_gpucount = gpu.get_gpu_count_with_lspci()
332332

333-
# Get the ComputeSDK type for the supported GPU type
334333
compute_sdk = _get_supported_driver(node)
335334

336-
# Create GPU driver instance using virtual tool pattern
337-
gpu_driver = node.tools.create(GpuDriver, compute_sdk=compute_sdk)
335+
gpu_driver: GpuDriver = node.tools.get(GpuDriver, compute_sdk=compute_sdk)
338336
driver_gpucount = gpu_driver.get_gpu_count()
339337

340338
assert_that(lspci_gpucount).described_as(
@@ -470,7 +468,7 @@ def __install_driver_using_sdk(node: Node, log: Logger, log_path: Path) -> None:
470468
log.debug(f"LisDriver is not installed. It might not be required. {e}")
471469

472470
compute_sdk = _get_supported_driver(node)
473-
_ = node.tools.create(GpuDriver, compute_sdk=compute_sdk)
471+
_ = node.tools.get(GpuDriver, compute_sdk=compute_sdk)
474472

475473
log.debug("GPU driver installed. Rebooting to load driver.")
476474
reboot_tool = node.tools[Reboot]

lisa/tools/gpu_drivers.py

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# Licensed under the MIT license.
33

44
import re
5-
from abc import abstractmethod
65
from enum import Enum
76
from pathlib import PurePosixPath
87
from typing import Any, List, Optional, Type
@@ -86,15 +85,13 @@ def create(
8685
f"Must be one of {list(ComputeSDK)}"
8786
)
8887

89-
gpu_driver_factory = Factory[GpuDriver](
90-
GpuDriver # type: ignore[type-abstract]
91-
)
88+
gpu_driver_factory = Factory[GpuDriver](GpuDriver)
9289

93-
driver_class = gpu_driver_factory.create_by_type_name(
90+
driver = gpu_driver_factory.create_by_type_name(
9491
compute_sdk, node=node, **kwargs
9592
)
96-
assert isinstance(driver_class, GpuDriver)
97-
return driver_class
93+
assert isinstance(driver, GpuDriver)
94+
return driver
9895

9996
@classmethod
10097
def type_name(cls) -> str:
@@ -118,17 +115,10 @@ def get_gpu_count(self) -> int:
118115
return smi_tool.get_gpu_count()
119116

120117
@classmethod
121-
@abstractmethod
122118
def smi(cls) -> Type[GpuSmi]:
123119
"""Return the smi tool class for this driver"""
124120
raise NotImplementedError
125121

126-
@property
127-
@abstractmethod
128-
def driver_name(self) -> str:
129-
"""Return the human-readable driver name (e.g., 'NVIDIA GRID', 'NVIDIA CUDA')"""
130-
raise NotImplementedError
131-
132122
def get_version(self) -> str:
133123
"""Get the currently installed driver version"""
134124
result = self.node.execute(f"{self.command} --version", shell=True, sudo=True)
@@ -149,23 +139,23 @@ def _install(self) -> bool:
149139
3. Reboot
150140
4. Verify installation
151141
"""
152-
self._log.info(f"Starting {self.driver_name} installation")
142+
compute_sdk = self.__class__.type_name()
143+
self._log.info(f"Starting {compute_sdk} driver installation")
153144

154145
self._install_dependencies()
155146
self._install_driver()
156-
self._log.info(f"{self.driver_name} installation completed successfully")
147+
self._log.info(f"{compute_sdk} driver installation completed successfully")
157148

158149
from lisa.tools.reboot import Reboot
159150

160151
reboot_tool = self.node.tools[Reboot]
161152
reboot_tool.reboot()
162153

163154
version = self.get_version()
164-
self._log.info(f"Installed {self.driver_name} \n {version}")
155+
self._log.info(f"Installed {compute_sdk} driver \n {version}")
165156

166157
return True
167158

168-
@abstractmethod
169159
def _install_driver(self) -> None:
170160
"""
171161
Install the actual GPU driver.
@@ -203,10 +193,6 @@ def smi(cls) -> Type[GpuSmi]:
203193

204194
return NvidiaSmi
205195

206-
@property
207-
def driver_name(self) -> str:
208-
return "NVIDIA GRID"
209-
210196
@property
211197
def command(self) -> str:
212198
return "nvidia-smi"
@@ -261,7 +247,8 @@ def _install_dependencies(self) -> None:
261247
if not dependencies:
262248
return
263249

264-
self._log.debug(f"Installing {self.driver_name} dependencies: {dependencies}")
250+
compute_sdk = self.__class__.type_name()
251+
self._log.debug(f"Installing {compute_sdk} dependencies: {dependencies}")
265252

266253
assert isinstance(
267254
self.node.os, Posix
@@ -326,10 +313,6 @@ def smi(cls) -> Type[GpuSmi]:
326313

327314
return NvidiaSmi
328315

329-
@property
330-
def driver_name(self) -> str:
331-
return "NVIDIA CUDA"
332-
333316
@property
334317
def command(self) -> str:
335318
return "nvidia-smi"
@@ -393,7 +376,8 @@ def _install_dependencies(self) -> None:
393376
if not dependencies:
394377
return
395378

396-
self._log.debug(f"Installing {self.driver_name} dependencies: {dependencies}")
379+
compute_sdk = self.__class__.type_name()
380+
self._log.debug(f"Installing {compute_sdk} dependencies: {dependencies}")
397381

398382
assert isinstance(
399383
self.node.os, Posix
@@ -580,10 +564,6 @@ def smi(cls) -> Type[GpuSmi]:
580564

581565
return AmdSmi
582566

583-
@property
584-
def driver_name(self) -> str:
585-
return "AMD GPU (ROCm)"
586-
587567
@property
588568
def command(self) -> str:
589569
return "amd-smi"
@@ -621,7 +601,8 @@ def _install_dependencies(self) -> None:
621601
"python3-wheel",
622602
]
623603

624-
self._log.debug(f"Installing {self.driver_name} dependencies: {dependencies}")
604+
compute_sdk = self.__class__.type_name()
605+
self._log.debug(f"Installing {compute_sdk} dependencies: {dependencies}")
625606

626607
assert isinstance(
627608
self.node.os, Posix

0 commit comments

Comments
 (0)