Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MatMul and blocksparse matmul incorrect precision in some shape. #1808

Open
Qu-Xiangjun opened this issue Jun 21, 2023 · 3 comments
Open

MatMul and blocksparse matmul incorrect precision in some shape. #1808

Qu-Xiangjun opened this issue Jun 21, 2023 · 3 comments

Comments

@Qu-Xiangjun
Copy link

Using an Triton 2.0.0, Pytorch 2.0.0, Python 3.9.16, Cuda 11.6 on a pc running Centos release 7.4.1708 with an nvidia A100. I using the matmul and blocksparse/matmul ops in https://github.com/openai/triton/tree/main/python/triton/ops . And I using the test code like to [test_matmul.py](https://github.com/openai/triton/blob/main/python/test/unit/operators/test_matmul.py) and [test_blocksparse.py](https://github.com/openai/triton/blob/main/python/test/unit/operators/test_blocksparse.py).

Then I find some problem when I compare the tirton matmul with torch.matmul, the result is different by torch.allclose(atol = 1e-5, rtol=0) as follow:

Matmul Test

the tesing code as follow:

import torch
import triton

M, N, K = 2048, 2048, 2048
torch.manual_seed(0)
a = torch.randn((M,K), device = 'cuda', dtype = torch.float16)
b = torch.randn((K,N), device = 'cuda', dtype = torch.float16)
# compute torch
torch_output = torch.matmul(a, b)
# compute triton
triton_output = triton.ops.matmul(a, b)

# compare
diff = torch.sum(torch.abs(triton_output - torch_output))
print("total difference: {:10f}".format(diff))

if(torch.allclose(triton_output, torch_output, atol = 1e-5, rtol = 0)):
    print("✅ Triton and Torch match")
else:
    print("❌ Triton and Torch differ")

This code will print total difference more than 0.0, and the torch.allclose is return false.

Then I tried observed some character:

  1. the diff increasing as the shape increase. I guess it maybe related from cumulative accuracy of the calculation. But when I using M,K,N = 4096,4096,4096 running this code in my machine, it's pass ✅ the allclose function and diff = 0.000000. It's also related with shape? Because only some shape will occur the problem.

  2. Moreover, I had try some special data to test in shape M, N, K = 2048, 2048, 2048.

    • I take the a = torch.ones ,b = torch.ones to run the code, which result is always pass ✅. So in some times this don't related from shape.

    • I take the a = torch.ones ,b = torch.randn to run the code, which every row for the result matrix is same, also same in the incorrect elements.

Blocksparse Matmul Test

The incorrect precision also in blocksparse matmul function. the test code as follow, which only using the forward testing for [test_blocksparse.py](https://github.com/openai/triton/blob/main/python/test/unit/operators/test_blocksparse.py) :

def sparsify_tensor(x, mask, block):
    ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
    for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
        ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
    return ret

def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None,
                 dtype=torch.float32):
    if data is None:
        data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device)
    ref_ret = data
    ref_ret = ref_ret * alpha + beta
    ref_ret = ref_ret.half().to(dtype)
    if trans:
        ref_ret = ref_ret.t().requires_grad_()
    ref_ret = ref_ret.detach().requires_grad_()
    tri_ret = ref_ret.clone().detach().requires_grad_()
    return ref_ret, tri_ret

def mask_tensor(x, mask, block, value=0):
    ret = x.clone()
    for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
        ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
    return ret

