-
Notifications
You must be signed in to change notification settings - Fork 10
/
RG.py
32 lines (27 loc) · 1.04 KB
/
RG.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from torch import nn
class RG(nn.Module):
"""Recurrent Generative Module"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
""" Initialise RG Module (parameters as nn.ConvTranspose2d)"""
super().__init__()
self.from_input = nn.ConvTranspose2d(
in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding
)
self.from_state = nn.Conv2d(
in_channels=out_channels, out_channels=out_channels,
kernel_size=kernel_size, padding=padding, bias=False
)
def forward(self, x, state):
"""
Calling signature
:param x: (input, output_size)
:type x: tuple
:param state: previous output
:type state: torch.Tensor
:return: current state
:rtype: torch.Tensor
"""
x = self.from_input(*x) # the very first x is a tuple (input, expected_output_size)
if state: x += self.from_state(state)
return x