Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chunk3: Add custom operator to avoid torch.cat in BW #458

Merged
merged 7 commits into from
Oct 6, 2022
44 changes: 42 additions & 2 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def test_forward(
query, key, value = c[0], c[1], c[2]
else:
# bm3hk -> 3 x bmhk
query, key, value = xformers.ops.Chunk3.apply(c, 2)
query, key, value = xformers.ops.unbind(c, 2)
assert not query.is_contiguous()

out = xformers.ops.memory_efficient_attention(
Expand Down Expand Up @@ -480,7 +480,7 @@ def test_backward(
qkv = torch.stack([query, key, value], 2)
qkv.requires_grad_(True)
# bm3hk -> 3 x bmhk
query, key, value = xformers.ops.Chunk3.apply(qkv, 2)
query, key, value = xformers.ops.unbind(qkv, 2)
assert not query.is_contiguous()

query.requires_grad_(True)
Expand Down Expand Up @@ -724,3 +724,43 @@ def test_memory_efficient_attention_full_block_masked(
assert_allclose(grad_q, query.grad, "grad_q", atol=atol)
assert_allclose(grad_k, key.grad, "grad_k", atol=atol)
assert_allclose(grad_v, value.grad, "grad_v", atol=atol)


@pytest.mark.parametrize("dim", [0, 1, 2, 3, 4])
def test_unbind(dim):
x = torch.ones([10, 20, 4, 10, 3], requires_grad=True)
x2 = torch.ones([10, 20, 4, 10, 3], requires_grad=True)

# FW
tensors = xformers.ops.unbind(x, dim)
tensors2 = torch.unbind(x2, dim)
assert len(tensors) == len(tensors2)
for t1, t2 in zip(tensors, tensors2):
assert torch.allclose(t1, t2)

# BW
grads = torch.unbind(torch.randn(x.shape), dim)
loss1 = sum(g * t for (g, t) in zip(grads, tensors))
loss2 = sum(g * t for (g, t) in zip(grads, tensors2))
assert torch.allclose(loss1, loss2)
g = torch.ones_like(loss1)
danthe3rd marked this conversation as resolved.
Show resolved Hide resolved
loss1.backward(g)
loss2.backward(g)
assert torch.allclose(x.grad, x2.grad)


@pytest.mark.parametrize("dim", [0, 1, 2, 3, 4])
def test_unbind_get_stack_strides(dim: int):
def not_stacked(t, d):
return xformers.ops.get_stack_strides(t, d) is None

x = torch.ones([10, 20, 4, 10, 3], requires_grad=True)
tensors = xformers.ops.unbind(x, dim)
ndim = x.ndim

assert not_stacked(tensors, (dim + 1) % ndim)
assert xformers.ops.get_stack_strides(tensors, dim) == x.stride()
assert xformers.ops.get_stack_strides(tensors[1:], dim) == x.stride()
assert not_stacked(tensors[::2], dim)
assert not_stacked([tensors[0], tensors[1].clone()], dim)
assert not_stacked(tensors, (dim + ndim - 1) % ndim)
2 changes: 1 addition & 1 deletion xformers/benchmarks/benchmark_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def create_tensors(shape, dtype, requires_grad=False):
qkv = torch.rand(
[B, M, 3, H, K], device=device, dtype=dtype, requires_grad=requires_grad
)
q, k, v = xformers.ops.Chunk3.apply(qkv, 2)
q, k, v = xformers.ops.unbind(qkv, 2)
return qkv, q, k, v


Expand Down
78 changes: 53 additions & 25 deletions xformers/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import math
from dataclasses import dataclass
from typing import Any, List, Mapping, Optional, Set, Type, Union
from typing import Any, List, Mapping, Optional, Sequence, Set, Type, Union

import torch

Expand Down Expand Up @@ -655,7 +655,46 @@ def from_arguments(
)


class Chunk3(torch.autograd.Function):
def get_stack_strides(
tensors: Sequence[torch.Tensor], dim: int
) -> Optional[Sequence[int]]:
"""
If the tensors are already stacked, returns the strides of the stacked
tensors. Otherwise returns None.
"""
if len(tensors) <= 1 or dim > tensors[0].ndim:
return None

final_stride = []
for i in range(tensors[0].ndim + 1):
if i == dim:
final_stride.append(0)
continue
if i > dim:
i -= 1
final_stride.append(tensors[0].stride(i))

# Set the stride of the concat dimension
if dim == tensors[0].ndim:
final_stride[dim] = 1
else:
final_stride[dim] = final_stride[dim + 1] * tensors[0].shape[dim]

for i, x in enumerate(tensors):
# Sanity checks
if x.shape != tensors[0].shape:
return None
# Actual storage check
if x.storage().data_ptr() != tensors[0].storage().data_ptr():
return None
if x.stride() != tensors[0].stride():
return None
if x.storage_offset() != tensors[0].storage_offset() + i * final_stride[dim]:
return None
return tuple(final_stride)


class _Unbind(torch.autograd.Function):
"""
Splits a packed `qkv` tensor into query, key and values.
The magic happens in the backward. We want to `torch.stack` the tensors
Expand All @@ -665,34 +704,23 @@ class Chunk3(torch.autograd.Function):

@staticmethod
# type: ignore
def forward(ctx, qkv: torch.Tensor, dim: int):
q, k, v = qkv.select(dim, 0), qkv.select(dim, 1), qkv.select(dim, 2)
def forward(ctx, x: torch.Tensor, dim: int):
ctx.dim = dim
ctx.qkv_shape = qkv.shape
ctx.qkv_strides = qkv.stride()
ctx.q_stride = q.stride()
ctx.k_stride = k.stride()
ctx.v_stride = v.stride()
ctx.storage_offsets = (
q.storage_offset(),
k.storage_offset(),
v.storage_offset(),
)
return q, k, v
ctx.input_shape = x.shape
return x.unbind(dim)

@classmethod
# type: ignore
def backward(cls, ctx, gq: torch.Tensor, gk: torch.Tensor, gv: torch.Tensor):
def backward(cls, ctx, *tensors: torch.Tensor):
# Fast path
if (
ctx.storage_offsets
== (gq.storage_offset(), gk.storage_offset(), gv.storage_offset())
and gq.stride() == ctx.q_stride
and gk.stride() == ctx.k_stride
and gv.stride() == ctx.v_stride
):
return gq.as_strided(ctx.qkv_shape, ctx.qkv_strides), None
return torch.stack([gq, gk, gv], dim=ctx.dim), None
strides = get_stack_strides(tensors, ctx.dim)
if strides is not None:
return tensors[0].as_strided(ctx.input_shape, strides), None
return torch.stack(tensors, dim=ctx.dim), None


def unbind(x: torch.Tensor, dim: int) -> Sequence[torch.Tensor]:
return _Unbind.apply(x, dim)


def memory_efficient_attention(
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.