Skip to content
This repository has been archived by the owner on Sep 24, 2023. It is now read-only.

Caveat in last commit #12

Open
ipod825 opened this issue Apr 15, 2018 · 13 comments
Open

Caveat in last commit #12

ipod825 opened this issue Apr 15, 2018 · 13 comments

Comments

@ipod825
Copy link

ipod825 commented Apr 15, 2018

99c4cbe#diff-40d9c2c37e955447b1175a32afab171fL353
This is not an unnecessary detach.
As it is used in
log_pi = Normal(mu, self.std).log_prob(l_t)
which is then used in
loss_reinforce = torch.sum(-log_pi*adjusted_reward, dim=1)
which means when minimizing reinforce loss, you are altering your location network through both mu and l_t (and yes, log_pi is differentiable w.r.t both mu and l_t). However, l_t is just mu+noise and we only want the gradient to flow through mu.

@ipod825 ipod825 changed the title Cavet in last commit Caveat in last commit Apr 15, 2018
@kevinzakka
Copy link
Owner

kevinzakka commented Apr 15, 2018

Won't l_t.detach() stop the gradients for l_t and hence for mu? I mean, we'll still have gradient for mu from log_pi but there's a contribution from l_t as well.

This is my understanding:

We want the weights of the location network to be trained using REINFORCE. Now the hidden state vector h_t (detached because we do not want to train the weights of the RNN with REINFORCE) is fed through the fully-connected layer of the location net to produce the mean mu. Then using the reparametrization trick, we sample a location vector l_t from a Gaussian parametrized by mu.

Doing the above, we've made it possible to backpropagate through l_t and hence back to mu which means we can train the weights of the location network.

@ipod825
Copy link
Author

ipod825 commented Apr 15, 2018

mu and l_t are two separate Variable (though highly correlated). l_t.detach() does not stop you from calculating d loss_reinforce / d mu. For example, consider the following math:

x = 1
y = x + 1
z = x * y

Both dz/dx and dz/dy are well defined. Even if you "detach" y, you can still calculate dz/dx

@ipod825
Copy link
Author

ipod825 commented Apr 15, 2018

Oh. I think you know what I meant.
I'll think more and reply again.

@ipod825
Copy link
Author

ipod825 commented Apr 15, 2018

I think we shouldn't flow the information through l_t.
Intuitively, for loss_reinforce to decrease, we want log_pi to increase.
To have log_pi to increase, we want mu and l_t to be closer.
Assume mu < l_t, gradient flow then tries to increase mu and decrease l_t
simultaneously. However, decreasing l_t essentially decreasing mu as l_t = mu + noise.
If you try to deriving the formula for the gradients, one should be the negative of the other
as the kernel of Gaussian is (l_t-mu)^2, so they should cancel with each other.

@kevinzakka
Copy link
Owner

kevinzakka commented Apr 15, 2018

@ipod825 I need to think about it some more. Empirically, I haven't seen a performance difference between the 2. I still reach ~1.3-1.4% error in about 30 epochs of training.

What's bugging me right now is that I learned about the reparametrization trick this weekend, which essentially makes it possible to backprop through a sampled variable. So right now, I'm confused as to why we even need REINFORCE to train our network. We could just use the reparametrization trick like in VAEs to make the whole process differentiable and directly optimize for the weights of the location network.

I'll give it some more thought tonight.

@ipod825
Copy link
Author

ipod825 commented Apr 15, 2018

Performance issue might not be related to all this formula issue. If you check this
thread, you'll see many of the
implementations online doesn't even learn anything for their location network
but still get good performance on MNIST.

@ipod825
Copy link
Author

ipod825 commented Apr 15, 2018

Also, I don't think re-parametrization trick applies to this scenario.
Re-parametrization requires your target function (in our-scenario, the reward) to be differentiable w.r.t to its parameters. However, our reward is just an unknown function that we don't even have a formula for that.

@ipod825
Copy link
Author

ipod825 commented Apr 18, 2018

l_t = F.tanh(l_t)

This line is related to this issue.
You shouldn't apply tanh on l_t again. Say mu is 100, tanh(mu)=1.0. Even after adding a noinse, tanh(l_t) ~ tanh(1.0) = 0.76159.

A better idea is to use tocrh.clamp(l_t, -1, 1)

@kevinzakka
Copy link
Owner

@ipod825 The PDF of a normal distribution is not bounded, so it is not guaranteed that l_t will never exceed [-1,1].

I was against using torch.clamp because it is not as smooth as tanh. Why do you think it's a better idea?

@ipod825
Copy link
Author

ipod825 commented Apr 18, 2018

        mu = F.tanh(self.fc(h_t.detach()))
        # reparametrization trick
        noise = torch.zeros_like(mu)
        noise.data.normal_(std=self.std)
        l_t = mu + noise

        # bound between [-1, 1]
        l_t = F.tanh(l_t)

l_t is squeezed by tanh two times while mu is squeezed only one time.
When mu saturates to 1.0, l_t is almost surely to be smaller than 1.0 as I described above.
Second, if you modify the code as following

        mu = F.clamp(self.fc(h_t.detach()), -1,1)
        # reparametrization trick
        noise = torch.zeros_like(mu)
        noise.data.normal_(std=self.std)
        l_t = mu + noise

        # bound between [-1, 1]
        l_t = F.clamp(l_t,-1,1)

And do not detach the l_t in

log_pi = Normal(mu, self.std).log_prob(l_t)

You can check that the gradient in the location network is actually 0, as predicted by the discussion above. But if you use tanh, the gradient wouldn't be 0, as mu and l_t is not squeezed in the same way.

@xycforgithub
Copy link

@ipod825 Have you tried your implementation using clamp and l_t.detach()? I tried that and got a very high performance on 6 glimpses, 8*8, 1 scale setting, around 0.58%. Paper reported 1.12%.

@ipod825
Copy link
Author

ipod825 commented May 31, 2018

I never got error lower than 1%. If you use only vanilla RNN (as already implemented by @kevinzakka), that would be an interesting result. If you consistently got similar results, it would be nice if you can share your code and let others figure out why it works so well.

@sujoyp
Copy link

sujoyp commented Nov 5, 2018

@kevinzakka

We want the weights of the location network to be trained using REINFORCE. Now the hidden state vector h_t (detached because we do not want to train the weights of the RNN with REINFORCE) is fed through the fully-connected layer of the location net to produce the mean mu. Then using the reparametrization trick, we sample a location vector l_t from a Gaussian parametrized by mu.

Why should we not train the weights of the RNN with REINFORCE ?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants