Skip to content

Commit

Permalink
check python environment
Browse files Browse the repository at this point in the history
  • Loading branch information
tjtanaa committed Dec 6, 2024
1 parent 333d6ba commit 3afa73e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
14 changes: 8 additions & 6 deletions .github/workflows/amd-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,21 @@ jobs:
run: |
cp -r /opt/rocm/share/amd_smi ./
cd amd_smi
python3 -m pip install -e .
python -m pip install -e .
cd ..
python3 -m pip install pytest pytest-xdist pytest-rerunfailures pytest-flakefinder pytest-cpp
python3 -m pip uninstall -y torch torchvision
python3 -m pip install --pre \
python -m pip install pytest pytest-xdist pytest-rerunfailures pytest-flakefinder pytest-cpp
python -m pip uninstall -y torch torchvision
python -m pip install --pre \
torch==2.6.0.dev20241113+rocm6.2 \
'setuptools-scm>=8' \
torchvision==0.20.0.dev20241113+rocm6.2 \
--extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
python3 -m pip install triton==3.1.0 transformers==4.46.3
python3 -m pip install -e .[dev]
python -m pip install triton==3.1.0 transformers==4.46.3
python -m pip install -e .[dev]
python -m pip list
- name: Run Unit Tests
run: |
python -m pip list
make test
make test-convergence
5 changes: 3 additions & 2 deletions test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.utils import infer_device
from liger_kernel.ops.utils import is_hip

device = infer_device()
set_seed(42)
Expand Down Expand Up @@ -763,7 +764,7 @@ def test_float32_internal():
RETURN_Z_LOSS=0, # False
HAS_SOFTCAPPING=False,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if torch.hip.version is None else 16,
num_warps=32 if not is_hip() else 16,
)

# Run kernel for float32
Expand All @@ -787,7 +788,7 @@ def test_float32_internal():
RETURN_Z_LOSS=0, # False
HAS_SOFTCAPPING=False,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if torch.hip.version is None else 16,
num_warps=32 if not is_hip() else 16,
)

torch.allclose(X_bf16, X_fp32.bfloat16())
Expand Down

0 comments on commit 3afa73e

Please sign in to comment.