From 39ef8926c679d5cb03b5ee011c99e3300d6acc60 Mon Sep 17 00:00:00 2001 From: Artem Glazychev Date: Wed, 21 Sep 2022 17:06:52 +0700 Subject: [PATCH] Don't restore eventFactory in case the connection has already been closed/unregistered Signed-off-by: Artem Glazychev --- pkg/networkservice/common/begin/client.go | 4 +- .../common/begin/event_factory.go | 6 +- .../common/begin/event_factory_client_test.go | 132 +++++++++++++++++ .../common/begin/event_factory_server_test.go | 129 +++++++++++++++++ pkg/networkservice/common/begin/server.go | 6 +- pkg/registry/common/begin/ns_client.go | 4 +- pkg/registry/common/begin/ns_server.go | 4 +- pkg/registry/common/begin/nse_client.go | 4 +- .../begin/nse_event_factory_client_test.go | 134 ++++++++++++++++++ .../begin/nse_event_factory_server_test.go | 131 +++++++++++++++++ pkg/registry/common/begin/nse_server.go | 4 +- 11 files changed, 542 insertions(+), 16 deletions(-) create mode 100644 pkg/networkservice/common/begin/event_factory_client_test.go create mode 100644 pkg/networkservice/common/begin/event_factory_server_test.go create mode 100644 pkg/registry/common/begin/nse_event_factory_client_test.go create mode 100644 pkg/registry/common/begin/nse_event_factory_server_test.go diff --git a/pkg/networkservice/common/begin/client.go b/pkg/networkservice/common/begin/client.go index 48d8803c6..a0978408a 100644 --- a/pkg/networkservice/common/begin/client.go +++ b/pkg/networkservice/common/begin/client.go @@ -59,7 +59,7 @@ func (b *beginClient) Request(ctx context.Context, request *networkservice.Netwo <-eventFactoryClient.executor.AsyncExec(func() { // If the eventFactory has changed, usually because the connection has been Closed and re-established // go back to the beginning and try again. - currentEventFactoryClient, _ := b.LoadOrStore(request.GetConnection().GetId(), eventFactoryClient) + currentEventFactoryClient, _ := b.Load(request.GetConnection().GetId()) if currentEventFactoryClient != eventFactoryClient { log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryClient != eventFactoryClient") conn, err = b.Request(ctx, request, opts...) @@ -103,7 +103,7 @@ func (b *beginClient) Close(ctx context.Context, conn *networkservice.Connection } // If this isn't the connection we started with, do nothing - currentEventFactoryClient, _ := b.LoadOrStore(conn.GetId(), eventFactoryClient) + currentEventFactoryClient, _ := b.Load(conn.GetId()) if currentEventFactoryClient != eventFactoryClient { return } diff --git a/pkg/networkservice/common/begin/event_factory.go b/pkg/networkservice/common/begin/event_factory.go index a4335fa4f..a1db68fcf 100644 --- a/pkg/networkservice/common/begin/event_factory.go +++ b/pkg/networkservice/common/begin/event_factory.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Cisco and/or its affiliates. +// Copyright (c) 2021-2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -60,7 +60,7 @@ func newEventFactoryClient(ctx context.Context, afterClose func(), opts ...grpc. client: next.Client(ctx), opts: opts, } - ctxFunc := postpone.Context(ctx) + ctxFunc := postpone.ContextWithValues(ctx) f.ctxFunc = func() (context.Context, context.CancelFunc) { eventCtx, cancel := ctxFunc() return withEventFactory(eventCtx, f), cancel @@ -155,7 +155,7 @@ func newEventFactoryServer(ctx context.Context, afterClose func()) *eventFactory f := &eventFactoryServer{ server: next.Server(ctx), } - ctxFunc := postpone.Context(ctx) + ctxFunc := postpone.ContextWithValues(ctx) f.ctxFunc = func() (context.Context, context.CancelFunc) { eventCtx, cancel := ctxFunc() return withEventFactory(eventCtx, f), cancel diff --git a/pkg/networkservice/common/begin/event_factory_client_test.go b/pkg/networkservice/common/begin/event_factory_client_test.go new file mode 100644 index 000000000..3b6484a17 --- /dev/null +++ b/pkg/networkservice/common/begin/event_factory_client_test.go @@ -0,0 +1,132 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package begin_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.uber.org/goleak" + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/networkservicemesh/api/pkg/api/networkservice" + + "github.com/networkservicemesh/sdk/pkg/networkservice/common/begin" + "github.com/networkservicemesh/sdk/pkg/networkservice/core/chain" + "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" +) + +// This test reproduces the situation when Close and Request were called at the same time +// nolint:dupl +func TestRefreshDuringClose_Client(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + syncChan := make(chan struct{}) + checkCtxCl := &checkContextClient{t: t} + eventFactoryCl := &eventFactoryClient{ch: syncChan} + client := chain.NewNetworkServiceClient( + begin.NewClient(), + checkCtxCl, + eventFactoryCl, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Set any value to context + ctx = context.WithValue(ctx, contextKey{}, "value_1") + checkCtxCl.setExpectedValue("value_1") + + // Do Request with this context + request := testRequest("1") + conn, err := client.Request(ctx, request.Clone()) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + + // Change context value before refresh Request + ctx = context.WithValue(ctx, contextKey{}, "value_2") + checkCtxCl.setExpectedValue("value_2") + request.Connection = conn.Clone() + + // Call Close from eventFactory + eventFactoryCl.callClose() + <-syncChan + + // Call refresh (should be called at the same time as Close) + conn, err = client.Request(ctx, request.Clone()) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + + // Call refresh from eventFactory. We are expecting updated value in the context + eventFactoryCl.callRefresh() + <-syncChan +} + +type eventFactoryClient struct { + ctx context.Context + ch chan<- struct{} +} + +func (s *eventFactoryClient) Request(ctx context.Context, request *networkservice.NetworkServiceRequest, opts ...grpc.CallOption) (*networkservice.Connection, error) { + s.ctx = ctx + return next.Client(ctx).Request(ctx, request, opts...) +} + +func (s *eventFactoryClient) Close(ctx context.Context, conn *networkservice.Connection, opts ...grpc.CallOption) (*emptypb.Empty, error) { + // Wait to be sure that rerequest was called + time.Sleep(time.Millisecond * 100) + return next.Client(ctx).Close(ctx, conn, opts...) +} + +func (s *eventFactoryClient) callClose() { + eventFactory := begin.FromContext(s.ctx) + go func() { + s.ch <- struct{}{} + eventFactory.Close() + }() +} + +func (s *eventFactoryClient) callRefresh() { + eventFactory := begin.FromContext(s.ctx) + go func() { + s.ch <- struct{}{} + eventFactory.Request() + }() +} + +type contextKey struct{} + +type checkContextClient struct { + t *testing.T + expectedValue string +} + +func (c *checkContextClient) Request(ctx context.Context, request *networkservice.NetworkServiceRequest, opts ...grpc.CallOption) (*networkservice.Connection, error) { + assert.Equal(c.t, c.expectedValue, ctx.Value(contextKey{})) + return next.Client(ctx).Request(ctx, request, opts...) +} + +func (c *checkContextClient) Close(ctx context.Context, conn *networkservice.Connection, opts ...grpc.CallOption) (*emptypb.Empty, error) { + return next.Client(ctx).Close(ctx, conn, opts...) +} + +func (c *checkContextClient) setExpectedValue(value string) { + c.expectedValue = value +} diff --git a/pkg/networkservice/common/begin/event_factory_server_test.go b/pkg/networkservice/common/begin/event_factory_server_test.go new file mode 100644 index 000000000..d4ad8e9e5 --- /dev/null +++ b/pkg/networkservice/common/begin/event_factory_server_test.go @@ -0,0 +1,129 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package begin_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.uber.org/goleak" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/networkservicemesh/api/pkg/api/networkservice" + + "github.com/networkservicemesh/sdk/pkg/networkservice/common/begin" + "github.com/networkservicemesh/sdk/pkg/networkservice/core/chain" + "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" +) + +// This test reproduces the situation when Close and Request were called at the same time +// nolint:dupl +func TestRefreshDuringClose_Server(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + syncChan := make(chan struct{}) + checkCtxServ := &checkContextServer{t: t} + eventFactoryServ := &eventFactoryServer{ch: syncChan} + server := chain.NewNetworkServiceServer( + begin.NewServer(), + checkCtxServ, + eventFactoryServ, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Set any value to context + ctx = context.WithValue(ctx, contextKey{}, "value_1") + checkCtxServ.setExpectedValue("value_1") + + // Do Request with this context + request := testRequest("1") + conn, err := server.Request(ctx, request.Clone()) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + + // Change context value before refresh Request + ctx = context.WithValue(ctx, contextKey{}, "value_2") + checkCtxServ.setExpectedValue("value_2") + request.Connection = conn.Clone() + + // Call Close from eventFactory + eventFactoryServ.callClose() + <-syncChan + + // Call refresh (should be called at the same time as Close) + conn, err = server.Request(ctx, request.Clone()) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + + // Call refresh from eventFactory. We are expecting updated value in the context + eventFactoryServ.callRefresh() + <-syncChan +} + +type eventFactoryServer struct { + ctx context.Context + ch chan<- struct{} +} + +func (e *eventFactoryServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { + e.ctx = ctx + return next.Server(ctx).Request(ctx, request) +} + +func (e *eventFactoryServer) Close(ctx context.Context, conn *networkservice.Connection) (*emptypb.Empty, error) { + // Wait to be sure that rerequest was called + time.Sleep(time.Millisecond * 100) + return next.Server(ctx).Close(ctx, conn) +} + +func (e *eventFactoryServer) callClose() { + eventFactory := begin.FromContext(e.ctx) + go func() { + e.ch <- struct{}{} + eventFactory.Close() + }() +} + +func (e *eventFactoryServer) callRefresh() { + eventFactory := begin.FromContext(e.ctx) + go func() { + e.ch <- struct{}{} + eventFactory.Request() + }() +} + +type checkContextServer struct { + t *testing.T + expectedValue string +} + +func (c *checkContextServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { + assert.Equal(c.t, c.expectedValue, ctx.Value(contextKey{})) + return next.Server(ctx).Request(ctx, request) +} + +func (c *checkContextServer) Close(ctx context.Context, conn *networkservice.Connection) (*emptypb.Empty, error) { + return next.Server(ctx).Close(ctx, conn) +} + +func (c *checkContextServer) setExpectedValue(value string) { + c.expectedValue = value +} diff --git a/pkg/networkservice/common/begin/server.go b/pkg/networkservice/common/begin/server.go index 4cb49028d..790916c1d 100644 --- a/pkg/networkservice/common/begin/server.go +++ b/pkg/networkservice/common/begin/server.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Cisco and/or its affiliates. +// Copyright (c) 2021-2022 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -55,7 +55,7 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo ), ) <-eventFactoryServer.executor.AsyncExec(func() { - currentEventFactoryServer, _ := b.LoadOrStore(request.GetConnection().GetId(), eventFactoryServer) + currentEventFactoryServer, _ := b.Load(request.GetConnection().GetId()) if currentEventFactoryServer != eventFactoryServer { log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryServer != eventFactoryServer") conn, err = b.Request(ctx, request) @@ -93,7 +93,7 @@ func (b *beginServer) Close(ctx context.Context, conn *networkservice.Connection if eventFactoryServer.state != established || eventFactoryServer.request == nil { return } - currentServerClient, _ := b.LoadOrStore(conn.GetId(), eventFactoryServer) + currentServerClient, _ := b.Load(conn.GetId()) if currentServerClient != eventFactoryServer { return } diff --git a/pkg/registry/common/begin/ns_client.go b/pkg/registry/common/begin/ns_client.go index 2853bbed9..fd35ee9a5 100644 --- a/pkg/registry/common/begin/ns_client.go +++ b/pkg/registry/common/begin/ns_client.go @@ -55,7 +55,7 @@ func (b *beginNSClient) Register(ctx context.Context, in *registry.NetworkServic <-eventFactoryClient.executor.AsyncExec(func() { // If the eventFactory has changed, usually because the connection has been Closed and re-established // go back to the beginning and try again. - currentEventFactoryClient, _ := b.LoadOrStore(id, eventFactoryClient) + currentEventFactoryClient, _ := b.Load(id) if currentEventFactoryClient != eventFactoryClient { log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryClient != eventFactoryClient") resp, err = b.Register(ctx, in, opts...) @@ -101,7 +101,7 @@ func (b *beginNSClient) Unregister(ctx context.Context, in *registry.NetworkServ } // If this isn't the connection we started with, do nothing - currentEventFactoryClient, _ := b.LoadOrStore(id, eventFactoryClient) + currentEventFactoryClient, _ := b.Load(id) if currentEventFactoryClient != eventFactoryClient { return } diff --git a/pkg/registry/common/begin/ns_server.go b/pkg/registry/common/begin/ns_server.go index 711c59696..55e35b26b 100644 --- a/pkg/registry/common/begin/ns_server.go +++ b/pkg/registry/common/begin/ns_server.go @@ -54,7 +54,7 @@ func (b *beginNSServer) Register(ctx context.Context, in *registry.NetworkServic var err error <-eventFactoryServer.executor.AsyncExec(func() { - currentEventFactoryServer, _ := b.LoadOrStore(id, eventFactoryServer) + currentEventFactoryServer, _ := b.Load(id) if currentEventFactoryServer != eventFactoryServer { log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryServer != eventFactoryServer") resp, err = b.Register(ctx, in) @@ -96,7 +96,7 @@ func (b *beginNSServer) Unregister(ctx context.Context, in *registry.NetworkServ if eventFactoryServer.state != established || eventFactoryServer.registration == nil { return } - currentServerClient, _ := b.LoadOrStore(id, eventFactoryServer) + currentServerClient, _ := b.Load(id) if currentServerClient != eventFactoryServer { return } diff --git a/pkg/registry/common/begin/nse_client.go b/pkg/registry/common/begin/nse_client.go index 078887e68..a2c53c344 100644 --- a/pkg/registry/common/begin/nse_client.go +++ b/pkg/registry/common/begin/nse_client.go @@ -55,7 +55,7 @@ func (b *beginNSEClient) Register(ctx context.Context, in *registry.NetworkServi <-eventFactoryClient.executor.AsyncExec(func() { // If the eventFactory has changed, usually because the connection has been Closed and re-established // go back to the beginning and try again. - currentEventFactoryClient, _ := b.LoadOrStore(id, eventFactoryClient) + currentEventFactoryClient, _ := b.Load(id) if currentEventFactoryClient != eventFactoryClient { log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryClient != eventFactoryClient") resp, err = b.Register(ctx, in, opts...) @@ -101,7 +101,7 @@ func (b *beginNSEClient) Unregister(ctx context.Context, in *registry.NetworkSer } // If this isn't the connection we started with, do nothing - currentEventFactoryClient, _ := b.LoadOrStore(id, eventFactoryClient) + currentEventFactoryClient, _ := b.Load(id) if currentEventFactoryClient != eventFactoryClient { return } diff --git a/pkg/registry/common/begin/nse_event_factory_client_test.go b/pkg/registry/common/begin/nse_event_factory_client_test.go new file mode 100644 index 000000000..00e5be845 --- /dev/null +++ b/pkg/registry/common/begin/nse_event_factory_client_test.go @@ -0,0 +1,134 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package begin_test + +import ( + "context" + "testing" + "time" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" + "github.com/networkservicemesh/sdk/pkg/registry/core/chain" + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + + "github.com/stretchr/testify/assert" + "go.uber.org/goleak" + "google.golang.org/grpc" +) + +// This test reproduces the situation when Unregister and Register were called at the same time +func TestRefreshDuringUnregister_Client(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + syncChan := make(chan struct{}) + checkCtxCl := &checkContextClient{t: t} + eventFactoryCl := &eventFactoryClient{ch: syncChan} + client := chain.NewNetworkServiceEndpointRegistryClient( + begin.NewNetworkServiceEndpointRegistryClient(), + checkCtxCl, + eventFactoryCl, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Set any value to context + ctx = context.WithValue(ctx, contextKey{}, "value_1") + checkCtxCl.setExpectedValue("value_1") + + // Do Register with this context + nse := ®istry.NetworkServiceEndpoint{ + Name: "1", + } + conn, err := client.Register(ctx, nse.Clone()) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + + // Change context value before refresh + ctx = context.WithValue(ctx, contextKey{}, "value_2") + checkCtxCl.setExpectedValue("value_2") + + // Call Unregister from eventFactory + eventFactoryCl.callUnregister() + <-syncChan + + // Call refresh (should be called at the same time as Unregister) + conn, err = client.Register(ctx, nse.Clone()) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + + // Call refresh from eventFactory. We are expecting updated value in the context + eventFactoryCl.callRefresh() + <-syncChan +} + +type eventFactoryClient struct { + registry.NetworkServiceEndpointRegistryClient + ctx context.Context + ch chan<- struct{} +} + +func (e *eventFactoryClient) Register(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { + e.ctx = ctx + return next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) +} + +func (e *eventFactoryClient) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*empty.Empty, error) { + // Wait to be sure that reregister was called + time.Sleep(time.Millisecond * 100) + return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) +} + +func (e *eventFactoryClient) callUnregister() { + eventFactory := begin.FromContext(e.ctx) + go func() { + e.ch <- struct{}{} + eventFactory.Unregister() + }() +} + +func (e *eventFactoryClient) callRefresh() { + eventFactory := begin.FromContext(e.ctx) + go func() { + e.ch <- struct{}{} + eventFactory.Register() + }() +} + +type contextKey struct{} + +type checkContextClient struct { + registry.NetworkServiceEndpointRegistryClient + t *testing.T + expectedValue string +} + +func (c *checkContextClient) Register(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { + assert.Equal(c.t, c.expectedValue, ctx.Value(contextKey{})) + return next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) +} + +func (c *checkContextClient) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*empty.Empty, error) { + return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) +} + +func (c *checkContextClient) setExpectedValue(value string) { + c.expectedValue = value +} diff --git a/pkg/registry/common/begin/nse_event_factory_server_test.go b/pkg/registry/common/begin/nse_event_factory_server_test.go new file mode 100644 index 000000000..ea980bbfe --- /dev/null +++ b/pkg/registry/common/begin/nse_event_factory_server_test.go @@ -0,0 +1,131 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package begin_test + +import ( + "context" + "testing" + "time" + + "github.com/networkservicemesh/api/pkg/api/registry" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/networkservicemesh/sdk/pkg/registry/common/begin" + "github.com/networkservicemesh/sdk/pkg/registry/core/chain" + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + + "github.com/stretchr/testify/assert" + "go.uber.org/goleak" +) + +// This test reproduces the situation when Unregister and Register were called at the same time +func TestRefreshDuringUnregister_Server(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + syncChan := make(chan struct{}) + checkCtxServ := &checkContextServer{t: t} + eventFactoryServ := &eventFactoryServer{ch: syncChan} + server := chain.NewNetworkServiceEndpointRegistryServer( + begin.NewNetworkServiceEndpointRegistryServer(), + checkCtxServ, + eventFactoryServ, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Set any value to context + ctx = context.WithValue(ctx, contextKey{}, "value_1") + checkCtxServ.setExpectedValue("value_1") + + // Do Register with this context + nse := ®istry.NetworkServiceEndpoint{ + Name: "1", + } + conn, err := server.Register(ctx, nse.Clone()) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + + // Change context value before refresh + ctx = context.WithValue(ctx, contextKey{}, "value_2") + checkCtxServ.setExpectedValue("value_2") + + // Call Unregister from eventFactory + eventFactoryServ.callUnregister() + <-syncChan + + // Call refresh (should be called at the same time as Unregister) + conn, err = server.Register(ctx, nse.Clone()) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + + // Call refresh from eventFactory. We are expecting updated value in the context + eventFactoryServ.callRefresh() + <-syncChan +} + +type eventFactoryServer struct { + registry.NetworkServiceEndpointRegistryServer + ctx context.Context + ch chan<- struct{} +} + +func (e *eventFactoryServer) Register(ctx context.Context, in *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { + e.ctx = ctx + return next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, in) +} + +func (e *eventFactoryServer) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint) (*emptypb.Empty, error) { + // Wait to be sure that reregister was called + time.Sleep(time.Millisecond * 100) + return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, in) +} + +func (e *eventFactoryServer) callUnregister() { + eventFactory := begin.FromContext(e.ctx) + go func() { + e.ch <- struct{}{} + eventFactory.Unregister() + }() +} + +func (e *eventFactoryServer) callRefresh() { + eventFactory := begin.FromContext(e.ctx) + go func() { + e.ch <- struct{}{} + eventFactory.Register() + }() +} + +type checkContextServer struct { + registry.NetworkServiceEndpointRegistryServer + t *testing.T + expectedValue string +} + +func (c *checkContextServer) Register(ctx context.Context, in *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { + assert.Equal(c.t, c.expectedValue, ctx.Value(contextKey{})) + return next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, in) +} + +func (c *checkContextServer) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint) (*emptypb.Empty, error) { + return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, in) +} + +func (c *checkContextServer) setExpectedValue(value string) { + c.expectedValue = value +} diff --git a/pkg/registry/common/begin/nse_server.go b/pkg/registry/common/begin/nse_server.go index bba97fb54..28221d05a 100644 --- a/pkg/registry/common/begin/nse_server.go +++ b/pkg/registry/common/begin/nse_server.go @@ -54,7 +54,7 @@ func (b *beginNSEServer) Register(ctx context.Context, in *registry.NetworkServi var err error <-eventFactoryServer.executor.AsyncExec(func() { - currentEventFactoryServer, _ := b.LoadOrStore(id, eventFactoryServer) + currentEventFactoryServer, _ := b.Load(id) if currentEventFactoryServer != eventFactoryServer { log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryServer != eventFactoryServer") resp, err = b.Register(ctx, in) @@ -96,7 +96,7 @@ func (b *beginNSEServer) Unregister(ctx context.Context, in *registry.NetworkSer if eventFactoryServer.state != established || eventFactoryServer.registration == nil { return } - currentServerClient, _ := b.LoadOrStore(id, eventFactoryServer) + currentServerClient, _ := b.Load(id) if currentServerClient != eventFactoryServer { return }