Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precndition w.r.t. 1 loss, gradient w.r.t another #2

Open
ConstantinPuiu opened this issue Jan 13, 2022 · 1 comment
Open

Precndition w.r.t. 1 loss, gradient w.r.t another #2

ConstantinPuiu opened this issue Jan 13, 2022 · 1 comment

Comments

@ConstantinPuiu
Copy link

Can your code be used to compute the KFAC (Fisher) matrix using a different loss than the loss we take the gradient of?

If so, how?

thanks

@n-gao
Copy link
Owner

n-gao commented Feb 10, 2022

Hi, sorry for the late reply.
You may use a different loss for KFAC and a different one for the gradients. For the KFAC loss, you should pick the probability distribution of your model's output while the loss may be any loss.

In the training loop from the MNIST example it may look like:

model = Classifier().cuda()
optim = KFAC(model, 9e-3, 1e-3, momentum_type='regular', momentum=0.95, adapt_damping=True, update_cov_manually=True)
model_logprob = nn.CrossEntropyLoss(reduction='mean')
loss_fn = <your loss here>

kfac_losses = []
with tqdm.tqdm(train_loader) as progress:
    for inp, labels in progress:
        inp, labels = inp.cuda(), labels.cuda()
        model.zero_grad()
        # Estimate with model distribution
        with optim.track_forward():
            out = model(inp)
            out_samples = torch.multinomial(torch.softmax(out.detach(), 1), 1).reshape(out.shape[0])
            loss = model_logprob(out, out_samples)
        with optim.track_backward():
            loss.backward()
        optim.update_cov()
        # Compute loss to backprop
        model.zero_grad()
        out = model(inp)
        loss = loss_fn(out, labels)
        loss.backward()
        optim.step(loss=loss)
        progress.set_postfix({
            'loss': loss.item(),
            'damping': optim.damping.item()
        })
        kfac_losses.append(loss.item())

sfetzel pushed a commit to sfetzel/pytorch-kfac that referenced this issue Jul 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants