Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manually resetting BrainPy's delay variables cause JAX leaking error #626

Closed
CloudyDory opened this issue Feb 19, 2024 · 7 comments
Closed

Comments

@CloudyDory
Copy link
Contributor

CloudyDory commented Feb 19, 2024

I need to simulate a network with trials of long durations. To avoid out-of-memory error, I divide the trial into several segments. So the simulation code has two nested for loops: the outer loop over different trials, and inner loop over segments of a single trial.

The network's synapse has a LengthDelay variable. I find that this variable is not being reset even after manually calling bp.reset_state(model), so I implement a custom reset_state() function in the Synapse object. However, this function causes JAX leaking error, and I am not sure how to deal with it.

The example code to reproduce the issue is given below:

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

#%% Configurations
cfg = {}
cfg['T'] = 100.0    # total length of trial, ms
cfg['dt'] = 1.0     # ms
cfg['trial_num'] = 5
cfg['segment_num'] = 4  # divide the trial into segements to save memory

#%% Network definition
class Synapse(bp.dyn.SynConn):
    def __init__(self, pre, post, delay_step):
        super().__init__(pre, post, conn=None)
        self.delay_step = int(delay_step)
        self.pre_spike_buffer = bm.LengthDelay(self.pre.spike, self.delay_step)
        
    def reset_state(self, *args):
        self.pre_spike_buffer = bm.LengthDelay(self.pre.spike, self.delay_step)
    
    def update(self, pre_spike):
        self.pre_spike_buffer.update(pre_spike)

class Network(bp.DynSysGroup):
    def __init__(self, size):
        super().__init__()
        self.neu = bp.dyn.LifRef(size=size)
        self.syn  = Synapse(self.neu, self.neu, delay_step=2)
    
    def update(self):
        spike = bm.random.rand(*self.neu.size) < 0.5
        self.syn.update(spike)

#%% Run the simulation
model = Network(size=10)
runner = bp.DSRunner(model, jit=True, dt=cfg['dt'], progress_bar=False)

for i in range(cfg['trial_num']):
    print('Running the simulation... ')
    bp.reset_state(model)
    for j in range(cfg['segment_num']):
        runner.run(duration=cfg['T']/cfg['segment_num'])

The code generates the following error when running the second trial:

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type int32[1] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was _step_func_predict at /home/xxx/miniconda3/envs/brainpy2.5/lib/python3.11/site-packages/brainpy/_src/runners.py:619 traced for scan.
------------------------------
The leaked intermediate value was created on line /home/xxx/project/test_bug.py:32:8 (Synapse.update). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
<frozen runpy>:88:4 (_run_code)
/tmp/ipykernel_420910/489073642.py:1 (<module>)
/home/xxx/project/test_bug.py:52:8 (<module>)
/home/xxx/project/test_bug.py:42:8 (Network.update)
/home/xxx/project/test_bug.py:32:8 (Synapse.update)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

Removing reset_state() function in Synapse object can solve the issue, but then the LengthDelay variable won't be reset after a trial is finished, and the spikes stored in the previous trial will leak into the next trial.

Another way that I found to solve the issue is to use the following reset function:

def reset_state(self, *args):
    self.pre_spike_buffer.data = self.pre_spike_buffer.data.at[:].set(0)

But I am not sure whether this will also reset the gradient information assosiated with the spike data when training by back-propagation.

Environment:

  • Ubuntu 22.04
  • Python 3.11
  • brainpy 2.5.0
  • brainpylib 0.2.6 (with cuda12 support)
  • jax and jaxlib 0.4.24 (with cuda support)
  • taichi 1.7.0
@chaoming0625
Copy link
Collaborator

So many thanks!

The error is caused by the wrong usage of reset_state.

For LengthDelay, it cannot be reset by reassigning a new LengthDelay instance. If you define

    def reset_state(self, *args):
        self.pre_spike_buffer = bm.LengthDelay(self.pre.spike, self.delay_step)
    

self.pre_spike_buffer will refer to another instance of LengthDelay. This will cause errors when the function has been compiled. Instead,

    def reset_state(self, *args):
        self.pre_spike_buffer.reset(self.pre.spike, self.delay_step)
    

will reset the traced LengthDelay variables wiithout changing the traced LengthDelay.

It is important to note that BrainPyObject has special handling on Variable (only for all subclasses of Variable). Users can define reset_state function by assigning another Variable instance without changing the original instantiated Variable, like this:

def reset_state():
   self.a = bm.Variable(bm.ones(10))

During the first call of reset_state, this function will create self.a since a does not exist in this class. However, during the second call of reset_state, the function inspects that a has existed and it is a Variable, the class will not change a reference, instead the class change its value according to the new data. The overall behavior looks like this:

def reset_state():
   if not hasattr(self, 'a'):
      self.a = bm.Variable(bm.ones(10))
   else:
      self.a.value = bm.Variable(bm.ones(10))

However, this behavior only works for bm.Variable.

@CloudyDory
Copy link
Contributor Author

Thank you very much for the explanation. This fixes the error.

@CloudyDory
Copy link
Contributor Author

Hi, I find that resetting the delay variables in training mode still leads to JAX leaking error. Example code:

