From 3f89475d661c66354edb9281c4d6ab9015868c6a Mon Sep 17 00:00:00 2001 From: Jin Shang Date: Sat, 16 Dec 2023 11:39:24 +0800 Subject: [PATCH] binary_slice kernel for fixed size binary --- .../compute/kernels/scalar_string_ascii.cc | 29 ++++++ .../compute/kernels/scalar_string_internal.h | 2 + .../compute/kernels/scalar_string_test.cc | 90 +++++++++++++++++++ 3 files changed, 121 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc index 6764845dfca81..58ba961e55024 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc @@ -20,6 +20,7 @@ #include #include #include +#include "arrow/compute/api_scalar.h" #ifdef ARROW_WITH_RE2 #include @@ -2436,6 +2437,7 @@ void AddAsciiStringReplaceSlice(FunctionRegistry* registry) { namespace { struct SliceBytesTransform : StringSliceTransformBase { + using StringSliceTransformBase::StringSliceTransformBase; int64_t MaxCodeunits(int64_t ninputs, int64_t input_bytes) override { const SliceOptions& opt = *this->options; if ((opt.start >= 0) != (opt.stop >= 0)) { @@ -2568,6 +2570,27 @@ struct SliceBytesTransform : StringSliceTransformBase { return dest - output; } + + static int32_t FixedOutputSize(SliceOptions options, int32_t input_width_32) { + auto step = options.step; + auto start = options.start; + auto stop = options.stop; + auto input_width = static_cast(input_width_32); + + if (start < 0) { + start = std::max(0L, start + input_width); + } + if (stop < 0) { + stop = std::max(0L, stop + input_width); + } + start = std::min(start, input_width); + stop = std::min(stop, input_width); + + if ((start >= stop and step > 0) || (start <= stop and step < 0) || start == stop) { + return 0; + } + return std::max(0L, (stop - start + (step - (step > 0 ? 1 : -1))) / step); + } }; template @@ -2594,6 +2617,12 @@ void AddAsciiStringSlice(FunctionRegistry* registry) { DCHECK_OK( func->AddKernel({ty}, ty, std::move(exec), SliceBytesTransform::State::Init)); } + using TransformExec = FixedSizeBinaryTransformExecWithState; + ScalarKernel fsb_kernel({InputType(Type::FIXED_SIZE_BINARY)}, + OutputType(TransformExec::OutputType), TransformExec::Exec, + StringSliceTransformBase::State::Init); + fsb_kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + DCHECK_OK(func->AddKernel(std::move(fsb_kernel))); DCHECK_OK(registry->AddFunction(std::move(func))); } diff --git a/cpp/src/arrow/compute/kernels/scalar_string_internal.h b/cpp/src/arrow/compute/kernels/scalar_string_internal.h index 1a9969441655d..a3ba48d41d89d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_internal.h +++ b/cpp/src/arrow/compute/kernels/scalar_string_internal.h @@ -250,6 +250,8 @@ struct StringSliceTransformBase : public StringTransformBase { using State = OptionsWrapper; const SliceOptions* options; + StringSliceTransformBase() = default; + explicit StringSliceTransformBase(const SliceOptions& options) : options{&options} {} Status PreExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) override { options = &State::Get(ctx); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index ff14f5e7a5c5d..516be92210fe6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -19,6 +19,7 @@ #include #include #include +#include "arrow/type_fwd.h" #include #include @@ -712,6 +713,95 @@ TEST_F(TestFixedSizeBinaryKernels, BinaryLength) { "[6, null, 6]"); } +TEST_F(TestFixedSizeBinaryKernels, SliceBytesBasic) { + SliceOptions options{2, 4}; + CheckUnary("binary_slice", R"(["abcabc", "defdef"])", fixed_size_binary(2), + R"(["ca", "fd"])", &options); + + SliceOptions options_edgecase_1{-3, 1}; + CheckUnary("binary_slice", R"(["abcabc", "defdef"])", fixed_size_binary(0), + R"(["", ""])", &options_edgecase_1); + + SliceOptions options_edgecase_2{-10, -3}; + CheckUnary("binary_slice", R"(["abcabc", "defdef"])", fixed_size_binary(3), + R"(["abc", "def"])", &options_edgecase_2); + + auto input = ArrayFromJSON(this->type(), R"(["foobaz"])"); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + testing::HasSubstr("Function 'binary_slice' cannot be called without options"), + CallFunction("binary_slice", {input})); + + SliceOptions options_invalid{2, 4, 0}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("Slice step cannot be zero"), + CallFunction("binary_slice", {input}, &options_invalid)); +} + +TEST_F(TestFixedSizeBinaryKernels, SliceBytesPosPos) { + SliceOptions options_step{1, 5, 2}; + CheckUnary("binary_slice", R"(["abcabc", "defdef"])", fixed_size_binary(2), + R"(["ba", "ed"])", &options_step); + + SliceOptions options_step_neg{5, 0, -2}; + CheckUnary("binary_slice", R"(["abcabc", "defdef"])", fixed_size_binary(3), + R"(["cab", "fde"])", &options_step_neg); +} + +TEST_F(TestFixedSizeBinaryKernels, SliceBytesPosNeg) { + SliceOptions options{2, -1}; + CheckUnary("binary_slice", R"(["abcabc", "defdef"])", fixed_size_binary(3), + R"(["cab", "fde"])", &options); + + SliceOptions options_step{1, -1, 2}; + CheckUnary("binary_slice", R"(["abcabc", "defdef"])", fixed_size_binary(2), + R"(["ba", "ed"])", &options_step); + + SliceOptions options_step_neg{5, -4, -2}; + CheckUnary("binary_slice", R"(["abcabc", "defdef"])", fixed_size_binary(2), + R"(["ca", "fd"])", &options_step_neg); + + options_step_neg.stop = -6; + CheckUnary("binary_slice", R"(["abcabc", "defdef"])", fixed_size_binary(3), + R"(["cab", "fde"])", &options_step_neg); +} + +TEST_F(TestFixedSizeBinaryKernels, SliceBytesNegNeg) { + SliceOptions options{-2, -1}; + CheckUnary("binary_slice", R"(["abcabc", "defdef"])", fixed_size_binary(1), + R"(["b", "e"])", &options); + + SliceOptions options_step{-4, -1, 2}; + CheckUnary("binary_slice", R"(["abcabc", "defdef"])", fixed_size_binary(2), + R"(["cb", "fe"])", &options_step); + + SliceOptions options_step_neg{-1, -3, -2}; + CheckUnary("binary_slice", R"(["abcabc", "defdef"])", fixed_size_binary(1), + R"(["c", "f"])", &options_step_neg); + + options_step_neg.stop = -4; + CheckUnary("binary_slice", R"(["abcabc", "defdef"])", fixed_size_binary(2), + R"(["ca", "fd"])", &options_step_neg); +} + +TEST_F(TestFixedSizeBinaryKernels, SliceBytesNegPos) { + SliceOptions options{-2, 4}; + CheckUnary("binary_slice", R"(["abcabc", "defdef"])", fixed_size_binary(0), + R"(["", ""])", &options); + + SliceOptions options_step{-4, 5, 2}; + CheckUnary("binary_slice", R"(["abcabc", "defdef"])", fixed_size_binary(2), + R"(["cb", "fe"])", &options_step); + + SliceOptions options_step_neg{-1, 1, -2}; + CheckUnary("binary_slice", R"(["abcabc", "defdef"])", fixed_size_binary(2), + R"(["ca", "fd"])", &options_step_neg); + + options_step_neg.stop = 0; + CheckUnary("binary_slice", R"(["abcabc", "defdef"])", fixed_size_binary(3), + R"(["cab", "fde"])", &options_step_neg); +} + TEST_F(TestFixedSizeBinaryKernels, BinaryReplaceSlice) { ReplaceSliceOptions options{0, 1, "XX"}; CheckUnary("binary_replace_slice", "[]", fixed_size_binary(7), "[]", &options);