Skip to content

Commit

Permalink
Fix sparse parser. (#1262)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

1. Fix: memory leak when parsing sparse vector.
2. Fix: parser reduce warning by adding new DataType: EmptyArray
3. Refactor: knn expr parser.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
  • Loading branch information
small-turtle-1 authored May 30, 2024
1 parent 5be183c commit 91ec6de
Show file tree
Hide file tree
Showing 19 changed files with 1,702 additions and 1,713 deletions.
4 changes: 4 additions & 0 deletions src/function/cast/cast_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import embedding_cast;
import varchar_cast;
import tensor_cast;
import tensor_array_cast;
import empty_array_cast;
import logger;
import stl;
import sparse_cast;
Expand Down Expand Up @@ -249,6 +250,9 @@ BoundCastFunc CastFunction::GetBoundFunc(const DataType &source, const DataType
UnrecoverableError(error_message);
break;
}
case kEmptyArray: {
return BindEmptyArrayCast(source, target);
}
default:
String error_message = fmt::format("Can't cast from {} to {}", source.ToString(), target.ToString());
LOG_CRITICAL(error_message);
Expand Down
58 changes: 58 additions & 0 deletions src/function/cast/empty_array_cast.cppm
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

export module empty_array_cast;

import stl;
import data_type;
import bound_cast_func;
import column_vector_cast;
import infinity_exception;
import internal_types;
import logical_type;
import third_party;

namespace infinity {

struct EmptyTryCastToFixlen;

export inline BoundCastFunc BindEmptyArrayCast(const DataType &source, const DataType &target) {
switch (target.type()) {
case LogicalType::kSparse: {
return BoundCastFunc(&ColumnVectorCast::TryCastColumnVector<EmptyArrayT, SparseT, EmptyTryCastToFixlen>);
}
default: {
UnrecoverableError("Not implemented");
}
}
return BoundCastFunc(nullptr);
}

struct EmptyTryCastToFixlen {
template <typename SourceType, typename TargetType>
static bool Run(SourceType, TargetType &) {
UnrecoverableError(fmt::format("Not support to cast from {} to {}", DataType::TypeToString<SourceType>(), DataType::TypeToString<TargetType>()));
return false;
}
};

template<>
bool EmptyTryCastToFixlen::Run(EmptyArrayT, SparseT &target) {
target.nnz_ = 0;
return true;
}

}
8 changes: 6 additions & 2 deletions src/parser/expr/constant_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ std::string ConstantExpr::ToString() const {
size_t nnz = double_sparse_array_.first.size();
return SparseT::Sparse2StringT2(double_sparse_array_.first.data(), double_sparse_array_.second.data(), nnz);
}
default: {
ParserError("Unexpected branch");
case LiteralType::kEmptyArray: {
return {};
}
}
Expand All @@ -156,6 +155,7 @@ int32_t ConstantExpr::GetSizeInBytes() const {
size += sizeof(int64_t);
break;
}
case LiteralType::kEmptyArray:
case LiteralType::kNull: {
break;
}
Expand Down Expand Up @@ -221,6 +221,7 @@ void ConstantExpr::WriteAdv(char *&ptr) const {
WriteBufAdv<int64_t>(ptr, integer_value_);
break;
}
case LiteralType::kEmptyArray:
case LiteralType::kNull: {
break;
}
Expand Down Expand Up @@ -307,6 +308,7 @@ std::shared_ptr<ParsedExpr> ConstantExpr::ReadAdv(char *&ptr, int32_t maxbytes)
const_expr->integer_value_ = integer_value;
break;
}
case LiteralType::kEmptyArray:
case LiteralType::kNull: {
break;
}
Expand Down Expand Up @@ -401,6 +403,7 @@ nlohmann::json ConstantExpr::Serialize() const {
j["value"] = integer_value_;
break;
}
case LiteralType::kEmptyArray:
case LiteralType::kNull: {
break;
}
Expand Down Expand Up @@ -468,6 +471,7 @@ std::shared_ptr<ParsedExpr> ConstantExpr::Deserialize(const nlohmann::json &cons
const_expr->integer_value_ = constant_expr["value"].get<int64_t>();
break;
}
case LiteralType::kEmptyArray:
case LiteralType::kNull: {
break;
}
Expand Down
1 change: 1 addition & 0 deletions src/parser/expr/constant_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ enum class LiteralType : int32_t {
kInterval,
kLongSparseArray,
kDoubleSparseArray,
kEmptyArray,
};

