Java Convolutional Neural Network (CNN) package combined with Apache Spark framework
DeepSpark_java is an early version of ongoing DeepSpark project (https://github.com/deepspark/deepspark) implemented in pure java and jBlas. It provides GPU Acceleration using jCublas. (gpuAccel
option)
DeepSpark_java also supports local training running on single machine and distributed (sync & async) training aided by Apache Spark (http://spark.apache.org/)
Class | Description |
---|---|
Tensor | Base class for Tensor. Implemented using jBlas |
Weight | Class for representing Network parameters |
Sample | Class for representing Data container |
Weight
and Sample
class are implemented using Tensor
.
To load custom dataset, users should create own data loader to be compatible with Sample
.
We provide built-in Mnist/CIFAR/ImageNet loader (See examples on src/main/java/org/acl/deepspark/utils/
)
Layer | Description |
---|---|
Layer | Base interface for layers |
BaseLayer | Abstract class implementing Layer interface |
ConvolutionLayer | Convolutional layer |
PoolingLayer | Pooling (subsampling) layer |
FullyConnectedLayer | Normal fully connected layer |
Users should define LayerConf
to specify layer details (LayerType, kernel width/height, stride, padding etc.)
To add more options, check on src/main/java/org/acl/deepspark/nn/conf/LayerConf
Layer | Description |
---|---|
NeuralNet | Class for representing overall Network. Provides methods for initializing, training and inference |
NeuralNetRunner | Runner of NeuralNet on local machine |
DistNeuralNetRunner | Runner of NeuralNet in synchronous distributed setting |
DistAsyncNeuralNetRunner | Runner of NeuralNet in asynchronous distributed setting |
Users should define NeuralNetConf
to specify training details (lr, l2_lambda, momentum, gpuAccel etc.)
To add more options, check on src/main/java/org/acl/deepspark/nn/conf/NeuralNetConf
For asynchronous update, simple ParameterServer/Client class are implemented. Check on src/main/java/org/acl/deepspark/nn/async
For actual usage code, see examples on src/test/java/org/acl/deepspark/nn/driver
Type | Path |
---|---|
Single Machine | MnistTest.java / CIFARTest.java |
Distributed (sync) | SyncMnistTest.java |
Distributed (async) | AsyncMnistTest.java |
Kim, Hanjoo, Jaehong Park, Jaehee Jang, and Sungroh Yoon. "DeepSpark: Spark-Based Deep Learning Supporting Asynchronous Updates and Caffe Compatibility." arXiv preprint arXiv:1602.08191 (2016).