Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Unit Test error brought by Transformer Breaking Changes #523

Closed
wants to merge 6 commits into from

Conversation

hebiao064
Copy link
Collaborator

@hebiao064 hebiao064 commented Jan 15, 2025

Summary

There is a breaking changes from HF Transformers that will break out unit test:

https://github.com/huggingface/transformers/blame/137965ca7d0b453b22806c0a5fffc51fde821c33/src/transformers/models/llama/modeling_llama.py#L85

For example, it broke my PR about KTO Loss which has nothing to do with FAILED test/transformers/test_rope.py::test_correctness

https://github.com/linkedin/Liger-Kernel/actions/runs/12779518445/job/35624270834?pr=475

Testing Done

Unit Test is done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

ByronHsu
ByronHsu previously approved these changes Jan 15, 2025
@ByronHsu
Copy link
Collaborator

@hebiao064 can you make a condition for tests-bwd to pass with older transformers version? Some users are on older transformers version

Something like

if transformer_version < ...:
  ...
else:
  ...

@hebiao064
Copy link
Collaborator Author

@hebiao064 can you make a condition for tests-bwd to pass with older transformers version? Some users are on older transformers version

Something like

if transformer_version < ...:
  ...
else:
  ...

Make sense, added

@hebiao064 hebiao064 enabled auto-merge (squash) January 15, 2025 19:00
Comment on lines +117 to +122
if transformers_version < "4.48.0":
# LlamaRotaryEmbedding constructor signature changed in transformers 4.48.0
rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)
else:
llama_config = LlamaConfig(head_dim=head_dim)
rotary_emb = LlamaRotaryEmbedding(llama_config, device=device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like we are repeating same code in all places. Shall we prepare a common function to get the rotary embeddings?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @kvignesh1420 thanks for the great advice, I added a more general helper function here: #526.

austin362667 added a commit that referenced this pull request Jan 20, 2025
## Summary

1. Add a general `version_dispatch` utility function that selects
constructors and its args based on version comparisons.
2. Update `LlamaRotaryEmbedding`.

Closes Issue #525 and PR
#523 (Thanks to
@hebiao064's nice works).

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [X] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Austin Liu <austin362667@gmail.com>
@austin362667
Copy link
Collaborator

austin362667 commented Jan 20, 2025

Close as it was finished in #526

auto-merge was automatically disabled January 20, 2025 03:03

Pull request was closed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants