diff --git a/deepxde/callbacks.py b/deepxde/callbacks.py index 3039390df..ce6957bc8 100644 --- a/deepxde/callbacks.py +++ b/deepxde/callbacks.py @@ -6,7 +6,7 @@ from . import config from . import gradients as grad from . import utils -from .backend import backend_name, tf, torch, paddle +from .backend import backend_name, tf, torch, paddle, Variable class Callback: @@ -571,3 +571,34 @@ def on_epoch_end(self): raise ValueError( "`num_bcs` changed! Please update the loss function by `model.compile`." ) + + +class SoftAdapt(Callback): + """Use adaptive loss balancing. + + Args: + beta: If beta > 0, then softAdapt will pay more attention the worst performing + loss component. If beta < 0, then SoftAdapt will assign higher weights + to the better performing components. Beta==0 is the trivial case and + all loss components will have coefficient 1. + epsilon: parameter to prevent overflows. + + """ + + def __init__(self, beta=0.1, epsilon=1e-8): + super().__init__() + + self.beta = beta + self.epsilon = epsilon + + def on_train_begin(self): + loss_weights = tf.constant(self.model.loss_weights) + loss_weights = Variable(loss_weights, dtype=loss_weights.dtype) + loss_weights *= 0 + + self.model.loss_weights = loss_weights + + print(loss_weights, "loss_weights") + # Allow instances to be re-used. + # Evaluate coefficients. + # Update weights.