Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CE after grad acc fix #375

Merged
merged 4 commits into from
Nov 12, 2024
Merged

Support CE after grad acc fix #375

merged 4 commits into from
Nov 12, 2024

Conversation

ByronHsu
Copy link
Collaborator

@ByronHsu ByronHsu commented Nov 12, 2024

Summary

Based on #374, but make it leaner

  1. The use of cross entropy in model code has changed after grad fix
  2. It changed from module CrossEntropy to functional cross_entropy
  3. Our monkey patching needs to change accordingly
  4. While also make sure backward compatibility by adding a condition for different versions

Notable Changes

  1. Add a functional api for CE to take keyword args
  2. Add back conv test with logits to test CE convergence
  3. Add back comp test for transformers 4.44

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@ByronHsu ByronHsu changed the title support CE after grad acc fix Support CE after grad acc fix Nov 12, 2024
@ByronHsu ByronHsu merged commit 5ef09d5 into main Nov 12, 2024
3 checks passed
@ByronHsu ByronHsu deleted the byhsu/fix-ce branch November 12, 2024 20:49
@hongpeng-guo
Copy link
Collaborator

Thanks for providing this exemplary PR! While I am working on enabling kwargs for all other operators, I meet a few questions and would like to hear your suggestions. 😄

For the function signature, should we try to follow their counter-parts' signatures in torch.nn.Function? I found there are a few cases that make it hard to make the signature compatible with torch, i.e:

  1. rms_norm and layer_norm in torch ask for non-optional arg normalized_shape that is not required in liger;
  2. group_norm in liger ask for non-optional num_channels which is not required in torch;
  3. gelu in liger takes two args a, b as inputs, but for torch, it only asks for one input

There are also ops that is not available in torch.nn.Functional, i.e., the fused ops and jsd ops.

What would be a good strategy here to redefine the function signature here? cc @ByronHsu

@ByronHsu
Copy link
Collaborator Author

Thanks @hongpeng-guo! I think for now we don't need to be restricted by torch since some of the layers like rmsnorm is not actually taken from torch. Let's just put whatever arg liger has as kwargs and don't worry about torch. We can do some adjustment once we receive more community feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants