Uplift mlir and add shardy to dependencies #347
GitHub Actions / TT-XLA Tests
failed
Jan 28, 2025 in 0s
447 tests run, 383 passed, 60 skipped, 4 failed.
Annotations
Check failure on line 40 in tests/jax/ops/test_convert.py
github-actions / TT-XLA Tests
test_convert.test_convert[uint32-bfloat16]
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error code: 13
Raw output
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
from_dtype = 'bfloat16', to_dtype = 'uint32'
@pytest.mark.parametrize(
"from_dtype",
[
"bfloat16",
"float32",
],
)
@pytest.mark.parametrize(
"to_dtype",
[
"uint32",
"uint64",
"int32",
"int64",
"bfloat16",
"float32",
"float64",
],
)
def test_convert(from_dtype: DTypeLike, to_dtype: DTypeLike):
def convert(x: jax.Array) -> jax.Array:
return jlx.convert_element_type(x, new_dtype=jnp.dtype(to_dtype))
x_shape = (32, 32) # Shape does not make any impact here, thus not parametrized.
input = random_tensor(x_shape, dtype=from_dtype)
> run_op_test(convert, [input])
tests/jax/ops/test_convert.py:40:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/infra/op_tester.py:69: in run_op_test
tester.test(workload)
tests/infra/op_tester.py:35: in test
tt_res = DeviceRunner.run_on_tt_device(compiled_workload)
tests/infra/device_runner.py:24: in run_on_tt_device
return DeviceRunner._run_on_device(DeviceType.TT, workload)
tests/infra/device_runner.py:73: in _run_on_device
return device_workload.execute()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Workload(executable=<PjitFunction of <function test_convert.<locals>.convert at 0x7f57788ac1f0>>, args=[Array([[0.1015... [0.265625, 0.492188, 0.78125, ..., 0.75, 0.429688, 0.820312]], dtype=bfloat16)], kwargs={}, static_argnames=[])
def execute(self) -> Any:
"""Calls callable passing stored args and kwargs directly."""
> return self.executable(*self.args, **self.kwargs)
E jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error code: 13
tests/infra/workload.py:29: XlaRuntimeError
Check failure on line 40 in tests/jax/ops/test_convert.py
github-actions / TT-XLA Tests
test_convert.test_convert[uint32-float32]
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error code: 13
Raw output
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
from_dtype = 'float32', to_dtype = 'uint32'
@pytest.mark.parametrize(
"from_dtype",
[
"bfloat16",
"float32",
],
)
@pytest.mark.parametrize(
"to_dtype",
[
"uint32",
"uint64",
"int32",
"int64",
"bfloat16",
"float32",
"float64",
],
)
def test_convert(from_dtype: DTypeLike, to_dtype: DTypeLike):
def convert(x: jax.Array) -> jax.Array:
return jlx.convert_element_type(x, new_dtype=jnp.dtype(to_dtype))
x_shape = (32, 32) # Shape does not make any impact here, thus not parametrized.
input = random_tensor(x_shape, dtype=from_dtype)
> run_op_test(convert, [input])
tests/jax/ops/test_convert.py:40:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/infra/op_tester.py:69: in run_op_test
tester.test(workload)
tests/infra/op_tester.py:35: in test
tt_res = DeviceRunner.run_on_tt_device(compiled_workload)
tests/infra/device_runner.py:24: in run_on_tt_device
return DeviceRunner._run_on_device(DeviceType.TT, workload)
tests/infra/device_runner.py:73: in _run_on_device
return device_workload.execute()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Workload(executable=<PjitFunction of <function test_convert.<locals>.convert at 0x7f57788ae440>>, args=[Array([[0.0947...2062037, 0.7978035 , ..., 0.9188044 , 0.5129981 ,
0.53420603]], dtype=float32)], kwargs={}, static_argnames=[])
def execute(self) -> Any:
"""Calls callable passing stored args and kwargs directly."""
> return self.executable(*self.args, **self.kwargs)
E jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error code: 13
tests/infra/workload.py:29: XlaRuntimeError
Check failure on line 40 in tests/jax/ops/test_convert.py
github-actions / TT-XLA Tests
test_convert.test_convert[uint64-bfloat16]
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error code: 13
Raw output
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
from_dtype = 'bfloat16', to_dtype = 'uint64'
@pytest.mark.parametrize(
"from_dtype",
[
"bfloat16",
"float32",
],
)
@pytest.mark.parametrize(
"to_dtype",
[
"uint32",
"uint64",
"int32",
"int64",
"bfloat16",
"float32",
"float64",
],
)
def test_convert(from_dtype: DTypeLike, to_dtype: DTypeLike):
def convert(x: jax.Array) -> jax.Array:
return jlx.convert_element_type(x, new_dtype=jnp.dtype(to_dtype))
x_shape = (32, 32) # Shape does not make any impact here, thus not parametrized.
input = random_tensor(x_shape, dtype=from_dtype)
> run_op_test(convert, [input])
tests/jax/ops/test_convert.py:40:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/infra/op_tester.py:69: in run_op_test
tester.test(workload)
tests/infra/op_tester.py:35: in test
tt_res = DeviceRunner.run_on_tt_device(compiled_workload)
tests/infra/device_runner.py:24: in run_on_tt_device
return DeviceRunner._run_on_device(DeviceType.TT, workload)
tests/infra/device_runner.py:73: in _run_on_device
return device_workload.execute()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Workload(executable=<PjitFunction of <function test_convert.<locals>.convert at 0x7f57788ad750>>, args=[Array([[0.1015... [0.265625, 0.492188, 0.78125, ..., 0.75, 0.429688, 0.820312]], dtype=bfloat16)], kwargs={}, static_argnames=[])
def execute(self) -> Any:
"""Calls callable passing stored args and kwargs directly."""
> return self.executable(*self.args, **self.kwargs)
E jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error code: 13
tests/infra/workload.py:29: XlaRuntimeError
Check failure on line 40 in tests/jax/ops/test_convert.py
github-actions / TT-XLA Tests
test_convert.test_convert[uint64-float32]
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error code: 13
Raw output
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
from_dtype = 'float32', to_dtype = 'uint64'
@pytest.mark.parametrize(
"from_dtype",
[
"bfloat16",
"float32",
],
)
@pytest.mark.parametrize(
"to_dtype",
[
"uint32",
"uint64",
"int32",
"int64",
"bfloat16",
"float32",
"float64",
],
)
def test_convert(from_dtype: DTypeLike, to_dtype: DTypeLike):
def convert(x: jax.Array) -> jax.Array:
return jlx.convert_element_type(x, new_dtype=jnp.dtype(to_dtype))
x_shape = (32, 32) # Shape does not make any impact here, thus not parametrized.
input = random_tensor(x_shape, dtype=from_dtype)
> run_op_test(convert, [input])
tests/jax/ops/test_convert.py:40:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/infra/op_tester.py:69: in run_op_test
tester.test(workload)
tests/infra/op_tester.py:35: in test
tt_res = DeviceRunner.run_on_tt_device(compiled_workload)
tests/infra/device_runner.py:24: in run_on_tt_device
return DeviceRunner._run_on_device(DeviceType.TT, workload)
tests/infra/device_runner.py:73: in _run_on_device
return device_workload.execute()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Workload(executable=<PjitFunction of <function test_convert.<locals>.convert at 0x7f57788aeef0>>, args=[Array([[0.0947...2062037, 0.7978035 , ..., 0.9188044 , 0.5129981 ,
0.53420603]], dtype=float32)], kwargs={}, static_argnames=[])
def execute(self) -> Any:
"""Calls callable passing stored args and kwargs directly."""
> return self.executable(*self.args, **self.kwargs)
E jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error code: 13
tests/infra/workload.py:29: XlaRuntimeError
Loading