From 68dd326e08eaf54f9a84c4a8c8cb7af44c5cba0f Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Thu, 10 Oct 2024 07:50:37 +0200 Subject: [PATCH 01/23] Support vector type --- common_test.go | 13 ++++ frame.go | 19 ++++++ helpers.go | 14 ++++ helpers_test.go | 10 +++ marshal.go | 170 ++++++++++++++++++++++++++++++++++++++++++++++++ metadata.go | 1 + vector_test.go | 76 ++++++++++++++++++++++ 7 files changed, 303 insertions(+) create mode 100644 vector_test.go diff --git a/common_test.go b/common_test.go index a5edb03c6..ae7f83f94 100644 --- a/common_test.go +++ b/common_test.go @@ -28,6 +28,7 @@ import ( "flag" "fmt" "log" + "math/rand" "net" "reflect" "strings" @@ -52,6 +53,10 @@ var ( flagCassVersion cassVersion ) +var seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano())) + +const randCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + func init() { flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against") @@ -277,6 +282,14 @@ func assertTrue(t *testing.T, description string, value bool) { } } +func randomText(size int) string { + result := make([]byte, size) + for i := range result { + result[i] = randCharset[rand.Intn(len(randCharset))] + } + return string(result) +} + func assertEqual(t *testing.T, description string, expected, actual interface{}) { t.Helper() if expected != actual { diff --git a/frame.go b/frame.go index d374ae574..4df219cc8 100644 --- a/frame.go +++ b/frame.go @@ -32,6 +32,7 @@ import ( "io/ioutil" "net" "runtime" + "strconv" "strings" "time" ) @@ -928,6 +929,24 @@ func (f *framer) readTypeInfo() TypeInfo { collection.Elem = f.readTypeInfo() return collection + case TypeCustom: + if strings.HasPrefix(simple.custom, VECTOR_TYPE) { + spec := strings.TrimPrefix(simple.custom, VECTOR_TYPE) + spec = spec[1 : len(spec)-1] // remove parenthesis + types := strings.Split(spec, ",") + // TODO(lantoniak): for now we use only simple subtypes + subType := NativeType{ + proto: f.proto, + typ: getApacheCassandraType(strings.TrimSpace(types[0])), + } + dim, _ := strconv.Atoi(strings.TrimSpace(types[1])) + vector := VectorType{ + NativeType: simple, + SubType: subType, + Dimensions: dim, + } + return vector + } } return simple diff --git a/helpers.go b/helpers.go index f2faee9e0..005148144 100644 --- a/helpers.go +++ b/helpers.go @@ -29,6 +29,7 @@ import ( "math/big" "net" "reflect" + "strconv" "strings" "time" @@ -200,6 +201,19 @@ func getCassandraType(name string, logger StdLogger) TypeInfo { NativeType: NativeType{typ: TypeTuple}, Elems: types, } + } else if strings.HasPrefix(name, "vector<") { + names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "vector<")) + subType := getCassandraType(names[0], logger) + dim, _ := strconv.Atoi(strings.TrimSpace(names[1])) + + return VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: VECTOR_TYPE, + }, + SubType: subType, + Dimensions: dim, + } } else { return NativeType{ typ: getCassandraBaseType(name), diff --git a/helpers_test.go b/helpers_test.go index 67922ba5d..4622da361 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -223,6 +223,16 @@ func TestGetCassandraType(t *testing.T) { Elem: NativeType{typ: TypeDuration}, }, }, + { + "vector", VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: VECTOR_TYPE, + }, + SubType: NativeType{typ: TypeFloat}, + Dimensions: 3, + }, + }, } for _, test := range tests { diff --git a/marshal.go b/marshal.go index 4d0adb923..813b8e282 100644 --- a/marshal.go +++ b/marshal.go @@ -170,6 +170,11 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { return marshalDate(info, value) case TypeDuration: return marshalDuration(info, value) + case TypeCustom: + switch info.(type) { + case VectorType: + return marshalVector(info.(VectorType), value) + } } // detect protocol 2 UDT @@ -274,6 +279,11 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error { return unmarshalDate(info, data, value) case TypeDuration: return unmarshalDuration(info, data, value) + case TypeCustom: + switch info.(type) { + case VectorType: + return unmarshalVector(info.(VectorType), data, value) + } } // detect protocol 2 UDT @@ -1709,6 +1719,160 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } +func marshalVector(info VectorType, value interface{}) ([]byte, error) { + if value == nil { + return nil, nil + } else if _, ok := value.(unsetColumn); ok { + return nil, nil + } + + rv := reflect.ValueOf(value) + t := rv.Type() + k := t.Kind() + if k == reflect.Slice && rv.IsNil() { + return nil, nil + } + + switch k { + case reflect.Slice, reflect.Array: + buf := &bytes.Buffer{} + n := rv.Len() + + for i := 0; i < n; i++ { + if isVectorVariableLengthType(info.SubType.Type()) { + elemSize := rv.Index(i).Len() + writeUnsignedVInt(buf, uint64(elemSize)) + } + item, err := Marshal(info.SubType, rv.Index(i).Interface()) + if err != nil { + return nil, err + } + buf.Write(item) + } + return buf.Bytes(), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func unmarshalVector(info VectorType, data []byte, value interface{}) error { + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + t := rv.Type() + k := t.Kind() + switch k { + case reflect.Slice, reflect.Array: + if data == nil { + if k == reflect.Array { + return unmarshalErrorf("unmarshal vector: can not store nil in array value") + } + if rv.IsNil() { + return nil + } + rv.Set(reflect.Zero(t)) + return nil + } + if k == reflect.Array { + if rv.Len() != info.Dimensions { + return unmarshalErrorf("unmarshal vector: array with wrong size") + } + } else { + rv.Set(reflect.MakeSlice(t, info.Dimensions, info.Dimensions)) + } + elemSize := len(data) / info.Dimensions + for i := 0; i < info.Dimensions; i++ { + offset := 0 + if isVectorVariableLengthType(info.SubType.Type()) { + m, p, err := readUnsignedVint(data, 0) + if err != nil { + return err + } + elemSize = int(m) + offset = p + } + if offset > 0 { + data = data[offset:] + } + var unmarshalData []byte + if elemSize >= 0 { + if len(data) < elemSize { + return unmarshalErrorf("unmarshal vector: unexpected eof") + } + unmarshalData = data[:elemSize] + data = data[elemSize:] + } + err := Unmarshal(info.SubType, unmarshalData, rv.Index(i).Addr().Interface()) + if err != nil { + return unmarshalErrorf("failed to unmarshal %s into %T: %s", info.SubType, unmarshalData, err.Error()) + } + } + return nil + } + return unmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func isVectorVariableLengthType(elemType Type) bool { + switch elemType { + case TypeVarchar, TypeAscii, TypeBlob, TypeText: + return true + // TODO(lantonia): double check list of variable vector types + //case TypeCounter: + // return true + //case TypeDuration, TypeDate, TypeTime: + // return true + //case TypeDecimal, TypeSmallInt, TypeTinyInt: + // return true + case TypeInet: + return true + } + return false +} + +func writeUnsignedVInt(buf *bytes.Buffer, v uint64) { + numBytes := computeUnsignedVIntSize(v) + if numBytes <= 1 { + buf.WriteByte(byte(v)) + return + } + + numBytes = computeUnsignedVIntSize(v) + extraBytes := numBytes - 1 + var tmp = make([]byte, numBytes) + for i := extraBytes; i >= 0; i-- { + tmp[i] = byte(v) + v >>= 8 + } + tmp[0] |= byte(^(0xff >> uint(extraBytes))) + buf.Write(tmp) +} + +func readUnsignedVint(data []byte, start int) (uint64, int, error) { + if len(data) <= start { + return 0, 0, errors.New("unexpected eof") + } + firstByte := data[start] + if firstByte&0x80 == 0 { + return uint64(firstByte), start + 1, nil + } + numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24 + ret := uint64(firstByte & (0xff >> uint(numBytes))) + if len(data) < start+numBytes+1 { + return 0, 0, fmt.Errorf("data expect to have %d bytes, but it has only %d", start+numBytes+1, len(data)) + } + for i := start; i < start+numBytes; i++ { + ret <<= 8 + ret |= uint64(data[i+1] & 0xff) + } + return ret, start + numBytes + 1, nil +} + +func computeUnsignedVIntSize(v uint64) int { + lead0 := bits.LeadingZeros64(v) + return (639 - lead0*9) >> 6 +} + func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { mapInfo, ok := info.(CollectionType) if !ok { @@ -2523,6 +2687,12 @@ type CollectionType struct { Elem TypeInfo // only used for TypeMap, TypeList and TypeSet } +type VectorType struct { + NativeType + SubType TypeInfo + Dimensions int +} + func (t CollectionType) NewWithError() (interface{}, error) { typ, err := goType(t) if err != nil { diff --git a/metadata.go b/metadata.go index 6eb798f8a..ea962d553 100644 --- a/metadata.go +++ b/metadata.go @@ -1209,6 +1209,7 @@ const ( LIST_TYPE = "org.apache.cassandra.db.marshal.ListType" SET_TYPE = "org.apache.cassandra.db.marshal.SetType" MAP_TYPE = "org.apache.cassandra.db.marshal.MapType" + VECTOR_TYPE = "org.apache.cassandra.db.marshal.VectorType" ) // represents a class specification in the type def AST diff --git a/vector_test.go b/vector_test.go new file mode 100644 index 000000000..00d6d48cb --- /dev/null +++ b/vector_test.go @@ -0,0 +1,76 @@ +//go:build all || cassandra +// +build all cassandra + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 + * Copyright (c) 2016, The Gocql authors, + * provided under the BSD-3-Clause License. + * See the NOTICE file distributed with this work for additional information. + */ + +package gocql + +import ( + "testing" +) + +func TestVector_Marshaler(t *testing.T) { + session := createSession(t) + defer session.Close() + + if flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + err := createTable(session, `CREATE TABLE gocql_test.vector_fixed(id int primary key, vec vector);`) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, `CREATE TABLE gocql_test.vector_variable(id int primary key, vec vector);`) + if err != nil { + t.Fatal(err) + } + + insertFixVec := []float32{8, 2.5, -5.0} + err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 1, insertFixVec).Exec() + if err != nil { + t.Fatal(err) + } + var vf []float32 + err = session.Query("SELECT vec FROM vector_fixed WHERE id = ?", 1).Scan(&vf) + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "fixed-size element size vector", insertFixVec, vf) + + longText := randomText(500) + insertVarVec := []string{"apache", "cassandra", longText, "gocql"} + err = session.Query("INSERT INTO vector_variable(id, vec) VALUES(?, ?)", 1, insertVarVec).Exec() + if err != nil { + t.Fatal(err) + } + var vv []string + err = session.Query("SELECT vec FROM vector_variable WHERE id = ?", 1).Scan(&vv) + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "variable-size element vector", insertVarVec, vv) +} From cf779eb3a771da77719a2c1224fb0454240cbad9 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Thu, 10 Oct 2024 09:44:38 +0200 Subject: [PATCH 02/23] Vector support tests --- marshal.go | 3 +++ vector_test.go | 67 +++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/marshal.go b/marshal.go index 813b8e282..d80437fbc 100644 --- a/marshal.go +++ b/marshal.go @@ -1737,6 +1737,9 @@ func marshalVector(info VectorType, value interface{}) ([]byte, error) { case reflect.Slice, reflect.Array: buf := &bytes.Buffer{} n := rv.Len() + if n != info.Dimensions { + return nil, marshalErrorf("expected vector with %d dimensions, received %d", info.Dimensions, n) + } for i := 0; i < n; i++ { if isVectorVariableLengthType(info.SubType.Type()) { diff --git a/vector_test.go b/vector_test.go index 00d6d48cb..171436f18 100644 --- a/vector_test.go +++ b/vector_test.go @@ -28,6 +28,7 @@ package gocql import ( + "github.com/stretchr/testify/require" "testing" ) @@ -39,12 +40,12 @@ func TestVector_Marshaler(t *testing.T) { t.Skip("Vector types have been introduced in Cassandra 5.0") } - err := createTable(session, `CREATE TABLE gocql_test.vector_fixed(id int primary key, vec vector);`) + err := createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_fixed(id int primary key, vec vector);`) if err != nil { t.Fatal(err) } - err = createTable(session, `CREATE TABLE gocql_test.vector_variable(id int primary key, vec vector);`) + err = createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_variable(id int primary key, vec vector);`) if err != nil { t.Fatal(err) } @@ -59,7 +60,7 @@ func TestVector_Marshaler(t *testing.T) { if err != nil { t.Fatal(err) } - assertDeepEqual(t, "fixed-size element size vector", insertFixVec, vf) + assertDeepEqual(t, "fixed size element vector", insertFixVec, vf) longText := randomText(500) insertVarVec := []string{"apache", "cassandra", longText, "gocql"} @@ -72,5 +73,63 @@ func TestVector_Marshaler(t *testing.T) { if err != nil { t.Fatal(err) } - assertDeepEqual(t, "variable-size element vector", insertVarVec, vv) + assertDeepEqual(t, "variable size element vector", insertVarVec, vv) +} + +func TestVector_Empty(t *testing.T) { + session := createSession(t) + defer session.Close() + + if flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + err := createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_fixed_null(id int primary key, vec vector);`) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_variable_null(id int primary key, vec vector);`) + if err != nil { + t.Fatal(err) + } + + err = session.Query("INSERT INTO vector_fixed_null(id) VALUES(?)", 1).Exec() + if err != nil { + t.Fatal(err) + } + var vf []float32 + err = session.Query("SELECT vec FROM vector_fixed_null WHERE id = ?", 1).Scan(&vf) + if err != nil { + t.Fatal(err) + } + assertTrue(t, "fixed size element vector is empty", vf == nil) + + err = session.Query("INSERT INTO vector_variable_null(id) VALUES(?)", 1).Exec() + if err != nil { + t.Fatal(err) + } + var vv []string + err = session.Query("SELECT vec FROM vector_variable_null WHERE id = ?", 1).Scan(&vv) + if err != nil { + t.Fatal(err) + } + assertTrue(t, "variable size element vector is empty", vv == nil) +} + +func TestVector_MissingDimension(t *testing.T) { + session := createSession(t) + defer session.Close() + + if flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + err := createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_fixed(id int primary key, vec vector);`) + if err != nil { + t.Fatal(err) + } + + err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 1, []float32{8, -5.0}).Exec() + require.Error(t, err, "expected vector with 3 dimensions, received 2") } From 01c9d74e3db2f37b1b33507d14e812fa815f5e72 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Thu, 10 Oct 2024 16:20:42 +0200 Subject: [PATCH 03/23] Test vectors of simple types --- helpers.go | 2 + marshal.go | 20 ++++---- vector_test.go | 127 ++++++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 126 insertions(+), 23 deletions(-) diff --git a/helpers.go b/helpers.go index 005148144..f0806ff33 100644 --- a/helpers.go +++ b/helpers.go @@ -311,6 +311,8 @@ func getApacheCassandraType(class string) Type { return TypeTuple case "DurationType": return TypeDuration + case "SimpleDateType": + return TypeDate default: return TypeCustom } diff --git a/marshal.go b/marshal.go index d80437fbc..8abb9e820 100644 --- a/marshal.go +++ b/marshal.go @@ -1742,14 +1742,13 @@ func marshalVector(info VectorType, value interface{}) ([]byte, error) { } for i := 0; i < n; i++ { - if isVectorVariableLengthType(info.SubType.Type()) { - elemSize := rv.Index(i).Len() - writeUnsignedVInt(buf, uint64(elemSize)) - } item, err := Marshal(info.SubType, rv.Index(i).Interface()) if err != nil { return nil, err } + if isVectorVariableLengthType(info.SubType.Type()) { + writeUnsignedVInt(buf, uint64(len(item))) + } buf.Write(item) } return buf.Bytes(), nil @@ -1820,13 +1819,12 @@ func isVectorVariableLengthType(elemType Type) bool { switch elemType { case TypeVarchar, TypeAscii, TypeBlob, TypeText: return true - // TODO(lantonia): double check list of variable vector types - //case TypeCounter: - // return true - //case TypeDuration, TypeDate, TypeTime: - // return true - //case TypeDecimal, TypeSmallInt, TypeTinyInt: - // return true + case TypeCounter: + return true + case TypeDuration, TypeDate, TypeTime: + return true + case TypeDecimal, TypeSmallInt, TypeTinyInt, TypeVarint: + return true case TypeInet: return true } diff --git a/vector_test.go b/vector_test.go index 171436f18..171cfc9ac 100644 --- a/vector_test.go +++ b/vector_test.go @@ -28,8 +28,13 @@ package gocql import ( + "fmt" "github.com/stretchr/testify/require" + "gopkg.in/inf.v0" + "net" + "reflect" "testing" + "time" ) func TestVector_Marshaler(t *testing.T) { @@ -55,12 +60,12 @@ func TestVector_Marshaler(t *testing.T) { if err != nil { t.Fatal(err) } - var vf []float32 - err = session.Query("SELECT vec FROM vector_fixed WHERE id = ?", 1).Scan(&vf) + var selectFixVec []float32 + err = session.Query("SELECT vec FROM vector_fixed WHERE id = ?", 1).Scan(&selectFixVec) if err != nil { t.Fatal(err) } - assertDeepEqual(t, "fixed size element vector", insertFixVec, vf) + assertDeepEqual(t, "fixed size element vector", insertFixVec, selectFixVec) longText := randomText(500) insertVarVec := []string{"apache", "cassandra", longText, "gocql"} @@ -68,12 +73,107 @@ func TestVector_Marshaler(t *testing.T) { if err != nil { t.Fatal(err) } - var vv []string - err = session.Query("SELECT vec FROM vector_variable WHERE id = ?", 1).Scan(&vv) + var selectVarVec []string + err = session.Query("SELECT vec FROM vector_variable WHERE id = ?", 1).Scan(&selectVarVec) if err != nil { t.Fatal(err) } - assertDeepEqual(t, "variable size element vector", insertVarVec, vv) + assertDeepEqual(t, "variable size element vector", insertVarVec, selectVarVec) +} + +func TestVector_Types(t *testing.T) { + session := createSession(t) + defer session.Close() + + if flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + timestamp1, _ := time.Parse("2006-01-02", "2000-01-01") + timestamp2, _ := time.Parse("2006-01-02 15:04:05", "2024-01-01 10:31:45") + timestamp3, _ := time.Parse("2006-01-02 15:04:05.000", "2024-05-01 10:31:45.987") + + date1, _ := time.Parse("2006-01-02", "2000-01-01") + date2, _ := time.Parse("2006-01-02", "2022-03-14") + date3, _ := time.Parse("2006-01-02", "2024-12-31") + + time1, _ := time.Parse("15:04:05", "01:00:00") + time2, _ := time.Parse("15:04:05", "15:23:59") + time3, _ := time.Parse("15:04:05.000", "10:31:45.987") + + duration1 := Duration{0, 1, 1920000000000} + duration2 := Duration{1, 1, 1920000000000} + duration3 := Duration{31, 0, 60000000000} + + tests := []struct { + name string + cqlType Type + value interface{} + comparator func(interface{}, interface{}) + }{ + {name: "ascii", cqlType: TypeAscii, value: []string{"a", "1", "Z"}}, + // TODO(lantonia): Test vector of custom types + // TODO(lantonia): Test vector of list, maps, set types + {name: "bigint", cqlType: TypeBigInt, value: []int64{1, 2, 3}}, + {name: "blob", cqlType: TypeBlob, value: [][]byte{[]byte{1, 2, 3}, []byte{4, 5, 6, 7}, []byte{8, 9}}}, + {name: "boolean", cqlType: TypeBoolean, value: []bool{true, false, true}}, + {name: "counter", cqlType: TypeCounter, value: []int64{5, 6, 7}}, + {name: "decimal", cqlType: TypeDecimal, value: []inf.Dec{*inf.NewDec(1, 0), *inf.NewDec(2, 1), *inf.NewDec(-3, 2)}}, + {name: "double", cqlType: TypeDouble, value: []float64{0.1, -1.2, 3}}, + {name: "float", cqlType: TypeFloat, value: []float32{0.1, -1.2, 3}}, + {name: "int", cqlType: TypeInt, value: []int32{1, 2, 3}}, + {name: "text", cqlType: TypeText, value: []string{"a", "b", "c"}}, + {name: "timestamp", cqlType: TypeTimestamp, value: []time.Time{timestamp1, timestamp2, timestamp3}}, + {name: "uuid", cqlType: TypeUUID, value: []UUID{MustRandomUUID(), MustRandomUUID(), MustRandomUUID()}}, + {name: "varchar", cqlType: TypeVarchar, value: []string{"abc", "def", "ghi"}}, + {name: "varint", cqlType: TypeVarint, value: []uint64{uint64(1234), uint64(123498765), uint64(18446744073709551615)}}, + {name: "timeuuid", cqlType: TypeTimeUUID, value: []UUID{TimeUUID(), TimeUUID(), TimeUUID()}}, + { + name: "inet", + cqlType: TypeInet, + value: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv4(192, 168, 1, 1), net.IPv4(8, 8, 8, 8)}, + comparator: func(e interface{}, a interface{}) { + expected := e.([]net.IP) + actual := a.([]net.IP) + assertEqual(t, "vector size", len(expected), len(actual)) + for i, _ := range expected { + // TODO(lantoniak): Find a better way to compare IP addresses + assertEqual(t, "vector", expected[i].String(), actual[i].String()) + } + }, + }, + {name: "date", cqlType: TypeDate, value: []time.Time{date1, date2, date3}}, + {name: "time", cqlType: TypeTimestamp, value: []time.Time{time1, time2, time3}}, + {name: "smallint", cqlType: TypeSmallInt, value: []int16{127, 256, -1234}}, + {name: "tinyint", cqlType: TypeTinyInt, value: []int8{127, 9, -123}}, + {name: "duration", cqlType: TypeDuration, value: []Duration{duration1, duration2, duration3}}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tableName := fmt.Sprintf("vector_%s", test.name) + err := createTable(session, fmt.Sprintf(`CREATE TABLE IF NOT EXISTS gocql_test.%s(id int primary key, vec vector<%s, 3>);`, tableName, test.cqlType)) + if err != nil { + t.Fatal(err) + } + + err = session.Query(fmt.Sprintf("INSERT INTO %s(id, vec) VALUES(?, ?)", tableName), 1, test.value).Exec() + if err != nil { + t.Fatal(err) + } + + v := reflect.New(reflect.TypeOf(test.value)) + err = session.Query(fmt.Sprintf("SELECT vec FROM %s WHERE id = ?", tableName), 1).Scan(v.Interface()) + if err != nil { + t.Fatal(err) + } + if test.comparator != nil { + test.comparator(test.value, v.Elem().Interface()) + } else { + assertDeepEqual(t, "vector", test.value, v.Elem().Interface()) + } + }) + } } func TestVector_Empty(t *testing.T) { @@ -98,23 +198,23 @@ func TestVector_Empty(t *testing.T) { if err != nil { t.Fatal(err) } - var vf []float32 - err = session.Query("SELECT vec FROM vector_fixed_null WHERE id = ?", 1).Scan(&vf) + var selectFixVec []float32 + err = session.Query("SELECT vec FROM vector_fixed_null WHERE id = ?", 1).Scan(&selectFixVec) if err != nil { t.Fatal(err) } - assertTrue(t, "fixed size element vector is empty", vf == nil) + assertTrue(t, "fixed size element vector is empty", selectFixVec == nil) err = session.Query("INSERT INTO vector_variable_null(id) VALUES(?)", 1).Exec() if err != nil { t.Fatal(err) } - var vv []string - err = session.Query("SELECT vec FROM vector_variable_null WHERE id = ?", 1).Scan(&vv) + var selectVarVec []string + err = session.Query("SELECT vec FROM vector_variable_null WHERE id = ?", 1).Scan(&selectVarVec) if err != nil { t.Fatal(err) } - assertTrue(t, "variable size element vector is empty", vv == nil) + assertTrue(t, "variable size element vector is empty", selectVarVec == nil) } func TestVector_MissingDimension(t *testing.T) { @@ -132,4 +232,7 @@ func TestVector_MissingDimension(t *testing.T) { err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 1, []float32{8, -5.0}).Exec() require.Error(t, err, "expected vector with 3 dimensions, received 2") + + err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 1, []float32{8, -5.0, 1, 3}).Exec() + require.Error(t, err, "expected vector with 3 dimensions, received 4") } From 20a1b538c3bdf2fc7087b736f1fa1ca6b3f439c2 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Fri, 11 Oct 2024 00:18:55 +0200 Subject: [PATCH 04/23] Support complex vectory subtypes --- frame.go | 12 ++++------ helpers.go | 62 +++++++++++++++++++++++++++++++++++-------------- helpers_test.go | 21 +++++++++++++++-- marshal.go | 2 ++ metadata.go | 18 +++++++------- vector_test.go | 59 +++++++++++++++++++++++++++------------------- 6 files changed, 115 insertions(+), 59 deletions(-) diff --git a/frame.go b/frame.go index 4df219cc8..3bbe9359d 100644 --- a/frame.go +++ b/frame.go @@ -933,13 +933,11 @@ func (f *framer) readTypeInfo() TypeInfo { if strings.HasPrefix(simple.custom, VECTOR_TYPE) { spec := strings.TrimPrefix(simple.custom, VECTOR_TYPE) spec = spec[1 : len(spec)-1] // remove parenthesis - types := strings.Split(spec, ",") - // TODO(lantoniak): for now we use only simple subtypes - subType := NativeType{ - proto: f.proto, - typ: getApacheCassandraType(strings.TrimSpace(types[0])), - } - dim, _ := strconv.Atoi(strings.TrimSpace(types[1])) + idx := strings.LastIndex(spec, ",") + typeStr := spec[:idx] + dimStr := spec[idx+1:] + subType := getTypeInfo(strings.TrimSpace(typeStr), f.proto, nopLogger{}) + dim, _ := strconv.Atoi(strings.TrimSpace(dimStr)) vector := VectorType{ NativeType: simple, SubType: subType, diff --git a/helpers.go b/helpers.go index f0806ff33..4e06e403b 100644 --- a/helpers.go +++ b/helpers.go @@ -163,51 +163,53 @@ func getCassandraBaseType(name string) Type { } } -func getCassandraType(name string, logger StdLogger) TypeInfo { +func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo { if strings.HasPrefix(name, "frozen<") { - return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"), logger) + return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"), protoVer, logger) } else if strings.HasPrefix(name, "set<") { return CollectionType{ - NativeType: NativeType{typ: TypeSet}, - Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<"), logger), + NativeType: NativeType{typ: TypeSet, proto: protoVer}, + Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<"), protoVer, logger), } } else if strings.HasPrefix(name, "list<") { return CollectionType{ - NativeType: NativeType{typ: TypeList}, - Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), logger), + NativeType: NativeType{typ: TypeList, proto: protoVer}, + Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), protoVer, logger), } } else if strings.HasPrefix(name, "map<") { names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) if len(names) != 2 { logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names)) return NativeType{ - typ: TypeCustom, + proto: protoVer, + typ: TypeCustom, } } return CollectionType{ - NativeType: NativeType{typ: TypeMap}, - Key: getCassandraType(names[0], logger), - Elem: getCassandraType(names[1], logger), + NativeType: NativeType{typ: TypeMap, proto: protoVer}, + Key: getCassandraType(names[0], protoVer, logger), + Elem: getCassandraType(names[1], protoVer, logger), } } else if strings.HasPrefix(name, "tuple<") { names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<")) types := make([]TypeInfo, len(names)) for i, name := range names { - types[i] = getCassandraType(name, logger) + types[i] = getCassandraType(name, protoVer, logger) } return TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, + NativeType: NativeType{typ: TypeTuple, proto: protoVer}, Elems: types, } } else if strings.HasPrefix(name, "vector<") { names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "vector<")) - subType := getCassandraType(names[0], logger) + subType := getCassandraType(strings.TrimSpace(names[0]), protoVer, logger) dim, _ := strconv.Atoi(strings.TrimSpace(names[1])) return VectorType{ NativeType: NativeType{ + proto: protoVer, typ: TypeCustom, custom: VECTOR_TYPE, }, @@ -216,7 +218,8 @@ func getCassandraType(name string, logger StdLogger) TypeInfo { } } else { return NativeType{ - typ: getCassandraBaseType(name), + proto: protoVer, + typ: getCassandraBaseType(name), } } } @@ -250,19 +253,36 @@ func splitCompositeTypes(name string) []string { } func apacheToCassandraType(t string) string { - t = strings.Replace(t, apacheCassandraTypePrefix, "", -1) t = strings.Replace(t, "(", "<", -1) t = strings.Replace(t, ")", ">", -1) types := strings.FieldsFunc(t, func(r rune) bool { return r == '<' || r == '>' || r == ',' }) - for _, typ := range types { - t = strings.Replace(t, typ, getApacheCassandraType(typ).String(), -1) + for _, class := range types { + class = strings.TrimSpace(class) + if !isDigitsOnly(class) { + // vector types include dimension (digits) as second type parameter + act := getApacheCassandraType(class) + val := act.String() + if act == TypeCustom { + val = getApacheCassandraCustomSubType(class) + } + t = strings.Replace(t, class, val, -1) + } } // This is done so it exactly matches what Cassandra returns return strings.Replace(t, ",", ", ", -1) } +func isDigitsOnly(s string) bool { + for _, c := range s { + if c < '0' || c > '9' { + return false + } + } + return true +} + func getApacheCassandraType(class string) Type { switch strings.TrimPrefix(class, apacheCassandraTypePrefix) { case "AsciiType": @@ -318,6 +338,14 @@ func getApacheCassandraType(class string) Type { } } +func getApacheCassandraCustomSubType(class string) string { + switch strings.TrimPrefix(class, apacheCassandraTypePrefix) { + case "VectorType": + return "vector" + } + return "custom" +} + func (r *RowData) rowMap(m map[string]interface{}) { for i, column := range r.Columns { val := dereference(r.Values[i]) diff --git a/helpers_test.go b/helpers_test.go index 4622da361..61b369d9a 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -30,7 +30,7 @@ import ( ) func TestGetCassandraType_Set(t *testing.T) { - typ := getCassandraType("set", &defaultLogger{}) + typ := getCassandraType("set", 4, &defaultLogger{}) set, ok := typ.(CollectionType) if !ok { t.Fatalf("expected CollectionType got %T", typ) @@ -233,11 +233,28 @@ func TestGetCassandraType(t *testing.T) { Dimensions: 3, }, }, + { + "vector, 5>", VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: VECTOR_TYPE, + }, + SubType: VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: VECTOR_TYPE, + }, + SubType: NativeType{typ: TypeFloat}, + Dimensions: 3, + }, + Dimensions: 5, + }, + }, } for _, test := range tests { t.Run(test.input, func(t *testing.T) { - got := getCassandraType(test.input, &defaultLogger{}) + got := getCassandraType(test.input, 0, &defaultLogger{}) // TODO(zariel): define an equal method on the types? if !reflect.DeepEqual(got, test.exp) { diff --git a/marshal.go b/marshal.go index 8abb9e820..16f7a1c41 100644 --- a/marshal.go +++ b/marshal.go @@ -1827,6 +1827,8 @@ func isVectorVariableLengthType(elemType Type) bool { return true case TypeInet: return true + case TypeList, TypeSet, TypeMap: + return true } return false } diff --git a/metadata.go b/metadata.go index ea962d553..4cdf7d500 100644 --- a/metadata.go +++ b/metadata.go @@ -383,7 +383,7 @@ func compileMetadata( col := &columns[i] // decode the validator for TypeInfo and order if col.ClusteringOrder != "" { // Cassandra 3.x+ - col.Type = getCassandraType(col.Validator, logger) + col.Type = getCassandraType(col.Validator, byte(protoVersion), logger) col.Order = ASC if col.ClusteringOrder == "desc" { col.Order = DESC @@ -947,11 +947,11 @@ func getColumnMetadata(session *Session, keyspaceName string) ([]ColumnMetadata, return columns, nil } -func getTypeInfo(t string, logger StdLogger) TypeInfo { +func getTypeInfo(t string, protoVer byte, logger StdLogger) TypeInfo { if strings.HasPrefix(t, apacheCassandraTypePrefix) { t = apacheToCassandraType(t) } - return getCassandraType(t, logger) + return getCassandraType(t, protoVer, logger) } func getViewsMetadata(session *Session, keyspaceName string) ([]ViewMetadata, error) { @@ -987,7 +987,7 @@ func getViewsMetadata(session *Session, keyspaceName string) ([]ViewMetadata, er } view.FieldTypes = make([]TypeInfo, len(argumentTypes)) for i, argumentType := range argumentTypes { - view.FieldTypes[i] = getTypeInfo(argumentType, session.logger) + view.FieldTypes[i] = getTypeInfo(argumentType, byte(session.cfg.ProtoVersion), session.logger) } views = append(views, view) } @@ -1108,10 +1108,10 @@ func getFunctionsMetadata(session *Session, keyspaceName string) ([]FunctionMeta if err != nil { return nil, err } - function.ReturnType = getTypeInfo(returnType, session.logger) + function.ReturnType = getTypeInfo(returnType, byte(session.cfg.ProtoVersion), session.logger) function.ArgumentTypes = make([]TypeInfo, len(argumentTypes)) for i, argumentType := range argumentTypes { - function.ArgumentTypes[i] = getTypeInfo(argumentType, session.logger) + function.ArgumentTypes[i] = getTypeInfo(argumentType, byte(session.cfg.ProtoVersion), session.logger) } functions = append(functions, function) } @@ -1165,11 +1165,11 @@ func getAggregatesMetadata(session *Session, keyspaceName string) ([]AggregateMe if err != nil { return nil, err } - aggregate.ReturnType = getTypeInfo(returnType, session.logger) - aggregate.StateType = getTypeInfo(stateType, session.logger) + aggregate.ReturnType = getTypeInfo(returnType, byte(session.cfg.ProtoVersion), session.logger) + aggregate.StateType = getTypeInfo(stateType, byte(session.cfg.ProtoVersion), session.logger) aggregate.ArgumentTypes = make([]TypeInfo, len(argumentTypes)) for i, argumentType := range argumentTypes { - aggregate.ArgumentTypes[i] = getTypeInfo(argumentType, session.logger) + aggregate.ArgumentTypes[i] = getTypeInfo(argumentType, byte(session.cfg.ProtoVersion), session.logger) } aggregates = append(aggregates, aggregate) } diff --git a/vector_test.go b/vector_test.go index 171cfc9ac..5e7aa9545 100644 --- a/vector_test.go +++ b/vector_test.go @@ -105,32 +105,38 @@ func TestVector_Types(t *testing.T) { duration2 := Duration{1, 1, 1920000000000} duration3 := Duration{31, 0, 60000000000} + map1 := make(map[string]int) + map1["a"] = 1 + map1["b"] = 2 + map1["c"] = 3 + map2 := make(map[string]int) + map2["abc"] = 123 + map3 := make(map[string]int) + tests := []struct { name string - cqlType Type + cqlType string value interface{} comparator func(interface{}, interface{}) }{ - {name: "ascii", cqlType: TypeAscii, value: []string{"a", "1", "Z"}}, - // TODO(lantonia): Test vector of custom types - // TODO(lantonia): Test vector of list, maps, set types - {name: "bigint", cqlType: TypeBigInt, value: []int64{1, 2, 3}}, - {name: "blob", cqlType: TypeBlob, value: [][]byte{[]byte{1, 2, 3}, []byte{4, 5, 6, 7}, []byte{8, 9}}}, - {name: "boolean", cqlType: TypeBoolean, value: []bool{true, false, true}}, - {name: "counter", cqlType: TypeCounter, value: []int64{5, 6, 7}}, - {name: "decimal", cqlType: TypeDecimal, value: []inf.Dec{*inf.NewDec(1, 0), *inf.NewDec(2, 1), *inf.NewDec(-3, 2)}}, - {name: "double", cqlType: TypeDouble, value: []float64{0.1, -1.2, 3}}, - {name: "float", cqlType: TypeFloat, value: []float32{0.1, -1.2, 3}}, - {name: "int", cqlType: TypeInt, value: []int32{1, 2, 3}}, - {name: "text", cqlType: TypeText, value: []string{"a", "b", "c"}}, - {name: "timestamp", cqlType: TypeTimestamp, value: []time.Time{timestamp1, timestamp2, timestamp3}}, - {name: "uuid", cqlType: TypeUUID, value: []UUID{MustRandomUUID(), MustRandomUUID(), MustRandomUUID()}}, - {name: "varchar", cqlType: TypeVarchar, value: []string{"abc", "def", "ghi"}}, - {name: "varint", cqlType: TypeVarint, value: []uint64{uint64(1234), uint64(123498765), uint64(18446744073709551615)}}, - {name: "timeuuid", cqlType: TypeTimeUUID, value: []UUID{TimeUUID(), TimeUUID(), TimeUUID()}}, + {name: "ascii", cqlType: TypeAscii.String(), value: []string{"a", "1", "Z"}}, + {name: "bigint", cqlType: TypeBigInt.String(), value: []int64{1, 2, 3}}, + {name: "blob", cqlType: TypeBlob.String(), value: [][]byte{[]byte{1, 2, 3}, []byte{4, 5, 6, 7}, []byte{8, 9}}}, + {name: "boolean", cqlType: TypeBoolean.String(), value: []bool{true, false, true}}, + {name: "counter", cqlType: TypeCounter.String(), value: []int64{5, 6, 7}}, + {name: "decimal", cqlType: TypeDecimal.String(), value: []inf.Dec{*inf.NewDec(1, 0), *inf.NewDec(2, 1), *inf.NewDec(-3, 2)}}, + {name: "double", cqlType: TypeDouble.String(), value: []float64{0.1, -1.2, 3}}, + {name: "float", cqlType: TypeFloat.String(), value: []float32{0.1, -1.2, 3}}, + {name: "int", cqlType: TypeInt.String(), value: []int32{1, 2, 3}}, + {name: "text", cqlType: TypeText.String(), value: []string{"a", "b", "c"}}, + {name: "timestamp", cqlType: TypeTimestamp.String(), value: []time.Time{timestamp1, timestamp2, timestamp3}}, + {name: "uuid", cqlType: TypeUUID.String(), value: []UUID{MustRandomUUID(), MustRandomUUID(), MustRandomUUID()}}, + {name: "varchar", cqlType: TypeVarchar.String(), value: []string{"abc", "def", "ghi"}}, + {name: "varint", cqlType: TypeVarint.String(), value: []uint64{uint64(1234), uint64(123498765), uint64(18446744073709551615)}}, + {name: "timeuuid", cqlType: TypeTimeUUID.String(), value: []UUID{TimeUUID(), TimeUUID(), TimeUUID()}}, { name: "inet", - cqlType: TypeInet, + cqlType: TypeInet.String(), value: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv4(192, 168, 1, 1), net.IPv4(8, 8, 8, 8)}, comparator: func(e interface{}, a interface{}) { expected := e.([]net.IP) @@ -142,11 +148,16 @@ func TestVector_Types(t *testing.T) { } }, }, - {name: "date", cqlType: TypeDate, value: []time.Time{date1, date2, date3}}, - {name: "time", cqlType: TypeTimestamp, value: []time.Time{time1, time2, time3}}, - {name: "smallint", cqlType: TypeSmallInt, value: []int16{127, 256, -1234}}, - {name: "tinyint", cqlType: TypeTinyInt, value: []int8{127, 9, -123}}, - {name: "duration", cqlType: TypeDuration, value: []Duration{duration1, duration2, duration3}}, + {name: "date", cqlType: TypeDate.String(), value: []time.Time{date1, date2, date3}}, + {name: "time", cqlType: TypeTimestamp.String(), value: []time.Time{time1, time2, time3}}, + {name: "smallint", cqlType: TypeSmallInt.String(), value: []int16{127, 256, -1234}}, + {name: "tinyint", cqlType: TypeTinyInt.String(), value: []int8{127, 9, -123}}, + {name: "duration", cqlType: TypeDuration.String(), value: []Duration{duration1, duration2, duration3}}, + // TODO(lantonia): Test vector of custom types + {name: "vector_vector_float", cqlType: "vector", value: [][]float32{{0.1, -1.2, 3, 5, 5}, {10.1, -122222.0002, 35.0, 1, 1}, {0, 0, 0, 0, 0}}}, + {name: "vector_set_text", cqlType: "set", value: [][]string{{"a", "b"}, {"c", "d"}, {"e", "f"}}}, + {name: "vector_list_int", cqlType: "list", value: [][]int32{{1, 2, 3}, {-1, -2, -3}, {0, 0, 0}}}, + {name: "vector_map_text_int", cqlType: "map", value: []map[string]int{map1, map2, map3}}, } for _, test := range tests { From 3256fa36fc378e27917b71c290944cfef706f556 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Fri, 11 Oct 2024 06:42:26 +0200 Subject: [PATCH 05/23] Fix comparison in tests --- cassandra_test.go | 52 +++++++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/cassandra_test.go b/cassandra_test.go index 797a7cf7f..5e51929cf 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -2245,15 +2245,16 @@ func TestViewMetadata(t *testing.T) { textType = TypeVarchar } + protoVer := byte(session.cfg.ProtoVersion) expectedView := ViewMetadata{ Keyspace: "gocql_test", Name: "basicview", FieldNames: []string{"birthday", "nationality", "weight", "height"}, FieldTypes: []TypeInfo{ - NativeType{typ: TypeTimestamp}, - NativeType{typ: textType}, - NativeType{typ: textType}, - NativeType{typ: textType}, + NativeType{typ: TypeTimestamp, proto: protoVer}, + NativeType{typ: textType, proto: protoVer}, + NativeType{typ: textType, proto: protoVer}, + NativeType{typ: textType, proto: protoVer}, }, } @@ -2351,18 +2352,19 @@ func TestAggregateMetadata(t *testing.T) { t.Fatal("expected two aggregates") } + protoVer := byte(session.cfg.ProtoVersion) expectedAggregrate := AggregateMetadata{ Keyspace: "gocql_test", Name: "average", - ArgumentTypes: []TypeInfo{NativeType{typ: TypeInt}}, + ArgumentTypes: []TypeInfo{NativeType{typ: TypeInt, proto: protoVer}}, InitCond: "(0, 0)", - ReturnType: NativeType{typ: TypeDouble}, + ReturnType: NativeType{typ: TypeDouble, proto: protoVer}, StateType: TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, + NativeType: NativeType{typ: TypeTuple, proto: protoVer}, Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeBigInt}, + NativeType{typ: TypeInt, proto: protoVer}, + NativeType{typ: TypeBigInt, proto: protoVer}, }, }, stateFunc: "avgstate", @@ -2401,28 +2403,29 @@ func TestFunctionMetadata(t *testing.T) { avgState := functions[1] avgFinal := functions[0] + protoVer := byte(session.cfg.ProtoVersion) avgStateBody := "if (val !=null) {state.setInt(0, state.getInt(0)+1); state.setLong(1, state.getLong(1)+val.intValue());}return state;" expectedAvgState := FunctionMetadata{ Keyspace: "gocql_test", Name: "avgstate", ArgumentTypes: []TypeInfo{ TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, + NativeType: NativeType{typ: TypeTuple, proto: protoVer}, Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeBigInt}, + NativeType{typ: TypeInt, proto: protoVer}, + NativeType{typ: TypeBigInt, proto: protoVer}, }, }, - NativeType{typ: TypeInt}, + NativeType{typ: TypeInt, proto: protoVer}, }, ArgumentNames: []string{"state", "val"}, ReturnType: TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, + NativeType: NativeType{typ: TypeTuple, proto: protoVer}, Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeBigInt}, + NativeType{typ: TypeInt, proto: protoVer}, + NativeType{typ: TypeBigInt, proto: protoVer}, }, }, CalledOnNullInput: true, @@ -2439,16 +2442,16 @@ func TestFunctionMetadata(t *testing.T) { Name: "avgfinal", ArgumentTypes: []TypeInfo{ TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, + NativeType: NativeType{typ: TypeTuple, proto: protoVer}, Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeBigInt}, + NativeType{typ: TypeInt, proto: protoVer}, + NativeType{typ: TypeBigInt, proto: protoVer}, }, }, }, ArgumentNames: []string{"state"}, - ReturnType: NativeType{typ: TypeDouble}, + ReturnType: NativeType{typ: TypeDouble, proto: protoVer}, CalledOnNullInput: true, Language: "java", Body: finalStateBody, @@ -2557,15 +2560,16 @@ func TestKeyspaceMetadata(t *testing.T) { if flagCassVersion.Before(3, 0, 0) { textType = TypeVarchar } + protoVer := byte(session.cfg.ProtoVersion) expectedType := UserTypeMetadata{ Keyspace: "gocql_test", Name: "basicview", FieldNames: []string{"birthday", "nationality", "weight", "height"}, FieldTypes: []TypeInfo{ - NativeType{typ: TypeTimestamp}, - NativeType{typ: textType}, - NativeType{typ: textType}, - NativeType{typ: textType}, + NativeType{typ: TypeTimestamp, proto: protoVer}, + NativeType{typ: textType, proto: protoVer}, + NativeType{typ: textType, proto: protoVer}, + NativeType{typ: textType, proto: protoVer}, }, } if !reflect.DeepEqual(*keyspaceMetadata.UserTypes["basicview"], expectedType) { From 6304d8f93d05391d8e85f93964b13cbe12ddcce0 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Fri, 11 Oct 2024 10:24:17 +0200 Subject: [PATCH 06/23] Support vector of UDT types --- helpers.go | 37 +++++++++++++++++++++++++++++++++- marshal.go | 2 +- vector_test.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 2 deletions(-) diff --git a/helpers.go b/helpers.go index 4e06e403b..2ad923f3a 100644 --- a/helpers.go +++ b/helpers.go @@ -25,6 +25,7 @@ package gocql import ( + "encoding/hex" "fmt" "math/big" "net" @@ -202,6 +203,26 @@ func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo { NativeType: NativeType{typ: TypeTuple, proto: protoVer}, Elems: types, } + } else if strings.HasPrefix(name, "udt<") { + names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "udt<")) + fields := make([]UDTField, len(names)-2) + + for i := 2; i < len(names); i++ { + spec := strings.Split(names[i], ":") + fieldName, _ := hex.DecodeString(spec[0]) + fields[i-2] = UDTField{ + Name: string(fieldName), + Type: getTypeInfo(spec[1], protoVer, logger), + } + } + + udtName, _ := hex.DecodeString(names[1]) + return UDTTypeInfo{ + NativeType: NativeType{typ: TypeUDT, proto: protoVer}, + KeySpace: names[0], + Name: string(udtName), + Elements: fields, + } } else if strings.HasPrefix(name, "vector<") { names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "vector<")) subType := getCassandraType(strings.TrimSpace(names[0]), protoVer, logger) @@ -258,13 +279,25 @@ func apacheToCassandraType(t string) string { types := strings.FieldsFunc(t, func(r rune) bool { return r == '<' || r == '>' || r == ',' }) + skip := 0 for _, class := range types { class = strings.TrimSpace(class) if !isDigitsOnly(class) { // vector types include dimension (digits) as second type parameter + // UDT fields are represented in format {field id}:{class}, example 66697273745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type + if skip > 0 { + skip -= 1 + continue + } + idx := strings.Index(class, ":") + class = class[idx+1:] act := getApacheCassandraType(class) val := act.String() - if act == TypeCustom { + switch act { + case TypeUDT: + val = "udt" + skip = 2 // skip next two parameters (keyspace and type ID), do not attempt to resolve their type + case TypeCustom: val = getApacheCassandraCustomSubType(class) } t = strings.Replace(t, class, val, -1) @@ -333,6 +366,8 @@ func getApacheCassandraType(class string) Type { return TypeDuration case "SimpleDateType": return TypeDate + case "UserType": + return TypeUDT default: return TypeCustom } diff --git a/marshal.go b/marshal.go index 16f7a1c41..60b287085 100644 --- a/marshal.go +++ b/marshal.go @@ -1827,7 +1827,7 @@ func isVectorVariableLengthType(elemType Type) bool { return true case TypeInet: return true - case TypeList, TypeSet, TypeMap: + case TypeList, TypeSet, TypeMap, TypeUDT: return true } return false diff --git a/vector_test.go b/vector_test.go index 5e7aa9545..12df6fd68 100644 --- a/vector_test.go +++ b/vector_test.go @@ -37,6 +37,16 @@ import ( "time" ) +type person struct { + FirstName string `cql:"first_name"` + LastName string `cql:"last_name"` + Age int `cql:"age"` +} + +func (p person) String() string { + return fmt.Sprintf("Person{firstName: %s, lastName: %s, Age: %d}", p.FirstName, p.LastName, p.Age) +} + func TestVector_Marshaler(t *testing.T) { session := createSession(t) defer session.Close() @@ -187,6 +197,50 @@ func TestVector_Types(t *testing.T) { } } +func TestVector_MarshalerUDT(t *testing.T) { + session := createSession(t) + defer session.Close() + + if flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + err := createTable(session, `CREATE TYPE gocql_test.person( + first_name text, + last_name text, + age int);`) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, `CREATE TABLE gocql_test.vector_relatives( + id int, + couple vector, + primary key(id) + );`) + if err != nil { + t.Fatal(err) + } + + p1 := person{"Johny", "Bravo", 25} + p2 := person{"Capitan", "Planet", 5} + insVec := []person{p1, p2} + + err = session.Query("INSERT INTO vector_relatives(id, couple) VALUES(?, ?)", 1, insVec).Exec() + if err != nil { + t.Fatal(err) + } + + var selVec []person + + err = session.Query("SELECT couple FROM vector_relatives WHERE id = ?", 1).Scan(&selVec) + if err != nil { + t.Fatal(err) + } + + assertDeepEqual(t, "udt", &insVec, &selVec) +} + func TestVector_Empty(t *testing.T) { session := createSession(t) defer session.Close() From 4d2f0802d61316f482a7880d49842dcc59178cb3 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Fri, 11 Oct 2024 11:18:57 +0200 Subject: [PATCH 07/23] Support vector of UDT types --- metadata_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metadata_test.go b/metadata_test.go index 6e3633ccc..37cf44b3d 100644 --- a/metadata_test.go +++ b/metadata_test.go @@ -640,7 +640,7 @@ func TestTypeParser(t *testing.T) { assertParseNonCompositeType( t, "org.apache.cassandra.db.marshal.UserType(sandbox,61646472657373,737472656574:org.apache.cassandra.db.marshal.UTF8Type,63697479:org.apache.cassandra.db.marshal.UTF8Type,7a6970:org.apache.cassandra.db.marshal.Int32Type)", - assertTypeInfo{Type: TypeCustom, Custom: "org.apache.cassandra.db.marshal.UserType(sandbox,61646472657373,737472656574:org.apache.cassandra.db.marshal.UTF8Type,63697479:org.apache.cassandra.db.marshal.UTF8Type,7a6970:org.apache.cassandra.db.marshal.Int32Type)"}, + assertTypeInfo{Type: TypeUDT, Custom: ""}, ) assertParseNonCompositeType( t, From 907e69ebebbe6debc1647fb56625149d417ce41c Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Tue, 15 Oct 2024 18:31:29 +0200 Subject: [PATCH 08/23] Alternative to parse custom string --- frame.go | 24 ++++++++----- marshal.go | 15 +++++--- metadata.go | 89 ++++++++++++++++++++++++++++++++++++++++-------- metadata_test.go | 8 +++-- vector_test.go | 63 ++++++++++++++++++++++++++++++++++ 5 files changed, 169 insertions(+), 30 deletions(-) diff --git a/frame.go b/frame.go index 3bbe9359d..0917cb3b8 100644 --- a/frame.go +++ b/frame.go @@ -931,16 +931,24 @@ func (f *framer) readTypeInfo() TypeInfo { return collection case TypeCustom: if strings.HasPrefix(simple.custom, VECTOR_TYPE) { - spec := strings.TrimPrefix(simple.custom, VECTOR_TYPE) - spec = spec[1 : len(spec)-1] // remove parenthesis - idx := strings.LastIndex(spec, ",") - typeStr := spec[:idx] - dimStr := spec[idx+1:] - subType := getTypeInfo(strings.TrimSpace(typeStr), f.proto, nopLogger{}) - dim, _ := strconv.Atoi(strings.TrimSpace(dimStr)) + // TODO(lantoniak): There are currently two ways of parsing types in the driver. + // a) using getTypeInfo() + // b) using parseType() + // I think we could agree to use getTypeInfo() when parsing binary type definition + // and parseType() would be responsible for parsing "custom" string definition. + //spec := strings.TrimPrefix(simple.custom, VECTOR_TYPE) + //spec = spec[1 : len(spec)-1] // remove parenthesis + //idx := strings.LastIndex(spec, ",") + //typeStr := spec[:idx] + //dimStr := spec[idx+1:] + //subType := getTypeInfo(strings.TrimSpace(typeStr), f.proto, nopLogger{}) + //dim, _ := strconv.Atoi(strings.TrimSpace(dimStr)) + result := parseType(simple.custom, simple.proto, nopLogger{}) + dim, _ := strconv.Atoi(result.types[1].Custom()) vector := VectorType{ NativeType: simple, - SubType: subType, + //SubType: subType, + SubType: result.types[0], Dimensions: dim, } return vector diff --git a/marshal.go b/marshal.go index 60b287085..8f782d7b3 100644 --- a/marshal.go +++ b/marshal.go @@ -1746,7 +1746,7 @@ func marshalVector(info VectorType, value interface{}) ([]byte, error) { if err != nil { return nil, err } - if isVectorVariableLengthType(info.SubType.Type()) { + if isVectorVariableLengthType(info.SubType) { writeUnsignedVInt(buf, uint64(len(item))) } buf.Write(item) @@ -1786,7 +1786,7 @@ func unmarshalVector(info VectorType, data []byte, value interface{}) error { elemSize := len(data) / info.Dimensions for i := 0; i < info.Dimensions; i++ { offset := 0 - if isVectorVariableLengthType(info.SubType.Type()) { + if isVectorVariableLengthType(info.SubType) { m, p, err := readUnsignedVint(data, 0) if err != nil { return err @@ -1815,8 +1815,8 @@ func unmarshalVector(info VectorType, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func isVectorVariableLengthType(elemType Type) bool { - switch elemType { +func isVectorVariableLengthType(elemType TypeInfo) bool { + switch elemType.Type() { case TypeVarchar, TypeAscii, TypeBlob, TypeText: return true case TypeCounter: @@ -1829,6 +1829,13 @@ func isVectorVariableLengthType(elemType Type) bool { return true case TypeList, TypeSet, TypeMap, TypeUDT: return true + case TypeCustom: + switch elemType.(type) { + case VectorType: + vecType := elemType.(VectorType) + return isVectorVariableLengthType(vecType.SubType) + } + return true } return false } diff --git a/metadata.go b/metadata.go index 4cdf7d500..73178813d 100644 --- a/metadata.go +++ b/metadata.go @@ -389,7 +389,7 @@ func compileMetadata( col.Order = DESC } } else { - validatorParsed := parseType(col.Validator, logger) + validatorParsed := parseType(col.Validator, byte(protoVersion), logger) col.Type = validatorParsed.types[0] col.Order = ASC if validatorParsed.reversed[0] { @@ -411,9 +411,9 @@ func compileMetadata( } if protoVersion == protoVersion1 { - compileV1Metadata(tables, logger) + compileV1Metadata(tables, protoVersion, logger) } else { - compileV2Metadata(tables, logger) + compileV2Metadata(tables, protoVersion, logger) } } @@ -422,14 +422,14 @@ func compileMetadata( // column metadata as V2+ (because V1 doesn't support the "type" column in the // system.schema_columns table) so determining PartitionKey and ClusterColumns // is more complex. -func compileV1Metadata(tables []TableMetadata, logger StdLogger) { +func compileV1Metadata(tables []TableMetadata, protoVer int, logger StdLogger) { for i := range tables { table := &tables[i] // decode the key validator - keyValidatorParsed := parseType(table.KeyValidator, logger) + keyValidatorParsed := parseType(table.KeyValidator, byte(protoVer), logger) // decode the comparator - comparatorParsed := parseType(table.Comparator, logger) + comparatorParsed := parseType(table.Comparator, byte(protoVer), logger) // the partition key length is the same as the number of types in the // key validator @@ -515,7 +515,7 @@ func compileV1Metadata(tables []TableMetadata, logger StdLogger) { alias = table.ValueAlias } // decode the default validator - defaultValidatorParsed := parseType(table.DefaultValidator, logger) + defaultValidatorParsed := parseType(table.DefaultValidator, byte(protoVer), logger) column := &ColumnMetadata{ Keyspace: table.Keyspace, Table: table.Name, @@ -529,7 +529,7 @@ func compileV1Metadata(tables []TableMetadata, logger StdLogger) { } // The simpler compile case for V2+ protocol -func compileV2Metadata(tables []TableMetadata, logger StdLogger) { +func compileV2Metadata(tables []TableMetadata, protoVer int, logger StdLogger) { for i := range tables { table := &tables[i] @@ -537,7 +537,7 @@ func compileV2Metadata(tables []TableMetadata, logger StdLogger) { table.ClusteringColumns = make([]*ColumnMetadata, clusteringColumnCount) if table.KeyValidator != "" { - keyValidatorParsed := parseType(table.KeyValidator, logger) + keyValidatorParsed := parseType(table.KeyValidator, byte(protoVer), logger) table.PartitionKey = make([]*ColumnMetadata, len(keyValidatorParsed.types)) } else { // Cassandra 3.x+ partitionKeyCount := componentColumnCountOfType(table.Columns, ColumnPartitionKey) @@ -1186,6 +1186,7 @@ type typeParser struct { input string index int logger StdLogger + proto byte } // the type definition parser result @@ -1197,8 +1198,8 @@ type typeParserResult struct { } // Parse the type definition used for validator and comparator schema data -func parseType(def string, logger StdLogger) typeParserResult { - parser := &typeParser{input: def, logger: logger} +func parseType(def string, protoVer byte, logger StdLogger) typeParserResult { + parser := &typeParser{input: def, proto: protoVer, logger: logger} return parser.parse() } @@ -1209,6 +1210,7 @@ const ( LIST_TYPE = "org.apache.cassandra.db.marshal.ListType" SET_TYPE = "org.apache.cassandra.db.marshal.SetType" MAP_TYPE = "org.apache.cassandra.db.marshal.MapType" + UDT_TYPE = "org.apache.cassandra.db.marshal.UserType" VECTOR_TYPE = "org.apache.cassandra.db.marshal.VectorType" ) @@ -1218,6 +1220,7 @@ type typeParserClassNode struct { params []typeParserParamNode // this is the segment of the input string that defined this node input string + proto byte } // represents a class parameter in the type def AST @@ -1237,6 +1240,7 @@ func (t *typeParser) parse() typeParserResult { NativeType{ typ: TypeCustom, custom: t.input, + proto: t.proto, }, }, reversed: []bool{false}, @@ -1292,6 +1296,26 @@ func (t *typeParser) parse() typeParserResult { reversed: reversed, collections: collections, } + } else if strings.HasPrefix(ast.name, VECTOR_TYPE) { + count := len(ast.params) + + types := make([]TypeInfo, count) + reversed := make([]bool, count) + + for i, param := range ast.params[:count] { + class := param.class + reversed[i] = strings.HasPrefix(class.name, REVERSED_TYPE) + if reversed[i] { + class = class.params[0].class + } + types[i] = class.asTypeInfo() + } + + return typeParserResult{ + isComposite: true, + types: types, + reversed: reversed, + } } else { // not composite, so one type class := *ast @@ -1314,7 +1338,8 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { elem := class.params[0].class.asTypeInfo() return CollectionType{ NativeType: NativeType{ - typ: TypeList, + typ: TypeList, + proto: class.proto, }, Elem: elem, } @@ -1323,7 +1348,8 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { elem := class.params[0].class.asTypeInfo() return CollectionType{ NativeType: NativeType{ - typ: TypeSet, + typ: TypeSet, + proto: class.proto, }, Elem: elem, } @@ -1333,15 +1359,47 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { elem := class.params[1].class.asTypeInfo() return CollectionType{ NativeType: NativeType{ - typ: TypeMap, + typ: TypeMap, + proto: class.proto, }, Key: key, Elem: elem, } } + if strings.HasPrefix(class.name, UDT_TYPE) { + udtName, _ := hex.DecodeString(class.params[1].class.name) + fields := make([]UDTField, len(class.params)-2) + for i := 2; i < len(class.params); i++ { + fieldName, _ := hex.DecodeString(*class.params[i].name) + fields[i-2] = UDTField{ + Name: string(fieldName), + Type: class.params[i].class.asTypeInfo(), + } + } + return UDTTypeInfo{ + NativeType: NativeType{ + typ: TypeUDT, + proto: class.proto, + }, + KeySpace: class.params[0].class.name, + Name: string(udtName), + Elements: fields, + } + } + if strings.HasPrefix(class.name, VECTOR_TYPE) { + dim, _ := strconv.Atoi(class.params[1].class.name) + return VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + proto: class.proto, + }, + SubType: class.params[0].class.asTypeInfo(), + Dimensions: dim, + } + } // must be a simple type or custom type - info := NativeType{typ: getApacheCassandraType(class.name)} + info := NativeType{typ: getApacheCassandraType(class.name), proto: class.proto} if info.typ == TypeCustom { // add the entire class definition info.custom = class.input @@ -1371,6 +1429,7 @@ func (t *typeParser) parseClassNode() (node *typeParserClassNode, ok bool) { name: name, params: params, input: t.input[startIndex:endIndex], + proto: t.proto, } return node, true } diff --git a/metadata_test.go b/metadata_test.go index 37cf44b3d..b9bd9f74a 100644 --- a/metadata_test.go +++ b/metadata_test.go @@ -636,12 +636,14 @@ func TestTypeParser(t *testing.T) { }, ) - // custom + // udt assertParseNonCompositeType( t, "org.apache.cassandra.db.marshal.UserType(sandbox,61646472657373,737472656574:org.apache.cassandra.db.marshal.UTF8Type,63697479:org.apache.cassandra.db.marshal.UTF8Type,7a6970:org.apache.cassandra.db.marshal.Int32Type)", assertTypeInfo{Type: TypeUDT, Custom: ""}, ) + + // custom assertParseNonCompositeType( t, "org.apache.cassandra.db.marshal.DynamicCompositeType(u=>org.apache.cassandra.db.marshal.UUIDType,d=>org.apache.cassandra.db.marshal.DateType,t=>org.apache.cassandra.db.marshal.TimeUUIDType,b=>org.apache.cassandra.db.marshal.BytesType,s=>org.apache.cassandra.db.marshal.UTF8Type,B=>org.apache.cassandra.db.marshal.BooleanType,a=>org.apache.cassandra.db.marshal.AsciiType,l=>org.apache.cassandra.db.marshal.LongType,i=>org.apache.cassandra.db.marshal.IntegerType,x=>org.apache.cassandra.db.marshal.LexicalUUIDType)", @@ -700,7 +702,7 @@ func assertParseNonCompositeType( ) { log := &defaultLogger{} - result := parseType(def, log) + result := parseType(def, 4, log) if len(result.reversed) != 1 { t.Errorf("%s expected %d reversed values but there were %d", def, 1, len(result.reversed)) } @@ -731,7 +733,7 @@ func assertParseCompositeType( ) { log := &defaultLogger{} - result := parseType(def, log) + result := parseType(def, 4, log) if len(result.reversed) != len(typesExpected) { t.Errorf("%s expected %d reversed values but there were %d", def, len(typesExpected), len(result.reversed)) } diff --git a/vector_test.go b/vector_test.go index 12df6fd68..1206d27d0 100644 --- a/vector_test.go +++ b/vector_test.go @@ -165,6 +165,11 @@ func TestVector_Types(t *testing.T) { {name: "duration", cqlType: TypeDuration.String(), value: []Duration{duration1, duration2, duration3}}, // TODO(lantonia): Test vector of custom types {name: "vector_vector_float", cqlType: "vector", value: [][]float32{{0.1, -1.2, 3, 5, 5}, {10.1, -122222.0002, 35.0, 1, 1}, {0, 0, 0, 0, 0}}}, + {name: "vector_vector_set_float", cqlType: "vector, 5>", value: [][][]float32{ + {{1, 2}, {2, -1}, {3}, {0}, {-1.3}}, + {{2, 3}, {2, -1}, {3}, {0}, {-1.3}}, + {{1, 1000.0}, {0}, {}, {12, 14, 15, 16}, {-1.3}}, + }}, {name: "vector_set_text", cqlType: "set", value: [][]string{{"a", "b"}, {"c", "d"}, {"e", "f"}}}, {name: "vector_list_int", cqlType: "list", value: [][]int32{{1, 2, 3}, {-1, -2, -3}, {0, 0, 0}}}, {name: "vector_map_text_int", cqlType: "map", value: []map[string]int{map1, map2, map3}}, @@ -301,3 +306,61 @@ func TestVector_MissingDimension(t *testing.T) { err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 1, []float32{8, -5.0, 1, 3}).Exec() require.Error(t, err, "expected vector with 3 dimensions, received 4") } + +func TestVector_SubTypeParsing(t *testing.T) { + tests := []struct { + name string + custom string + expected TypeInfo + }{ + {name: "text", custom: "org.apache.cassandra.db.marshal.UTF8Type", expected: NativeType{typ: TypeVarchar}}, + {name: "set_int", custom: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.Int32Type)", expected: CollectionType{NativeType{typ: TypeSet}, nil, NativeType{typ: TypeInt}}}, + { + name: "udt", + custom: "org.apache.cassandra.db.marshal.UserType(gocql_test,706572736f6e,66697273745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type,6c6173745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type,616765:org.apache.cassandra.db.marshal.Int32Type)", + expected: UDTTypeInfo{ + NativeType{typ: TypeUDT}, + "gocql_test", + "person", + []UDTField{ + UDTField{"first_name", NativeType{typ: TypeVarchar}}, + UDTField{"last_name", NativeType{typ: TypeVarchar}}, + UDTField{"age", NativeType{typ: TypeInt}}, + }, + }, + }, + { + name: "vector_vector_inet", + custom: "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.InetAddressType, 2), 3)", + expected: VectorType{ + NativeType{typ: TypeCustom}, + VectorType{ + NativeType{typ: TypeCustom}, + NativeType{typ: TypeInet}, + 2, + }, + 3, + }, + }, + { + name: "map_int_vector_text", + custom: "org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.UTF8Type, 10))", + expected: CollectionType{ + NativeType{typ: TypeMap}, + NativeType{typ: TypeInt}, + VectorType{ + NativeType{typ: TypeCustom}, + NativeType{typ: TypeVarchar}, + 10, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + subType := parseType(fmt.Sprintf("org.apache.cassandra.db.marshal.VectorType(%s, 2)", test.custom), 0, &defaultLogger{}) + assertDeepEqual(t, "vector", test.expected, subType.types[0]) + }) + } +} From a4d8a5b46275c2e3a5c506afceb9eda1e6dfc7ac Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 16 Oct 2024 07:25:28 +0200 Subject: [PATCH 09/23] Support vector of tuples --- marshal.go | 2 +- metadata.go | 14 ++++++++++++++ vector_test.go | 14 ++++++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/marshal.go b/marshal.go index 8f782d7b3..214752e6c 100644 --- a/marshal.go +++ b/marshal.go @@ -1827,7 +1827,7 @@ func isVectorVariableLengthType(elemType TypeInfo) bool { return true case TypeInet: return true - case TypeList, TypeSet, TypeMap, TypeUDT: + case TypeList, TypeSet, TypeMap, TypeUDT, TypeTuple: return true case TypeCustom: switch elemType.(type) { diff --git a/metadata.go b/metadata.go index 73178813d..6279b071f 100644 --- a/metadata.go +++ b/metadata.go @@ -1211,6 +1211,7 @@ const ( SET_TYPE = "org.apache.cassandra.db.marshal.SetType" MAP_TYPE = "org.apache.cassandra.db.marshal.MapType" UDT_TYPE = "org.apache.cassandra.db.marshal.UserType" + TUPLE_TYPE = "org.apache.cassandra.db.marshal.TupleType" VECTOR_TYPE = "org.apache.cassandra.db.marshal.VectorType" ) @@ -1386,6 +1387,19 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { Elements: fields, } } + if strings.HasPrefix(class.name, TUPLE_TYPE) { + fields := make([]TypeInfo, len(class.params)) + for i := 0; i < len(class.params); i++ { + fields[i] = class.params[i].class.asTypeInfo() + } + return TupleTypeInfo{ + NativeType: NativeType{ + typ: TypeTuple, + proto: class.proto, + }, + Elems: fields, + } + } if strings.HasPrefix(class.name, VECTOR_TYPE) { dim, _ := strconv.Atoi(class.params[1].class.name) return VectorType{ diff --git a/vector_test.go b/vector_test.go index 1206d27d0..f782878e9 100644 --- a/vector_test.go +++ b/vector_test.go @@ -170,6 +170,8 @@ func TestVector_Types(t *testing.T) { {{2, 3}, {2, -1}, {3}, {0}, {-1.3}}, {{1, 1000.0}, {0}, {}, {12, 14, 15, 16}, {-1.3}}, }}, + {name: "vector_tuple_text_int_float", cqlType: "tuple", value: [][]interface{}{{"a", 1, float32(0.5)}, {"b", 2, float32(-1.2)}, {"c", 3, float32(0)}}}, + {name: "vector_tuple_text_list_text", cqlType: "tuple>", value: [][]interface{}{{"a", []string{"b", "c"}}, {"d", []string{"e", "f", "g"}}, {"h", []string{"i"}}}}, {name: "vector_set_text", cqlType: "set", value: [][]string{{"a", "b"}, {"c", "d"}, {"e", "f"}}}, {name: "vector_list_int", cqlType: "list", value: [][]int32{{1, 2, 3}, {-1, -2, -3}, {0, 0, 0}}}, {name: "vector_map_text_int", cqlType: "map", value: []map[string]int{map1, map2, map3}}, @@ -329,6 +331,18 @@ func TestVector_SubTypeParsing(t *testing.T) { }, }, }, + { + name: "tuple", + custom: "org.apache.cassandra.db.marshal.TupleType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.UTF8Type)", + expected: TupleTypeInfo{ + NativeType{typ: TypeTuple}, + []TypeInfo{ + NativeType{typ: TypeVarchar}, + NativeType{typ: TypeInt}, + NativeType{typ: TypeVarchar}, + }, + }, + }, { name: "vector_vector_inet", custom: "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.InetAddressType, 2), 3)", From 8835ad8b24115a6020be498a757c3b60ab67883f Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 16 Oct 2024 07:28:47 +0200 Subject: [PATCH 10/23] Apply review comments --- marshal.go | 1 - 1 file changed, 1 deletion(-) diff --git a/marshal.go b/marshal.go index 214752e6c..2fb7ccaa7 100644 --- a/marshal.go +++ b/marshal.go @@ -1847,7 +1847,6 @@ func writeUnsignedVInt(buf *bytes.Buffer, v uint64) { return } - numBytes = computeUnsignedVIntSize(v) extraBytes := numBytes - 1 var tmp = make([]byte, numBytes) for i := extraBytes; i >= 0; i-- { From cafc8dc769049c57c08719127c65f03d06fefa41 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 16 Oct 2024 09:12:54 +0200 Subject: [PATCH 11/23] Fix PIP installation --- .github/workflows/main.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d2044a0c0..2473632d9 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -52,7 +52,10 @@ jobs: restore-keys: | ${{ runner.os }}-go- - name: Install CCM - run: pip install "git+https://github.com/riptano/ccm.git@${CCM_VERSION}" + run: | + python3 -m venv ~/venv + ~/venv/bin/pip install "git+https://github.com/riptano/ccm.git@${CCM_VERSION}" + echo "PATH=~/venv/bin:$PATH" >> $GITHUB_ENV - name: Start cassandra nodes run: | VERSION=${{ matrix.cassandra_version }} From 3bdc9f52770a07abedc31baab6d503917ed92a84 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 16 Oct 2024 09:28:45 +0200 Subject: [PATCH 12/23] Fix PIP installation --- .github/workflows/main.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 2473632d9..c8486c7ca 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -54,6 +54,7 @@ jobs: - name: Install CCM run: | python3 -m venv ~/venv + ~/venv/bin/pip install setuptools ~/venv/bin/pip install "git+https://github.com/riptano/ccm.git@${CCM_VERSION}" echo "PATH=~/venv/bin:$PATH" >> $GITHUB_ENV - name: Start cassandra nodes From 09a56886a5ceefcb641995d151f3c4569f6c8f04 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 16 Oct 2024 12:50:25 +0200 Subject: [PATCH 13/23] Install JDK 11 for C* 4.x --- .github/workflows/main.yml | 40 +++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index c8486c7ca..fb6c4dba3 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -51,6 +51,23 @@ jobs: key: ${{ runner.os }}-go-${{ hashFiles('go.mod') }} restore-keys: | ${{ runner.os }}-go- + - name: Install Java + run: | + curl -s "https://get.sdkman.io" | bash + source "$HOME/.sdkman/bin/sdkman-init.sh" + echo "sdkman_auto_answer=true" >> ~/.sdkman/etc/config + # sdk list java + + sdk install java 11.0.24-zulu + echo "JAVA11_HOME=$JAVA_HOME_11_X64" >> $GITHUB_ENV + + sdk install java 17.0.12-zulu + echo "JAVA17_HOME=$JAVA_HOME_17_X64" >> $GITHUB_ENV + + # by default use JDK 11 + sdk default java 11.0.24-zulu + sdk use java 11.0.24-zulu + echo "JAVA_HOME=$JAVA_HOME_11_X64" >> $GITHUB_ENV - name: Install CCM run: | python3 -m venv ~/venv @@ -139,8 +156,29 @@ jobs: - uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} + - name: Install Java + run: | + curl -s "https://get.sdkman.io" | bash + source "$HOME/.sdkman/bin/sdkman-init.sh" + echo "sdkman_auto_answer=true" >> ~/.sdkman/etc/config + # sdk list java + + sdk install java 11.0.24-zulu + echo "JAVA11_HOME=$JAVA_HOME_11_X64" >> $GITHUB_ENV + + sdk install java 17.0.12-zulu + echo "JAVA17_HOME=$JAVA_HOME_17_X64" >> $GITHUB_ENV + + # by default use JDK 11 + sdk default java 11.0.24-zulu + sdk use java 11.0.24-zulu + echo "JAVA_HOME=$JAVA_HOME_11_X64" >> $GITHUB_ENV - name: Install CCM - run: pip install "git+https://github.com/riptano/ccm.git@${CCM_VERSION}" + run: | + python3 -m venv ~/venv + ~/venv/bin/pip install setuptools + ~/venv/bin/pip install "git+https://github.com/riptano/ccm.git@${CCM_VERSION}" + echo "PATH=~/venv/bin:$PATH" >> $GITHUB_ENV - name: Start cassandra nodes run: | VERSION=${{ matrix.cassandra_version }} From 533f52e748df95b5c83fcf875792686a98294d52 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 16 Oct 2024 13:05:00 +0200 Subject: [PATCH 14/23] Install JDK 11 for C* 4.x --- .github/workflows/main.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index fb6c4dba3..282a22bc8 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -68,6 +68,7 @@ jobs: sdk default java 11.0.24-zulu sdk use java 11.0.24-zulu echo "JAVA_HOME=$JAVA_HOME_11_X64" >> $GITHUB_ENV + echo "PATH=$PATH" >> $GITHUB_ENV - name: Install CCM run: | python3 -m venv ~/venv @@ -173,6 +174,7 @@ jobs: sdk default java 11.0.24-zulu sdk use java 11.0.24-zulu echo "JAVA_HOME=$JAVA_HOME_11_X64" >> $GITHUB_ENV + echo "PATH=$PATH" >> $GITHUB_ENV - name: Install CCM run: | python3 -m venv ~/venv From 3952e5429aafa6f5958fbaac8f42716c421045be Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 16 Oct 2024 14:46:38 +0200 Subject: [PATCH 15/23] Trigger Build From dc288e7480c32453b3bbc9cb9e623bf772643783 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 16 Oct 2024 16:31:00 +0200 Subject: [PATCH 16/23] Tune Gossip --- .github/workflows/main.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 282a22bc8..cc74c727d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -117,7 +117,7 @@ jobs: ccm create test -v $VERSION -n 3 -d --vnodes --jvm_arg="-Xmx256m -XX:NewSize=100m" ccm updateconf "${conf[@]}" - export JVM_EXTRA_OPTS=" -Dcassandra.test.fail_writes_ks=test -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler" + export JVM_EXTRA_OPTS=" -Dcassandra.test.fail_writes_ks=test -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler -Dcassandra.gossip_settle_min_wait_ms=1000 -Dcassandra.gossip_settle_interval_ms=500 -Dcassandra.gossip_settle_poll_success_required=2" ccm start --wait-for-binary-proto --verbose ccm status @@ -228,7 +228,7 @@ jobs: rm -rf $HOME/.ccm/test/node1/data/system_auth - export JVM_EXTRA_OPTS=" -Dcassandra.test.fail_writes_ks=test -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler" + export JVM_EXTRA_OPTS=" -Dcassandra.test.fail_writes_ks=test -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler -Dcassandra.gossip_settle_min_wait_ms=1000 -Dcassandra.gossip_settle_interval_ms=500 -Dcassandra.gossip_settle_poll_success_required=2" ccm start --wait-for-binary-proto --verbose ccm status From 9205eff50ee1531e7b1c29e527d465ba47b4fab2 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Thu, 17 Oct 2024 11:25:24 +0200 Subject: [PATCH 17/23] Fix integration tests Github action --- .github/workflows/main.yml | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index cc74c727d..afa1e3755 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -74,9 +74,9 @@ jobs: python3 -m venv ~/venv ~/venv/bin/pip install setuptools ~/venv/bin/pip install "git+https://github.com/riptano/ccm.git@${CCM_VERSION}" - echo "PATH=~/venv/bin:$PATH" >> $GITHUB_ENV - name: Start cassandra nodes run: | + source ~/venv/bin/activate VERSION=${{ matrix.cassandra_version }} keypath="$(pwd)/testdata/pki" conf=( @@ -117,7 +117,7 @@ jobs: ccm create test -v $VERSION -n 3 -d --vnodes --jvm_arg="-Xmx256m -XX:NewSize=100m" ccm updateconf "${conf[@]}" - export JVM_EXTRA_OPTS=" -Dcassandra.test.fail_writes_ks=test -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler -Dcassandra.gossip_settle_min_wait_ms=1000 -Dcassandra.gossip_settle_interval_ms=500 -Dcassandra.gossip_settle_poll_success_required=2" + export JVM_EXTRA_OPTS=" -Dcassandra.test.fail_writes_ks=test -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler" ccm start --wait-for-binary-proto --verbose ccm status @@ -129,6 +129,7 @@ jobs: echo "JVM_EXTRA_OPTS=$JVM_EXTRA_OPTS" >> $GITHUB_ENV - name: Integration tests run: | + source ~/venv/bin/activate export JVM_EXTRA_OPTS="${{env.JVM_EXTRA_OPTS}}" go test -v -tags "${{ matrix.tags }} gocql_debug" -timeout=5m -race ${{ env.args }} - name: 'Save ccm logs' @@ -180,9 +181,9 @@ jobs: python3 -m venv ~/venv ~/venv/bin/pip install setuptools ~/venv/bin/pip install "git+https://github.com/riptano/ccm.git@${CCM_VERSION}" - echo "PATH=~/venv/bin:$PATH" >> $GITHUB_ENV - name: Start cassandra nodes run: | + source ~/venv/bin/activate VERSION=${{ matrix.cassandra_version }} keypath="$(pwd)/testdata/pki" conf=( @@ -228,7 +229,7 @@ jobs: rm -rf $HOME/.ccm/test/node1/data/system_auth - export JVM_EXTRA_OPTS=" -Dcassandra.test.fail_writes_ks=test -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler -Dcassandra.gossip_settle_min_wait_ms=1000 -Dcassandra.gossip_settle_interval_ms=500 -Dcassandra.gossip_settle_poll_success_required=2" + export JVM_EXTRA_OPTS=" -Dcassandra.test.fail_writes_ks=test -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler" ccm start --wait-for-binary-proto --verbose ccm status @@ -241,5 +242,6 @@ jobs: sleep 30s - name: Integration tests run: | + source ~/venv/bin/activate export JVM_EXTRA_OPTS="${{env.JVM_EXTRA_OPTS}}" go test -v -run=TestAuthentication -tags "${{ matrix.tags }} gocql_debug" -timeout=15s -runauth ${{ env.args }} From 57c99b940d2f2661e82e55b679f7c50bc275581f Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Mon, 28 Oct 2024 13:43:00 +0100 Subject: [PATCH 18/23] Apply review comments --- frame.go | 22 +++++------ helpers.go | 100 +++++++++++++++++++++++------------------------ marshal.go | 6 ++- metadata_test.go | 17 ++++++-- vector_test.go | 20 +++++++--- 5 files changed, 95 insertions(+), 70 deletions(-) diff --git a/frame.go b/frame.go index 0917cb3b8..26af4edbb 100644 --- a/frame.go +++ b/frame.go @@ -936,19 +936,19 @@ func (f *framer) readTypeInfo() TypeInfo { // b) using parseType() // I think we could agree to use getTypeInfo() when parsing binary type definition // and parseType() would be responsible for parsing "custom" string definition. - //spec := strings.TrimPrefix(simple.custom, VECTOR_TYPE) - //spec = spec[1 : len(spec)-1] // remove parenthesis - //idx := strings.LastIndex(spec, ",") - //typeStr := spec[:idx] - //dimStr := spec[idx+1:] - //subType := getTypeInfo(strings.TrimSpace(typeStr), f.proto, nopLogger{}) - //dim, _ := strconv.Atoi(strings.TrimSpace(dimStr)) - result := parseType(simple.custom, simple.proto, nopLogger{}) - dim, _ := strconv.Atoi(result.types[1].Custom()) + spec := strings.TrimPrefix(simple.custom, VECTOR_TYPE) + spec = spec[1 : len(spec)-1] // remove parenthesis + idx := strings.LastIndex(spec, ",") + typeStr := spec[:idx] + dimStr := spec[idx+1:] + subType := getTypeInfo(strings.TrimSpace(typeStr), f.proto, nopLogger{}) + dim, _ := strconv.Atoi(strings.TrimSpace(dimStr)) + //result := parseType(simple.custom, simple.proto, nopLogger{}) + //dim, _ := strconv.Atoi(result.types[1].Custom()) vector := VectorType{ NativeType: simple, - //SubType: subType, - SubType: result.types[0], + SubType: subType, + //SubType: result.types[0], Dimensions: dim, } return vector diff --git a/helpers.go b/helpers.go index 2ad923f3a..a4a1efde2 100644 --- a/helpers.go +++ b/helpers.go @@ -164,30 +164,30 @@ func getCassandraBaseType(name string) Type { } } +// Parses short CQL type representation to internal data structures. +// Mapping of long Java-style type definition into short format is performed in +// apacheToCassandraType function. func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo { if strings.HasPrefix(name, "frozen<") { return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"), protoVer, logger) } else if strings.HasPrefix(name, "set<") { return CollectionType{ - NativeType: NativeType{typ: TypeSet, proto: protoVer}, + NativeType: NewNativeType(protoVer, TypeSet), Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<"), protoVer, logger), } } else if strings.HasPrefix(name, "list<") { return CollectionType{ - NativeType: NativeType{typ: TypeList, proto: protoVer}, + NativeType: NewNativeType(protoVer, TypeList), Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), protoVer, logger), } } else if strings.HasPrefix(name, "map<") { names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) if len(names) != 2 { logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names)) - return NativeType{ - proto: protoVer, - typ: TypeCustom, - } + return NewNativeType(protoVer, TypeCustom) } return CollectionType{ - NativeType: NativeType{typ: TypeMap, proto: protoVer}, + NativeType: NewNativeType(protoVer, TypeMap), Key: getCassandraType(names[0], protoVer, logger), Elem: getCassandraType(names[1], protoVer, logger), } @@ -200,11 +200,29 @@ func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo { } return TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple, proto: protoVer}, + NativeType: NewNativeType(protoVer, TypeTuple), Elems: types, } - } else if strings.HasPrefix(name, "udt<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "udt<")) + } else if strings.HasPrefix(name, "vector<") { + names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "vector<")) + subType := getCassandraType(strings.TrimSpace(names[0]), protoVer, logger) + dim, _ := strconv.Atoi(strings.TrimSpace(names[1])) + + return VectorType{ + NativeType: NewCustomType(protoVer, TypeCustom, VECTOR_TYPE), + SubType: subType, + Dimensions: dim, + } + } else if strings.Index(name, "<") == -1 { + // basic type + return NativeType{ + proto: protoVer, + typ: getCassandraBaseType(name), + } + } else { + // udt + idx := strings.Index(name, "<") + names := splitCompositeTypes(name[idx+1 : len(name)-1]) fields := make([]UDTField, len(names)-2) for i := 2; i < len(names); i++ { @@ -218,30 +236,11 @@ func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo { udtName, _ := hex.DecodeString(names[1]) return UDTTypeInfo{ - NativeType: NativeType{typ: TypeUDT, proto: protoVer}, + NativeType: NewNativeType(protoVer, TypeUDT), KeySpace: names[0], Name: string(udtName), Elements: fields, } - } else if strings.HasPrefix(name, "vector<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "vector<")) - subType := getCassandraType(strings.TrimSpace(names[0]), protoVer, logger) - dim, _ := strconv.Atoi(strings.TrimSpace(names[1])) - - return VectorType{ - NativeType: NativeType{ - proto: protoVer, - typ: TypeCustom, - custom: VECTOR_TYPE, - }, - SubType: subType, - Dimensions: dim, - } - } else { - return NativeType{ - proto: protoVer, - typ: getCassandraBaseType(name), - } } } @@ -273,35 +272,34 @@ func splitCompositeTypes(name string) []string { return parts } +// Convert long Java style type definition into the short CQL type names. func apacheToCassandraType(t string) string { t = strings.Replace(t, "(", "<", -1) t = strings.Replace(t, ")", ">", -1) types := strings.FieldsFunc(t, func(r rune) bool { return r == '<' || r == '>' || r == ',' }) - skip := 0 - for _, class := range types { - class = strings.TrimSpace(class) - if !isDigitsOnly(class) { - // vector types include dimension (digits) as second type parameter - // UDT fields are represented in format {field id}:{class}, example 66697273745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type - if skip > 0 { - skip -= 1 - continue - } - idx := strings.Index(class, ":") - class = class[idx+1:] - act := getApacheCassandraType(class) - val := act.String() - switch act { - case TypeUDT: - val = "udt" - skip = 2 // skip next two parameters (keyspace and type ID), do not attempt to resolve their type - case TypeCustom: + for i := 0; i < len(types); i++ { + class := strings.TrimSpace(types[i]) + // UDT fields are represented in format {field id}:{class}, example 66697273745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type + // Do not override hex encoded field names + idx := strings.Index(class, ":") + class = class[idx+1:] + act := getApacheCassandraType(class) + val := act.String() + switch act { + case TypeUDT: + i += 2 // skip next two parameters (keyspace and type ID), do not attempt to resolve their type + case TypeCustom: + if isDigitsOnly(class) { + // vector types include dimension (digits) as second type parameter + // getApacheCassandraType() returns "custom" by default, but we need to leave digits intact + val = class + } else { val = getApacheCassandraCustomSubType(class) } - t = strings.Replace(t, class, val, -1) } + t = strings.Replace(t, class, val, -1) } // This is done so it exactly matches what Cassandra returns return strings.Replace(t, ",", ", ", -1) @@ -373,6 +371,8 @@ func getApacheCassandraType(class string) Type { } } +// Dedicated function parsing known special subtypes of CQL custom type. +// Currently, only vectors are implemented as special custom subtype. func getApacheCassandraCustomSubType(class string) string { switch strings.TrimPrefix(class, apacheCassandraTypePrefix) { case "VectorType": diff --git a/marshal.go b/marshal.go index 2fb7ccaa7..cd8c7b32e 100644 --- a/marshal.go +++ b/marshal.go @@ -2649,7 +2649,11 @@ type NativeType struct { custom string // only used for TypeCustom } -func NewNativeType(proto byte, typ Type, custom string) NativeType { +func NewNativeType(proto byte, typ Type) NativeType { + return NativeType{proto, typ, ""} +} + +func NewCustomType(proto byte, typ Type, custom string) NativeType { return NativeType{proto, typ, custom} } diff --git a/metadata_test.go b/metadata_test.go index b9bd9f74a..78d94dd21 100644 --- a/metadata_test.go +++ b/metadata_test.go @@ -643,6 +643,17 @@ func TestTypeParser(t *testing.T) { assertTypeInfo{Type: TypeUDT, Custom: ""}, ) + // vector + assertParseCompositeType( + t, + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 3)", + []assertTypeInfo{ + {Type: TypeFloat}, + {Type: TypeCustom, Custom: "3"}, + }, + nil, + ) + // custom assertParseNonCompositeType( t, @@ -702,7 +713,7 @@ func assertParseNonCompositeType( ) { log := &defaultLogger{} - result := parseType(def, 4, log) + result := parseType(def, protoVersion4, log) if len(result.reversed) != 1 { t.Errorf("%s expected %d reversed values but there were %d", def, 1, len(result.reversed)) } @@ -733,7 +744,7 @@ func assertParseCompositeType( ) { log := &defaultLogger{} - result := parseType(def, 4, log) + result := parseType(def, protoVersion4, log) if len(result.reversed) != len(typesExpected) { t.Errorf("%s expected %d reversed values but there were %d", def, len(typesExpected), len(result.reversed)) } @@ -749,7 +760,7 @@ func assertParseCompositeType( if !result.isComposite { t.Errorf("%s: Expected composite", def) } - if result.collections == nil { + if result.collections == nil && collectionsExpected != nil { t.Errorf("%s: Expected non-nil collections: %v", def, result.collections) } diff --git a/vector_test.go b/vector_test.go index f782878e9..1e8dbdc21 100644 --- a/vector_test.go +++ b/vector_test.go @@ -347,9 +347,9 @@ func TestVector_SubTypeParsing(t *testing.T) { name: "vector_vector_inet", custom: "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.InetAddressType, 2), 3)", expected: VectorType{ - NativeType{typ: TypeCustom}, + NativeType{typ: TypeCustom, custom: VECTOR_TYPE}, VectorType{ - NativeType{typ: TypeCustom}, + NativeType{typ: TypeCustom, custom: VECTOR_TYPE}, NativeType{typ: TypeInet}, 2, }, @@ -363,7 +363,7 @@ func TestVector_SubTypeParsing(t *testing.T) { NativeType{typ: TypeMap}, NativeType{typ: TypeInt}, VectorType{ - NativeType{typ: TypeCustom}, + NativeType{typ: TypeCustom, custom: VECTOR_TYPE}, NativeType{typ: TypeVarchar}, 10, }, @@ -373,8 +373,18 @@ func TestVector_SubTypeParsing(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - subType := parseType(fmt.Sprintf("org.apache.cassandra.db.marshal.VectorType(%s, 2)", test.custom), 0, &defaultLogger{}) - assertDeepEqual(t, "vector", test.expected, subType.types[0]) + f := newFramer(nil, 0) + f.writeShort(0) + f.writeString(fmt.Sprintf("org.apache.cassandra.db.marshal.VectorType(%s, 2)", test.custom)) + parsedType := f.readTypeInfo() + require.IsType(t, parsedType, VectorType{}) + + // test first parsing method + vectorType := parsedType.(VectorType) + assertEqual(t, "dimensions", 2, vectorType.Dimensions) + assertDeepEqual(t, "vector", test.expected, vectorType.SubType) + //subType := parseType(fmt.Sprintf("org.apache.cassandra.db.marshal.VectorType(%s, 2)", test.custom), 0, &defaultLogger{}) + //assertDeepEqual(t, "vector", test.expected, subType.types[0]) }) } } From 41fed01147d52946ddbc6a6f5e832abf8046f3ee Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Tue, 29 Oct 2024 11:11:44 +0100 Subject: [PATCH 19/23] Apply review comments --- helpers.go | 30 ++++++++++-------------------- vector_test.go | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/helpers.go b/helpers.go index a4a1efde2..ee1e5a32f 100644 --- a/helpers.go +++ b/helpers.go @@ -285,19 +285,18 @@ func apacheToCassandraType(t string) string { // Do not override hex encoded field names idx := strings.Index(class, ":") class = class[idx+1:] - act := getApacheCassandraType(class) - val := act.String() - switch act { - case TypeUDT: - i += 2 // skip next two parameters (keyspace and type ID), do not attempt to resolve their type - case TypeCustom: - if isDigitsOnly(class) { - // vector types include dimension (digits) as second type parameter - // getApacheCassandraType() returns "custom" by default, but we need to leave digits intact - val = class - } else { + val := "" + if strings.HasPrefix(class, apacheCassandraTypePrefix) { + act := getApacheCassandraType(class) + val = act.String() + switch act { + case TypeUDT: + i += 2 // skip next two parameters (keyspace and type ID), do not attempt to resolve their type + case TypeCustom: val = getApacheCassandraCustomSubType(class) } + } else { + val = class } t = strings.Replace(t, class, val, -1) } @@ -305,15 +304,6 @@ func apacheToCassandraType(t string) string { return strings.Replace(t, ",", ", ", -1) } -func isDigitsOnly(s string) bool { - for _, c := range s { - if c < '0' || c > '9' { - return false - } - } - return true -} - func getApacheCassandraType(class string) Type { switch strings.TrimPrefix(class, apacheCassandraTypePrefix) { case "AsciiType": diff --git a/vector_test.go b/vector_test.go index 1e8dbdc21..edaee09cd 100644 --- a/vector_test.go +++ b/vector_test.go @@ -369,6 +369,23 @@ func TestVector_SubTypeParsing(t *testing.T) { }, }, }, + { + name: "set_map_vector_text_text", + custom: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.Int32Type, 10),org.apache.cassandra.db.marshal.UTF8Type))", + expected: CollectionType{ + NativeType{typ: TypeSet}, + nil, + CollectionType{ + NativeType{typ: TypeMap}, + VectorType{ + NativeType{typ: TypeCustom, custom: VECTOR_TYPE}, + NativeType{typ: TypeInt}, + 10, + }, + NativeType{typ: TypeVarchar}, + }, + }, + }, } for _, test := range tests { From e21b4967379928c111f7e70f04c005ef17dc0367 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Tue, 29 Oct 2024 17:52:15 +0100 Subject: [PATCH 20/23] Apply review comments --- frame.go | 10 +--- helpers.go | 125 +++++++++++++++++++++++++++++++++++------------ metadata.go | 64 ------------------------ metadata_test.go | 11 ----- vector_test.go | 4 -- 5 files changed, 95 insertions(+), 119 deletions(-) diff --git a/frame.go b/frame.go index 26af4edbb..12ae90692 100644 --- a/frame.go +++ b/frame.go @@ -931,24 +931,16 @@ func (f *framer) readTypeInfo() TypeInfo { return collection case TypeCustom: if strings.HasPrefix(simple.custom, VECTOR_TYPE) { - // TODO(lantoniak): There are currently two ways of parsing types in the driver. - // a) using getTypeInfo() - // b) using parseType() - // I think we could agree to use getTypeInfo() when parsing binary type definition - // and parseType() would be responsible for parsing "custom" string definition. spec := strings.TrimPrefix(simple.custom, VECTOR_TYPE) spec = spec[1 : len(spec)-1] // remove parenthesis idx := strings.LastIndex(spec, ",") typeStr := spec[:idx] dimStr := spec[idx+1:] - subType := getTypeInfo(strings.TrimSpace(typeStr), f.proto, nopLogger{}) + subType := getCassandraLongType(strings.TrimSpace(typeStr), f.proto, nopLogger{}) dim, _ := strconv.Atoi(strings.TrimSpace(dimStr)) - //result := parseType(simple.custom, simple.proto, nopLogger{}) - //dim, _ := strconv.Atoi(result.types[1].Custom()) vector := VectorType{ NativeType: simple, SubType: subType, - //SubType: result.types[0], Dimensions: dim, } return vector diff --git a/helpers.go b/helpers.go index ee1e5a32f..e22f5f126 100644 --- a/helpers.go +++ b/helpers.go @@ -164,6 +164,80 @@ func getCassandraBaseType(name string) Type { } } +// Parse long Java-style type definition to internal data structures. +func getCassandraLongType(name string, protoVer byte, logger StdLogger) TypeInfo { + if strings.HasPrefix(name, SET_TYPE) { + return CollectionType{ + NativeType: NewNativeType(protoVer, TypeSet), + Elem: getCassandraLongType(strings.TrimPrefix(name[:len(name)-1], SET_TYPE+"("), protoVer, logger), + } + } else if strings.HasPrefix(name, LIST_TYPE) { + return CollectionType{ + NativeType: NewNativeType(protoVer, TypeList), + Elem: getCassandraLongType(strings.TrimPrefix(name[:len(name)-1], LIST_TYPE+"("), protoVer, logger), + } + } else if strings.HasPrefix(name, MAP_TYPE) { + names := splitJavaCompositeTypes(strings.TrimPrefix(name[:len(name)-1], MAP_TYPE+"(")) + if len(names) != 2 { + logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names)) + return NewNativeType(protoVer, TypeCustom) + } + return CollectionType{ + NativeType: NewNativeType(protoVer, TypeMap), + Key: getCassandraLongType(names[0], protoVer, logger), + Elem: getCassandraLongType(names[1], protoVer, logger), + } + } else if strings.HasPrefix(name, TUPLE_TYPE) { + names := splitJavaCompositeTypes(strings.TrimPrefix(name[:len(name)-1], TUPLE_TYPE+"(")) + types := make([]TypeInfo, len(names)) + + for i, name := range names { + types[i] = getCassandraLongType(name, protoVer, logger) + } + + return TupleTypeInfo{ + NativeType: NewNativeType(protoVer, TypeTuple), + Elems: types, + } + } else if strings.HasPrefix(name, UDT_TYPE) { + names := splitJavaCompositeTypes(strings.TrimPrefix(name[:len(name)-1], UDT_TYPE+"(")) + fields := make([]UDTField, len(names)-2) + + for i := 2; i < len(names); i++ { + spec := strings.Split(names[i], ":") + fieldName, _ := hex.DecodeString(spec[0]) + fields[i-2] = UDTField{ + Name: string(fieldName), + Type: getTypeInfo(spec[1], protoVer, logger), + } + } + + udtName, _ := hex.DecodeString(names[1]) + return UDTTypeInfo{ + NativeType: NewNativeType(protoVer, TypeUDT), + KeySpace: names[0], + Name: string(udtName), + Elements: fields, + } + } else if strings.HasPrefix(name, VECTOR_TYPE) { + names := splitJavaCompositeTypes(strings.TrimPrefix(name[:len(name)-1], VECTOR_TYPE+"(")) + subType := getCassandraLongType(strings.TrimSpace(names[0]), protoVer, logger) + dim, _ := strconv.Atoi(strings.TrimSpace(names[1])) + + return VectorType{ + NativeType: NewCustomType(protoVer, TypeCustom, VECTOR_TYPE), + SubType: subType, + Dimensions: dim, + } + } else { + // basic type + return NativeType{ + proto: protoVer, + typ: getApacheCassandraType(name), + } + } +} + // Parses short CQL type representation to internal data structures. // Mapping of long Java-style type definition into short format is performed in // apacheToCassandraType function. @@ -181,7 +255,7 @@ func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo { Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), protoVer, logger), } } else if strings.HasPrefix(name, "map<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) + names := splitCQLCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) if len(names) != 2 { logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names)) return NewNativeType(protoVer, TypeCustom) @@ -192,7 +266,7 @@ func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo { Elem: getCassandraType(names[1], protoVer, logger), } } else if strings.HasPrefix(name, "tuple<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<")) + names := splitCQLCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<")) types := make([]TypeInfo, len(names)) for i, name := range names { @@ -204,7 +278,7 @@ func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo { Elems: types, } } else if strings.HasPrefix(name, "vector<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "vector<")) + names := splitCQLCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "vector<")) subType := getCassandraType(strings.TrimSpace(names[0]), protoVer, logger) dim, _ := strconv.Atoi(strings.TrimSpace(names[1])) @@ -213,40 +287,29 @@ func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo { SubType: subType, Dimensions: dim, } - } else if strings.Index(name, "<") == -1 { - // basic type + } else { return NativeType{ proto: protoVer, typ: getCassandraBaseType(name), } - } else { - // udt - idx := strings.Index(name, "<") - names := splitCompositeTypes(name[idx+1 : len(name)-1]) - fields := make([]UDTField, len(names)-2) + } +} - for i := 2; i < len(names); i++ { - spec := strings.Split(names[i], ":") - fieldName, _ := hex.DecodeString(spec[0]) - fields[i-2] = UDTField{ - Name: string(fieldName), - Type: getTypeInfo(spec[1], protoVer, logger), - } - } +func splitCQLCompositeTypes(name string) []string { + return splitCompositeTypes(name, '<', '>') +} - udtName, _ := hex.DecodeString(names[1]) - return UDTTypeInfo{ - NativeType: NewNativeType(protoVer, TypeUDT), - KeySpace: names[0], - Name: string(udtName), - Elements: fields, - } - } +func splitJavaCompositeTypes(name string) []string { + return splitCompositeTypes(name, '(', ')') } -func splitCompositeTypes(name string) []string { - if !strings.Contains(name, "<") { - return strings.Split(name, ", ") +func splitCompositeTypes(name string, typeOpen int32, typeClose int32) []string { + if !strings.Contains(name, string(typeOpen)) { + parts := strings.Split(name, ",") + for i := range parts { + parts[i] = strings.TrimSpace(parts[i]) + } + return parts } var parts []string lessCount := 0 @@ -260,9 +323,9 @@ func splitCompositeTypes(name string) []string { continue } segment += string(char) - if char == '<' { + if char == typeOpen { lessCount++ - } else if char == '>' { + } else if char == typeClose { lessCount-- } } diff --git a/metadata.go b/metadata.go index 6279b071f..cc0f93dc2 100644 --- a/metadata.go +++ b/metadata.go @@ -1297,26 +1297,6 @@ func (t *typeParser) parse() typeParserResult { reversed: reversed, collections: collections, } - } else if strings.HasPrefix(ast.name, VECTOR_TYPE) { - count := len(ast.params) - - types := make([]TypeInfo, count) - reversed := make([]bool, count) - - for i, param := range ast.params[:count] { - class := param.class - reversed[i] = strings.HasPrefix(class.name, REVERSED_TYPE) - if reversed[i] { - class = class.params[0].class - } - types[i] = class.asTypeInfo() - } - - return typeParserResult{ - isComposite: true, - types: types, - reversed: reversed, - } } else { // not composite, so one type class := *ast @@ -1367,50 +1347,6 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { Elem: elem, } } - if strings.HasPrefix(class.name, UDT_TYPE) { - udtName, _ := hex.DecodeString(class.params[1].class.name) - fields := make([]UDTField, len(class.params)-2) - for i := 2; i < len(class.params); i++ { - fieldName, _ := hex.DecodeString(*class.params[i].name) - fields[i-2] = UDTField{ - Name: string(fieldName), - Type: class.params[i].class.asTypeInfo(), - } - } - return UDTTypeInfo{ - NativeType: NativeType{ - typ: TypeUDT, - proto: class.proto, - }, - KeySpace: class.params[0].class.name, - Name: string(udtName), - Elements: fields, - } - } - if strings.HasPrefix(class.name, TUPLE_TYPE) { - fields := make([]TypeInfo, len(class.params)) - for i := 0; i < len(class.params); i++ { - fields[i] = class.params[i].class.asTypeInfo() - } - return TupleTypeInfo{ - NativeType: NativeType{ - typ: TypeTuple, - proto: class.proto, - }, - Elems: fields, - } - } - if strings.HasPrefix(class.name, VECTOR_TYPE) { - dim, _ := strconv.Atoi(class.params[1].class.name) - return VectorType{ - NativeType: NativeType{ - typ: TypeCustom, - proto: class.proto, - }, - SubType: class.params[0].class.asTypeInfo(), - Dimensions: dim, - } - } // must be a simple type or custom type info := NativeType{typ: getApacheCassandraType(class.name), proto: class.proto} diff --git a/metadata_test.go b/metadata_test.go index 78d94dd21..6b3d11982 100644 --- a/metadata_test.go +++ b/metadata_test.go @@ -643,17 +643,6 @@ func TestTypeParser(t *testing.T) { assertTypeInfo{Type: TypeUDT, Custom: ""}, ) - // vector - assertParseCompositeType( - t, - "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 3)", - []assertTypeInfo{ - {Type: TypeFloat}, - {Type: TypeCustom, Custom: "3"}, - }, - nil, - ) - // custom assertParseNonCompositeType( t, diff --git a/vector_test.go b/vector_test.go index edaee09cd..88b4453ad 100644 --- a/vector_test.go +++ b/vector_test.go @@ -395,13 +395,9 @@ func TestVector_SubTypeParsing(t *testing.T) { f.writeString(fmt.Sprintf("org.apache.cassandra.db.marshal.VectorType(%s, 2)", test.custom)) parsedType := f.readTypeInfo() require.IsType(t, parsedType, VectorType{}) - - // test first parsing method vectorType := parsedType.(VectorType) assertEqual(t, "dimensions", 2, vectorType.Dimensions) assertDeepEqual(t, "vector", test.expected, vectorType.SubType) - //subType := parseType(fmt.Sprintf("org.apache.cassandra.db.marshal.VectorType(%s, 2)", test.custom), 0, &defaultLogger{}) - //assertDeepEqual(t, "vector", test.expected, subType.types[0]) }) } } From 24daa933506ee5f6804e8ed9e17cf3c1f867b09c Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Tue, 29 Oct 2024 18:04:35 +0100 Subject: [PATCH 21/23] Revert unnecessary changes --- helpers.go | 33 +++------------------------------ helpers_test.go | 2 +- 2 files changed, 4 insertions(+), 31 deletions(-) diff --git a/helpers.go b/helpers.go index e22f5f126..e7c1084c8 100644 --- a/helpers.go +++ b/helpers.go @@ -337,31 +337,14 @@ func splitCompositeTypes(name string, typeOpen int32, typeClose int32) []string // Convert long Java style type definition into the short CQL type names. func apacheToCassandraType(t string) string { + t = strings.Replace(t, apacheCassandraTypePrefix, "", -1) t = strings.Replace(t, "(", "<", -1) t = strings.Replace(t, ")", ">", -1) types := strings.FieldsFunc(t, func(r rune) bool { return r == '<' || r == '>' || r == ',' }) - for i := 0; i < len(types); i++ { - class := strings.TrimSpace(types[i]) - // UDT fields are represented in format {field id}:{class}, example 66697273745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type - // Do not override hex encoded field names - idx := strings.Index(class, ":") - class = class[idx+1:] - val := "" - if strings.HasPrefix(class, apacheCassandraTypePrefix) { - act := getApacheCassandraType(class) - val = act.String() - switch act { - case TypeUDT: - i += 2 // skip next two parameters (keyspace and type ID), do not attempt to resolve their type - case TypeCustom: - val = getApacheCassandraCustomSubType(class) - } - } else { - val = class - } - t = strings.Replace(t, class, val, -1) + for _, typ := range types { + t = strings.Replace(t, typ, getApacheCassandraType(typ).String(), -1) } // This is done so it exactly matches what Cassandra returns return strings.Replace(t, ",", ", ", -1) @@ -424,16 +407,6 @@ func getApacheCassandraType(class string) Type { } } -// Dedicated function parsing known special subtypes of CQL custom type. -// Currently, only vectors are implemented as special custom subtype. -func getApacheCassandraCustomSubType(class string) string { - switch strings.TrimPrefix(class, apacheCassandraTypePrefix) { - case "VectorType": - return "vector" - } - return "custom" -} - func (r *RowData) rowMap(m map[string]interface{}) { for i, column := range r.Columns { val := dereference(r.Values[i]) diff --git a/helpers_test.go b/helpers_test.go index 61b369d9a..50d6f2ad7 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -30,7 +30,7 @@ import ( ) func TestGetCassandraType_Set(t *testing.T) { - typ := getCassandraType("set", 4, &defaultLogger{}) + typ := getCassandraType("set", protoVersion4, &defaultLogger{}) set, ok := typ.(CollectionType) if !ok { t.Fatalf("expected CollectionType got %T", typ) From ae2d2e452e1ed2b7ab5b542cc496dd2ca6f0382e Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 30 Oct 2024 17:04:27 +0100 Subject: [PATCH 22/23] Apply review comments --- helpers.go | 31 +++++++++---------------------- metadata.go | 2 +- 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/helpers.go b/helpers.go index e7c1084c8..97f96f7de 100644 --- a/helpers.go +++ b/helpers.go @@ -164,7 +164,7 @@ func getCassandraBaseType(name string) Type { } } -// Parse long Java-style type definition to internal data structures. +// Parses long Java-style type definition to internal data structures. func getCassandraLongType(name string, protoVer byte, logger StdLogger) TypeInfo { if strings.HasPrefix(name, SET_TYPE) { return CollectionType{ @@ -179,7 +179,7 @@ func getCassandraLongType(name string, protoVer byte, logger StdLogger) TypeInfo } else if strings.HasPrefix(name, MAP_TYPE) { names := splitJavaCompositeTypes(strings.TrimPrefix(name[:len(name)-1], MAP_TYPE+"(")) if len(names) != 2 { - logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names)) + logger.Printf("gocql: error parsing map type, it has %d subelements, expecting 2\n", len(names)) return NewNativeType(protoVer, TypeCustom) } return CollectionType{ @@ -208,7 +208,7 @@ func getCassandraLongType(name string, protoVer byte, logger StdLogger) TypeInfo fieldName, _ := hex.DecodeString(spec[0]) fields[i-2] = UDTField{ Name: string(fieldName), - Type: getTypeInfo(spec[1], protoVer, logger), + Type: getCassandraLongType(spec[1], protoVer, logger), } } @@ -222,7 +222,11 @@ func getCassandraLongType(name string, protoVer byte, logger StdLogger) TypeInfo } else if strings.HasPrefix(name, VECTOR_TYPE) { names := splitJavaCompositeTypes(strings.TrimPrefix(name[:len(name)-1], VECTOR_TYPE+"(")) subType := getCassandraLongType(strings.TrimSpace(names[0]), protoVer, logger) - dim, _ := strconv.Atoi(strings.TrimSpace(names[1])) + dim, err := strconv.Atoi(strings.TrimSpace(names[1])) + if err != nil { + logger.Printf("gocql: error parsing vector dimensions: %v\n", err) + return NewNativeType(protoVer, TypeCustom) + } return VectorType{ NativeType: NewCustomType(protoVer, TypeCustom, VECTOR_TYPE), @@ -238,9 +242,7 @@ func getCassandraLongType(name string, protoVer byte, logger StdLogger) TypeInfo } } -// Parses short CQL type representation to internal data structures. -// Mapping of long Java-style type definition into short format is performed in -// apacheToCassandraType function. +// Parses short CQL type representation (e.g. map) to internal data structures. func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo { if strings.HasPrefix(name, "frozen<") { return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"), protoVer, logger) @@ -335,21 +337,6 @@ func splitCompositeTypes(name string, typeOpen int32, typeClose int32) []string return parts } -// Convert long Java style type definition into the short CQL type names. -func apacheToCassandraType(t string) string { - t = strings.Replace(t, apacheCassandraTypePrefix, "", -1) - t = strings.Replace(t, "(", "<", -1) - t = strings.Replace(t, ")", ">", -1) - types := strings.FieldsFunc(t, func(r rune) bool { - return r == '<' || r == '>' || r == ',' - }) - for _, typ := range types { - t = strings.Replace(t, typ, getApacheCassandraType(typ).String(), -1) - } - // This is done so it exactly matches what Cassandra returns - return strings.Replace(t, ",", ", ", -1) -} - func getApacheCassandraType(class string) Type { switch strings.TrimPrefix(class, apacheCassandraTypePrefix) { case "AsciiType": diff --git a/metadata.go b/metadata.go index cc0f93dc2..2773bae30 100644 --- a/metadata.go +++ b/metadata.go @@ -949,7 +949,7 @@ func getColumnMetadata(session *Session, keyspaceName string) ([]ColumnMetadata, func getTypeInfo(t string, protoVer byte, logger StdLogger) TypeInfo { if strings.HasPrefix(t, apacheCassandraTypePrefix) { - t = apacheToCassandraType(t) + return getCassandraLongType(t, protoVer, logger) } return getCassandraType(t, protoVer, logger) } From f9e22c7f69a00c8ac83a766cf7f44dba1261e0e9 Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Thu, 31 Oct 2024 11:04:00 +0100 Subject: [PATCH 23/23] Apply review comments --- marshal.go | 2 +- vector_test.go | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/marshal.go b/marshal.go index cd8c7b32e..d2ea2978a 100644 --- a/marshal.go +++ b/marshal.go @@ -1778,7 +1778,7 @@ func unmarshalVector(info VectorType, data []byte, value interface{}) error { } if k == reflect.Array { if rv.Len() != info.Dimensions { - return unmarshalErrorf("unmarshal vector: array with wrong size") + return unmarshalErrorf("unmarshal vector: array of size %d cannot store vector of %d dimensions", rv.Len(), info.Dimensions) } } else { rv.Set(reflect.MakeSlice(t, info.Dimensions, info.Dimensions)) diff --git a/vector_test.go b/vector_test.go index 88b4453ad..4e52a8856 100644 --- a/vector_test.go +++ b/vector_test.go @@ -153,8 +153,7 @@ func TestVector_Types(t *testing.T) { actual := a.([]net.IP) assertEqual(t, "vector size", len(expected), len(actual)) for i, _ := range expected { - // TODO(lantoniak): Find a better way to compare IP addresses - assertEqual(t, "vector", expected[i].String(), actual[i].String()) + assertTrue(t, "vector", expected[i].Equal(actual[i])) } }, }, @@ -163,7 +162,6 @@ func TestVector_Types(t *testing.T) { {name: "smallint", cqlType: TypeSmallInt.String(), value: []int16{127, 256, -1234}}, {name: "tinyint", cqlType: TypeTinyInt.String(), value: []int8{127, 9, -123}}, {name: "duration", cqlType: TypeDuration.String(), value: []Duration{duration1, duration2, duration3}}, - // TODO(lantonia): Test vector of custom types {name: "vector_vector_float", cqlType: "vector", value: [][]float32{{0.1, -1.2, 3, 5, 5}, {10.1, -122222.0002, 35.0, 1, 1}, {0, 0, 0, 0, 0}}}, {name: "vector_vector_set_float", cqlType: "vector, 5>", value: [][][]float32{ {{1, 2}, {2, -1}, {3}, {0}, {-1.3}},