diff --git a/fla/modules/__init__.py b/fla/modules/__init__.py index e84247373..e5f22e1f5 100644 --- a/fla/modules/__init__.py +++ b/fla/modules/__init__.py @@ -10,11 +10,12 @@ from fla.modules.layernorm import (GroupNorm, GroupNormLinear, LayerNorm, LayerNormLinear, RMSNorm, RMSNormLinear) from fla.modules.rotary import RotaryEmbedding +from fla.modules.chunked_kl_div import ChunkedKLDiv __all__ = [ 'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution', 'FusedCrossEntropyLoss', 'GroupNorm', 'GroupNormLinear', 'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear', 'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear', - 'RotaryEmbedding' + 'RotaryEmbedding', 'ChunkedKLDiv' ] diff --git a/fla/modules/chunked_kl_div.py b/fla/modules/chunked_kl_div.py new file mode 100644 index 000000000..2e477369d --- /dev/null +++ b/fla/modules/chunked_kl_div.py @@ -0,0 +1,112 @@ +import torch +import torch.nn.functional as F + + +def reference_torch( + x: torch.Tensor, + x_weight: torch.Tensor, + target_x: torch.Tensor, + target_weight: torch.Tensor, + reduction: str = "batchmean", +): + V = x_weight.shape[0] + logits = F.linear(x, x_weight).view(-1, V) + target_probs = F.linear(target_x, target_weight).view(-1, V) + target_probs = F.softmax(target_probs, dim=-1) + + kl_loss = F.kl_div( + F.log_softmax(logits, dim=-1), + target_probs, + reduction=reduction, + ) + return kl_loss + + +class ChunkedKLDiv(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + x_weight: torch.Tensor, + target_x: torch.Tensor, + target_weight: torch.Tensor, + reduction: str = "batchmean", + sp: int = 8, + ): + T = x.size(1) + chunk_size = (T + sp - 1) // sp + + if reduction == "batchmean": + reduction_factor = T + elif reduction == "mean": + reduction_factor = sp + elif reduction == "sum": + reduction_factor = 1 + else: + raise ValueError(f"Invalid reduction type: {reduction}") + + kl_loss = 0 + for i in range(sp): + logits_i = F.linear( + x[:, i * chunk_size : (i + 1) * chunk_size, :], x_weight + ) + target_probs_i = F.linear( + target_x[:, i * chunk_size : (i + 1) * chunk_size, :], target_weight + ) + log_probs_i = F.log_softmax(logits_i, dim=-1) + target_probs_i = F.softmax(target_probs_i, dim=-1) + + loss_i = F.kl_div( + log_probs_i, + target_probs_i, + reduction=reduction, + ) + + kl_loss = kl_loss + loss_i + + kl_loss = kl_loss / reduction_factor + + ctx.save_for_backward(x, x_weight, target_x, target_weight) + + ctx.sp = sp + ctx.reduction = reduction + return kl_loss + + @staticmethod + def backward(ctx, grad_output): + x, x_weight, target_x, target_weight = ctx.saved_tensors + sp = ctx.sp + reduction = ctx.reduction + + B, T, _ = x.size() + V = x_weight.size(0) + + chunk_size = (T + sp - 1) // sp + + if reduction == "batchmean": + reduction_factor = B * T + elif reduction == "mean": + reduction_factor = B * T * V + elif reduction == "sum": + reduction_factor = 1 + + grad_x = [] + grad_weight = 0 + for i in range(sp): + chunk_x = x[:, i * chunk_size : (i + 1) * chunk_size, :] + logits = F.linear(chunk_x, x_weight) + target_probs = F.linear( + target_x[:, i * chunk_size : (i + 1) * chunk_size, :], target_weight + ) + target_probs = F.softmax(target_probs, dim=-1) + + d_logits = -target_probs + torch.softmax(logits, dim=-1) + + d_logits = d_logits / reduction_factor + + grad_x.append(torch.einsum("blv, vh -> blh", d_logits, x_weight)) + grad_weight += torch.einsum("blv, blh -> vh", d_logits, chunk_x) + + grad_x = torch.cat(grad_x, dim=1) + + return grad_output * grad_x, grad_output * grad_weight, None, None, None, None diff --git a/tests/modules/test_chunked_kl_div.py b/tests/modules/test_chunked_kl_div.py new file mode 100644 index 000000000..b316b9c6d --- /dev/null +++ b/tests/modules/test_chunked_kl_div.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fla.modules import LowMemKLDiv + + +@pytest.mark.parametrize("B", [1, 4]) +@pytest.mark.parametrize("T", [1, 50, 2048, 4096]) +@pytest.mark.parametrize("D", [1024, 2048]) +@pytest.mark.parametrize("V", [32000, 100000]) +@pytest.mark.parametrize("reduction", ["mean", "batchmean", "sum"]) +def test_fused(B: int, T: int, D: int, V: int, reduction: str): + torch.manual_seed(42) + x = torch.randn(B, T, D).cuda().requires_grad_() + x_weight = torch.randn(V, D).cuda().requires_grad_() + target_x = torch.randn(B, T, D).cuda() + target_weight = torch.randn(V, D).cuda() + + logits = F.linear(x, x_weight) + target_probs = F.linear(target_x, target_weight) + target_probs = F.softmax(target_probs, dim=-1) + + ref = F.kl_div( + F.log_softmax(logits, dim=-1), + target_probs, + reduction=reduction, + ) + + do = torch.randn_like(ref).cuda() + + ref.backward(do) + ref_d_x, x.grad = x.grad.clone(), None + ref_d_x_weight, x_weight.grad = x_weight.grad.clone(), None + + chunk = LowMemKLDiv.apply(x, x_weight, target_x, target_weight, reduction) + chunk.backward(do) + chunk_d, x.grad = x.grad.clone(), None + chunk_d_x_weight, x_weight.grad = x_weight.grad.clone(), None + + torch.testing.assert_close(ref, chunk) + torch.testing.assert_close(ref_d_x, chunk_d) + torch.testing.assert_close(ref_d_x_weight, chunk_d_x_weight)