22# Licensed under the MIT license.
33
44import re
5- from abc import abstractmethod
65from enum import Enum
76from pathlib import PurePosixPath
87from typing import Any , List , Optional , Type
@@ -86,15 +85,13 @@ def create(
8685 f"Must be one of { list (ComputeSDK )} "
8786 )
8887
89- gpu_driver_factory = Factory [GpuDriver ](
90- GpuDriver # type: ignore[type-abstract]
91- )
88+ gpu_driver_factory = Factory [GpuDriver ](GpuDriver )
9289
93- driver_class = gpu_driver_factory .create_by_type_name (
90+ driver = gpu_driver_factory .create_by_type_name (
9491 compute_sdk , node = node , ** kwargs
9592 )
96- assert isinstance (driver_class , GpuDriver )
97- return driver_class
93+ assert isinstance (driver , GpuDriver )
94+ return driver
9895
9996 @classmethod
10097 def type_name (cls ) -> str :
@@ -118,17 +115,10 @@ def get_gpu_count(self) -> int:
118115 return smi_tool .get_gpu_count ()
119116
120117 @classmethod
121- @abstractmethod
122118 def smi (cls ) -> Type [GpuSmi ]:
123119 """Return the smi tool class for this driver"""
124120 raise NotImplementedError
125121
126- @property
127- @abstractmethod
128- def driver_name (self ) -> str :
129- """Return the human-readable driver name (e.g., 'NVIDIA GRID', 'NVIDIA CUDA')"""
130- raise NotImplementedError
131-
132122 def get_version (self ) -> str :
133123 """Get the currently installed driver version"""
134124 result = self .node .execute (f"{ self .command } --version" , shell = True , sudo = True )
@@ -149,23 +139,23 @@ def _install(self) -> bool:
149139 3. Reboot
150140 4. Verify installation
151141 """
152- self ._log .info (f"Starting { self .driver_name } installation" )
142+ compute_sdk = self .__class__ .type_name ()
143+ self ._log .info (f"Starting { compute_sdk } driver installation" )
153144
154145 self ._install_dependencies ()
155146 self ._install_driver ()
156- self ._log .info (f"{ self . driver_name } installation completed successfully" )
147+ self ._log .info (f"{ compute_sdk } driver installation completed successfully" )
157148
158149 from lisa .tools .reboot import Reboot
159150
160151 reboot_tool = self .node .tools [Reboot ]
161152 reboot_tool .reboot ()
162153
163154 version = self .get_version ()
164- self ._log .info (f"Installed { self . driver_name } \n { version } " )
155+ self ._log .info (f"Installed { compute_sdk } driver \n { version } " )
165156
166157 return True
167158
168- @abstractmethod
169159 def _install_driver (self ) -> None :
170160 """
171161 Install the actual GPU driver.
@@ -203,10 +193,6 @@ def smi(cls) -> Type[GpuSmi]:
203193
204194 return NvidiaSmi
205195
206- @property
207- def driver_name (self ) -> str :
208- return "NVIDIA GRID"
209-
210196 @property
211197 def command (self ) -> str :
212198 return "nvidia-smi"
@@ -261,7 +247,8 @@ def _install_dependencies(self) -> None:
261247 if not dependencies :
262248 return
263249
264- self ._log .debug (f"Installing { self .driver_name } dependencies: { dependencies } " )
250+ compute_sdk = self .__class__ .type_name ()
251+ self ._log .debug (f"Installing { compute_sdk } dependencies: { dependencies } " )
265252
266253 assert isinstance (
267254 self .node .os , Posix
@@ -326,10 +313,6 @@ def smi(cls) -> Type[GpuSmi]:
326313
327314 return NvidiaSmi
328315
329- @property
330- def driver_name (self ) -> str :
331- return "NVIDIA CUDA"
332-
333316 @property
334317 def command (self ) -> str :
335318 return "nvidia-smi"
@@ -393,7 +376,8 @@ def _install_dependencies(self) -> None:
393376 if not dependencies :
394377 return
395378
396- self ._log .debug (f"Installing { self .driver_name } dependencies: { dependencies } " )
379+ compute_sdk = self .__class__ .type_name ()
380+ self ._log .debug (f"Installing { compute_sdk } dependencies: { dependencies } " )
397381
398382 assert isinstance (
399383 self .node .os , Posix
@@ -580,10 +564,6 @@ def smi(cls) -> Type[GpuSmi]:
580564
581565 return AmdSmi
582566
583- @property
584- def driver_name (self ) -> str :
585- return "AMD GPU (ROCm)"
586-
587567 @property
588568 def command (self ) -> str :
589569 return "amd-smi"
@@ -621,7 +601,8 @@ def _install_dependencies(self) -> None:
621601 "python3-wheel" ,
622602 ]
623603
624- self ._log .debug (f"Installing { self .driver_name } dependencies: { dependencies } " )
604+ compute_sdk = self .__class__ .type_name ()
605+ self ._log .debug (f"Installing { compute_sdk } dependencies: { dependencies } " )
625606
626607 assert isinstance (
627608 self .node .os , Posix
0 commit comments