forked from linkedin/Liger-Kernel
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support CE after grad acc fix (linkedin#375)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> Based on linkedin#374, but make it leaner 1. The use of cross entropy in model code has changed after grad fix 2. It changed from module CrossEntropy to functional cross_entropy 3. Our monkey patching needs to change accordingly 4. While also make sure backward compatibility by adding a condition for different versions Notable Changes 1. Add a functional api for CE to take keyword args 2. Add back conv test with logits to test CE convergence 3. Add back comp test for transformers 4.44 ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
- Loading branch information
Showing
8 changed files
with
881 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from pathlib import Path | ||
|
||
import modal | ||
|
||
ROOT_PATH = Path(__file__).parent.parent.parent | ||
|
||
# tests_bwd is to ensure the backward compatibility of liger with older transformers | ||
image = ( | ||
modal.Image.debian_slim() | ||
.pip_install_from_pyproject( | ||
ROOT_PATH / "pyproject.toml", optional_dependencies=["dev"] | ||
) | ||
.pip_install("transformers==4.44.2") | ||
) | ||
|
||
app = modal.App("liger_tests", image=image) | ||
|
||
# mount: add local files to the remote container | ||
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path="/root/liger-kernel") | ||
|
||
|
||
@app.function(gpu="A10G", mounts=[repo], timeout=60 * 10) | ||
def liger_tests(): | ||
import subprocess | ||
|
||
subprocess.run(["pip", "install", "-e", "."], check=True, cwd="/root/liger-kernel") | ||
subprocess.run(["make", "test"], check=True, cwd="/root/liger-kernel") | ||
subprocess.run(["make", "test-convergence"], check=True, cwd="/root/liger-kernel") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.