Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement EncryptedData struct #3302

Merged
merged 4 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions internal/controlplane/handlers_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (s *Server) GetAuthorizationURL(ctx context.Context,
if err != nil {
return nil, status.Errorf(codes.Internal, "error encrypting redirect URL: %s", err)
}
redirectUrl = sql.NullString{Valid: true, String: encryptedRedirectUrl}
redirectUrl = sql.NullString{Valid: true, String: encryptedRedirectUrl.EncodedData}
}

// Insert the new session state into the database along with the user's project ID
Expand Down Expand Up @@ -246,7 +246,10 @@ func (s *Server) processOAuthCallback(ctx context.Context, w http.ResponseWriter
logger.BusinessRecord(ctx).ProviderID = p.ID

if stateData.RedirectUrl.Valid {
redirectUrl, err := s.cryptoEngine.DecryptString(stateData.RedirectUrl.String)
// TODO: get rid of this once we store the EncryptedData struct in
// the database.
encryptedData := mcrypto.NewBackwardsCompatibleEncryptedData(stateData.RedirectUrl.String)
redirectUrl, err := s.cryptoEngine.DecryptString(encryptedData)
if err != nil {
return fmt.Errorf("error decrypting redirect URL: %w", err)
}
Expand Down Expand Up @@ -325,7 +328,10 @@ func (s *Server) processAppCallback(ctx context.Context, w http.ResponseWriter,

// If we have a redirect URL, redirect the user, otherwise show a success page
if stateData.RedirectUrl.Valid {
redirectUrl, err := s.cryptoEngine.DecryptString(stateData.RedirectUrl.String)
// TODO: get rid of this once we store the EncryptedData struct in
// the database.
encryptedData := mcrypto.NewBackwardsCompatibleEncryptedData(stateData.RedirectUrl.String)
redirectUrl, err := s.cryptoEngine.DecryptString(encryptedData)
if err != nil {
return fmt.Errorf("error decrypting redirect URL: %w", err)
}
Expand Down Expand Up @@ -473,7 +479,7 @@ func (s *Server) StoreProviderToken(ctx context.Context,
_, err = s.store.UpsertAccessToken(ctx, db.UpsertAccessTokenParams{
ProjectID: projectID,
Provider: provider.Name,
EncryptedToken: encryptedToken,
EncryptedToken: encryptedToken.EncodedData,
OwnerFilter: owner,
})
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/controlplane/handlers_oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ func TestProviderCallback(t *testing.T) {
}
encryptedUrl := sql.NullString{
Valid: true,
String: encryptedUrlString,
String: encryptedUrlString.EncodedData,
}

tx := sql.Tx{}
Expand Down
41 changes: 13 additions & 28 deletions internal/crypto/algorithm.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"crypto/cipher"
"crypto/rand"
"errors"
"fmt"
"io"

"golang.org/x/crypto/argon2"
Expand All @@ -29,36 +28,22 @@ import (

// EncryptionAlgorithm represents a crypto algorithm used by the Engine
type EncryptionAlgorithm interface {
Encrypt(data []byte) ([]byte, error)
Decrypt(data []byte) ([]byte, error)
Encrypt(data []byte, salt []byte) ([]byte, error)
Decrypt(data []byte, salt []byte) ([]byte, error)
}

const maxSize = 32 * 1024 * 1024

// In a real application, you should use a unique salt for
// each key and save it with the encrypted data.
var (
salt = []byte("somesalt")
errUnknownAlgorithm = errors.New("unexpected encryption algorithm")
)

// EncryptionAlgorithmType is an enum of supported encryption algorithms
type EncryptionAlgorithmType string

const (
// AESCFB is the AES-CFB algorithm
AESCFB EncryptionAlgorithmType = "aes-cfb"
// Aes256Cfb is the AES-256-CFB algorithm
Aes256Cfb EncryptionAlgorithmType = "aes-256-cfb"
)

// AlgorithmTypeFromString converts a string to an EncryptionAlgorithmType
// or returns errUnknownAlgorithm.
func AlgorithmTypeFromString(input string) (EncryptionAlgorithmType, error) {
// for backwards compatibility - default to AES-CFB if string is empty
if input == "" || input == string(AESCFB) {
return AESCFB, nil
}
return "", fmt.Errorf("%w: %s", errUnknownAlgorithm, input)
}
const maxSize = 32 * 1024 * 1024

// ErrUnknownAlgorithm is used when an incorrect algorithm name is used.
var ErrUnknownAlgorithm = errors.New("unexpected encryption algorithm")

func newAlgorithm(key []byte) EncryptionAlgorithm {
// TODO: Make the type of algorithm selectable
Expand All @@ -70,11 +55,11 @@ type aesCFBSAlgorithm struct {
}

// Encrypt encrypts a row of data.
func (a *aesCFBSAlgorithm) Encrypt(data []byte) ([]byte, error) {
func (a *aesCFBSAlgorithm) Encrypt(data []byte, salt []byte) ([]byte, error) {
if len(data) > maxSize {
return nil, status.Errorf(codes.InvalidArgument, "data is too large (>32MB)")
}
block, err := aes.NewCipher(a.deriveKey())
block, err := aes.NewCipher(a.deriveKey(salt))
if err != nil {
return nil, status.Errorf(codes.Unknown, "failed to create cipher: %s", err)
}
Expand All @@ -93,8 +78,8 @@ func (a *aesCFBSAlgorithm) Encrypt(data []byte) ([]byte, error) {
}

// Decrypt decrypts a row of data.
func (a *aesCFBSAlgorithm) Decrypt(ciphertext []byte) ([]byte, error) {
block, err := aes.NewCipher(a.deriveKey())
func (a *aesCFBSAlgorithm) Decrypt(ciphertext []byte, salt []byte) ([]byte, error) {
block, err := aes.NewCipher(a.deriveKey(salt))
if err != nil {
return nil, status.Errorf(codes.Unknown, "failed to create cipher: %s", err)
}
Expand All @@ -110,6 +95,6 @@ func (a *aesCFBSAlgorithm) Decrypt(ciphertext []byte) ([]byte, error) {
}

// Function to derive a key from a passphrase using Argon2
func (a *aesCFBSAlgorithm) deriveKey() []byte {
func (a *aesCFBSAlgorithm) deriveKey(salt []byte) []byte {
return argon2.IDKey(a.encryptionKey, salt, 1, 64*1024, 4, 32)
}
105 changes: 64 additions & 41 deletions internal/crypto/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,25 @@ import (

// Engine provides all functions to encrypt and decrypt data
type Engine interface {
EncryptOAuthToken(token *oauth2.Token) (string, error)
DecryptOAuthToken(encToken string) (oauth2.Token, error)
EncryptString(data string) (string, error)
DecryptString(encData string) (string, error)
// EncryptOAuthToken takes an OAuth2 token, serializes to JSON and encrypts it.
EncryptOAuthToken(token *oauth2.Token) (EncryptedData, error)
// DecryptOAuthToken takes an OAuth2 token encrypted using EncryptOAuthToken and decrypts it.
DecryptOAuthToken(encryptedToken EncryptedData) (oauth2.Token, error)
// EncryptString encrypts a string.
EncryptString(data string) (EncryptedData, error)
// DecryptString decrypts a string encrypted with EncryptString.
DecryptString(encryptedString EncryptedData) (string, error)
}

var (
// TODO: get rid of this when we allow per-secret salting.
legacySalt = []byte("somesalt")
// ErrDecrypt is returned when we cannot decrypt a secret.
ErrDecrypt = errors.New("unable to decrypt")
// ErrEncrypt is returned when we cannot encrypt a secret.
ErrEncrypt = errors.New("unable to encrypt")
)

type engine struct {
algorithm EncryptionAlgorithm
}
Expand All @@ -60,69 +73,79 @@ func NewEngine(key []byte) Engine {
return &engine{algorithm: newAlgorithm(key)}
}

// EncryptOAuthToken encrypts an oauth token
func (e *engine) EncryptOAuthToken(token *oauth2.Token) (string, error) {
// Convert token to JSON
func (e *engine) EncryptOAuthToken(token *oauth2.Token) (EncryptedData, error) {
// Convert token to JSON.
jsonData, err := json.Marshal(token)
if err != nil {
return "", fmt.Errorf("unable to marshal token to json: %w", err)
return EncryptedData{}, fmt.Errorf("unable to marshal token to json: %w", err)
}
encrypted, err := e.algorithm.Encrypt(jsonData)

// Encrypt the JSON.
encrypted, err := e.encrypt(jsonData)
if err != nil {
return "", fmt.Errorf("unable to encrypt token: %w", err)
return EncryptedData{}, fmt.Errorf("unable to encrypt token: %w", err)
}
return base64.StdEncoding.EncodeToString(encrypted), nil
return encrypted, nil
}

// DecryptOAuthToken decrypts an encrypted oauth token
func (e *engine) DecryptOAuthToken(encToken string) (oauth2.Token, error) {
var decryptedToken oauth2.Token
func (e *engine) DecryptOAuthToken(encryptedToken EncryptedData) (result oauth2.Token, err error) {
// Decrypt the token.
token, err := e.decrypt(encryptedToken)
if err != nil {
return result, err
}

// base64 decode the token
decodeToken, err := base64.StdEncoding.DecodeString(encToken)
// Deserialize to token struct.
err = json.Unmarshal(token, &result)
if err != nil {
return decryptedToken, err
return result, err
}
return result, nil
}

// decrypt the token
token, err := e.algorithm.Decrypt(decodeToken)
func (e *engine) EncryptString(data string) (EncryptedData, error) {
encrypted, err := e.encrypt([]byte(data))
if err != nil {
return decryptedToken, err
return EncryptedData{}, err
}
return encrypted, nil
}

// serialise token *oauth.Token
err = json.Unmarshal(token, &decryptedToken)
func (e *engine) DecryptString(encryptedString EncryptedData) (string, error) {
decrypted, err := e.decrypt(encryptedString)
if err != nil {
return decryptedToken, err
return "", fmt.Errorf("%w: %w", ErrDecrypt, err)
}
return decryptedToken, nil
return string(decrypted), nil
}

// EncryptString encrypts a string
func (e *engine) EncryptString(data string) (string, error) {
encrypted, err := e.algorithm.Encrypt([]byte(data))
func (e *engine) encrypt(data []byte) (EncryptedData, error) {
encrypted, err := e.algorithm.Encrypt(data, legacySalt)
if err != nil {
return "", err
return EncryptedData{}, err
}

return base64.StdEncoding.EncodeToString(encrypted), nil
encoded := base64.StdEncoding.EncodeToString(encrypted)
// TODO:
// 1. when we support more than one algorithm, remove hard-coding.
// 2. Allow salt to be randomly generated per secret.
// 3. Set key version.
return NewBackwardsCompatibleEncryptedData(encoded), nil
}

// DecryptString decrypts an encrypted string
func (e *engine) DecryptString(encData string) (string, error) {
var decrypted string

// base64 decode the string
decodeToken, err := base64.StdEncoding.DecodeString(encData)
if err != nil {
return decrypted, err
func (e *engine) decrypt(data EncryptedData) ([]byte, error) {
// TODO: Select algorithm based on Algorithm field when we support
// more than one algorithm.
if data.Algorithm != Aes256Cfb {
return nil, fmt.Errorf("%w: %s", ErrUnknownAlgorithm, data.Algorithm)
}

// decrypt the string
token, err := e.algorithm.Decrypt(decodeToken)
// base64 decode the string
encrypted, err := base64.StdEncoding.DecodeString(data.EncodedData)
if err != nil {
return decrypted, err
return nil, err
}

return string(token), nil
// decrypt the data
return e.algorithm.Decrypt(encrypted, data.Salt)
}
25 changes: 13 additions & 12 deletions internal/crypto/mock/engine.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading