Skip to content

Commit

Permalink
refactor range logic; add consistency test in c++ and py
Browse files Browse the repository at this point in the history
  • Loading branch information
js8544 committed Jan 8, 2024
1 parent bec3747 commit 1a50d05
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 54 deletions.
118 changes: 65 additions & 53 deletions cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2457,22 +2457,15 @@ struct SliceBytesTransform : StringSliceTransformBase {
return SliceBackward(input, input_string_bytes, output);
}

int64_t SliceForward(const uint8_t* input, int64_t input_string_bytes,
uint8_t* output) {
// Slice in forward order (step > 0)
const SliceOptions& opt = *this->options;
const uint8_t* begin = input;
const uint8_t* end = input + input_string_bytes;
const uint8_t* begin_sliced;
const uint8_t* end_sliced;

if (!input_string_bytes) {
return 0;
}
// First, compute begin_sliced and end_sliced
static std::pair<int64_t, int64_t> SliceForwardRange(const SliceOptions& opt,
int64_t input_string_bytes) {
int64_t begin = 0;
int64_t end = input_string_bytes;
int64_t begin_sliced = 0;
int64_t end_sliced = 0;
if (opt.start >= 0) {
// start counting from the left
begin_sliced = std::min(begin + opt.start, end);
begin_sliced = std::min(opt.start, end);
if (opt.stop > opt.start) {
// continue counting from begin_sliced
const int64_t length = opt.stop - opt.start;
Expand All @@ -2482,7 +2475,7 @@ struct SliceBytesTransform : StringSliceTransformBase {
end_sliced = std::max(end + opt.stop, begin_sliced);
} else {
// zero length slice
return 0;
return {0, 0};
}
} else {
// start counting from the right
Expand All @@ -2494,7 +2487,7 @@ struct SliceBytesTransform : StringSliceTransformBase {
// and therefore we also need this
if (end_sliced <= begin_sliced) {
// zero length slice
return 0;
return {0, 0};
}
} else if ((opt.stop < 0) && (opt.stop > opt.start)) {
// stop is negative, but larger than start, so we count again from the right
Expand All @@ -2504,12 +2497,30 @@ struct SliceBytesTransform : StringSliceTransformBase {
end_sliced = std::max(end + opt.stop, begin_sliced);
} else {
// zero length slice
return 0;
return {0, 0};
}
}
return {begin_sliced, end_sliced};
}

int64_t SliceForward(const uint8_t* input, int64_t input_string_bytes,
uint8_t* output) {
// Slice in forward order (step > 0)
if (!input_string_bytes) {
return 0;
}

const SliceOptions& opt = *this->options;
auto [begin_index, end_index] = SliceForwardRange(opt, input_string_bytes);
const uint8_t* begin_sliced = input + begin_index;
const uint8_t* end_sliced = input + end_index;

if (begin_sliced == end_sliced) {
return 0;
}

// Second, copy computed slice to output
DCHECK(begin_sliced <= end_sliced);
DCHECK(begin_sliced < end_sliced);
if (opt.step == 1) {
// fast case, where we simply can finish with a memcpy
std::copy(begin_sliced, end_sliced, output);
Expand All @@ -2528,18 +2539,13 @@ struct SliceBytesTransform : StringSliceTransformBase {
return dest - output;
}

int64_t SliceBackward(const uint8_t* input, int64_t input_string_bytes,
uint8_t* output) {
static std::pair<int64_t, int64_t> SliceBackwardRange(const SliceOptions& opt,
int64_t input_string_bytes) {
// Slice in reverse order (step < 0)
const SliceOptions& opt = *this->options;
const uint8_t* begin = input;
const uint8_t* end = input + input_string_bytes;
const uint8_t* begin_sliced = begin;
const uint8_t* end_sliced = end;

if (!input_string_bytes) {
return 0;
}
int64_t begin = 0;
int64_t end = input_string_bytes;
int64_t begin_sliced = begin;
int64_t end_sliced = end;

if (opt.start >= 0) {
// +1 because begin_sliced acts as as the end of a reverse iterator
Expand All @@ -2558,6 +2564,28 @@ struct SliceBytesTransform : StringSliceTransformBase {
}
end_sliced--;

if (begin_sliced <= end_sliced) {
// zero length slice
return {0, 0};
}

return {begin_sliced, end_sliced};
}

int64_t SliceBackward(const uint8_t* input, int64_t input_string_bytes,
uint8_t* output) {
if (!input_string_bytes) {
return 0;
}

const SliceOptions& opt = *this->options;
auto [begin_index, end_index] = SliceBackwardRange(opt, input_string_bytes);
const uint8_t* begin_sliced = input + begin_index;
const uint8_t* end_sliced = input + end_index;

if (begin_sliced == end_sliced) {
return 0;
}
// Copy computed slice to output
uint8_t* dest = output;
const uint8_t* i = begin_sliced;
Expand All @@ -2577,31 +2605,15 @@ struct SliceBytesTransform : StringSliceTransformBase {
if (step == 0) {
return Status::Invalid("Slice step cannot be zero");
}
auto start = options.start;
auto stop = options.stop;
auto input_width = static_cast<int64_t>(input_width_32);

if (start < 0) {
start = std::max(static_cast<int64_t>(0), start + input_width);
}
if (stop < 0) {
stop = std::max(static_cast<int64_t>(0), stop + input_width);
}
start = std::min(start, input_width);
stop = std::min(stop, input_width);

if ((start == stop) || ((start >= stop) && (step > 0)) ||
((start <= stop) && (step < 0))) {
return 0;
}

if (step < 0) {
return static_cast<int32_t>(
std::max(static_cast<int64_t>(0), (stop - start + step + 1) / step));
if (step > 0) {
// forward slice
auto [begin_index, end_index] = SliceForwardRange(options, input_width_32);
return (end_index - begin_index + step - 1) / step;
} else {
// backward slice
auto [begin_index, end_index] = SliceBackwardRange(options, input_width_32);
return (end_index - begin_index + step + 1) / step;
}

return static_cast<int32_t>(
std::max(static_cast<int64_t>(0), (stop - start + step - 1) / step));
}
};

Expand Down
26 changes: 26 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_string_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,32 @@ TEST_F(TestFixedSizeBinaryKernels, BinarySliceNegPos) {
R"(["fdb", "zbo"])", &options_step_neg);
}

TEST_F(TestFixedSizeBinaryKernels, BinarySliceConsistentyWithVarLenBinary) {
std::string source_str = "abcdef";
for (size_t str_len = 0; str_len < source_str.size(); ++str_len) {
auto input_str = source_str.substr(0, str_len);
auto fixed_input =
ArrayFromJSON(fixed_size_binary(str_len), R"([")" + input_str + R"("])");
auto varlen_input = ArrayFromJSON(binary(), R"([")" + input_str + R"("])");
for (auto start = -6; start <= 6; ++start) {
for (auto stop = -6; stop <= 6; ++stop) {
for (auto step = -3; step <= 4; ++step) {
if (step == 0) {
continue;
}
SliceOptions options{start, stop, step};
auto expected =
CallFunction("binary_slice", {varlen_input}, &options).ValueOrDie();
auto actual =
CallFunction("binary_slice", {fixed_input}, &options).ValueOrDie();
actual = Cast(actual, binary()).ValueOrDie();
AssertDatumsEqual(expected, actual);
}
}
}
}
}

TEST_F(TestFixedSizeBinaryKernels, BinaryReplaceSlice) {
ReplaceSliceOptions options{0, 1, "XX"};
CheckUnary("binary_replace_slice", "[]", fixed_size_binary(7), "[]", &options);
Expand Down
13 changes: 12 additions & 1 deletion python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,8 @@ def test_slice_compatibility():


def test_binary_slice_compatibility():
arr = pa.array([b"", b"a", b"a\xff", b"ab\x00", b"abc\xfb", b"ab\xf2de"])
data = [b"", b"a", b"a\xff", b"ab\x00", b"abc\xfb", b"ab\xf2de"]
arr = pa.array(data)
for start, stop, step in itertools.product(range(-6, 6),
range(-6, 6),
range(-3, 4)):
Expand All @@ -574,6 +575,16 @@ def test_binary_slice_compatibility():
assert expected.equals(result)
# Positional options
assert pc.binary_slice(arr, start, stop, step) == result
# Fixed size binary input / output
for item in data:
print(item)
print(start, stop, step)
fsb_scalar = pa.scalar(item, type=pa.binary(len(item)))
expected = item[start:stop:step]
print(expected)
actual = pc.binary_slice(fsb_scalar, start, stop, step)
assert actual.type == pa.binary(len(expected))
assert actual.as_py() == expected


def test_split_pattern():
Expand Down

0 comments on commit 1a50d05

Please sign in to comment.