diff --git a/auth_scram.go b/auth_scram.go new file mode 100644 index 000000000..8ac8a82b5 --- /dev/null +++ b/auth_scram.go @@ -0,0 +1,255 @@ +// SCRAM-SHA-256 authentication +// +// Resources: +// https://tools.ietf.org/html/rfc5802 +// https://tools.ietf.org/html/rfc8265 +// https://www.postgresql.org/docs/current/sasl-authentication.html +// +// Inspiration drawn from other implementations: +// https://github.com/lib/pq/pull/608 +// https://github.com/lib/pq/pull/788 +// https://github.com/lib/pq/pull/833 +package pgx + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "strconv" + + "github.com/jackc/pgx/pgproto3" + "golang.org/x/crypto/pbkdf2" + "golang.org/x/text/secure/precis" +) + +const clientNonceLen = 18 + +// Perform SCRAM authentication. +func (c *Conn) scramAuth(serverAuthMechanisms []string) error { + sc, err := newScramClient(serverAuthMechanisms, c.config.Password) + if err != nil { + return err + } + + // Send client-first-message in a SASLInitialResponse + saslInitialResponse := &pgproto3.SASLInitialResponse{ + AuthMechanism: "SCRAM-SHA-256", + Data: sc.clientFirstMessage(), + } + _, err = c.conn.Write(saslInitialResponse.Encode(nil)) + if err != nil { + return err + } + + // Receive server-first-message payload in a AuthenticationSASLContinue. + authMsg, err := c.rxAuthMsg(pgproto3.AuthTypeSASLContinue) + if err != nil { + return err + } + err = sc.recvServerFirstMessage(authMsg.SASLData) + if err != nil { + return err + } + + // Send client-final-message in a SASLResponse + saslResponse := &pgproto3.SASLResponse{ + Data: []byte(sc.clientFinalMessage()), + } + _, err = c.conn.Write(saslResponse.Encode(nil)) + if err != nil { + return err + } + + // Receive server-final-message payload in a AuthenticationSASLFinal. + authMsg, err = c.rxAuthMsg(pgproto3.AuthTypeSASLFinal) + if err != nil { + return err + } + return sc.recvServerFinalMessage(authMsg.SASLData) +} + +func (c *Conn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) { + msg, err := c.rxMsg() + if err != nil { + return nil, err + } + authMsg, ok := msg.(*pgproto3.Authentication) + if !ok { + return nil, errors.New("unexpected message type") + } + if authMsg.Type != typ { + return nil, errors.New("unexpected auth type") + } + + return authMsg, nil +} + +type scramClient struct { + serverAuthMechanisms []string + password []byte + clientNonce []byte + + clientFirstMessageBare []byte + + serverFirstMessage []byte + clientAndServerNonce []byte + salt []byte + iterations int + + saltedPassword []byte + authMessage []byte +} + +func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) { + sc := &scramClient{ + serverAuthMechanisms: serverAuthMechanisms, + } + + // Ensure server supports SCRAM-SHA-256 + hasScramSHA256 := false + for _, mech := range sc.serverAuthMechanisms { + if mech == "SCRAM-SHA-256" { + hasScramSHA256 = true + break + } + } + if !hasScramSHA256 { + return nil, errors.New("server does not support SCRAM-SHA-256") + } + + // precis.OpaqueString is equivalent to SASLprep for password. + var err error + sc.password, err = precis.OpaqueString.Bytes([]byte(password)) + if err != nil { + // PostgreSQL allows passwords invalid according to SCRAM / SASLprep. + sc.password = []byte(password) + } + + buf := make([]byte, clientNonceLen) + _, err = rand.Read(buf) + if err != nil { + return nil, err + } + sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf))) + base64.RawStdEncoding.Encode(sc.clientNonce, buf) + + return sc, nil +} + +func (sc *scramClient) clientFirstMessage() []byte { + sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce)) + return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare)) +} + +func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { + sc.serverFirstMessage = serverFirstMessage + buf := serverFirstMessage + if !bytes.HasPrefix(buf, []byte("r=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include r=") + } + buf = buf[2:] + + idx := bytes.IndexByte(buf, ',') + if idx == -1 { + return errors.New("invalid SCRAM server-first-message received from server: did not include s=") + } + sc.clientAndServerNonce = buf[:idx] + buf = buf[idx+1:] + + if !bytes.HasPrefix(buf, []byte("s=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include s=") + } + buf = buf[2:] + + idx = bytes.IndexByte(buf, ',') + if idx == -1 { + return errors.New("invalid SCRAM server-first-message received from server: did not include i=") + } + saltStr := buf[:idx] + buf = buf[idx+1:] + + if !bytes.HasPrefix(buf, []byte("i=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include i=") + } + buf = buf[2:] + iterationsStr := buf + + var err error + sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr)) + if err != nil { + return fmt.Errorf("invalid SCRAM salt received from server: %v", err) + } + + sc.iterations, err = strconv.Atoi(string(iterationsStr)) + if err != nil || sc.iterations <= 0 { + return fmt.Errorf("invalid SCRAM iteration count received from server: %s", iterationsStr) + } + + if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) { + return errors.New("invalid SCRAM nonce: did not start with client nonce") + } + + if len(sc.clientAndServerNonce) <= len(sc.clientNonce) { + return errors.New("invalid SCRAM nonce: did not include server nonce") + } + + return nil +} + +func (sc *scramClient) clientFinalMessage() string { + clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce)) + + sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New) + sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(",")) + + clientProof := computeClientProof(sc.saltedPassword, sc.authMessage) + + return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof) +} + +func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error { + if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) { + return errors.New("invalid SCRAM server-final-message received from server") + } + + serverSignature := serverFinalMessage[2:] + + if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) { + return errors.New("invalid SCRAM ServerSignature received from server") + } + + return nil +} + +func computeHMAC(key, msg []byte) []byte { + mac := hmac.New(sha256.New, key) + mac.Write(msg) + return mac.Sum(nil) +} + +func computeClientProof(saltedPassword, authMessage []byte) []byte { + clientKey := computeHMAC(saltedPassword, []byte("Client Key")) + storedKey := sha256.Sum256(clientKey) + clientSignature := computeHMAC(storedKey[:], authMessage) + + clientProof := make([]byte, len(clientSignature)) + for i := 0; i < len(clientSignature); i++ { + clientProof[i] = clientKey[i] ^ clientSignature[i] + } + + buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof))) + base64.StdEncoding.Encode(buf, clientProof) + return buf +} + +func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte { + serverKey := computeHMAC(saltedPassword, []byte("Server Key")) + serverSignature := computeHMAC(serverKey[:], authMessage) + buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) + base64.StdEncoding.Encode(buf, serverSignature) + return buf +} diff --git a/conn.go b/conn.go index d500b1327..cb24748c7 100644 --- a/conn.go +++ b/conn.go @@ -1425,6 +1425,8 @@ func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { case pgproto3.AuthTypeMD5Password: digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+string(msg.Salt[:])) err = c.txPasswordMessage(digestedPassword) + case pgproto3.AuthTypeSASL: + err = c.scramAuth(msg.SASLAuthMechanisms) default: err = errors.New("Received unknown authentication message") } diff --git a/conn_test.go b/conn_test.go index c745d392d..6ca00c6d5 100644 --- a/conn_test.go +++ b/conn_test.go @@ -212,6 +212,56 @@ func TestConnectWithMD5Password(t *testing.T) { } } +func TestConnectWithSCRAMPassword(t *testing.T) { + t.Parallel() + + connString := os.Getenv("PGX_TEST_SCRAM_PASSWORD_CONN_STRING") + if connString == "" { + t.Skip("Skipping due to missing PGX_TEST_SCRAM_PASSWORD_CONN_STRING env var") + } + + connConfig, err := pgx.ParseConnectionString(connString) + if err != nil { + t.Fatalf("Unable to parse config: %v", err) + } + + conn, err := pgx.Connect(connConfig) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + + if _, present := conn.RuntimeParams["server_version"]; !present { + t.Error("Runtime parameters not stored") + } + + if conn.PID() == 0 { + t.Error("Backend PID not stored") + } + + var currentDB string + err = conn.QueryRow("select current_database()").Scan(¤tDB) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if currentDB != defaultConnConfig.Database { + t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) + } + + var user string + err = conn.QueryRow("select current_user").Scan(&user) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if user != defaultConnConfig.User { + t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + func TestConnectWithTLSFallback(t *testing.T) { t.Parallel() diff --git a/pgproto3/authentication.go b/pgproto3/authentication.go index 77750b862..5f698d0c9 100644 --- a/pgproto3/authentication.go +++ b/pgproto3/authentication.go @@ -1,6 +1,7 @@ package pgproto3 import ( + "bytes" "encoding/binary" "github.com/jackc/pgx/pgio" @@ -11,6 +12,9 @@ const ( AuthTypeOk = 0 AuthTypeCleartextPassword = 3 AuthTypeMD5Password = 5 + AuthTypeSASL = 10 + AuthTypeSASLContinue = 11 + AuthTypeSASLFinal = 12 ) type Authentication struct { @@ -18,6 +22,12 @@ type Authentication struct { // MD5Password fields Salt [4]byte + + // SASL fields + SASLAuthMechanisms []string + + // SASLContinue and SASLFinal data + SASLData []byte } func (*Authentication) Backend() {} @@ -30,6 +40,17 @@ func (dst *Authentication) Decode(src []byte) error { case AuthTypeCleartextPassword: case AuthTypeMD5Password: copy(dst.Salt[:], src[4:8]) + case AuthTypeSASL: + authMechanisms := src[4:] + for len(authMechanisms) > 1 { + idx := bytes.IndexByte(authMechanisms, 0) + if idx > 0 { + dst.SASLAuthMechanisms = append(dst.SASLAuthMechanisms, string(authMechanisms[:idx])) + authMechanisms = authMechanisms[idx+1:] + } + } + case AuthTypeSASLContinue, AuthTypeSASLFinal: + dst.SASLData = src[4:] default: return errors.Errorf("unknown authentication type: %d", dst.Type) } @@ -46,6 +67,15 @@ func (src *Authentication) Encode(dst []byte) []byte { switch src.Type { case AuthTypeMD5Password: dst = append(dst, src.Salt[:]...) + case AuthTypeSASL: + for _, s := range src.SASLAuthMechanisms { + dst = append(dst, []byte(s)...) + dst = append(dst, 0) + } + dst = append(dst, 0) + case AuthTypeSASLContinue: + dst = pgio.AppendInt32(dst, int32(len(src.SASLData))) + dst = append(dst, src.SASLData...) } pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) diff --git a/pgproto3/sasl_initial_response.go b/pgproto3/sasl_initial_response.go new file mode 100644 index 000000000..f58a6b566 --- /dev/null +++ b/pgproto3/sasl_initial_response.go @@ -0,0 +1,64 @@ +package pgproto3 + +import ( + "bytes" + "encoding/hex" + "encoding/json" + "errors" + + "github.com/jackc/pgx/pgio" +) + +type SASLInitialResponse struct { + AuthMechanism string + Data []byte +} + +func (*SASLInitialResponse) Frontend() {} + +func (dst *SASLInitialResponse) Decode(src []byte) error { + *dst = SASLInitialResponse{} + + rp := 0 + + idx := bytes.IndexByte(src, 0) + if idx < 0 { + return errors.New("invalid SASLInitialResponse") + } + + dst.AuthMechanism = string(src[rp:idx]) + rp = idx + 1 + + rp += 4 // The rest of the message is data so we can just skip the size + dst.Data = src[rp:] + + return nil +} + +func (src *SASLInitialResponse) Encode(dst []byte) []byte { + dst = append(dst, 'p') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, []byte(src.AuthMechanism)...) + dst = append(dst, 0) + + dst = pgio.AppendInt32(dst, int32(len(src.Data))) + dst = append(dst, src.Data...) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *SASLInitialResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + AuthMechanism string + Data string + }{ + Type: "SASLInitialResponse", + AuthMechanism: src.AuthMechanism, + Data: hex.EncodeToString(src.Data), + }) +} diff --git a/pgproto3/sasl_response.go b/pgproto3/sasl_response.go new file mode 100644 index 000000000..ed96686b4 --- /dev/null +++ b/pgproto3/sasl_response.go @@ -0,0 +1,38 @@ +package pgproto3 + +import ( + "encoding/hex" + "encoding/json" + + "github.com/jackc/pgx/pgio" +) + +type SASLResponse struct { + Data []byte +} + +func (*SASLResponse) Frontend() {} + +func (dst *SASLResponse) Decode(src []byte) error { + *dst = SASLResponse{Data: src} + return nil +} + +func (src *SASLResponse) Encode(dst []byte) []byte { + dst = append(dst, 'p') + dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) + + dst = append(dst, src.Data...) + + return dst +} + +func (src *SASLResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "SASLResponse", + Data: hex.EncodeToString(src.Data), + }) +}