Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove redundant reference to ocp.tree.serialize_tree(...) by removing dead code. #1605

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 6 additions & 17 deletions t5x/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@

import enum
import os
from typing import Any, BinaryIO, Optional, Tuple, Union
from typing import Any, BinaryIO, Optional, Union

from absl import logging
from etils import epath
import jax
import msgpack
import orbax.checkpoint as ocp
from tensorflow.io import gfile
Expand Down Expand Up @@ -257,27 +256,21 @@ def _is_supported_empty_value(value: Any) -> bool:
return ocp.type_handlers.is_supported_empty_value(value)


def get_restore_parameters(
directory: epath.Path,
structure: PyTree,
) -> Tuple[PyTree, PyTree]:
"""Construct parameters needed for restoration.
def get_restore_parameters(directory: epath.Path, structure: PyTree) -> PyTree:
"""Construct ParamInfos tree needed for restoration.

ParamInfos are
constructed from the structure of the original checkpoint, and restore_args
are serialized to a tree structure compatible with param_infos and structure.
ParamInfos are constructed from the structure of the original checkpoint.

Args:
directory: Checkpoint directory.
structure: The structure of the original checkpoint.

Returns:
Tuple of param_infos, and restore_args.
PyTree of `ParamInfo`.
"""
flat_structure = ocp.tree.to_flat_dict(structure, keep_empty_nodes=True)
param_names = ocp.tree.get_param_names(structure)
flat_param_names = ocp.tree.to_flat_dict(param_names, keep_empty_nodes=True)
restore_args = jax.tree.map(lambda x: ocp.RestoreArgs(), structure)
flat_param_infos = {}
is_ocdbt_checkpoint = ocp.type_handlers.is_ocdbt_checkpoint(directory)
ts_context = ocp.type_handlers.get_ts_context()
Expand Down Expand Up @@ -305,9 +298,5 @@ def _get_param_info(

for key, meta in flat_structure.items():
flat_param_infos[key] = _get_param_info(flat_param_names[key], meta)
restore_args = ocp.tree.serialize_tree(restore_args, keep_empty_nodes=True)

return (
ocp.tree.from_flat_dict(flat_param_infos, target=structure),
restore_args,
)
return ocp.tree.from_flat_dict(flat_param_infos, target=structure)
2 changes: 1 addition & 1 deletion t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2166,7 +2166,7 @@ def _modify_orbax_param_info(info, value):
return info

item_ = jax.tree.map(_make_orbax_internal_metadata, item_, restore_args)
param_infos_, _ = checkpoint_utils.get_restore_parameters(directory_, item_)
param_infos_ = checkpoint_utils.get_restore_parameters(directory_, item_)
param_infos_ = jax.tree.map(
_modify_orbax_param_info, param_infos_, state_dict_to_restore
)
Expand Down
Loading