From f3020d30fbf3f5e260e8930e20dff5f799f01036 Mon Sep 17 00:00:00 2001 From: Ajay Kelkar Date: Fri, 10 Feb 2023 17:29:19 +0100 Subject: [PATCH] refactor: persister query and setup more test data --- identity/test/pool.go | 43 ++++++++--- persistence/sql/persister_identity.go | 107 ++++++++++++++------------ 2 files changed, 90 insertions(+), 60 deletions(-) diff --git a/identity/test/pool.go b/identity/test/pool.go index 63cf4a45d90..07975852aeb 100644 --- a/identity/test/pool.go +++ b/identity/test/pool.go @@ -594,21 +594,46 @@ func TestPool(ctx context.Context, conf *config.Config, p interface { }) t.Run("case=find identity by its credentials identifier", func(t *testing.T) { - expected := passwordIdentity("", "find-identity-by-identifier@ory.sh") - expected.Traits = identity.Traits(`{}`) + var expectedIdentifiers []string + var expectedIdentities []*identity.Identity - require.NoError(t, p.CreateIdentity(ctx, expected)) - createdIDs = append(createdIDs, expected.ID) + for _, c := range []identity.CredentialsType{ + identity.CredentialsTypePassword, + identity.CredentialsTypeOIDC, + identity.CredentialsTypeWebAuthn, + } { + identityIdentifier := fmt.Sprintf("find-identity-by-identifier-%s@ory.sh", c) + expected := identity.NewIdentity("") + expected.SetCredentials(c, identity.Credentials{Type: c, Identifiers: []string{identityIdentifier}, Config: sqlxx.JSONRawMessage(`{}`)}) - actual, err := p.FindByCredentialsIdentifier(ctx, "find-identity-by-identifier@ory.sh") - require.NoError(t, err) + require.NoError(t, p.CreateIdentity(ctx, expected)) + createdIDs = append(createdIDs, expected.ID) + expectedIdentifiers = append(expectedIdentifiers, identityIdentifier) + expectedIdentities = append(expectedIdentities, expected) + } - expected.Credentials = nil - assertEqual(t, expected, actual) + for c, ct := range []identity.CredentialsType{ + identity.CredentialsTypePassword, + identity.CredentialsTypeOIDC, + identity.CredentialsTypeWebAuthn, + } { + t.Run(ct.String(), func(t *testing.T) { + actual, err := p.FindByCredentialsIdentifier(ctx, expectedIdentifiers[c]) + require.NoError(t, err) + + expected := expectedIdentities[c] + assert.EqualValues(t, expected.Credentials[ct].ID, actual.Credentials[ct].ID) + assert.EqualValues(t, expected.Credentials[ct].Identifiers, actual.Credentials[ct].Identifiers) + + expected.Credentials = nil + actual.Credentials = nil + assertEqual(t, expected, actual) + }) + } t.Run("not if on another network", func(t *testing.T) { _, p := testhelpers.NewNetwork(t, ctx, p) - _, err := p.FindByCredentialsIdentifier(ctx, "find-identity-by-identifier@ory.sh") + _, err := p.FindByCredentialsIdentifier(ctx, "find-identity-by-identifier-password@ory.sh") require.ErrorIs(t, err, sqlcon.ErrNoRows) }) }) diff --git a/persistence/sql/persister_identity.go b/persistence/sql/persister_identity.go index 2ab44685412..8d6fc15ba5f 100644 --- a/persistence/sql/persister_identity.go +++ b/persistence/sql/persister_identity.go @@ -87,31 +87,7 @@ func (p *Persister) FindByCredentialsIdentifier(ctx context.Context, match strin ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindByCredentialsIdentifier") defer span.End() - nid := p.NetworkID(ctx) - - var find struct { - IdentityID uuid.UUID `db:"identity_id"` - } - - // #nosec G201 - if err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(`SELECT - ic.identity_id -FROM %s ic - INNER JOIN %s ici on ic.id = ici.identity_credential_id -WHERE ici.identifier = ? - AND ic.nid = ? - AND ici.nid = ?`, - "identity_credentials", - "identity_credential_identifiers", - ), - match, - nid, - nid, - ).First(&find); err != nil { - return nil, sqlcon.HandleError(err) - } - - i, err := p.GetIdentity(ctx, find.IdentityID, identity.ExpandDefault) + i, err := p.findIdentityByIdentifier(ctx, nil, match) if err != nil { return nil, err } @@ -123,17 +99,32 @@ func (p *Persister) FindByCredentialsTypeAndIdentifier(ctx context.Context, ct i ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindByCredentialsTypeAndIdentifier") defer span.End() - nid := p.NetworkID(ctx) + // Force case-insensitivity and trimming for identifiers + match = p.normalizeIdentifier(ct, match) + + i, err := p.findIdentityByIdentifier(ctx, &ct, match) + if err != nil { + return nil, nil, err + } + + creds, ok := i.GetCredentials(ct) + if !ok { + return nil, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The SQL adapter failed to return the appropriate credentials_type \"%s\". This is a bug in the code.", ct)) + } + return i.CopyWithoutCredentials(), creds, nil +} + +func (p *Persister) findIdentityByIdentifier(ctx context.Context, ct *identity.CredentialsType, match string) (*identity.Identity, error) { var find struct { IdentityID uuid.UUID `db:"identity_id"` } - // Force case-insensitivity and trimming for identifiers - match = p.normalizeIdentifier(ct, match) + nid := p.NetworkID(ctx) - // #nosec G201 - if err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(`SELECT + if ct != nil { + // #nosec G201 + if err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(`SELECT ic.identity_id FROM %s ic INNER JOIN %s ict on ic.identity_credential_type_id = ict.id @@ -142,33 +133,47 @@ WHERE ici.identifier = ? AND ic.nid = ? AND ici.nid = ? AND ict.name = ?`, - "identity_credentials", - "identity_credential_types", - "identity_credential_identifiers", - ), - match, - nid, - nid, - ct, - ).First(&find); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, nil, sqlcon.HandleError(err) // herodot.ErrNotFound.WithTrace(err).WithReasonf(`No identity matching credentials identifier "%s" could be found.`, match) - } + "identity_credentials", + "identity_credential_types", + "identity_credential_identifiers", + ), + match, + nid, + nid, + ct, + ).First(&find); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, sqlcon.HandleError(err) // herodot.ErrNotFound.WithTrace(err).WithReasonf(`No identity matching credentials identifier "%s" could be found.`, match) + } - return nil, nil, sqlcon.HandleError(err) + return nil, sqlcon.HandleError(err) + } + } else { + // #nosec G201 + if err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(`SELECT + ic.identity_id +FROM %s ic + INNER JOIN %s ici on ic.id = ici.identity_credential_id +WHERE ici.identifier = ? + AND ic.nid = ? + AND ici.nid = ?`, + "identity_credentials", + "identity_credential_identifiers", + ), + match, + nid, + nid, + ).First(&find); err != nil { + return nil, sqlcon.HandleError(err) + } } - i, err := p.GetIdentityConfidential(ctx, find.IdentityID) + i, err := p.GetIdentity(ctx, find.IdentityID, identity.ExpandEverything) if err != nil { - return nil, nil, err - } - - creds, ok := i.GetCredentials(ct) - if !ok { - return nil, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The SQL adapter failed to return the appropriate credentials_type \"%s\". This is a bug in the code.", ct)) + return nil, err } - return i.CopyWithoutCredentials(), creds, nil + return i, nil } func (p *Persister) findIdentityCredentialsType(ctx context.Context, ct identity.CredentialsType) (*identity.CredentialsTypeTable, error) {