Skip to content

Commit

Permalink
Merge branch 'lucidrains-master'
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinBruegger committed Jan 9, 2020
2 parents df55b91 + 73db4a8 commit aaa1c5d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ Tested with Python 3.6 and PyTorch 1.1.0. Should work with any version of Python

## Changelog

#### Version 0.2.1
- Added option to select the dimension along which the tensor is split ([Pull request](https://github.com/RobinBruegger/RevTorch/pull/2))

#### Version 0.2.0
- Fixed memory leak when not consuming output of the reversible block ([Issue](https://github.com/RobinBruegger/RevTorch/issues/1))

Expand Down
15 changes: 8 additions & 7 deletions revtorch/revtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,25 @@ class ReversibleBlock(nn.Module):
g_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape
'''

def __init__(self, f_block, g_block):
def __init__(self, f_block, g_block, split_along_dim=1):
super(ReversibleBlock, self).__init__()
self.f_block = f_block
self.g_block = g_block
self.split_along_dim = split_along_dim

def forward(self, x):
"""
Performs the forward pass of the reversible block. Does not record any gradients.
:param x: Input tensor. Must be splittable along dimension 1.
:return: Output tensor of the same shape as the input tensor
"""
x1, x2 = torch.chunk(x, 2, dim=1)
x1, x2 = torch.chunk(x, 2, dim=self.split_along_dim)
y1, y2 = None, None
with torch.no_grad():
y1 = x1 + self.f_block(x2)
y2 = x2 + self.g_block(y1)

return torch.cat([y1, y2], dim=1)
return torch.cat([y1, y2], dim=self.split_along_dim)

def backward_pass(self, y, dy):
"""
Expand All @@ -47,11 +48,11 @@ def backward_pass(self, y, dy):
"""

# Split the arguments channel-wise
y1, y2 = torch.chunk(y, 2, dim=1)
y1, y2 = torch.chunk(y, 2, dim=self.split_along_dim)
del y
assert (not y1.requires_grad), "y1 must already be detached"
assert (not y2.requires_grad), "y2 must already be detached"
dy1, dy2 = torch.chunk(dy, 2, dim=1)
dy1, dy2 = torch.chunk(dy, 2, dim=self.split_along_dim)
del dy
assert (not dy1.requires_grad), "dy1 must not require grad"
assert (not dy2.requires_grad), "dy2 must not require grad"
Expand Down Expand Up @@ -100,8 +101,8 @@ def backward_pass(self, y, dy):
x2.grad = None

# Undo the channelwise split
x = torch.cat([x1, x2.detach()], dim=1)
dx = torch.cat([dx1, dx2], dim=1)
x = torch.cat([x1, x2.detach()], dim=self.split_along_dim)
dx = torch.cat([dx1, dx2], dim=self.split_along_dim)

return x, dx

Expand Down

0 comments on commit aaa1c5d

Please sign in to comment.