@@ -28,20 +28,19 @@ import (
2828 "github.com/google/uuid"
2929
3030 "github.com/NVIDIA/vgpu-device-manager/internal/nvlib"
31- "github.com/NVIDIA/go-nvlib/pkg/nvpci"
3231 "github.com/NVIDIA/vgpu-device-manager/pkg/types"
3332)
3433
3534const (
36- PCIDevicesRoot = "/sys/bus/pci/devices"
35+ HostPCIDevicesRoot = "/host /sys/bus/pci/devices"
3736)
3837
3938// Manager represents a set of functions for managing vGPU configurations on a node
4039type Manager interface {
4140 GetVGPUConfig (gpu int ) (types.VGPUConfig , error )
4241 SetVGPUConfig (gpu int , config types.VGPUConfig ) error
4342 ClearVGPUConfig (gpu int ) error
44- IsUbuntu2404 () (bool , error )
43+ IsVFIOEnabled () (bool , error )
4544 GetVGPUConfigforVFIO (gpu int ) (types.VGPUConfig , error )
4645 SetVGPUConfigforVFIO (gpu int , config types.VGPUConfig ) error
4746}
@@ -57,62 +56,42 @@ func NewNvlibVGPUConfigManager() Manager {
5756 return & nvlibVGPUConfigManager {nvlib .New ()}
5857}
5958
60- func (m * nvlibVGPUConfigManager ) GetAllNvidiaGPUDevices () ([]* nvpci.NvidiaPCIDevice , error ) {
61- var nvdevices []* nvpci.NvidiaPCIDevice
62- deviceDirs , err := os .ReadDir (PCIDevicesRoot )
63- if err != nil {
64- return nil , fmt .Errorf ("unable to read parent PCI bus devices: %v" , err )
65- }
66- for _ , deviceDir := range deviceDirs {
67- deviceAddress := deviceDir .Name ()
68- nvdevice , err := m .nvlib .Nvpci .GetGPUByPciBusID (deviceAddress )
69- if err != nil || nvdevice == nil {
70- continue
71- }
72- if nvdevice .IsGPU () {
73- nvdevices = append (nvdevices , nvdevice )
74- }
75- }
76- return nvdevices , nil
77- }
78-
7959func (m * nvlibVGPUConfigManager ) GetVGPUConfigforVFIO (gpu int ) (types.VGPUConfig , error ) {
8060 nvdevice , err := m .nvlib .Nvpci .GetGPUByIndex (gpu )
8161 if err != nil {
8262 return nil , fmt .Errorf ("unable to get GPU by index %d: %v" , gpu , err )
8363 }
84- GPUDevices , err := m .nvlib .Nvpci .GetGPUs ()
85- if err != nil {
86- return nil , fmt .Errorf ("unable to get all NVIDIA GPU devices: %v" , err )
87- }
88-
8964 vgpuConfig := types.VGPUConfig {}
90- for _ , device := range GPUDevices {
91- if device .Address == nvdevice .Address {
92- VFnum := 0
93- totalVF := int (nvdevice .SriovInfo .PhysicalFunction .TotalVFs )
94- for VFnum < totalVF {
95- VFAddr := PCIDevicesRoot + "/" + device .Address + "/virtfn" + strconv .Itoa (VFnum ) + "/nvidia"
96- if _ , err := os .Stat (VFAddr ); err == nil {
97- VGPUTypeNumberBytes , err := os .ReadFile (VFAddr + "/current_vgpu_type" )
98- if err != nil {
99- return nil , fmt .Errorf ("unable to read current vGPU type: %v" , err )
100- }
101- VGPUTypeNumber , err := strconv .Atoi (string (VGPUTypeNumberBytes ))
102- if err != nil {
103- return nil , fmt .Errorf ("unable to convert current vGPU type to int: %v" , err )
104- }
105- VGPUTypeName , err := m .getVGPUTypeNameforVFIO (VFAddr + "/creatable_vgpu_types" , VGPUTypeNumber )
106- if err != nil {
107- return nil , fmt .Errorf ("unable to get vGPU type name: %v" , err )
108- }
109- vgpuConfig [VGPUTypeName ]++
110- }
111- VFnum ++
112- }
65+ VFnum := 0
66+ if nvdevice .SriovInfo .PhysicalFunction == nil {
67+ return vgpuConfig , nil
68+ }
69+ totalVF := int (nvdevice .SriovInfo .PhysicalFunction .TotalVFs )
70+ for VFnum < totalVF {
71+ VFAddr := HostPCIDevicesRoot + "/" + nvdevice .Address + "/virtfn" + strconv .Itoa (VFnum ) + "/nvidia"
72+ if _ , err := os .Stat (VFAddr ); err != nil {
73+ VFnum ++
74+ continue
75+ }
76+ VGPUTypeNumberBytes , err := os .ReadFile (VFAddr + "/current_vgpu_type" )
77+ if err != nil {
78+ return nil , fmt .Errorf ("unable to read current vGPU type: %v" , err )
79+ }
80+ VGPUTypeNumber , err := strconv .Atoi (strings .TrimSpace (string (VGPUTypeNumberBytes )))
81+ if err != nil {
82+ return nil , fmt .Errorf ("unable to convert current vGPU type to int: %v" , err )
11383 }
84+ if VGPUTypeNumber == 0 {
85+ VFnum ++
86+ continue
87+ }
88+ VGPUTypeName , err := m .getVGPUTypeNameforVFIO (VFAddr + "/creatable_vgpu_types" , VGPUTypeNumber )
89+ if err != nil {
90+ return nil , fmt .Errorf ("unable to get vGPU type name: %v" , err )
91+ }
92+ vgpuConfig [VGPUTypeName ]++
93+ VFnum ++
11494 }
115-
11695 return vgpuConfig , nil
11796}
11897
@@ -139,11 +118,6 @@ func (m *nvlibVGPUConfigManager) SetVGPUConfigforVFIO(gpu int, config types.VGPU
139118 return fmt .Errorf ("GPU at index %d not found in available NVIDIA devices" , gpu )
140119 }
141120
142- err = m .ClearVGPUConfig (gpu )
143- if err != nil {
144- return fmt .Errorf ("error clearing VGPUConfig: %v" , err )
145- }
146-
147121 cmd := exec .Command ("chroot" , "/host" , "/run/nvidia/driver/usr/lib/nvidia/sriov-manage" , "-e" , nvdevice .Address )
148122 output , err := cmd .CombinedOutput ()
149123 if err != nil {
@@ -154,7 +128,7 @@ func (m *nvlibVGPUConfigManager) SetVGPUConfigforVFIO(gpu int, config types.VGPU
154128 remainingToCreate := val
155129 VFnum := 0
156130 for remainingToCreate > 0 {
157- VFAddr := PCIDevicesRoot + "/" + nvdevice .Address + "/virtfn" + strconv .Itoa (VFnum ) + "/nvidia"
131+ VFAddr := HostPCIDevicesRoot + "/" + nvdevice .Address + "/virtfn" + strconv .Itoa (VFnum ) + "/nvidia"
158132 number , err := m .getVGPUTypeNumberforVFIO (VFAddr + "/creatable_vgpu_types" , key )
159133 if err != nil {
160134 return fmt .Errorf ("unable to get vGPU type number: %v" , err )
@@ -216,18 +190,24 @@ func (m *nvlibVGPUConfigManager) getVGPUTypeNumberforVFIO(filePath string, vgpuT
216190 return 0 , fmt .Errorf ("vGPU type %s not found in file %s" , vgpuTypeName , filePath )
217191}
218192
219- func (m * nvlibVGPUConfigManager ) IsUbuntu2404 () (bool , error ) {
220- // Read from the host's /etc/os-release (mounted at /host in the container)
193+ func (m * nvlibVGPUConfigManager ) IsVFIOEnabled () (bool , error ) {
194+ VFIOdistributions := map [string ]string {
195+ "ubuntu" : "24.04" ,
196+ "rhel" : "10" ,
197+ }
198+ // Read from the host's /etc/os-release (mounted at /host in the container)
221199 data , err := os .ReadFile ("/host/etc/os-release" )
222200 if err != nil {
223201 return false , fmt .Errorf ("unable to read host OS release info: %v" , err )
224202 }
225203
226204 content := string (data )
227- isUbuntu := strings .Contains (content , "ID=ubuntu" )
228- is2404 := strings .Contains (content , `VERSION_ID="24.04"` )
229-
230- return isUbuntu && is2404 , nil
205+ for distribution , version := range VFIOdistributions {
206+ if strings .Contains (content , distribution ) && strings .Contains (content , version ) {
207+ return true , nil
208+ }
209+ }
210+ return false , nil
231211}
232212
233213// GetVGPUConfig gets the 'VGPUConfig' currently applied to a GPU at a particular index
0 commit comments