diff --git a/helper/pool/pool.go b/helper/pool/pool.go index 7b5d3b424878..7587cbc17a58 100644 --- a/helper/pool/pool.go +++ b/helper/pool/pool.go @@ -67,7 +67,7 @@ func (c *Conn) Close() error { } // getClient is used to get a cached or new client -func (c *Conn) getClient() (*StreamClient, error) { +func (c *Conn) getRPCClient() (*StreamClient, error) { // Check for cached client c.clientLock.Lock() front := c.clients.Front() @@ -85,6 +85,11 @@ func (c *Conn) getClient() (*StreamClient, error) { return nil, err } + if _, err := stream.Write([]byte{byte(RpcNomad)}); err != nil { + stream.Close() + return nil, err + } + // Create a client codec codec := NewClientCodec(stream) @@ -332,7 +337,7 @@ func (p *ConnPool) getNewConn(region string, addr net.Addr, version int) (*Conn, } // Write the multiplex byte to set the mode - if _, err := conn.Write([]byte{byte(RpcMultiplex)}); err != nil { + if _, err := conn.Write([]byte{byte(RpcMultiplexV2)}); err != nil { conn.Close() return nil, err } @@ -390,7 +395,7 @@ func (p *ConnPool) releaseConn(conn *Conn) { } // getClient is used to get a usable client for an address and protocol version -func (p *ConnPool) getClient(region string, addr net.Addr, version int) (*Conn, *StreamClient, error) { +func (p *ConnPool) getRPCClient(region string, addr net.Addr, version int) (*Conn, *StreamClient, error) { retries := 0 START: // Try to get a conn first @@ -400,7 +405,7 @@ START: } // Get a client - client, err := conn.getClient() + client, err := conn.getRPCClient() if err != nil { p.clearConn(conn) p.releaseConn(conn) @@ -415,10 +420,31 @@ START: return conn, client, nil } +// StreamingRPC is used to make an streaming RPC call. Callers must +// close the channel when done. +func (p *ConnPool) StreamingRPC(region string, addr net.Addr, version int) (net.Conn, error) { + conn, err := p.acquire(region, addr, version) + if err != nil { + return nil, fmt.Errorf("failed to get conn: %v", err) + } + + s, err := conn.session.Open() + if err != nil { + return nil, fmt.Errorf("failed to open a streaming connection: %v", err) + } + + if _, err := s.Write([]byte{byte(RpcStreaming)}); err != nil { + conn.Close() + return nil, err + } + + return s, nil +} + // RPC is used to make an RPC call to a remote host func (p *ConnPool) RPC(region string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error { // Get a usable client - conn, sc, err := p.getClient(region, addr, version) + conn, sc, err := p.getRPCClient(region, addr, version) if err != nil { return fmt.Errorf("rpc error: %v", err) } diff --git a/nomad/rpc.go b/nomad/rpc.go index f3986bc257cf..29461eb40335 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -653,52 +653,19 @@ func (r *rpcHandler) getServer(region, serverID string) (*serverParts, error) { // initial handshake, returning the connection or an error. It is the callers // responsibility to close the connection if there is no returned error. func (r *rpcHandler) streamingRpc(server *serverParts, method string) (net.Conn, error) { - // Try to dial the server - conn, err := net.DialTimeout("tcp", server.Addr.String(), 10*time.Second) + c, err := r.connPool.StreamingRPC(r.config.Region, server.Addr, server.MajorVersion) if err != nil { return nil, err } - // Cast to TCPConn - if tcp, ok := conn.(*net.TCPConn); ok { - tcp.SetKeepAlive(true) - tcp.SetNoDelay(true) - } - - return r.streamingRpcImpl(conn, server.Region, method) + return r.streamingRpcImpl(c, method) } // streamingRpcImpl takes a pre-established connection to a server and conducts // the handshake to establish a streaming RPC for the given method. If an error // is returned, the underlying connection has been closed. Otherwise it is // assumed that the connection has been hijacked by the RPC method. -func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) (net.Conn, error) { - // Check if TLS is enabled - r.tlsWrapLock.RLock() - tlsWrap := r.tlsWrap - r.tlsWrapLock.RUnlock() - - if tlsWrap != nil { - // Switch the connection into TLS mode - if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil { - conn.Close() - return nil, err - } - - // Wrap the connection in a TLS client - tlsConn, err := tlsWrap(region, conn) - if err != nil { - conn.Close() - return nil, err - } - conn = tlsConn - } - - // Write the multiplex byte to set the mode - if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil { - conn.Close() - return nil, err - } +func (r *rpcHandler) streamingRpcImpl(conn net.Conn, method string) (net.Conn, error) { // Send the header encoder := codec.NewEncoder(conn, structs.MsgpackHandle) diff --git a/nomad/rpc_test.go b/nomad/rpc_test.go index 647ef53c7104..8ac614d07011 100644 --- a/nomad/rpc_test.go +++ b/nomad/rpc_test.go @@ -530,7 +530,10 @@ func TestRPC_handleMultiplexV2(t *testing.T) { require.NotEmpty(l) // Make a streaming RPC - _, err = s.streamingRpcImpl(s2, s.Region(), "Bogus") + _, err = s2.Write([]byte{byte(pool.RpcStreaming)}) + require.Nil(err) + + _, err = s.streamingRpcImpl(s2, "Bogus") require.NotNil(err) require.Contains(err.Error(), "Bogus") require.True(structs.IsErrUnknownMethod(err))