From 70fd0e28b2889157c529d1e40c69b82b6d7837bc Mon Sep 17 00:00:00 2001 From: Ertan Deniz Date: Thu, 9 Nov 2023 09:50:31 +0100 Subject: [PATCH] feat: implement kip-368 for sasl reauthentication --- conn.go | 57 +++++++++++++++++++++++++++++++++------- dialer.go | 35 +++++++++++++++--------- protocol/conn.go | 19 ++++++++++---- saslauthenticate.go | 38 +++++++++++++++++++++++++++ saslauthenticate_test.go | 28 ++++++++++++++++++++ transport.go | 34 +++++++++++++++++++++--- 6 files changed, 181 insertions(+), 30 deletions(-) diff --git a/conn.go b/conn.go index 2b51afbd5..1cb66ccb7 100644 --- a/conn.go +++ b/conn.go @@ -42,6 +42,10 @@ type Conn struct { wbuf bufio.Writer wb writeBuffer + // sasl session + saslSessionDeadline time.Time + saslAuth func() error + // deadline management wdeadline connDeadline rdeadline connDeadline @@ -1363,6 +1367,11 @@ func (c *Conn) do(d *connDeadline, write func(time.Time, int32) error, read func } func (c *Conn) doRequest(d *connDeadline, write func(time.Time, int32) error) (id int32, err error) { + // KIP-368 + if !c.saslSessionDeadline.IsZero() && time.Now().After(c.saslSessionDeadline) { + c.saslAuth() + } + c.enter() c.wlock.Lock() c.correlationID++ @@ -1601,28 +1610,58 @@ func (c *Conn) saslAuthenticate(data []byte) ([]byte, error) { // if we sent a v1 handshake, then we must encapsulate the authentication // request in a saslAuthenticateRequest. otherwise, we read and write raw // bytes. - version, err := c.negotiateVersion(saslHandshake, v0, v1) + handshakeVersion, err := c.negotiateVersion(saslHandshake, v0, v1) if err != nil { return nil, err } - if version == v1 { + if handshakeVersion == v1 { + authVersion, err := c.negotiateVersion(saslAuthenticate, v0, v1) + if err != nil { + return nil, err + } var request = saslAuthenticateRequestV0{Data: data} - var response saslAuthenticateResponseV0 + var errorCode int16 + var authData []byte - err := c.writeOperation( + err = c.writeOperation( func(deadline time.Time, id int32) error { - return c.writeRequest(saslAuthenticate, v0, id, request) + return c.writeRequest(saslAuthenticate, authVersion, id, request) }, func(deadline time.Time, size int) error { return expectZeroSize(func() (remain int, err error) { - return (&response).readFrom(&c.rbuf, size) + switch authVersion { + case v0: + var response saslAuthenticateResponseV0 + remain, err = (&response).readFrom(&c.rbuf, size) + if err != nil { + return remain, err + } + + errorCode = response.ErrorCode + authData = response.Data + case v1: + var response saslAuthenticateResponseV1 + remain, err = (&response).readFrom(&c.rbuf, size) + if err != nil { + return remain, err + } + + errorCode = response.ErrorCode + authData = response.Data + if response.SessionLifetimeMs > 0 { + // set sasl session deadline to %90 of session lifetime + c.saslSessionDeadline = time.Now().Add(time.Duration(float64(response.SessionLifetimeMs)*0.9) * time.Millisecond) + } + } + + return remain, err }()) }, ) - if err == nil && response.ErrorCode != 0 { - err = Error(response.ErrorCode) + if err == nil && errorCode != 0 { + err = Error(errorCode) } - return response.Data, err + return authData, err } // fall back to opaque bytes on the wire. the broker is expecting these if diff --git a/dialer.go b/dialer.go index 7786ed320..fbe44ee18 100644 --- a/dialer.go +++ b/dialer.go @@ -282,19 +282,27 @@ func (d *Dialer) connect(ctx context.Context, network, address string, connCfg C conn := NewConnWith(c, connCfg) - if d.SASLMechanism != nil { - host, port, err := splitHostPortNumber(address) - if err != nil { - return nil, fmt.Errorf("could not determine host/port for SASL authentication: %w", err) - } - metadata := &sasl.Metadata{ - Host: host, - Port: port, - } - if err := d.authenticateSASL(sasl.WithMetadata(ctx, metadata), conn); err != nil { - _ = conn.Close() - return nil, fmt.Errorf("could not successfully authenticate to %s:%d with SASL: %w", host, port, err) + conn.saslAuth = func() error { + if d.SASLMechanism != nil { + host, port, err := splitHostPortNumber(address) + if err != nil { + return fmt.Errorf("could not determine host/port for SASL authentication: %w", err) + } + metadata := &sasl.Metadata{ + Host: host, + Port: port, + } + if err := d.authenticateSASL(sasl.WithMetadata(ctx, metadata), conn); err != nil { + _ = conn.Close() + return fmt.Errorf("could not successfully authenticate to %s:%d with SASL: %w", host, port, err) + } } + return nil + } + + err = conn.saslAuth() + if err != nil { + return nil, err } return conn, nil @@ -307,6 +315,9 @@ func (d *Dialer) connect(ctx context.Context, network, address string, connCfg C // In case of error, this function *does not* close the connection. That is the // responsibility of the caller. func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error { + // reset the SaslSessionDeadline before authenticating + conn.saslSessionDeadline = time.Time{} + if err := conn.saslHandshake(d.SASLMechanism.Name()); err != nil { return fmt.Errorf("SASL handshake failed: %w", err) } diff --git a/protocol/conn.go b/protocol/conn.go index d08a577f6..3405d7ff1 100644 --- a/protocol/conn.go +++ b/protocol/conn.go @@ -9,11 +9,12 @@ import ( ) type Conn struct { - buffer *bufio.Reader - conn net.Conn - clientID string - idgen int32 - versions atomic.Value // map[ApiKey]int16 + buffer *bufio.Reader + conn net.Conn + clientID string + idgen int32 + saslSessionDeadline time.Time + versions atomic.Value // map[ApiKey]int16 } func NewConn(conn net.Conn, clientID string) *Conn { @@ -68,6 +69,14 @@ func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } +func (c *Conn) SetSaslSessionDeadline(t time.Time) { + c.saslSessionDeadline = t +} + +func (c *Conn) GetSaslSessionDeadline() time.Time { + return c.saslSessionDeadline +} + func (c *Conn) SetVersions(versions map[ApiKey]int16) { connVersions := make(map[ApiKey]int16, len(versions)) diff --git a/saslauthenticate.go b/saslauthenticate.go index ad1292918..865c8d070 100644 --- a/saslauthenticate.go +++ b/saslauthenticate.go @@ -52,3 +52,41 @@ func (t *saslAuthenticateResponseV0) readFrom(r *bufio.Reader, sz int) (remain i } return } + +type saslAuthenticateResponseV1 struct { + // ErrorCode holds response error code + ErrorCode int16 + + ErrorMessage string + + Data []byte + + SessionLifetimeMs int64 +} + +func (t saslAuthenticateResponseV1) size() int32 { + return sizeofInt16(t.ErrorCode) + sizeofString(t.ErrorMessage) + sizeofBytes(t.Data) + sizeofInt64(t.SessionLifetimeMs) +} + +func (t saslAuthenticateResponseV1) writeTo(wb *writeBuffer) { + wb.writeInt16(t.ErrorCode) + wb.writeString(t.ErrorMessage) + wb.writeBytes(t.Data) + wb.writeInt64(t.SessionLifetimeMs) +} + +func (t *saslAuthenticateResponseV1) readFrom(r *bufio.Reader, sz int) (remain int, err error) { + if remain, err = readInt16(r, sz, &t.ErrorCode); err != nil { + return + } + if remain, err = readString(r, remain, &t.ErrorMessage); err != nil { + return + } + if remain, err = readBytes(r, remain, &t.Data); err != nil { + return + } + if remain, err = readInt64(r, remain, &t.SessionLifetimeMs); err != nil { + return + } + return +} diff --git a/saslauthenticate_test.go b/saslauthenticate_test.go index 89a33e3da..e5318bfd0 100644 --- a/saslauthenticate_test.go +++ b/saslauthenticate_test.go @@ -58,3 +58,31 @@ func TestSASLAuthenticateResponseV0(t *testing.T) { t.FailNow() } } + +func TestSASLAuthenticateResponseV1(t *testing.T) { + item := saslAuthenticateResponseV1{ + ErrorCode: 2, + ErrorMessage: "Message", + Data: []byte("bytes"), + SessionLifetimeMs: 300000, + } + + b := bytes.NewBuffer(nil) + w := &writeBuffer{w: b} + item.writeTo(w) + + var found saslAuthenticateResponseV1 + remain, err := (&found).readFrom(bufio.NewReader(b), b.Len()) + if err != nil { + t.Error(err) + t.FailNow() + } + if remain != 0 { + t.Errorf("expected 0 remain, got %v", remain) + t.FailNow() + } + if !reflect.DeepEqual(item, found) { + t.Error("expected item and found to be the same") + t.FailNow() + } +} diff --git a/transport.go b/transport.go index 685bdddb1..9bc9181f8 100644 --- a/transport.go +++ b/transport.go @@ -1274,6 +1274,22 @@ func (c *conn) roundTrip(ctx context.Context, pc *protocol.Conn, req Request) (R pprof.SetGoroutineLabels(ctx) defer pprof.SetGoroutineLabels(context.Background()) + // KIP-368 + var saslSessionDeadline = pc.GetSaslSessionDeadline() + if !saslSessionDeadline.IsZero() && time.Now().After(saslSessionDeadline) { + host, port, err := splitHostPortNumber(c.address) + if err != nil { + return nil, err + } + metadata := &sasl.Metadata{ + Host: host, + Port: port, + } + if err := authenticateSASL(sasl.WithMetadata(ctx, metadata), pc, c.group.pool.sasl); err != nil { + return nil, err + } + } + if deadline, hasDeadline := ctx.Deadline(); hasDeadline { pc.SetDeadline(deadline) defer pc.SetDeadline(time.Time{}) @@ -1286,6 +1302,9 @@ func (c *conn) roundTrip(ctx context.Context, pc *protocol.Conn, req Request) (R // connection. If any step fails, this function returns with an error. A nil // error indicates successful authentication. func authenticateSASL(ctx context.Context, pc *protocol.Conn, mechanism sasl.Mechanism) error { + // reset the SaslSessionDeadline before authenticating + pc.SetSaslSessionDeadline(time.Time{}) + if err := saslHandshakeRoundTrip(pc, mechanism.Name()); err != nil { return err } @@ -1296,7 +1315,7 @@ func authenticateSASL(ctx context.Context, pc *protocol.Conn, mechanism sasl.Mec } for completed := false; !completed; { - challenge, err := saslAuthenticateRoundTrip(pc, state) + challenge, sessionLifetimeMs, err := saslAuthenticateRoundTrip(pc, state) if err != nil { if errors.Is(err, io.EOF) { // the broker may communicate a failed exchange by closing the @@ -1312,6 +1331,12 @@ func authenticateSASL(ctx context.Context, pc *protocol.Conn, mechanism sasl.Mec if err != nil { return err } + + if sessionLifetimeMs > 0 { + // set sasl session deadline to %90 of session lifetime + var saslSessionDeadline = time.Now().Add(time.Duration(float64(sessionLifetimeMs)*0.9) * time.Millisecond) + pc.SetSaslSessionDeadline(saslSessionDeadline) + } } return nil @@ -1346,18 +1371,19 @@ func saslHandshakeRoundTrip(pc *protocol.Conn, mechanism string) error { // be immediately preceded by a successful saslHandshake. // // See http://kafka.apache.org/protocol.html#The_Messages_SaslAuthenticate -func saslAuthenticateRoundTrip(pc *protocol.Conn, data []byte) ([]byte, error) { +func saslAuthenticateRoundTrip(pc *protocol.Conn, data []byte) ([]byte, int64, error) { msg, err := pc.RoundTrip(&saslauthenticate.Request{ AuthBytes: data, }) if err != nil { - return nil, err + return nil, 0, err } res := msg.(*saslauthenticate.Response) if res.ErrorCode != 0 { err = makeError(res.ErrorCode, res.ErrorMessage) } - return res.AuthBytes, err + + return res.AuthBytes, res.SessionLifetimeMs, err } var _ RoundTripper = (*Transport)(nil)