diff --git a/cmd/csi-attacher/main.go b/cmd/csi-attacher/main.go index 447e0444a8..f96380df11 100644 --- a/cmd/csi-attacher/main.go +++ b/cmd/csi-attacher/main.go @@ -103,21 +103,31 @@ func main() { os.Exit(1) } - // Find out if the driver supports attach/detach. - supportsAttach, err := csiConn.SupportsControllerPublish(ctx) + supportsService, err := csiConn.SupportsPluginControllerService(ctx) if err != nil { glog.Error(err.Error()) os.Exit(1) } - if !supportsAttach { + if !supportsService { handler = controller.NewTrivialHandler(clientset) - glog.V(2).Infof("CSI driver does not support ControllerPublishUnpublish, using trivial handler") + glog.V(2).Infof("CSI driver does not support Plugin Controller Service, using trivial handler") } else { - pvLister := factory.Core().V1().PersistentVolumes().Lister() - nodeLister := factory.Core().V1().Nodes().Lister() - vaLister := factory.Storage().V1alpha1().VolumeAttachments().Lister() - handler = controller.NewCSIHandler(clientset, attacher, csiConn, pvLister, nodeLister, vaLister) - glog.V(2).Infof("CSI driver supports ControllerPublishUnpublish, using real CSI handler") + // Find out if the driver supports attach/detach. + supportsAttach, err := csiConn.SupportsControllerPublish(ctx) + if err != nil { + glog.Error(err.Error()) + os.Exit(1) + } + if supportsAttach { + pvLister := factory.Core().V1().PersistentVolumes().Lister() + nodeLister := factory.Core().V1().Nodes().Lister() + vaLister := factory.Storage().V1alpha1().VolumeAttachments().Lister() + handler = controller.NewCSIHandler(clientset, attacher, csiConn, pvLister, nodeLister, vaLister) + glog.V(2).Infof("CSI driver supports ControllerPublishUnpublish, using real CSI handler") + } else { + handler = controller.NewTrivialHandler(clientset) + glog.V(2).Infof("CSI driver does not support ControllerPublishUnpublish, using trivial handler") + } } } diff --git a/pkg/connection/connection.go b/pkg/connection/connection.go index fc5bb98677..91c4555caf 100644 --- a/pkg/connection/connection.go +++ b/pkg/connection/connection.go @@ -42,6 +42,10 @@ type CSIConnection interface { // PUBLISH_UNPUBLISH_VOLUME in ControllerGetCapabilities() gRPC call. SupportsControllerPublish(ctx context.Context) (bool, error) + // SupportsPluginControllerService return true if the CSI driver reports + // CONTROLLER_SERVICE in GetPluginCapabilities() gRPC call. + SupportsPluginControllerService(ctx context.Context) (bool, error) + // Attach given volume to given node. Returns PublishVolumeInfo. Note that // "detached" is returned on error and means that the volume is for sure // detached from the node. "false" means that the volume may be either @@ -164,6 +168,30 @@ func (c *csiConnection) SupportsControllerPublish(ctx context.Context) (bool, er return false, nil } +func (c *csiConnection) SupportsPluginControllerService(ctx context.Context) (bool, error) { + client := csi.NewIdentityClient(c.conn) + req := csi.GetPluginCapabilitiesRequest{} + + rsp, err := client.GetPluginCapabilities(ctx, &req) + if err != nil { + return false, err + } + caps := rsp.GetCapabilities() + for _, cap := range caps { + if cap == nil { + continue + } + service := cap.GetService() + if service == nil { + continue + } + if service.GetType() == csi.PluginCapability_Service_CONTROLLER_SERVICE { + return true, nil + } + } + return false, nil +} + func (c *csiConnection) Attach(ctx context.Context, volumeID string, readOnly bool, nodeID string, caps *csi.VolumeCapability, attributes map[string]string) (metadata map[string]string, detached bool, err error) { client := csi.NewControllerClient(c.conn) diff --git a/pkg/connection/connection_test.go b/pkg/connection/connection_test.go index af2dfe653a..15247f413f 100644 --- a/pkg/connection/connection_test.go +++ b/pkg/connection/connection_test.go @@ -222,6 +222,107 @@ func TestSupportsControllerPublish(t *testing.T) { } } +func TestSupportsPluginControllerService(t *testing.T) { + tests := []struct { + name string + output *csi.GetPluginCapabilitiesResponse + injectError bool + expectError bool + }{ + { + name: "success", + output: &csi.GetPluginCapabilitiesResponse{ + Capabilities: []*csi.PluginCapability{ + { + Type: &csi.PluginCapability_Service_{ + Service: &csi.PluginCapability_Service{ + Type: csi.PluginCapability_Service_CONTROLLER_SERVICE, + }, + }, + }, + { + Type: &csi.PluginCapability_Service_{ + Service: &csi.PluginCapability_Service{ + Type: csi.PluginCapability_Service_UNKNOWN, + }, + }, + }, + }, + }, + expectError: false, + }, + { + name: "gRPC error", + output: nil, + injectError: true, + expectError: true, + }, + { + name: "no controller service", + output: &csi.GetPluginCapabilitiesResponse{ + Capabilities: []*csi.PluginCapability{ + { + Type: &csi.PluginCapability_Service_{ + Service: &csi.PluginCapability_Service{ + Type: csi.PluginCapability_Service_UNKNOWN, + }, + }, + }, + }, + }, + expectError: false, + }, + { + name: "empty capability", + output: &csi.GetPluginCapabilitiesResponse{ + Capabilities: []*csi.PluginCapability{ + { + Type: nil, + }, + }, + }, + expectError: false, + }, + { + name: "no capabilities", + output: &csi.GetPluginCapabilitiesResponse{ + Capabilities: []*csi.PluginCapability{}, + }, + expectError: false, + }, + } + + mockController, driver, identityServer, _, csiConn, err := createMockServer(t) + if err != nil { + t.Fatal(err) + } + defer mockController.Finish() + defer driver.Stop() + defer csiConn.Close() + + for _, test := range tests { + + in := &csi.GetPluginCapabilitiesRequest{} + + out := test.output + var injectedErr error = nil + if test.injectError { + injectedErr = fmt.Errorf("mock error") + } + + // Setup expectation + identityServer.EXPECT().GetPluginCapabilities(gomock.Any(), in).Return(out, injectedErr).Times(1) + + _, err = csiConn.SupportsPluginControllerService(context.Background()) + if test.expectError && err == nil { + t.Errorf("test %q: Expected error, got none", test.name) + } + if !test.expectError && err != nil { + t.Errorf("test %q: got error: %v", test.name, err) + } + } +} + func TestAttach(t *testing.T) { defaultVolumeID := "myname" defaultNodeID := "MyNodeID" diff --git a/pkg/controller/framework_test.go b/pkg/controller/framework_test.go index cb69237bd5..d1593bd2f8 100644 --- a/pkg/controller/framework_test.go +++ b/pkg/controller/framework_test.go @@ -324,6 +324,10 @@ func (f *fakeCSIConnection) GetDriverName(ctx context.Context) (string, error) { return "", fmt.Errorf("Not implemented") } +func (f *fakeCSIConnection) SupportsPluginControllerService(ctx context.Context) (bool, error) { + return false, fmt.Errorf("Not implemented") +} + func (f *fakeCSIConnection) SupportsControllerPublish(ctx context.Context) (bool, error) { return false, fmt.Errorf("Not implemented") }