def test_blocksparsematmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
    seed = 0
    torch.manual_seed(seed)
    is_sdd = MODE == "sdd"
    is_dsd = MODE == "dsd"
    is_dds = MODE == "dds"
    do_sparsify = lambda x: sparsify_tensor(x, layout, BLOCK)
    do_mask = lambda x: mask_tensor(x, layout, BLOCK)


    # create inputs
    # create op
    a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K)
    b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N)
    c_shape = (Z, H, M, N)
    shape = {
        "sdd": (M, N),
        "dsd": (a_shape[2], a_shape[3]),
        "dds": (b_shape[2], b_shape[3]),
    }[MODE]
    layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))

    # create data
    a_ref, a_tri = make_pair(a_shape, alpha=.1, dtype=DTYPE)
    b_ref, b_tri = make_pair(b_shape, alpha=.1, dtype=DTYPE)
    dc_ref, dc_tri = make_pair(c_shape, dtype=DTYPE)

    # compute [torch]
    a_ref = do_mask(a_ref) if is_dsd else a_ref
    b_ref = do_mask(b_ref) if is_dds else b_ref
    a_ref.retain_grad()
    b_ref.retain_grad()
    c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
                         b_ref.transpose(2, 3) if TRANS_B else b_ref)
    c_ref = do_sparsify(c_ref) if is_sdd else c_ref      
    # dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
    # c_ref.backward(dc_ref) 
    # da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad
    # db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad

    # triton result
    a_tri = do_sparsify(a_tri) if is_dsd else a_tri
    b_tri = do_sparsify(b_tri) if is_dds else b_tri
    a_tri.retain_grad()
    b_tri.retain_grad()
    # op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
    op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
    c_tri = op(a_tri, b_tri)
    # dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri
    # c_tri.backward(dc_tri)
    # da_tri = a_tri.grad
    # db_tri = b_tri.grad

    # compare
    print("--------------------------------------------------------------")
    perf = lambda ms: 2 * M * N * K * Z * H * 1e9 / ( ms * 1e-3)
    total_op = 2 * M * N * K * Z * H

    print('''MODE={}, Z={}, H={}, M={}, N={}, K={}, total_op={}. '''
            .format(MODE,Z, H, M, N, K, total_op))

    diff = torch.sum(torch.abs(c_ref - c_tri))
    print('total diff = {:.10f}'.format(diff))

    if(torch.allclose(c_ref, c_tri, atol = 1e-5, rtol = 0)):
        print("✅ Triton and Torch match")
    else:
        print("❌ Triton and Torch differ")

    ms, _, _ = triton.testing.do_bench(lambda: op(a_tri, b_tri), rep = 20)
    print('''Triton: GFLOPS: {:.3f}, time: {:.6f}ms.'''.format(perf(ms), ms))

    ms_torch, _, _ = triton.testing.do_bench(
        lambda: torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
                            b_ref.transpose(2, 3) if TRANS_B else b_ref),
        rep = 20
    )
    print('''Torch: GFLOPS: {:.3f}, time: {:.6f}ms.'''.format(perf(ms_torch), ms_torch))

    return perf(ms), perf(ms_torch), diff

test_blocksparsematmul('dds', False, False, 32, torch.float16, Z = 1, H = 2, M = 64, K = 4096, N = 4096)

This code will print total difference more than 0.0, and the torch.allclose is return false.

Then I tried observed some character:

  • In small shape sunch as M, N, K = 256, 256, 256 , the code always pass ✅
  • I tried testing the shape in the range [1,1024] for M and N, K = 4096, 4096, which show the more than half of the range print the ❌ Triton and Torch differ.

So what could be causing the incorrect precision and how to solute the problem?

@992355092
Copy link

I have also encountered this problem. Have you resolved it?

@Qu-Xiangjun
Copy link
Author

Unfortunately, I didn't find the reason.

@Qu-Xiangjun Qu-Xiangjun closed this as not planned Won't fix, can't repro, duplicate, stale Nov 10, 2023
@Qu-Xiangjun Qu-Xiangjun reopened this Nov 10, 2023
@FelixSchoen
Copy link

FelixSchoen commented Mar 21, 2024

