Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX device selection when formatting #5547

Merged
merged 15 commits into from
Feb 21, 2023
Merged

Add JAX device selection when formatting #5547

merged 15 commits into from
Feb 21, 2023

Conversation

alvarobartt
Copy link
Member

@alvarobartt alvarobartt commented Feb 18, 2023

What's in this PR?

After exploring for a while the JAX integration in 🤗datasets, I found out that, even though JAX prioritizes the TPU and GPU as the default device when available, the JaxFormatter doesn't let you specify the device where you want to place the jax.Arrays in case you don't want to rely on JAX's default array placement.

So on, I've included the device param in JaxFormatter but there are some things to take into consideration:

  • A formatted Dataset is copied with copy.deepcopy which means that if one adds the param device in JaxFormatter as a jaxlib.xla_extension.Device, it "fails" because that object cannot be serialized (instead of serializing the param adds a random hash instead). That's the reason why I added a function _map_devices_to_str to basically create a mapping of strings to jaxlib.xla_extension.Devices so that self.device is a string and not a jaxlib.xla_extension.Device.
  • To create a jax.Array in a device you need to either create it in the default device and then move it to the desired device with jax.device_put or directly create it in the device you want with jax.default_device() context manager.
  • JAX will create an array by default in jax.devices()[0]

More information on JAX device management is available at https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices

What's missing in this PR?

I've tested it both locally in CPU (Mac M2 and Mac M1, as no GPU support for Mac yet), and in GPU and TPU in Google Colab, let me know if you want me to provide you the Notebook for the latter.

But I did not implement any integration test as I wanted to get your feedback first.

@alvarobartt
Copy link
Member Author

alvarobartt commented Feb 18, 2023

The code below was throwing a warning:

class JaxFormatter(Formatter[Mapping, "jax.Array", Mapping]):
    def __init__(self, features=None, device=None, **jnp_array_kwargs):
        super().__init__(features=features)
        import jax
        from jaxlib.xla_extension import Device
        
        self.device = (
            device if isinstance(device, Device) else jax.devices()[0]
        )
        self.jnp_array_kwargs = jnp_array_kwargs

    ...

    def _tensorize(self, value):
        ...

        with jax.default_device(self.device):
            # calling jnp.array on a np.ndarray does copy the data
            # see https://github.com/google/jax/issues/4486
            return jnp.array(value, **{**default_dtype, **self.jnp_array_kwargs})

When providing device via param:

from datasets import Dataset
import jax

ds = Dataset.from_dict({"a": [1, 2, 3], "b": [4, 5, 6]})
ds = ds.with_format("jax", device=jax.devices()[0])
print(ds[0])

Producing the following warning:

WARNING:datasets.fingerprint:Parameter 'device'=TFRT_CPU_0 of the transform datasets.arrow_dataset.Dataset.set_format couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.

That's why I decided to map all the available devices, and assign their string representation e.g. TFRT_CPU_0 to self.device instead of jaxlib.xla_extension.Device, so that the value of the param device is washable. So on, the code that remains at the end is:

class JaxFormatter(Formatter[Mapping, "jax.Array", Mapping]):
    def __init__(self, features=None, device=None, **jnp_array_kwargs):
        super().__init__(features=features)
        import jax
        from jaxlib.xla_client import Device

        self.device_mapping = self._map_devices_to_str()
        self.device = (
            device if isinstance(device, str) else str(device) if isinstance(device, Device) else str(jax.devices()[0])
        )
        self.jnp_array_kwargs = jnp_array_kwargs

    def _map_devices_to_str(self) -> Mapping[str, "jaxlib.xla_extension.Device"]:
        import jax

        return {str(device): device for device in jax.devices()}

    ...

    def _tensorize(self, value):
        ...

        with jax.default_device(self.device_mapping[self.device]):
            # calling jnp.array on a np.ndarray does copy the data
            # see https://github.com/google/jax/issues/4486
            return jnp.array(value, **{**default_dtype, **self.jnp_array_kwargs})

But note that the latter also throws a warning if the provided device is not a string but a jaxlib.xla_extension.Device, so that's why it needs to be converted to string.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Feb 19, 2023

The documentation is not available anymore as the PR was closed or merged.

@alvarobartt
Copy link
Member Author

After some investigation, it seems that when using device=jaxlib.xla_extension.Device instead of device=string it shows the warning so that later formats fail as that cannot be unpickled.

