diff --git a/t5x/utils.py b/t5x/utils.py index 36ead8d0a..67823c027 100644 --- a/t5x/utils.py +++ b/t5x/utils.py @@ -256,6 +256,7 @@ def save( path: str, item: train_state_lib.TrainState, force: bool = False, + custom: dict[str, Any] | None = None, state_transformation_fns: Sequence[ checkpoints.SaveStateTransformationFn ] = (), @@ -268,6 +269,7 @@ def save( path: path to save item to. item: a TrainState PyTree to save. force: unused. + custom: unused. state_transformation_fns: Transformations to apply, in order, to the state before writing. concurrent_gb: the approximate number of gigabytes of partitionable