From d35a3a8ec8884619f711c5e3ea8ac90c6cca6e24 Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Wed, 16 Oct 2024 09:44:07 -0700 Subject: [PATCH] De-duplicate `get_ts_context` usages and move to ts_utils. PiperOrigin-RevId: 686540442 --- t5x/checkpoint_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/t5x/checkpoint_utils.py b/t5x/checkpoint_utils.py index d817b61af..f6f611a8e 100644 --- a/t5x/checkpoint_utils.py +++ b/t5x/checkpoint_utils.py @@ -280,7 +280,9 @@ def get_restore_parameters( restore_args = jax.tree.map(lambda x: ocp.RestoreArgs(), structure) flat_param_infos = {} is_ocdbt_checkpoint = ocp.type_handlers.is_ocdbt_checkpoint(directory) - ts_context = ocp.type_handlers.get_ts_context() + ts_context = ocp.serialization.ts_utils.get_ts_context( + use_ocdbt=is_ocdbt_checkpoint + ) def _get_param_info( name: str,