Skip to content

Commit

Permalink
add a parameter base_criterion to models (#217)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgnassou committed Jul 19, 2024
1 parent ce1d60c commit 28aaf82
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 16 deletions.
18 changes: 16 additions & 2 deletions skada/deep/_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def DANN(
reg=1,
domain_classifier=None,
num_features=None,
base_criterion=None,
domain_criterion=None,
**kwargs,
):
Expand Down Expand Up @@ -109,6 +110,9 @@ def DANN(
the feature extractor.
If domain_classifier is None, num_features has to be
provided.
base_criterion : torch criterion (class)
The base criterion used to compute the loss with source
labels. If None, the default is `torch.nn.CrossEntropyLoss`.
domain_criterion : torch criterion (class)
The criterion (loss) used to compute the
DANN loss. If None, a BCELoss is used.
Expand All @@ -127,14 +131,17 @@ def DANN(
)
domain_classifier = DomainClassifier(num_features=num_features)

if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()

net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
module__domain_classifier=domain_classifier,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion,
criterion__criterion=nn.CrossEntropyLoss(),
criterion__base_criterion=base_criterion,
criterion__reg=reg,
criterion__adapt_criterion=DANNLoss(domain_criterion=domain_criterion),
**kwargs,
Expand Down Expand Up @@ -319,6 +326,7 @@ def CDAN(
domain_classifier=None,
num_features=None,
n_classes=None,
base_criterion=None,
domain_criterion=None,
**kwargs,
):
Expand Down Expand Up @@ -351,6 +359,9 @@ def CDAN(
n_classes : int, default None
Number of output classes.
If domain_classifier is None, n_classes has to be provided.
base_criterion : torch criterion (class)
The base criterion used to compute the loss with source
labels. If None, the default is `torch.nn.CrossEntropyLoss`.
domain_criterion : torch criterion (class)
The criterion (loss) used to compute the
CDAN loss. If None, a BCELoss is used.
Expand All @@ -372,6 +383,9 @@ def CDAN(
num_features = np.min([num_features * n_classes, max_features])
domain_classifier = DomainClassifier(num_features=num_features)

if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()

net = DomainAwareNet(
module=CDANModule,
module__base_module=module,
Expand All @@ -380,7 +394,7 @@ def CDAN(
module__max_features=max_features,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion,
criterion__criterion=nn.CrossEntropyLoss(),
criterion__base_criterion=base_criterion,
criterion__reg=reg,
criterion__adapt_criterion=CDANLoss(domain_criterion=domain_criterion),
**kwargs,
Expand Down
24 changes: 17 additions & 7 deletions skada/deep/_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def forward(
return loss


def DeepCoral(module, layer_name, reg=1, **kwargs):
def DeepCoral(module, layer_name, reg=1, base_criterion=None, **kwargs):
"""DeepCORAL domain adaptation method.
From [12]_.
Expand All @@ -64,20 +64,26 @@ def DeepCoral(module, layer_name, reg=1, **kwargs):
collected during the training for the adaptation.
reg : float, optional (default=1)
The regularization parameter of the covariance estimator.
base_criterion : torch criterion (class)
The base criterion used to compute the loss with source
labels. If None, the default is `torch.nn.CrossEntropyLoss`.
References
----------
.. [12] Baochen Sun and Kate Saenko. Deep coral:
Correlation alignment for deep domain
adaptation. In ECCV Workshops, 2016.
"""
if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()

net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion,
criterion__criterion=torch.nn.CrossEntropyLoss(),
criterion__base_criterion=base_criterion,
criterion__reg=reg,
criterion__adapt_criterion=DeepCoralLoss(),
**kwargs,
Expand Down Expand Up @@ -123,7 +129,7 @@ def forward(
return loss


def DAN(module, layer_name, reg=1, sigmas=None, **kwargs):
def DAN(module, layer_name, reg=1, sigmas=None, base_criterion=None, **kwargs):
"""DAN domain adaptation method.
See [14]_.
Expand All @@ -139,22 +145,26 @@ def DAN(module, layer_name, reg=1, sigmas=None, **kwargs):
The regularization parameter of the covariance estimator.
sigmas : array-like, optional (default=None)
The sigmas for the Gaussian kernel.
base_criterion : torch criterion (class)
The base criterion used to compute the loss with source
labels. If None, the default is `torch.nn.CrossEntropyLoss`.
References
----------
.. [14] Mingsheng Long et. al. Learning Transferable
Features with Deep Adaptation Networks.
In ICML, 2015.
"""
if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()

net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion(
torch.nn.CrossEntropyLoss(), DANLoss(sigmas=sigmas), reg=reg
),
criterion__criterion=torch.nn.CrossEntropyLoss(),
criterion=DomainAwareCriterion,
criterion__base_criterion=base_criterion,
criterion__reg=reg,
criterion__adapt_criterion=DANLoss(sigmas=sigmas),
**kwargs,
Expand Down
20 changes: 17 additions & 3 deletions skada/deep/_optimal_transport.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Author: Theo Gnassounou <theo.gnassounou@inria.fr>
#
# License: BSD 3-Clause
from torch import nn
import torch

from skada.deep.base import (
BaseDALoss,
Expand Down Expand Up @@ -67,7 +67,15 @@ def forward(
return loss


def DeepJDOT(module, layer_name, reg=1, reg_cl=1, target_criterion=None, **kwargs):
def DeepJDOT(
module,
layer_name,
reg=1,
reg_cl=1,
base_criterion=None,
target_criterion=None,
**kwargs,
):
"""DeepJDOT.
See [13]_.
Expand All @@ -83,6 +91,9 @@ def DeepJDOT(module, layer_name, reg=1, reg_cl=1, target_criterion=None, **kwarg
Regularization parameter.
reg_cl : float, default=1
Class distance term regularization parameter.
base_criterion : torch criterion (class)
The base criterion used to compute the loss with source
labels. If None, the default is `torch.nn.CrossEntropyLoss`.
target_criterion : torch criterion (class)
The uninitialized criterion (loss) used to compute the
DeepJDOT loss. The criterion should support reduction='none'.
Expand All @@ -96,13 +107,16 @@ def DeepJDOT(module, layer_name, reg=1, reg_cl=1, target_criterion=None, **kwarg
15th European Conference on Computer Vision,
September 2018. Springer.
"""
if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()

net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion,
criterion__criterion=nn.CrossEntropyLoss(),
criterion__base_criterion=base_criterion,
criterion__adapt_criterion=DeepJDOTLoss(reg_cl, target_criterion),
criterion__reg=reg,
**kwargs,
Expand Down
8 changes: 4 additions & 4 deletions skada/deep/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class DomainAwareCriterion(torch.nn.Module):
Parameters
----------
criterion : torch criterion (class)
base_criterion : torch criterion (class)
The initialized criterion (loss) used to optimize the
module with prediction on source.
adapt_criterion : torch criterion (class)
Expand All @@ -32,9 +32,9 @@ class DomainAwareCriterion(torch.nn.Module):
Regularization parameter.
"""

def __init__(self, criterion, adapt_criterion, reg=1):
def __init__(self, base_criterion, adapt_criterion, reg=1):
super(DomainAwareCriterion, self).__init__()
self.criterion = criterion
self.base_criterion = base_criterion
self.adapt_criterion = adapt_criterion
self.reg = reg

Expand Down Expand Up @@ -73,7 +73,7 @@ def forward(
features_t = features[~source_idx]

# predict
return self.criterion(
return self.base_criterion(
y_pred_s, y_true[source_idx]
) + self.reg * self.adapt_criterion(
y_true[source_idx],
Expand Down

0 comments on commit 28aaf82

Please sign in to comment.