Skip to content

Commit

Permalink
apacheGH-34588:[C++][Python] Add a MetaFunction for "dictionary_decod…
Browse files Browse the repository at this point in the history
…e" (apache#35356)

**Rationale for this change**
This PR is for [Issue-34588](apache#34588). Discussing with @ westonpace, a MetaFunction for "dictionary_decode" is implemented instead of adding a compute kernel.

**What changes are included in this PR?**
C++: Meta Function of dictionary_decode.
Python: Test

**Are these changes tested?**
One test in tests/test_compute.py

* Closes: apache#34588

Lead-authored-by: Junming Chen <junming.chen.r@outlook.com>
Co-authored-by: Alenka Frim <AlenkaF@users.noreply.github.com>
Co-authored-by: Weston Pace <weston.pace@gmail.com>
Signed-off-by: Weston Pace <weston.pace@gmail.com>
  • Loading branch information
3 people authored Jul 18, 2023
1 parent e821473 commit c7741fb
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 0 deletions.
37 changes: 37 additions & 0 deletions cpp/src/arrow/compute/kernels/vector_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "arrow/array/dict_internal.h"
#include "arrow/array/util.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/cast.h"
#include "arrow/compute/kernels/common_internal.h"
#include "arrow/result.h"
#include "arrow/util/hashing.h"
Expand Down Expand Up @@ -762,6 +763,38 @@ const FunctionDoc dictionary_encode_doc(
("Return a dictionary-encoded version of the input array."), {"array"},
"DictionaryEncodeOptions");

// ----------------------------------------------------------------------
// This function does not use any hashing utilities
// but is kept in this file to be near dictionary_encode
// Dictionary decode implementation

const FunctionDoc dictionary_decode_doc{
"Decodes a DictionaryArray to an Array",
("Return a plain-encoded version of the array input\n"
"This function does nothing if the input is not a dictionary."),
{"dictionary_array"}};

class DictionaryDecodeMetaFunction : public MetaFunction {
public:
DictionaryDecodeMetaFunction()
: MetaFunction("dictionary_decode", Arity::Unary(), dictionary_decode_doc) {}

Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
const FunctionOptions* options,
ExecContext* ctx) const override {
if (args[0].type() == nullptr || args[0].type()->id() != Type::DICTIONARY) {
return args[0];
}

if (args[0].is_array() || args[0].is_chunked_array()) {
DictionaryType* dict_type = checked_cast<DictionaryType*>(args[0].type().get());
CastOptions cast_options = CastOptions::Safe(dict_type->value_type());
return CallFunction("cast", args, &cast_options, ctx);
} else {
return Status::TypeError("Expected an Array or a Chunked Array");
}
}
};
} // namespace

void RegisterVectorHash(FunctionRegistry* registry) {
Expand Down Expand Up @@ -819,6 +852,10 @@ void RegisterVectorHash(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(dict_encode)));
}

void RegisterDictionaryDecode(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::make_shared<DictionaryDecodeMetaFunction>()));
}

} // namespace internal
} // namespace compute
} // namespace arrow
1 change: 1 addition & 0 deletions cpp/src/arrow/compute/registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ static std::unique_ptr<FunctionRegistry> CreateBuiltInRegistry() {

// Register core kernels
RegisterScalarCast(registry.get());
RegisterDictionaryDecode(registry.get());
RegisterVectorHash(registry.get());
RegisterVectorSelection(registry.get());

Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/compute/registry_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace internal {
void RegisterScalarArithmetic(FunctionRegistry* registry);
void RegisterScalarBoolean(FunctionRegistry* registry);
void RegisterScalarCast(FunctionRegistry* registry);
void RegisterDictionaryDecode(FunctionRegistry* registry);
void RegisterScalarComparison(FunctionRegistry* registry);
void RegisterScalarIfElse(FunctionRegistry* registry);
void RegisterScalarNested(FunctionRegistry* registry);
Expand Down
11 changes: 11 additions & 0 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -1756,6 +1756,17 @@ def test_logical():
assert pc.invert(a) == pa.array([False, True, True, None])


def test_dictionary_decode():
array = pa.array(["a", "a", "b", "c", "b"])
dictionary_array = array.dictionary_encode()
dictionary_array_decode = pc.dictionary_decode(dictionary_array)

assert array != dictionary_array

assert array == dictionary_array_decode
assert array == pc.dictionary_decode(array)


def test_cast():
arr = pa.array([1, 2, 3, 4], type='int64')
options = pc.CastOptions(pa.int8())
Expand Down

0 comments on commit c7741fb

Please sign in to comment.