From 909380a49e60d75b6bcd8a7c409cde7cad1375aa Mon Sep 17 00:00:00 2001 From: moocss Date: Mon, 19 Aug 2024 11:11:30 +0800 Subject: [PATCH] add: auth, subscriptions, template --- authx/jwt/jwt.go | 137 +++++++++++ authx/jwt/jwt_test.go | 75 ++++++ authx/jwt/options.go | 86 +++++++ authx/jwt/store.go | 18 ++ authx/jwt/store/redis/redis.go | 53 ++++ authx/jwt/token.go | 37 +++ authx/types.go | 48 ++++ cmd/van/config/config.go | 2 +- cmd/van/go.mod | 39 +++ cmd/van/go.sum | 71 ++++++ cmd/van/internal/create/create.go | 25 ++ cmd/van/internal/create/model/model.go | 17 ++ cmd/van/internal/create/service/service.go | 17 ++ cmd/van/internal/new/new.go | 86 +++++++ cmd/van/internal/new/project.go | 91 +++++++ cmd/van/internal/rpc/rpc.go | 16 ++ cmd/van/internal/run/run.go | 16 ++ cmd/van/internal/upgrade/upgrade.go | 17 ++ cmd/van/main.go | 38 +++ db/gorm/wrapper.go | 6 + db/redis/wrapper.go | 6 + db/sqlx/wrapper.go | 6 + errors/stack.go | 21 -- errorsx/stack.go | 49 ++++ {errors => errx}/error.go | 17 +- {errors => errx}/error_test.go | 2 +- errx/ignore.go | 16 ++ errx/must.go | 24 ++ errx/result.go | 64 +++++ {errors => errx}/types.go | 2 +- go.mod | 5 +- go.sum | 4 + pkg/hook/hook.go | 126 ++++++++++ pkg/hook/hook_test.go | 179 ++++++++++++++ pkg/inflector/inflector.go | 85 +++++++ pkg/inflector/inflector_test.go | 134 ++++++++++ pkg/ptr/ptr_test.go | 17 ++ pkg/rand/rand.go | 64 +++++ pkg/rand/rand_test.go | 89 +++++++ pkg/types/datetime.go | 105 ++++++++ pkg/types/datetime_test.go | 205 +++++++++++++++ pkg/types/json_array.go | 52 ++++ pkg/types/json_array_test.go | 96 ++++++++ pkg/types/json_map.go | 67 +++++ pkg/types/json_map_test.go | 132 ++++++++++ pkg/types/json_raw.go | 83 +++++++ pkg/types/json_raw_test.go | 178 +++++++++++++ pkg/{kv => types}/kv.go | 2 +- pkg/{kv => types}/kv_test.go | 2 +- pkg/uuid/uuid.go | 24 ++ pkg/uuid/uuid_test.go | 55 +++++ store/store.go | 134 ++++++++++ store/store_test.go | 232 +++++++++++++++++ subscriptions/broker.go | 70 ++++++ subscriptions/broker_test.go | 93 +++++++ subscriptions/client.go | 274 +++++++++++++++++++++ subscriptions/client_test.go | 244 ++++++++++++++++++ template/registry.go | 141 +++++++++++ template/registry_test.go | 250 +++++++++++++++++++ template/renderer.go | 33 +++ template/renderer_test.go | 63 +++++ 61 files changed, 4312 insertions(+), 28 deletions(-) create mode 100644 authx/jwt/jwt.go create mode 100644 authx/jwt/jwt_test.go create mode 100644 authx/jwt/options.go create mode 100644 authx/jwt/store.go create mode 100644 authx/jwt/store/redis/redis.go create mode 100644 authx/jwt/token.go create mode 100644 authx/types.go create mode 100644 cmd/van/go.mod create mode 100644 cmd/van/go.sum create mode 100644 cmd/van/internal/create/model/model.go create mode 100644 cmd/van/internal/create/service/service.go create mode 100644 cmd/van/internal/new/new.go create mode 100644 cmd/van/internal/rpc/rpc.go create mode 100644 cmd/van/internal/run/run.go create mode 100644 cmd/van/internal/upgrade/upgrade.go create mode 100644 cmd/van/main.go delete mode 100644 errors/stack.go create mode 100644 errorsx/stack.go rename {errors => errx}/error.go (73%) rename {errors => errx}/error_test.go (95%) create mode 100644 errx/ignore.go create mode 100644 errx/must.go create mode 100644 errx/result.go rename {errors => errx}/types.go (99%) create mode 100644 pkg/hook/hook.go create mode 100644 pkg/hook/hook_test.go create mode 100644 pkg/inflector/inflector.go create mode 100644 pkg/inflector/inflector_test.go create mode 100644 pkg/rand/rand.go create mode 100644 pkg/rand/rand_test.go create mode 100644 pkg/types/datetime.go create mode 100644 pkg/types/datetime_test.go create mode 100644 pkg/types/json_array.go create mode 100644 pkg/types/json_array_test.go create mode 100644 pkg/types/json_map.go create mode 100644 pkg/types/json_map_test.go create mode 100644 pkg/types/json_raw.go create mode 100644 pkg/types/json_raw_test.go rename pkg/{kv => types}/kv.go (96%) rename pkg/{kv => types}/kv_test.go (96%) create mode 100644 pkg/uuid/uuid.go create mode 100644 pkg/uuid/uuid_test.go create mode 100644 store/store.go create mode 100644 store/store_test.go create mode 100644 subscriptions/broker.go create mode 100644 subscriptions/broker_test.go create mode 100644 subscriptions/client.go create mode 100644 subscriptions/client_test.go create mode 100644 template/registry.go create mode 100644 template/registry_test.go create mode 100644 template/renderer.go create mode 100644 template/renderer_test.go diff --git a/authx/jwt/jwt.go b/authx/jwt/jwt.go new file mode 100644 index 0000000..8dadb98 --- /dev/null +++ b/authx/jwt/jwt.go @@ -0,0 +1,137 @@ +package jwt + +import ( + "context" + "errors" + "time" + + "github.com/apus-run/van/authx" + "github.com/golang-jwt/jwt/v5" +) + +var ( + ErrTokenInvalid = errors.New("token is invalid") + ErrUnSupportSigningMethod = errors.New("wrong signing method") + ErrSignToken = errors.New("can not sign token. is the key correct") + ErrGetKey = errors.New("can not get key while signing token") +) + +// JwtAuth implement the authx.Authenticator interface. + +type JwtAuth struct { + *options + store Storer +} + +func NewJwtAuth(store Storer, opts ...Option) *JwtAuth { + options := Apply(opts...) + return &JwtAuth{ + options: options, + store: store, + } +} + +func (j *JwtAuth) Sign(ctx context.Context) (authx.Token, error) { + now := time.Now() + expiresAt := now.Add(j.expired) + + tokenString, err := j.GenerateToken(ctx) + if err != nil { + return nil, err + } + tokenInfo := &tokenInfo{ + Token: tokenString, + Type: j.tokenType, + ExpiresAt: expiresAt.Unix(), + } + return tokenInfo, nil +} + +func (j *JwtAuth) Destroy(ctx context.Context, refreshToken string) error { + claims, err := j.ParseClaims(ctx, refreshToken) + if err != nil { + return err + } + + // If storage is set, put the unexpired token in + store := func(store Storer) error { + expired := time.Until(claims.ExpiresAt.Time) + return store.Set(ctx, refreshToken, "1", expired) + } + return j.callStore(store) +} + +func (j *JwtAuth) ParseClaims(ctx context.Context, accessToken string) (*jwt.RegisteredClaims, error) { + if accessToken == "" { + return nil, ErrTokenInvalid + } + + token, err := j.ParseToken(ctx, accessToken) + if err != nil { + return nil, err + } + + store := func(store Storer) error { + exists, err := store.Check(ctx, accessToken) + if err != nil { + return err + } + + if exists { + return ErrTokenInvalid + } + + return nil + } + + if err := j.callStore(store); err != nil { + return nil, err + } + + return token.Claims.(*jwt.RegisteredClaims), nil +} + +func (j *JwtAuth) ParseToken(ctx context.Context, accessToken string) (token *jwt.Token, err error) { + if j.claims != nil { + token, err = jwt.ParseWithClaims(accessToken, j.claims(), j.keyfunc) + } else { + token, err = jwt.Parse(accessToken, j.keyfunc) + } + + // 过期的, 伪造的, 都可以认为是无效token + if err != nil || !token.Valid { + return nil, ErrTokenInvalid + } + + if token.Method != j.signingMethod { + return nil, ErrUnSupportSigningMethod + } + + return token, nil +} + +func (j *JwtAuth) GenerateToken(ctx context.Context) (string, error) { + token := jwt.NewWithClaims(j.signingMethod, j.claims()) + if j.tokenHeader != nil { + for k, v := range j.tokenHeader { + token.Header[k] = v + } + } + key, err := j.keyfunc(token) + if err != nil { + return "", ErrGetKey + } + tokenStr, err := token.SignedString(key) + if err != nil { + return "", ErrSignToken + } + + return tokenStr, nil +} + +func (j *JwtAuth) callStore(fn func(Storer) error) error { + if store := j.store; store != nil { + return fn(store) + } + return nil +} diff --git a/authx/jwt/jwt_test.go b/authx/jwt/jwt_test.go new file mode 100644 index 0000000..8a7e04e --- /dev/null +++ b/authx/jwt/jwt_test.go @@ -0,0 +1,75 @@ +package jwt + +import ( + "context" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + + "github.com/apus-run/van/authx/jwt/store/redis" +) + +type CustomClaims struct { + UserID uint64 + + // UserAgent 增强安全性,防止token被盗用 + UserAgent string + + jwt.RegisteredClaims +} + +func TestGenerateToken(t *testing.T) { + +} + +func TestParseToken(t *testing.T) { + +} + +func TestNewJwtAuth(t *testing.T) { + headers := make(map[string]any) + headers["kid"] = "8b5228a5-b3d2-4165-aaac-58a052629846" + now := time.Now() + expiresAt := now.Add(2 * time.Hour) + opts := []Option{ + + WithTokenHeader(headers), + WithExpired(2 * time.Hour), + WithKeyfunc(func(token *jwt.Token) (any, error) { + // Verify that the signing method is HMAC. + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, ErrTokenInvalid + } + return []byte("moyn8y9abnd7q4zkq2m73yw8tu9j5ixm"), nil + }), + WithClaims(func() jwt.Claims { + return &CustomClaims{ + UserID: 1, + UserAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/99.0.4844.83 Safari/537.36", + RegisteredClaims: jwt.RegisteredClaims{ + // Issuer = iss,令牌颁发者。它表示该令牌是由谁创建的 + Issuer: "", + // IssuedAt = iat,令牌颁发时的时间戳。它表示令牌是何时被创建的 + IssuedAt: jwt.NewNumericDate(now), + // ExpiresAt = exp,令牌的过期时间戳。它表示令牌将在何时过期 + ExpiresAt: jwt.NewNumericDate(expiresAt), + // NotBefore = nbf,令牌的生效时的时间戳。它表示令牌从什么时候开始生效 + NotBefore: jwt.NewNumericDate(now), + // Subject = sub,令牌的主体。它表示该令牌是关于谁的 + Subject: "", + }, + } + }), + } + + opts = append(opts, WithSigningMethod(jwt.SigningMethodHS256)) + + store := redis.NewStore(nil, "authx") + + j, err := NewJwtAuth(store, opts...).Sign(context.Background()) + if err != nil { + t.Fatal(err) + } + t.Log(j.GetToken()) +} diff --git a/authx/jwt/options.go b/authx/jwt/options.go new file mode 100644 index 0000000..b75b3c8 --- /dev/null +++ b/authx/jwt/options.go @@ -0,0 +1,86 @@ +package jwt + +import ( + "time" + + "github.com/golang-jwt/jwt/v5" +) + +const ( + // defaultKey holds the default key used to sign a jwt token. + defaultKey = "authx::jwt(#)9527" +) + +// Option is jwt option. +type Option func(*options) + +// Parser is a jwt parser +type options struct { + signingMethod jwt.SigningMethod + claims func() jwt.Claims + tokenHeader map[string]any + + expired time.Duration + keyfunc jwt.Keyfunc + tokenType string +} + +// DefaultOptions . +func DefaultOptions() *options { + return &options{ + tokenType: "Bearer", + expired: 2 * time.Hour, + signingMethod: jwt.SigningMethodHS256, + keyfunc: func(token *jwt.Token) (any, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, ErrTokenInvalid + } + return []byte(defaultKey), nil + }, + } +} + +func Apply(opts ...Option) *options { + options := DefaultOptions() + for _, opt := range opts { + opt(options) + } + return options +} + +// WithSigningMethod with signing method option. +func WithSigningMethod(method jwt.SigningMethod) Option { + return func(o *options) { + o.signingMethod = method + } +} + +// WithClaims with customer claim +// If you use it in Server, f needs to return a new jwt.Claims object each time to avoid concurrent write problems +// If you use it in Client, f only needs to return a single object to provide performance +func WithClaims(f func() jwt.Claims) Option { + return func(o *options) { + o.claims = f + } +} + +// WithTokenHeader withe customer tokenHeader for client side +func WithTokenHeader(header map[string]any) Option { + return func(o *options) { + o.tokenHeader = header + } +} + +// WithKeyfunc set the callback function for verifying the key. +func WithKeyfunc(keyFunc jwt.Keyfunc) Option { + return func(o *options) { + o.keyfunc = keyFunc + } +} + +// WithExpired set the token expiration time (in seconds, default 2h). +func WithExpired(expired time.Duration) Option { + return func(o *options) { + o.expired = expired + } +} diff --git a/authx/jwt/store.go b/authx/jwt/store.go new file mode 100644 index 0000000..a6d0b94 --- /dev/null +++ b/authx/jwt/store.go @@ -0,0 +1,18 @@ +package jwt + +import ( + "context" + "time" +) + +// Storer token storage interface. +type Storer interface { + // Set Store token data and specify expiration time. + Set(ctx context.Context, accessToken string, val any, expiration time.Duration) error + + // Delete token data from storage. + Delete(ctx context.Context, accessToken string) (bool, error) + + // Check if token exists. + Check(ctx context.Context, accessToken string) (bool, error) +} diff --git a/authx/jwt/store/redis/redis.go b/authx/jwt/store/redis/redis.go new file mode 100644 index 0000000..41146e5 --- /dev/null +++ b/authx/jwt/store/redis/redis.go @@ -0,0 +1,53 @@ +package redis + +import ( + "context" + "fmt" + "time" + + "github.com/redis/go-redis/v9" +) + +// Store redis storage. +type Store struct { + client redis.Cmdable + + prefix string +} + +// NewStore create an *Store instance to handle token storage, deletion, and checking. +func NewStore(client redis.Cmdable, prefix string) *Store { + return &Store{client: client, prefix: prefix} +} + +// Set call the Redis client to set a key-value pair with an +// expiration time, where the key name format is . +func (s *Store) Set(ctx context.Context, accessToken string, val any, expiration time.Duration) error { + cmd := s.client.Set(ctx, s.key(accessToken), val, expiration) + return cmd.Err() +} + +// Delete delete the specified JWT Token in Redis. +func (s *Store) Delete(ctx context.Context, accessToken string) (bool, error) { + cmd := s.client.Del(ctx, s.key(accessToken)) + if err := cmd.Err(); err != nil { + return false, err + } + return cmd.Val() > 0, nil +} + +// Check check if the specified JWT Token exists in Redis. +func (s *Store) Check(ctx context.Context, accessToken string) (bool, error) { + s.client.Get(ctx, s.key(accessToken)) + + cmd := s.client.Exists(ctx, s.key(accessToken)) + if err := cmd.Err(); err != nil { + return false, err + } + return cmd.Val() > 0, nil +} + +// wrapperKey is used to build the key name in Redis. +func (s *Store) key(key string) string { + return fmt.Sprintf("%s%s", s.prefix, key) +} diff --git a/authx/jwt/token.go b/authx/jwt/token.go new file mode 100644 index 0000000..67c8ce1 --- /dev/null +++ b/authx/jwt/token.go @@ -0,0 +1,37 @@ +package jwt + +import ( + "encoding/json" +) + +// tokenInfo contains token information. +type tokenInfo struct { + // Token string. + Token string `json:"token"` + + // Token type. + Type string `json:"type"` + + // Token expiration time + ExpiresAt int64 `json:"expiresAt"` +} + +func (t *tokenInfo) GetExpireAt() int64 { + return t.ExpiresAt +} + +func (t *tokenInfo) GetToken() string { + return t.Token +} + +func (t *tokenInfo) GetTokenType() string { + return t.Type +} + +func (t *tokenInfo) GetExpiresAt() int64 { + return t.ExpiresAt +} + +func (t *tokenInfo) EncodeToJSON() ([]byte, error) { + return json.Marshal(t) +} diff --git a/authx/types.go b/authx/types.go new file mode 100644 index 0000000..926918d --- /dev/null +++ b/authx/types.go @@ -0,0 +1,48 @@ +package authx + +import ( + "context" + + "github.com/golang-jwt/jwt/v5" + "golang.org/x/crypto/bcrypt" +) + +type Token interface { + // GetToken Get token string. + GetToken() string + // GetTokenType Get token type. + GetTokenType() string + // GetExpireAt Get token expiration timestamp. + GetExpireAt() int64 + // EncodeToJSON JSON encoding + EncodeToJSON() ([]byte, error) +} + +// Authenticator defines methods used for token processing. +type Authenticator interface { + // Sign is used to generate a token. + Sign(ctx context.Context, userID string) (Token, error) + + // Destroy is used to destroy a token. + Destroy(ctx context.Context, accessToken string) error + + // ParseClaims parse the token and return the claims. + ParseClaims(ctx context.Context, accessToken string) (*jwt.RegisteredClaims, error) + + // ParseToken is used to parse a token. + ParseToken(ctx context.Context, accessToken string) (*jwt.Token, error) + + // GenerateToken is used to generate a token. + GenerateToken(ctx context.Context) (string, error) +} + +// Encrypt encrypts the plain text with bcrypt. +func Encrypt(source string) (string, error) { + hashedBytes, err := bcrypt.GenerateFromPassword([]byte(source), bcrypt.DefaultCost) + return string(hashedBytes), err +} + +// Compare compares the encrypted text with the plain text if it's the same. +func Compare(hashedPassword, password string) error { + return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) +} diff --git a/cmd/van/config/config.go b/cmd/van/config/config.go index 2f7e3db..14c4e3c 100644 --- a/cmd/van/config/config.go +++ b/cmd/van/config/config.go @@ -1,4 +1,4 @@ -package main +package config const ( Release = "v0.2.0" diff --git a/cmd/van/go.mod b/cmd/van/go.mod new file mode 100644 index 0000000..2e26025 --- /dev/null +++ b/cmd/van/go.mod @@ -0,0 +1,39 @@ +module github.com/apus-run/van/cmd/van + +go 1.22 + +require ( + github.com/charmbracelet/huh v0.5.1 + github.com/spf13/cobra v1.8.1 +) + +require ( + github.com/atotto/clipboard v0.1.4 // indirect + github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect + github.com/catppuccin/go v0.2.0 // indirect + github.com/charmbracelet/bubbles v0.18.0 // indirect + github.com/charmbracelet/bubbletea v0.26.4 // indirect + github.com/charmbracelet/lipgloss v0.11.0 // indirect + github.com/charmbracelet/x/ansi v0.1.2 // indirect + github.com/charmbracelet/x/exp/strings v0.0.0-20240617190524-788ec55faed1 // indirect + github.com/charmbracelet/x/input v0.1.2 // indirect + github.com/charmbracelet/x/term v0.1.1 // indirect + github.com/charmbracelet/x/windows v0.1.2 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/lucasb-eyer/go-colorful v1.2.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-localereader v0.0.1 // indirect + github.com/mattn/go-runewidth v0.0.15 // indirect + github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect + github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect + github.com/muesli/cancelreader v0.2.2 // indirect + github.com/muesli/termenv v0.15.2 // indirect + github.com/rivo/uniseg v0.4.7 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/sys v0.21.0 // indirect + golang.org/x/text v0.16.0 // indirect +) diff --git a/cmd/van/go.sum b/cmd/van/go.sum new file mode 100644 index 0000000..b465cff --- /dev/null +++ b/cmd/van/go.sum @@ -0,0 +1,71 @@ +github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= +github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= +github.com/catppuccin/go v0.2.0 h1:ktBeIrIP42b/8FGiScP9sgrWOss3lw0Z5SktRoithGA= +github.com/catppuccin/go v0.2.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc= +github.com/charmbracelet/bubbles v0.18.0 h1:PYv1A036luoBGroX6VWjQIE9Syf2Wby2oOl/39KLfy0= +github.com/charmbracelet/bubbles v0.18.0/go.mod h1:08qhZhtIwzgrtBjAcJnij1t1H0ZRjwHyGsy6AL11PSw= +github.com/charmbracelet/bubbletea v0.26.4 h1:2gDkkzLZaTjMl/dQBpNVtnvcCxsh/FCkimep7FC9c40= +github.com/charmbracelet/bubbletea v0.26.4/go.mod h1:P+r+RRA5qtI1DOHNFn0otoNwB4rn+zNAzSj/EXz6xU0= +github.com/charmbracelet/huh v0.5.1 h1:t5j6g9sMjAE2a9AQuc4lNL7pf/0X4WdHiiMGkL8v/aM= +github.com/charmbracelet/huh v0.5.1/go.mod h1:gs7b2brpzXkY0PBWUqJrlzvOowTCL0vNAR6OTItc+kA= +github.com/charmbracelet/lipgloss v0.11.0 h1:UoAcbQ6Qml8hDwSWs0Y1cB5TEQuZkDPH/ZqwWWYTG4g= +github.com/charmbracelet/lipgloss v0.11.0/go.mod h1:1UdRTH9gYgpcdNN5oBtjbu/IzNKtzVtb7sqN1t9LNn8= +github.com/charmbracelet/x/ansi v0.1.2 h1:6+LR39uG8DE6zAmbu023YlqjJHkYXDF1z36ZwzO4xZY= +github.com/charmbracelet/x/ansi v0.1.2/go.mod h1:dk73KoMTT5AX5BsX0KrqhsTqAnhZZoCBjs7dGWp4Ktw= +github.com/charmbracelet/x/exp/strings v0.0.0-20240617190524-788ec55faed1 h1:VZIQzjwFE0EamzG2v8HfemeisB8X02Tl0BZBnJ0PeU8= +github.com/charmbracelet/x/exp/strings v0.0.0-20240617190524-788ec55faed1/go.mod h1:pBhA0ybfXv6hDjQUZ7hk1lVxBiUbupdw5R31yPUViVQ= +github.com/charmbracelet/x/exp/term v0.0.0-20240524151031-ff83003bf67a h1:k/s6UoOSVynWiw7PlclyGO2VdVs5ZLbMIHiGp4shFZE= +github.com/charmbracelet/x/exp/term v0.0.0-20240524151031-ff83003bf67a/go.mod h1:YBotIGhfoWhHDlnUpJMkjebGV2pdGRCn1Y4/Nk/vVcU= +github.com/charmbracelet/x/input v0.1.2 h1:QJAZr33eOhDowkkEQ24rsJy4Llxlm+fRDf/cQrmqJa0= +github.com/charmbracelet/x/input v0.1.2/go.mod h1:LGBim0maUY4Pitjn/4fHnuXb4KirU3DODsyuHuXdOyA= +github.com/charmbracelet/x/term v0.1.1 h1:3cosVAiPOig+EV4X9U+3LDgtwwAoEzJjNdwbXDjF6yI= +github.com/charmbracelet/x/term v0.1.1/go.mod h1:wB1fHt5ECsu3mXYusyzcngVWWlu1KKUmmLhfgr/Flxw= +github.com/charmbracelet/x/windows v0.1.2 h1:Iumiwq2G+BRmgoayww/qfcvof7W/3uLoelhxojXlRWg= +github.com/charmbracelet/x/windows v0.1.2/go.mod h1:GLEO/l+lizvFDBPLIOk+49gdX49L9YWMB5t+DZd0jkQ= +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= +github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= +github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4= +github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= +github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= +github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/cmd/van/internal/create/create.go b/cmd/van/internal/create/create.go index ef4f218..cb2e508 100644 --- a/cmd/van/internal/create/create.go +++ b/cmd/van/internal/create/create.go @@ -1 +1,26 @@ package create + +import ( + "github.com/spf13/cobra" + + "github.com/apus-run/van/cmd/van/internal/create/model" + "github.com/apus-run/van/cmd/van/internal/create/service" +) + +// Cmd represents the new command. +var Cmd = &cobra.Command{ + Use: "new", + Short: "Create new ", + Long: "Generate the new files.", + Run: func(cmd *cobra.Command, args []string) { + err := cmd.Help() + if err != nil { + return + } + }, +} + +func init() { + Cmd.AddCommand(model.Cmd) + Cmd.AddCommand(service.Cmd) +} diff --git a/cmd/van/internal/create/model/model.go b/cmd/van/internal/create/model/model.go new file mode 100644 index 0000000..8c5e333 --- /dev/null +++ b/cmd/van/internal/create/model/model.go @@ -0,0 +1,17 @@ +package model + +import ( + "github.com/spf13/cobra" +) + +// Cmd represents the new command. +var Cmd = &cobra.Command{ + Use: "new", + Short: "Create a model", + Long: "Create a model using the repository template. Example: van new helloworld", + Run: run, +} + +func run(cmd *cobra.Command, args []string) { + +} diff --git a/cmd/van/internal/create/service/service.go b/cmd/van/internal/create/service/service.go new file mode 100644 index 0000000..879953b --- /dev/null +++ b/cmd/van/internal/create/service/service.go @@ -0,0 +1,17 @@ +package service + +import ( + "github.com/spf13/cobra" +) + +// Cmd represents the new command. +var Cmd = &cobra.Command{ + Use: "new", + Short: "Create a service template", + Long: "Create a service project using the repository template. Example: van new helloworld", + Run: run, +} + +func run(cmd *cobra.Command, args []string) { + +} diff --git a/cmd/van/internal/new/new.go b/cmd/van/internal/new/new.go new file mode 100644 index 0000000..7d0e5cc --- /dev/null +++ b/cmd/van/internal/new/new.go @@ -0,0 +1,86 @@ +package new + +import ( + "context" + "errors" + "fmt" + "os" + "time" + + "github.com/charmbracelet/huh" + "github.com/spf13/cobra" +) + +// Cmd represents the new command. +var Cmd = &cobra.Command{ + Use: "new", + Aliases: []string{"create", "init"}, + Example: "van new ", + Short: "Create a new project.", + Long: "create a new project with van layout.", + Run: run, +} + +var ( + repoURL string + timeout string +) + +func init() { + timeout = "60s" + Cmd.Flags().StringVarP(&repoURL, "repo-url", "r", repoURL, "layout repo") + Cmd.Flags().StringVarP(&timeout, "timeout", "t", timeout, "time out") +} + +func run(_ *cobra.Command, args []string) { + wd, err := os.Getwd() + if err != nil { + panic(err) + } + t, err := time.ParseDuration(timeout) + if err != nil { + panic(err) + } + ctx, cancel := context.WithTimeout(context.Background(), t) + defer cancel() + + name := "" + if len(args) == 0 { + err := huh.NewInput(). + Title("What is your project name?"). + Description("project name."). + Prompt("🚚"). + Value(&name). + Validate(func(name string) error { + if name == "" { + return errors.New("The project name cannot be empty!") + } + return nil + }).Run() + + if err != nil { + return + } + } else { + name = args[0] + } + p := NewProject(name) + done := make(chan error, 1) + go func() { + done <- p.New(ctx, wd, repoURL) + }() + + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + fmt.Fprint(os.Stderr, "\033[31mERROR: project creation timed out\033[m\n") + } else { + fmt.Fprintf(os.Stderr, "\033[31mERROR: failed to create project(%s)\033[m\n", ctx.Err().Error()) + } + case err = <-done: + if err != nil { + fmt.Fprintf(os.Stderr, "\033[31mERROR: Failed to create project(%s)\033[m\n", err.Error()) + } + } + +} diff --git a/cmd/van/internal/new/project.go b/cmd/van/internal/new/project.go index 3031951..bffcd2b 100644 --- a/cmd/van/internal/new/project.go +++ b/cmd/van/internal/new/project.go @@ -1 +1,92 @@ package new + +import ( + "context" + "fmt" + "os" + "os/exec" + "path" + + "github.com/charmbracelet/huh" + + "github.com/apus-run/van/cmd/van/config" +) + +// Project is a project template. +type Project struct { + Name string +} + +func NewProject(name string) *Project { + return &Project{ + Name: name, + } +} + +func (p *Project) New(ctx context.Context, dir string, layout string) error { + to := path.Join(dir, p.Name) + if _, err := os.Stat(to); !os.IsNotExist(err) { + fmt.Printf("🚫 %s already exists\n", p.Name) + override := false + e := huh.NewConfirm(). + Title("📂 Do you want to override the folder ?"). + Description("Delete the existing folder and create the project."). + Affirmative("Yes!"). + Negative("No."). + Value(&override).Run() + if e != nil { + return e + } + if !override { + return err + } + e = os.RemoveAll(to) + if e != nil { + fmt.Println("remove old project error: ", err) + return e + } + } + + repo := "" + if layout == "" { + selected := "" + err := huh.NewSelect[string](). + Title("Please select a layout:"). + Options( + huh.NewOptions("Basic", "Advanced", "Multiple")..., + ).Value(&selected).Run() + if err != nil { + return err + } + + switch selected { + case "Basic": + repo = config.RepoBase + case "Advanced": + repo = config.RepoAdvanced + case "Multiple": + repo = config.RepoMultiple + default: + repo = config.RepoBase + } + + err = os.RemoveAll(p.Name) + if err != nil { + fmt.Println("remove old project error: ", err) + return err + } + } else { + repo = layout + } + + fmt.Printf("🚀 Creating service %s, layout repo is %s, please wait a moment.\n\n", p.Name, repo) + + fmt.Printf("git clone %s\n", repo) + cmd := exec.Command("git", "clone", repo, p.Name) + _, err := cmd.CombinedOutput() + if err != nil { + fmt.Printf("git clone %s error: %s\n", repo, err) + return err + } + return true, nil +} diff --git a/cmd/van/internal/rpc/rpc.go b/cmd/van/internal/rpc/rpc.go new file mode 100644 index 0000000..bb96f05 --- /dev/null +++ b/cmd/van/internal/rpc/rpc.go @@ -0,0 +1,16 @@ +package rpc + +import ( + "github.com/spf13/cobra" +) + +var Cmd = &cobra.Command{ + Use: "rpc", + Short: "Rpc project", + Long: "Rpc project. Example: van rpc", + Run: Run, +} + +func Run(cmd *cobra.Command, args []string) { + +} diff --git a/cmd/van/internal/run/run.go b/cmd/van/internal/run/run.go new file mode 100644 index 0000000..f43c2ba --- /dev/null +++ b/cmd/van/internal/run/run.go @@ -0,0 +1,16 @@ +package run + +import ( + "github.com/spf13/cobra" +) + +var Cmd = &cobra.Command{ + Use: "run", + Short: "Run project", + Long: "Run project. Example: van run", + Run: Run, +} + +func Run(cmd *cobra.Command, args []string) { + +} diff --git a/cmd/van/internal/upgrade/upgrade.go b/cmd/van/internal/upgrade/upgrade.go new file mode 100644 index 0000000..e303f10 --- /dev/null +++ b/cmd/van/internal/upgrade/upgrade.go @@ -0,0 +1,17 @@ +package upgrade + +import ( + "github.com/spf13/cobra" +) + +var Cmd = &cobra.Command{ + Use: "upgrade", + Short: "Upgrade the van tools", + Long: "Upgrade the van tools. Example: van upgrade", + Run: Run, +} + +// Run upgrade the van tools. +func Run(cmd *cobra.Command, args []string) { + +} diff --git a/cmd/van/main.go b/cmd/van/main.go new file mode 100644 index 0000000..c653f70 --- /dev/null +++ b/cmd/van/main.go @@ -0,0 +1,38 @@ +package main + +import ( + "fmt" + "log" + + "github.com/spf13/cobra" + + "github.com/apus-run/van/cmd/van/config" + "github.com/apus-run/van/cmd/van/internal/new" +) + +var RootCmd = &cobra.Command{ + Use: "van", + Example: "van [flags]", + Short: "Van CLI", + Long: "+-------------------------------------------+\n| █████ █████ |\n| ░░███ ░░███ |\n| ░███ ░███ ██████ ████████ |\n| ░███ ░███ ░░░░░███ ░░███░░███ |\n| ░░███ ███ ███████ ░███ ░███ |\n| ░░░█████░ ███░░███ ░███ ░███ |\n| ░░███ ░░████████ ████ █████ |\n| ░░░ ░░░░░░░░ ░░░░ ░░░░░ |\n+-------------------------------------------+\nVan: 一个轻量级的Golang应用搭建脚手架", + Version: fmt.Sprintf("\n__ __ \n\\ \\ / / __ _ _ __ \n \\ \\ / / / _` || '_ \\ \n \\ V / | (_| || | | |\n \\_/ \\__,_||_| |_|\n \nVan %s - Copyright (c) 2024-2026 Van\nReleased under the MIT License.\n\n", config.Release), + SilenceUsage: true, + Run: func(cmd *cobra.Command, args []string) { + err := cmd.Help() + if err != nil { + return + } + }, +} + +func init() { + // RootCmd.AddCommand(rpc.Cmd) + RootCmd.AddCommand(new.Cmd) + // RootCmd.AddCommand(run.Cmd) + // RootCmd.AddCommand(upgrade.Cmd) +} +func main() { + if err := RootCmd.Execute(); err != nil { + log.Fatal(err) + } +} diff --git a/db/gorm/wrapper.go b/db/gorm/wrapper.go index acebd94..da688d3 100644 --- a/db/gorm/wrapper.go +++ b/db/gorm/wrapper.go @@ -23,6 +23,12 @@ type Database interface { CloseDB(ctx context.Context, options ...Option) error } +// Transaction 事物接口 +type Transaction interface { + // Execute 执行一个事务方法,func为一个需要保证事务完整性的业务方法 + Execute(ctx context.Context, fn func(ctx context.Context) error) error +} + type Helper struct { lock *sync.RWMutex group *singleflight.Group diff --git a/db/redis/wrapper.go b/db/redis/wrapper.go index a6d2681..d3aa0e8 100644 --- a/db/redis/wrapper.go +++ b/db/redis/wrapper.go @@ -29,6 +29,12 @@ type Database interface { CloseDB(ctx context.Context, options ...Option) error } +// Transaction 事物接口 +type Transaction interface { + // Execute 执行一个事务方法,func为一个需要保证事务完整性的业务方法 + Execute(ctx context.Context, fn func(ctx context.Context) error) error +} + type Helper struct { lock *sync.RWMutex group *singleflight.Group diff --git a/db/sqlx/wrapper.go b/db/sqlx/wrapper.go index 0840b9e..5a23453 100644 --- a/db/sqlx/wrapper.go +++ b/db/sqlx/wrapper.go @@ -20,6 +20,12 @@ type Database interface { CloseDB(ctx context.Context, options ...Option) error } +// Transaction 事物接口 +type Transaction interface { + // Execute 执行一个事务方法,func为一个需要保证事务完整性的业务方法 + Execute(ctx context.Context, fn func(ctx context.Context) error) error +} + type Helper struct { lock *sync.RWMutex group *singleflight.Group diff --git a/errors/stack.go b/errors/stack.go deleted file mode 100644 index c73c3bf..0000000 --- a/errors/stack.go +++ /dev/null @@ -1,21 +0,0 @@ -package errors - -import ( - "bytes" - "fmt" - "runtime" -) - -// LogStack return call function stack info from start stack to end stack. -// if end is a positive number, return all call function stack. -func LogStack(start, end int) string { - stack := bytes.Buffer{} - for i := start; i < end || end <= 0; i++ { - pc, str, line, _ := runtime.Caller(i) - if line == 0 { - break - } - stack.WriteString(fmt.Sprintf("%s:%d %s\n", str, line, runtime.FuncForPC(pc).Name())) - } - return stack.String() -} diff --git a/errorsx/stack.go b/errorsx/stack.go new file mode 100644 index 0000000..0dd991f --- /dev/null +++ b/errorsx/stack.go @@ -0,0 +1,49 @@ +package errorsx + +import ( + "encoding/json" + "fmt" + "io" + + "github.com/pkg/errors" +) + +// StackTrace handle the error carrying the stack +type StackTrace errors.StackTrace + +func (st StackTrace) MarshalJSON() ([]byte, error) { + var stacks []string + for _, frame := range st { + f, err := frame.MarshalText() + if err != nil { + return nil, err + } + stacks = append(stacks, string(f)) + } + return json.Marshal(stacks) +} + +func (st StackTrace) Format(s fmt.State, verb rune) { + _, err := io.WriteString(s, fmt.Sprintf("%+v", st)) + if err != nil { + return + } +} + +// ErrorStackTrace format error with stack +func ErrorStackTrace(err error) StackTrace { + if v, ok := err.(interface { + StackTrace() errors.StackTrace + }); ok { + return StackTrace(v.StackTrace()) + } + return nil +} + +// IsStackTrace check if error implements the stack +func IsStackTrace(err error) bool { + _, ok := err.(interface { + StackTrace() errors.StackTrace + }) + return ok +} diff --git a/errors/error.go b/errx/error.go similarity index 73% rename from errors/error.go rename to errx/error.go index 0832922..9dbfd04 100644 --- a/errors/error.go +++ b/errx/error.go @@ -1,8 +1,9 @@ -package errors +package errx import ( "bytes" "fmt" + "runtime" "strings" ) @@ -65,3 +66,17 @@ func (e *Error) Format(state fmt.State, verb rune) { fmt.Fprintf(state, e.Message) } } + +// LogStack return call function stack info from start stack to end stack. +// if end is a positive number, return all call function stack. +func LogStack(start, end int) string { + stack := bytes.Buffer{} + for i := start; i < end || end <= 0; i++ { + pc, str, line, _ := runtime.Caller(i) + if line == 0 { + break + } + stack.WriteString(fmt.Sprintf("%s:%d %s\n", str, line, runtime.FuncForPC(pc).Name())) + } + return stack.String() +} diff --git a/errors/error_test.go b/errx/error_test.go similarity index 95% rename from errors/error_test.go rename to errx/error_test.go index 977eefb..6fa37a1 100644 --- a/errors/error_test.go +++ b/errx/error_test.go @@ -1,4 +1,4 @@ -package errors +package errx import ( "fmt" diff --git a/errx/ignore.go b/errx/ignore.go new file mode 100644 index 0000000..6b9b25c --- /dev/null +++ b/errx/ignore.go @@ -0,0 +1,16 @@ +package errx + +/* +Ignore is a helper that wraps a call to a function returning value and error +and ignores if the error is non-nil. +*/ +func Ignore[T any](val T, err error) T { + return val +} + +/* +IgnoreRuntime is a runtime version of errorsx.Ignore(). +*/ +func IgnoreRuntime(val any, err error) any { + return val +} diff --git a/errx/must.go b/errx/must.go new file mode 100644 index 0000000..d454a63 --- /dev/null +++ b/errx/must.go @@ -0,0 +1,24 @@ +package errx + +/* +Must is a helper that wraps a call to a function returning value and error +and panics if the error is non-nil. +*/ +func Must[T any](val T, err error) T { + if err != nil { + panic(err) + } + + return val +} + +/* +MustRuntime is a runtime version of errorsx.Must(). +*/ +func MustRuntime(val any, err error) any { + if err != nil { + panic(err) + } + + return val +} diff --git a/errx/result.go b/errx/result.go new file mode 100644 index 0000000..fbda360 --- /dev/null +++ b/errx/result.go @@ -0,0 +1,64 @@ +package errx + +// Result is a generic type that wraps a value and an error. +// It also provides a set of methods to resolve/unwrap underlying value +// with a specific value/error handling. +type Result[T any] struct { + value T + error error +} + +// Simple value getters + +// Error getter +func (r Result[T]) Error() error { + return r.error +} + +// Value getter +func (r Result[T]) Value() T { + return r.value +} + +// Advanced value getters + +// ValueOr returns the value if no error is present, +// otherwise returns the provided default value +func (r Result[T]) ValueOr(def T) T { + if r.error != nil { + return def + } + + return r.value +} + +// Must returns the value if no error is present, +// otherwise panics with the error +func (r Result[T]) Must() T { + return Must(r.value, r.error) +} + +// Case handlers + +// Then executes the provided function if no error is present +func (r Result[T]) Then(fn func(T)) Result[T] { + if r.error == nil { + fn(r.value) + } + + return r +} + +// Catch executes the provided function if an error is present +func (r Result[T]) Catch(fn func(error)) Result[T] { + if r.error != nil { + fn(r.error) + } + + return r +} + +// Wrap wraps a value and an error into a Result +func Wrap[T any](val T, err error) Result[T] { + return Result[T]{value: val, error: err} +} diff --git a/errors/types.go b/errx/types.go similarity index 99% rename from errors/types.go rename to errx/types.go index 4ec7de6..4a4b296 100644 --- a/errors/types.go +++ b/errx/types.go @@ -1,4 +1,4 @@ -package errors +package errx // BadRequest new BadRequest error func BadRequest(reason string) *Error { diff --git a/go.mod b/go.mod index 46b26c1..51c8929 100644 --- a/go.mod +++ b/go.mod @@ -12,13 +12,17 @@ require ( github.com/fsnotify/fsnotify v1.7.0 github.com/gin-gonic/gin v1.10.0 github.com/go-sql-driver/mysql v1.8.1 + github.com/golang-jwt/jwt/v5 v5.0.0 github.com/golang/mock v1.6.0 github.com/google/uuid v1.6.0 github.com/iancoleman/strcase v0.3.0 github.com/jmoiron/sqlx v1.4.0 + github.com/lithammer/shortuuid/v4 v4.0.0 github.com/mattn/go-sqlite3 v1.14.22 + github.com/oklog/ulid v1.3.1 github.com/pkg/errors v0.9.1 github.com/redis/go-redis/v9 v9.5.4 + github.com/spf13/cast v1.6.0 github.com/spf13/viper v1.19.0 github.com/stretchr/testify v1.9.0 go.uber.org/zap v1.27.0 @@ -86,7 +90,6 @@ require ( github.com/shopspring/decimal v1.4.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.11.0 // indirect - github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect diff --git a/go.sum b/go.sum index 52d39ed..bef21b7 100644 --- a/go.sum +++ b/go.sum @@ -163,6 +163,8 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lithammer/shortuuid/v4 v4.0.0 h1:QRbbVkfgNippHOS8PXDkti4NaWeyYfcBTHtw7k08o4c= +github.com/lithammer/shortuuid/v4 v4.0.0/go.mod h1:Zs8puNcrvf2rV9rTH51ZLLcj7ZXqQI3lv67aw4KiB1Y= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= @@ -182,6 +184,8 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= +github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= +github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/paulmach/orb v0.11.1 h1:3koVegMC4X/WeiXYz9iswopaTwMem53NzTJuTF20JzU= github.com/paulmach/orb v0.11.1/go.mod h1:5mULz1xQfs3bmQm63QEJA6lNGujuRafwA5S/EnuLaLU= github.com/paulmach/protoscan v0.2.1/go.mod h1:SpcSwydNLrxUGSDvXvO0P7g7AuhJ7lcKfDlhJCDw2gY= diff --git a/pkg/hook/hook.go b/pkg/hook/hook.go new file mode 100644 index 0000000..abc0949 --- /dev/null +++ b/pkg/hook/hook.go @@ -0,0 +1,126 @@ +package hook + +import ( + "errors" + "fmt" + "sync" + + "github.com/apus-run/van/pkg/rand" +) + +var StopPropagation = errors.New("Event hook propagation stopped") + +// Handler defines a hook handler function. +type Handler[T any] func(e T) error + +// handlerPair defines a pair of string id and Handler. +type handlerPair[T any] struct { + id string + handler Handler[T] +} + +// Hook defines a concurrent safe structure for handling event hooks +// (aka. callbacks propagation). +type Hook[T any] struct { + mux sync.RWMutex + handlers []*handlerPair[T] +} + +// PreAdd registers a new handler to the hook by prepending it to the existing queue. +// +// Returns an autogenerated hook id that could be used later to remove the hook with Hook.Remove(id). +func (h *Hook[T]) PreAdd(fn Handler[T]) string { + h.mux.Lock() + defer h.mux.Unlock() + + id := generateHookId() + + // minimize allocations by shifting the slice + h.handlers = append(h.handlers, nil) + copy(h.handlers[1:], h.handlers) + h.handlers[0] = &handlerPair[T]{id, fn} + + return id +} + +// Add registers a new handler to the hook by appending it to the existing queue. +// +// Returns an autogenerated hook id that could be used later to remove the hook with Hook.Remove(id). +func (h *Hook[T]) Add(fn Handler[T]) string { + h.mux.Lock() + defer h.mux.Unlock() + + id := generateHookId() + + h.handlers = append(h.handlers, &handlerPair[T]{id, fn}) + + return id +} + +// Remove removes a single hook handler by its id. +func (h *Hook[T]) Remove(id string) { + h.mux.Lock() + defer h.mux.Unlock() + + for i := len(h.handlers) - 1; i >= 0; i-- { + if h.handlers[i].id == id { + h.handlers = append(h.handlers[:i], h.handlers[i+1:]...) + return + } + } +} + +// RemoveAll removes all registered handlers. +func (h *Hook[T]) RemoveAll() { + h.mux.Lock() + defer h.mux.Unlock() + + h.handlers = nil +} + +// Trigger executes all registered hook handlers one by one +// with the specified `data` as an argument. +// +// Optionally, this method allows also to register additional one off +// handlers that will be temporary appended to the handlers queue. +// +// The execution stops when: +// - hook.StopPropagation is returned in one of the handlers +// - any non-nil error is returned in one of the handlers +func (h *Hook[T]) Trigger(data T, oneOffHandlers ...Handler[T]) error { + h.mux.RLock() + + handlers := make([]*handlerPair[T], 0, len(h.handlers)+len(oneOffHandlers)) + handlers = append(handlers, h.handlers...) + + // append the one off handlers + for i, oneOff := range oneOffHandlers { + handlers = append(handlers, &handlerPair[T]{ + id: fmt.Sprintf("@%d", i), + handler: oneOff, + }) + } + + // unlock is not deferred to avoid deadlocks in case Trigger + // is called recursively by the handlers + h.mux.RUnlock() + + for _, item := range handlers { + err := item.handler(data) + if err == nil { + continue + } + + if errors.Is(err, StopPropagation) { + return nil + } + + return err + } + + return nil +} + +func generateHookId() string { + return rand.PseudorandomString(8) +} diff --git a/pkg/hook/hook_test.go b/pkg/hook/hook_test.go new file mode 100644 index 0000000..a3ab688 --- /dev/null +++ b/pkg/hook/hook_test.go @@ -0,0 +1,179 @@ +package hook + +import ( + "errors" + "testing" +) + +func TestHookAddAndPreAdd(t *testing.T) { + h := Hook[int]{} + + if total := len(h.handlers); total != 0 { + t.Fatalf("Expected no handlers, found %d", total) + } + + triggerSequence := "" + + f1 := func(data int) error { triggerSequence += "f1"; return nil } + f2 := func(data int) error { triggerSequence += "f2"; return nil } + f3 := func(data int) error { triggerSequence += "f3"; return nil } + f4 := func(data int) error { triggerSequence += "f4"; return nil } + + h.Add(f1) + h.Add(f2) + h.PreAdd(f3) + h.PreAdd(f4) + h.Trigger(1) + + if total := len(h.handlers); total != 4 { + t.Fatalf("Expected %d handlers, found %d", 4, total) + } + + expectedTriggerSequence := "f4f3f1f2" + + if triggerSequence != expectedTriggerSequence { + t.Fatalf("Expected trigger sequence %s, got %s", expectedTriggerSequence, triggerSequence) + } +} + +func TestHookRemove(t *testing.T) { + h := Hook[int]{} + + h1Called := false + h2Called := false + + id1 := h.Add(func(data int) error { h1Called = true; return nil }) + h.Add(func(data int) error { h2Called = true; return nil }) + + h.Remove("missing") // should do nothing and not panic + + if total := len(h.handlers); total != 2 { + t.Fatalf("Expected %d handlers, got %d", 2, total) + } + + h.Remove(id1) + + if total := len(h.handlers); total != 1 { + t.Fatalf("Expected %d handlers, got %d", 1, total) + } + + if err := h.Trigger(1); err != nil { + t.Fatal(err) + } + + if h1Called { + t.Fatalf("Expected hook 1 to be removed and not called") + } + + if !h2Called { + t.Fatalf("Expected hook 2 to be called") + } +} + +func TestHookRemoveAll(t *testing.T) { + h := Hook[int]{} + + h.RemoveAll() // should do nothing and not panic + + h.Add(func(data int) error { return nil }) + h.Add(func(data int) error { return nil }) + + if total := len(h.handlers); total != 2 { + t.Fatalf("Expected 2 handlers before RemoveAll, found %d", total) + } + + h.RemoveAll() + + if total := len(h.handlers); total != 0 { + t.Fatalf("Expected no handlers after RemoveAll, found %d", total) + } +} + +func TestHookTrigger(t *testing.T) { + err1 := errors.New("demo") + err2 := errors.New("demo") + + scenarios := []struct { + handlers []Handler[int] + expectedError error + }{ + { + []Handler[int]{ + func(data int) error { return nil }, + func(data int) error { return nil }, + }, + nil, + }, + { + []Handler[int]{ + func(data int) error { return nil }, + func(data int) error { return err1 }, + func(data int) error { return err2 }, + }, + err1, + }, + } + + for i, scenario := range scenarios { + h := Hook[int]{} + for _, handler := range scenario.handlers { + h.Add(handler) + } + result := h.Trigger(1) + if result != scenario.expectedError { + t.Fatalf("(%d) Expected %v, got %v", i, scenario.expectedError, result) + } + } +} + +func TestHookTriggerStopPropagation(t *testing.T) { + called1 := false + f1 := func(data int) error { called1 = true; return nil } + + called2 := false + f2 := func(data int) error { called2 = true; return nil } + + called3 := false + f3 := func(data int) error { called3 = true; return nil } + + called4 := false + f4 := func(data int) error { called4 = true; return StopPropagation } + + called5 := false + f5 := func(data int) error { called5 = true; return nil } + + called6 := false + f6 := func(data int) error { called6 = true; return nil } + + h := Hook[int]{} + h.Add(f1) + h.Add(f2) + + result := h.Trigger(123, f3, f4, f5, f6) + + if result != nil { + t.Fatalf("Expected nil after StopPropagation, got %v", result) + } + + // ensure that the trigger handler were not persisted + if total := len(h.handlers); total != 2 { + t.Fatalf("Expected 2 handlers, found %d", total) + } + + scenarios := []struct { + called bool + expected bool + }{ + {called1, true}, + {called2, true}, + {called3, true}, + {called4, true}, // StopPropagation + {called5, false}, + {called6, false}, + } + for i, scenario := range scenarios { + if scenario.called != scenario.expected { + t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, scenario.called) + } + } +} diff --git a/pkg/inflector/inflector.go b/pkg/inflector/inflector.go new file mode 100644 index 0000000..3fa1ab1 --- /dev/null +++ b/pkg/inflector/inflector.go @@ -0,0 +1,85 @@ +package inflector + +import ( + "regexp" + "strings" + "unicode" +) + +var columnifyRemoveRegex = regexp.MustCompile(`[^\w\.\*\-\_\@\#]+`) +var snakecaseSplitRegex = regexp.MustCompile(`[\W_]+`) + +// UcFirst converts the first character of a string into uppercase. +func UcFirst(str string) string { + if str == "" { + return "" + } + + s := []rune(str) + + return string(unicode.ToUpper(s[0])) + string(s[1:]) +} + +// Columnify strips invalid db identifier characters. +func Columnify(str string) string { + return columnifyRemoveRegex.ReplaceAllString(str, "") +} + +// Sentenize converts and normalizes string into a sentence. +func Sentenize(str string) string { + str = strings.TrimSpace(str) + if str == "" { + return "" + } + + str = UcFirst(str) + + lastChar := str[len(str)-1:] + if lastChar != "." && lastChar != "?" && lastChar != "!" { + return str + "." + } + + return str +} + +// Sanitize sanitizes `str` by removing all characters satisfying `removePattern`. +// Returns an error if the pattern is not valid regex string. +func Sanitize(str string, removePattern string) (string, error) { + exp, err := regexp.Compile(removePattern) + if err != nil { + return "", err + } + + return exp.ReplaceAllString(str, ""), nil +} + +// Snakecase removes all non word characters and converts any english text into a snakecase. +// "ABBREVIATIONS" are preserved, eg. "myTestDB" will become "my_test_db". +func Snakecase(str string) string { + var result strings.Builder + + // split at any non word character and underscore + words := snakecaseSplitRegex.Split(str, -1) + + for _, word := range words { + if word == "" { + continue + } + + if result.Len() > 0 { + result.WriteString("_") + } + + for i, c := range word { + if unicode.IsUpper(c) && i > 0 && + // is not a following uppercase character + !unicode.IsUpper(rune(word[i-1])) { + result.WriteString("_") + } + + result.WriteRune(c) + } + } + + return strings.ToLower(result.String()) +} diff --git a/pkg/inflector/inflector_test.go b/pkg/inflector/inflector_test.go new file mode 100644 index 0000000..80a7ae2 --- /dev/null +++ b/pkg/inflector/inflector_test.go @@ -0,0 +1,134 @@ +package inflector_test + +import ( + "testing" + + "github.com/apus-run/van/pkg/inflector" +) + +func TestUcFirst(t *testing.T) { + scenarios := []struct { + val string + expected string + }{ + {"", ""}, + {" ", " "}, + {"Test", "Test"}, + {"test", "Test"}, + {"test test2", "Test test2"}, + } + + for i, scenario := range scenarios { + if result := inflector.UcFirst(scenario.val); result != scenario.expected { + t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result) + } + } +} + +func TestColumnify(t *testing.T) { + scenarios := []struct { + val string + expected string + }{ + {"", ""}, + {" ", ""}, + {"123", "123"}, + {"Test.", "Test."}, + {" test ", "test"}, + {"test1.test2", "test1.test2"}, + {"@test!abc", "@testabc"}, + {"#test?abc", "#testabc"}, + {"123test(123)#", "123test123#"}, + {"test1--test2", "test1--test2"}, + } + + for i, scenario := range scenarios { + if result := inflector.Columnify(scenario.val); result != scenario.expected { + t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result) + } + } +} + +func TestSentenize(t *testing.T) { + scenarios := []struct { + val string + expected string + }{ + {"", ""}, + {" ", ""}, + {".", "."}, + {"?", "?"}, + {"!", "!"}, + {"Test", "Test."}, + {" test ", "Test."}, + {"hello world", "Hello world."}, + {"hello world.", "Hello world."}, + {"hello world!", "Hello world!"}, + {"hello world?", "Hello world?"}, + } + + for i, scenario := range scenarios { + if result := inflector.Sentenize(scenario.val); result != scenario.expected { + t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result) + } + } +} + +func TestSanitize(t *testing.T) { + scenarios := []struct { + val string + pattern string + expected string + expectErr bool + }{ + {"", ``, "", false}, + {" ", ``, " ", false}, + {" ", ` `, "", false}, + {"", `[A-Z]`, "", false}, + {"abcABC", `[A-Z]`, "abc", false}, + {"abcABC", `[A-Z`, "", true}, // invalid pattern + } + + for i, scenario := range scenarios { + result, err := inflector.Sanitize(scenario.val, scenario.pattern) + hasErr := err != nil + + if scenario.expectErr != hasErr { + if scenario.expectErr { + t.Errorf("(%d) Expected error, got nil", i) + } else { + t.Errorf("(%d) Didn't expect error, got", err) + } + } + + if result != scenario.expected { + t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result) + } + } +} + +func TestSnakecase(t *testing.T) { + scenarios := []struct { + val string + expected string + }{ + {"", ""}, + {" ", ""}, + {"!@#$%^", ""}, + {"...", ""}, + {"_", ""}, + {"John Doe", "john_doe"}, + {"John_Doe", "john_doe"}, + {".a!b@c#d$e%123. ", "a_b_c_d_e_123"}, + {"HelloWorld", "hello_world"}, + {"HelloWorld1HelloWorld2", "hello_world1_hello_world2"}, + {"TEST", "test"}, + {"testABR", "test_abr"}, + } + + for i, scenario := range scenarios { + if result := inflector.Snakecase(scenario.val); result != scenario.expected { + t.Errorf("(%d) Expected %q, got %q", i, scenario.expected, result) + } + } +} diff --git a/pkg/ptr/ptr_test.go b/pkg/ptr/ptr_test.go index 69a4ef0..46b4c4c 100644 --- a/pkg/ptr/ptr_test.go +++ b/pkg/ptr/ptr_test.go @@ -8,6 +8,23 @@ import ( "github.com/stretchr/testify/assert" ) +func TestPointer(t *testing.T) { + s1 := ToPtr("") + if s1 == nil || *s1 != "" { + t.Fatalf("Expected empty string pointer, got %#v", s1) + } + + s2 := ToPtr("test") + if s2 == nil || *s2 != "test" { + t.Fatalf("Expected 'test' string pointer, got %#v", s2) + } + + s3 := ToPtr(123) + if s3 == nil || *s3 != 123 { + t.Fatalf("Expected 123 string pointer, got %#v", s3) + } +} + func TestToPtr(t *testing.T) { t.Run("int", func(t *testing.T) { i := 1 diff --git a/pkg/rand/rand.go b/pkg/rand/rand.go new file mode 100644 index 0000000..3fc0099 --- /dev/null +++ b/pkg/rand/rand.go @@ -0,0 +1,64 @@ +package rand + +import ( + cryptoRand "crypto/rand" + "math/big" + mathRand "math/rand" + "time" +) + +const defaultRandomAlphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + +func init() { + mathRand.New(mathRand.NewSource(time.Now().UnixNano())) +} + +// RandomString generates a cryptographically random string with the specified length. +// +// The generated string matches [A-Za-z0-9]+ and it's transparent to URL-encoding. +func RandomString(length int) string { + return RandomStringWithAlphabet(length, defaultRandomAlphabet) +} + +// RandomStringWithAlphabet generates a cryptographically random string +// with the specified length and characters set. +// +// It panics if for some reason rand.Int returns a non-nil error. +func RandomStringWithAlphabet(length int, alphabet string) string { + b := make([]byte, length) + max := big.NewInt(int64(len(alphabet))) + + for i := range b { + n, err := cryptoRand.Int(cryptoRand.Reader, max) + if err != nil { + panic(err) + } + b[i] = alphabet[n.Int64()] + } + + return string(b) +} + +// PseudorandomString generates a pseudorandom string with the specified length. +// +// The generated string matches [A-Za-z0-9]+ and it's transparent to URL-encoding. +// +// For a cryptographically random string (but a little bit slower) use RandomString instead. +func PseudorandomString(length int) string { + return PseudorandomStringWithAlphabet(length, defaultRandomAlphabet) +} + +// PseudorandomStringWithAlphabet generates a pseudorandom string +// with the specified length and characters set. +// +// For a cryptographically random (but a little bit slower) use RandomStringWithAlphabet instead. +func PseudorandomStringWithAlphabet(length int, alphabet string) string { + b := make([]byte, length) + max := len(alphabet) + + for i := range b { + b[i] = alphabet[mathRand.Intn(max)] + } + + return string(b) +} diff --git a/pkg/rand/rand_test.go b/pkg/rand/rand_test.go new file mode 100644 index 0000000..240c545 --- /dev/null +++ b/pkg/rand/rand_test.go @@ -0,0 +1,89 @@ +package rand_test + +import ( + "regexp" + "testing" + + rand "github.com/apus-run/van/pkg/rand" +) + +func TestRandomString(t *testing.T) { + testRandomString(t, rand.RandomString) +} + +func TestRandomStringWithAlphabet(t *testing.T) { + testRandomStringWithAlphabet(t, rand.RandomStringWithAlphabet) +} + +func TestPseudorandomString(t *testing.T) { + testRandomString(t, rand.PseudorandomString) +} + +func TestPseudorandomStringWithAlphabet(t *testing.T) { + testRandomStringWithAlphabet(t, rand.PseudorandomStringWithAlphabet) +} + +// ------------------------------------------------------------------- + +func testRandomStringWithAlphabet(t *testing.T, randomFunc func(n int, alphabet string) string) { + scenarios := []struct { + alphabet string + expectPattern string + }{ + {"0123456789_", `[0-9_]+`}, + {"abcdef123", `[abcdef123]+`}, + {"!@#$%^&*()", `[\!\@\#\$\%\^\&\*\(\)]+`}, + } + + for i, s := range scenarios { + generated := make([]string, 0, 1000) + length := 10 + + for j := 0; j < 1000; j++ { + result := randomFunc(length, s.alphabet) + + if len(result) != length { + t.Fatalf("(%d:%d) Expected the length of the string to be %d, got %d", i, j, length, len(result)) + } + + reg := regexp.MustCompile(s.expectPattern) + if match := reg.MatchString(result); !match { + t.Fatalf("(%d:%d) The generated string should have only %s characters, got %q", i, j, s.expectPattern, result) + } + + for _, str := range generated { + if str == result { + t.Fatalf("(%d:%d) Repeating random string - found %q in %q", i, j, result, generated) + } + } + + generated = append(generated, result) + } + } +} + +func testRandomString(t *testing.T, randomFunc func(n int) string) { + generated := make([]string, 0, 1000) + reg := regexp.MustCompile(`[a-zA-Z0-9]+`) + length := 10 + + for i := 0; i < 1000; i++ { + result := randomFunc(length) + + if len(result) != length { + t.Fatalf("(%d) Expected the length of the string to be %d, got %d", i, length, len(result)) + } + + if match := reg.MatchString(result); !match { + t.Fatalf("(%d) The generated string should have only [a-zA-Z0-9]+ characters, got %q", i, result) + } + + for _, str := range generated { + if str == result { + t.Fatalf("(%d) Repeating random string - found %q in \n%v", i, result, generated) + } + } + + generated = append(generated, result) + } +} diff --git a/pkg/types/datetime.go b/pkg/types/datetime.go new file mode 100644 index 0000000..5db96fc --- /dev/null +++ b/pkg/types/datetime.go @@ -0,0 +1,105 @@ +package types + +import ( + "database/sql/driver" + "encoding/json" + "time" + + "github.com/spf13/cast" +) + +// DefaultDateLayout specifies the default app date strings layout. +const DefaultDateLayout = "2006-01-02 15:04:05.000Z" + +// NowDateTime returns new DateTime instance with the current local time. +func NowDateTime() DateTime { + return DateTime{t: time.Now()} +} + +// ParseDateTime creates a new DateTime from the provided value +// (could be [cast.ToTime] supported string, [time.Time], etc.). +func ParseDateTime(value any) (DateTime, error) { + d := DateTime{} + err := d.Scan(value) + return d, err +} + +// DateTime represents a [time.Time] instance in UTC that is wrapped +// and serialized using the app default date layout. +type DateTime struct { + t time.Time +} + +// Time returns the internal [time.Time] instance. +func (d DateTime) Time() time.Time { + return d.t +} + +// IsZero checks whether the current DateTime instance has zero time value. +func (d DateTime) IsZero() bool { + return d.Time().IsZero() +} + +// String serializes the current DateTime instance into a formatted +// UTC date string. +// +// The zero value is serialized to an empty string. +func (d DateTime) String() string { + t := d.Time() + if t.IsZero() { + return "" + } + return t.UTC().Format(DefaultDateLayout) +} + +// MarshalJSON implements the [json.Marshaler] interface. +func (d DateTime) MarshalJSON() ([]byte, error) { + return []byte(`"` + d.String() + `"`), nil +} + +// UnmarshalJSON implements the [json.Unmarshaler] interface. +func (d *DateTime) UnmarshalJSON(b []byte) error { + var raw string + if err := json.Unmarshal(b, &raw); err != nil { + return err + } + return d.Scan(raw) +} + +// Value implements the [driver.Valuer] interface. +func (d DateTime) Value() (driver.Value, error) { + return d.String(), nil +} + +// Scan implements [sql.Scanner] interface to scan the provided value +// into the current DateTime instance. +func (d *DateTime) Scan(value any) error { + switch v := value.(type) { + case time.Time: + d.t = v + case DateTime: + d.t = v.Time() + case string: + if v == "" { + d.t = time.Time{} + } else { + t, err := time.Parse(DefaultDateLayout, v) + if err != nil { + // check for other common date layouts + t = cast.ToTime(v) + } + d.t = t + } + case int, int64, int32, uint, uint64, uint32: + d.t = cast.ToTime(v) + default: + str := cast.ToString(v) + if str == "" { + d.t = time.Time{} + } else { + d.t = cast.ToTime(str) + } + } + + return nil +} diff --git a/pkg/types/datetime_test.go b/pkg/types/datetime_test.go new file mode 100644 index 0000000..02bc220 --- /dev/null +++ b/pkg/types/datetime_test.go @@ -0,0 +1,205 @@ +package types_test + +import ( + "strings" + "testing" + "time" + + "github.com/apus-run/van/pkg/types" +) + +func TestNowDateTime(t *testing.T) { + now := time.Now().UTC().Format("2006-01-02 15:04:05") // without ms part for test consistency + dt := types.NowDateTime() + + if !strings.Contains(dt.String(), now) { + t.Fatalf("Expected %q, got %q", now, dt.String()) + } +} + +func TestParseDateTime(t *testing.T) { + nowTime := time.Now().UTC() + nowDateTime, _ := types.ParseDateTime(nowTime) + nowStr := nowTime.Format(types.DefaultDateLayout) + + scenarios := []struct { + value any + expected string + }{ + {nil, ""}, + {"", ""}, + {"invalid", ""}, + {nowDateTime, nowStr}, + {nowTime, nowStr}, + {1641024040, "2022-01-01 08:00:40.000Z"}, + {int32(1641024040), "2022-01-01 08:00:40.000Z"}, + {int64(1641024040), "2022-01-01 08:00:40.000Z"}, + {uint(1641024040), "2022-01-01 08:00:40.000Z"}, + {uint64(1641024040), "2022-01-01 08:00:40.000Z"}, + {uint32(1641024040), "2022-01-01 08:00:40.000Z"}, + {"2022-01-01 11:23:45.678", "2022-01-01 11:23:45.678Z"}, + } + + for i, s := range scenarios { + dt, err := types.ParseDateTime(s.value) + if err != nil { + t.Errorf("(%d) Failed to parse %v: %v", i, s.value, err) + continue + } + + if dt.String() != s.expected { + t.Errorf("(%d) Expected %q, got %q", i, s.expected, dt.String()) + } + } +} + +func TestDateTimeTime(t *testing.T) { + str := "2022-01-01 11:23:45.678Z" + + expected, err := time.Parse(types.DefaultDateLayout, str) + if err != nil { + t.Fatal(err) + } + + dt, err := types.ParseDateTime(str) + if err != nil { + t.Fatal(err) + } + + result := dt.Time() + + if !expected.Equal(result) { + t.Errorf("Expected time %v, got %v", expected, result) + } +} + +func TestDateTimeIsZero(t *testing.T) { + dt0 := types.DateTime{} + if !dt0.IsZero() { + t.Fatalf("Expected zero datatime, got %v", dt0) + } + + dt1 := types.NowDateTime() + if dt1.IsZero() { + t.Fatalf("Expected non-zero datatime, got %v", dt1) + } +} + +func TestDateTimeString(t *testing.T) { + dt0 := types.DateTime{} + if dt0.String() != "" { + t.Fatalf("Expected empty string for zer datetime, got %q", dt0.String()) + } + + expected := "2022-01-01 11:23:45.678Z" + dt1, _ := types.ParseDateTime(expected) + if dt1.String() != expected { + t.Fatalf("Expected %q, got %v", expected, dt1) + } +} + +func TestDateTimeMarshalJSON(t *testing.T) { + scenarios := []struct { + date string + expected string + }{ + {"", `""`}, + {"2022-01-01 11:23:45.678", `"2022-01-01 11:23:45.678Z"`}, + } + + for i, s := range scenarios { + dt, err := types.ParseDateTime(s.date) + if err != nil { + t.Errorf("(%d) %v", i, err) + } + + result, err := dt.MarshalJSON() + if err != nil { + t.Errorf("(%d) %v", i, err) + } + + if string(result) != s.expected { + t.Errorf("(%d) Expected %q, got %q", i, s.expected, string(result)) + } + } +} + +func TestDateTimeUnmarshalJSON(t *testing.T) { + scenarios := []struct { + date string + expected string + }{ + {"", ""}, + {"invalid_json", ""}, + {"'123'", ""}, + {"2022-01-01 11:23:45.678", ""}, + {`"2022-01-01 11:23:45.678"`, "2022-01-01 11:23:45.678Z"}, + } + + for i, s := range scenarios { + dt := types.DateTime{} + dt.UnmarshalJSON([]byte(s.date)) + + if dt.String() != s.expected { + t.Errorf("(%d) Expected %q, got %q", i, s.expected, dt.String()) + } + } +} + +func TestDateTimeValue(t *testing.T) { + scenarios := []struct { + value any + expected string + }{ + {"", ""}, + {"invalid", ""}, + {1641024040, "2022-01-01 08:00:40.000Z"}, + {"2022-01-01 11:23:45.678", "2022-01-01 11:23:45.678Z"}, + {types.NowDateTime(), types.NowDateTime().String()}, + } + + for i, s := range scenarios { + dt, _ := types.ParseDateTime(s.value) + result, err := dt.Value() + if err != nil { + t.Errorf("(%d) %v", i, err) + continue + } + + if result != s.expected { + t.Errorf("(%d) Expected %q, got %q", i, s.expected, result) + } + } +} + +func TestDateTimeScan(t *testing.T) { + now := time.Now().UTC().Format("2006-01-02 15:04:05") // without ms part for test consistency + + scenarios := []struct { + value any + expected string + }{ + {nil, ""}, + {"", ""}, + {"invalid", ""}, + {types.NowDateTime(), now}, + {time.Now(), now}, + {1.0, ""}, + {1641024040, "2022-01-01 08:00:40.000Z"}, + {"2022-01-01 11:23:45.678", "2022-01-01 11:23:45.678Z"}, + } + + for i, s := range scenarios { + dt := types.DateTime{} + + err := dt.Scan(s.value) + if err != nil { + t.Errorf("(%d) Failed to parse %v: %v", i, s.value, err) + continue + } + + if !strings.Contains(dt.String(), s.expected) { + t.Errorf("(%d) Expected %q, got %q", i, s.expected, dt.String()) + } + } +} diff --git a/pkg/types/json_array.go b/pkg/types/json_array.go new file mode 100644 index 0000000..b06f116 --- /dev/null +++ b/pkg/types/json_array.go @@ -0,0 +1,52 @@ +package types + +import ( + "database/sql/driver" + "encoding/json" + "fmt" +) + +// JsonArray defines a slice that is safe for json and db read/write. +type JsonArray[T any] []T + +// internal alias to prevent recursion during marshalization. +type jsonArrayAlias[T any] JsonArray[T] + +// MarshalJSON implements the [json.Marshaler] interface. +func (m JsonArray[T]) MarshalJSON() ([]byte, error) { + // initialize an empty map to ensure that `[]` is returned as json + if m == nil { + m = JsonArray[T]{} + } + + return json.Marshal(jsonArrayAlias[T](m)) +} + +// Value implements the [driver.Valuer] interface. +func (m JsonArray[T]) Value() (driver.Value, error) { + data, err := json.Marshal(m) + + return string(data), err +} + +// Scan implements [sql.Scanner] interface to scan the provided value +// into the current JsonArray[T] instance. +func (m *JsonArray[T]) Scan(value any) error { + var data []byte + switch v := value.(type) { + case nil: + // no cast needed + case []byte: + data = v + case string: + data = []byte(v) + default: + return fmt.Errorf("failed to unmarshal JsonArray value: %q", value) + } + + if len(data) == 0 { + data = []byte("[]") + } + + return json.Unmarshal(data, m) +} diff --git a/pkg/types/json_array_test.go b/pkg/types/json_array_test.go new file mode 100644 index 0000000..8168326 --- /dev/null +++ b/pkg/types/json_array_test.go @@ -0,0 +1,96 @@ +package types_test + +import ( + "database/sql/driver" + "encoding/json" + "testing" + + "github.com/apus-run/van/pkg/types" +) + +func TestJsonArrayMarshalJSON(t *testing.T) { + scenarios := []struct { + json json.Marshaler + expected string + }{ + {new(types.JsonArray[any]), "[]"}, + {types.JsonArray[any]{}, `[]`}, + {types.JsonArray[int]{1, 2, 3}, `[1,2,3]`}, + {types.JsonArray[string]{"test1", "test2", "test3"}, `["test1","test2","test3"]`}, + {types.JsonArray[any]{1, "test"}, `[1,"test"]`}, + } + + for i, s := range scenarios { + result, err := s.json.MarshalJSON() + if err != nil { + t.Errorf("(%d) %v", i, err) + continue + } + if string(result) != s.expected { + t.Errorf("(%d) Expected %s, got %s", i, s.expected, string(result)) + } + } +} + +func TestJsonArrayValue(t *testing.T) { + scenarios := []struct { + json driver.Valuer + expected driver.Value + }{ + {new(types.JsonArray[any]), `[]`}, + {types.JsonArray[any]{}, `[]`}, + {types.JsonArray[int]{1, 2, 3}, `[1,2,3]`}, + {types.JsonArray[string]{"test1", "test2", "test3"}, `["test1","test2","test3"]`}, + {types.JsonArray[any]{1, "test"}, `[1,"test"]`}, + } + + for i, s := range scenarios { + result, err := s.json.Value() + if err != nil { + t.Errorf("(%d) %v", i, err) + continue + } + if result != s.expected { + t.Errorf("(%d) Expected %s, got %v", i, s.expected, result) + } + } +} + +func TestJsonArrayScan(t *testing.T) { + scenarios := []struct { + value any + expectError bool + expectJson string + }{ + {``, false, `[]`}, + {[]byte{}, false, `[]`}, + {nil, false, `[]`}, + {123, true, `[]`}, + {`""`, true, `[]`}, + {`invalid_json`, true, `[]`}, + {`"test"`, true, `[]`}, + {`1,2,3`, true, `[]`}, + {`[1, 2, 3`, true, `[]`}, + {`[1, 2, 3]`, false, `[1,2,3]`}, + {[]byte(`[1, 2, 3]`), false, `[1,2,3]`}, + {`[1, "test"]`, false, `[1,"test"]`}, + {`[]`, false, `[]`}, + } + + for i, s := range scenarios { + arr := types.JsonArray[any]{} + scanErr := arr.Scan(s.value) + + hasErr := scanErr != nil + if hasErr != s.expectError { + t.Errorf("(%d) Expected %v, got %v (%v)", i, s.expectError, hasErr, scanErr) + continue + } + + result, _ := arr.MarshalJSON() + + if string(result) != s.expectJson { + t.Errorf("(%d) Expected %s, got %v", i, s.expectJson, string(result)) + } + } +} diff --git a/pkg/types/json_map.go b/pkg/types/json_map.go new file mode 100644 index 0000000..a94ad22 --- /dev/null +++ b/pkg/types/json_map.go @@ -0,0 +1,67 @@ +package types + +import ( + "database/sql/driver" + "encoding/json" + "fmt" +) + +// JsonMap defines a map that is safe for json and db read/write. +type JsonMap map[string]any + +// MarshalJSON implements the [json.Marshaler] interface. +func (m JsonMap) MarshalJSON() ([]byte, error) { + type alias JsonMap // prevent recursion + + // initialize an empty map to ensure that `{}` is returned as json + if m == nil { + m = JsonMap{} + } + + return json.Marshal(alias(m)) +} + +// Get retrieves a single value from the current JsonMap. +// +// This helper was added primarily to assist the goja integration since custom map types +// don't have direct access to the map keys (https://pkg.go.dev/github.com/dop251/goja#hdr-Maps_with_methods). +func (m JsonMap) Get(key string) any { + return m[key] +} + +// Set sets a single value in the current JsonMap. +// +// This helper was added primarily to assist the goja integration since custom map types +// don't have direct access to the map keys (https://pkg.go.dev/github.com/dop251/goja#hdr-Maps_with_methods). +func (m JsonMap) Set(key string, value any) { + m[key] = value +} + +// Value implements the [driver.Valuer] interface. +func (m JsonMap) Value() (driver.Value, error) { + data, err := json.Marshal(m) + + return string(data), err +} + +// Scan implements [sql.Scanner] interface to scan the provided value +// into the current `JsonMap` instance. +func (m *JsonMap) Scan(value any) error { + var data []byte + switch v := value.(type) { + case nil: + // no cast needed + case []byte: + data = v + case string: + data = []byte(v) + default: + return fmt.Errorf("failed to unmarshal JsonMap value: %q", value) + } + + if len(data) == 0 { + data = []byte("{}") + } + + return json.Unmarshal(data, m) +} diff --git a/pkg/types/json_map_test.go b/pkg/types/json_map_test.go new file mode 100644 index 0000000..a21ad70 --- /dev/null +++ b/pkg/types/json_map_test.go @@ -0,0 +1,132 @@ +package types_test + +import ( + "database/sql/driver" + "testing" + + "github.com/apus-run/van/pkg/types" +) + +func TestJsonMapMarshalJSON(t *testing.T) { + scenarios := []struct { + json types.JsonMap + expected string + }{ + {nil, "{}"}, + {types.JsonMap{}, `{}`}, + {types.JsonMap{"test1": 123, "test2": "lorem"}, `{"test1":123,"test2":"lorem"}`}, + {types.JsonMap{"test": []int{1, 2, 3}}, `{"test":[1,2,3]}`}, + } + + for i, s := range scenarios { + result, err := s.json.MarshalJSON() + if err != nil { + t.Errorf("(%d) %v", i, err) + continue + } + if string(result) != s.expected { + t.Errorf("(%d) Expected %s, got %s", i, s.expected, string(result)) + } + } +} + +func TestJsonMapGet(t *testing.T) { + scenarios := []struct { + json types.JsonMap + key string + expected any + }{ + {nil, "test", nil}, + {types.JsonMap{"test": 123}, "test", 123}, + {types.JsonMap{"test": 123}, "missing", nil}, + } + + for i, s := range scenarios { + result := s.json.Get(s.key) + if result != s.expected { + t.Errorf("(%d) Expected %s, got %v", i, s.expected, result) + } + } +} + +func TestJsonMapSet(t *testing.T) { + scenarios := []struct { + key string + value any + }{ + {"a", nil}, + {"a", 123}, + {"b", "test"}, + } + + for i, s := range scenarios { + j := types.JsonMap{} + + j.Set(s.key, s.value) + + if v := j[s.key]; v != s.value { + t.Errorf("(%d) Expected %s, got %v", i, s.value, v) + } + } +} + +func TestJsonMapValue(t *testing.T) { + scenarios := []struct { + json types.JsonMap + expected driver.Value + }{ + {nil, `{}`}, + {types.JsonMap{}, `{}`}, + {types.JsonMap{"test1": 123, "test2": "lorem"}, `{"test1":123,"test2":"lorem"}`}, + {types.JsonMap{"test": []int{1, 2, 3}}, `{"test":[1,2,3]}`}, + } + + for i, s := range scenarios { + result, err := s.json.Value() + if err != nil { + t.Errorf("(%d) %v", i, err) + continue + } + if result != s.expected { + t.Errorf("(%d) Expected %s, got %v", i, s.expected, result) + } + } +} + +func TestJsonArrayMapScan(t *testing.T) { + scenarios := []struct { + value any + expectError bool + expectJson string + }{ + {``, false, `{}`}, + {nil, false, `{}`}, + {[]byte{}, false, `{}`}, + {`{}`, false, `{}`}, + {123, true, `{}`}, + {`""`, true, `{}`}, + {`invalid_json`, true, `{}`}, + {`"test"`, true, `{}`}, + {`1,2,3`, true, `{}`}, + {`{"test": 1`, true, `{}`}, + {`{"test": 1}`, false, `{"test":1}`}, + {[]byte(`{"test": 1}`), false, `{"test":1}`}, + } + + for i, s := range scenarios { + arr := types.JsonMap{} + scanErr := arr.Scan(s.value) + + hasErr := scanErr != nil + if hasErr != s.expectError { + t.Errorf("(%d) Expected %v, got %v (%v)", i, s.expectError, hasErr, scanErr) + continue + } + + result, _ := arr.MarshalJSON() + + if string(result) != s.expectJson { + t.Errorf("(%d) Expected %s, got %v", i, s.expectJson, string(result)) + } + } +} diff --git a/pkg/types/json_raw.go b/pkg/types/json_raw.go new file mode 100644 index 0000000..670f299 --- /dev/null +++ b/pkg/types/json_raw.go @@ -0,0 +1,83 @@ +package types + +import ( + "database/sql/driver" + "encoding/json" + "errors" +) + +// JsonRaw defines a json value type that is safe for db read/write. +type JsonRaw []byte + +// ParseJsonRaw creates a new JsonRaw instance from the provided value +// (could be JsonRaw, int, float, string, []byte, etc.). +func ParseJsonRaw(value any) (JsonRaw, error) { + result := JsonRaw{} + err := result.Scan(value) + return result, err +} + +// String returns the current JsonRaw instance as a json encoded string. +func (j JsonRaw) String() string { + return string(j) +} + +// MarshalJSON implements the [json.Marshaler] interface. +func (j JsonRaw) MarshalJSON() ([]byte, error) { + if len(j) == 0 { + return []byte("null"), nil + } + + return j, nil +} + +// UnmarshalJSON implements the [json.Unmarshaler] interface. +func (j *JsonRaw) UnmarshalJSON(b []byte) error { + if j == nil { + return errors.New("JsonRaw: UnmarshalJSON on nil pointer") + } + + *j = append((*j)[0:0], b...) + + return nil +} + +// Value implements the [driver.Valuer] interface. +func (j JsonRaw) Value() (driver.Value, error) { + if len(j) == 0 { + return nil, nil + } + + return j.String(), nil +} + +// Scan implements [sql.Scanner] interface to scan the provided value +// into the current JsonRaw instance. +func (j *JsonRaw) Scan(value any) error { + var data []byte + + switch v := value.(type) { + case nil: + // no cast is needed + case []byte: + if len(v) != 0 { + data = v + } + case string: + if v != "" { + data = []byte(v) + } + case JsonRaw: + if len(v) != 0 { + data = []byte(v) + } + default: + bytes, err := json.Marshal(v) + if err != nil { + return err + } + data = bytes + } + + return j.UnmarshalJSON(data) +} diff --git a/pkg/types/json_raw_test.go b/pkg/types/json_raw_test.go new file mode 100644 index 0000000..37fb3c0 --- /dev/null +++ b/pkg/types/json_raw_test.go @@ -0,0 +1,178 @@ +package types_test + +import ( + "database/sql/driver" + "testing" + + "github.com/apus-run/van/pkg/types" +) + +func TestParseJsonRaw(t *testing.T) { + scenarios := []struct { + value any + expectError bool + expectJson string + }{ + {nil, false, `null`}, + {``, false, `null`}, + {[]byte{}, false, `null`}, + {types.JsonRaw{}, false, `null`}, + {`{}`, false, `{}`}, + {`[]`, false, `[]`}, + {123, false, `123`}, + {`""`, false, `""`}, + {`test`, false, `test`}, + {`{"invalid"`, false, `{"invalid"`}, // treated as a byte casted string + {`{"test":1}`, false, `{"test":1}`}, + {[]byte(`[1,2,3]`), false, `[1,2,3]`}, + {[]int{1, 2, 3}, false, `[1,2,3]`}, + {map[string]int{"test": 1}, false, `{"test":1}`}, + } + + for i, s := range scenarios { + raw, parseErr := types.ParseJsonRaw(s.value) + hasErr := parseErr != nil + if hasErr != s.expectError { + t.Errorf("(%d) Expected %v, got %v (%v)", i, s.expectError, hasErr, parseErr) + continue + } + + result, _ := raw.MarshalJSON() + + if string(result) != s.expectJson { + t.Errorf("(%d) Expected %s, got %v", i, s.expectJson, string(result)) + } + } +} + +func TestJsonRawString(t *testing.T) { + scenarios := []struct { + json types.JsonRaw + expected string + }{ + {nil, ``}, + {types.JsonRaw{}, ``}, + {types.JsonRaw([]byte(`123`)), `123`}, + {types.JsonRaw(`{"demo":123}`), `{"demo":123}`}, + } + + for i, s := range scenarios { + result := s.json.String() + if result != s.expected { + t.Errorf("(%d) Expected %q, got %q", i, s.expected, result) + } + } +} + +func TestJsonRawMarshalJSON(t *testing.T) { + scenarios := []struct { + json types.JsonRaw + expected string + }{ + {nil, `null`}, + {types.JsonRaw{}, `null`}, + {types.JsonRaw([]byte(`123`)), `123`}, + {types.JsonRaw(`{"demo":123}`), `{"demo":123}`}, + } + + for i, s := range scenarios { + result, err := s.json.MarshalJSON() + if err != nil { + t.Errorf("(%d) %v", i, err) + continue + } + + if string(result) != s.expected { + t.Errorf("(%d) Expected %q, got %q", i, s.expected, string(result)) + } + } +} + +func TestJsonRawUnmarshalJSON(t *testing.T) { + scenarios := []struct { + json []byte + expectString string + }{ + {nil, ""}, + {[]byte{0, 1, 2}, "\x00\x01\x02"}, + {[]byte("123"), "123"}, + {[]byte("test"), "test"}, + {[]byte(`{"test":123}`), `{"test":123}`}, + } + + for i, s := range scenarios { + raw := types.JsonRaw{} + err := raw.UnmarshalJSON(s.json) + if err != nil { + t.Errorf("(%d) %v", i, err) + continue + } + + if raw.String() != s.expectString { + t.Errorf("(%d) Expected %q, got %q", i, s.expectString, raw.String()) + } + } +} + +func TestJsonRawValue(t *testing.T) { + scenarios := []struct { + json types.JsonRaw + expected driver.Value + }{ + {nil, nil}, + {types.JsonRaw{}, nil}, + {types.JsonRaw(``), nil}, + {types.JsonRaw(`test`), `test`}, + } + + for i, s := range scenarios { + result, err := s.json.Value() + if err != nil { + t.Errorf("(%d) %v", i, err) + continue + } + if result != s.expected { + t.Errorf("(%d) Expected %s, got %v", i, s.expected, result) + } + } +} + +func TestJsonRawScan(t *testing.T) { + scenarios := []struct { + value any + expectError bool + expectJson string + }{ + {nil, false, `null`}, + {``, false, `null`}, + {[]byte{}, false, `null`}, + {types.JsonRaw{}, false, `null`}, + {types.JsonRaw(`test`), false, `test`}, + {`{}`, false, `{}`}, + {`[]`, false, `[]`}, + {123, false, `123`}, + {`""`, false, `""`}, + {`test`, false, `test`}, + {`{"invalid"`, false, `{"invalid"`}, // treated as a byte casted string + {`{"test":1}`, false, `{"test":1}`}, + {[]byte(`[1,2,3]`), false, `[1,2,3]`}, + {[]int{1, 2, 3}, false, `[1,2,3]`}, + {map[string]int{"test": 1}, false, `{"test":1}`}, + } + + for i, s := range scenarios { + raw := types.JsonRaw{} + scanErr := raw.Scan(s.value) + hasErr := scanErr != nil + if hasErr != s.expectError { + t.Errorf("(%d) Expected %v, got %v (%v)", i, s.expectError, hasErr, scanErr) + continue + } + + result, _ := raw.MarshalJSON() + + if string(result) != s.expectJson { + t.Errorf("(%d) Expected %s, got %v", i, s.expectJson, string(result)) + } + } +} diff --git a/pkg/kv/kv.go b/pkg/types/kv.go similarity index 96% rename from pkg/kv/kv.go rename to pkg/types/kv.go index 39ad360..eea2e9f 100644 --- a/pkg/kv/kv.go +++ b/pkg/types/kv.go @@ -1,4 +1,4 @@ -package kv +package types import ( "database/sql/driver" diff --git a/pkg/kv/kv_test.go b/pkg/types/kv_test.go similarity index 96% rename from pkg/kv/kv_test.go rename to pkg/types/kv_test.go index df0814b..64b6d5e 100644 --- a/pkg/kv/kv_test.go +++ b/pkg/types/kv_test.go @@ -1,4 +1,4 @@ -package kv +package types import ( "strconv" diff --git a/pkg/uuid/uuid.go b/pkg/uuid/uuid.go new file mode 100644 index 0000000..7254dfe --- /dev/null +++ b/pkg/uuid/uuid.go @@ -0,0 +1,24 @@ +package utils + +import ( + "crypto/rand" + + "github.com/google/uuid" + "github.com/lithammer/shortuuid/v4" + "github.com/oklog/ulid" +) + +// NewUUID returns a new UUID Version 4. +func NewUUID() string { + return uuid.New().String() +} + +// NewShortUUID returns a new short UUID. +func NewShortUUID() string { + return shortuuid.New() +} + +// NewULID returns a new ULID. +func NewULID() string { + return ulid.MustNew(ulid.Now(), rand.Reader).String() +} diff --git a/pkg/uuid/uuid_test.go b/pkg/uuid/uuid_test.go new file mode 100644 index 0000000..59a4ad1 --- /dev/null +++ b/pkg/uuid/uuid_test.go @@ -0,0 +1,55 @@ +package utils + +import ( + "sync" + "testing" +) + +func testuUniqness(t *testing.T, genFunc func() string) { + producers := 100 + uuidsPerProducer := 10000 + + if testing.Short() { + producers = 10 + uuidsPerProducer = 1000 + } + + uuidsCount := producers * uuidsPerProducer + + uuids := make(chan string, uuidsCount) + allGenerated := sync.WaitGroup{} + allGenerated.Add(producers) + + for i := 0; i < producers; i++ { + go func() { + for j := 0; j < uuidsPerProducer; j++ { + uuids <- genFunc() + } + allGenerated.Done() + }() + } + + uniqueUUIDs := make(map[string]struct{}, uuidsCount) + + allGenerated.Wait() + close(uuids) + + for uuid := range uuids { + if _, ok := uniqueUUIDs[uuid]; ok { + t.Error(uuid, " has duplicate") + } + uniqueUUIDs[uuid] = struct{}{} + } +} + +func TestUUID(t *testing.T) { + testuUniqness(t, NewUUID) +} + +func TestShortUUID(t *testing.T) { + testuUniqness(t, NewShortUUID) +} + +func TestULID(t *testing.T) { + testuUniqness(t, NewULID) +} diff --git a/store/store.go b/store/store.go new file mode 100644 index 0000000..9dd8c36 --- /dev/null +++ b/store/store.go @@ -0,0 +1,134 @@ +package store + +import "sync" + +// Store defines a concurrent safe in memory key-value data store. +type Store[T any] struct { + data map[string]T + mux sync.RWMutex +} + +// New creates a new Store[T] instance with a shallow copy of the provided data (if any). +func New[T any](data map[string]T) *Store[T] { + s := &Store[T]{} + + s.Reset(data) + + return s +} + +// Reset clears the store and replaces the store data with a +// shallow copy of the provided newData. +func (s *Store[T]) Reset(newData map[string]T) { + s.mux.Lock() + defer s.mux.Unlock() + + if len(newData) > 0 { + s.data = make(map[string]T, len(newData)) + for k, v := range newData { + s.data[k] = v + } + } else { + s.data = make(map[string]T) + } +} + +// Length returns the current number of elements in the store. +func (s *Store[T]) Length() int { + s.mux.RLock() + defer s.mux.RUnlock() + + return len(s.data) +} + +// RemoveAll removes all the existing store entries. +func (s *Store[T]) RemoveAll() { + s.mux.Lock() + defer s.mux.Unlock() + + s.data = make(map[string]T) +} + +// Remove removes a single entry from the store. +// +// Remove does nothing if key doesn't exist in the store. +func (s *Store[T]) Remove(key string) { + s.mux.Lock() + defer s.mux.Unlock() + + delete(s.data, key) +} + +// Has checks if element with the specified key exist or not. +func (s *Store[T]) Has(key string) bool { + s.mux.RLock() + defer s.mux.RUnlock() + + _, ok := s.data[key] + + return ok +} + +// Get returns a single element value from the store. +// +// If key is not set, the zero T value is returned. +func (s *Store[T]) Get(key string) T { + s.mux.RLock() + defer s.mux.RUnlock() + + return s.data[key] +} + +// GetAll returns a shallow copy of the current store data. +func (s *Store[T]) GetAll() map[string]T { + s.mux.RLock() + defer s.mux.RUnlock() + + var clone = make(map[string]T, len(s.data)) + + for k, v := range s.data { + clone[k] = v + } + + return clone +} + +// Set sets (or overwrite if already exist) a new value for key. +func (s *Store[T]) Set(key string, value T) { + s.mux.Lock() + defer s.mux.Unlock() + + if s.data == nil { + s.data = make(map[string]T) + } + + s.data[key] = value +} + +// SetIfLessThanLimit sets (or overwrite if already exist) a new value for key. +// +// This method is similar to Set() but **it will skip adding new elements** +// to the store if the store length has reached the specified limit. +// false is returned if maxAllowedElements limit is reached. +func (s *Store[T]) SetIfLessThanLimit(key string, value T, maxAllowedElements int) bool { + s.mux.Lock() + defer s.mux.Unlock() + + // init map if not already + if s.data == nil { + s.data = make(map[string]T) + } + + // check for existing item + _, ok := s.data[key] + + if !ok && len(s.data) >= maxAllowedElements { + // cannot add more items + return false + } + + // add/overwrite item + s.data[key] = value + + return true +} diff --git a/store/store_test.go b/store/store_test.go new file mode 100644 index 0000000..c8b753e --- /dev/null +++ b/store/store_test.go @@ -0,0 +1,232 @@ +package store_test + +import ( + "bytes" + "encoding/json" + "testing" + + "github.com/apus-run/van/store" +) + +func TestNew(t *testing.T) { + data := map[string]int{"test1": 1, "test2": 2} + originalRawData, err := json.Marshal(data) + if err != nil { + t.Fatal(err) + } + + s := store.New(data) + s.Set("test3", 3) // add 1 item + s.Remove("test1") // remove 1 item + + // check if data was shallow copied + rawData, _ := json.Marshal(data) + if !bytes.Equal(originalRawData, rawData) { + t.Fatalf("Expected data \n%s, \ngot \n%s", originalRawData, rawData) + } + + if s.Has("test1") { + t.Fatalf("Expected test1 to be deleted, got %v", s.Get("test1")) + } + + if v := s.Get("test2"); v != 2 { + t.Fatalf("Expected test2 to be %v, got %v", 2, v) + } + + if v := s.Get("test3"); v != 3 { + t.Fatalf("Expected test3 to be %v, got %v", 3, v) + } +} + +func TestReset(t *testing.T) { + s := store.New(map[string]int{"test1": 1}) + + data := map[string]int{"test2": 2} + originalRawData, err := json.Marshal(data) + if err != nil { + t.Fatal(err) + } + + s.Reset(data) + s.Set("test3", 3) + + // check if data was shallow copied + rawData, _ := json.Marshal(data) + if !bytes.Equal(originalRawData, rawData) { + t.Fatalf("Expected data \n%s, \ngot \n%s", originalRawData, rawData) + } + + if s.Has("test1") { + t.Fatalf("Expected test1 to be deleted, got %v", s.Get("test1")) + } + + if v := s.Get("test2"); v != 2 { + t.Fatalf("Expected test2 to be %v, got %v", 2, v) + } + + if v := s.Get("test3"); v != 3 { + t.Fatalf("Expected test3 to be %v, got %v", 3, v) + } +} + +func TestLength(t *testing.T) { + s := store.New(map[string]int{"test1": 1}) + s.Set("test2", 2) + + if v := s.Length(); v != 2 { + t.Fatalf("Expected length %d, got %d", 2, v) + } +} + +func TestRemoveAll(t *testing.T) { + s := store.New(map[string]bool{"test1": true, "test2": true}) + + keys := []string{"test1", "test2"} + + s.RemoveAll() + + for i, key := range keys { + if s.Has(key) { + t.Errorf("(%d) Expected %q to be removed", i, key) + } + } +} + +func TestRemove(t *testing.T) { + s := store.New(map[string]bool{"test": true}) + + keys := []string{"test", "missing"} + + for i, key := range keys { + s.Remove(key) + if s.Has(key) { + t.Errorf("(%d) Expected %q to be removed", i, key) + } + } +} + +func TestHas(t *testing.T) { + s := store.New(map[string]int{"test1": 0, "test2": 1}) + + scenarios := []struct { + key string + exist bool + }{ + {"test1", true}, + {"test2", true}, + {"missing", false}, + } + + for i, scenario := range scenarios { + exist := s.Has(scenario.key) + if exist != scenario.exist { + t.Errorf("(%d) Expected %v, got %v", i, scenario.exist, exist) + } + } +} + +func TestGet(t *testing.T) { + s := store.New(map[string]int{"test1": 0, "test2": 1}) + + scenarios := []struct { + key string + expect int + }{ + {"test1", 0}, + {"test2", 1}, + {"missing", 0}, // should auto fallback to the zero value + } + + for i, scenario := range scenarios { + val := s.Get(scenario.key) + if val != scenario.expect { + t.Errorf("(%d) Expected %v, got %v", i, scenario.expect, val) + } + } +} + +func TestGetAll(t *testing.T) { + data := map[string]int{ + "a": 1, + "b": 2, + } + + s := store.New(data) + + // fetch and delete each key to make sure that it was shallow copied + result := s.GetAll() + for k := range result { + delete(result, k) + } + + // refetch again + result = s.GetAll() + + if len(result) != len(data) { + t.Fatalf("Expected %d, got %d items", len(data), len(result)) + } + + for k := range result { + if result[k] != data[k] { + t.Fatalf("Expected %s to be %v, got %v", k, data[k], result[k]) + } + } +} + +func TestSet(t *testing.T) { + s := store.Store[int]{} + + data := map[string]int{"test1": 0, "test2": 1, "test3": 3} + + // set values + for k, v := range data { + s.Set(k, v) + } + + // verify that the values are set + for k, v := range data { + if !s.Has(k) { + t.Errorf("Expected key %q", k) + } + + val := s.Get(k) + if val != v { + t.Errorf("Expected %v, got %v for key %q", v, val, k) + } + } +} + +func TestSetIfLessThanLimit(t *testing.T) { + s := store.Store[int]{} + + limit := 2 + + // set values + scenarios := []struct { + key string + value int + expected bool + }{ + {"test1", 1, true}, + {"test2", 2, true}, + {"test3", 3, false}, + {"test2", 4, true}, // overwrite + } + + for i, scenario := range scenarios { + result := s.SetIfLessThanLimit(scenario.key, scenario.value, limit) + + if result != scenario.expected { + t.Errorf("(%d) Expected result %v, got %v", i, scenario.expected, result) + } + + if !scenario.expected && s.Has(scenario.key) { + t.Errorf("(%d) Expected key %q to not be set", i, scenario.key) + } + + val := s.Get(scenario.key) + if scenario.expected && val != scenario.value { + t.Errorf("(%d) Expected value %v, got %v", i, scenario.value, val) + } + } +} diff --git a/subscriptions/broker.go b/subscriptions/broker.go new file mode 100644 index 0000000..296efad --- /dev/null +++ b/subscriptions/broker.go @@ -0,0 +1,70 @@ +package subscriptions + +import ( + "fmt" + "sync" +) + +// Broker defines a struct for managing subscriptions clients. +type Broker struct { + clients map[string]Client + mux sync.RWMutex +} + +// NewBroker initializes and returns a new Broker instance. +func NewBroker() *Broker { + return &Broker{ + clients: make(map[string]Client), + } +} + +// Clients returns a shallow copy of all registered clients indexed +// with their connection id. +func (b *Broker) Clients() map[string]Client { + b.mux.RLock() + defer b.mux.RUnlock() + + copy := make(map[string]Client, len(b.clients)) + + for id, c := range b.clients { + copy[id] = c + } + + return copy +} + +// ClientById finds a registered client by its id. +// +// Returns non-nil error when client with clientId is not registered. +func (b *Broker) ClientById(clientId string) (Client, error) { + b.mux.RLock() + defer b.mux.RUnlock() + + client, ok := b.clients[clientId] + if !ok { + return nil, fmt.Errorf("No client associated with connection ID %q", clientId) + } + + return client, nil +} + +// Register adds a new client to the broker instance. +func (b *Broker) Register(client Client) { + b.mux.Lock() + defer b.mux.Unlock() + + b.clients[client.Id()] = client +} + +// Unregister removes a single client by its id. +// +// If client with clientId doesn't exist, this method does nothing. +func (b *Broker) Unregister(clientId string) { + b.mux.Lock() + defer b.mux.Unlock() + + if client, ok := b.clients[clientId]; ok { + client.Discard() + delete(b.clients, clientId) + } +} diff --git a/subscriptions/broker_test.go b/subscriptions/broker_test.go new file mode 100644 index 0000000..26a3e80 --- /dev/null +++ b/subscriptions/broker_test.go @@ -0,0 +1,93 @@ +package subscriptions_test + +import ( + "testing" + + "github.com/apus-run/van/subscriptions" +) + +func TestNewBroker(t *testing.T) { + b := subscriptions.NewBroker() + + if b.Clients() == nil { + t.Fatal("Expected clients map to be initialized") + } +} + +func TestClients(t *testing.T) { + b := subscriptions.NewBroker() + + if total := len(b.Clients()); total != 0 { + t.Fatalf("Expected no clients, got %v", total) + } + + b.Register(subscriptions.NewDefaultClient()) + b.Register(subscriptions.NewDefaultClient()) + + // check if it is a shallow copy + clients := b.Clients() + for k := range clients { + delete(clients, k) + } + + // should return a new copy + if total := len(b.Clients()); total != 2 { + t.Fatalf("Expected 2 clients, got %v", total) + } +} + +func TestClientById(t *testing.T) { + b := subscriptions.NewBroker() + + clientA := subscriptions.NewDefaultClient() + clientB := subscriptions.NewDefaultClient() + b.Register(clientA) + b.Register(clientB) + + resultClient, err := b.ClientById(clientA.Id()) + if err != nil { + t.Fatalf("Expected client with id %s, got error %v", clientA.Id(), err) + } + if resultClient.Id() != clientA.Id() { + t.Fatalf("Expected client %s, got %s", clientA.Id(), resultClient.Id()) + } + + if c, err := b.ClientById("missing"); err == nil { + t.Fatalf("Expected error, found client %v", c) + } +} + +func TestRegister(t *testing.T) { + b := subscriptions.NewBroker() + + client := subscriptions.NewDefaultClient() + b.Register(client) + + if _, err := b.ClientById(client.Id()); err != nil { + t.Fatalf("Expected client with id %s, got error %v", client.Id(), err) + } +} + +func TestUnregister(t *testing.T) { + b := subscriptions.NewBroker() + + clientA := subscriptions.NewDefaultClient() + clientB := subscriptions.NewDefaultClient() + b.Register(clientA) + b.Register(clientB) + + if _, err := b.ClientById(clientA.Id()); err != nil { + t.Fatalf("Expected client with id %s, got error %v", clientA.Id(), err) + } + + b.Unregister(clientA.Id()) + + if c, err := b.ClientById(clientA.Id()); err == nil { + t.Fatalf("Expected error, found client %v", c) + } + + // clientB shouldn't have been removed + if _, err := b.ClientById(clientB.Id()); err != nil { + t.Fatalf("Expected client with id %s, got error %v", clientB.Id(), err) + } +} diff --git a/subscriptions/client.go b/subscriptions/client.go new file mode 100644 index 0000000..e6ce696 --- /dev/null +++ b/subscriptions/client.go @@ -0,0 +1,274 @@ +package subscriptions + +import ( + "encoding/json" + "net/url" + "strings" + "sync" + + "github.com/spf13/cast" + + "github.com/apus-run/van/pkg/inflector" + "github.com/apus-run/van/pkg/rand" +) + +const optionsParam = "options" + +// Message defines a client's channel data. +type Message struct { + Name string `json:"name"` + Data []byte `json:"data"` +} + +// SubscriptionOptions defines the request options (query params, headers, etc.) +// for a single subscription topic. +type SubscriptionOptions struct { + // @todo after the requests handling refactoring consider + // changing to map[string]string or map[string][]string + + Query map[string]any `json:"query"` + Headers map[string]any `json:"headers"` +} + +// Client is an interface for a generic subscription client. +type Client interface { + // Id Returns the unique id of the client. + Id() string + + // Channel returns the client's communication channel. + Channel() chan Message + + // Subscriptions returns a shallow copy of the the client subscriptions matching the prefixes. + // If no prefix is specified, returns all subscriptions. + Subscriptions(prefixes ...string) map[string]SubscriptionOptions + + // Subscribe subscribes the client to the provided subscriptions list. + // + // Each subscription can also have "options" (json serialized SubscriptionOptions) as query parameter. + // + // Example: + // + // Subscribe( + // "subscriptionA", + // `subscriptionB?options={"query":{"a":1},"headers":{"x_token":"abc"}}`, + // ) + Subscribe(subs ...string) + + // Unsubscribe unsubscribes the client from the provided subscriptions list. + Unsubscribe(subs ...string) + + // HasSubscription checks if the client is subscribed to `sub`. + HasSubscription(sub string) bool + + // Set stores any value to the client's context. + Set(key string, value any) + + // Unset removes a single value from the client's context. + Unset(key string) + + // Get retrieves the key value from the client's context. + Get(key string) any + + // Discard marks the client as "discarded", meaning that it + // shouldn't be used anymore for sending new messages. + // + // It is safe to call Discard() multiple times. + Discard() + + // IsDiscarded indicates whether the client has been "discarded" + // and should no longer be used. + IsDiscarded() bool + + // Send sends the specified message to the client's channel (if not discarded). + Send(m Message) +} + +// ensures that DefaultClient satisfies the Client interface +var _ Client = (*DefaultClient)(nil) + +// DefaultClient defines a generic subscription client. +type DefaultClient struct { + store map[string]any + subscriptions map[string]SubscriptionOptions + channel chan Message + id string + mux sync.RWMutex + isDiscarded bool +} + +// NewDefaultClient creates and returns a new DefaultClient instance. +func NewDefaultClient() *DefaultClient { + return &DefaultClient{ + id: rand.RandomString(40), + store: map[string]any{}, + channel: make(chan Message), + subscriptions: map[string]SubscriptionOptions{}, + } +} + +// Id implements the [Client.Id] interface method. +func (c *DefaultClient) Id() string { + c.mux.RLock() + defer c.mux.RUnlock() + + return c.id +} + +// Channel implements the [Client.Channel] interface method. +func (c *DefaultClient) Channel() chan Message { + c.mux.RLock() + defer c.mux.RUnlock() + + return c.channel +} + +// Subscriptions implements the [Client.Subscriptions] interface method. +// +// It returns a shallow copy of the the client subscriptions matching the prefixes. +// If no prefix is specified, returns all subscriptions. +func (c *DefaultClient) Subscriptions(prefixes ...string) map[string]SubscriptionOptions { + c.mux.RLock() + defer c.mux.RUnlock() + + // no prefix -> return copy of all subscriptions + if len(prefixes) == 0 { + result := make(map[string]SubscriptionOptions, len(c.subscriptions)) + + for s, options := range c.subscriptions { + result[s] = options + } + + return result + } + + result := make(map[string]SubscriptionOptions) + + for _, prefix := range prefixes { + for s, options := range c.subscriptions { + // "?" ensures that the options query start character is always there + // so that it can be used as an end separator when looking only for the main subscription topic + if strings.HasPrefix(s+"?", prefix) { + result[s] = options + } + } + } + + return result +} + +// Subscribe implements the [Client.Subscribe] interface method. +// +// Empty subscriptions (aka. "") are ignored. +func (c *DefaultClient) Subscribe(subs ...string) { + c.mux.Lock() + defer c.mux.Unlock() + + for _, s := range subs { + if s == "" { + continue // skip empty + } + + // extract subscription options (if any) + options := SubscriptionOptions{} + u, err := url.Parse(s) + if err == nil { + rawOptions := u.Query().Get(optionsParam) + if rawOptions != "" { + json.Unmarshal([]byte(rawOptions), &options) + } + } + + // normalize query + // (currently only single string values are supported for consistency with the default routes handling) + for k, v := range options.Query { + options.Query[k] = cast.ToString(v) + } + + // normalize headers name and values, eg. "X-Token" is converted to "x_token" + // (currently only single string values are supported for consistency with the default routes handling) + for k, v := range options.Headers { + delete(options.Headers, k) + options.Headers[inflector.Snakecase(k)] = cast.ToString(v) + } + + c.subscriptions[s] = options + } +} + +// Unsubscribe implements the [Client.Unsubscribe] interface method. +// +// If subs is not set, this method removes all registered client's subscriptions. +func (c *DefaultClient) Unsubscribe(subs ...string) { + c.mux.Lock() + defer c.mux.Unlock() + + if len(subs) > 0 { + for _, s := range subs { + delete(c.subscriptions, s) + } + } else { + // unsubscribe all + for s := range c.subscriptions { + delete(c.subscriptions, s) + } + } +} + +// HasSubscription implements the [Client.HasSubscription] interface method. +func (c *DefaultClient) HasSubscription(sub string) bool { + c.mux.RLock() + defer c.mux.RUnlock() + + _, ok := c.subscriptions[sub] + + return ok +} + +// Get implements the [Client.Get] interface method. +func (c *DefaultClient) Get(key string) any { + c.mux.RLock() + defer c.mux.RUnlock() + + return c.store[key] +} + +// Set implements the [Client.Set] interface method. +func (c *DefaultClient) Set(key string, value any) { + c.mux.Lock() + defer c.mux.Unlock() + + c.store[key] = value +} + +// Unset implements the [Client.Unset] interface method. +func (c *DefaultClient) Unset(key string) { + c.mux.Lock() + defer c.mux.Unlock() + + delete(c.store, key) +} + +// Discard implements the [Client.Discard] interface method. +func (c *DefaultClient) Discard() { + c.mux.Lock() + defer c.mux.Unlock() + + c.isDiscarded = true +} + +// IsDiscarded implements the [Client.IsDiscarded] interface method. +func (c *DefaultClient) IsDiscarded() bool { + c.mux.RLock() + defer c.mux.RUnlock() + + return c.isDiscarded +} + +// Send sends the specified message to the client's channel (if not discarded). +func (c *DefaultClient) Send(m Message) { + if c.IsDiscarded() { + return + } + + c.Channel() <- m +} diff --git a/subscriptions/client_test.go b/subscriptions/client_test.go new file mode 100644 index 0000000..cff028a --- /dev/null +++ b/subscriptions/client_test.go @@ -0,0 +1,244 @@ +package subscriptions_test + +import ( + "encoding/json" + "strings" + "testing" + "time" + + "github.com/apus-run/van/subscriptions" +) + +func TestNewDefaultClient(t *testing.T) { + c := subscriptions.NewDefaultClient() + + if c.Channel() == nil { + t.Errorf("Expected channel to be initialized") + } + + if c.Subscriptions() == nil { + t.Errorf("Expected subscriptions map to be initialized") + } + + if c.Id() == "" { + t.Errorf("Expected unique id to be set") + } +} + +func TestId(t *testing.T) { + clients := []*subscriptions.DefaultClient{ + subscriptions.NewDefaultClient(), + subscriptions.NewDefaultClient(), + subscriptions.NewDefaultClient(), + subscriptions.NewDefaultClient(), + } + + ids := map[string]struct{}{} + for i, c := range clients { + // check uniqueness + if _, ok := ids[c.Id()]; ok { + t.Errorf("(%d) Expected unique id, got %v", i, c.Id()) + } else { + ids[c.Id()] = struct{}{} + } + + // check length + if len(c.Id()) != 40 { + t.Errorf("(%d) Expected unique id to have 40 chars length, got %v", i, c.Id()) + } + } +} + +func TestChannel(t *testing.T) { + c := subscriptions.NewDefaultClient() + + if c.Channel() == nil { + t.Fatalf("Expected channel to be initialized, got") + } +} + +func TestSubscriptions(t *testing.T) { + c := subscriptions.NewDefaultClient() + + if len(c.Subscriptions()) != 0 { + t.Fatalf("Expected subscriptions to be empty") + } + + c.Subscribe("sub1", "sub11", "sub2") + + scenarios := []struct { + prefixes []string + expected []string + }{ + {nil, []string{"sub1", "sub11", "sub2"}}, + {[]string{"missing"}, nil}, + {[]string{"sub1"}, []string{"sub1", "sub11"}}, + {[]string{"sub2"}, []string{"sub2"}}, // with extra query start char + } + + for _, s := range scenarios { + t.Run(strings.Join(s.prefixes, ","), func(t *testing.T) { + subs := c.Subscriptions(s.prefixes...) + + if len(subs) != len(s.expected) { + t.Fatalf("Expected %d subscriptions, got %d", len(s.expected), len(subs)) + } + + for _, s := range s.expected { + if _, ok := subs[s]; !ok { + t.Fatalf("Missing subscription %q in \n%v", s, subs) + } + } + }) + } +} + +func TestSubscribe(t *testing.T) { + c := subscriptions.NewDefaultClient() + + subs := []string{"", "sub1", "sub2", "sub3"} + expected := []string{"sub1", "sub2", "sub3"} + + c.Subscribe(subs...) // empty string should be skipped + + if len(c.Subscriptions()) != 3 { + t.Fatalf("Expected 3 subscriptions, got %v", c.Subscriptions()) + } + + for i, s := range expected { + if !c.HasSubscription(s) { + t.Errorf("(%d) Expected sub %s", i, s) + } + } +} + +func TestSubscribeOptions(t *testing.T) { + c := subscriptions.NewDefaultClient() + + sub1 := "test1" + sub2 := `test2?options={"query":{"name":123},"headers":{"X-Token":456}}` + + c.Subscribe(sub1, sub2) + + subs := c.Subscriptions() + + scenarios := []struct { + name string + expectedOptions string + }{ + {sub1, `{"query":null,"headers":null}`}, + {sub2, `{"query":{"name":"123"},"headers":{"x_token":"456"}}`}, + } + + for _, s := range scenarios { + t.Run(s.name, func(t *testing.T) { + options, ok := subs[s.name] + if !ok { + t.Fatalf("Missing subscription \n%q \nin \n%v", s.name, subs) + } + + rawBytes, err := json.Marshal(options) + if err != nil { + t.Fatal(err) + } + rawStr := string(rawBytes) + + if rawStr != s.expectedOptions { + t.Fatalf("Expected options \n%v \ngot \n%v", s.expectedOptions, rawStr) + } + }) + } +} + +func TestUnsubscribe(t *testing.T) { + c := subscriptions.NewDefaultClient() + + c.Subscribe("sub1", "sub2", "sub3") + + c.Unsubscribe("sub1") + + if c.HasSubscription("sub1") { + t.Fatalf("Expected sub1 to be removed") + } + + c.Unsubscribe( /* all */ ) + if len(c.Subscriptions()) != 0 { + t.Fatalf("Expected all subscriptions to be removed, got %v", c.Subscriptions()) + } +} + +func TestHasSubscription(t *testing.T) { + c := subscriptions.NewDefaultClient() + + if c.HasSubscription("missing") { + t.Error("Expected false, got true") + } + + c.Subscribe("sub") + + if !c.HasSubscription("sub") { + t.Error("Expected true, got false") + } +} + +func TestSetAndGet(t *testing.T) { + c := subscriptions.NewDefaultClient() + + c.Set("demo", 1) + + result, _ := c.Get("demo").(int) + + if result != 1 { + t.Errorf("Expected 1, got %v", result) + } +} + +func TestDiscard(t *testing.T) { + c := subscriptions.NewDefaultClient() + + if v := c.IsDiscarded(); v { + t.Fatal("Expected false, got true") + } + + c.Discard() + + if v := c.IsDiscarded(); !v { + t.Fatal("Expected true, got false") + } +} + +func TestSend(t *testing.T) { + c := subscriptions.NewDefaultClient() + + received := []string{} + go func() { + for m := range c.Channel() { + received = append(received, m.Name) + } + }() + + c.Send(subscriptions.Message{Name: "m1"}) + c.Send(subscriptions.Message{Name: "m2"}) + c.Discard() + c.Send(subscriptions.Message{Name: "m3"}) + c.Send(subscriptions.Message{Name: "m4"}) + time.Sleep(5 * time.Millisecond) + + expected := []string{"m1", "m2"} + + if len(received) != len(expected) { + t.Fatalf("Expected %d messages, got %d", len(expected), len(received)) + } + for _, name := range expected { + var exists bool + for _, n := range received { + if n == name { + exists = true + break + } + } + if !exists { + t.Fatalf("Missing expected %q message, got %v", name, received) + } + } +} diff --git a/template/registry.go b/template/registry.go new file mode 100644 index 0000000..96b24d3 --- /dev/null +++ b/template/registry.go @@ -0,0 +1,141 @@ +// Package template is a thin wrapper around the standard html/template +// and text/template packages that implements a convenient registry to +// load and cache templates on the fly concurrently. +// +// It was created to assist the JSVM plugin HTML rendering, but could be used in other Go code. +// +// Example: +// +// registry := template.NewRegistry() +// +// html1, err := registry.LoadFiles( +// // the files set wil be parsed only once and then cached +// "layout.html", +// "content.html", +// ).Render(map[string]any{"name": "John"}) +// +// html2, err := registry.LoadFiles( +// // reuse the already parsed and cached files set +// "layout.html", +// "content.html", +// ).Render(map[string]any{"name": "Jane"}) +package template + +import ( + "fmt" + "html/template" + "io/fs" + "path/filepath" + "strings" + + "github.com/apus-run/van/store" +) + +// NewRegistry creates and initializes a new templates registry with +// some defaults (eg. global "raw" template function for unescaped HTML). +// +// Use the Registry.Load* methods to load templates into the registry. +func NewRegistry() *Registry { + return &Registry{ + cache: store.New[*Renderer](nil), + funcs: template.FuncMap{ + "raw": func(str string) template.HTML { + return template.HTML(str) + }, + }, + } +} + +// Registry defines a templates registry that is safe to be used by multiple goroutines. +// +// Use the Registry.Load* methods to load templates into the registry. +type Registry struct { + cache *store.Store[*Renderer] + funcs template.FuncMap +} + +// AddFuncs registers new global template functions. +// +// The key of each map entry is the function name that will be used in the templates. +// If a function with the map entry name already exists it will be replaced with the new one. +// +// The value of each map entry is a function that must have either a +// single return value, or two return values of which the second has type error. +// +// Example: +// +// r.AddFuncs(map[string]any{ +// "toUpper": func(str string) string { +// return strings.ToUppser(str) +// }, +// ... +// }) +func (r *Registry) AddFuncs(funcs map[string]any) *Registry { + for name, f := range funcs { + r.funcs[name] = f + } + + return r +} + +// LoadFiles caches (if not already) the specified filenames set as a +// single template and returns a ready to use Renderer instance. +// +// There must be at least 1 filename specified. +func (r *Registry) LoadFiles(filenames ...string) *Renderer { + key := strings.Join(filenames, ",") + + found := r.cache.Get(key) + + if found == nil { + // parse and cache + tpl, err := template.New(filepath.Base(filenames[0])).Funcs(r.funcs).ParseFiles(filenames...) + found = &Renderer{template: tpl, parseError: err} + r.cache.Set(key, found) + } + + return found +} + +// LoadString caches (if not already) the specified inline string as a +// single template and returns a ready to use Renderer instance. +func (r *Registry) LoadString(text string) *Renderer { + found := r.cache.Get(text) + + if found == nil { + // parse and cache (using the text as key) + tpl, err := template.New("").Funcs(r.funcs).Parse(text) + found = &Renderer{template: tpl, parseError: err} + r.cache.Set(text, found) + } + + return found +} + +// LoadFS caches (if not already) the specified fs and globPatterns +// pair as single template and returns a ready to use Renderer instance. +// +// There must be at least 1 file matching the provided globPattern(s) +// (note that most file names serves as glob patterns matching themselves). +func (r *Registry) LoadFS(fsys fs.FS, globPatterns ...string) *Renderer { + key := fmt.Sprintf("%v%v", fsys, globPatterns) + + found := r.cache.Get(key) + + if found == nil { + // find the first file to use as template name (it is required when specifying Funcs) + var firstFilename string + if len(globPatterns) > 0 { + list, _ := fs.Glob(fsys, globPatterns[0]) + if len(list) > 0 { + firstFilename = filepath.Base(list[0]) + } + } + + tpl, err := template.New(firstFilename).Funcs(r.funcs).ParseFS(fsys, globPatterns...) + found = &Renderer{template: tpl, parseError: err} + r.cache.Set(key, found) + } + + return found +} diff --git a/template/registry_test.go b/template/registry_test.go new file mode 100644 index 0000000..c037fe3 --- /dev/null +++ b/template/registry_test.go @@ -0,0 +1,250 @@ +package template + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" +) + +func checkRegistryFuncs(t *testing.T, r *Registry, expectedFuncs ...string) { + if v := len(r.funcs); v != len(expectedFuncs) { + t.Fatalf("Expected total %d funcs, got %d", len(expectedFuncs), v) + } + + for _, name := range expectedFuncs { + if _, ok := r.funcs[name]; !ok { + t.Fatalf("Missing %q func", name) + } + } +} + +func TestNewRegistry(t *testing.T) { + r := NewRegistry() + + if r.cache == nil { + t.Fatalf("Expected cache store to be initialized, got nil") + } + + if v := r.cache.Length(); v != 0 { + t.Fatalf("Expected cache store length to be 0, got %d", v) + } + + checkRegistryFuncs(t, r, "raw") +} + +func TestRegistryAddFuncs(t *testing.T) { + r := NewRegistry() + + r.AddFuncs(map[string]any{ + "test": func(a string) string { return a + "-TEST" }, + }) + + checkRegistryFuncs(t, r, "raw", "test") + + result, err := r.LoadString(`{{.|test}}`).Render("example") + if err != nil { + t.Fatalf("Unexpected Render() error, got %v", err) + } + + expected := "example-TEST" + if result != expected { + t.Fatalf("Expected Render() result %q, got %q", expected, result) + } +} + +func TestRegistryLoadFiles(t *testing.T) { + r := NewRegistry() + + t.Run("invalid or missing files", func(t *testing.T) { + r.LoadFiles("file1.missing", "file2.missing") + + key := "file1.missing,file2.missing" + renderer := r.cache.Get(key) + + if renderer == nil { + t.Fatal("Expected renderer to be initialized even if invalid, got nil") + } + + if renderer.template != nil { + t.Fatalf("Expected renderer template to be nil, got %v", renderer.template) + } + + if renderer.parseError == nil { + t.Fatalf("Expected renderer parseError to be set, got nil") + } + }) + + t.Run("valid files", func(t *testing.T) { + // create test templates + dir, err := os.MkdirTemp(os.TempDir(), "template_test") + if err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "base.html"), []byte(`Base:{{template "content" .}}`), 0644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "content.html"), []byte(`{{define "content"}}Content:{{.|raw}}{{end}}`), 0644); err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + files := []string{filepath.Join(dir, "base.html"), filepath.Join(dir, "content.html")} + + r.LoadFiles(files...) + + renderer := r.cache.Get(strings.Join(files, ",")) + + if renderer == nil { + t.Fatal("Expected renderer to be initialized even if invalid, got nil") + } + + if renderer.template == nil { + t.Fatal("Expected renderer template to be set, got nil") + } + + if renderer.parseError != nil { + t.Fatalf("Expected renderer parseError to be nil, got %v", renderer.parseError) + } + + result, err := renderer.Render("

