diff --git a/.github/workflows/native_s3_huggingface.yml b/.github/workflows/native_s3_huggingface.yml index 5d26195de03..3c54d028e18 100644 --- a/.github/workflows/native_s3_huggingface.yml +++ b/.github/workflows/native_s3_huggingface.yml @@ -216,14 +216,16 @@ jobs: aws s3 sync extensions/tokenizers/jnilib/$DJL_VERSION/linux-aarch64 s3://djl-ai/publish/tokenizers/${TOKENIZERS_VERSION}/jnilib/$DJL_VERSION/linux-aarch64/ aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/tokenizers/${TOKENIZERS_VERSION}/jnilib/*" - build-tokenizers-jni-cu124: + build-tokenizers-jni-cu122: if: github.repository == 'deepjavalibrary/djl' runs-on: [ self-hosted, g5 ] timeout-minutes: 30 needs: create-runners container: - image: nvidia/cuda:12.4.1-cudnn-devel-ubuntu20.04 + image: nvidia/cuda:12.2.2-cudnn8-devel-ubuntu20.04 options: --gpus all --runtime=nvidia + env: + CUDA_VERSION: cu122 steps: - name: Install Environment run: | @@ -254,9 +256,8 @@ jobs: ${{ runner.os }}-gradle- - name: Release JNI prep run: | - CUDA_VERSION=cu124 . "$HOME/.cargo/env" - ./gradlew :extensions:tokenizers:compileJNI -Pcuda=$CUDA_VERSION + ./gradlew :extensions:tokenizers:compileJNI -Pcuda=${{ env.CUDA_VERSION }} ./gradlew -Pjni :extensions:tokenizers:test - name: Configure AWS Credentials uses: aws-actions/configure-aws-credentials@v2 @@ -268,13 +269,13 @@ jobs: run: | DJL_VERSION=$(awk -F '=' '/djl / {gsub(/ ?"/, "", $2); print $2}' gradle/libs.versions.toml) TOKENIZERS_VERSION="$(awk -F '=' '/tokenizers/ {gsub(/ ?"/, "", $2); print $2}' gradle/libs.versions.toml)" - aws s3 sync extensions/tokenizers/jnilib/$DJL_VERSION/linux-x86_64/cu124 s3://djl-ai/publish/tokenizers/${TOKENIZERS_VERSION}/jnilib/$DJL_VERSION/linux-x86_64/cu124/ + aws s3 sync extensions/tokenizers/jnilib/$DJL_VERSION/linux-x86_64/${{ env.CUDA_VERSION }} s3://djl-ai/publish/tokenizers/${TOKENIZERS_VERSION}/jnilib/${DJL_VERSION}/linux-x86_64/${{ env.CUDA_VERSION }}/ aws cloudfront create-invalidation --distribution-id E371VB8JQ6NRVY --paths "/tokenizers/${TOKENIZERS_VERSION}/jnilib/*" stop-runners: if: ${{ github.repository == 'deepjavalibrary/djl' && always() }} runs-on: [ self-hosted, scheduler ] - needs: [ create-runners, build-tokenizer-jni-aarch64, build-tokenizers-jni-cu124 ] + needs: [ create-runners, build-tokenizer-jni-aarch64, build-tokenizers-jni-cu122 ] steps: - name: Stop all instances run: | diff --git a/extensions/tokenizers/build.sh b/extensions/tokenizers/build.sh index c3cbf793d70..58c4df4f017 100755 --- a/extensions/tokenizers/build.sh +++ b/extensions/tokenizers/build.sh @@ -29,8 +29,10 @@ javac -sourcepath src/main/java/ src/main/java/ai/djl/huggingface/tokenizers/jni javac -sourcepath src/main/java/ src/main/java/ai/djl/engine/rust/RustLibrary.java -h build/include -d build/classes RUST_MANIFEST=rust/Cargo.toml -if [ -x "$(command -v nvcc)" ]; then - cargo build --manifest-path $RUST_MANIFEST --release --features cuda,flash-attn,cublaslt +if [[ "$FLAVOR" = "cpu"* ]]; then + cargo build --manifest-path $RUST_MANIFEST --release +elif [[ "$FLAVOR" = "cu"* && "$FLAVOR" > "cu121" ]]; then + cargo build --manifest-path $RUST_MANIFEST --release --features cuda,cublaslt,flash-attn else cargo build --manifest-path $RUST_MANIFEST --release fi diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/LibUtils.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/LibUtils.java index 684504de47a..2caa6e9312e 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/LibUtils.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/LibUtils.java @@ -38,7 +38,7 @@ public final class LibUtils { private static final Pattern VERSION_PATTERN = Pattern.compile( "(\\d+\\.\\d+\\.\\d+(-[a-z]+)?)-(\\d+\\.\\d+\\.\\d+)(-SNAPSHOT)?(-\\d+)?"); - private static final String FLAVOR_CU124 = "cu124"; + private static final int[] SUPPORTED_CUDA_VERSIONS = {122}; private static EngineException exception; @@ -89,8 +89,39 @@ private static Path copyJniLibrary(String[] libs) { Platform platform = Platform.detectPlatform("tokenizers"); String os = platform.getOsPrefix(); String classifier = platform.getClassifier(); - String flavor = platform.getFlavor(); String version = platform.getVersion(); + String flavor = Utils.getEnvOrSystemProperty("TOKENIZERS_FLAVOR"); + boolean override = flavor != null && !flavor.isEmpty(); + if (override) { + logger.info("Uses override TOKENIZERS_FLAVOR: {}", flavor); + } else { + if (Utils.isOfflineMode() || "win".equals(os)) { + flavor = "cpu"; + } else { + flavor = platform.getFlavor(); + } + } + + // Find the highest matching CUDA version + if (flavor.startsWith("cu")) { + int cudaVersion = Integer.parseInt(flavor.substring(2, 5)); + boolean match = false; + for (int v : SUPPORTED_CUDA_VERSIONS) { + if (override && cudaVersion == v) { + match = true; + break; + } else if (cudaVersion >= v) { + flavor = "cu" + v; + match = true; + break; + } + } + if (!match) { + logger.warn("No matching cuda flavor for {} found: {}.", classifier, flavor); + flavor = "cpu"; // Fallback to CPU + } + } + Path dir = cacheDir.resolve(version + '-' + flavor + '-' + classifier); Path path = dir.resolve(LIB_NAME); logger.debug("Using cache dir: {}", dir); @@ -98,8 +129,13 @@ private static Path copyJniLibrary(String[] libs) { return dir.toAbsolutePath(); } - // For Linux cuda 12.x, download JNI library - if (flavor.startsWith("cu12") && !"win".equals(os)) { + // Copy JNI library from classpath + if (copyJniLibraryFromClasspath(libs, cacheDir, dir, classifier, flavor)) { + return dir.toAbsolutePath(); + } + + // Download JNI library + if (flavor.startsWith("cu")) { Matcher matcher = VERSION_PATTERN.matcher(version); if (!matcher.matches()) { throw new EngineException("Unexpected version: " + version); @@ -107,11 +143,14 @@ private static Path copyJniLibrary(String[] libs) { String jniVersion = matcher.group(1); String djlVersion = matcher.group(3); - downloadJniLib(dir, path, djlVersion, jniVersion, classifier, FLAVOR_CU124); + downloadJniLib(dir, path, djlVersion, jniVersion, classifier, flavor); return dir.toAbsolutePath(); } + return null; + } - // Extract JNI library from classpath + private static boolean copyJniLibraryFromClasspath( + String[] libs, Path cacheDir, Path dir, String classifier, String flavor) { Path tmp = null; try { Files.createDirectories(cacheDir); @@ -126,14 +165,15 @@ private static Path copyJniLibrary(String[] libs) { } } Utils.moveQuietly(tmp, dir); - return dir.toAbsolutePath(); + return true; } catch (IOException e) { - throw new IllegalStateException("Cannot copy jni files", e); + logger.error("Cannot copy jni files", e); } finally { if (tmp != null) { Utils.deleteQuietly(tmp); } } + return false; } private static void downloadJniLib(