Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add multi layer adaptation possibility + add flattening of each feature outputs #218

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/deep/plot_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
# ----------------------------------------------------------------------------
model = DANN(
MNISTtoUSPSNet(),
layer_name="fc1",
layer_names="fc1",
batch_size=128,
max_epochs=5,
train_split=False,
Expand Down
2 changes: 1 addition & 1 deletion examples/deep/plot_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
# ----------------------------------------------------------------------------
model = DeepCoral(
MNISTtoUSPSNet(),
layer_name="fc1",
layer_names="fc1",
batch_size=128,
max_epochs=5,
train_split=False,
Expand Down
2 changes: 1 addition & 1 deletion examples/deep/plot_optimal_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
# ----------------------------------------------------------------------------
model = DeepJDOT(
MNISTtoUSPSNet(),
layer_name="fc1",
layer_names="fc1",
batch_size=128,
max_epochs=5,
train_split=False,
Expand Down
4 changes: 2 additions & 2 deletions examples/deep/plot_training_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
# ----------------------------------------------------------------------------
model = DeepCoral(
MNISTtoUSPSNet(),
layer_name="fc1",
layer_names="fc1",
batch_size=batch_size,
max_epochs=max_epochs,
train_split=False,
Expand All @@ -68,7 +68,7 @@

model = DeepCoral(
MNISTtoUSPSNet(),
layer_name="fc1",
layer_names="fc1",
batch_size=batch_size,
max_epochs=max_epochs,
train_split=False,
Expand Down
137 changes: 81 additions & 56 deletions skada/deep/_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def forward(

def DANN(
module,
layer_name,
layer_names,
reg=1,
domain_classifier=None,
num_features=None,
Expand All @@ -91,11 +91,11 @@ def DANN(

Parameters
----------
module : torch module (class or instance)
module : torch module (class or instance)v
A PyTorch :class:`~torch.nn.Module`. In general, the
uninstantiated class should be passed, although instantiated
modules will also work.
layer_name : str
layer_names : str
The name of the module's layer whose outputs are
collected during the training.
reg : float, default=1
Expand Down Expand Up @@ -130,7 +130,7 @@ def DANN(
net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
module__layer_names=layer_names,
module__domain_classifier=domain_classifier,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion,
Expand Down Expand Up @@ -211,8 +211,9 @@ class CDANModule(DomainAwareModule):
module : torch module (class or instance)
A PyTorch :class:`~torch.nn.Module`.
layer_name : str
The name of the module's layer whose outputs are
collected during the training for adaptation.
List of the names of the module's layers whose outputs are
collected during the training for the adaptation.
If only one layer is needed, it could be a string.
domain_classifier : torch module
A PyTorch :class:`~torch.nn.Module` used to classify the
domain.
Expand All @@ -230,12 +231,12 @@ class CDANModule(DomainAwareModule):
def __init__(
self,
base_module,
layer_name,
layer_names,
domain_classifier,
max_features=4096,
random_state=42,
):
super().__init__(base_module, layer_name, domain_classifier)
super().__init__(base_module, layer_names, domain_classifier)
self.max_features = max_features
self.random_state = random_state

Expand All @@ -245,75 +246,99 @@ def forward(self, X, sample_domain=None, is_fit=False, return_features=False):

X_t = X[~source_idx]
X_s = X[source_idx]

# predict
y_pred_s = self.base_module_(X_s)
features_s = self.intermediate_layers[self.layer_name]
features_s = [
self.intermediate_layers[layer_name].reshape(len(X_s), -1)
for layer_name in self.layer_names
]
y_pred_t = self.base_module_(X_t)
features_t = self.intermediate_layers[self.layer_name]

n_classes = y_pred_s.shape[1]
n_features = features_s.shape[1]
if n_features * n_classes > self.max_features:
random_layer = _RandomLayer(
self.random_state,
input_dims=[n_features, n_classes],
output_dim=self.max_features,
)
else:
random_layer = None

# Compute the input for the domain classifier
if random_layer is None:
multilinear_map = torch.bmm(
y_pred_s.unsqueeze(2), features_s.unsqueeze(1)
)
multilinear_map_target = torch.bmm(
y_pred_t.unsqueeze(2), features_t.unsqueeze(1)
)

multilinear_map = multilinear_map.view(-1, n_features * n_classes)
multilinear_map_target = multilinear_map_target.view(
-1, n_features * n_classes
)

else:
multilinear_map = random_layer.forward([features_s, y_pred_s])
multilinear_map_target = random_layer.forward([features_t, y_pred_t])

domain_pred_s = self.domain_classifier_(multilinear_map)
domain_pred_t = self.domain_classifier_(multilinear_map_target)
domain_pred = torch.empty(len(sample_domain), device=domain_pred_s.device)
domain_pred[source_idx] = domain_pred_s
domain_pred[~source_idx] = domain_pred_t
features_t = [
self.intermediate_layers[layer_name].reshape(len(X_t), -1)
for layer_name in self.layer_names
]

y_pred = torch.empty(
(len(sample_domain), y_pred_s.shape[1]), device=y_pred_s.device
)
y_pred[source_idx] = y_pred_s
y_pred[~source_idx] = y_pred_t
n_classes = y_pred_s.shape[1]

features = torch.empty(
(len(sample_domain), features_s.shape[1]), device=features_s.device
)
features[source_idx] = features_s
features[~source_idx] = features_t
features = []
domain_preds = []
for i in range(len(features_s)):
features.append(
torch.empty(
(len(sample_domain), features_s[i].shape[1]),
device=features_s[i].device,
)
)
features[i][source_idx] = features_s[i]
features[i][~source_idx] = features_t[i]

n_features = features_s[i].shape[1]
if n_features * n_classes > self.max_features:
random_layer = _RandomLayer(
self.random_state,
input_dims=[n_features, n_classes],
output_dim=self.max_features,
)
else:
random_layer = None

# Compute the input for the domain classifier
if random_layer is None:
multilinear_map = torch.bmm(
y_pred_s.unsqueeze(2), features_s[i].unsqueeze(1)
)
multilinear_map_target = torch.bmm(
y_pred_t.unsqueeze(2), features_t[i].unsqueeze(1)
)

multilinear_map = multilinear_map.view(-1, n_features * n_classes)
multilinear_map_target = multilinear_map_target.view(
-1, n_features * n_classes
)

else:
multilinear_map = random_layer.forward([features_s[i], y_pred_s])
multilinear_map_target = random_layer.forward(
[features_t[i], y_pred_t]
)

domain_pred_s = self.domain_classifier_(multilinear_map)
domain_pred_t = self.domain_classifier_(multilinear_map_target)
domain_pred = torch.empty(
len(sample_domain), device=domain_pred_s.device
)
domain_pred[source_idx] = domain_pred_s
domain_pred[~source_idx] = domain_pred_t
domain_preds.append(domain_pred)

return (
y_pred,
domain_pred,
domain_preds,
features,
sample_domain,
)
else:
if return_features:
return self.base_module_(X), self.intermediate_layers[self.layer_name]
return (
self.base_module_(X),
[
self.intermediate_layers[layer_name].reshape(len(X), -1)
for layer_name in self.layer_names
],
)
else:
return self.base_module_(X)


def CDAN(
module,
layer_name,
layer_names,
reg=1,
max_features=4096,
domain_classifier=None,
Expand All @@ -332,7 +357,7 @@ def CDAN(
A PyTorch :class:`~torch.nn.Module`. In general, the
uninstantiated class should be passed, although instantiated
modules will also work.
layer_name : str
layer_names : str
The name of the module's layer whose outputs are
collected during the training.
reg : float, default=1
Expand All @@ -345,7 +370,7 @@ def CDAN(
A PyTorch :class:`~torch.nn.Module` used to classify the
domain. If None, a domain classifier is created following [1]_.
num_features : int, default=None
Size of the embedding space e.g. the size of the output of layer_name.
Size of the embedding space e.g. the size of the output of layer_names.
If domain_classifier is None, num_features has to be
provided.
n_classes : int, default None
Expand Down Expand Up @@ -375,7 +400,7 @@ def CDAN(
net = DomainAwareNet(
module=CDANModule,
module__base_module=module,
module__layer_name=layer_name,
module__layer_names=layer_names,
module__domain_classifier=domain_classifier,
module__max_features=max_features,
iterator_train=DomainBalancedDataLoader,
Expand Down
18 changes: 10 additions & 8 deletions skada/deep/_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def forward(
return loss


def DeepCoral(module, layer_name, reg=1, **kwargs):
def DeepCoral(module, layer_names, reg=1, **kwargs):
"""DeepCORAL domain adaptation method.

From [12]_.
Expand All @@ -59,9 +59,10 @@ def DeepCoral(module, layer_name, reg=1, **kwargs):
----------
module : torch module (class or instance)
A PyTorch :class:`~torch.nn.Module`.
layer_name : str
The name of the module's layer whose outputs are
layer_names : str
List of the names of the module's layers whose outputs are
collected during the training for the adaptation.
If only one layer is needed, it could be a string.
reg : float, optional (default=1)
The regularization parameter of the covariance estimator.

Expand All @@ -74,7 +75,7 @@ def DeepCoral(module, layer_name, reg=1, **kwargs):
net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
module__layer_names=layer_names,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion,
criterion__criterion=torch.nn.CrossEntropyLoss(),
Expand Down Expand Up @@ -123,7 +124,7 @@ def forward(
return loss


def DAN(module, layer_name, reg=1, sigmas=None, **kwargs):
def DAN(module, layer_names, reg=1, sigmas=None, **kwargs):
"""DAN domain adaptation method.

See [14]_.
Expand All @@ -132,9 +133,10 @@ def DAN(module, layer_name, reg=1, sigmas=None, **kwargs):
----------
module : torch module (class or instance)
A PyTorch :class:`~torch.nn.Module`.
layer_name : str
The name of the module's layer whose outputs are
layer_names : str
List of the names of the module's layers whose outputs are
collected during the training for the adaptation.
If only one layer is needed, it could be a string.
reg : float, optional (default=1)
The regularization parameter of the covariance estimator.
sigmas : array-like, optional (default=None)
Expand All @@ -149,7 +151,7 @@ def DAN(module, layer_name, reg=1, sigmas=None, **kwargs):
net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
module__layer_names=layer_names,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion(
torch.nn.CrossEntropyLoss(), DANLoss(sigmas=sigmas), reg=reg
Expand Down
9 changes: 5 additions & 4 deletions skada/deep/_optimal_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def forward(
return loss


def DeepJDOT(module, layer_name, reg=1, reg_cl=1, target_criterion=None, **kwargs):
def DeepJDOT(module, layer_names, reg=1, reg_cl=1, target_criterion=None, **kwargs):
"""DeepJDOT.

See [13]_.
Expand All @@ -76,9 +76,10 @@ def DeepJDOT(module, layer_name, reg=1, reg_cl=1, target_criterion=None, **kwarg
----------
module : torch module (class or instance)
A PyTorch :class:`~torch.nn.Module`.
layer_name : str
The name of the module's layer whose outputs are
layer_names : str
List of the names of the module's layers whose outputs are
collected during the training for the adaptation.
If only one layer is needed, it could be a string.
reg : float, default=1
Regularization parameter.
reg_cl : float, default=1
Expand All @@ -99,7 +100,7 @@ def DeepJDOT(module, layer_name, reg=1, reg_cl=1, target_criterion=None, **kwarg
net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
module__layer_names=layer_names,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion,
criterion__criterion=nn.CrossEntropyLoss(),
Expand Down
Loading
Loading