123

") + if err != nil { + t.Fatalf("Unexpected Render() error, got %v", err) + } + + expected := "Base:Content:

123

" + if result != expected { + t.Fatalf("Expected Render() result %q, got %q", expected, result) + } + }) +} + +func TestRegistryLoadString(t *testing.T) { + r := NewRegistry() + + t.Run("invalid template string", func(t *testing.T) { + txt := `test {{define "content"}}` + + r.LoadString(txt) + + renderer := r.cache.Get(txt) + + if renderer == nil { + t.Fatal("Expected renderer to be initialized even if invalid, got nil") + } + + if renderer.template != nil { + t.Fatalf("Expected renderer template to be nil, got %v", renderer.template) + } + + if renderer.parseError == nil { + t.Fatalf("Expected renderer parseError to be set, got nil") + } + }) + + t.Run("valid template string", func(t *testing.T) { + txt := `test {{.|raw}}` + + r.LoadString(txt) + + renderer := r.cache.Get(txt) + + if renderer == nil { + t.Fatal("Expected renderer to be initialized even if invalid, got nil") + } + + if renderer.template == nil { + t.Fatal("Expected renderer template to be set, got nil") + } + + if renderer.parseError != nil { + t.Fatalf("Expected renderer parseError to be nil, got %v", renderer.parseError) + } + + result, err := renderer.Render("

123

") + if err != nil { + t.Fatalf("Unexpected Render() error, got %v", err) + } + + expected := "test

123

" + if result != expected { + t.Fatalf("Expected Render() result %q, got %q", expected, result) + } + }) +} + +func TestRegistryLoadFS(t *testing.T) { + r := NewRegistry() + + t.Run("invalid fs", func(t *testing.T) { + fs := os.DirFS("__missing__") + + files := []string{"missing1", "missing2"} + + key := fmt.Sprintf("%v%v", fs, files) + + r.LoadFS(fs, files...) + + renderer := r.cache.Get(key) + + if renderer == nil { + t.Fatal("Expected renderer to be initialized even if invalid, got nil") + } + + if renderer.template != nil { + t.Fatalf("Expected renderer template to be nil, got %v", renderer.template) + } + + if renderer.parseError == nil { + t.Fatalf("Expected renderer parseError to be set, got nil") + } + }) + + t.Run("valid fs", func(t *testing.T) { + // create test templates + dir, err := os.MkdirTemp(os.TempDir(), "template_test2") + if err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "base.html"), []byte(`Base:{{template "content" .}}`), 0644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "content.html"), []byte(`{{define "content"}}Content:{{.|raw}}{{end}}`), 0644); err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + fs := os.DirFS(dir) + + files := []string{"base.html", "content.html"} + + key := fmt.Sprintf("%v%v", fs, files) + + r.LoadFS(fs, files...) + + renderer := r.cache.Get(key) + + if renderer == nil { + t.Fatal("Expected renderer to be initialized even if invalid, got nil") + } + + if renderer.template == nil { + t.Fatal("Expected renderer template to be set, got nil") + } + + if renderer.parseError != nil { + t.Fatalf("Expected renderer parseError to be nil, got %v", renderer.parseError) + } + + result, err := renderer.Render("

