Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: refactor request params to use generics #1464

Merged
merged 5 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 7 additions & 18 deletions internal/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,12 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex
}

func (a *API) getAdminParams(r *http.Request) (*AdminUserParams, error) {
params := AdminUserParams{}

body, err := getBodyBytes(r)
if err != nil {
return nil, badRequestError("Could not read body").WithInternalError(err)
params := &AdminUserParams{}
if err := retrieveRequestParams(r, params); err != nil {
return nil, err
}

if err := json.Unmarshal(body, &params); err != nil {
return nil, badRequestError("Could not decode admin user params: %v", err)
}

return &params, nil
return params, nil
}

// adminUsers responds with a list of all users in a given audience
Expand Down Expand Up @@ -565,16 +559,11 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
user := getUser(ctx)
adminUser := getAdminUser(ctx)
params := &adminUserUpdateFactorParams{}
body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read factor update params: %v", err)
if err := retrieveRequestParams(r, params); err != nil {
return err
}

err = a.db.Transaction(func(tx *storage.Connection) error {
err := a.db.Transaction(func(tx *storage.Connection) error {
if params.FriendlyName != "" {
if terr := factor.UpdateFriendlyName(tx, params.FriendlyName); terr != nil {
return terr
Expand Down
4 changes: 2 additions & 2 deletions internal/api/anonymous.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error {
return forbiddenError("Signups not allowed for this instance")
}

params, err := retrieveSignupParams(r)
if err != nil {
params := &SignupParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}
params.Aud = aud
Expand Down
4 changes: 2 additions & 2 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(api.config.RateLimitAnonymousUsers)).SetMethods([]string{"POST"})
r.Post("/", func(w http.ResponseWriter, r *http.Request) error {
params, err := retrieveSignupParams(r)
if err != nil {
params := &SignupParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}
if params.Email == "" && params.Phone == "" {
Expand Down
38 changes: 38 additions & 0 deletions internal/api/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,41 @@ func isStringInSlice(checkValue string, list []string) bool {
func getBodyBytes(req *http.Request) ([]byte, error) {
return utilities.GetBodyBytes(req)
}

type RequestParams interface {
AdminUserParams |
CreateSSOProviderParams |
EnrollFactorParams |
GenerateLinkParams |
IdTokenGrantParams |
InviteParams |
OtpParams |
PKCEGrantParams |
PasswordGrantParams |
RecoverParams |
RefreshTokenGrantParams |
ResendConfirmationParams |
SignupParams |
SingleSignOnParams |
SmsParams |
UserUpdateParams |
VerifyFactorParams |
VerifyParams |
adminUserUpdateFactorParams |
struct {
Email string `json:"email"`
Phone string `json:"phone"`
}
}

// retrieveRequestParams is a generic method that unmarshals the request body into the params struct provided
func retrieveRequestParams[A RequestParams](r *http.Request, params *A) error {
body, err := getBodyBytes(r)
if err != nil {
return internalServerError("Could not read body into byte slice").WithInternalError(err)
}
if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read request body: %v", err)
}
return nil
}
12 changes: 3 additions & 9 deletions internal/api/invite.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package api

import (
"encoding/json"
"net/http"

"github.com/fatih/structs"
Expand All @@ -24,16 +23,11 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error {
config := a.config
adminUser := getAdminUser(ctx)
params := &InviteParams{}

body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read Invite params: %v", err)
if err := retrieveRequestParams(r, params); err != nil {
return err
}

var err error
params.Email, err = validateEmail(params.Email)
if err != nil {
return err
Expand Down
12 changes: 3 additions & 9 deletions internal/api/mail.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package api

import (
"encoding/json"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -49,16 +48,11 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error {
mailer := a.Mailer(ctx)
adminUser := getAdminUser(ctx)
params := &GenerateLinkParams{}

body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not parse JSON: %v", err)
if err := retrieveRequestParams(r, params); err != nil {
return err
}

var err error
params.Email, err = validateEmail(params.Email)
if err != nil {
return err
Expand Down
24 changes: 6 additions & 18 deletions internal/api/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package api

import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -66,15 +65,10 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error {
config := a.config

params := &EnrollFactorParams{}
issuer := ""
body, err := getBodyBytes(r)
if err != nil {
return internalServerError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("invalid body: unable to parse JSON").WithInternalError(err)
if err := retrieveRequestParams(r, params); err != nil {
return err
}
issuer := ""

if params.FactorType != models.TOTP {
return badRequestError("factor_type needs to be totp")
Expand Down Expand Up @@ -206,16 +200,10 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error {
config := a.config

params := &VerifyFactorParams{}
currentIP := utilities.GetIPAddress(r)

body, err := getBodyBytes(r)
if err != nil {
return internalServerError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("invalid body: unable to parse JSON").WithInternalError(err)
if err := retrieveRequestParams(r, params); err != nil {
return err
}
currentIP := utilities.GetIPAddress(r)

if !factor.IsOwnedBy(user) {
return internalServerError(InvalidFactorOwnerErrorMessage)
Expand Down
7 changes: 1 addition & 6 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,12 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler {

if shouldRateLimitEmail || shouldRateLimitPhone {
if req.Method == "PUT" || req.Method == "POST" {
bodyBytes, err := getBodyBytes(req)
if err != nil {
return c, internalServerError("Error invalid request body").WithInternalError(err)
}

var requestBody struct {
Email string `json:"email"`
Phone string `json:"phone"`
}

if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
if err := retrieveRequestParams(req, &requestBody); err != nil {
return c, badRequestError("Error invalid request body").WithInternalError(err)
}

Expand Down
16 changes: 3 additions & 13 deletions internal/api/otp.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,10 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error {
params.Data = make(map[string]interface{})
}

body, err := getBodyBytes(r)
if err != nil {
if err := retrieveRequestParams(r, params); err != nil {
return err
}

if err = json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read verification params: %v", err)
}

if err := params.Validate(); err != nil {
return err
}
Expand Down Expand Up @@ -115,15 +110,10 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error {
var err error

params := &SmsParams{}

body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
if err := retrieveRequestParams(r, params); err != nil {
return err
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read sms otp params: %v", err)
}
// For backwards compatibility, we default to SMS if params Channel is not specified
if params.Phone != "" && params.Channel == "" {
params.Channel = sms_provider.SMSProvider
Expand Down
12 changes: 3 additions & 9 deletions internal/api/recover.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package api

import (
"encoding/json"
"errors"
"net/http"

Expand Down Expand Up @@ -37,14 +36,8 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error {
db := a.db.WithContext(ctx)
config := a.config
params := &RecoverParams{}

body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read verification params: %v", err)
if err := retrieveRequestParams(r, params); err != nil {
return err
}

flowType := getFlowFromChallenge(params.CodeChallenge)
Expand All @@ -53,6 +46,7 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error {
}

var user *models.User
var err error
aud := a.requestAud(ctx, r)

user, err = models.FindUserByEmailAndAudience(db, params.Email, aud)
Expand Down
12 changes: 3 additions & 9 deletions internal/api/resend.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package api

import (
"encoding/json"
"errors"
"net/http"
"time"
Expand Down Expand Up @@ -68,21 +67,16 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error {
db := a.db.WithContext(ctx)
config := a.config
params := &ResendConfirmationParams{}

body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read params: %v", err)
if err := retrieveRequestParams(r, params); err != nil {
return err
}

if err := params.Validate(config); err != nil {
return err
}

var user *models.User
var err error
aud := a.requestAud(ctx, r)
if params.Email != "" {
user, err = models.FindUserByEmailAndAudience(db, params.Email, aud)
Expand Down
18 changes: 3 additions & 15 deletions internal/api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package api

import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
Expand Down Expand Up @@ -108,18 +107,6 @@ func (params *SignupParams) ToUserModel(isSSOUser bool) (user *models.User, err
return user, nil
}

func retrieveSignupParams(r *http.Request) (*SignupParams, error) {
params := &SignupParams{}
body, err := getBodyBytes(r)
if err != nil {
return nil, internalServerError("Could not read body").WithInternalError(err)
}
if err := json.Unmarshal(body, params); err != nil {
return nil, badRequestError("Could not read Signup params: %v", err)
}
return params, nil
}

// Signup is the endpoint for registering a new user
func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
Expand All @@ -130,8 +117,8 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
return forbiddenError("Signups not allowed for this instance")
}

params, err := retrieveSignupParams(r)
if err != nil {
params := &SignupParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}

Expand All @@ -142,6 +129,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
}

var codeChallengeMethod models.CodeChallengeMethod
var err error
flowType := getFlowFromChallenge(params.CodeChallenge)

if isPKCEFlow(flowType) {
Expand Down
Loading
Loading