Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: cleanup the code for CORS handling #1959

Merged
merged 3 commits into from
Aug 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 9 additions & 27 deletions driver/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,14 @@ func OAuth2AwareCORSMiddleware(iface string, reg Registry, conf configuration.Pr
return h
}
}

corsOptions := conf.CORSOptions(iface)

var alwaysAllow bool = len(corsOptions.AllowedOrigins) == 0
var patterns []glob.Glob
for _, o := range corsOptions.AllowedOrigins {
if o == "*" {
alwaysAllow = true
}
// if the protocol (http or https) is specified, but the url is wildcard, use special ** glob, which ignore the '.' separator.
// This way g := glob.Compile("http://**") g.Match("http://google.com") returns true.
if splittedO := strings.Split(o, "://"); len(splittedO) != 1 && splittedO[1] == "*" {
Expand All @@ -54,20 +58,10 @@ func OAuth2AwareCORSMiddleware(iface string, reg Registry, conf configuration.Pr
if err != nil {
reg.Logger().WithError(err).Fatalf("Unable to parse cors origin: %s", o)
}
patterns = append(patterns, g)
}

var alwaysAllow bool
for _, o := range corsOptions.AllowedOrigins {
if o == "*" {
alwaysAllow = true
break
}
patterns = append(patterns, g)
}

if len(corsOptions.AllowedOrigins) == 0 {
alwaysAllow = true
}

options := cors.Options{
AllowedOrigins: corsOptions.AllowedOrigins,
Expand Down Expand Up @@ -111,27 +105,15 @@ func OAuth2AwareCORSMiddleware(iface string, reg Registry, conf configuration.Pr
return false
}

if alwaysAllow {
return true
}

for _, p := range cl.AllowedCORSOrigins {
if p == "*" {
for _, o := range cl.AllowedCORSOrigins {
if o == "*" {
return true
}
}

var clientPatterns []glob.Glob
for _, o := range cl.AllowedCORSOrigins {
g, err := glob.Compile(strings.ToLower(o), '.')
if err != nil {
return false
}
clientPatterns = append(patterns, g)
}

for _, p := range clientPatterns {
if p.Match(origin) {
if(g.Match(origin)){
return true
}
}
Expand Down
52 changes: 43 additions & 9 deletions driver/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
. "github.com/ory/hydra/driver"
"github.com/ory/hydra/internal"
"github.com/ory/hydra/oauth2"

"github.com/ory/hydra/x"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -68,69 +68,77 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) {
viper.Set("serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Basic Zm9vOmJhcg=="}},
header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo", "bar"))}},
expectHeader: http.Header{"Vary": {"Origin"}},
},
{
d: "should reject when basic auth client exists but origin not allowed",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
viper.Set("serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})
r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo-2", Secret: "bar", AllowedCORSOrigins: []string{"http://not-foobar.com"}})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Basic Zm9vLTI6YmFy"}},
header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-2", "bar"))}},
expectHeader: http.Header{"Vary": {"Origin"}},
},
{
d: "should accept when basic auth client exists and origin allowed",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo-3", Secret: "bar", AllowedCORSOrigins: []string{"http://foobar.com"}})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Basic Zm9vLTM6YmFy"}},
header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-3", "bar"))}},
expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foobar.com"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}},
},
{
d: "should accept when basic auth client exists and origin (with partial wildcard) is allowed per client",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo-4", Secret: "bar", AllowedCORSOrigins: []string{"http://*.foobar.com"}})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {"Basic Zm9vLTQ6YmFy"}},
header: http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-4", "bar"))}},
expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foo.foobar.com"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}},
},
{
d: "should accept when basic auth client exists and origin (with full wildcard) is allowed globally",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
viper.Set("serve.public.cors.allowed_origins", []string{"*"})
r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo-5", Secret: "bar", AllowedCORSOrigins: []string{"http://barbar.com"}})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"*"}, "Authorization": {"Basic Zm9vLTU6YmFy"}},
header: http.Header{"Origin": {"*"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-5", "bar"))}},
expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"*"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}},
},
{
d: "should accept when basic auth client exists and origin (with partial wildcard) is allowed globally",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
viper.Set("serve.public.cors.allowed_origins", []string{"http://*.foobar.com"})
r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo-6", Secret: "bar", AllowedCORSOrigins: []string{"http://barbar.com"}})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {"Basic Zm9vLTY6YmFy"}},
header: http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-6", "bar"))}},
expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foo.foobar.com"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}},
},
{
d: "should accept when basic auth client exists and origin (with full wildcard) allowed per client",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
viper.Set("serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})
r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo-7", Secret: "bar", AllowedCORSOrigins: []string{"*"}})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Basic Zm9vLTc6YmFy"}},
header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-7", "bar"))}},
expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foobar.com"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}},
},
{
d: "should fail when token introspection fails",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
viper.Set("serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})
},
code: http.StatusNotImplemented,
Expand All @@ -140,6 +148,7 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) {
{
d: "should work when token introspection returns a session",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
viper.Set("serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})
sess := oauth2.NewSession("foo-9")
sess.SetExpiresAt(fosite.AccessToken, time.Now().Add(time.Hour))
Expand All @@ -160,17 +169,41 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) {
{
d: "should accept any allowed specified origin protocol",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo-11", Secret: "bar", AllowedCORSOrigins: []string{"*"}})
viper.Set("serve.public.cors.enabled", true)
viper.Set("serve.public.cors.allowed_origins", []string{"http://*", "https://*"})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {"Basic Zm9vLTQ6YmFy"}},
header: http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-11", "bar"))}},
expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foo.foobar.com"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}},
},
{
d: "should accept client origin when basic auth client exists and origin is set at the client as well as the server",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
viper.Set("serve.public.cors.allowed_origins", []string{"http://**.example.com"})
r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo-12", Secret: "bar", AllowedCORSOrigins: []string{"http://myapp.example.biz"}})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://myapp.example.biz"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-12", "bar"))}},
expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://myapp.example.biz"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}},
},
{
d: "should accept server origin when basic auth client exists and origin is set at the client as well as the server",
prep: func() {
viper.Set("serve.public.cors.enabled", true)
viper.Set("serve.public.cors.allowed_origins", []string{"http://**.example.com"})
r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo-13", Secret: "bar", AllowedCORSOrigins: []string{"http://myapp.example.biz"}})
},
code: http.StatusNotImplemented,
header: http.Header{"Origin": {"http://client-app.example.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-13", "bar"))}},
expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://client-app.example.com"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}},
},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the bug that was reported would be reproduced like this (checking against the origin domain, not the server one):


 		{
 			d: "should accept when basic auth client exists and origin is set at the client as well as the server",
 			prep: func() {
 				viper.Set("serve.public.cors.enabled", true)
 				viper.Set("serve.public.cors.allowed_origins", []string{"http://**.example.com"})
 				r.ClientManager().CreateClient(context.Background(), &client.Client{ID: "foo8", Secret: "bar", AllowedCORSOrigins: []string{"http://myapp.example.biz"}})
 			},
 			code:         http.StatusNotImplemented,
 			header:       http.Header{"Origin": {"http://myapp.example.biz"}, "Authorization": {"Basic Zm9vLTU6YmFy"}},
 			expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://client-app.example.com"}, "Access-Control-Expose-Headers": []string{"Content-Type"}, "Vary": []string{"Origin"}},
 		},

Copy link
Contributor Author

@harsimranmaan harsimranmaan Jul 31, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lemme give it a try. I expect the Allow-Origin to be set to http://myapp.example.biz.

Copy link
Contributor Author

@harsimranmaan harsimranmaan Jul 31, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aeneasr I could not repro the defect. Added the new test-case. Also refactored auth headers(in test cases) for readability.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, in that case I don't think there's any more to do! And you are correct, the expected value was a copy/paste mistake on my end!

} {
t.Run(fmt.Sprintf("case=%d/description=%s", k, tc.d), func(t *testing.T) {
if tc.prep != nil {
viper.Reset()
tc.prep()
}

Expand All @@ -189,4 +222,5 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) {
assert.EqualValues(t, tc.expectHeader, res.Header())
})
}

}