diff --git a/driver/registry_default.go b/driver/registry_default.go index 36d324a9b02..c474d098f73 100644 --- a/driver/registry_default.go +++ b/driver/registry_default.go @@ -414,7 +414,11 @@ func (m *RegistryDefault) Hasher() hash.Hasher { func (m *RegistryDefault) PasswordValidator() password2.Validator { if m.passwordValidator == nil { - m.passwordValidator = password2.NewDefaultPasswordValidatorStrategy(m) + var err error + m.passwordValidator, err = password2.NewDefaultPasswordValidatorStrategy(m) + if err != nil { + m.Logger().WithError(err).Fatal("could not initialize DefaultPasswordValidator") + } } return m.passwordValidator } diff --git a/go.mod b/go.mod index 4cf33e84031..b2feabd10c0 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( github.com/cortesi/modd v0.0.0-20210323234521-b35eddab86cc github.com/davecgh/go-spew v1.1.1 github.com/davidrjonas/semver-cli v0.0.0-20190116233701-ee19a9a0dda6 + github.com/dgraph-io/ristretto v0.1.0 github.com/duo-labs/webauthn v0.0.0-20220330035159-03696f3d4499 github.com/fatih/color v1.13.0 github.com/form3tech-oss/jwt-go v3.2.3+incompatible diff --git a/selfservice/strategy/password/validator.go b/selfservice/strategy/password/validator.go index ad8088b0580..ba7d71f3734 100644 --- a/selfservice/strategy/password/validator.go +++ b/selfservice/strategy/password/validator.go @@ -10,10 +10,10 @@ import ( "net/http" "strconv" "strings" - "sync" "time" "github.com/arbovm/levenshtein" + "github.com/dgraph-io/ristretto" "github.com/hashicorp/go-retryablehttp" "github.com/pkg/errors" @@ -22,6 +22,8 @@ import ( "github.com/ory/x/httpx" ) +const hashCacheItemTTL = time.Hour + // Validator implements a validation strategy for passwords. One example is that the password // has to have at least 6 characters and at least one lower and one uppercase password. type Validator interface { @@ -49,10 +51,9 @@ var ErrUnexpectedStatusCode = errors.New("unexpected status code") // [haveibeenpwnd](https://haveibeenpwned.com/API/v2#SearchingPwnedPasswordsByRange) service to check if the // password has been breached in a previous data leak using k-anonymity. type DefaultPasswordValidator struct { - sync.RWMutex reg validatorDependencies Client *retryablehttp.Client - hashes map[string]int64 + hashes *ristretto.Cache minIdentifierPasswordDist int maxIdentifierPasswordSubstrThreshold float32 @@ -62,12 +63,22 @@ type validatorDependencies interface { config.Provider } -func NewDefaultPasswordValidatorStrategy(reg validatorDependencies) *DefaultPasswordValidator { +func NewDefaultPasswordValidatorStrategy(reg validatorDependencies) (*DefaultPasswordValidator, error) { + cache, err := ristretto.NewCache(&ristretto.Config{ + NumCounters: 10 * 10000, + MaxCost: 60 * 10000, // BCrypt hash size is 60 bytes + BufferItems: 64, + IgnoreInternalCost: true, + }) + // sanity check - this should never happen unless above configuration variables are invalid + if err != nil { + return nil, errors.Wrap(err, "error while setting up validator cache") + } return &DefaultPasswordValidator{ Client: httpx.NewResilientClient(httpx.ResilientClientWithConnectionTimeout(time.Second)), reg: reg, - hashes: map[string]int64{}, - minIdentifierPasswordDist: 5, maxIdentifierPasswordSubstrThreshold: 0.5} + hashes: cache, + minIdentifierPasswordDist: 5, maxIdentifierPasswordSubstrThreshold: 0.5}, nil } func b20(src []byte) string { @@ -109,9 +120,7 @@ func (s *DefaultPasswordValidator) fetch(hpw []byte, apiDNSName string) error { return errors.Wrapf(ErrUnexpectedStatusCode, "%d", res.StatusCode) } - s.Lock() - s.hashes[b20(hpw)] = 0 - s.Unlock() + s.hashes.SetWithTTL(b20(hpw), 0, 1, hashCacheItemTTL) sc := bufio.NewScanner(res.Body) for sc.Scan() { @@ -130,9 +139,7 @@ func (s *DefaultPasswordValidator) fetch(hpw []byte, apiDNSName string) error { } } - s.Lock() - s.hashes[(prefix + result[0])] = count - s.Unlock() + s.hashes.SetWithTTL(prefix+result[0], count, 1, hashCacheItemTTL) } if err := sc.Err(); err != nil { @@ -169,10 +176,7 @@ func (s *DefaultPasswordValidator) Validate(ctx context.Context, identifier, pas } hpw := h.Sum(nil) - s.RLock() - c, ok := s.hashes[b20(hpw)] - s.RUnlock() - + c, ok := s.hashes.Get(b20(hpw)) if !ok { err := s.fetch(hpw, passwordPolicyConfig.HaveIBeenPwnedHost) if (errors.Is(err, ErrNetworkFailure) || errors.Is(err, ErrUnexpectedStatusCode)) && passwordPolicyConfig.IgnoreNetworkErrors { @@ -184,7 +188,8 @@ func (s *DefaultPasswordValidator) Validate(ctx context.Context, identifier, pas return s.Validate(ctx, identifier, password) } - if c > int64(s.reg.Config(ctx).PasswordPolicyConfig().MaxBreaches) { + v, ok := c.(int64) + if ok && v > int64(s.reg.Config(ctx).PasswordPolicyConfig().MaxBreaches) { return errors.New("the password has been found in data breaches and must no longer be used") } diff --git a/selfservice/strategy/password/validator_test.go b/selfservice/strategy/password/validator_test.go index 13577fad81c..1430ca00d76 100644 --- a/selfservice/strategy/password/validator_test.go +++ b/selfservice/strategy/password/validator_test.go @@ -28,7 +28,7 @@ func TestDefaultPasswordValidationStrategy(t *testing.T) { t.Run("default strategy", func(t *testing.T) { _, reg := internal.NewFastRegistryWithMocks(t) - s := password.NewDefaultPasswordValidatorStrategy(reg) + s, _ := password.NewDefaultPasswordValidatorStrategy(reg) for k, tc := range []struct { id string pw string @@ -80,7 +80,7 @@ func TestDefaultPasswordValidationStrategy(t *testing.T) { t.Run("failure cases", func(t *testing.T) { conf, reg := internal.NewFastRegistryWithMocks(t) - s := password.NewDefaultPasswordValidatorStrategy(reg) + s, _ := password.NewDefaultPasswordValidatorStrategy(reg) fakeClient := NewFakeHTTPClient() s.Client = httpx.NewResilientClient(httpx.ResilientClientWithClient(&fakeClient.Client), httpx.ResilientClientWithMaxRetry(1), httpx.ResilientClientWithConnectionTimeout(time.Millisecond), httpx.ResilientClientWithMaxRetryWait(time.Millisecond)) @@ -117,7 +117,7 @@ func TestDefaultPasswordValidationStrategy(t *testing.T) { t.Run("max breaches", func(t *testing.T) { conf, reg := internal.NewFastRegistryWithMocks(t) - s := password.NewDefaultPasswordValidatorStrategy(reg) + s, _ := password.NewDefaultPasswordValidatorStrategy(reg) fakeClient := NewFakeHTTPClient() s.Client = httpx.NewResilientClient(httpx.ResilientClientWithClient(&fakeClient.Client), httpx.ResilientClientWithMaxRetry(1), httpx.ResilientClientWithConnectionTimeout(time.Millisecond)) @@ -194,7 +194,7 @@ func TestChangeHaveIBeenPwnedValidationHost(t *testing.T) { testServer.StartTLS() testServerURL, _ := url.Parse(testServer.URL) conf, reg := internal.NewFastRegistryWithMocks(t) - s := password.NewDefaultPasswordValidatorStrategy(reg) + s, _ := password.NewDefaultPasswordValidatorStrategy(reg) conf.MustSet(config.ViperKeyPasswordHaveIBeenPwnedHost, testServerURL.Host) fakeClient := NewFakeHTTPClient() @@ -211,7 +211,7 @@ func TestChangeHaveIBeenPwnedValidationHost(t *testing.T) { func TestDisableHaveIBeenPwnedValidationHost(t *testing.T) { conf, reg := internal.NewFastRegistryWithMocks(t) - s := password.NewDefaultPasswordValidatorStrategy(reg) + s, _ := password.NewDefaultPasswordValidatorStrategy(reg) conf.MustSet(config.ViperKeyPasswordHaveIBeenPwnedEnabled, false) fakeClient := NewFakeHTTPClient() @@ -225,7 +225,7 @@ func TestDisableHaveIBeenPwnedValidationHost(t *testing.T) { func TestChangeMinPasswordLength(t *testing.T) { conf, reg := internal.NewFastRegistryWithMocks(t) - s := password.NewDefaultPasswordValidatorStrategy(reg) + s, _ := password.NewDefaultPasswordValidatorStrategy(reg) conf.MustSet(config.ViperKeyPasswordMinLength, 10) t.Run("case=should not fail if password is longer than min length", func(t *testing.T) { @@ -239,7 +239,7 @@ func TestChangeMinPasswordLength(t *testing.T) { func TestChangeIdentifierSimilarityCheckEnabled(t *testing.T) { conf, reg := internal.NewFastRegistryWithMocks(t) - s := password.NewDefaultPasswordValidatorStrategy(reg) + s, _ := password.NewDefaultPasswordValidatorStrategy(reg) t.Run("case=should not fail if password is similar to identifier", func(t *testing.T) { conf.MustSet(config.ViperKeyPasswordIdentifierSimilarityCheckEnabled, false)