Skip to content

Commit

Permalink
Merge pull request segmentio#1 from ertanden/feat/sasl-reauthentication
Browse files Browse the repository at this point in the history
feat: implement kip-368 for sasl reauthentication
  • Loading branch information
ertanden authored Nov 10, 2023
2 parents c6378c3 + 70fd0e2 commit daa69a4
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 30 deletions.
57 changes: 48 additions & 9 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ type Conn struct {
wbuf bufio.Writer
wb writeBuffer

// sasl session
saslSessionDeadline time.Time
saslAuth func() error

// deadline management
wdeadline connDeadline
rdeadline connDeadline
Expand Down Expand Up @@ -1363,6 +1367,11 @@ func (c *Conn) do(d *connDeadline, write func(time.Time, int32) error, read func
}

func (c *Conn) doRequest(d *connDeadline, write func(time.Time, int32) error) (id int32, err error) {
// KIP-368
if !c.saslSessionDeadline.IsZero() && time.Now().After(c.saslSessionDeadline) {
c.saslAuth()
}

c.enter()
c.wlock.Lock()
c.correlationID++
Expand Down Expand Up @@ -1601,28 +1610,58 @@ func (c *Conn) saslAuthenticate(data []byte) ([]byte, error) {
// if we sent a v1 handshake, then we must encapsulate the authentication
// request in a saslAuthenticateRequest. otherwise, we read and write raw
// bytes.
version, err := c.negotiateVersion(saslHandshake, v0, v1)
handshakeVersion, err := c.negotiateVersion(saslHandshake, v0, v1)
if err != nil {
return nil, err
}
if version == v1 {
if handshakeVersion == v1 {
authVersion, err := c.negotiateVersion(saslAuthenticate, v0, v1)
if err != nil {
return nil, err
}
var request = saslAuthenticateRequestV0{Data: data}
var response saslAuthenticateResponseV0
var errorCode int16
var authData []byte

err := c.writeOperation(
err = c.writeOperation(
func(deadline time.Time, id int32) error {
return c.writeRequest(saslAuthenticate, v0, id, request)
return c.writeRequest(saslAuthenticate, authVersion, id, request)
},
func(deadline time.Time, size int) error {
return expectZeroSize(func() (remain int, err error) {
return (&response).readFrom(&c.rbuf, size)
switch authVersion {
case v0:
var response saslAuthenticateResponseV0
remain, err = (&response).readFrom(&c.rbuf, size)
if err != nil {
return remain, err
}

errorCode = response.ErrorCode
authData = response.Data
case v1:
var response saslAuthenticateResponseV1
remain, err = (&response).readFrom(&c.rbuf, size)
if err != nil {
return remain, err
}

errorCode = response.ErrorCode
authData = response.Data
if response.SessionLifetimeMs > 0 {
// set sasl session deadline to %90 of session lifetime
c.saslSessionDeadline = time.Now().Add(time.Duration(float64(response.SessionLifetimeMs)*0.9) * time.Millisecond)
}
}

return remain, err
}())
},
)
if err == nil && response.ErrorCode != 0 {
err = Error(response.ErrorCode)
if err == nil && errorCode != 0 {
err = Error(errorCode)
}
return response.Data, err
return authData, err
}

// fall back to opaque bytes on the wire. the broker is expecting these if
Expand Down
35 changes: 23 additions & 12 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,19 +282,27 @@ func (d *Dialer) connect(ctx context.Context, network, address string, connCfg C

conn := NewConnWith(c, connCfg)

if d.SASLMechanism != nil {
host, port, err := splitHostPortNumber(address)
if err != nil {
return nil, fmt.Errorf("could not determine host/port for SASL authentication: %w", err)
}
metadata := &sasl.Metadata{
Host: host,
Port: port,
}
if err := d.authenticateSASL(sasl.WithMetadata(ctx, metadata), conn); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("could not successfully authenticate to %s:%d with SASL: %w", host, port, err)
conn.saslAuth = func() error {
if d.SASLMechanism != nil {
host, port, err := splitHostPortNumber(address)
if err != nil {
return fmt.Errorf("could not determine host/port for SASL authentication: %w", err)
}
metadata := &sasl.Metadata{
Host: host,
Port: port,
}
if err := d.authenticateSASL(sasl.WithMetadata(ctx, metadata), conn); err != nil {
_ = conn.Close()
return fmt.Errorf("could not successfully authenticate to %s:%d with SASL: %w", host, port, err)
}
}
return nil
}

err = conn.saslAuth()
if err != nil {
return nil, err
}

return conn, nil
Expand All @@ -307,6 +315,9 @@ func (d *Dialer) connect(ctx context.Context, network, address string, connCfg C
// In case of error, this function *does not* close the connection. That is the
// responsibility of the caller.
func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error {
// reset the SaslSessionDeadline before authenticating
conn.saslSessionDeadline = time.Time{}

if err := conn.saslHandshake(d.SASLMechanism.Name()); err != nil {
return fmt.Errorf("SASL handshake failed: %w", err)
}
Expand Down
19 changes: 14 additions & 5 deletions protocol/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ import (
)

type Conn struct {
buffer *bufio.Reader
conn net.Conn
clientID string
idgen int32
versions atomic.Value // map[ApiKey]int16
buffer *bufio.Reader
conn net.Conn
clientID string
idgen int32
saslSessionDeadline time.Time
versions atomic.Value // map[ApiKey]int16
}

func NewConn(conn net.Conn, clientID string) *Conn {
Expand Down Expand Up @@ -68,6 +69,14 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}

func (c *Conn) SetSaslSessionDeadline(t time.Time) {
c.saslSessionDeadline = t
}

func (c *Conn) GetSaslSessionDeadline() time.Time {
return c.saslSessionDeadline
}

func (c *Conn) SetVersions(versions map[ApiKey]int16) {
connVersions := make(map[ApiKey]int16, len(versions))

Expand Down
38 changes: 38 additions & 0 deletions saslauthenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,41 @@ func (t *saslAuthenticateResponseV0) readFrom(r *bufio.Reader, sz int) (remain i
}
return
}

type saslAuthenticateResponseV1 struct {
// ErrorCode holds response error code
ErrorCode int16

ErrorMessage string

Data []byte

SessionLifetimeMs int64
}

func (t saslAuthenticateResponseV1) size() int32 {
return sizeofInt16(t.ErrorCode) + sizeofString(t.ErrorMessage) + sizeofBytes(t.Data) + sizeofInt64(t.SessionLifetimeMs)
}

func (t saslAuthenticateResponseV1) writeTo(wb *writeBuffer) {
wb.writeInt16(t.ErrorCode)
wb.writeString(t.ErrorMessage)
wb.writeBytes(t.Data)
wb.writeInt64(t.SessionLifetimeMs)
}

func (t *saslAuthenticateResponseV1) readFrom(r *bufio.Reader, sz int) (remain int, err error) {
if remain, err = readInt16(r, sz, &t.ErrorCode); err != nil {
return
}
if remain, err = readString(r, remain, &t.ErrorMessage); err != nil {
return
}
if remain, err = readBytes(r, remain, &t.Data); err != nil {
return
}
if remain, err = readInt64(r, remain, &t.SessionLifetimeMs); err != nil {
return
}
return
}
28 changes: 28 additions & 0 deletions saslauthenticate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,31 @@ func TestSASLAuthenticateResponseV0(t *testing.T) {
t.FailNow()
}
}

func TestSASLAuthenticateResponseV1(t *testing.T) {
item := saslAuthenticateResponseV1{
ErrorCode: 2,
ErrorMessage: "Message",
Data: []byte("bytes"),
SessionLifetimeMs: 300000,
}

b := bytes.NewBuffer(nil)
w := &writeBuffer{w: b}
item.writeTo(w)

var found saslAuthenticateResponseV1
remain, err := (&found).readFrom(bufio.NewReader(b), b.Len())
if err != nil {
t.Error(err)
t.FailNow()
}
if remain != 0 {
t.Errorf("expected 0 remain, got %v", remain)
t.FailNow()
}
if !reflect.DeepEqual(item, found) {
t.Error("expected item and found to be the same")
t.FailNow()
}
}
34 changes: 30 additions & 4 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,22 @@ func (c *conn) roundTrip(ctx context.Context, pc *protocol.Conn, req Request) (R
pprof.SetGoroutineLabels(ctx)
defer pprof.SetGoroutineLabels(context.Background())

// KIP-368
var saslSessionDeadline = pc.GetSaslSessionDeadline()
if !saslSessionDeadline.IsZero() && time.Now().After(saslSessionDeadline) {
host, port, err := splitHostPortNumber(c.address)
if err != nil {
return nil, err
}
metadata := &sasl.Metadata{
Host: host,
Port: port,
}
if err := authenticateSASL(sasl.WithMetadata(ctx, metadata), pc, c.group.pool.sasl); err != nil {
return nil, err
}
}

if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
pc.SetDeadline(deadline)
defer pc.SetDeadline(time.Time{})
Expand All @@ -1286,6 +1302,9 @@ func (c *conn) roundTrip(ctx context.Context, pc *protocol.Conn, req Request) (R
// connection. If any step fails, this function returns with an error. A nil
// error indicates successful authentication.
func authenticateSASL(ctx context.Context, pc *protocol.Conn, mechanism sasl.Mechanism) error {
// reset the SaslSessionDeadline before authenticating
pc.SetSaslSessionDeadline(time.Time{})

if err := saslHandshakeRoundTrip(pc, mechanism.Name()); err != nil {
return err
}
Expand All @@ -1296,7 +1315,7 @@ func authenticateSASL(ctx context.Context, pc *protocol.Conn, mechanism sasl.Mec
}

for completed := false; !completed; {
challenge, err := saslAuthenticateRoundTrip(pc, state)
challenge, sessionLifetimeMs, err := saslAuthenticateRoundTrip(pc, state)
if err != nil {
if errors.Is(err, io.EOF) {
// the broker may communicate a failed exchange by closing the
Expand All @@ -1312,6 +1331,12 @@ func authenticateSASL(ctx context.Context, pc *protocol.Conn, mechanism sasl.Mec
if err != nil {
return err
}

if sessionLifetimeMs > 0 {
// set sasl session deadline to %90 of session lifetime
var saslSessionDeadline = time.Now().Add(time.Duration(float64(sessionLifetimeMs)*0.9) * time.Millisecond)
pc.SetSaslSessionDeadline(saslSessionDeadline)
}
}

return nil
Expand Down Expand Up @@ -1346,18 +1371,19 @@ func saslHandshakeRoundTrip(pc *protocol.Conn, mechanism string) error {
// be immediately preceded by a successful saslHandshake.
//
// See http://kafka.apache.org/protocol.html#The_Messages_SaslAuthenticate
func saslAuthenticateRoundTrip(pc *protocol.Conn, data []byte) ([]byte, error) {
func saslAuthenticateRoundTrip(pc *protocol.Conn, data []byte) ([]byte, int64, error) {
msg, err := pc.RoundTrip(&saslauthenticate.Request{
AuthBytes: data,
})
if err != nil {
return nil, err
return nil, 0, err
}
res := msg.(*saslauthenticate.Response)
if res.ErrorCode != 0 {
err = makeError(res.ErrorCode, res.ErrorMessage)
}
return res.AuthBytes, err

return res.AuthBytes, res.SessionLifetimeMs, err
}

var _ RoundTripper = (*Transport)(nil)

0 comments on commit daa69a4

Please sign in to comment.