From 8a11cca6bbb0945d985100eddc3de3c6059d2b5b Mon Sep 17 00:00:00 2001 From: Kory Prince Date: Mon, 13 Nov 2023 17:29:12 -0600 Subject: [PATCH 1/5] Add configurable CORS policy in OpenIDProvider --- pkg/op/op.go | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/pkg/op/op.go b/pkg/op/op.go index ba36c617..268c3c92 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -97,9 +97,17 @@ type OpenIDProvider interface { type HttpInterceptor func(http.Handler) http.Handler +type corsOptioner interface { + CORSOptions() cors.Options +} + func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router { router := chi.NewRouter() - router.Use(cors.New(defaultCORSOptions).Handler) + if co, ok := o.(corsOptioner); ok { + router.Use(cors.New(co.CORSOptions()).Handler) + } else { + router.Use(cors.New(defaultCORSOptions).Handler) + } router.Use(intercept(o.IssuerFromRequest, interceptors...)) router.HandleFunc(healthEndpoint, healthHandler) router.HandleFunc(readinessEndpoint, readyHandler(o.Probes())) @@ -224,6 +232,7 @@ func NewProvider(config *Config, storage Storage, issuer func(insecure bool) (Is storage: storage, endpoints: DefaultEndpoints, timer: make(<-chan time.Time), + corsOpts: defaultCORSOptions, logger: slog.Default(), } @@ -268,6 +277,7 @@ type Provider struct { timer <-chan time.Time accessTokenVerifierOpts []AccessTokenVerifierOpt idTokenHintVerifierOpts []IDTokenHintVerifierOpt + corsOpts cors.Options logger *slog.Logger } @@ -427,6 +437,10 @@ func (o *Provider) Probes() []ProbesFn { } } +func (o *Provider) CORSOptions() cors.Options { + return o.corsOpts +} + func (o *Provider) Logger() *slog.Logger { return o.logger } @@ -587,6 +601,13 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option { } } +func WithCORSOptions(opts cors.Options) Option { + return func(o *Provider) error { + o.corsOpts = opts + return nil + } +} + // WithLogger lets a logger other than slog.Default(). // // EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20 From 2a038f8746473bf3ba98696c587d42572c62d73f Mon Sep 17 00:00:00 2001 From: Kory Prince Date: Mon, 13 Nov 2023 18:10:05 -0600 Subject: [PATCH 2/5] Add configurable CORS policy to Server --- pkg/op/server_http.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go index 6d379c63..6fd2a290 100644 --- a/pkg/op/server_http.go +++ b/pkg/op/server_http.go @@ -29,16 +29,16 @@ func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) server: server, endpoints: endpoints, decoder: decoder, + corsOpts: defaultCORSOptions, logger: slog.Default(), } - ws.router.Use(cors.New(defaultCORSOptions).Handler) for _, option := range options { option(ws) } ws.createRouter() - return ws + return cors.New(ws.corsOpts).Handler(ws) } type ServerOption func(s *webServer) @@ -66,6 +66,13 @@ func WithDecoder(decoder httphelper.Decoder) ServerOption { } } +// WithServerCORSOptions sets the CORS policy for the Server's router. +func WithServerCORSOptions(opts cors.Options) ServerOption { + return func(s *webServer) { + s.corsOpts = opts + } +} + // WithFallbackLogger overrides the fallback logger, which // is used when no logger was found in the context. // Defaults to [slog.Default]. @@ -80,6 +87,7 @@ type webServer struct { router *chi.Mux endpoints Endpoints decoder httphelper.Decoder + corsOpts cors.Options logger *slog.Logger } From cd3cb17b2c9caffdcdb8d9c134cf9c66b7e1008d Mon Sep 17 00:00:00 2001 From: Kory Prince Date: Mon, 13 Nov 2023 18:10:15 -0600 Subject: [PATCH 3/5] remove duplicated CORS middleware --- pkg/op/op.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/op/op.go b/pkg/op/op.go index 268c3c92..1a7ccdb2 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -624,6 +624,6 @@ func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handle for i := len(interceptors) - 1; i >= 0; i-- { handler = interceptors[i](handler) } - return cors.New(defaultCORSOptions).Handler(issuerInterceptor.Handler(handler)) + return issuerInterceptor.Handler(handler) } } From 37e01449e0cf479d90c85aaa8ccc7fd452130fb1 Mon Sep 17 00:00:00 2001 From: Kory Prince Date: Mon, 13 Nov 2023 18:31:39 -0600 Subject: [PATCH 4/5] Allow nil CORS policy to be set to disable CORS middleware --- pkg/op/op.go | 14 ++++++++------ pkg/op/server_http.go | 11 +++++++---- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/pkg/op/op.go b/pkg/op/op.go index 1a7ccdb2..939ebf85 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -98,13 +98,15 @@ type OpenIDProvider interface { type HttpInterceptor func(http.Handler) http.Handler type corsOptioner interface { - CORSOptions() cors.Options + CORSOptions() *cors.Options } func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router { router := chi.NewRouter() if co, ok := o.(corsOptioner); ok { - router.Use(cors.New(co.CORSOptions()).Handler) + if opts := co.CORSOptions(); opts != nil { + router.Use(cors.New(*opts).Handler) + } } else { router.Use(cors.New(defaultCORSOptions).Handler) } @@ -232,7 +234,7 @@ func NewProvider(config *Config, storage Storage, issuer func(insecure bool) (Is storage: storage, endpoints: DefaultEndpoints, timer: make(<-chan time.Time), - corsOpts: defaultCORSOptions, + corsOpts: &defaultCORSOptions, logger: slog.Default(), } @@ -277,7 +279,7 @@ type Provider struct { timer <-chan time.Time accessTokenVerifierOpts []AccessTokenVerifierOpt idTokenHintVerifierOpts []IDTokenHintVerifierOpt - corsOpts cors.Options + corsOpts *cors.Options logger *slog.Logger } @@ -437,7 +439,7 @@ func (o *Provider) Probes() []ProbesFn { } } -func (o *Provider) CORSOptions() cors.Options { +func (o *Provider) CORSOptions() *cors.Options { return o.corsOpts } @@ -601,7 +603,7 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option { } } -func WithCORSOptions(opts cors.Options) Option { +func WithCORSOptions(opts *cors.Options) Option { return func(o *Provider) error { o.corsOpts = opts return nil diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go index 6fd2a290..34a322fe 100644 --- a/pkg/op/server_http.go +++ b/pkg/op/server_http.go @@ -29,7 +29,7 @@ func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) server: server, endpoints: endpoints, decoder: decoder, - corsOpts: defaultCORSOptions, + corsOpts: &defaultCORSOptions, logger: slog.Default(), } @@ -38,7 +38,10 @@ func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) } ws.createRouter() - return cors.New(ws.corsOpts).Handler(ws) + if ws.corsOpts != nil { + return cors.New(*ws.corsOpts).Handler(ws) + } + return ws } type ServerOption func(s *webServer) @@ -67,7 +70,7 @@ func WithDecoder(decoder httphelper.Decoder) ServerOption { } // WithServerCORSOptions sets the CORS policy for the Server's router. -func WithServerCORSOptions(opts cors.Options) ServerOption { +func WithServerCORSOptions(opts *cors.Options) ServerOption { return func(s *webServer) { s.corsOpts = opts } @@ -87,7 +90,7 @@ type webServer struct { router *chi.Mux endpoints Endpoints decoder httphelper.Decoder - corsOpts cors.Options + corsOpts *cors.Options logger *slog.Logger } From 5d7cb2c1a2346d36319f1c3a6293d25d8bb3a643 Mon Sep 17 00:00:00 2001 From: Kory Prince Date: Thu, 16 Nov 2023 17:49:13 -0600 Subject: [PATCH 5/5] create a separate handler on webServer so type assertion works in tests --- pkg/op/server_http.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go index 34a322fe..2220e448 100644 --- a/pkg/op/server_http.go +++ b/pkg/op/server_http.go @@ -38,8 +38,9 @@ func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) } ws.createRouter() + ws.handler = ws.router if ws.corsOpts != nil { - return cors.New(*ws.corsOpts).Handler(ws) + ws.handler = cors.New(*ws.corsOpts).Handler(ws.router) } return ws } @@ -88,6 +89,7 @@ func WithFallbackLogger(logger *slog.Logger) ServerOption { type webServer struct { server Server router *chi.Mux + handler http.Handler endpoints Endpoints decoder httphelper.Decoder corsOpts *cors.Options @@ -95,7 +97,7 @@ type webServer struct { } func (s *webServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - s.router.ServeHTTP(w, r) + s.handler.ServeHTTP(w, r) } func (s *webServer) getLogger(ctx context.Context) *slog.Logger {