diff --git a/pkg/azurefile/controllerserver.go b/pkg/azurefile/controllerserver.go index 41336552a4..9f882d9bb8 100644 --- a/pkg/azurefile/controllerserver.go +++ b/pkg/azurefile/controllerserver.go @@ -312,6 +312,10 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) return nil, status.Errorf(codes.InvalidArgument, "fsType(%s) is not supported with protocol(%s)", fsType, protocol) } + if protocol == smb && sku == "" { + return nil, status.Errorf(codes.InvalidArgument, "%s/%s parameter is required with protocol(%s)", skuNameField, storageAccountField, protocol) + } + enableHTTPSTrafficOnly := true shareProtocol := storage.EnabledProtocolsSMB createPrivateEndpoint := false diff --git a/pkg/azurefile/controllerserver_test.go b/pkg/azurefile/controllerserver_test.go index d9edb3077a..e450bc5dab 100644 --- a/pkg/azurefile/controllerserver_test.go +++ b/pkg/azurefile/controllerserver_test.go @@ -1266,6 +1266,90 @@ func TestCreateVolume(t *testing.T) { } }, }, + { + name: "Valid smb request", + testFunc: func(t *testing.T) { + name := "baz" + sku := "sku" + kind := "StorageV2" + location := "centralus" + value := "foo bar" + accounts := []storage.Account{ + {Name: &name, Sku: &storage.Sku{Name: storage.SkuName(sku)}, Kind: storage.Kind(kind), Location: &location}, + } + keys := storage.AccountListKeysResult{ + Keys: &[]storage.AccountKey{ + {Value: &value}, + }, + } + + allParam := map[string]string{ + skuNameField: "premium", + storageAccountTypeField: "stoacctype", + locationField: "loc", + storageAccountField: "stoacc", + resourceGroupField: "rg", + shareNameField: "", + diskNameField: "diskname.vhd", + fsTypeField: "", + storeAccountKeyField: "storeaccountkey", + secretNamespaceField: "default", + mountPermissionsField: "0755", + accountQuotaField: "1000", + protocolField: smb, + } + + req := &csi.CreateVolumeRequest{ + Name: "random-vol-name-valid-request", + VolumeCapabilities: stdVolCap, + CapacityRange: lessThanPremCapRange, + Parameters: allParam, + } + + d := NewFakeDriver() + d.cloud = &azure.Cloud{} + d.cloud.KubeClient = fake.NewSimpleClientset() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockFileClient := mockfileclient.NewMockInterface(ctrl) + d.cloud.FileClient = mockFileClient + + mockStorageAccountsClient := mockstorageaccountclient.NewMockInterface(ctrl) + d.cloud.StorageAccountClient = mockStorageAccountsClient + + mockFileClient.EXPECT().WithSubscriptionID(gomock.Any()).Return(mockFileClient).AnyTimes() + mockFileClient.EXPECT().CreateFileShare(context.TODO(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(storage.FileShare{FileShareProperties: &storage.FileShareProperties{ShareQuota: nil}}, nil).AnyTimes() + mockStorageAccountsClient.EXPECT().ListKeys(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(keys, nil).AnyTimes() + mockStorageAccountsClient.EXPECT().ListByResourceGroup(gomock.Any(), gomock.Any(), gomock.Any()).Return(accounts, nil).AnyTimes() + mockStorageAccountsClient.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockFileClient.EXPECT().GetFileShare(context.TODO(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(storage.FileShare{FileShareProperties: &storage.FileShareProperties{ShareQuota: &fakeShareQuota}}, nil).AnyTimes() + + _, err := d.CreateVolume(ctx, req) + if !reflect.DeepEqual(err, nil) { + t.Errorf("Unexpected error: %v", err) + } + }, + }, + { + name: "Missing storage account type with smb protocol", + testFunc: func(t *testing.T) { + req := &csi.CreateVolumeRequest{ + Name: "vol-smb-proto", + Parameters: map[string]string{protocolField: smb}, + VolumeCapabilities: stdVolCap, + } + + d := NewFakeDriver() + + expectedErr := status.Error(codes.InvalidArgument, fmt.Sprintf("%s/%s parameter is required with protocol(%s)", skuNameField, storageAccountField, smb)) + _, err := d.CreateVolume(ctx, req) + if !reflect.DeepEqual(err, expectedErr) { + t.Errorf("Unexpected error: %v", err) + } + }, + }, } for _, tc := range testCases {