Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nvidia: support disabling the nvidia plugin #8353

Merged
merged 2 commits into from
Jul 21, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions devices/gpu/nvidia/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ const (
// notAvailable value is returned to nomad server in case some properties were
// undetected by nvml driver
notAvailable = "N/A"
)

const (
// Nvidia-container-runtime environment variable names
NvidiaVisibleDevices = "NVIDIA_VISIBLE_DEVICES"
)
Expand Down Expand Up @@ -59,6 +57,10 @@ var (

// configSpec is the specification of the plugin's configuration
configSpec = hclspec.NewObject(map[string]*hclspec.Spec{
"enabled": hclspec.NewDefault(
hclspec.NewAttr("enabled", "bool", false),
hclspec.NewLiteral("true"),
),
"ignored_gpu_ids": hclspec.NewDefault(
hclspec.NewAttr("ignored_gpu_ids", "list(string)", false),
hclspec.NewLiteral("[]"),
Expand All @@ -68,16 +70,22 @@ var (
hclspec.NewLiteral("\"1m\""),
),
})

errDeviceNotEnabled = fmt.Errorf("Nvidia device is not enabled")
)

// Config contains configuration information for the plugin.
type Config struct {
Enabled bool `codec:"enabled"`
IgnoredGPUIDs []string `codec:"ignored_gpu_ids"`
FingerprintPeriod string `codec:"fingerprint_period"`
}

// NvidiaDevice contains all plugin specific data
type NvidiaDevice struct {
// enabled indicates whether the plugin should be enabled
enabled bool

// nvmlClient is used to get data from nvidia
nvmlClient nvml.NvmlClient

Expand Down Expand Up @@ -133,6 +141,8 @@ func (d *NvidiaDevice) SetConfig(cfg *base.Config) error {
}
}

d.enabled = config.Enabled

for _, ignoredGPUId := range config.IgnoredGPUIDs {
d.ignoredGPUIDs[ignoredGPUId] = struct{}{}
}
Expand All @@ -149,6 +159,10 @@ func (d *NvidiaDevice) SetConfig(cfg *base.Config) error {
// Fingerprint streams detected devices. If device changes are detected or the
// devices health changes, messages will be emitted.
func (d *NvidiaDevice) Fingerprint(ctx context.Context) (<-chan *device.FingerprintResponse, error) {
if !d.enabled {
return nil, errDeviceNotEnabled
notnoop marked this conversation as resolved.
Show resolved Hide resolved
}

outCh := make(chan *device.FingerprintResponse)
go d.fingerprint(ctx, outCh)
return outCh, nil
Expand All @@ -169,6 +183,10 @@ func (d *NvidiaDevice) Reserve(deviceIDs []string) (*device.ContainerReservation
if len(deviceIDs) == 0 {
return &device.ContainerReservation{}, nil
}
if !d.enabled {
return nil, errDeviceNotEnabled
}

// Due to the asynchronous nature of NvidiaPlugin, there is a possibility
// of race condition
//
Expand Down Expand Up @@ -202,6 +220,10 @@ func (d *NvidiaDevice) Reserve(deviceIDs []string) (*device.ContainerReservation

// Stats streams statistics for the detected devices.
func (d *NvidiaDevice) Stats(ctx context.Context, interval time.Duration) (<-chan *device.StatsResponse, error) {
if !d.enabled {
return nil, errDeviceNotEnabled
}

outCh := make(chan *device.StatsResponse)
go d.stats(ctx, outCh, interval)
return outCh, nil
Expand Down
46 changes: 36 additions & 10 deletions devices/gpu/nvidia/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (c *MockNvmlClient) GetStatsData() ([]*nvml.StatsData, error) {
}

func TestReserve(t *testing.T) {
for _, testCase := range []struct {
cases := []struct {
Name string
ExpectedReservation *device.ContainerReservation
ExpectedError error
Expand All @@ -47,7 +47,8 @@ func TestReserve(t *testing.T) {
"UUID3",
},
Device: &NvidiaDevice{
logger: hclog.NewNullLogger(),
logger: hclog.NewNullLogger(),
enabled: true,
},
},
{
Expand All @@ -66,7 +67,8 @@ func TestReserve(t *testing.T) {
devices: map[string]struct{}{
"UUID3": {},
},
logger: hclog.NewNullLogger(),
logger: hclog.NewNullLogger(),
enabled: true,
},
},
{
Expand All @@ -88,7 +90,8 @@ func TestReserve(t *testing.T) {
"UUID2": {},
"UUID3": {},
},
logger: hclog.NewNullLogger(),
logger: hclog.NewNullLogger(),
enabled: true,
},
},
{
Expand All @@ -102,13 +105,36 @@ func TestReserve(t *testing.T) {
"UUID2": {},
"UUID3": {},
},
logger: hclog.NewNullLogger(),
logger: hclog.NewNullLogger(),
enabled: true,
},
},
} {
actualReservation, actualError := testCase.Device.Reserve(testCase.RequestedIDs)
req := require.New(t)
req.Equal(testCase.ExpectedReservation, actualReservation)
req.Equal(testCase.ExpectedError, actualError)
{
Name: "Device is disabled",
ExpectedReservation: nil,
ExpectedError: errDeviceNotEnabled,
RequestedIDs: []string{
"UUID1",
"UUID2",
"UUID3",
},
Device: &NvidiaDevice{
devices: map[string]struct{}{
"UUID1": {},
"UUID2": {},
"UUID3": {},
},
logger: hclog.NewNullLogger(),
enabled: false,
},
},
}

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
actualReservation, actualError := c.Device.Reserve(c.RequestedIDs)
require.Equal(t, c.ExpectedReservation, actualReservation)
require.Equal(t, c.ExpectedError, actualError)
})
}
}