diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py index 8ba20225f16d..ead4e7be0bb6 100644 --- a/jax/lib/xla_bridge.py +++ b/jax/lib/xla_bridge.py @@ -63,7 +63,10 @@ 'jax_disable_most_optimizations', False, 'Try not to do much optimization work. This can be useful if the cost of ' 'optimization is greater than that of running a less-optimized program.') - +flags.DEFINE_string( + 'jax_cpu_backend_variant', 'tfrt', + 'jax_cpu_backend_variant selects cpu backend variant: stream_executor or ' + 'tfrt') def get_compile_options( num_replicas: int, @@ -148,28 +151,39 @@ def register_backend_factory(name, factory, *, priority=0): if jax.lib._xla_extension_version >= 23: - register_backend_factory('interpreter', xla_client.make_interpreter_client, - priority=-100) - register_backend_factory('cpu', xla_client.make_cpu_client, - priority=0) - register_backend_factory('tpu_driver', _make_tpu_driver_client, - priority=100) - register_backend_factory('gpu', xla_client.make_gpu_client, - priority=200) - register_backend_factory('tpu', xla_client.make_tpu_client, - priority=300) + register_backend_factory('interpreter', xla_client.make_interpreter_client, + priority=-100) + if jax.lib._xla_extension_version >= 24: + if FLAGS.jax_cpu_backend_variant == 'stream_executor': + register_backend_factory('cpu', + partial(xla_client.make_cpu_client, use_tfrt=False), + priority=0) + else: + assert FLAGS.jax_cpu_backend_variant == 'tfrt' + register_backend_factory('cpu', + partial(xla_client.make_cpu_client, use_tfrt=True), + priority=0) + else: + register_backend_factory('cpu', + partial(xla_client.make_cpu_client, use_tfrt=False), + priority=0) + register_backend_factory('tpu_driver', _make_tpu_driver_client, + priority=100) + register_backend_factory('gpu', xla_client.make_gpu_client, + priority=200) + register_backend_factory('tpu', xla_client.make_tpu_client, + priority=300) else: - register_backend_factory('interpreter', - xla_client._interpreter_backend_factory, - priority=-100) - register_backend_factory('cpu', xla_client._cpu_backend_factory, priority=0) - register_backend_factory('tpu_driver', _make_tpu_driver_client, - priority=100) - register_backend_factory('gpu', xla_client._gpu_backend_factory, - priority=200) - register_backend_factory('tpu', xla_client._tpu_backend_factory, - priority=300) - + register_backend_factory('interpreter', + xla_client._interpreter_backend_factory, + priority=-100) + register_backend_factory('cpu', xla_client._cpu_backend_factory, priority=0) + register_backend_factory('tpu_driver', _make_tpu_driver_client, + priority=100) + register_backend_factory('gpu', xla_client._gpu_backend_factory, + priority=200) + register_backend_factory('tpu', xla_client._tpu_backend_factory, + priority=300) _default_backend = None _backends = None