Skip to content

Commit

Permalink
[sdk#1028] Fix Timeout, Expire to handle failed Close, Unregister (#1030
Browse files Browse the repository at this point in the history
)

* Fix timeout to handle failed Close

Signed-off-by: Vladimir Popov <vladimir.popov@xored.com>

* Fix expire to handle failed Unregister

Signed-off-by: Vladimir Popov <vladimir.popov@xored.com>

* Add registry injecterror package

Signed-off-by: Vladimir Popov <vladimir.popov@xored.com>

* Fix expire tests with injecterror

Signed-off-by: Vladimir Popov <vladimir.popov@xored.com>
  • Loading branch information
Vladimir Popov authored Jul 20, 2021
1 parent b086a10 commit 0602401
Show file tree
Hide file tree
Showing 10 changed files with 430 additions and 81 deletions.
11 changes: 8 additions & 3 deletions pkg/networkservice/common/timeout/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,15 @@ func (s *timeoutServer) Request(ctx context.Context, request *networkservice.Net
func (s *timeoutServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) {
logger := log.FromContext(ctx).WithField("timeoutServer", "Close")

if !s.expireManager.Delete(conn.GetId()) {
if s.expireManager.Stop(conn.GetId()) {
if _, err := next.Server(ctx).Close(ctx, conn); err != nil {
s.expireManager.Start(conn.GetId())
return nil, err
}
s.expireManager.Delete(conn.GetId())
} else {
logger.Warnf("connection has been already closed: %s", conn.GetId())
return new(empty.Empty), nil
}

return next.Server(ctx).Close(ctx, conn)
return new(empty.Empty), nil
}
35 changes: 35 additions & 0 deletions pkg/networkservice/common/timeout/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,41 @@ func TestTimeoutServer_RefreshFailure(t *testing.T) {
require.Condition(t, connServer.validator(0, 1))
}

func TestTimeoutServer_CloseFailure(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

clockMock := clockmock.New(ctx)
ctx = clock.WithClock(ctx, clockMock)

connServer := newConnectionsServer(t)

client := testClient(
ctx,
null.NewClient(),
next.NewNetworkServiceServer(
injecterror.NewServer(
injecterror.WithRequestErrorTimes(),
injecterror.WithCloseErrorTimes(0)),
connServer,
),
tokenTimeout,
)

conn, err := client.Request(ctx, &networkservice.NetworkServiceRequest{})
require.NoError(t, err)
require.Condition(t, connServer.validator(1, 0))

_, err = client.Close(ctx, conn)
require.Error(t, err)
require.Condition(t, connServer.validator(1, 0))

clockMock.Add(tokenTimeout)
require.Eventually(t, connServer.validator(0, 1), testWait, testTick)
}

type remoteServer struct{}

func (s *remoteServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) {
Expand Down
11 changes: 8 additions & 3 deletions pkg/registry/common/expire/nse_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,15 @@ func (s *expireNSEServer) Find(query *registry.NetworkServiceEndpointQuery, serv
func (s *expireNSEServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*empty.Empty, error) {
logger := log.FromContext(ctx).WithField("expireNSEServer", "Unregister")

if !s.expireManager.Delete(nse.Name) {
if s.expireManager.Stop(nse.Name) {
if _, err := next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse); err != nil {
s.expireManager.Start(nse.Name)
return nil, err
}
s.expireManager.Delete(nse.Name)
} else {
logger.Warnf("endpoint has been already unregistered: %s", nse.Name)
return new(empty.Empty), nil
}

return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse)
return new(empty.Empty), nil
}
166 changes: 93 additions & 73 deletions pkg/registry/common/expire/nse_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"time"

"github.com/golang/protobuf/ptypes/empty"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"google.golang.org/protobuf/types/known/emptypb"
Expand All @@ -40,6 +39,7 @@ import (
"github.com/networkservicemesh/sdk/pkg/registry/core/adapters"
"github.com/networkservicemesh/sdk/pkg/registry/core/next"
"github.com/networkservicemesh/sdk/pkg/registry/utils/checks/checknse"
"github.com/networkservicemesh/sdk/pkg/registry/utils/inject/injecterror"
"github.com/networkservicemesh/sdk/pkg/tools/clock"
"github.com/networkservicemesh/sdk/pkg/tools/clockmock"
)
Expand All @@ -51,6 +51,26 @@ const (
testTick = testWait / 100
)

func find(ctx context.Context, c registry.NetworkServiceEndpointRegistryClient) (nses []*registry.NetworkServiceEndpoint, err error) {
stream, err := c.Find(ctx, &registry.NetworkServiceEndpointQuery{
NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint),
})
if err != nil {
return nil, err
}

var nse *registry.NetworkServiceEndpoint
for nse, err = stream.Recv(); err == nil; nse, err = stream.Recv() {
nses = append(nses, nse)
}

if err != io.EOF {
return nil, err
}

return nses, nil
}

func TestExpireNSEServer_ShouldCorrectlySetExpirationTime_InRemoteCase(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

Expand Down Expand Up @@ -99,17 +119,10 @@ func TestExpireNSEServer_ShouldUseLessExpirationTimeFromInput_AndWork(t *testing

require.Equal(t, clockMock.Until(resp.ExpirationTime.AsTime()), expireTimeout/2)

c := adapters.NetworkServiceEndpointServerToClient(mem)

clockMock.Add(expireTimeout / 2)
require.Eventually(t, func() bool {
stream, err := c.Find(ctx, &registry.NetworkServiceEndpointQuery{
NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint),
})
require.NoError(t, err)

_, err = stream.Recv()
return err == io.EOF
nses, err := find(ctx, adapters.NetworkServiceEndpointServerToClient(mem))
return err == nil && len(nses) == 0
}, testWait, testTick)
}

Expand Down Expand Up @@ -161,35 +174,29 @@ func TestExpireNSEServer_ShouldRemoveNSEAfterExpirationTime(t *testing.T) {

c := adapters.NetworkServiceEndpointServerToClient(mem)

stream, err := c.Find(ctx, &registry.NetworkServiceEndpointQuery{
NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint),
})
require.NoError(t, err)

nse, err := stream.Recv()
nses, err := find(ctx, c)
require.NoError(t, err)
require.Equal(t, nseName, nse.Name)
require.Len(t, nses, 1)
require.Equal(t, nseName, nses[0].Name)

clockMock.Add(expireTimeout)
require.Eventually(t, func() bool {
stream, err = c.Find(ctx, &registry.NetworkServiceEndpointQuery{
NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint),
})
require.NoError(t, err)

_, err = stream.Recv()
return err == io.EOF
nses, err = find(ctx, c)
return err == nil && len(nses) == 0
}, testWait, testTick)
}

