From 1f2479908323daff3b08d1d585517239cae637d2 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Wed, 7 Aug 2024 15:45:03 +0200 Subject: [PATCH] GH-43487: [Python] Sanitize Python reference handling in UDF implementation (#43557) 1. Remove spurious increfs (the function object is already incref'ed at an upper level) 2. Add unit test with an ephemeral Python function object 3. Streamline and improve Python reference handling * GitHub Issue: #43487 Authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- python/pyarrow/src/arrow/python/udf.cc | 149 +++++++++---------------- python/pyarrow/tests/test_udf.py | 31 +++++ 2 files changed, 85 insertions(+), 95 deletions(-) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index b6a862af8ca07..2c1e97c3ea03d 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -43,35 +43,18 @@ namespace py { namespace { struct PythonUdfKernelState : public compute::KernelState { + // NOTE: this KernelState constructor doesn't require the GIL. + // If it did, the corresponding KernelInit::operator() should be wrapped + // within SafeCallIntoPython (GH-43487). explicit PythonUdfKernelState(std::shared_ptr function) - : function(function) { - Py_INCREF(function->obj()); - } - - // function needs to be destroyed at process exit - // and Python may no longer be initialized. - ~PythonUdfKernelState() { - if (Py_IsFinalizing()) { - function->detach(); - } - } + : function(std::move(function)) {} std::shared_ptr function; }; struct PythonUdfKernelInit { explicit PythonUdfKernelInit(std::shared_ptr function) - : function(function) { - Py_INCREF(function->obj()); - } - - // function needs to be destroyed at process exit - // and Python may no longer be initialized. - ~PythonUdfKernelInit() { - if (Py_IsFinalizing()) { - function->detach(); - } - } + : function(std::move(function)) {} Result> operator()( compute::KernelContext*, const compute::KernelInitArgs&) { @@ -94,68 +77,56 @@ struct HashUdfAggregator : public compute::KernelState { virtual Status Finalize(KernelContext* ctx, Datum* out) = 0; }; -arrow::Status AggregateUdfConsume(compute::KernelContext* ctx, - const compute::ExecSpan& batch) { +Status AggregateUdfConsume(compute::KernelContext* ctx, const compute::ExecSpan& batch) { return checked_cast(ctx->state())->Consume(ctx, batch); } -arrow::Status AggregateUdfMerge(compute::KernelContext* ctx, compute::KernelState&& src, - compute::KernelState* dst) { +Status AggregateUdfMerge(compute::KernelContext* ctx, compute::KernelState&& src, + compute::KernelState* dst) { return checked_cast(dst)->MergeFrom(ctx, std::move(src)); } -arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* out) { +Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* out) { return checked_cast(ctx->state())->Finalize(ctx, out); } -arrow::Status HashAggregateUdfResize(KernelContext* ctx, int64_t size) { +Status HashAggregateUdfResize(KernelContext* ctx, int64_t size) { return checked_cast(ctx->state())->Resize(ctx, size); } -arrow::Status HashAggregateUdfConsume(KernelContext* ctx, const ExecSpan& batch) { +Status HashAggregateUdfConsume(KernelContext* ctx, const ExecSpan& batch) { return checked_cast(ctx->state())->Consume(ctx, batch); } -arrow::Status HashAggregateUdfMerge(KernelContext* ctx, KernelState&& src, - const ArrayData& group_id_mapping) { +Status HashAggregateUdfMerge(KernelContext* ctx, KernelState&& src, + const ArrayData& group_id_mapping) { return checked_cast(ctx->state()) ->Merge(ctx, std::move(src), group_id_mapping); } -arrow::Status HashAggregateUdfFinalize(KernelContext* ctx, Datum* out) { +Status HashAggregateUdfFinalize(KernelContext* ctx, Datum* out) { return checked_cast(ctx->state())->Finalize(ctx, out); } struct PythonTableUdfKernelInit { PythonTableUdfKernelInit(std::shared_ptr function_maker, UdfWrapperCallback cb) - : function_maker(function_maker), cb(cb) { - Py_INCREF(function_maker->obj()); - } - - // function needs to be destroyed at process exit - // and Python may no longer be initialized. - ~PythonTableUdfKernelInit() { - if (Py_IsFinalizing()) { - function_maker->detach(); - } - } + : function_maker(std::move(function_maker)), cb(std::move(cb)) {} Result> operator()( compute::KernelContext* ctx, const compute::KernelInitArgs&) { - UdfContext udf_context{ctx->memory_pool(), /*batch_length=*/0}; - std::unique_ptr function; - RETURN_NOT_OK(SafeCallIntoPython([this, &udf_context, &function] { - OwnedRef empty_tuple(PyTuple_New(0)); - function = std::make_unique( - cb(function_maker->obj(), udf_context, empty_tuple.obj())); - RETURN_NOT_OK(CheckPyError()); - return Status::OK(); - })); - if (!PyCallable_Check(function->obj())) { - return Status::TypeError("Expected a callable Python object."); - } - return std::make_unique(std::move(function)); + return SafeCallIntoPython( + [this, ctx]() -> Result> { + UdfContext udf_context{ctx->memory_pool(), /*batch_length=*/0}; + OwnedRef empty_tuple(PyTuple_New(0)); + auto function = std::make_shared( + cb(function_maker->obj(), udf_context, empty_tuple.obj())); + RETURN_NOT_OK(CheckPyError()); + if (!PyCallable_Check(function->obj())) { + return Status::TypeError("Expected a callable Python object."); + } + return std::make_unique(std::move(function)); + }); } std::shared_ptr function_maker; @@ -167,8 +138,9 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { UdfWrapperCallback cb, std::vector> input_types, std::shared_ptr output_type) - : function(function), cb(std::move(cb)), output_type(std::move(output_type)) { - Py_INCREF(function->obj()); + : function(std::move(function)), + cb(std::move(cb)), + output_type(std::move(output_type)) { std::vector> fields; for (size_t i = 0; i < input_types.size(); i++) { fields.push_back(field("", input_types[i])); @@ -176,12 +148,6 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { input_schema = schema(std::move(fields)); }; - ~PythonUdfScalarAggregatorImpl() override { - if (Py_IsFinalizing()) { - function->detach(); - } - } - Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) override { ARROW_ASSIGN_OR_RAISE( auto rb, batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool())); @@ -263,8 +229,9 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { UdfWrapperCallback cb, std::vector> input_types, std::shared_ptr output_type) - : function(function), cb(std::move(cb)), output_type(std::move(output_type)) { - Py_INCREF(function->obj()); + : function(std::move(function)), + cb(std::move(cb)), + output_type(std::move(output_type)) { std::vector> fields; fields.reserve(input_types.size()); for (size_t i = 0; i < input_types.size(); i++) { @@ -273,12 +240,6 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { input_schema = schema(std::move(fields)); }; - ~PythonUdfHashAggregatorImpl() override { - if (Py_IsFinalizing()) { - function->detach(); - } - } - // same as ApplyGrouping in partition.cc // replicated the code here to avoid complicating the dependencies static Result ApplyGroupings( @@ -416,10 +377,10 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { struct PythonUdf : public PythonUdfKernelState { PythonUdf(std::shared_ptr function, UdfWrapperCallback cb, std::vector input_types, compute::OutputType output_type) - : PythonUdfKernelState(function), - cb(cb), - input_types(input_types), - output_type(output_type) {} + : PythonUdfKernelState(std::move(function)), + cb(std::move(cb)), + input_types(std::move(input_types)), + output_type(std::move(output_type)) {} UdfWrapperCallback cb; std::vector input_types; @@ -440,7 +401,7 @@ struct PythonUdf : public PythonUdfKernelState { Status Exec(compute::KernelContext* ctx, const compute::ExecSpan& batch, compute::ExecResult* out) { auto state = arrow::internal::checked_cast(ctx->state()); - std::shared_ptr& function = state->function; + PyObject* function = state->function->obj(); const int num_args = batch.num_values(); UdfContext udf_context{ctx->memory_pool(), batch.length}; @@ -458,7 +419,7 @@ struct PythonUdf : public PythonUdfKernelState { } } - OwnedRef result(cb(function->obj(), udf_context, arg_tuple.obj())); + OwnedRef result(cb(function, udf_context, arg_tuple.obj())); RETURN_NOT_OK(CheckPyError()); // unwrapping the output for expected output type if (is_array(result.obj())) { @@ -497,12 +458,13 @@ Status RegisterUdf(PyObject* function, compute::KernelInit kernel_init, } auto scalar_func = std::make_shared(options.func_name, options.arity, options.func_doc); - Py_INCREF(function); std::vector input_types; for (const auto& in_dtype : options.input_types) { input_types.emplace_back(in_dtype); } compute::OutputType output_type(options.output_type); + // Take reference before wrapping with OwnedRefNoGIL + Py_INCREF(function); auto udf_data = std::make_shared( std::make_shared(function), cb, TypeHolder::FromTypes(options.input_types), options.output_type); @@ -565,11 +527,6 @@ Status RegisterScalarAggregateFunction(PyObject* function, UdfWrapperCallback cb registry = compute::GetFunctionRegistry(); } - // Py_INCREF here so that once a function is registered - // its refcount gets increased by 1 and doesn't get gced - // if all existing refs are gone - Py_INCREF(function); - static auto default_scalar_aggregate_options = compute::ScalarAggregateOptions::Defaults(); auto aggregate_func = std::make_shared( @@ -582,12 +539,16 @@ Status RegisterScalarAggregateFunction(PyObject* function, UdfWrapperCallback cb } compute::OutputType output_type(options.output_type); - compute::KernelInit init = [cb, function, options](compute::KernelContext* ctx, - const compute::KernelInitArgs& args) + // Take reference before wrapping with OwnedRefNoGIL + Py_INCREF(function); + auto function_ref = std::make_shared(function); + + compute::KernelInit init = [cb, function_ref, options]( + compute::KernelContext* ctx, + const compute::KernelInitArgs& args) -> Result> { return std::make_unique( - std::make_shared(function), cb, options.input_types, - options.output_type); + function_ref, cb, options.input_types, options.output_type); }; auto sig = compute::KernelSignature::Make( @@ -638,10 +599,6 @@ Status RegisterHashAggregateFunction(PyObject* function, UdfWrapperCallback cb, registry = compute::GetFunctionRegistry(); } - // Py_INCREF here so that once a function is registered - // its refcount gets increased by 1 and doesn't get gced - // if all existing refs are gone - Py_INCREF(function); UdfOptions hash_options = AdjustForHashAggregate(options); std::vector input_types; @@ -656,13 +613,15 @@ Status RegisterHashAggregateFunction(PyObject* function, UdfWrapperCallback cb, hash_options.func_name, hash_options.arity, hash_options.func_doc, &default_hash_aggregate_options); - compute::KernelInit init = [function, cb, hash_options]( + // Take reference before wrapping with OwnedRefNoGIL + Py_INCREF(function); + auto function_ref = std::make_shared(function); + compute::KernelInit init = [function_ref, cb, hash_options]( compute::KernelContext* ctx, const compute::KernelInitArgs& args) -> Result> { return std::make_unique( - std::make_shared(function), cb, hash_options.input_types, - hash_options.output_type); + function_ref, cb, hash_options.input_types, hash_options.output_type); }; auto sig = compute::KernelSignature::Make( diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index c8e376fefb3b8..22fefbbb58ba9 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -219,6 +219,31 @@ def nullary_func(context): return nullary_func, func_name +@pytest.fixture(scope="session") +def ephemeral_nullary_func_fixture(): + """ + Register a nullary scalar function with an ephemeral Python function. + This stresses that the Python function object is properly kept alive by the + registered function. + """ + def nullary_func(context): + return pa.array([42] * context.batch_length, type=pa.int64(), + memory_pool=context.memory_pool) + + func_doc = { + "summary": "random function", + "description": "generates a random value" + } + func_name = "test_ephemeral_nullary_func" + pc.register_scalar_function(nullary_func, + func_name, + func_doc, + {}, + pa.int64()) + + return func_name + + @pytest.fixture(scope="session") def wrong_output_type_func_fixture(): """ @@ -505,6 +530,12 @@ def test_nullary_function(nullary_func_fixture): batch_length=1) +def test_ephemeral_function(ephemeral_nullary_func_fixture): + name = ephemeral_nullary_func_fixture + result = pc.call_function(name, [], length=1) + assert result.to_pylist() == [42] + + def test_wrong_output_type(wrong_output_type_func_fixture): _, func_name = wrong_output_type_func_fixture