Skip to content

Uplift mlir and add shardy to dependencies #347

Uplift mlir and add shardy to dependencies

Uplift mlir and add shardy to dependencies #347

GitHub Actions / TT-XLA Tests failed Jan 28, 2025 in 0s

447 tests run, 383 passed, 60 skipped, 4 failed.

Annotations

Check failure on line 40 in tests/jax/ops/test_convert.py

See this annotation in the file changed.

@github-actions github-actions / TT-XLA Tests

test_convert.test_convert[uint32-bfloat16]

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error code: 13
Raw output
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

from_dtype = 'bfloat16', to_dtype = 'uint32'

    @pytest.mark.parametrize(
        "from_dtype",
        [
            "bfloat16",
            "float32",
        ],
    )
    @pytest.mark.parametrize(
        "to_dtype",
        [
            "uint32",
            "uint64",
            "int32",
            "int64",
            "bfloat16",
            "float32",
            "float64",
        ],
    )
    def test_convert(from_dtype: DTypeLike, to_dtype: DTypeLike):
        def convert(x: jax.Array) -> jax.Array:
            return jlx.convert_element_type(x, new_dtype=jnp.dtype(to_dtype))
    
        x_shape = (32, 32)  # Shape does not make any impact here, thus not parametrized.
        input = random_tensor(x_shape, dtype=from_dtype)
    
>       run_op_test(convert, [input])

tests/jax/ops/test_convert.py:40: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/infra/op_tester.py:69: in run_op_test
    tester.test(workload)
tests/infra/op_tester.py:35: in test
    tt_res = DeviceRunner.run_on_tt_device(compiled_workload)
tests/infra/device_runner.py:24: in run_on_tt_device
    return DeviceRunner._run_on_device(DeviceType.TT, workload)
