Skip to content

Commit

Permalink
Fix custom gradient lookup bug (#512)
Browse files Browse the repository at this point in the history
* Bazel script cleanup

* Fix custom gradient lookup issue
  • Loading branch information
karllessard authored Jan 16, 2024
1 parent 5aeda89 commit f6acf31
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
import org.junit.jupiter.api.condition.OS;
import org.tensorflow.ndarray.index.Indices;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Merge;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.nn.NthElement;
import org.tensorflow.proto.DataType;
import org.tensorflow.types.TFloat32;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import static org.junit.jupiter.api.Assertions.*;

Expand Down Expand Up @@ -100,6 +103,24 @@ public void testCustomGradient() {
}
}

@DisabledOnOs(OS.WINDOWS)
@Test
public void applyGradientOnMultipleNodesOfSameOpType() {
try (Graph g = new Graph()) {
assertTrue(TensorFlow.registerCustomGradient(
Merge.Inputs.class,
(tf, op, gradInputs) -> gradInputs.stream().map(i -> tf.constant(-10)).collect(Collectors.toList())
));
var tf = Ops.create(g);
var initialValue = tf.constant(10);
var merge1 = tf.merge(List.of(initialValue, tf.constant(20)));
var merge2 = tf.merge(List.of(merge1.output(), tf.constant(30)));

// Just make sure that it won't throw
g.addGradients(merge2.output(), toArray(initialValue.asOutput()));
}
}

private static Output<?>[] toArray(Output<?>... outputs) {
return outputs;
}
Expand Down
51 changes: 1 addition & 50 deletions tensorflow-core/tensorflow-core-native/BUILD
Original file line number Diff line number Diff line change
@@ -1,58 +1,9 @@
load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_binary", "clean_dep", "if_windows")
load("@rules_java//java:defs.bzl", "java_proto_library")

java_proto_library(
name = "java_proto_gen_sources",
deps = [
clean_dep("//tensorflow/core:protos_all"),
"@org_tensorflow//tensorflow/core:protos_all",
"@local_tsl//tsl/protobuf:protos_all"
]
)

# cc_import(
# name = "libtensorflow_import",
# includes = [
# "external/org_tensorflow",
# "bazel-out/darwin_arm64-opt/bin/external/org_tensorflow",
# "external/com_google_protobuf/src/",
# "external/local_tsl"
# ],
# interface_library = if_windows(
# "tensorflow.lib",
# otherwise = None
# ),
# shared_library = select({
# clean_dep("//tensorflow:macos"): "libtensorflow_cc.dylib",
# clean_dep("//tensorflow:windows"): "tensorflow.dll",
# "//conditions:default": "libtensorflow_cc.so"
# })
# )
#
# cc_binary(
# name = "java_api_import",
# srcs = [
# "src/tools/native/api_import.cc",
# ],
# deps = [
# ":libtensorflow_import"
# ]
# )
#
# filegroup(
# name = "custom_ops_test",
# srcs = select({
# # FIXME(karllessard) Disable custom ops test on Windows since TF is still monolithic on this platform
# clean_dep("//tensorflow:windows"): [],
# "//conditions:default": [":libcustom_ops_test.so"],
# })
# )
#
# tf_cc_binary(
# name = "libcustom_ops_test.so",
# srcs = ["src/test/native/my_test_op.cc"],
# linkshared = 1,
# linkopts = ["-lm"],
# deps = [
# ":libtensorflow_import"
# ]
# )
22 changes: 0 additions & 22 deletions tensorflow-core/tensorflow-core-native/rules.bzl

This file was deleted.

11 changes: 0 additions & 11 deletions tensorflow-core/tensorflow-core-native/scripts/bazel_generate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ echo "Generate sources with Bazel"
${BAZEL_CMD:=bazel} --bazelrc=tensorflow.bazelrc $BUILD_FLAGS ${BUILD_USER_FLAGS:-} \
@org_tensorflow//tensorflow/tools/lib_package:jnilicenses_generate \
:java_proto_gen_sources
# :java_op_exporter \
# :java_api_import \
# :custom_ops_test

echo "Rebuilding generated source directories"
GEN_SRCS_DIR=src/gen/java
Expand All @@ -31,11 +28,3 @@ echo "Extracting TF/TSL proto Java sources"
cd $GEN_SRCS_DIR
find $TENSORFLOW_BIN $BAZEL_BIN/external/local_tsl/tsl -name \*-speed-src.jar -exec jar xf {} \;
rm -rf META-INF

