-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PR4: Add deep_mmd_loss files #170
base: main
Are you sure you want to change the base?
Conversation
EvaluationLosses: an instance of EvaluationLosses containing checkpoint loss and additional losses | ||
indexed by name. | ||
""" | ||
for layer in self.flatten_feature_extraction_layers.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're going to be indexing into self.deep_mmd_losses
anyway, could we simply do
for layer_loss_module in self.deep_mmd_losses.values():
layer_loss_module.training = False
For Ditto, we do this process in validate
and train_by_steps
/train_by_epochs
for the global model, maybe we can just do this there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's still worth overriding compute_evaluation_loss
and compute_training_loss
and asserting that all layer_loss_module.training == False
or vice versa though to be safe 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also might be missing this, but I don't see where we set layer_loss_module.training
to True
in the client. Based on the loss code, this would mean that we won't run training of the deep kernels after the first server round, which I think we want to keep doing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! The True
setting was indeed missing, so I added it to the update_before_train
function. Following your suggestion, I moved the False
setting to the validate
function. I kept the assertions in both compute_evaluation_loss
and compute_training_loss
functions for consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can just iterate through the self.deep_mmd_losses
values and do assertions I think?
for layer_loss_module in self.deep_mmd_losses.values():
assert not layer_loss_module.training
fl4health/losses/deep_mmd_loss.py
Outdated
list(self.featurizer.parameters()) + [self.epsilonOPT] + [self.sigmaOPT] + [self.sigma0OPT], lr=self.lr | ||
) | ||
|
||
def Pdist2(self, x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe expand this to pairwise_distiance_squared
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like we don't leverage the fact that y can be none to get the distances of x with itself. Maybe we just drop that option and require y to be passed to simplify this function.
fl4health/losses/deep_mmd_loss.py
Outdated
# Compute output of deep network | ||
model_output = self.featurizer(features) | ||
# Compute epsilon, sigma and sigma_0 | ||
ep = torch.exp(self.epsilonOPT) / (1 + torch.exp(self.epsilonOPT)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename epsilon
and note that it is the epsilon in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't look like we did this? I think both the rename and comment are worthwhile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I missed this
fl4health/losses/deep_mmd_loss.py
Outdated
# Compute epsilon, sigma and sigma_0 | ||
ep = torch.exp(self.epsilonOPT) / (1 + torch.exp(self.epsilonOPT)) | ||
sigma = self.sigmaOPT**2 | ||
sigma0_u = self.sigma0OPT**2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
based on the implementation of MMDu
I would suggest renaming sigma0
to sigma_phi
, sigma0OPT
to sigma_phi_opt
and sigma0_u
to sigma_phi
(since there doesn't seem to be any reason to have _u
in there anyway. Similarly, anything that is sigma or sigmaOPT can be sigma_q
or sigma_q_opt
to match the notation of the paper.
research/flamby/fed_isic2019/ditto_deep_mmd/run_fold_experiment.slrm
Outdated
Show resolved
Hide resolved
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice changes. Just added a few small comments and reminders of a few pieces you might have overlooked in my comments. Very close to ready to go!
EvaluationLosses: an instance of EvaluationLosses containing checkpoint loss and additional losses | ||
indexed by name. | ||
""" | ||
for layer in self.flatten_feature_extraction_layers.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can just iterate through the self.deep_mmd_losses
values and do assertions I think?
for layer_loss_module in self.deep_mmd_losses.values():
assert not layer_loss_module.training
for layer, layer_deep_mmd_loss in self.deep_mmd_losses.items(): | ||
deep_mmd_loss = layer_deep_mmd_loss(features[layer], features[" ".join(["init_global", layer])]) | ||
additional_losses["_".join(["deep_mmd_loss", layer])] = deep_mmd_loss | ||
total_deep_mmd_loss += deep_mmd_loss | ||
total_loss += self.deep_mmd_loss_weight * total_deep_mmd_loss | ||
additional_losses["deep_mmd_loss_total"] = total_deep_mmd_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to be safe, maybe we can clone total_deep_mmd_loss
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added that but I am checking bunch of other ditto versions and we don't have any where. I am wondering whether I should update them or not.
fl4health/losses/deep_mmd_loss.py
Outdated
# Compute output of deep network | ||
model_output = self.featurizer(features) | ||
# Compute epsilon, sigma and sigma_0 | ||
ep = torch.exp(self.epsilonOPT) / (1 + torch.exp(self.epsilonOPT)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't look like we did this? I think both the rename and comment are worthwhile
fl4health/losses/deep_mmd_loss.py
Outdated
ep = torch.exp(self.epsilonOPT) / (1 + torch.exp(self.epsilonOPT)) | ||
sigma = self.sigmaOPT**2 | ||
sigma0_u = self.sigma0OPT**2 | ||
# Compute Compute J (STAT_u) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd include the notation mention in your comment as well if you're alright with it (\hat{J}_{\lambda})
PR Type
[Feature]
Short Description
This is a tentative implementation for deep mmd loss.
Tests Added
No tests added yet.