From d3e7cc8ec680333db8d83c1a5413eec2569cc8bc Mon Sep 17 00:00:00 2001 From: Spike Lu Date: Fri, 15 Nov 2024 22:24:14 -0800 Subject: [PATCH] add encryption --- .gitignore | 3 +- CHANGELOG.md | 8 ++ cmd/bricksllm/.env | 2 - cmd/bricksllm/config_local.json | 2 - cmd/bricksllm/main.go | 14 +-- internal/authenticator/authenticator.go | 50 ++++++++-- internal/config/config.go | 11 ++- internal/encryptor/encryptor.go | 120 ++++++++++++++++++++++++ internal/manager/provider_setting.go | 57 ++++++++++- 9 files changed, 237 insertions(+), 30 deletions(-) create mode 100644 internal/encryptor/encryptor.go diff --git a/.gitignore b/.gitignore index ba36687..0f708b6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ release_notes.md target .DS_STORE -.vscode/launch.json \ No newline at end of file +.vscode/launch.json +.env \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 0116c18..7b89999 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,11 @@ +## 1.39.0 - 2024-11-15 +### Added +- Added encryption integration + +### Changed +- Removed support for Redis TLS config + + ## 1.38.0 - 2024-11-09 ### Added - Added support for `claude-3-5-haiku` diff --git a/cmd/bricksllm/.env b/cmd/bricksllm/.env index f24859a..09ae4ee 100644 --- a/cmd/bricksllm/.env +++ b/cmd/bricksllm/.env @@ -7,8 +7,6 @@ POSTGRESQL_PASSWORD= POSTGRESQL_SSL_MODE=disable POSTGRESQL_PORT=5432 REDIS_HOSTS=localhost -REDIS_ENABLE_TLS=false -REDIS_INSECURE_SKIP_VERIFY=false REDIS_PORT=6379 REDIS_USERNAME= REDIS_PASSWORD= diff --git a/cmd/bricksllm/config_local.json b/cmd/bricksllm/config_local.json index b1ec5bd..ae92a01 100644 --- a/cmd/bricksllm/config_local.json +++ b/cmd/bricksllm/config_local.json @@ -7,8 +7,6 @@ "postgresql_port": "5432", "redis_hosts": "localhost", "redis_port": "6379", - "redis_enable_tls": false, - "redis_insecure_skip_verify": false, "redis_username": "", "redis_password": "", "redis_read_time_out": "1s", diff --git a/cmd/bricksllm/main.go b/cmd/bricksllm/main.go index 46166ce..e5d2b7a 100644 --- a/cmd/bricksllm/main.go +++ b/cmd/bricksllm/main.go @@ -2,7 +2,6 @@ package main import ( "context" - "crypto/tls" "flag" "fmt" "os" @@ -13,6 +12,7 @@ import ( auth "github.com/bricks-cloud/bricksllm/internal/authenticator" "github.com/bricks-cloud/bricksllm/internal/cache" "github.com/bricks-cloud/bricksllm/internal/config" + "github.com/bricks-cloud/bricksllm/internal/encryptor" "github.com/bricks-cloud/bricksllm/internal/logger/zap" "github.com/bricks-cloud/bricksllm/internal/manager" "github.com/bricks-cloud/bricksllm/internal/message" @@ -182,12 +182,6 @@ func main() { DB: cfg.RedisDBStartIndex + dbIndex, } - if cfg.RedisEnableTLS { - options.TLSConfig = &tls.Config{ - InsecureSkipVerify: cfg.RedisInsecureSkipVerify, - } - } - return options } @@ -292,9 +286,11 @@ func main() { psCache := redisStorage.NewProviderSettingsCache(providerSettingsRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) keysCache := redisStorage.NewKeysCache(keysRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) + encryptor := encryptor.NewEncryptor(cfg.DecryptionEndpoint, cfg.EncryptionEndpoint, cfg.EnableEncrytion, cfg.EncryptionTimeout) + m := manager.NewManager(store, costLimitCache, rateLimitCache, accessCache, keysCache) krm := manager.NewReportingManager(costStorage, store, store) - psm := manager.NewProviderSettingsManager(store, psCache) + psm := manager.NewProviderSettingsManager(store, psCache, encryptor) cpm := manager.NewCustomProvidersManager(store, cpMemStore) rm := manager.NewRouteManager(store, store, rMemStore, psm) pm := manager.NewPolicyManager(store, rMemStore) @@ -332,7 +328,7 @@ func main() { rec := recorder.NewRecorder(costStorage, userCostStorage, costLimitCache, userCostLimitCache, ce, store) rlm := manager.NewRateLimitManager(rateLimitCache, userRateLimitCache) - a := auth.NewAuthenticator(psm, m, rm, store) + a := auth.NewAuthenticator(psm, m, rm, store, encryptor) c := cache.NewCache(apiCache) diff --git a/internal/authenticator/authenticator.go b/internal/authenticator/authenticator.go index 38c53b8..68959d3 100644 --- a/internal/authenticator/authenticator.go +++ b/internal/authenticator/authenticator.go @@ -5,6 +5,7 @@ import ( "fmt" "math/rand" "net/http" + "strconv" "strings" internal_errors "github.com/bricks-cloud/bricksllm/internal/errors" @@ -34,19 +35,26 @@ type keyStorage interface { GetKeyByHash(hash string) (*key.ResponseKey, error) } +type Decryptor interface { + Decrypt(input string, headers map[string]string) (string, error) + Enabled() bool +} + type Authenticator struct { - psm providerSettingsManager - kc keysCache - rm routesManager - ks keyStorage + psm providerSettingsManager + kc keysCache + rm routesManager + ks keyStorage + decryptor Decryptor } -func NewAuthenticator(psm providerSettingsManager, kc keysCache, rm routesManager, ks keyStorage) *Authenticator { +func NewAuthenticator(psm providerSettingsManager, kc keysCache, rm routesManager, ks keyStorage, decryptor Decryptor) *Authenticator { return &Authenticator{ - psm: psm, - kc: kc, - rm: rm, - ks: ks, + psm: psm, + kc: kc, + rm: rm, + ks: ks, + decryptor: decryptor, } } @@ -268,6 +276,30 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request) (*key.Respons used = selected[rand.Intn(len(selected))] } + if a.decryptor.Enabled() { + encryptedParam := "" + if used.Provider == "amazon" { + encryptedParam = used.Setting["awsSecretAccessKey"] + } else if len(used.Setting["apikey"]) != 0 { + encryptedParam = used.Setting["apikey"] + } + + if len(encryptedParam) != 0 { + decryptedSecret, err := a.decryptor.Decrypt(encryptedParam, map[string]string{"X-UPDATED-AT": strconv.FormatInt(used.UpdatedAt, 10)}) + if err == nil { + if used.Provider == "amazon" { + used.Setting["awsSecretAccessKey"] = decryptedSecret + } else { + used.Setting["apikey"] = decryptedSecret + } + } + + if err != nil { + fmt.Println(fmt.Printf("error when encrypting %v", err)) + } + } + } + err := rewriteHttpAuthHeader(req, used) if err != nil { return nil, nil, err diff --git a/internal/config/config.go b/internal/config/config.go index b1d3f49..ed55462 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,7 @@ package config import ( + "errors" "os" "path/filepath" "time" @@ -25,8 +26,6 @@ type Config struct { RedisPort string `koanf:"redis_port" env:"REDIS_PORT" envDefault:"6379"` RedisUsername string `koanf:"redis_username" env:"REDIS_USERNAME"` RedisPassword string `koanf:"redis_password" env:"REDIS_PASSWORD"` - RedisEnableTLS bool `koanf:"redis_enable_tls" env:"REDIS_ENABLE_TLS" envDefault:"false"` - RedisInsecureSkipVerify bool `koanf:"redis_insecure_skip_verify" env:"REDIS_INSECURE_SKIP_VERIFY" envDefault:"false"` RedisDBStartIndex int `koanf:"redis_db_start_index" env:"REDIS_DB_START_INDEX" envDefault:"0"` RedisReadTimeout time.Duration `koanf:"redis_read_time_out" env:"REDIS_READ_TIME_OUT" envDefault:"1s"` RedisWriteTimeout time.Duration `koanf:"redis_write_time_out" env:"REDIS_WRITE_TIME_OUT" envDefault:"500ms"` @@ -47,6 +46,10 @@ type Config struct { AmazonRequestTimeout time.Duration `koanf:"amazon_request_timeout" env:"AMAZON_REQUEST_TIMEOUT" envDefault:"5s"` AmazonConnectionTimeout time.Duration `koanf:"amazon_connection_timeout" env:"AMAZON_CONNECTION_TIMEOUT" envDefault:"10s"` RemoveUserAgent bool `koanf:"remove_user_agent" env:"REMOVE_USER_AGENT" envDefault:"false"` + EnableEncrytion bool `koanf:"enable_encryption" env:"ENABLE_ENCRYPTION" envDefault:"false"` + EncryptionEndpoint string `koanf:"encryption_endpoint" env:"ENCRYPTION_ENDPOINT"` + DecryptionEndpoint string `koanf:"decryption_endpoint" env:"DECRYPTION_ENDPOINT"` + EncryptionTimeout time.Duration `koanf:"encryption_timeout" env:"ENCRYPTION_TIMEOUT" envDefault:"5s"` } func prepareDotEnv(envFilePath string) error { @@ -82,6 +85,10 @@ func LoadConfig(log *zap.Logger) (*Config, error) { return nil, err } + if cfg.EnableEncrytion && len(cfg.EncryptionEndpoint) == 0 { + return nil, errors.New("encryption endpoint cannot be empty") + } + err = prepareDotEnv(".env") if err != nil { log.Sugar().Infof("error loading config from .env file: %v", err) diff --git a/internal/encryptor/encryptor.go b/internal/encryptor/encryptor.go new file mode 100644 index 0000000..eef5256 --- /dev/null +++ b/internal/encryptor/encryptor.go @@ -0,0 +1,120 @@ +package encryptor + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "time" +) + +type Encryptor struct { + decryptionURL string + encryptionURL string + enabled bool + client http.Client + timeout time.Duration +} + +type Secret struct { + Secret string `json:"secret"` +} + +type EncryptionResponse struct { + EncryptedSecret string `json:"encryptedSecret"` +} + +type DecryptionResponse struct { + DecryptedSecret string `json:"decryptedSecret"` +} + +func NewEncryptor(decryptionURL string, encryptionURL string, enabled bool, timeout time.Duration) Encryptor { + return Encryptor{ + decryptionURL: decryptionURL, + encryptionURL: encryptionURL, + client: http.Client{}, + enabled: enabled, + timeout: timeout, + } +} + +func (e Encryptor) Encrypt(input string, headers map[string]string) (string, error) { + data, err := json.Marshal(Secret{ + Secret: input, + }) + if err != nil { + return "", err + } + + ctx, cancel := context.WithTimeout(context.Background(), e.timeout) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.encryptionURL, bytes.NewBuffer(data)) + if err != nil { + return "", err + } + + for header, value := range headers { + req.Header.Add(header, value) + } + + res, err := e.client.Do(req) + if err != nil { + return "", err + } + + bytes, err := io.ReadAll(res.Body) + if err != nil { + return "", err + } + + encryptionResponse := EncryptionResponse{} + err = json.Unmarshal(bytes, &encryptionResponse) + if err != nil { + return "", err + } + + return encryptionResponse.EncryptedSecret, nil +} + +func (e Encryptor) Enabled() bool { + return e.enabled && len(e.decryptionURL) != 0 && len(e.encryptionURL) != 0 +} + +func (e Encryptor) Decrypt(input string, headers map[string]string) (string, error) { + data, err := json.Marshal(Secret{ + Secret: input, + }) + if err != nil { + return "", err + } + + ctx, cancel := context.WithTimeout(context.Background(), e.timeout) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.decryptionURL, bytes.NewBuffer(data)) + if err != nil { + return "", err + } + + for header, value := range headers { + req.Header.Add(header, value) + } + + res, err := e.client.Do(req) + if err != nil { + return "", err + } + + bytes, err := io.ReadAll(res.Body) + if err != nil { + return "", err + } + + decryptionSecret := DecryptionResponse{} + err = json.Unmarshal(bytes, &decryptionSecret) + if err != nil { + return "", err + } + + return decryptionSecret.DecryptedSecret, nil +} diff --git a/internal/manager/provider_setting.go b/internal/manager/provider_setting.go index 5b246d1..8ed9bca 100644 --- a/internal/manager/provider_setting.go +++ b/internal/manager/provider_setting.go @@ -3,6 +3,7 @@ package manager import ( "encoding/json" "fmt" + "strconv" "strings" "time" @@ -27,15 +28,22 @@ type ProviderSettingsCache interface { Delete(pid string) error } +type Encryptor interface { + Encrypt(input string, headers map[string]string) (string, error) + Enabled() bool +} + type ProviderSettingsManager struct { - Storage ProviderSettingsStorage - Cache ProviderSettingsCache + Storage ProviderSettingsStorage + Cache ProviderSettingsCache + Encryptor Encryptor } -func NewProviderSettingsManager(s ProviderSettingsStorage, cache ProviderSettingsCache) *ProviderSettingsManager { +func NewProviderSettingsManager(s ProviderSettingsStorage, cache ProviderSettingsCache, encryptor Encryptor) *ProviderSettingsManager { return &ProviderSettingsManager{ - Storage: s, - Cache: cache, + Storage: s, + Cache: cache, + Encryptor: encryptor, } } @@ -118,6 +126,27 @@ func (m *ProviderSettingsManager) validateSettings(providerName string, setting return nil } +func (m *ProviderSettingsManager) EncryptParams(updatedAt int64, provider string, params map[string]string) (map[string]string, error) { + if provider == "amazon" { + encryted, err := m.Encryptor.Encrypt(params["awsSecretAccessKey"], map[string]string{"X-UPDATED-AT": strconv.FormatInt(updatedAt, 10)}) + if err != nil { + return nil, err + } + + params["awsSecretAccessKey"] = encryted + + } else if provider == "openai" || provider == "anthropic" || provider == "deepinfra" || provider == "azure" { + encryted, err := m.Encryptor.Encrypt(params["apikey"], map[string]string{"X-UPDATED-AT": strconv.FormatInt(updatedAt, 10)}) + if err != nil { + return nil, err + } + + params["apikey"] = encryted + } + + return params, nil +} + func (m *ProviderSettingsManager) CreateSetting(setting *provider.Setting) (*provider.Setting, error) { if len(setting.Provider) == 0 { return nil, internal_errors.NewValidationError("provider field cannot be empty") @@ -131,6 +160,15 @@ func (m *ProviderSettingsManager) CreateSetting(setting *provider.Setting) (*pro setting.CreatedAt = time.Now().Unix() setting.UpdatedAt = time.Now().Unix() + if m.Encryptor.Enabled() { + params, err := m.EncryptParams(setting.UpdatedAt, setting.Provider, setting.Setting) + if err != nil { + return nil, err + } + + setting.Setting = params + } + return m.Storage.CreateProviderSetting(setting) } @@ -164,6 +202,15 @@ func (m *ProviderSettingsManager) UpdateSetting(id string, setting *provider.Upd telemetry.Incr("bricksllm.provider_settings_manager.update_setting.delete_cache_error", nil, 1) } + if m.Encryptor.Enabled() { + params, err := m.EncryptParams(existing.UpdatedAt, existing.Provider, setting.Setting) + if err != nil { + return nil, err + } + + setting.Setting = params + } + return m.Storage.UpdateProviderSetting(id, setting) }