-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
129 lines (107 loc) · 5.36 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
from torch import nn
from torch.optim import RMSprop, Adam
# Feature corruption function
def feature_corruption(x, corruption_ratio=0.3):
# We sample a mask of the features to be zeroed out
corruption_mask = torch.bernoulli(torch.full(x.shape, 1-corruption_ratio)).to(x.device)
return x * corruption_mask
# Encoder network with a three-layer transformer
class Encoder(nn.Module):
def __init__(self, feature_size, num_heads=2):
super(Encoder, self).__init__()
self.transformer_layers = nn.Sequential(
nn.TransformerEncoderLayer(d_model=feature_size, nhead=num_heads),
nn.TransformerEncoderLayer(d_model=feature_size, nhead=num_heads),
nn.TransformerEncoderLayer(d_model=feature_size, nhead=num_heads)
)
def forward(self, x):
# Since Transformer expects seq_length x batch x features, we assume x is already shaped correctly
return self.transformer_layers(x)
# Projector network
class Projector(nn.Module):
def __init__(self, feature_size):
super(Projector, self).__init__()
self.linear = nn.Linear(feature_size, feature_size)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
return self.sigmoid(self.linear(x))
# Decoder network
class Decoder(nn.Module):
def __init__(self, input_feature_size, output_feature_size):
super(Decoder, self).__init__()
self.linear = nn.Linear(input_feature_size, output_feature_size)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
return self.sigmoid(self.linear(x))
# Prediction network for pre-training
class Predictor(nn.Module):
def __init__(self, feature_size, num_classes):
super(Predictor, self).__init__()
self.linear = nn.Linear(feature_size, num_classes)
def forward(self, x):
return self.linear(x)
# Implementing the self-supervised learning algorithm with the updated components
def self_supervised_learning_with_switchtab(data, batch_size, feature_size, num_classes):
# Assuming data is a PyTorch dataset
batch_size = 128
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
# Initialize the components with the feature size
f_encoder = Encoder(feature_size)
pm_mutual = Projector(feature_size)
ps_salient = Projector(feature_size)
d_decoder = Decoder(feature_size)
pred_predictor = Predictor(feature_size, num_classes) # For pre-training stage with labels
# Loss function and optimizer
mse_loss = nn.MSELoss()
# Optimizer for pre-training
pretrain_optimizer = RMSprop(list(f_encoder.parameters()) + list(pm_mutual.parameters()) +
list(ps_salient.parameters()) + list(d_decoder.parameters()) +
list(pred_predictor.parameters()), lr=0.0003)
# Pre-training loop
print_interval = 50
for epoch in range(1000):
for x1_batch, x2_batch in zip(dataloader, dataloader):
# Feature corruption
x1_corrupted = feature_corruption(x1_batch)
x2_corrupted = feature_corruption(x2_batch)
# Data encoding
z1_encoded = f_encoder(x1_corrupted)
z2_encoded = f_encoder(x2_corrupted)
# Feature decoupling
s1_salient = ps_salient(z1_encoded)
m1_mutual = pm_mutual(z1_encoded)
s2_salient = ps_salient(z2_encoded)
m2_mutual = pm_mutual(z2_encoded)
# Data reconstruction
x1_reconstructed = d_decoder(torch.cat((m1_mutual, s1_salient), dim=1))
x2_reconstructed = d_decoder(torch.cat((m2_mutual, s2_salient), dim=1))
x1_switched = d_decoder(torch.cat((m2_mutual, s1_salient), dim=1))
x2_switched = d_decoder(torch.cat((m1_mutual, s2_salient), dim=1))
# Calculate loss
loss = mse_loss(x1_batch, x1_reconstructed) + mse_loss(x2_batch, x2_reconstructed) + mse_loss(x1_batch, x1_switched) + mse_loss(x2_batch, x2_switched)
# Update model parameters
pretrain_optimizer.zero_grad()
loss.backward()
pretrain_optimizer.step()
# Print loss every print_interval epochs
if (epoch+1) % print_interval == 0:
print(f'Epoch [{epoch+1}/1000], Pre-training Loss: {loss.item():.4f}')
# Fine-tuning loop
fine_tuning_loss_function = nn.CrossEntropyLoss()
fine_tuning_optimizer = Adam(f_encoder.parameters(), lr=0.001)
for epoch in range(200):
for x_batch, labels in dataloader:
# Assume that now we have labels
z_encoded = f_encoder(x_batch)
predictions = pred_predictor(z_encoded)
# Replace 'some_loss_function' with the actual loss function used for fine-tuning
prediction_loss = fine_tuning_loss_function(predictions, labels)
fine_tuning_optimizer.zero_grad()
prediction_loss.backward()
fine_tuning_optimizer.step()
# Print loss every print_interval epochs
if (epoch+1) % print_interval == 0:
print(f'Epoch [{epoch+1}/200], Fine-tuning Loss: {prediction_loss.item():.4f}')
# You would call this function with your dataset, batch size, feature size, and the number of classes
# self_supervised_learning_with_switchtab(data, batch_size, feature_size, num_classes)