-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TF Sharded #17713
TF Sharded #17713
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Okay so the tfopt_for_causal_lm/tfopt_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, LGTM 👍
There are two main sets of comments which I think it would be nice to address before merging:
- Some documentation is copy/paste from the PT side, which means it needs some tweaks for TF. Also, some typos were copied over :D
- There is duplicated functionality that I think we could move to a shared module
@@ -1075,3 +1080,114 @@ def send_example_telemetry(example_name, *example_args, framework="pytorch"): | |||
except Exception: | |||
# We don't want to error in case of connection errors of any kind. | |||
pass | |||
|
|||
|
|||
def convert_file_size_to_int(size: Union[int, str]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! It will be removed from modeling_utils
raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.") | ||
|
||
|
||
def get_checkpoint_shard_files( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same comment as above -- there is a very similar function here
The docstring also needs a minor update: PreTrainedModel
has a corresponding TF version, TFPreTrainedModel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, should remove the functions from modeling_utils
in a next PR
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there are plans to move shared functionality in a future PR, I'm happy to approve 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this!
unexpected_keys = set() | ||
# Read the H5 file | ||
try: | ||
with h5py.File(resolved_archive_file, "r") as f: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it's used a lot in the below, can we give this f
a better (more descriptive) name?
) | ||
if is_sharded: | ||
for file in resolved_archive_file: | ||
assert os.path.isfile(file), f"Error retrieving files {file}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No new asserts in the codebase ;-) Please use a test and raise the appropriate error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Error is pretty much already handled with the OSError
as the call to load_tf_...
is already in a try/catch
clause
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case you can remove the assert :-p
ignore_mismatched_sizes=ignore_mismatched_sizes, | ||
) | ||
else: | ||
assert os.path.isfile(resolved_archive_file), f"Error retrieving file {resolved_archive_file}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here :-)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
* initial commit * update modeeling tf utils * quality * clean and update args * update * remove potential bug * code quality * update * update max shard * update tests for sharding from pretrained * fix remaining test * make style * h5py if tf available * update and fix test * fix test * style * modified push to hub to support shard for TF * quick fix * update code * merge branch main and style * Apply suggestions from code review Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * update based on reviews * update doc * update and style * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update based on reviews * fix typo * style Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
* initial commit * update modeeling tf utils * quality * clean and update args * update * remove potential bug * code quality * update * update max shard * update tests for sharding from pretrained * fix remaining test * make style * h5py if tf available * update and fix test * fix test * style * modified push to hub to support shard for TF * quick fix * update code * merge branch main and style * Apply suggestions from code review Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * update based on reviews * update doc * update and style * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update based on reviews * fix typo * style Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
What does this PR do?
Introduces the sharding of TF models following the pytroch implementation.
A simple working example is the following :