Skip to content

Commit

Permalink
refactor:Adopt dependency injection pattern in grpc server
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinForReal committed Dec 21, 2023
1 parent 4018fb1 commit 3ed013d
Show file tree
Hide file tree
Showing 12 changed files with 212 additions and 229 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ require (
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.21.0
go.opentelemetry.io/otel/sdk v1.21.0
golang.org/x/net v0.19.0
golang.org/x/sync v0.5.0
google.golang.org/grpc v1.60.1
google.golang.org/protobuf v1.31.0
k8s.io/api v0.29.0
Expand Down Expand Up @@ -131,7 +132,6 @@ require (
golang.org/x/exp v0.0.0-20231214170342-aacd6d4b4611 // indirect
golang.org/x/mod v0.14.0 // indirect
golang.org/x/oauth2 v0.15.0 // indirect
golang.org/x/sync v0.5.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/term v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
Expand Down
54 changes: 46 additions & 8 deletions pkg/azuredisk/azuredisk.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package azuredisk

import (
"context"
"errors"
"fmt"
"reflect"
"strconv"
Expand All @@ -26,7 +27,9 @@ import (

"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"github.com/container-storage-interface/spec/lib/go/csi"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

Expand Down Expand Up @@ -78,6 +81,9 @@ type DriverOptions struct {
EnableOtelTracing bool
WaitForSnapshotReady bool
CheckDiskLUNCollision bool
Kubeconfig string
Endpoint string
DisableAVSetNodes bool
}

// CSIDriver defines the interface for a CSI driver.
Expand All @@ -86,7 +92,7 @@ type CSIDriver interface {
csi.NodeServer
csi.IdentityServer

Run(endpoint, kubeconfig string, disableAVSetNodes, testMode bool)
Run(ctx context.Context) error
}

type hostUtil interface {
Expand Down Expand Up @@ -127,6 +133,8 @@ type DriverCore struct {
enableOtelTracing bool
shouldWaitForSnapshotReady bool
checkDiskLUNCollision bool
endpoint string
disableAVSetNodes bool
}

// Driver is the v1 implementation of the Azure Disk CSI Driver.
Expand Down Expand Up @@ -169,6 +177,9 @@ func newDriverV1(options *DriverOptions) *Driver {
driver.enableOtelTracing = options.EnableOtelTracing
driver.shouldWaitForSnapshotReady = options.WaitForSnapshotReady
driver.checkDiskLUNCollision = options.CheckDiskLUNCollision
driver.endpoint = options.Endpoint
driver.disableAVSetNodes = options.DisableAVSetNodes
driver.kubeconfig = options.Kubeconfig
driver.volumeLocks = volumehelper.NewVolumeLocks()
driver.ioHandler = azureutils.NewOSIOHandler()
driver.hostUtil = hostutil.NewHostUtil()
Expand All @@ -186,7 +197,7 @@ func newDriverV1(options *DriverOptions) *Driver {
}

// Run driver initialization
func (d *Driver) Run(endpoint, kubeconfig string, disableAVSetNodes, testingMock bool) {
func (d *Driver) Run(ctx context.Context) error {
versionMeta, err := GetVersionYAML(d.Name)
if err != nil {
klog.Fatalf("%v", err)
Expand All @@ -196,13 +207,12 @@ func (d *Driver) Run(endpoint, kubeconfig string, disableAVSetNodes, testingMock
userAgent := GetUserAgent(d.Name, d.customUserAgent, d.userAgentSuffix)
klog.V(2).Infof("driver userAgent: %s", userAgent)

cloud, err := azureutils.GetCloudProvider(context.Background(), kubeconfig, d.cloudConfigSecretName, d.cloudConfigSecretNamespace,
cloud, err := azureutils.GetCloudProvider(context.Background(), d.kubeconfig, d.cloudConfigSecretName, d.cloudConfigSecretNamespace,
userAgent, d.allowEmptyCloudConfig, d.enableTrafficManager, d.trafficManagerPort)
if err != nil {
klog.Fatalf("failed to get Azure Cloud Provider, error: %v", err)
}
d.cloud = cloud
d.kubeconfig = kubeconfig

if d.cloud != nil {
if d.vmType != "" {
Expand All @@ -221,7 +231,7 @@ func (d *Driver) Run(endpoint, kubeconfig string, disableAVSetNodes, testingMock
d.cloud.DisableAvailabilitySetNodes = false
}

if d.cloud.VMType == azurecloudconsts.VMTypeVMSS && !d.cloud.DisableAvailabilitySetNodes && disableAVSetNodes {
if d.cloud.VMType == azurecloudconsts.VMTypeVMSS && !d.cloud.DisableAvailabilitySetNodes && d.disableAVSetNodes {
klog.V(2).Infof("DisableAvailabilitySetNodes for controller since current VMType is vmss")
d.cloud.DisableAvailabilitySetNodes = true
}
Expand Down Expand Up @@ -283,11 +293,39 @@ func (d *Driver) Run(endpoint, kubeconfig string, disableAVSetNodes, testingMock
csi.NodeServiceCapability_RPC_GET_VOLUME_STATS,
csi.NodeServiceCapability_RPC_SINGLE_NODE_MULTI_WRITER,
})
grpcInterceptor := grpc.UnaryInterceptor(csicommon.LogGRPC)
if d.enableOtelTracing {
grpcInterceptor = grpc.ChainUnaryInterceptor(csicommon.LogGRPC, otelgrpc.UnaryServerInterceptor())
}
opts := []grpc.ServerOption{
grpcInterceptor,
}

if d.enableOtelTracing {
opts = append(opts, grpc.StatsHandler(otelgrpc.NewServerHandler()))
}

s := grpc.NewServer(opts...)
csi.RegisterIdentityServer(s, d)
csi.RegisterControllerServer(s, d)
csi.RegisterNodeServer(s, d)

s := csicommon.NewNonBlockingGRPCServer()
go func() {
//graceful shutdown
<-ctx.Done()
s.GracefulStop()
}()
// Driver d act as IdentityServer, ControllerServer and NodeServer
s.Start(endpoint, d, d, d, testingMock, d.enableOtelTracing)
s.Wait()
listener, err := csicommon.Listen(ctx, d.endpoint)
if err != nil {
klog.Fatalf("failed to listen to endpoint, error: %v", err)
}
err = s.Serve(listener)
if errors.Is(err, grpc.ErrServerStopped) {
klog.Infof("gRPC server stopped serving")
return nil
}
return err
}

func (d *Driver) isGetDiskThrottled() bool {
Expand Down
38 changes: 34 additions & 4 deletions pkg/azuredisk/azuredisk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/Azure/go-autorest/autorest/date"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/status"
"k8s.io/apimachinery/pkg/types"
clientset "k8s.io/client-go/kubernetes"
Expand Down Expand Up @@ -107,7 +108,14 @@ func TestRun(t *testing.T) {
cntl := gomock.NewController(t)
defer cntl.Finish()
d, _ := NewFakeDriver(cntl)
d.Run("tcp://127.0.0.1:0", "", true, true)
ctx, cancelFn := context.WithCancel(context.Background())
var routines errgroup.Group
routines.Go(func() error { return d.Run(ctx) })
time.Sleep(time.Millisecond * 500)
cancelFn()
time.Sleep(time.Millisecond * 500)
err := routines.Wait()
assert.Nil(t, err)
},
},
{
Expand All @@ -116,7 +124,14 @@ func TestRun(t *testing.T) {
cntl := gomock.NewController(t)
defer cntl.Finish()
d, _ := NewFakeDriver(cntl)
d.Run("tcp://127.0.0.1:0", "", true, true)
ctx, cancelFn := context.WithCancel(context.Background())
var routines errgroup.Group
routines.Go(func() error { return d.Run(ctx) })
time.Sleep(time.Millisecond * 500)
cancelFn()
time.Sleep(time.Millisecond * 500)
err := routines.Wait()
assert.Nil(t, err)
},
},
{
Expand All @@ -139,7 +154,14 @@ func TestRun(t *testing.T) {
d, _ := NewFakeDriver(cntl)
d.setCloud(&azure.Cloud{})
d.setNodeID("")
d.Run("tcp://127.0.0.1:0", "", true, true)
ctx, cancelFn := context.WithCancel(context.Background())
var routines errgroup.Group
routines.Go(func() error { return d.Run(ctx) })
time.Sleep(time.Millisecond * 500)
cancelFn()
time.Sleep(time.Millisecond * 500)
err := routines.Wait()
assert.Nil(t, err)
},
},
{
Expand All @@ -165,8 +187,16 @@ func TestRun(t *testing.T) {
EnablePerfOptimization: true,
VMSSCacheTTLInSeconds: 10,
VMType: "vmss",
Endpoint: "tcp://127.0.0.1:0",
})
d.Run("tcp://127.0.0.1:0", "", true, true)
ctx, cancelFn := context.WithCancel(context.Background())
var routines errgroup.Group
routines.Go(func() error { return d.Run(ctx) })
time.Sleep(time.Millisecond * 500)
cancelFn()
time.Sleep(time.Millisecond * 500)
err := routines.Wait()
assert.Nil(t, err)
},
},
}
Expand Down
43 changes: 36 additions & 7 deletions pkg/azuredisk/azuredisk_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@ package azuredisk

import (
"context"
"errors"
"flag"
"fmt"
"reflect"

"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"github.com/container-storage-interface/spec/lib/go/csi"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

Expand Down Expand Up @@ -78,13 +81,16 @@ func newDriverV2(options *DriverOptions) *DriverV2 {
driver.enableOtelTracing = options.EnableOtelTracing
driver.ioHandler = azureutils.NewOSIOHandler()
driver.hostUtil = hostutil.NewHostUtil()
driver.kubeconfig = options.Kubeconfig
driver.disableAVSetNodes = options.DisableAVSetNodes
driver.endpoint = options.Endpoint

topologyKey = fmt.Sprintf("topology.%s/zone", driver.Name)
return &driver
}

// Run driver initialization
func (d *DriverV2) Run(endpoint, kubeconfig string, disableAVSetNodes, testingMock bool) {
func (d *DriverV2) Run(ctx context.Context) error {
versionMeta, err := GetVersionYAML(d.Name)
if err != nil {
klog.Fatalf("%v", err)
Expand All @@ -94,7 +100,7 @@ func (d *DriverV2) Run(endpoint, kubeconfig string, disableAVSetNodes, testingMo
userAgent := GetUserAgent(d.Name, d.customUserAgent, d.userAgentSuffix)
klog.V(2).Infof("driver userAgent: %s", userAgent)

cloud, err := azureutils.GetCloudProvider(context.Background(), kubeconfig, d.cloudConfigSecretName, d.cloudConfigSecretNamespace,
cloud, err := azureutils.GetCloudProvider(context.Background(), d.kubeconfig, d.cloudConfigSecretName, d.cloudConfigSecretNamespace,
userAgent, d.allowEmptyCloudConfig, d.enableTrafficManager, d.trafficManagerPort)
if err != nil {
klog.Fatalf("failed to get Azure Cloud Provider, error: %v", err)
Expand All @@ -118,7 +124,7 @@ func (d *DriverV2) Run(endpoint, kubeconfig string, disableAVSetNodes, testingMo
d.cloud.DisableAvailabilitySetNodes = false
}

if d.cloud.VMType == consts.VMTypeVMSS && !d.cloud.DisableAvailabilitySetNodes && disableAVSetNodes {
if d.cloud.VMType == consts.VMTypeVMSS && !d.cloud.DisableAvailabilitySetNodes && d.disableAVSetNodes {
klog.V(2).Infof("DisableAvailabilitySetNodes for controller since current VMType is vmss")
d.cloud.DisableAvailabilitySetNodes = true
}
Expand Down Expand Up @@ -165,11 +171,34 @@ func (d *DriverV2) Run(endpoint, kubeconfig string, disableAVSetNodes, testingMo
csi.NodeServiceCapability_RPC_GET_VOLUME_STATS,
csi.NodeServiceCapability_RPC_SINGLE_NODE_MULTI_WRITER,
})

s := csicommon.NewNonBlockingGRPCServer()
grpcInterceptor := grpc.UnaryInterceptor(csicommon.LogGRPC)
if d.enableOtelTracing {
grpcInterceptor = grpc.ChainUnaryInterceptor(csicommon.LogGRPC, otelgrpc.UnaryServerInterceptor())
}
opts := []grpc.ServerOption{
grpcInterceptor,
}
s := grpc.NewServer(opts...)
csi.RegisterIdentityServer(s, d)
csi.RegisterControllerServer(s, d)
csi.RegisterNodeServer(s, d)

go func() {
//graceful shutdown
<-ctx.Done()
s.GracefulStop()
}()
// Driver d act as IdentityServer, ControllerServer and NodeServer
s.Start(endpoint, d, d, d, testingMock, d.enableOtelTracing)
s.Wait()
listener, err := csicommon.Listen(ctx, d.endpoint)
if err != nil {
klog.Fatalf("failed to listen to endpoint, error: %v", err)
}
err = s.Serve(listener)
if errors.Is(err, grpc.ErrServerStopped) {
klog.Infof("gRPC server stopped serving")
return nil
}
return err
}

func (d *DriverV2) checkDiskExists(ctx context.Context, diskURI string) (*compute.Disk, error) {
Expand Down
3 changes: 3 additions & 0 deletions pkg/azuredisk/fake_azuredisk.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ func newFakeDriverV1(ctrl *gomock.Controller) (*fakeDriverV1, error) {
driver.useCSIProxyGAInterface = true
driver.allowEmptyCloudConfig = true
driver.shouldWaitForSnapshotReady = true
driver.endpoint = "tcp://127.0.0.1:0"
driver.disableAVSetNodes = true
driver.kubeconfig = ""

driver.cloud = azure.GetTestCloud(ctrl)
mounter, err := mounter.NewSafeMounter(driver.enableWindowsHostProcess, driver.useCSIProxyGAInterface)
Expand Down
3 changes: 3 additions & 0 deletions pkg/azuredisk/fake_azuredisk_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ func newFakeDriverV2(ctrl *gomock.Controller) (*fakeDriverV2, error) {
driver.hostUtil = azureutils.NewFakeHostUtil()
driver.useCSIProxyGAInterface = true
driver.allowEmptyCloudConfig = true
driver.endpoint = "tcp://127.0.0.1:0"
driver.disableAVSetNodes = true
driver.kubeconfig = ""

driver.cloud = azure.GetTestCloud(ctrl)
mounter, err := mounter.NewSafeMounter(driver.enableWindowsHostProcess, driver.useCSIProxyGAInterface)
Expand Down
8 changes: 6 additions & 2 deletions pkg/azurediskplugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,17 @@ func handle() {
EnableOtelTracing: *enableOtelTracing,
WaitForSnapshotReady: *waitForSnapshotReady,
CheckDiskLUNCollision: *checkDiskLUNCollision,
Endpoint: *endpoint,
Kubeconfig: *kubeconfig,
DisableAVSetNodes: *disableAVSetNodes,
}
driver := azuredisk.NewDriver(&driverOptions)
if driver == nil {
klog.Fatalln("Failed to initialize azuredisk CSI Driver")
}
testingMock := false
driver.Run(*endpoint, *kubeconfig, *disableAVSetNodes, testingMock)
if err := driver.Run(context.Background()); err != nil {
klog.Fatalf("Failed to run azuredisk CSI Driver: %v", err)
}
}

func exportMetrics() {
Expand Down
Loading

0 comments on commit 3ed013d

Please sign in to comment.