tests/infra/device_runner.py:73: in _run_on_device
    return device_workload.execute()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = Workload(executable=<PjitFunction of <function test_convert.<locals>.convert at 0x7f57788ac1f0>>, args=[Array([[0.1015...   [0.265625, 0.492188, 0.78125, ..., 0.75, 0.429688, 0.820312]],      dtype=bfloat16)], kwargs={}, static_argnames=[])

    def execute(self) -> Any:
        """Calls callable passing stored args and kwargs directly."""
>       return self.executable(*self.args, **self.kwargs)
E       jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error code: 13

tests/infra/workload.py:29: XlaRuntimeError

Check failure on line 40 in tests/jax/ops/test_convert.py

See this annotation in the file changed.

@github-actions github-actions / TT-XLA Tests

test_convert.test_convert[uint32-float32]

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error code: 13
Raw output
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

from_dtype = 'float32', to_dtype = 'uint32'

    @pytest.mark.parametrize(
        "from_dtype",
        [
            "bfloat16",
            "float32",
        ],
    )
    @pytest.mark.parametrize(
        "to_dtype",
        [
            "uint32",
            "uint64",
            "int32",
            "int64",
            "bfloat16",
            "float32",
            "float64",
        ],
    )
    def test_convert(from_dtype: DTypeLike, to_dtype: DTypeLike):
        def convert(x: jax.Array) -> jax.Array:
            return jlx.convert_element_type(x, new_dtype=jnp.dtype(to_dtype))
    
        x_shape = (32, 32)  # Shape does not make any impact here, thus not parametrized.
        input = random_tensor(x_shape, dtype=from_dtype)
    
>       run_op_test(convert, [input])

tests/jax/ops/test_convert.py:40: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/infra/op_tester.py:69: in run_op_test
    tester.test(workload)
tests/infra/op_tester.py:35: in test
    tt_res = DeviceRunner.run_on_tt_device(compiled_workload)
tests/infra/device_runner.py:24: in run_on_tt_device
    return DeviceRunner._run_on_device(DeviceType.TT, workload)
tests/infra/device_runner.py:73: in _run_on_device
    return device_workload.execute()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = Workload(executable=<PjitFunction of <function test_convert.<locals>.convert at 0x7f57788ae440>>, args=[Array([[0.0947...2062037, 0.7978035 , ..., 0.9188044 , 0.5129981 ,
        0.53420603]], dtype=float32)], kwargs={}, static_argnames=[])

    def execute(self) -> Any:
        """Calls callable passing stored args and kwargs directly."""
>       return self.executable(*self.args, **self.kwargs)
E       jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error code: 13

tests/infra/workload.py:29: XlaRuntimeError

Check failure on line 40 in tests/jax/ops/test_convert.py

See this annotation in the file changed.

@github-actions github-actions / TT-XLA Tests

test_convert.test_convert[uint64-bfloat16]

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error code: 13
Raw output
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

from_dtype = 'bfloat16', to_dtype = 'uint64'

    @pytest.mark.parametrize(
        "from_dtype",
        [
            "bfloat16",
            "float32",
        ],
    )
    @pytest.mark.parametrize(
        "to_dtype",
        [
            "uint32",
            "uint64",
            "int32",
            "int64",
            "bfloat16",
            "float32",
            "float64",
        ],
    )
    def test_convert(from_dtype: DTypeLike, to_dtype: DTypeLike):
        def convert(x: jax.Array) -> jax.Array:
            return jlx.convert_element_type(x, new_dtype=jnp.dtype(to_dtype))
    
        x_shape = (32, 32)  # Shape does not make any impact here, thus not parametrized.
        input = random_tensor(x_shape, dtype=from_dtype)
    
>       run_op_test(convert, [input])

tests/jax/ops/test_convert.py:40: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/infra/op_tester.py:69: in run_op_test
    tester.test(workload)
tests/infra/op_tester.py:35: in test
    tt_res = DeviceRunner.run_on_tt_device(compiled_workload)
tests/infra/device_runner.py:24: in run_on_tt_device
    return DeviceRunner._run_on_device(DeviceType.TT, workload)
tests/infra/device_runner.py:73: in _run_on_device
    return device_workload.execute()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = Workload(executable=<PjitFunction of <function test_convert.<locals>.convert at 0x7f57788ad750>>, args=[Array([[0.1015...   [0.265625, 0.492188, 0.78125, ..., 0.75, 0.429688, 0.820312]],      dtype=bfloat16)], kwargs={}, static_argnames=[])

    def execute(self) -> Any:
        """Calls callable passing stored args and kwargs directly."""
>       return self.executable(*self.args, **self.kwargs)
E       jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error code: 13

tests/infra/workload.py:29: XlaRuntimeError

Check failure on line 40 in tests/jax/ops/test_convert.py

See this annotation in the file changed.

@github-actions github-actions / TT-XLA Tests

test_convert.test_convert[uint64-float32]

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error code: 13
Raw output
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

from_dtype = 'float32', to_dtype = 'uint64'

    @pytest.mark.parametrize(
        "from_dtype",
        [
            "bfloat16",
            "float32",
        ],
    )
    @pytest.mark.parametrize(
        "to_dtype",
        [
            "uint32",
            "uint64",
            "int32",
            "int64",
            "bfloat16",
            "float32",
            "float64",
        ],
    )
    def test_convert(from_dtype: DTypeLike, to_dtype: DTypeLike):
        def convert(x: jax.Array) -> jax.Array:
            return jlx.convert_element_type(x, new_dtype=jnp.dtype(to_dtype))
    
        x_shape = (32, 32)  # Shape does not make any impact here, thus not parametrized.
        input = random_tensor(x_shape, dtype=from_dtype)
    
>       run_op_test(convert, [input])

tests/jax/ops/test_convert.py:40: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/infra/op_tester.py:69: in run_op_test
    tester.test(workload)
tests/infra/op_tester.py:35: in test
    tt_res = DeviceRunner.run_on_tt_device(compiled_workload)
tests/infra/device_runner.py:24: in run_on_tt_device
    return DeviceRunner._run_on_device(DeviceType.TT, workload)
tests/infra/device_runner.py:73: in _run_on_device
    return device_workload.execute()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = Workload(executable=<PjitFunction of <function test_convert.<locals>.convert at 0x7f57788aeef0>>, args=[Array([[0.0947...2062037, 0.7978035 , ..., 0.9188044 , 0.5129981 ,
        0.53420603]], dtype=float32)], kwargs={}, static_argnames=[])

    def execute(self) -> Any:
        """Calls callable passing stored args and kwargs directly."""
>       return self.executable(*self.args, **self.kwargs)
E       jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Error code: 13

tests/infra/workload.py:29: XlaRuntimeError