Skip to content

Commit

Permalink
[rust] Update gpu build pipeline to cu122 (#3334)
Browse files Browse the repository at this point in the history
* [rust] Update gpu build pipeline to cu122

* Update

* Update .github/workflows/native_s3_huggingface.yml

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>

* Update

* Update

* Update

---------

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
  • Loading branch information
xyang16 and frankfliu authored Jul 12, 2024
1 parent b171572 commit d55e84c
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 16 deletions.
13 changes: 7 additions & 6 deletions .github/workflows/native_s3_huggingface.yml
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,16 @@ jobs:
aws s3 sync extensions/tokenizers/jnilib/$DJL_VERSION/linux-aarch64 s3://djl-ai/publish/tokenizers/${TOKENIZERS_VERSION}/jnilib/$DJL_VERSION/linux-aarch64/
aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/tokenizers/${TOKENIZERS_VERSION}/jnilib/*"
build-tokenizers-jni-cu124:
build-tokenizers-jni-cu122:
if: github.repository == 'deepjavalibrary/djl'
runs-on: [ self-hosted, g5 ]
timeout-minutes: 30
needs: create-runners
container:
image: nvidia/cuda:12.4.1-cudnn-devel-ubuntu20.04
image: nvidia/cuda:12.2.2-cudnn8-devel-ubuntu20.04
options: --gpus all --runtime=nvidia
env:
CUDA_VERSION: cu122
steps:
- name: Install Environment
run: |
Expand Down Expand Up @@ -254,9 +256,8 @@ jobs:
${{ runner.os }}-gradle-
- name: Release JNI prep
run: |
CUDA_VERSION=cu124
. "$HOME/.cargo/env"
./gradlew :extensions:tokenizers:compileJNI -Pcuda=$CUDA_VERSION
./gradlew :extensions:tokenizers:compileJNI -Pcuda=${{ env.CUDA_VERSION }}
./gradlew -Pjni :extensions:tokenizers:test
- name: Configure AWS Credentials
uses: aws-actions/configure-aws-credentials@v2
Expand All @@ -268,13 +269,13 @@ jobs:
run: |
DJL_VERSION=$(awk -F '=' '/djl / {gsub(/ ?"/, "", $2); print $2}' gradle/libs.versions.toml)
TOKENIZERS_VERSION="$(awk -F '=' '/tokenizers/ {gsub(/ ?"/, "", $2); print $2}' gradle/libs.versions.toml)"
aws s3 sync extensions/tokenizers/jnilib/$DJL_VERSION/linux-x86_64/cu124 s3://djl-ai/publish/tokenizers/${TOKENIZERS_VERSION}/jnilib/$DJL_VERSION/linux-x86_64/cu124/
aws s3 sync extensions/tokenizers/jnilib/$DJL_VERSION/linux-x86_64/${{ env.CUDA_VERSION }} s3://djl-ai/publish/tokenizers/${TOKENIZERS_VERSION}/jnilib/${DJL_VERSION}/linux-x86_64/${{ env.CUDA_VERSION }}/
aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/tokenizers/${TOKENIZERS_VERSION}/jnilib/*"
stop-runners:
if: ${{ github.repository == 'deepjavalibrary/djl' && always() }}
runs-on: [ self-hosted, scheduler ]
needs: [ create-runners, build-tokenizer-jni-aarch64, build-tokenizers-jni-cu124 ]
needs: [ create-runners, build-tokenizer-jni-aarch64, build-tokenizers-jni-cu122 ]
steps:
- name: Stop all instances
run: |
Expand Down
6 changes: 4 additions & 2 deletions extensions/tokenizers/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ javac -sourcepath src/main/java/ src/main/java/ai/djl/huggingface/tokenizers/jni
javac -sourcepath src/main/java/ src/main/java/ai/djl/engine/rust/RustLibrary.java -h build/include -d build/classes

RUST_MANIFEST=rust/Cargo.toml
if [ -x "$(command -v nvcc)" ]; then
cargo build --manifest-path $RUST_MANIFEST --release --features cuda,flash-attn,cublaslt
if [[ "$FLAVOR" = "cpu"* ]]; then
cargo build --manifest-path $RUST_MANIFEST --release
elif [[ "$FLAVOR" = "cu"* && "$FLAVOR" > "cu121" ]]; then
cargo build --manifest-path $RUST_MANIFEST --release --features cuda,cublaslt,flash-attn
else
cargo build --manifest-path $RUST_MANIFEST --release
fi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public final class LibUtils {
private static final Pattern VERSION_PATTERN =
Pattern.compile(
"(\\d+\\.\\d+\\.\\d+(-[a-z]+)?)-(\\d+\\.\\d+\\.\\d+)(-SNAPSHOT)?(-\\d+)?");
private static final String FLAVOR_CU124 = "cu124";
private static final int[] SUPPORTED_CUDA_VERSIONS = {122};

private static EngineException exception;

Expand Down Expand Up @@ -89,29 +89,68 @@ private static Path copyJniLibrary(String[] libs) {
Platform platform = Platform.detectPlatform("tokenizers");
String os = platform.getOsPrefix();
String classifier = platform.getClassifier();
String flavor = platform.getFlavor();
String version = platform.getVersion();
String flavor = Utils.getEnvOrSystemProperty("TOKENIZERS_FLAVOR");
boolean override = flavor != null && !flavor.isEmpty();
if (override) {
logger.info("Uses override TOKENIZERS_FLAVOR: {}", flavor);
} else {
if (Utils.isOfflineMode() || "win".equals(os)) {
flavor = "cpu";
} else {
flavor = platform.getFlavor();
}
}

// Find the highest matching CUDA version
if (flavor.startsWith("cu")) {
int cudaVersion = Integer.parseInt(flavor.substring(2, 5));
boolean match = false;
for (int v : SUPPORTED_CUDA_VERSIONS) {
if (override && cudaVersion == v) {
match = true;
break;
} else if (cudaVersion >= v) {
flavor = "cu" + v;
match = true;
break;
}
}
if (!match) {
logger.warn("No matching cuda flavor for {} found: {}.", classifier, flavor);
flavor = "cpu"; // Fallback to CPU
}
}

Path dir = cacheDir.resolve(version + '-' + flavor + '-' + classifier);
Path path = dir.resolve(LIB_NAME);
logger.debug("Using cache dir: {}", dir);
if (Files.exists(path)) {
return dir.toAbsolutePath();
}

// For Linux cuda 12.x, download JNI library
if (flavor.startsWith("cu12") && !"win".equals(os)) {
// Copy JNI library from classpath
if (copyJniLibraryFromClasspath(libs, cacheDir, dir, classifier, flavor)) {
return dir.toAbsolutePath();
}

// Download JNI library
if (flavor.startsWith("cu")) {
Matcher matcher = VERSION_PATTERN.matcher(version);
if (!matcher.matches()) {
throw new EngineException("Unexpected version: " + version);
}
String jniVersion = matcher.group(1);
String djlVersion = matcher.group(3);

downloadJniLib(dir, path, djlVersion, jniVersion, classifier, FLAVOR_CU124);
downloadJniLib(dir, path, djlVersion, jniVersion, classifier, flavor);
return dir.toAbsolutePath();
}
return null;
}

// Extract JNI library from classpath
private static boolean copyJniLibraryFromClasspath(
String[] libs, Path cacheDir, Path dir, String classifier, String flavor) {
Path tmp = null;
try {
Files.createDirectories(cacheDir);
Expand All @@ -126,14 +165,15 @@ private static Path copyJniLibrary(String[] libs) {
}
}
Utils.moveQuietly(tmp, dir);
return dir.toAbsolutePath();
return true;
} catch (IOException e) {
throw new IllegalStateException("Cannot copy jni files", e);
logger.error("Cannot copy jni files", e);
} finally {
if (tmp != null) {
Utils.deleteQuietly(tmp);
}
}
return false;
}

private static void downloadJniLib(
Expand Down

0 comments on commit d55e84c

Please sign in to comment.