Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EXTREMELY SERIOUS BUGS THAT MAKE BLOCKSPARSE COMPLETELY USELESS IN TRAINING #419

Closed
btyu opened this issue Jan 6, 2022 · 13 comments
Closed

Comments

@btyu
Copy link
Contributor

btyu commented Jan 6, 2022

I found that the blocksparse ops' backward gradient is totally wrong, which makes the training meaningless. Take matmul (mode='sdd') as a simple example, the following is my test code. As you can see, the forward result is the same as pytorch's implementation, while the backward is wrong, with big errors, even when layout is a dense one.

I tested the old version, and the problem still exists. For other operators or other modes, it seems the gradient result is still wrong. My environment is:

- python: 3.7
- pytorch: 1.10.1
- triton: 1.1.1
- CUDA: 11.3.1
- GPU: Nvidia Tesla V100
import torch
from triton.ops.blocksparse.matmul import matmul as TtMatmul
from triton.ops.blocksparse.softmax import softmax as TtSoftmax

# ===== Settings =====
batch = 1
head = 1
len1 = 1024
len2 = 768
block = 64
dim = 64

device = 'cuda:0'
requires_grad = True
dtype = torch.float32
same_heads = False  # all heads have a same attention pattern
dense_layout = False  # use the dense layout

repeat = 1000  # test time
raise_size_error = True  # if the output size does not match, raise error
do_not_use_full_zero_line_layout = True  # ignore the layout that has a row or column full of zero

# --- check input ---
assert block in (16, 32, 64, 128)
chunk1 = len1 // block
chunk2 = len2 // block
assert chunk1 * block == len1
assert chunk2 * block == len2

# ===== Basic Functions =====
def get_triton_matmul(*args, **kwargs):
    return TtMatmul(*args, **kwargs)

def get_triton_softmax(*args, **kwargs):
    return TtSoftmax(*args, **kwargs)

get_matmul = get_triton_matmul
get_softmax = get_triton_softmax

def layout_full_zero_check(layout):
    row_check = layout.sum(dim=2).eq(0).any()  # (H, L // block)
    col_check = layout.sum(dim=1).eq(0).any()
    return row_check or col_check

def generate_layout(num_heads, num_chunks_1, num_chunks_2, dtype=torch.long, device='cpu'):
    if dense_layout:
        layout = torch.ones(
            1 if same_heads else num_heads, num_chunks_1, num_chunks_2, 
            dtype=dtype, 
            device=device
        )
        num_selected_blocks = layout.numel()
    else:
        while True:
            layout = torch.randint(
                0, 2, (1 if same_heads else num_heads, num_chunks_1, num_chunks_2), 
                dtype=dtype, 
                device=device
            )
            num_selected_blocks = layout.sum().item()
            if num_selected_blocks > 1:
                if do_not_use_full_zero_line_layout and layout_full_zero_check(layout):
                    continue
                break
    if same_heads and num_heads > 1:
        layout = layout.expand(num_heads, -1, -1)
    
    return layout, num_selected_blocks


# ===== Test SDD Matmul =====

ka = 0
kb = 0
for _ in range(repeat):
    layout, num_selected_blocks = generate_layout(head, chunk1, chunk2)

    a = torch.rand((batch, head, len1, dim), dtype=dtype, device=device)
    b = torch.rand((batch, head, len2, dim), dtype=dtype, device=device)

    a_pytorch = a.clone()
    b_pytorch = b.clone()

    if requires_grad:
        for item in (a, b, a_pytorch, b_pytorch):
            item.requires_grad_()
    
    dot = get_matmul(layout, block, mode='sdd', trans_a=False, trans_b=True)

    c = dot(a, b)  # (batch, num_selected_blocks, block, block)

    try:
        assert c.shape[1] == num_selected_blocks
    except AssertionError:
        print('SIZE ERROR: %d\t%d' % (c.shape[1], num_selected_blocks))
        if raise_size_error:
            raise

    c_pytorch = torch.bmm(a_pytorch.view(-1, len1, dim), b_pytorch.view(-1, len2, dim).transpose(-2, -1))  # (batch * head, len1, len2)
    c_pytorch = c_pytorch.view(batch, head, chunk1, block, chunk2, block)
    c_pytorch = c_pytorch.permute(0, 1, 2, 4, 3, 5)
    c_pytorch = c_pytorch.masked_select(layout.bool().to(device)[None, :, :, :, None, None]).view(batch, -1, block, block)  # (batch, num_selected_blocks, block, block)

    assert torch.allclose(c, c_pytorch)

    sum_c, sum_c_pytorch = c.sum(), c_pytorch.sum()

    assert torch.allclose(sum_c, sum_c_pytorch)

    if requires_grad:
        sum_c.backward()
        sum_c_pytorch.backward()

        try:
            assert torch.allclose(a.grad, a_pytorch.grad)
        except AssertionError:
            ka += 1
            not_same = a.grad != a_pytorch.grad
            print('a.grad ERROR', (a.grad[not_same] - a_pytorch.grad[not_same]).abs().max(), float(not_same.sum() / a.grad.numel()))
        try:
            assert torch.allclose(b.grad, b_pytorch.grad)
        except AssertionError:
            kb += 1
            not_same = b.grad != b_pytorch.grad
            print('b.grad ERROR', (b.grad[not_same] - b_pytorch.grad[not_same]).abs().max(), float(not_same.sum() / b.grad.numel()))

