Skip to content

Commit 3f31828

Browse files
committed
updated setVGPUConfig
Signed-off-by: Arjun <[email protected]>
1 parent b17d50a commit 3f31828

File tree

3 files changed

+49
-69
lines changed

3 files changed

+49
-69
lines changed

cmd/nvidia-vgpu-dm/apply/config.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ import (
2929
func VGPUConfig(c *Context) error {
3030
return assert.WalkSelectedVGPUConfigForEachGPU(c.VGPUConfig, func(vc *v1.VGPUConfigSpec, i int, d types.DeviceID) error {
3131
configManager := vgpu.NewNvlibVGPUConfigManager()
32-
IsUbuntu2404, err := configManager.IsUbuntu2404()
32+
IsVFIOEnabled, err := configManager.IsVFIOEnabled()
3333
if err != nil {
3434
return fmt.Errorf("error checking if Ubuntu 24.04: %v", err)
3535
}
3636

3737
var current types.VGPUConfig
38-
if IsUbuntu2404 {
38+
if IsVFIOEnabled {
3939
current, err = configManager.GetVGPUConfigforVFIO(i)
4040
if err != nil {
4141
return fmt.Errorf("error getting VGPU config for VFIO: %v", err)
@@ -53,7 +53,7 @@ func VGPUConfig(c *Context) error {
5353
}
5454

5555
log.Debugf(" Updating vGPU config: %v", vc.VGPUDevices)
56-
if IsUbuntu2404 {
56+
if IsVFIOEnabled {
5757
err = configManager.SetVGPUConfigforVFIO(i, vc.VGPUDevices)
5858
if err != nil {
5959
return fmt.Errorf("error setting VGPU config for VFIO: %v", err)

cmd/nvidia-vgpu-dm/assert/config.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ func VGPUConfig(c *Context) error {
3737
matched := make([]bool, len(gpus))
3838
err = WalkSelectedVGPUConfigForEachGPU(c.VGPUConfig, func(vc *v1.VGPUConfigSpec, i int, d types.DeviceID) error {
3939
configManager := vgpu.NewNvlibVGPUConfigManager()
40-
IsUbuntu2404, err := configManager.IsUbuntu2404()
40+
IsVFIOEnabled, err := configManager.IsVFIOEnabled()
4141
if err != nil {
42-
return fmt.Errorf("error checking if Ubuntu 24.04: %v", err)
42+
return fmt.Errorf("error checking if VFIO is enabled: %v", err)
4343
}
4444

4545
var current types.VGPUConfig
46-
if IsUbuntu2404 {
46+
if IsVFIOEnabled {
4747
current, err = configManager.GetVGPUConfigforVFIO(i)
4848
if err != nil {
4949
return fmt.Errorf("error getting VGPU config for VFIO: %v", err)

pkg/vgpu/config.go

Lines changed: 43 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,19 @@ import (
2828
"github.com/google/uuid"
2929

3030
"github.com/NVIDIA/vgpu-device-manager/internal/nvlib"
31-
"github.com/NVIDIA/go-nvlib/pkg/nvpci"
3231
"github.com/NVIDIA/vgpu-device-manager/pkg/types"
3332
)
3433

3534
const (
36-
PCIDevicesRoot = "/sys/bus/pci/devices"
35+
HostPCIDevicesRoot = "/host/sys/bus/pci/devices"
3736
)
3837

3938
// Manager represents a set of functions for managing vGPU configurations on a node
4039
type Manager interface {
4140
GetVGPUConfig(gpu int) (types.VGPUConfig, error)
4241
SetVGPUConfig(gpu int, config types.VGPUConfig) error
4342
ClearVGPUConfig(gpu int) error
44-
IsUbuntu2404() (bool, error)
43+
IsVFIOEnabled() (bool, error)
4544
GetVGPUConfigforVFIO(gpu int) (types.VGPUConfig, error)
4645
SetVGPUConfigforVFIO(gpu int, config types.VGPUConfig) error
4746
}
@@ -57,62 +56,42 @@ func NewNvlibVGPUConfigManager() Manager {
5756
return &nvlibVGPUConfigManager{nvlib.New()}
5857
}
5958

60-
func (m *nvlibVGPUConfigManager) GetAllNvidiaGPUDevices() ([]*nvpci.NvidiaPCIDevice, error) {
61-
var nvdevices []*nvpci.NvidiaPCIDevice
62-
deviceDirs, err := os.ReadDir(PCIDevicesRoot)
63-
if err != nil {
64-
return nil, fmt.Errorf("unable to read parent PCI bus devices: %v", err)
65-
}
66-
for _, deviceDir := range deviceDirs {
67-
deviceAddress := deviceDir.Name()
68-
nvdevice, err := m.nvlib.Nvpci.GetGPUByPciBusID(deviceAddress)
69-
if err != nil || nvdevice == nil {
70-
continue
71-
}
72-
if nvdevice.IsGPU() {
73-
nvdevices = append(nvdevices, nvdevice)
74-
}
75-
}
76-
return nvdevices, nil
77-
}
78-
7959
func (m *nvlibVGPUConfigManager) GetVGPUConfigforVFIO(gpu int) (types.VGPUConfig, error) {
8060
nvdevice, err := m.nvlib.Nvpci.GetGPUByIndex(gpu)
8161
if err != nil {
8262
return nil, fmt.Errorf("unable to get GPU by index %d: %v", gpu, err)
8363
}
84-
GPUDevices, err := m.nvlib.Nvpci.GetGPUs()
85-
if err != nil {
86-
return nil, fmt.Errorf("unable to get all NVIDIA GPU devices: %v", err)
87-
}
88-
8964
vgpuConfig := types.VGPUConfig{}
90-
for _, device := range GPUDevices {
91-
if device.Address == nvdevice.Address {
92-
VFnum := 0
93-
totalVF := int(nvdevice.SriovInfo.PhysicalFunction.TotalVFs)
94-
for VFnum < totalVF {
95-
VFAddr := PCIDevicesRoot + "/" + device.Address + "/virtfn" + strconv.Itoa(VFnum) + "/nvidia"
96-
if _, err := os.Stat(VFAddr); err == nil {
97-
VGPUTypeNumberBytes, err := os.ReadFile(VFAddr + "/current_vgpu_type")
98-
if err != nil {
99-
return nil, fmt.Errorf("unable to read current vGPU type: %v", err)
100-
}
101-
VGPUTypeNumber, err := strconv.Atoi(string(VGPUTypeNumberBytes))
102-
if err != nil {
103-
return nil, fmt.Errorf("unable to convert current vGPU type to int: %v", err)
104-
}
105-
VGPUTypeName, err := m.getVGPUTypeNameforVFIO(VFAddr + "/creatable_vgpu_types", VGPUTypeNumber)
106-
if err != nil {
107-
return nil, fmt.Errorf("unable to get vGPU type name: %v", err)
108-
}
109-
vgpuConfig[VGPUTypeName]++
110-
}
111-
VFnum++
112-
}
65+
VFnum := 0
66+
if nvdevice.SriovInfo.PhysicalFunction == nil {
67+
return vgpuConfig, nil
68+
}
69+
totalVF := int(nvdevice.SriovInfo.PhysicalFunction.TotalVFs)
70+
for VFnum < totalVF {
71+
VFAddr := HostPCIDevicesRoot + "/" + nvdevice.Address + "/virtfn" + strconv.Itoa(VFnum) + "/nvidia"
72+
if _, err := os.Stat(VFAddr); err != nil {
73+
VFnum++
74+
continue
75+
}
76+
VGPUTypeNumberBytes, err := os.ReadFile(VFAddr + "/current_vgpu_type")
77+
if err != nil {
78+
return nil, fmt.Errorf("unable to read current vGPU type: %v", err)
79+
}
80+
VGPUTypeNumber, err := strconv.Atoi(strings.TrimSpace(string(VGPUTypeNumberBytes)))
81+
if err != nil {
82+
return nil, fmt.Errorf("unable to convert current vGPU type to int: %v", err)
11383
}
84+
if VGPUTypeNumber == 0 {
85+
VFnum++
86+
continue
87+
}
88+
VGPUTypeName, err := m.getVGPUTypeNameforVFIO(VFAddr + "/creatable_vgpu_types", VGPUTypeNumber)
89+
if err != nil {
90+
return nil, fmt.Errorf("unable to get vGPU type name: %v", err)
91+
}
92+
vgpuConfig[VGPUTypeName]++
93+
VFnum++
11494
}
115-
11695
return vgpuConfig, nil
11796
}
11897

@@ -139,11 +118,6 @@ func (m *nvlibVGPUConfigManager) SetVGPUConfigforVFIO(gpu int, config types.VGPU
139118
return fmt.Errorf("GPU at index %d not found in available NVIDIA devices", gpu)
140119
}
141120

142-
err = m.ClearVGPUConfig(gpu)
143-
if err != nil {
144-
return fmt.Errorf("error clearing VGPUConfig: %v", err)
145-
}
146-
147121
cmd := exec.Command("chroot", "/host", "/run/nvidia/driver/usr/lib/nvidia/sriov-manage", "-e", nvdevice.Address)
148122
output, err := cmd.CombinedOutput()
149123
if err != nil {
@@ -154,7 +128,7 @@ func (m *nvlibVGPUConfigManager) SetVGPUConfigforVFIO(gpu int, config types.VGPU
154128
remainingToCreate := val
155129
VFnum := 0
156130
for remainingToCreate > 0 {
157-
VFAddr := PCIDevicesRoot + "/" + nvdevice.Address + "/virtfn" + strconv.Itoa(VFnum) + "/nvidia"
131+
VFAddr := HostPCIDevicesRoot + "/" + nvdevice.Address + "/virtfn" + strconv.Itoa(VFnum) + "/nvidia"
158132
number, err := m.getVGPUTypeNumberforVFIO(VFAddr + "/creatable_vgpu_types", key)
159133
if err != nil {
160134
return fmt.Errorf("unable to get vGPU type number: %v", err)
@@ -216,18 +190,24 @@ func (m *nvlibVGPUConfigManager) getVGPUTypeNumberforVFIO(filePath string, vgpuT
216190
return 0, fmt.Errorf("vGPU type %s not found in file %s", vgpuTypeName, filePath)
217191
}
218192

219-
func (m *nvlibVGPUConfigManager) IsUbuntu2404() (bool, error) {
220-
// Read from the host's /etc/os-release (mounted at /host in the container)
193+
func (m *nvlibVGPUConfigManager) IsVFIOEnabled() (bool, error) {
194+
VFIOdistributions := map[string]string{
195+
"ubuntu": "24.04",
196+
"rhel": "10",
197+
}
198+
// Read from the host's /etc/os-release (mounted at /host in the container)
221199
data, err := os.ReadFile("/host/etc/os-release")
222200
if err != nil {
223201
return false, fmt.Errorf("unable to read host OS release info: %v", err)
224202
}
225203

226204
content := string(data)
227-
isUbuntu := strings.Contains(content, "ID=ubuntu")
228-
is2404 := strings.Contains(content, `VERSION_ID="24.04"`)
229-
230-
return isUbuntu && is2404, nil
205+
for distribution, version := range VFIOdistributions {
206+
if strings.Contains(content, distribution) && strings.Contains(content, version) {
207+
return true, nil
208+
}
209+
}
210+
return false, nil
231211
}
232212

233213
// GetVGPUConfig gets the 'VGPUConfig' currently applied to a GPU at a particular index

0 commit comments

Comments
 (0)