Skip to content

Commit

Permalink
Make nvmllib a requried argument to devicelib
Browse files Browse the repository at this point in the history
In oder to ensure consistent usage, we add an explicit
argument for an nvml.Interface implementation to the
device.New constructor.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
  • Loading branch information
elezar committed May 27, 2024
1 parent 7604335 commit 67ba04e
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 20 deletions.
18 changes: 5 additions & 13 deletions pkg/nvlib/device/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type Interface interface {
}

type devicelib struct {
nvml nvml.Interface
nvmllib nvml.Interface
skippedDevices map[string]struct{}
verifySymbols *bool
migProfiles []MigProfile
Expand All @@ -47,14 +47,13 @@ type devicelib struct {
var _ Interface = &devicelib{}

// New creates a new instance of the 'device' interface.
func New(opts ...Option) Interface {
d := &devicelib{}
func New(nvmllib nvml.Interface, opts ...Option) Interface {
d := &devicelib{
nvmllib: nvmllib,
}
for _, opt := range opts {
opt(d)
}
if d.nvml == nil {
d.nvml = nvml.New()
}
if d.verifySymbols == nil {
verify := true
d.verifySymbols = &verify
Expand All @@ -68,13 +67,6 @@ func New(opts ...Option) Interface {
return d
}

// WithNvml provides an Option to set the NVML library used by the 'device' interface.
func WithNvml(nvml nvml.Interface) Option {
return func(d *devicelib) {
d.nvml = nvml
}
}

// WithVerifySymbols provides an option to toggle whether to verify select symbols exist in dynamic libraries before calling them.
func WithVerifySymbols(verify bool) Option {
return func(d *devicelib) {
Expand Down
8 changes: 4 additions & 4 deletions pkg/nvlib/device/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) {

// NewDeviceByUUID builds a new Device from a UUID.
func (d *devicelib) NewDeviceByUUID(uuid string) (Device, error) {
dev, ret := d.nvml.DeviceGetHandleByUUID(uuid)
dev, ret := d.nvmllib.DeviceGetHandleByUUID(uuid)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting device handle for uuid '%v': %v", uuid, ret)
}
Expand Down Expand Up @@ -334,13 +334,13 @@ func (d *device) isSkipped() (bool, error) {

// VisitDevices visits each top-level device and invokes a callback function for it.
func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
count, ret := d.nvml.DeviceGetCount()
count, ret := d.nvmllib.DeviceGetCount()
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device count: %v", ret)
}

for i := 0; i < count; i++ {
device, ret := d.nvml.DeviceGetHandleByIndex(i)
device, ret := d.nvmllib.DeviceGetHandleByIndex(i)
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device handle for index '%v': %v", i, ret)
}
Expand Down Expand Up @@ -469,5 +469,5 @@ func (d *devicelib) hasSymbol(symbol string) bool {
return true
}

return d.nvml.Extensions().LookupSymbol(symbol) == nil
return d.nvmllib.Extensions().LookupSymbol(symbol) == nil
}
2 changes: 1 addition & 1 deletion pkg/nvlib/device/mig_device.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (d *devicelib) NewMigDevice(handle nvml.Device) (MigDevice, error) {

// NewMigDeviceByUUID builds a new MigDevice from a UUID.
func (d *devicelib) NewMigDeviceByUUID(uuid string) (MigDevice, error) {
dev, ret := d.nvml.DeviceGetHandleByUUID(uuid)
dev, ret := d.nvmllib.DeviceGetHandleByUUID(uuid)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting device handle for uuid '%v': %v", uuid, ret)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/nvlib/device/mig_profile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func newMockDeviceLib() Interface {
},
}

return New(WithNvml(mockNvml), WithVerifySymbols(false))
return New(mockNvml, WithVerifySymbols(false))
}

func TestParseMigProfile(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/nvlib/info/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func New(opts ...Option) Interface {
)
}
if o.devicelib == nil {
o.devicelib = device.New(device.WithNvml(o.nvmllib))
o.devicelib = device.New(o.nvmllib)
}
if o.platform == "" {
o.platform = PlatformAuto
Expand Down

0 comments on commit 67ba04e

Please sign in to comment.