-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Manually resetting BrainPy's delay variables cause JAX leaking error #626
Comments
So many thanks! The error is caused by the wrong usage of For def reset_state(self, *args):
self.pre_spike_buffer = bm.LengthDelay(self.pre.spike, self.delay_step)
def reset_state(self, *args):
self.pre_spike_buffer.reset(self.pre.spike, self.delay_step)
will reset the traced It is important to note that def reset_state():
self.a = bm.Variable(bm.ones(10)) During the first call of def reset_state():
if not hasattr(self, 'a'):
self.a = bm.Variable(bm.ones(10))
else:
self.a.value = bm.Variable(bm.ones(10)) However, this behavior only works for |
Thank you very much for the explanation. This fixes the error. |
Hi, I find that resetting the delay variables in training mode still leads to JAX leaking error. Example code: import numpy as np
import jax
import brainpy as bp
import brainpy.math as bm
bm.clear_buffer_memory()
bm.set(float_=bm.float32)
bm.set_platform('cpu')
#%% Network definition
class Network(bp.DynSysGroup):
def __init__(self):
super().__init__()
self.neu = bp.dyn.Lif(size=2, spk_fun=bm.surrogate.Arctan())
self.delay_len = 2
self.spike_buffer = bm.LengthDelay(self.neu.spike, delay_len=self.delay_len)
self.weight = bm.TrainVar(bm.random.randn(2))
def reset_state(self, *args):
self.neu.reset_state(self.neu.mode)
self.spike_buffer.reset(self.neu.spike, delay_len=self.delay_len)
def update(self, data):
spike = self.neu(data) # [batch, 2]
self.spike_buffer.update(spike)
spike_delay = self.spike_buffer.retrieve(self.delay_len) # [batch, 2]
out = bm.sum(self.weight * spike_delay) # scalar
return out
#%% Create network and fake data
print('Creating network... ')
with bm.training_environment():
model = Network()
optimizer = bp.optim.Adam(lr=1e-1, train_vars=model.train_vars().unique())
print('Creating data... ')
train_data = np.concatenate([np.random.randn(100, 2) + np.array([[-1,-1]]),
np.random.randn(100, 2) + np.array([[ 1, 1]])], axis=0) # [batch, 2]
train_label = bm.concatenate([bm.zeros(100, dtype=bm.int32),
bm.ones(100, dtype=bm.int32)], axis=0) # [batch]
#%% Training functions
def loss_fun(x_single, y_single):
'''
Inputs:
x_single: [feature]
y_single: scalar
'''
model.reset_state()
predict = model.step_run(0, x_single) # scalar
loss = bp.losses.binary_logistic_loss(predict, y_single) # scalar
acc = bm.mean(bm.int32(predict >= 0.0) == y_single) # scalar
return loss, acc
grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)
def grad_fun(last_grad, x_y_single):
'''
Inputs:
last_grad: PyTree of gradients of each trainable parameter.
x_y_single: tuple of ([feature], scalar), a single training sample.
'''
x_single, y_single = x_y_single # [feature], scalar
grads, loss, acc = grad_f(x_single, y_single) # PyTree of gradients, scalar, scalar
new_grad = jax.tree_map(bm.add, last_grad, grads) # accumulate gradients
return new_grad, (loss, acc)
@bm.jit
def train(x_batch, y_batch):
'''
Inputs:
x_batch: [batch, feature]
y_batch: [batch]
'''
train_vars = model.train_vars().unique()
# Gradient accumulation
grads = jax.tree_map(bm.zeros_like, train_vars)
grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch)) # PyTree of gradients, [batch], [batch]
optimizer.update(grads)
loss = losses.mean() # scalar
acc = acces.mean() # scalar
return loss, acc
#%% Start training
print('Start training...')
train_epochs = 15
train_loss = bm.zeros(train_epochs, dtype=bm.float_)
train_acc = bm.zeros(train_epochs, dtype=bm.float_)
for e in range(train_epochs):
train_loss[e], train_acc[e] = train(train_data, train_label)
print("Epoch {}, train_loss={:.3f}, train_acc={:.2f}%".format(e, train_loss[e], train_acc[e]*100.0))
print('Done!') The error message is:
Is there a fix for this error? |
Thanks for the report. There is a bug for the delay reset. I have fixed it in #631 . Moreover, in the training environment, it's better to initialize the delay as bm.LengthDelay(self.neu.spike, delay_len=self.delay_len, update_method='concat') |
Why is it better to use 'concat' update method? I originally think the |
Yes. The |
Thank you very much for the information. In general, is there a way to check whether the gradient is computed correctly? |
I need to simulate a network with trials of long durations. To avoid out-of-memory error, I divide the trial into several segments. So the simulation code has two nested for loops: the outer loop over different trials, and inner loop over segments of a single trial.
The network's synapse has a
LengthDelay
variable. I find that this variable is not being reset even after manually callingbp.reset_state(model)
, so I implement a customreset_state()
function in the Synapse object. However, this function causes JAX leaking error, and I am not sure how to deal with it.The example code to reproduce the issue is given below:
The code generates the following error when running the second trial:
Removing
reset_state()
function inSynapse
object can solve the issue, but then theLengthDelay
variable won't be reset after a trial is finished, and the spikes stored in the previous trial will leak into the next trial.Another way that I found to solve the issue is to use the following reset function:
But I am not sure whether this will also reset the gradient information assosiated with the spike data when training by back-propagation.
Environment:
The text was updated successfully, but these errors were encountered: