Skip to content

Commit

Permalink
Next round of db.DefaultContext refactor (#27089)
Browse files Browse the repository at this point in the history
Part of #27065
  • Loading branch information
JakobDev authored Sep 16, 2023
1 parent a1b2a11 commit f91dbbb
Show file tree
Hide file tree
Showing 90 changed files with 434 additions and 464 deletions.
8 changes: 4 additions & 4 deletions models/actions/schedule.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ func init() {
}

// GetSchedulesMapByIDs returns the schedules by given id slice.
func GetSchedulesMapByIDs(ids []int64) (map[int64]*ActionSchedule, error) {
func GetSchedulesMapByIDs(ctx context.Context, ids []int64) (map[int64]*ActionSchedule, error) {
schedules := make(map[int64]*ActionSchedule, len(ids))
return schedules, db.GetEngine(db.DefaultContext).In("id", ids).Find(&schedules)
return schedules, db.GetEngine(ctx).In("id", ids).Find(&schedules)
}

// GetReposMapByIDs returns the repos by given id slice.
func GetReposMapByIDs(ids []int64) (map[int64]*repo_model.Repository, error) {
func GetReposMapByIDs(ctx context.Context, ids []int64) (map[int64]*repo_model.Repository, error) {
repos := make(map[int64]*repo_model.Repository, len(ids))
return repos, db.GetEngine(db.DefaultContext).In("id", ids).Find(&repos)
return repos, db.GetEngine(ctx).In("id", ids).Find(&repos)
}

var cronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor)
Expand Down
8 changes: 4 additions & 4 deletions models/actions/schedule_spec_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ func (specs SpecList) GetScheduleIDs() []int64 {
return ids.Values()
}

func (specs SpecList) LoadSchedules() error {
func (specs SpecList) LoadSchedules(ctx context.Context) error {
scheduleIDs := specs.GetScheduleIDs()
schedules, err := GetSchedulesMapByIDs(scheduleIDs)
schedules, err := GetSchedulesMapByIDs(ctx, scheduleIDs)
if err != nil {
return err
}
Expand All @@ -34,7 +34,7 @@ func (specs SpecList) LoadSchedules() error {
}

repoIDs := specs.GetRepoIDs()
repos, err := GetReposMapByIDs(repoIDs)
repos, err := GetReposMapByIDs(ctx, repoIDs)
if err != nil {
return err
}
Expand Down Expand Up @@ -95,7 +95,7 @@ func FindSpecs(ctx context.Context, opts FindSpecOptions) (SpecList, int64, erro
return nil, 0, err
}

if err := specs.LoadSchedules(); err != nil {
if err := specs.LoadSchedules(ctx); err != nil {
return nil, 0, err
}
return specs, total, nil
Expand Down
34 changes: 15 additions & 19 deletions models/admin/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,7 @@ type TranslatableMessage struct {
}

// LoadRepo loads repository of the task
func (task *Task) LoadRepo() error {
return task.loadRepo(db.DefaultContext)
}

func (task *Task) loadRepo(ctx context.Context) error {
func (task *Task) LoadRepo(ctx context.Context) error {
if task.Repo != nil {
return nil
}
Expand All @@ -70,13 +66,13 @@ func (task *Task) loadRepo(ctx context.Context) error {
}

// LoadDoer loads do user
func (task *Task) LoadDoer() error {
func (task *Task) LoadDoer(ctx context.Context) error {
if task.Doer != nil {
return nil
}

var doer user_model.User
has, err := db.GetEngine(db.DefaultContext).ID(task.DoerID).Get(&doer)
has, err := db.GetEngine(ctx).ID(task.DoerID).Get(&doer)
if err != nil {
return err
} else if !has {
Expand All @@ -90,13 +86,13 @@ func (task *Task) LoadDoer() error {
}

// LoadOwner loads owner user
func (task *Task) LoadOwner() error {
func (task *Task) LoadOwner(ctx context.Context) error {
if task.Owner != nil {
return nil
}

var owner user_model.User
has, err := db.GetEngine(db.DefaultContext).ID(task.OwnerID).Get(&owner)
has, err := db.GetEngine(ctx).ID(task.OwnerID).Get(&owner)
if err != nil {
return err
} else if !has {
Expand All @@ -110,8 +106,8 @@ func (task *Task) LoadOwner() error {
}

// UpdateCols updates some columns
func (task *Task) UpdateCols(cols ...string) error {
_, err := db.GetEngine(db.DefaultContext).ID(task.ID).Cols(cols...).Update(task)
func (task *Task) UpdateCols(ctx context.Context, cols ...string) error {
_, err := db.GetEngine(ctx).ID(task.ID).Cols(cols...).Update(task)
return err
}

Expand Down Expand Up @@ -169,12 +165,12 @@ func (err ErrTaskDoesNotExist) Unwrap() error {
}

// GetMigratingTask returns the migrating task by repo's id
func GetMigratingTask(repoID int64) (*Task, error) {
func GetMigratingTask(ctx context.Context, repoID int64) (*Task, error) {
task := Task{
RepoID: repoID,
Type: structs.TaskTypeMigrateRepo,
}
has, err := db.GetEngine(db.DefaultContext).Get(&task)
has, err := db.GetEngine(ctx).Get(&task)
if err != nil {
return nil, err
} else if !has {
Expand All @@ -184,13 +180,13 @@ func GetMigratingTask(repoID int64) (*Task, error) {
}

// GetMigratingTaskByID returns the migrating task by repo's id
func GetMigratingTaskByID(id, doerID int64) (*Task, *migration.MigrateOptions, error) {
func GetMigratingTaskByID(ctx context.Context, id, doerID int64) (*Task, *migration.MigrateOptions, error) {
task := Task{
ID: id,
DoerID: doerID,
Type: structs.TaskTypeMigrateRepo,
}
has, err := db.GetEngine(db.DefaultContext).Get(&task)
has, err := db.GetEngine(ctx).Get(&task)
if err != nil {
return nil, nil, err
} else if !has {
Expand All @@ -205,12 +201,12 @@ func GetMigratingTaskByID(id, doerID int64) (*Task, *migration.MigrateOptions, e
}

// CreateTask creates a task on database
func CreateTask(task *Task) error {
return db.Insert(db.DefaultContext, task)
func CreateTask(ctx context.Context, task *Task) error {
return db.Insert(ctx, task)
}

// FinishMigrateTask updates database when migrate task finished
func FinishMigrateTask(task *Task) error {
func FinishMigrateTask(ctx context.Context, task *Task) error {
task.Status = structs.TaskStatusFinished
task.EndTime = timeutil.TimeStampNow()

Expand All @@ -231,6 +227,6 @@ func FinishMigrateTask(task *Task) error {
}
task.PayloadContent = string(confBytes)

_, err = db.GetEngine(db.DefaultContext).ID(task.ID).Cols("status", "end_time", "payload_content").Update(task)
_, err = db.GetEngine(ctx).ID(task.ID).Cols("status", "end_time", "payload_content").Update(task)
return err
}
29 changes: 15 additions & 14 deletions models/auth/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package auth

import (
"context"
"fmt"

"code.gitea.io/gitea/models/db"
Expand All @@ -22,21 +23,21 @@ func init() {
}

// UpdateSession updates the session with provided id
func UpdateSession(key string, data []byte) error {
_, err := db.GetEngine(db.DefaultContext).ID(key).Update(&Session{
func UpdateSession(ctx context.Context, key string, data []byte) error {
_, err := db.GetEngine(ctx).ID(key).Update(&Session{
Data: data,
Expiry: timeutil.TimeStampNow(),
})
return err
}

// ReadSession reads the data for the provided session
func ReadSession(key string) (*Session, error) {
func ReadSession(ctx context.Context, key string) (*Session, error) {
session := Session{
Key: key,
}

ctx, committer, err := db.TxContext(db.DefaultContext)
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
Expand All @@ -55,24 +56,24 @@ func ReadSession(key string) (*Session, error) {
}

// ExistSession checks if a session exists
func ExistSession(key string) (bool, error) {
func ExistSession(ctx context.Context, key string) (bool, error) {
session := Session{
Key: key,
}
return db.GetEngine(db.DefaultContext).Get(&session)
return db.GetEngine(ctx).Get(&session)
}

// DestroySession destroys a session
func DestroySession(key string) error {
_, err := db.GetEngine(db.DefaultContext).Delete(&Session{
func DestroySession(ctx context.Context, key string) error {
_, err := db.GetEngine(ctx).Delete(&Session{
Key: key,
})
return err
}

// RegenerateSession regenerates a session from the old id
func RegenerateSession(oldKey, newKey string) (*Session, error) {
ctx, committer, err := db.TxContext(db.DefaultContext)
func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, error) {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -114,12 +115,12 @@ func RegenerateSession(oldKey, newKey string) (*Session, error) {
}

// CountSessions returns the number of sessions
func CountSessions() (int64, error) {
return db.GetEngine(db.DefaultContext).Count(&Session{})
func CountSessions(ctx context.Context) (int64, error) {
return db.GetEngine(ctx).Count(&Session{})
}

// CleanupSessions cleans up expired sessions
func CleanupSessions(maxLifetime int64) error {
_, err := db.GetEngine(db.DefaultContext).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{})
func CleanupSessions(ctx context.Context, maxLifetime int64) error {
_, err := db.GetEngine(ctx).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{})
return err
}
56 changes: 12 additions & 44 deletions models/auth/webauthn.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,7 @@ func (cred WebAuthnCredential) TableName() string {
}

// UpdateSignCount will update the database value of SignCount
func (cred *WebAuthnCredential) UpdateSignCount() error {
return cred.updateSignCount(db.DefaultContext)
}

func (cred *WebAuthnCredential) updateSignCount(ctx context.Context) error {
func (cred *WebAuthnCredential) UpdateSignCount(ctx context.Context) error {
_, err := db.GetEngine(ctx).ID(cred.ID).Cols("sign_count").Update(cred)
return err
}
Expand Down Expand Up @@ -113,30 +109,18 @@ func (list WebAuthnCredentialList) ToCredentials() []webauthn.Credential {
}

// GetWebAuthnCredentialsByUID returns all WebAuthn credentials of the given user
func GetWebAuthnCredentialsByUID(uid int64) (WebAuthnCredentialList, error) {
return getWebAuthnCredentialsByUID(db.DefaultContext, uid)
}

func getWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) {
func GetWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) {
creds := make(WebAuthnCredentialList, 0)
return creds, db.GetEngine(ctx).Where("user_id = ?", uid).Find(&creds)
}

// ExistsWebAuthnCredentialsForUID returns if the given user has credentials
func ExistsWebAuthnCredentialsForUID(uid int64) (bool, error) {
return existsWebAuthnCredentialsByUID(db.DefaultContext, uid)
}

func existsWebAuthnCredentialsByUID(ctx context.Context, uid int64) (bool, error) {
func ExistsWebAuthnCredentialsForUID(ctx context.Context, uid int64) (bool, error) {
return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
}

// GetWebAuthnCredentialByName returns WebAuthn credential by id
func GetWebAuthnCredentialByName(uid int64, name string) (*WebAuthnCredential, error) {
return getWebAuthnCredentialByName(db.DefaultContext, uid, name)
}

func getWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) {
func GetWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) {
cred := new(WebAuthnCredential)
if found, err := db.GetEngine(ctx).Where("user_id = ? AND lower_name = ?", uid, strings.ToLower(name)).Get(cred); err != nil {
return nil, err
Expand All @@ -147,11 +131,7 @@ func getWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*
}

// GetWebAuthnCredentialByID returns WebAuthn credential by id
func GetWebAuthnCredentialByID(id int64) (*WebAuthnCredential, error) {
return getWebAuthnCredentialByID(db.DefaultContext, id)
}

func getWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) {
func GetWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) {
cred := new(WebAuthnCredential)
if found, err := db.GetEngine(ctx).ID(id).Get(cred); err != nil {
return nil, err
Expand All @@ -162,16 +142,12 @@ func getWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredenti
}

// HasWebAuthnRegistrationsByUID returns whether a given user has WebAuthn registrations
func HasWebAuthnRegistrationsByUID(uid int64) (bool, error) {
return db.GetEngine(db.DefaultContext).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
func HasWebAuthnRegistrationsByUID(ctx context.Context, uid int64) (bool, error) {
return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
}

// GetWebAuthnCredentialByCredID returns WebAuthn credential by credential ID
func GetWebAuthnCredentialByCredID(userID int64, credID []byte) (*WebAuthnCredential, error) {
return getWebAuthnCredentialByCredID(db.DefaultContext, userID, credID)
}

func getWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) {
func GetWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) {
cred := new(WebAuthnCredential)
if found, err := db.GetEngine(ctx).Where("user_id = ? AND credential_id = ?", userID, credID).Get(cred); err != nil {
return nil, err
Expand All @@ -182,11 +158,7 @@ func getWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []b
}

// CreateCredential will create a new WebAuthnCredential from the given Credential
func CreateCredential(userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
return createCredential(db.DefaultContext, userID, name, cred)
}

func createCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
func CreateCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
c := &WebAuthnCredential{
UserID: userID,
Name: name,
Expand All @@ -205,18 +177,14 @@ func createCredential(ctx context.Context, userID int64, name string, cred *weba
}

// DeleteCredential will delete WebAuthnCredential
func DeleteCredential(id, userID int64) (bool, error) {
return deleteCredential(db.DefaultContext, id, userID)
}

func deleteCredential(ctx context.Context, id, userID int64) (bool, error) {
func DeleteCredential(ctx context.Context, id, userID int64) (bool, error) {
had, err := db.GetEngine(ctx).ID(id).Where("user_id = ?", userID).Delete(&WebAuthnCredential{})
return had > 0, err
}

// WebAuthnCredentials implementns the webauthn.User interface
func WebAuthnCredentials(userID int64) ([]webauthn.Credential, error) {
dbCreds, err := GetWebAuthnCredentialsByUID(userID)
func WebAuthnCredentials(ctx context.Context, userID int64) ([]webauthn.Credential, error) {
dbCreds, err := GetWebAuthnCredentialsByUID(ctx, userID)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit f91dbbb

Please sign in to comment.