Skip to content

Commit

Permalink
fix: postgrest jwt mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
moshloop committed Aug 16, 2024
1 parent 79aa374 commit 6cf7c90
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 18 deletions.
70 changes: 70 additions & 0 deletions api/config.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
package api

import (
"crypto/md5"
"encoding/hex"
"fmt"
"net/url"
"os"
)

var DefaultConfig = Config{
Postgrest: PostgrestConfig{
Version: "v10.0.0",
Expand All @@ -20,6 +28,49 @@ type Config struct {
LogName string
}

func PrintableSecret(secret string) string {
if len(secret) == 0 {
return "<nil>"
} else if len(secret) > 30 {
sum := md5.Sum([]byte(secret))
hash := hex.EncodeToString(sum[:])
return fmt.Sprintf("md5(%s),length=%d", hash[0:8], len(secret))
} else if len(secret) > 16 {
return fmt.Sprintf("%s****%s", secret[0:1], secret[len(secret)-2:])
} else if len(secret) > 10 {
return fmt.Sprintf("****%s", secret[len(secret)-1:])
}
return "****"
}

func readEnv(val string) string {
if v := os.Getenv(val); v != "" {
return v
}
return val
}

func (c Config) ReadEnv() Config {
clone := c
clone.ConnectionString = readEnv(clone.ConnectionString)
if clone.ConnectionString == "DB_URL" {
clone.ConnectionString = ""
}
clone.Schema = readEnv(clone.Schema)
clone.LogLevel = readEnv(clone.LogLevel)
clone.Postgrest = clone.Postgrest.ReadEnv()
return clone
}

func (c Config) String() string {
s := fmt.Sprintf("migrate=%v log=%v postgres=(%s)", !c.SkipMigrations, c.LogLevel, c.Postgrest.String())
if pgUrl, err := url.Parse(c.ConnectionString); err == nil {
s = fmt.Sprintf("url=%s ", pgUrl.Redacted()) + s
}

return s
}

type PostgrestConfig struct {
Port int
Disable bool
Expand All @@ -33,3 +84,22 @@ type PostgrestConfig struct {
// Limits payload size for accidental or malicious requests.
MaxRows int
}

func (p PostgrestConfig) ReadEnv() PostgrestConfig {
clone := p

clone.JWTSecret = readEnv(clone.JWTSecret)
if clone.JWTSecret == "PGRST_JWT_SECRET" {
clone.JWTSecret = ""
}
clone.LogLevel = readEnv(clone.LogLevel)
return clone
}

func (p PostgrestConfig) String() string {
return fmt.Sprintf("version:%v port=%d log-level=%v, jwt=%s",
p.Version,
p.Port,
p.LogLevel,
PrintableSecret(p.JWTSecret))
}
8 changes: 3 additions & 5 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,6 @@ func NewPgxPool(connection string) (*pgxpool.Pool, error) {
return nil, err
}

pgUrl, err := url.Parse(connection)
if err == nil {
logger.Infof("Connecting to %s", pgUrl.Redacted())
}

config, err := pgxpool.ParseConfig(connection)
if err != nil {
return nil, err
Expand Down Expand Up @@ -196,6 +191,9 @@ func InitDB(config api.Config) (*dutyContext.Context, error) {

// SetupDB runs migrations for the connection and returns a gorm.DB and a pgxpool.Pool
func SetupDB(config api.Config) (gormDB *gorm.DB, pgxpool *pgxpool.Pool, err error) {
config = config.ReadEnv()

logger.Infof("Connecting to %s", config)

pgxpool, err = NewPgxPool(config.ConnectionString)
if err != nil {
Expand Down
15 changes: 2 additions & 13 deletions start.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package duty

import (
"net/url"
"os"
"strconv"
"strings"

Expand All @@ -14,21 +13,13 @@ import (
"github.com/spf13/pflag"
)

func readFromEnv(v string) string {
val := os.Getenv(v)
if val != "" {
return val
}
return v
}

func BindPFlags(flags *pflag.FlagSet) {
flags.StringVar(&DefaultConfig.ConnectionString, "db", "DB_URL", "Connection string for the postgres database")
flags.StringVar(&DefaultConfig.Schema, "db-schema", "public", "Postgres schema")
flags.StringVar(&DefaultConfig.Postgrest.URL, "postgrest-uri", "http://localhost:3000", "URL for the PostgREST instance to use. If localhost is supplied, a PostgREST instance will be started")
flags.StringVar(&DefaultConfig.Postgrest.LogLevel, "postgrest-log-level", "info", "PostgREST log level")
flags.StringVar(&DefaultConfig.Postgrest.JWTSecret, "postgrest-jwt-secret", "PGRST_JWT_SECRET", "JWT Secret Token for PostgREST")
flags.BoolVar(&DefaultConfig.SkipMigrations, "skip-migrations", false, "Run database migrations")
flags.BoolVar(&DefaultConfig.SkipMigrations, "skip-migrations", false, "Skip database migrations")
flags.BoolVar(&DefaultConfig.Postgrest.Disable, "disable-postgrest", false, "Disable PostgREST. Deprecated (Use --postgrest-uri '' to disable PostgREST)")
flags.StringVar(&DefaultConfig.Postgrest.DBAnonRole, "postgrest-anon-role", "postgrest-api", "PostgREST anonymous role")
flags.IntVar(&DefaultConfig.Postgrest.MaxRows, "postgrest-max-rows", 2000, "A hard limit to the number of rows PostgREST will fetch")
Expand Down Expand Up @@ -71,6 +62,7 @@ func Start(name string, opts ...StartOption) (context.Context, func(), error) {
for _, opt := range opts {
config = opt(config)
}
config = config.ReadEnv()

if config.Postgrest.URL != "" && !config.Postgrest.Disable {
parsedURL, err := url.Parse(config.Postgrest.URL)
Expand All @@ -88,9 +80,6 @@ func Start(name string, opts ...StartOption) (context.Context, func(), error) {

stop := func() {}

config.ConnectionString = readFromEnv(config.ConnectionString)
config.Schema = readFromEnv(config.Schema)

var ctx context.Context
if c, err := InitDB(config); err != nil {
return context.Context{}, stop, err
Expand Down

0 comments on commit 6cf7c90

Please sign in to comment.