I faced a similar issue when playing around with blocksparse matrix multiplication, here is my code:

    import torch
    import triton.ops

    device = torch.device("cuda")
    dtype = torch.float16

    # Parameters
    batch_size = 2
    head_size = 1
    sequence_length = 32
    d_model = 16
    block_size = 16

    use_int_tensors = False
    max_int_val = 20

    # Tensors
    if use_int_tensors:
        tensor_a = torch.randint(1, max_int_val, (head_size, batch_size, sequence_length, d_model), device=device,
                                 dtype=dtype)
        tensor_b = torch.randint(1, max_int_val, (head_size, batch_size, sequence_length, d_model), device=device,
                                 dtype=dtype)
    else:
        tensor_a = torch.rand((head_size, batch_size, sequence_length, d_model), device=device,
                              dtype=dtype)
        tensor_b = torch.rand((head_size, batch_size, sequence_length, d_model), device=device,
                              dtype=dtype)
    identity_matrix = torch.eye(sequence_length, device=device, dtype=dtype).unsqueeze(0).unsqueeze(0).repeat(
        (head_size, batch_size, 1, 1))
    sparsity_layout = torch.tensor([[[1, 1],
                                     [1, 1]],

                                    [[1, 1],
                                     [1, 1]]
                                    ])

    # Triton Matmul
    triton_matmul_sdd = triton.ops.blocksparse.matmul(sparsity_layout, block_size, "sdd", device)
    triton_result = triton_matmul_sdd(tensor_a, torch.transpose(tensor_b, -1, -2))
    triton_matmul_dsd = triton.ops.blocksparse.matmul(sparsity_layout, block_size, "dsd", device)
    triton_identity = triton_matmul_dsd(triton_result, identity_matrix)

    # Conventional Matmul
    conventional_result = torch.matmul(tensor_a, torch.transpose(tensor_b, -1, -2))
    conventional_identity = torch.matmul(conventional_result, identity_matrix)

    assert torch.allclose(conventional_result, conventional_identity)

    # Passes only up to max_int_val ~ 15 if computing with dtype float16 and integer values
    if not torch.allclose(triton_identity, conventional_identity):
        difference = torch.abs(conventional_identity - triton_identity)
        indices = torch.where(difference > 1e-6)

        # Print the differing values
        print("Conventional Identity at differing indices: ", conventional_identity[indices])
        print("Triton Identity at differing indices: ", triton_identity[indices])
        print("Number of differing values: ", len(conventional_identity[indices]))

In this I generate two tensors containing only integer (or floating values if use_int_tensors=False) and multiply the two matrices using torch.matmul and triton.ops.blocksparse.matmul (with a sparsity layout that corresponds to a regular "full" matrix multiplication). When using float32 as dtype I see numerical inaccuracies, with float16 everything seems to work fine.

Am I overlooking something here? For now I'll stick to float16, please let me know if there are other ways of achieving better precision!

