Skip to content

Commit

Permalink
Add SCRAM authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
jackc committed Apr 17, 2019
1 parent 5c96798 commit 5044e84
Show file tree
Hide file tree
Showing 6 changed files with 439 additions and 0 deletions.
255 changes: 255 additions & 0 deletions auth_scram.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
// SCRAM-SHA-256 authentication
//
// Resources:
// https://tools.ietf.org/html/rfc5802
// https://tools.ietf.org/html/rfc8265
// https://www.postgresql.org/docs/current/sasl-authentication.html
//
// Inspiration drawn from other implementations:
// https://github.com/lib/pq/pull/608
// https://github.com/lib/pq/pull/788
// https://github.com/lib/pq/pull/833
package pgx

import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"strconv"

"github.com/jackc/pgx/pgproto3"
"golang.org/x/crypto/pbkdf2"
"golang.org/x/text/secure/precis"
)

const clientNonceLen = 18

// Perform SCRAM authentication.
func (c *Conn) scramAuth(serverAuthMechanisms []string) error {
sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
if err != nil {
return err
}

// Send client-first-message in a SASLInitialResponse
saslInitialResponse := &pgproto3.SASLInitialResponse{
AuthMechanism: "SCRAM-SHA-256",
Data: sc.clientFirstMessage(),
}
_, err = c.conn.Write(saslInitialResponse.Encode(nil))
if err != nil {
return err
}

// Receive server-first-message payload in a AuthenticationSASLContinue.
authMsg, err := c.rxAuthMsg(pgproto3.AuthTypeSASLContinue)
if err != nil {
return err
}
err = sc.recvServerFirstMessage(authMsg.SASLData)
if err != nil {
return err
}

// Send client-final-message in a SASLResponse
saslResponse := &pgproto3.SASLResponse{
Data: []byte(sc.clientFinalMessage()),
}
_, err = c.conn.Write(saslResponse.Encode(nil))
if err != nil {
return err
}

// Receive server-final-message payload in a AuthenticationSASLFinal.
authMsg, err = c.rxAuthMsg(pgproto3.AuthTypeSASLFinal)
if err != nil {
return err
}
return sc.recvServerFinalMessage(authMsg.SASLData)
}

func (c *Conn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) {
msg, err := c.rxMsg()
if err != nil {
return nil, err
}
authMsg, ok := msg.(*pgproto3.Authentication)
if !ok {
return nil, errors.New("unexpected message type")
}
if authMsg.Type != typ {
return nil, errors.New("unexpected auth type")
}

return authMsg, nil
}

type scramClient struct {
serverAuthMechanisms []string
password []byte
clientNonce []byte

clientFirstMessageBare []byte

serverFirstMessage []byte
clientAndServerNonce []byte
salt []byte
iterations int

saltedPassword []byte
authMessage []byte
}

func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
sc := &scramClient{
serverAuthMechanisms: serverAuthMechanisms,
}

// Ensure server supports SCRAM-SHA-256
hasScramSHA256 := false
for _, mech := range sc.serverAuthMechanisms {
if mech == "SCRAM-SHA-256" {
hasScramSHA256 = true
break
}
}
if !hasScramSHA256 {
return nil, errors.New("server does not support SCRAM-SHA-256")
}

// precis.OpaqueString is equivalent to SASLprep for password.
var err error
sc.password, err = precis.OpaqueString.Bytes([]byte(password))
if err != nil {
// PostgreSQL allows passwords invalid according to SCRAM / SASLprep.
sc.password = []byte(password)
}

buf := make([]byte, clientNonceLen)
_, err = rand.Read(buf)
if err != nil {
return nil, err
}
sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf)))
base64.RawStdEncoding.Encode(sc.clientNonce, buf)

return sc, nil
}

func (sc *scramClient) clientFirstMessage() []byte {
sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
}

func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
sc.serverFirstMessage = serverFirstMessage
buf := serverFirstMessage
if !bytes.HasPrefix(buf, []byte("r=")) {
return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
}
buf = buf[2:]

idx := bytes.IndexByte(buf, ',')
if idx == -1 {
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
}
sc.clientAndServerNonce = buf[:idx]
buf = buf[idx+1:]

if !bytes.HasPrefix(buf, []byte("s=")) {
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
}
buf = buf[2:]

idx = bytes.IndexByte(buf, ',')
if idx == -1 {
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
}
saltStr := buf[:idx]
buf = buf[idx+1:]

if !bytes.HasPrefix(buf, []byte("i=")) {
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
}
buf = buf[2:]
iterationsStr := buf

var err error
sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
if err != nil {
return fmt.Errorf("invalid SCRAM salt received from server: %v", err)
}

sc.iterations, err = strconv.Atoi(string(iterationsStr))
if err != nil || sc.iterations <= 0 {
return fmt.Errorf("invalid SCRAM iteration count received from server: %s", iterationsStr)
}

if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
return errors.New("invalid SCRAM nonce: did not start with client nonce")
}

if len(sc.clientAndServerNonce) <= len(sc.clientNonce) {
return errors.New("invalid SCRAM nonce: did not include server nonce")
}

return nil
}

