Skip to content

Commit

Permalink
Add missing ignore_index tests (#310)
Browse files Browse the repository at this point in the history
## Summary
`ignore_index` in fused_linear_cross_entropy was not tested

## Testing Done

- Hardware Type: gpu-ci
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

---------

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
Co-authored-by: Yun Dai <yundai424@gmail.com>
  • Loading branch information
3 people authored Oct 30, 2024
1 parent 1c0c75c commit 48aa62d
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions test/transformers/test_fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,28 @@ def forward(self, x, y):
],
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("label_smoothing", [0, 0.1])
@pytest.mark.parametrize("label_smoothing, ignore_index", [(0.0, -100), (0.1, 42)])
def test_correctness(
B, T, H, V, scalar, dtype, bias, label_smoothing, reduction, atol, rtol
B,
T,
H,
V,
scalar,
dtype,
bias,
label_smoothing,
ignore_index,
reduction,
atol,
rtol,
):
device = "cuda"
torch_lm_head_ce = TorchLMHeadCE(
H=H,
V=V,
bias=bias,
label_smoothing=label_smoothing,
ignore_index=ignore_index,
reduction=reduction,
dtype=dtype,
).to(device)
Expand All @@ -118,6 +130,7 @@ def test_correctness(
V=V,
bias=bias,
label_smoothing=label_smoothing,
ignore_index=ignore_index,
reduction=reduction,
dtype=dtype,
).to(device)
Expand All @@ -137,6 +150,14 @@ def test_correctness(
_input2 = _tensor.detach().clone().requires_grad_(True)

target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
# Assign some random number of elements as ignore_index
num_elements_to_assign = torch.randint(
1, B * T // 2, (1,)
).item() # Random number of elements to set to ignore_index
indices_to_assign = torch.randperm(B * T)[
:num_elements_to_assign
] # Randomly select indices
target[indices_to_assign] = ignore_index

output1 = torch_lm_head_ce(_input1, target)
output2 = liger_lm_head_ce(_input2, target)
Expand Down

0 comments on commit 48aa62d

Please sign in to comment.