ZzEeKkAa pushed a commit to ZzEeKkAa/triton that referenced this issue Aug 16, 2024
This PR fixes triton-lang#1176 
IGC detects the call of `__devicelib_assert_fail` and replace it with a
'safe' implementation.
However, the SYCL library contains a 'fallback' implementation of
assertion, which does not work in our setup.
If we mark the function with `InternalLinkage`, the fallback
implementation is inlined and IGC cannot replace it with the safe
implementation.
By declaring `__devicelib_assert_fail` as an external function in SYCL
library, IGC can correctly insert its implementation.
The diff between the old and new `libsycl-spir64-unknown-unknown.ll` is
as follows:
```diff
@@ -5424,149 +5424,7 @@ declare extern_weak dso_local spir_func noundef i32 @_Z18__spirv_AtomicLoadPU3AS
 declare void @llvm.memcpy.p4.p1.i64(ptr addrspace(4) noalias nocapture writeonly, ptr addrspace(1) noalias nocapture readonly, i64, i1 immarg) triton-lang#16
 
 ; Function Attrs: convergent mustprogress norecurse nounwind
-define weak dso_local spir_func void @__devicelib_assert_fail(ptr addrspace(4) noundef %0, ptr addrspace(4) noundef %1, i32 noundef %2, ptr addrspace(4) noundef %3, i64 noundef %4, i64 noundef %5, i64 noundef %6, i64 noundef %7, i64 noundef %8, i64 noundef %9) local_unnamed_addr triton-lang#14 !srcloc !720 {
-  %11 = tail call spir_func noundef i32 @_Z29__spirv_AtomicCompareExchangePU3AS1iN5__spv5Scope4FlagENS1_19MemorySemanticsMask4FlagES5_ii(ptr addrspace(1) noundef @SPIR_AssertHappenedMem, i32 noundef 1, i32 noundef 16, i32 noundef 16, i32 noundef 1, i32 noundef 0) triton-lang#54
-  %12 = icmp eq i32 %11, 0
-  br i1 %12, label %13, label %92
-
-13:                                               ; preds = %10
-  store i32 %2, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 4), align 8, !tbaa !721
-  store i64 %4, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 5), align 8, !tbaa !722
-  store i64 %5, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 6), align 8, !tbaa !723
-  store i64 %6, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 7), align 8, !tbaa !724
-  store i64 %7, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 8), align 8, !tbaa !725
-  store i64 %8, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 9), align 8, !tbaa !726
-  store i64 %9, ptr addrspace(1) getelementptr inbounds (%struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 10), align 8, !tbaa !727
-  %14 = icmp eq ptr addrspace(4) %0, null
-  br i1 %14, label %23, label %15
-
-15:                                               ; preds = %20, %13
-  %16 = phi i32 [ %22, %20 ], [ 0, %13 ]
-  %17 = phi ptr addrspace(4) [ %21, %20 ], [ %0, %13 ]
-  %18 = load i8, ptr addrspace(4) %17, align 1, !tbaa !718
-  %19 = icmp eq i8 %18, 0
-  br i1 %19, label %23, label %20
-
-20:                                               ; preds = %15
-  %21 = getelementptr inbounds i8, ptr addrspace(4) %17, i64 1
-  %22 = add nuw nsw i32 %16, 1
-  br label %15, !llvm.loop !728
-
-23:                                               ; preds = %15, %13
-  %24 = phi i32 [ 0, %13 ], [ %16, %15 ]
-  %25 = icmp eq ptr addrspace(4) %1, null
-  br i1 %25, label %34, label %26
-
-26:                                               ; preds = %31, %23
-  %27 = phi i32 [ %33, %31 ], [ 0, %23 ]
-  %28 = phi ptr addrspace(4) [ %32, %31 ], [ %1, %23 ]
-  %29 = load i8, ptr addrspace(4) %28, align 1, !tbaa !718
-  %30 = icmp eq i8 %29, 0
-  br i1 %30, label %34, label %31
-
-31:                                               ; preds = %26
-  %32 = getelementptr inbounds i8, ptr addrspace(4) %28, i64 1
-  %33 = add nuw nsw i32 %27, 1
-  br label %26, !llvm.loop !729
-
-34:                                               ; preds = %26, %23
-  %35 = phi i32 [ 0, %23 ], [ %27, %26 ]
-  %36 = icmp eq ptr addrspace(4) %3, null
-  br i1 %36, label %37, label %40
-
-37:                                               ; preds = %34
-  %38 = tail call i32 @llvm.umin.i32(i32 %24, i32 256)
-  %39 = tail call i32 @llvm.umin.i32(i32 %35, i32 256)
-  br label %52
-
-40:                                               ; preds = %45, %34
-  %41 = phi i32 [ %47, %45 ], [ 0, %34 ]
-  %42 = phi ptr addrspace(4) [ %46, %45 ], [ %3, %34 ]
-  %43 = load i8, ptr addrspace(4) %42, align 1, !tbaa !718
-  %44 = icmp eq i8 %43, 0
-  br i1 %44, label %48, label %45
-
-45:                                               ; preds = %40
-  %46 = getelementptr inbounds i8, ptr addrspace(4) %42, i64 1
-  %47 = add i32 %41, 1
-  br label %40, !llvm.loop !730
-
-48:                                               ; preds = %40
-  %49 = tail call i32 @llvm.umin.i32(i32 %24, i32 256)
-  %50 = tail call i32 @llvm.umin.i32(i32 %35, i32 256)
-  %51 = tail call i32 @llvm.umin.i32(i32 %41, i32 128)
-  br label %52
-
-52:                                               ; preds = %48, %37
-  %53 = phi i32 [ %39, %37 ], [ %50, %48 ]
-  %54 = phi i32 [ %38, %37 ], [ %49, %48 ]
-  %55 = phi i32 [ 0, %37 ], [ %51, %48 ]
-  br label %56
-
-56:                                               ; preds = %62, %52
-  %57 = phi i32 [ 0, %52 ], [ %67, %62 ]
-  %58 = icmp ult i32 %57, %54
-  br i1 %58, label %62, label %59
-
-59:                                               ; preds = %56
-  %60 = zext nneg i32 %54 to i64
-  %61 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 1, i64 %60
-  store i8 0, ptr addrspace(1) %61, align 1, !tbaa !718
-  br label %68
-
-62:                                               ; preds = %56
-  %63 = sext i32 %57 to i64
-  %64 = getelementptr inbounds i8, ptr addrspace(4) %0, i64 %63
-  %65 = load i8, ptr addrspace(4) %64, align 1, !tbaa !718
-  %66 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 1, i64 %63
-  store i8 %65, ptr addrspace(1) %66, align 1, !tbaa !718
-  %67 = add nuw nsw i32 %57, 1
-  br label %56, !llvm.loop !731
-
-68:                                               ; preds = %74, %59
-  %69 = phi i32 [ 0, %59 ], [ %79, %74 ]
-  %70 = icmp ult i32 %69, %53
-  br i1 %70, label %74, label %71
-
-71:                                               ; preds = %68
-  %72 = zext nneg i32 %53 to i64
-  %73 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 2, i64 %72
-  store i8 0, ptr addrspace(1) %73, align 1, !tbaa !718
-  br label %80
-
-74:                                               ; preds = %68
-  %75 = sext i32 %69 to i64
-  %76 = getelementptr inbounds i8, ptr addrspace(4) %1, i64 %75
-  %77 = load i8, ptr addrspace(4) %76, align 1, !tbaa !718
-  %78 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 2, i64 %75
-  store i8 %77, ptr addrspace(1) %78, align 1, !tbaa !718
-  %79 = add nuw nsw i32 %69, 1
-  br label %68, !llvm.loop !732
-
-80:                                               ; preds = %86, %71
-  %81 = phi i32 [ 0, %71 ], [ %91, %86 ]
-  %82 = icmp ult i32 %81, %55
-  br i1 %82, label %86, label %83
-
-83:                                               ; preds = %80
-  %84 = sext i32 %55 to i64
-  %85 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 3, i64 %84
-  store i8 0, ptr addrspace(1) %85, align 1, !tbaa !718
-  tail call spir_func void @_Z19__spirv_AtomicStorePU3AS1iN5__spv5Scope4FlagENS1_19MemorySemanticsMask4FlagEi(ptr addrspace(1) noundef @SPIR_AssertHappenedMem, i32 noundef 1, i32 noundef 16, i32 noundef 2) triton-lang#54
-  br label %92
-
-86:                                               ; preds = %80
-  %87 = sext i32 %81 to i64
-  %88 = getelementptr inbounds i8, ptr addrspace(4) %3, i64 %87
-  %89 = load i8, ptr addrspace(4) %88, align 1, !tbaa !718
-  %90 = getelementptr inbounds %struct.AssertHappened, ptr addrspace(1) @SPIR_AssertHappenedMem, i64 0, i32 3, i64 %87
-  store i8 %89, ptr addrspace(1) %90, align 1, !tbaa !718
-  %91 = add nuw nsw i32 %81, 1
-  br label %80, !llvm.loop !733
-
-92:                                               ; preds = %83, %10
-  ret void
-}
+declare extern_weak dso_local spir_func void @__devicelib_assert_fail(ptr addrspace(4) noundef %0, ptr addrspace(4) noundef %1, i32 noundef %2, ptr addrspace(4) noundef %3, i64 noundef %4, i64 noundef %5, i64 noundef %6, i64 noundef %7, i64 noundef %8, i64 noundef %9) local_unnamed_addr triton-lang#14
 
 ; Function Attrs: convergent nounwind
 declare extern_weak dso_local spir_func noundef i32 @_Z29__spirv_AtomicCompareExchangePU3AS1iN5__spv5Scope4FlagENS1_19MemorySemanticsMask4FlagES5_ii(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr triton-lang#15

```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants