Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF Sharded #17713

Merged
merged 34 commits into from
Jun 21, 2022
Merged

TF Sharded #17713

merged 34 commits into from
Jun 21, 2022

Conversation

ArthurZucker
Copy link
Collaborator

What does this PR do?

Introduces the sharding of TF models following the pytroch implementation.

A simple working example is the following :

from transformers import TFOPTModel
save_directory = "opt-350m"
model = TFOPTModel.from_pretrained("facebook/opt-350m")
model.save_pretrained(save_directory, max_shard_size = "1GB")
tf_model = TFOPTModel.from_pretrained(save_directory)

@ArthurZucker ArthurZucker self-assigned this Jun 15, 2022
@ArthurZucker ArthurZucker linked an issue Jun 15, 2022 that may be closed by this pull request
@ArthurZucker ArthurZucker added TensorFlow Anything TensorFlow Core: Modeling Internals of the library; Models. labels Jun 15, 2022
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 15, 2022

The documentation is not available anymore as the PR was closed or merged.

@ArthurZucker
Copy link
Collaborator Author

Okay so the tfopt_for_causal_lm/tfopt_model
prefix from the tfopt_for_causal_lm/model/decoder/embed_positions/weight:0 in the index json comes from the actual name of the layer (so tf side). This also creates the hack that we sometime need when some layer is shared : for OPT we have the following : 'decoder.embed_tokens/model.decoder.embed_tokens/weight:0' which thus becomes model.decoder.embed_tokens/weight:0 . Most interesting part is that the ‘decoder.embed_tokens’ comes from https://github.com/ArthurZucker/transformers/blob/e950ff48a91840e30966abaf86bdb02dc16fcdab/src/transformers/models/opt/modeling_tf_opt.py#L499-L511 (the load weight prefix hack using load_weight_prefix) I am sure that there is something to do about that so I will detail that and dig a bit further

@ArthurZucker ArthurZucker marked this pull request as ready for review June 17, 2022 12:06
@ArthurZucker ArthurZucker changed the title load and save tensorflow sharded checkpoints TF Sharded Jun 20, 2022
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, LGTM 👍

There are two main sets of comments which I think it would be nice to address before merging:

  • Some documentation is copy/paste from the PT side, which means it needs some tweaks for TF. Also, some typos were copied over :D
  • There is duplicated functionality that I think we could move to a shared module

@@ -1075,3 +1080,114 @@ def send_example_telemetry(example_name, *example_args, framework="pytorch"):
except Exception:
# We don't want to error in case of connection errors of any kind.
pass


def convert_file_size_to_int(size: Union[int, str]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is the same as in here.

Perhaps move the function to some file with shared functionality, like this one, and import from there? (cc @sgugger )

Copy link
Collaborator Author

@ArthurZucker ArthurZucker Jun 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! It will be removed from modeling_utils

raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")


def get_checkpoint_shard_files(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same comment as above -- there is a very similar function here

The docstring also needs a minor update: PreTrainedModel has a corresponding TF version, TFPreTrainedModel

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, should remove the functions from modeling_utils in a next PR

src/transformers/modeling_tf_utils.py Show resolved Hide resolved
src/transformers/modeling_tf_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_utils.py Outdated Show resolved Hide resolved
@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Jun 20, 2022

Looks very nice to me!

Only did a very high-level review. Defering to @gante and @sgugger here :-)

ArthurZucker and others added 5 commits June 20, 2022 22:25
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are plans to move shared functionality in a future PR, I'm happy to approve 👍

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this!

src/transformers/modeling_tf_utils.py Outdated Show resolved Hide resolved
unexpected_keys = set()
# Read the H5 file
try:
with h5py.File(resolved_archive_file, "r") as f:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's used a lot in the below, can we give this f a better (more descriptive) name?

)
if is_sharded:
for file in resolved_archive_file:
assert os.path.isfile(file), f"Error retrieving files {file}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No new asserts in the codebase ;-) Please use a test and raise the appropriate error.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error is pretty much already handled with the OSError as the call to load_tf_... is already in a try/catch clause

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case you can remove the assert :-p

ignore_mismatched_sizes=ignore_mismatched_sizes,
)
else:
assert os.path.isfile(resolved_archive_file), f"Error retrieving file {resolved_archive_file}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here :-)

@ArthurZucker ArthurZucker merged commit 7cced02 into huggingface:main Jun 21, 2022
younesbelkada pushed a commit to younesbelkada/transformers that referenced this pull request Jun 25, 2022
* initial commit

* update modeeling tf utils

* quality

* clean and update args

* update

* remove potential bug

* code quality

* update

* update max shard

* update tests for sharding from pretrained

* fix remaining test

* make style

* h5py if tf available

* update and fix test

* fix test

* style

* modified push to hub to support shard for TF

* quick fix

* update code

* merge branch main and style

* Apply suggestions from code review

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* update based on reviews

* update doc

* update and style

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update based on reviews

* fix typo

* style

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
younesbelkada pushed a commit to younesbelkada/transformers that referenced this pull request Jun 29, 2022
* initial commit

* update modeeling tf utils

* quality

* clean and update args

* update

* remove potential bug

* code quality

* update

* update max shard

* update tests for sharding from pretrained

* fix remaining test

* make style

* h5py if tf available

* update and fix test

* fix test

* style

* modified push to hub to support shard for TF

* quick fix

* update code

* merge branch main and style

* Apply suggestions from code review

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* update based on reviews

* update doc

* update and style

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update based on reviews

* fix typo

* style

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Core: Modeling Internals of the library; Models. TensorFlow Anything TensorFlow
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Shard checkpoint for tf and flax
5 participants