diff --git a/README.md b/README.md index 8295b3693f..c80b71a142 100644 --- a/README.md +++ b/README.md @@ -38,12 +38,14 @@ The client is tested against the currently [supported versions](https://github.c * Named and numeric placeholders support * LZ4/ZSTD compression support * External data +* [Query parameters](examples/std/query_parameters.go) Support for the ClickHouse protocol advanced features using `Context`: * Query ID * Quota Key * Settings +* [Query parameters](examples/clickhouse_api/query_parameters.go) * OpenTelemetry * Execution events: * Logs @@ -267,14 +269,16 @@ go get -u github.com/ClickHouse/clickhouse-go/v2 * [batch struct](examples/clickhouse_api/append_struct.go) * [columnar](examples/clickhouse_api/columnar_insert.go) * [scan struct](examples/clickhouse_api/scan_struct.go) -* [bind params](examples/clickhouse_api/bind.go) +* [query parameters](examples/clickhouse_api/query_parameters.go) (deprecated in favour of native query parameters) +* [bind params](examples/clickhouse_api/bind.go) (deprecated in favour of native query parameters) ### std `database/sql` interface * [batch](examples/std/batch.go) * [async insert](examples/std/async.go) * [open db](examples/std/connect.go) -* [bind params](examples/std/bind.go) +* [query parameters](examples/std/query_parameters.go) +* [bind params](examples/std/bind.go) (deprecated in favour of native query parameters) ## ClickHouse alternatives - ch-go diff --git a/conn.go b/conn.go index fa557634be..75eaf607cf 100644 --- a/conn.go +++ b/conn.go @@ -72,7 +72,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er var ( connect = &connect{ - id: num, + id: num, opt: opt, conn: conn, debugf: debugf, @@ -91,6 +91,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er if err := connect.handshake(opt.Auth.Database, opt.Auth.Username, opt.Auth.Password); err != nil { return nil, err } + if connect.revision >= proto.DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM { + if err := connect.sendAddendum(); err != nil { + return nil, err + } + } // warn only on the first connection in the pool if num == 1 && !resources.ClientMeta.IsSupportedClickHouseVersion(connect.server.Version) { @@ -103,7 +108,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er // https://github.com/ClickHouse/ClickHouse/blob/master/src/Client/Connection.cpp type connect struct { - id int + id int opt *Options conn net.Conn debugf func(format string, v ...interface{}) @@ -203,7 +208,7 @@ func (c *connect) sendData(block *proto.Block, name string) error { return err } for i := range block.Columns { - if err := block.EncodeColumn(c.buffer, i); err != nil { + if err := block.EncodeColumn(c.buffer, c.revision, i); err != nil { return err } if len(c.buffer.Buf) >= c.maxCompressionBuffer { diff --git a/conn_exec.go b/conn_exec.go index 1a39da0d05..b7666214b8 100644 --- a/conn_exec.go +++ b/conn_exec.go @@ -19,13 +19,15 @@ package clickhouse import ( "context" + "github.com/ClickHouse/clickhouse-go/v2/lib/proto" "time" ) func (c *connect) exec(ctx context.Context, query string, args ...interface{}) error { var ( - options = queryOptions(ctx) - body, err = bind(c.server.Timezone, query, args...) + options = queryOptions(ctx) + queryParamsProtocolSupport = c.revision >= proto.DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS + body, err = bindQueryOrAppendParameters(queryParamsProtocolSupport, &options, query, c.server.Timezone, args...) ) if err != nil { return err diff --git a/conn_handshake.go b/conn_handshake.go index a476bdba09..ff3be24645 100644 --- a/conn_handshake.go +++ b/conn_handshake.go @@ -76,3 +76,11 @@ func (c *connect) handshake(database, username, password string) error { c.debugf("[handshake] <- %s", c.server) return nil } + +func (c *connect) sendAddendum() error { + if c.revision >= proto.DBMS_MIN_PROTOCOL_VERSION_WITH_QUOTA_KEY { + c.buffer.PutString("") // todo quota key support + } + + return c.flush() +} diff --git a/conn_http.go b/conn_http.go index 419691fe4c..3947a9cdcf 100644 --- a/conn_http.go +++ b/conn_http.go @@ -440,6 +440,9 @@ func (h *httpConnect) prepareRequest(ctx context.Context, reader io.Reader, opti } query.Set(key, fmt.Sprint(value)) } + for key, value := range options.parameters { + query.Set(fmt.Sprintf("param_%s", key), value) + } req.URL.RawQuery = query.Encode() } diff --git a/conn_http_exec.go b/conn_http_exec.go index 89afce8689..7778f9f337 100644 --- a/conn_http_exec.go +++ b/conn_http_exec.go @@ -25,13 +25,12 @@ import ( ) func (h *httpConnect) exec(ctx context.Context, query string, args ...interface{}) error { - query, err := bind(h.location, query, args...) + options := queryOptions(ctx) + query, err := bindQueryOrAppendParameters(true, &options, query, h.location, args...) if err != nil { return err } - options := queryOptions(ctx) - res, err := h.sendQuery(ctx, strings.NewReader(query), &options, h.headers) if res != nil { defer res.Body.Close() diff --git a/conn_http_query.go b/conn_http_query.go index 828fc58321..9d7d746aee 100644 --- a/conn_http_query.go +++ b/conn_http_query.go @@ -30,11 +30,11 @@ import ( // release is ignored, because http used by std with empty release function func (h *httpConnect) query(ctx context.Context, release func(*connect, error), query string, args ...interface{}) (*rows, error) { - query, err := bind(h.location, query, args...) + options := queryOptions(ctx) + query, err := bindQueryOrAppendParameters(true, &options, query, h.location, args...) if err != nil { return nil, err } - options := queryOptions(ctx) headers := make(map[string]string) switch h.compression { case CompressionZSTD, CompressionLZ4: diff --git a/conn_query.go b/conn_query.go index 9dfa9f7999..b673ca4a0d 100644 --- a/conn_query.go +++ b/conn_query.go @@ -26,9 +26,10 @@ import ( func (c *connect) query(ctx context.Context, release func(*connect, error), query string, args ...interface{}) (*rows, error) { var ( - options = queryOptions(ctx) - onProcess = options.onProcess() - body, err = bind(c.server.Timezone, query, args...) + options = queryOptions(ctx) + onProcess = options.onProcess() + queryParamsProtocolSupport = c.revision >= proto.DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS + body, err = bindQueryOrAppendParameters(queryParamsProtocolSupport, &options, query, c.server.Timezone, args...) ) if err != nil { diff --git a/conn_send_query.go b/conn_send_query.go index f2367b798a..df21b7ac87 100644 --- a/conn_send_query.go +++ b/conn_send_query.go @@ -34,6 +34,7 @@ func (c *connect) sendQuery(body string, o *QueryOptions) error { Compression: c.compression != CompressionNone, InitialAddress: c.conn.LocalAddr().String(), Settings: c.settings(o.settings), + Parameters: parametersToProtoParameters(o.parameters), } if err := q.Encode(c.buffer, c.revision); err != nil { return err @@ -48,3 +49,14 @@ func (c *connect) sendQuery(body string, o *QueryOptions) error { } return c.flush() } + +func parametersToProtoParameters(parameters Parameters) (s proto.Parameters) { + for k, v := range parameters { + s = append(s, proto.Parameter{ + Key: k, + Value: v, + }) + } + + return s +} diff --git a/context.go b/context.go index bb56c6e14c..11ba1a6c17 100644 --- a/context.go +++ b/context.go @@ -32,6 +32,7 @@ var _contextOptionKey = &QueryOptions{ } type Settings map[string]interface{} +type Parameters map[string]string type ( QueryOption func(*QueryOptions) error QueryOptions struct { @@ -49,6 +50,7 @@ type ( profileEvents func([]ProfileEvent) } settings Settings + parameters Parameters external []*ext.Table blockBufferSize uint8 } @@ -89,6 +91,13 @@ func WithSettings(settings Settings) QueryOption { } } +func WithParameters(params Parameters) QueryOption { + return func(o *QueryOptions) error { + o.parameters = params + return nil + } +} + func WithLogs(fn func(*Log)) QueryOption { return func(o *QueryOptions) error { o.events.logs = fn diff --git a/examples/clickhouse_api/main_test.go b/examples/clickhouse_api/main_test.go index 548b57458b..0037099e5c 100644 --- a/examples/clickhouse_api/main_test.go +++ b/examples/clickhouse_api/main_test.go @@ -173,6 +173,10 @@ func TestQueryRow(t *testing.T) { require.NoError(t, QueryRow()) } +func TestQueryWithParameters(t *testing.T) { + require.NoError(t, QueryWithParameters()) +} + func TestSelectStruct(t *testing.T) { require.NoError(t, SelectStruct()) } diff --git a/examples/clickhouse_api/query_parameters.go b/examples/clickhouse_api/query_parameters.go new file mode 100644 index 0000000000..8c0f0920a8 --- /dev/null +++ b/examples/clickhouse_api/query_parameters.go @@ -0,0 +1,54 @@ +// Licensed to ClickHouse, Inc. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. ClickHouse, Inc. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package clickhouse_api + +import ( + "context" + "fmt" + "github.com/ClickHouse/clickhouse-go/v2" + clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests" +) + +func QueryWithParameters() error { + conn, err := GetNativeConnection(nil, nil, nil) + if err != nil { + return err + } + + if !clickhouse_tests.CheckMinServerServerVersion(conn, 22, 8, 0) { + return nil + } + + chCtx := clickhouse.Context(context.Background(), clickhouse.WithParameters(clickhouse.Parameters{ + "num": "42", + "str": "hello", + "array": "['a', 'b', 'c']", + })) + + row := conn.QueryRow(chCtx, "SELECT {num:UInt64} v, {str:String} s, {array:Array(String)} a") + var ( + col1 uint64 + col2 string + col3 []string + ) + if err := row.Scan(&col1, &col2, &col3); err != nil { + return err + } + fmt.Printf("row: col1=%d, col2=%s, col3=%s\n", col1, col2, col3) + return nil +} diff --git a/examples/std/main_test.go b/examples/std/main_test.go index e11be436ad..8e080ae9b4 100644 --- a/examples/std/main_test.go +++ b/examples/std/main_test.go @@ -99,6 +99,10 @@ func TestStdQueryRows(t *testing.T) { require.NoError(t, QueryRows()) } +func TestStdQueryWithParameters(t *testing.T) { + require.NoError(t, QueryWithParameters()) +} + func TestStdAsyncInsert(t *testing.T) { require.NoError(t, AsyncInsert()) } diff --git a/examples/std/query_parameters.go b/examples/std/query_parameters.go new file mode 100644 index 0000000000..1c0c25eacd --- /dev/null +++ b/examples/std/query_parameters.go @@ -0,0 +1,52 @@ +// Licensed to ClickHouse, Inc. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. ClickHouse, Inc. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package std + +import ( + "fmt" + "github.com/ClickHouse/clickhouse-go/v2" + "github.com/ClickHouse/clickhouse-go/v2/tests/std" +) + +func QueryWithParameters() error { + conn, err := GetStdOpenDBConnection(clickhouse.Native, nil, nil, nil) + if err != nil { + return err + } + + if !std.CheckMinServerVersion(conn, 22, 8, 0) { + return nil + } + + row := conn.QueryRow( + "SELECT {num:UInt64} v, {str:String} s, {array:Array(String)} a", + clickhouse.Named("num", "42"), + clickhouse.Named("str", "hello"), + clickhouse.Named("array", "['a', 'b', 'c']"), + ) + var ( + col1 uint64 + col2 string + col3 []string + ) + if err := row.Scan(&col1, &col2, &col3); err != nil { + return err + } + fmt.Printf("row: col1=%d, col2=%s, col3=%s\n", col1, col2, col3) + return nil +} diff --git a/lib/proto/block.go b/lib/proto/block.go index 9015cd3685..57d829c463 100644 --- a/lib/proto/block.go +++ b/lib/proto/block.go @@ -139,11 +139,16 @@ func (b *Block) EncodeHeader(buffer *proto.Buffer, revision uint64) (err error) return nil } -func (b *Block) EncodeColumn(buffer *proto.Buffer, i int) (err error) { +func (b *Block) EncodeColumn(buffer *proto.Buffer, revision uint64, i int) (err error) { if i >= 0 && i < len(b.Columns) { c := b.Columns[i] buffer.PutString(c.Name()) buffer.PutString(string(c.Type())) + + if revision >= DBMS_MIN_REVISION_WITH_CUSTOM_SERIALIZATION { + buffer.PutBool(false) + } + if serialize, ok := c.(column.CustomSerialization); ok { if err := serialize.WriteStatePrefix(buffer); err != nil { return &BlockError{ @@ -167,7 +172,7 @@ func (b *Block) Encode(buffer *proto.Buffer, revision uint64) (err error) { return err } for i := range b.Columns { - if err := b.EncodeColumn(buffer, i); err != nil { + if err := b.EncodeColumn(buffer, revision, i); err != nil { return err } } @@ -213,6 +218,20 @@ func (b *Block) Decode(reader *proto.Reader, revision uint64) (err error) { if err != nil { return err } + + if revision >= DBMS_MIN_REVISION_WITH_CUSTOM_SERIALIZATION { + hasCustom, err := reader.Bool() + if err != nil { + return err + } + if hasCustom { + return &BlockError{ + Op: "Decode", + Err: errors.New(fmt.Sprintf("custom serialization for column %s. not supported", columnName)), + } + } + } + if numRows != 0 { if serialize, ok := c.(column.CustomSerialization); ok { if err := serialize.ReadStatePrefix(reader); err != nil { diff --git a/lib/proto/const.go b/lib/proto/const.go index 6eac3cdf67..1e7321b981 100644 --- a/lib/proto/const.go +++ b/lib/proto/const.go @@ -32,7 +32,11 @@ const ( DBMS_MIN_PROTOCOL_VERSION_WITH_INITIAL_QUERY_START_TIME = 54449 DBMS_MIN_PROTOCOL_VERSION_WITH_INCREMENTAL_PROFILE_EVENTS = 54451 DBMS_MIN_REVISION_WITH_PARALLEL_REPLICAS = 54453 - DBMS_TCP_PROTOCOL_VERSION = DBMS_MIN_REVISION_WITH_PARALLEL_REPLICAS + DBMS_MIN_REVISION_WITH_CUSTOM_SERIALIZATION = 54454 + DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM = 54458 + DBMS_MIN_PROTOCOL_VERSION_WITH_QUOTA_KEY = 54458 + DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS = 54459 + DBMS_TCP_PROTOCOL_VERSION = DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS ) const ( diff --git a/lib/proto/query.go b/lib/proto/query.go index d70dbd4583..a3d5980b2e 100644 --- a/lib/proto/query.go +++ b/lib/proto/query.go @@ -23,6 +23,7 @@ import ( chproto "github.com/ClickHouse/ch-go/proto" "go.opentelemetry.io/otel/trace" "os" + "strings" ) var ( @@ -36,6 +37,7 @@ type Query struct { Body string QuotaKey string Settings Settings + Parameters Parameters Compression bool InitialUser string InitialAddress string @@ -61,6 +63,14 @@ func (q *Query) Encode(buffer *chproto.Buffer, revision uint64) error { buffer.PutBool(q.Compression) } buffer.PutString(q.Body) + + if revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS { + if err := q.Parameters.Encode(buffer, revision); err != nil { + return err + } + buffer.PutString("") /* empty string is a marker of the end of parameters */ + } + return nil } @@ -163,3 +173,27 @@ func (s *Setting) encode(buffer *chproto.Buffer, revision uint64) error { buffer.PutString(fmt.Sprint(s.Value)) return nil } + +type Parameters []Parameter + +type Parameter struct { + Key string + Value string +} + +func (s Parameters) Encode(buffer *chproto.Buffer, revision uint64) error { + for _, s := range s { + if err := s.encode(buffer, revision); err != nil { + return err + } + } + return nil +} + +func (s *Parameter) encode(buffer *chproto.Buffer, revision uint64) error { + buffer.PutString(s.Key) + buffer.PutUVarInt(uint64(0x02)) + buffer.PutString(fmt.Sprintf("'%v'", strings.ReplaceAll(s.Value, "'", "\\'"))) + + return nil +} diff --git a/query_parameters.go b/query_parameters.go new file mode 100644 index 0000000000..cba95848fa --- /dev/null +++ b/query_parameters.go @@ -0,0 +1,60 @@ +// Licensed to ClickHouse, Inc. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. ClickHouse, Inc. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package clickhouse + +import ( + "github.com/ClickHouse/clickhouse-go/v2/lib/driver" + "github.com/pkg/errors" + "regexp" + "time" +) + +var ( + ErrExpectedStringValueInNamedValueForQueryParameter = errors.New("expected string value in NamedValue for query parameter") + + hasQueryParamsRe = regexp.MustCompile("{.+:.+}") +) + +func bindQueryOrAppendParameters(paramsProtocolSupport bool, options *QueryOptions, query string, timezone *time.Location, args ...interface{}) (string, error) { + // prefer native query parameters over legacy bind if query parameters provided explicit + if len(options.parameters) > 0 { + return query, nil + } + + // validate if query contains a {:} syntax, so it's intentional use of query parameters + // parameter values will be loaded from `args ...interface{}` for compatibility + if paramsProtocolSupport && + len(args) > 0 && + hasQueryParamsRe.MatchString(query) { + options.parameters = make(Parameters, len(args)) + for _, a := range args { + if p, ok := a.(driver.NamedValue); ok { + if str, ok := p.Value.(string); ok { + options.parameters[p.Name] = str + continue + } + } + + return "", ErrExpectedStringValueInNamedValueForQueryParameter + } + + return query, nil + } + + return bind(timezone, query, args...) +} diff --git a/tests/query_parameters_test.go b/tests/query_parameters_test.go new file mode 100644 index 0000000000..f3f62e9837 --- /dev/null +++ b/tests/query_parameters_test.go @@ -0,0 +1,113 @@ +// Licensed to ClickHouse, Inc. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. ClickHouse, Inc. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package tests + +import ( + "context" + "fmt" + "github.com/ClickHouse/clickhouse-go/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +func TestQueryParameters(t *testing.T) { + ctx := context.Background() + + env, err := GetTestEnvironment(testSet) + require.NoError(t, err) + client, err := testClientWithDefaultSettings(env) + require.NoError(t, err) + defer client.Close() + + if !CheckMinServerServerVersion(client, 22, 8, 0) { + t.Skip(fmt.Errorf("unsupported clickhouse version")) + return + } + + t.Run("with context parameters", func(t *testing.T) { + chCtx := clickhouse.Context(ctx, clickhouse.WithParameters(clickhouse.Parameters{ + "num": "42", + "str": "hello", + "array": "['a', 'b', 'c']", + })) + + var actualNum uint64 + var actualStr string + var actualArray []string + row := client.QueryRow(chCtx, "SELECT {num:UInt64} v, {str:String} s, {array:Array(String)} a") + require.NoError(t, row.Err()) + require.NoError(t, row.Scan(&actualNum, &actualStr, &actualArray)) + + assert.Equal(t, uint64(42), actualNum) + assert.Equal(t, "hello", actualStr) + assert.Equal(t, []string{"a", "b", "c"}, actualArray) + }) + + t.Run("with named arguments", func(t *testing.T) { + var actualNum uint64 + var actualStr string + row := client.QueryRow( + ctx, + "SELECT {num:UInt64}, {str:String}", + clickhouse.Named("num", "42"), + clickhouse.Named("str", "hello"), + ) + require.NoError(t, row.Err()) + require.NoError(t, row.Scan(&actualNum, &actualStr)) + + assert.Equal(t, uint64(42), actualNum) + assert.Equal(t, "hello", actualStr) + }) + + t.Run("named args with only strings supported", func(t *testing.T) { + row := client.QueryRow( + ctx, + "SELECT {num:UInt64}, {str:String}", + clickhouse.Named("num", 42), + clickhouse.Named("str", "hello"), + ) + require.ErrorIs(t, row.Err(), clickhouse.ErrExpectedStringValueInNamedValueForQueryParameter) + }) + + t.Run("unsupported arg type", func(t *testing.T) { + row := client.QueryRow( + ctx, + "SELECT {num:UInt64}, {str:String}", + 1234, + "String", + ) + require.ErrorIs(t, row.Err(), clickhouse.ErrExpectedStringValueInNamedValueForQueryParameter) + }) + + t.Run("with bind backwards compatibility", func(t *testing.T) { + var actualNum uint8 + var actualStr string + row := client.QueryRow( + ctx, + "SELECT @num, @str", + clickhouse.Named("num", 42), + clickhouse.Named("str", "hello"), + ) + require.NoError(t, row.Err()) + require.NoError(t, row.Scan(&actualNum, &actualStr)) + + assert.Equal(t, uint8(42), actualNum) + assert.Equal(t, "hello", actualStr) + }) +} diff --git a/tests/std/query_parameters_test.go b/tests/std/query_parameters_test.go new file mode 100644 index 0000000000..3a582da1e2 --- /dev/null +++ b/tests/std/query_parameters_test.go @@ -0,0 +1,96 @@ +// Licensed to ClickHouse, Inc. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. ClickHouse, Inc. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package std + +import ( + "fmt" + "github.com/ClickHouse/clickhouse-go/v2" + clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "strconv" + "testing" +) + +func TestQueryParameters(t *testing.T) { + env, err := GetStdTestEnvironment() + require.NoError(t, err) + require.NoError(t, err) + useSSL, err := strconv.ParseBool(clickhouse_tests.GetEnv("CLICKHOUSE_USE_SSL", "false")) + require.NoError(t, err) + connectionString := fmt.Sprintf("http://%s:%d?username=%s&password=%s&dial_timeout=200ms&max_execution_time=60", env.Host, env.HttpPort, env.Username, env.Password) + if useSSL { + connectionString = fmt.Sprintf("https://%s:%d?username=%s&password=%s&dial_timeout=200ms&max_execution_time=60&secure=true", env.Host, env.HttpsPort, env.Username, env.Password) + } + dsns := map[string]string{"Http": connectionString} + + for name, dsn := range dsns { + t.Run(fmt.Sprintf("%s Protocol", name), func(t *testing.T) { + conn, err := GetConnectionFromDSN(dsn) + require.NoError(t, err) + + t.Run("with named arguments", func(t *testing.T) { + var actualNum uint64 + var actualStr string + row := conn.QueryRow( + "SELECT {num:UInt64}, {str:String}", + clickhouse.Named("num", "42"), + clickhouse.Named("str", "hello"), + ) + require.NoError(t, row.Err()) + require.NoError(t, row.Scan(&actualNum, &actualStr)) + + assert.Equal(t, uint64(42), actualNum) + assert.Equal(t, "hello", actualStr) + }) + + t.Run("named args with only strings supported", func(t *testing.T) { + row := conn.QueryRow( + "SELECT {num:UInt64}, {str:String}", + clickhouse.Named("num", 42), + clickhouse.Named("str", "hello"), + ) + require.ErrorIs(t, row.Err(), clickhouse.ErrExpectedStringValueInNamedValueForQueryParameter) + }) + + t.Run("unsupported arg type", func(t *testing.T) { + row := conn.QueryRow( + "SELECT {num:UInt64}, {str:String}", + 1234, + "String", + ) + require.ErrorIs(t, row.Err(), clickhouse.ErrExpectedStringValueInNamedValueForQueryParameter) + }) + + t.Run("with bind backwards compatibility", func(t *testing.T) { + var actualNum uint8 + var actualStr string + row := conn.QueryRow( + "SELECT @num, @str", + clickhouse.Named("num", 42), + clickhouse.Named("str", "hello"), + ) + require.NoError(t, row.Err()) + require.NoError(t, row.Scan(&actualNum, &actualStr)) + + assert.Equal(t, uint8(42), actualNum) + assert.Equal(t, "hello", actualStr) + }) + }) + } +}