Skip to content

Commit

Permalink
fix: resolve issue with ambiguous optional selectors (#1495)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssoroka authored Apr 8, 2022
1 parent e482d50 commit b30c8ac
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 26 deletions.
2 changes: 1 addition & 1 deletion internal/access/access_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func ListAccessKeys(c *gin.Context, identityID uid.ID, name string) ([]models.Ac
return nil, err
}

return data.ListAccessKeys(db, data.ByIssuedFor(identityID), data.ByName(name))
return data.ListAccessKeys(db, data.ByOptionalIssuedFor(identityID), data.ByOptionalName(name))
}

func CreateAccessKey(c *gin.Context, accessKey *models.AccessKey, identityID uid.ID) (body string, err error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/access/destination.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func ListDestinations(c *gin.Context, uniqueID, name string) ([]models.Destinati
return nil, err
}

return data.ListDestinations(db, data.ByUniqueID(uniqueID), data.ByName(name))
return data.ListDestinations(db, data.ByOptionalUniqueID(uniqueID), data.ByOptionalName(name))
}

func DeleteDestination(c *gin.Context, id uid.ID) error {
Expand Down
2 changes: 1 addition & 1 deletion internal/access/grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func ListGrants(c *gin.Context, subject uid.PolymorphicID, resource string, priv
return nil, err
}

return data.ListGrants(db, data.BySubject(subject), data.ByResource(resource), data.ByPrivilege(privilege), data.NotCreatedBy(models.CreatedBySystem))
return data.ListGrants(db, data.ByOptionalSubject(subject), data.ByOptionalResource(resource), data.ByOptionalPrivilege(privilege), data.NotCreatedBy(models.CreatedBySystem))
}

func ListIdentityGrants(c *gin.Context, identityID uid.ID) ([]models.Grant, error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/access/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func ListGroups(c *gin.Context, name string) ([]models.Group, error) {
return nil, err
}

return data.ListGroups(db, data.ByName(name))
return data.ListGroups(db, data.ByOptionalName(name))
}

