Skip to content

Commit

Permalink
Merge pull request #44 from ElenaRyumina/main
Browse files Browse the repository at this point in the history
Conversion of tensorflow models to pytorch models
  • Loading branch information
DmitryRyumin authored Oct 2, 2024
2 parents 404d03b + 5f9cc48 commit 052840c
Show file tree
Hide file tree
Showing 18 changed files with 1,557 additions and 1,583 deletions.
44 changes: 40 additions & 4 deletions oceanai/modules/core/core.py

Large diffs are not rendered by default.

71 changes: 71 additions & 0 deletions oceanai/modules/lab/architectures/audio_architectures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Архитектуры аудио моделей для Torch
"""

from __future__ import print_function

import torch.nn as nn
import torchvision.models as models

class audio_model_hc(nn.Module):
def __init__(self, input_size=25):
super(audio_model_hc, self).__init__()

self.lstm1 = nn.LSTM(input_size, 64, batch_first=True)
self.dropout1 = nn.Dropout(0.2)
self.lstm2 = nn.LSTM(64, 128, batch_first=True)
self.dropout2 = nn.Dropout(0.2)
self.fc = nn.Linear(128, 5)

def extract_features(self, x):
x, _ = self.lstm1(x)
x = self.dropout1(x)
x, _ = self.lstm2(x)
return x[:, -1, :]

def forward(self, x):
features = self.extract_features(x)
x = self.dropout2(features)
x = self.fc(x)
return x, features

class audio_model_nn(nn.Module):
def __init__(self, input_size=512):
super(audio_model_nn, self).__init__()

self.vgg = models.vgg16(pretrained=False)
self.vgg.classifier = nn.Identity()

self.flatten = nn.Flatten()
self.fc1 = nn.Linear(512 * 7 * 7, 512)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 5)

def extract_features(self, x):
x = self.vgg.features(x)
x = self.flatten(x.permute(0, 2, 3, 1))
x = self.relu(self.fc1(x))
x = self.dropout(x)
x = self.relu(self.fc2(x))
return x

def forward(self, x):
features = self.extract_features(x)
x = self.fc3(features)
return x, features

class audio_model_b5(nn.Module):
def __init__(self, input_size=32):
super(audio_model_b5, self).__init__()
self.fc = nn.Linear(input_size, 1)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
x = self.fc(x)
x = self.sigmoid(x)
return x
108 changes: 108 additions & 0 deletions oceanai/modules/lab/architectures/fusion_architectures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Архитектуры моделей слияния для Torch
"""

from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

class GFL(nn.Module):
def __init__(self, output_dim, input_shapes):
super(GFL, self).__init__()
self.output_dim = output_dim
self.W_HCF1 = nn.Parameter(torch.Tensor(input_shapes[0], output_dim))
self.W_DF1 = nn.Parameter(torch.Tensor(input_shapes[2], output_dim))
self.W_HCF2 = nn.Parameter(torch.Tensor(input_shapes[1], output_dim))
self.W_DF2 = nn.Parameter(torch.Tensor(input_shapes[3], output_dim))

init.xavier_uniform_(self.W_HCF1)
init.xavier_uniform_(self.W_DF1)
init.xavier_uniform_(self.W_HCF2)
init.xavier_uniform_(self.W_DF2)

dim_size1 = input_shapes[0] + input_shapes[1]
dim_size2 = input_shapes[2] + input_shapes[3]

self.W_HCF = nn.Parameter(torch.Tensor(dim_size1, output_dim))
self.W_DF = nn.Parameter(torch.Tensor(dim_size2, output_dim))

init.xavier_uniform_(self.W_HCF)
init.xavier_uniform_(self.W_DF)

def forward(self, inputs):
HCF1, HCF2, DF1, DF2 = inputs

h_HCF1 = torch.tanh(torch.matmul(HCF1, self.W_HCF1))
h_HCF2 = torch.tanh(torch.matmul(HCF2, self.W_HCF2))
h_DF1 = torch.tanh(torch.matmul(DF1, self.W_DF1))
h_DF2 = torch.tanh(torch.matmul(DF2, self.W_DF2))

h_HCF = torch.sigmoid(torch.matmul(torch.cat((HCF1, HCF2), dim=-1), self.W_HCF))
h_DF = torch.sigmoid(torch.matmul(torch.cat((DF1, DF2), dim=-1), self.W_DF))

h = h_HCF * h_HCF1 + (1 - h_HCF) * h_HCF2 + h_DF * h_DF1 + (1 - h_DF) * h_DF2

return h

class LayerNormalization(nn.Module):
def __init__(self, dim):
super(LayerNormalization, self).__init__()
self.layer_norm = nn.LayerNorm(dim)

def forward(self, x):
return self.layer_norm(x)

class avt_model_b5(nn.Module):
def __init__(self, input_shapes, output_dim=64, hidden_states=50):
super(avt_model_b5, self).__init__()

self.ln_hc_t = LayerNormalization(input_shapes[0])
self.ln_nn_t = LayerNormalization(input_shapes[1])
self.ln_hc_a = LayerNormalization(input_shapes[2])
self.ln_nn_a = LayerNormalization(input_shapes[3])
self.ln_hc_v = LayerNormalization(input_shapes[4])
self.ln_nn_v = LayerNormalization(input_shapes[5])