So I think we can either add that specifically in use_with_jax.mdx documentation entry I'm creating at #5535 so that the users know that they need to surroung the jaxlib.xla_extension.Device with str(), or find a workaround to override default deepcopy behavior with def __deepcopy__(self) so that the device param is converted to string if provided as a jaxlib.xla_extension.Device, but not sure if the latter works 😕

Do you think there's any other possible solution to this issue? Thanks, @lhoestq

@lhoestq
Copy link
Member

lhoestq commented Feb 19, 2023

Cool ! Specifying the device is indeed super important.

I think we can just require device to always be a string for now, and add an example in the doc on how to get the string that corresponds to a jaxlib.xla_extension.Device ? This way we never deal with objects that are not picklable

@alvarobartt
Copy link
Member Author

Cool ! Specifying the device is indeed super important.

I think we can just require device to always be a string for now, and add an example in the doc on how to get the string that corresponds to a jaxlib.xla_extension.Device ? This way we never deal with objects that are not picklable

Sure, then I'll restrict it to string for now! Also regarding the documentation update, should we wait until #5535 is merged so that I add this on top of that?

@alvarobartt
Copy link
Member Author

CI is failing due to missing resampy in librosa already being fixed by @lhoestq in #5554

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
@alvarobartt
Copy link
Member Author

@lhoestq already moved to a global variable, I can confirm that the following now works:

import copy
import pickle

import jax
import pyarrow as pa

from datasets.formatting import JaxFormatter


_COL_A = [0, 1, 2]
_COL_B = ["foo", "bar", "foobar"]
_COL_C = [[[1.0, 0.0, 0.0]] * 2, [[0.0, 1.0, 0.0]] * 2, [[0.0, 0.0, 1.0]] * 2]
pa_table = pa.Table.from_pydict({"a": _COL_A, "b": _COL_B, "c": _COL_C})

device = jax.devices()[0]
formatter = JaxFormatter(device=str(device))

pickle.dumps(formatter)
copy.deepcopy(formatter)

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks all good now thank you !

Is there anything else you wanted to add ? Otherwise I think it's ready for merge

src/datasets/formatting/jax_formatter.py Outdated Show resolved Hide resolved
src/datasets/formatting/jax_formatter.py Outdated Show resolved Hide resolved
src/datasets/formatting/jax_formatter.py Outdated Show resolved Hide resolved
src/datasets/formatting/jax_formatter.py Outdated Show resolved Hide resolved
Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
@alvarobartt
Copy link
Member Author

alvarobartt commented Feb 21, 2023

Looks all good now thank you !

Is there anything else you wanted to add ? Otherwise I think it's ready for merge

Nothing else to add, I've already applied your suggestions, so ready to merge! Thanks for your input/feedback @lhoestq 🤗

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's merge then :)

@lhoestq lhoestq merged commit 21c86d5 into huggingface:main Feb 21, 2023
@github-actions
Copy link

Show benchmarks

