-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(python): Add to_jax
methods to support Jax Array export from DataFrame
and Series
#16294
Conversation
bf5ed60
to
23ec6da
Compare
23ec6da
to
179319b
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #16294 +/- ##
==========================================
- Coverage 81.35% 81.34% -0.01%
==========================================
Files 1403 1403
Lines 183463 183515 +52
Branches 2929 2946 +17
==========================================
+ Hits 149253 149288 +35
- Misses 33707 33723 +16
- Partials 503 504 +1 ☔ View full report in Codecov by Sentry. |
92bfe04
to
5cf3de6
Compare
to_jax
methods to support export to Jax arrays from DataFrame
and Series
to_jax
methods to support Jax Array export from DataFrame
and Series
to_jax
methods to support Jax Array export from DataFrame
and Series
to_jax
methods to support Jax Array export from DataFrame
and Series
5c182c3
to
4f14dbe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally I am not familiar with jax, but if you're enthousiastic about it that's good enough reason for me to accept an integration.
I left some minor comments, but overall looks great as usual 👍
… from `DataFrame`
e9b2043
to
a68782a
Compare
CodSpeed Performance ReportMerging #16294 will degrade performances by 26.57%Comparing Summary
Benchmarks breakdown
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what CodSpeed is on about, but I don't see any reason why this PR should cause a regression.
Indeed! And it was fine before adding the extra docstring info; must be having some transient issues on their side 🤷♂️ |
Just 2c, I happened to need this + find it literally right now. An immense thank you from me! |
Nice to hear; feedback welcome :) |
Continuing the theme of streamlining ML preprocessing tasks using Polars (following the previous PR offering
torch
integration), this PR adds support forjax
1 array export from DataFrame and/or Series.Features
Supported Jax export modes:
df.to_jax()
: export the entire frame to a single Array.df.to_jax("dict")
: export frame to a dictionary of column Arrays.df.to_jax("dict", label=…, features=…)
: export frame to a dictionary of label/features Arrays (this mode is also now available forto_torch
).Additional options:
Can create arrays on a specific device2 (eg: CPU, GPU, TPU, METAL3, etc).
Can specify the memory format (eg: "c" or "fortran" order).
df.to_jax(device="cpu", order="c")
df.to_jax(device=jax.devices("gpu")[1])
Examples
As Array:
As dict of Arrays:
As dict of label/features Arrays:
(note: if features are not specified they are implied as being "everything except the label")Notes
As with the torch PR, Jax support is designated "CI-only" for unit tests and requirements, and you'll need to use
make requirements-all
if you want Polars to install the related libraries in your local development environment. The doctests are similarly gated by the presence of thejax
library; they will run if you have it (or are executing on CI), and are omitted otherwise.Footnotes
Jax: https://jax.readthedocs.io/en/latest/index.html ↩
Placement on devices: https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices ↩
Accelerated JAX training on Mac: https://developer.apple.com/metal/jax/ ↩