From d174868d8c5443bcb2e3cda4a976bb317eb9a93a Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 11 Jun 2021 16:01:41 -0400 Subject: [PATCH 1/3] work on grads --- python/tvm/relay/op/_tensor_grad.py | 62 ++++++++++++++++++++-- tests/python/relay/test_op_grad_level10.py | 34 +++++++++++- 2 files changed, 89 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index d5b891088933..c2d0089b31b0 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-argument -"""Backend compiler related feature registration""" +"""Gradient definitions for Relay operators""" from tvm.topi.nn.utils import get_pad_tuple from tvm.topi.utils import get_const_tuple from tvm.error import OpError @@ -63,6 +63,7 @@ strided_set, arange, scatter_nd, + strided_set, ) @@ -527,10 +528,7 @@ def softmax_grad(orig, grad): @register_gradient("nn.log_softmax") def log_softmax_grad(orig, grad): """Gradient of log_softmax""" - x = orig.args[0] - sm = _nn.softmax(x, axis=orig.attrs.axis) - grad = grad / sm - return softmax_grad(sm, grad) + return [grad - _sum(grad, axis=orig.attrs.axis, keepdims=True) * exp(orig)] @register_gradient("nn.bias_add") @@ -596,6 +594,12 @@ def cast_grad(orig, grad): return [cast_like(grad, x)] +@register_gradient("cast_like") +def cast_like_grad(orig, grad): + x, like = orig.args + return [cast_like(grad, x), zeros_like(like)] + + @register_gradient("nn.batch_flatten") def batch_flatten_grad(orig, grad): """Returns grad reshaped to data dims""" @@ -873,3 +877,51 @@ def less_equal_grad(orig, grad): Returns the gradient of less_equal. """ return [zeros_like(orig.args[0]), zeros_like(orig.args[1])] + + +@register_gradient("not_equal") +def not_equal_grad(orig, grad): + """ + Returns the gradient of not_equal (just zeros). + """ + return [zeros_like(orig.args[0]), zeros_like(orig.args[1])] + + +# TODO: test +@register_gradient("strided_slice") +def strided_slice_grad(orig, grad): + """ + Returns the gradient of strided_slice, which is equal to grad where the + input was sliced and zero elsewhere. + """ + x = orig.args[0] + begin = get_const_tuple(orig.attrs.begin) + end = get_const_tuple(orig.attrs.end) + strides = get_const_tuple(orig.attrs.strides) + if orig.attrs.slice_mode == "size": + # convert sizes to ending indices + end = list(end) + for i, (start, size) in enumerate(zip(begin, end)): + if size == -1: + end[i] = x.checked_type.shape[i] + else: + end[i] = start + size + else: + assert orig.attrs.slice_mode == "end" + return [strided_set(zeros_like(x), grad, begin, end, strides)] + + +@register_gradient("one_hot") +def one_hot_grad(orig, grad): + """ + Returns the gradient of one_hot, which is the sum of grad at on and off + indices for on_value and off_value respectively. + """ + indices, on_value, off_value = orig.args + + g_zeros = zeros_like(grad) + on_mask = equal(orig, on_value) + grad_on = collapse_sum_like(where(on_mask, grad, g_zeros), on_value) + grad_off = collapse_sum_like(where(on_mask, g_zeros, grad), off_value) + + return [zeros_like(indices), cast_like(grad_on, on_value), cast_like(grad_off, on_value)] diff --git a/tests/python/relay/test_op_grad_level10.py b/tests/python/relay/test_op_grad_level10.py index 4a6ffb933881..1cf6de43eaca 100644 --- a/tests/python/relay/test_op_grad_level10.py +++ b/tests/python/relay/test_op_grad_level10.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. import pytest +import numpy as np -from tvm import relay -from tvm.relay.testing import check_grad +from tvm import relay, transform +from tvm.relay.testing import check_grad, _np_randn_from_type, run_infer_type, run_opt_pass +from tvm.ir.instrument import pass_instrument def test_cross_entropy_grad(): @@ -72,5 +74,33 @@ def test_reverse_reshape_grad(): check_grad(relay.Function([x], relay.op.reverse_reshape(x, (-1, 0)))) +def test_one_hot_grad(): + indices_shape = (3, 4) + depth = 5 + axis = -1 + indices_dtype = "int32" + dtype = "float32" + inputs = [ + np.random.randint(depth, size=indices_shape, dtype=indices_dtype), + np.array(np.random.randn() * 1e-5).astype(dtype), + np.array(np.random.randn() * 1e-5).astype(dtype), + ] + test_inputs = inputs[1:] + + indices = relay.var("indices", shape=indices_shape, dtype=indices_dtype) + on_val = relay.var("on_val", shape=tuple(), dtype=dtype) + off_val = relay.var("off_val", shape=tuple(), dtype=dtype) + y = relay.one_hot(indices, on_val, off_val, depth, axis, dtype) + f = run_infer_type(relay.Function([indices, on_val, off_val], y)) + + @pass_instrument + class Bruh: + def run_before_pass(self, mod, info): + print("bruh", info.name) + + with transform.PassContext(instruments=[Bruh()]): + check_grad(f, inputs=inputs, test_inputs=test_inputs, mode="first_order") + + if __name__ == "__main__": pytest.main([__file__]) From c5c71fba547fbcc4262387c7f8b277b17ef27c45 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Mon, 21 Jun 2021 22:59:06 -0700 Subject: [PATCH 2/3] add tests --- python/tvm/relay/op/_tensor_grad.py | 1 - tests/python/relay/test_op_grad_level10.py | 41 ++++++++++------------ tests/python/relay/test_op_grad_level3.py | 7 ++++ tests/python/relay/test_op_grad_level4.py | 27 ++++++++++++++ 4 files changed, 53 insertions(+), 23 deletions(-) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index c2d0089b31b0..e2fbbeec5777 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -63,7 +63,6 @@ strided_set, arange, scatter_nd, - strided_set, ) diff --git a/tests/python/relay/test_op_grad_level10.py b/tests/python/relay/test_op_grad_level10.py index 1cf6de43eaca..d4553fc5ea4d 100644 --- a/tests/python/relay/test_op_grad_level10.py +++ b/tests/python/relay/test_op_grad_level10.py @@ -78,28 +78,25 @@ def test_one_hot_grad(): indices_shape = (3, 4) depth = 5 axis = -1 - indices_dtype = "int32" - dtype = "float32" - inputs = [ - np.random.randint(depth, size=indices_shape, dtype=indices_dtype), - np.array(np.random.randn() * 1e-5).astype(dtype), - np.array(np.random.randn() * 1e-5).astype(dtype), - ] - test_inputs = inputs[1:] - - indices = relay.var("indices", shape=indices_shape, dtype=indices_dtype) - on_val = relay.var("on_val", shape=tuple(), dtype=dtype) - off_val = relay.var("off_val", shape=tuple(), dtype=dtype) - y = relay.one_hot(indices, on_val, off_val, depth, axis, dtype) - f = run_infer_type(relay.Function([indices, on_val, off_val], y)) - - @pass_instrument - class Bruh: - def run_before_pass(self, mod, info): - print("bruh", info.name) - - with transform.PassContext(instruments=[Bruh()]): - check_grad(f, inputs=inputs, test_inputs=test_inputs, mode="first_order") + indices_dtype = "int64" + dtype = "float64" + + for indices_dtype in ["int32", "int64"]: + for val_dtype in ["float32", "float64"]: + inputs = [ + np.random.randint(depth, size=indices_shape, dtype=indices_dtype), + np.array(np.random.randn() * 1e-5).astype(dtype), + np.array(np.random.randn() * 1e-5).astype(dtype), + ] + test_inputs = inputs[1:] + + indices = relay.var("indices", shape=indices_shape, dtype=indices_dtype) + on_val = relay.var("on_val", shape=tuple(), dtype=dtype) + off_val = relay.var("off_val", shape=tuple(), dtype=dtype) + y = relay.one_hot(indices, on_val, off_val, depth, axis, dtype) + f = relay.Function([indices, on_val, off_val], y) + + check_grad(f, inputs=inputs, test_inputs=test_inputs) if __name__ == "__main__": diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index 821e10f97e21..ae3fc2641a25 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -69,6 +69,13 @@ def test_cast_grad(): check_grad(fwd_func) +def test_cast_like_grad(): + data = relay.var("data", shape=(10, 4), dtype="float32") + like = relay.var("like", shape=(1,), dtype="float64") + fwd_func = relay.Function([data, like], relay.cast_like(data, like)) + check_grad(fwd_func) + + def test_copy_grad(): data = relay.var("data", relay.TensorType((10, 4), "float64")) fwd_func = relay.Function([data], relay.copy(data)) diff --git a/tests/python/relay/test_op_grad_level4.py b/tests/python/relay/test_op_grad_level4.py index 0f73e89c94ad..995e81f8fe47 100644 --- a/tests/python/relay/test_op_grad_level4.py +++ b/tests/python/relay/test_op_grad_level4.py @@ -86,5 +86,32 @@ def test_less_equal_grad(): check_grad(fwd_func, inputs=inputs, test_inputs=inputs, eps=1e-6) +def test_not_equal_grad(): + x_type = relay.TensorType((2, 3, 4), "float32") + y_type = relay.TensorType((3, 1), "float32") + # We need to generate inputs far apart to get correct numerical gradients + # (otherwise adding epsilon may change comparison result). The gradient + # should always be zero for both inputs. + inputs = [ + np.random.choice([-1, 1], size=x_type.concrete_shape).astype(x_type.dtype), + np.random.choice([-2, 2], size=y_type.concrete_shape).astype(y_type.dtype), + ] + + x = relay.var("x", type_annotation=x_type) + y = relay.var("y", type_annotation=y_type) + fwd_func = relay.Function([x, y], relay.not_equal(x, y)) + check_grad(fwd_func, inputs=inputs, test_inputs=inputs, eps=1e-6) + + +def test_strided_slice_grad(): + def check(sh, dtype, begin, end, strides, slice_mode): + x = relay.var("x", shape=sh, dtype=dtype) + check_grad(relay.Function([x], relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode))) + + check((2, 3, 4), "float32", (0, 1, 0), (-1, -1, 1), (1, 1, 1), "size") + check((2, 3, 4), "float32", (0, 1, 0), (2, 3, 1), (1, 1, 1), "end") + check((2, 3, 4), "float32", (0, 0, 0), (-1, -1, -1), (1, 1, 2), "size") + + if __name__ == "__main__": pytest.main() From 88d84991eb7ebba225234d993a912926dfc2b4cb Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Tue, 22 Jun 2021 11:21:57 -0700 Subject: [PATCH 3/3] finish up --- python/tvm/relay/op/_tensor_grad.py | 13 +++++++------ tests/python/relay/test_op_grad_level10.py | 19 ++++++++----------- tests/python/relay/test_op_grad_level4.py | 8 +++++++- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index e2fbbeec5777..09b1435aac0f 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -886,25 +886,26 @@ def not_equal_grad(orig, grad): return [zeros_like(orig.args[0]), zeros_like(orig.args[1])] -# TODO: test @register_gradient("strided_slice") def strided_slice_grad(orig, grad): """ Returns the gradient of strided_slice, which is equal to grad where the input was sliced and zero elsewhere. """ + assert orig.attrs.axes is None, "grad for strided_slice with axes is not yet supported" x = orig.args[0] begin = get_const_tuple(orig.attrs.begin) end = get_const_tuple(orig.attrs.end) strides = get_const_tuple(orig.attrs.strides) if orig.attrs.slice_mode == "size": - # convert sizes to ending indices + # convert sizes to ending indices and ignore strides end = list(end) for i, (start, size) in enumerate(zip(begin, end)): if size == -1: - end[i] = x.checked_type.shape[i] + end[i] = int(x.checked_type.shape[i]) else: end[i] = start + size + strides = None else: assert orig.attrs.slice_mode == "end" return [strided_set(zeros_like(x), grad, begin, end, strides)] @@ -920,7 +921,7 @@ def one_hot_grad(orig, grad): g_zeros = zeros_like(grad) on_mask = equal(orig, on_value) - grad_on = collapse_sum_like(where(on_mask, grad, g_zeros), on_value) - grad_off = collapse_sum_like(where(on_mask, g_zeros, grad), off_value) + grad_on = _sum(where(on_mask, grad, g_zeros)) + grad_off = _sum(where(on_mask, g_zeros, grad)) - return [zeros_like(indices), cast_like(grad_on, on_value), cast_like(grad_off, on_value)] + return [zeros_like(indices), cast_like(grad_on, on_value), cast_like(grad_off, off_value)] diff --git a/tests/python/relay/test_op_grad_level10.py b/tests/python/relay/test_op_grad_level10.py index d4553fc5ea4d..e2145f77b366 100644 --- a/tests/python/relay/test_op_grad_level10.py +++ b/tests/python/relay/test_op_grad_level10.py @@ -17,9 +17,8 @@ import pytest import numpy as np -from tvm import relay, transform -from tvm.relay.testing import check_grad, _np_randn_from_type, run_infer_type, run_opt_pass -from tvm.ir.instrument import pass_instrument +from tvm import relay +from tvm.relay.testing import check_grad def test_cross_entropy_grad(): @@ -78,22 +77,20 @@ def test_one_hot_grad(): indices_shape = (3, 4) depth = 5 axis = -1 - indices_dtype = "int64" - dtype = "float64" - + for indices_dtype in ["int32", "int64"]: for val_dtype in ["float32", "float64"]: inputs = [ np.random.randint(depth, size=indices_shape, dtype=indices_dtype), - np.array(np.random.randn() * 1e-5).astype(dtype), - np.array(np.random.randn() * 1e-5).astype(dtype), + np.array(np.random.randn() * 1e-5).astype(val_dtype), + np.array(np.random.randn() * 1e-5).astype(val_dtype), ] test_inputs = inputs[1:] indices = relay.var("indices", shape=indices_shape, dtype=indices_dtype) - on_val = relay.var("on_val", shape=tuple(), dtype=dtype) - off_val = relay.var("off_val", shape=tuple(), dtype=dtype) - y = relay.one_hot(indices, on_val, off_val, depth, axis, dtype) + on_val = relay.var("on_val", shape=tuple(), dtype=val_dtype) + off_val = relay.var("off_val", shape=tuple(), dtype=val_dtype) + y = relay.one_hot(indices, on_val, off_val, depth, axis, val_dtype) f = relay.Function([indices, on_val, off_val], y) check_grad(f, inputs=inputs, test_inputs=test_inputs) diff --git a/tests/python/relay/test_op_grad_level4.py b/tests/python/relay/test_op_grad_level4.py index 995e81f8fe47..17d30cacac41 100644 --- a/tests/python/relay/test_op_grad_level4.py +++ b/tests/python/relay/test_op_grad_level4.py @@ -106,11 +106,17 @@ def test_not_equal_grad(): def test_strided_slice_grad(): def check(sh, dtype, begin, end, strides, slice_mode): x = relay.var("x", shape=sh, dtype=dtype) - check_grad(relay.Function([x], relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode))) + f = relay.Function( + [x], + relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode), + ) + check_grad(f) check((2, 3, 4), "float32", (0, 1, 0), (-1, -1, 1), (1, 1, 1), "size") check((2, 3, 4), "float32", (0, 1, 0), (2, 3, 1), (1, 1, 1), "end") + # check that strides are properly ignored when using "size" mode check((2, 3, 4), "float32", (0, 0, 0), (-1, -1, -1), (1, 1, 2), "size") + check((2, 3, 4), "float32", (0, 0, 0), (2, 3, 4), (1, 1, 2), "end") if __name__ == "__main__":