Skip to content

Commit

Permalink
fix tests and logic
Browse files Browse the repository at this point in the history
  • Loading branch information
CaptainStandby committed Mar 23, 2023
1 parent 73b5f13 commit f97913d
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
{
"attributes": {
"disabled": false,
"name": "link",
"name": "unlink",
"node_type": "input",
"type": "submit",
"value": "github"
Expand All @@ -111,8 +111,8 @@
"context": {
"provider": "github"
},
"id": 1050002,
"text": "Link github",
"id": 1050003,
"text": "Unlink github",
"type": "info"
}
},
Expand All @@ -139,5 +139,27 @@
}
},
"type": "input"
},
{
"attributes": {
"disabled": false,
"name": "unlink",
"node_type": "input",
"type": "submit",
"value": "ory"
},
"group": "oidc",
"messages": [],
"meta": {
"label": {
"context": {
"provider": "ory"
},
"id": 1050003,
"text": "Unlink ory",
"type": "info"
}
},
"type": "input"
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
"type": "input",
"group": "oidc",
"attributes": {
"name": "link",
"name": "unlink",
"type": "submit",
"value": "github",
"disabled": false,
Expand All @@ -109,8 +109,8 @@
"messages": [],
"meta": {
"label": {
"id": 1050002,
"text": "Link github",
"id": 1050003,
"text": "Unlink github",
"type": "info",
"context": {
"provider": "github"
Expand Down Expand Up @@ -139,5 +139,27 @@
}
}
}
},
{
"type": "input",
"group": "oidc",
"attributes": {
"name": "unlink",
"type": "submit",
"value": "ory",
"disabled": false,
"node_type": "input"
},
"messages": [],
"meta": {
"label": {
"id": 1050003,
"text": "Unlink ory",
"type": "info",
"context": {
"provider": "ory"
}
}
}
}
]
124 changes: 74 additions & 50 deletions selfservice/strategy/oidc/provider_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,58 +117,82 @@ type ConfigurationCollection struct {
}

