Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] fused_add_rmsnorm Fails Due to Misaligned Address #634

Closed
ovowei opened this issue Nov 24, 2024 · 1 comment
Closed

[Bug] fused_add_rmsnorm Fails Due to Misaligned Address #634

ovowei opened this issue Nov 24, 2024 · 1 comment

Comments

@ovowei
Copy link

ovowei commented Nov 24, 2024

Runtime Error: Misaligned Address in fused_add_rmsnorm with hidden_dim=3584

I encountered a runtime error when using the fused_add_rmsnorm operator with a model configured for hidden_dim=3584 (28*128). The error message is as follows:

RuntimeError: CUDA error: misaligned address

This issue can be reproduced by modifying the test case in tests/test_norm.py. Specifically, setting:

@pytest.mark.parametrize("hidden_size", [3584])

will trigger the error during testing.

Upon investigation, I identified the problematic line in the code:

const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

located [here](

const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
). As a temporary workaround, I replaced it with:

const uint32_t vec_size = 1;

This change resolves the issue; however, it may lead to performance degradation.
Below is the detailed pytest error output for reference:

collected 1 item                                                                                                                                                                                                                                                                                  

test_norm.py F                                                                                                                                                                                                                                                                              [100%]

============================================================================================================================================ FAILURES =============================================================================================================================================
______________________________________________________________________________________________________________________________ test_fused_add_rmsnorm[dtype0-3584-1] ______________________________________________________________________________________________________________________________

self = <[RuntimeError('CUDA error: misaligned address\nCUDA kernel errors might be asynchronously reported at some other API ...h `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n') raised in repr()] TensorLikePair object at 0x7f74a4e83cd0>

    def compare(self) -> None:
        actual, expected = self.actual, self.expected
    
        self._compare_attributes(actual, expected)
        if any(input.device.type == "meta" for input in (actual, expected)):
            return
    
        actual, expected = self._equalize_attributes(actual, expected)
>       self._compare_values(actual, expected)

/opt/conda/lib/python3.10/site-packages/torch/testing/_comparison.py:713: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/opt/conda/lib/python3.10/site-packages/torch/testing/_comparison.py:831: in _compare_values
    compare_fn(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <[RuntimeError('CUDA error: misaligned address\nCUDA kernel errors might be asynchronously reported at some other API ...h `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n') raised in repr()] TensorLikePair object at 0x7f74a4e83cd0>
actual = <[RuntimeError('CUDA error: misaligned address\nCUDA kernel errors might be asynchronously reported at some other API ...pile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n') raised in repr()] Tensor object at 0x7f74a4e6dfd0>
expected = <[RuntimeError('CUDA error: misaligned address\nCUDA kernel errors might be asynchronously reported at some other API ...pile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n') raised in repr()] Tensor object at 0x7f74a4e6f1a0>

    def _compare_regular_values_close(
        self,
        actual: torch.Tensor,
        expected: torch.Tensor,
        *,
        rtol: float,
        atol: float,
        equal_nan: bool,
        identifier: Optional[Union[str, Callable[[str], str]]] = None,
    ) -> None:
        """Checks if the values of two tensors are close up to a desired tolerance."""
>       matches = torch.isclose(
            actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan
        )
E       RuntimeError: CUDA error: misaligned address
E       CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
E       For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
E       Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

/opt/conda/lib/python3.10/site-packages/torch/testing/_comparison.py:1010: RuntimeError

During handling of the above exception, another exception occurred:

batch_size = 1, hidden_size = 3584, dtype = torch.float16

    @pytest.mark.parametrize("batch_size", [1])
    @pytest.mark.parametrize("hidden_size", [3584])
    # @pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
    # @pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192])
    @pytest.mark.parametrize("dtype", [torch.float16])
    def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
        eps = 1e-6
    
        x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
        residual = torch.randn_like(x)
        weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
    
        x_native, residual_native = fused_add_rms_norm(
            x.clone(), residual.clone(), weight, eps
        )
    
        x_fused = x.clone()
        residual_fused = residual.clone()
        flashinfer.fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
    
>       torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)

test_norm.py:104: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/opt/conda/lib/python3.10/site-packages/torch/testing/_comparison.py:381: in __repr__
    body = [
/opt/conda/lib/python3.10/site-packages/torch/testing/_comparison.py:382: in <listcomp>
    f"    {name}={value!s},"
/opt/conda/lib/python3.10/site-packages/torch/_tensor.py:461: in __repr__
    return torch._tensor_str._str(self, tensor_contents=tensor_contents)
/opt/conda/lib/python3.10/site-packages/torch/_tensor_str.py:677: in _str
    return _str_intern(self, tensor_contents=tensor_contents)
/opt/conda/lib/python3.10/site-packages/torch/_tensor_str.py:597: in _str_intern
    tensor_str = _tensor_str(self, indent)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <[RuntimeError('CUDA error: misaligned address\nCUDA kernel errors might be asynchronously reported at some other API ...pile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n') raised in repr()] Tensor object at 0x7f74a4e6dfd0>, indent = 7

    def _tensor_str(self, indent):
        if self.numel() == 0:
            return "[]"
    
        if self.has_names():
            # There are two main codepaths (possibly more) that tensor printing goes through:
            # - tensor data can fit comfortably on screen
            # - tensor data needs to be summarized
            # Some of the codepaths don't fully support named tensors, so we send in
            # an unnamed tensor to the formatting code as a workaround.
            self = self.rename(None)
    
        summarize = self.numel() > PRINT_OPTS.threshold
    
        if self._is_zerotensor():
            self = self.clone()
    
        # handle the negative bit
        if self.is_neg():
            self = self.resolve_neg()
    
        if self.dtype in [
            torch.float16,
            torch.bfloat16,
            torch.float8_e5m2,
            torch.float8_e5m2fnuz,
            torch.float8_e4m3fn,
            torch.float8_e4m3fnuz,
        ]:
>           self = self.float()
E           RuntimeError: CUDA error: misaligned address
E           CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
E           For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
E           Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

/opt/conda/lib/python3.10/site-packages/torch/_tensor_str.py:331: RuntimeError
-------------------------------------------------------------------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------------------------------------------------------------------
2024-11-24 15:57:37,239 - INFO - flashinfer.jit: Loading JIT ops: norm
===================================================================================================================================== short test summary info =====================================================================================================================================
FAILED test_norm.py::test_fused_add_rmsnorm[dtype0-3584-1] - RuntimeError: CUDA error: misaligned address
======================================================================================================================================= 1 failed in 46.00s ========================================================================================================================================
yzh119 added a commit that referenced this issue Nov 25, 2024
…apes (#636)

This PR fixes the issue #634, which is brought by #592 .
If we want to use 16-bytes vectorized read/write, we need to confirm the
address is aligned to 16 bytes.
When `num_warps` is not a multiple of 4 (4*sizeof(float) = 16), the
address of `smem + num_warps` might not align to 16 bytes.

We can fix this by shifting the start offset of vectorized read/write to
`smem + ceil_div(num_warps, 4) * 4` to force the alignment.

cc @ovowei @Abatom
@yzh119
Copy link
Collaborator

yzh119 commented Nov 25, 2024

Thanks for spotting this bug, should have been fixed in #636 :)

@yzh119 yzh119 closed this as completed Nov 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants