From b316464b7ea9f445140e274adbd9cb2c90180ff3 Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Tue, 5 Dec 2023 12:46:53 -0800 Subject: [PATCH 1/6] Creates DJL manual engine initialization fixes #2875 reverts #2876 This adds new support for DJL manual initialization of engines to support `DJL_ENGINE_MANUAL_INIT`. Once done, no engines providers will be found or loaded on startup. Instead, they can be added manually by: ```java PtEngineProvider provider = new PtEngineProvider(); provider.getEngine(); // Optional, throws exception if the provider can not load Engine.registerEngine(provider); Engine.setDefaultEngine(provider.getEngineName()); // Optional, sets as default ``` --- api/src/main/java/ai/djl/engine/Engine.java | 49 ++++++++++++++----- docs/development/troubleshooting.md | 5 ++ .../djl/pytorch/engine/PtEngineProvider.java | 15 ++---- 3 files changed, 48 insertions(+), 21 deletions(-) diff --git a/api/src/main/java/ai/djl/engine/Engine.java b/api/src/main/java/ai/djl/engine/Engine.java index 8a1fc8871ac..47a6e50f386 100644 --- a/api/src/main/java/ai/djl/engine/Engine.java +++ b/api/src/main/java/ai/djl/engine/Engine.java @@ -59,7 +59,7 @@ public abstract class Engine { private static final Map ALL_ENGINES = new ConcurrentHashMap<>(); - private static final String DEFAULT_ENGINE = initEngine(); + private static String defaultEngine = initEngine(); private static final Pattern PATTERN = Pattern.compile("KEY|TOKEN|PASSWORD", Pattern.CASE_INSENSITIVE); @@ -69,6 +69,10 @@ public abstract class Engine { private Integer seed; private static synchronized String initEngine() { + if (Boolean.parseBoolean(Utils.getenv("DJL_ENGINE_MANUAL_INIT"))) { + return null; + } + ServiceLoader loaders = ServiceLoader.load(EngineProvider.class); for (EngineProvider provider : loaders) { registerEngine(provider); @@ -80,21 +84,21 @@ private static synchronized String initEngine() { } String def = System.getProperty("ai.djl.default_engine"); - String defaultEngine = Utils.getenv("DJL_DEFAULT_ENGINE", def); - if (defaultEngine == null || defaultEngine.isEmpty()) { + String newDefaultEngine = Utils.getenv("DJL_DEFAULT_ENGINE", def); + if (newDefaultEngine == null || newDefaultEngine.isEmpty()) { int rank = Integer.MAX_VALUE; for (EngineProvider provider : ALL_ENGINES.values()) { if (provider.getEngineRank() < rank) { - defaultEngine = provider.getEngineName(); + newDefaultEngine = provider.getEngineName(); rank = provider.getEngineRank(); } } - } else if (!ALL_ENGINES.containsKey(defaultEngine)) { - throw new EngineException("Unknown default engine: " + defaultEngine); + } else if (!ALL_ENGINES.containsKey(newDefaultEngine)) { + throw new EngineException("Unknown default engine: " + newDefaultEngine); } - logger.debug("Found default engine: {}", defaultEngine); - Ec2Utils.callHome(defaultEngine); - return defaultEngine; + logger.debug("Found default engine: {}", newDefaultEngine); + Ec2Utils.callHome(newDefaultEngine); + return newDefaultEngine; } /** @@ -124,7 +128,7 @@ private static synchronized String initEngine() { * @return the default Engine name */ public static String getDefaultEngineName() { - return System.getProperty("ai.djl.default_engine", DEFAULT_ENGINE); + return System.getProperty("ai.djl.default_engine", defaultEngine); } /** @@ -134,7 +138,7 @@ public static String getDefaultEngineName() { * @see EngineProvider */ public static Engine getInstance() { - if (DEFAULT_ENGINE == null) { + if (defaultEngine == null) { throw new EngineException( "No deep learning engine found." + System.lineSeparator() @@ -166,6 +170,29 @@ public static void registerEngine(EngineProvider provider) { ALL_ENGINES.putIfAbsent(provider.getEngineName(), provider); } + /** + * Returns the default engine. + * + * @return the default engine + */ + public static String getDefaultEngine() { + return defaultEngine; + } + + /** + * Sets the default engine returned by {@link #getInstance()}. + * + * @param engineName the new default engine's name + */ + public static void setDefaultEngine(String engineName) { + Engine engine = getEngine(engineName); + if (engine != null) { + // Requires an engine to be loaded (without exception) before being the default + logger.debug("Setting new default engine: {}", engineName); + defaultEngine = engineName; + } + } + /** * Returns a set of engine names that are loaded. * diff --git a/docs/development/troubleshooting.md b/docs/development/troubleshooting.md index ff03d32648e..1a04592dc12 100644 --- a/docs/development/troubleshooting.md +++ b/docs/development/troubleshooting.md @@ -105,6 +105,11 @@ For more information, please refer to [DJL Cache Management](cache_management.md It happened when you had a wrong version with DJL and Deep Engines. You can check the combination [here](dependency_management.md) and use DJL BOM to solve the issue. +### 1.6 Manual initialization + +If you are using manual engine initialization, you must both register an engine and set it as the default. +This can be done with `Engine.registerEngine(..)` and `Engine.setDefaultEngine(..)`. + ## 2. IntelliJ throws the `No Log4j 2 configuration file found.` exception. The following exception may appear after running the `./gradlew clean` command: diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java index 42ca3c5b8a5..1b9cdd0ab19 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java @@ -18,8 +18,6 @@ /** {@code PtEngineProvider} is the PyTorch implementation of {@link EngineProvider}. */ public class PtEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD - /** {@inheritDoc} */ @Override public String getEngineName() { @@ -35,13 +33,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { - synchronized (PtEngineProvider.class) { - if (engine == null) { - engine = PtEngine.newInstance(); - } - } - } - return engine; + return InstanceHolder.INSTANCE; + } + + private static class InstanceHolder { + static final Engine INSTANCE = PtEngine.newInstance(); } } From 134c581093bbf906efc1a193b6d6c216e8401a53 Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Tue, 5 Dec 2023 12:59:13 -0800 Subject: [PATCH 2/6] Revert "[tensorflow] Revert InstanceHolder for TensorFlow engine (#2884)" This reverts commit 586bb0709dbde3950509f217c80a9c6b829a06fd. --- .../djl/tensorflow/engine/TfEngineProvider.java | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java index ad440a47951..f42f691d222 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java @@ -18,8 +18,6 @@ /** {@code TfEngineProvider} is the TensorFlow implementation of {@link EngineProvider}. */ public class TfEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD - /** {@inheritDoc} */ @Override public String getEngineName() { @@ -35,13 +33,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { - synchronized (TfEngineProvider.class) { - if (engine == null) { - engine = TfEngine.newInstance(); - } - } - } - return engine; + return InstanceHolder.INSTANCE; + } + + private static class InstanceHolder { + static final Engine INSTANCE = TfEngine.newInstance(); } } From 0be9c2587377e9eebb074074b9287c90a6e832eb Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Tue, 5 Dec 2023 15:15:19 -0800 Subject: [PATCH 3/6] Revert "[api] Replace double-check singlton with lazy initialization (#2826)" This reverts commit 39278672bebe9c5f9d590d73914a2a9e2f53fb91. --- .../java/ai/djl/ml/lightgbm/LgbmEngineProvider.java | 13 ++++++++----- .../java/ai/djl/ml/xgboost/XgbEngineProvider.java | 13 ++++++++----- .../java/ai/djl/mxnet/engine/MxEngineProvider.java | 13 ++++++++----- .../djl/onnxruntime/engine/OrtEngineProvider.java | 13 ++++++++----- .../djl/paddlepaddle/engine/PpEngineProvider.java | 13 ++++++++----- .../ai/djl/pytorch/engine/PtEngineProvider.java | 13 ++++++++----- .../ai/djl/tensorflow/engine/TfEngineProvider.java | 13 ++++++++----- .../ai/djl/tensorrt/engine/TrtEngineProvider.java | 13 ++++++++----- .../java/ai/djl/tensorrt/engine/TrtEngineTest.java | 2 +- .../ai/djl/tensorrt/engine/TrtNDManagerTest.java | 2 +- .../java/ai/djl/tensorrt/integration/TrtTest.java | 6 +++--- .../ai/djl/tflite/engine/TfLiteEngineProvider.java | 13 ++++++++----- 12 files changed, 77 insertions(+), 50 deletions(-) diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java index f8c84c753ef..a253ce3d246 100644 --- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java @@ -18,6 +18,8 @@ /** {@code LgbmEngineProvider} is the LightGBM implementation of {@link EngineProvider}. */ public class LgbmEngineProvider implements EngineProvider { + private static volatile Engine engine; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +35,11 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = LgbmEngine.newInstance(); + if (engine == null) { + synchronized (LgbmEngineProvider.class) { + engine = LgbmEngine.newInstance(); + } + } + return engine; } } diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java index 5859f3f344d..19cba32cc71 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java @@ -18,6 +18,8 @@ /** {@code XgbEngineProvider} is the XGBoost implementation of {@link EngineProvider}. */ public class XgbEngineProvider implements EngineProvider { + private static volatile Engine engine; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +35,11 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = XgbEngine.newInstance(); + if (engine == null) { + synchronized (XgbEngineProvider.class) { + engine = XgbEngine.newInstance(); + } + } + return engine; } } diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java index 5f45116f615..f30a6a89252 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java @@ -18,6 +18,8 @@ /** {@code MxEngineProvider} is the MXNet implementation of {@link EngineProvider}. */ public class MxEngineProvider implements EngineProvider { + private static volatile Engine engine; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +35,11 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = MxEngine.newInstance(); + if (engine == null) { + synchronized (MxEngineProvider.class) { + engine = MxEngine.newInstance(); + } + } + return engine; } } diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java index 005c0fa25f1..c673b3dcbf1 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java @@ -18,6 +18,8 @@ /** {@code OrtEngineProvider} is the ONNX Runtime implementation of {@link EngineProvider}. */ public class OrtEngineProvider implements EngineProvider { + private static volatile Engine engine; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +35,11 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = OrtEngine.newInstance(); + if (engine == null) { + synchronized (OrtEngineProvider.class) { + engine = OrtEngine.newInstance(); + } + } + return engine; } } diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java index 59e5cd90724..e2b5bdd35a0 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java +++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java @@ -18,6 +18,8 @@ /** {@code PpEngineProvider} is the PaddlePaddle implementation of {@link EngineProvider}. */ public class PpEngineProvider implements EngineProvider { + private static volatile Engine engine; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +35,11 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = PpEngine.newInstance(); + if (engine == null) { + synchronized (PpEngineProvider.class) { + engine = PpEngine.newInstance(); + } + } + return engine; } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java index 1b9cdd0ab19..57ae6c09d34 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java @@ -18,6 +18,8 @@ /** {@code PtEngineProvider} is the PyTorch implementation of {@link EngineProvider}. */ public class PtEngineProvider implements EngineProvider { + private static volatile Engine engine; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +35,11 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = PtEngine.newInstance(); + if (engine == null) { + synchronized (PtEngineProvider.class) { + engine = PtEngine.newInstance(); + } + } + return engine; } } diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java index f42f691d222..d964ea5c295 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java @@ -18,6 +18,8 @@ /** {@code TfEngineProvider} is the TensorFlow implementation of {@link EngineProvider}. */ public class TfEngineProvider implements EngineProvider { + private static volatile Engine engine; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +35,11 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = TfEngine.newInstance(); + if (engine == null) { + synchronized (TfEngineProvider.class) { + engine = TfEngine.newInstance(); + } + } + return engine; } } diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java index d92ed9e449d..05a7eceeb41 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java @@ -18,6 +18,8 @@ /** {@code TrtEngineProvider} is the TensorRT implementation of {@link EngineProvider}. */ public class TrtEngineProvider implements EngineProvider { + private static volatile Engine engine; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +35,11 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = TrtEngine.newInstance(); + if (engine == null) { + synchronized (TrtEngineProvider.class) { + engine = TrtEngine.newInstance(); + } + } + return engine; } } diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java index efd9d89e509..96066b380e1 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java @@ -26,7 +26,7 @@ public void getVersion() { try { Engine engine = Engine.getEngine("TensorRT"); version = engine.getVersion(); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } Assert.assertEquals(version, "8.4.1"); diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java index 09001f0e2da..24d734af54c 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java @@ -28,7 +28,7 @@ public void testNDArray() { Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } if (!engine.defaultDevice().isGpu()) { diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java index 99cbc6f763e..105e057ba0a 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java @@ -49,7 +49,7 @@ public void testTrtOnnx() throws ModelException, IOException, TranslateException Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } if (!engine.defaultDevice().isGpu()) { @@ -75,7 +75,7 @@ public void testTrtUff() throws ModelException, IOException, TranslateException Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } if (!engine.defaultDevice().isGpu()) { @@ -112,7 +112,7 @@ public void testSerializedEngine() throws ModelException, IOException, Translate Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } Device device = engine.defaultDevice(); diff --git a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java index fb61551a3bf..aa0fdb73d21 100644 --- a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java +++ b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java @@ -18,6 +18,8 @@ /** {@code TfLiteEngineProvider} is the TFLite implementation of {@link EngineProvider}. */ public class TfLiteEngineProvider implements EngineProvider { + private static volatile Engine engine; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +35,11 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = TfLiteEngine.newInstance(); + if (engine == null) { + synchronized (TfLiteEngineProvider.class) { + engine = TfLiteEngine.newInstance(); + } + } + return engine; } } From 09cddb0a05edb2accaa2b9676aeb202f3099518e Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Mon, 11 Dec 2023 13:37:00 -0800 Subject: [PATCH 4/6] Make engines initialized This makes several updates: - engines will now initialize once per instance of EngineProvider rather than re-attempt to initialize - Registering an engine can overwrite the existing one - All engines now use the synchronized form rather than the static instance holder. This allows them to have multiple versions and be local to the instance rather than global (but the instance is saved globally) --- api/src/main/java/ai/djl/engine/Engine.java | 2 +- .../java/ai/djl/ml/lightgbm/LgbmEngineProvider.java | 10 +++++++--- .../main/java/ai/djl/ml/xgboost/XgbEngineProvider.java | 10 +++++++--- .../java/ai/djl/mxnet/engine/MxEngineProvider.java | 10 +++++++--- .../ai/djl/onnxruntime/engine/OrtEngineProvider.java | 10 +++++++--- .../ai/djl/paddlepaddle/engine/PpEngineProvider.java | 10 +++++++--- .../java/ai/djl/pytorch/engine/PtEngineProvider.java | 10 +++++++--- .../ai/djl/tensorflow/engine/TfEngineProvider.java | 10 +++++++--- .../java/ai/djl/tensorrt/engine/TrtEngineProvider.java | 10 +++++++--- .../ai/djl/tflite/engine/TfLiteEngineProvider.java | 10 +++++++--- 10 files changed, 64 insertions(+), 28 deletions(-) diff --git a/api/src/main/java/ai/djl/engine/Engine.java b/api/src/main/java/ai/djl/engine/Engine.java index 47a6e50f386..d67b184e2d5 100644 --- a/api/src/main/java/ai/djl/engine/Engine.java +++ b/api/src/main/java/ai/djl/engine/Engine.java @@ -167,7 +167,7 @@ public static boolean hasEngine(String engineName) { */ public static void registerEngine(EngineProvider provider) { logger.debug("Registering EngineProvider: {}", provider.getEngineName()); - ALL_ENGINES.putIfAbsent(provider.getEngineName(), provider); + ALL_ENGINES.put(provider.getEngineName(), provider); } /** diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java index a253ce3d246..583cd8132b2 100644 --- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java @@ -18,7 +18,8 @@ /** {@code LgbmEngineProvider} is the LightGBM implementation of {@link EngineProvider}. */ public class LgbmEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,12 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (LgbmEngineProvider.class) { - engine = LgbmEngine.newInstance(); + if (!initialized) { + initialized = true; + engine = LgbmEngine.newInstance(); + } } } return engine; diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java index 19cba32cc71..8b534d5196c 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java @@ -18,7 +18,8 @@ /** {@code XgbEngineProvider} is the XGBoost implementation of {@link EngineProvider}. */ public class XgbEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,12 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (XgbEngineProvider.class) { - engine = XgbEngine.newInstance(); + if (!initialized) { + initialized = true; + engine = XgbEngine.newInstance(); + } } } return engine; diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java index f30a6a89252..2a5ab970560 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java @@ -18,7 +18,8 @@ /** {@code MxEngineProvider} is the MXNet implementation of {@link EngineProvider}. */ public class MxEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,12 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (MxEngineProvider.class) { - engine = MxEngine.newInstance(); + if (!initialized) { + initialized = true; + engine = MxEngine.newInstance(); + } } } return engine; diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java index c673b3dcbf1..5616eb80edb 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java @@ -18,7 +18,8 @@ /** {@code OrtEngineProvider} is the ONNX Runtime implementation of {@link EngineProvider}. */ public class OrtEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,12 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (OrtEngineProvider.class) { - engine = OrtEngine.newInstance(); + if (!initialized) { + initialized = true; + engine = OrtEngine.newInstance(); + } } } return engine; diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java index e2b5bdd35a0..e2fb86974f5 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java +++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java @@ -18,7 +18,8 @@ /** {@code PpEngineProvider} is the PaddlePaddle implementation of {@link EngineProvider}. */ public class PpEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,12 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (PpEngineProvider.class) { - engine = PpEngine.newInstance(); + if (!initialized) { + initialized = true; + engine = PpEngine.newInstance(); + } } } return engine; diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java index 57ae6c09d34..24be3e91d7a 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java @@ -18,7 +18,8 @@ /** {@code PtEngineProvider} is the PyTorch implementation of {@link EngineProvider}. */ public class PtEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,12 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (PtEngineProvider.class) { - engine = PtEngine.newInstance(); + if (!initialized) { + initialized = true; + engine = PtEngine.newInstance(); + } } } return engine; diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java index d964ea5c295..fa7813a49fb 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java @@ -18,7 +18,8 @@ /** {@code TfEngineProvider} is the TensorFlow implementation of {@link EngineProvider}. */ public class TfEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,12 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (TfEngineProvider.class) { - engine = TfEngine.newInstance(); + if (!initialized) { + initialized = true; + engine = TfEngine.newInstance(); + } } } return engine; diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java index 05a7eceeb41..8c90859c6c6 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java @@ -18,7 +18,8 @@ /** {@code TrtEngineProvider} is the TensorRT implementation of {@link EngineProvider}. */ public class TrtEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,12 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (TrtEngineProvider.class) { - engine = TrtEngine.newInstance(); + if (!initialized) { + initialized = true; + engine = TrtEngine.newInstance(); + } } } return engine; diff --git a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java index aa0fdb73d21..b46cad53b99 100644 --- a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java +++ b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java @@ -18,7 +18,8 @@ /** {@code TfLiteEngineProvider} is the TFLite implementation of {@link EngineProvider}. */ public class TfLiteEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,12 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (TfLiteEngineProvider.class) { - engine = TfLiteEngine.newInstance(); + if (!initialized) { + initialized = true; + engine = TfLiteEngine.newInstance(); + } } } return engine; From 406f655dc4c9c469f166d701fe82b7c2e8d44815 Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Mon, 11 Dec 2023 16:52:16 -0800 Subject: [PATCH 5/6] Throws Exception on bad getEngine --- api/src/main/java/ai/djl/engine/Engine.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/api/src/main/java/ai/djl/engine/Engine.java b/api/src/main/java/ai/djl/engine/Engine.java index d67b184e2d5..de354bd3e52 100644 --- a/api/src/main/java/ai/djl/engine/Engine.java +++ b/api/src/main/java/ai/djl/engine/Engine.java @@ -214,7 +214,12 @@ public static Engine getEngine(String engineName) { if (provider == null) { throw new IllegalArgumentException("Deep learning engine not found: " + engineName); } - return provider.getEngine(); + Engine engine = provider.getEngine(); + if (engine == null) { + throw new IllegalStateException( + "The engine " + engineName + " was not able to initialize"); + } + return engine; } /** From 16fbd1f08aa112f105b5869006efcf45ecca3ecd Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Tue, 9 Jan 2024 11:08:26 -0800 Subject: [PATCH 6/6] Removes unnecessary check --- api/src/main/java/ai/djl/engine/Engine.java | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/api/src/main/java/ai/djl/engine/Engine.java b/api/src/main/java/ai/djl/engine/Engine.java index de354bd3e52..a799c70f600 100644 --- a/api/src/main/java/ai/djl/engine/Engine.java +++ b/api/src/main/java/ai/djl/engine/Engine.java @@ -185,12 +185,11 @@ public static String getDefaultEngine() { * @param engineName the new default engine's name */ public static void setDefaultEngine(String engineName) { - Engine engine = getEngine(engineName); - if (engine != null) { - // Requires an engine to be loaded (without exception) before being the default - logger.debug("Setting new default engine: {}", engineName); - defaultEngine = engineName; - } + // Requires an engine to be loaded (without exception) before being the default + getEngine(engineName); + + logger.debug("Setting new default engine: {}", engineName); + defaultEngine = engineName; } /**