Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flag to select tfrt backend for CPU. #7042

Merged
merged 1 commit into from
Jun 22, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 36 additions & 22 deletions jax/lib/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@
'jax_disable_most_optimizations', False,
'Try not to do much optimization work. This can be useful if the cost of '
'optimization is greater than that of running a less-optimized program.')

flags.DEFINE_string(
'jax_cpu_backend_variant', 'tfrt',
'jax_cpu_backend_variant selects cpu backend variant: stream_executor or '
'tfrt')

def get_compile_options(
num_replicas: int,
Expand Down Expand Up @@ -148,28 +151,39 @@ def register_backend_factory(name, factory, *, priority=0):


if jax.lib._xla_extension_version >= 23:
register_backend_factory('interpreter', xla_client.make_interpreter_client,
priority=-100)
register_backend_factory('cpu', xla_client.make_cpu_client,
priority=0)
register_backend_factory('tpu_driver', _make_tpu_driver_client,
priority=100)
register_backend_factory('gpu', xla_client.make_gpu_client,
priority=200)
register_backend_factory('tpu', xla_client.make_tpu_client,
priority=300)
register_backend_factory('interpreter', xla_client.make_interpreter_client,
priority=-100)
if jax.lib._xla_extension_version >= 24:
if FLAGS.jax_cpu_backend_variant == 'stream_executor':
register_backend_factory('cpu',
partial(xla_client.make_cpu_client, use_tfrt=False),
priority=0)
else:
assert FLAGS.jax_cpu_backend_variant == 'tfrt'
register_backend_factory('cpu',
partial(xla_client.make_cpu_client, use_tfrt=True),
priority=0)
else:
register_backend_factory('cpu',
partial(xla_client.make_cpu_client, use_tfrt=False),
priority=0)
register_backend_factory('tpu_driver', _make_tpu_driver_client,
priority=100)
register_backend_factory('gpu', xla_client.make_gpu_client,
priority=200)
register_backend_factory('tpu', xla_client.make_tpu_client,
priority=300)
else:
register_backend_factory('interpreter',
xla_client._interpreter_backend_factory,
priority=-100)
register_backend_factory('cpu', xla_client._cpu_backend_factory, priority=0)
register_backend_factory('tpu_driver', _make_tpu_driver_client,
priority=100)
register_backend_factory('gpu', xla_client._gpu_backend_factory,
priority=200)
register_backend_factory('tpu', xla_client._tpu_backend_factory,
priority=300)

register_backend_factory('interpreter',
xla_client._interpreter_backend_factory,
priority=-100)
register_backend_factory('cpu', xla_client._cpu_backend_factory, priority=0)
register_backend_factory('tpu_driver', _make_tpu_driver_client,
priority=100)
register_backend_factory('gpu', xla_client._gpu_backend_factory,
priority=200)
register_backend_factory('tpu', xla_client._tpu_backend_factory,
priority=300)

_default_backend = None
_backends = None
Expand Down