# Export op defs
#echo "Exporting Ops"
#$BAZEL_BIN/java_op_exporter \
# $GEN_RESOURCE_DIR/ops.pb \
# $GEN_RESOURCE_DIR/ops.pbtxt \
# $BAZEL_SRCS/external/org_tensorflow/tensorflow/core/api_def/base_api \
# src/bazel/api_def
Original file line number Diff line number Diff line change
Expand Up @@ -4905,10 +4905,9 @@ public static native void TFE_InitializeLocalOnlyContext(TFE_Context ctx,
public static native @Cast("bool") boolean TFJ_HasGradient(@Cast("const char*") BytePointer op_type);
public static native @Cast("bool") boolean TFJ_HasGradient(String op_type);

/** Registers a gradient function for operations of type {@code op_type}. It is possible to register a new function even if another has already been registered for this
* type of operations (this will only generate a warning).
/** Registers a gradient function for operations of type {@code op_type}.
*
* Returns true if the function has been registered successfully */
* Returns true if the function has been registered successfully, false if operation failed or if gradient function is already registered to that {@code op_type}. */
public static native @Cast("bool") boolean TFJ_RegisterCustomGradient(@Cast("const char*") BytePointer op_type, TFJ_GradFuncAdapter custom_gradient_adapter);
public static native @Cast("bool") boolean TFJ_RegisterCustomGradient(String op_type, TFJ_GradFuncAdapter custom_gradient_adapter);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@ typedef int (*TFJ_GradFuncAdapter)(TFJ_GraphId graphId, TFJ_Scope* scope, TF_Ope
/// Returns true if a gradient function has already be registered for operations of type `op_type`
TF_CAPI_EXPORT extern bool TFJ_HasGradient(const char* op_type);

/// Registers a gradient function for operations of type `op_type`. It is possible to register a new function even if another has already been registered for this
/// type of operations (this will only generate a warning).
/// Registers a gradient function for operations of type `op_type`.
///
/// Returns true if the function has been registered successfully
/// Returns true if the function has been registered successfully, false if operation failed or if gradient function is already registered to that `op_type`.
TF_CAPI_EXPORT extern bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter custom_gradient_adapter);

#ifdef __cplusplus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ namespace tensorflow {
const vector<Output>& grad_inputs,
vector<Output>* grad_outputs)
{
auto found_adapter = g_grad_func_adapters.find(op.node()->name());
const string& op_type = op.node()->type_string();
auto found_adapter = g_grad_func_adapters.find(op_type);
if (found_adapter == g_grad_func_adapters.end()) {
return errors::NotFound("No gradient adapter found for operation ", op.node()->name());
return errors::NotFound("No gradient adapter found for operation ", op_type);
}
int num_inputs = grad_inputs.size();
TF_Output* inputs = (TF_Output*)malloc(num_inputs * sizeof(TF_Output));
Expand All @@ -58,7 +59,7 @@ namespace tensorflow {
inputs[i].index = grad_input.index();
}
TF_Output* outputs = NULL;
LOG(INFO) << "Calling custom gradient function for operation " << op.node()->name();
LOG(INFO) << "Calling Java gradient function for operation of type " << op_type;
int num_outputs = found_adapter->second(
static_cast<TFJ_GraphId>(scope.graph()),
struct_cast<TFJ_Scope>(const_cast<Scope*>(&scope)),
Expand Down Expand Up @@ -90,11 +91,10 @@ bool TFJ_HasGradient(const char* op_type) {
}

bool TFJ_RegisterCustomGradient(const char* op_type, TFJ_GradFuncAdapter grad_func_adapter) {
if (TFJ_HasGradient(op_type)) {
LOG(WARNING) << "Registering gradient function for operation " << op_type
<< ", which has already an existing gradient function";
} else {
LOG(INFO) << "Registering gradient function for operation " << op_type;
if (TFJ_HasGradient(op_type)) { // Check if gradient already exists otherwise the JVM might abort/crash
LOG(WARNING) << "Tried to register Java gradient function for operation " << op_type
<< ", which has already a registered function";
return false;
}
bool registered = GradOpRegistry::Global()->Register(op_type, CustomGradFunc);
if (registered) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,39 @@
*/
package org.tensorflow.internal.c_api;

import org.bytedeco.javacpp.PointerPointer;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.tensorflow.internal.c_api.global.tensorflow.TFJ_HasGradient;
import static org.tensorflow.internal.c_api.global.tensorflow.TFJ_RegisterCustomGradient;

// WARNING: Gradient registry in native library is stateful across all tests
public class GradientTest {

@Test
public void testExistingGradientCheck() {
assertTrue(TFJ_HasGradient("Cast"));
assertTrue(TFJ_HasGradient("Cast"));
}

@Test
public void testNonExistingGradientCheck() {
assertFalse(TFJ_HasGradient("NthElement"));
assertFalse(TFJ_HasGradient("NthElement"));
}

@Test
public void testNonExistingOpGradientCheck() {
assertFalse(TFJ_HasGradient("IDontExists"));
assertFalse(TFJ_HasGradient("IDontExists"));
}

@Test
public void registerCustomGradientAdapter() {
assertTrue(TFJ_RegisterCustomGradient("Merge", new TFJ_GradFuncAdapter()));
}

@Test
public void registerCustomGradientAdapterFailedIfGradFuncAlreadyRegistered() {
assertFalse(TFJ_RegisterCustomGradient("Add", new TFJ_GradFuncAdapter()));
}
}

0 comments on commit f6acf31

Please sign in to comment.