Skip to content

Commit

Permalink
Workaround #179, allow load shared library from different ClassLoader
Browse files Browse the repository at this point in the history
Change-Id: I60ba3469cc841c2bdb2d1696f2b0926c11db1f37
  • Loading branch information
frankfliu committed Oct 21, 2020
1 parent 518a08a commit 9106f95
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
Expand Down Expand Up @@ -71,7 +72,7 @@ public static void loadLibrary() {
if (System.getProperty("os.name").startsWith("Win")) {
loadWinDependencies(libName);
}
System.load(libName); // NOPMD
loadNativeLibrary(libName);
}

public static String getLibName() {
Expand Down Expand Up @@ -107,14 +108,14 @@ private static void loadWinDependencies(String libName) {
})
.map(path -> path.toAbsolutePath().toString())
.forEach(System::load);
System.load(libDir.resolve("fbgemm.dll").toAbsolutePath().toString());
System.load(libDir.resolve("torch_cpu.dll").toAbsolutePath().toString());
loadNativeLibrary(libDir.resolve("fbgemm.dll").toAbsolutePath().toString());
loadNativeLibrary(libDir.resolve("torch_cpu.dll").toAbsolutePath().toString());
if (Files.exists(libDir.resolve("c10_cuda.dll"))) {
// Windows System.load is global load
System.load(libDir.resolve("c10_cuda.dll").toAbsolutePath().toString());
System.load(libDir.resolve("torch_cuda.dll").toAbsolutePath().toString());
loadNativeLibrary(libDir.resolve("c10_cuda.dll").toAbsolutePath().toString());
loadNativeLibrary(libDir.resolve("torch_cuda.dll").toAbsolutePath().toString());
}
System.load(libDir.resolve("torch.dll").toAbsolutePath().toString());
loadNativeLibrary(libDir.resolve("torch.dll").toAbsolutePath().toString());
} catch (IOException e) {
throw new IllegalArgumentException("Folder not exist! " + libDir, e);
}
Expand Down Expand Up @@ -291,6 +292,20 @@ private static String copyNativeLibraryFromClasspath(Platform platform) {
}
}

private static void loadNativeLibrary(String path) {
String nativeHelper = System.getProperty("ai.djl.pytorch.native_helper");
if (nativeHelper != null && !nativeHelper.isEmpty()) {
try {
Class<?> clazz = Class.forName(nativeHelper);
Method method = clazz.getDeclaredMethod("load", String.class);
method.invoke(null, path);
} catch (ReflectiveOperationException e) {
throw new IllegalArgumentException("Invalid native_helper: " + nativeHelper, e);
}
}
System.load(path); // NOPMD
}

private static String downloadPyTorch(Platform platform, AtomicBoolean fallback)
throws IOException {
String version = platform.getVersion();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.pytorch.integration;

import ai.djl.engine.Engine;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

public class LibUtilsTest {

@BeforeClass
public void setup() {
System.setProperty(
"ai.djl.pytorch.native_helper", "ai.djl.pytorch.integration.LibUtilsTest");
}

@AfterClass
public void teardown() {
System.setProperty("ai.djl.pytorch.native_helper", "");
}

@Test
public void test() {
Engine.getInstance();
}

public static void load(String path) {
System.load(path); // NOPMD
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
* and limitations under the License.
*/

package integration;
package ai.djl.pytorch.integration;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
* and limitations under the License.
*/

package integration;
package ai.djl.pytorch.integration;

import ai.djl.Application;
import ai.djl.MalformedModelException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
* and limitations under the License.
*/
/** The integration test for testing PyTorch specific features. */
package integration;
package ai.djl.pytorch.integration;
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.pytorch.jni;

/** A helper class allows engine shared library to be loaded from different class loader. */
public final class NativeHelper {

private NativeHelper() {}

/**
* Load native shared library from file.
*
* @param path the file to load
*/
public static void load(String path) {
System.load(path); // NOPMD
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
/** Contains helper class to load native shared library. */
package ai.djl.pytorch.jni;

0 comments on commit 9106f95

Please sign in to comment.