Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FSDPv2 compute dtype #8056

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

Conversation

lausannel
Copy link
Contributor

This PR does the following:

  1. Supports using torchdistx for model initialization to reduce memory usage
  2. Supports different compute dtypes

@JackCaoG
Copy link
Collaborator

AFAIK torchdistx is not being maintained, has that changed?

@lausannel lausannel changed the title Support spmd distx init and compute dtype Support FSDPv2 distx init and compute dtype Sep 25, 2024
@lausannel
Copy link
Contributor Author

You're correct. It's worth noting that torchdistx is still in use, and a similar implementation logic is employed in FSDPv1.

_materialize_module(
module,
param_init_fn,
[], # TODO: ignored_params is set to empty now, pass in correct params when this feature is fully enabled
deferred_init_check_fn=lambda k: not isinstance(k, wrapper_cls))

@JackCaoG
Copy link
Collaborator

actually I would prefer not to add this logic here. I think pytorch has been trying to use the meta tensor to solve the parameter issue, I prefer not to add additional logics to a library that has been deprecated.

@lausannel
Copy link
Contributor Author

I have removed all the torchdistx-related content as per your suggestion. Could you please review the changes again?

@lausannel lausannel changed the title Support FSDPv2 distx init and compute dtype Support FSDPv2 compute dtype Sep 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants