-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add JAX device selection when formatting #5547
Conversation
The code below was throwing a warning: class JaxFormatter(Formatter[Mapping, "jax.Array", Mapping]):
def __init__(self, features=None, device=None, **jnp_array_kwargs):
super().__init__(features=features)
import jax
from jaxlib.xla_extension import Device
self.device = (
device if isinstance(device, Device) else jax.devices()[0]
)
self.jnp_array_kwargs = jnp_array_kwargs
...
def _tensorize(self, value):
...
with jax.default_device(self.device):
# calling jnp.array on a np.ndarray does copy the data
# see https://github.com/google/jax/issues/4486
return jnp.array(value, **{**default_dtype, **self.jnp_array_kwargs}) When providing from datasets import Dataset
import jax
ds = Dataset.from_dict({"a": [1, 2, 3], "b": [4, 5, 6]})
ds = ds.with_format("jax", device=jax.devices()[0])
print(ds[0]) Producing the following warning:
That's why I decided to map all the available devices, and assign their string representation e.g. class JaxFormatter(Formatter[Mapping, "jax.Array", Mapping]):
def __init__(self, features=None, device=None, **jnp_array_kwargs):
super().__init__(features=features)
import jax
from jaxlib.xla_client import Device
self.device_mapping = self._map_devices_to_str()
self.device = (
device if isinstance(device, str) else str(device) if isinstance(device, Device) else str(jax.devices()[0])
)
self.jnp_array_kwargs = jnp_array_kwargs
def _map_devices_to_str(self) -> Mapping[str, "jaxlib.xla_extension.Device"]:
import jax
return {str(device): device for device in jax.devices()}
...
def _tensorize(self, value):
...
with jax.default_device(self.device_mapping[self.device]):
# calling jnp.array on a np.ndarray does copy the data
# see https://github.com/google/jax/issues/4486
return jnp.array(value, **{**default_dtype, **self.jnp_array_kwargs}) But note that the latter also throws a warning if the provided |
The documentation is not available anymore as the PR was closed or merged. |
After some investigation, it seems that when using So I think we can either add that specifically in Do you think there's any other possible solution to this issue? Thanks, @lhoestq |
Cool ! Specifying the device is indeed super important. I think we can just require |
Sure, then I'll restrict it to string for now! Also regarding the documentation update, should we wait until #5535 is merged so that I add this on top of that? |
For some reason it was throwing a `TypeError` as explained at jax-ml/jax#4867, but suddenly it disappeared
Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
@lhoestq already moved to a global variable, I can confirm that the following now works: import copy
import pickle
import jax
import pyarrow as pa
from datasets.formatting import JaxFormatter
_COL_A = [0, 1, 2]
_COL_B = ["foo", "bar", "foobar"]
_COL_C = [[[1.0, 0.0, 0.0]] * 2, [[0.0, 1.0, 0.0]] * 2, [[0.0, 0.0, 1.0]] * 2]
pa_table = pa.Table.from_pydict({"a": _COL_A, "b": _COL_B, "c": _COL_C})
device = jax.devices()[0]
formatter = JaxFormatter(device=str(device))
pickle.dumps(formatter)
copy.deepcopy(formatter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks all good now thank you !
Is there anything else you wanted to add ? Otherwise I think it's ready for merge
Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
Nothing else to add, I've already applied your suggestions, so ready to merge! Thanks for your input/feedback @lhoestq 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's merge then :)
Show benchmarksPyArrow==6.0.0 Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
|
What's in this PR?
After exploring for a while the JAX integration in 🤗
datasets
, I found out that, even though JAX prioritizes the TPU and GPU as the default device when available, theJaxFormatter
doesn't let you specify the device where you want to place thejax.Array
s in case you don't want to rely on JAX's default array placement.So on, I've included the
device
param inJaxFormatter
but there are some things to take into consideration:Dataset
is copied withcopy.deepcopy
which means that if one adds the paramdevice
inJaxFormatter
as ajaxlib.xla_extension.Device
, it "fails" because that object cannot be serialized (instead of serializing the param adds a random hash instead). That's the reason why I added a function_map_devices_to_str
to basically create a mapping of strings tojaxlib.xla_extension.Device
s so thatself.device
is a string and not ajaxlib.xla_extension.Device
.jax.Array
in a device you need to either create it in the default device and then move it to the desired device withjax.device_put
or directly create it in the device you want withjax.default_device()
context manager.jax.devices()[0]
More information on JAX device management is available at https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices
What's missing in this PR?
I've tested it both locally in CPU (Mac M2 and Mac M1, as no GPU support for Mac yet), and in GPU and TPU in Google Colab, let me know if you want me to provide you the Notebook for the latter.
But I did not implement any integration test as I wanted to get your feedback first.