Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add chunked kl div #62

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fla/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from fla.modules.layernorm import (GroupNorm, GroupNormLinear, LayerNorm,
LayerNormLinear, RMSNorm, RMSNormLinear)
from fla.modules.rotary import RotaryEmbedding
from fla.modules.chunked_kl_div import ChunkedKLDiv

__all__ = [
'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution',
'FusedCrossEntropyLoss',
'GroupNorm', 'GroupNormLinear', 'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear',
'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear',
'RotaryEmbedding'
'RotaryEmbedding', 'ChunkedKLDiv'
]
112 changes: 112 additions & 0 deletions fla/modules/chunked_kl_div.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import torch
import torch.nn.functional as F


def reference_torch(
x: torch.Tensor,
x_weight: torch.Tensor,
target_x: torch.Tensor,
target_weight: torch.Tensor,
reduction: str = "batchmean",
):
V = x_weight.shape[0]
logits = F.linear(x, x_weight).view(-1, V)
target_probs = F.linear(target_x, target_weight).view(-1, V)
target_probs = F.softmax(target_probs, dim=-1)

kl_loss = F.kl_div(
F.log_softmax(logits, dim=-1),
target_probs,
reduction=reduction,
)
return kl_loss


class ChunkedKLDiv(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: torch.Tensor,
x_weight: torch.Tensor,
target_x: torch.Tensor,
target_weight: torch.Tensor,
reduction: str = "batchmean",
sp: int = 8,
):
T = x.size(1)
chunk_size = (T + sp - 1) // sp

if reduction == "batchmean":
reduction_factor = T
elif reduction == "mean":
reduction_factor = sp
elif reduction == "sum":
reduction_factor = 1
else:
raise ValueError(f"Invalid reduction type: {reduction}")

kl_loss = 0
for i in range(sp):
logits_i = F.linear(
x[:, i * chunk_size : (i + 1) * chunk_size, :], x_weight
)
target_probs_i = F.linear(
target_x[:, i * chunk_size : (i + 1) * chunk_size, :], target_weight
)
log_probs_i = F.log_softmax(logits_i, dim=-1)
target_probs_i = F.softmax(target_probs_i, dim=-1)

loss_i = F.kl_div(
log_probs_i,
target_probs_i,
reduction=reduction,
)

kl_loss = kl_loss + loss_i

kl_loss = kl_loss / reduction_factor

ctx.save_for_backward(x, x_weight, target_x, target_weight)

ctx.sp = sp
ctx.reduction = reduction
return kl_loss

@staticmethod
def backward(ctx, grad_output):
x, x_weight, target_x, target_weight = ctx.saved_tensors
sp = ctx.sp
reduction = ctx.reduction

B, T, _ = x.size()
V = x_weight.size(0)

chunk_size = (T + sp - 1) // sp

if reduction == "batchmean":
reduction_factor = B * T
elif reduction == "mean":
reduction_factor = B * T * V
elif reduction == "sum":
reduction_factor = 1

grad_x = []
grad_weight = 0
for i in range(sp):
chunk_x = x[:, i * chunk_size : (i + 1) * chunk_size, :]
logits = F.linear(chunk_x, x_weight)
target_probs = F.linear(
target_x[:, i * chunk_size : (i + 1) * chunk_size, :], target_weight
)
target_probs = F.softmax(target_probs, dim=-1)

d_logits = -target_probs + torch.softmax(logits, dim=-1)

d_logits = d_logits / reduction_factor

grad_x.append(torch.einsum("blv, vh -> blh", d_logits, x_weight))
grad_weight += torch.einsum("blv, blh -> vh", d_logits, chunk_x)

grad_x = torch.cat(grad_x, dim=1)

return grad_output * grad_x, grad_output * grad_weight, None, None, None, None
46 changes: 46 additions & 0 deletions tests/modules/test_chunked_kl_div.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F

from fla.modules import LowMemKLDiv


@pytest.mark.parametrize("B", [1, 4])
@pytest.mark.parametrize("T", [1, 50, 2048, 4096])
@pytest.mark.parametrize("D", [1024, 2048])
@pytest.mark.parametrize("V", [32000, 100000])
@pytest.mark.parametrize("reduction", ["mean", "batchmean", "sum"])
def test_fused(B: int, T: int, D: int, V: int, reduction: str):
torch.manual_seed(42)
x = torch.randn(B, T, D).cuda().requires_grad_()
x_weight = torch.randn(V, D).cuda().requires_grad_()
target_x = torch.randn(B, T, D).cuda()
target_weight = torch.randn(V, D).cuda()

logits = F.linear(x, x_weight)
target_probs = F.linear(target_x, target_weight)
target_probs = F.softmax(target_probs, dim=-1)

ref = F.kl_div(
F.log_softmax(logits, dim=-1),
target_probs,
reduction=reduction,
)

do = torch.randn_like(ref).cuda()

ref.backward(do)
ref_d_x, x.grad = x.grad.clone(), None
ref_d_x_weight, x_weight.grad = x_weight.grad.clone(), None

chunk = LowMemKLDiv.apply(x, x_weight, target_x, target_weight, reduction)
chunk.backward(do)
chunk_d, x.grad = x.grad.clone(), None
chunk_d_x_weight, x_weight.grad = x_weight.grad.clone(), None

torch.testing.assert_close(ref, chunk)
torch.testing.assert_close(ref_d_x, chunk_d)
torch.testing.assert_close(ref_d_x_weight, chunk_d_x_weight)
Loading