Skip to content

Commit

Permalink
fix: remove redundant queries to get session (#1204)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?
* We're querying for the session redundantly when we already have the
sessionId. Every time we query for a session, it fetches all the amr
claims tied to the session too.
* There's no need to query for the entire session when we already have
the session id, since the `auth.mfa_amr_claims` table already has a
foreign-key constraint on the `auth.sessions.id` column, the insert will
fail if the given `sessionId` doesn't exist
  • Loading branch information
kangmingtay authored Aug 1, 2023
1 parent 1802ff3 commit 669ce97
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 17 deletions.
15 changes: 3 additions & 12 deletions internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,7 @@ func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, u
return internalServerError("Database error granting user").WithInternalError(terr)
}

session, terr := models.FindSessionByID(tx, *refreshToken.SessionId, false)
if terr != nil {
return terr
}
terr = models.AddClaimToSession(tx, session, authenticationMethod)
terr = models.AddClaimToSession(tx, *refreshToken.SessionId, authenticationMethod)
if terr != nil {
return terr
}
Expand Down Expand Up @@ -361,15 +357,10 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection,
return nil, internalServerError("Cannot read SessionId claim as UUID").WithInternalError(err)
}
err = tx.Transaction(func(tx *storage.Connection) error {
session, terr := models.FindSessionByID(tx, sessionId, false)
if terr != nil {
if terr := models.AddClaimToSession(tx, sessionId, authenticationMethod); terr != nil {
return terr
}
terr = models.AddClaimToSession(tx, session, authenticationMethod)
if terr != nil {
return terr
}
session, terr = models.FindSessionByID(tx, sessionId, false)
session, terr := models.FindSessionByID(tx, sessionId, false)
if terr != nil {
return terr
}
Expand Down
4 changes: 2 additions & 2 deletions internal/models/amr.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ func (AMRClaim) TableName() string {
return tableName
}

func AddClaimToSession(tx *storage.Connection, session *Session, authenticationMethod AuthenticationMethod) error {
func AddClaimToSession(tx *storage.Connection, sessionId uuid.UUID, authenticationMethod AuthenticationMethod) error {
id := uuid.Must(uuid.NewV4())

currentTime := time.Now()
return tx.RawQuery("INSERT INTO "+(&pop.Model{Value: AMRClaim{}}).TableName()+
`(id, session_id, created_at, updated_at, authentication_method) values (?, ?, ?, ?, ?)
ON CONFLICT ON CONSTRAINT mfa_amr_claims_session_id_authentication_method_pkey
DO UPDATE SET updated_at = ?;`, id, session.ID, currentTime, currentTime, authenticationMethod.String(), currentTime).Exec()
DO UPDATE SET updated_at = ?;`, id, sessionId, currentTime, currentTime, authenticationMethod.String(), currentTime).Exec()
}

func (a *AMRClaim) GetAuthenticationMethod() string {
Expand Down
6 changes: 3 additions & 3 deletions internal/models/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ func (ts *SessionsTestSuite) TestCalculateAALAndAMR() {
session.UserID = u.ID
require.NoError(ts.T(), ts.db.Create(session))

err = AddClaimToSession(ts.db, session, PasswordGrant)
err = AddClaimToSession(ts.db, session.ID, PasswordGrant)
require.NoError(ts.T(), err)

firstClaimAddedTime := time.Now()
err = AddClaimToSession(ts.db, session, TOTPSignIn)
err = AddClaimToSession(ts.db, session.ID, TOTPSignIn)
require.NoError(ts.T(), err)
session, err = FindSessionByID(ts.db, session.ID, false)
require.NoError(ts.T(), err)
Expand All @@ -77,7 +77,7 @@ func (ts *SessionsTestSuite) TestCalculateAALAndAMR() {
require.Equal(ts.T(), AAL2.String(), aal)
require.Equal(ts.T(), totalDistinctClaims, len(amr))

err = AddClaimToSession(ts.db, session, TOTPSignIn)
err = AddClaimToSession(ts.db, session.ID, TOTPSignIn)
require.NoError(ts.T(), err)

session, err = FindSessionByID(ts.db, session.ID, false)
Expand Down

0 comments on commit 669ce97

Please sign in to comment.