diff --git a/examples/batch/postgresql/batch.go b/examples/batch/postgresql/batch.go index e362268e82..aa2dbef34c 100644 --- a/examples/batch/postgresql/batch.go +++ b/examples/batch/postgresql/batch.go @@ -304,6 +304,71 @@ func (b *DeleteBookNamedSignBatchResults) Close() error { return b.br.Close() } +const getAuthorWithFirstBook = `-- name: GetAuthorWithFirstBook :batchone +SELECT books.book_id, books.author_id, books.isbn, books.book_type, books.title, books.year, books.available, books.tags, authors.author_id, authors.name, authors.biography +FROM authors +INNER JOIN books ON authors.author_id = books.author_id +WHERE authors.author_id = $1 +` + +type GetAuthorWithFirstBookBatchResults struct { + br pgx.BatchResults + tot int + closed bool +} + +type GetAuthorWithFirstBookRow struct { + Book Book `json:"book"` + Author Author `json:"author"` +} + +func (q *Queries) GetAuthorWithFirstBook(ctx context.Context, authorID []int32) *GetAuthorWithFirstBookBatchResults { + batch := &pgx.Batch{} + for _, a := range authorID { + vals := []interface{}{ + a, + } + batch.Queue(getAuthorWithFirstBook, vals...) + } + br := q.db.SendBatch(ctx, batch) + return &GetAuthorWithFirstBookBatchResults{br, len(authorID), false} +} + +func (b *GetAuthorWithFirstBookBatchResults) QueryRow(f func(int, GetAuthorWithFirstBookRow, error)) { + defer b.br.Close() + for t := 0; t < b.tot; t++ { + var i GetAuthorWithFirstBookRow + if b.closed { + if f != nil { + f(t, i, ErrBatchAlreadyClosed) + } + continue + } + row := b.br.QueryRow() + err := row.Scan( + &i.Book.BookID, + &i.Book.AuthorID, + &i.Book.Isbn, + &i.Book.BookType, + &i.Book.Title, + &i.Book.Year, + &i.Book.Available, + &i.Book.Tags, + &i.Author.AuthorID, + &i.Author.Name, + &i.Author.Biography, + ) + if f != nil { + f(t, i, err) + } + } +} + +func (b *GetAuthorWithFirstBookBatchResults) Close() error { + b.closed = true + return b.br.Close() +} + const getBiography = `-- name: GetBiography :batchone SELECT biography FROM authors WHERE author_id = $1 diff --git a/examples/batch/postgresql/querier.go b/examples/batch/postgresql/querier.go index 8cacdbe09e..f36b72d94f 100644 --- a/examples/batch/postgresql/querier.go +++ b/examples/batch/postgresql/querier.go @@ -19,6 +19,7 @@ type Querier interface { DeleteBookNamedFunc(ctx context.Context, bookID []int32) *DeleteBookNamedFuncBatchResults DeleteBookNamedSign(ctx context.Context, bookID []int32) *DeleteBookNamedSignBatchResults GetAuthor(ctx context.Context, authorID int32) (Author, error) + GetAuthorWithFirstBook(ctx context.Context, authorID []int32) *GetAuthorWithFirstBookBatchResults GetBiography(ctx context.Context, authorID []int32) *GetBiographyBatchResults UpdateBook(ctx context.Context, arg []UpdateBookParams) *UpdateBookBatchResults } diff --git a/examples/batch/postgresql/query.sql b/examples/batch/postgresql/query.sql index 4e21f25285..3c56e046cf 100644 --- a/examples/batch/postgresql/query.sql +++ b/examples/batch/postgresql/query.sql @@ -54,3 +54,9 @@ WHERE book_id = $3; -- name: GetBiography :batchone SELECT biography FROM authors WHERE author_id = $1; + +-- name: GetAuthorWithFirstBook :batchone +SELECT sqlc.embed (books), sqlc.embed (authors) +FROM authors +INNER JOIN books ON authors.author_id = books.author_id +WHERE authors.author_id = $1; diff --git a/internal/sql/rewrite/embeds.go b/internal/sql/rewrite/embeds.go index 596c03be89..ce275505c4 100644 --- a/internal/sql/rewrite/embeds.go +++ b/internal/sql/rewrite/embeds.go @@ -2,6 +2,7 @@ package rewrite import ( "fmt" + "strings" "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/astutils" @@ -9,14 +10,15 @@ import ( // Embed is an instance of `sqlc.embed(param)` type Embed struct { - Table *ast.TableName - param string - Node *ast.ColumnRef + Table *ast.TableName + param string + origin string + Node *ast.ColumnRef } // Orig string to replace func (e Embed) Orig() string { - return fmt.Sprintf("sqlc.embed(%s)", e.param) + return e.origin } // EmbedSet is a set of Embed instances @@ -60,10 +62,16 @@ func Embeds(raw *ast.RawStmt) (*ast.RawStmt, EmbedSet) { }, } + // Get the origin string + const noSpaceLen = 11 + spaceCount := fun.Args.Items[0].Pos() - fun.Pos() - noSpaceLen + origin := fmt.Sprintf("sqlc.embed%s(%s)", strings.Repeat(" ", spaceCount), param) + embeds = append(embeds, &Embed{ - Table: &ast.TableName{Name: param}, - param: param, - Node: node, + Table: &ast.TableName{Name: param}, + param: param, + origin: origin, + Node: node, }) cr.Replace(node)