func CreateGroup(c *gin.Context, group *models.Group) error {
Expand Down
2 changes: 1 addition & 1 deletion internal/access/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func GetProvider(c *gin.Context, id uid.ID) (*models.Provider, error) {
func ListProviders(c *gin.Context, name string) ([]models.Provider, error) {
db := getDB(c)

return data.ListProviders(db, data.ByName(name))
return data.ListProviders(db, data.ByOptionalName(name))
}

func SaveProvider(c *gin.Context, provider *models.Provider) error {
Expand Down
20 changes: 16 additions & 4 deletions internal/server/data/grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ func GetGrant(db *gorm.DB, selectors ...SelectorFunc) (*models.Grant, error) {

func ListIdentityGrants(db *gorm.DB, userID uid.ID) (result []models.Grant, err error) {
polymorphicID := uid.NewIdentityPolymorphicID(userID)
return ListGrants(db, BySubject(polymorphicID), NotCreatedBy(models.CreatedBySystem))
return ListGrants(db, ByOptionalSubject(polymorphicID), NotCreatedBy(models.CreatedBySystem))
}

func ListGroupGrants(db *gorm.DB, groupID uid.ID) (result []models.Grant, err error) {
polymorphicID := uid.NewGroupPolymorphicID(groupID)
return ListGrants(db, BySubject(polymorphicID), NotCreatedBy(models.CreatedBySystem))
return ListGrants(db, ByOptionalSubject(polymorphicID), NotCreatedBy(models.CreatedBySystem))
}

func ListGrants(db *gorm.DB, selectors ...SelectorFunc) ([]models.Grant, error) {
Expand All @@ -56,7 +56,7 @@ func DeleteGrants(db *gorm.DB, selectors ...SelectorFunc) error {
return deleteAll[models.Grant](db, ByIDs(ids))
}

func ByPrivilege(s string) SelectorFunc {
func ByOptionalPrivilege(s string) SelectorFunc {
return func(db *gorm.DB) *gorm.DB {
if s == "" {
return db
Expand All @@ -66,7 +66,13 @@ func ByPrivilege(s string) SelectorFunc {
}
}

func ByResource(s string) SelectorFunc {
func ByPrivilege(s string) SelectorFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Where("privilege = ?", s)
}
}

func ByOptionalResource(s string) SelectorFunc {
return func(db *gorm.DB) *gorm.DB {
if s == "" {
return db
Expand All @@ -75,3 +81,9 @@ func ByResource(s string) SelectorFunc {
return db.Where("resource = ?", s)
}
}

func ByResource(s string) SelectorFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Where("resource = ?", s)
}
}
6 changes: 3 additions & 3 deletions internal/server/data/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func TestListProviders(t *testing.T) {
assert.NilError(t, err)
assert.Equal(t, 2, len(providers))

providers, err = ListProviders(db, ByURL("dev.okta.com"))
providers, err = ListProviders(db, ByOptionalName("okta-development"))
assert.NilError(t, err)
assert.Equal(t, 1, len(providers))
}
Expand All @@ -104,10 +104,10 @@ func TestDeleteProviders(t *testing.T) {
assert.NilError(t, err)
assert.Equal(t, 2, len(providers))

err = DeleteProviders(db, ByURL("dev.okta.com"))
err = DeleteProviders(db, ByOptionalName("okta-development"))
assert.NilError(t, err)

_, err = GetProvider(db, ByURL("dev.okta.com"))
_, err = GetProvider(db, ByOptionalName("okta-development"))
assert.Error(t, err, "record not found")
}

Expand Down
32 changes: 18 additions & 14 deletions internal/server/data/selectors.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func ByNotIDs(ids []uid.ID) SelectorFunc {
}
}

func ByName(name string) SelectorFunc {
func ByOptionalName(name string) SelectorFunc {
return func(db *gorm.DB) *gorm.DB {
if len(name) > 0 {
return db.Where("name = ?", name)
Expand All @@ -37,7 +37,13 @@ func ByName(name string) SelectorFunc {
}
}

func ByUniqueID(nodeID string) SelectorFunc {
func ByName(name string) SelectorFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Where("name = ?", name)
}
}

func ByOptionalUniqueID(nodeID string) SelectorFunc {
return func(db *gorm.DB) *gorm.DB {
if len(nodeID) > 0 {
return db.Where("unique_id = ?", nodeID)
Expand All @@ -49,10 +55,6 @@ func ByUniqueID(nodeID string) SelectorFunc {

func ByProviderID(id uid.ID) SelectorFunc {
return func(db *gorm.DB) *gorm.DB {
if id == 0 {
return db
}

return db.Where("provider_id = ?", id)
}
}
Expand All @@ -63,27 +65,23 @@ func ByKeyID(key string) SelectorFunc {
}
}

func ByURL(url string) SelectorFunc {
func ByOptionalSubject(polymorphicID uid.PolymorphicID) SelectorFunc {
return func(db *gorm.DB) *gorm.DB {
if len(url) == 0 {
if polymorphicID == "" {
return db
}

return db.Where("url = ?", url)
return db.Where("subject = ?", string(polymorphicID))
}
}

func BySubject(polymorphicID uid.PolymorphicID) SelectorFunc {
return func(db *gorm.DB) *gorm.DB {
if polymorphicID == "" {
return db
}

return db.Where("subject = ?", string(polymorphicID))
}
}

func ByIssuedFor(id uid.ID) SelectorFunc {
func ByOptionalIssuedFor(id uid.ID) SelectorFunc {
return func(db *gorm.DB) *gorm.DB {
if id == 0 {
return db
Expand All @@ -93,6 +91,12 @@ func ByIssuedFor(id uid.ID) SelectorFunc {
}
}

func ByIssuedFor(id uid.ID) SelectorFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Where("issued_for = ?", id)
}
}

func ByIdentityID(identityID uid.ID) SelectorFunc {
return func(db *gorm.DB) *gorm.DB {
return db.Where("identity_id = ?", identityID)
Expand Down

0 comments on commit b30c8ac

Please sign in to comment.