func TestExpireNSEServer_DataRace(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

mem := memory.NewNetworkServiceEndpointRegistryServer()

s := next.NewNetworkServiceEndpointRegistryServer(
serialize.NewNetworkServiceEndpointRegistryServer(),
expire.NewNetworkServiceEndpointRegistryServer(context.Background(), 0),
expire.NewNetworkServiceEndpointRegistryServer(ctx, 0),
localbypass.NewNetworkServiceEndpointRegistryServer("tcp://0.0.0.0"),
mem,
)
Expand All @@ -202,16 +209,9 @@ func TestExpireNSEServer_DataRace(t *testing.T) {
require.NoError(t, err)
}

c := adapters.NetworkServiceEndpointServerToClient(mem)

require.Eventually(t, func() bool {
stream, err := c.Find(context.Background(), &registry.NetworkServiceEndpointQuery{
NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint),
})
require.NoError(t, err)

_, err = stream.Recv()
return err == io.EOF
nses, err := find(context.Background(), adapters.NetworkServiceEndpointServerToClient(mem))
return err == nil && len(nses) == 0
}, testWait, testTick)
}

Expand All @@ -231,7 +231,11 @@ func TestExpireNSEServer_RefreshFailure(t *testing.T) {
new(remoteNSEServer), // <-- GRPC invocation
serialize.NewNetworkServiceEndpointRegistryServer(),
expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout),
newFailureNSEServer(1, -1),
injecterror.NewNetworkServiceEndpointRegistryServer(
injecterror.WithRegisterErrorTimes(1, -1),
injecterror.WithFindErrorTimes(),
injecterror.WithUnregisterErrorTimes(),
),
memory.NewNetworkServiceEndpointRegistryServer(),
)),
)
Expand All @@ -241,13 +245,52 @@ func TestExpireNSEServer_RefreshFailure(t *testing.T) {

clockMock.Add(expireTimeout)
require.Eventually(t, func() bool {
stream, err := c.Find(ctx, &registry.NetworkServiceEndpointQuery{
NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint),
})
require.NoError(t, err)
nses, err := find(ctx, c)
return err == nil && len(nses) == 0
}, testWait, testTick)
}

