Skip to content

Commit

Permalink
Improve TensorFlowTrainer compatibility for TF<2.9 (#598)
Browse files Browse the repository at this point in the history
* Improve `TensorFlowTrainer` compatibility for TF<2.9

* Fix the version number for `support_reduce_retracing`

* Make `support_reduce_retracing` private
  • Loading branch information
taehoonlee authored Jul 26, 2023
1 parent e784f6e commit 28c29e7
Showing 1 changed file with 32 additions and 16 deletions.
48 changes: 32 additions & 16 deletions keras_core/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
import tensorflow as tf
import tree

from packaging.version import Version
from tensorflow.python.eager import context as tf_context

from keras_core import callbacks as callbacks_module
Expand All @@ -22,6 +24,11 @@ def __init__(self):
self.test_function = None
self.predict_function = None

if Version(tf.__version__) >= Version('2.9.0'):
self._support_reduce_retracing = True
else:
self._support_reduce_retracing = False

# Model must be created under scope of DistStrat it will be trained
# with.
if tf.distribute.has_strategy():
Expand Down Expand Up @@ -99,11 +106,10 @@ def one_step_on_data(data):
return self.train_step(data)

if not self.run_eagerly:
one_step_on_data = tf.function(
one_step_on_data,
jit_compile=self.jit_compile,
reduce_retracing=True,
)
kwargs = {'jit_compile': self.jit_compile}
if self._support_reduce_retracing:
kwargs.update({'reduce_retracing': True})
one_step_on_data = tf.function(one_step_on_data, **kwargs)

@tf.autograph.experimental.do_not_convert
def one_step_on_iterator(iterator):
Expand Down Expand Up @@ -131,7 +137,10 @@ def multi_step_on_iterator(iterator):
train_function = one_step_on_iterator

if not self.run_eagerly:
train_function = tf.function(train_function, reduce_retracing=True)
kwargs = {}
if self._support_reduce_retracing:
kwargs.update({'reduce_retracing': True})
train_function = tf.function(train_function, **kwargs)

self.train_function = train_function

Expand All @@ -145,9 +154,10 @@ def one_step_on_data(data):
return self.test_step(data)

if not self.run_eagerly and self.jit_compile:
one_step_on_data = tf.function(
one_step_on_data, jit_compile=True, reduce_retracing=True
)
kwargs = {'jit_compile': True}
if self._support_reduce_retracing:
kwargs.update({'reduce_retracing': True})
one_step_on_data = tf.function(one_step_on_data, **kwargs)

@tf.autograph.experimental.do_not_convert
def one_step_on_iterator(iterator):
Expand Down Expand Up @@ -175,7 +185,10 @@ def multi_step_on_iterator(iterator):
test_function = one_step_on_iterator

if not self.run_eagerly:
test_function = tf.function(test_function, reduce_retracing=True)
kwargs = {}
if self._support_reduce_retracing:
kwargs.update({'reduce_retracing': True})
test_function = tf.function(test_function, **kwargs)

self.test_function = test_function

Expand All @@ -189,9 +202,10 @@ def one_step_on_data(data):
return self.predict_step(data)

if not self.run_eagerly and self.jit_compile:
one_step_on_data = tf.function(
one_step_on_data, jit_compile=True, reduce_retracing=True
)
kwargs = {'jit_compile': True}
if self._support_reduce_retracing:
kwargs.update({'reduce_retracing': True})
one_step_on_data = tf.function(one_step_on_data, **kwargs)

@tf.autograph.experimental.do_not_convert
def one_step_on_data_distributed(data):
Expand Down Expand Up @@ -222,9 +236,11 @@ def multi_step_on_data(data):
predict_function = one_step_on_data_distributed

if not self.run_eagerly:
predict_function = tf.function(
predict_function, reduce_retracing=True
)
kwargs = {}
if self._support_reduce_retracing:
kwargs.update({'reduce_retracing': True})

predict_function = tf.function(predict_function, **kwargs)

self.predict_function = predict_function

Expand Down

0 comments on commit 28c29e7

Please sign in to comment.