33
44import re
55from abc import abstractmethod
6- from typing import Any , List , Optional
6+ from pathlib import PurePosixPath
7+ from typing import Any , List , Optional , Type , Union
78
89from lisa .base_tools import Wget
10+ from lisa .base_tools .cat import Cat
11+ from lisa .base_tools .sed import Sed
912from lisa .base_tools .uname import Uname
1013from lisa .executable import Tool
11- from lisa .features import Gpu
12- from lisa .features .gpu import ComputeSDK
1314from lisa .operating_system import (
1415 CBLMariner ,
1516 CpuArchitecture ,
2021)
2122from lisa .tools import Df
2223from lisa .tools .amdsmi import AmdSmi
24+ from lisa .tools .echo import Echo
25+ from lisa .tools .mkdir import Mkdir
26+ from lisa .tools .modprobe import Modprobe
2327from lisa .tools .nvidiasmi import NvidiaSmi
2428from lisa .util import LisaException , MissingPackagesException , SkippedException
2529
@@ -39,59 +43,65 @@ def can_install(self) -> bool:
3943
4044 def _initialize (self , * args : Any , ** kwargs : Any ) -> None :
4145 """
42- Determine which GPU vendor is present and set the appropriate monitoring tool.
43- This runs once when the tool is first accessed.
46+ Determine which GPU vendor is present and cache the monitoring tool type.
47+ Does not instantiate the monitoring tool yet, since the driver
48+ may not be installed yet.
4449 """
50+ from lisa .features import Gpu
51+ from lisa .features .gpu import ComputeSDK
52+
4553 gpu = self .node .features [Gpu ]
46- supported_drivers = gpu .get_supported_driver ()
54+ self . _supported_drivers = gpu .get_supported_driver ()
4755
48- # Determine GPU vendor type and set the appropriate monitoring tool
49- if ComputeSDK . AMD in supported_drivers :
50- self . _monitoring_tool = self .node . tools [ AmdSmi ]
51- self ._supported_drivers = supported_drivers
56+ # Determine GPU vendor type and store the tool class (not instance)
57+ self . _monitoring_tool_class : Type [ Union [ AmdSmi , NvidiaSmi ]]
58+ if ComputeSDK . AMD in self ._supported_drivers :
59+ self ._monitoring_tool_class = AmdSmi
5260 self ._log .debug (
53- f"GPU vendor detected: AMD (supported drivers: { supported_drivers } )"
61+ f"GPU vendor detected: AMD "
62+ f"(supported drivers: { self ._supported_drivers } )"
5463 )
5564 elif (
56- ComputeSDK .CUDA in supported_drivers or ComputeSDK .GRID in supported_drivers
65+ ComputeSDK .CUDA in self ._supported_drivers
66+ or ComputeSDK .GRID in self ._supported_drivers
5767 ):
58- self ._monitoring_tool = self .node .tools [NvidiaSmi ]
59- self ._supported_drivers = supported_drivers
68+ self ._monitoring_tool_class = NvidiaSmi
6069 self ._log .debug (
61- f"GPU vendor detected: NVIDIA (supported drivers: { supported_drivers } )"
70+ f"GPU vendor detected: NVIDIA "
71+ f"(supported drivers: { self ._supported_drivers } )"
6272 )
6373 else :
6474 raise SkippedException (
65- f"No supported GPU driver type found: { supported_drivers } "
75+ f"No supported GPU driver type found: { self . _supported_drivers } "
6676 )
6777
6878 def get_gpu_count (self ) -> int :
69- return self ._monitoring_tool .get_gpu_count ()
79+ """
80+ Get GPU count using the appropriate monitoring tool.
81+ """
82+ monitoring_tool = self .node .tools [self ._monitoring_tool_class ]
83+ return monitoring_tool .get_gpu_count ()
7084
7185 def install_driver (self ) -> None :
86+ from lisa .features .gpu import ComputeSDK
87+
88+ # The if-else ordering is based on priority.
89+ # Only one driver type will be installed.
7290 for driver_type in self ._supported_drivers :
7391 if driver_type == ComputeSDK .GRID :
7492 self .node .log .info ("Installing NVIDIA GRID driver" )
75- from lisa .tools .gpu_drivers import NvidiaGridDriver
76-
7793 _ = self .node .tools [NvidiaGridDriver ]
94+ return
7895 elif driver_type == ComputeSDK .CUDA :
7996 self .node .log .info ("Installing NVIDIA CUDA driver" )
80- from lisa .tools .gpu_drivers import NvidiaCudaDriver
81-
8297 _ = self .node .tools [NvidiaCudaDriver ]
98+ return
8399 elif driver_type == ComputeSDK .AMD :
84100 self .node .log .info ("Installing AMD GPU driver" )
85- from lisa .tools .gpu_drivers import AmdGpuDriver
86-
87101 _ = self .node .tools [AmdGpuDriver ]
88- else :
89- raise LisaException (f"Unsupported driver type: '{ driver_type } '" )
102+ return
90103
91- self .node .log .debug (
92- f"{ self ._supported_drivers } driver installed. "
93- "Reboot required to load driver."
94- )
104+ raise LisaException (f"Unsupported driver type: '{ self ._supported_drivers } '" )
95105
96106
97107class GpuDriverInstaller (Tool ):
@@ -181,6 +191,11 @@ def _install(self) -> bool:
181191 self ._install_driver ()
182192 self ._log .info (f"{ self .driver_name } installation completed successfully" )
183193
194+ from lisa .tools .reboot import Reboot
195+
196+ reboot_tool = self .node .tools [Reboot ]
197+ reboot_tool .reboot ()
198+
184199 version = self .get_version ()
185200 self ._log .info (f"Installed { self .driver_name } \n { version } " )
186201
@@ -686,21 +701,72 @@ def _install_driver(self) -> None:
686701 expected_exit_code_failure_message = "amdgpu kernel module not found" ,
687702 )
688703
704+ # Remove amdgpu from deny-list if present
705+ # Azure VMs may have amdgpu blacklisted by default
706+ self ._log .info ("Checking for amdgpu deny-list entries" )
707+ modprobe_dir = "/etc/modprobe.d"
708+
709+ # Search for files containing "blacklist amdgpu" in /etc/modprobe.d/
710+ # Using grep to find which file(s) contain the blacklist entry
711+ search_result = self .node .execute (
712+ f"grep -l 'blacklist amdgpu' { modprobe_dir } /*.conf 2>/dev/null || true" ,
713+ sudo = True ,
714+ shell = True ,
715+ )
716+
717+ if search_result .stdout .strip ():
718+ # Found file(s) with blacklist entry
719+ denylist_files = search_result .stdout .strip ().split ("\n " )
720+
721+ for denylist_file in denylist_files :
722+ denylist_file = denylist_file .strip ()
723+ if not denylist_file :
724+ continue
725+
726+ self ._log .info (f"Removing amdgpu blacklist from { denylist_file } " )
727+
728+ # Use Sed tool to comment out the blacklist line
729+ sed = self .node .tools [Sed ]
730+ sed .substitute (
731+ regexp = "^blacklist amdgpu" ,
732+ replacement = "# blacklist amdgpu" ,
733+ file = denylist_file ,
734+ sudo = True ,
735+ )
736+ else :
737+ self ._log .debug ("No amdgpu deny-list entries found in /etc/modprobe.d/" )
738+
739+ # Load the amdgpu kernel module
740+ self ._log .info ("Loading amdgpu kernel module" )
741+ modprobe = self .node .tools [Modprobe ]
742+ modprobe .load ("amdgpu" )
743+
744+ # Configure module to load on boot using /etc/modules-load.d/
745+ # This is the modern systemd method for auto-loading kernel modules
746+ self ._log .info ("Configuring amdgpu module to load on boot" )
747+
748+ # Create the modules-load.d configuration file
749+ modules_load_dir = "/etc/modules-load.d"
750+ amdgpu_conf = f"{ modules_load_dir } /amdgpu.conf"
751+
752+ # Ensure directory exists
753+ mkdir = self .node .tools [Mkdir ]
754+ mkdir .create_directory (modules_load_dir , sudo = True )
755+
756+ # Write amdgpu module name to the config file
757+ echo = self .node .tools [Echo ]
758+ echo .write_to_file (
759+ "amdgpu" ,
760+ PurePosixPath (amdgpu_conf ),
761+ sudo = True ,
762+ ignore_error = False ,
763+ )
764+
689765 # Clean up package cache to free disk space
690766 self ._log .debug ("Cleaning up package cache to free disk space" )
691767 self .node .os .clean_package_cache ()
692768
693- # Check final disk space
694- root_partition = df_tool .get_partition_by_mountpoint ("/" , force_run = True )
695- if root_partition :
696- available_gb = root_partition .available_blocks / (1024 * 1024 )
697- self ._log .debug (
698- f"Disk space after installation: / partition has "
699- f"{ available_gb :.2f} GB available "
700- f"({ root_partition .percentage_blocks_used } % used)"
701- )
702-
703769 self ._log .info (
704770 "Successfully installed AMD GPU (ROCm) driver. "
705- "Reboot required to load the driver ."
771+ "Module configured to load on boot. Reboot required ."
706772 )
0 commit comments