Skip to content

Commit

Permalink
Merge pull request #27 from kayrein/query_templates
Browse files Browse the repository at this point in the history
Query templates
  • Loading branch information
qustavo authored Nov 23, 2023
2 parents 6147875 + 138f85d commit 5d06b89
Show file tree
Hide file tree
Showing 6 changed files with 601 additions and 458 deletions.
21 changes: 16 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,27 @@ dot2, err := dotsql.LoadFromFile("queries2.sql")
dot := dotsql.Merge(dot1, dot2)
```

Text Interpolation
--
[text/template](https://pkg.go.dev/text/template)-style text interpolation is supported.

To use, call `.WithData(any)` on your dotsql instance to
create a new instance which passes those values into the templating library.

```sql
-- name: count-users
SELECT count(*) FROM users {{if .exclude_deleted}}WHERE deleted IS NULL{{end}}
```

```go
dotsql.WithData(map[string]any{"exclude_deleted": true}).Query(db, "count-users")
```

Embeding
--
To avoid distributing `sql` files alongside the binary file, you will need to use tools like
[gotic](https://github.com/qustavo/gotic) to embed / pack everything into one file.

TODO
--
- [ ] Enable text interpolation inside queries using `text/template`


SQLX
--
For [sqlx](https://github.com/jmoiron/sqlx) support check [dotsqlx](https://github.com/swithek/dotsqlx)
93 changes: 93 additions & 0 deletions compareCalls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package dotsql

import (
"context"
"reflect"
"testing"
)

type PrepareCalls []struct {
Query string
}

func comparePrepareCalls(t *testing.T, ff PrepareCalls, template string) bool {
t.Helper()
if len(ff) != 1 {
t.Errorf("prepare was expected to be called only once, but was called %d times", len(ff))
return false
} else if ff[0].Query != template {
t.Errorf("prepare was expected to be called with %q query, got %q", template, ff[0].Query)
return false
}
return true
}

type PrepareContextCalls []struct {
Ctx context.Context
Query string
}

func comparePrepareContextCalls(t *testing.T, ff PrepareContextCalls, ctx context.Context, template string) bool {
t.Helper()
if len(ff) != 1 {
t.Errorf("prepare was expected to be called only once, but was called %d times", len(ff))
return false
} else if ff[0].Query != template {
t.Errorf("prepare was expected to be called with %q query, got %q", template, ff[0].Query)
return false
} else if !reflect.DeepEqual(ff[0].Ctx, ctx) {
t.Error("prepare context does not match")
return false
}
return true
}

type QueryCalls []struct {
Query string
Args []interface{}
}

func compareCalls(t *testing.T, ff QueryCalls, command, template, testArg string) bool {
t.Helper()
if len(ff) != 1 {
t.Errorf("%s was expected to be called only once, but was called %d times", command, len(ff))
return false
} else if ff[0].Query != template {
t.Errorf("%s was expected to be called with %q query, got %q", command, template, ff[0].Query)
return false
} else if len(ff[0].Args) != 1 {
t.Errorf("%s was expected to be called with 1 argument, got %d", command, len(ff[0].Args))
return false
} else if !reflect.DeepEqual(ff[0].Args[0], testArg) {
t.Errorf("%s was expected to be called with %q argument, got %v", command, testArg, ff[0].Args[0])
return false
}
return true
}

type QueryContextCalls []struct {
Ctx context.Context
Query string
Args []interface{}
}

func compareContextCalls(t *testing.T, ff QueryContextCalls, ctx context.Context, command, template, testArg string) bool {
t.Helper()
if len(ff) != 1 {
t.Errorf("%s was expected to be called only once, but was called %d times", command, len(ff))
return false
} else if ff[0].Query != template {
t.Errorf("%s was expected to be called with %q query, got %q", command, template, ff[0].Query)
return false
} else if len(ff[0].Args) != 1 {
t.Errorf("%s was expected to be called with 1 argument, got %d", command, len(ff[0].Args))
return false
} else if !reflect.DeepEqual(ff[0].Args[0], testArg) {
t.Errorf("%s was expected to be called with %q argument, got %v", command, testArg, ff[0].Args[0])
return false
} else if !reflect.DeepEqual(ff[0].Ctx, ctx) {
t.Errorf("%s context does not match", command)
return false
}
return true
}
59 changes: 40 additions & 19 deletions dotsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"fmt"
"io"
"os"
"text/template"
)

// Preparer is an interface used by Prepare.
Expand Down Expand Up @@ -58,21 +59,34 @@ type ExecerContext interface {

// DotSql represents a dotSQL queries holder.
type DotSql struct {
queries map[string]string
queries map[string]*template.Template
data any
}

func (d DotSql) lookupQuery(name string) (query string, err error) {
query, ok := d.queries[name]
func (d DotSql) WithData(data any) DotSql {
return DotSql{queries: d.queries, data: data}
}

func (d DotSql) lookupQuery(name string, data any) (string, error) {
template, ok := d.queries[name]
if !ok {
err = fmt.Errorf("dotsql: '%s' could not be found", name)
return "", fmt.Errorf("dotsql: '%s' could not be found", name)
}
if template == nil {
return "", nil
}
buffer := bytes.NewBufferString("")
err := template.Execute(buffer, data)
if err != nil {
return "", fmt.Errorf("error parsing template: %w", err)
}

return
return buffer.String(), nil
}

// Prepare is a wrapper for database/sql's Prepare(), using dotsql named query.
func (d DotSql) Prepare(db Preparer, name string) (*sql.Stmt, error) {
query, err := d.lookupQuery(name)
query, err := d.lookupQuery(name, d.data)
if err != nil {
return nil, err
}
Expand All @@ -82,7 +96,7 @@ func (d DotSql) Prepare(db Preparer, name string) (*sql.Stmt, error) {

// PrepareContext is a wrapper for database/sql's PrepareContext(), using dotsql named query.
func (d DotSql) PrepareContext(ctx context.Context, db PreparerContext, name string) (*sql.Stmt, error) {
query, err := d.lookupQuery(name)
query, err := d.lookupQuery(name, d.data)
if err != nil {
return nil, err
}
Expand All @@ -92,7 +106,7 @@ func (d DotSql) PrepareContext(ctx context.Context, db PreparerContext, name str

// Query is a wrapper for database/sql's Query(), using dotsql named query.
func (d DotSql) Query(db Queryer, name string, args ...interface{}) (*sql.Rows, error) {
query, err := d.lookupQuery(name)
query, err := d.lookupQuery(name, d.data)
if err != nil {
return nil, err
}
Expand All @@ -102,7 +116,7 @@ func (d DotSql) Query(db Queryer, name string, args ...interface{}) (*sql.Rows,

// QueryContext is a wrapper for database/sql's QueryContext(), using dotsql named query.
func (d DotSql) QueryContext(ctx context.Context, db QueryerContext, name string, args ...interface{}) (*sql.Rows, error) {
query, err := d.lookupQuery(name)
query, err := d.lookupQuery(name, d.data)
if err != nil {
return nil, err
}
Expand All @@ -112,7 +126,7 @@ func (d DotSql) QueryContext(ctx context.Context, db QueryerContext, name string

// QueryRow is a wrapper for database/sql's QueryRow(), using dotsql named query.
func (d DotSql) QueryRow(db QueryRower, name string, args ...interface{}) (*sql.Row, error) {
query, err := d.lookupQuery(name)
query, err := d.lookupQuery(name, d.data)
if err != nil {
return nil, err
}
Expand All @@ -122,7 +136,7 @@ func (d DotSql) QueryRow(db QueryRower, name string, args ...interface{}) (*sql.

// QueryRowContext is a wrapper for database/sql's QueryRowContext(), using dotsql named query.
func (d DotSql) QueryRowContext(ctx context.Context, db QueryRowerContext, name string, args ...interface{}) (*sql.Row, error) {
query, err := d.lookupQuery(name)
query, err := d.lookupQuery(name, d.data)
if err != nil {
return nil, err
}
Expand All @@ -132,7 +146,7 @@ func (d DotSql) QueryRowContext(ctx context.Context, db QueryRowerContext, name

// Exec is a wrapper for database/sql's Exec(), using dotsql named query.
func (d DotSql) Exec(db Execer, name string, args ...interface{}) (sql.Result, error) {
query, err := d.lookupQuery(name)
query, err := d.lookupQuery(name, d.data)
if err != nil {
return nil, err
}
Expand All @@ -142,7 +156,7 @@ func (d DotSql) Exec(db Execer, name string, args ...interface{}) (sql.Result, e

// ExecContext is a wrapper for database/sql's ExecContext(), using dotsql named query.
func (d DotSql) ExecContext(ctx context.Context, db ExecerContext, name string, args ...interface{}) (sql.Result, error) {
query, err := d.lookupQuery(name)
query, err := d.lookupQuery(name, d.data)
if err != nil {
return nil, err
}
Expand All @@ -152,11 +166,11 @@ func (d DotSql) ExecContext(ctx context.Context, db ExecerContext, name string,

// Raw returns the query, everything after the --name tag
func (d DotSql) Raw(name string) (string, error) {
return d.lookupQuery(name)
return d.lookupQuery(name, d.data)
}

// QueryMap returns a map[string]string of loaded queries
func (d DotSql) QueryMap() map[string]string {
func (d DotSql) QueryMap() map[string]*template.Template {
return d.queries
}

Expand All @@ -165,11 +179,18 @@ func Load(r io.Reader) (*DotSql, error) {
scanner := &Scanner{}
queries := scanner.Run(bufio.NewScanner(r))

dotsql := &DotSql{
queries: queries,
templates := make(map[string]*template.Template)
for k, v := range queries {
tmpl, err := template.New(k).Parse(v)
if err != nil {
return nil, err
}
templates[k] = tmpl
}

return dotsql, nil
return &DotSql{
queries: templates,
}, nil
}

// LoadFromFile imports SQL queries from the file.
Expand All @@ -193,7 +214,7 @@ func LoadFromString(sql string) (*DotSql, error) {
// It's in-order, so the last source will override queries with the same name
// in the previous arguments if any.
func Merge(dots ...*DotSql) *DotSql {
queries := make(map[string]string)
queries := make(map[string]*template.Template)

for _, dot := range dots {
for k, v := range dot.QueryMap() {
Expand Down
Loading

0 comments on commit 5d06b89

Please sign in to comment.