Skip to content

Commit

Permalink
refactor: generic openapi authenticator
Browse files Browse the repository at this point in the history
  • Loading branch information
katallaxie committed May 6, 2024
1 parent bda82d8 commit ff11740
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 71 deletions.
19 changes: 18 additions & 1 deletion authz.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ const (
authzPrincipial contextKey = iota
authzObject
authzAction
authzAPIKey
authzContext
)

// Unimplemented is the default implementation.
Expand All @@ -74,6 +74,23 @@ func (u *Unimplemented) Allowed(_ context.Context, _ AuthzPrincipal, _ AuthzObje
return false, nil
}

var _ AuthzChecker = (*Fake)(nil)

// Fake is a fake authz checker.
type Fake struct {
allowd bool
}

// NewFake returns a new Fake authz checker.
func NewFake(allowed bool) *Fake {
return &Fake{allowd: allowed}
}

// Allowed returns true if the principal is allowed to perform the action on the object.
func (f *Fake) Allowed(_ context.Context, _ AuthzPrincipal, _ AuthzObject, _ AuthzAction) (bool, error) {
return f.allowd, nil
}

// Config ...
type Config struct {
// Next defines a function to skip this middleware when returned true.
Expand Down
51 changes: 51 additions & 0 deletions authz_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package authz

import (
"context"
"testing"

"github.com/stretchr/testify/require"
)

func TestUnimplemented(t *testing.T) {
t.Parallel()

checker := &Unimplemented{}
require.NotNil(t, checker)

allowed, err := checker.Allowed(context.TODO(), "principal", "object", "action")
require.NoError(t, err)
require.False(t, allowed)
}

func TestFakeChecker(t *testing.T) {
t.Parallel()

tests := []struct {
name string
allowed bool
expected bool
}{
{
name: "allowed",
allowed: true,
expected: true,
},
{
name: "not allowed",
allowed: false,
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
checker := NewFake(tt.allowed)
require.NotNil(t, checker)

allowed, err := checker.Allowed(context.TODO(), "principal", "object", "action")
require.NoError(t, err)
require.Equal(t, tt.expected, allowed)
})
}
}
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module github.com/zeiss/fiber-authz

go 1.21.6
go 1.21.9

toolchain go1.22.2

