Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "multi device" support #59

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

betatim
Copy link
Member

@betatim betatim commented Sep 4, 2024

Having more than one device is useful during testing to allow you to find bugs related to how arrays on different devices are handled. Closes #56

With scikit-learn we run into the frustrating situation were contributors execute tests locally, they all pass but then see failures on the CI related to the fact that e.g. PyTorch has several devices and some things work on the CPU device but not on the CUDA/MPS device. However, if you have neither of those on your local machine you can't really test this upfront and to debug it you need to rely on the CI.

The idea of this PR is to add support for multiple devices to array-api-strict to make testing easier. The default device continues to be the CPU device and for arrays that use it nothing should change. However, you can now place an array on a different device with array_api_strict.Device("pony") (or some other string, each string is a new device). For arrays on a device that isn't the CPU device calls like np.asarray(some_strict_array) will raise an error. This mirrors how PyTorch treats arrays on the CPU and MPS device.

What isn't yet implemented in this PR is raising an error if you try to operate on arrays that are not on the same device.

I wanted to open this PR already now after just a short amount of effort to get feedback what people think about this before putting in the time to update all the tests, etc.

Having more than one device is useful during testing to allow you to
find bugs related to how arrays on different devices are handled.
@lucascolley
Copy link
Contributor

+1! Another question is what the default should be (technically Device("pony") is more strict), but probably better if we can keep the cpu default for backwards compatibility.

@betatim
Copy link
Member Author

betatim commented Sep 6, 2024

I think the CPU device should be the default. That way code that exists today should keep working and the only people who notice any changes are those who use the pony device.

@asmeurer
Copy link
Member

asmeurer commented Sep 7, 2024

This looks good so far. We need to make sure the semantics specified at https://data-apis.org/array-api/latest/design_topics/device_support.html#semantics are followed, namely, disallowing combining arrays from different devices, and making sure that if a function creates a new array based on an existing array that it uses the same device.

For tests, ideally this would be tested in array-api-tests, but right now device support is not tested at all there. If you just want to add some basic tests here for now, that is fien.

Finally, there is the devices inspection API. https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.devices.html#array_api.info.devices We need to think about how that will work. One option would just be to create a small but fixed number of devices. Or we could add some flags to make it configurable https://data-apis.org/array-api-strict/api.html#array_api_strict.set_array_api_strict_flags

@betatim
Copy link
Member Author

betatim commented Sep 27, 2024

I've rebooted my work. The long pause is because I went on holiday :D

I think we have to limit ourselves to a fixed number of devices, otherwise we can't full fill the requirement that the info extension can provide a list of devices. So now you can use the CPU_DEVICE and the (creatively named) device1 and device2.

Slowly making progress towards the creation functions and "array combination" functions respecting the device

@betatim
Copy link
Member Author

betatim commented Sep 27, 2024

Do you use ruff or black or something like that for formatting?

@betatim
Copy link
Member Author

betatim commented Sep 27, 2024

It looks like it would be quite tricky to add device testing to array-api-tests. At least at my level of knowledge (unfamiliar with hypothesis and array-api-tests). It looks like you'd add a helper, maybe all_devices, similar to all_dtypes and then use it in the @given decorator. The tricky thing is that how to specify a device depends on the library, in a new version of the standard you could use the inspection API to get all devices. So yeah, for now I might add some basic testing here.

@asmeurer
Copy link
Member

Do you use ruff or black or something like that for formatting?

There's no autoformatting on this repo.

It looks like it would be quite tricky to add device testing to array-api-tests. At least at my level of knowledge (unfamiliar with hypothesis and array-api-tests). It looks like you'd add a helper, maybe all_devices, similar to all_dtypes and then use it in the @given decorator. The tricky thing is that how to specify a device depends on the library, in a new version of the standard you could use the inspection API to get all devices. So yeah, for now I might add some basic testing here.

I think it would have to use the devices() function in the inspection API. That would mean that the tests would only work against the newest version of the standard and it would only work against the compat library, but I think that's fine. You'd also probably want to make it optional.

It's also possible to do some basic testing using the default device, like that x.device and device= are consistent.

The annoying thing for the test suite is making sure every function everywhere is passing device through properly so that everything gets created on the same device. It would also probably require some upstream fixes to the hypothesis array-api support.

@asmeurer
Copy link
Member

I think what we need here are just some big parameterized tests combining basic example arrays with different devices across all the APIs. For instance, there's an existing test that checks type promotion and the "no mixing devices" test could look very similar to that.



def logaddexp(x1: Array, x2: Array) -> Array:
def logaddexp(x1: Array, x2: Array, /) -> Array:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was missing/typo.

@@ -19,8 +19,16 @@

import pytest

import array_api_strict


def nargs(func):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified this so that it works with decorated/wrapped functions as well. len(getfullargspec(f).args) returns zero for functions that are decorated with the array API version decorator. From what I can tell from the Python docs this is kind of on purpose/to preserve existing behaviour. signature() does the right thing for wrapped functions, but it needs slightly more explicit work to count the arguments.

I think the intention/rule is that nargs() counts the number of positional only arguments, which is basically the "number of arrays you need to pass to a elementwise function". I went with the very explicit way of counting the args partially as a way to make it easier for people from the future to understand what nargs is meant to do (even if it contains a bug and doesn't actually do what it is meant to do).

@@ -91,12 +99,57 @@ def nargs(func):
"trunc": "real numeric",
}


def test_nargs():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A short test to make sure nargs works but also that all of the functions that we look at have "the right signature" - I found logaddexp was missing that trailing / when working on nargs. So it seems useful to have a "all functions have a reasonable number of arguments" test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's good. The array-api test suite doesn't check for positional-only, though it probably could now that it uses inspect.signature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add virtual devices to make it easier for array API consumer to check that they use device correctly
3 participants