Skip to content

Commit

Permalink
Merge pull request #20 from muroon/prepare
Browse files Browse the repository at this point in the history
Prepared Statement
  • Loading branch information
muroon authored Oct 15, 2021
2 parents ddd2f4c + 3bf3031 commit 2ef1be7
Show file tree
Hide file tree
Showing 9 changed files with 552 additions and 5 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ Note
- Detailed explanation is described [here](doc/result_mode.md).
- [Usages of Result Mode](doc/result_mode.md#usages).

## Prepared Statements

You can use [Athena Prepared Statements](https://docs.aws.amazon.com/athena/latest/ug/querying-with-prepared-statements.html).
Click [here](doc/prepared_statements.md) for details on how to use.

## Testing

Athena doesn't have a local version and revolves around S3 so our tests are
Expand Down
70 changes: 68 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error)
// mode ctas
var ctasTable string
var afterDownload func() error
if isSelect && resultMode == ResultModeGzipDL {
if isCreatingCTASTable(isSelect, resultMode) {
// Create AS Select
ctasTable = fmt.Sprintf("tmp_ctas_%v", strings.Replace(uuid.NewV4().String(), "-", "", -1))
query = fmt.Sprintf("CREATE TABLE %s WITH (format='TEXTFILE') AS %s", ctasTable, query)
Expand Down Expand Up @@ -183,7 +183,69 @@ func (c *conn) waitOnQuery(ctx context.Context, queryID string) error {
}

func (c *conn) Prepare(query string) (driver.Stmt, error) {
panic("Athena doesn't support prepared statements")
return c.prepareContext(context.Background(), query)
}

func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
stmt, err := c.prepareContext(ctx, query)

select {
default:
case <-ctx.Done():
stmt.Close()
return nil, ctx.Err()
}

return stmt, err
}

func (c *conn) prepareContext(ctx context.Context, query string) (driver.Stmt, error) {
// resultMode
isSelect := isSelectQuery(query)
resultMode := c.resultMode
if rmode, ok := getResultMode(ctx); ok {
resultMode = rmode
}
if !isSelect {
resultMode = ResultModeAPI
}

// ctas
var ctasTable string
var afterDownload func() error
if isCreatingCTASTable(isSelect, resultMode) {
// Create AS Select
ctasTable = fmt.Sprintf("tmp_ctas_%v", strings.Replace(uuid.NewV4().String(), "-", "", -1))
query = fmt.Sprintf("CREATE TABLE %s WITH (format='TEXTFILE') AS %s", ctasTable, query)
afterDownload = c.dropCTASTable(ctx, ctasTable)
}

numInput := len(strings.Split(query, "?")) - 1

// prepare
prepareKey := fmt.Sprintf("tmp_prepare_%v", strings.Replace(uuid.NewV4().String(), "-", "", -1))
newQuery := fmt.Sprintf("PREPARE %s FROM %s", prepareKey, query)

queryID, err := c.startQuery(newQuery)
if err != nil {
return nil, err
}

if err := c.waitOnQuery(ctx, queryID); err != nil {
return nil, err
}

return &stmtAthena{
prepareKey: prepareKey,
numInput: numInput,
ctasTable: ctasTable,
afterDownload: afterDownload,
conn: c,
resultMode: resultMode,
}, nil
}

func (c *conn) Begin() (driver.Tx, error) {
Expand Down Expand Up @@ -226,3 +288,7 @@ func isSelectQuery(query string) bool {
func isCTASQuery(query string) bool {
return regexp.MustCompile(`(?i)^CREATE.+AS\s+SELECT`).Match([]byte(query))
}

func isCreatingCTASTable(isSelect bool, resultMode ResultMode) bool {
return isSelect && resultMode == ResultModeGzipDL
}
2 changes: 1 addition & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ const catalogContextKey string = "catalog_key"
var CatalogContextKey string = contextPrefix + catalogContextKey

// SetCatalog set catalog from context
func SetTimout(ctx context.Context, catalog string) context.Context {
func SetCatalog(ctx context.Context, catalog string) context.Context {
return context.WithValue(ctx, CatalogContextKey, catalog)
}

Expand Down
72 changes: 72 additions & 0 deletions context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package athena

import (
"context"
"testing"
)

func Test_getCatalog(t *testing.T) {
tests := []struct {
name string
ctx context.Context
want string
want1 bool
}{
{
name: "Default",
ctx: context.Background(),
want: "",
want1: false,
},
{
name: "SetCatalog",
ctx: SetCatalog(context.Background(), "test_catalog"),
want: "test_catalog",
want1: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := getCatalog(tt.ctx)
if got != tt.want {
t.Errorf("getCatalog() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("getCatalog() got1 = %v, want %v", got1, tt.want1)
}
})
}
}

func Test_getTimeout(t *testing.T) {
tests := []struct {
name string
ctx context.Context
want uint
want1 bool
}{
{
name: "Default",
ctx: context.Background(),
want: 0,
want1: false,
},
{
name: "SetTimeout",
ctx: SetTimeout(context.Background(), 100),
want: 100,
want1: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := getTimeout(tt.ctx)
if got != tt.want {
t.Errorf("getTimeout() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("getTimeout() got1 = %v, want %v", got1, tt.want1)
}
})
}
}
144 changes: 144 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,146 @@ func TestQuery(t *testing.T) {
require.Equal(t, 3, index+1, fmt.Sprintf("row count. resultMode:%v", resultMode))
}
}

func TestPrepare(t *testing.T) {
harness := setup(t, false)
defer harness.teardown()

data := []dummyRow{
{
SmallintType: 1,
IntType: 2,
BigintType: 3,
BooleanType: true,
FloatType: 3.1415928,
DoubleType: 3141592653589.793,
StringType: "some string",
TimestampType: athenaTimestamp(time.Date(2006, 1, 2, 3, 4, 11, 0, time.UTC)),
DateType: athenaDate(time.Date(2006, 1, 2, 0, 0, 0, 0, time.UTC)),
DecimalType: 1001,
},
{
SmallintType: 9,
IntType: 8,
BigintType: 0,
BooleanType: false,
FloatType: 3.1415930,
DoubleType: 3141592653589.79,
StringType: "another string",
TimestampType: athenaTimestamp(time.Date(2017, 12, 3, 1, 11, 12, 0, time.UTC)),
DateType: athenaDate(time.Date(2017, 12, 3, 0, 0, 0, 0, time.UTC)),
DecimalType: 0,
},
{
SmallintType: 9,
IntType: 8,
BigintType: 0,
BooleanType: false,
FloatType: 3.14159,
DoubleType: 3141592653589.8,
StringType: "123.456",
TimestampType: athenaTimestamp(time.Date(2017, 12, 3, 20, 11, 12, 0, time.UTC)),
DateType: athenaDate(time.Date(2017, 12, 3, 0, 0, 0, 0, time.UTC)),
DecimalType: 0.48,
},
}
harness.uploadData(data)

resultModes := []ResultMode{
ResultModeAPI,
ResultModeDL,
ResultModeGzipDL,
}

tests := []struct {
name string
sql string
params []interface{}
startFunc func(ctx context.Context) context.Context
endFunc func(ctx context.Context) context.Context
want dummyRow
}{
{
name: "NoInput",
sql: fmt.Sprintf("select * from %s order by intType limit 1", harness.table),
params: []interface{}{},
want: data[0],
},
{
name: "IntType",
sql: fmt.Sprintf("select * from %s where intType = ?", harness.table),
params: []interface{}{data[0].IntType},
want: data[0],
},
{
name: "StringType",
sql: fmt.Sprintf("select * from %s where stringType = ?", harness.table),
params: []interface{}{data[0].StringType},
want: data[0],
},
{
name: "FloatType",
sql: fmt.Sprintf("select * from %s where floattype = ?", harness.table),
params: []interface{}{data[0].FloatType},
want: data[0],
},
{
name: "DoubleType",
sql: fmt.Sprintf("select * from %s where doubletype = ?", harness.table),
params: []interface{}{data[0].DoubleType},
want: data[0],
},
}

for _, resultMode := range resultModes {
ctx := context.Background()
switch resultMode {
case ResultModeAPI:
ctx = SetAPIMode(ctx)
case ResultModeDL:
ctx = SetDLMode(ctx)
case ResultModeGzipDL:
ctx = SetGzipDLMode(ctx)
}

for _, test := range tests {
t.Run(fmt.Sprintf("ResultMode:%v/%s", resultMode, test.name), func(t *testing.T) {
if startFunc := test.startFunc; startFunc != nil {
ctx = startFunc(ctx)
}
if endFunc := test.startFunc; endFunc != nil {
defer func() {
ctx = endFunc(ctx)
}()
}

stmt, err := harness.prepare(ctx, test.sql)
defer func() {
err := stmt.Close()
require.NoError(t, err)
}()
require.NoError(t, err)

rows, err := stmt.QueryContext(ctx, test.params...)
defer rows.Close()
require.NoError(t, err)

var length int
for rows.Next() {
length++
var got dummyRow
err := rows.Scan(
&got.NullValue, &got.SmallintType, &got.IntType, &got.BigintType, &got.BooleanType, &got.FloatType, &got.DoubleType, &got.StringType, &got.TimestampType, &got.DateType, &got.DecimalType,
)
require.NoError(t, err)
assert.Equal(t, test.want, got, fmt.Sprintf("resultMode:%v, prepareIntType error", resultMode))
}
assert.Equal(t, 1, length)
})
}
}
}

func TestQueryForUsingWorkGroup(t *testing.T) {
resultModes := []ResultMode{
ResultModeAPI,
Expand Down Expand Up @@ -333,6 +473,10 @@ func (a *athenaHarness) mustQuery(ctx context.Context, sql string, args ...inter
return rows
}

func (a *athenaHarness) prepare(ctx context.Context, sql string) (*sql.Stmt, error) {
return a.db.PrepareContext(ctx, sql)
}

func (a *athenaHarness) uploadData(rows []dummyRow) {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
Expand Down
53 changes: 53 additions & 0 deletions doc/prepared_statements.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Prepared Statements

You can use [Prepared Statements on Athena](https://docs.aws.amazon.com/athena/latest/ug/querying-with-prepared-statements.html).

## How to use

```
db, _ := sql.Open("athena", "db=default&output_location=s3://results")
// 1. Prepare
stmt, _ := db.PrepareContext(ctx, "SELECT url, code FROM cloudfront WHERE code = ?")
defer stmt.Close() // 3. Close
// 2. Execute
rows, _ := stmt.QueryContext(ctx, targetCode)
defer rows.Close()
for rows.Next() {
var url string
var code int
rows.Scan(&url, &code)
}
```

### 1. Prepare
- Run [PREPARE](https://docs.aws.amazon.com/athena/latest/ug/querying-with-prepared-statements.html#querying-with-prepared-statements-sql-statements) on Athena to create a Statement for use with Athena
- Create a Stmt object and keep statement_name inside
- The stmt object is valid until Close method is executed
- Result Mode
- Available under all Result Modes
- ResultMode can be specified in PrepareContext
```
rows, _ := stmt.PrepareContext(SetDLMode(ctx), sql) // Prepares a statement in DL Mode
```
### 2. Execute
- Run a prepared statement using [EXECUTE](https://docs.aws.amazon.com/athena/latest/ug/querying-with-prepared-statements.html#querying-with-prepared-statements-sql-statements) on Athena
- You can specify parameters
- Use QueryContext and ExecContext methods
### 3. Close
- Run [DEALLOCATE PREPARE](https://docs.aws.amazon.com/athena/latest/ug/querying-with-prepared-statements.html#querying-with-prepared-statements-sql-statements) and delete the prepared statement
- Use the Close method of the Stmt object
## Examples
```
intParam := 1
stringParam := "string value"
stmt, _ := db.PrepareContext(ctx, "SELECT * FROM test_table WHERE int_column = ? AND string_column = ?")
rows, _ := stmt.QueryContext(ctx, intParam, stringParam)
```
execute `SELECT * FROM test_table WHERE int_column = 1 and string_column = 'string value'`
Loading

0 comments on commit 2ef1be7

Please sign in to comment.