-
Notifications
You must be signed in to change notification settings - Fork 18
/
diff_models.py
141 lines (117 loc) · 6.29 KB
/
diff_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from layers import *
class Guide_diff(nn.Module):
def __init__(self, config, inputdim=1, target_dim=36, is_itp=False):
super().__init__()
self.channels = config["channels"]
self.is_itp = is_itp
self.itp_channels = None
if self.is_itp:
self.itp_channels = config["channels"]
self.itp_projection = Conv1d_with_init(inputdim-1, self.itp_channels, 1)
self.itp_modeling = GuidanceConstruct(channels=self.itp_channels, nheads=config["nheads"], target_dim=target_dim,
order=2, include_self=True, device=config["device"], is_adp=config["is_adp"],
adj_file=config["adj_file"], proj_t=config["proj_t"])
self.cond_projection = Conv1d_with_init(config["side_dim"], self.itp_channels, 1)
self.itp_projection2 = Conv1d_with_init(self.itp_channels, 1, 1)
self.diffusion_embedding = DiffusionEmbedding(
num_steps=config["num_steps"],
embedding_dim=config["diffusion_embedding_dim"],
)
if config["adj_file"] == 'AQI36':
self.adj = get_adj_AQI36()
elif config["adj_file"] == 'metr-la':
self.adj = get_similarity_metrla(thr=0.1)
elif config["adj_file"] == 'pems-bay':
self.adj = get_similarity_pemsbay(thr=0.1)
self.device = config["device"]
self.support = compute_support_gwn(self.adj, device=config["device"])
self.is_adp = config["is_adp"]
if self.is_adp:
node_num = self.adj.shape[0]
self.nodevec1 = nn.Parameter(torch.randn(node_num, 10).to(self.device), requires_grad=True).to(self.device)
self.nodevec2 = nn.Parameter(torch.randn(10, node_num).to(self.device), requires_grad=True).to(self.device)
self.support.append([self.nodevec1, self.nodevec2])
self.input_projection = Conv1d_with_init(inputdim, self.channels, 1)
self.output_projection1 = Conv1d_with_init(self.channels, self.channels, 1)
self.output_projection2 = Conv1d_with_init(self.channels, 1, 1)
nn.init.zeros_(self.output_projection2.weight)
self.residual_layers = nn.ModuleList(
[
NoiseProject(
side_dim=config["side_dim"],
channels=self.channels,
diffusion_embedding_dim=config["diffusion_embedding_dim"],
nheads=config["nheads"],
target_dim=target_dim,
proj_t=config["proj_t"],
is_adp=config["is_adp"],
device=config["device"],
adj_file=config["adj_file"],
is_cross_t=config["is_cross_t"],
is_cross_s=config["is_cross_s"],
)
for _ in range(config["layers"])
]
)
def forward(self, x, side_info, diffusion_step, itp_x, cond_mask):
if self.is_itp:
x = torch.cat([x, itp_x], dim=1)
B, inputdim, K, L = x.shape
x = x.reshape(B, inputdim, K * L)
x = self.input_projection(x)
x = F.relu(x)
x = x.reshape(B, self.channels, K, L)
if self.is_itp:
itp_x = itp_x.reshape(B, inputdim-1, K * L)
itp_x = self.itp_projection(itp_x)
itp_cond_info = side_info.reshape(B, -1, K * L)
itp_cond_info = self.cond_projection(itp_cond_info)
itp_x = itp_x + itp_cond_info
itp_x = self.itp_modeling(itp_x, [B, self.itp_channels, K, L], self.support)
itp_x = F.relu(itp_x)
itp_x = itp_x.reshape(B, self.itp_channels, K, L)
diffusion_emb = self.diffusion_embedding(diffusion_step)
skip = []
for i in range(len(self.residual_layers)):
x, skip_connection = self.residual_layers[i](x, side_info, diffusion_emb, itp_x, self.support)
skip.append(skip_connection)
x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
x = x.reshape(B, self.channels, K * L)
x = self.output_projection1(x) # (B,channel,K*L)
x = F.relu(x)
x = self.output_projection2(x) # (B,1,K*L)
x = x.reshape(B, K, L)
return x
class NoiseProject(nn.Module):
def __init__(self, side_dim, channels, diffusion_embedding_dim, nheads, target_dim, proj_t, order=2, include_self=True,
device=None, is_adp=False, adj_file=None, is_cross_t=False, is_cross_s=True):
super().__init__()
self.diffusion_projection = nn.Linear(diffusion_embedding_dim, channels)
self.cond_projection = Conv1d_with_init(side_dim, 2 * channels, 1)
self.mid_projection = Conv1d_with_init(channels, 2 * channels, 1)
self.output_projection = Conv1d_with_init(channels, 2 * channels, 1)
self.forward_time = TemporalLearning(channels=channels, nheads=nheads, is_cross=is_cross_t)
self.forward_feature = SpatialLearning(channels=channels, nheads=nheads, target_dim=target_dim,
order=order, include_self=include_self, device=device, is_adp=is_adp,
adj_file=adj_file, proj_t=proj_t, is_cross=is_cross_s)
def forward(self, x, side_info, diffusion_emb, itp_info, support):
B, channel, K, L = x.shape
base_shape = x.shape
x = x.reshape(B, channel, K * L)
diffusion_emb = self.diffusion_projection(diffusion_emb).unsqueeze(-1) # (B,channel,1)
y = x + diffusion_emb
y = self.forward_time(y, base_shape, itp_info)
y = self.forward_feature(y, base_shape, support, itp_info) # (B,channel,K*L)
y = self.mid_projection(y) # (B,2*channel,K*L)
_, side_dim, _, _ = side_info.shape
side_info = side_info.reshape(B, side_dim, K * L)
side_info = self.cond_projection(side_info) # (B,2*channel,K*L)
y = y + side_info
gate, filter = torch.chunk(y, 2, dim=1)
y = torch.sigmoid(gate) * torch.tanh(filter) # (B,channel,K*L)
y = self.output_projection(y)
residual, skip = torch.chunk(y, 2, dim=1)
x = x.reshape(base_shape)
residual = residual.reshape(base_shape)
skip = skip.reshape(base_shape)
return (x + residual) / math.sqrt(2.0), skip