From 72cc8cfb72b94364ba82b94ae04ee56b83a75b33 Mon Sep 17 00:00:00 2001 From: Joseph Anttila Hall Date: Mon, 21 Aug 2023 10:49:57 -0700 Subject: [PATCH] proxy-server: Wrap Backend more completely. The goal is to fix nil deref at https://github.com/kubernetes-sigs/apiserver-network-proxy/issues/513, which was caused by ProxyServer.addBackend inconsistent return value. --- pkg/server/backend_manager.go | 71 ++++++++++++-------- pkg/server/backend_manager_test.go | 86 ++++++++++++------------- pkg/server/server.go | 58 ++++++----------- pkg/server/server_test.go | 13 ++-- proto/header/header.go | 2 +- tests/concurrent_client_request_test.go | 20 +++--- 6 files changed, 124 insertions(+), 126 deletions(-) diff --git a/pkg/server/backend_manager.go b/pkg/server/backend_manager.go index 77f620423..6a36914bf 100644 --- a/pkg/server/backend_manager.go +++ b/pkg/server/backend_manager.go @@ -25,6 +25,7 @@ import ( "sync" "time" + "google.golang.org/grpc/metadata" "k8s.io/klog/v2" commonmetrics "sigs.k8s.io/apiserver-network-proxy/konnectivity-client/pkg/common/metrics" @@ -70,18 +71,21 @@ func GenProxyStrategiesFromStr(proxyStrategies string) ([]ProxyStrategy, error) return ps, nil } +// Backend abstracts a connected Konnectivity agent. +// +// In the only currently supported case (gRPC), it wraps an +// agent.AgentService_ConnectServer, provides synchronization and +// emits common stream metrics. type Backend interface { Send(p *client.Packet) error Recv() (*client.Packet, error) Context() context.Context + GetAgentIdentifiers() (header.Identifiers, error) } var _ Backend = &backend{} -var _ Backend = agent.AgentService_ConnectServer(nil) type backend struct { - // TODO: this is a multi-writer single-reader pattern, it's tricky to - // write it using channel. Let's worry about performance later. sendLock sync.Mutex recvLock sync.Mutex conn agent.AgentService_ConnectServer @@ -121,7 +125,24 @@ func (b *backend) Context() context.Context { return b.conn.Context() } -func newBackend(conn agent.AgentService_ConnectServer) *backend { +func (b *backend) GetAgentIdentifiers() (header.Identifiers, error) { + var agentIdentifiers header.Identifiers + md, ok := metadata.FromIncomingContext(b.Context()) + if !ok { + return agentIdentifiers, fmt.Errorf("failed to get metadata from context") + } + agentIDs := md.Get(header.AgentIdentifiers) + if len(agentIDs) > 1 { + return agentIdentifiers, fmt.Errorf("expected at most one set of agent IDs in the context, got %v", agentIDs) + } + if len(agentIDs) == 0 { + return agentIdentifiers, nil + } + + return header.GenAgentIdentifiers(agentIDs[0]) +} + +func NewBackend(conn agent.AgentService_ConnectServer) Backend { return &backend{conn: conn} } @@ -129,9 +150,9 @@ func newBackend(conn agent.AgentService_ConnectServer) *backend { // connections, i.e., get, add and remove type BackendStorage interface { // AddBackend adds a backend. - AddBackend(identifier string, idType header.IdentifierType, conn agent.AgentService_ConnectServer) Backend + AddBackend(identifier string, idType header.IdentifierType, backend Backend) // RemoveBackend removes a backend. - RemoveBackend(identifier string, idType header.IdentifierType, conn agent.AgentService_ConnectServer) + RemoveBackend(identifier string, idType header.IdentifierType, backend Backend) // NumBackends returns the number of backends. NumBackends() int } @@ -168,7 +189,7 @@ type DefaultBackendStorage struct { // For a given agent, ProxyServer prefers backends[agentID][0] to send // traffic, because backends[agentID][1:] are more likely to be closed // by the agent to deduplicate connections to the same server. - backends map[string][]*backend + backends map[string][]Backend // agentID is tracked in this slice to enable randomly picking an // agentID in the Backend() method. There is no reliable way to // randomly pick a key from a map (in this case, the backends) in @@ -198,7 +219,7 @@ func NewDefaultBackendStorage(idTypes []header.IdentifierType) *DefaultBackendSt // no agent ever successfully connects. metrics.Metrics.SetBackendCount(0) return &DefaultBackendStorage{ - backends: make(map[string][]*backend), + backends: make(map[string][]Backend), random: rand.New(rand.NewSource(time.Now().UnixNano())), idTypes: idTypes, } /* #nosec G404 */ @@ -214,42 +235,40 @@ func containIDType(idTypes []header.IdentifierType, idType header.IdentifierType } // AddBackend adds a backend. -func (s *DefaultBackendStorage) AddBackend(identifier string, idType header.IdentifierType, conn agent.AgentService_ConnectServer) Backend { +func (s *DefaultBackendStorage) AddBackend(identifier string, idType header.IdentifierType, backend Backend) { if !containIDType(s.idTypes, idType) { klog.V(4).InfoS("fail to add backend", "backend", identifier, "error", &ErrWrongIDType{idType, s.idTypes}) - return nil + return } - klog.V(5).InfoS("Register backend for agent", "connection", conn, "agentID", identifier) + klog.V(5).InfoS("Register backend for agent", "agentID", identifier) s.mu.Lock() defer s.mu.Unlock() _, ok := s.backends[identifier] - addedBackend := newBackend(conn) if ok { - for _, v := range s.backends[identifier] { - if v.conn == conn { - klog.V(1).InfoS("This should not happen. Adding existing backend for agent", "connection", conn, "agentID", identifier) - return v + for _, b := range s.backends[identifier] { + if b == backend { + klog.V(1).InfoS("This should not happen. Adding existing backend for agent", "agentID", identifier) + return } } - s.backends[identifier] = append(s.backends[identifier], addedBackend) - return addedBackend + s.backends[identifier] = append(s.backends[identifier], backend) + return } - s.backends[identifier] = []*backend{addedBackend} + s.backends[identifier] = []Backend{backend} metrics.Metrics.SetBackendCount(len(s.backends)) s.agentIDs = append(s.agentIDs, identifier) if idType == header.DefaultRoute { s.defaultRouteAgentIDs = append(s.defaultRouteAgentIDs, identifier) } - return addedBackend } // RemoveBackend removes a backend. -func (s *DefaultBackendStorage) RemoveBackend(identifier string, idType header.IdentifierType, conn agent.AgentService_ConnectServer) { +func (s *DefaultBackendStorage) RemoveBackend(identifier string, idType header.IdentifierType, backend Backend) { if !containIDType(s.idTypes, idType) { klog.ErrorS(&ErrWrongIDType{idType, s.idTypes}, "fail to remove backend") return } - klog.V(5).InfoS("Remove connection for agent", "connection", conn, "identifier", identifier) + klog.V(5).InfoS("Remove connection for agent", "agentID", identifier) s.mu.Lock() defer s.mu.Unlock() backends, ok := s.backends[identifier] @@ -258,11 +277,11 @@ func (s *DefaultBackendStorage) RemoveBackend(identifier string, idType header.I return } var found bool - for i, c := range backends { - if c.conn == conn { + for i, b := range backends { + if b == backend { s.backends[identifier] = append(s.backends[identifier][:i], s.backends[identifier][i+1:]...) if i == 0 && len(s.backends[identifier]) != 0 { - klog.V(1).InfoS("This should not happen. Removed connection that is not the first connection", "connection", conn, "remainingConnections", s.backends[identifier]) + klog.V(1).InfoS("This should not happen. Removed connection that is not the first connection", "agentID", identifier) } found = true } @@ -286,7 +305,7 @@ func (s *DefaultBackendStorage) RemoveBackend(identifier string, idType header.I } } if !found { - klog.V(1).InfoS("Could not find connection matching identifier to remove", "connection", conn, "identifier", identifier) + klog.V(1).InfoS("Could not find connection matching identifier to remove", "agentID", identifier, "idType", idType) } metrics.Metrics.SetBackendCount(len(s.backends)) } diff --git a/pkg/server/backend_manager_test.go b/pkg/server/backend_manager_test.go index f79a3df77..82ae1874a 100644 --- a/pkg/server/backend_manager_test.go +++ b/pkg/server/backend_manager_test.go @@ -29,17 +29,17 @@ type fakeAgentServiceConnectServer struct { } func TestAddRemoveBackends(t *testing.T) { - conn1 := new(fakeAgentServiceConnectServer) - conn12 := new(fakeAgentServiceConnectServer) - conn2 := new(fakeAgentServiceConnectServer) - conn22 := new(fakeAgentServiceConnectServer) - conn3 := new(fakeAgentServiceConnectServer) + backend1 := NewBackend(new(fakeAgentServiceConnectServer)) + backend12 := NewBackend(new(fakeAgentServiceConnectServer)) + backend2 := NewBackend(new(fakeAgentServiceConnectServer)) + backend22 := NewBackend(new(fakeAgentServiceConnectServer)) + backend3 := NewBackend(new(fakeAgentServiceConnectServer)) p := NewDefaultBackendManager() - p.AddBackend("agent1", header.UID, conn1) - p.RemoveBackend("agent1", header.UID, conn1) - expectedBackends := make(map[string][]*backend) + p.AddBackend("agent1", header.UID, backend1) + p.RemoveBackend("agent1", header.UID, backend1) + expectedBackends := make(map[string][]Backend) expectedAgentIDs := []string{} if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { t.Errorf("expected %v, got %v", e, a) @@ -49,21 +49,21 @@ func TestAddRemoveBackends(t *testing.T) { } p = NewDefaultBackendManager() - p.AddBackend("agent1", header.UID, conn1) - p.AddBackend("agent1", header.UID, conn12) + p.AddBackend("agent1", header.UID, backend1) + p.AddBackend("agent1", header.UID, backend12) // Adding the same connection again should be a no-op. - p.AddBackend("agent1", header.UID, conn12) - p.AddBackend("agent2", header.UID, conn2) - p.AddBackend("agent2", header.UID, conn22) - p.AddBackend("agent3", header.UID, conn3) - p.RemoveBackend("agent2", header.UID, conn22) - p.RemoveBackend("agent2", header.UID, conn2) - p.RemoveBackend("agent1", header.UID, conn1) - // This is invalid. agent1 doesn't have conn3. This should be a no-op. - p.RemoveBackend("agent1", header.UID, conn3) - expectedBackends = map[string][]*backend{ - "agent1": {newBackend(conn12)}, - "agent3": {newBackend(conn3)}, + p.AddBackend("agent1", header.UID, backend12) + p.AddBackend("agent2", header.UID, backend2) + p.AddBackend("agent2", header.UID, backend22) + p.AddBackend("agent3", header.UID, backend3) + p.RemoveBackend("agent2", header.UID, backend22) + p.RemoveBackend("agent2", header.UID, backend2) + p.RemoveBackend("agent1", header.UID, backend1) + // This is invalid. agent1 doesn't have backend3. This should be a no-op. + p.RemoveBackend("agent1", header.UID, backend3) + expectedBackends = map[string][]Backend{ + "agent1": {backend12}, + "agent3": {backend3}, } expectedAgentIDs = []string{"agent1", "agent3"} if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { @@ -75,17 +75,17 @@ func TestAddRemoveBackends(t *testing.T) { } func TestAddRemoveBackendsWithDefaultRoute(t *testing.T) { - conn1 := new(fakeAgentServiceConnectServer) - conn12 := new(fakeAgentServiceConnectServer) - conn2 := new(fakeAgentServiceConnectServer) - conn22 := new(fakeAgentServiceConnectServer) - conn3 := new(fakeAgentServiceConnectServer) + backend1 := NewBackend(new(fakeAgentServiceConnectServer)) + backend12 := NewBackend(new(fakeAgentServiceConnectServer)) + backend2 := NewBackend(new(fakeAgentServiceConnectServer)) + backend22 := NewBackend(new(fakeAgentServiceConnectServer)) + backend3 := NewBackend(new(fakeAgentServiceConnectServer)) p := NewDefaultRouteBackendManager() - p.AddBackend("agent1", header.DefaultRoute, conn1) - p.RemoveBackend("agent1", header.DefaultRoute, conn1) - expectedBackends := make(map[string][]*backend) + p.AddBackend("agent1", header.DefaultRoute, backend1) + p.RemoveBackend("agent1", header.DefaultRoute, backend1) + expectedBackends := make(map[string][]Backend) expectedAgentIDs := []string{} if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { t.Errorf("expected %v, got %v", e, a) @@ -98,22 +98,22 @@ func TestAddRemoveBackendsWithDefaultRoute(t *testing.T) { } p = NewDefaultRouteBackendManager() - p.AddBackend("agent1", header.DefaultRoute, conn1) - p.AddBackend("agent1", header.DefaultRoute, conn12) + p.AddBackend("agent1", header.DefaultRoute, backend1) + p.AddBackend("agent1", header.DefaultRoute, backend12) // Adding the same connection again should be a no-op. - p.AddBackend("agent1", header.DefaultRoute, conn12) - p.AddBackend("agent2", header.DefaultRoute, conn2) - p.AddBackend("agent2", header.DefaultRoute, conn22) - p.AddBackend("agent3", header.DefaultRoute, conn3) - p.RemoveBackend("agent2", header.DefaultRoute, conn22) - p.RemoveBackend("agent2", header.DefaultRoute, conn2) - p.RemoveBackend("agent1", header.DefaultRoute, conn1) + p.AddBackend("agent1", header.DefaultRoute, backend12) + p.AddBackend("agent2", header.DefaultRoute, backend2) + p.AddBackend("agent2", header.DefaultRoute, backend22) + p.AddBackend("agent3", header.DefaultRoute, backend3) + p.RemoveBackend("agent2", header.DefaultRoute, backend22) + p.RemoveBackend("agent2", header.DefaultRoute, backend2) + p.RemoveBackend("agent1", header.DefaultRoute, backend1) // This is invalid. agent1 doesn't have conn3. This should be a no-op. - p.RemoveBackend("agent1", header.DefaultRoute, conn3) + p.RemoveBackend("agent1", header.DefaultRoute, backend3) - expectedBackends = map[string][]*backend{ - "agent1": {newBackend(conn12)}, - "agent3": {newBackend(conn3)}, + expectedBackends = map[string][]Backend{ + "agent1": {backend12}, + "agent3": {backend3}, } expectedDefaultRouteAgentIDs := []string{"agent1", "agent3"} diff --git a/pkg/server/server.go b/pkg/server/server.go index c1bab4660..851dd7ac6 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -255,79 +255,79 @@ func (s *ProxyServer) getBackend(reqHost string) (Backend, error) { return nil, &ErrNotFound{} } -func (s *ProxyServer) addBackend(agentID string, conn agent.AgentService_ConnectServer) (backend Backend) { +func (s *ProxyServer) addBackend(agentID string, backend Backend) { for i := 0; i < len(s.BackendManagers); i++ { switch s.BackendManagers[i].(type) { case *DestHostBackendManager: - agentIdentifiers, err := getAgentIdentifiers(conn) + agentIdentifiers, err := backend.GetAgentIdentifiers() if err != nil { klog.ErrorS(err, "fail to get the agent identifiers", "agentID", agentID) break } for _, ipv4 := range agentIdentifiers.IPv4 { klog.V(5).InfoS("Add the agent to DestHostBackendManager", "agent address", ipv4) - s.BackendManagers[i].AddBackend(ipv4, header.IPv4, conn) + s.BackendManagers[i].AddBackend(ipv4, header.IPv4, backend) } for _, ipv6 := range agentIdentifiers.IPv6 { klog.V(5).InfoS("Add the agent to DestHostBackendManager", "agent address", ipv6) - s.BackendManagers[i].AddBackend(ipv6, header.IPv6, conn) + s.BackendManagers[i].AddBackend(ipv6, header.IPv6, backend) } for _, host := range agentIdentifiers.Host { klog.V(5).InfoS("Add the agent to DestHostBackendManager", "agent address", host) - s.BackendManagers[i].AddBackend(host, header.Host, conn) + s.BackendManagers[i].AddBackend(host, header.Host, backend) } case *DefaultRouteBackendManager: - agentIdentifiers, err := getAgentIdentifiers(conn) + agentIdentifiers, err := backend.GetAgentIdentifiers() if err != nil { klog.ErrorS(err, "fail to get the agent identifiers", "agentID", agentID) break } if agentIdentifiers.DefaultRoute { klog.V(5).InfoS("Add the agent to DefaultRouteBackendManager", "agentID", agentID) - backend = s.BackendManagers[i].AddBackend(agentID, header.DefaultRoute, conn) + s.BackendManagers[i].AddBackend(agentID, header.DefaultRoute, backend) } default: klog.V(5).InfoS("Add the agent to DefaultBackendManager", "agentID", agentID) - backend = s.BackendManagers[i].AddBackend(agentID, header.UID, conn) + s.BackendManagers[i].AddBackend(agentID, header.UID, backend) } } return } -func (s *ProxyServer) removeBackend(agentID string, conn agent.AgentService_ConnectServer) { +func (s *ProxyServer) removeBackend(agentID string, backend Backend) { for _, bm := range s.BackendManagers { switch bm.(type) { case *DestHostBackendManager: - agentIdentifiers, err := getAgentIdentifiers(conn) + agentIdentifiers, err := backend.GetAgentIdentifiers() if err != nil { klog.ErrorS(err, "fail to get the agent identifiers", "agentID", agentID) break } for _, ipv4 := range agentIdentifiers.IPv4 { klog.V(5).InfoS("Remove the agent from the DestHostBackendManager", "agentHost", ipv4) - bm.RemoveBackend(ipv4, header.IPv4, conn) + bm.RemoveBackend(ipv4, header.IPv4, backend) } for _, ipv6 := range agentIdentifiers.IPv6 { klog.V(5).InfoS("Remove the agent from the DestHostBackendManager", "agentHost", ipv6) - bm.RemoveBackend(ipv6, header.IPv6, conn) + bm.RemoveBackend(ipv6, header.IPv6, backend) } for _, host := range agentIdentifiers.Host { klog.V(5).InfoS("Remove the agent from the DestHostBackendManager", "agentHost", host) - bm.RemoveBackend(host, header.Host, conn) + bm.RemoveBackend(host, header.Host, backend) } case *DefaultRouteBackendManager: - agentIdentifiers, err := getAgentIdentifiers(conn) + agentIdentifiers, err := backend.GetAgentIdentifiers() if err != nil { klog.ErrorS(err, "fail to get the agent identifiers", "agentID", agentID) break } if agentIdentifiers.DefaultRoute { klog.V(5).InfoS("Remove the agent from the DefaultRouteBackendManager", "agentID", agentID) - bm.RemoveBackend(agentID, header.DefaultRoute, conn) + bm.RemoveBackend(agentID, header.DefaultRoute, backend) } default: klog.V(5).InfoS("Remove the agent from the DefaultBackendManager", "agentID", agentID) - bm.RemoveBackend(agentID, header.UID, conn) + bm.RemoveBackend(agentID, header.UID, backend) } } } @@ -706,27 +706,6 @@ func agentID(stream agent.AgentService_ConnectServer) (string, error) { return agentIDs[0], nil } -func getAgentIdentifiers(stream agent.AgentService_ConnectServer) (header.Identifiers, error) { - var agentIdentifiers header.Identifiers - md, ok := metadata.FromIncomingContext(stream.Context()) - if !ok { - return agentIdentifiers, fmt.Errorf("failed to get context") - } - agentIDs := md.Get(header.AgentIdentifiers) - if len(agentIDs) > 1 { - return agentIdentifiers, fmt.Errorf("expected at most one agent IP in the context, got %v", agentIDs) - } - if len(agentIDs) == 0 { - return agentIdentifiers, nil - } - - agentIdentifiers, err := header.GenAgentIdentifiers(agentIDs[0]) - if err != nil { - return agentIdentifiers, err - } - return agentIdentifiers, nil -} - func (s *ProxyServer) validateAuthToken(ctx context.Context, token string) (username string, err error) { trReq := &authv1.TokenReview{ Spec: authv1.TokenReviewSpec{ @@ -831,8 +810,9 @@ func (s *ProxyServer) Connect(stream agent.AgentService_ConnectServer) error { } klog.V(2).InfoS("Agent connected", "agentID", agentID, "serverID", s.serverID) - backend := s.addBackend(agentID, stream) - defer s.removeBackend(agentID, stream) + backend := NewBackend(stream) + s.addBackend(agentID, backend) + defer s.removeBackend(agentID, backend) recvCh := make(chan *client.Packet, xfrChannelSize) diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index c687290a7..233b1ad70 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -359,7 +359,7 @@ func prepareFrontendConn(ctrl *gomock.Controller) *agentmock.MockAgentService_Co return frontendConn } -func prepareAgentConnMD(ctrl *gomock.Controller, proxyServer *ProxyServer) *agentmock.MockAgentService_ConnectServer { +func prepareAgentConnMD(ctrl *gomock.Controller, proxyServer *ProxyServer) (*agentmock.MockAgentService_ConnectServer, Backend) { // prepare the the connection to agent of proxy-server agentConn := agentmock.NewMockAgentService_ConnectServer(ctrl) agentID := uuid.New().String() @@ -372,8 +372,9 @@ func prepareAgentConnMD(ctrl *gomock.Controller, proxyServer *ProxyServer) *agen } agentConnCtx := metadata.NewIncomingContext(context.Background(), agentConnMD) agentConn.EXPECT().Context().Return(agentConnCtx).AnyTimes() - _ = proxyServer.addBackend(agentID, agentConn) - return agentConn + backend := NewBackend(agentConn) + proxyServer.addBackend(agentID, backend) + return agentConn, backend } func baseServerProxyTestWithoutBackend(t *testing.T, validate func(*agentmock.MockAgentService_ConnectServer)) { @@ -397,7 +398,7 @@ func baseServerProxyTestWithBackend(t *testing.T, validate func(*agentmock.MockA // prepare proxy server proxyServer := NewProxyServer(uuid.New().String(), []ProxyStrategy{ProxyStrategyDefault}, 1, &AgentTokenAuthenticationOptions{}) - agentConn := prepareAgentConnMD(ctrl, proxyServer) + agentConn, _ := prepareAgentConnMD(ctrl, proxyServer) validate(frontendConn, agentConn) @@ -604,14 +605,14 @@ func TestReadyBackendsMetric(t *testing.T) { p := NewProxyServer(uuid.New().String(), []ProxyStrategy{ProxyStrategyDefault}, 1, &AgentTokenAuthenticationOptions{}) assertReadyBackendsMetric(t, 0) - agentConn := prepareAgentConnMD(ctrl, p) + agentConn, backend := prepareAgentConnMD(ctrl, p) assertReadyBackendsMetric(t, 1) agentID, err := agentID(agentConn) if err != nil { t.Fatalf("Could not get agentID: %v", err) } - p.removeBackend(agentID, agentConn) + p.removeBackend(agentID, backend) assertReadyBackendsMetric(t, 0) } diff --git a/proto/header/header.go b/proto/header/header.go index 5594add24..5c3afb91f 100644 --- a/proto/header/header.go +++ b/proto/header/header.go @@ -61,7 +61,7 @@ const ( ) // GenAgentIdentifiers generates an Identifiers based on the input string, the -// input string should be a comma-seprated list with each item in the format +// input string should be a comma-separated list with each item in the format // of =
func GenAgentIdentifiers(addrs string) (Identifiers, error) { var agentIDs Identifiers diff --git a/tests/concurrent_client_request_test.go b/tests/concurrent_client_request_test.go index 972d42f4c..9a03e8eb9 100644 --- a/tests/concurrent_client_request_test.go +++ b/tests/concurrent_client_request_test.go @@ -30,7 +30,6 @@ import ( "k8s.io/apimachinery/pkg/util/wait" "sigs.k8s.io/apiserver-network-proxy/konnectivity-client/pkg/client" "sigs.k8s.io/apiserver-network-proxy/pkg/server" - "sigs.k8s.io/apiserver-network-proxy/proto/agent" "sigs.k8s.io/apiserver-network-proxy/proto/header" ) @@ -74,26 +73,25 @@ func getTestClient(front string, t *testing.T) *http.Client { // singleTimeManager makes sure that a backend only serves one request. type singleTimeManager struct { mu sync.Mutex - backends map[string]agent.AgentService_ConnectServer + backends map[string]server.Backend used map[string]struct{} } -func (s *singleTimeManager) AddBackend(agentID string, _ header.IdentifierType, conn agent.AgentService_ConnectServer) server.Backend { +func (s *singleTimeManager) AddBackend(agentID string, _ header.IdentifierType, backend server.Backend) { s.mu.Lock() defer s.mu.Unlock() - s.backends[agentID] = conn - return conn + s.backends[agentID] = backend } -func (s *singleTimeManager) RemoveBackend(agentID string, _ header.IdentifierType, conn agent.AgentService_ConnectServer) { +func (s *singleTimeManager) RemoveBackend(agentID string, _ header.IdentifierType, backend server.Backend) { s.mu.Lock() defer s.mu.Unlock() v, ok := s.backends[agentID] if !ok { panic(fmt.Errorf("no backends found for %s", agentID)) } - if v != conn { - panic(fmt.Errorf("recorded connection %v does not match conn %v", v, conn)) + if v != backend { + panic(fmt.Errorf("recorded backend %v does not match %v", &v, backend)) } delete(s.backends, agentID) } @@ -121,7 +119,7 @@ func (s *singleTimeManager) NumBackends() int { func newSingleTimeGetter(m *server.DefaultBackendManager) *singleTimeManager { return &singleTimeManager{ used: make(map[string]struct{}), - backends: make(map[string]agent.AgentService_ConnectServer), + backends: make(map[string]server.Backend), } } @@ -145,8 +143,8 @@ func TestConcurrentClientRequest(t *testing.T) { stopCh := make(chan struct{}) defer close(stopCh) // Run two agents - cs1 := runAgent(proxy.agent, stopCh) - cs2 := runAgent(proxy.agent, stopCh) + cs1 := runAgentWithID("a", proxy.agent, stopCh) + cs2 := runAgentWithID("b", proxy.agent, stopCh) waitForConnectedServerCount(t, 1, cs1) waitForConnectedServerCount(t, 1, cs2)