Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: nsmgr should keep fowarder while conn is fine #1479

Merged
merged 2 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 99 additions & 10 deletions pkg/networkservice/chains/nsmgr/select_forwarder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,19 @@ package nsmgr_test
import (
"context"
"fmt"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/require"
"go.uber.org/goleak"

"github.com/networkservicemesh/api/pkg/api/networkservice"
"github.com/networkservicemesh/api/pkg/api/registry"

nsclient "github.com/networkservicemesh/sdk/pkg/networkservice/chains/client"
"github.com/networkservicemesh/sdk/pkg/networkservice/chains/nsmgr"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/heal"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/count"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/inject/injecterror"
"github.com/networkservicemesh/sdk/pkg/tools/sandbox"
Expand Down Expand Up @@ -141,26 +144,26 @@ func Test_DiscoverForwarder_ChangeForwarderOnClose(t *testing.T) {
require.Equal(t, skipCount+1, counter.UniqueRequests())
require.Equal(t, skipCount+1, counter.Requests())

selectedFwd := conn.GetPath().GetPathSegments()[2].Name
selectedForwarder := conn.GetPath().GetPathSegments()[2].Name

requestsCount := counter.Requests()
for i := 0; i < reselectCount; i++ {
_, err = nsc.Close(ctx, conn)
require.NoError(t, err)

// check that we select a different forwarder
selectedFwd = conn.GetPath().GetPathSegments()[2].Name
selectedForwarder = conn.GetPath().GetPathSegments()[2].Name
request.Connection = conn
conn, err = nsc.Request(ctx, request.Clone())
require.NoError(t, err)
require.Equal(t, skipCount+1, counter.UniqueRequests())
require.Equal(t, requestsCount+3, counter.Requests())
requestsCount = counter.Requests()
if selectedFwd != conn.GetPath().GetPathSegments()[2].Name {
if selectedForwarder != conn.GetPath().GetPathSegments()[2].Name {
break
}
}
require.NotEqual(t, selectedFwd, conn.GetPath().GetPathSegments()[2].Name)
require.NotEqual(t, selectedForwarder, conn.GetPath().GetPathSegments()[2].Name)
}

func Test_DiscoverForwarder_ChangeForwarderOnDeath_LostHeal(t *testing.T) {
Expand Down Expand Up @@ -210,9 +213,9 @@ func Test_DiscoverForwarder_ChangeForwarderOnDeath_LostHeal(t *testing.T) {
require.Equal(t, 1, counter.UniqueRequests())
require.Equal(t, 1, counter.Requests())

selectedFwd := conn.GetPath().GetPathSegments()[2].Name
selectedForwarder := conn.GetPath().GetPathSegments()[2].Name

domain.Nodes[0].Forwarders[selectedFwd].Cancel()
domain.Nodes[0].Forwarders[selectedForwarder].Cancel()

require.Eventually(t, checkSecondRequestsReceived(counter.Requests), timeout, tick)
require.Equal(t, 1, counter.UniqueRequests())
Expand All @@ -226,7 +229,7 @@ func Test_DiscoverForwarder_ChangeForwarderOnDeath_LostHeal(t *testing.T) {
require.Equal(t, 1, counter.UniqueRequests())
require.Equal(t, 3, counter.Requests())
require.Equal(t, 1, counter.Closes())
require.NotEqual(t, selectedFwd, conn.GetPath().GetPathSegments()[2].Name)
require.NotEqual(t, selectedForwarder, conn.GetPath().GetPathSegments()[2].Name)
}

func Test_DiscoverForwarder_ChangeRemoteForwarderOnDeath(t *testing.T) {
Expand Down Expand Up @@ -281,11 +284,11 @@ func Test_DiscoverForwarder_ChangeRemoteForwarderOnDeath(t *testing.T) {
require.Equal(t, 1, counter.UniqueRequests())
require.Equal(t, 1, counter.Requests())

selectedFwd := conn.GetPath().GetPathSegments()[4].Name
selectedForwarder := conn.GetPath().GetPathSegments()[4].Name

domain.Registry.Restart()

domain.Nodes[1].Forwarders[selectedFwd].Cancel()
domain.Nodes[1].Forwarders[selectedForwarder].Cancel()

require.Eventually(t, checkSecondRequestsReceived(counter.Requests), timeout, tick)
require.Equal(t, 1, counter.UniqueRequests())
Expand All @@ -299,5 +302,91 @@ func Test_DiscoverForwarder_ChangeRemoteForwarderOnDeath(t *testing.T) {
require.Equal(t, 1, counter.UniqueRequests())
require.Equal(t, 3, counter.Requests())
require.Equal(t, 1, counter.Closes())
require.NotEqual(t, selectedFwd, conn.GetPath().GetPathSegments()[4].Name)
require.NotEqual(t, selectedForwarder, conn.GetPath().GetPathSegments()[4].Name)
}

func Test_DiscoverForwarder_Should_KeepSelectedForwarderWhileConnectionIsFine(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })
ctx, cancel := context.WithTimeout(context.Background(), timeout)

defer cancel()
domain := sandbox.NewBuilder(ctx, t).
SetNodesCount(1).
SetNSMgrProxySupplier(nil).
SetRegistryProxySupplier(nil).
SetNodeSetup(func(ctx context.Context, node *sandbox.Node, _ int) {
node.NewNSMgr(ctx, "nsmgr", nil, sandbox.GenerateTestToken, nsmgr.NewServer)
}).
Build()

const fwdCount = 10
for i := 0; i < fwdCount; i++ {
domain.Nodes[0].NewForwarder(ctx, &registry.NetworkServiceEndpoint{
Name: sandbox.UniqueName("forwarder-" + fmt.Sprint(i)),
NetworkServiceNames: []string{"forwarder"},
}, sandbox.GenerateTestToken)
}

nsRegistryClient := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken)

nsReg := defaultRegistryService(t.Name())
nsReg, err := nsRegistryClient.Register(ctx, nsReg)
require.NoError(t, err)

nseReg := defaultRegistryEndpoint(nsReg.Name)

counter := new(count.Server)
domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken, counter)

request := defaultRequest(nsReg.Name)

var livenessValue atomic.Value
livenessValue.Store(true)

var selectedForwarder string

var livenessChecker = func(deadlineCtx context.Context, conn *networkservice.Connection) bool {
if v := livenessValue.Load().(bool); !v {
return conn.GetPath().GetPathSegments()[2].Name != selectedForwarder
}
return true
}

nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken,
nsclient.WithHealClient(heal.NewClient(ctx,
heal.WithLivenessCheck(livenessChecker))))

conn, err := nsc.Request(ctx, request.Clone())
require.NoError(t, err)
require.Equal(t, 1, counter.UniqueRequests())
require.Equal(t, 1, counter.Requests())

selectedForwarder = conn.GetPath().GetPathSegments()[2].Name

domain.Registry.Restart()

domain.Nodes[0].NSMgr.Restart()

require.Eventually(t, checkSecondRequestsReceived(counter.Requests), timeout, tick)
require.Equal(t, 1, counter.UniqueRequests())
require.Equal(t, 2, counter.Requests())
require.Equal(t, 0, counter.Closes())

request.Connection = conn
conn, err = nsc.Request(ctx, request.Clone())
require.NoError(t, err)
require.Equal(t, 1, counter.UniqueRequests())
require.Equal(t, 0, counter.Closes())
require.Equal(t, selectedForwarder, conn.GetPath().GetPathSegments()[2].Name)

// datapath is down
livenessValue.Store(false)
domain.Nodes[0].Forwarders[selectedForwarder].Cancel()

request.Connection = conn
conn, err = nsc.Request(ctx, request.Clone())
require.NoError(t, err)
require.Equal(t, 1, counter.UniqueRequests())
require.Greater(t, counter.Closes(), 0)
require.NotEqual(t, selectedForwarder, conn.GetPath().GetPathSegments()[2].Name)
}
36 changes: 36 additions & 0 deletions pkg/networkservice/common/discoverforwarder/metadata.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2023 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package discoverforwarder

import (
"context"

"github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata"
)

type forwarderNameMetadataKey struct{}

func loadForwarderName(ctx context.Context) string {
if v, ok := metadata.Map(ctx, false).Load(forwarderNameMetadataKey{}); ok {
return v.(string)
}
return ""
}

func storeForwarderName(ctx context.Context, name string) {
metadata.Map(ctx, false).Store(forwarderNameMetadataKey{}, name)
}
43 changes: 26 additions & 17 deletions pkg/networkservice/common/discoverforwarder/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,8 @@ func NewServer(nsClient registry.NetworkServiceRegistryClient, nseClient registr
return result
}

func (d *discoverForwarderServer) forwarderName(conn *networkservice.Connection) string {
var segments = conn.GetPath().GetPathSegments()
if pathIndex := int(conn.GetPath().Index); len(conn.GetPath().PathSegments) > pathIndex+1 {
return segments[pathIndex+1].Name
}
return ""
}

func (d *discoverForwarderServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) {
var forwarderName = d.forwarderName(request.GetConnection())
var forwarderName = loadForwarderName(ctx)
var logger = log.FromContext(ctx).WithField("discoverForwarderServer", "request")

ns, err := d.discoverNetworkService(ctx, request.GetConnection().GetNetworkService(), request.GetConnection().GetPayload())
Expand All @@ -87,7 +79,8 @@ func (d *discoverForwarderServer) Request(ctx context.Context, request *networks
NetworkServiceNames: []string{
d.forwarderServiceName,
},
Url: d.nsmgrURL,
Name: forwarderName,
Url: d.nsmgrURL,
},
})
if err != nil {
Expand All @@ -97,15 +90,20 @@ func (d *discoverForwarderServer) Request(ctx context.Context, request *networks

nses := d.matchForwarders(request.Connection.GetLabels(), ns, registry.ReadNetworkServiceEndpointList(stream))
if len(nses) == 0 {
if forwarderName != "" {
return nil, errors.Errorf("forwarder %v is not available", forwarderName)
}
return nil, errors.New("no candidates found")
}

segments := request.Connection.GetPath().GetPathSegments()
if pathIndex := int(request.Connection.GetPath().Index); len(segments) > pathIndex+1 {
for i, candidate := range nses {
if candidate.Name == forwarderName {
nses[0], nses[i] = nses[i], nses[0]
break
if forwarderName == "" {
segments := request.Connection.GetPath().GetPathSegments()
if pathIndex := int(request.Connection.GetPath().Index); len(segments) > pathIndex+1 {
for i, candidate := range nses {
if candidate.Name == segments[pathIndex+1].GetName() {
nses[0], nses[i] = nses[i], nses[0]
break
}
}
}
}
Expand All @@ -123,6 +121,9 @@ func (d *discoverForwarderServer) Request(ctx context.Context, request *networks

resp, err := next.Server(ctx).Request(clienturlctx.WithClientURL(ctx, u), request.Clone())
if err == nil {
if forwarderName == "" {
storeForwarderName(ctx, candidate.GetName())
}
return resp, nil
}
logger.Errorf("forwarder=%v url=%v returned error=%v", candidate.Name, candidate.Url, err.Error())
Expand All @@ -136,7 +137,15 @@ func (d *discoverForwarderServer) Close(ctx context.Context, conn *networkservic
// Unlike Request, Close method should always call next element in chain
// to make sure we clear resources in the current app.

var forwarderName = d.forwarderName(conn)
var forwarderName = loadForwarderName(ctx)

if forwarderName == "" {
segments := conn.GetPath().GetPathSegments()
if pathIndex := int(conn.GetPath().Index); len(segments) > pathIndex+1 {
forwarderName = segments[pathIndex+1].GetName()
}
}

var logger = log.FromContext(ctx).WithField("discoverForwarderServer", "request")
if forwarderName == "" {
logger.Error("connection doesn't have forwarder")
Expand Down
1 change: 0 additions & 1 deletion pkg/networkservice/common/heal/eventloop.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ func (cev *eventLoop) waitForEvents() (canceled, reselect bool) {
return true, false
}
cev.logger.Warnf("Data plane is down")
reselect = true
cev.healingStartedCh <- true
return false, true
case <-cev.chainCtx.Done():
Expand Down
Loading