-
-
Notifications
You must be signed in to change notification settings - Fork 961
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: unreliable HIBP caching strategy #2468
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ package password | |
import ( | ||
"bufio" | ||
"context" | ||
stderrs "errors" | ||
|
||
/* #nosec G505 sha1 is used for k-anonymity */ | ||
"crypto/sha1" | ||
|
@@ -37,9 +38,12 @@ type ValidationProvider interface { | |
PasswordValidator() Validator | ||
} | ||
|
||
var _ Validator = new(DefaultPasswordValidator) | ||
var ErrNetworkFailure = errors.New("unable to check if password has been leaked because an unexpected network error occurred") | ||
var ErrUnexpectedStatusCode = errors.New("unexpected status code") | ||
var ( | ||
_ Validator = new(DefaultPasswordValidator) | ||
ErrNetworkFailure = stderrs.New("unable to check if password has been leaked because an unexpected network error occurred") | ||
ErrUnexpectedStatusCode = stderrs.New("unexpected status code") | ||
ErrTooManyBreaches = stderrs.New("the password has been found in data breaches and must no longer be used") | ||
) | ||
|
||
// DefaultPasswordValidator implements Validator. It is based on best | ||
// practices as defined in the following blog posts: | ||
|
@@ -107,20 +111,20 @@ func lcsLength(a, b string) int { | |
return greatestLength | ||
} | ||
|
||
func (s *DefaultPasswordValidator) fetch(hpw []byte, apiDNSName string) error { | ||
func (s *DefaultPasswordValidator) fetch(hpw []byte, apiDNSName string) (int64, error) { | ||
prefix := fmt.Sprintf("%X", hpw)[0:5] | ||
loc := fmt.Sprintf("https://%s/range/%s", apiDNSName, prefix) | ||
res, err := s.Client.Get(loc) | ||
if err != nil { | ||
return errors.Wrapf(ErrNetworkFailure, "%s", err) | ||
return 0, errors.Wrapf(ErrNetworkFailure, "%s", err) | ||
} | ||
defer res.Body.Close() | ||
|
||
if res.StatusCode != http.StatusOK { | ||
return errors.Wrapf(ErrUnexpectedStatusCode, "%d", res.StatusCode) | ||
return 0, errors.Wrapf(ErrUnexpectedStatusCode, "%d", res.StatusCode) | ||
} | ||
|
||
s.hashes.SetWithTTL(b20(hpw), 0, 1, hashCacheItemTTL) | ||
var thisCount int64 | ||
|
||
sc := bufio.NewScanner(res.Body) | ||
for sc.Scan() { | ||
|
@@ -135,18 +139,22 @@ func (s *DefaultPasswordValidator) fetch(hpw []byte, apiDNSName string) error { | |
if len(result) == 2 { | ||
count, err = strconv.ParseInt(result[1], 10, 64) | ||
if err != nil { | ||
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected password hash to contain a count formatted as int but got: %s", result[1])) | ||
return 0, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected password hash to contain a count formatted as int but got: %s", result[1])) | ||
} | ||
} | ||
|
||
s.hashes.SetWithTTL(prefix+result[0], count, 1, hashCacheItemTTL) | ||
if prefix+result[0] == b20(hpw) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this check good for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
thisCount = count | ||
} | ||
} | ||
|
||
if err := sc.Err(); err != nil { | ||
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to initialize string scanner: %s", err)) | ||
return 0, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to initialize string scanner: %s", err)) | ||
} | ||
|
||
return nil | ||
s.hashes.SetWithTTL(b20(hpw), thisCount, 1, hashCacheItemTTL) | ||
return thisCount, nil | ||
} | ||
|
||
func (s *DefaultPasswordValidator) Validate(ctx context.Context, identifier, password string) error { | ||
|
@@ -178,19 +186,18 @@ func (s *DefaultPasswordValidator) Validate(ctx context.Context, identifier, pas | |
|
||
c, ok := s.hashes.Get(b20(hpw)) | ||
if !ok { | ||
err := s.fetch(hpw, passwordPolicyConfig.HaveIBeenPwnedHost) | ||
var err error | ||
c, err = s.fetch(hpw, passwordPolicyConfig.HaveIBeenPwnedHost) | ||
if (errors.Is(err, ErrNetworkFailure) || errors.Is(err, ErrUnexpectedStatusCode)) && passwordPolicyConfig.IgnoreNetworkErrors { | ||
return nil | ||
} else if err != nil { | ||
return err | ||
} | ||
|
||
return s.Validate(ctx, identifier, password) | ||
} | ||
|
||
v, ok := c.(int64) | ||
if ok && v > int64(s.reg.Config(ctx).PasswordPolicyConfig().MaxBreaches) { | ||
return errors.New("the password has been found in data breaches and must no longer be used") | ||
return errors.WithStack(ErrTooManyBreaches) | ||
} | ||
|
||
return nil | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ package password_test | |
import ( | ||
"bytes" | ||
"context" | ||
"crypto/rand" | ||
"crypto/sha1" | ||
"errors" | ||
"fmt" | ||
"io/ioutil" | ||
|
@@ -12,6 +14,10 @@ import ( | |
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
|
||
"github.com/ory/herodot" | ||
|
||
"github.com/stretchr/testify/require" | ||
|
||
"github.com/ory/x/httpx" | ||
|
@@ -117,73 +123,111 @@ func TestDefaultPasswordValidationStrategy(t *testing.T) { | |
|
||
t.Run("max breaches", func(t *testing.T) { | ||
conf, reg := internal.NewFastRegistryWithMocks(t) | ||
s, _ := password.NewDefaultPasswordValidatorStrategy(reg) | ||
s, err := password.NewDefaultPasswordValidatorStrategy(reg) | ||
require.NoError(t, err) | ||
|
||
hibpResp := make(chan string, 1) | ||
fakeClient := NewFakeHTTPClient() | ||
fakeClient.responder = func(req *http.Request) (*http.Response, error) { | ||
buffer := bytes.NewBufferString(<-hibpResp) | ||
return &http.Response{ | ||
StatusCode: http.StatusOK, | ||
Body: ioutil.NopCloser(buffer), | ||
ContentLength: int64(buffer.Len()), | ||
Request: req, | ||
}, nil | ||
} | ||
s.Client = httpx.NewResilientClient(httpx.ResilientClientWithClient(&fakeClient.Client), httpx.ResilientClientWithMaxRetry(1), httpx.ResilientClientWithConnectionTimeout(time.Millisecond)) | ||
|
||
var hashPw = func(t *testing.T, pw string) string { | ||
/* #nosec G401 sha1 is used for k-anonymity */ | ||
h := sha1.New() | ||
_, err := h.Write([]byte(pw)) | ||
require.NoError(t, err) | ||
hpw := h.Sum(nil) | ||
return fmt.Sprintf("%X", hpw)[5:] | ||
} | ||
randomPassword := func(t *testing.T) string { | ||
pw := make([]byte, 10) | ||
_, err := rand.Read(pw) | ||
require.NoError(t, err) | ||
return fmt.Sprintf("%x", pw) | ||
} | ||
|
||
conf.MustSet(config.ViperKeyPasswordMaxBreaches, 5) | ||
for _, tc := range []struct { | ||
cs string | ||
pw string | ||
res string | ||
pass bool | ||
name string | ||
res func(t *testing.T, hash string) string | ||
expectErr error | ||
}{ | ||
{ | ||
cs: "contains invalid data which is ignored", | ||
pw: "lufsokpugo", | ||
res: "0225BDB8F106B1B4A5DF4C31B80AC695874:2\ninvalid", | ||
pass: true, | ||
name: "contains invalid data which is ignored", | ||
res: func(t *testing.T, hash string) string { | ||
return fmt.Sprintf("%s:2\ninvalid", hash) | ||
}, | ||
}, | ||
{ | ||
cs: "is missing a colon", | ||
pw: "lufsokpugo", | ||
res: "0225BDB8F106B1B4A5DF4C31B80AC695874", | ||
pass: true, | ||
name: "is missing a colon", | ||
res: func(t *testing.T, hash string) string { | ||
return hash | ||
}, | ||
}, | ||
{ | ||
cs: "contains invalid hash count", | ||
pw: "gimekvizec", | ||
res: "0248B3D6077106761CC84F4B9CF680C6D84:text\n1A34C526A9D14832C6ACFEAE90261ED78F8:2", | ||
pass: false, | ||
name: "contains invalid hash count", | ||
res: func(t *testing.T, hash string) string { | ||
return fmt.Sprintf("%s:text\n%s:2", hashPw(t, randomPassword(t)), hash) | ||
}, | ||
expectErr: herodot.ErrInternalServerError, | ||
}, | ||
{ | ||
cs: "is missing hash count", | ||
pw: "bofulosasm", | ||
res: "1D29CF237A57F6FEA8F29E8D907DCF1EBBA\n026364A8EE59DEDCF9E2DC80B9D7BAB7389:2", | ||
pass: true, | ||
name: "is missing hash count", | ||
res: func(t *testing.T, hash string) string { | ||
return fmt.Sprintf("%s\n%s:2", hash, hashPw(t, randomPassword(t))) | ||
}, | ||
}, | ||
{ | ||
cs: "response contains no matches", | ||
pw: "lizrafakha", | ||
res: "0D6CF6289C9CA71B47D2167EB7FE89690E7:57", | ||
pass: true, | ||
name: "response contains no matches", | ||
res: func(t *testing.T, hash string) string { | ||
return fmt.Sprintf("%s:57", hashPw(t, randomPassword(t))) | ||
}, | ||
}, | ||
{ | ||
cs: "contains less than maxBreachesThreshold", | ||
pw: "tafpabdopa", | ||
res: fmt.Sprintf("280915F3B572F94217D86F1D63BED53F66A:%d\n0F76A7D21E7C3E653E98236897AD7888937:%d", conf.PasswordPolicyConfig().MaxBreaches, conf.PasswordPolicyConfig().MaxBreaches+1), | ||
pass: true, | ||
name: "contains less than maxBreachesThreshold", | ||
res: func(t *testing.T, hash string) string { | ||
return fmt.Sprintf( | ||
"%s:%d\n%s:%d", | ||
hash, | ||
conf.PasswordPolicyConfig().MaxBreaches, | ||
hashPw(t, randomPassword(t)), | ||
conf.PasswordPolicyConfig().MaxBreaches+1, | ||
) | ||
}, | ||
}, | ||
{ | ||
cs: "contains more than maxBreachesThreshold", | ||
pw: "hicudsumla", | ||
res: fmt.Sprintf("5656812AA72561AAA6663E486A46D5711BE:%d", conf.PasswordPolicyConfig().MaxBreaches+1), | ||
pass: false, | ||
name: "contains more than maxBreachesThreshold", | ||
res: func(t *testing.T, hash string) string { | ||
return fmt.Sprintf("%s:%d", hash, conf.PasswordPolicyConfig().MaxBreaches+1) | ||
}, | ||
expectErr: password.ErrTooManyBreaches, | ||
}, | ||
} { | ||
fakeClient.RespondWith(http.StatusOK, tc.res) | ||
format := "case=should not fail if response %s" | ||
if !tc.pass { | ||
format = "case=should fail if response %s" | ||
} | ||
t.Run(fmt.Sprintf(format, tc.cs), func(t *testing.T) { | ||
err := s.Validate(context.Background(), "", tc.pw) | ||
if tc.pass { | ||
require.NoError(t, err) | ||
} else { | ||
require.Error(t, err) | ||
} | ||
t.Run(fmt.Sprintf("case=%s/expected err=%s", tc.name, tc.expectErr), func(t *testing.T) { | ||
pw := randomPassword(t) | ||
hash := hashPw(t, pw) | ||
hibpResp <- tc.res(t, hash) | ||
|
||
err := s.Validate(context.Background(), "", pw) | ||
assert.ErrorIs(t, err, tc.expectErr) | ||
}) | ||
|
||
// verify the fetch was done, i.e. channel is empty | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We had some fishy cases, therefore we ensure here the fetch was done. |
||
select { | ||
case r := <-hibpResp: | ||
t.Logf("expected the validate step to fetch the response, but I still got %s", r) | ||
t.FailNow() | ||
default: | ||
// continue | ||
} | ||
} | ||
}) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Returning the count here does not rely on the cache having the value in the recursive run.