123

") + if err != nil { + t.Fatalf("Unexpected Render() error, got %v", err) + } + + expected := "Base:Content:

123

" + if result != expected { + t.Fatalf("Expected Render() result %q, got %q", expected, result) + } + }) +} diff --git a/template/renderer.go b/template/renderer.go new file mode 100644 index 0000000..7a2d85d --- /dev/null +++ b/template/renderer.go @@ -0,0 +1,33 @@ +package template + +import ( + "bytes" + "errors" + "html/template" +) + +// Renderer defines a single parsed template. +type Renderer struct { + template *template.Template + parseError error +} + +// Render executes the template with the specified data as the dot object +// and returns the result as plain string. +func (r *Renderer) Render(data any) (string, error) { + if r.parseError != nil { + return "", r.parseError + } + + if r.template == nil { + return "", errors.New("invalid or nil template") + } + + buf := new(bytes.Buffer) + + if err := r.template.Execute(buf, data); err != nil { + return "", err + } + + return buf.String(), nil +} diff --git a/template/renderer_test.go b/template/renderer_test.go new file mode 100644 index 0000000..7c75111 --- /dev/null +++ b/template/renderer_test.go @@ -0,0 +1,63 @@ +package template + +import ( + "errors" + "html/template" + "testing" +) + +func TestRendererRender(t *testing.T) { + tpl, _ := template.New("").Parse("Hello {{.Name}}!") + tpl.Option("missingkey=error") // enforce execute errors + + scenarios := map[string]struct { + renderer *Renderer + data any + expectedHasErr bool + expectedResult string + }{ + "with nil template": { + &Renderer{}, + nil, + true, + "", + }, + "with parse error": { + &Renderer{ + template: tpl, + parseError: errors.New("test"), + }, + nil, + true, + "", + }, + "with execute error": { + &Renderer{template: tpl}, + nil, + true, + "", + }, + "no error": { + &Renderer{template: tpl}, + struct{ Name string }{"world"}, + false, + "Hello world!", + }, + } + + for name, s := range scenarios { + t.Run(name, func(t *testing.T) { + result, err := s.renderer.Render(s.data) + + hasErr := err != nil + + if s.expectedHasErr != hasErr { + t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectedHasErr, hasErr, err) + } + + if s.expectedResult != result { + t.Fatalf("Expected result %v, got %v", s.expectedResult, result) + } + }) + } +}