diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d2044a0c0..afa1e3755 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -51,10 +51,32 @@ jobs: key: ${{ runner.os }}-go-${{ hashFiles('go.mod') }} restore-keys: | ${{ runner.os }}-go- + - name: Install Java + run: | + curl -s "https://get.sdkman.io" | bash + source "$HOME/.sdkman/bin/sdkman-init.sh" + echo "sdkman_auto_answer=true" >> ~/.sdkman/etc/config + # sdk list java + + sdk install java 11.0.24-zulu + echo "JAVA11_HOME=$JAVA_HOME_11_X64" >> $GITHUB_ENV + + sdk install java 17.0.12-zulu + echo "JAVA17_HOME=$JAVA_HOME_17_X64" >> $GITHUB_ENV + + # by default use JDK 11 + sdk default java 11.0.24-zulu + sdk use java 11.0.24-zulu + echo "JAVA_HOME=$JAVA_HOME_11_X64" >> $GITHUB_ENV + echo "PATH=$PATH" >> $GITHUB_ENV - name: Install CCM - run: pip install "git+https://github.com/riptano/ccm.git@${CCM_VERSION}" + run: | + python3 -m venv ~/venv + ~/venv/bin/pip install setuptools + ~/venv/bin/pip install "git+https://github.com/riptano/ccm.git@${CCM_VERSION}" - name: Start cassandra nodes run: | + source ~/venv/bin/activate VERSION=${{ matrix.cassandra_version }} keypath="$(pwd)/testdata/pki" conf=( @@ -107,6 +129,7 @@ jobs: echo "JVM_EXTRA_OPTS=$JVM_EXTRA_OPTS" >> $GITHUB_ENV - name: Integration tests run: | + source ~/venv/bin/activate export JVM_EXTRA_OPTS="${{env.JVM_EXTRA_OPTS}}" go test -v -tags "${{ matrix.tags }} gocql_debug" -timeout=5m -race ${{ env.args }} - name: 'Save ccm logs' @@ -135,10 +158,32 @@ jobs: - uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} + - name: Install Java + run: | + curl -s "https://get.sdkman.io" | bash + source "$HOME/.sdkman/bin/sdkman-init.sh" + echo "sdkman_auto_answer=true" >> ~/.sdkman/etc/config + # sdk list java + + sdk install java 11.0.24-zulu + echo "JAVA11_HOME=$JAVA_HOME_11_X64" >> $GITHUB_ENV + + sdk install java 17.0.12-zulu + echo "JAVA17_HOME=$JAVA_HOME_17_X64" >> $GITHUB_ENV + + # by default use JDK 11 + sdk default java 11.0.24-zulu + sdk use java 11.0.24-zulu + echo "JAVA_HOME=$JAVA_HOME_11_X64" >> $GITHUB_ENV + echo "PATH=$PATH" >> $GITHUB_ENV - name: Install CCM - run: pip install "git+https://github.com/riptano/ccm.git@${CCM_VERSION}" + run: | + python3 -m venv ~/venv + ~/venv/bin/pip install setuptools + ~/venv/bin/pip install "git+https://github.com/riptano/ccm.git@${CCM_VERSION}" - name: Start cassandra nodes run: | + source ~/venv/bin/activate VERSION=${{ matrix.cassandra_version }} keypath="$(pwd)/testdata/pki" conf=( @@ -197,5 +242,6 @@ jobs: sleep 30s - name: Integration tests run: | + source ~/venv/bin/activate export JVM_EXTRA_OPTS="${{env.JVM_EXTRA_OPTS}}" go test -v -run=TestAuthentication -tags "${{ matrix.tags }} gocql_debug" -timeout=15s -runauth ${{ env.args }} diff --git a/cassandra_test.go b/cassandra_test.go index 797a7cf7f..5e51929cf 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -2245,15 +2245,16 @@ func TestViewMetadata(t *testing.T) { textType = TypeVarchar } + protoVer := byte(session.cfg.ProtoVersion) expectedView := ViewMetadata{ Keyspace: "gocql_test", Name: "basicview", FieldNames: []string{"birthday", "nationality", "weight", "height"}, FieldTypes: []TypeInfo{ - NativeType{typ: TypeTimestamp}, - NativeType{typ: textType}, - NativeType{typ: textType}, - NativeType{typ: textType}, + NativeType{typ: TypeTimestamp, proto: protoVer}, + NativeType{typ: textType, proto: protoVer}, + NativeType{typ: textType, proto: protoVer}, + NativeType{typ: textType, proto: protoVer}, }, } @@ -2351,18 +2352,19 @@ func TestAggregateMetadata(t *testing.T) { t.Fatal("expected two aggregates") } + protoVer := byte(session.cfg.ProtoVersion) expectedAggregrate := AggregateMetadata{ Keyspace: "gocql_test", Name: "average", - ArgumentTypes: []TypeInfo{NativeType{typ: TypeInt}}, + ArgumentTypes: []TypeInfo{NativeType{typ: TypeInt, proto: protoVer}}, InitCond: "(0, 0)", - ReturnType: NativeType{typ: TypeDouble}, + ReturnType: NativeType{typ: TypeDouble, proto: protoVer}, StateType: TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, + NativeType: NativeType{typ: TypeTuple, proto: protoVer}, Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeBigInt}, + NativeType{typ: TypeInt, proto: protoVer}, + NativeType{typ: TypeBigInt, proto: protoVer}, }, }, stateFunc: "avgstate", @@ -2401,28 +2403,29 @@ func TestFunctionMetadata(t *testing.T) { avgState := functions[1] avgFinal := functions[0] + protoVer := byte(session.cfg.ProtoVersion) avgStateBody := "if (val !=null) {state.setInt(0, state.getInt(0)+1); state.setLong(1, state.getLong(1)+val.intValue());}return state;" expectedAvgState := FunctionMetadata{ Keyspace: "gocql_test", Name: "avgstate", ArgumentTypes: []TypeInfo{ TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, + NativeType: NativeType{typ: TypeTuple, proto: protoVer}, Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeBigInt}, + NativeType{typ: TypeInt, proto: protoVer}, + NativeType{typ: TypeBigInt, proto: protoVer}, }, }, - NativeType{typ: TypeInt}, + NativeType{typ: TypeInt, proto: protoVer}, }, ArgumentNames: []string{"state", "val"}, ReturnType: TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, + NativeType: NativeType{typ: TypeTuple, proto: protoVer}, Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeBigInt}, + NativeType{typ: TypeInt, proto: protoVer}, + NativeType{typ: TypeBigInt, proto: protoVer}, }, }, CalledOnNullInput: true, @@ -2439,16 +2442,16 @@ func TestFunctionMetadata(t *testing.T) { Name: "avgfinal", ArgumentTypes: []TypeInfo{ TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, + NativeType: NativeType{typ: TypeTuple, proto: protoVer}, Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeBigInt}, + NativeType{typ: TypeInt, proto: protoVer}, + NativeType{typ: TypeBigInt, proto: protoVer}, }, }, }, ArgumentNames: []string{"state"}, - ReturnType: NativeType{typ: TypeDouble}, + ReturnType: NativeType{typ: TypeDouble, proto: protoVer}, CalledOnNullInput: true, Language: "java", Body: finalStateBody, @@ -2557,15 +2560,16 @@ func TestKeyspaceMetadata(t *testing.T) { if flagCassVersion.Before(3, 0, 0) { textType = TypeVarchar } + protoVer := byte(session.cfg.ProtoVersion) expectedType := UserTypeMetadata{ Keyspace: "gocql_test", Name: "basicview", FieldNames: []string{"birthday", "nationality", "weight", "height"}, FieldTypes: []TypeInfo{ - NativeType{typ: TypeTimestamp}, - NativeType{typ: textType}, - NativeType{typ: textType}, - NativeType{typ: textType}, + NativeType{typ: TypeTimestamp, proto: protoVer}, + NativeType{typ: textType, proto: protoVer}, + NativeType{typ: textType, proto: protoVer}, + NativeType{typ: textType, proto: protoVer}, }, } if !reflect.DeepEqual(*keyspaceMetadata.UserTypes["basicview"], expectedType) { diff --git a/common_test.go b/common_test.go index a5edb03c6..ae7f83f94 100644 --- a/common_test.go +++ b/common_test.go @@ -28,6 +28,7 @@ import ( "flag" "fmt" "log" + "math/rand" "net" "reflect" "strings" @@ -52,6 +53,10 @@ var ( flagCassVersion cassVersion ) +var seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano())) + +const randCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + func init() { flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against") @@ -277,6 +282,14 @@ func assertTrue(t *testing.T, description string, value bool) { } } +func randomText(size int) string { + result := make([]byte, size) + for i := range result { + result[i] = randCharset[rand.Intn(len(randCharset))] + } + return string(result) +} + func assertEqual(t *testing.T, description string, expected, actual interface{}) { t.Helper() if expected != actual { diff --git a/frame.go b/frame.go index d374ae574..12ae90692 100644 --- a/frame.go +++ b/frame.go @@ -32,6 +32,7 @@ import ( "io/ioutil" "net" "runtime" + "strconv" "strings" "time" ) @@ -928,6 +929,22 @@ func (f *framer) readTypeInfo() TypeInfo { collection.Elem = f.readTypeInfo() return collection + case TypeCustom: + if strings.HasPrefix(simple.custom, VECTOR_TYPE) { + spec := strings.TrimPrefix(simple.custom, VECTOR_TYPE) + spec = spec[1 : len(spec)-1] // remove parenthesis + idx := strings.LastIndex(spec, ",") + typeStr := spec[:idx] + dimStr := spec[idx+1:] + subType := getCassandraLongType(strings.TrimSpace(typeStr), f.proto, nopLogger{}) + dim, _ := strconv.Atoi(strings.TrimSpace(dimStr)) + vector := VectorType{ + NativeType: simple, + SubType: subType, + Dimensions: dim, + } + return vector + } } return simple diff --git a/helpers.go b/helpers.go index f2faee9e0..97f96f7de 100644 --- a/helpers.go +++ b/helpers.go @@ -25,10 +25,12 @@ package gocql import ( + "encoding/hex" "fmt" "math/big" "net" "reflect" + "strconv" "strings" "time" @@ -162,54 +164,154 @@ func getCassandraBaseType(name string) Type { } } -func getCassandraType(name string, logger StdLogger) TypeInfo { +// Parses long Java-style type definition to internal data structures. +func getCassandraLongType(name string, protoVer byte, logger StdLogger) TypeInfo { + if strings.HasPrefix(name, SET_TYPE) { + return CollectionType{ + NativeType: NewNativeType(protoVer, TypeSet), + Elem: getCassandraLongType(strings.TrimPrefix(name[:len(name)-1], SET_TYPE+"("), protoVer, logger), + } + } else if strings.HasPrefix(name, LIST_TYPE) { + return CollectionType{ + NativeType: NewNativeType(protoVer, TypeList), + Elem: getCassandraLongType(strings.TrimPrefix(name[:len(name)-1], LIST_TYPE+"("), protoVer, logger), + } + } else if strings.HasPrefix(name, MAP_TYPE) { + names := splitJavaCompositeTypes(strings.TrimPrefix(name[:len(name)-1], MAP_TYPE+"(")) + if len(names) != 2 { + logger.Printf("gocql: error parsing map type, it has %d subelements, expecting 2\n", len(names)) + return NewNativeType(protoVer, TypeCustom) + } + return CollectionType{ + NativeType: NewNativeType(protoVer, TypeMap), + Key: getCassandraLongType(names[0], protoVer, logger), + Elem: getCassandraLongType(names[1], protoVer, logger), + } + } else if strings.HasPrefix(name, TUPLE_TYPE) { + names := splitJavaCompositeTypes(strings.TrimPrefix(name[:len(name)-1], TUPLE_TYPE+"(")) + types := make([]TypeInfo, len(names)) + + for i, name := range names { + types[i] = getCassandraLongType(name, protoVer, logger) + } + + return TupleTypeInfo{ + NativeType: NewNativeType(protoVer, TypeTuple), + Elems: types, + } + } else if strings.HasPrefix(name, UDT_TYPE) { + names := splitJavaCompositeTypes(strings.TrimPrefix(name[:len(name)-1], UDT_TYPE+"(")) + fields := make([]UDTField, len(names)-2) + + for i := 2; i < len(names); i++ { + spec := strings.Split(names[i], ":") + fieldName, _ := hex.DecodeString(spec[0]) + fields[i-2] = UDTField{ + Name: string(fieldName), + Type: getCassandraLongType(spec[1], protoVer, logger), + } + } + + udtName, _ := hex.DecodeString(names[1]) + return UDTTypeInfo{ + NativeType: NewNativeType(protoVer, TypeUDT), + KeySpace: names[0], + Name: string(udtName), + Elements: fields, + } + } else if strings.HasPrefix(name, VECTOR_TYPE) { + names := splitJavaCompositeTypes(strings.TrimPrefix(name[:len(name)-1], VECTOR_TYPE+"(")) + subType := getCassandraLongType(strings.TrimSpace(names[0]), protoVer, logger) + dim, err := strconv.Atoi(strings.TrimSpace(names[1])) + if err != nil { + logger.Printf("gocql: error parsing vector dimensions: %v\n", err) + return NewNativeType(protoVer, TypeCustom) + } + + return VectorType{ + NativeType: NewCustomType(protoVer, TypeCustom, VECTOR_TYPE), + SubType: subType, + Dimensions: dim, + } + } else { + // basic type + return NativeType{ + proto: protoVer, + typ: getApacheCassandraType(name), + } + } +} + +// Parses short CQL type representation (e.g. map) to internal data structures. +func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo { if strings.HasPrefix(name, "frozen<") { - return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"), logger) + return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"), protoVer, logger) } else if strings.HasPrefix(name, "set<") { return CollectionType{ - NativeType: NativeType{typ: TypeSet}, - Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<"), logger), + NativeType: NewNativeType(protoVer, TypeSet), + Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<"), protoVer, logger), } } else if strings.HasPrefix(name, "list<") { return CollectionType{ - NativeType: NativeType{typ: TypeList}, - Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), logger), + NativeType: NewNativeType(protoVer, TypeList), + Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), protoVer, logger), } } else if strings.HasPrefix(name, "map<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) + names := splitCQLCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) if len(names) != 2 { logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names)) - return NativeType{ - typ: TypeCustom, - } + return NewNativeType(protoVer, TypeCustom) } return CollectionType{ - NativeType: NativeType{typ: TypeMap}, - Key: getCassandraType(names[0], logger), - Elem: getCassandraType(names[1], logger), + NativeType: NewNativeType(protoVer, TypeMap), + Key: getCassandraType(names[0], protoVer, logger), + Elem: getCassandraType(names[1], protoVer, logger), } } else if strings.HasPrefix(name, "tuple<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<")) + names := splitCQLCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<")) types := make([]TypeInfo, len(names)) for i, name := range names { - types[i] = getCassandraType(name, logger) + types[i] = getCassandraType(name, protoVer, logger) } return TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, + NativeType: NewNativeType(protoVer, TypeTuple), Elems: types, } + } else if strings.HasPrefix(name, "vector<") { + names := splitCQLCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "vector<")) + subType := getCassandraType(strings.TrimSpace(names[0]), protoVer, logger) + dim, _ := strconv.Atoi(strings.TrimSpace(names[1])) + + return VectorType{ + NativeType: NewCustomType(protoVer, TypeCustom, VECTOR_TYPE), + SubType: subType, + Dimensions: dim, + } } else { return NativeType{ - typ: getCassandraBaseType(name), + proto: protoVer, + typ: getCassandraBaseType(name), } } } -func splitCompositeTypes(name string) []string { - if !strings.Contains(name, "<") { - return strings.Split(name, ", ") +func splitCQLCompositeTypes(name string) []string { + return splitCompositeTypes(name, '<', '>') +} + +func splitJavaCompositeTypes(name string) []string { + return splitCompositeTypes(name, '(', ')') +} + +func splitCompositeTypes(name string, typeOpen int32, typeClose int32) []string { + if !strings.Contains(name, string(typeOpen)) { + parts := strings.Split(name, ",") + for i := range parts { + parts[i] = strings.TrimSpace(parts[i]) + } + return parts } var parts []string lessCount := 0 @@ -223,9 +325,9 @@ func splitCompositeTypes(name string) []string { continue } segment += string(char) - if char == '<' { + if char == typeOpen { lessCount++ - } else if char == '>' { + } else if char == typeClose { lessCount-- } } @@ -235,20 +337,6 @@ func splitCompositeTypes(name string) []string { return parts } -func apacheToCassandraType(t string) string { - t = strings.Replace(t, apacheCassandraTypePrefix, "", -1) - t = strings.Replace(t, "(", "<", -1) - t = strings.Replace(t, ")", ">", -1) - types := strings.FieldsFunc(t, func(r rune) bool { - return r == '<' || r == '>' || r == ',' - }) - for _, typ := range types { - t = strings.Replace(t, typ, getApacheCassandraType(typ).String(), -1) - } - // This is done so it exactly matches what Cassandra returns - return strings.Replace(t, ",", ", ", -1) -} - func getApacheCassandraType(class string) Type { switch strings.TrimPrefix(class, apacheCassandraTypePrefix) { case "AsciiType": @@ -297,6 +385,10 @@ func getApacheCassandraType(class string) Type { return TypeTuple case "DurationType": return TypeDuration + case "SimpleDateType": + return TypeDate + case "UserType": + return TypeUDT default: return TypeCustom } diff --git a/helpers_test.go b/helpers_test.go index 67922ba5d..50d6f2ad7 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -30,7 +30,7 @@ import ( ) func TestGetCassandraType_Set(t *testing.T) { - typ := getCassandraType("set", &defaultLogger{}) + typ := getCassandraType("set", protoVersion4, &defaultLogger{}) set, ok := typ.(CollectionType) if !ok { t.Fatalf("expected CollectionType got %T", typ) @@ -223,11 +223,38 @@ func TestGetCassandraType(t *testing.T) { Elem: NativeType{typ: TypeDuration}, }, }, + { + "vector", VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: VECTOR_TYPE, + }, + SubType: NativeType{typ: TypeFloat}, + Dimensions: 3, + }, + }, + { + "vector, 5>", VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: VECTOR_TYPE, + }, + SubType: VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: VECTOR_TYPE, + }, + SubType: NativeType{typ: TypeFloat}, + Dimensions: 3, + }, + Dimensions: 5, + }, + }, } for _, test := range tests { t.Run(test.input, func(t *testing.T) { - got := getCassandraType(test.input, &defaultLogger{}) + got := getCassandraType(test.input, 0, &defaultLogger{}) // TODO(zariel): define an equal method on the types? if !reflect.DeepEqual(got, test.exp) { diff --git a/marshal.go b/marshal.go index 4d0adb923..d2ea2978a 100644 --- a/marshal.go +++ b/marshal.go @@ -170,6 +170,11 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { return marshalDate(info, value) case TypeDuration: return marshalDuration(info, value) + case TypeCustom: + switch info.(type) { + case VectorType: + return marshalVector(info.(VectorType), value) + } } // detect protocol 2 UDT @@ -274,6 +279,11 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error { return unmarshalDate(info, data, value) case TypeDuration: return unmarshalDuration(info, data, value) + case TypeCustom: + switch info.(type) { + case VectorType: + return unmarshalVector(info.(VectorType), data, value) + } } // detect protocol 2 UDT @@ -1709,6 +1719,169 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } +func marshalVector(info VectorType, value interface{}) ([]byte, error) { + if value == nil { + return nil, nil + } else if _, ok := value.(unsetColumn); ok { + return nil, nil + } + + rv := reflect.ValueOf(value) + t := rv.Type() + k := t.Kind() + if k == reflect.Slice && rv.IsNil() { + return nil, nil + } + + switch k { + case reflect.Slice, reflect.Array: + buf := &bytes.Buffer{} + n := rv.Len() + if n != info.Dimensions { + return nil, marshalErrorf("expected vector with %d dimensions, received %d", info.Dimensions, n) + } + + for i := 0; i < n; i++ { + item, err := Marshal(info.SubType, rv.Index(i).Interface()) + if err != nil { + return nil, err + } + if isVectorVariableLengthType(info.SubType) { + writeUnsignedVInt(buf, uint64(len(item))) + } + buf.Write(item) + } + return buf.Bytes(), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func unmarshalVector(info VectorType, data []byte, value interface{}) error { + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + t := rv.Type() + k := t.Kind() + switch k { + case reflect.Slice, reflect.Array: + if data == nil { + if k == reflect.Array { + return unmarshalErrorf("unmarshal vector: can not store nil in array value") + } + if rv.IsNil() { + return nil + } + rv.Set(reflect.Zero(t)) + return nil + } + if k == reflect.Array { + if rv.Len() != info.Dimensions { + return unmarshalErrorf("unmarshal vector: array of size %d cannot store vector of %d dimensions", rv.Len(), info.Dimensions) + } + } else { + rv.Set(reflect.MakeSlice(t, info.Dimensions, info.Dimensions)) + } + elemSize := len(data) / info.Dimensions + for i := 0; i < info.Dimensions; i++ { + offset := 0 + if isVectorVariableLengthType(info.SubType) { + m, p, err := readUnsignedVint(data, 0) + if err != nil { + return err + } + elemSize = int(m) + offset = p + } + if offset > 0 { + data = data[offset:] + } + var unmarshalData []byte + if elemSize >= 0 { + if len(data) < elemSize { + return unmarshalErrorf("unmarshal vector: unexpected eof") + } + unmarshalData = data[:elemSize] + data = data[elemSize:] + } + err := Unmarshal(info.SubType, unmarshalData, rv.Index(i).Addr().Interface()) + if err != nil { + return unmarshalErrorf("failed to unmarshal %s into %T: %s", info.SubType, unmarshalData, err.Error()) + } + } + return nil + } + return unmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func isVectorVariableLengthType(elemType TypeInfo) bool { + switch elemType.Type() { + case TypeVarchar, TypeAscii, TypeBlob, TypeText: + return true + case TypeCounter: + return true + case TypeDuration, TypeDate, TypeTime: + return true + case TypeDecimal, TypeSmallInt, TypeTinyInt, TypeVarint: + return true + case TypeInet: + return true + case TypeList, TypeSet, TypeMap, TypeUDT, TypeTuple: + return true + case TypeCustom: + switch elemType.(type) { + case VectorType: + vecType := elemType.(VectorType) + return isVectorVariableLengthType(vecType.SubType) + } + return true + } + return false +} + +func writeUnsignedVInt(buf *bytes.Buffer, v uint64) { + numBytes := computeUnsignedVIntSize(v) + if numBytes <= 1 { + buf.WriteByte(byte(v)) + return + } + + extraBytes := numBytes - 1 + var tmp = make([]byte, numBytes) + for i := extraBytes; i >= 0; i-- { + tmp[i] = byte(v) + v >>= 8 + } + tmp[0] |= byte(^(0xff >> uint(extraBytes))) + buf.Write(tmp) +} + +func readUnsignedVint(data []byte, start int) (uint64, int, error) { + if len(data) <= start { + return 0, 0, errors.New("unexpected eof") + } + firstByte := data[start] + if firstByte&0x80 == 0 { + return uint64(firstByte), start + 1, nil + } + numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24 + ret := uint64(firstByte & (0xff >> uint(numBytes))) + if len(data) < start+numBytes+1 { + return 0, 0, fmt.Errorf("data expect to have %d bytes, but it has only %d", start+numBytes+1, len(data)) + } + for i := start; i < start+numBytes; i++ { + ret <<= 8 + ret |= uint64(data[i+1] & 0xff) + } + return ret, start + numBytes + 1, nil +} + +func computeUnsignedVIntSize(v uint64) int { + lead0 := bits.LeadingZeros64(v) + return (639 - lead0*9) >> 6 +} + func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { mapInfo, ok := info.(CollectionType) if !ok { @@ -2476,7 +2649,11 @@ type NativeType struct { custom string // only used for TypeCustom } -func NewNativeType(proto byte, typ Type, custom string) NativeType { +func NewNativeType(proto byte, typ Type) NativeType { + return NativeType{proto, typ, ""} +} + +func NewCustomType(proto byte, typ Type, custom string) NativeType { return NativeType{proto, typ, custom} } @@ -2523,6 +2700,12 @@ type CollectionType struct { Elem TypeInfo // only used for TypeMap, TypeList and TypeSet } +type VectorType struct { + NativeType + SubType TypeInfo + Dimensions int +} + func (t CollectionType) NewWithError() (interface{}, error) { typ, err := goType(t) if err != nil { diff --git a/metadata.go b/metadata.go index 6eb798f8a..2773bae30 100644 --- a/metadata.go +++ b/metadata.go @@ -383,13 +383,13 @@ func compileMetadata( col := &columns[i] // decode the validator for TypeInfo and order if col.ClusteringOrder != "" { // Cassandra 3.x+ - col.Type = getCassandraType(col.Validator, logger) + col.Type = getCassandraType(col.Validator, byte(protoVersion), logger) col.Order = ASC if col.ClusteringOrder == "desc" { col.Order = DESC } } else { - validatorParsed := parseType(col.Validator, logger) + validatorParsed := parseType(col.Validator, byte(protoVersion), logger) col.Type = validatorParsed.types[0] col.Order = ASC if validatorParsed.reversed[0] { @@ -411,9 +411,9 @@ func compileMetadata( } if protoVersion == protoVersion1 { - compileV1Metadata(tables, logger) + compileV1Metadata(tables, protoVersion, logger) } else { - compileV2Metadata(tables, logger) + compileV2Metadata(tables, protoVersion, logger) } } @@ -422,14 +422,14 @@ func compileMetadata( // column metadata as V2+ (because V1 doesn't support the "type" column in the // system.schema_columns table) so determining PartitionKey and ClusterColumns // is more complex. -func compileV1Metadata(tables []TableMetadata, logger StdLogger) { +func compileV1Metadata(tables []TableMetadata, protoVer int, logger StdLogger) { for i := range tables { table := &tables[i] // decode the key validator - keyValidatorParsed := parseType(table.KeyValidator, logger) + keyValidatorParsed := parseType(table.KeyValidator, byte(protoVer), logger) // decode the comparator - comparatorParsed := parseType(table.Comparator, logger) + comparatorParsed := parseType(table.Comparator, byte(protoVer), logger) // the partition key length is the same as the number of types in the // key validator @@ -515,7 +515,7 @@ func compileV1Metadata(tables []TableMetadata, logger StdLogger) { alias = table.ValueAlias } // decode the default validator - defaultValidatorParsed := parseType(table.DefaultValidator, logger) + defaultValidatorParsed := parseType(table.DefaultValidator, byte(protoVer), logger) column := &ColumnMetadata{ Keyspace: table.Keyspace, Table: table.Name, @@ -529,7 +529,7 @@ func compileV1Metadata(tables []TableMetadata, logger StdLogger) { } // The simpler compile case for V2+ protocol -func compileV2Metadata(tables []TableMetadata, logger StdLogger) { +func compileV2Metadata(tables []TableMetadata, protoVer int, logger StdLogger) { for i := range tables { table := &tables[i] @@ -537,7 +537,7 @@ func compileV2Metadata(tables []TableMetadata, logger StdLogger) { table.ClusteringColumns = make([]*ColumnMetadata, clusteringColumnCount) if table.KeyValidator != "" { - keyValidatorParsed := parseType(table.KeyValidator, logger) + keyValidatorParsed := parseType(table.KeyValidator, byte(protoVer), logger) table.PartitionKey = make([]*ColumnMetadata, len(keyValidatorParsed.types)) } else { // Cassandra 3.x+ partitionKeyCount := componentColumnCountOfType(table.Columns, ColumnPartitionKey) @@ -947,11 +947,11 @@ func getColumnMetadata(session *Session, keyspaceName string) ([]ColumnMetadata, return columns, nil } -func getTypeInfo(t string, logger StdLogger) TypeInfo { +func getTypeInfo(t string, protoVer byte, logger StdLogger) TypeInfo { if strings.HasPrefix(t, apacheCassandraTypePrefix) { - t = apacheToCassandraType(t) + return getCassandraLongType(t, protoVer, logger) } - return getCassandraType(t, logger) + return getCassandraType(t, protoVer, logger) } func getViewsMetadata(session *Session, keyspaceName string) ([]ViewMetadata, error) { @@ -987,7 +987,7 @@ func getViewsMetadata(session *Session, keyspaceName string) ([]ViewMetadata, er } view.FieldTypes = make([]TypeInfo, len(argumentTypes)) for i, argumentType := range argumentTypes { - view.FieldTypes[i] = getTypeInfo(argumentType, session.logger) + view.FieldTypes[i] = getTypeInfo(argumentType, byte(session.cfg.ProtoVersion), session.logger) } views = append(views, view) } @@ -1108,10 +1108,10 @@ func getFunctionsMetadata(session *Session, keyspaceName string) ([]FunctionMeta if err != nil { return nil, err } - function.ReturnType = getTypeInfo(returnType, session.logger) + function.ReturnType = getTypeInfo(returnType, byte(session.cfg.ProtoVersion), session.logger) function.ArgumentTypes = make([]TypeInfo, len(argumentTypes)) for i, argumentType := range argumentTypes { - function.ArgumentTypes[i] = getTypeInfo(argumentType, session.logger) + function.ArgumentTypes[i] = getTypeInfo(argumentType, byte(session.cfg.ProtoVersion), session.logger) } functions = append(functions, function) } @@ -1165,11 +1165,11 @@ func getAggregatesMetadata(session *Session, keyspaceName string) ([]AggregateMe if err != nil { return nil, err } - aggregate.ReturnType = getTypeInfo(returnType, session.logger) - aggregate.StateType = getTypeInfo(stateType, session.logger) + aggregate.ReturnType = getTypeInfo(returnType, byte(session.cfg.ProtoVersion), session.logger) + aggregate.StateType = getTypeInfo(stateType, byte(session.cfg.ProtoVersion), session.logger) aggregate.ArgumentTypes = make([]TypeInfo, len(argumentTypes)) for i, argumentType := range argumentTypes { - aggregate.ArgumentTypes[i] = getTypeInfo(argumentType, session.logger) + aggregate.ArgumentTypes[i] = getTypeInfo(argumentType, byte(session.cfg.ProtoVersion), session.logger) } aggregates = append(aggregates, aggregate) } @@ -1186,6 +1186,7 @@ type typeParser struct { input string index int logger StdLogger + proto byte } // the type definition parser result @@ -1197,8 +1198,8 @@ type typeParserResult struct { } // Parse the type definition used for validator and comparator schema data -func parseType(def string, logger StdLogger) typeParserResult { - parser := &typeParser{input: def, logger: logger} +func parseType(def string, protoVer byte, logger StdLogger) typeParserResult { + parser := &typeParser{input: def, proto: protoVer, logger: logger} return parser.parse() } @@ -1209,6 +1210,9 @@ const ( LIST_TYPE = "org.apache.cassandra.db.marshal.ListType" SET_TYPE = "org.apache.cassandra.db.marshal.SetType" MAP_TYPE = "org.apache.cassandra.db.marshal.MapType" + UDT_TYPE = "org.apache.cassandra.db.marshal.UserType" + TUPLE_TYPE = "org.apache.cassandra.db.marshal.TupleType" + VECTOR_TYPE = "org.apache.cassandra.db.marshal.VectorType" ) // represents a class specification in the type def AST @@ -1217,6 +1221,7 @@ type typeParserClassNode struct { params []typeParserParamNode // this is the segment of the input string that defined this node input string + proto byte } // represents a class parameter in the type def AST @@ -1236,6 +1241,7 @@ func (t *typeParser) parse() typeParserResult { NativeType{ typ: TypeCustom, custom: t.input, + proto: t.proto, }, }, reversed: []bool{false}, @@ -1313,7 +1319,8 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { elem := class.params[0].class.asTypeInfo() return CollectionType{ NativeType: NativeType{ - typ: TypeList, + typ: TypeList, + proto: class.proto, }, Elem: elem, } @@ -1322,7 +1329,8 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { elem := class.params[0].class.asTypeInfo() return CollectionType{ NativeType: NativeType{ - typ: TypeSet, + typ: TypeSet, + proto: class.proto, }, Elem: elem, } @@ -1332,7 +1340,8 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { elem := class.params[1].class.asTypeInfo() return CollectionType{ NativeType: NativeType{ - typ: TypeMap, + typ: TypeMap, + proto: class.proto, }, Key: key, Elem: elem, @@ -1340,7 +1349,7 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { } // must be a simple type or custom type - info := NativeType{typ: getApacheCassandraType(class.name)} + info := NativeType{typ: getApacheCassandraType(class.name), proto: class.proto} if info.typ == TypeCustom { // add the entire class definition info.custom = class.input @@ -1370,6 +1379,7 @@ func (t *typeParser) parseClassNode() (node *typeParserClassNode, ok bool) { name: name, params: params, input: t.input[startIndex:endIndex], + proto: t.proto, } return node, true } diff --git a/metadata_test.go b/metadata_test.go index 6e3633ccc..6b3d11982 100644 --- a/metadata_test.go +++ b/metadata_test.go @@ -636,12 +636,14 @@ func TestTypeParser(t *testing.T) { }, ) - // custom + // udt assertParseNonCompositeType( t, "org.apache.cassandra.db.marshal.UserType(sandbox,61646472657373,737472656574:org.apache.cassandra.db.marshal.UTF8Type,63697479:org.apache.cassandra.db.marshal.UTF8Type,7a6970:org.apache.cassandra.db.marshal.Int32Type)", - assertTypeInfo{Type: TypeCustom, Custom: "org.apache.cassandra.db.marshal.UserType(sandbox,61646472657373,737472656574:org.apache.cassandra.db.marshal.UTF8Type,63697479:org.apache.cassandra.db.marshal.UTF8Type,7a6970:org.apache.cassandra.db.marshal.Int32Type)"}, + assertTypeInfo{Type: TypeUDT, Custom: ""}, ) + + // custom assertParseNonCompositeType( t, "org.apache.cassandra.db.marshal.DynamicCompositeType(u=>org.apache.cassandra.db.marshal.UUIDType,d=>org.apache.cassandra.db.marshal.DateType,t=>org.apache.cassandra.db.marshal.TimeUUIDType,b=>org.apache.cassandra.db.marshal.BytesType,s=>org.apache.cassandra.db.marshal.UTF8Type,B=>org.apache.cassandra.db.marshal.BooleanType,a=>org.apache.cassandra.db.marshal.AsciiType,l=>org.apache.cassandra.db.marshal.LongType,i=>org.apache.cassandra.db.marshal.IntegerType,x=>org.apache.cassandra.db.marshal.LexicalUUIDType)", @@ -700,7 +702,7 @@ func assertParseNonCompositeType( ) { log := &defaultLogger{} - result := parseType(def, log) + result := parseType(def, protoVersion4, log) if len(result.reversed) != 1 { t.Errorf("%s expected %d reversed values but there were %d", def, 1, len(result.reversed)) } @@ -731,7 +733,7 @@ func assertParseCompositeType( ) { log := &defaultLogger{} - result := parseType(def, log) + result := parseType(def, protoVersion4, log) if len(result.reversed) != len(typesExpected) { t.Errorf("%s expected %d reversed values but there were %d", def, len(typesExpected), len(result.reversed)) } @@ -747,7 +749,7 @@ func assertParseCompositeType( if !result.isComposite { t.Errorf("%s: Expected composite", def) } - if result.collections == nil { + if result.collections == nil && collectionsExpected != nil { t.Errorf("%s: Expected non-nil collections: %v", def, result.collections) } diff --git a/vector_test.go b/vector_test.go new file mode 100644 index 000000000..4e52a8856 --- /dev/null +++ b/vector_test.go @@ -0,0 +1,401 @@ +//go:build all || cassandra +// +build all cassandra + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 + * Copyright (c) 2016, The Gocql authors, + * provided under the BSD-3-Clause License. + * See the NOTICE file distributed with this work for additional information. + */ + +package gocql + +import ( + "fmt" + "github.com/stretchr/testify/require" + "gopkg.in/inf.v0" + "net" + "reflect" + "testing" + "time" +) + +type person struct { + FirstName string `cql:"first_name"` + LastName string `cql:"last_name"` + Age int `cql:"age"` +} + +func (p person) String() string { + return fmt.Sprintf("Person{firstName: %s, lastName: %s, Age: %d}", p.FirstName, p.LastName, p.Age) +} + +func TestVector_Marshaler(t *testing.T) { + session := createSession(t) + defer session.Close() + + if flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + err := createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_fixed(id int primary key, vec vector);`) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_variable(id int primary key, vec vector);`) + if err != nil { + t.Fatal(err) + } + + insertFixVec := []float32{8, 2.5, -5.0} + err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 1, insertFixVec).Exec() + if err != nil { + t.Fatal(err) + } + var selectFixVec []float32 + err = session.Query("SELECT vec FROM vector_fixed WHERE id = ?", 1).Scan(&selectFixVec) + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "fixed size element vector", insertFixVec, selectFixVec) + + longText := randomText(500) + insertVarVec := []string{"apache", "cassandra", longText, "gocql"} + err = session.Query("INSERT INTO vector_variable(id, vec) VALUES(?, ?)", 1, insertVarVec).Exec() + if err != nil { + t.Fatal(err) + } + var selectVarVec []string + err = session.Query("SELECT vec FROM vector_variable WHERE id = ?", 1).Scan(&selectVarVec) + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "variable size element vector", insertVarVec, selectVarVec) +} + +func TestVector_Types(t *testing.T) { + session := createSession(t) + defer session.Close() + + if flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + timestamp1, _ := time.Parse("2006-01-02", "2000-01-01") + timestamp2, _ := time.Parse("2006-01-02 15:04:05", "2024-01-01 10:31:45") + timestamp3, _ := time.Parse("2006-01-02 15:04:05.000", "2024-05-01 10:31:45.987") + + date1, _ := time.Parse("2006-01-02", "2000-01-01") + date2, _ := time.Parse("2006-01-02", "2022-03-14") + date3, _ := time.Parse("2006-01-02", "2024-12-31") + + time1, _ := time.Parse("15:04:05", "01:00:00") + time2, _ := time.Parse("15:04:05", "15:23:59") + time3, _ := time.Parse("15:04:05.000", "10:31:45.987") + + duration1 := Duration{0, 1, 1920000000000} + duration2 := Duration{1, 1, 1920000000000} + duration3 := Duration{31, 0, 60000000000} + + map1 := make(map[string]int) + map1["a"] = 1 + map1["b"] = 2 + map1["c"] = 3 + map2 := make(map[string]int) + map2["abc"] = 123 + map3 := make(map[string]int) + + tests := []struct { + name string + cqlType string + value interface{} + comparator func(interface{}, interface{}) + }{ + {name: "ascii", cqlType: TypeAscii.String(), value: []string{"a", "1", "Z"}}, + {name: "bigint", cqlType: TypeBigInt.String(), value: []int64{1, 2, 3}}, + {name: "blob", cqlType: TypeBlob.String(), value: [][]byte{[]byte{1, 2, 3}, []byte{4, 5, 6, 7}, []byte{8, 9}}}, + {name: "boolean", cqlType: TypeBoolean.String(), value: []bool{true, false, true}}, + {name: "counter", cqlType: TypeCounter.String(), value: []int64{5, 6, 7}}, + {name: "decimal", cqlType: TypeDecimal.String(), value: []inf.Dec{*inf.NewDec(1, 0), *inf.NewDec(2, 1), *inf.NewDec(-3, 2)}}, + {name: "double", cqlType: TypeDouble.String(), value: []float64{0.1, -1.2, 3}}, + {name: "float", cqlType: TypeFloat.String(), value: []float32{0.1, -1.2, 3}}, + {name: "int", cqlType: TypeInt.String(), value: []int32{1, 2, 3}}, + {name: "text", cqlType: TypeText.String(), value: []string{"a", "b", "c"}}, + {name: "timestamp", cqlType: TypeTimestamp.String(), value: []time.Time{timestamp1, timestamp2, timestamp3}}, + {name: "uuid", cqlType: TypeUUID.String(), value: []UUID{MustRandomUUID(), MustRandomUUID(), MustRandomUUID()}}, + {name: "varchar", cqlType: TypeVarchar.String(), value: []string{"abc", "def", "ghi"}}, + {name: "varint", cqlType: TypeVarint.String(), value: []uint64{uint64(1234), uint64(123498765), uint64(18446744073709551615)}}, + {name: "timeuuid", cqlType: TypeTimeUUID.String(), value: []UUID{TimeUUID(), TimeUUID(), TimeUUID()}}, + { + name: "inet", + cqlType: TypeInet.String(), + value: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv4(192, 168, 1, 1), net.IPv4(8, 8, 8, 8)}, + comparator: func(e interface{}, a interface{}) { + expected := e.([]net.IP) + actual := a.([]net.IP) + assertEqual(t, "vector size", len(expected), len(actual)) + for i, _ := range expected { + assertTrue(t, "vector", expected[i].Equal(actual[i])) + } + }, + }, + {name: "date", cqlType: TypeDate.String(), value: []time.Time{date1, date2, date3}}, + {name: "time", cqlType: TypeTimestamp.String(), value: []time.Time{time1, time2, time3}}, + {name: "smallint", cqlType: TypeSmallInt.String(), value: []int16{127, 256, -1234}}, + {name: "tinyint", cqlType: TypeTinyInt.String(), value: []int8{127, 9, -123}}, + {name: "duration", cqlType: TypeDuration.String(), value: []Duration{duration1, duration2, duration3}}, + {name: "vector_vector_float", cqlType: "vector", value: [][]float32{{0.1, -1.2, 3, 5, 5}, {10.1, -122222.0002, 35.0, 1, 1}, {0, 0, 0, 0, 0}}}, + {name: "vector_vector_set_float", cqlType: "vector, 5>", value: [][][]float32{ + {{1, 2}, {2, -1}, {3}, {0}, {-1.3}}, + {{2, 3}, {2, -1}, {3}, {0}, {-1.3}}, + {{1, 1000.0}, {0}, {}, {12, 14, 15, 16}, {-1.3}}, + }}, + {name: "vector_tuple_text_int_float", cqlType: "tuple", value: [][]interface{}{{"a", 1, float32(0.5)}, {"b", 2, float32(-1.2)}, {"c", 3, float32(0)}}}, + {name: "vector_tuple_text_list_text", cqlType: "tuple>", value: [][]interface{}{{"a", []string{"b", "c"}}, {"d", []string{"e", "f", "g"}}, {"h", []string{"i"}}}}, + {name: "vector_set_text", cqlType: "set", value: [][]string{{"a", "b"}, {"c", "d"}, {"e", "f"}}}, + {name: "vector_list_int", cqlType: "list", value: [][]int32{{1, 2, 3}, {-1, -2, -3}, {0, 0, 0}}}, + {name: "vector_map_text_int", cqlType: "map", value: []map[string]int{map1, map2, map3}}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tableName := fmt.Sprintf("vector_%s", test.name) + err := createTable(session, fmt.Sprintf(`CREATE TABLE IF NOT EXISTS gocql_test.%s(id int primary key, vec vector<%s, 3>);`, tableName, test.cqlType)) + if err != nil { + t.Fatal(err) + } + + err = session.Query(fmt.Sprintf("INSERT INTO %s(id, vec) VALUES(?, ?)", tableName), 1, test.value).Exec() + if err != nil { + t.Fatal(err) + } + + v := reflect.New(reflect.TypeOf(test.value)) + err = session.Query(fmt.Sprintf("SELECT vec FROM %s WHERE id = ?", tableName), 1).Scan(v.Interface()) + if err != nil { + t.Fatal(err) + } + if test.comparator != nil { + test.comparator(test.value, v.Elem().Interface()) + } else { + assertDeepEqual(t, "vector", test.value, v.Elem().Interface()) + } + }) + } +} + +func TestVector_MarshalerUDT(t *testing.T) { + session := createSession(t) + defer session.Close() + + if flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + err := createTable(session, `CREATE TYPE gocql_test.person( + first_name text, + last_name text, + age int);`) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, `CREATE TABLE gocql_test.vector_relatives( + id int, + couple vector, + primary key(id) + );`) + if err != nil { + t.Fatal(err) + } + + p1 := person{"Johny", "Bravo", 25} + p2 := person{"Capitan", "Planet", 5} + insVec := []person{p1, p2} + + err = session.Query("INSERT INTO vector_relatives(id, couple) VALUES(?, ?)", 1, insVec).Exec() + if err != nil { + t.Fatal(err) + } + + var selVec []person + + err = session.Query("SELECT couple FROM vector_relatives WHERE id = ?", 1).Scan(&selVec) + if err != nil { + t.Fatal(err) + } + + assertDeepEqual(t, "udt", &insVec, &selVec) +} + +func TestVector_Empty(t *testing.T) { + session := createSession(t) + defer session.Close() + + if flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + err := createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_fixed_null(id int primary key, vec vector);`) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_variable_null(id int primary key, vec vector);`) + if err != nil { + t.Fatal(err) + } + + err = session.Query("INSERT INTO vector_fixed_null(id) VALUES(?)", 1).Exec() + if err != nil { + t.Fatal(err) + } + var selectFixVec []float32 + err = session.Query("SELECT vec FROM vector_fixed_null WHERE id = ?", 1).Scan(&selectFixVec) + if err != nil { + t.Fatal(err) + } + assertTrue(t, "fixed size element vector is empty", selectFixVec == nil) + + err = session.Query("INSERT INTO vector_variable_null(id) VALUES(?)", 1).Exec() + if err != nil { + t.Fatal(err) + } + var selectVarVec []string + err = session.Query("SELECT vec FROM vector_variable_null WHERE id = ?", 1).Scan(&selectVarVec) + if err != nil { + t.Fatal(err) + } + assertTrue(t, "variable size element vector is empty", selectVarVec == nil) +} + +func TestVector_MissingDimension(t *testing.T) { + session := createSession(t) + defer session.Close() + + if flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + err := createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_fixed(id int primary key, vec vector);`) + if err != nil { + t.Fatal(err) + } + + err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 1, []float32{8, -5.0}).Exec() + require.Error(t, err, "expected vector with 3 dimensions, received 2") + + err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 1, []float32{8, -5.0, 1, 3}).Exec() + require.Error(t, err, "expected vector with 3 dimensions, received 4") +} + +func TestVector_SubTypeParsing(t *testing.T) { + tests := []struct { + name string + custom string + expected TypeInfo + }{ + {name: "text", custom: "org.apache.cassandra.db.marshal.UTF8Type", expected: NativeType{typ: TypeVarchar}}, + {name: "set_int", custom: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.Int32Type)", expected: CollectionType{NativeType{typ: TypeSet}, nil, NativeType{typ: TypeInt}}}, + { + name: "udt", + custom: "org.apache.cassandra.db.marshal.UserType(gocql_test,706572736f6e,66697273745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type,6c6173745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type,616765:org.apache.cassandra.db.marshal.Int32Type)", + expected: UDTTypeInfo{ + NativeType{typ: TypeUDT}, + "gocql_test", + "person", + []UDTField{ + UDTField{"first_name", NativeType{typ: TypeVarchar}}, + UDTField{"last_name", NativeType{typ: TypeVarchar}}, + UDTField{"age", NativeType{typ: TypeInt}}, + }, + }, + }, + { + name: "tuple", + custom: "org.apache.cassandra.db.marshal.TupleType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.UTF8Type)", + expected: TupleTypeInfo{ + NativeType{typ: TypeTuple}, + []TypeInfo{ + NativeType{typ: TypeVarchar}, + NativeType{typ: TypeInt}, + NativeType{typ: TypeVarchar}, + }, + }, + }, + { + name: "vector_vector_inet", + custom: "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.InetAddressType, 2), 3)", + expected: VectorType{ + NativeType{typ: TypeCustom, custom: VECTOR_TYPE}, + VectorType{ + NativeType{typ: TypeCustom, custom: VECTOR_TYPE}, + NativeType{typ: TypeInet}, + 2, + }, + 3, + }, + }, + { + name: "map_int_vector_text", + custom: "org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.UTF8Type, 10))", + expected: CollectionType{ + NativeType{typ: TypeMap}, + NativeType{typ: TypeInt}, + VectorType{ + NativeType{typ: TypeCustom, custom: VECTOR_TYPE}, + NativeType{typ: TypeVarchar}, + 10, + }, + }, + }, + { + name: "set_map_vector_text_text", + custom: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.Int32Type, 10),org.apache.cassandra.db.marshal.UTF8Type))", + expected: CollectionType{ + NativeType{typ: TypeSet}, + nil, + CollectionType{ + NativeType{typ: TypeMap}, + VectorType{ + NativeType{typ: TypeCustom, custom: VECTOR_TYPE}, + NativeType{typ: TypeInt}, + 10, + }, + NativeType{typ: TypeVarchar}, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f := newFramer(nil, 0) + f.writeShort(0) + f.writeString(fmt.Sprintf("org.apache.cassandra.db.marshal.VectorType(%s, 2)", test.custom)) + parsedType := f.readTypeInfo() + require.IsType(t, parsedType, VectorType{}) + vectorType := parsedType.(VectorType) + assertEqual(t, "dimensions", 2, vectorType.Dimensions) + assertDeepEqual(t, "vector", test.expected, vectorType.SubType) + }) + } +}