diff --git a/.github/workflows/amd-ci.yml b/.github/workflows/amd-ci.yml index 74e454706..6e95d65ee 100644 --- a/.github/workflows/amd-ci.yml +++ b/.github/workflows/amd-ci.yml @@ -56,35 +56,9 @@ jobs: with: python-version: '3.10' - - name: Check Docker Version - run: docker version - - - name: Check Ubuntu version - run: lsb_release -a - - - name: Check Hardware Specs - run: lscpu - - - name: ROCM-SMI Output - run: | - rocm-smi - rocm-smi --showproductname - - name: Setup Dependencies run: | - cp -r /opt/rocm/share/amd_smi ./ - cd amd_smi - python -m pip install -e . - cd .. - python -m pip install pytest pytest-xdist pytest-rerunfailures pytest-flakefinder pytest-cpp - python -m pip uninstall -y torch torchvision - python -m pip install --pre \ - torch==2.6.0.dev20241113+rocm6.2 \ - 'setuptools-scm>=8' \ - torchvision==0.20.0.dev20241113+rocm6.2 \ - --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2 - python -m pip install triton==3.1.0 transformers==4.46.3 - python -m pip install -e .[dev] + python -m pip install -e .[dev,amd] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2 - name: List Python Environments run: python -m pip list diff --git a/README.md b/README.md index 29800cd3d..417e33523 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,8 @@ To install from source: git clone https://github.com/linkedin/Liger-Kernel.git cd Liger-Kernel pip install -e . +# or if installing on amd platform +pip install -e .[amd] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2 # rocm6.2 # or if using transformers pip install -e .[transformers] ``` diff --git a/pyproject.toml b/pyproject.toml index fd76bdee3..c285d26fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "triton>=2.3.1", ] + [project.optional-dependencies] transformers = [ "transformers~=4.0" @@ -27,11 +28,22 @@ dev = [ "black>=24.4.2", "isort>=5.13.2", "pytest>=7.1.2", + "pytest-xdist", + "pytest-rerunfailures", "datasets>=2.19.2", "torchvision>=0.16.2", "seaborn", ] +amd = [ + "torch>=2.6.0.dev", + "setuptools-scm>=8", + "torchvision>=0.20.0.dev", + "triton>=3.0.0", +] + + + [tool.setuptools.packages.find] where = ["src"] include = ["liger_kernel", "liger_kernel.*"]