diff --git a/skada/deep/_adversarial.py b/skada/deep/_adversarial.py index 7e2aeb0c..3e888aff 100644 --- a/skada/deep/_adversarial.py +++ b/skada/deep/_adversarial.py @@ -82,6 +82,7 @@ def DANN( reg=1, domain_classifier=None, num_features=None, + base_criterion=None, domain_criterion=None, **kwargs, ): @@ -109,6 +110,9 @@ def DANN( the feature extractor. If domain_classifier is None, num_features has to be provided. + base_criterion : torch criterion (class) + The base criterion used to compute the loss with source + labels. If None, the default is `torch.nn.CrossEntropyLoss`. domain_criterion : torch criterion (class) The criterion (loss) used to compute the DANN loss. If None, a BCELoss is used. @@ -127,6 +131,9 @@ def DANN( ) domain_classifier = DomainClassifier(num_features=num_features) + if base_criterion is None: + base_criterion = torch.nn.CrossEntropyLoss() + net = DomainAwareNet( module=DomainAwareModule, module__base_module=module, @@ -134,7 +141,7 @@ def DANN( module__domain_classifier=domain_classifier, iterator_train=DomainBalancedDataLoader, criterion=DomainAwareCriterion, - criterion__criterion=nn.CrossEntropyLoss(), + criterion__base_criterion=base_criterion, criterion__reg=reg, criterion__adapt_criterion=DANNLoss(domain_criterion=domain_criterion), **kwargs, @@ -319,6 +326,7 @@ def CDAN( domain_classifier=None, num_features=None, n_classes=None, + base_criterion=None, domain_criterion=None, **kwargs, ): @@ -351,6 +359,9 @@ def CDAN( n_classes : int, default None Number of output classes. If domain_classifier is None, n_classes has to be provided. + base_criterion : torch criterion (class) + The base criterion used to compute the loss with source + labels. If None, the default is `torch.nn.CrossEntropyLoss`. domain_criterion : torch criterion (class) The criterion (loss) used to compute the CDAN loss. If None, a BCELoss is used. @@ -372,6 +383,9 @@ def CDAN( num_features = np.min([num_features * n_classes, max_features]) domain_classifier = DomainClassifier(num_features=num_features) + if base_criterion is None: + base_criterion = torch.nn.CrossEntropyLoss() + net = DomainAwareNet( module=CDANModule, module__base_module=module, @@ -380,7 +394,7 @@ def CDAN( module__max_features=max_features, iterator_train=DomainBalancedDataLoader, criterion=DomainAwareCriterion, - criterion__criterion=nn.CrossEntropyLoss(), + criterion__base_criterion=base_criterion, criterion__reg=reg, criterion__adapt_criterion=CDANLoss(domain_criterion=domain_criterion), **kwargs, diff --git a/skada/deep/_divergence.py b/skada/deep/_divergence.py index cade5799..c1715f57 100644 --- a/skada/deep/_divergence.py +++ b/skada/deep/_divergence.py @@ -50,7 +50,7 @@ def forward( return loss -def DeepCoral(module, layer_name, reg=1, **kwargs): +def DeepCoral(module, layer_name, reg=1, base_criterion=None, **kwargs): """DeepCORAL domain adaptation method. From [12]_. @@ -64,6 +64,9 @@ def DeepCoral(module, layer_name, reg=1, **kwargs): collected during the training for the adaptation. reg : float, optional (default=1) The regularization parameter of the covariance estimator. + base_criterion : torch criterion (class) + The base criterion used to compute the loss with source + labels. If None, the default is `torch.nn.CrossEntropyLoss`. References ---------- @@ -71,13 +74,16 @@ def DeepCoral(module, layer_name, reg=1, **kwargs): Correlation alignment for deep domain adaptation. In ECCV Workshops, 2016. """ + if base_criterion is None: + base_criterion = torch.nn.CrossEntropyLoss() + net = DomainAwareNet( module=DomainAwareModule, module__base_module=module, module__layer_name=layer_name, iterator_train=DomainBalancedDataLoader, criterion=DomainAwareCriterion, - criterion__criterion=torch.nn.CrossEntropyLoss(), + criterion__base_criterion=base_criterion, criterion__reg=reg, criterion__adapt_criterion=DeepCoralLoss(), **kwargs, @@ -123,7 +129,7 @@ def forward( return loss -def DAN(module, layer_name, reg=1, sigmas=None, **kwargs): +def DAN(module, layer_name, reg=1, sigmas=None, base_criterion=None, **kwargs): """DAN domain adaptation method. See [14]_. @@ -139,6 +145,9 @@ def DAN(module, layer_name, reg=1, sigmas=None, **kwargs): The regularization parameter of the covariance estimator. sigmas : array-like, optional (default=None) The sigmas for the Gaussian kernel. + base_criterion : torch criterion (class) + The base criterion used to compute the loss with source + labels. If None, the default is `torch.nn.CrossEntropyLoss`. References ---------- @@ -146,15 +155,16 @@ def DAN(module, layer_name, reg=1, sigmas=None, **kwargs): Features with Deep Adaptation Networks. In ICML, 2015. """ + if base_criterion is None: + base_criterion = torch.nn.CrossEntropyLoss() + net = DomainAwareNet( module=DomainAwareModule, module__base_module=module, module__layer_name=layer_name, iterator_train=DomainBalancedDataLoader, - criterion=DomainAwareCriterion( - torch.nn.CrossEntropyLoss(), DANLoss(sigmas=sigmas), reg=reg - ), - criterion__criterion=torch.nn.CrossEntropyLoss(), + criterion=DomainAwareCriterion, + criterion__base_criterion=base_criterion, criterion__reg=reg, criterion__adapt_criterion=DANLoss(sigmas=sigmas), **kwargs, diff --git a/skada/deep/_optimal_transport.py b/skada/deep/_optimal_transport.py index 8f78e285..f1f0f547 100644 --- a/skada/deep/_optimal_transport.py +++ b/skada/deep/_optimal_transport.py @@ -1,7 +1,7 @@ # Author: Theo Gnassounou # # License: BSD 3-Clause -from torch import nn +import torch from skada.deep.base import ( BaseDALoss, @@ -67,7 +67,15 @@ def forward( return loss -def DeepJDOT(module, layer_name, reg=1, reg_cl=1, target_criterion=None, **kwargs): +def DeepJDOT( + module, + layer_name, + reg=1, + reg_cl=1, + base_criterion=None, + target_criterion=None, + **kwargs, +): """DeepJDOT. See [13]_. @@ -83,6 +91,9 @@ def DeepJDOT(module, layer_name, reg=1, reg_cl=1, target_criterion=None, **kwarg Regularization parameter. reg_cl : float, default=1 Class distance term regularization parameter. + base_criterion : torch criterion (class) + The base criterion used to compute the loss with source + labels. If None, the default is `torch.nn.CrossEntropyLoss`. target_criterion : torch criterion (class) The uninitialized criterion (loss) used to compute the DeepJDOT loss. The criterion should support reduction='none'. @@ -96,13 +107,16 @@ def DeepJDOT(module, layer_name, reg=1, reg_cl=1, target_criterion=None, **kwarg 15th European Conference on Computer Vision, September 2018. Springer. """ + if base_criterion is None: + base_criterion = torch.nn.CrossEntropyLoss() + net = DomainAwareNet( module=DomainAwareModule, module__base_module=module, module__layer_name=layer_name, iterator_train=DomainBalancedDataLoader, criterion=DomainAwareCriterion, - criterion__criterion=nn.CrossEntropyLoss(), + criterion__base_criterion=base_criterion, criterion__adapt_criterion=DeepJDOTLoss(reg_cl, target_criterion), criterion__reg=reg, **kwargs, diff --git a/skada/deep/base.py b/skada/deep/base.py index 69e0ba5d..d9dd2514 100644 --- a/skada/deep/base.py +++ b/skada/deep/base.py @@ -22,7 +22,7 @@ class DomainAwareCriterion(torch.nn.Module): Parameters ---------- - criterion : torch criterion (class) + base_criterion : torch criterion (class) The initialized criterion (loss) used to optimize the module with prediction on source. adapt_criterion : torch criterion (class) @@ -32,9 +32,9 @@ class DomainAwareCriterion(torch.nn.Module): Regularization parameter. """ - def __init__(self, criterion, adapt_criterion, reg=1): + def __init__(self, base_criterion, adapt_criterion, reg=1): super(DomainAwareCriterion, self).__init__() - self.criterion = criterion + self.base_criterion = base_criterion self.adapt_criterion = adapt_criterion self.reg = reg @@ -73,7 +73,7 @@ def forward( features_t = features[~source_idx] # predict - return self.criterion( + return self.base_criterion( y_pred_s, y_true[source_idx] ) + self.reg * self.adapt_criterion( y_true[source_idx],