Skip to content

Commit

Permalink
apacheGH-43487: [Python] Sanitize Python reference handling in UDF im…
Browse files Browse the repository at this point in the history
…plementation (apache#43557)

1. Remove spurious increfs (the function object is already incref'ed at an upper level)
2. Add unit test with an ephemeral Python function object
3. Streamline and improve Python reference handling

* GitHub Issue: apache#43487

Authored-by: Antoine Pitrou <antoine@python.org>
Signed-off-by: Antoine Pitrou <antoine@python.org>
  • Loading branch information
pitrou authored Aug 7, 2024
1 parent 9b58454 commit 1f24799
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 95 deletions.
149 changes: 54 additions & 95 deletions python/pyarrow/src/arrow/python/udf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,35 +43,18 @@ namespace py {
namespace {

struct PythonUdfKernelState : public compute::KernelState {
// NOTE: this KernelState constructor doesn't require the GIL.
// If it did, the corresponding KernelInit::operator() should be wrapped
// within SafeCallIntoPython (GH-43487).
explicit PythonUdfKernelState(std::shared_ptr<OwnedRefNoGIL> function)
: function(function) {
Py_INCREF(function->obj());
}

// function needs to be destroyed at process exit
// and Python may no longer be initialized.
~PythonUdfKernelState() {
if (Py_IsFinalizing()) {
function->detach();
}
}
: function(std::move(function)) {}

std::shared_ptr<OwnedRefNoGIL> function;
};

struct PythonUdfKernelInit {
explicit PythonUdfKernelInit(std::shared_ptr<OwnedRefNoGIL> function)
: function(function) {
Py_INCREF(function->obj());
}

// function needs to be destroyed at process exit
// and Python may no longer be initialized.
~PythonUdfKernelInit() {
if (Py_IsFinalizing()) {
function->detach();
}
}
: function(std::move(function)) {}

Result<std::unique_ptr<compute::KernelState>> operator()(
compute::KernelContext*, const compute::KernelInitArgs&) {
Expand All @@ -94,68 +77,56 @@ struct HashUdfAggregator : public compute::KernelState {
virtual Status Finalize(KernelContext* ctx, Datum* out) = 0;
};

arrow::Status AggregateUdfConsume(compute::KernelContext* ctx,
const compute::ExecSpan& batch) {
Status AggregateUdfConsume(compute::KernelContext* ctx, const compute::ExecSpan& batch) {
return checked_cast<ScalarUdfAggregator*>(ctx->state())->Consume(ctx, batch);
}

arrow::Status AggregateUdfMerge(compute::KernelContext* ctx, compute::KernelState&& src,
compute::KernelState* dst) {
Status AggregateUdfMerge(compute::KernelContext* ctx, compute::KernelState&& src,
compute::KernelState* dst) {
return checked_cast<ScalarUdfAggregator*>(dst)->MergeFrom(ctx, std::move(src));
}

arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* out) {
Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* out) {
return checked_cast<ScalarUdfAggregator*>(ctx->state())->Finalize(ctx, out);
}

arrow::Status HashAggregateUdfResize(KernelContext* ctx, int64_t size) {
Status HashAggregateUdfResize(KernelContext* ctx, int64_t size) {
return checked_cast<HashUdfAggregator*>(ctx->state())->Resize(ctx, size);
}

arrow::Status HashAggregateUdfConsume(KernelContext* ctx, const ExecSpan& batch) {
Status HashAggregateUdfConsume(KernelContext* ctx, const ExecSpan& batch) {
return checked_cast<HashUdfAggregator*>(ctx->state())->Consume(ctx, batch);
}

arrow::Status HashAggregateUdfMerge(KernelContext* ctx, KernelState&& src,
const ArrayData& group_id_mapping) {
Status HashAggregateUdfMerge(KernelContext* ctx, KernelState&& src,
const ArrayData& group_id_mapping) {
return checked_cast<HashUdfAggregator*>(ctx->state())
->Merge(ctx, std::move(src), group_id_mapping);
}

arrow::Status HashAggregateUdfFinalize(KernelContext* ctx, Datum* out) {
Status HashAggregateUdfFinalize(KernelContext* ctx, Datum* out) {
return checked_cast<HashUdfAggregator*>(ctx->state())->Finalize(ctx, out);
}

struct PythonTableUdfKernelInit {
PythonTableUdfKernelInit(std::shared_ptr<OwnedRefNoGIL> function_maker,
UdfWrapperCallback cb)
: function_maker(function_maker), cb(cb) {
Py_INCREF(function_maker->obj());
}

// function needs to be destroyed at process exit
// and Python may no longer be initialized.
~PythonTableUdfKernelInit() {
if (Py_IsFinalizing()) {
function_maker->detach();
}
}
: function_maker(std::move(function_maker)), cb(std::move(cb)) {}

Result<std::unique_ptr<compute::KernelState>> operator()(
compute::KernelContext* ctx, const compute::KernelInitArgs&) {
UdfContext udf_context{ctx->memory_pool(), /*batch_length=*/0};
std::unique_ptr<OwnedRefNoGIL> function;
RETURN_NOT_OK(SafeCallIntoPython([this, &udf_context, &function] {
OwnedRef empty_tuple(PyTuple_New(0));
function = std::make_unique<OwnedRefNoGIL>(
cb(function_maker->obj(), udf_context, empty_tuple.obj()));
RETURN_NOT_OK(CheckPyError());
return Status::OK();
}));
if (!PyCallable_Check(function->obj())) {
return Status::TypeError("Expected a callable Python object.");
}
return std::make_unique<PythonUdfKernelState>(std::move(function));
return SafeCallIntoPython(
[this, ctx]() -> Result<std::unique_ptr<compute::KernelState>> {
UdfContext udf_context{ctx->memory_pool(), /*batch_length=*/0};
OwnedRef empty_tuple(PyTuple_New(0));
auto function = std::make_shared<OwnedRefNoGIL>(
cb(function_maker->obj(), udf_context, empty_tuple.obj()));
RETURN_NOT_OK(CheckPyError());
if (!PyCallable_Check(function->obj())) {
return Status::TypeError("Expected a callable Python object.");
}
return std::make_unique<PythonUdfKernelState>(std::move(function));
});
}

std::shared_ptr<OwnedRefNoGIL> function_maker;
Expand All @@ -167,21 +138,16 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator {
UdfWrapperCallback cb,
std::vector<std::shared_ptr<DataType>> input_types,
std::shared_ptr<DataType> output_type)
: function(function), cb(std::move(cb)), output_type(std::move(output_type)) {
Py_INCREF(function->obj());
: function(std::move(function)),
cb(std::move(cb)),
output_type(std::move(output_type)) {
std::vector<std::shared_ptr<Field>> fields;
for (size_t i = 0; i < input_types.size(); i++) {
fields.push_back(field("", input_types[i]));
}
input_schema = schema(std::move(fields));
};

~PythonUdfScalarAggregatorImpl() override {
if (Py_IsFinalizing()) {
function->detach();
}
}

Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) override {
ARROW_ASSIGN_OR_RAISE(
auto rb, batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool()));
Expand Down Expand Up @@ -263,8 +229,9 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator {
UdfWrapperCallback cb,
std::vector<std::shared_ptr<DataType>> input_types,
std::shared_ptr<DataType> output_type)
: function(function), cb(std::move(cb)), output_type(std::move(output_type)) {
Py_INCREF(function->obj());
: function(std::move(function)),
cb(std::move(cb)),
output_type(std::move(output_type)) {
std::vector<std::shared_ptr<Field>> fields;
fields.reserve(input_types.size());
for (size_t i = 0; i < input_types.size(); i++) {
Expand All @@ -273,12 +240,6 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator {
input_schema = schema(std::move(fields));
};

~PythonUdfHashAggregatorImpl() override {
if (Py_IsFinalizing()) {
function->detach();
}
}

// same as ApplyGrouping in partition.cc
// replicated the code here to avoid complicating the dependencies
static Result<RecordBatchVector> ApplyGroupings(
Expand Down Expand Up @@ -416,10 +377,10 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator {
struct PythonUdf : public PythonUdfKernelState {
PythonUdf(std::shared_ptr<OwnedRefNoGIL> function, UdfWrapperCallback cb,
std::vector<TypeHolder> input_types, compute::OutputType output_type)
: PythonUdfKernelState(function),
cb(cb),
input_types(input_types),
output_type(output_type) {}
: PythonUdfKernelState(std::move(function)),
cb(std::move(cb)),
input_types(std::move(input_types)),
output_type(std::move(output_type)) {}

UdfWrapperCallback cb;
std::vector<TypeHolder> input_types;
Expand All @@ -440,7 +401,7 @@ struct PythonUdf : public PythonUdfKernelState {
Status Exec(compute::KernelContext* ctx, const compute::ExecSpan& batch,
compute::ExecResult* out) {
auto state = arrow::internal::checked_cast<PythonUdfKernelState*>(ctx->state());
std::shared_ptr<OwnedRefNoGIL>& function = state->function;
PyObject* function = state->function->obj();
const int num_args = batch.num_values();
UdfContext udf_context{ctx->memory_pool(), batch.length};

Expand All @@ -458,7 +419,7 @@ struct PythonUdf : public PythonUdfKernelState {
}
}

OwnedRef result(cb(function->obj(), udf_context, arg_tuple.obj()));
OwnedRef result(cb(function, udf_context, arg_tuple.obj()));
RETURN_NOT_OK(CheckPyError());
// unwrapping the output for expected output type
if (is_array(result.obj())) {
Expand Down Expand Up @@ -497,12 +458,13 @@ Status RegisterUdf(PyObject* function, compute::KernelInit kernel_init,
}
auto scalar_func =
std::make_shared<Function>(options.func_name, options.arity, options.func_doc);
Py_INCREF(function);
std::vector<compute::InputType> input_types;
for (const auto& in_dtype : options.input_types) {
input_types.emplace_back(in_dtype);
}
compute::OutputType output_type(options.output_type);
// Take reference before wrapping with OwnedRefNoGIL
Py_INCREF(function);
auto udf_data = std::make_shared<PythonUdf>(
std::make_shared<OwnedRefNoGIL>(function), cb,
TypeHolder::FromTypes(options.input_types), options.output_type);
Expand Down Expand Up @@ -565,11 +527,6 @@ Status RegisterScalarAggregateFunction(PyObject* function, UdfWrapperCallback cb
registry = compute::GetFunctionRegistry();
}

// Py_INCREF here so that once a function is registered
// its refcount gets increased by 1 and doesn't get gced
// if all existing refs are gone
Py_INCREF(function);

static auto default_scalar_aggregate_options =
compute::ScalarAggregateOptions::Defaults();
auto aggregate_func = std::make_shared<compute::ScalarAggregateFunction>(
Expand All @@ -582,12 +539,16 @@ Status RegisterScalarAggregateFunction(PyObject* function, UdfWrapperCallback cb
}
compute::OutputType output_type(options.output_type);

compute::KernelInit init = [cb, function, options](compute::KernelContext* ctx,
const compute::KernelInitArgs& args)
// Take reference before wrapping with OwnedRefNoGIL
Py_INCREF(function);
auto function_ref = std::make_shared<OwnedRefNoGIL>(function);

compute::KernelInit init = [cb, function_ref, options](
compute::KernelContext* ctx,
const compute::KernelInitArgs& args)
-> Result<std::unique_ptr<compute::KernelState>> {
return std::make_unique<PythonUdfScalarAggregatorImpl>(
std::make_shared<OwnedRefNoGIL>(function), cb, options.input_types,
options.output_type);
function_ref, cb, options.input_types, options.output_type);
};

auto sig = compute::KernelSignature::Make(
Expand Down Expand Up @@ -638,10 +599,6 @@ Status RegisterHashAggregateFunction(PyObject* function, UdfWrapperCallback cb,
registry = compute::GetFunctionRegistry();
}

// Py_INCREF here so that once a function is registered
// its refcount gets increased by 1 and doesn't get gced
// if all existing refs are gone
Py_INCREF(function);
UdfOptions hash_options = AdjustForHashAggregate(options);

std::vector<compute::InputType> input_types;
Expand All @@ -656,13 +613,15 @@ Status RegisterHashAggregateFunction(PyObject* function, UdfWrapperCallback cb,
hash_options.func_name, hash_options.arity, hash_options.func_doc,
&default_hash_aggregate_options);

compute::KernelInit init = [function, cb, hash_options](
// Take reference before wrapping with OwnedRefNoGIL
Py_INCREF(function);
auto function_ref = std::make_shared<OwnedRefNoGIL>(function);
compute::KernelInit init = [function_ref, cb, hash_options](
compute::KernelContext* ctx,
const compute::KernelInitArgs& args)
-> Result<std::unique_ptr<compute::KernelState>> {
return std::make_unique<PythonUdfHashAggregatorImpl>(
std::make_shared<OwnedRefNoGIL>(function), cb, hash_options.input_types,
hash_options.output_type);
function_ref, cb, hash_options.input_types, hash_options.output_type);
};

auto sig = compute::KernelSignature::Make(
Expand Down
31 changes: 31 additions & 0 deletions python/pyarrow/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,31 @@ def nullary_func(context):
return nullary_func, func_name


@pytest.fixture(scope="session")
def ephemeral_nullary_func_fixture():
"""
Register a nullary scalar function with an ephemeral Python function.
This stresses that the Python function object is properly kept alive by the
registered function.
"""
def nullary_func(context):
return pa.array([42] * context.batch_length, type=pa.int64(),
memory_pool=context.memory_pool)

func_doc = {
"summary": "random function",
"description": "generates a random value"
}
func_name = "test_ephemeral_nullary_func"
pc.register_scalar_function(nullary_func,
func_name,
func_doc,
{},
pa.int64())

return func_name


@pytest.fixture(scope="session")
def wrong_output_type_func_fixture():
"""
Expand Down Expand Up @@ -505,6 +530,12 @@ def test_nullary_function(nullary_func_fixture):
batch_length=1)


def test_ephemeral_function(ephemeral_nullary_func_fixture):
name = ephemeral_nullary_func_fixture
result = pc.call_function(name, [], length=1)
assert result.to_pylist() == [42]


def test_wrong_output_type(wrong_output_type_func_fixture):
_, func_name = wrong_output_type_func_fixture

Expand Down

0 comments on commit 1f24799

Please sign in to comment.