Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Integrates JorgeKFAC with diffusion model training script #3

Draft
wants to merge 5 commits into
base: jorge
Choose a base branch
from

Conversation

keshprad
Copy link
Collaborator

@keshprad keshprad commented Jan 7, 2025

No description provided.

@keshprad
Copy link
Collaborator Author

keshprad commented Jan 7, 2025

  1. I still am having issues with sampling like I shared on slack. This is both w/ adam and jorge
  2. Not sure what to look for to verify correctness in the implementation...
    - General trend is that I see decreasing loss/mse; however the loss_sampled is around 0.9-1 throughout.

terms["loss_sampled"] = mean_flat((y_sampled - model_output) ** 2)
if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
# TODO: Should terms["vb"] be added to terms["loss_sampled"]?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should terms["vb"] be added to terms["loss_sampled"]?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most likely "vb" shouldn't be in terms. Can you check and let me know?

improved_diffusion/train_util_jorge.py Outdated Show resolved Hide resolved
loss_sampled = (losses["loss_sampled"] * weights).mean()
loss_sampled.backward(retain_graph=True)
self.opt.acc_stats = False
self.opt.zero_grad() # clear the gradient for computing true-fisher

if isinstance(self.schedule_sampler, LossAwareSampler):
self.schedule_sampler.update_with_local_losses(
Copy link
Collaborator Author

@keshprad keshprad Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what L233 (self.schedule_sampler.update_with_local_losses) does. Should anything here be modified to do similar for losses["loss_sampled"]?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to use the proxy loss (i.e. loss_sampled) anywhere in this function. This can remain untouched.

@keshprad keshprad marked this pull request as draft January 8, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants