Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A question about the gradients #9

Open
yurunsheng1 opened this issue Nov 27, 2019 · 1 comment
Open

A question about the gradients #9

yurunsheng1 opened this issue Nov 27, 2019 · 1 comment

Comments

@yurunsheng1
Copy link

yurunsheng1 commented Nov 27, 2019

Hi,
First thank you for providing us such a nice work!

But I meet a question and really need you help:

In your MeLU.py lines 71-79:

grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
            # local update
            for i in range(self.weight_len):
                if self.weight_name[i] in self.local_update_target_weight_name:
                    self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i]
                else:
                    self.fast_weights[self.weight_name[i]] = weight_for_local_update[i]
        self.model.load_state_dict(self.fast_weights)
        query_set_y_pred = self.model(query_set_x)

I understand this is the standard MAML approach (inner loop).

However, the function load_state_dict() will erase (break) the gradient (https://discuss.pytorch.org/t/loading-a-state-dict-seems-to-erase-grad/56676) and thus the global update will no longer consider the local update gradient in the final optimization. So, create_graph=True may not work and the algorithm may not be standard MAML any more. I am wondering whether I lose any insight behind that.

Looking forward to your reply!

@GGchen1997
Copy link

Hi,
First thank you for providing us such a nice work!

But I meet a question and really need you help:

In your MeLU.py lines 71-79:

grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
            # local update
            for i in range(self.weight_len):
                if self.weight_name[i] in self.local_update_target_weight_name:
                    self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i]
                else:
                    self.fast_weights[self.weight_name[i]] = weight_for_local_update[i]
        self.model.load_state_dict(self.fast_weights)
        query_set_y_pred = self.model(query_set_x)

I understand this is the standard MAML approach (inner loop).

However, the function load_state_dict() will erase (break) the gradient (https://discuss.pytorch.org/t/loading-a-state-dict-seems-to-erase-grad/56676) and thus the global update will no longer consider the local update gradient in the final optimization. So, create_graph=True may not work and the algorithm may not be standard MAML any more. I am wondering whether I lose any insight behind that.

Looking forward to your reply!

I believe you are right and the original code is wrong.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants