diff --git a/tests/test_base.py b/tests/test_base.py index 62fd702..d558cfc 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -6,7 +6,6 @@ from types import MappingProxyType from typing import Any -import jax.numpy as jnp import pytest import quaxed.array_api as xp @@ -161,7 +160,7 @@ def test_asdict(self, vector): # Test with a different dict_factory adict = vector.asdict(dict_factory=UserDict) assert isinstance(adict, UserDict) - assert all(qnp.array_equal(v, getattr(vector, k)) for k, v in adict.items()) + assert all(jnp.array_equal(v, getattr(vector, k)) for k, v in adict.items()) def test_components(self, vector): """Test :meth:`AbstractVector.components`."""