require (
Expand Down
146 changes: 77 additions & 69 deletions openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,45 @@ import (
"context"
"errors"
"fmt"
"net/http"

"github.com/getkin/kin-openapi/openapi3filter"
"github.com/gofiber/fiber/v2"
middleware "github.com/oapi-codegen/fiber-middleware"
"gorm.io/gorm"
)

// ErrNoAuthzContext is the error returned when the context is not found.
var ErrNoAuthzContext = errors.New("no authz context")

// AuthzContext is the type of the context key.
type AuthzContext struct {
Principal AuthzPrincipal
Object AuthzObject
Action AuthzAction
}

// NewAuthzContext is the constructor for the AuthzContext.
func NewAuthzContext(principal AuthzPrincipal, object AuthzObject, action AuthzAction) AuthzContext {
return AuthzContext{
Principal: principal,
Object: object,
Action: action,
}
}

// AuthzExtractor is the interface that wraps the Extract method.
type AuthzExtractor func(c *fiber.Ctx) (AuthzPrincipal, AuthzObject, AuthzAction, error)

// DefaultAuthzExtractor is the default authz extractor.
func DefaultAuthzExtractor(c *fiber.Ctx) (AuthzPrincipal, AuthzObject, AuthzAction, error) {
return AuthzNoPrincipial, AuthzNoObject, AuthzNoAction, nil
}

// OpenAPIAuthenticatorOpts are the OpenAPI authenticator options.
type OpenAPIAuthenticatorOpts struct {
PathParam string
Checker AuthzChecker
AuthzPrincipalResolver AuthzPrincipalResolver
AuthzObjectResolver AuthzObjectResolver
AuthzActionResolver AuthzActionResolver
AuthzChecker AuthzChecker
}

// Conigure the OpenAPI authenticator.
Expand All @@ -31,22 +58,38 @@ type OpenAPIAuthenticatorOpt func(*OpenAPIAuthenticatorOpts)
// OpenAPIAuthenticatorDefaultOpts are the default OpenAPI authenticator options.
func OpenAPIAuthenticatorDefaultOpts() OpenAPIAuthenticatorOpts {
return OpenAPIAuthenticatorOpts{
PathParam: "teamId",
Checker: NewNoop(),
AuthzChecker: NewNoop(),
AuthzPrincipalResolver: NewNoopPrincipalResolver(),
AuthzObjectResolver: NewNoopObjectResolver(),
AuthzActionResolver: NewNoopActionResolver(),
}
}

// WithPathParam sets the path parameter.
func WithPathParam(param string) OpenAPIAuthenticatorOpt {
// WithAuthzPrincipalResolver sets the authz extractor.
func WithAuthzPrincipalResolver(resolver AuthzPrincipalResolver) OpenAPIAuthenticatorOpt {
return func(opts *OpenAPIAuthenticatorOpts) {
opts.PathParam = param
opts.AuthzPrincipalResolver = resolver
}
}

// WithChecker sets the authz checker.
func WithChecker(checker AuthzChecker) OpenAPIAuthenticatorOpt {
// WithAuthzObjectResolver sets the authz extractor.
func WithAuthzObjectResolver(resolver AuthzObjectResolver) OpenAPIAuthenticatorOpt {
return func(opts *OpenAPIAuthenticatorOpts) {
opts.Checker = checker
opts.AuthzObjectResolver = resolver
}
}

// WithAuthzActionResolver sets the authz extractor.
func WithAuthzActionResolver(resolver AuthzActionResolver) OpenAPIAuthenticatorOpt {
return func(opts *OpenAPIAuthenticatorOpts) {
opts.AuthzActionResolver = resolver
}
}

// WithAuthzChecker sets the authz checker.
func WithAuthzChecker(checker AuthzChecker) OpenAPIAuthenticatorOpt {
return func(opts *OpenAPIAuthenticatorOpts) {
opts.AuthzChecker = checker
}
}

Expand All @@ -63,41 +106,40 @@ func NewOpenAPIErrorHandler() middleware.ErrorHandler {
// NewOpenAPIAuthenticator creates a new OpenAPI authenticator.
func NewOpenAPIAuthenticator(opts ...OpenAPIAuthenticatorOpt) openapi3filter.AuthenticationFunc {
return func(ctx context.Context, input *openapi3filter.AuthenticationInput) error {
opt := OpenAPIAuthenticatorDefaultOpts()
opt.Conigure(opts...)
options := OpenAPIAuthenticatorDefaultOpts()
options.Conigure(opts...)

c := middleware.GetFiberContext(ctx)

teamId, ok := input.RequestValidationInput.PathParams[opt.PathParam]
if !ok {
return fiber.NewError(fiber.StatusBadRequest, fmt.Errorf("missing path parameter %s", opt.PathParam).Error())
principal, err := options.AuthzPrincipalResolver.Resolve(c)
if err != nil {
return fiber.NewError(fiber.StatusBadRequest, fmt.Errorf("error resolving principal: %w", err).Error())
}

key, err := GetAPIKeyFromRequest(input.RequestValidationInput.Request)
object, err := options.AuthzObjectResolver.Resolve(c)
if err != nil {
return err
return fiber.NewError(fiber.StatusBadRequest, fmt.Errorf("error resolving object: %w", err).Error())
}

err = validate.Var(key, "required,uuid")
action, err := options.AuthzActionResolver.Resolve(c)
if err != nil {
return fiber.NewError(fiber.StatusUnauthorized, "Invalid API key")
return fiber.NewError(fiber.StatusBadRequest, fmt.Errorf("error resolving action: %w", err).Error())
}

allowed := len(input.Scopes) == 0
if len(input.Scopes) > 0 {
allowed, err = opt.Checker.Allowed(ctx, AuthzPrincipal(key), AuthzObject(teamId), AuthzAction(input.Scopes[0]))
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, "Internal Server Error")
}
allowed, err := options.AuthzChecker.Allowed(ctx, principal, object, action)
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, "internal server error")
}

if !allowed {
return fiber.NewError(fiber.StatusForbidden, "Forbidden")
return fiber.NewError(fiber.StatusForbidden, "forbidden")
}

// Create a new context with the API key.
// Create a new context for the authz context.
usrCtx := c.UserContext()
authCtx := context.WithValue(usrCtx, authzAPIKey, key)

authzCtx := NewAuthzContext(principal, object, action)
authCtx := context.WithValue(usrCtx, authzContext, authzCtx)

// nolint: contextcheck
c.SetUserContext(authCtx)
Expand All @@ -106,47 +148,13 @@ func NewOpenAPIAuthenticator(opts ...OpenAPIAuthenticatorOpt) openapi3filter.Aut
}
}

// GetAPIKeyFromContext extracts the API key from the context.
func GetAPIKeyFromContext(ctx context.Context) (string, error) {
key := ctx.Value(authzAPIKey)
// GetAuthzContext extracts the AuthzContext from the context.
func GetAuthzContext(ctx context.Context) (AuthzContext, error) {
key := ctx.Value(authzContext)

if key == nil {
return "", errors.New("API key not found")
}

return key.(string), nil
}

// GetAPIKeyFromRequest is a fake implementation of the API key extractor.
func GetAPIKeyFromRequest(req *http.Request) (string, error) {
return req.Header.Get("x-api-key"), nil
}

var _ AuthzChecker = (*apiKey)(nil)

type apiKey struct {
db *gorm.DB
}

// NewAPIKey returns a new API key authenticator.
func NewAPIKey(db *gorm.DB) *apiKey {
return &apiKey{
db: db,
}
}

// Allowed is a method that returns true if the principal is allowed to perform the action on the user.
func (t *apiKey) Allowed(ctx context.Context, principal AuthzPrincipal, object AuthzObject, action AuthzAction) (bool, error) {
var allowed int64

err := t.db.Raw("SELECT COUNT(1) FROM vw_api_key_team_permissions WHERE key_id = ? AND team_id = (?) AND permission = ?", principal, object, action).Count(&allowed).Error
if err != nil {
return false, err
}

if allowed > 0 {
return true, nil
return AuthzContext{}, ErrNoAuthzContext
}

return false, nil
return key.(AuthzContext), nil
}

0 comments on commit ff11740

Please sign in to comment.