Skip to content

Commit

Permalink
Remove NewSession method from db.Engine interface (go-gitea#17577)
Browse files Browse the repository at this point in the history
* Remove NewSession method from db.Engine interface

* Fix bug

* Some improvements

* Fix bug

* Fix test

* Use XXXBean instead of XXXExample
  • Loading branch information
lunny authored and Stelios Malathouras committed Mar 28, 2022
1 parent 5edbb08 commit 52f17b9
Show file tree
Hide file tree
Showing 44 changed files with 550 additions and 570 deletions.
11 changes: 6 additions & 5 deletions models/branches.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,12 +614,13 @@ func FindRenamedBranch(repoID int64, from string) (branch *RenamedBranch, exist

// RenameBranch rename a branch
func (repo *Repository) RenameBranch(from, to string, gitAction func(isDefault bool) error) (err error) {
sess := db.NewSession(db.DefaultContext)
defer sess.Close()
if err := sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()

sess := db.GetEngine(ctx)
// 1. update default branch if needed
isDefault := repo.DefaultBranch == from
if isDefault {
Expand Down Expand Up @@ -663,10 +664,10 @@ func (repo *Repository) RenameBranch(from, to string, gitAction func(isDefault b
From: from,
To: to,
}
_, err = sess.Insert(renamedBranch)
err = db.Insert(ctx, renamedBranch)
if err != nil {
return err
}

return sess.Commit()
return committer.Commit()
}
55 changes: 26 additions & 29 deletions models/db/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ package db

import (
"context"
"database/sql"

"code.gitea.io/gitea/modules/setting"

"xorm.io/builder"
"xorm.io/xorm"
)

// DefaultContext is the default context to run xorm queries in
Expand Down Expand Up @@ -44,15 +44,6 @@ func (ctx *Context) Engine() Engine {
return ctx.e
}

// NewSession returns a new session
func (ctx *Context) NewSession() *xorm.Session {
e, ok := ctx.e.(*xorm.Engine)
if ok {
return e.NewSession()
}
return nil
}

// Value shadows Value for context.Context but allows us to get ourselves and an Engined object
func (ctx *Context) Value(key interface{}) interface{} {
if key == EnginedContextKey {
Expand All @@ -64,7 +55,6 @@ func (ctx *Context) Value(key interface{}) interface{} {
// Engined structs provide an Engine
type Engined interface {
Engine() Engine
NewSession() *xorm.Session
}

// GetEngine will get a db Engine from this context or return an Engine restricted to this context
Expand All @@ -79,24 +69,6 @@ func GetEngine(ctx context.Context) Engine {
return x.Context(ctx)
}

// NewSession will get a db Session from this context or return a session restricted to this context
func NewSession(ctx context.Context) *xorm.Session {
if engined, ok := ctx.(Engined); ok {
return engined.NewSession()
}

enginedInterface := ctx.Value(EnginedContextKey)
if enginedInterface != nil {
sess := enginedInterface.(Engined).NewSession()
if sess != nil {
return sess.Context(ctx)
}
return nil
}

return x.NewSession().Context(ctx)
}

// Committer represents an interface to Commit or Close the Context
type Committer interface {
Commit() error
Expand Down Expand Up @@ -155,3 +127,28 @@ func Insert(ctx context.Context, beans ...interface{}) error {
_, err := GetEngine(ctx).Insert(beans...)
return err
}

// Exec executes a sql with args
func Exec(ctx context.Context, sqlAndArgs ...interface{}) (sql.Result, error) {
return GetEngine(ctx).Exec(sqlAndArgs...)
}

// GetByBean filled empty fields of the bean according non-empty fields to query in database.
func GetByBean(ctx context.Context, bean interface{}) (bool, error) {
return GetEngine(ctx).Get(bean)
}

// DeleteByBean deletes all records according non-empty fields of the bean as conditions.
func DeleteByBean(ctx context.Context, bean interface{}) (int64, error) {
return GetEngine(ctx).Delete(bean)
}

// CountByBean counts the number of database records according non-empty fields of the bean as conditions.
func CountByBean(ctx context.Context, bean interface{}) (int64, error) {
return GetEngine(ctx).Count(bean)
}

// TableName returns the table name according a bean object
func TableName(bean interface{}) string {
return x.TableName(bean)
}
1 change: 1 addition & 0 deletions models/db/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ type Engine interface {
Asc(colNames ...string) *xorm.Session
Desc(colNames ...string) *xorm.Session
Limit(limit int, start ...int) *xorm.Session
NoAutoTime() *xorm.Session
SumInt(bean interface{}, columnName string) (res int64, err error)
Sync2(...interface{}) error
Select(string) *xorm.Session
Expand Down
2 changes: 1 addition & 1 deletion models/db/list_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func GetPaginatedSession(p Paginator) *xorm.Session {
}

// SetSessionPagination sets pagination for a database session
func SetSessionPagination(sess *xorm.Session, p Paginator) *xorm.Session {
func SetSessionPagination(sess Engine, p Paginator) *xorm.Session {
skip, take := p.GetSkipTake()

return sess.Limit(take, skip)
Expand Down
21 changes: 8 additions & 13 deletions models/issue.go
Original file line number Diff line number Diff line change
Expand Up @@ -1349,10 +1349,9 @@ func applyReviewRequestedCondition(sess *xorm.Session, reviewRequestedID int64)

// CountIssuesByRepo map from repoID to number of issues matching the options
func CountIssuesByRepo(opts *IssuesOptions) (map[int64]int64, error) {
sess := db.NewSession(db.DefaultContext)
defer sess.Close()
e := db.GetEngine(db.DefaultContext)

sess.Join("INNER", "repository", "`issue`.repo_id = `repository`.id")
sess := e.Join("INNER", "repository", "`issue`.repo_id = `repository`.id")

opts.setupSession(sess)

Expand All @@ -1377,10 +1376,9 @@ func CountIssuesByRepo(opts *IssuesOptions) (map[int64]int64, error) {
// GetRepoIDsForIssuesOptions find all repo ids for the given options
func GetRepoIDsForIssuesOptions(opts *IssuesOptions, user *User) ([]int64, error) {
repoIDs := make([]int64, 0, 5)
sess := db.NewSession(db.DefaultContext)
defer sess.Close()
e := db.GetEngine(db.DefaultContext)

sess.Join("INNER", "repository", "`issue`.repo_id = `repository`.id")
sess := e.Join("INNER", "repository", "`issue`.repo_id = `repository`.id")

opts.setupSession(sess)

Expand All @@ -1397,10 +1395,9 @@ func GetRepoIDsForIssuesOptions(opts *IssuesOptions, user *User) ([]int64, error

// Issues returns a list of issues by given conditions.
func Issues(opts *IssuesOptions) ([]*Issue, error) {
sess := db.NewSession(db.DefaultContext)
defer sess.Close()
e := db.GetEngine(db.DefaultContext)

sess.Join("INNER", "repository", "`issue`.repo_id = `repository`.id")
sess := e.Join("INNER", "repository", "`issue`.repo_id = `repository`.id")
opts.setupSession(sess)
sortIssuesSession(sess, opts.SortType, opts.PriorityRepoID)

Expand All @@ -1419,15 +1416,14 @@ func Issues(opts *IssuesOptions) ([]*Issue, error) {

// CountIssues number return of issues by given conditions.
func CountIssues(opts *IssuesOptions) (int64, error) {
sess := db.NewSession(db.DefaultContext)
defer sess.Close()
e := db.GetEngine(db.DefaultContext)

countsSlice := make([]*struct {
RepoID int64
Count int64
}, 0, 1)

sess.Select("COUNT(issue.id) AS count").Table("issue")
sess := e.Select("COUNT(issue.id) AS count").Table("issue")
sess.Join("INNER", "repository", "`issue`.repo_id = `repository`.id")
opts.setupSession(sess)
if err := sess.Find(&countsSlice); err != nil {
Expand Down Expand Up @@ -1901,7 +1897,6 @@ func UpdateIssueDeadline(issue *Issue, deadlineUnix timeutil.TimeStamp, doer *Us
if issue.DeadlineUnix == deadlineUnix {
return nil
}

ctx, committer, err := db.TxContext()
if err != nil {
return err
Expand Down
10 changes: 5 additions & 5 deletions models/issue_comment.go
Original file line number Diff line number Diff line change
Expand Up @@ -1098,17 +1098,17 @@ func UpdateComment(c *Comment, doer *User) error {

// DeleteComment deletes the comment
func DeleteComment(comment *Comment) error {
sess := db.NewSession(db.DefaultContext)
defer sess.Close()
if err := sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()

if err := deleteComment(sess, comment); err != nil {
if err := deleteComment(db.GetEngine(ctx), comment); err != nil {
return err
}

return sess.Commit()
return committer.Commit()
}

func deleteComment(e db.Engine, comment *Comment) error {
Expand Down
2 changes: 1 addition & 1 deletion models/issue_dependency.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func CreateIssueDependency(user *User, issue, dep *Issue) error {
return ErrCircularDependency{issue.ID, dep.ID}
}

if _, err := sess.Insert(&IssueDependency{
if err := db.Insert(ctx, &IssueDependency{
UserID: user.ID,
IssueID: issue.ID,
DependencyID: dep.ID,
Expand Down
10 changes: 6 additions & 4 deletions models/issue_label.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,13 @@ func DeleteLabel(id, labelID int64) error {
return err
}

sess := db.NewSession(db.DefaultContext)
defer sess.Close()
if err = sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()

sess := db.GetEngine(ctx)

if label.BelongsToOrg() && label.OrgID != id {
return nil
Expand All @@ -323,7 +325,7 @@ func DeleteLabel(id, labelID int64) error {
return err
}

return sess.Commit()
return committer.Commit()
}

// getLabelByID returns a label by label id
Expand Down
54 changes: 30 additions & 24 deletions models/issue_milestone.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,22 @@ func (m *Milestone) State() api.StateType {

// NewMilestone creates new milestone of repository.
func NewMilestone(m *Milestone) (err error) {
sess := db.NewSession(db.DefaultContext)
defer sess.Close()
if err = sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()

m.Name = strings.TrimSpace(m.Name)

if _, err = sess.Insert(m); err != nil {
if err = db.Insert(ctx, m); err != nil {
return err
}

if _, err = sess.Exec("UPDATE `repository` SET num_milestones = num_milestones + 1 WHERE id = ?", m.RepoID); err != nil {
if _, err = db.Exec(ctx, "UPDATE `repository` SET num_milestones = num_milestones + 1 WHERE id = ?", m.RepoID); err != nil {
return err
}
return sess.Commit()
return committer.Commit()
}

func getMilestoneByRepoID(e db.Engine, repoID, id int64) (*Milestone, error) {
Expand Down Expand Up @@ -150,11 +150,13 @@ func getMilestoneByID(e db.Engine, id int64) (*Milestone, error) {

// UpdateMilestone updates information of given milestone.
func UpdateMilestone(m *Milestone, oldIsClosed bool) error {
sess := db.NewSession(db.DefaultContext)
defer sess.Close()
if err := sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()

sess := db.GetEngine(ctx)

if m.IsClosed && !oldIsClosed {
m.ClosedDateUnix = timeutil.TimeStampNow()
Expand All @@ -171,7 +173,7 @@ func UpdateMilestone(m *Milestone, oldIsClosed bool) error {
}
}

return sess.Commit()
return committer.Commit()
}

func updateMilestone(e db.Engine, m *Milestone) error {
Expand Down Expand Up @@ -207,11 +209,13 @@ func updateMilestoneCounters(e db.Engine, id int64) error {

// ChangeMilestoneStatusByRepoIDAndID changes a milestone open/closed status if the milestone ID is in the repo.
func ChangeMilestoneStatusByRepoIDAndID(repoID, milestoneID int64, isClosed bool) error {
sess := db.NewSession(db.DefaultContext)
defer sess.Close()
if err := sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()

sess := db.GetEngine(ctx)

m := &Milestone{
ID: milestoneID,
Expand All @@ -229,22 +233,22 @@ func ChangeMilestoneStatusByRepoIDAndID(repoID, milestoneID int64, isClosed bool
return err
}

return sess.Commit()
return committer.Commit()
}

// ChangeMilestoneStatus changes the milestone open/closed status.
func ChangeMilestoneStatus(m *Milestone, isClosed bool) (err error) {
sess := db.NewSession(db.DefaultContext)
defer sess.Close()
if err = sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()

if err := changeMilestoneStatus(sess, m, isClosed); err != nil {
if err := changeMilestoneStatus(db.GetEngine(ctx), m, isClosed); err != nil {
return err
}

return sess.Commit()
return committer.Commit()
}

func changeMilestoneStatus(e db.Engine, m *Milestone, isClosed bool) error {
Expand Down Expand Up @@ -335,11 +339,13 @@ func DeleteMilestoneByRepoID(repoID, id int64) error {
return err
}

sess := db.NewSession(db.DefaultContext)
defer sess.Close()
if err = sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()

sess := db.GetEngine(ctx)

if _, err = sess.ID(m.ID).Delete(new(Milestone)); err != nil {
return err
Expand All @@ -360,10 +366,10 @@ func DeleteMilestoneByRepoID(repoID, id int64) error {
return err
}

if _, err = sess.Exec("UPDATE `issue` SET milestone_id = 0 WHERE milestone_id = ?", m.ID); err != nil {
if _, err = db.Exec(ctx, "UPDATE `issue` SET milestone_id = 0 WHERE milestone_id = ?", m.ID); err != nil {
return err
}
return sess.Commit()
return committer.Commit()
}

// MilestoneList is a list of milestones offering additional functionality
Expand Down
Loading

0 comments on commit 52f17b9

Please sign in to comment.