Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backprop issue with membrane potential reset in PPO #340

Open
ADebor opened this issue Jul 30, 2024 · 0 comments
Open

Backprop issue with membrane potential reset in PPO #340

ADebor opened this issue Jul 30, 2024 · 0 comments

Comments

@ADebor
Copy link

ADebor commented Jul 30, 2024

  • snntorch version: 0.9.1
  • Python version: 3.10.2
  • Operating System: Ubuntu 22.04
  • device: GPU

Description

Hi there,

I'm trying to implement a basic RL training loop for spiking nets using snntorch and torchrl. In this Actor Critic setting, the actor is an SNN made of 3 parts: a population encoder, a spiking MLP, and a population decoder. The critic is a non-spiking ANN.

For context, the PPO algo comprises a rollout phase during which one gathers data from the environment, followed by an update phase during which the actor and critics are updated.

In the encoder and the spiking MLP, I use Leaky neurons to generate and process spikes. I initialize these neurons with init_hidden set to True. The encoder and the MLP are two different nn.Modules, each defining its own forward method. In each of these methods, I used to call utils.reset(self.net) prior to any processing, in which net is an nn.Sequential (resp. containing one Leaky neuron, and multiple Leaky neurons and linear layers).

This goes fine during the rollout phase (at least, it runs without throwing any errors). However, problems arise when in the update phase: the update loop crashes when trying to backward for the second time (i.e. using the second mini-batch from the data collected during the first rollout). I get the [RuntimeError: Trying to backward through the graph a second time (...)] error.

After digging in a bit, I noticed that this seems to come from the mem variable of the Leaky neuron(s). During the first loop iteration, the batch size changes (changing from "number of parallel environments" to "minibatch size"), and the self.mem variable is assigned a new tensor in the Leaky's forward method. For the second iteration though, as the batch size is the same as for the first iteration, this does not happen. I thought that calling the utils.reset(self.net) would have the same effect as assigning a new tensor, but this is not what I observed. Actually, the mem tensor which is manipulated in the reset_hidden class method does not seem to be the same as the one used in forward (batch size is always equal to "number of parallel environments" in reset_hidden while I'd expect it to change after the first training iteration). The fact that the same mem is used in two iterations seems to cause the issue with the backprop.

What I Did

I'm not getting errors if I change utils.reset(self.net) for net.a_leaky_neuron.reset_mem() in my modules' forward methods. Not sure the training is done properly though, I'm only trying to have something running without errors at the moment.

I might be wrong in my way of using snntorch utils, but could you tell me if this problem rings a bell on your side? Do you maybe see something not correct in the way I reset neurons? Why does calling the utils method does not work?

If needed, I could share some code but it is quite bulky at the moment and it would require me a bit of work to provide a minimal example.

Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant