Skip to content

Commit

Permalink
Update sample for TensorFlow and fix ops as necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
saudet committed Oct 2, 2016
1 parent aec7255 commit dd4ffa7
Show file tree
Hide file tree
Showing 4 changed files with 1,404 additions and 175 deletions.
42 changes: 24 additions & 18 deletions tensorflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Introduction
------------
This directory contains the JavaCPP Presets module for:

* TensorFlow 0.8.0 http://www.tensorflow.org/
* TensorFlow 0.10.0 http://www.tensorflow.org/

Please refer to the parent README.md file for more detailed information about the JavaCPP Presets.

Expand All @@ -21,7 +21,7 @@ Sample Usage
------------
Here is a simple example of TensorFlow ported to Java from this C++ source file:

* https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/tutorials/example_trainer.cc
* https://github.com/tensorflow/tensorflow/blob/v0.10.0/tensorflow/cc/tutorials/example_trainer.cc

We can use [Maven 3](http://maven.apache.org/) to download and install automatically all the class files as well as the native binaries. To run this sample code, after creating the `pom.xml` and `src/main/java/ExampleTrainer.java` source files below, simply execute on the command line:
```bash
Expand All @@ -34,23 +34,23 @@ We can use [Maven 3](http://maven.apache.org/) to download and install automatic
<modelVersion>4.0.0</modelVersion>
<groupId>org.bytedeco.javacpp-presets.tensorflow</groupId>
<artifactId>exampletrainer</artifactId>
<version>1.2</version>
<version>1.2.5-SNAPSHOT</version>
<properties>
<exec.mainClass>ExampleTrainer</exec.mainClass>
</properties>
<dependencies>
<dependency>
<groupId>org.bytedeco.javacpp-presets</groupId>
<artifactId>tensorflow</artifactId>
<version>0.8.0-1.2</version>
<version>0.10.0-1.2.5-SNAPSHOT</version>
</dependency>
</dependencies>
</project>
```

### The `src/main/java/ExampleTrainer.java` source file
```java
/* Copyright 2015 Google Inc. All Rights Reserved.
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -125,31 +125,31 @@ public class ExampleTrainer {
static GraphDef CreateGraphDef() throws Exception {
// TODO(jeff,opensource): This should really be a more interesting
// computation. Maybe turn this into an mnist model instead?
GraphDefBuilder b = new GraphDefBuilder();
Scope root = Scope.NewRootScope();

// Store rows [3, 2] and [-1, 0] in row major format.
Node a = Const(new float[] {3.f, 2.f, -1.f, 0.f}, new TensorShape(2, 2), b.opts());
// a = [3 2; -1 0]
Output a = Const(root, Tensor.create(new float[] {3.f, 2.f, -1.f, 0.f}, new TensorShape(2, 2)));

// x is from the feed.
Node x = Const(new float[] {0.f}, new TensorShape(2, 1), b.opts().WithName("x"));
// x = [1.0; 1.0]
Output x = Const(root.WithOpName("x"), Tensor.create(new float[] {1.f, 1.f}, new TensorShape(2, 1)));

// y = A * x
Node y = MatMul(a, x, b.opts().WithName("y"));
// y = a * x
MatMul y = new MatMul(root.WithOpName("y"), new Input(a), new Input(x));

// y2 = y.^2
Node y2 = Square(y, b.opts());
Square y2 = new Square(root, y.asInput());

// y2_sum = sum(y2)
Node y2_sum = Sum(y2, Const(0, b.opts()), b.opts());
Sum y2_sum = new Sum(root, y2.asInput(), new Input(0));

// y_norm = sqrt(y2_sum)
Node y_norm = Sqrt(y2_sum, b.opts());
Sqrt y_norm = new Sqrt(root, y2_sum.asInput());

// y_normalized = y ./ y_norm
Div(y, y_norm, b.opts().WithName("y_normalized"));
new Div(root.WithOpName("y_normalized"), y.asInput(), y_norm.asInput());

GraphDef def = new GraphDef();
Status s = b.ToGraphDef(def);
Status s = root.ToGraphDef(def);
if (!s.ok()) {
throw new Exception(s.error_message().getString());
}
Expand All @@ -161,7 +161,10 @@ public class ExampleTrainer {
assert y.NumElements() == 2;
FloatBuffer x_flat = x.createBuffer();
FloatBuffer y_flat = y.createBuffer();
float lambda = y_flat.get(0) / x_flat.get(0);
// Compute an estimate of the eigenvalue via
// (x' A x) / (x' x) = (x' y) / (x' x)
// and exploit the fact that x' x = 1 by assumption
float lambda = x_flat.get(0) * y_flat.get(0) + x_flat.get(1) * y_flat.get(1);
return String.format("lambda = %8.6f x = [%8.6f %8.6f] y = [%8.6f %8.6f]",
lambda, x_flat.get(0), x_flat.get(1), y_flat.get(0), y_flat.get(1));
}
Expand Down Expand Up @@ -192,6 +195,9 @@ public class ExampleTrainer {
FloatBuffer x_flat = x.createBuffer();
x_flat.put(0, (float)Math.random());
x_flat.put(1, (float)Math.random());
float inv_norm = 1 / (float)Math.sqrt(x_flat.get(0) * x_flat.get(0) + x_flat.get(1) * x_flat.get(1));
x_flat.put(0, x_flat.get(0) * inv_norm);
x_flat.put(1, x_flat.get(1) * inv_norm);

// Iterations.
TensorVector outputs = new TensorVector();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (C) 2015 Samuel Audet
* Copyright (C) 2015-2016 Samuel Audet
*
* Licensed either under the Apache License, Version 2.0, or (at your option)
* under the terms of the GNU General Public License as published by
Expand All @@ -22,7 +22,13 @@

package org.bytedeco.javacpp.helper;

import java.nio.ByteBuffer;
import java.nio.Buffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
Expand Down Expand Up @@ -107,6 +113,21 @@ public static abstract class AbstractTensor extends Pointer implements Indexable
static { Loader.load(); }
public AbstractTensor(Pointer p) { super(p); }

public static Tensor create(float[] data, TensorShape shape) { Tensor t = new Tensor(DT_FLOAT, shape); FloatBuffer b = t.createBuffer(); b.put(data); return t; }
public static Tensor create(double[] data, TensorShape shape) { Tensor t = new Tensor(DT_DOUBLE, shape); DoubleBuffer b = t.createBuffer(); b.put(data); return t; }
public static Tensor create(int[] data, TensorShape shape) { Tensor t = new Tensor(DT_INT32, shape); IntBuffer b = t.createBuffer(); b.put(data); return t; }
public static Tensor create(short[] data, TensorShape shape) { Tensor t = new Tensor(DT_INT16, shape); ShortBuffer b = t.createBuffer(); b.put(data); return t; }
public static Tensor create(byte[] data, TensorShape shape) { Tensor t = new Tensor(DT_INT8, shape); ByteBuffer b = t.createBuffer(); b.put(data); return t; }
public static Tensor create(long[] data, TensorShape shape) { Tensor t = new Tensor(DT_INT64, shape); LongBuffer b = t.createBuffer(); b.put(data); return t; }
public static Tensor create(String[] data, TensorShape shape) {
Tensor t = new Tensor(DT_STRING, shape);
StringArray a = t.createStringArray();
for (int i = 0; i < a.capacity(); i++) {
a.position(i).put(data[i]);
}
return t;
}

public abstract int dtype();
public abstract int dims();
public abstract long dim_size(int d);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,19 @@
"tensorflow/cc/framework/ops.h",
"tensorflow/cc/framework/cc_op_gen.h",
"tensorflow_adapters.h",
"tensorflow/cc/ops/standard_ops.h",
"tensorflow/cc/ops/const_op.h",
"tensorflow/cc/ops/array_ops.h",
"tensorflow/cc/ops/candidate_sampling_ops.h",
"tensorflow/cc/ops/control_flow_ops.h",
"tensorflow/cc/ops/data_flow_ops.h",
"tensorflow/cc/ops/image_ops.h",
"tensorflow/cc/ops/io_ops.h",
"tensorflow/cc/ops/linalg_ops.h",
"tensorflow/cc/ops/logging_ops.h",
"tensorflow/cc/ops/math_ops.h",
"tensorflow/cc/ops/nn_ops.h",
"tensorflow/cc/ops/no_op.h",
"tensorflow/cc/ops/parsing_ops.h",
"tensorflow/cc/ops/random_ops.h",
"tensorflow/cc/ops/sparse_ops.h",
Expand Down Expand Up @@ -259,9 +263,13 @@ public void map(InfoMap infoMap) {

infoMap.put(new Info("tensorflow::gtl::ArraySlice").annotations("@ArraySlice"))
.put(new Info("tensorflow::StringPiece").annotations("@StringPiece").valueTypes("BytePointer", "String").pointerTypes("BytePointer"))
.put(new Info("tensorflow::ops::Const(tensorflow::StringPiece, tensorflow::GraphDefBuilder::Options&)")
.javaText("@Namespace(\"tensorflow::ops\") public static native Node Const("
+ "@Cast({\"\", \"tensorflow::StringPiece&\"}) @StringPiece String s, @Const @ByRef GraphDefBuilder.Options options);"));
.put(new Info("tensorflow::ops::Input::Initializer").pointerTypes("Input.Initializer").valueTypes("@Const @ByRef Input.Initializer",
"@ByRef Tensor", "byte", "short", "int", "long", "float", "double", "boolean", "@StdString String", "@StdString BytePointer"));

String[] consts = {"unsigned char", "short", "int", "long long", "float", "double", "bool", "std::string", "tensorflow::StringPiece"};
for (int i = 0; i < consts.length; i++) {
infoMap.put(new Info("tensorflow::ops::Const<" + consts[i] + ">").javaNames("Const"));
}
}

public static class Fn extends FunctionPointer {
Expand Down
Loading

0 comments on commit dd4ffa7

Please sign in to comment.