Skip to content

Commit

Permalink
fix: improve tests and pop adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Oct 20, 2020
1 parent c4438b0 commit 1354611
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 138 deletions.
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ require (
github.com/go-openapi/strfmt v0.19.4
github.com/go-openapi/swag v0.19.7
github.com/go-openapi/validate v0.19.6
github.com/go-sql-driver/mysql v1.5.0
github.com/go-swagger/go-swagger v0.22.1-0.20200306221957-4aad3a5f78b8
github.com/gobuffalo/packr v1.24.0 // indirect
github.com/gobuffalo/packr/v2 v2.8.0
Expand Down
3 changes: 3 additions & 0 deletions jwk/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import (
"net/http/httptest"
"testing"

"github.com/ory/hydra/jwk"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
jose "gopkg.in/square/go-jose.v2"
Expand All @@ -45,6 +47,7 @@ func TestHandlerWellKnown(t *testing.T) {
viper.Set(configuration.ViperKeyWellKnownKeys, []string{x.OpenIDConnectKeyName, x.OpenIDConnectKeyName})

router := x.NewRouterPublic()
var testGenerator = &jwk.RS256Generator{}
IDKS, _ := testGenerator.Generate("test-id", "sig")

h := reg.KeyHandler()
Expand Down
67 changes: 0 additions & 67 deletions jwk/manager_test.go

This file was deleted.

120 changes: 60 additions & 60 deletions persistence/sql/persister_oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,55 +57,6 @@ func (r OAuth2RequestSQL) TableName() string {
return "hydra_oauth2_" + string(r.Table)
}

func (r *OAuth2RequestSQL) toRequest(ctx context.Context, session fosite.Session, p *Persister) (*fosite.Request, error) {
sess := r.Session
if !gjson.ValidBytes(sess) {
var err error
sess, err = p.r.KeyCipher().Decrypt(string(sess))
if err != nil {
return nil, errors.WithStack(err)
}
}

if session != nil {
if err := json.Unmarshal(sess, session); err != nil {
return nil, errors.WithStack(err)
}
} else {
p.l.Debugf("Got an empty session in toRequest")
}

c, err := p.GetClient(ctx, r.Client)
if err != nil {
return nil, err
}

val, err := url.ParseQuery(r.Form)
if err != nil {
return nil, errors.WithStack(err)
}

return &fosite.Request{
ID: r.Request,
RequestedAt: r.RequestedAt,
Client: c,
RequestedScope: stringsx.Splitx(r.Scopes, "|"),
GrantedScope: stringsx.Splitx(r.GrantedScope, "|"),
RequestedAudience: stringsx.Splitx(r.RequestedAudience, "|"),
GrantedAudience: stringsx.Splitx(r.GrantedAudience, "|"),
Form: val,
Session: session,
}, nil
}

// hashSignature prevents errors where the signature is longer than 128 characters (and thus doesn't fit into the pk).
func (p *Persister) hashSignature(signature string, table tableName) string {
if table == sqlTableAccess && p.config.IsUsingJWTAsAccessTokens() {
return fmt.Sprintf("%x", sha512.Sum384([]byte(signature)))
}
return signature
}

func (p *Persister) sqlSchemaFromRequest(rawSignature string, r fosite.Requester, table tableName) (*OAuth2RequestSQL, error) {
subject := ""
if r.GetSession() == nil {
Expand Down Expand Up @@ -155,13 +106,53 @@ func (p *Persister) sqlSchemaFromRequest(rawSignature string, r fosite.Requester
}, nil
}

func (p *Persister) GetClientAssertionJWT(ctx context.Context, j string) (*oauth2.BlacklistedJTI, error) {
jti := oauth2.NewBlacklistedJTI(j, time.Time{})
return jti, sqlcon.HandleError(p.Connection(ctx).Find(jti, jti.ID))
func (r *OAuth2RequestSQL) toRequest(ctx context.Context, session fosite.Session, p *Persister) (*fosite.Request, error) {
sess := r.Session
if !gjson.ValidBytes(sess) {
var err error
sess, err = p.r.KeyCipher().Decrypt(string(sess))
if err != nil {
return nil, errors.WithStack(err)
}
}

if session != nil {
if err := json.Unmarshal(sess, session); err != nil {
return nil, errors.WithStack(err)
}
} else {
p.l.Debugf("Got an empty session in toRequest")
}

c, err := p.GetClient(ctx, r.Client)
if err != nil {
return nil, err
}

val, err := url.ParseQuery(r.Form)
if err != nil {
return nil, errors.WithStack(err)
}

return &fosite.Request{
ID: r.Request,
RequestedAt: r.RequestedAt,
Client: c,
RequestedScope: stringsx.Splitx(r.Scopes, "|"),
GrantedScope: stringsx.Splitx(r.GrantedScope, "|"),
RequestedAudience: stringsx.Splitx(r.RequestedAudience, "|"),
GrantedAudience: stringsx.Splitx(r.GrantedAudience, "|"),
Form: val,
Session: session,
}, nil
}

func (p *Persister) SetClientAssertionJWTRaw(ctx context.Context, jti *oauth2.BlacklistedJTI) error {
return sqlcon.HandleError(p.Connection(ctx).Create(jti))
// hashSignature prevents errors where the signature is longer than 128 characters (and thus doesn't fit into the pk).
func (p *Persister) hashSignature(signature string, table tableName) string {
if table == sqlTableAccess && p.config.IsUsingJWTAsAccessTokens() {
return fmt.Sprintf("%x", sha512.Sum384([]byte(signature)))
}
return signature
}

func (p *Persister) ClientAssertionJWTValid(ctx context.Context, jti string) error {
Expand Down Expand Up @@ -203,17 +194,27 @@ func (p *Persister) SetClientAssertionJWT(ctx context.Context, jti string, exp t
})
}

func (p *Persister) GetClientAssertionJWT(ctx context.Context, j string) (*oauth2.BlacklistedJTI, error) {
jti := oauth2.NewBlacklistedJTI(j, time.Time{})
return jti, sqlcon.HandleError(p.Connection(ctx).Find(jti, jti.ID))
}

func (p *Persister) SetClientAssertionJWTRaw(ctx context.Context, jti *oauth2.BlacklistedJTI) error {
return sqlcon.HandleError(p.Connection(ctx).Create(jti))
}

func (p *Persister) createSession(ctx context.Context, signature string, requester fosite.Requester, table tableName) error {
req, err := p.sqlSchemaFromRequest(signature, requester, table)
if err != nil {
return err
}

err = sqlcon.HandleError(p.Connection(ctx).Create(req))
if errors.Is(err, sqlcon.ErrConcurrentUpdate) {
if err := sqlcon.HandleError(p.Connection(ctx).Create(req)); errors.Is(err, sqlcon.ErrConcurrentUpdate) {
return errors.Wrap(fosite.ErrSerializationFailure, err.Error())
} else if err != nil {
return err
}
return err
return nil
}

func (p *Persister) findSessionBySignature(ctx context.Context, rawSignature string, session fosite.Session, table tableName) (fosite.Requester, error) {
Expand Down Expand Up @@ -259,13 +260,12 @@ func (p *Persister) revokeSession(ctx context.Context, id string, table tableNam
if err := p.Connection(ctx).RawQuery(
fmt.Sprintf("DELETE FROM %s WHERE request_id=?", OAuth2RequestSQL{Table: table}.TableName()),
id,
).Exec(); err == sql.ErrNoRows {
).Exec(); errors.Is(err, sql.ErrNoRows) {
return errors.WithStack(fosite.ErrNotFound)
} else if err := sqlcon.HandleError(err); err != nil {
if errors.Is(err, sqlcon.ErrConcurrentUpdate) {
return errors.Wrap(fosite.ErrSerializationFailure, err.Error())
} else if strings.Contains(err.Error(), "Error 1213") {
p.l.Infof("got error 1213: %+v", err)
} else if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock?
return errors.Wrap(fosite.ErrSerializationFailure, err.Error())
}
return err
Expand Down
41 changes: 31 additions & 10 deletions persistence/sql/persister_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import (
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/ory/hydra/jwk"

"github.com/ory/hydra/client"
"github.com/ory/hydra/consent"
"github.com/ory/hydra/driver"
Expand All @@ -25,20 +29,37 @@ func TestManagers(t *testing.T) {
}

for k, m := range registries {
t.Run("package=client", func(t *testing.T) {
t.Run("case=create-get-update-delete", func(t *testing.T) {
t.Run(fmt.Sprintf("db=%s", k), client.TestHelperCreateGetUpdateDeleteClient(k, m.ClientManager()))
})
t.Run("package=client/manager="+k, func(t *testing.T) {
t.Run("case=create-get-update-delete", client.TestHelperCreateGetUpdateDeleteClient(k, m.ClientManager()))

t.Run("case=autogenerate-key", func(t *testing.T) {
t.Run(fmt.Sprintf("db=%s", k), client.TestHelperClientAutoGenerateKey(k, m.ClientManager()))
})
t.Run("case=autogenerate-key", client.TestHelperClientAutoGenerateKey(k, m.ClientManager()))

t.Run("case=auth-client", func(t *testing.T) {
t.Run(fmt.Sprintf("db=%s", k), client.TestHelperClientAuthenticate(k, m.ClientManager()))
})
t.Run("case=auth-client", client.TestHelperClientAuthenticate(k, m.ClientManager()))
})

t.Run("package=consent/manager="+k, consent.ManagerTests(m.ConsentManager(), m.ClientManager(), m.OAuth2Storage()))

t.Run("package=jwk/manager="+k, func(t *testing.T) {
var testGenerator = &jwk.RS256Generator{}

t.Run("TestManagerKey", func(t *testing.T) {
ks, err := testGenerator.Generate("TestManagerKey", "sig")
require.NoError(t, err)

for name, r := range registries {
t.Run(fmt.Sprintf("case=%s", name), jwk.TestHelperManagerKey(r.KeyManager(), ks, "TestManagerKey"))
}
})

t.Run("TestManagerKeySet", func(t *testing.T) {
ks, err := testGenerator.Generate("TestManagerKeySet", "sig")
require.NoError(t, err)
ks.Key("private")

for name, r := range registries {
t.Run(fmt.Sprintf("case=%s", name), jwk.TestHelperManagerKeySet(r.KeyManager(), ks, "TestManagerKeySet"))
}
})
})
}
}

0 comments on commit 1354611

Please sign in to comment.