func (sc *scramClient) clientFinalMessage() string {
clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))

sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))

clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)

return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
}

func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {
if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) {
return errors.New("invalid SCRAM server-final-message received from server")
}

serverSignature := serverFinalMessage[2:]

if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
return errors.New("invalid SCRAM ServerSignature received from server")
}

return nil
}

func computeHMAC(key, msg []byte) []byte {
mac := hmac.New(sha256.New, key)
mac.Write(msg)
return mac.Sum(nil)
}

func computeClientProof(saltedPassword, authMessage []byte) []byte {
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
storedKey := sha256.Sum256(clientKey)
clientSignature := computeHMAC(storedKey[:], authMessage)

clientProof := make([]byte, len(clientSignature))
for i := 0; i < len(clientSignature); i++ {
clientProof[i] = clientKey[i] ^ clientSignature[i]
}

buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
base64.StdEncoding.Encode(buf, clientProof)
return buf
}

func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
serverSignature := computeHMAC(serverKey[:], authMessage)
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
base64.StdEncoding.Encode(buf, serverSignature)
return buf
}
2 changes: 2 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,8 @@ func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
case pgproto3.AuthTypeMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+string(msg.Salt[:]))
err = c.txPasswordMessage(digestedPassword)
case pgproto3.AuthTypeSASL:
err = c.scramAuth(msg.SASLAuthMechanisms)
default:
err = errors.New("Received unknown authentication message")
}
Expand Down
50 changes: 50 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,56 @@ func TestConnectWithMD5Password(t *testing.T) {
}
}

func TestConnectWithSCRAMPassword(t *testing.T) {
t.Parallel()

connString := os.Getenv("PGX_TEST_SCRAM_PASSWORD_CONN_STRING")
if connString == "" {
t.Skip("Skipping due to missing PGX_TEST_SCRAM_PASSWORD_CONN_STRING env var")
}

connConfig, err := pgx.ParseConnectionString(connString)
if err != nil {
t.Fatalf("Unable to parse config: %v", err)
}

conn, err := pgx.Connect(connConfig)
if err != nil {
t.Fatalf("Unable to establish connection: %v", err)
}

if _, present := conn.RuntimeParams["server_version"]; !present {
t.Error("Runtime parameters not stored")
}

if conn.PID() == 0 {
t.Error("Backend PID not stored")
}

var currentDB string
err = conn.QueryRow("select current_database()").Scan(&currentDB)
if err != nil {
t.Fatalf("QueryRow Scan unexpectedly failed: %v", err)
}
if currentDB != defaultConnConfig.Database {
t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database)
}

var user string
err = conn.QueryRow("select current_user").Scan(&user)
if err != nil {
t.Fatalf("QueryRow Scan unexpectedly failed: %v", err)
}
if user != defaultConnConfig.User {
t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User)
}

err = conn.Close()
if err != nil {
t.Fatal("Unable to close connection")
}
}

func TestConnectWithTLSFallback(t *testing.T) {
t.Parallel()

Expand Down
30 changes: 30 additions & 0 deletions pgproto3/authentication.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pgproto3

import (
"bytes"
"encoding/binary"

"github.com/jackc/pgx/pgio"
Expand All @@ -11,13 +12,22 @@ const (
AuthTypeOk = 0
AuthTypeCleartextPassword = 3
AuthTypeMD5Password = 5
AuthTypeSASL = 10
AuthTypeSASLContinue = 11
AuthTypeSASLFinal = 12
)

type Authentication struct {
Type uint32

// MD5Password fields
Salt [4]byte

// SASL fields
SASLAuthMechanisms []string

// SASLContinue and SASLFinal data
SASLData []byte
}

func (*Authentication) Backend() {}
Expand All @@ -30,6 +40,17 @@ func (dst *Authentication) Decode(src []byte) error {
case AuthTypeCleartextPassword:
case AuthTypeMD5Password:
copy(dst.Salt[:], src[4:8])
case AuthTypeSASL:
authMechanisms := src[4:]
for len(authMechanisms) > 1 {
idx := bytes.IndexByte(authMechanisms, 0)
if idx > 0 {
dst.SASLAuthMechanisms = append(dst.SASLAuthMechanisms, string(authMechanisms[:idx]))
authMechanisms = authMechanisms[idx+1:]
}
}
case AuthTypeSASLContinue, AuthTypeSASLFinal:
dst.SASLData = src[4:]
default:
return errors.Errorf("unknown authentication type: %d", dst.Type)
}
Expand All @@ -46,6 +67,15 @@ func (src *Authentication) Encode(dst []byte) []byte {
switch src.Type {
case AuthTypeMD5Password:
dst = append(dst, src.Salt[:]...)
case AuthTypeSASL:
for _, s := range src.SASLAuthMechanisms {
dst = append(dst, []byte(s)...)
dst = append(dst, 0)
}
dst = append(dst, 0)
case AuthTypeSASLContinue:
dst = pgio.AppendInt32(dst, int32(len(src.SASLData)))
dst = append(dst, src.SASLData...)
}

pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
Expand Down
Loading

0 comments on commit 5044e84

Please sign in to comment.