Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove gorm from models #3584

Merged
merged 3 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/access/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
const ResourceInfraAPI = "infra"

// RequireInfraRole checks that the identity in the context can perform an action on a resource based on their granted roles
func RequireInfraRole(c *gin.Context, oneOfRoles ...string) (data.GormTxn, error) {
func RequireInfraRole(c *gin.Context, oneOfRoles ...string) (*data.Transaction, error) {
rCtx := GetRequestContext(c)
if err := IsAuthorized(rCtx, oneOfRoles...); err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion internal/access/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func TestRequireInfraRole(t *testing.T) {
})
}

func grant(t *testing.T, db data.GormTxn, createdBy *models.Identity, subject uid.PolymorphicID, privilege, resource string) {
func grant(t *testing.T, db data.WriteTxn, createdBy *models.Identity, subject uid.PolymorphicID, privilege, resource string) {
err := data.CreateGrant(db, &models.Grant{
Subject: subject,
Privilege: privilege,
Expand Down
2 changes: 1 addition & 1 deletion internal/access/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func hasMinimumCount(min int, password string, check func(rune) bool) bool {
return count >= min
}

func checkPasswordRequirements(db data.GormTxn, password string) error {
func checkPasswordRequirements(db data.ReadTxn, password string) error {
settings, err := data.GetSettings(db)
if err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion internal/access/grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func logError(fn func() error, msg string) {
}
}

func userInGroup(db data.GormTxn, authnUserID uid.ID, groupID uid.ID) bool {
func userInGroup(db data.ReadTxn, authnUserID uid.ID, groupID uid.ID) bool {
groups, err := data.ListGroupIDsForUser(db, authnUserID)
if err != nil {
return false
Expand Down
2 changes: 1 addition & 1 deletion internal/access/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func DeleteGroup(c *gin.Context, id uid.ID) error {
return data.DeleteGroup(db, id)
}

func checkIdentitiesInList(db data.GormTxn, ids []uid.ID) ([]uid.ID, error) {
func checkIdentitiesInList(db data.ReadTxn, ids []uid.ID) ([]uid.ID, error) {
if len(ids) == 0 {
return ids, nil
}
Expand Down
12 changes: 6 additions & 6 deletions internal/server/authn/key_exchange_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func TestKeyExchangeAuthentication(t *testing.T) {
db := setupDB(t)

type testCase struct {
setup func(t *testing.T, db data.GormTxn) (LoginMethod, time.Time)
setup func(t *testing.T, db data.WriteTxn) (LoginMethod, time.Time)
expectedErr string
expected func(t *testing.T, authnIdentity AuthenticatedIdentity)
}
Expand All @@ -27,7 +27,7 @@ func TestKeyExchangeAuthentication(t *testing.T) {

cases := map[string]testCase{
"InvalidAccessKeyCannotBeExchanged": {
setup: func(t *testing.T, db data.GormTxn) (LoginMethod, time.Time) {
setup: func(t *testing.T, db data.WriteTxn) (LoginMethod, time.Time) {
user := &models.Identity{Name: "goku@example.com"}
err := data.CreateIdentity(db, user)
assert.NilError(t, err)
Expand All @@ -39,7 +39,7 @@ func TestKeyExchangeAuthentication(t *testing.T) {
expectedErr: "could not get access key from database",
},
"ExpiredAccessKeyCannotBeExchanged": {
setup: func(t *testing.T, db data.GormTxn) (LoginMethod, time.Time) {
setup: func(t *testing.T, db data.WriteTxn) (LoginMethod, time.Time) {
user := &models.Identity{Name: "bulma@example.com"}
err := data.CreateIdentity(db, user)
assert.NilError(t, err)
Expand All @@ -59,7 +59,7 @@ func TestKeyExchangeAuthentication(t *testing.T) {
expectedErr: data.ErrAccessKeyExpired.Error(),
},
"AccessKeyCannotBeExchangedWhenUserNoLongerExists": {
setup: func(t *testing.T, db data.GormTxn) (LoginMethod, time.Time) {
setup: func(t *testing.T, db data.WriteTxn) (LoginMethod, time.Time) {
user := &models.Identity{Name: "notforlong@example.com"}
user.DeletedAt.Time = time.Now()
user.DeletedAt.Valid = true
Expand All @@ -81,7 +81,7 @@ func TestKeyExchangeAuthentication(t *testing.T) {
expectedErr: "user is not valid",
},
"AccessKeyCannotBeExchangedForLongerLived": {
setup: func(t *testing.T, db data.GormTxn) (LoginMethod, time.Time) {
setup: func(t *testing.T, db data.WriteTxn) (LoginMethod, time.Time) {
user := &models.Identity{Name: "krillin@example.com"}
err := data.CreateIdentity(db, user)
assert.NilError(t, err)
Expand All @@ -105,7 +105,7 @@ func TestKeyExchangeAuthentication(t *testing.T) {
},
},
"ValidAccessKeySuccess": {
setup: func(t *testing.T, db data.GormTxn) (LoginMethod, time.Time) {
setup: func(t *testing.T, db data.WriteTxn) (LoginMethod, time.Time) {
user := &models.Identity{Name: "cell@example.com"}
err := data.CreateIdentity(db, user)
assert.NilError(t, err)
Expand Down
12 changes: 6 additions & 6 deletions internal/server/authn/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ func TestExchangeAuthCodeForProviderTokens(t *testing.T) {
sessionExpiry := time.Now().Add(5 * time.Minute)

type testCase struct {
setup func(t *testing.T, db data.GormTxn) providers.OIDCClient
setup func(t *testing.T, db data.WriteTxn) providers.OIDCClient
expected func(t *testing.T, authnIdentity AuthenticatedIdentity)
}

testCases := map[string]testCase{
"NewUserNewGroups": {
setup: func(t *testing.T, db data.GormTxn) providers.OIDCClient {
setup: func(t *testing.T, db data.WriteTxn) providers.OIDCClient {
return &mockOIDCImplementation{
UserEmailResp: "newusernewgroups@example.com",
UserGroupsResp: []string{"Everyone", "developers"},
Expand All @@ -106,7 +106,7 @@ func TestExchangeAuthCodeForProviderTokens(t *testing.T) {
},
},
"NewUserExistingGroups": {
setup: func(t *testing.T, db data.GormTxn) providers.OIDCClient {
setup: func(t *testing.T, db data.WriteTxn) providers.OIDCClient {
existingGroup1 := &models.Group{Name: "existing1"}
existingGroup2 := &models.Group{Name: "existing2"}

Expand Down Expand Up @@ -137,7 +137,7 @@ func TestExchangeAuthCodeForProviderTokens(t *testing.T) {
},
},
"ExistingUserNewGroups": {
setup: func(t *testing.T, db data.GormTxn) providers.OIDCClient {
setup: func(t *testing.T, db data.WriteTxn) providers.OIDCClient {
err := data.CreateIdentity(db, &models.Identity{Name: "existingusernewgroups@example.com"})
assert.NilError(t, err)

Expand All @@ -162,7 +162,7 @@ func TestExchangeAuthCodeForProviderTokens(t *testing.T) {
},
},
"ExistingUserExistingGroups": {
setup: func(t *testing.T, db data.GormTxn) providers.OIDCClient {
setup: func(t *testing.T, db data.WriteTxn) providers.OIDCClient {
err := data.CreateIdentity(db, &models.Identity{Name: "existinguserexistinggroups@example.com"})
assert.NilError(t, err)

Expand Down Expand Up @@ -193,7 +193,7 @@ func TestExchangeAuthCodeForProviderTokens(t *testing.T) {
},
},
"ExistingUserGroupsWithNewGroups": {
setup: func(t *testing.T, db data.GormTxn) providers.OIDCClient {
setup: func(t *testing.T, db data.WriteTxn) providers.OIDCClient {
user := &models.Identity{Name: "eugwnw@example.com"}
err := data.CreateIdentity(db, user)
assert.NilError(t, err)
Expand Down
16 changes: 8 additions & 8 deletions internal/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ func (s Server) loadConfig(config Config) error {
return tx.Commit()
}

func (s Server) loadProviders(db data.GormTxn, providers []Provider) error {
func (s Server) loadProviders(db data.WriteTxn, providers []Provider) error {
keep := []uid.ID{}

for _, p := range providers {
Expand All @@ -685,7 +685,7 @@ func (s Server) loadProviders(db data.GormTxn, providers []Provider) error {
return nil
}

func (s Server) loadProvider(db data.GormTxn, input Provider) (*models.Provider, error) {
func (s Server) loadProvider(db data.WriteTxn, input Provider) (*models.Provider, error) {
// provider kind is an optional field
kind, err := models.ParseProviderKind(input.Kind)
if err != nil {
Expand Down Expand Up @@ -764,7 +764,7 @@ func (s Server) loadProvider(db data.GormTxn, input Provider) (*models.Provider,
return provider, nil
}

func (s Server) loadGrants(db data.GormTxn, grants []Grant) error {
func (s Server) loadGrants(db data.WriteTxn, grants []Grant) error {
keep := make([]uid.ID, 0, len(grants))

for _, g := range grants {
Expand All @@ -787,7 +787,7 @@ func (s Server) loadGrants(db data.GormTxn, grants []Grant) error {
return nil
}

func (Server) loadGrant(db data.GormTxn, input Grant) (*models.Grant, error) {
func (Server) loadGrant(db data.WriteTxn, input Grant) (*models.Grant, error) {
var id uid.PolymorphicID

switch {
Expand Down Expand Up @@ -863,7 +863,7 @@ func (Server) loadGrant(db data.GormTxn, input Grant) (*models.Grant, error) {
return grant, nil
}

func (s Server) loadUsers(db data.GormTxn, users []User) error {
func (s Server) loadUsers(db data.WriteTxn, users []User) error {
keep := make([]uid.ID, 0, len(users)+1)

for _, i := range users {
Expand All @@ -888,7 +888,7 @@ func (s Server) loadUsers(db data.GormTxn, users []User) error {
return nil
}

func (s Server) loadUser(db data.GormTxn, input User) (*models.Identity, error) {
func (s Server) loadUser(db data.WriteTxn, input User) (*models.Identity, error) {
identity, err := data.GetIdentity(db, data.GetIdentityOptions{ByName: input.Name})
if err != nil {
if !errors.Is(err, internal.ErrNotFound) {
Expand Down Expand Up @@ -929,7 +929,7 @@ func (s Server) loadUser(db data.GormTxn, input User) (*models.Identity, error)
return identity, nil
}

func (s Server) loadCredential(db data.GormTxn, identity *models.Identity, password string) error {
func (s Server) loadCredential(db data.WriteTxn, identity *models.Identity, password string) error {
if password == "" {
return nil
}
Expand Down Expand Up @@ -975,7 +975,7 @@ func (s Server) loadCredential(db data.GormTxn, identity *models.Identity, passw
return nil
}

func (s Server) loadAccessKey(db data.GormTxn, identity *models.Identity, key string) error {
func (s Server) loadAccessKey(db data.WriteTxn, identity *models.Identity, key string) error {
if key == "" {
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion internal/server/data/access_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func validateAccessKey(accessKey *models.AccessKey) error {
return nil
}

func CreateAccessKey(db GormTxn, accessKey *models.AccessKey) (body string, err error) {
func CreateAccessKey(db WriteTxn, accessKey *models.AccessKey) (body string, err error) {
// check if this is an access key being issued for identity provider scim
provider, err := GetProvider(db, GetProviderOptions{ByID: accessKey.IssuedFor})
if err != nil && !errors.Is(err, internal.ErrNotFound) {
Expand Down
6 changes: 3 additions & 3 deletions internal/server/data/access_key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ var anyValidUID = cmp.Comparer(func(x, y uid.ID) bool {
// PostgreSQL only has microsecond precision
var cmpTimeWithDBPrecision = cmpopts.EquateApproxTime(time.Microsecond)

func createAccessKeyWithExtensionDeadline(t *testing.T, db GormTxn, ttl, extensionDeadline time.Duration) (string, *models.AccessKey) {
func createAccessKeyWithExtensionDeadline(t *testing.T, db WriteTxn, ttl, extensionDeadline time.Duration) (string, *models.AccessKey) {
identity := &models.Identity{Name: "Wall-E"}
err := CreateIdentity(db, identity)
assert.NilError(t, err)
Expand Down Expand Up @@ -268,7 +268,7 @@ func TestDeleteAccessKeys(t *testing.T) {
})
}

func createAccessKeys(t *testing.T, db GormTxn, keys ...*models.AccessKey) {
func createAccessKeys(t *testing.T, db WriteTxn, keys ...*models.AccessKey) {
t.Helper()
for i := range keys {
_, err := CreateAccessKey(db, keys[i])
Expand Down Expand Up @@ -514,7 +514,7 @@ func TestListAccessKeys(t *testing.T) {
})
}

func createTestAccessKey(t *testing.T, db GormTxn, sessionDuration time.Duration) (string, *models.AccessKey) {
func createTestAccessKey(t *testing.T, db WriteTxn, sessionDuration time.Duration) (string, *models.AccessKey) {
user := &models.Identity{Name: "tmp@infrahq.com"}
err := CreateIdentity(db, user)
assert.NilError(t, err)
Expand Down
42 changes: 1 addition & 41 deletions internal/server/data/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,6 @@ func (d *DB) OrganizationID() uid.ID {
return d.DefaultOrg.ID
}

func (d *DB) GormDB() *gorm.DB {
return d.DB
}

func (d *DB) Begin(ctx context.Context, opts *sql.TxOptions) (*Transaction, error) {
tx := d.DB.WithContext(ctx).Begin(opts)
if err := tx.Error; err != nil {
Expand All @@ -136,16 +132,6 @@ func (d *DB) Begin(ctx context.Context, opts *sql.TxOptions) (*Transaction, erro
return &Transaction{DB: tx, completed: new(atomic.Bool)}, nil
}

// GormTxn is used as a shim in preparation for removing gorm.
type GormTxn interface {
WriteTxn

// GormDB returns the underlying reference to the gorm.DB struct.
// Do not use this in new code! Instead, write SQL using the stdlib\
// interface of Query, QueryRow, and Exec.
GormDB() *gorm.DB
}

type Transaction struct {
*gorm.DB
orgID uid.ID
Expand Down Expand Up @@ -173,10 +159,6 @@ func (t *Transaction) QueryRow(query string, args ...any) *sql.Row {
return t.DB.Raw(query, args...).Row()
}

func (t *Transaction) GormDB() *gorm.DB {
return t.DB
}

// Rollback the transaction. If the transaction was already committed then do
// nothing.
func (t *Transaction) Rollback() error {
Expand Down Expand Up @@ -264,28 +246,6 @@ func initialize(db *DB) error {
return tx.Commit()
}

func get[T models.Modelable](tx GormTxn, selectors ...SelectorFunc) (*T, error) {
db := tx.GormDB()
for _, selector := range selectors {
db = selector(db)
}

result := new(T)
if isOrgMember(result) {
db = ByOrgID(tx.OrganizationID())(db)
}

if err := db.Model((*T)(nil)).First(result).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, internal.ErrNotFound
}

return nil, err
}

return result, nil
}

// setOrg checks if model is an organization member, and sets the organizationID
// from the transaction when it is an organization member.
func setOrg(tx ReadTxn, model any) {
Expand Down Expand Up @@ -423,7 +383,7 @@ func InfraProvider(tx ReadTxn) *models.Provider {

// InfraConnectorIdentity returns the connector identity for the organization set
// in the db context.
func InfraConnectorIdentity(db GormTxn) *models.Identity {
func InfraConnectorIdentity(db ReadTxn) *models.Identity {
connector, err := GetIdentity(db, GetIdentityOptions{ByName: models.InternalInfraConnectorIdentityName})
if err != nil {
logging.L.Panic().Err(err).Msg("failed to retrieve connector identity")
Expand Down
2 changes: 1 addition & 1 deletion internal/server/data/destination_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ func TestCountDestinationsByConnectedVersion(t *testing.T) {
})
}

func createDestinations(t *testing.T, tx GormTxn, destinations ...*models.Destination) {
func createDestinations(t *testing.T, tx WriteTxn, destinations ...*models.Destination) {
t.Helper()
for i := range destinations {
err := CreateDestination(tx, destinations[i])
Expand Down
2 changes: 1 addition & 1 deletion internal/server/data/deviceflowauthrequest.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ type GetDeviceFlowAuthRequestOptions struct {
ByUserCode string
}

func GetDeviceFlowAuthRequest(tx GormTxn, opts GetDeviceFlowAuthRequestOptions) (*models.DeviceFlowAuthRequest, error) {
func GetDeviceFlowAuthRequest(tx ReadTxn, opts GetDeviceFlowAuthRequestOptions) (*models.DeviceFlowAuthRequest, error) {
if opts.ByDeviceCode == "" && opts.ByUserCode == "" && opts.ByID == 0 {
return nil, errors.New("must supply one of device_code, user_code, or id to GetDeviceFlowAuthRequest")
}
Expand Down
2 changes: 1 addition & 1 deletion internal/server/data/group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func TestCreateGroup(t *testing.T) {
})
}

func createGroups(t *testing.T, db GormTxn, groups ...*models.Group) {
func createGroups(t *testing.T, db WriteTxn, groups ...*models.Group) {
t.Helper()
for i := range groups {
err := CreateGroup(db, groups[i])
Expand Down
Loading