diff --git a/README.md b/README.md index 213a416..6f0ccd2 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,9 @@ Tested with Python 3.6 and PyTorch 1.1.0. Should work with any version of Python ## Changelog +#### Version 0.2.1 +- Added option to select the dimension along which the tensor is split ([Pull request](https://github.com/RobinBruegger/RevTorch/pull/2)) + #### Version 0.2.0 - Fixed memory leak when not consuming output of the reversible block ([Issue](https://github.com/RobinBruegger/RevTorch/issues/1)) diff --git a/revtorch/revtorch.py b/revtorch/revtorch.py index b9645a2..01ed46a 100644 --- a/revtorch/revtorch.py +++ b/revtorch/revtorch.py @@ -15,10 +15,11 @@ class ReversibleBlock(nn.Module): g_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape ''' - def __init__(self, f_block, g_block): + def __init__(self, f_block, g_block, split_along_dim=1): super(ReversibleBlock, self).__init__() self.f_block = f_block self.g_block = g_block + self.split_along_dim = split_along_dim def forward(self, x): """ @@ -26,13 +27,13 @@ def forward(self, x): :param x: Input tensor. Must be splittable along dimension 1. :return: Output tensor of the same shape as the input tensor """ - x1, x2 = torch.chunk(x, 2, dim=1) + x1, x2 = torch.chunk(x, 2, dim=self.split_along_dim) y1, y2 = None, None with torch.no_grad(): y1 = x1 + self.f_block(x2) y2 = x2 + self.g_block(y1) - return torch.cat([y1, y2], dim=1) + return torch.cat([y1, y2], dim=self.split_along_dim) def backward_pass(self, y, dy): """ @@ -47,11 +48,11 @@ def backward_pass(self, y, dy): """ # Split the arguments channel-wise - y1, y2 = torch.chunk(y, 2, dim=1) + y1, y2 = torch.chunk(y, 2, dim=self.split_along_dim) del y assert (not y1.requires_grad), "y1 must already be detached" assert (not y2.requires_grad), "y2 must already be detached" - dy1, dy2 = torch.chunk(dy, 2, dim=1) + dy1, dy2 = torch.chunk(dy, 2, dim=self.split_along_dim) del dy assert (not dy1.requires_grad), "dy1 must not require grad" assert (not dy2.requires_grad), "dy2 must not require grad" @@ -100,8 +101,8 @@ def backward_pass(self, y, dy): x2.grad = None # Undo the channelwise split - x = torch.cat([x1, x2.detach()], dim=1) - dx = torch.cat([dx1, dx2], dim=1) + x = torch.cat([x1, x2.detach()], dim=self.split_along_dim) + dx = torch.cat([dx1, dx2], dim=self.split_along_dim) return x, dx