diff --git a/keras_core/backend/jax/trainer.py b/keras_core/backend/jax/trainer.py index 77c59eb36..7b605f40d 100644 --- a/keras_core/backend/jax/trainer.py +++ b/keras_core/backend/jax/trainer.py @@ -1,3 +1,5 @@ +from functools import partial + import jax import numpy as np import tree @@ -237,8 +239,11 @@ def multi_train_steps(state, data): train_step = one_train_step if not self.run_eagerly and self.jit_compile: - - @jax.jit + # Note that we mark the state and data to be donated to jax, + # so that jax will reuse the memory buffer for outputs. + # This will reduce the memory usage of the training function by + # half. + @partial(jax.jit, donate_argnames="state") def compiled_train_step(state, data): return train_step(state, data) @@ -266,8 +271,11 @@ def multi_test_steps(state, data): test_step = one_test_step if not self.run_eagerly and self.jit_compile: - - @jax.jit + # Note that we mark the state and data to be donated to jax, + # so that jax will reuse the memory buffer for outputs. + # This will reduce the memory usage of the training function by + # half. + @partial(jax.jit, donate_argnames="state") def compiled_test_step(state, data): return test_step(state, data) @@ -578,15 +586,18 @@ def evaluate( ) data = self._distribute_data(data) logs, state = self.test_function(state, data) - # Note that trainable variables are not returned since they're - # immutable here. - _, non_trainable_variables, metrics_variables = state + ( + trainable_variables, + non_trainable_variables, + metrics_variables, + ) = state # Setting _jax_state enables callbacks to force a state sync # if they need to. self._jax_state = { # I wouldn't recommend modifying non-trainable model state # during evaluate(), but it's allowed. + "trainable_variables": trainable_variables, "non_trainable_variables": non_trainable_variables, "metrics_variables": metrics_variables, } @@ -764,8 +775,9 @@ def test_on_batch( logs, state = self.test_function(state, [data]) # State sync - _, non_trainable_variables, metrics_variables = state + trainable_variables, non_trainable_variables, metrics_variables = state self._jax_state = { + "trainable_variables": trainable_variables, "non_trainable_variables": non_trainable_variables, "metrics_variables": metrics_variables, }