import numpy as np
import jax
import brainpy as bp
import brainpy.math as bm

bm.clear_buffer_memory()
bm.set(float_=bm.float32)
bm.set_platform('cpu')

#%% Network definition
class Network(bp.DynSysGroup):
    def __init__(self):
        super().__init__()
        self.neu = bp.dyn.Lif(size=2, spk_fun=bm.surrogate.Arctan())
        self.delay_len = 2
        self.spike_buffer = bm.LengthDelay(self.neu.spike, delay_len=self.delay_len)
        self.weight = bm.TrainVar(bm.random.randn(2))
    
    def reset_state(self, *args):
        self.neu.reset_state(self.neu.mode)
        self.spike_buffer.reset(self.neu.spike, delay_len=self.delay_len)
    
    def update(self, data):
        spike = self.neu(data)  # [batch, 2] 
        self.spike_buffer.update(spike)
        spike_delay = self.spike_buffer.retrieve(self.delay_len)  # [batch, 2] 
        out = bm.sum(self.weight * spike_delay)  # scalar
        return out

#%% Create network and fake data
print('Creating network... ')
with bm.training_environment():
    model = Network()
    optimizer = bp.optim.Adam(lr=1e-1, train_vars=model.train_vars().unique())

print('Creating data... ')
train_data = np.concatenate([np.random.randn(100, 2) + np.array([[-1,-1]]),
                             np.random.randn(100, 2) + np.array([[ 1, 1]])], axis=0)  # [batch, 2]
train_label = bm.concatenate([bm.zeros(100, dtype=bm.int32), 
                              bm.ones(100, dtype=bm.int32)], axis=0)  # [batch]

#%% Training functions
def loss_fun(x_single, y_single):
    '''
    Inputs:
        x_single: [feature]
        y_single: scalar
    '''
    model.reset_state()
    predict = model.step_run(0, x_single)  # scalar
    loss = bp.losses.binary_logistic_loss(predict, y_single)  # scalar
    acc = bm.mean(bm.int32(predict >= 0.0) == y_single)  # scalar
    return loss, acc

grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)

def grad_fun(last_grad, x_y_single):
    '''
    Inputs:
        last_grad: PyTree of gradients of each trainable parameter.
        x_y_single: tuple of ([feature], scalar), a single training sample.
    '''
    x_single, y_single = x_y_single  # [feature], scalar
    grads, loss, acc = grad_f(x_single, y_single)  # PyTree of gradients, scalar, scalar
    new_grad = jax.tree_map(bm.add, last_grad, grads)  # accumulate gradients
    return new_grad, (loss, acc)

@bm.jit
def train(x_batch, y_batch):
    '''
    Inputs:
        x_batch: [batch, feature]
        y_batch: [batch]
    '''
    train_vars = model.train_vars().unique()
    
    # Gradient accumulation
    grads = jax.tree_map(bm.zeros_like, train_vars)
    grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch))  # PyTree of gradients, [batch], [batch]
    optimizer.update(grads)
    
    loss = losses.mean()  # scalar
    acc = acces.mean()    # scalar
    return loss, acc

#%% Start training
print('Start training...')
train_epochs = 15
train_loss = bm.zeros(train_epochs, dtype=bm.float_)
train_acc = bm.zeros(train_epochs, dtype=bm.float_)

for e in range(train_epochs):
    train_loss[e], train_acc[e] = train(train_data, train_label)
    print("Epoch {}, train_loss={:.3f}, train_acc={:.2f}%".format(e, train_loss[e], train_acc[e]*100.0))
    
print('Done!')

The error message is:

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[3,1,2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was fun2scan at /home/xxx/miniconda3/envs/brainpy2.5/lib/python3.11/site-packages/brainpy/_src/math/object_transform/controls.py:929 traced for scan.
------------------------------
The leaked intermediate value was created on line /home/xxx/train_delayvar_reset.py:29:8 (Network.reset_state). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/home/xxx/train_delayvar_reset.py:101:34 (<module>)
/home/xxx/train_delayvar_reset.py:87:29 (train)
/home/xxx/train_delayvar_reset.py:72:23 (grad_fun)
/home/xxx/train_delayvar_reset.py:57:4 (loss_fun)
/home/xxx/train_delayvar_reset.py:29:8 (Network.reset_state)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

Is there a fix for this error?

@CloudyDory CloudyDory reopened this Feb 22, 2024
chaoming0625 added a commit that referenced this issue Feb 22, 2024
@chaoming0625
Copy link
Collaborator

Thanks for the report. There is a bug for the delay reset. I have fixed it in #631 .

Moreover, in the training environment, it's better to initialize the delay as

bm.LengthDelay(self.neu.spike, delay_len=self.delay_len, update_method='concat')

@CloudyDory
Copy link
Contributor Author

CloudyDory commented Feb 22, 2024

Why is it better to use 'concat' update method? I originally think the concat method needs to creat new tensors from old tensors, so it involves more memory operation.

@chaoming0625
Copy link
Collaborator

Yes. The rotation method does not implement an autograd functionality. For gradient-based learning, it should use concat.

@CloudyDory
Copy link
Contributor Author

Thank you very much for the information. In general, is there a way to check whether the gradient is computed correctly?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants