Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/instance udfs #890

Merged
merged 9 commits into from
Oct 4, 2024
Merged

Conversation

timsaucer
Copy link
Contributor

Which issue does this PR close?

Closes #822

Rationale for this change

With this change we add an additional argument to udaf and udwf to be used for calls to __init__ of the class that implements the aggregation or partition evaluator. This allows users to pass in these arguments so we can reuse a single class that requires parameters.

The other option is to always pass in these parameters as lit() values which is not as performant.

What changes are included in this PR?

Updates the udaf call to add in the provided arguments when initializing the classes. Also adds unit tests.

Are there any user-facing changes?

There is a non-breaking change to the udaf function that allows for an optional set of arguments to be passed in. If any users are instead calling AggregateUDF() directly they will need to add in an empty list as the extra argument.

There is a breaking change to udwf but since it is not in any released version yet, it shouldn't count as a user-facing change in my opinion.

@timsaucer timsaucer self-assigned this Oct 2, 2024
@Michael-J-Ward
Copy link
Contributor

Instead of saving the __init__ arguments and calling them ourselves, wouldn't it be more flexible and less work for datafusion-python to allow the user to provide a factory function?

Modulo runtime assertions, this works on main currently:

def test_udaf_aggregate_with_arguments(df):
    bias = 10.0

    def factory() -> Accumulator:
        return Summarize(bias)

    summarize = udaf(
        factory,
        pa.float64(),
        pa.float64(),
        [pa.float64()],
        volatility="immutable",
        # arguments=[bias],
    )

    df1 = df.aggregate([], [summarize(column("a"))])

    # execute and collect the first (and only) batch
    result = df1.collect()[0]

    assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])

@timsaucer
Copy link
Contributor Author

This is a very good idea. I'll change to your approach.

Copy link
Contributor

@Michael-J-Ward Michael-J-Ward left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@Michael-J-Ward
Copy link
Contributor

Ah, shoot. Sorry @timsaucer, I created a conflict when I merged the release-testing PR. This needs a rebase.

@timsaucer
Copy link
Contributor Author

No worries at all. I'll wait for #892 CI to finish, merge that, then rebase this.

@timsaucer timsaucer force-pushed the feature/instance-udfs branch from 9430084 to 250baea Compare October 4, 2024 15:31
@timsaucer timsaucer merged commit 1fd3762 into apache:main Oct 4, 2024
15 checks passed
@timsaucer timsaucer deleted the feature/instance-udfs branch October 4, 2024 16:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enhance udf to take additional non-expr arguments
2 participants