From 81825b1d9b85d98d1bc9814dd2e6fc5c20a94c88 Mon Sep 17 00:00:00 2001 From: YangKeao Date: Thu, 9 Nov 2023 18:42:42 +0800 Subject: [PATCH] This is an automated cherry-pick of #48237 Signed-off-by: ti-chi-bot --- pkg/param/BUILD.bazel | 16 + pkg/param/binary_params.go | 275 +++++++++++++++++ pkg/server/BUILD.bazel | 197 ++++++++++++ pkg/server/internal/parse/BUILD.bazel | 32 ++ pkg/server/internal/parse/parse.go | 173 +++++++++++ pkg/server/internal/parse/parse_test.go | 45 +++ pkg/session/BUILD.bazel | 169 +++++++++++ server/conn_stmt.go | 33 +- server/conn_stmt_params.go | 130 ++++++++ server/conn_stmt_params_test.go | 380 ++++++++++++++++++++++++ server/extension.go | 26 +- session/session.go | 81 +++++ 12 files changed, 1551 insertions(+), 6 deletions(-) create mode 100644 pkg/param/BUILD.bazel create mode 100644 pkg/param/binary_params.go create mode 100644 pkg/server/BUILD.bazel create mode 100644 pkg/server/internal/parse/BUILD.bazel create mode 100644 pkg/server/internal/parse/parse.go create mode 100644 pkg/server/internal/parse/parse_test.go create mode 100644 pkg/session/BUILD.bazel create mode 100644 server/conn_stmt_params.go create mode 100644 server/conn_stmt_params_test.go diff --git a/pkg/param/BUILD.bazel b/pkg/param/BUILD.bazel new file mode 100644 index 0000000000000..1b2204b5e7f45 --- /dev/null +++ b/pkg/param/BUILD.bazel @@ -0,0 +1,16 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "param", + srcs = ["binary_params.go"], + importpath = "github.com/pingcap/tidb/pkg/param", + visibility = ["//visibility:public"], + deps = [ + "//pkg/errno", + "//pkg/expression", + "//pkg/parser/mysql", + "//pkg/types", + "//pkg/util/dbterror", + "//pkg/util/hack", + ], +) diff --git a/pkg/param/binary_params.go b/pkg/param/binary_params.go new file mode 100644 index 0000000000000..ce016f5016918 --- /dev/null +++ b/pkg/param/binary_params.go @@ -0,0 +1,275 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package param + +import ( + "encoding/binary" + "fmt" + "math" + + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/hack" +) + +var errUnknownFieldType = dbterror.ClassServer.NewStd(errno.ErrUnknownFieldType) + +// BinaryParam stores the information decoded from the binary protocol +// It can be further parsed into `expression.Expression` through the `ExecArgs` function in this package +type BinaryParam struct { + Tp byte + IsUnsigned bool + IsNull bool + Val []byte +} + +// ExecArgs parse execute arguments to datum slice. +func ExecArgs(typectx types.Context, binaryParams []BinaryParam) (params []expression.Expression, err error) { + var ( + tmp interface{} + ) + + params = make([]expression.Expression, len(binaryParams)) + args := make([]types.Datum, len(binaryParams)) + for i := 0; i < len(args); i++ { + tp := binaryParams[i].Tp + isUnsigned := binaryParams[i].IsUnsigned + + switch tp { + case mysql.TypeNull: + var nilDatum types.Datum + nilDatum.SetNull() + args[i] = nilDatum + continue + + case mysql.TypeTiny: + if isUnsigned { + args[i] = types.NewUintDatum(uint64(binaryParams[i].Val[0])) + } else { + args[i] = types.NewIntDatum(int64(int8(binaryParams[i].Val[0]))) + } + continue + + case mysql.TypeShort, mysql.TypeYear: + valU16 := binary.LittleEndian.Uint16(binaryParams[i].Val) + if isUnsigned { + args[i] = types.NewUintDatum(uint64(valU16)) + } else { + args[i] = types.NewIntDatum(int64(int16(valU16))) + } + continue + + case mysql.TypeInt24, mysql.TypeLong: + valU32 := binary.LittleEndian.Uint32(binaryParams[i].Val) + if isUnsigned { + args[i] = types.NewUintDatum(uint64(valU32)) + } else { + args[i] = types.NewIntDatum(int64(int32(valU32))) + } + continue + + case mysql.TypeLonglong: + valU64 := binary.LittleEndian.Uint64(binaryParams[i].Val) + if isUnsigned { + args[i] = types.NewUintDatum(valU64) + } else { + args[i] = types.NewIntDatum(int64(valU64)) + } + continue + + case mysql.TypeFloat: + args[i] = types.NewFloat32Datum(math.Float32frombits(binary.LittleEndian.Uint32(binaryParams[i].Val))) + continue + + case mysql.TypeDouble: + args[i] = types.NewFloat64Datum(math.Float64frombits(binary.LittleEndian.Uint64(binaryParams[i].Val))) + continue + + case mysql.TypeDate, mysql.TypeTimestamp, mysql.TypeDatetime: + switch len(binaryParams[i].Val) { + case 0: + tmp = types.ZeroDatetimeStr + case 4: + _, tmp = binaryDate(0, binaryParams[i].Val) + case 7: + _, tmp = binaryDateTime(0, binaryParams[i].Val) + case 11: + _, tmp = binaryTimestamp(0, binaryParams[i].Val) + case 13: + _, tmp = binaryTimestampWithTZ(0, binaryParams[i].Val) + default: + err = mysql.ErrMalformPacket + return + } + // TODO: generate the time datum directly + var parseTime func(types.Context, string) (types.Time, error) + switch tp { + case mysql.TypeDate: + parseTime = types.ParseDate + case mysql.TypeDatetime: + parseTime = types.ParseDatetime + case mysql.TypeTimestamp: + // To be compatible with MySQL, even the type of parameter is + // TypeTimestamp, the return type should also be `Datetime`. + parseTime = types.ParseDatetime + } + var time types.Time + time, err = parseTime(typectx, tmp.(string)) + err = typectx.HandleTruncate(err) + if err != nil { + return + } + args[i] = types.NewDatum(time) + continue + + case mysql.TypeDuration: + switch len(binaryParams[i].Val) { + case 0: + tmp = "0" + case 8: + isNegative := binaryParams[i].Val[0] + if isNegative > 1 { + err = mysql.ErrMalformPacket + return + } + _, tmp = binaryDuration(1, binaryParams[i].Val, isNegative) + case 12: + isNegative := binaryParams[i].Val[0] + if isNegative > 1 { + err = mysql.ErrMalformPacket + return + } + _, tmp = binaryDurationWithMS(1, binaryParams[i].Val, isNegative) + default: + err = mysql.ErrMalformPacket + return + } + // TODO: generate the duration datum directly + var dur types.Duration + dur, _, err = types.ParseDuration(typectx, tmp.(string), types.MaxFsp) + err = typectx.HandleTruncate(err) + if err != nil { + return + } + args[i] = types.NewDatum(dur) + continue + case mysql.TypeNewDecimal: + if binaryParams[i].IsNull { + args[i] = types.NewDecimalDatum(nil) + } else { + var dec types.MyDecimal + err = typectx.HandleTruncate(dec.FromString(binaryParams[i].Val)) + if err != nil { + return nil, err + } + args[i] = types.NewDecimalDatum(&dec) + } + continue + case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + if binaryParams[i].IsNull { + args[i] = types.NewBytesDatum(nil) + } else { + args[i] = types.NewBytesDatum(binaryParams[i].Val) + } + continue + case mysql.TypeUnspecified, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, + mysql.TypeEnum, mysql.TypeSet, mysql.TypeGeometry, mysql.TypeBit: + if !binaryParams[i].IsNull { + tmp = string(hack.String(binaryParams[i].Val)) + } else { + tmp = nil + } + args[i] = types.NewDatum(tmp) + continue + default: + err = errUnknownFieldType.GenWithStack("stmt unknown field type %d", tp) + return + } + } + + for i := range params { + ft := new(types.FieldType) + types.InferParamTypeFromUnderlyingValue(args[i].GetValue(), ft) + params[i] = &expression.Constant{Value: args[i], RetType: ft} + } + return +} + +func binaryDate(pos int, paramValues []byte) (int, string) { + year := binary.LittleEndian.Uint16(paramValues[pos : pos+2]) + pos += 2 + month := paramValues[pos] + pos++ + day := paramValues[pos] + pos++ + return pos, fmt.Sprintf("%04d-%02d-%02d", year, month, day) +} + +func binaryDateTime(pos int, paramValues []byte) (int, string) { + pos, date := binaryDate(pos, paramValues) + hour := paramValues[pos] + pos++ + minute := paramValues[pos] + pos++ + second := paramValues[pos] + pos++ + return pos, fmt.Sprintf("%s %02d:%02d:%02d", date, hour, minute, second) +} + +func binaryTimestamp(pos int, paramValues []byte) (int, string) { + pos, dateTime := binaryDateTime(pos, paramValues) + microSecond := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) + pos += 4 + return pos, fmt.Sprintf("%s.%06d", dateTime, microSecond) +} + +func binaryTimestampWithTZ(pos int, paramValues []byte) (int, string) { + pos, timestamp := binaryTimestamp(pos, paramValues) + tzShiftInMin := int16(binary.LittleEndian.Uint16(paramValues[pos : pos+2])) + tzShiftHour := tzShiftInMin / 60 + tzShiftAbsMin := tzShiftInMin % 60 + if tzShiftAbsMin < 0 { + tzShiftAbsMin = -tzShiftAbsMin + } + pos += 2 + return pos, fmt.Sprintf("%s%+02d:%02d", timestamp, tzShiftHour, tzShiftAbsMin) +} + +func binaryDuration(pos int, paramValues []byte, isNegative uint8) (int, string) { + sign := "" + if isNegative == 1 { + sign = "-" + } + days := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) + pos += 4 + hours := paramValues[pos] + pos++ + minutes := paramValues[pos] + pos++ + seconds := paramValues[pos] + pos++ + return pos, fmt.Sprintf("%s%d %02d:%02d:%02d", sign, days, hours, minutes, seconds) +} + +func binaryDurationWithMS(pos int, paramValues []byte, + isNegative uint8) (int, string) { + pos, dur := binaryDuration(pos, paramValues, isNegative) + microSecond := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) + pos += 4 + return pos, fmt.Sprintf("%s.%06d", dur, microSecond) +} diff --git a/pkg/server/BUILD.bazel b/pkg/server/BUILD.bazel new file mode 100644 index 0000000000000..fbe3f826b0807 --- /dev/null +++ b/pkg/server/BUILD.bazel @@ -0,0 +1,197 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "server", + srcs = [ + "conn.go", + "conn_stmt.go", + "conn_stmt_params.go", + "driver.go", + "driver_tidb.go", + "extension.go", + "extract.go", + "http_handler.go", + "http_status.go", + "mock_conn.go", + "rpc_server.go", + "server.go", + "stat.go", + "tokenlimiter.go", + ], + importpath = "github.com/pingcap/tidb/pkg/server", + visibility = ["//visibility:public"], + deps = [ + "//pkg/autoid_service", + "//pkg/config", + "//pkg/domain", + "//pkg/domain/infosync", + "//pkg/errno", + "//pkg/executor", + "//pkg/executor/mppcoordmanager", + "//pkg/expression", + "//pkg/extension", + "//pkg/infoschema", + "//pkg/kv", + "//pkg/metrics", + "//pkg/param", + "//pkg/parser", + "//pkg/parser/ast", + "//pkg/parser/auth", + "//pkg/parser/charset", + "//pkg/parser/model", + "//pkg/parser/mysql", + "//pkg/parser/terror", + "//pkg/planner/core", + "//pkg/plugin", + "//pkg/privilege", + "//pkg/privilege/conn", + "//pkg/privilege/privileges", + "//pkg/privilege/privileges/ldap", + "//pkg/server/err", + "//pkg/server/handler", + "//pkg/server/handler/extactorhandler", + "//pkg/server/handler/optimizor", + "//pkg/server/handler/tikvhandler", + "//pkg/server/handler/ttlhandler", + "//pkg/server/internal", + "//pkg/server/internal/column", + "//pkg/server/internal/dump", + "//pkg/server/internal/handshake", + "//pkg/server/internal/parse", + "//pkg/server/internal/resultset", + "//pkg/server/internal/util", + "//pkg/server/metrics", + "//pkg/session", + "//pkg/session/txninfo", + "//pkg/sessionctx", + "//pkg/sessionctx/sessionstates", + "//pkg/sessionctx/stmtctx", + "//pkg/sessionctx/variable", + "//pkg/sessiontxn", + "//pkg/statistics/handle", + "//pkg/store", + "//pkg/store/driver/error", + "//pkg/store/helper", + "//pkg/tablecodec", + "//pkg/types", + "//pkg/util", + "//pkg/util/arena", + "//pkg/util/chunk", + "//pkg/util/cpuprofile", + "//pkg/util/dbterror", + "//pkg/util/dbterror/exeerrors", + "//pkg/util/execdetails", + "//pkg/util/fastrand", + "//pkg/util/hack", + "//pkg/util/intest", + "//pkg/util/logutil", + "//pkg/util/memory", + "//pkg/util/printer", + "//pkg/util/sqlexec", + "//pkg/util/sqlkiller", + "//pkg/util/sys/linux", + "//pkg/util/timeutil", + "//pkg/util/tls", + "//pkg/util/topsql", + "//pkg/util/topsql/state", + "//pkg/util/topsql/stmtstats", + "//pkg/util/tracing", + "//pkg/util/versioninfo", + "@com_github_blacktear23_go_proxyprotocol//:go-proxyprotocol", + "@com_github_gorilla_mux//:mux", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", + "@com_github_pingcap_fn//:fn", + "@com_github_pingcap_kvproto//pkg/autoid", + "@com_github_pingcap_kvproto//pkg/coprocessor", + "@com_github_pingcap_kvproto//pkg/diagnosticspb", + "@com_github_pingcap_kvproto//pkg/mpp", + "@com_github_pingcap_kvproto//pkg/tikvpb", + "@com_github_pingcap_sysutil//:sysutil", + "@com_github_prometheus_client_golang//prometheus", + "@com_github_prometheus_client_golang//prometheus/promhttp", + "@com_github_soheilhy_cmux//:cmux", + "@com_github_stretchr_testify//require", + "@com_github_tiancaiamao_appdash//traceapp", + "@com_github_tikv_client_go_v2//util", + "@com_sourcegraph_sourcegraph_appdash_data//:appdash-data", + "@org_golang_google_grpc//:grpc", + "@org_golang_google_grpc//channelz/service", + "@org_golang_google_grpc//keepalive", + "@org_golang_google_grpc//peer", + "@org_uber_go_atomic//:atomic", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "server_test", + timeout = "short", + srcs = [ + "conn_stmt_params_test.go", + "conn_stmt_test.go", + "conn_test.go", + "driver_tidb_test.go", + "main_test.go", + "mock_conn_test.go", + "server_test.go", + "stat_test.go", + "tidb_library_test.go", + "tidb_test.go", + ], + data = glob(["testdata/**"]), + embed = [":server"], + flaky = True, + shard_count = 50, + deps = [ + "//pkg/config", + "//pkg/domain", + "//pkg/domain/infosync", + "//pkg/expression", + "//pkg/extension", + "//pkg/keyspace", + "//pkg/kv", + "//pkg/metrics", + "//pkg/param", + "//pkg/parser/ast", + "//pkg/parser/auth", + "//pkg/parser/charset", + "//pkg/parser/model", + "//pkg/parser/mysql", + "//pkg/parser/terror", + "//pkg/server/internal", + "//pkg/server/internal/column", + "//pkg/server/internal/handshake", + "//pkg/server/internal/parse", + "//pkg/server/internal/testutil", + "//pkg/server/internal/util", + "//pkg/session", + "//pkg/sessionctx/variable", + "//pkg/sessiontxn", + "//pkg/store/mockstore", + "//pkg/store/mockstore/unistore", + "//pkg/testkit", + "//pkg/testkit/external", + "//pkg/testkit/testdata", + "//pkg/testkit/testmain", + "//pkg/testkit/testsetup", + "//pkg/types", + "//pkg/util", + "//pkg/util/arena", + "//pkg/util/chunk", + "//pkg/util/dbterror/exeerrors", + "//pkg/util/replayer", + "//pkg/util/sqlkiller", + "//pkg/util/syncutil", + "//pkg/util/topsql/state", + "@com_github_docker_go_units//:go-units", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", + "@com_github_pingcap_kvproto//pkg/metapb", + "@com_github_stretchr_testify//require", + "@com_github_tikv_client_go_v2//error", + "@com_github_tikv_client_go_v2//testutils", + "@com_github_tikv_client_go_v2//tikv", + "@org_uber_go_goleak//:goleak", + ], +) diff --git a/pkg/server/internal/parse/BUILD.bazel b/pkg/server/internal/parse/BUILD.bazel new file mode 100644 index 0000000000000..dfab3da1f7923 --- /dev/null +++ b/pkg/server/internal/parse/BUILD.bazel @@ -0,0 +1,32 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "parse", + srcs = ["parse.go"], + importpath = "github.com/pingcap/tidb/pkg/server/internal/parse", + visibility = ["//pkg/server:__subpackages__"], + deps = [ + "//pkg/parser/mysql", + "//pkg/server/internal/handshake", + "//pkg/server/internal/util", + "//pkg/util/logutil", + "@com_github_klauspost_compress//zstd", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "parse_test", + timeout = "short", + srcs = [ + "handshake_test.go", + "parse_test.go", + ], + embed = [":parse"], + flaky = True, + deps = [ + "//pkg/parser/mysql", + "//pkg/server/internal/handshake", + "@com_github_stretchr_testify//require", + ], +) diff --git a/pkg/server/internal/parse/parse.go b/pkg/server/internal/parse/parse.go new file mode 100644 index 0000000000000..de7571d2c287c --- /dev/null +++ b/pkg/server/internal/parse/parse.go @@ -0,0 +1,173 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package parse + +import ( + "bytes" + "context" + "encoding/binary" + + "github.com/klauspost/compress/zstd" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/server/internal/handshake" + util2 "github.com/pingcap/tidb/pkg/server/internal/util" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" +) + +// maxFetchSize constants +const ( + maxFetchSize = 1024 +) + +// StmtFetchCmd parse COM_STMT_FETCH command +func StmtFetchCmd(data []byte) (stmtID uint32, fetchSize uint32, err error) { + if len(data) != 8 { + return 0, 0, mysql.ErrMalformPacket + } + // Please refer to https://dev.mysql.com/doc/internals/en/com-stmt-fetch.html + stmtID = binary.LittleEndian.Uint32(data[0:4]) + fetchSize = binary.LittleEndian.Uint32(data[4:8]) + if fetchSize > maxFetchSize { + fetchSize = maxFetchSize + } + return +} + +// HandshakeResponseHeader parses the common header of SSLRequest and Response41. +func HandshakeResponseHeader(ctx context.Context, packet *handshake.Response41, data []byte) (parsedBytes int, err error) { + // Ensure there are enough data to read: + // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest + if len(data) < 4+4+1+23 { + logutil.Logger(ctx).Error("got malformed handshake response", zap.ByteString("packetData", data)) + return 0, mysql.ErrMalformPacket + } + + offset := 0 + // capability + capability := binary.LittleEndian.Uint32(data[:4]) + packet.Capability = capability + offset += 4 + // skip max packet size + offset += 4 + // charset, skip, if you want to use another charset, use set names + packet.Collation = data[offset] + offset++ + // skip reserved 23[00] + offset += 23 + + return offset, nil +} + +// HandshakeResponseBody parse the HandshakeResponse (except the common header part). +func HandshakeResponseBody(ctx context.Context, packet *handshake.Response41, data []byte, offset int) (err error) { + defer func() { + // Check malformat packet cause out of range is disgusting, but don't panic! + if r := recover(); r != nil { + logutil.Logger(ctx).Error("handshake panic", zap.ByteString("packetData", data)) + err = mysql.ErrMalformPacket + } + }() + // user name + packet.User = string(data[offset : offset+bytes.IndexByte(data[offset:], 0)]) + offset += len(packet.User) + 1 + + if packet.Capability&mysql.ClientPluginAuthLenencClientData > 0 { + // MySQL client sets the wrong capability, it will set this bit even server doesn't + // support ClientPluginAuthLenencClientData. + // https://github.com/mysql/mysql-server/blob/5.7/sql-common/client.c#L3478 + if data[offset] == 0x1 { // No auth data + offset += 2 + } else { + num, null, off := util2.ParseLengthEncodedInt(data[offset:]) + offset += off + if !null { + packet.Auth = data[offset : offset+int(num)] + offset += int(num) + } + } + } else if packet.Capability&mysql.ClientSecureConnection > 0 { + // auth length and auth + authLen := int(data[offset]) + offset++ + packet.Auth = data[offset : offset+authLen] + offset += authLen + } else { + packet.Auth = data[offset : offset+bytes.IndexByte(data[offset:], 0)] + offset += len(packet.Auth) + 1 + } + + if packet.Capability&mysql.ClientConnectWithDB > 0 { + if len(data[offset:]) > 0 { + idx := bytes.IndexByte(data[offset:], 0) + packet.DBName = string(data[offset : offset+idx]) + offset += idx + 1 + } + } + + if packet.Capability&mysql.ClientPluginAuth > 0 { + idx := bytes.IndexByte(data[offset:], 0) + s := offset + f := offset + idx + if s < f { // handle unexpected bad packets + packet.AuthPlugin = string(data[s:f]) + } + offset += idx + 1 + } + + if packet.Capability&mysql.ClientConnectAtts > 0 { + if len(data[offset:]) == 0 { + // Defend some ill-formated packet, connection attribute is not important and can be ignored. + return nil + } + if num, null, intOff := util2.ParseLengthEncodedInt(data[offset:]); !null { + offset += intOff // Length of variable length encoded integer itself in bytes + row := data[offset : offset+int(num)] + attrs, err := parseAttrs(row) + if err != nil { + logutil.Logger(ctx).Warn("parse attrs failed", zap.Error(err)) + return nil + } + packet.Attrs = attrs + offset += int(num) // Length of attributes + } + } + + if packet.Capability&mysql.ClientZstdCompressionAlgorithm > 0 { + packet.ZstdLevel = zstd.EncoderLevelFromZstd(int(data[offset])) + } + + return nil +} + +func parseAttrs(data []byte) (map[string]string, error) { + attrs := make(map[string]string) + pos := 0 + for pos < len(data) { + key, _, off, err := util2.ParseLengthEncodedBytes(data[pos:]) + if err != nil { + return attrs, err + } + pos += off + value, _, off, err := util2.ParseLengthEncodedBytes(data[pos:]) + if err != nil { + return attrs, err + } + pos += off + + attrs[string(key)] = string(value) + } + return attrs, nil +} diff --git a/pkg/server/internal/parse/parse_test.go b/pkg/server/internal/parse/parse_test.go new file mode 100644 index 0000000000000..cc44038cb50a5 --- /dev/null +++ b/pkg/server/internal/parse/parse_test.go @@ -0,0 +1,45 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package parse + +import ( + "testing" + + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/stretchr/testify/require" +) + +func TestParseStmtFetchCmd(t *testing.T) { + tests := []struct { + arg []byte + stmtID uint32 + fetchSize uint32 + err error + }{ + {[]byte{3, 0, 0, 0, 50, 0, 0, 0}, 3, 50, nil}, + {[]byte{5, 0, 0, 0, 232, 3, 0, 0}, 5, 1000, nil}, + {[]byte{5, 0, 0, 0, 0, 8, 0, 0}, 5, maxFetchSize, nil}, + {[]byte{5, 0, 0}, 0, 0, mysql.ErrMalformPacket}, + {[]byte{1, 0, 0, 0, 3, 2, 0, 0, 3, 5, 6}, 0, 0, mysql.ErrMalformPacket}, + {[]byte{}, 0, 0, mysql.ErrMalformPacket}, + } + + for _, tc := range tests { + stmtID, fetchSize, err := StmtFetchCmd(tc.arg) + require.Equal(t, tc.stmtID, stmtID) + require.Equal(t, tc.fetchSize, fetchSize) + require.Equal(t, tc.err, err) + } +} diff --git a/pkg/session/BUILD.bazel b/pkg/session/BUILD.bazel new file mode 100644 index 0000000000000..5a27a9785a923 --- /dev/null +++ b/pkg/session/BUILD.bazel @@ -0,0 +1,169 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "session", + srcs = [ + "advisory_locks.go", + "bootstrap.go", + "mock_bootstrap.go", + "nontransactional.go", + "session.go", + "sync_upgrade.go", + "testutil.go", #keep + "tidb.go", + "txn.go", + "txnmanager.go", + ], + importpath = "github.com/pingcap/tidb/pkg/session", + visibility = ["//visibility:public"], + deps = [ + "//pkg/bindinfo", + "//pkg/config", + "//pkg/ddl", + "//pkg/ddl/placement", + "//pkg/ddl/schematracker", + "//pkg/ddl/syncer", + "//pkg/domain", + "//pkg/domain/infosync", + "//pkg/errno", + "//pkg/executor", + "//pkg/expression", + "//pkg/extension", + "//pkg/extension/extensionimpl", + "//pkg/infoschema", + "//pkg/kv", + "//pkg/meta", + "//pkg/metrics", + "//pkg/owner", + "//pkg/param", + "//pkg/parser", + "//pkg/parser/ast", + "//pkg/parser/auth", + "//pkg/parser/charset", + "//pkg/parser/format", + "//pkg/parser/model", + "//pkg/parser/mysql", + "//pkg/parser/opcode", + "//pkg/parser/terror", + "//pkg/planner", + "//pkg/planner/core", + "//pkg/plugin", + "//pkg/privilege", + "//pkg/privilege/conn", + "//pkg/privilege/privileges", + "//pkg/session/metrics", + "//pkg/session/txninfo", + "//pkg/sessionctx", + "//pkg/sessionctx/binloginfo", + "//pkg/sessionctx/sessionstates", + "//pkg/sessionctx/stmtctx", + "//pkg/sessionctx/variable", + "//pkg/sessiontxn", + "//pkg/sessiontxn/isolation", + "//pkg/sessiontxn/staleread", + "//pkg/statistics/handle/usage", + "//pkg/store/driver/error", + "//pkg/store/driver/txn", + "//pkg/store/helper", + "//pkg/store/mockstore", + "//pkg/table", + "//pkg/table/tables", + "//pkg/table/temptable", + "//pkg/tablecodec", + "//pkg/telemetry", + "//pkg/testkit/testenv", + "//pkg/timer/tablestore", + "//pkg/ttl/ttlworker", + "//pkg/types", + "//pkg/types/parser_driver", + "//pkg/util", + "//pkg/util/chunk", + "//pkg/util/collate", + "//pkg/util/dbterror", + "//pkg/util/dbterror/exeerrors", + "//pkg/util/execdetails", + "//pkg/util/intest", + "//pkg/util/kvcache", + "//pkg/util/logutil", + "//pkg/util/logutil/consistency", + "//pkg/util/memory", + "//pkg/util/parser", + "//pkg/util/sem", + "//pkg/util/sli", + "//pkg/util/sqlescape", + "//pkg/util/sqlexec", + "//pkg/util/syncutil", + "//pkg/util/tableutil", + "//pkg/util/timeutil", + "//pkg/util/topsql", + "//pkg/util/topsql/state", + "//pkg/util/topsql/stmtstats", + "//pkg/util/tracing", + "@com_github_ngaut_pools//:pools", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", + "@com_github_pingcap_kvproto//pkg/kvrpcpb", + "@com_github_pingcap_tipb//go-binlog", + "@com_github_stretchr_testify//require", + "@com_github_tikv_client_go_v2//error", + "@com_github_tikv_client_go_v2//kv", + "@com_github_tikv_client_go_v2//oracle", + "@com_github_tikv_client_go_v2//tikv", + "@com_github_tikv_client_go_v2//util", + "@io_etcd_go_etcd_client_v3//concurrency", + "@org_uber_go_atomic//:atomic", + "@org_uber_go_zap//:zap", + "@org_uber_go_zap//zapcore", + ], +) + +go_test( + name = "session_test", + timeout = "moderate", + srcs = [ + "bench_test.go", + "bootstrap_test.go", + "index_usage_sync_lease_test.go", + "main_test.go", + "tidb_test.go", + ], + data = glob(["testdata/**"]), + embed = [":session"], + flaky = True, + race = "on", + shard_count = 50, + deps = [ + "//pkg/autoid_service", + "//pkg/bindinfo", + "//pkg/config", + "//pkg/domain", + "//pkg/executor", + "//pkg/expression", + "//pkg/kv", + "//pkg/meta", + "//pkg/parser/ast", + "//pkg/parser/auth", + "//pkg/planner/core", + "//pkg/sessionctx", + "//pkg/sessionctx/variable", + "//pkg/statistics", + "//pkg/store/mockstore", + "//pkg/tablecodec", + "//pkg/telemetry", + "//pkg/testkit/testmain", + "//pkg/testkit/testsetup", + "//pkg/types", + "//pkg/util", + "//pkg/util/benchdaily", + "//pkg/util/chunk", + "//pkg/util/logutil", + "//pkg/util/sqlexec", + "@com_github_pingcap_log//:log", + "@com_github_stretchr_testify//require", + "@com_github_tikv_client_go_v2//tikv", + "@org_uber_go_atomic//:atomic", #keep + "@org_uber_go_goleak//:goleak", + "@org_uber_go_zap//:zap", + "@org_uber_go_zap//zapcore", + ], +) diff --git a/server/conn_stmt.go b/server/conn_stmt.go index eb9a53ece3958..5dbc8df9e5cc6 100644 --- a/server/conn_stmt.go +++ b/server/conn_stmt.go @@ -46,6 +46,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" +<<<<<<< HEAD:server/conn_stmt.go "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/ast" @@ -65,6 +66,28 @@ import ( "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/topsql" topsqlstate "github.com/pingcap/tidb/util/topsql/state" +======= + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/param" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/server/internal/dump" + "github.com/pingcap/tidb/pkg/server/internal/parse" + "github.com/pingcap/tidb/pkg/server/internal/resultset" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + storeerr "github.com/pingcap/tidb/pkg/store/driver/error" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/topsql" + topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" +>>>>>>> 8ce2ad16961 (parse: fix the type of date/time parameters (#48237)):pkg/server/conn_stmt.go "github.com/tikv/client-go/v2/util" "go.uber.org/zap" ) @@ -181,7 +204,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e ) cc.initInputEncoder(ctx) numParams := stmt.NumParams() - args := make([]expression.Expression, numParams) + args := make([]param.BinaryParam, numParams) if numParams > 0 { nullBitmapLen := (numParams + 7) >> 3 if len(data) < (pos + nullBitmapLen + 1) { @@ -207,7 +230,11 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e paramValues = data[pos+1:] } +<<<<<<< HEAD:server/conn_stmt.go err = parseExecArgs(cc.ctx.GetSessionVars().StmtCtx, args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues, cc.inputDecoder) +======= + err = parseBinaryParams(args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues, cc.inputDecoder) +>>>>>>> 8ce2ad16961 (parse: fix the type of date/time parameters (#48237)):pkg/server/conn_stmt.go // This `.Reset` resets the arguments, so it's fine to just ignore the error (and the it'll be reset again in the following routine) errReset := stmt.Reset() if errReset != nil { @@ -228,7 +255,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e return err } -func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt interface{}, args []expression.Expression, useCursor bool) (err error) { +func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt interface{}, args []param.BinaryParam, useCursor bool) (err error) { ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) ctx = context.WithValue(ctx, util.ExecDetailsKey, &util.ExecDetails{}) retryable, err := cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor) @@ -263,7 +290,7 @@ func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt interface{} // The first return value indicates whether the call of executePreparedStmtAndWriteResult has no side effect and can be retried. // Currently the first return value is used to fallback to TiKV when TiFlash is down. -func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stmt PreparedStatement, args []expression.Expression, useCursor bool) (bool, error) { +func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stmt PreparedStatement, args []param.BinaryParam, useCursor bool) (bool, error) { vars := (&cc.ctx).GetSessionVars() prepStmt, err := vars.GetPreparedStmtByID(uint32(stmt.ID())) if err != nil { diff --git a/server/conn_stmt_params.go b/server/conn_stmt_params.go new file mode 100644 index 0000000000000..5b625c683faac --- /dev/null +++ b/server/conn_stmt_params.go @@ -0,0 +1,130 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/param" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/mysql" + util2 "github.com/pingcap/tidb/pkg/server/internal/util" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/dbterror" +) + +var errUnknownFieldType = dbterror.ClassServer.NewStd(errno.ErrUnknownFieldType) + +// parseBinaryParams decodes the binary params according to the protocol +func parseBinaryParams(params []param.BinaryParam, boundParams [][]byte, nullBitmap, paramTypes, paramValues []byte, enc *util2.InputDecoder) (err error) { + pos := 0 + if enc == nil { + enc = util2.NewInputDecoder(charset.CharsetUTF8) + } + + for i := 0; i < len(params); i++ { + // if params had received via ComStmtSendLongData, use them directly. + // ref https://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html + // see clientConn#handleStmtSendLongData + if boundParams[i] != nil { + params[i] = param.BinaryParam{ + Tp: mysql.TypeBlob, + Val: enc.DecodeInput(boundParams[i]), + } + continue + } + + // check nullBitMap to determine the NULL arguments. + // ref https://dev.mysql.com/doc/internals/en/com-stmt-execute.html + // notice: some client(e.g. mariadb) will set nullBitMap even if data had be sent via ComStmtSendLongData, + // so this check need place after boundParam's check. + if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 { + var nilDatum types.Datum + nilDatum.SetNull() + params[i] = param.BinaryParam{ + Tp: mysql.TypeNull, + } + continue + } + + if (i<<1)+1 >= len(paramTypes) { + return mysql.ErrMalformPacket + } + + tp := paramTypes[i<<1] + isUnsigned := (paramTypes[(i<<1)+1] & 0x80) > 0 + isNull := false + + decodeWithDecoder := false + + var length uint64 + switch tp { + case mysql.TypeNull: + length = 0 + isNull = true + case mysql.TypeTiny: + length = 1 + case mysql.TypeShort, mysql.TypeYear: + length = 2 + case mysql.TypeInt24, mysql.TypeLong, mysql.TypeFloat: + length = 4 + case mysql.TypeLonglong, mysql.TypeDouble: + length = 8 + case mysql.TypeDate, mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDuration: + if len(paramValues) < (pos + 1) { + err = mysql.ErrMalformPacket + return + } + length = uint64(paramValues[pos]) + pos++ + case mysql.TypeNewDecimal, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + if len(paramValues) < (pos + 1) { + err = mysql.ErrMalformPacket + return + } + var n int + length, isNull, n = util2.ParseLengthEncodedInt(paramValues[pos:]) + pos += n + case mysql.TypeUnspecified, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, + mysql.TypeEnum, mysql.TypeSet, mysql.TypeGeometry, mysql.TypeBit: + if len(paramValues) < (pos + 1) { + err = mysql.ErrMalformPacket + return + } + var n int + length, isNull, n = util2.ParseLengthEncodedInt(paramValues[pos:]) + pos += n + decodeWithDecoder = true + default: + err = errUnknownFieldType.GenWithStack("stmt unknown field type %d", tp) + return + } + + if len(paramValues) < (pos + int(length)) { + err = mysql.ErrMalformPacket + return + } + params[i] = param.BinaryParam{ + Tp: tp, + IsUnsigned: isUnsigned, + IsNull: isNull, + Val: paramValues[pos : pos+int(length)], + } + if decodeWithDecoder { + params[i].Val = enc.DecodeInput(params[i].Val) + } + pos += int(length) + } + return +} diff --git a/server/conn_stmt_params_test.go b/server/conn_stmt_params_test.go new file mode 100644 index 0000000000000..beaa275692e52 --- /dev/null +++ b/server/conn_stmt_params_test.go @@ -0,0 +1,380 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "context" + "encoding/binary" + "testing" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/param" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/server/internal/column" + "github.com/pingcap/tidb/pkg/server/internal/util" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/stretchr/testify/require" +) + +// decodeAndParse uses the `parseBinaryParams` and `parse.ExecArgs` to parse the params passed through binary protocol +// It helps to test the integration of these two functions +func decodeAndParse(typectx types.Context, args []expression.Expression, boundParams [][]byte, + nullBitmap, paramTypes, paramValues []byte, enc *util.InputDecoder) (err error) { + binParams := make([]param.BinaryParam, len(args)) + err = parseBinaryParams(binParams, boundParams, nullBitmap, paramTypes, paramValues, enc) + if err != nil { + return err + } + + parsedArgs, err := param.ExecArgs(typectx, binParams) + if err != nil { + return err + } + + for i := 0; i < len(args); i++ { + args[i] = parsedArgs[i] + } + return +} + +func TestParseExecArgs(t *testing.T) { + type args struct { + args []expression.Expression + boundParams [][]byte + nullBitmap []byte + paramTypes []byte + paramValues []byte + } + tests := []struct { + args args + err error + warn error + expect interface{} + }{ + // Tests for int overflow + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{1, 0}, + []byte{0xff}, + }, + nil, + nil, + int64(-1), + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{2, 0}, + []byte{0xff, 0xff}, + }, + nil, + nil, + int64(-1), + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{3, 0}, + []byte{0xff, 0xff, 0xff, 0xff}, + }, + nil, + nil, + int64(-1), + }, + // Tests for date/datetime/timestamp + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{12, 0}, + []byte{0x0b, 0xda, 0x07, 0x0a, 0x11, 0x13, 0x1b, 0x1e, 0x01, 0x00, 0x00, 0x00}, + }, + nil, + nil, + types.NewTime(types.FromDate(2010, 10, 17, 19, 27, 30, 1), mysql.TypeDatetime, types.MaxFsp), + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{10, 0}, + []byte{0x04, 0xda, 0x07, 0x0a, 0x11}, + }, + nil, + nil, + types.NewTime(types.FromDate(2010, 10, 17, 0, 0, 0, 0), mysql.TypeDate, types.DefaultFsp), + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{7, 0}, + []byte{0x0b, 0xda, 0x07, 0x0a, 0x11, 0x13, 0x1b, 0x1e, 0x01, 0x00, 0x00, 0x00}, + }, + nil, + nil, + types.NewTime(types.FromDate(2010, 10, 17, 19, 27, 30, 1), mysql.TypeDatetime, types.MaxFsp), + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{7, 0}, + []byte{0x07, 0xda, 0x07, 0x0a, 0x11, 0x13, 0x1b, 0x1e}, + }, + nil, + nil, + types.NewTime(types.FromDate(2010, 10, 17, 19, 27, 30, 0), mysql.TypeDatetime, types.DefaultFsp), + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{7, 0}, + []byte{0x0d, 0xdb, 0x07, 0x02, 0x03, 0x04, 0x05, 0x06, 0x40, 0xe2, 0x01, 0x00, 0xf2, 0x02}, + }, + nil, + nil, + types.NewTime(types.FromDate(2011, 02, 02, 15, 31, 06, 123456), mysql.TypeDatetime, types.MaxFsp), + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{7, 0}, + []byte{0x0d, 0xdb, 0x07, 0x02, 0x03, 0x04, 0x05, 0x06, 0x40, 0xe2, 0x01, 0x00, 0x0e, 0xfd}, + }, + nil, + nil, + types.NewTime(types.FromDate(2011, 02, 03, 16, 39, 06, 123456), mysql.TypeDatetime, types.MaxFsp), + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{7, 0}, + []byte{0x00}, + }, + nil, + nil, + types.NewTime(types.ZeroCoreTime, mysql.TypeDatetime, types.DefaultFsp), + }, + // Tests for time + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{11, 0}, + []byte{0x0c, 0x01, 0x78, 0x00, 0x00, 0x00, 0x13, 0x1b, 0x1e, 0x01, 0x00, 0x00, 0x00}, + }, + nil, + types.ErrTruncatedWrongVal, + types.Duration{Duration: types.MinTime, Fsp: types.MaxFsp}, + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{11, 0}, + []byte{0x08, 0x01, 0x78, 0x00, 0x00, 0x00, 0x13, 0x1b, 0x1e}, + }, + nil, + types.ErrTruncatedWrongVal, + types.Duration{Duration: types.MinTime, Fsp: types.MaxFsp}, + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{11, 0}, + []byte{0x00}, + }, + nil, + nil, + types.Duration{Duration: time.Duration(0), Fsp: types.MaxFsp}, + }, + // For error test + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{7, 0}, + []byte{10}, + }, + mysql.ErrMalformPacket, + nil, + nil, + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{11, 0}, + []byte{10}, + }, + mysql.ErrMalformPacket, + nil, + nil, + }, + { + args{ + expression.Args2Expressions4Test(1), + [][]byte{nil}, + []byte{0x0}, + []byte{11, 0}, + []byte{8, 2}, + }, + mysql.ErrMalformPacket, + nil, + nil, + }, + } + for _, tt := range tests { + var warn error + typectx := types.NewContext(types.DefaultStmtFlags.WithTruncateAsWarning(true), time.UTC, func(err error) { + warn = err + }) + err := decodeAndParse(typectx, tt.args.args, tt.args.boundParams, tt.args.nullBitmap, tt.args.paramTypes, tt.args.paramValues, nil) + require.Truef(t, terror.ErrorEqual(err, tt.err), "err %v", err) + require.Truef(t, terror.ErrorEqual(warn, tt.warn), "warn %v", warn) + if err == nil { + require.Equal(t, tt.expect, tt.args.args[0].(*expression.Constant).Value.GetValue()) + } + } +} + +func TestParseExecArgsAndEncode(t *testing.T) { + dt := expression.Args2Expressions4Test(1) + err := decodeAndParse(types.DefaultStmtNoWarningContext, + dt, + [][]byte{nil}, + []byte{0x0}, + []byte{mysql.TypeVarchar, 0}, + []byte{4, 178, 226, 202, 212}, + util.NewInputDecoder("gbk")) + require.NoError(t, err) + require.Equal(t, "测试", dt[0].(*expression.Constant).Value.GetValue()) + + err = decodeAndParse(types.DefaultStmtNoWarningContext, + dt, + [][]byte{{178, 226, 202, 212}}, + []byte{0x0}, + []byte{mysql.TypeString, 0}, + []byte{}, + util.NewInputDecoder("gbk")) + require.NoError(t, err) + require.Equal(t, "测试", dt[0].(*expression.Constant).Value.GetString()) +} + +func buildDatetimeParam(year uint16, month uint8, day uint8, hour uint8, min uint8, sec uint8, msec uint32) []byte { + endian := binary.LittleEndian + + result := []byte{mysql.TypeDatetime, 0x0, 11} + result = endian.AppendUint16(result, year) + result = append(result, month) + result = append(result, day) + result = append(result, hour) + result = append(result, min) + result = append(result, sec) + result = endian.AppendUint32(result, msec) + return result +} + +func expectedDatetimeExecuteResult(t *testing.T, c *mockConn, time types.Time, warnCount int) []byte { + return getExpectOutput(t, c, func(conn *clientConn) { + var err error + + cols := []*column.Info{{ + Name: "t", + Table: "", + Type: mysql.TypeDatetime, + Charset: uint16(mysql.CharsetNameToID(charset.CharsetBin)), + Flag: uint16(mysql.NotNullFlag | mysql.BinaryFlag), + Decimal: 6, + ColumnLength: 26, + }} + require.NoError(t, conn.writeColumnInfo(cols)) + + chk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeDatetime)}, 1) + chk.AppendTime(0, time) + data := make([]byte, 4) + data, err = column.DumpBinaryRow(data, cols, chk.GetRow(0), conn.rsEncoder) + require.NoError(t, err) + require.NoError(t, conn.writePacket(data)) + + for i := 0; i < warnCount; i++ { + conn.ctx.GetSessionVars().StmtCtx.AppendWarning(errors.New("any error")) + } + require.NoError(t, conn.writeEOF(context.Background(), mysql.ServerStatusAutocommit)) + }) +} + +func TestDateTimeTypes(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + srv := CreateMockServer(t, store) + srv.SetDomain(dom) + defer srv.Close() + + appendUint32 := binary.LittleEndian.AppendUint32 + ctx := context.Background() + c := CreateMockConn(t, srv).(*mockConn) + c.capability = mysql.ClientProtocol41 | mysql.ClientDeprecateEOF + + tk := testkit.NewTestKitWithSession(t, store, c.Context().Session) + tk.MustExec("use test") + stmt, _, _, err := c.Context().Prepare("select ? as t") + require.NoError(t, err) + + expectedTimeDatum, err := types.ParseDatetime(types.DefaultStmtNoWarningContext, "2023-11-09 14:23:45.000100") + require.NoError(t, err) + expected := expectedDatetimeExecuteResult(t, c, expectedTimeDatum, 1) + + // execute the statement with datetime parameter + req := append( + appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())), + 0x0, 0x1, 0x0, 0x0, 0x0, + 0x0, 0x1, + ) + req = append(req, buildDatetimeParam(2023, 11, 9, 14, 23, 45, 100)...) + out := c.GetOutput() + require.NoError(t, c.Dispatch(ctx, req)) + + require.Equal(t, expected, out.Bytes()) +} diff --git a/server/extension.go b/server/extension.go index c7cf018eb85a8..28dbc3fc67e9a 100644 --- a/server/extension.go +++ b/server/extension.go @@ -17,6 +17,7 @@ package server import ( "fmt" +<<<<<<< HEAD:server/extension.go "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/extension" "github.com/pingcap/tidb/parser" @@ -26,6 +27,17 @@ import ( "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" +======= + "github.com/pingcap/tidb/pkg/extension" + "github.com/pingcap/tidb/pkg/param" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/auth" + "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/types" +>>>>>>> 8ce2ad16961 (parse: fix the type of date/time parameters (#48237)):pkg/server/extension.go ) func (cc *clientConn) onExtensionConnEvent(tp extension.ConnEventTp, err error) { @@ -54,7 +66,7 @@ func (cc *clientConn) onExtensionConnEvent(tp extension.ConnEventTp, err error) cc.extensions.OnConnectionEvent(tp, info) } -func (cc *clientConn) onExtensionStmtEnd(node interface{}, stmtCtxValid bool, err error, args ...expression.Expression) { +func (cc *clientConn) onExtensionStmtEnd(node interface{}, stmtCtxValid bool, err error, args ...param.BinaryParam) { if !cc.extensions.HasStmtEventListeners() { return } @@ -82,9 +94,17 @@ func (cc *clientConn) onExtensionStmtEnd(node interface{}, stmtCtxValid bool, er case PreparedStatement: info.executeStmtID = uint32(stmt.ID()) prepared, _ := sessVars.GetPreparedStmtByID(info.executeStmtID) + + // TODO: the `BinaryParam` is parsed two times: one in the `Execute` method and one here. It would be better to + // eliminate one of them by storing the parsed result. + typectx := ctx.GetSessionVars().StmtCtx.TypeCtx() + typectx = types.NewContext(typectx.Flags(), typectx.Location(), func(_ error) { + // ignore all warnings + }) + params, _ := param.ExecArgs(typectx, args) info.executeStmt = &ast.ExecuteStmt{ PrepStmt: prepared, - BinaryArgs: args, + BinaryArgs: params, } info.stmtNode = info.executeStmt case ast.StmtNode: @@ -112,7 +132,7 @@ func (cc *clientConn) onExtensionSQLParseFailed(sql string, err error) { }) } -func (cc *clientConn) onExtensionBinaryExecuteEnd(prep PreparedStatement, args []expression.Expression, stmtCtxValid bool, err error) { +func (cc *clientConn) onExtensionBinaryExecuteEnd(prep PreparedStatement, args []param.BinaryParam, stmtCtxValid bool, err error) { cc.onExtensionStmtEnd(prep, stmtCtxValid, err, args...) } diff --git a/session/session.go b/session/session.go index ef665e50c07e2..f34221b5b5e38 100644 --- a/session/session.go +++ b/session/session.go @@ -42,6 +42,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/kvrpcpb" +<<<<<<< HEAD:session/session.go "github.com/pingcap/tidb/bindinfo" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/ddl" @@ -104,6 +105,76 @@ import ( "github.com/pingcap/tidb/util/topsql" topsqlstate "github.com/pingcap/tidb/util/topsql/state" "github.com/pingcap/tidb/util/topsql/stmtstats" +======= + "github.com/pingcap/tidb/pkg/bindinfo" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl" + "github.com/pingcap/tidb/pkg/ddl/placement" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/executor" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/extension" + "github.com/pingcap/tidb/pkg/extension/extensionimpl" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/owner" + "github.com/pingcap/tidb/pkg/param" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/auth" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/planner" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/plugin" + "github.com/pingcap/tidb/pkg/privilege" + "github.com/pingcap/tidb/pkg/privilege/conn" + "github.com/pingcap/tidb/pkg/privilege/privileges" + session_metrics "github.com/pingcap/tidb/pkg/session/metrics" + "github.com/pingcap/tidb/pkg/session/txninfo" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" + "github.com/pingcap/tidb/pkg/sessionctx/sessionstates" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/statistics/handle/usage" + storeerr "github.com/pingcap/tidb/pkg/store/driver/error" + "github.com/pingcap/tidb/pkg/store/driver/txn" + "github.com/pingcap/tidb/pkg/store/helper" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/temptable" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/telemetry" + "github.com/pingcap/tidb/pkg/ttl/ttlworker" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/kvcache" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/logutil/consistency" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/sem" + "github.com/pingcap/tidb/pkg/util/sli" + "github.com/pingcap/tidb/pkg/util/sqlescape" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/pingcap/tidb/pkg/util/syncutil" + "github.com/pingcap/tidb/pkg/util/tableutil" + "github.com/pingcap/tidb/pkg/util/timeutil" + "github.com/pingcap/tidb/pkg/util/topsql" + topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" + "github.com/pingcap/tidb/pkg/util/topsql/stmtstats" + "github.com/pingcap/tidb/pkg/util/tracing" +>>>>>>> 8ce2ad16961 (parse: fix the type of date/time parameters (#48237)):pkg/session/session.go "github.com/pingcap/tipb/go-binlog" tikverr "github.com/tikv/client-go/v2/error" tikvstore "github.com/tikv/client-go/v2/kv" @@ -2192,6 +2263,16 @@ func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlex if err := executor.ResetContextOfStmt(s, stmtNode); err != nil { return nil, err } + if execStmt, ok := stmtNode.(*ast.ExecuteStmt); ok { + if binParam, ok := execStmt.BinaryArgs.([]param.BinaryParam); ok { + args, err := param.ExecArgs(s.GetSessionVars().StmtCtx.TypeCtx(), binParam) + if err != nil { + return nil, err + } + execStmt.BinaryArgs = args + } + } + normalizedSQL, digest := s.sessionVars.StmtCtx.SQLDigest() if topsqlstate.TopSQLEnabled() { s.sessionVars.StmtCtx.IsSQLRegistered.Store(true)