PyArrow==6.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.009815 / 0.011353 (-0.001538) 0.005443 / 0.011008 (-0.005565) 0.101244 / 0.038508 (0.062736) 0.036573 / 0.023109 (0.013464) 0.304761 / 0.275898 (0.028863) 0.365527 / 0.323480 (0.042047) 0.008244 / 0.007986 (0.000258) 0.004200 / 0.004328 (-0.000128) 0.077471 / 0.004250 (0.073221) 0.045266 / 0.037052 (0.008214) 0.310213 / 0.258489 (0.051724) 0.344247 / 0.293841 (0.050406) 0.039530 / 0.128546 (-0.089016) 0.012254 / 0.075646 (-0.063393) 0.335039 / 0.419271 (-0.084233) 0.049525 / 0.043533 (0.005992) 0.298350 / 0.255139 (0.043211) 0.312031 / 0.283200 (0.028832) 0.108581 / 0.141683 (-0.033102) 1.481178 / 1.452155 (0.029023) 1.497662 / 1.492716 (0.004946)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.014762 / 0.018006 (-0.003244) 0.447099 / 0.000490 (0.446609) 0.009074 / 0.000200 (0.008874) 0.000688 / 0.000054 (0.000633)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.027466 / 0.037411 (-0.009945) 0.109715 / 0.014526 (0.095189) 0.119062 / 0.176557 (-0.057495) 0.188964 / 0.737135 (-0.548171) 0.127057 / 0.296338 (-0.169282)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.395092 / 0.215209 (0.179883) 3.948091 / 2.077655 (1.870436) 1.795160 / 1.504120 (0.291040) 1.603704 / 1.541195 (0.062509) 1.714491 / 1.468490 (0.246001) 0.700489 / 4.584777 (-3.884288) 3.767493 / 3.745712 (0.021781) 3.288374 / 5.269862 (-1.981488) 1.783711 / 4.565676 (-2.781965) 0.085119 / 0.424275 (-0.339156) 0.012349 / 0.007607 (0.004742) 0.502135 / 0.226044 (0.276091) 5.019321 / 2.268929 (2.750392) 2.236469 / 55.444624 (-53.208155) 1.914376 / 6.876477 (-4.962101) 1.998579 / 2.142072 (-0.143494) 0.847841 / 4.805227 (-3.957386) 0.166035 / 6.500664 (-6.334629) 0.062469 / 0.075469 (-0.013000)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.245380 / 1.841788 (-0.596408) 14.757872 / 8.074308 (6.683564) 14.460373 / 10.191392 (4.268981) 0.152981 / 0.680424 (-0.527443) 0.029001 / 0.534201 (-0.505200) 0.439597 / 0.579283 (-0.139686) 0.437232 / 0.434364 (0.002868) 0.532464 / 0.540337 (-0.007873) 0.629225 / 1.386936 (-0.757711)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.007165 / 0.011353 (-0.004188) 0.005220 / 0.011008 (-0.005789) 0.075849 / 0.038508 (0.037341) 0.032717 / 0.023109 (0.009608) 0.331205 / 0.275898 (0.055307) 0.364955 / 0.323480 (0.041475) 0.005518 / 0.007986 (-0.002468) 0.004069 / 0.004328 (-0.000259) 0.073900 / 0.004250 (0.069650) 0.046346 / 0.037052 (0.009294) 0.337473 / 0.258489 (0.078984) 0.393062 / 0.293841 (0.099222) 0.037533 / 0.128546 (-0.091013) 0.012577 / 0.075646 (-0.063070) 0.087975 / 0.419271 (-0.331297) 0.049508 / 0.043533 (0.005975) 0.333423 / 0.255139 (0.078284) 0.354345 / 0.283200 (0.071145) 0.099879 / 0.141683 (-0.041804) 1.413304 / 1.452155 (-0.038851) 1.494222 / 1.492716 (0.001506)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.206835 / 0.018006 (0.188828) 0.438246 / 0.000490 (0.437757) 0.000410 / 0.000200 (0.000210) 0.000059 / 0.000054 (0.000004)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.028186 / 0.037411 (-0.009225) 0.109322 / 0.014526 (0.094797) 0.119581 / 0.176557 (-0.056975) 0.191784 / 0.737135 (-0.545351) 0.125100 / 0.296338 (-0.171238)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.419418 / 0.215209 (0.204209) 4.167374 / 2.077655 (2.089720) 1.995812 / 1.504120 (0.491693) 1.804602 / 1.541195 (0.263407) 1.869131 / 1.468490 (0.400641) 0.709486 / 4.584777 (-3.875291) 3.838019 / 3.745712 (0.092307) 2.086206 / 5.269862 (-3.183656) 1.323970 / 4.565676 (-3.241707) 0.089477 / 0.424275 (-0.334798) 0.012402 / 0.007607 (0.004795) 0.519291 / 0.226044 (0.293246) 5.194091 / 2.268929 (2.925162) 2.487055 / 55.444624 (-52.957570) 2.122495 / 6.876477 (-4.753982) 2.194910 / 2.142072 (0.052837) 0.842837 / 4.805227 (-3.962390) 0.167229 / 6.500664 (-6.333435) 0.064690 / 0.075469 (-0.010779)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.275931 / 1.841788 (-0.565857) 14.577000 / 8.074308 (6.502692) 13.633235 / 10.191392 (3.441843) 0.184511 / 0.680424 (-0.495913) 0.017439 / 0.534201 (-0.516762) 0.424374 / 0.579283 (-0.154909) 0.427803 / 0.434364 (-0.006561) 0.527790 / 0.540337 (-0.012548) 0.627301 / 1.386936 (-0.759635)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants