3030)
3131from lisa .sut_orchestrator .azure .features import AzureExtension
3232from lisa .tools import Lspci , Mkdir , Modprobe , Reboot , Tar , Wget
33- from lisa .tools .dmesg import Dmesg
3433from lisa .tools .gpu_drivers import (
3534 AmdGpuDriver ,
3635 GpuDriver ,
37- GpuDriverInstaller ,
3836 NvidiaCudaDriver ,
3937 NvidiaGridDriver ,
4038)
@@ -338,21 +336,10 @@ def _check_driver_installed(node: Node, log: Logger) -> None:
338336
339337 lspci_gpucount = gpu .get_gpu_count_with_lspci ()
340338
341- # Get supported driver types from GPU feature
342- driver_type = gpu . get_supported_driver ( )
339+ # Get the driver class for the supported GPU type
340+ driver_class = _get_driver_class ( node )
343341
344- # Map ComputeSDK to driver installer class
345- driver_class : Type [GpuDriverInstaller ]
346- if driver_type == ComputeSDK .AMD :
347- driver_class = AmdGpuDriver
348- elif driver_type == ComputeSDK .CUDA :
349- driver_class = NvidiaCudaDriver
350- elif driver_type == ComputeSDK .GRID :
351- driver_class = NvidiaGridDriver
352- else :
353- raise SkippedException (f"Unsupported driver type: { driver_type } " )
354-
355- # Create GpuDriver with the specific driver installer class
342+ # Create GPU driver instance using virtual tool pattern
356343 gpu_driver = node .tools .create (GpuDriver , driver_class = driver_class )
357344 driver_gpucount = gpu_driver .get_gpu_count ()
358345
@@ -364,6 +351,29 @@ def _check_driver_installed(node: Node, log: Logger) -> None:
364351 log .info (f"GPU driver validated successfully with { driver_gpucount } GPUs" )
365352
366353
354+ def _get_driver_class (node : Node ) -> Type [GpuDriver ]:
355+ """
356+ Determine the appropriate GPU driver class based on the GPU feature.
357+
358+ Returns:
359+ The driver class to use (AmdGpuDriver, NvidiaCudaDriver, or NvidiaGridDriver)
360+
361+ Raises:
362+ SkippedException: If the driver type is not supported
363+ """
364+ gpu_feature = node .features [Gpu ]
365+ driver_type = gpu_feature .get_supported_driver ()
366+
367+ if driver_type == ComputeSDK .AMD :
368+ return AmdGpuDriver
369+ elif driver_type == ComputeSDK .CUDA :
370+ return NvidiaCudaDriver
371+ elif driver_type == ComputeSDK .GRID :
372+ return NvidiaGridDriver
373+ else :
374+ raise SkippedException (f"Unsupported driver type: { driver_type } " )
375+
376+
367377def _install_cudnn (node : Node , log : Logger , install_path : str ) -> None :
368378 wget = node .tools [Wget ]
369379
@@ -429,36 +439,7 @@ def _install_driver(node: Node, log_path: Path, log: Logger) -> None:
429439 ).stdout .split ("\n " )
430440 __remove_sources_added_by_extension (node , sources_before , sources_after )
431441
432- # Install LIS driver if required (for older kernels)
433- try :
434- from lisa .tools import LisDriver
435-
436- node .tools [LisDriver ]
437- except Exception as e :
438- log .debug (f"LisDriver is not installed. It might not be required. { e } " )
439-
440- # Get supported driver types from GPU feature
441- # TODO: Move 'get_supported_driver' to GpuDriver, it should detect the
442- # device and driver using lspci instead of relying on the GPU feature.
443- driver_type = gpu_feature .get_supported_driver ()
444-
445- driver_class : Type [GpuDriverInstaller ]
446- if driver_type == ComputeSDK .AMD :
447- driver_class = AmdGpuDriver
448- elif driver_type == ComputeSDK .CUDA :
449- driver_class = NvidiaCudaDriver
450- elif driver_type == ComputeSDK .GRID :
451- driver_class = NvidiaGridDriver
452- else :
453- raise SkippedException (f"Unsupported driver type: { driver_type } " )
454-
455- # Create GpuDriver with the specific driver installer class
456- gpu_driver = node .tools .create (GpuDriver , driver_class = driver_class )
457- gpu_driver .install_driver ()
458-
459- log .debug ("GPU driver installed" )
460- dmesg_tool = node .tools [Dmesg ]
461- dmesg_tool .check_kernel_errors ()
442+ __install_driver_using_sdk (node , log , log_path )
462443
463444
464445def _gpu_provision_check (min_pci_count : int , node : Node , log : Logger ) -> None :
@@ -485,3 +466,35 @@ def __remove_sources_added_by_extension(
485466 rm_sources = [source for source in sources_after if source not in sources_before ]
486467 for source in rm_sources :
487468 node .execute (f"rm /etc/apt/sources.list.d/{ source } " , sudo = True )
469+
470+
471+ def __install_driver_using_sdk (node : Node , log : Logger , log_path : Path ) -> None :
472+ """
473+ Install GPU driver using appropriate driver tool based on supported driver type.
474+
475+ This function:
476+ 1. Installs LIS driver if required (for older kernels)
477+ 2. Determines which driver type is supported (GRID, CUDA, or AMD)
478+ 3. Installs the appropriate driver using the corresponding tool
479+ 4. Reboots to load the driver
480+ """
481+ # Install LIS driver if required (for older kernels)
482+ try :
483+ from lisa .tools import LisDriver
484+
485+ node .tools [LisDriver ]
486+ except Exception as e :
487+ log .debug (f"LisDriver is not installed. It might not be required. { e } " )
488+
489+ # Get the driver class for the supported GPU type
490+ # TODO: Move 'get_supported_driver' to GpuDriver, it should detect the
491+ # device and driver using lspci instead of relying on the GPU feature.
492+ driver_class = _get_driver_class (node )
493+
494+ # Create GPU driver instance using virtual tool pattern
495+ # Driver installation is triggered automatically during creation
496+ _ = node .tools .create (GpuDriver , driver_class = driver_class )
497+
498+ log .debug ("GPU driver installed. Rebooting to load driver." )
499+ reboot_tool = node .tools [Reboot ]
500+ reboot_tool .reboot_and_check_panic (log_path )
0 commit comments