From 98bd34c50fab0abab01353c94ca0cdd8da8da143 Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 22 Oct 2024 04:42:24 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 688497033 --- t5x/export_lib.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/t5x/export_lib.py b/t5x/export_lib.py index dfb9296ae..5acc0ed10 100644 --- a/t5x/export_lib.py +++ b/t5x/export_lib.py @@ -309,6 +309,8 @@ def create_inference_function( output_len: Optional[int] = None, ) -> Callable[[Mapping[str, Any], Any], PyTree]: """Fetches a model and returns the inference function based on inference_mode.""" + # Always use native serialization. The non-native serialization is deprecated. + del native_lowering if partitioner and train_state_initializer: maybe_partition = lambda fn: partitioner.partition( # pylint:disable=g-long-lambda fn, @@ -390,13 +392,12 @@ def model_fn( if jax2tf_disable_platform_checks else [] ) - if native_lowering and (not native_lowering_platforms): + if not native_lowering_platforms: # Change default value to make the exported cpu model still work. native_lowering_platforms = ['cpu', 'tpu'] model_fn = jax2tf.convert( model_fn, polymorphic_shapes=[None, polymorphic_shapes_inputs], - native_serialization=native_lowering, native_serialization_platforms=native_lowering_platforms, native_serialization_disabled_checks=disabled_checks, enable_xla=enable_xla, @@ -1517,8 +1518,7 @@ def save( validation_examples: Optional list of validation examples. If proveded, they will be used to validate the latency and numeric accuracy of the TPU saved model. - native_lowering: for experimental purposes only -- if True, don't convert - Jax fns to TF fns. + native_lowering: deprecated, always True. native_lowering_platforms: In conjunction with `native_lowering`, specify the platform(s) for which to lower the code. Must be a tuple of strings, including a subset of: 'cpu', 'cuda', 'rocm', 'tpu'. The default @@ -1538,6 +1538,8 @@ def save( create_polymorphic_shapes_fn: Optional function to create polymorphic shapes for input tensors to the JAX model function. """ # fmt: skip + # Always use native serialization. The non-native serialization is deprecated. + del native_lowering jax.monitoring.record_event('/jax/t5x/export/beacon') output_dirs = _standardize_output_dirs(output_dir) del output_dir @@ -1584,7 +1586,7 @@ def save( if create_decoding_state_callback_fn is not None: decoding_state_callback_fn = create_decoding_state_callback_fn( vocab=output_vocab, - call_tf_graph=native_lowering, + call_tf_graph=True, ) model_tf_fn = create_inference_function_fn( @@ -1598,7 +1600,7 @@ def save( polymorphic_shapes_inputs=create_polymorphic_shapes_fn( input_signature, preprocessor ), - native_lowering=native_lowering, + native_lowering=True, native_lowering_platforms=native_lowering_platforms, )