Skip to content

Commit

Permalink
Add flag to select tfrt backend for CPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangqiaorjc committed Jun 22, 2021
1 parent 23b6b0b commit 132a542
Showing 1 changed file with 36 additions and 22 deletions.
58 changes: 36 additions & 22 deletions jax/lib/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@
'jax_disable_most_optimizations', False,
'Try not to do much optimization work. This can be useful if the cost of '
'optimization is greater than that of running a less-optimized program.')

flags.DEFINE_string(
'jax_cpu_backend_variant', 'tfrt',
'jax_cpu_backend_variant selects cpu backend variant: stream_executor or '
'tfrt')

def get_compile_options(
num_replicas: int,
Expand Down Expand Up @@ -148,28 +151,39 @@ def register_backend_factory(name, factory, *, priority=0):


if jax.lib._xla_extension_version >= 23:
register_backend_factory('interpreter', xla_client.make_interpreter_client,
priority=-100)
register_backend_factory('cpu', xla_client.make_cpu_client,
priority=0)
register_backend_factory('tpu_driver', _make_tpu_driver_client,
priority=100)
register_backend_factory('gpu', xla_client.make_gpu_client,
priority=200)
register_backend_factory('tpu', xla_client.make_tpu_client,
priority=300)
register_backend_factory('interpreter', xla_client.make_interpreter_client,
priority=-100)
if jax.lib._xla_extension_version >= 24:
if FLAGS.jax_cpu_backend_variant == 'stream_executor':
register_backend_factory('cpu',
partial(xla_client.make_cpu_client, use_tfrt=False),
priority=0)
else:
assert FLAGS.jax_cpu_backend_variant == 'tfrt'
register_backend_factory('cpu',
partial(xla_client.make_cpu_client, use_tfrt=True),
priority=0)
else:
register_backend_factory('cpu',
partial(xla_client.make_cpu_client, use_tfrt=False),
priority=0)
register_backend_factory('tpu_driver', _make_tpu_driver_client,
priority=100)
register_backend_factory('gpu', xla_client.make_gpu_client,
priority=200)
register_backend_factory('tpu', xla_client.make_tpu_client,
priority=300)
else:
register_backend_factory('interpreter',
xla_client._interpreter_backend_factory,
priority=-100)
register_backend_factory('cpu', xla_client._cpu_backend_factory, priority=0)
register_backend_factory('tpu_driver', _make_tpu_driver_client,
priority=100)
register_backend_factory('gpu', xla_client._gpu_backend_factory,
priority=200)
register_backend_factory('tpu', xla_client._tpu_backend_factory,
priority=300)

register_backend_factory('interpreter',
xla_client._interpreter_backend_factory,
priority=-100)
register_backend_factory('cpu', xla_client._cpu_backend_factory, priority=0)
register_backend_factory('tpu_driver', _make_tpu_driver_client,
priority=100)
register_backend_factory('gpu', xla_client._gpu_backend_factory,
priority=200)
register_backend_factory('tpu', xla_client._tpu_backend_factory,
priority=300)

_default_backend = None
_backends = None
Expand Down

0 comments on commit 132a542

Please sign in to comment.