class ConstantExpr : public ParsedExpr {
Expand Down
98 changes: 98 additions & 0 deletions src/parser/expr/knn_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,104 @@ std::string KnnExpr::ToString() const {
return expr_str;
}

bool KnnExpr::InitDistanceType(const char *distance_type) {
if (strcmp(distance_type, "l2") == 0) {
distance_type_ = infinity::KnnDistanceType::kL2;
} else if (strcmp(distance_type, "ip") == 0) {
distance_type_ = infinity::KnnDistanceType::kInnerProduct;
} else if (strcmp(distance_type, "cosine") == 0) {
distance_type_ = infinity::KnnDistanceType::kCosine;
} else if (strcmp(distance_type, "hamming") == 0) {
distance_type_ = infinity::KnnDistanceType::kHamming;
} else {
return false;
}
return true;
}

bool KnnExpr::InitEmbedding(const char *data_type, const ConstantExpr *query_vec) {
if (strcmp(data_type, "float") == 0 and distance_type_ != infinity::KnnDistanceType::kHamming) {
embedding_data_type_ = infinity::EmbeddingDataType::kElemFloat;
if (!(query_vec->double_array_.empty())) {
dimension_ = query_vec->double_array_.size();
embedding_data_ptr_ = new float[dimension_];
for (long i = 0; i < dimension_; ++i) {
((float *)(embedding_data_ptr_))[i] = query_vec->double_array_[i];
}
}
if (!(query_vec->long_array_.empty())) {
dimension_ = query_vec->long_array_.size();
embedding_data_ptr_ = new float[dimension_];
for (long i = 0; i < dimension_; ++i) {
((float *)(embedding_data_ptr_))[i] = query_vec->long_array_[i];
}
}
} else if (strcmp(data_type, "tinyint") == 0 and distance_type_ != infinity::KnnDistanceType::kHamming) {
dimension_ = query_vec->long_array_.size();
embedding_data_type_ = infinity::EmbeddingDataType::kElemInt8;
embedding_data_ptr_ = new char[dimension_];

for (long i = 0; i < dimension_; ++i) {
((char *)embedding_data_ptr_)[i] = query_vec->long_array_[i];
}
} else if (strcmp(data_type, "smallint") == 0 and distance_type_ != infinity::KnnDistanceType::kHamming) {
dimension_ = query_vec->long_array_.size();
embedding_data_type_ = infinity::EmbeddingDataType::kElemInt16;
embedding_data_ptr_ = new short int[dimension_];

for (long i = 0; i < dimension_; ++i) {
((short int *)embedding_data_ptr_)[i] = query_vec->long_array_[i];
}
} else if (strcmp(data_type, "integer") == 0 and distance_type_ != infinity::KnnDistanceType::kHamming) {
dimension_ = query_vec->long_array_.size();
embedding_data_type_ = infinity::EmbeddingDataType::kElemInt32;
embedding_data_ptr_ = new int[dimension_];

for (long i = 0; i < dimension_; ++i) {
((int *)embedding_data_ptr_)[i] = query_vec->long_array_[i];
}
} else if (strcmp(data_type, "bigint") == 0 and distance_type_ != infinity::KnnDistanceType::kHamming) {
dimension_ = query_vec->long_array_.size();
embedding_data_type_ = infinity::EmbeddingDataType::kElemInt64;
embedding_data_ptr_ = new long[dimension_];

memcpy(embedding_data_ptr_, (void *)query_vec->long_array_.data(), dimension_ * sizeof(long));
} else if (strcmp(data_type, "bit") == 0 and distance_type_ == infinity::KnnDistanceType::kHamming) {
dimension_ = query_vec->long_array_.size();
if (dimension_ % 8 == 0) {
embedding_data_type_ = infinity::EmbeddingDataType::kElemBit;
long embedding_size = dimension_ / 8;
char *char_ptr = new char[embedding_size];
uint8_t *data_ptr = reinterpret_cast<uint8_t *>(char_ptr);
embedding_data_ptr_ = char_ptr;
for (long i = 0; i < embedding_size; ++i) {
uint8_t embedding_unit = 0;
for (long bit_idx = 0; bit_idx < 8; ++bit_idx) {
if (query_vec->long_array_[i * 8 + bit_idx] == 1) {
embedding_unit |= (uint8_t(1) << bit_idx);
} else if (query_vec->long_array_[i * 8 + bit_idx] == 0) {
// no-op
} else {
return false;
}
}
data_ptr[i] = embedding_unit;
}
} else {
return false;
}
} else if (strcmp(data_type, "double") == 0 and distance_type_ != infinity::KnnDistanceType::kHamming) {
dimension_ = query_vec->double_array_.size();
embedding_data_type_ = infinity::EmbeddingDataType::kElemDouble;
embedding_data_ptr_ = new double[dimension_];

memcpy(embedding_data_ptr_, (void *)query_vec->double_array_.data(), dimension_ * sizeof(double));
} else {
return false;
}
return true;
}

std::string KnnExpr::KnnDistanceType2Str(KnnDistanceType knn_distance_type) {
switch (knn_distance_type) {
case KnnDistanceType::kL2: {
Expand Down
7 changes: 6 additions & 1 deletion src/parser/expr/knn_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

#pragma once

#include "expr/constant_expr.h"
#include "parsed_expr.h"
#include "statement/statement_common.h"
#include "type/complex/embedding_type.h"
#include <vector>

namespace infinity {

enum class KnnDistanceType {
Expand All @@ -36,6 +37,10 @@ class KnnExpr : public ParsedExpr {

[[nodiscard]] std::string ToString() const override;

bool InitDistanceType(const char *distance_type);

bool InitEmbedding(const char *data_type, const ConstantExpr *query_vec);

public:
static std::string KnnDistanceType2Str(KnnDistanceType knn_distance_type);

Expand Down
Loading

0 comments on commit 91ec6de

Please sign in to comment.