Skip to content

Commit

Permalink
Merge pull request #39 from elezar/make-nvmllib-requried
Browse files Browse the repository at this point in the history
Make nvmllib a requried argument to devicelib
  • Loading branch information
elezar authored May 27, 2024
2 parents 7604335 + 67ba04e commit 4663406
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 20 deletions.
18 changes: 5 additions & 13 deletions pkg/nvlib/device/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type Interface interface {
}

type devicelib struct {
nvml nvml.Interface
nvmllib nvml.Interface
skippedDevices map[string]struct{}
verifySymbols *bool
migProfiles []MigProfile
Expand All @@ -47,14 +47,13 @@ type devicelib struct {
var _ Interface = &devicelib{}

// New creates a new instance of the 'device' interface.
func New(opts ...Option) Interface {
d := &devicelib{}
func New(nvmllib nvml.Interface, opts ...Option) Interface {
d := &devicelib{
nvmllib: nvmllib,
}
for _, opt := range opts {
opt(d)
}
if d.nvml == nil {
d.nvml = nvml.New()
}
if d.verifySymbols == nil {
verify := true
d.verifySymbols = &verify
Expand All @@ -68,13 +67,6 @@ func New(opts ...Option) Interface {
return d
}

// WithNvml provides an Option to set the NVML library used by the 'device' interface.
func WithNvml(nvml nvml.Interface) Option {
return func(d *devicelib) {
d.nvml = nvml
}
}

// WithVerifySymbols provides an option to toggle whether to verify select symbols exist in dynamic libraries before calling them.
func WithVerifySymbols(verify bool) Option {
return func(d *devicelib) {
Expand Down
8 changes: 4 additions & 4 deletions pkg/nvlib/device/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) {

// NewDeviceByUUID builds a new Device from a UUID.
func (d *devicelib) NewDeviceByUUID(uuid string) (Device, error) {
dev, ret := d.nvml.DeviceGetHandleByUUID(uuid)
dev, ret := d.nvmllib.DeviceGetHandleByUUID(uuid)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting device handle for uuid '%v': %v", uuid, ret)
}
Expand Down Expand Up @@ -334,13 +334,13 @@ func (d *device) isSkipped() (bool, error) {

// VisitDevices visits each top-level device and invokes a callback function for it.
func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
count, ret := d.nvml.DeviceGetCount()
count, ret := d.nvmllib.DeviceGetCount()
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device count: %v", ret)
}

for i := 0; i < count; i++ {
device, ret := d.nvml.DeviceGetHandleByIndex(i)
device, ret := d.nvmllib.DeviceGetHandleByIndex(i)
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device handle for index '%v': %v", i, ret)
}
Expand Down Expand Up @@ -469,5 +469,5 @@ func (d *devicelib) hasSymbol(symbol string) bool {
return true
}

return d.nvml.Extensions().LookupSymbol(symbol) == nil
return d.nvmllib.Extensions().LookupSymbol(symbol) == nil
}
2 changes: 1 addition & 1 deletion pkg/nvlib/device/mig_device.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (d *devicelib) NewMigDevice(handle nvml.Device) (MigDevice, error) {

// NewMigDeviceByUUID builds a new MigDevice from a UUID.
func (d *devicelib) NewMigDeviceByUUID(uuid string) (MigDevice, error) {
dev, ret := d.nvml.DeviceGetHandleByUUID(uuid)
dev, ret := d.nvmllib.DeviceGetHandleByUUID(uuid)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting device handle for uuid '%v': %v", uuid, ret)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/nvlib/device/mig_profile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func newMockDeviceLib() Interface {
},
}

return New(WithNvml(mockNvml), WithVerifySymbols(false))
return New(mockNvml, WithVerifySymbols(false))
}

func TestParseMigProfile(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/nvlib/info/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func New(opts ...Option) Interface {
)
}
if o.devicelib == nil {
o.devicelib = device.New(device.WithNvml(o.nvmllib))
o.devicelib = device.New(o.nvmllib)
}
if o.platform == "" {
o.platform = PlatformAuto
Expand Down

0 comments on commit 4663406

Please sign in to comment.