-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(python): Support pytorch
Tensor and Dataset export with new to_torch
DataFrame/Series method
#15931
Conversation
86287c7
to
082f662
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #15931 +/- ##
==========================================
+ Coverage 81.26% 81.31% +0.04%
==========================================
Files 1381 1382 +1
Lines 176636 176953 +317
Branches 3034 3056 +22
==========================================
+ Hits 143549 143882 +333
+ Misses 32606 32589 -17
- Partials 481 482 +1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
db05dc3
to
a7838c2
Compare
a7838c2
to
d051f9c
Compare
Good. I think we should keep that for third party exports as it feels like much out of our control and not our core. Otherwise I'd be much more hesitant. On the testing, I'd like to have a new pytest mark that for larger third party integrations. That way we can test this in CI and not be bothered with the extra testing overhead when we are developing local to the core. I don't anticipate any of this breaking when I do large refactors. |
…o_torch` DataFrame/Series method
d051f9c
to
854a843
Compare
854a843
to
c672849
Compare
Yup. While it's a large commit, the functionality itself is actually straightforward. Unless PyTorch undergoes some fairly radical restructuring we won't have any trouble. Definitely should leave it 'unstable' for now, but we can look to soften/drop that message once we have a decent number of releases without incident/modification ✌️
Done! Added a new Also, this started out in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Alex. Looks good. Especially given the "unstable", let's give it a spin!
This adds streamlined DataFrame (and Series) export to PyTorch
Tensor
andDataset
.(Note: a
torch.IterableDataset
1 option would likely be useful for particularly large DataFrames. Needs some experimentation, so will be left to a subsequent PR 🤔).Features
Supported PyTorch export types:
df.to_torch()
: export the entire frame to a single 2D Tensor (equivalent todf.to_torch("tensor")
).df.to_torch("dict")
: export frame to a dictionary of Tensors.df.to_torch("dataset")
: export frame to aPolarsDataset
(inheriting fromTensorDataset
, but additionally offering clean frame integration and a handful of other niceties; can also be imported independently frompolars.ml.torch
).The
PolarsDataset
object:TensorDataset
and, once initialised, is drop-in compatible with it.selectors
2.half()
method for experimenting withfloat16
data early.schema
attribute, showing the current feature/label dtypes.features
andlabels
attributes (in addition totensors
).repr
, for example:<PolarsDataset [len:20640, features:8, labels:1] at 0x3301B1EE0>
.Examples
As 2D Tensor:
As dict of Tensors:
Demonstrate
PolarsDataset
usage with somescikit-learn
data:Establish a DataFrame from the sklearn datasets...
...trivially export a float32 Dataset with features/labels...
...and pass to a DataLoader:
Follow-up
Feedback on this one is welcome! It has been marked as unstable (though likely only for a limited time) to allow for quick iteration/tweaks if necessary.
Likely upcoming additions:
PolarsIterableDataset
could be useful for constraining peak memory usage (eg: don't materialise all frame data to Tensor format up-front).sequence_length
parameter for Dataset export has been suggested, which could be helpful for transformer use-cases (vs linear regression).Note
The associated unit tests provide 100% line coverage of the new code 🎯
Footnotes
https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset ↩
https://docs.pola.rs/py-polars/html/reference/selectors.html ↩