Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Next round of db.DefaultContext refactor #27089

Merged
merged 6 commits into from
Sep 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions models/actions/schedule.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ func init() {
}

// GetSchedulesMapByIDs returns the schedules by given id slice.
func GetSchedulesMapByIDs(ids []int64) (map[int64]*ActionSchedule, error) {
func GetSchedulesMapByIDs(ctx context.Context, ids []int64) (map[int64]*ActionSchedule, error) {
schedules := make(map[int64]*ActionSchedule, len(ids))
return schedules, db.GetEngine(db.DefaultContext).In("id", ids).Find(&schedules)
return schedules, db.GetEngine(ctx).In("id", ids).Find(&schedules)
}

// GetReposMapByIDs returns the repos by given id slice.
func GetReposMapByIDs(ids []int64) (map[int64]*repo_model.Repository, error) {
func GetReposMapByIDs(ctx context.Context, ids []int64) (map[int64]*repo_model.Repository, error) {
repos := make(map[int64]*repo_model.Repository, len(ids))
return repos, db.GetEngine(db.DefaultContext).In("id", ids).Find(&repos)
return repos, db.GetEngine(ctx).In("id", ids).Find(&repos)
}

var cronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor)
Expand Down
8 changes: 4 additions & 4 deletions models/actions/schedule_spec_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ func (specs SpecList) GetScheduleIDs() []int64 {
return ids.Values()
}

func (specs SpecList) LoadSchedules() error {
func (specs SpecList) LoadSchedules(ctx context.Context) error {
scheduleIDs := specs.GetScheduleIDs()
schedules, err := GetSchedulesMapByIDs(scheduleIDs)
schedules, err := GetSchedulesMapByIDs(ctx, scheduleIDs)
if err != nil {
return err
}
Expand All @@ -34,7 +34,7 @@ func (specs SpecList) LoadSchedules() error {
}

repoIDs := specs.GetRepoIDs()
repos, err := GetReposMapByIDs(repoIDs)
repos, err := GetReposMapByIDs(ctx, repoIDs)
if err != nil {
return err
}
Expand Down Expand Up @@ -95,7 +95,7 @@ func FindSpecs(ctx context.Context, opts FindSpecOptions) (SpecList, int64, erro
return nil, 0, err
}

if err := specs.LoadSchedules(); err != nil {
if err := specs.LoadSchedules(ctx); err != nil {
return nil, 0, err
}
return specs, total, nil
Expand Down
34 changes: 15 additions & 19 deletions models/admin/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,7 @@ type TranslatableMessage struct {
}

// LoadRepo loads repository of the task
func (task *Task) LoadRepo() error {
return task.loadRepo(db.DefaultContext)
}

func (task *Task) loadRepo(ctx context.Context) error {
func (task *Task) LoadRepo(ctx context.Context) error {
JakobDev marked this conversation as resolved.
Show resolved Hide resolved
if task.Repo != nil {
return nil
}
Expand All @@ -70,13 +66,13 @@ func (task *Task) loadRepo(ctx context.Context) error {
}

// LoadDoer loads do user
func (task *Task) LoadDoer() error {
func (task *Task) LoadDoer(ctx context.Context) error {
if task.Doer != nil {
return nil
}

var doer user_model.User
has, err := db.GetEngine(db.DefaultContext).ID(task.DoerID).Get(&doer)
has, err := db.GetEngine(ctx).ID(task.DoerID).Get(&doer)
if err != nil {
return err
} else if !has {
Expand All @@ -90,13 +86,13 @@ func (task *Task) LoadDoer() error {
}

// LoadOwner loads owner user
func (task *Task) LoadOwner() error {
func (task *Task) LoadOwner(ctx context.Context) error {
if task.Owner != nil {
return nil
}

var owner user_model.User
has, err := db.GetEngine(db.DefaultContext).ID(task.OwnerID).Get(&owner)
has, err := db.GetEngine(ctx).ID(task.OwnerID).Get(&owner)
if err != nil {
return err
} else if !has {
Expand All @@ -110,8 +106,8 @@ func (task *Task) LoadOwner() error {
}

// UpdateCols updates some columns
func (task *Task) UpdateCols(cols ...string) error {
_, err := db.GetEngine(db.DefaultContext).ID(task.ID).Cols(cols...).Update(task)
func (task *Task) UpdateCols(ctx context.Context, cols ...string) error {
_, err := db.GetEngine(ctx).ID(task.ID).Cols(cols...).Update(task)
return err
}

Expand Down Expand Up @@ -169,12 +165,12 @@ func (err ErrTaskDoesNotExist) Unwrap() error {
}

// GetMigratingTask returns the migrating task by repo's id
func GetMigratingTask(repoID int64) (*Task, error) {
func GetMigratingTask(ctx context.Context, repoID int64) (*Task, error) {
task := Task{
RepoID: repoID,
Type: structs.TaskTypeMigrateRepo,
}
has, err := db.GetEngine(db.DefaultContext).Get(&task)
has, err := db.GetEngine(ctx).Get(&task)
if err != nil {
return nil, err
} else if !has {
Expand All @@ -184,13 +180,13 @@ func GetMigratingTask(repoID int64) (*Task, error) {
}

// GetMigratingTaskByID returns the migrating task by repo's id
func GetMigratingTaskByID(id, doerID int64) (*Task, *migration.MigrateOptions, error) {
func GetMigratingTaskByID(ctx context.Context, id, doerID int64) (*Task, *migration.MigrateOptions, error) {
task := Task{
ID: id,
DoerID: doerID,
Type: structs.TaskTypeMigrateRepo,
}
has, err := db.GetEngine(db.DefaultContext).Get(&task)
has, err := db.GetEngine(ctx).Get(&task)
if err != nil {
return nil, nil, err
} else if !has {
Expand All @@ -205,12 +201,12 @@ func GetMigratingTaskByID(id, doerID int64) (*Task, *migration.MigrateOptions, e
}

// CreateTask creates a task on database
func CreateTask(task *Task) error {
return db.Insert(db.DefaultContext, task)
func CreateTask(ctx context.Context, task *Task) error {
return db.Insert(ctx, task)
}

// FinishMigrateTask updates database when migrate task finished
func FinishMigrateTask(task *Task) error {
func FinishMigrateTask(ctx context.Context, task *Task) error {
task.Status = structs.TaskStatusFinished
task.EndTime = timeutil.TimeStampNow()

Expand All @@ -231,6 +227,6 @@ func FinishMigrateTask(task *Task) error {
}
task.PayloadContent = string(confBytes)

_, err = db.GetEngine(db.DefaultContext).ID(task.ID).Cols("status", "end_time", "payload_content").Update(task)
_, err = db.GetEngine(ctx).ID(task.ID).Cols("status", "end_time", "payload_content").Update(task)
return err
}
29 changes: 15 additions & 14 deletions models/auth/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package auth

import (
"context"
"fmt"

"code.gitea.io/gitea/models/db"
Expand All @@ -22,21 +23,21 @@ func init() {
}

// UpdateSession updates the session with provided id
func UpdateSession(key string, data []byte) error {
_, err := db.GetEngine(db.DefaultContext).ID(key).Update(&Session{
func UpdateSession(ctx context.Context, key string, data []byte) error {
_, err := db.GetEngine(ctx).ID(key).Update(&Session{
Data: data,
Expiry: timeutil.TimeStampNow(),
})
return err
}

// ReadSession reads the data for the provided session
func ReadSession(key string) (*Session, error) {
func ReadSession(ctx context.Context, key string) (*Session, error) {
session := Session{
Key: key,
}

ctx, committer, err := db.TxContext(db.DefaultContext)
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
Expand All @@ -55,24 +56,24 @@ func ReadSession(key string) (*Session, error) {
}

// ExistSession checks if a session exists
func ExistSession(key string) (bool, error) {
func ExistSession(ctx context.Context, key string) (bool, error) {
session := Session{
Key: key,
}
return db.GetEngine(db.DefaultContext).Get(&session)
return db.GetEngine(ctx).Get(&session)
}

// DestroySession destroys a session
func DestroySession(key string) error {
_, err := db.GetEngine(db.DefaultContext).Delete(&Session{
func DestroySession(ctx context.Context, key string) error {
_, err := db.GetEngine(ctx).Delete(&Session{
Key: key,
})
return err
}

// RegenerateSession regenerates a session from the old id
func RegenerateSession(oldKey, newKey string) (*Session, error) {
ctx, committer, err := db.TxContext(db.DefaultContext)
func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, error) {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -114,12 +115,12 @@ func RegenerateSession(oldKey, newKey string) (*Session, error) {
}

// CountSessions returns the number of sessions
func CountSessions() (int64, error) {
return db.GetEngine(db.DefaultContext).Count(&Session{})
func CountSessions(ctx context.Context) (int64, error) {
return db.GetEngine(ctx).Count(&Session{})
}

// CleanupSessions cleans up expired sessions
func CleanupSessions(maxLifetime int64) error {
_, err := db.GetEngine(db.DefaultContext).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{})
func CleanupSessions(ctx context.Context, maxLifetime int64) error {
_, err := db.GetEngine(ctx).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{})
return err
}
56 changes: 12 additions & 44 deletions models/auth/webauthn.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,7 @@ func (cred WebAuthnCredential) TableName() string {
}

// UpdateSignCount will update the database value of SignCount
func (cred *WebAuthnCredential) UpdateSignCount() error {
return cred.updateSignCount(db.DefaultContext)
}

func (cred *WebAuthnCredential) updateSignCount(ctx context.Context) error {
func (cred *WebAuthnCredential) UpdateSignCount(ctx context.Context) error {
JakobDev marked this conversation as resolved.
Show resolved Hide resolved
_, err := db.GetEngine(ctx).ID(cred.ID).Cols("sign_count").Update(cred)
return err
}
Expand Down Expand Up @@ -113,30 +109,18 @@ func (list WebAuthnCredentialList) ToCredentials() []webauthn.Credential {
}

// GetWebAuthnCredentialsByUID returns all WebAuthn credentials of the given user
func GetWebAuthnCredentialsByUID(uid int64) (WebAuthnCredentialList, error) {
return getWebAuthnCredentialsByUID(db.DefaultContext, uid)
}

func getWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) {
func GetWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) {
JakobDev marked this conversation as resolved.
Show resolved Hide resolved
creds := make(WebAuthnCredentialList, 0)
return creds, db.GetEngine(ctx).Where("user_id = ?", uid).Find(&creds)
}

// ExistsWebAuthnCredentialsForUID returns if the given user has credentials
func ExistsWebAuthnCredentialsForUID(uid int64) (bool, error) {
return existsWebAuthnCredentialsByUID(db.DefaultContext, uid)
}

func existsWebAuthnCredentialsByUID(ctx context.Context, uid int64) (bool, error) {
func ExistsWebAuthnCredentialsForUID(ctx context.Context, uid int64) (bool, error) {
JakobDev marked this conversation as resolved.
Show resolved Hide resolved
return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
}

// GetWebAuthnCredentialByName returns WebAuthn credential by id
func GetWebAuthnCredentialByName(uid int64, name string) (*WebAuthnCredential, error) {
return getWebAuthnCredentialByName(db.DefaultContext, uid, name)
}

func getWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) {
func GetWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) {
JakobDev marked this conversation as resolved.
Show resolved Hide resolved
cred := new(WebAuthnCredential)
if found, err := db.GetEngine(ctx).Where("user_id = ? AND lower_name = ?", uid, strings.ToLower(name)).Get(cred); err != nil {
return nil, err
Expand All @@ -147,11 +131,7 @@ func getWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*
}

// GetWebAuthnCredentialByID returns WebAuthn credential by id
func GetWebAuthnCredentialByID(id int64) (*WebAuthnCredential, error) {
return getWebAuthnCredentialByID(db.DefaultContext, id)
}

func getWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) {
func GetWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) {
JakobDev marked this conversation as resolved.
Show resolved Hide resolved
cred := new(WebAuthnCredential)
if found, err := db.GetEngine(ctx).ID(id).Get(cred); err != nil {
return nil, err
Expand All @@ -162,16 +142,12 @@ func getWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredenti
}

// HasWebAuthnRegistrationsByUID returns whether a given user has WebAuthn registrations
func HasWebAuthnRegistrationsByUID(uid int64) (bool, error) {
return db.GetEngine(db.DefaultContext).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
func HasWebAuthnRegistrationsByUID(ctx context.Context, uid int64) (bool, error) {
return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
}

// GetWebAuthnCredentialByCredID returns WebAuthn credential by credential ID
func GetWebAuthnCredentialByCredID(userID int64, credID []byte) (*WebAuthnCredential, error) {
return getWebAuthnCredentialByCredID(db.DefaultContext, userID, credID)
}

func getWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) {
func GetWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) {
JakobDev marked this conversation as resolved.
Show resolved Hide resolved
cred := new(WebAuthnCredential)
if found, err := db.GetEngine(ctx).Where("user_id = ? AND credential_id = ?", userID, credID).Get(cred); err != nil {
return nil, err
Expand All @@ -182,11 +158,7 @@ func getWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []b
}

// CreateCredential will create a new WebAuthnCredential from the given Credential
func CreateCredential(userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
return createCredential(db.DefaultContext, userID, name, cred)
}

func createCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
func CreateCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
JakobDev marked this conversation as resolved.
Show resolved Hide resolved
c := &WebAuthnCredential{
UserID: userID,
Name: name,
Expand All @@ -205,18 +177,14 @@ func createCredential(ctx context.Context, userID int64, name string, cred *weba
}

// DeleteCredential will delete WebAuthnCredential
func DeleteCredential(id, userID int64) (bool, error) {
return deleteCredential(db.DefaultContext, id, userID)
}

func deleteCredential(ctx context.Context, id, userID int64) (bool, error) {
func DeleteCredential(ctx context.Context, id, userID int64) (bool, error) {
JakobDev marked this conversation as resolved.
Show resolved Hide resolved
had, err := db.GetEngine(ctx).ID(id).Where("user_id = ?", userID).Delete(&WebAuthnCredential{})
return had > 0, err
}

// WebAuthnCredentials implementns the webauthn.User interface
func WebAuthnCredentials(userID int64) ([]webauthn.Credential, error) {
dbCreds, err := GetWebAuthnCredentialsByUID(userID)
func WebAuthnCredentials(ctx context.Context, userID int64) ([]webauthn.Credential, error) {
dbCreds, err := GetWebAuthnCredentialsByUID(ctx, userID)
if err != nil {
return nil, err
}
Expand Down
Loading