Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: refactor request params to use generics #1464

Merged
merged 5 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 5 additions & 16 deletions internal/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,12 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex
}

func (a *API) getAdminParams(r *http.Request) (*AdminUserParams, error) {
params := AdminUserParams{}

body, err := getBodyBytes(r)
params, err := retrieveRequestParams(r, &AdminUserParams{})
if err != nil {
return nil, badRequestError("Could not read body").WithInternalError(err)
return nil, err
}

if err := json.Unmarshal(body, &params); err != nil {
return nil, badRequestError("Could not decode admin user params: %v", err)
}

return &params, nil
return params, nil
}

// adminUsers responds with a list of all users in a given audience
Expand Down Expand Up @@ -564,14 +558,9 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
factor := getFactor(ctx)
user := getUser(ctx)
adminUser := getAdminUser(ctx)
params := &adminUserUpdateFactorParams{}
body, err := getBodyBytes(r)
params, err := retrieveRequestParams(r, &adminUserUpdateFactorParams{})
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read factor update params: %v", err)
return err
}

err = a.db.Transaction(func(tx *storage.Connection) error {
Expand Down
11 changes: 11 additions & 0 deletions internal/api/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,14 @@ func isStringInSlice(checkValue string, list []string) bool {
func getBodyBytes(req *http.Request) ([]byte, error) {
return utilities.GetBodyBytes(req)
}

func retrieveRequestParams[A any](r *http.Request, params *A) (*A, error) {
body, err := getBodyBytes(r)
if err != nil {
return nil, internalServerError("Could not read body into byte slice").WithInternalError(err)
}
if err := json.Unmarshal(body, params); err != nil {
return nil, badRequestError("Could not read request body: %v", err)
}
return params, nil
}
11 changes: 2 additions & 9 deletions internal/api/invite.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package api

import (
"encoding/json"
"net/http"

"github.com/fatih/structs"
Expand All @@ -23,15 +22,9 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error {
db := a.db.WithContext(ctx)
config := a.config
adminUser := getAdminUser(ctx)
params := &InviteParams{}

body, err := getBodyBytes(r)
params, err := retrieveRequestParams(r, &InviteParams{})
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read Invite params: %v", err)
return err
}

params.Email, err = validateEmail(params.Email)
Expand Down
11 changes: 2 additions & 9 deletions internal/api/mail.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package api

import (
"encoding/json"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -48,15 +47,9 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error {
config := a.config
mailer := a.Mailer(ctx)
adminUser := getAdminUser(ctx)
params := &GenerateLinkParams{}

body, err := getBodyBytes(r)
params, err := retrieveRequestParams(r, &GenerateLinkParams{})
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not parse JSON: %v", err)
return err
}

params.Email, err = validateEmail(params.Email)
Expand Down
24 changes: 6 additions & 18 deletions internal/api/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package api

import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -65,16 +64,11 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error {
session := getSession(ctx)
config := a.config

params := &EnrollFactorParams{}
issuer := ""
body, err := getBodyBytes(r)
params, err := retrieveRequestParams(r, &EnrollFactorParams{})
if err != nil {
return internalServerError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("invalid body: unable to parse JSON").WithInternalError(err)
return err
}
issuer := ""

if params.FactorType != models.TOTP {
return badRequestError("factor_type needs to be totp")
Expand Down Expand Up @@ -205,17 +199,11 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error {
factor := getFactor(ctx)
config := a.config

params := &VerifyFactorParams{}
currentIP := utilities.GetIPAddress(r)

body, err := getBodyBytes(r)
params, err := retrieveRequestParams(r, &VerifyFactorParams{})
if err != nil {
return internalServerError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("invalid body: unable to parse JSON").WithInternalError(err)
return err
}
currentIP := utilities.GetIPAddress(r)

if !factor.IsOwnedBy(user) {
return internalServerError(InvalidFactorOwnerErrorMessage)
Expand Down
12 changes: 4 additions & 8 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,22 +95,18 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler {

if shouldRateLimitEmail || shouldRateLimitPhone {
if req.Method == "PUT" || req.Method == "POST" {
bodyBytes, err := getBodyBytes(req)
if err != nil {
return c, internalServerError("Error invalid request body").WithInternalError(err)
}

var requestBody struct {
Email string `json:"email"`
Phone string `json:"phone"`
}

if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
params, err := retrieveRequestParams(req, &requestBody)
if err != nil {
return c, badRequestError("Error invalid request body").WithInternalError(err)
}

if shouldRateLimitEmail {
if requestBody.Email != "" {
if params.Email != "" {
if err := tollbooth.LimitByKeys(emailLimiter, []string{"email_functions"}); err != nil {
emailRateLimitCounter.Add(
req.Context(),
Expand All @@ -123,7 +119,7 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler {
}

if shouldRateLimitPhone {
if requestBody.Phone != "" {
if params.Phone != "" {
if err := tollbooth.LimitByKeys(phoneLimiter, []string{"phone_functions"}); err != nil {
return c, httpError(http.StatusTooManyRequests, "Sms rate limit exceeded")
}
Expand Down
16 changes: 4 additions & 12 deletions internal/api/otp.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,19 @@ func (p *SmsParams) Validate(smsProvider string) error {

// Otp returns the MagicLink or SmsOtp handler based on the request body params
func (a *API) Otp(w http.ResponseWriter, r *http.Request) error {
var err error
params := &OtpParams{
CreateUser: true,
}
if params.Data == nil {
params.Data = make(map[string]interface{})
}

body, err := getBodyBytes(r)
params, err = retrieveRequestParams(r, params)
if err != nil {
return err
}

if err = json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read verification params: %v", err)
}

if err := params.Validate(); err != nil {
return err
}
Expand Down Expand Up @@ -114,16 +111,11 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error {
}
var err error

params := &SmsParams{}

body, err := getBodyBytes(r)
params, err := retrieveRequestParams(r, &SmsParams{})
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
return err
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read sms otp params: %v", err)
}
// For backwards compatibility, we default to SMS if params Channel is not specified
if params.Phone != "" && params.Channel == "" {
params.Channel = sms_provider.SMSProvider
Expand Down
11 changes: 2 additions & 9 deletions internal/api/recover.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package api

import (
"encoding/json"
"errors"
"net/http"

Expand Down Expand Up @@ -36,15 +35,9 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := a.db.WithContext(ctx)
config := a.config
params := &RecoverParams{}

body, err := getBodyBytes(r)
params, err := retrieveRequestParams(r, &RecoverParams{})
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read verification params: %v", err)
return err
}

flowType := getFlowFromChallenge(params.CodeChallenge)
Expand Down
11 changes: 2 additions & 9 deletions internal/api/resend.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package api

import (
"encoding/json"
"errors"
"net/http"
"time"
Expand Down Expand Up @@ -67,15 +66,9 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := a.db.WithContext(ctx)
config := a.config
params := &ResendConfirmationParams{}

body, err := getBodyBytes(r)
params, err := retrieveRequestParams(r, &ResendConfirmationParams{})
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read params: %v", err)
return err
}

if err := params.Validate(config); err != nil {
Expand Down
9 changes: 2 additions & 7 deletions internal/api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package api

import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
Expand Down Expand Up @@ -109,13 +108,9 @@ func (params *SignupParams) ToUserModel(isSSOUser bool) (user *models.User, err
}

func retrieveSignupParams(r *http.Request) (*SignupParams, error) {
params := &SignupParams{}
body, err := getBodyBytes(r)
params, err := retrieveRequestParams(r, &SignupParams{})
if err != nil {
return nil, internalServerError("Could not read body").WithInternalError(err)
}
if err := json.Unmarshal(body, params); err != nil {
return nil, badRequestError("Could not read Signup params: %v", err)
return nil, err
}
return params, nil
}
Expand Down
11 changes: 2 additions & 9 deletions internal/api/sso.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package api

import (
"encoding/json"
"net/http"

"github.com/crewjam/saml"
Expand Down Expand Up @@ -41,15 +40,9 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := a.db.WithContext(ctx)

body, err := getBodyBytes(r)
params, err := retrieveRequestParams(r, &SingleSignOnParams{})
if err != nil {
return internalServerError("Unable to read request body").WithInternalError(err)
}

var params SingleSignOnParams

if err := json.Unmarshal(body, &params); err != nil {
return badRequestError("Unable to parse request body as JSON").WithInternalError(err)
return err
}

hasProviderID := false
Expand Down
19 changes: 4 additions & 15 deletions internal/api/ssoadmin.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package api

import (
"context"
"encoding/json"
"io"
"net/http"
"net/url"
Expand Down Expand Up @@ -184,14 +183,9 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er
ctx := r.Context()
db := a.db.WithContext(ctx)

body, err := getBodyBytes(r)
params, err := retrieveRequestParams(r, &CreateSSOProviderParams{})
if err != nil {
return internalServerError("Unable to read request body").WithInternalError(err)
}

var params CreateSSOProviderParams
if err := json.Unmarshal(body, &params); err != nil {
return badRequestError("Unable to parse JSON").WithInternalError(err)
return err
}

if err := params.validate(false /* <- forUpdate */); err != nil {
Expand Down Expand Up @@ -264,14 +258,9 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er
ctx := r.Context()
db := a.db.WithContext(ctx)

body, err := getBodyBytes(r)
params, err := retrieveRequestParams(r, &CreateSSOProviderParams{})
if err != nil {
return internalServerError("Unable to read request body").WithInternalError(err)
}

var params CreateSSOProviderParams
if err := json.Unmarshal(body, &params); err != nil {
return badRequestError("Unable to parse JSON").WithInternalError(err)
return err
}

if err := params.validate(true /* <- forUpdate */); err != nil {
Expand Down
Loading
Loading