From 82827c2bf1d62ccce85980c21253843245f61218 Mon Sep 17 00:00:00 2001 From: Dhruv Bhoot Date: Thu, 22 Feb 2024 16:29:18 -0800 Subject: [PATCH] Allow a origin validation function with context --- config.go | 42 +++++++++++++++++++++++++----------------- cors.go | 10 ++++++++-- cors_test.go | 2 ++ 3 files changed, 35 insertions(+), 19 deletions(-) diff --git a/config.go b/config.go index 427cfc0..ef24187 100644 --- a/config.go +++ b/config.go @@ -8,14 +8,15 @@ import ( ) type cors struct { - allowAllOrigins bool - allowCredentials bool - allowOriginFunc func(string) bool - allowOrigins []string - normalHeaders http.Header - preflightHeaders http.Header - wildcardOrigins [][]string - optionsResponseStatusCode int + allowAllOrigins bool + allowCredentials bool + allowOriginFunc func(string) bool + allowOriginWithContextFunc func(*gin.Context, string) bool + allowOrigins []string + normalHeaders http.Header + preflightHeaders http.Header + wildcardOrigins [][]string + optionsResponseStatusCode int } var ( @@ -54,14 +55,15 @@ func newCors(config Config) *cors { } return &cors{ - allowOriginFunc: config.AllowOriginFunc, - allowAllOrigins: config.AllowAllOrigins, - allowCredentials: config.AllowCredentials, - allowOrigins: normalize(config.AllowOrigins), - normalHeaders: generateNormalHeaders(config), - preflightHeaders: generatePreflightHeaders(config), - wildcardOrigins: config.parseWildcardRules(), - optionsResponseStatusCode: config.OptionsResponseStatusCode, + allowOriginFunc: config.AllowOriginFunc, + allowOriginWithContextFunc: config.AllowOriginWithContextFunc, + allowAllOrigins: config.AllowAllOrigins, + allowCredentials: config.AllowCredentials, + allowOrigins: normalize(config.AllowOrigins), + normalHeaders: generateNormalHeaders(config), + preflightHeaders: generatePreflightHeaders(config), + wildcardOrigins: config.parseWildcardRules(), + optionsResponseStatusCode: config.OptionsResponseStatusCode, } } @@ -79,7 +81,13 @@ func (cors *cors) applyCors(c *gin.Context) { return } - if !cors.validateOrigin(origin) { + if cors.allowOriginWithContextFunc != nil { + if !cors.allowOriginWithContextFunc(c, origin) { + c.AbortWithStatus(http.StatusForbidden) + return + } + + } else if !cors.validateOrigin(origin) { c.AbortWithStatus(http.StatusForbidden) return } diff --git a/cors.go b/cors.go index b325222..b0f3bec 100644 --- a/cors.go +++ b/cors.go @@ -22,6 +22,9 @@ type Config struct { // set, the content of AllowOrigins is ignored. AllowOriginFunc func(origin string) bool + // The same as AllowOriginFunc but allows access to the entire request context + AllowOriginWithContextFunc func(c *gin.Context, origin string) bool + // AllowMethods is a list of methods the client is allowed to use with // cross-domain requests. Default value is simple methods (GET, POST, PUT, PATCH, DELETE, HEAD, and OPTIONS) AllowMethods []string @@ -102,12 +105,15 @@ func (c Config) validateAllowedSchemas(origin string) bool { // Validate is check configuration of user defined. func (c Config) Validate() error { - if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) { + if c.AllowAllOrigins && (c.AllowOriginFunc != nil || c.AllowOriginWithContextFunc != nil || len(c.AllowOrigins) > 0) { return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowOrigins is not needed") } - if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 { + if !c.AllowAllOrigins && c.AllowOriginFunc == nil && c.AllowOriginWithContextFunc == nil && len(c.AllowOrigins) == 0 { return errors.New("conflict settings: all origins disabled") } + if c.AllowOriginFunc != nil && c.AllowOriginWithContextFunc != nil { + return errors.New("conflict settings: Both original validation functions are defined") + } for _, origin := range c.AllowOrigins { if !strings.Contains(origin, "*") && !c.validateAllowedSchemas(origin) { return errors.New("bad origin: origins must contain '*' or include " + strings.Join(c.getAllowedSchemas(), ",")) diff --git a/cors_test.go b/cors_test.go index c87d60a..2800238 100644 --- a/cors_test.go +++ b/cors_test.go @@ -205,6 +205,8 @@ func TestGeneratePreflightHeaders_MaxAge(t *testing.T) { } func TestValidateOrigin(t *testing.T) { + // review the below for adding a testing context + //https://pkg.go.dev/github.com/gin-gonic/gin#CreateTestContextOnly cors := newCors(Config{ AllowAllOrigins: true, })