Skip to content

Commit

Permalink
Implement ConnBeginTx as replacement for deprecated Begin (#1255)
Browse files Browse the repository at this point in the history
* implement missing method

* sprinkle assertions of implementations

* add note for viewers at home
  • Loading branch information
FelipeLema authored Mar 27, 2024
1 parent fe03b98 commit b3f481c
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions clickhouse_std.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ func (o *stdConnOpener) Connect(ctx context.Context) (_ driver.Conn, err error)
return nil, err
}

var _ driver.Connector = (*stdConnOpener)(nil)

func init() {
var debugf = func(format string, v ...any) {}
sql.Register("clickhouse", &stdDriver{debugf: debugf})
Expand Down Expand Up @@ -196,6 +198,12 @@ type stdDriver struct {
debugf func(format string, v ...any)
}

var _ driver.Conn = (*stdDriver)(nil)
var _ driver.ConnBeginTx = (*stdDriver)(nil)
var _ driver.ExecerContext = (*stdDriver)(nil)
var _ driver.QueryerContext = (*stdDriver)(nil)
var _ driver.ConnPrepareContext = (*stdDriver)(nil)

func (std *stdDriver) Open(dsn string) (_ driver.Conn, err error) {
var opt Options
if err := opt.fromDSN(dsn); err != nil {
Expand All @@ -211,6 +219,8 @@ func (std *stdDriver) Open(dsn string) (_ driver.Conn, err error) {
return (&stdConnOpener{opt: o, debugf: debugf}).Connect(context.Background())
}

var _ driver.Driver = (*stdDriver)(nil)

func (std *stdDriver) ResetSession(ctx context.Context) error {
if std.conn.isBad() {
std.debugf("Resetting session because connection is bad")
Expand All @@ -219,9 +229,16 @@ func (std *stdDriver) ResetSession(ctx context.Context) error {
return nil
}

var _ driver.SessionResetter = (*stdDriver)(nil)

func (std *stdDriver) Ping(ctx context.Context) error { return std.conn.ping(ctx) }

var _ driver.Pinger = (*stdDriver)(nil)

func (std *stdDriver) Begin() (driver.Tx, error) { return std, nil }
func (std *stdDriver) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
return std, nil
}

func (std *stdDriver) Commit() error {
if std.commit == nil {
Expand All @@ -248,8 +265,12 @@ func (std *stdDriver) Rollback() error {
return nil
}

var _ driver.Tx = (*stdDriver)(nil)

func (std *stdDriver) CheckNamedValue(nv *driver.NamedValue) error { return nil }

var _ driver.NamedValueChecker = (*stdDriver)(nil)

func (std *stdDriver) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if options := queryOptions(ctx); options.async.ok {
return driver.RowsAffected(0), std.conn.asyncInsert(ctx, query, options.async.wait, rebind(args)...)
Expand Down Expand Up @@ -340,7 +361,10 @@ func (s *stdBatch) ExecContext(ctx context.Context, args []driver.NamedValue) (d
return s.Exec(values)
}

var _ driver.StmtExecContext = (*stdBatch)(nil)

func (s *stdBatch) Query(args []driver.Value) (driver.Rows, error) {
// Note: not implementing driver.StmtQueryContext accordingly
return nil, errors.New("only Exec method supported in batch mode")
}

Expand All @@ -359,6 +383,8 @@ func (r *stdRows) ColumnTypeScanType(idx int) reflect.Type {
return r.rows.block.Columns[idx].ScanType()
}

var _ driver.RowsColumnTypeScanType = (*stdRows)(nil)

func (r *stdRows) ColumnTypeDatabaseTypeName(idx int) string {
return string(r.rows.block.Columns[idx].Type())
}
Expand All @@ -381,6 +407,12 @@ func (r *stdRows) ColumnTypePrecisionScale(idx int) (precision, scale int64, ok
return 0, 0, false
}

var _ driver.Rows = (*stdRows)(nil)
var _ driver.RowsNextResultSet = (*stdRows)(nil)
var _ driver.RowsColumnTypeDatabaseTypeName = (*stdRows)(nil)
var _ driver.RowsColumnTypeNullable = (*stdRows)(nil)
var _ driver.RowsColumnTypePrecisionScale = (*stdRows)(nil)

func (r *stdRows) Next(dest []driver.Value) error {
if len(r.rows.block.Columns) != len(dest) {
err := fmt.Errorf("expected %d destination arguments in Next, not %d", len(r.rows.block.Columns), len(dest))
Expand Down Expand Up @@ -429,10 +461,14 @@ func (r *stdRows) NextResultSet() error {
return nil
}

var _ driver.RowsNextResultSet = (*stdRows)(nil)

func (r *stdRows) Close() error {
err := r.rows.Close()
if err != nil {
r.debugf("Rows Close error: %v\n", err)
}
return err
}

var _ driver.Rows = (*stdRows)(nil)

0 comments on commit b3f481c

Please sign in to comment.