Skip to content

Commit 86355c1

Browse files
committed
Remove GPUDriverInstaller
1 parent 056479b commit 86355c1

File tree

3 files changed

+136
-103
lines changed

3 files changed

+136
-103
lines changed

lisa/microsoft/testsuites/gpu/gpusuite.py

Lines changed: 59 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,9 @@
3030
)
3131
from lisa.sut_orchestrator.azure.features import AzureExtension
3232
from lisa.tools import Lspci, Mkdir, Modprobe, Reboot, Tar, Wget
33-
from lisa.tools.dmesg import Dmesg
3433
from lisa.tools.gpu_drivers import (
3534
AmdGpuDriver,
3635
GpuDriver,
37-
GpuDriverInstaller,
3836
NvidiaCudaDriver,
3937
NvidiaGridDriver,
4038
)
@@ -338,21 +336,10 @@ def _check_driver_installed(node: Node, log: Logger) -> None:
338336

339337
lspci_gpucount = gpu.get_gpu_count_with_lspci()
340338

341-
# Get supported driver types from GPU feature
342-
driver_type = gpu.get_supported_driver()
339+
# Get the driver class for the supported GPU type
340+
driver_class = _get_driver_class(node)
343341

344-
# Map ComputeSDK to driver installer class
345-
driver_class: Type[GpuDriverInstaller]
346-
if driver_type == ComputeSDK.AMD:
347-
driver_class = AmdGpuDriver
348-
elif driver_type == ComputeSDK.CUDA:
349-
driver_class = NvidiaCudaDriver
350-
elif driver_type == ComputeSDK.GRID:
351-
driver_class = NvidiaGridDriver
352-
else:
353-
raise SkippedException(f"Unsupported driver type: {driver_type}")
354-
355-
# Create GpuDriver with the specific driver installer class
342+
# Create GPU driver instance using virtual tool pattern
356343
gpu_driver = node.tools.create(GpuDriver, driver_class=driver_class)
357344
driver_gpucount = gpu_driver.get_gpu_count()
358345

@@ -364,6 +351,29 @@ def _check_driver_installed(node: Node, log: Logger) -> None:
364351
log.info(f"GPU driver validated successfully with {driver_gpucount} GPUs")
365352

366353

354+
def _get_driver_class(node: Node) -> Type[GpuDriver]:
355+
"""
356+
Determine the appropriate GPU driver class based on the GPU feature.
357+
358+
Returns:
359+
The driver class to use (AmdGpuDriver, NvidiaCudaDriver, or NvidiaGridDriver)
360+
361+
Raises:
362+
SkippedException: If the driver type is not supported
363+
"""
364+
gpu_feature = node.features[Gpu]
365+
driver_type = gpu_feature.get_supported_driver()
366+
367+
if driver_type == ComputeSDK.AMD:
368+
return AmdGpuDriver
369+
elif driver_type == ComputeSDK.CUDA:
370+
return NvidiaCudaDriver
371+
elif driver_type == ComputeSDK.GRID:
372+
return NvidiaGridDriver
373+
else:
374+
raise SkippedException(f"Unsupported driver type: {driver_type}")
375+
376+
367377
def _install_cudnn(node: Node, log: Logger, install_path: str) -> None:
368378
wget = node.tools[Wget]
369379

@@ -429,36 +439,7 @@ def _install_driver(node: Node, log_path: Path, log: Logger) -> None:
429439
).stdout.split("\n")
430440
__remove_sources_added_by_extension(node, sources_before, sources_after)
431441

432-
# Install LIS driver if required (for older kernels)
433-
try:
434-
from lisa.tools import LisDriver
435-
436-
node.tools[LisDriver]
437-
except Exception as e:
438-
log.debug(f"LisDriver is not installed. It might not be required. {e}")
439-
440-
# Get supported driver types from GPU feature
441-
# TODO: Move 'get_supported_driver' to GpuDriver, it should detect the
442-
# device and driver using lspci instead of relying on the GPU feature.
443-
driver_type = gpu_feature.get_supported_driver()
444-
445-
driver_class: Type[GpuDriverInstaller]
446-
if driver_type == ComputeSDK.AMD:
447-
driver_class = AmdGpuDriver
448-
elif driver_type == ComputeSDK.CUDA:
449-
driver_class = NvidiaCudaDriver
450-
elif driver_type == ComputeSDK.GRID:
451-
driver_class = NvidiaGridDriver
452-
else:
453-
raise SkippedException(f"Unsupported driver type: {driver_type}")
454-
455-
# Create GpuDriver with the specific driver installer class
456-
gpu_driver = node.tools.create(GpuDriver, driver_class=driver_class)
457-
gpu_driver.install_driver()
458-
459-
log.debug("GPU driver installed")
460-
dmesg_tool = node.tools[Dmesg]
461-
dmesg_tool.check_kernel_errors()
442+
__install_driver_using_sdk(node, log, log_path)
462443

463444

464445
def _gpu_provision_check(min_pci_count: int, node: Node, log: Logger) -> None:
@@ -485,3 +466,35 @@ def __remove_sources_added_by_extension(
485466
rm_sources = [source for source in sources_after if source not in sources_before]
486467
for source in rm_sources:
487468
node.execute(f"rm /etc/apt/sources.list.d/{source}", sudo=True)
469+
470+
471+
def __install_driver_using_sdk(node: Node, log: Logger, log_path: Path) -> None:
472+
"""
473+
Install GPU driver using appropriate driver tool based on supported driver type.
474+
475+
This function:
476+
1. Installs LIS driver if required (for older kernels)
477+
2. Determines which driver type is supported (GRID, CUDA, or AMD)
478+
3. Installs the appropriate driver using the corresponding tool
479+
4. Reboots to load the driver
480+
"""
481+
# Install LIS driver if required (for older kernels)
482+
try:
483+
from lisa.tools import LisDriver
484+
485+
node.tools[LisDriver]
486+
except Exception as e:
487+
log.debug(f"LisDriver is not installed. It might not be required. {e}")
488+
489+
# Get the driver class for the supported GPU type
490+
# TODO: Move 'get_supported_driver' to GpuDriver, it should detect the
491+
# device and driver using lspci instead of relying on the GPU feature.
492+
driver_class = _get_driver_class(node)
493+
494+
# Create GPU driver instance using virtual tool pattern
495+
# Driver installation is triggered automatically during creation
496+
_ = node.tools.create(GpuDriver, driver_class=driver_class)
497+
498+
log.debug("GPU driver installed. Rebooting to load driver.")
499+
reboot_tool = node.tools[Reboot]
500+
reboot_tool.reboot_and_check_panic(log_path)

lisa/tools/__init__.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,7 @@
4848
from .gcc import Gcc
4949
from .gdb import Gdb
5050
from .git import Git
51-
from .gpu_drivers import (
52-
AmdGpuDriver,
53-
GpuDriver,
54-
GpuDriverInstaller,
55-
NvidiaCudaDriver,
56-
NvidiaGridDriver,
57-
)
51+
from .gpu_drivers import AmdGpuDriver, GpuDriver, NvidiaCudaDriver, NvidiaGridDriver
5852
from .gpu_smi import AmdSmi, GpuSmi, NvidiaSmi
5953
from .grub_config import GrubConfig
6054
from .hibernation_setup import HibernationSetup
@@ -183,7 +177,6 @@
183177
"Gdb",
184178
"Git",
185179
"GpuDriver",
186-
"GpuDriverInstaller",
187180
"GpuSmi",
188181
"GrubConfig",
189182
"Ip",

0 commit comments

Comments
 (0)