From 3ed013d7479ce8468ad62c57b508f478c4abcb11 Mon Sep 17 00:00:00 2001 From: Fan Shang Xiang Date: Fri, 15 Dec 2023 16:12:25 +0800 Subject: [PATCH] refactor:Adopt dependency injection pattern in grpc server --- go.mod | 2 +- pkg/azuredisk/azuredisk.go | 54 ++++++++++-- pkg/azuredisk/azuredisk_test.go | 38 ++++++++- pkg/azuredisk/azuredisk_v2.go | 43 ++++++++-- pkg/azuredisk/fake_azuredisk.go | 3 + pkg/azuredisk/fake_azuredisk_v2.go | 3 + pkg/azurediskplugin/main.go | 8 +- pkg/csi-common/server.go | 128 ----------------------------- pkg/csi-common/server_test.go | 75 ----------------- pkg/csi-common/utils.go | 30 ++++++- pkg/csi-common/utils_test.go | 49 ++++++++++- test/e2e/suite_test.go | 8 +- 12 files changed, 212 insertions(+), 229 deletions(-) delete mode 100644 pkg/csi-common/server.go delete mode 100644 pkg/csi-common/server_test.go diff --git a/go.mod b/go.mod index 7c6705085f..597e2e2851 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.21.0 go.opentelemetry.io/otel/sdk v1.21.0 golang.org/x/net v0.19.0 + golang.org/x/sync v0.5.0 google.golang.org/grpc v1.60.1 google.golang.org/protobuf v1.31.0 k8s.io/api v0.29.0 @@ -131,7 +132,6 @@ require ( golang.org/x/exp v0.0.0-20231214170342-aacd6d4b4611 // indirect golang.org/x/mod v0.14.0 // indirect golang.org/x/oauth2 v0.15.0 // indirect - golang.org/x/sync v0.5.0 // indirect golang.org/x/sys v0.15.0 // indirect golang.org/x/term v0.15.0 // indirect golang.org/x/text v0.14.0 // indirect diff --git a/pkg/azuredisk/azuredisk.go b/pkg/azuredisk/azuredisk.go index 7c8d6b4f68..a6795e4a1c 100644 --- a/pkg/azuredisk/azuredisk.go +++ b/pkg/azuredisk/azuredisk.go @@ -18,6 +18,7 @@ package azuredisk import ( "context" + "errors" "fmt" "reflect" "strconv" @@ -26,7 +27,9 @@ import ( "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" "github.com/container-storage-interface/spec/lib/go/csi" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -78,6 +81,9 @@ type DriverOptions struct { EnableOtelTracing bool WaitForSnapshotReady bool CheckDiskLUNCollision bool + Kubeconfig string + Endpoint string + DisableAVSetNodes bool } // CSIDriver defines the interface for a CSI driver. @@ -86,7 +92,7 @@ type CSIDriver interface { csi.NodeServer csi.IdentityServer - Run(endpoint, kubeconfig string, disableAVSetNodes, testMode bool) + Run(ctx context.Context) error } type hostUtil interface { @@ -127,6 +133,8 @@ type DriverCore struct { enableOtelTracing bool shouldWaitForSnapshotReady bool checkDiskLUNCollision bool + endpoint string + disableAVSetNodes bool } // Driver is the v1 implementation of the Azure Disk CSI Driver. @@ -169,6 +177,9 @@ func newDriverV1(options *DriverOptions) *Driver { driver.enableOtelTracing = options.EnableOtelTracing driver.shouldWaitForSnapshotReady = options.WaitForSnapshotReady driver.checkDiskLUNCollision = options.CheckDiskLUNCollision + driver.endpoint = options.Endpoint + driver.disableAVSetNodes = options.DisableAVSetNodes + driver.kubeconfig = options.Kubeconfig driver.volumeLocks = volumehelper.NewVolumeLocks() driver.ioHandler = azureutils.NewOSIOHandler() driver.hostUtil = hostutil.NewHostUtil() @@ -186,7 +197,7 @@ func newDriverV1(options *DriverOptions) *Driver { } // Run driver initialization -func (d *Driver) Run(endpoint, kubeconfig string, disableAVSetNodes, testingMock bool) { +func (d *Driver) Run(ctx context.Context) error { versionMeta, err := GetVersionYAML(d.Name) if err != nil { klog.Fatalf("%v", err) @@ -196,13 +207,12 @@ func (d *Driver) Run(endpoint, kubeconfig string, disableAVSetNodes, testingMock userAgent := GetUserAgent(d.Name, d.customUserAgent, d.userAgentSuffix) klog.V(2).Infof("driver userAgent: %s", userAgent) - cloud, err := azureutils.GetCloudProvider(context.Background(), kubeconfig, d.cloudConfigSecretName, d.cloudConfigSecretNamespace, + cloud, err := azureutils.GetCloudProvider(context.Background(), d.kubeconfig, d.cloudConfigSecretName, d.cloudConfigSecretNamespace, userAgent, d.allowEmptyCloudConfig, d.enableTrafficManager, d.trafficManagerPort) if err != nil { klog.Fatalf("failed to get Azure Cloud Provider, error: %v", err) } d.cloud = cloud - d.kubeconfig = kubeconfig if d.cloud != nil { if d.vmType != "" { @@ -221,7 +231,7 @@ func (d *Driver) Run(endpoint, kubeconfig string, disableAVSetNodes, testingMock d.cloud.DisableAvailabilitySetNodes = false } - if d.cloud.VMType == azurecloudconsts.VMTypeVMSS && !d.cloud.DisableAvailabilitySetNodes && disableAVSetNodes { + if d.cloud.VMType == azurecloudconsts.VMTypeVMSS && !d.cloud.DisableAvailabilitySetNodes && d.disableAVSetNodes { klog.V(2).Infof("DisableAvailabilitySetNodes for controller since current VMType is vmss") d.cloud.DisableAvailabilitySetNodes = true } @@ -283,11 +293,39 @@ func (d *Driver) Run(endpoint, kubeconfig string, disableAVSetNodes, testingMock csi.NodeServiceCapability_RPC_GET_VOLUME_STATS, csi.NodeServiceCapability_RPC_SINGLE_NODE_MULTI_WRITER, }) + grpcInterceptor := grpc.UnaryInterceptor(csicommon.LogGRPC) + if d.enableOtelTracing { + grpcInterceptor = grpc.ChainUnaryInterceptor(csicommon.LogGRPC, otelgrpc.UnaryServerInterceptor()) + } + opts := []grpc.ServerOption{ + grpcInterceptor, + } + + if d.enableOtelTracing { + opts = append(opts, grpc.StatsHandler(otelgrpc.NewServerHandler())) + } + + s := grpc.NewServer(opts...) + csi.RegisterIdentityServer(s, d) + csi.RegisterControllerServer(s, d) + csi.RegisterNodeServer(s, d) - s := csicommon.NewNonBlockingGRPCServer() + go func() { + //graceful shutdown + <-ctx.Done() + s.GracefulStop() + }() // Driver d act as IdentityServer, ControllerServer and NodeServer - s.Start(endpoint, d, d, d, testingMock, d.enableOtelTracing) - s.Wait() + listener, err := csicommon.Listen(ctx, d.endpoint) + if err != nil { + klog.Fatalf("failed to listen to endpoint, error: %v", err) + } + err = s.Serve(listener) + if errors.Is(err, grpc.ErrServerStopped) { + klog.Infof("gRPC server stopped serving") + return nil + } + return err } func (d *Driver) isGetDiskThrottled() bool { diff --git a/pkg/azuredisk/azuredisk_test.go b/pkg/azuredisk/azuredisk_test.go index a80e3210bf..6f38ba5ac4 100644 --- a/pkg/azuredisk/azuredisk_test.go +++ b/pkg/azuredisk/azuredisk_test.go @@ -29,6 +29,7 @@ import ( "github.com/Azure/go-autorest/autorest/date" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" + "golang.org/x/sync/errgroup" "google.golang.org/grpc/status" "k8s.io/apimachinery/pkg/types" clientset "k8s.io/client-go/kubernetes" @@ -107,7 +108,14 @@ func TestRun(t *testing.T) { cntl := gomock.NewController(t) defer cntl.Finish() d, _ := NewFakeDriver(cntl) - d.Run("tcp://127.0.0.1:0", "", true, true) + ctx, cancelFn := context.WithCancel(context.Background()) + var routines errgroup.Group + routines.Go(func() error { return d.Run(ctx) }) + time.Sleep(time.Millisecond * 500) + cancelFn() + time.Sleep(time.Millisecond * 500) + err := routines.Wait() + assert.Nil(t, err) }, }, { @@ -116,7 +124,14 @@ func TestRun(t *testing.T) { cntl := gomock.NewController(t) defer cntl.Finish() d, _ := NewFakeDriver(cntl) - d.Run("tcp://127.0.0.1:0", "", true, true) + ctx, cancelFn := context.WithCancel(context.Background()) + var routines errgroup.Group + routines.Go(func() error { return d.Run(ctx) }) + time.Sleep(time.Millisecond * 500) + cancelFn() + time.Sleep(time.Millisecond * 500) + err := routines.Wait() + assert.Nil(t, err) }, }, { @@ -139,7 +154,14 @@ func TestRun(t *testing.T) { d, _ := NewFakeDriver(cntl) d.setCloud(&azure.Cloud{}) d.setNodeID("") - d.Run("tcp://127.0.0.1:0", "", true, true) + ctx, cancelFn := context.WithCancel(context.Background()) + var routines errgroup.Group + routines.Go(func() error { return d.Run(ctx) }) + time.Sleep(time.Millisecond * 500) + cancelFn() + time.Sleep(time.Millisecond * 500) + err := routines.Wait() + assert.Nil(t, err) }, }, { @@ -165,8 +187,16 @@ func TestRun(t *testing.T) { EnablePerfOptimization: true, VMSSCacheTTLInSeconds: 10, VMType: "vmss", + Endpoint: "tcp://127.0.0.1:0", }) - d.Run("tcp://127.0.0.1:0", "", true, true) + ctx, cancelFn := context.WithCancel(context.Background()) + var routines errgroup.Group + routines.Go(func() error { return d.Run(ctx) }) + time.Sleep(time.Millisecond * 500) + cancelFn() + time.Sleep(time.Millisecond * 500) + err := routines.Wait() + assert.Nil(t, err) }, }, } diff --git a/pkg/azuredisk/azuredisk_v2.go b/pkg/azuredisk/azuredisk_v2.go index 391c44256d..1ae2f9c72c 100644 --- a/pkg/azuredisk/azuredisk_v2.go +++ b/pkg/azuredisk/azuredisk_v2.go @@ -21,13 +21,16 @@ package azuredisk import ( "context" + "errors" "flag" "fmt" "reflect" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" "github.com/container-storage-interface/spec/lib/go/csi" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -78,13 +81,16 @@ func newDriverV2(options *DriverOptions) *DriverV2 { driver.enableOtelTracing = options.EnableOtelTracing driver.ioHandler = azureutils.NewOSIOHandler() driver.hostUtil = hostutil.NewHostUtil() + driver.kubeconfig = options.Kubeconfig + driver.disableAVSetNodes = options.DisableAVSetNodes + driver.endpoint = options.Endpoint topologyKey = fmt.Sprintf("topology.%s/zone", driver.Name) return &driver } // Run driver initialization -func (d *DriverV2) Run(endpoint, kubeconfig string, disableAVSetNodes, testingMock bool) { +func (d *DriverV2) Run(ctx context.Context) error { versionMeta, err := GetVersionYAML(d.Name) if err != nil { klog.Fatalf("%v", err) @@ -94,7 +100,7 @@ func (d *DriverV2) Run(endpoint, kubeconfig string, disableAVSetNodes, testingMo userAgent := GetUserAgent(d.Name, d.customUserAgent, d.userAgentSuffix) klog.V(2).Infof("driver userAgent: %s", userAgent) - cloud, err := azureutils.GetCloudProvider(context.Background(), kubeconfig, d.cloudConfigSecretName, d.cloudConfigSecretNamespace, + cloud, err := azureutils.GetCloudProvider(context.Background(), d.kubeconfig, d.cloudConfigSecretName, d.cloudConfigSecretNamespace, userAgent, d.allowEmptyCloudConfig, d.enableTrafficManager, d.trafficManagerPort) if err != nil { klog.Fatalf("failed to get Azure Cloud Provider, error: %v", err) @@ -118,7 +124,7 @@ func (d *DriverV2) Run(endpoint, kubeconfig string, disableAVSetNodes, testingMo d.cloud.DisableAvailabilitySetNodes = false } - if d.cloud.VMType == consts.VMTypeVMSS && !d.cloud.DisableAvailabilitySetNodes && disableAVSetNodes { + if d.cloud.VMType == consts.VMTypeVMSS && !d.cloud.DisableAvailabilitySetNodes && d.disableAVSetNodes { klog.V(2).Infof("DisableAvailabilitySetNodes for controller since current VMType is vmss") d.cloud.DisableAvailabilitySetNodes = true } @@ -165,11 +171,34 @@ func (d *DriverV2) Run(endpoint, kubeconfig string, disableAVSetNodes, testingMo csi.NodeServiceCapability_RPC_GET_VOLUME_STATS, csi.NodeServiceCapability_RPC_SINGLE_NODE_MULTI_WRITER, }) - - s := csicommon.NewNonBlockingGRPCServer() + grpcInterceptor := grpc.UnaryInterceptor(csicommon.LogGRPC) + if d.enableOtelTracing { + grpcInterceptor = grpc.ChainUnaryInterceptor(csicommon.LogGRPC, otelgrpc.UnaryServerInterceptor()) + } + opts := []grpc.ServerOption{ + grpcInterceptor, + } + s := grpc.NewServer(opts...) + csi.RegisterIdentityServer(s, d) + csi.RegisterControllerServer(s, d) + csi.RegisterNodeServer(s, d) + + go func() { + //graceful shutdown + <-ctx.Done() + s.GracefulStop() + }() // Driver d act as IdentityServer, ControllerServer and NodeServer - s.Start(endpoint, d, d, d, testingMock, d.enableOtelTracing) - s.Wait() + listener, err := csicommon.Listen(ctx, d.endpoint) + if err != nil { + klog.Fatalf("failed to listen to endpoint, error: %v", err) + } + err = s.Serve(listener) + if errors.Is(err, grpc.ErrServerStopped) { + klog.Infof("gRPC server stopped serving") + return nil + } + return err } func (d *DriverV2) checkDiskExists(ctx context.Context, diskURI string) (*compute.Disk, error) { diff --git a/pkg/azuredisk/fake_azuredisk.go b/pkg/azuredisk/fake_azuredisk.go index 46465124d7..03a519fd88 100644 --- a/pkg/azuredisk/fake_azuredisk.go +++ b/pkg/azuredisk/fake_azuredisk.go @@ -114,6 +114,9 @@ func newFakeDriverV1(ctrl *gomock.Controller) (*fakeDriverV1, error) { driver.useCSIProxyGAInterface = true driver.allowEmptyCloudConfig = true driver.shouldWaitForSnapshotReady = true + driver.endpoint = "tcp://127.0.0.1:0" + driver.disableAVSetNodes = true + driver.kubeconfig = "" driver.cloud = azure.GetTestCloud(ctrl) mounter, err := mounter.NewSafeMounter(driver.enableWindowsHostProcess, driver.useCSIProxyGAInterface) diff --git a/pkg/azuredisk/fake_azuredisk_v2.go b/pkg/azuredisk/fake_azuredisk_v2.go index d1a54978c9..f3b6173a1d 100644 --- a/pkg/azuredisk/fake_azuredisk_v2.go +++ b/pkg/azuredisk/fake_azuredisk_v2.go @@ -64,6 +64,9 @@ func newFakeDriverV2(ctrl *gomock.Controller) (*fakeDriverV2, error) { driver.hostUtil = azureutils.NewFakeHostUtil() driver.useCSIProxyGAInterface = true driver.allowEmptyCloudConfig = true + driver.endpoint = "tcp://127.0.0.1:0" + driver.disableAVSetNodes = true + driver.kubeconfig = "" driver.cloud = azure.GetTestCloud(ctrl) mounter, err := mounter.NewSafeMounter(driver.enableWindowsHostProcess, driver.useCSIProxyGAInterface) diff --git a/pkg/azurediskplugin/main.go b/pkg/azurediskplugin/main.go index 1f4e801b27..d8eaea742d 100644 --- a/pkg/azurediskplugin/main.go +++ b/pkg/azurediskplugin/main.go @@ -135,13 +135,17 @@ func handle() { EnableOtelTracing: *enableOtelTracing, WaitForSnapshotReady: *waitForSnapshotReady, CheckDiskLUNCollision: *checkDiskLUNCollision, + Endpoint: *endpoint, + Kubeconfig: *kubeconfig, + DisableAVSetNodes: *disableAVSetNodes, } driver := azuredisk.NewDriver(&driverOptions) if driver == nil { klog.Fatalln("Failed to initialize azuredisk CSI Driver") } - testingMock := false - driver.Run(*endpoint, *kubeconfig, *disableAVSetNodes, testingMock) + if err := driver.Run(context.Background()); err != nil { + klog.Fatalf("Failed to run azuredisk CSI Driver: %v", err) + } } func exportMetrics() { diff --git a/pkg/csi-common/server.go b/pkg/csi-common/server.go deleted file mode 100644 index b0e426ffa7..0000000000 --- a/pkg/csi-common/server.go +++ /dev/null @@ -1,128 +0,0 @@ -/* -Copyright 2017 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package csicommon - -import ( - "net" - "os" - "runtime" - "sync" - "time" - - "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" - "google.golang.org/grpc" - "k8s.io/klog/v2" - - "github.com/container-storage-interface/spec/lib/go/csi" -) - -// Defines Non blocking GRPC server interfaces -type NonBlockingGRPCServer interface { - // Start services at the endpoint - Start(endpoint string, ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer, testMode, enableOtelTracing bool) - // Waits for the service to stop - Wait() - // Stops the service gracefully - Stop() - // Stops the service forcefully - ForceStop() -} - -func NewNonBlockingGRPCServer() NonBlockingGRPCServer { - return &nonBlockingGRPCServer{} -} - -// NonBlocking server -type nonBlockingGRPCServer struct { - wg sync.WaitGroup - server *grpc.Server -} - -func (s *nonBlockingGRPCServer) Start(endpoint string, ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer, testMode, enableOtelTracing bool) { - s.wg.Add(1) - go s.serve(endpoint, ids, cs, ns, testMode, enableOtelTracing) -} - -func (s *nonBlockingGRPCServer) Wait() { - s.wg.Wait() -} - -func (s *nonBlockingGRPCServer) Stop() { - s.server.GracefulStop() -} - -func (s *nonBlockingGRPCServer) ForceStop() { - s.server.Stop() -} - -func (s *nonBlockingGRPCServer) serve(endpoint string, ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer, testMode, enableOtelTracing bool) { - proto, addr, err := ParseEndpoint(endpoint) - if err != nil { - klog.Fatal(err.Error()) - } - - if proto == "unix" { - if runtime.GOOS != "windows" { - addr = "/" + addr - } - if err := os.Remove(addr); err != nil && !os.IsNotExist(err) { - klog.Fatalf("Failed to remove %s, error: %s", addr, err.Error()) - } - } - - listener, err := net.Listen(proto, addr) - if err != nil { - klog.Fatalf("Failed to listen: %v", err) - } - - grpcInterceptor := grpc.UnaryInterceptor(logGRPC) - if enableOtelTracing { - grpcInterceptor = grpc.ChainUnaryInterceptor(logGRPC, otelgrpc.UnaryServerInterceptor()) - } - - opts := []grpc.ServerOption{ - grpcInterceptor, - } - - server := grpc.NewServer(opts...) - s.server = server - - if ids != nil { - csi.RegisterIdentityServer(server, ids) - } - if cs != nil { - csi.RegisterControllerServer(server, cs) - } - if ns != nil { - csi.RegisterNodeServer(server, ns) - } - // Used to stop the server while running tests - if testMode { - s.wg.Done() - go func() { - // make sure Serve() is called - s.wg.Wait() - time.Sleep(time.Millisecond * 1000) - s.server.GracefulStop() - }() - } - - klog.Infof("Listening for connections on address: %#v", listener.Addr()) - if err := server.Serve(listener); err != nil { - klog.Errorf("Listening for connections on address: %#v, error: %v", listener.Addr(), err) - } -} diff --git a/pkg/csi-common/server_test.go b/pkg/csi-common/server_test.go deleted file mode 100644 index 09678052bc..0000000000 --- a/pkg/csi-common/server_test.go +++ /dev/null @@ -1,75 +0,0 @@ -/* -Copyright 2020 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package csicommon - -import ( - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "google.golang.org/grpc" -) - -func TestNewNonBlockingGRPCServer(t *testing.T) { - s := NewNonBlockingGRPCServer() - assert.NotNil(t, s) -} - -func TestStart(_ *testing.T) { - s := NewNonBlockingGRPCServer() - // sleep a while to avoid race condition in unit test - time.Sleep(time.Millisecond * 500) - s.Start("tcp://127.0.0.1:0", nil, nil, nil, true, false) - time.Sleep(time.Millisecond * 500) -} - -func TestStartWithOtelTracing(_ *testing.T) { - s := NewNonBlockingGRPCServer() - // sleep a while to avoid race condition in unit test - time.Sleep(time.Millisecond * 500) - s.Start("tcp://127.0.0.1:0", nil, nil, nil, true, true) - time.Sleep(time.Millisecond * 500) -} - -func TestServe(_ *testing.T) { - s := nonBlockingGRPCServer{} - s.server = grpc.NewServer() - s.wg = sync.WaitGroup{} - //need to add one here as the actual also requires one. - s.wg.Add(1) - s.serve("tcp://127.0.0.1:0", nil, nil, nil, true, false) -} - -func TestWait(_ *testing.T) { - s := nonBlockingGRPCServer{} - s.server = grpc.NewServer() - s.wg = sync.WaitGroup{} - s.Wait() -} - -func TestStop(_ *testing.T) { - s := nonBlockingGRPCServer{} - s.server = grpc.NewServer() - s.Stop() -} - -func TestForceStop(_ *testing.T) { - s := nonBlockingGRPCServer{} - s.server = grpc.NewServer() - s.ForceStop() -} diff --git a/pkg/csi-common/utils.go b/pkg/csi-common/utils.go index 18550642ac..42fa0cf3fc 100644 --- a/pkg/csi-common/utils.go +++ b/pkg/csi-common/utils.go @@ -18,6 +18,9 @@ package csicommon import ( "fmt" + "net" + "os" + "runtime" "strings" "golang.org/x/net/context" @@ -38,6 +41,31 @@ func ParseEndpoint(ep string) (string, string, error) { return "", "", fmt.Errorf("Invalid endpoint: %v", ep) } +func Listen(ctx context.Context, endpoint string) (net.Listener, error) { + proto, addr, err := ParseEndpoint(endpoint) + if err != nil { + klog.Errorf(err.Error()) + return nil, err + } + + if proto == "unix" { + if runtime.GOOS != "windows" { + addr = "/" + addr + } + if err := os.Remove(addr); err != nil && !os.IsNotExist(err) { + klog.Errorf("Failed to remove %s, error: %s", addr, err.Error()) + return nil, err + } + } + listenConfig := net.ListenConfig{} + listener, err := listenConfig.Listen(ctx, proto, addr) + if err != nil { + klog.Errorf("Failed to listen: %v", err) + return nil, err + } + return listener, nil +} + func NewVolumeCapabilityAccessMode(mode csi.VolumeCapability_AccessMode_Mode) *csi.VolumeCapability_AccessMode { return &csi.VolumeCapability_AccessMode{Mode: mode} } @@ -72,7 +100,7 @@ func getLogLevel(method string) int32 { return 2 } -func logGRPC(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { +func LogGRPC(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { level := klog.Level(getLogLevel(info.FullMethod)) klog.V(level).Infof("GRPC call: %s", info.FullMethod) klog.V(level).Infof("GRPC request: %s", protosanitizer.StripSecrets(req)) diff --git a/pkg/csi-common/utils_test.go b/pkg/csi-common/utils_test.go index b152cda5e4..122537b99a 100644 --- a/pkg/csi-common/utils_test.go +++ b/pkg/csi-common/utils_test.go @@ -20,6 +20,7 @@ import ( "bytes" "context" "flag" + "os" "testing" "github.com/container-storage-interface/spec/lib/go/csi" @@ -133,7 +134,7 @@ func TestLogGRPC(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { // EXECUTE - _, _ = logGRPC(context.Background(), test.req, &info, handler) + _, _ = LogGRPC(context.Background(), test.req, &info, handler) klog.Flush() // ASSERT @@ -232,3 +233,49 @@ func TestGetLogLevel(t *testing.T) { } } } + +func TestListen(t *testing.T) { + tests := []struct { + name string + endpoint string + filePath string + wantErr bool + }{ + { + name: "unix socket", + endpoint: "unix:///tmp/csi.sock", + filePath: "/tmp/csi.sock", + wantErr: false, + }, + { + name: "tcp socket", + endpoint: "tcp://127.0.0.1:0", + wantErr: false, + }, + { + name: "invalid endpoint", + endpoint: "invalid://", + wantErr: true, + }, + { + name: "invalid unix socket", + endpoint: "unix://does/not/exist", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Listen(context.Background(), tt.endpoint) + if (err != nil) != tt.wantErr { + t.Errorf("Listen() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err == nil { + got.Close() + if tt.filePath != "" { + os.Remove(tt.filePath) + } + } + }) + } +} diff --git a/test/e2e/suite_test.go b/test/e2e/suite_test.go index f190f3d48c..a5e5d14c92 100644 --- a/test/e2e/suite_test.go +++ b/test/e2e/suite_test.go @@ -17,6 +17,7 @@ limitations under the License. package e2e import ( + "context" "flag" "fmt" "log" @@ -152,12 +153,15 @@ var _ = ginkgo.BeforeSuite(func(ctx ginkgo.SpecContext) { DriverName: consts.DefaultDriverName, VolumeAttachLimit: 16, EnablePerfOptimization: false, + Kubeconfig: os.Getenv(kubeconfigEnvVar), + Endpoint: fmt.Sprintf("unix:///tmp/csi-%s.sock", string(uuid.NewUUID())), } azurediskDriver = azuredisk.NewDriver(&driverOptions) - kubeconfig := os.Getenv(kubeconfigEnvVar) + go func() { os.Setenv("AZURE_CREDENTIAL_FILE", credentials.TempAzureCredentialFilePath) - azurediskDriver.Run(fmt.Sprintf("unix:///tmp/csi-%s.sock", string(uuid.NewUUID())), kubeconfig, false, false) + err := azurediskDriver.Run(context.Background()) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) }() } })