forked from greydanus/hamiltonian-nn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hnn.py
211 lines (186 loc) · 7.75 KB
/
hnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# Hamiltonian Neural Networks | 2019
# Sam Greydanus, Misko Dzamba, Jason Yosinski
# import torch
# import numpy as np
#
# from nn_models import MLP
# from utils import rk4
#
# class HNN(torch.nn.Module):
# '''Learn arbitrary vector fields that are sums of conservative and solenoidal fields'''
# def __init__(self, input_dim, differentiable_model, field_type='solenoidal',
# baseline=False, assume_canonical_coords=True):
# super(HNN, self).__init__()
# self.baseline = baseline
# self.differentiable_model = differentiable_model
# self.assume_canonical_coords = assume_canonical_coords
# self.M = self.permutation_tensor(input_dim) # Levi-Civita permutation tensor
# self.field_type = field_type
#
# def forward(self, x):
# # traditional forward pass
# if self.baseline:
# return self.differentiable_model(x)
#
# y = self.differentiable_model(x)
# # print(y.dim())
# # print(y.shape[1])
# assert y.dim() == 2 and y.shape[1] == 4, "Output tensor should have shape [batch_size, 2]"
# x1,x2,x3,x4 = (y.split(1,1))
# # print(torch.cat((x1, x2), 1))
# return torch.cat((x1, x2), 1), torch.cat((x3, x4), 1)
#
# def rk4_time_derivative(self, x, dt):
# return rk4(fun=self.time_derivative, y0=x, t=0, dt=dt)
#
# def time_derivative(self, x, t=None, separate_fields=False):
# '''NEURAL ODE-STLE VECTOR FIELD'''
# if self.baseline:
# return self.differentiable_model(x)
#
# '''NEURAL HAMILTONIAN-STLE VECTOR FIELD'''
# # print(x)
# F1, F2 = self.forward(x) # traditional forward pass
#
# conservative_field = torch.zeros_like(x) # start out with both components set to 0
# solenoidal_field = torch.zeros_like(x)
#
# if self.field_type != 'solenoidal':
# dF1 = torch.autograd.grad(F1.sum(), x, create_graph=True)[0] # gradients for conservative field
# conservative_field = dF1 @ torch.eye(*self.M.shape)
#
# if self.field_type != 'conservative':
# dF2 = torch.autograd.grad(F2.sum(), x, create_graph=True)[0] # gradients for solenoidal field
# solenoidal_field = dF2 @ self.M.t()
#
# if separate_fields:
# return [conservative_field, solenoidal_field]
#
# return conservative_field + solenoidal_field
#
# def permutation_tensor(self,n):
# M = None
# if self.assume_canonical_coords:
# M = torch.eye(n)
# M = torch.cat([M[n//2:], -M[:n//2]])
# else:
# '''Constructs the Levi-Civita permutation tensor'''
# M = torch.ones(n,n) # matrix of ones
# M *= 1 - torch.eye(n) # clear diagonals
# M[::2] *= -1 # pattern of signs
# M[:,::2] *= -1
#
# for i in range(n): # make asymmetric
# for j in range(i+1, n):
# M[i,j] *= -1
# return M
#
#
# class PixelHNN(torch.nn.Module):
# def __init__(self, input_dim, hidden_dim, autoencoder,
# field_type='solenoidal', nonlinearity='tanh', baseline=False):
# super(PixelHNN, self).__init__()
# self.autoencoder = autoencoder
# self.baseline = baseline
#
# output_dim = input_dim if baseline else 2
# nn_model = MLP(input_dim, hidden_dim, output_dim, nonlinearity)
# self.hnn = HNN(input_dim, differentiable_model=nn_model, field_type=field_type, baseline=baseline)
#
# def encode(self, x):
# return self.autoencoder.encode(x)
#
# def decode(self, z):
# return self.autoencoder.decode(z)
#
# def time_derivative(self, z, separate_fields=False):
# return self.hnn.time_derivative(z, separate_fields)
#
# def forward(self, x):
# z = self.encode(x)
# z_next = z + self.time_derivative(z)
# return self.decode(z_next)
# Hamiltonian Neural Networks | 2019
# Sam Greydanus, Misko Dzamba, Jason Yosinski
import torch
import numpy as np
from nn_models import MLP
from utils import rk4
input_dim1 = 3
class HNN(torch.nn.Module):
'''Learn arbitrary vector fields that are sums of conservative and solenoidal fields'''
def __init__(self, input_dim, differentiable_model, field_type='solenoidal',
baseline=False, assume_canonical_coords=True):
super(HNN, self).__init__()
self.baseline = baseline
self.differentiable_model = differentiable_model
self.assume_canonical_coords = assume_canonical_coords
self.M = self.permutation_tensor(input_dim) # Levi-Civita permutation tensor
self.field_type = field_type
def forward(self, x):
# traditional forward pass
if self.baseline:
return self.differentiable_model(x)
y = self.differentiable_model(x)
# print(y.dim())
# print(y.shape[1])
assert y.dim() == 2 and y.shape[1] == input_dim1*2, "Output tensor should have shape [batch_size, 2]"
# x1,x2,x3,x4,x5,x6 = (y.split(1,1))
# answer1 = torch.cat((x1, x2, x3), 1), torch.cat((x3, x4, x6), 1)
dic1 = (y.split(1,1))
answer1 = torch.cat(dic1[0:input_dim1], 1), torch.cat(dic1[input_dim1:2*input_dim1], 1)
return answer1
def rk4_time_derivative(self, x, dt):
return rk4(fun=self.time_derivative, y0=x, t=0, dt=dt)
def time_derivative(self, x, t=None, separate_fields=False):
'''NEURAL ODE-STLE VECTOR FIELD'''
if self.baseline:
return self.differentiable_model(x)
'''NEURAL HAMILTONIAN-STLE VECTOR FIELD'''
# print(x)
F1, F2 = self.forward(x) # traditional forward pass
conservative_field = torch.zeros_like(x) # start out with both components set to 0
solenoidal_field = torch.zeros_like(x)
if self.field_type != 'solenoidal':
dF1 = torch.autograd.grad(F1.sum(), x, create_graph=True)[0] # gradients for conservative field
conservative_field = dF1 @ torch.eye(*self.M.shape)
if self.field_type != 'conservative':
dF2 = torch.autograd.grad(F2.sum(), x, create_graph=True)[0] # gradients for solenoidal field
solenoidal_field = dF2 @ self.M.t()
if separate_fields:
return [conservative_field, solenoidal_field]
return conservative_field + solenoidal_field
def permutation_tensor(self,n):
M = None
if self.assume_canonical_coords:
M = torch.eye(n)
M = torch.cat([M[n//2:], -M[:n//2]])
else:
'''Constructs the Levi-Civita permutation tensor'''
M = torch.ones(n,n) # matrix of ones
M *= 1 - torch.eye(n) # clear diagonals
M[::2] *= -1 # pattern of signs
M[:,::2] *= -1
for i in range(n): # make asymmetric
for j in range(i+1, n):
M[i,j] *= -1
return M
class PixelHNN(torch.nn.Module):
def __init__(self, input_dim, hidden_dim, autoencoder,
field_type='solenoidal', nonlinearity='tanh', baseline=False):
super(PixelHNN, self).__init__()
self.autoencoder = autoencoder
self.baseline = baseline
output_dim = input_dim if baseline else 2
nn_model = MLP(input_dim, hidden_dim, output_dim, nonlinearity)
self.hnn = HNN(input_dim, differentiable_model=nn_model, field_type=field_type, baseline=baseline)
def encode(self, x):
return self.autoencoder.encode(x)
def decode(self, z):
return self.autoencoder.decode(z)
def time_derivative(self, z, separate_fields=False):
return self.hnn.time_derivative(z, separate_fields)
def forward(self, x):
z = self.encode(x)
z_next = z + self.time_derivative(z)
return self.decode(z_next)