The DL4J supports models created in the popular Python Tensorflow and Keras frameworks. As of 1.0.0-M2, Keras models (including tf.keras) can be imported into Deeplearning. TensorFlow frozen format models can be imported into SameDiff.
Models in Tensorflow have to be converted to "frozen" pbs (protobuf). More information on freezing Tensorflow models can be found here for Tensorflow 1.X and here for Tensorflow 2.X. Keras models have to be saved in h5 format. More information can be found here. Importing both Keras 1 and Keras 2 models are supported. Of note - importing models saved with tf.keras is also supported. Currently general TensorFlow operations within Keras models (i.e., those not part of the tf.keras API) are currently importable but support inference only. Full training is supported for anything that is part of the Keras API.
Go back to the main repository page to explore other features/functionality of the Eclipse Deeplearning4J ecosystem. File an issue here to request new features.
The examples in this project and what they demonstrate are briefly described below. This is also the recommended order to explore them in. There is an FAQ gathered from the example READMEs available here as well.
- SimpleSequentialMlpImport.java Basic example for importing a Keras Sequential model into DL4J for training or inference.
- SimpleFunctionalMlpImport.java Basic example for importing a Keras functional Model into DL4J for training or inference.
- ImportDeepMoji.java Import of DeepMoji application. Demonstrates implementing a custom layer for import.
- MNISTMLP.java Basic example imports a frozen TF model trained on mnist. Python scripts used available.
- BostonHousingPricesModel.java Another basic example with the boston housing prices dataset
- ModifyMNISTMLP.java Import a frozen TF model. Demonstrate static execution, modify the graph and then execute it dynamically.
- ImportMobileNetExample.md Import MobileNet and run inference on it to give the same metrics as those obtained in Tensorflow.
- TFGraphRunnerExample.java Runs a tensorflow graph from java using the tensorflow graph runner.
- MobileNetTransferLearningExample.md Transfer learning on an imported TF mobile net model for CIFAR10
- BertInferenceExample.md Run inference on a BERT model trained in Tensorflow to give the same metrics as those obtained in Tensorflow.