Skip to content

Commit

Permalink
Merge pull request #162 from arikkfir/add-recursive-cte-support
Browse files Browse the repository at this point in the history
Fix #161: Add support for recursive CTEs
  • Loading branch information
huandu authored Aug 26, 2024
2 parents 2cc1f8c + 4b034a7 commit 8cd72ce
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
15 changes: 15 additions & 0 deletions cte.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ func With(tables ...*CTETableBuilder) *CTEBuilder {
return DefaultFlavor.NewCTEBuilder().With(tables...)
}

// WithRecursive creates a new recursive CTE builder with default flavor.
func WithRecursive(tables ...*CTETableBuilder) *CTEBuilder {
return DefaultFlavor.NewCTEBuilder().WithRecursive(tables...)
}

func newCTEBuilder() *CTEBuilder {
return &CTEBuilder{
args: &Args{},
Expand All @@ -22,6 +27,7 @@ func newCTEBuilder() *CTEBuilder {

// CTEBuilder is a CTE (Common Table Expression) builder.
type CTEBuilder struct {
recursive bool
tableNames []string
tableBuilderVars []string

Expand Down Expand Up @@ -49,6 +55,12 @@ func (cteb *CTEBuilder) With(tables ...*CTETableBuilder) *CTEBuilder {
return cteb
}

// WithRecursive sets the CTE name and columns and turns on the RECURSIVE keyword.
func (cteb *CTEBuilder) WithRecursive(tables ...*CTETableBuilder) *CTEBuilder {
cteb.With(tables...).recursive = true
return cteb
}

// Select creates a new SelectBuilder to build a SELECT statement using this CTE.
func (cteb *CTEBuilder) Select(col ...string) *SelectBuilder {
sb := cteb.args.Flavor.NewSelectBuilder()
Expand All @@ -73,6 +85,9 @@ func (cteb *CTEBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}

if len(cteb.tableBuilderVars) > 0 {
buf.WriteLeadingString("WITH ")
if cteb.recursive {
buf.WriteString("RECURSIVE ")
}
buf.WriteStrings(cteb.tableBuilderVars, ", ")
}

Expand Down
46 changes: 46 additions & 0 deletions cte_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,28 @@ func ExampleWith() {
// WITH users (id, name) AS (SELECT id, name FROM users WHERE name IS NOT NULL), devices AS (SELECT device_id FROM devices) SELECT users.id, orders.id, devices.device_id FROM users, devices JOIN orders ON users.id = orders.user_id AND devices.device_id = orders.device_id
}

func ExampleWithRecursive() {
sb := WithRecursive(
CTETable("source_accounts", "id", "parent_id").As(
UnionAll(
Select("p.id", "p.parent_id").
From("accounts AS p").
Where("p.id = 2"), // Show orders for account 2 and all its child accounts
Select("c.id", "c.parent_id").
From("accounts AS c").
Join("source_accounts AS sa", "c.parent_id = sa.id"),
),
),
).Select("o.id", "o.date", "o.amount").
From("orders AS o").
Join("source_accounts", "o.account_id = source_accounts.id")

fmt.Println(sb)

// Output:
// WITH RECURSIVE source_accounts (id, parent_id) AS ((SELECT p.id, p.parent_id FROM accounts AS p WHERE p.id = 2) UNION ALL (SELECT c.id, c.parent_id FROM accounts AS c JOIN source_accounts AS sa ON c.parent_id = sa.id)) SELECT o.id, o.date, o.amount FROM orders AS o JOIN source_accounts ON o.account_id = source_accounts.id
}

func ExampleCTEBuilder() {
usersBuilder := Select("id", "name", "level").From("users")
usersBuilder.Where(
Expand Down Expand Up @@ -82,3 +104,27 @@ func TestCTEBuilder(t *testing.T) {
sql = ctetb.String()
a.Equal(sql, "/* table init */ t (a, b) /* after table */ AS (SELECT a, b FROM t) /* after table as */")
}

func TestRecursiveCTEBuilder(t *testing.T) {
a := assert.New(t)
cteb := newCTEBuilder()
cteb.recursive = true
ctetb := newCTETableBuilder()
cteb.SQL("/* init */")
cteb.With(ctetb)
cteb.SQL("/* after with */")

ctetb.SQL("/* table init */")
ctetb.Table("t", "a", "b")
ctetb.SQL("/* after table */")

ctetb.As(Select("a", "b").From("t"))
ctetb.SQL("/* after table as */")

sql, args := cteb.Build()
a.Equal(sql, "/* init */ WITH RECURSIVE /* table init */ t (a, b) /* after table */ AS (SELECT a, b FROM t) /* after table as */ /* after with */")
a.Assert(args == nil)

sql = ctetb.String()
a.Equal(sql, "/* table init */ t (a, b) /* after table */ AS (SELECT a, b FROM t) /* after table as */")
}

0 comments on commit 8cd72ce

Please sign in to comment.