diff --git a/go.mod b/go.mod index f849bca60f..0077c0f417 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 github.com/nats-io/nats-streaming-server v0.17.0 github.com/nats-io/stan.go v0.6.0 - github.com/networkservicemesh/api v0.0.0-20210218170701-1a72f1cba074 + github.com/networkservicemesh/api v0.0.0-20210305165706-bcfdc8d78700 github.com/open-policy-agent/opa v0.16.1 github.com/opentracing/opentracing-go v1.1.0 github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum index c189f3fc0f..a5cb076a6c 100644 --- a/go.sum +++ b/go.sum @@ -156,8 +156,8 @@ github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/nats-io/stan.go v0.6.0 h1:26IJPeykh88d8KVLT4jJCIxCyUBOC5/IQup8oWD/QYY= github.com/nats-io/stan.go v0.6.0/go.mod h1:eIcD5bi3pqbHT/xIIvXMwvzXYElgouBvaVRftaE+eac= -github.com/networkservicemesh/api v0.0.0-20210218170701-1a72f1cba074 h1:lMU+bavS8l0vKZKtCYutUFtTaU5jzTEA7bD/s843XYU= -github.com/networkservicemesh/api v0.0.0-20210218170701-1a72f1cba074/go.mod h1:qvxdY1Zt4QTtiG+uH1XmjpegeHjlt5Jj4A8iK55iJPI= +github.com/networkservicemesh/api v0.0.0-20210305165706-bcfdc8d78700 h1:c4M5DLI0L3IMx56Gqnt6kQ4SAF0tRCu0thxH2gmTxCE= +github.com/networkservicemesh/api v0.0.0-20210305165706-bcfdc8d78700/go.mod h1:qvxdY1Zt4QTtiG+uH1XmjpegeHjlt5Jj4A8iK55iJPI= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/olekukonko/tablewriter v0.0.1/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= diff --git a/pkg/networkservice/common/timeout/timer_map.gen.go b/pkg/networkservice/common/timeout/close_timer_map.gen.go similarity index 58% rename from pkg/networkservice/common/timeout/timer_map.gen.go rename to pkg/networkservice/common/timeout/close_timer_map.gen.go index 6763ec3d47..18e1209e26 100644 --- a/pkg/networkservice/common/timeout/timer_map.gen.go +++ b/pkg/networkservice/common/timeout/close_timer_map.gen.go @@ -1,59 +1,58 @@ -// Code generated by "-output timer_map.gen.go -type timerMap -output timer_map.gen.go -type timerMap"; DO NOT EDIT. +// Code generated by "-output close_timer_map.gen.go -type closeTimerMap -output close_timer_map.gen.go -type closeTimerMap"; DO NOT EDIT. package timeout import ( "sync" // Used by sync.Map. - "time" ) // Generate code that will fail if the constants change value. func _() { - // An "cannot convert timerMap literal (type timerMap) to type sync.Map" compiler error signifies that the base type have changed. + // An "cannot convert closeTimerMap literal (type closeTimerMap) to type sync.Map" compiler error signifies that the base type have changed. // Re-run the go-syncmap command to generate them again. - _ = (sync.Map)(timerMap{}) + _ = (sync.Map)(closeTimerMap{}) } -var _nil_timerMap_time_Timer_value = func() (val *time.Timer) { return }() +var _nil_closeTimerMap_closeTimer_value = func() (val *closeTimer) { return }() // Load returns the value stored in the map for a key, or nil if no // value is present. // The ok result indicates whether value was found in the map. -func (m *timerMap) Load(key string) (*time.Timer, bool) { +func (m *closeTimerMap) Load(key string) (*closeTimer, bool) { value, ok := (*sync.Map)(m).Load(key) if value == nil { - return _nil_timerMap_time_Timer_value, ok + return _nil_closeTimerMap_closeTimer_value, ok } - return value.(*time.Timer), ok + return value.(*closeTimer), ok } // Store sets the value for a key. -func (m *timerMap) Store(key string, value *time.Timer) { +func (m *closeTimerMap) Store(key string, value *closeTimer) { (*sync.Map)(m).Store(key, value) } // LoadOrStore returns the existing value for the key if present. // Otherwise, it stores and returns the given value. // The loaded result is true if the value was loaded, false if stored. -func (m *timerMap) LoadOrStore(key string, value *time.Timer) (*time.Timer, bool) { +func (m *closeTimerMap) LoadOrStore(key string, value *closeTimer) (*closeTimer, bool) { actual, loaded := (*sync.Map)(m).LoadOrStore(key, value) if actual == nil { - return _nil_timerMap_time_Timer_value, loaded + return _nil_closeTimerMap_closeTimer_value, loaded } - return actual.(*time.Timer), loaded + return actual.(*closeTimer), loaded } // LoadAndDelete deletes the value for a key, returning the previous value if any. // The loaded result reports whether the key was present. -func (m *timerMap) LoadAndDelete(key string) (value *time.Timer, loaded bool) { +func (m *closeTimerMap) LoadAndDelete(key string) (value *closeTimer, loaded bool) { actual, loaded := (*sync.Map)(m).LoadAndDelete(key) if actual == nil { - return _nil_timerMap_time_Timer_value, loaded + return _nil_closeTimerMap_closeTimer_value, loaded } - return actual.(*time.Timer), loaded + return actual.(*closeTimer), loaded } // Delete deletes the value for a key. -func (m *timerMap) Delete(key string) { +func (m *closeTimerMap) Delete(key string) { (*sync.Map)(m).Delete(key) } @@ -67,8 +66,8 @@ func (m *timerMap) Delete(key string) { // // Range may be O(N) with the number of elements in the map even if f returns // false after a constant number of calls. -func (m *timerMap) Range(f func(key string, value *time.Timer) bool) { +func (m *closeTimerMap) Range(f func(key string, value *closeTimer) bool) { (*sync.Map)(m).Range(func(key, value interface{}) bool { - return f(key.(string), value.(*time.Timer)) + return f(key.(string), value.(*closeTimer)) }) } diff --git a/pkg/networkservice/common/timeout/gen.go b/pkg/networkservice/common/timeout/gen.go index ad87936fb5..36df705e28 100644 --- a/pkg/networkservice/common/timeout/gen.go +++ b/pkg/networkservice/common/timeout/gen.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -20,6 +20,6 @@ import ( "sync" ) -//go:generate go-syncmap -output timer_map.gen.go -type timerMap +//go:generate go-syncmap -output close_timer_map.gen.go -type closeTimerMap -type timerMap sync.Map +type closeTimerMap sync.Map diff --git a/pkg/networkservice/common/timeout/server.go b/pkg/networkservice/common/timeout/server.go index 7fc25072f5..7c695ea2af 100644 --- a/pkg/networkservice/common/timeout/server.go +++ b/pkg/networkservice/common/timeout/server.go @@ -23,7 +23,6 @@ import ( "context" "time" - "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/empty" "github.com/pkg/errors" @@ -36,7 +35,12 @@ import ( type timeoutServer struct { ctx context.Context - timers timerMap + timers closeTimerMap +} + +type closeTimer struct { + expirationTime time.Time + timer *time.Timer } // NewServer - creates a new NetworkServiceServer chain element that implements timeout of expired connections @@ -49,84 +53,76 @@ func NewServer(ctx context.Context) networkservice.NetworkServiceServer { } } -func (t *timeoutServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { - logger := log.FromContext(ctx).WithField("timeoutServer", "request") +func (s *timeoutServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { + if err := s.validateRequest(ctx, request); err != nil { + return nil, err + } connID := request.GetConnection().GetId() - conn, err := next.Server(ctx).Request(ctx, request) - if err != nil { - return nil, err - } + t, loaded := s.timers.Load(connID) + stopped := loaded && t.timer.Stop() - if timer, ok := t.timers.LoadAndDelete(connID); ok { - if !timer.Stop() { - // Even if we failed to stop the timer, we should execute. It does mean that the timeout action - // is waiting on `executor.AsyncExec()` until we will finish. - // Since timer is being deleted under the `executor.AsyncExec()` this can't be a situation when - // the Request is executing after the timeout Close. Such case cannot be distinguished with the - // first-request case. - logger.Warnf("connection has been timed out, re requesting: %v", connID) - } - } + expirationTime := request.GetConnection().GetPrevPathSegment().GetExpires().AsTime().Local() - timer, err := t.createTimer(ctx, conn.Clone()) + conn, err := next.Server(ctx).Request(ctx, request) if err != nil { - if _, closeErr := next.Server(ctx).Close(ctx, conn); closeErr != nil { - err = errors.Wrapf(err, "error attempting to close failed connection %v: %+v", connID, closeErr) + if stopped { + t.timer.Reset(time.Until(t.expirationTime)) } return nil, err } - t.timers.Store(connID, timer) + s.timers.Store(connID, s.newTimer(ctx, expirationTime, conn.Clone())) return conn, nil } -func (t *timeoutServer) createTimer(ctx context.Context, conn *networkservice.Connection) (*time.Timer, error) { - logger := log.FromContext(ctx).WithField("timeoutServer", "createTimer") - - executor := serialize.GetExecutor(ctx) - if executor == nil { - return nil, errors.New("no executor provided") +func (s *timeoutServer) validateRequest(ctx context.Context, request *networkservice.NetworkServiceRequest) error { + if request.GetConnection().GetPrevPathSegment().GetExpires() == nil { + return errors.Errorf("expiration for prev path segment cannot be nil. conn: %+v", request.GetConnection()) } - - if conn.GetPrevPathSegment().GetExpires() == nil { - return nil, errors.Errorf("expiration for prev path segment cannot be nil. conn: %+v", conn) + if serialize.GetExecutor(ctx) == nil { + return errors.New("no executor provided") } - expireTime, err := ptypes.Timestamp(conn.GetPrevPathSegment().GetExpires()) - if err != nil { - return nil, err + return nil +} + +func (s *timeoutServer) newTimer(ctx context.Context, expirationTime time.Time, conn *networkservice.Connection) *closeTimer { + logger := log.FromContext(ctx).WithField("timeoutServer", "newTimer") + + tPtr := new(*closeTimer) + *tPtr = &closeTimer{ + expirationTime: expirationTime, + timer: time.AfterFunc(time.Until(expirationTime), func() { + <-serialize.GetExecutor(ctx).AsyncExec(func() { + if t, ok := s.timers.LoadAndDelete(conn.GetId()); !ok || t != *tPtr { + // this timer has been stopped + return + } + + closeCtx, cancel := context.WithCancel(s.ctx) + defer cancel() + + if _, err := next.Server(ctx).Close(closeCtx, conn); err != nil { + logger.Errorf("failed to close timed out connection: %s %s", conn.GetId(), err.Error()) + } + }) + }), } - conn = conn.Clone() - - timerPtr := new(*time.Timer) - *timerPtr = time.AfterFunc(time.Until(expireTime), func() { - <-executor.AsyncExec(func() { - if timer, _ := t.timers.Load(conn.GetId()); timer != *timerPtr { - logger.Warnf("timer has been already stopped: %v", conn.GetId()) - return - } - t.timers.Delete(conn.GetId()) - if _, err := next.Server(ctx).Close(t.ctx, conn); err != nil { - logger.Errorf("failed to close timed out connection: %v %+v", conn.GetId(), err) - } - }) - }) - - return *timerPtr, nil + return *tPtr } -func (t *timeoutServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) { - logger := log.FromContext(ctx).WithField("timeoutServer", "close") +func (s *timeoutServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) { + logger := log.FromContext(ctx).WithField("timeoutServer", "Close") - timer, ok := t.timers.LoadAndDelete(conn.GetId()) + t, ok := s.timers.LoadAndDelete(conn.GetId()) if !ok { - logger.Warnf("connection has been already closed: %v", conn.GetId()) + logger.Warnf("connection has been already closed: %s", conn.GetId()) return new(empty.Empty), nil } - timer.Stop() + t.timer.Stop() return next.Server(ctx).Close(ctx, conn) } diff --git a/pkg/networkservice/common/timeout/server_test.go b/pkg/networkservice/common/timeout/server_test.go index dc849857d1..cae7d48b12 100644 --- a/pkg/networkservice/common/timeout/server_test.go +++ b/pkg/networkservice/common/timeout/server_test.go @@ -23,8 +23,10 @@ import ( "time" "github.com/golang/protobuf/ptypes/empty" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/goleak" "google.golang.org/grpc/credentials" "github.com/networkservicemesh/api/pkg/api/networkservice" @@ -32,40 +34,44 @@ import ( "github.com/networkservicemesh/sdk/pkg/networkservice/common/mechanisms" "github.com/networkservicemesh/sdk/pkg/networkservice/common/mechanisms/kernel" + "github.com/networkservicemesh/sdk/pkg/networkservice/common/null" + "github.com/networkservicemesh/sdk/pkg/networkservice/common/refresh" "github.com/networkservicemesh/sdk/pkg/networkservice/common/serialize" "github.com/networkservicemesh/sdk/pkg/networkservice/common/timeout" "github.com/networkservicemesh/sdk/pkg/networkservice/common/updatepath" "github.com/networkservicemesh/sdk/pkg/networkservice/common/updatetoken" "github.com/networkservicemesh/sdk/pkg/networkservice/core/adapters" - "github.com/networkservicemesh/sdk/pkg/networkservice/core/chain" "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" ) const ( - clientName = "client" - serverName = "server" - tokenTimeout = 100 * time.Millisecond - waitFor = 10 * tokenTimeout - tick = 10 * time.Millisecond - serverID = "server-id" - parallelCount = 1000 + clientName = "client" + serverName = "server" + tokenTimeout = 100 * time.Millisecond + waitFor = 10 * tokenTimeout + tick = tokenTimeout / 10 ) -func testClient(ctx context.Context, server networkservice.NetworkServiceServer, duration time.Duration) networkservice.NetworkServiceClient { - return chain.NewNetworkServiceClient( +func testClient( + ctx context.Context, + client networkservice.NetworkServiceClient, + server networkservice.NetworkServiceServer, + duration time.Duration, +) networkservice.NetworkServiceClient { + return next.NewNetworkServiceClient( updatepath.NewClient(clientName), - adapters.NewServerToClient(updatetoken.NewServer(func(_ credentials.AuthInfo) (string, time.Time, error) { - return "token", time.Now().Add(duration), nil - })), - kernel.NewClient(), + serialize.NewClient(), + client, adapters.NewServerToClient( - chain.NewNetworkServiceServer( + next.NewNetworkServiceServer( + updatetoken.NewServer(func(_ credentials.AuthInfo) (string, time.Time, error) { + return "token", time.Now().Add(duration), nil + }), + new(remoteServer), // <-- GRPC invocation updatepath.NewServer(serverName), serialize.NewServer(), timeout.NewServer(ctx), - mechanisms.NewServer(map[string]networkservice.NetworkServiceServer{ - kernelmech.MECHANISM: server, - }), + server, ), ), ) @@ -77,20 +83,34 @@ func TestTimeoutServer_Request(t *testing.T) { connServer := newConnectionsServer(t) - _, err := testClient(ctx, connServer, tokenTimeout).Request(ctx, &networkservice.NetworkServiceRequest{}) + client := testClient(ctx, + kernel.NewClient(), + mechanisms.NewServer(map[string]networkservice.NetworkServiceServer{ + kernelmech.MECHANISM: connServer, + }), + tokenTimeout, + ) + + _, err := client.Request(ctx, &networkservice.NetworkServiceRequest{}) require.NoError(t, err) require.Condition(t, connServer.validator(1, 0)) require.Eventually(t, connServer.validator(0, 1), waitFor, tick) } -func TestTimeoutServer_Close_BeforeTimeout(t *testing.T) { +func TestTimeoutServer_CloseBeforeTimeout(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() connServer := newConnectionsServer(t) - client := testClient(ctx, connServer, tokenTimeout) + client := testClient(ctx, + kernel.NewClient(), + mechanisms.NewServer(map[string]networkservice.NetworkServiceServer{ + kernelmech.MECHANISM: connServer, + }), + tokenTimeout, + ) conn, err := client.Request(ctx, &networkservice.NetworkServiceRequest{}) require.NoError(t, err) @@ -104,13 +124,19 @@ func TestTimeoutServer_Close_BeforeTimeout(t *testing.T) { <-time.After(waitFor) } -func TestTimeoutServer_Close_AfterTimeout(t *testing.T) { +func TestTimeoutServer_CloseAfterTimeout(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() connServer := newConnectionsServer(t) - client := testClient(ctx, connServer, tokenTimeout) + client := testClient(ctx, + kernel.NewClient(), + mechanisms.NewServer(map[string]networkservice.NetworkServiceServer{ + kernelmech.MECHANISM: connServer, + }), + tokenTimeout, + ) conn, err := client.Request(ctx, &networkservice.NetworkServiceRequest{}) require.NoError(t, err) @@ -134,7 +160,7 @@ func stressTestRequest() *networkservice.NetworkServiceRequest { }, { Name: serverName, - Id: serverID, + Id: "server-id", }, }, }, @@ -148,11 +174,11 @@ func TestTimeoutServer_StressTest(t *testing.T) { connServer := newConnectionsServer(t) - client := testClient(ctx, connServer, 0) + client := testClient(ctx, null.NewClient(), connServer, 0) - wg := new(sync.WaitGroup) - wg.Add(parallelCount) - for i := 0; i < parallelCount; i++ { + var wg sync.WaitGroup + for i := 0; i < 1000; i++ { + wg.Add(1) go func() { defer wg.Done() conn, err := client.Request(ctx, stressTestRequest()) @@ -164,6 +190,45 @@ func TestTimeoutServer_StressTest(t *testing.T) { wg.Wait() } +func TestTimeoutServer_RefreshFailure(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + connServer := newConnectionsServer(t) + + client := testClient( + ctx, + refresh.NewClient(ctx), + next.NewNetworkServiceServer( + newFailureServer(1, -1), + connServer, + ), + tokenTimeout, + ) + + conn, err := client.Request(ctx, &networkservice.NetworkServiceRequest{}) + require.NoError(t, err) + require.Condition(t, connServer.validator(1, 0)) + + require.Eventually(t, connServer.validator(0, 1), waitFor, tick) + + _, err = client.Close(ctx, conn) + require.NoError(t, err) + require.Condition(t, connServer.validator(0, 1)) +} + +type remoteServer struct{} + +func (s *remoteServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { + return next.Server(ctx).Request(ctx, request.Clone()) +} + +func (s *remoteServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) { + return next.Server(ctx).Close(ctx, conn.Clone()) +} + type connectionsServer struct { t *testing.T lock sync.Mutex @@ -202,9 +267,7 @@ func (s *connectionsServer) validator(open, closed int) func() bool { func (s *connectionsServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { s.lock.Lock() - connID := request.GetConnection().GetId() - - s.connections[connID] = true + s.connections[request.GetConnection().GetId()] = true s.lock.Unlock() @@ -214,15 +277,41 @@ func (s *connectionsServer) Request(ctx context.Context, request *networkservice func (s *connectionsServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) { s.lock.Lock() - connID := conn.GetId() - - if !s.connections[connID] { - assert.Fail(s.t, "closing not opened connection: %v", connID) + if !s.connections[conn.GetId()] { + assert.Fail(s.t, "closing not opened connection: %v", conn.GetId()) } else { - s.connections[connID] = false + s.connections[conn.GetId()] = false } s.lock.Unlock() return next.Server(ctx).Close(ctx, conn) } + +type failureServer struct { + count int + failureTimes []int +} + +func newFailureServer(failureTimes ...int) *failureServer { + return &failureServer{ + failureTimes: failureTimes, + } +} + +func (s *failureServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { + defer func() { s.count++ }() + for _, failureTime := range s.failureTimes { + if failureTime > s.count { + break + } + if failureTime == s.count || failureTime == -1 { + return nil, errors.New("failure") + } + } + return next.Server(ctx).Request(ctx, request) +} + +func (s *failureServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) { + return next.Server(ctx).Close(ctx, conn) +}