self.gf_ta = GFL(output_dim=output_dim, input_shapes = [input_shapes[0], input_shapes[2], input_shapes[1], input_shapes[3]])
self.gf_tv = GFL(output_dim=output_dim, input_shapes = [input_shapes[0], input_shapes[4], input_shapes[1], input_shapes[5]])
self.gf_av = GFL(output_dim=output_dim, input_shapes = [input_shapes[2], input_shapes[4], input_shapes[3], input_shapes[5]])

self.fc1 = nn.Linear(output_dim * 3, hidden_states)
self.fc2 = nn.Linear(hidden_states, 5)

def forward(self, hc_t, nn_t, hc_a, nn_a, hc_v, nn_v):
hc_t_n = self.ln_hc_t(hc_t)
nn_t_n = self.ln_nn_t(nn_t)
hc_a_n = self.ln_hc_a(hc_a)
nn_a_n = self.ln_nn_a(nn_a)
hc_v_n = self.ln_hc_v(hc_v)
nn_v_n = self.ln_nn_v(nn_v)

gf_ta_out = self.gf_ta([hc_t_n, hc_a_n, nn_t_n, nn_a_n])
gf_tv_out = self.gf_tv([hc_t_n, hc_v_n, nn_t_n, nn_v_n])
gf_av_out = self.gf_av([hc_a_n, hc_v_n, nn_a_n, nn_v_n])

concat_out = torch.cat([gf_ta_out, gf_tv_out, gf_av_out], dim=-1)

dense_out = F.relu(self.fc1(concat_out))
output = torch.sigmoid(self.fc2(dense_out))

return output

class av_model_b5(nn.Module):
def __init__(self, input_size=64):
super(av_model_b5, self).__init__()
self.fc = nn.Linear(input_size, 1)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
x = self.fc(x)
x = self.sigmoid(x)
return x

91 changes: 91 additions & 0 deletions oceanai/modules/lab/architectures/text_architectures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Архитектуры текстовых моделей для Torch
"""

from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
def __init__(self):
super(Attention, self).__init__()

def forward(self, query, key):
scores = torch.matmul(query, key.transpose(-1, -2))
scores = F.softmax(scores, dim=-1)
return torch.matmul(scores, key)

class Addition(nn.Module):
def __init__(self):
super(Addition, self).__init__()

def forward(self, x):
mean = torch.mean(x, dim=1)
std = torch.std(x, dim=1)
return torch.cat((mean, std), dim=1)


class Concat(nn.Module):
def __init__(self):
super(Concat, self).__init__()

def forward(self, inputs):
return torch.cat(inputs, dim=1)


class text_model_hc(nn.Module):
def __init__(self, input_shape):
super(text_model_hc, self).__init__()

self.lstm1 = nn.LSTM(input_size=input_shape[1], hidden_size=32, batch_first=True, bidirectional=True)
self.attention = Attention()
self.lstm2 = nn.LSTM(input_size=64, hidden_size=32, batch_first=True, bidirectional=True)
self.dense = nn.Linear(input_shape[1], 32 * 2)
self.addition = Addition()
self.final_dense = nn.Linear(128, 5)

def forward(self, x):
x_lstm, _ = self.lstm1(x)
x_attention = self.attention(x_lstm, x_lstm)
x_dense = F.relu(self.dense(x))
x_dense, _ = self.lstm2(x_dense)
x_add = torch.stack([x_lstm, x_attention, x_dense], dim=0)
x = torch.sum(x_add, dim=0)
feat = self.addition(x)
x = torch.sigmoid(self.final_dense(feat))
return x, feat

class text_model_nn(nn.Module):
def __init__(self, input_shape):
super(text_model_nn, self).__init__()

self.lstm1 = nn.LSTM(input_size=input_shape[1], hidden_size=32, batch_first=True, bidirectional=True)
self.attention = Attention()
self.dense1 = nn.Linear(64, 128)
self.addition = Addition()
self.dense2 = nn.Linear(128*2, 128)
self.final_dense = nn.Linear(128, 5)

def forward(self, x):
x, _ = self.lstm1(x)
x = self.attention(x, x)
x = self.dense1(x)
x = self.addition(x)
feat = self.dense2(x)
x = torch.sigmoid(self.final_dense(feat))
return x, feat

class text_model_b5(nn.Module):
def __init__(self):
super(text_model_b5, self).__init__()
self.dense = nn.Linear(10, 5)

def forward(self, input_1, input_2):
X = torch.cat((input_1, input_2), dim=1)
X = torch.sigmoid(self.dense(X))
return X
36 changes: 36 additions & 0 deletions oceanai/modules/lab/architectures/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Утилиты модели ResNet50
"""
from __future__ import print_function

import torch
from torchvision import transforms
from PIL import Image

def preprocess_input(fp):
class PreprocessInput(torch.nn.Module):
def init(self):
super(PreprocessInput, self).init()

def forward(self, x):
x = x.to(torch.float32)
x = torch.flip(x, dims=(0,))
# x[0, :, :] -= 91.4953
# x[1, :, :] -= 103.8827
# x[2, :, :] -= 131.0912
x[0, :, :] -= 93.5940
x[1, :, :] -= 104.7624
x[2, :, :] -= 129.1863
return x

def get_img_torch(img, target_size=(224, 224)):
transform = transforms.Compose([transforms.PILToTensor(), PreprocessInput()])
img = img.resize(target_size, Image.Resampling.NEAREST)
img = transform(img)
img = torch.unsqueeze(img, 0)
return img

return get_img_torch(fp)
Loading

0 comments on commit 052840c

Please sign in to comment.