func TestExpireNSEServer_UnregisterFailure(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

_, err = stream.Recv()
return err == io.EOF
clockMock := clockmock.New(ctx)
ctx = clock.WithClock(ctx, clockMock)

mem := memory.NewNetworkServiceEndpointRegistryServer()

s := next.NewNetworkServiceEndpointRegistryServer(
serialize.NewNetworkServiceEndpointRegistryServer(),
expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout),
injecterror.NewNetworkServiceEndpointRegistryServer(
injecterror.WithRegisterErrorTimes(),
injecterror.WithFindErrorTimes(),
injecterror.WithUnregisterErrorTimes(0),
),
mem,
)

nse, err := s.Register(ctx, &registry.NetworkServiceEndpoint{
Name: nseName,
})
require.NoError(t, err)

_, err = s.Unregister(ctx, nse)
require.Error(t, err)

c := adapters.NetworkServiceEndpointServerToClient(mem)

nses, err := find(ctx, c)
require.NoError(t, err)
require.Len(t, nses, 1)
require.Equal(t, nseName, nses[0].Name)

clockMock.Add(expireTimeout)
require.Eventually(t, func() bool {
nses, err = find(ctx, c)
return err == nil && len(nses) == 0
}, testWait, testTick)
}

Expand Down Expand Up @@ -295,47 +338,24 @@ type remoteNSEServer struct {
}

func (s *remoteNSEServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
return next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse.Clone())
}

func (s *remoteNSEServer) Find(query *registry.NetworkServiceEndpointQuery, server registry.NetworkServiceEndpointRegistry_FindServer) error {
if err := server.Context().Err(); err != nil {
return err
}
return next.NetworkServiceEndpointRegistryServer(server.Context()).Find(query, server)
}

func (s *remoteNSEServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*empty.Empty, error) {
return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse.Clone())
}

type failureNSEServer struct {
count int
failureTimes []int
}

func newFailureNSEServer(failureTimes ...int) *failureNSEServer {
return &failureNSEServer{
failureTimes: failureTimes,
}
}

func (s *failureNSEServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) {
defer func() { s.count++ }()
for _, failureTime := range s.failureTimes {
if failureTime > s.count {
break
}
if failureTime == s.count || failureTime == -1 {
return nil, errors.New("failure")
}
if err := ctx.Err(); err != nil {
return nil, err
}
return next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse)
}

func (s *failureNSEServer) Find(query *registry.NetworkServiceEndpointQuery, server registry.NetworkServiceEndpointRegistry_FindServer) error {
return next.NetworkServiceEndpointRegistryServer(server.Context()).Find(query, server)
}

func (s *failureNSEServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*empty.Empty, error) {
return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse)
return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse.Clone())
}

type unregisterNSEServer struct {
Expand Down
18 changes: 18 additions & 0 deletions pkg/registry/utils/inject/injecterror/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) 2021 Doc.ai and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package injecterror provides chain elements returning given error on Register, Unregister on given times
package injecterror
42 changes: 42 additions & 0 deletions pkg/registry/utils/inject/injecterror/error_supplier.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) 2020-2021 Doc.ai and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package injecterror

type errorSupplier struct {
err error
count int
errorTimes []int
}

// supply returns an error or nil depending on errorTimes
// * [0, 2, 3] - will return an error on 0, 2, 3 times
// * [-1] - will return an error on all requests
// * [1, 4, -1] - will return an error on 0 time and on all times starting from 4
func (e *errorSupplier) supply() error {
defer func() { e.count++ }()

for _, errorTime := range e.errorTimes {
if errorTime > e.count {
break
}
if errorTime == e.count || errorTime == -1 {
return e.err
}
}

return nil
}
Loading

0 comments on commit 0602401

Please sign in to comment.