Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple Sparse vector support. #1249

Merged
merged 3 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/common/stl.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ export namespace std {
using std::remove_if;
using std::reverse;
using std::sort;
using std::max_element;
using std::min_element;
using std::unique;
using std::reduce;
using std::accumulate;
Expand Down
20 changes: 6 additions & 14 deletions src/executor/operator/physical_show.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1258,15 +1258,7 @@ void PhysicalShow::ExecuteShowColumns(QueryContext *query_context, ShowOperatorS
++output_column_idx;
{
// Append column type to the second column, if the column type is embedded type, append the embedded type
String column_type;
if (column->type()->type() == kEmbedding) {
auto type = column->type();
auto embedding_type = type->type_info()->ToString();
column_type = fmt::format("{}({})", type->ToString(), embedding_type);

} else {
column_type = column->type()->ToString();
}
String column_type = column->type()->ToString();
Value value = Value::MakeVarchar(column_type);
ValueExpression value_expr(value);
value_expr.AppendToChunk(output_block_ptr->column_vectors[output_column_idx]);
Expand Down Expand Up @@ -1859,7 +1851,7 @@ void PhysicalShow::ExecuteShowConfigs(QueryContext *query_context, ShowOperatorS
}

{
{// option name
{ // option name
Value value = Value::MakeVarchar(TIME_ZONE_OPTION_NAME);
ValueExpression value_expr(value);
value_expr.AppendToChunk(output_block_ptr->column_vectors[0]);
Expand Down Expand Up @@ -2013,10 +2005,10 @@ void PhysicalShow::ExecuteShowConfigs(QueryContext *query_context, ShowOperatorS

{
{
// option name
Value value = Value::MakeVarchar(LOG_FILENAME_OPTION_NAME);
ValueExpression value_expr(value);
value_expr.AppendToChunk(output_block_ptr->column_vectors[0]);
// option name
Value value = Value::MakeVarchar(LOG_FILENAME_OPTION_NAME);
ValueExpression value_expr(value);
value_expr.AppendToChunk(output_block_ptr->column_vectors[0]);
}
{
// option name type
Expand Down
4 changes: 4 additions & 0 deletions src/function/cast/cast_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import embedding_cast;
import varchar_cast;
import tensor_cast;
import tensor_array_cast;
import sparse_cast;

import third_party;

Expand Down Expand Up @@ -156,6 +157,9 @@ BoundCastFunc CastFunction::GetBoundFunc(const DataType &source, const DataType
case kTensorArray: {
return BindTensorArrayCast(source, target);
}
case kSparse: {
return BindSparseCast(source, target);
}
case kRowID: {
UnrecoverableError(fmt::format("Can't cast from {} to {}", source.ToString(), target.ToString()));
}
Expand Down
270 changes: 270 additions & 0 deletions src/function/cast/sparse_cast.cppm
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

export module sparse_cast;

import stl;
import bound_cast_func;
import data_type;
import logical_type;
import status;
import infinity_exception;
import sparse_info;
import column_vector_cast;
import internal_types;
import column_vector;
import third_party;
import vector_buffer;
import fix_heap;
import sparse_info;
import embedding_cast;

namespace infinity {

struct SparseTryCastToSparse;

export inline BoundCastFunc BindSparseCast(const DataType &source, const DataType &target) {
if (source.type() != LogicalType::kSparse || target.type() != LogicalType::kSparse) {
RecoverableError(Status::NotSupportedTypeConversion(source.ToString(), target.ToString()));
}
return BoundCastFunc(&ColumnVectorCast::TryCastColumnVectorVarlenWithType<SparseT, SparseT, SparseTryCastToSparse>);
}

struct SparseTryCastToSparse {
template <typename SourceT, typename TargetT>
static bool Run(const SourceT &source,
const DataType &source_type,
ColumnVector *source_vector_ptr,
TargetT &target,
const DataType &target_type,
ColumnVector *target_vector_ptr) {
UnrecoverableError("Unexpected case");
return false;
}
};

template <typename TargetValueType, typename TargetIndiceType, typename SourceValueType, typename SourceIndiceType>
void SparseTryCastToSparseFunInner(const SparseInfo *source_info,
const SparseT &source,
FixHeapManager *source_fix_heap_mgr,
const SparseInfo *target_info,
SparseT &target,
FixHeapManager *target_fix_heap_mgr) {
const auto &[source_nnz, source_chunk_id, source_chunk_offset] = source;
target.nnz_ = source_nnz;
const_ptr_t source_ptr = source_fix_heap_mgr->GetRawPtrFromChunk(source_chunk_id, source_chunk_offset);
SizeT sparse_bytes = source_info->SparseSize(source_nnz);
if constexpr (std::is_same_v<TargetValueType, SourceValueType>) {
if constexpr (std::is_same_v<TargetIndiceType, SourceIndiceType>) {
std::tie(target.chunk_id_, target.chunk_offset_) = target_fix_heap_mgr->AppendToHeap(source_ptr, sparse_bytes);
} else {
auto target_tmp_ptr = MakeUniqueForOverwrite<TargetIndiceType[]>(source_nnz);
const SizeT source_indice_size = source_info->IndiceSize(source_nnz);
const SizeT target_indice_size = target_info->IndiceSize(source_nnz);
if (!EmbeddingTryCastToFixlen::Run(reinterpret_cast<const SourceIndiceType *>(source_ptr),
reinterpret_cast<TargetIndiceType *>(target_tmp_ptr.get()),
source_nnz)) {
UnrecoverableError(fmt::format("Fail to case from sparse with idx {} to sparse with idx {}",
DataType::TypeToString<SourceValueType>(),
DataType::TypeToString<TargetValueType>()));
}
std::tie(target.chunk_id_, target.chunk_offset_) =
target_fix_heap_mgr->AppendToHeap(reinterpret_cast<const char *>(target_tmp_ptr.get()), target_indice_size);

const_ptr_t source_data_ptr = source_ptr + source_indice_size;
target_fix_heap_mgr->AppendToHeap(source_data_ptr, source_info->DataSize(source_nnz));
}
} else {
UnrecoverableError("Unimplemented");
}
}

template <typename TargetValueType, typename TargetIndiceType, typename SourceValueType, typename SourceIndiceType>
void SparseTryCastToSparseFunT4(const SparseInfo *source_info,
const SparseT &source,
ColumnVector *source_vector_ptr,
const SparseInfo *target_info,
SparseT &target,
ColumnVector *target_vector_ptr) {
SparseTryCastToSparseFunInner<TargetValueType, TargetIndiceType, SourceValueType, SourceIndiceType>(
source_info,
source,
source_vector_ptr->buffer_->fix_heap_mgr_.get(),
target_info,
target,
target_vector_ptr->buffer_->fix_heap_mgr_.get());
}

template <typename TargetValueType, typename TargetIndiceType, typename SourceValueType>
void SparseTryCastToSparseFunT3(const SparseInfo *source_info,
const SparseT &source,
ColumnVector *source_vector_ptr,
const SparseInfo *target_info,
SparseT &target,
ColumnVector *target_vector_ptr) {
switch (source_info->IndexType()) {
case kElemInt8: {
SparseTryCastToSparseFunT4<TargetValueType, TargetIndiceType, SourceValueType, TinyIntT>(source_info,
source,
source_vector_ptr,
target_info,
target,
target_vector_ptr);
break;
}
case kElemInt16: {
SparseTryCastToSparseFunT4<TargetValueType, TargetIndiceType, SourceValueType, SmallIntT>(source_info,
source,
source_vector_ptr,
target_info,
target,
target_vector_ptr);
break;
}
case kElemInt32: {
SparseTryCastToSparseFunT4<TargetValueType, TargetIndiceType, SourceValueType, IntegerT>(source_info,
source,
source_vector_ptr,
target_info,
target,
target_vector_ptr);
break;
}
case kElemInt64: {
SparseTryCastToSparseFunT4<TargetValueType, TargetIndiceType, SourceValueType, BigIntT>(source_info,
source,
source_vector_ptr,
target_info,
target,
target_vector_ptr);
break;
}
default: {
UnrecoverableError("Invalid source index type");
}
}
}

template <typename TargetValueType, typename TargetIndiceType>
void SparseTryCastToSparseFunT2(const SparseInfo *source_info,
const SparseT &source,
ColumnVector *source_vector_ptr,
const SparseInfo *target_info,
SparseT &target,
ColumnVector *target_vector_ptr) {
switch (source_info->DataType()) {
case kElemBit:
case kElemInt8:
case kElemInt16:
case kElemInt32:
case kElemInt64:
case kElemFloat:
case kElemInvalid: {
UnrecoverableError("Unimplemented");
}
case kElemDouble: {
SparseTryCastToSparseFunT3<TargetValueType, TargetIndiceType, DoubleT>(source_info,
source,
source_vector_ptr,
target_info,
target,
target_vector_ptr);
break;
}
default: {
UnrecoverableError("Unreachable code");
}
}
}

template <typename TargetValueType>
void SparseTryCastToSparseFunT1(const SparseInfo *source_info,
const SparseT &source,
ColumnVector *source_vector_ptr,
const SparseInfo *target_info,
SparseT &target,
ColumnVector *target_vector_ptr) {
switch (target_info->IndexType()) {
case kElemInt8: {
SparseTryCastToSparseFunT2<TargetValueType, TinyIntT>(source_info, source, source_vector_ptr, target_info, target, target_vector_ptr);
break;
}
case kElemInt16: {
SparseTryCastToSparseFunT2<TargetValueType, SmallIntT>(source_info, source, source_vector_ptr, target_info, target, target_vector_ptr);
break;
}
case kElemInt32: {
SparseTryCastToSparseFunT2<TargetValueType, IntegerT>(source_info, source, source_vector_ptr, target_info, target, target_vector_ptr);
break;
}
case kElemInt64: {
SparseTryCastToSparseFunT2<TargetValueType, BigIntT>(source_info, source, source_vector_ptr, target_info, target, target_vector_ptr);
break;
}
default: {
UnrecoverableError("Invalid target index type");
}
}
}

void SparseTryCastToSparseFun(const SparseInfo *source_info,
const SparseT &source,
ColumnVector *source_vector_ptr,
const SparseInfo *target_info,
SparseT &target,
ColumnVector *target_vector_ptr) {
switch (target_info->DataType()) {
case kElemBit:
case kElemInt8:
case kElemInt16:
case kElemInt32:
case kElemInt64:
case kElemFloat:
case kElemInvalid: {
UnrecoverableError("Unimplemented");
}
case kElemDouble: {
SparseTryCastToSparseFunT1<DoubleT>(source_info, source, source_vector_ptr, target_info, target, target_vector_ptr);
break;
}
default: {
UnrecoverableError("Unreachable code");
}
}
}

template <>
bool SparseTryCastToSparse::Run(const SparseT &source,
const DataType &source_type,
ColumnVector *source_vector_ptr,
SparseT &target,
const DataType &target_type,
ColumnVector *target_vector_ptr) {
const auto *source_info = static_cast<const SparseInfo *>(source_type.type_info().get());
const auto *target_info = static_cast<const SparseInfo *>(target_type.type_info().get());
SizeT source_dim = source_info->Dimension();
SizeT target_dim = target_info->Dimension();
if (source_dim > target_dim) {
RecoverableError(Status::DataTypeMismatch(source_type.ToString(), target_type.ToString()));
}
if (target_vector_ptr->buffer_->buffer_type_ != VectorBufferType::kHeap) {
UnrecoverableError(fmt::format("Sparse column vector should use kHeap VectorBuffer."));
}
SparseTryCastToSparseFun(source_info, source, source_vector_ptr, target_info, target, target_vector_ptr);
return true;
}

} // namespace infinity
26 changes: 26 additions & 0 deletions src/network/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import session_manager;
import type_info;
import logical_type;
import embedding_info;
import sparse_info;
import data_type;

namespace infinity {
Expand Down Expand Up @@ -301,6 +302,31 @@ void Connection::SendTableDescription(const SharedPtr<DataTable> &result_table)
}
break;
}
case LogicalType::kSparse: {
if (column_type->type_info()->type() != TypeInfoType::kSparse) {
UnrecoverableError("Not sparse type");
}
const auto *sparse_info = static_cast<SparseInfo *>(column_type->type_info().get());
switch (sparse_info->DataType()) {
case kElemBit:
case kElemInt8:
case kElemInt16:
case kElemInt32:
case kElemInt64:
case kElemFloat: {
UnrecoverableError("Not implemented");
}
case kElemDouble: {
object_id = 1022;
object_width = 8;
break;
}
default: {
UnrecoverableError("Should not reach here");
}
}
break;
}
default: {
UnrecoverableError("Unexpected type");
}
Expand Down
2 changes: 1 addition & 1 deletion src/parser/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
project("sql_parser")

# execute_process(COMMAND ./generate_parser.sh WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
execute_process(COMMAND ./generate_parser.sh WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})

file(GLOB_RECURSE
parser_files
Expand Down
Loading
Loading