From 7798ce6c33b0ec6288606d7e72987e204d1daeee Mon Sep 17 00:00:00 2001 From: Jan Safranek Date: Tue, 5 Mar 2019 16:41:53 +0100 Subject: [PATCH] Cache driver capabilities --- cmd/csi-provisioner/csi-provisioner.go | 19 +- pkg/controller/controller.go | 161 ++++-------- pkg/controller/controller_test.go | 325 +++---------------------- 3 files changed, 105 insertions(+), 400 deletions(-) diff --git a/cmd/csi-provisioner/csi-provisioner.go b/cmd/csi-provisioner/csi-provisioner.go index 9630d7b12c..964175cf8e 100644 --- a/cmd/csi-provisioner/csi-provisioner.go +++ b/cmd/csi-provisioner/csi-provisioner.go @@ -25,10 +25,9 @@ import ( "strings" "time" - "k8s.io/klog" - flag "github.com/spf13/pflag" + "github.com/container-storage-interface/spec/lib/go/csi" ctrl "github.com/kubernetes-csi/external-provisioner/pkg/controller" snapclientset "github.com/kubernetes-csi/external-snapshotter/pkg/client/clientset/versioned" "github.com/kubernetes-sigs/sig-storage-lib-external-provisioner/controller" @@ -39,6 +38,7 @@ import ( "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" "k8s.io/client-go/util/workqueue" + "k8s.io/klog" utilfeature "k8s.io/apiserver/pkg/util/feature" utilflag "k8s.io/apiserver/pkg/util/flag" @@ -149,13 +149,26 @@ func init() { } klog.V(2).Infof("Detected CSI driver %s", provisionerName) + pluginCapabilities, controllerCapabilities, err := ctrl.GetDriverCapabilities(grpcClient, *operationTimeout) + if err != nil { + klog.Fatalf("Error getting CSI driver capabilities: %s", err) + } + + if !pluginCapabilities[csi.PluginCapability_Service_CONTROLLER_SERVICE] { + klog.Fatalf("CSI driver does not support dynamic provisioning: plugin CONTROLLER_SERVICE capability is not reported") + } + + if !controllerCapabilities[csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME] { + klog.Fatalf("CSI driver does not support dynamic provisioning: controller CREATE_DELETE_VOLUME capability is not reported") + } + // Generate a unique ID for this provisioner timeStamp := time.Now().UnixNano() / int64(time.Millisecond) identity := strconv.FormatInt(timeStamp, 10) + "-" + strconv.Itoa(rand.Intn(10000)) + "-" + provisionerName // Create the provisioner: it implements the Provisioner interface expected by // the controller - csiProvisioner := ctrl.NewCSIProvisioner(clientset, csiAPIClient, *operationTimeout, identity, *volumeNamePrefix, *volumeNameUUIDLength, grpcClient, snapClient, provisionerName) + csiProvisioner := ctrl.NewCSIProvisioner(clientset, csiAPIClient, *operationTimeout, identity, *volumeNamePrefix, *volumeNameUUIDLength, grpcClient, snapClient, provisionerName, pluginCapabilities, controllerCapabilities) provisionController = controller.NewProvisionController( clientset, provisionerName, diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index 170ccf03f8..3d74b29054 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -144,26 +144,21 @@ var ( // CSIProvisioner struct type csiProvisioner struct { - client kubernetes.Interface - csiClient csi.ControllerClient - csiAPIClient csiclientset.Interface - grpcClient *grpc.ClientConn - snapshotClient snapclientset.Interface - timeout time.Duration - identity string - volumeNamePrefix string - volumeNameUUIDLength int - config *rest.Config - driverName string + client kubernetes.Interface + csiClient csi.ControllerClient + csiAPIClient csiclientset.Interface + grpcClient *grpc.ClientConn + snapshotClient snapclientset.Interface + timeout time.Duration + identity string + volumeNamePrefix string + volumeNameUUIDLength int + config *rest.Config + driverName string + pluginCapabilities connection.PluginCapabilitySet + controllerCapabilities connection.ControllerCapabilitySet } -const ( - PluginCapability_CONTROLLER_SERVICE = iota - PluginCapability_ACCESSIBILITY_CONSTRAINTS - ControllerCapability_CREATE_DELETE_VOLUME - ControllerCapability_CREATE_DELETE_SNAPSHOT -) - var _ controller.Provisioner = &csiProvisioner{} var _ controller.BlockProvisioner = &csiProvisioner{} @@ -198,47 +193,23 @@ func GetDriverName(conn *grpc.ClientConn, timeout time.Duration) (string, error) return connection.GetDriverName(ctx, conn) } -func getDriverCapabilities(conn *grpc.ClientConn, timeout time.Duration) (sets.Int, error) { - pluginCaps, err := getPluginCapabilities(conn, timeout) +func GetDriverCapabilities(conn *grpc.ClientConn, timeout time.Duration) (connection.PluginCapabilitySet, connection.ControllerCapabilitySet, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + pluginCapabilities, err := connection.GetPluginCapabilities(ctx, conn) if err != nil { - return nil, err + return nil, nil, err } - controllerCaps, err := getControllerCapabilities(conn, timeout) + /* Each CSI operation gets its own timeout / context */ + ctx, cancel = context.WithTimeout(context.Background(), timeout) + defer cancel() + controllerCapabilities, err := connection.GetControllerCapabilities(ctx, conn) if err != nil { - return nil, err + return nil, nil, err } - capabilities := make(sets.Int) - for cap := range pluginCaps { - switch cap { - case csi.PluginCapability_Service_CONTROLLER_SERVICE: - capabilities.Insert(PluginCapability_CONTROLLER_SERVICE) - case csi.PluginCapability_Service_VOLUME_ACCESSIBILITY_CONSTRAINTS: - capabilities.Insert(PluginCapability_ACCESSIBILITY_CONSTRAINTS) - } - } - for cap := range controllerCaps { - switch cap { - case csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME: - capabilities.Insert(ControllerCapability_CREATE_DELETE_VOLUME) - case csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT: - capabilities.Insert(ControllerCapability_CREATE_DELETE_SNAPSHOT) - } - } - return capabilities, nil -} - -func getPluginCapabilities(conn *grpc.ClientConn, timeout time.Duration) (connection.PluginCapabilitySet, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return connection.GetPluginCapabilities(ctx, conn) -} - -func getControllerCapabilities(conn *grpc.ClientConn, timeout time.Duration) (connection.ControllerCapabilitySet, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return connection.GetControllerCapabilities(ctx, conn) + return pluginCapabilities, controllerCapabilities, nil } // NewCSIProvisioner creates new CSI provisioner @@ -250,55 +221,28 @@ func NewCSIProvisioner(client kubernetes.Interface, volumeNameUUIDLength int, grpcClient *grpc.ClientConn, snapshotClient snapclientset.Interface, - driverName string) controller.Provisioner { + driverName string, + pluginCapabilities connection.PluginCapabilitySet, + controllerCapabilities connection.ControllerCapabilitySet) controller.Provisioner { csiClient := csi.NewControllerClient(grpcClient) provisioner := &csiProvisioner{ - client: client, - grpcClient: grpcClient, - csiClient: csiClient, - csiAPIClient: csiAPIClient, - snapshotClient: snapshotClient, - timeout: connectionTimeout, - identity: identity, - volumeNamePrefix: volumeNamePrefix, - volumeNameUUIDLength: volumeNameUUIDLength, - driverName: driverName, + client: client, + grpcClient: grpcClient, + csiClient: csiClient, + csiAPIClient: csiAPIClient, + snapshotClient: snapshotClient, + timeout: connectionTimeout, + identity: identity, + volumeNamePrefix: volumeNamePrefix, + volumeNameUUIDLength: volumeNameUUIDLength, + driverName: driverName, + pluginCapabilities: pluginCapabilities, + controllerCapabilities: controllerCapabilities, } return provisioner } -// This function get called before any attempt to communicate with the driver. -// Before initiating Create/Delete API calls provisioner checks if Capabilities: -// PluginControllerService, ControllerCreateVolume sre supported and gets the driver name. -func checkDriverCapabilities(grpcClient *grpc.ClientConn, timeout time.Duration, needSnapshotSupport bool) (sets.Int, error) { - capabilities, err := getDriverCapabilities(grpcClient, timeout) - if err != nil { - return nil, fmt.Errorf("failed to get capabilities: %v", err) - } - - if !capabilities.Has(PluginCapability_CONTROLLER_SERVICE) { - return nil, fmt.Errorf("no plugin controller service support detected") - } - - if !capabilities.Has(ControllerCapability_CREATE_DELETE_VOLUME) { - return nil, fmt.Errorf("no create/delete volume support detected") - } - - // If PVC.Spec.DataSource is not nil, it indicates the request is to create volume - // from snapshot and therefore we should check for snapshot support; - // otherwise we don't need to check for snapshot support. - if needSnapshotSupport { - // Check whether plugin supports create snapshot - // If not, create volume from snapshot cannot proceed - if !capabilities.Has(ControllerCapability_CREATE_DELETE_SNAPSHOT) { - return nil, fmt.Errorf("no create/delete snapshot support detected. Cannot create volume from snapshot") - } - } - - return capabilities, nil -} - func makeVolumeName(prefix, pvcUID string, volumeNameUUIDLength int) (string, error) { // create persistent name based on a volumeNamePrefix and volumeNameUUIDLength // of PVC's UID @@ -386,12 +330,13 @@ func (p *csiProvisioner) Provision(options controller.VolumeOptions) (*v1.Persis if *(options.PVC.Spec.DataSource.APIGroup) != snapshotAPIGroup { return nil, fmt.Errorf("the PVC source does not belong to the right APIGroup. Expected %s, Got %s", snapshotAPIGroup, *(options.PVC.Spec.DataSource.APIGroup)) } + + // Snapshot support is requested, check it + if !p.controllerCapabilities[csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT] { + return nil, fmt.Errorf("no create/delete snapshot support detected. Cannot create volume from snapshot") + } needSnapshotSupport = true } - capabilities, err := checkDriverCapabilities(p.grpcClient, p.timeout, needSnapshotSupport) - if err != nil { - return nil, err - } pvName, err := makeVolumeName(p.volumeNamePrefix, fmt.Sprintf("%s", options.PVC.ObjectMeta.UID), p.volumeNameUUIDLength) if err != nil { @@ -443,8 +388,7 @@ func (p *csiProvisioner) Provision(options controller.VolumeOptions) (*v1.Persis req.VolumeContentSource = volumeContentSource } - if capabilities.Has(PluginCapability_ACCESSIBILITY_CONSTRAINTS) && - utilfeature.DefaultFeatureGate.Enabled(features.Topology) { + if p.supportsTopology() { requirements, err := GenerateAccessibilityRequirements( p.client, p.csiAPIClient, @@ -549,8 +493,7 @@ func (p *csiProvisioner) Provision(options controller.VolumeOptions) (*v1.Persis }, } - if capabilities.Has(PluginCapability_ACCESSIBILITY_CONSTRAINTS) && - utilfeature.DefaultFeatureGate.Enabled(features.Topology) { + if p.supportsTopology() { pv.Spec.NodeAffinity = GenerateVolumeNodeAffinity(rep.Volume.AccessibleTopology) } @@ -568,6 +511,11 @@ func (p *csiProvisioner) Provision(options controller.VolumeOptions) (*v1.Persis return pv, nil } +func (p *csiProvisioner) supportsTopology() bool { + return p.pluginCapabilities[csi.PluginCapability_Service_VOLUME_ACCESSIBILITY_CONSTRAINTS] && + utilfeature.DefaultFeatureGate.Enabled(features.Topology) +} + func removePrefixedParameters(param map[string]string) (map[string]string, error) { newParam := map[string]string{} for k, v := range param { @@ -656,11 +604,6 @@ func (p *csiProvisioner) Delete(volume *v1.PersistentVolume) error { } volumeId := p.volumeHandleToId(volume.Spec.CSI.VolumeHandle) - _, err := checkDriverCapabilities(p.grpcClient, p.timeout, false) - if err != nil { - return err - } - req := csi.DeleteVolumeRequest{ VolumeId: volumeId, } @@ -685,7 +628,7 @@ func (p *csiProvisioner) Delete(volume *v1.PersistentVolume) error { ctx, cancel := context.WithTimeout(context.Background(), p.timeout) defer cancel() - _, err = p.csiClient.DeleteVolume(ctx, &req) + _, err := p.csiClient.DeleteVolume(ctx, &req) return err } diff --git a/pkg/controller/controller_test.go b/pkg/controller/controller_test.go index a066ccb2dc..7e2121470a 100644 --- a/pkg/controller/controller_test.go +++ b/pkg/controller/controller_test.go @@ -18,7 +18,6 @@ package controller import ( "context" - "errors" "fmt" "io/ioutil" "os" @@ -37,8 +36,8 @@ import ( "github.com/kubernetes-csi/external-snapshotter/pkg/client/clientset/versioned/fake" "github.com/kubernetes-sigs/sig-storage-lib-external-provisioner/controller" "google.golang.org/grpc" + "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1" - v1 "k8s.io/api/core/v1" storage "k8s.io/api/storage/v1beta1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -267,170 +266,6 @@ func TestStripPrefixedCSIParams(t *testing.T) { } } -func TestGetDriverCapabilities(t *testing.T) { - type testcase struct { - name string - pluginCapabilities []*csi.PluginCapability_Service_Type - controllerCapabilities []*csi.ControllerServiceCapability_RPC_Type - injectPluginError bool - injectControllerError bool - expectError bool - } - tests := []testcase{{}} - - // Generate test cases by creating all possible combination of capabilities - for capName, capValue := range csi.PluginCapability_Service_Type_value { - cap := csi.PluginCapability_Service_Type(capValue) - var newTests []testcase - for _, test := range tests { - newTest := testcase{ - name: fmt.Sprintf("%s,Plugin_%s", test.name, capName), - } - copy(newTest.pluginCapabilities, append(test.pluginCapabilities, &cap)) - copy(newTest.controllerCapabilities, test.controllerCapabilities) - newTests = append(newTests, newTest) - } - tests = newTests - } - for capName, capValue := range csi.ControllerServiceCapability_RPC_Type_value { - cap := csi.ControllerServiceCapability_RPC_Type(capValue) - var newTests []testcase - for _, test := range tests { - newTest := testcase{ - name: fmt.Sprintf("%s,Plugin_%s", test.name, capName), - } - copy(newTest.pluginCapabilities, test.pluginCapabilities) - copy(newTest.controllerCapabilities, append(test.controllerCapabilities, &cap)) - newTests = append(newTests, newTest) - } - tests = newTests - } - - // nil capabilities tests - dummyPluginCap := csi.PluginCapability_Service_CONTROLLER_SERVICE - dummyControllerCap := csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME - tests = append(tests, []testcase{ - { - name: "plugin capabilities with nil entries", - pluginCapabilities: []*csi.PluginCapability_Service_Type{nil}, - controllerCapabilities: []*csi.ControllerServiceCapability_RPC_Type{&dummyControllerCap}, - }, - { - name: "controller capabilities with nil entries", - pluginCapabilities: []*csi.PluginCapability_Service_Type{&dummyPluginCap}, - controllerCapabilities: []*csi.ControllerServiceCapability_RPC_Type{nil}, - }, - }...) - - // gRPC errors - tests = append(tests, []testcase{ - { - name: "plugin capabilities call with gRPC error", - pluginCapabilities: []*csi.PluginCapability_Service_Type{&dummyPluginCap}, - controllerCapabilities: []*csi.ControllerServiceCapability_RPC_Type{&dummyControllerCap}, - injectPluginError: true, - expectError: true, - }, - { - name: "controller capabilities call with gRPC error", - pluginCapabilities: []*csi.PluginCapability_Service_Type{&dummyPluginCap}, - controllerCapabilities: []*csi.ControllerServiceCapability_RPC_Type{&dummyControllerCap}, - injectControllerError: true, - expectError: true, - }, - }...) - - tmpdir := tempDir(t) - defer os.RemoveAll(tmpdir) - mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t, tmpdir) - if err != nil { - t.Fatal(err) - } - defer mockController.Finish() - defer driver.Stop() - for _, test := range tests { - - var injectedPluginErr, injectedControllerErr error - if test.injectPluginError { - injectedPluginErr = fmt.Errorf("mock error") - } - if test.injectControllerError { - injectedControllerErr = fmt.Errorf("mock error") - } - - var pluginCaps []*csi.PluginCapability - for _, cap := range test.pluginCapabilities { - var c *csi.PluginCapability - if cap == nil { - c = &csi.PluginCapability{Type: nil} - } else { - c = &csi.PluginCapability{ - Type: &csi.PluginCapability_Service_{ - Service: &csi.PluginCapability_Service{ - Type: *cap, - }, - }, - } - } - pluginCaps = append(pluginCaps, c) - } - pluginResponse := &csi.GetPluginCapabilitiesResponse{Capabilities: pluginCaps} - - var controllerCaps []*csi.ControllerServiceCapability - for _, cap := range test.controllerCapabilities { - var c *csi.ControllerServiceCapability - if cap == nil { - c = &csi.ControllerServiceCapability{Type: nil} - } else { - c = &csi.ControllerServiceCapability{ - Type: &csi.ControllerServiceCapability_Rpc{ - Rpc: &csi.ControllerServiceCapability_RPC{ - Type: *cap, - }, - }, - } - } - controllerCaps = append(controllerCaps, c) - } - controllerResponse := &csi.ControllerGetCapabilitiesResponse{Capabilities: controllerCaps} - - identityServer.EXPECT().GetPluginCapabilities(gomock.Any(), &csi.GetPluginCapabilitiesRequest{}).Return(pluginResponse, injectedPluginErr).Times(1) - controllerServer.EXPECT().ControllerGetCapabilities(gomock.Any(), &csi.ControllerGetCapabilitiesRequest{}).Return(controllerResponse, injectedControllerErr).MinTimes(0).MaxTimes(1) - - capabilities, err := getDriverCapabilities(csiConn.conn, timeout) - if err != nil && !test.expectError { - t.Errorf("test %q failed with error: %v\n", test.name, err) - } - if err == nil { - ok := true - for _, cap := range test.pluginCapabilities { - if cap != nil { - switch *cap { - case csi.PluginCapability_Service_CONTROLLER_SERVICE: - ok = ok && capabilities.Has(PluginCapability_CONTROLLER_SERVICE) - case csi.PluginCapability_Service_VOLUME_ACCESSIBILITY_CONSTRAINTS: - ok = ok && capabilities.Has(PluginCapability_ACCESSIBILITY_CONSTRAINTS) - } - } - } - for _, cap := range test.controllerCapabilities { - if cap != nil { - switch *cap { - case csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME: - ok = ok && capabilities.Has(ControllerCapability_CREATE_DELETE_VOLUME) - case csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT: - ok = ok && capabilities.Has(ControllerCapability_CREATE_DELETE_SNAPSHOT) - } - } - } - - if !ok { - t.Errorf("test %q: missing capabilities", test.name) - } - } - } -} - func TestGetDriverName(t *testing.T) { tests := []struct { name string @@ -548,14 +383,15 @@ func TestCreateDriverReturnsInvalidCapacityDuringProvision(t *testing.T) { tmpdir := tempDir(t) defer os.RemoveAll(tmpdir) - mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t, tmpdir) + mockController, driver, _, controllerServer, csiConn, err := createMockServer(t, tmpdir) if err != nil { t.Fatal(err) } defer mockController.Finish() defer driver.Stop() - csiProvisioner := NewCSIProvisioner(nil, nil, 5*time.Second, "test-provisioner", "test", 5, csiConn.conn, nil, driverName) + pluginCaps, controllerCaps := provisionCapabilities() + csiProvisioner := NewCSIProvisioner(nil, nil, 5*time.Second, "test-provisioner", "test", 5, csiConn.conn, nil, driverName, pluginCaps, controllerCaps) // Requested PVC with requestedBytes storage opts := controller.VolumeOptions{ @@ -574,7 +410,6 @@ func TestCreateDriverReturnsInvalidCapacityDuringProvision(t *testing.T) { } // Set up Mocks - provisionMockServerSetupExpectations(identityServer, controllerServer) controllerServer.EXPECT().CreateVolume(gomock.Any(), gomock.Any()).Return(out, nil).Times(1) // Since capacity returned by driver is invalid, we expect the provision call to clean up the volume controllerServer.EXPECT().DeleteVolume(gomock.Any(), &csi.DeleteVolumeRequest{ @@ -590,95 +425,30 @@ func TestCreateDriverReturnsInvalidCapacityDuringProvision(t *testing.T) { t.Logf("Provision encountered an error: %v, expected: create volume capacity less than requested capacity", err) } -func provisionMockServerSetupExpectations(identityServer *driver.MockIdentityServer, controllerServer *driver.MockControllerServer) { - identityServer.EXPECT().GetPluginCapabilities(gomock.Any(), gomock.Any()).Return(&csi.GetPluginCapabilitiesResponse{ - Capabilities: []*csi.PluginCapability{ - { - Type: &csi.PluginCapability_Service_{ - Service: &csi.PluginCapability_Service{ - Type: csi.PluginCapability_Service_CONTROLLER_SERVICE, - }, - }, - }, - }, - }, nil).Times(1) - controllerServer.EXPECT().ControllerGetCapabilities(gomock.Any(), gomock.Any()).Return(&csi.ControllerGetCapabilitiesResponse{ - Capabilities: []*csi.ControllerServiceCapability{ - { - Type: &csi.ControllerServiceCapability_Rpc{ - Rpc: &csi.ControllerServiceCapability_RPC{ - Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, - }, - }, - }, - }, - }, nil).Times(1) +func provisionCapabilities() (connection.PluginCapabilitySet, connection.ControllerCapabilitySet) { + return connection.PluginCapabilitySet{ + csi.PluginCapability_Service_CONTROLLER_SERVICE: true, + }, connection.ControllerCapabilitySet{ + csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME: true, + } } -// provisionFromSnapshotMockServerSetupExpectations mocks plugin and controller capabilities reported -// by a CSI plugin that supports the snapshot feature -func provisionFromSnapshotMockServerSetupExpectations(identityServer *driver.MockIdentityServer, controllerServer *driver.MockControllerServer) { - identityServer.EXPECT().GetPluginCapabilities(gomock.Any(), gomock.Any()).Return(&csi.GetPluginCapabilitiesResponse{ - Capabilities: []*csi.PluginCapability{ - { - Type: &csi.PluginCapability_Service_{ - Service: &csi.PluginCapability_Service{ - Type: csi.PluginCapability_Service_CONTROLLER_SERVICE, - }, - }, - }, - }, - }, nil).Times(1) - controllerServer.EXPECT().ControllerGetCapabilities(gomock.Any(), gomock.Any()).Return(&csi.ControllerGetCapabilitiesResponse{ - Capabilities: []*csi.ControllerServiceCapability{ - { - Type: &csi.ControllerServiceCapability_Rpc{ - Rpc: &csi.ControllerServiceCapability_RPC{ - Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, - }, - }, - }, - { - Type: &csi.ControllerServiceCapability_Rpc{ - Rpc: &csi.ControllerServiceCapability_RPC{ - Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT, - }, - }, - }, - }, - }, nil).Times(1) +func provisionFromSnapshotCapabilities() (connection.PluginCapabilitySet, connection.ControllerCapabilitySet) { + return connection.PluginCapabilitySet{ + csi.PluginCapability_Service_CONTROLLER_SERVICE: true, + }, connection.ControllerCapabilitySet{ + csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME: true, + csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT: true, + } } -func provisionWithTopologyMockServerSetupExpectations(identityServer *driver.MockIdentityServer, controllerServer *driver.MockControllerServer) { - identityServer.EXPECT().GetPluginCapabilities(gomock.Any(), gomock.Any()).Return(&csi.GetPluginCapabilitiesResponse{ - Capabilities: []*csi.PluginCapability{ - { - Type: &csi.PluginCapability_Service_{ - Service: &csi.PluginCapability_Service{ - Type: csi.PluginCapability_Service_CONTROLLER_SERVICE, - }, - }, - }, - { - Type: &csi.PluginCapability_Service_{ - Service: &csi.PluginCapability_Service{ - Type: csi.PluginCapability_Service_VOLUME_ACCESSIBILITY_CONSTRAINTS, - }, - }, - }, - }, - }, nil).Times(1) - controllerServer.EXPECT().ControllerGetCapabilities(gomock.Any(), gomock.Any()).Return(&csi.ControllerGetCapabilitiesResponse{ - Capabilities: []*csi.ControllerServiceCapability{ - { - Type: &csi.ControllerServiceCapability_Rpc{ - Rpc: &csi.ControllerServiceCapability_RPC{ - Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, - }, - }, - }, - }, - }, nil).Times(1) +func provisionWithTopologyCapabilities() (connection.PluginCapabilitySet, connection.ControllerCapabilitySet) { + return connection.PluginCapabilitySet{ + csi.PluginCapability_Service_CONTROLLER_SERVICE: true, + csi.PluginCapability_Service_VOLUME_ACCESSIBILITY_CONSTRAINTS: true, + }, connection.ControllerCapabilitySet{ + csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME: true, + } } // Minimal PVC required for tests to function @@ -859,7 +629,6 @@ func TestGetSecretReference(t *testing.T) { type provisioningTestcase struct { volOpts controller.VolumeOptions notNilSelector bool - driverNotReady bool makeVolumeNameErr bool getSecretRefErr bool getCredentialsErr bool @@ -1241,14 +1010,6 @@ func TestProvision(t *testing.T) { notNilSelector: true, expectErr: true, }, - "fail driver not ready": { - volOpts: controller.VolumeOptions{ - PVName: "test-name", - PVC: createFakePVC(requestedBytes), - }, - driverNotReady: true, - expectErr: true, - }, "fail to make volume name": { volOpts: controller.VolumeOptions{ PVName: "test-name", @@ -1372,7 +1133,7 @@ func runProvisionTest(t *testing.T, k string, tc provisioningTestcase, requested tmpdir := tempDir(t) defer os.RemoveAll(tmpdir) - mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t, tmpdir) + mockController, driver, _, controllerServer, csiConn, err := createMockServer(t, tmpdir) if err != nil { t.Fatal(err) } @@ -1402,7 +1163,8 @@ func runProvisionTest(t *testing.T, k string, tc provisioningTestcase, requested clientSet = fakeclientset.NewSimpleClientset() } - csiProvisioner := NewCSIProvisioner(clientSet, nil, 5*time.Second, "test-provisioner", "test", 5, csiConn.conn, nil, driverName) + pluginCaps, controllerCaps := provisionCapabilities() + csiProvisioner := NewCSIProvisioner(clientSet, nil, 5*time.Second, "test-provisioner", "test", 5, csiConn.conn, nil, driverName, pluginCaps, controllerCaps) out := &csi.CreateVolumeResponse{ Volume: &csi.Volume{ @@ -1422,29 +1184,21 @@ func runProvisionTest(t *testing.T, k string, tc provisioningTestcase, requested if tc.notNilSelector { tc.volOpts.PVC.Spec.Selector = &metav1.LabelSelector{} - } else if tc.driverNotReady { - identityServer.EXPECT().GetPluginCapabilities(gomock.Any(), gomock.Any()).Return(nil, errors.New("driver not ready")).Times(1) } else if tc.makeVolumeNameErr { tc.volOpts.PVC.ObjectMeta.UID = "" - provisionMockServerSetupExpectations(identityServer, controllerServer) } else if tc.getSecretRefErr { tc.volOpts.Parameters[provisionerSecretNameKey] = "" - provisionMockServerSetupExpectations(identityServer, controllerServer) } else if tc.getCredentialsErr { tc.volOpts.Parameters[provisionerSecretNameKey] = "secretx" tc.volOpts.Parameters[provisionerSecretNamespaceKey] = "default" - provisionMockServerSetupExpectations(identityServer, controllerServer) } else if tc.volWithLessCap { out.Volume.CapacityBytes = int64(80) - provisionMockServerSetupExpectations(identityServer, controllerServer) controllerServer.EXPECT().CreateVolume(gomock.Any(), gomock.Any()).Return(out, nil).Times(1) controllerServer.EXPECT().DeleteVolume(gomock.Any(), gomock.Any()).Return(&csi.DeleteVolumeResponse{}, nil).Times(1) } else if tc.expectCreateVolDo != nil { - provisionMockServerSetupExpectations(identityServer, controllerServer) controllerServer.EXPECT().CreateVolume(gomock.Any(), gomock.Any()).Do(tc.expectCreateVolDo).Return(out, nil).Times(1) } else { // Setup regular mock call expectations. - provisionMockServerSetupExpectations(identityServer, controllerServer) if !tc.expectErr { controllerServer.EXPECT().CreateVolume(gomock.Any(), gomock.Any()).Return(out, nil).Times(1) } @@ -1737,7 +1491,7 @@ func TestProvisionFromSnapshot(t *testing.T) { tmpdir := tempDir(t) defer os.RemoveAll(tmpdir) - mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t, tmpdir) + mockController, driver, _, controllerServer, csiConn, err := createMockServer(t, tmpdir) if err != nil { t.Fatal(err) } @@ -1759,7 +1513,8 @@ func TestProvisionFromSnapshot(t *testing.T) { return true, content, nil }) - csiProvisioner := NewCSIProvisioner(clientSet, nil, 5*time.Second, "test-provisioner", "test", 5, csiConn.conn, client, driverName) + pluginCaps, controllerCaps := provisionFromSnapshotCapabilities() + csiProvisioner := NewCSIProvisioner(clientSet, nil, 5*time.Second, "test-provisioner", "test", 5, csiConn.conn, client, driverName, pluginCaps, controllerCaps) out := &csi.CreateVolumeResponse{ Volume: &csi.Volume{ @@ -1768,13 +1523,7 @@ func TestProvisionFromSnapshot(t *testing.T) { }, } - // Setup mock call expectations. If tc.wrongDataSource is false, DataSource is valid - // and the controller will proceed to check whether the plugin supports snapshot. - // So in this case, we need the plugin to report snapshot support capabilities; - // Otherwise, the controller will fail the operation so it won't check the capabilities. - if tc.wrongDataSource == false { - provisionFromSnapshotMockServerSetupExpectations(identityServer, controllerServer) - } + // Setup mock call expectations. // If tc.restoredVolSizeSmall is true, or tc.wrongDataSource is true, or // tc.snapshotStatusReady is false, create volume from snapshot operation will fail // early and therefore CreateVolume is not expected to be called. @@ -1848,7 +1597,7 @@ func TestProvisionWithTopology(t *testing.T) { tmpdir := tempDir(t) defer os.RemoveAll(tmpdir) - mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t, tmpdir) + mockController, driver, _, controllerServer, csiConn, err := createMockServer(t, tmpdir) if err != nil { t.Fatal(err) } @@ -1857,7 +1606,8 @@ func TestProvisionWithTopology(t *testing.T) { clientSet := fakeclientset.NewSimpleClientset() csiClientSet := fakecsiclientset.NewSimpleClientset() - csiProvisioner := NewCSIProvisioner(clientSet, csiClientSet, 5*time.Second, "test-provisioner", "test", 5, csiConn.conn, nil, driverName) + pluginCaps, controllerCaps := provisionWithTopologyCapabilities() + csiProvisioner := NewCSIProvisioner(clientSet, csiClientSet, 5*time.Second, "test-provisioner", "test", 5, csiConn.conn, nil, driverName, pluginCaps, controllerCaps) out := &csi.CreateVolumeResponse{ Volume: &csi.Volume{ @@ -1867,7 +1617,6 @@ func TestProvisionWithTopology(t *testing.T) { }, } - provisionWithTopologyMockServerSetupExpectations(identityServer, controllerServer) controllerServer.EXPECT().CreateVolume(gomock.Any(), gomock.Any()).Return(out, nil).Times(1) pv, err := csiProvisioner.Provision(controller.VolumeOptions{ @@ -1889,7 +1638,7 @@ func TestProvisionWithMountOptions(t *testing.T) { tmpdir := tempDir(t) defer os.RemoveAll(tmpdir) - mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t, tmpdir) + mockController, driver, _, controllerServer, csiConn, err := createMockServer(t, tmpdir) if err != nil { t.Fatal(err) } @@ -1898,7 +1647,8 @@ func TestProvisionWithMountOptions(t *testing.T) { clientSet := fakeclientset.NewSimpleClientset() csiClientSet := fakecsiclientset.NewSimpleClientset() - csiProvisioner := NewCSIProvisioner(clientSet, csiClientSet, 5*time.Second, "test-provisioner", "test", 5, csiConn.conn, nil, driverName) + pluginCaps, controllerCaps := provisionCapabilities() + csiProvisioner := NewCSIProvisioner(clientSet, csiClientSet, 5*time.Second, "test-provisioner", "test", 5, csiConn.conn, nil, driverName, pluginCaps, controllerCaps) out := &csi.CreateVolumeResponse{ Volume: &csi.Volume{ @@ -1907,7 +1657,6 @@ func TestProvisionWithMountOptions(t *testing.T) { }, } - provisionWithTopologyMockServerSetupExpectations(identityServer, controllerServer) controllerServer.EXPECT().CreateVolume(gomock.Any(), gomock.Any()).Return(out, nil).Times(1) pv, err := csiProvisioner.Provision(controller.VolumeOptions{