func (c ConfigurationCollection) Provider(id string, reg dependencies) (Provider, error) {
for k := range c.Providers {
p := c.Providers[k]
if p.ID == id {
var providerNames []string
var addProviderName = func(pn string) string {
providerNames = append(providerNames, pn)
return pn
}
// !!! WARNING !!!
//
// If you add a provider here, please also add a test to
// provider_private_net_test.go
var providers = map[string]func(config Configuration) Provider{
"generic": func(c Configuration) Provider {
return NewProviderGenericOIDC(&c, reg)
},
"google": func(c Configuration) Provider {
return NewProviderGoogle(&c, reg)
},
"github": func(c Configuration) Provider {
return NewProviderGitHub(&c, reg)
},
"github-app": func(c Configuration) Provider {
return NewProviderGitHubApp(&c, reg)
},
"gitlab": func(c Configuration) Provider {
return NewProviderGitLab(&c, reg)
},
"microsoft": func(c Configuration) Provider {
return NewProviderMicrosoft(&c, reg)
},
"discord": func(c Configuration) Provider {
return NewProviderDiscord(&c, reg)
},
"slack": func(c Configuration) Provider {
return NewProviderSlack(&c, reg)
},
"facebook": func(c Configuration) Provider {
return NewProviderFacebook(&c, reg)
},
"auth0": func(c Configuration) Provider {
return NewProviderAuth0(&c, reg)
},
"vk": func(c Configuration) Provider {
return NewProviderVK(&c, reg)
},
"yandex": func(c Configuration) Provider {
return NewProviderYandex(&c, reg)
},
"apple": func(c Configuration) Provider {
return NewProviderApple(&c, reg)
},
"spotify": func(c Configuration) Provider {
return NewProviderSpotify(&c, reg)
},
"netid": func(c Configuration) Provider {
return NewProviderNetID(&c, reg)
},
"dingtalk": func(c Configuration) Provider {
return NewProviderDingTalk(&c, reg)
},
"linkedin": func(c Configuration) Provider {
return NewProviderLinkedIn(&c, reg)
},
"patreon": func(c Configuration) Provider {
return NewProviderPatreon(&c, reg)
},
}
providerNames := func() []string {
var names []string
for pn := range providers {
names = append(names, pn)
}
return names
}

// !!! WARNING !!!
//
// If you add a provider here, please also add a test to
// provider_private_net_test.go
switch p.Provider {
case addProviderName("generic"):
return NewProviderGenericOIDC(&p, reg), nil
case addProviderName("google"):
return NewProviderGoogle(&p, reg), nil
case addProviderName("github"):
return NewProviderGitHub(&p, reg), nil
case addProviderName("github-app"):
return NewProviderGitHubApp(&p, reg), nil
case addProviderName("gitlab"):
return NewProviderGitLab(&p, reg), nil
case addProviderName("microsoft"):
return NewProviderMicrosoft(&p, reg), nil
case addProviderName("discord"):
return NewProviderDiscord(&p, reg), nil
case addProviderName("slack"):
return NewProviderSlack(&p, reg), nil
case addProviderName("facebook"):
return NewProviderFacebook(&p, reg), nil
case addProviderName("auth0"):
return NewProviderAuth0(&p, reg), nil
case addProviderName("vk"):
return NewProviderVK(&p, reg), nil
case addProviderName("yandex"):
return NewProviderYandex(&p, reg), nil
case addProviderName("apple"):
return NewProviderApple(&p, reg), nil
case addProviderName("spotify"):
return NewProviderSpotify(&p, reg), nil
case addProviderName("netid"):
return NewProviderNetID(&p, reg), nil
case addProviderName("dingtalk"):
return NewProviderDingTalk(&p, reg), nil
case addProviderName("linkedin"):
return NewProviderLinkedIn(&p, reg), nil
case addProviderName("patreon"):
return NewProviderPatreon(&p, reg), nil
// find provider config for id
for _, p := range c.Providers {
if p.ID == id {
f, ok := providers[p.Provider]
if !ok {
return nil, errors.Errorf("provider type %s is not supported, supported are: %v", p.Provider, providerNames())
}
return nil, errors.Errorf("provider type %s is not supported, supported are: %v", p.Provider, providerNames)
return f(p), nil
}
}
return nil, errors.WithStack(herodot.ErrNotFound.WithReasonf(`OpenID Connect Provider "%s" is unknown or has not been configured`, id))
Expand Down
41 changes: 26 additions & 15 deletions selfservice/strategy/oidc/strategy_settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ var UnknownConnectionValidationError = &jsonschema.ValidationError{
Message: "can not unlink non-existing OpenID Connect connection", InstancePtr: "#/"}
var ConnectionExistValidationError = &jsonschema.ValidationError{
Message: "can not link unknown or already existing OpenID Connect connection", InstancePtr: "#/"}
var UnlinkAllFirstFactorConnectionsError = &jsonschema.ValidationError{
Message: "can not unlink OpenID Connect connection because it is the last remaining first factor credential", InstancePtr: "#/"}

func (s *Strategy) RegisterSettingsRoutes(router *x.RouterPublic) {}

Expand Down Expand Up @@ -87,21 +89,12 @@ func (s *Strategy) linkedProviders(ctx context.Context, r *http.Request, conf *C
return nil, errors.WithStack(err)
}

count, err := s.d.IdentityManager().CountActiveFirstFactorCredentials(ctx, confidential)
if err != nil {
return nil, err
}

if count < 2 {
// This means that we're able to remove a connection because it is the last configured credential. If it is
// removed, the identity is no longer able to sign in.
return nil, nil
}

var result []Provider
for _, p := range available.Providers {
prov, err := conf.Provider(p.Provider, s.d)
if err != nil {
if errors.Is(err, herodot.ErrNotFound) {
continue
} else if err != nil {
return nil, err
}
result = append(result, prov)
Expand Down Expand Up @@ -172,8 +165,17 @@ func (s *Strategy) PopulateSettingsMethod(r *http.Request, id *identity.Identity
sr.UI.GetNodes().Append(NewLinkNode(l.Config().ID))
}

for _, l := range linked {
sr.UI.GetNodes().Append(NewUnlinkNode(l.Config().ID))
count, err := s.d.IdentityManager().CountActiveFirstFactorCredentials(r.Context(), confidential)
if err != nil {
return err
}

if count > 1 {
// This means that we're able to remove a connection because it is the last configured credential. If it is
// removed, the identity is no longer able to sign in.
for _, l := range linked {
sr.UI.GetNodes().Append(NewUnlinkNode(l.Config().ID))
}
}

return nil
Expand Down Expand Up @@ -466,7 +468,16 @@ func (s *Strategy) unlinkProvider(w http.ResponseWriter, r *http.Request, ctxUpd
var cc identity.CredentialsOIDC
creds, err := i.ParseCredentials(s.ID(), &cc)
if err != nil {
return s.handleSettingsError(w, r, ctxUpdate, p, errors.WithStack(UnknownConnectionValidationError))
return s.handleSettingsError(w, r, ctxUpdate, p, err)
}

count, err := s.d.IdentityManager().CountActiveFirstFactorCredentials(r.Context(), i)
if err != nil {
return s.handleSettingsError(w, r, ctxUpdate, p, err)
}

if count < 2 {
return s.handleSettingsError(w, r, ctxUpdate, p, errors.WithStack(UnlinkAllFirstFactorConnectionsError))
}

var found bool
Expand Down
15 changes: 7 additions & 8 deletions selfservice/strategy/oidc/strategy_settings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func TestSettingsStrategy(t *testing.T) {
return
}

var unlinkInvalid = func(agent, provider string) func(t *testing.T) {
var unlinkInvalid = func(agent, provider, errorMessage string) func(t *testing.T) {
return func(t *testing.T) {
body, res, req := unlink(t, agent, provider)

Expand All @@ -267,27 +267,26 @@ func TestSettingsStrategy(t *testing.T) {

// The original options to link google and github are still there
t.Run("flow=fetch", func(t *testing.T) {
snapshotx.SnapshotTExcept(t, req.Ui.Nodes, []string{"0.attributes.value", "1.attributes.value"})
snapshotx.SnapshotT(t, req.Ui.Nodes, snapshotx.ExceptPaths("0.attributes.value", "1.attributes.value"))
})

t.Run("flow=json", func(t *testing.T) {
snapshotx.SnapshotTExcept(t, json.RawMessage(gjson.GetBytes(body, `ui.nodes`).Raw), []string{"0.attributes.value", "1.attributes.value"})
snapshotx.SnapshotT(t, json.RawMessage(gjson.GetBytes(body, `ui.nodes`).Raw), snapshotx.ExceptPaths("0.attributes.value", "1.attributes.value"))
})

assert.Contains(t, gjson.GetBytes(body, "ui.action").String(), publicTS.URL+settings.RouteSubmitFlow+"?flow=")
assert.Contains(t, gjson.GetBytes(body, `ui.messages.0.text`).String(),
"can not unlink non-existing OpenID Connect")
assert.Contains(t, gjson.GetBytes(body, `ui.messages.0.text`).String(), errorMessage)
}
}

t.Run("case=should not be able to unlink the last remaining connection",
unlinkInvalid("oryer", "ory"))
unlinkInvalid("oryer", "ory", "can not unlink OpenID Connect connection because it is the last remaining first factor credential"))

t.Run("case=should not be able to unlink an non-existing connection",
unlinkInvalid("oryer", "i-do-not-exist"))
unlinkInvalid("githuber", "i-do-not-exist", "can not unlink non-existing OpenID Connect connection"))

t.Run("case=should not be able to unlink a connection not yet linked",
unlinkInvalid("githuber", "google"))
unlinkInvalid("githuber", "google", "can not unlink non-existing OpenID Connect connection"))

t.Run("case=should unlink a connection", func(t *testing.T) {
agent, provider := "githuber", "github"
Expand Down
Loading

0 comments on commit f97913d

Please sign in to comment.