Skip to content

Commit

Permalink
Replace more db.DefaultContext (#27628)
Browse files Browse the repository at this point in the history
Target #27065
  • Loading branch information
lunny committed Oct 15, 2023
1 parent 7480aac commit cddf245
Show file tree
Hide file tree
Showing 33 changed files with 99 additions and 85 deletions.
12 changes: 7 additions & 5 deletions contrib/fixtures/fixture_generation.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package main

import (
"context"
"fmt"
"os"
"path/filepath"
Expand All @@ -18,7 +19,7 @@ import (

var (
generators = []struct {
gen func() (string, error)
gen func(ctx context.Context) (string, error)
name string
}{
{
Expand All @@ -41,27 +42,28 @@ func main() {
fmt.Printf("PrepareTestDatabase: %+v\n", err)
os.Exit(1)
}
ctx := context.Background()
if len(os.Args) == 0 {
for _, r := range os.Args {
if err := generate(r); err != nil {
if err := generate(ctx, r); err != nil {
fmt.Printf("generate '%s': %+v\n", r, err)
os.Exit(1)
}
}
} else {
for _, g := range generators {
if err := generate(g.name); err != nil {
if err := generate(ctx, g.name); err != nil {
fmt.Printf("generate '%s': %+v\n", g.name, err)
os.Exit(1)
}
}
}
}

func generate(name string) error {
func generate(ctx context.Context, name string) error {
for _, g := range generators {
if g.name == name {
data, err := g.gen()
data, err := g.gen(ctx)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion models/asymkey/ssh_key_authorized_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func appendAuthorizedKeysToFile(keys ...*PublicKey) error {
}

// RewriteAllPublicKeys removes any authorized key and rewrite all keys from database again.
// Note: db.GetEngine(db.DefaultContext).Iterate does not get latest data after insert/delete, so we have to call this function
// Note: db.GetEngine(ctx).Iterate does not get latest data after insert/delete, so we have to call this function
// outside any session scope independently.
func RewriteAllPublicKeys(ctx context.Context) error {
// Don't rewrite key if internal server
Expand Down
2 changes: 1 addition & 1 deletion models/asymkey/ssh_key_authorized_principals.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import (
const authorizedPrincipalsFile = "authorized_principals"

// RewriteAllPrincipalKeys removes any authorized principal and rewrite all keys from database again.
// Note: db.GetEngine(db.DefaultContext).Iterate does not get latest data after insert/delete, so we have to call this function
// Note: db.GetEngine(ctx).Iterate does not get latest data after insert/delete, so we have to call this function
// outside any session scope independently.
func RewriteAllPrincipalKeys(ctx context.Context) error {
// Don't rewrite key if internal server
Expand Down
11 changes: 6 additions & 5 deletions models/fixture_generation.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package models

import (
"context"
"fmt"
"strings"

Expand All @@ -14,23 +15,23 @@ import (

// GetYamlFixturesAccess returns a string containing the contents
// for the access table, as recalculated using repo.RecalculateAccesses()
func GetYamlFixturesAccess() (string, error) {
func GetYamlFixturesAccess(ctx context.Context) (string, error) {
repos := make([]*repo_model.Repository, 0, 50)
if err := db.GetEngine(db.DefaultContext).Find(&repos); err != nil {
if err := db.GetEngine(ctx).Find(&repos); err != nil {
return "", err
}

for _, repo := range repos {
repo.MustOwner(db.DefaultContext)
if err := access_model.RecalculateAccesses(db.DefaultContext, repo); err != nil {
repo.MustOwner(ctx)
if err := access_model.RecalculateAccesses(ctx, repo); err != nil {
return "", err
}
}

var b strings.Builder

accesses := make([]*access_model.Access, 0, 200)
if err := db.GetEngine(db.DefaultContext).OrderBy("user_id, repo_id").Find(&accesses); err != nil {
if err := db.GetEngine(ctx).OrderBy("user_id, repo_id").Find(&accesses); err != nil {
return "", err
}

Expand Down
8 changes: 5 additions & 3 deletions models/fixture_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
package models

import (
"context"
"os"
"path/filepath"
"testing"

"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
"code.gitea.io/gitea/modules/util"

Expand All @@ -17,8 +19,8 @@ import (
func TestFixtureGeneration(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())

test := func(gen func() (string, error), name string) {
expected, err := gen()
test := func(ctx context.Context, gen func(ctx context.Context) (string, error), name string) {
expected, err := gen(ctx)
if !assert.NoError(t, err) {
return
}
Expand All @@ -31,5 +33,5 @@ func TestFixtureGeneration(t *testing.T) {
assert.EqualValues(t, expected, data, "Differences detected for %s", p)
}

test(GetYamlFixturesAccess, "access")
test(db.DefaultContext, GetYamlFixturesAccess, "access")
}
4 changes: 2 additions & 2 deletions models/org.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ func removeOrgUser(ctx context.Context, orgID, userID int64) error {
}

// RemoveOrgUser removes user from given organization.
func RemoveOrgUser(orgID, userID int64) error {
ctx, committer, err := db.TxContext(db.DefaultContext)
func RemoveOrgUser(ctx context.Context, orgID, userID int64) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
Expand Down
9 changes: 5 additions & 4 deletions models/org_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package models
import (
"testing"

"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/organization"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
Expand All @@ -20,15 +21,15 @@ func TestUser_RemoveMember(t *testing.T) {
// remove a user that is a member
unittest.AssertExistsAndLoadBean(t, &organization.OrgUser{UID: 4, OrgID: 3})
prevNumMembers := org.NumMembers
assert.NoError(t, RemoveOrgUser(org.ID, 4))
assert.NoError(t, RemoveOrgUser(db.DefaultContext, org.ID, 4))
unittest.AssertNotExistsBean(t, &organization.OrgUser{UID: 4, OrgID: 3})
org = unittest.AssertExistsAndLoadBean(t, &organization.Organization{ID: 3})
assert.Equal(t, prevNumMembers-1, org.NumMembers)

// remove a user that is not a member
unittest.AssertNotExistsBean(t, &organization.OrgUser{UID: 5, OrgID: 3})
prevNumMembers = org.NumMembers
assert.NoError(t, RemoveOrgUser(org.ID, 5))
assert.NoError(t, RemoveOrgUser(db.DefaultContext, org.ID, 5))
unittest.AssertNotExistsBean(t, &organization.OrgUser{UID: 5, OrgID: 3})
org = unittest.AssertExistsAndLoadBean(t, &organization.Organization{ID: 3})
assert.Equal(t, prevNumMembers, org.NumMembers)
Expand All @@ -44,15 +45,15 @@ func TestRemoveOrgUser(t *testing.T) {
if unittest.BeanExists(t, &organization.OrgUser{OrgID: orgID, UID: userID}) {
expectedNumMembers--
}
assert.NoError(t, RemoveOrgUser(orgID, userID))
assert.NoError(t, RemoveOrgUser(db.DefaultContext, orgID, userID))
unittest.AssertNotExistsBean(t, &organization.OrgUser{OrgID: orgID, UID: userID})
org = unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: orgID})
assert.EqualValues(t, expectedNumMembers, org.NumMembers)
}
testSuccess(3, 4)
testSuccess(3, 4)

err := RemoveOrgUser(7, 5)
err := RemoveOrgUser(db.DefaultContext, 7, 5)
assert.Error(t, err)
assert.True(t, organization.IsErrLastOrgOwner(err))
unittest.AssertExistsAndLoadBean(t, &organization.OrgUser{OrgID: 7, UID: 5})
Expand Down
4 changes: 2 additions & 2 deletions models/repo/repo_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ func FindUserCodeAccessibleOwnerRepoIDs(ctx context.Context, ownerID int64, user
}

// GetUserRepositories returns a list of repositories of given user.
func GetUserRepositories(opts *SearchRepoOptions) (RepositoryList, int64, error) {
func GetUserRepositories(ctx context.Context, opts *SearchRepoOptions) (RepositoryList, int64, error) {
if len(opts.OrderBy) == 0 {
opts.OrderBy = "updated_unix DESC"
}
Expand All @@ -734,7 +734,7 @@ func GetUserRepositories(opts *SearchRepoOptions) (RepositoryList, int64, error)
cond = cond.And(builder.In("lower_name", opts.LowerNames))
}

sess := db.GetEngine(db.DefaultContext)
sess := db.GetEngine(ctx)

count, err := sess.Where(cond).Count(new(Repository))
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions models/repo/watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ func watchRepoMode(ctx context.Context, watch Watch, mode WatchMode) (err error)
}

// WatchRepoMode watch repository in specific mode.
func WatchRepoMode(userID, repoID int64, mode WatchMode) (err error) {
func WatchRepoMode(ctx context.Context, userID, repoID int64, mode WatchMode) (err error) {
var watch Watch
if watch, err = GetWatch(db.DefaultContext, userID, repoID); err != nil {
if watch, err = GetWatch(ctx, userID, repoID); err != nil {
return err
}
return watchRepoMode(db.DefaultContext, watch, mode)
return watchRepoMode(ctx, watch, mode)
}

// WatchRepo watch or unwatch repository.
Expand Down
8 changes: 4 additions & 4 deletions models/repo/watch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,18 @@ func TestWatchRepoMode(t *testing.T) {

unittest.AssertCount(t, &repo_model.Watch{UserID: 12, RepoID: 1}, 0)

assert.NoError(t, repo_model.WatchRepoMode(12, 1, repo_model.WatchModeAuto))
assert.NoError(t, repo_model.WatchRepoMode(db.DefaultContext, 12, 1, repo_model.WatchModeAuto))
unittest.AssertCount(t, &repo_model.Watch{UserID: 12, RepoID: 1}, 1)
unittest.AssertCount(t, &repo_model.Watch{UserID: 12, RepoID: 1, Mode: repo_model.WatchModeAuto}, 1)

assert.NoError(t, repo_model.WatchRepoMode(12, 1, repo_model.WatchModeNormal))
assert.NoError(t, repo_model.WatchRepoMode(db.DefaultContext, 12, 1, repo_model.WatchModeNormal))
unittest.AssertCount(t, &repo_model.Watch{UserID: 12, RepoID: 1}, 1)
unittest.AssertCount(t, &repo_model.Watch{UserID: 12, RepoID: 1, Mode: repo_model.WatchModeNormal}, 1)

assert.NoError(t, repo_model.WatchRepoMode(12, 1, repo_model.WatchModeDont))
assert.NoError(t, repo_model.WatchRepoMode(db.DefaultContext, 12, 1, repo_model.WatchModeDont))
unittest.AssertCount(t, &repo_model.Watch{UserID: 12, RepoID: 1}, 1)
unittest.AssertCount(t, &repo_model.Watch{UserID: 12, RepoID: 1, Mode: repo_model.WatchModeDont}, 1)

assert.NoError(t, repo_model.WatchRepoMode(12, 1, repo_model.WatchModeNone))
assert.NoError(t, repo_model.WatchRepoMode(db.DefaultContext, 12, 1, repo_model.WatchModeNone))
unittest.AssertCount(t, &repo_model.Watch{UserID: 12, RepoID: 1}, 0)
}
8 changes: 4 additions & 4 deletions models/system/appstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ func init() {
}

// SaveAppStateContent saves the app state item to database
func SaveAppStateContent(key, content string) error {
return db.WithTx(db.DefaultContext, func(ctx context.Context) error {
func SaveAppStateContent(ctx context.Context, key, content string) error {
return db.WithTx(ctx, func(ctx context.Context) error {
eng := db.GetEngine(ctx)
// try to update existing row
res, err := eng.Exec("UPDATE app_state SET revision=revision+1, content=? WHERE id=?", content, key)
Expand All @@ -43,8 +43,8 @@ func SaveAppStateContent(key, content string) error {
}

// GetAppStateContent gets an app state from database
func GetAppStateContent(key string) (content string, err error) {
e := db.GetEngine(db.DefaultContext)
func GetAppStateContent(ctx context.Context, key string) (content string, err error) {
e := db.GetEngine(ctx)
appState := &AppState{ID: key}
has, err := e.Get(appState)
if err != nil {
Expand Down
6 changes: 4 additions & 2 deletions modules/system/appstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

package system

import "context"

// StateStore is the interface to get/set app state items
type StateStore interface {
Get(item StateItem) error
Set(item StateItem) error
Get(ctx context.Context, item StateItem) error
Set(ctx context.Context, item StateItem) error
}

// StateItem provides the name for a state item. the name will be used to generate filenames, etc
Expand Down
11 changes: 6 additions & 5 deletions modules/system/appstate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package system
import (
"testing"

"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -40,25 +41,25 @@ func TestAppStateDB(t *testing.T) {
as := &DBStore{}

item1 := new(testItem1)
assert.NoError(t, as.Get(item1))
assert.NoError(t, as.Get(db.DefaultContext, item1))
assert.Equal(t, "", item1.Val1)
assert.EqualValues(t, 0, item1.Val2)

item1 = new(testItem1)
item1.Val1 = "a"
item1.Val2 = 2
assert.NoError(t, as.Set(item1))
assert.NoError(t, as.Set(db.DefaultContext, item1))

item2 := new(testItem2)
item2.K = "V"
assert.NoError(t, as.Set(item2))
assert.NoError(t, as.Set(db.DefaultContext, item2))

item1 = new(testItem1)
assert.NoError(t, as.Get(item1))
assert.NoError(t, as.Get(db.DefaultContext, item1))
assert.Equal(t, "a", item1.Val1)
assert.EqualValues(t, 2, item1.Val2)

item2 = new(testItem2)
assert.NoError(t, as.Get(item2))
assert.NoError(t, as.Get(db.DefaultContext, item2))
assert.Equal(t, "V", item2.K)
}
10 changes: 6 additions & 4 deletions modules/system/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package system

import (
"context"

"code.gitea.io/gitea/models/system"
"code.gitea.io/gitea/modules/json"

Expand All @@ -14,8 +16,8 @@ import (
type DBStore struct{}

// Get reads the state item
func (f *DBStore) Get(item StateItem) error {
content, err := system.GetAppStateContent(item.Name())
func (f *DBStore) Get(ctx context.Context, item StateItem) error {
content, err := system.GetAppStateContent(ctx, item.Name())
if err != nil {
return err
}
Expand All @@ -26,10 +28,10 @@ func (f *DBStore) Get(item StateItem) error {
}

// Set saves the state item
func (f *DBStore) Set(item StateItem) error {
func (f *DBStore) Set(ctx context.Context, item StateItem) error {
b, err := json.Marshal(item)
if err != nil {
return err
}
return system.SaveAppStateContent(item.Name(), util.BytesToReadOnlyString(b))
return system.SaveAppStateContent(ctx, item.Name(), util.BytesToReadOnlyString(b))
}
Loading

0 comments on commit cddf245

Please sign in to comment.