-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathcross_entropy.py
97 lines (87 loc) · 3.29 KB
/
cross_entropy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import triton
import triton.language as tl
from fast_llm.functional.config import TritonConfig
@triton.jit
def triton_cross_entropy_forward_backward_kernel(
logits_ptr,
labels_ptr,
grad_logits_ptr,
losses_ptr,
grad_losses,
n_cols,
logits_stride_0,
grad_logits_stride_0,
logits_scale_factor: tl.constexpr,
block_size: tl.constexpr,
ignore_index: tl.constexpr,
):
# TODO: Int64 ptr only if needed?
block_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, block_size)
logits_ptr = logits_ptr + block_idx * logits_stride_0
mask = col_offsets < n_cols
logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32)
if logits_scale_factor != 1.0:
logits *= logits_scale_factor
max_logits = tl.max(logits, 0)
exp_logits = tl.exp(logits - max_logits)
sum_exp_logits = tl.sum(exp_logits, 0)
label_idx = tl.load(labels_ptr + block_idx)
label_logits = tl.load(logits_ptr + label_idx).to(tl.float32)
if label_idx < 0 or label_idx == ignore_index:
loss = 0.0
else:
loss = tl.log(sum_exp_logits) + max_logits - label_logits
tl.store(losses_ptr + block_idx, loss)
grad_logits_ptr = grad_logits_ptr + block_idx * grad_logits_stride_0
col_offsets = tl.arange(0, block_size)
label_idx = tl.load(labels_ptr + block_idx)
exp_logits = exp_logits / sum_exp_logits
if logits_scale_factor != 1.0:
exp_logits *= logits_scale_factor
if label_idx < 0 or label_idx == ignore_index:
grad_losses = 0.0
grad_logits = grad_losses * tl.where(col_offsets == label_idx, exp_logits - 1.0, exp_logits)
tl.store(grad_logits_ptr + col_offsets, grad_logits, mask=mask)
def triton_cross_entropy_forward_backward(
logits: torch.Tensor,
target: torch.Tensor,
grad_output: float | None,
logits_scale_factor: float = 1.0,
ignore_index: int = -100,
apply_loss_mask: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes,
all in a single kernel.
Compared to a standard pytorch implementation, this reduces memory usage (of logits) by 3x and memory I/O by 5x.
TODO: Better handling of `grad_output = None`
"""
assert TritonConfig.TRITON_ENABLED
# TODO: Improve assumptions.
assert logits.is_contiguous()
assert target.is_contiguous()
n_rows, n_cols = logits.shape
assert target.shape == (n_rows,)
block_size = triton.next_power_of_2(n_cols)
assert block_size <= TritonConfig.MAX_BLOCK_SIZE_BYTES
num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16)
losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
# TODO: Safe to do inplace?
grad_logits = torch.empty_like(logits)
triton_cross_entropy_forward_backward_kernel[(n_rows,)](
logits,
target,
grad_logits,
losses,
1 if grad_output is None else grad_output / n_rows,
n_cols,
logits.stride(0),
grad_logits.stride(0),
logits_scale_factor,
block_size=block_size,
num_warps=num_warps,
ignore_index=ignore_index,
)
return losses.mean(), None if grad_output is None else grad_logits