if requires_grad:
    print('dismatch for a.grad:', ka, ka / repeat)
    print('dismatch for b.grad:', kb, kb /repeat)

One of my running:

a.grad ERROR tensor(1.3109e+09, device='cuda:0') 1.0
b.grad ERROR tensor(2.5181e+10, device='cuda:0') 1.0
a.grad ERROR tensor(1.4002e+12, device='cuda:0') 1.0
b.grad ERROR tensor(2.7797e+13, device='cuda:0') 1.0
a.grad ERROR tensor(3715.7754, device='cuda:0') 1.0
b.grad ERROR tensor(5357.4009, device='cuda:0') 1.0
a.grad ERROR tensor(2650638.7500, device='cuda:0') 1.0
b.grad ERROR tensor(5492.8169, device='cuda:0') 1.0
a.grad ERROR tensor(4457.6938, device='cuda:0') 1.0
b.grad ERROR tensor(5327.3140, device='cuda:0') 1.0
a.grad ERROR tensor(3303.7292, device='cuda:0') 1.0
b.grad ERROR tensor(4878.7446, device='cuda:0') 1.0
a.grad ERROR tensor(4601.9453, device='cuda:0') 1.0
b.grad ERROR tensor(6078.3545, device='cuda:0') 1.0
a.grad ERROR tensor(3987.0869, device='cuda:0') 1.0
b.grad ERROR tensor(5731.9072, device='cuda:0') 1.0
a.grad ERROR tensor(3546.3162, device='cuda:0') 1.0
b.grad ERROR tensor(6374.2583, device='cuda:0') 1.0
a.grad ERROR tensor(2690220.7500, device='cuda:0') 1.0
b.grad ERROR tensor(5759.3281, device='cuda:0') 1.0
a.grad ERROR tensor(142.5978, device='cuda:0') 1.0
b.grad ERROR tensor(11595.9834, device='cuda:0') 1.0
a.grad ERROR tensor(4336.1611, device='cuda:0') 1.0
b.grad ERROR tensor(5793.8066, device='cuda:0') 1.0
a.grad ERROR tensor(131.8495, device='cuda:0') 1.0
b.grad ERROR tensor(15427.1631, device='cuda:0') 1.0
a.grad ERROR tensor(152.3687, device='cuda:0') 1.0
b.grad ERROR tensor(700640.1250, device='cuda:0') 1.0
a.grad ERROR tensor(3544.4004, device='cuda:0') 1.0
b.grad ERROR tensor(5818.4619, device='cuda:0') 1.0
a.grad ERROR tensor(4417.7246, device='cuda:0') 1.0
b.grad ERROR tensor(5654.2842, device='cuda:0') 1.0
a.grad ERROR tensor(2918971.7500, device='cuda:0') 1.0
b.grad ERROR tensor(5756.5093, device='cuda:0') 1.0
a.grad ERROR tensor(2656234.7500, device='cuda:0') 1.0
b.grad ERROR tensor(4773.9717, device='cuda:0') 1.0
a.grad ERROR tensor(2660450.5000, device='cuda:0') 1.0
b.grad ERROR tensor(6892.2666, device='cuda:0') 1.0
a.grad ERROR tensor(4277.8979, device='cuda:0') 1.0
b.grad ERROR tensor(5337.3726, device='cuda:0') 1.0
a.grad ERROR tensor(2650022.2500, device='cuda:0') 1.0
b.grad ERROR tensor(1.1324e+08, device='cuda:0') 1.0
a.grad ERROR tensor(359.3810, device='cuda:0') 1.0
b.grad ERROR tensor(5316.9575, device='cuda:0') 1.0
a.grad ERROR tensor(356.5109, device='cuda:0') 1.0
b.grad ERROR tensor(4588.6304, device='cuda:0') 1.0
a.grad ERROR tensor(2628183.7500, device='cuda:0') 1.0
b.grad ERROR tensor(4068.6719, device='cuda:0') 1.0
a.grad ERROR tensor(289.8997, device='cuda:0') 1.0
b.grad ERROR tensor(4290.9443, device='cuda:0') 1.0
a.grad ERROR tensor(256.7503, device='cuda:0') 1.0
b.grad ERROR tensor(4042.9343, device='cuda:0') 1.0
a.grad ERROR tensor(2634063.5000, device='cuda:0') 1.0
b.grad ERROR tensor(351.9893, device='cuda:0') 1.0
a.grad ERROR tensor(287.7516, device='cuda:0') 1.0
b.grad ERROR tensor(353.8267, device='cuda:0') 1.0
a.grad ERROR tensor(2627630.2500, device='cuda:0') 1.0
b.grad ERROR tensor(386.0250, device='cuda:0') 1.0
a.grad ERROR tensor(286.2173, device='cuda:0') 1.0
b.grad ERROR tensor(322.8841, device='cuda:0') 1.0
a.grad ERROR tensor(2692211.2500, device='cuda:0') 1.0
b.grad ERROR tensor(321.6712, device='cuda:0') 1.0
a.grad ERROR tensor(286.0741, device='cuda:0') 1.0
b.grad ERROR tensor(325.7232, device='cuda:0') 1.0
a.grad ERROR tensor(2567459.5000, device='cuda:0') 1.0
b.grad ERROR tensor(354.9262, device='cuda:0') 1.0
a.grad ERROR tensor(286.2664, device='cuda:0') 1.0
b.grad ERROR tensor(419.1502, device='cuda:0') 1.0
a.grad ERROR tensor(2686237.7500, device='cuda:0') 1.0
b.grad ERROR tensor(389.2686, device='cuda:0') 1.0
a.grad ERROR tensor(255.2152, device='cuda:0') 1.0
b.grad ERROR tensor(286.1020, device='cuda:0') 1.0
a.grad ERROR tensor(2576712.5000, device='cuda:0') 1.0
b.grad ERROR tensor(358.9259, device='cuda:0') 1.0
a.grad ERROR tensor(289.7795, device='cuda:0') 1.0
b.grad ERROR tensor(322.9184, device='cuda:0') 1.0
a.grad ERROR tensor(2665924.7500, device='cuda:0') 1.0
b.grad ERROR tensor(384.9941, device='cuda:0') 1.0
a.grad ERROR tensor(289.0149, device='cuda:0') 1.0
b.grad ERROR tensor(392.9610, device='cuda:0') 1.0
a.grad ERROR tensor(2662453.5000, device='cuda:0') 1.0
b.grad ERROR tensor(357.0535, device='cuda:0') 1.0
a.grad ERROR tensor(287.9712, device='cuda:0') 1.0
b.grad ERROR tensor(357.6203, device='cuda:0') 1.0
a.grad ERROR tensor(2590091., device='cuda:0') 1.0
b.grad ERROR tensor(355.0422, device='cuda:0') 1.0
a.grad ERROR tensor(292.4987, device='cuda:0') 1.0
b.grad ERROR tensor(390.0421, device='cuda:0') 1.0
a.grad ERROR tensor(2567506.7500, device='cuda:0') 1.0
b.grad ERROR tensor(360.4383, device='cuda:0') 1.0
a.grad ERROR tensor(255.9906, device='cuda:0') 1.0
b.grad ERROR tensor(353.5637, device='cuda:0') 1.0
a.grad ERROR tensor(2691391.5000, device='cuda:0') 1.0
b.grad ERROR tensor(354.3253, device='cuda:0') 1.0
a.grad ERROR tensor(286.5569, device='cuda:0') 1.0
b.grad ERROR tensor(351.1349, device='cuda:0') 1.0
a.grad ERROR tensor(2614949.2500, device='cuda:0') 1.0
b.grad ERROR tensor(353.5500, device='cuda:0') 1.0
a.grad ERROR tensor(328.6048, device='cuda:0') 1.0
b.grad ERROR tensor(423.7781, device='cuda:0') 1.0
a.grad ERROR tensor(313.4121, device='cuda:0') 1.0
b.grad ERROR tensor(392.1626, device='cuda:0') 1.0
a.grad ERROR tensor(256.8100, device='cuda:0') 1.0
b.grad ERROR tensor(320.8966, device='cuda:0') 1.0
a.grad ERROR tensor(2681082.2500, device='cuda:0') 1.0
b.grad ERROR tensor(320.8078, device='cuda:0') 1.0
a.grad ERROR tensor(257.9711, device='cuda:0') 1.0
b.grad ERROR tensor(355.4615, device='cuda:0') 1.0
a.grad ERROR tensor(289.8659, device='cuda:0') 1.0
b.grad ERROR tensor(353.8013, device='cuda:0') 1.0
a.grad ERROR tensor(284.3614, device='cuda:0') 1.0
b.grad ERROR tensor(353.3466, device='cuda:0') 1.0
a.grad ERROR tensor(296.3285, device='cuda:0') 1.0
b.grad ERROR tensor(359.3662, device='cuda:0') 1.0
a.grad ERROR tensor(319.2778, device='cuda:0') 1.0
b.grad ERROR tensor(327.0521, device='cuda:0') 1.0
a.grad ERROR tensor(291.4463, device='cuda:0') 1.0
b.grad ERROR tensor(450.9517, device='cuda:0') 1.0
a.grad ERROR tensor(352.7974, device='cuda:0') 1.0
b.grad ERROR tensor(388.5782, device='cuda:0') 1.0
a.grad ERROR tensor(254.4195, device='cuda:0') 1.0
b.grad ERROR tensor(390.5658, device='cuda:0') 1.0
a.grad ERROR tensor(256.4656, device='cuda:0') 1.0
b.grad ERROR tensor(1.1066e+08, device='cuda:0') 1.0
a.grad ERROR tensor(255.1293, device='cuda:0') 1.0
b.grad ERROR tensor(353.7872, device='cuda:0') 1.0
a.grad ERROR tensor(288.5960, device='cuda:0') 1.0
b.grad ERROR tensor(393.4873, device='cuda:0') 1.0
a.grad ERROR tensor(256.1300, device='cuda:0') 1.0
b.grad ERROR tensor(355.0295, device='cuda:0') 1.0
a.grad ERROR tensor(323.6937, device='cuda:0') 1.0
b.grad ERROR tensor(365.0252, device='cuda:0') 1.0
a.grad ERROR tensor(287.0813, device='cuda:0') 1.0
b.grad ERROR tensor(387.5172, device='cuda:0') 1.0
a.grad ERROR tensor(253.4359, device='cuda:0') 1.0
b.grad ERROR tensor(324.5351, device='cuda:0') 1.0
a.grad ERROR tensor(290.9200, device='cuda:0') 1.0
b.grad ERROR tensor(321.8120, device='cuda:0') 1.0
a.grad ERROR tensor(323.4057, device='cuda:0') 1.0
b.grad ERROR tensor(321.2264, device='cuda:0') 1.0
a.grad ERROR tensor(326.8586, device='cuda:0') 1.0
b.grad ERROR tensor(356.5294, device='cuda:0') 1.0
a.grad ERROR tensor(291.3673, device='cuda:0') 1.0
b.grad ERROR tensor(402.2349, device='cuda:0') 1.0
a.grad ERROR tensor(322.2409, device='cuda:0') 1.0
b.grad ERROR tensor(390.6744, device='cuda:0') 1.0
a.grad ERROR tensor(317.9997, device='cuda:0') 1.0
b.grad ERROR tensor(358.8085, device='cuda:0') 1.0
a.grad ERROR tensor(287.7614, device='cuda:0') 1.0
b.grad ERROR tensor(423.3942, device='cuda:0') 1.0
a.grad ERROR tensor(286.4266, device='cuda:0') 1.0
b.grad ERROR tensor(355.7695, device='cuda:0') 1.0
a.grad ERROR tensor(322.1042, device='cuda:0') 1.0
b.grad ERROR tensor(417.8162, device='cuda:0') 1.0
a.grad ERROR tensor(285.9888, device='cuda:0') 1.0
b.grad ERROR tensor(324.0033, device='cuda:0') 1.0
a.grad ERROR tensor(286.2515, device='cuda:0') 1.0
b.grad ERROR tensor(421.8573, device='cuda:0') 1.0
a.grad ERROR tensor(318.7636, device='cuda:0') 1.0
b.grad ERROR tensor(324.4319, device='cuda:0') 1.0
a.grad ERROR tensor(319.1130, device='cuda:0') 1.0
b.grad ERROR tensor(358.5825, device='cuda:0') 1.0
a.grad ERROR tensor(318.4194, device='cuda:0') 1.0
b.grad ERROR tensor(292.2460, device='cuda:0') 1.0
a.grad ERROR tensor(288.0099, device='cuda:0') 1.0
b.grad ERROR tensor(350.2174, device='cuda:0') 1.0
a.grad ERROR tensor(2751064.2500, device='cuda:0') 1.0
b.grad ERROR tensor(488.0098, device='cuda:0') 1.0
a.grad ERROR tensor(224.5474, device='cuda:0') 1.0
b.grad ERROR tensor(324.3686, device='cuda:0') 1.0
a.grad ERROR tensor(350.1700, device='cuda:0') 1.0
b.grad ERROR tensor(425.4680, device='cuda:0') 1.0
a.grad ERROR tensor(287.2039, device='cuda:0') 1.0
b.grad ERROR tensor(451.3908, device='cuda:0') 1.0
a.grad ERROR tensor(254.7051, device='cuda:0') 1.0
b.grad ERROR tensor(1.0482e+08, device='cuda:0') 1.0
a.grad ERROR tensor(284.7216, device='cuda:0') 1.0
b.grad ERROR tensor(326.0213, device='cuda:0') 1.0
a.grad ERROR tensor(1.2124e+08, device='cuda:0') 1.0
b.grad ERROR tensor(352.3882, device='cuda:0') 1.0
a.grad ERROR tensor(253.9305, device='cuda:0') 1.0
b.grad ERROR tensor(324.9531, device='cuda:0') 1.0
a.grad ERROR tensor(2662769.5000, device='cuda:0') 1.0
b.grad ERROR tensor(360.7639, device='cuda:0') 1.0
a.grad ERROR tensor(289.1925, device='cuda:0') 1.0
b.grad ERROR tensor(385.9752, device='cuda:0') 1.0
a.grad ERROR tensor(2861392.5000, device='cuda:0') 1.0
b.grad ERROR tensor(320.4403, device='cuda:0') 1.0
a.grad ERROR tensor(2671811.5000, device='cuda:0') 1.0
b.grad ERROR tensor(388.5952, device='cuda:0') 1.0
a.grad ERROR tensor(2979740., device='cuda:0') 1.0
b.grad ERROR tensor(22014.4922, device='cuda:0') 1.0
a.grad ERROR tensor(2979724.2500, device='cuda:0') 1.0
b.grad ERROR tensor(2.0401e+08, device='cuda:0') 1.0
a.grad ERROR tensor(1.4005e+08, device='cuda:0') 1.0
b.grad ERROR tensor(1.0918e+10, device='cuda:0') 1.0
a.grad ERROR tensor(111.5738, device='cuda:0') 1.0
b.grad ERROR tensor(8.8034e+09, device='cuda:0') 1.0
a.grad ERROR tensor(116.8540, device='cuda:0') 1.0
b.grad ERROR tensor(13286.1270, device='cuda:0') 1.0
a.grad ERROR tensor(2574.5845, device='cuda:0') 1.0
b.grad ERROR tensor(59198.0352, device='cuda:0') 1.0
a.grad ERROR tensor(2796279.2500, device='cuda:0') 1.0
b.grad ERROR tensor(290.7633, device='cuda:0') 1.0
a.grad ERROR tensor(257.6186, device='cuda:0') 1.0
b.grad ERROR tensor(351.0964, device='cuda:0') 1.0
a.grad ERROR tensor(254.5203, device='cuda:0') 1.0
b.grad ERROR tensor(297.9558, device='cuda:0') 1.0
a.grad ERROR tensor(259.2481, device='cuda:0') 1.0
b.grad ERROR tensor(270.5669, device='cuda:0') 1.0
a.grad ERROR tensor(288.9724, device='cuda:0') 1.0
b.grad ERROR tensor(352.0467, device='cuda:0') 1.0
a.grad ERROR tensor(257.4395, device='cuda:0') 1.0
b.grad ERROR tensor(352.4917, device='cuda:0') 1.0
a.grad ERROR tensor(2630416.5000, device='cuda:0') 1.0
b.grad ERROR tensor(348.8304, device='cuda:0') 1.0
a.grad ERROR tensor(2611716.5000, device='cuda:0') 1.0
b.grad ERROR tensor(351.9631, device='cuda:0') 1.0
a.grad ERROR tensor(2.4866e+08, device='cuda:0') 1.0
b.grad ERROR tensor(40511900., device='cuda:0') 1.0
a.grad ERROR tensor(133.4997, device='cuda:0') 1.0
b.grad ERROR tensor(322.5428, device='cuda:0') 1.0
a.grad ERROR tensor(2725685.5000, device='cuda:0') 1.0
b.grad ERROR tensor(420.1107, device='cuda:0') 1.0
a.grad ERROR tensor(2781390.2500, device='cuda:0') 1.0
b.grad ERROR tensor(388.4075, device='cuda:0') 1.0
a.grad ERROR tensor(257.0772, device='cuda:0') 1.0
b.grad ERROR tensor(425.4004, device='cuda:0') 1.0
a.grad ERROR tensor(286.1646, device='cuda:0') 1.0
b.grad ERROR tensor(354.5847, device='cuda:0') 1.0
a.grad ERROR tensor(2775809., device='cuda:0') 1.0
b.grad ERROR tensor(384.3141, device='cuda:0') 1.0
a.grad ERROR tensor(2784289.2500, device='cuda:0') 1.0
b.grad ERROR tensor(390.9427, device='cuda:0') 1.0
a.grad ERROR tensor(1491.5585, device='cuda:0') 1.0
b.grad ERROR tensor(1.8443e+09, device='cuda:0') 1.0
a.grad ERROR tensor(2666701.7500, device='cuda:0') 1.0
b.grad ERROR tensor(1.1275e+08, device='cuda:0') 1.0
a.grad ERROR tensor(2653713., device='cuda:0') 1.0
b.grad ERROR tensor(353.5638, device='cuda:0') 1.0
a.grad ERROR tensor(322.0308, device='cuda:0') 1.0
b.grad ERROR tensor(357.3578, device='cuda:0') 1.0
a.grad ERROR tensor(254.9564, device='cuda:0') 1.0
b.grad ERROR tensor(323.5333, device='cuda:0') 1.0
a.grad ERROR tensor(333.6547, device='cuda:0') 1.0
b.grad ERROR tensor(359.9502, device='cuda:0') 1.0
a.grad ERROR tensor(286.8031, device='cuda:0') 1.0
b.grad ERROR tensor(389.5343, device='cuda:0') 1.0
a.grad ERROR tensor(2726983.5000, device='cuda:0') 1.0
b.grad ERROR tensor(357.1115, device='cuda:0') 1.0
a.grad ERROR tensor(288.9051, device='cuda:0') 1.0
b.grad ERROR tensor(357.7245, device='cuda:0') 1.0
a.grad ERROR tensor(876.8928, device='cuda:0') 1.0
b.grad ERROR tensor(19977906., device='cuda:0') 1.0
a.grad ERROR tensor(2686767.2500, device='cuda:0') 1.0
b.grad ERROR tensor(357.6304, device='cuda:0') 1.0
a.grad ERROR tensor(2765077.7500, device='cuda:0') 1.0
b.grad ERROR tensor(360.3122, device='cuda:0') 1.0
a.grad ERROR tensor(283.9131, device='cuda:0') 1.0
b.grad ERROR tensor(2.1615e+08, device='cuda:0') 1.0
a.grad ERROR tensor(2798058.7500, device='cuda:0') 1.0
b.grad ERROR tensor(19864.1621, device='cuda:0') 1.0
a.grad ERROR tensor(2597423.2500, device='cuda:0') 1.0
b.grad ERROR tensor(751808.8125, device='cuda:0') 1.0
a.grad ERROR tensor(69647880., device='cuda:0') 1.0
b.grad ERROR tensor(2.6077e+09, device='cuda:0') 1.0
a.grad ERROR tensor(2.6070e+09, device='cuda:0') 1.0
b.grad ERROR tensor(1.5997e+08, device='cuda:0') 1.0
a.grad ERROR tensor(59987.2266, device='cuda:0') 1.0
b.grad ERROR tensor(6796.1133, device='cuda:0') 1.0
a.grad ERROR tensor(3446316.2500, device='cuda:0') 1.0
b.grad ERROR tensor(5.5290e+09, device='cuda:0') 1.0
a.grad ERROR tensor(254.4641, device='cuda:0') 1.0
b.grad ERROR tensor(2.0320e+11, device='cuda:0') 1.0
a.grad ERROR tensor(1.5091e+09, device='cuda:0') 1.0
b.grad ERROR tensor(4.8775e+10, device='cuda:0') 1.0
a.grad ERROR tensor(4.9482e+10, device='cuda:0') 1.0
b.grad ERROR tensor(4.9815e+10, device='cuda:0') 1.0
a.grad ERROR tensor(288.4737, device='cuda:0') 1.0
b.grad ERROR tensor(22817.1172, device='cuda:0') 1.0
Traceback (most recent call last):
  File "test_matmul_sdd.py", line 112, in <module>
    sum_c_pytorch.backward()
  File "/opt/miniconda/lib/python3.7/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/opt/miniconda/lib/python3.7/site-packages/torch/autograd/__init__.py", line 156, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)`

Sometimes, RuntimeError appears as you can see above, but sometimes it doesn't.

Could you please fix it asap? It's really serious, and many people use this module oblivious of the problem.

Thank you very much.

@btyu
Copy link
Contributor Author

btyu commented Jan 6, 2022

I saw your testing script, and I believe you have tested it. Maybe it's me getting something wrong, but it really confuses me.

@ptillet
Copy link
Collaborator

ptillet commented Jan 6, 2022

Yep, this big bug affects only float32 gradient of sdd and it's been fixed in v2.0 branch already. I can push a hotfix throwing an error when float32 is used on master

@btyu
Copy link
Contributor Author

btyu commented Jan 6, 2022

Thank you for the quick reply!

However, I tested fp16 (by changing the dtype setting to torch.float16 in my script, and the problem still exists.

And this time, RuntimeError still exists but changed:

a.grad ERROR tensor(12432., device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(3458., device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(4692., device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(2442., device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(2372., device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(5396., device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(2974., device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(5968., device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(nan, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(4612., device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(62304., device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(4936., device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(3882., device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(4964., device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(3520., device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(3600., device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(5008., device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(4532., device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(6056., device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(287.5000, device='cuda:0', dtype=torch.float16) 0.9990234375
b.grad ERROR tensor(350.7500, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(257.7500, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(322.2500, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(nan, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(348.7500, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(360.7500, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(nan, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(352., device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(16736., device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(386.5000, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(328.2500, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(nan, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(358.2500, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(285.2500, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(325.7500, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(319.2500, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(353.7500, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(289.7500, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(325., device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(257.5000, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(389.2500, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(353., device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(253.2500, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(358., device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(287.5000, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(316.2500, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(287.7500, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(327.5000, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(257., device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(322.7500, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(255.3750, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(294., device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(252.2500, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(356.2500, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(290.7500, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(387., device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(291., device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(254.8750, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(320.7500, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(349.5000, device='cuda:0', dtype=torch.float16) 1.0
Traceback (most recent call last):
  File "test_matmul_sdd.py", line 112, in <module>
    sum_c_pytorch.backward()
  File "/opt/miniconda/lib/python3.7/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/opt/miniconda/lib/python3.7/site-packages/torch/autograd/__init__.py", line 156, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `result`

REALLY LOOKING FORWARD TO YOUR HELP AND FIX.

THANK YOU.

@btyu
Copy link
Contributor Author

btyu commented Jan 6, 2022

And it seems some grads become nan or inf.

@ptillet
Copy link
Collaborator

ptillet commented Jan 6, 2022

Can you try v2.0? I have high confidence in the version there. We use Triton blocksparse internally at OpenAI

@btyu
Copy link
Contributor Author

btyu commented Jan 6, 2022

OK, I'll try it right now!

@btyu
Copy link
Contributor Author

btyu commented Jan 6, 2022

OK, I tested v2.0. But It still has the problem.

a.grad ERROR tensor(450., device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(385.7500, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(253.8750, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(352.5000, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(148.5000, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(318.7500, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(5860., device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(3924., device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
a.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
b.grad ERROR tensor(inf, device='cuda:0', dtype=torch.float16) 1.0
Traceback (most recent call last):
  File "test_matmul_sdd.py", line 112, in <module>
    sum_c_pytorch.backward()
  File "/opt/miniconda/lib/python3.7/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/opt/miniconda/lib/python3.7/site-packages/torch/autograd/__init__.py", line 156, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `result`

@btyu
Copy link
Contributor Author

btyu commented Jan 6, 2022

@ptillet The gradient problem still exists in v2.0, as well as the RuntimeErrors.
REALLY NEED HELP.

@ptillet
Copy link
Collaborator

ptillet commented Jan 6, 2022

I have confirmed that the bug doesn't appear if you replace

  sum_c, sum_c_pytorch = c.sum(), c_pytorch.sum()
  sum_c.backward()
  sum_c_pytorch.backward()

with

  dc = torch.randn_like(c)
  c.backward(dc)
   c_pytorch.backward(dc)

which shows that the op works in general and explains why many groups have been able to use it successfully in large models.

I am not sure how the sum makes things buggy -- probably has to do with some contiguous requirement in the blocksparse op backprop. Will look into it

@ptillet
Copy link
Collaborator

ptillet commented Jan 6, 2022

Because the incoming gradient is a scalar when the blocksparse is followed by a sum, it has stride "(0, 0, 0, 0)" and this throws off the compiler apparently. I'll fix this.

@btyu
Copy link
Contributor Author

btyu commented Jan 6, 2022

Indeed if using dc as above, the problem doesn't appear. However, there are always errors between two implementations, and the max error mostly is 0.0312. This is pretty strange and interesting ;) Could you please look into this as well, to make the error as small as possible.

a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.3912200927734375
b.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.4085693359375
a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.3688812255859375
b.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.4048258662223816
a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.3965301513671875
b.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.4230753779411316
a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.39031982421875
b.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.41583251953125
a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.38922119140625
b.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.4140625
a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.3896026611328125
b.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.4069010615348816
a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.3958587646484375
b.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.4134725034236908
a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.3994903564453125
b.grad ERROR tensor(0.0625, device='cuda:0', dtype=torch.float16) 0.4081624448299408
a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.3704986572265625
b.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.4122314453125
a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.3907470703125
b.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.4033610224723816
a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.390960693359375
b.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.4114176630973816
a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.3888702392578125
b.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.4073486328125
a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.3962554931640625
b.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.4239095151424408
a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.392425537109375
b.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.41851806640625
a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.3952484130859375
b.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.4020792841911316
a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.393646240234375
b.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.4062093198299408
a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.3912200927734375
b.grad ERROR tensor(0.0625, device='cuda:0', dtype=torch.float16) 0.4065144956111908
a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.39208984375
b.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.4174601435661316
a.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.3862457275390625
b.grad ERROR tensor(0.0312, device='cuda:0', dtype=torch.float16) 0.4197998046875

As for the stride problem, I add a line c = c + 0.0 before sum (I guess it creates a new tensor without 0 stride), but the problem still exists.

Looking forward to your fix. Please inform me if done. This is really essential for me recently.

THANK YOU to you and your team!!!

@ptillet
Copy link
Collaborator

ptillet commented Jan 6, 2022

Hey! This is fixed in #420.

As for the error difference, I think that this is nothing of concern. This seems within the normal range of FP16 inputs, so in a sense the Torch result is not more accurate than Triton

@btyu
Copy link
Contributor Author

btyu commented Jan 6, 2022

So fascinating and so nice that you have fixed this so quickly!
MANY HUGE THANKS and Best Wishes!

@btyu btyu closed this as completed Jan 6, 2022
htyu pushed a commit to htyu/triton that referenced this issue Jan 9, 2024
The script runs one given config for debug purposes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants