Skip to content

Commit

Permalink
[Onnx] Add Adagrad (#9001)
Browse files Browse the repository at this point in the history
* adagrad impl

* passing tests

* docstring

Co-authored-by: Andrew Zhao Luo <andrewzhaoluo@system76-pc.localdomain>
  • Loading branch information
AndrewZhaoLuo and Andrew Zhao Luo authored Sep 14, 2021
1 parent 80c8f35 commit 1b99adc
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
53 changes: 52 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3581,6 +3581,56 @@ def _impl_v13(cls, inputs, attr, params):
return loss


class Adagrad(OnnxOpConverter):
"""Operator converter for adagrad op."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
decay_factor = attr.get("decay_factor", 0.0)
epsilon = attr.get("epsilon", 0.0)
norm_coefficient = attr.get("norm_coefficient", 0.0)

R = inputs[0]
T = inputs[1]

# convert attributes to constants, proper types
dtype_inputs = infer_type(inputs[3]).checked_type.dtype
decay_factor = relay.const(decay_factor, dtype=dtype_inputs)
epsilon = relay.const(epsilon, dtype=dtype_inputs)
norm_coefficient = relay.const(norm_coefficient, dtype=dtype_inputs)
T = relay.cast_like(T, inputs[3])

assert (
len(inputs) - 2
) % 3 == 0, f"Expect triplets for remaining inputs, found {len(inputs) - 2}"

# Remaining inputs are:
# [x_1, x_2 ..., x_1_gradient, x_2_gradient, ... x_1_sq_g, x_2_sq_g...]
num_input_tensors = (len(inputs) - 2) // 3
output_tensors = []
output_accumulated_squared_gradients = []
for i in range(num_input_tensors):
x = inputs[i + 2]
gradient = inputs[i + 2 + num_input_tensors]
accumulated_squared_gradient = inputs[i + 2 + 2 * num_input_tensors]

r = R / (relay.const(1.0, dtype=dtype_inputs) + T * decay_factor)
g_regularized = norm_coefficient * x + gradient
new_accumulated_squared_gradient = (
accumulated_squared_gradient + g_regularized * g_regularized
)
h_adaptive = relay.sqrt(new_accumulated_squared_gradient) + epsilon

x_new = x - r * g_regularized / h_adaptive

output_tensors.append(x_new)
output_accumulated_squared_gradients.append(new_accumulated_squared_gradient)

# append lists together, momentums come after result tensors
result = output_tensors + output_accumulated_squared_gradients
return _expr.TupleWrapper(_expr.Tuple(result), len(result))


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -3764,8 +3814,9 @@ def _get_convert_map(opset):
"ConvInteger": ConvInteger.get_converter(opset),
# Random number generation.
"RandomUniform": RandomUniform.get_converter(opset),
# Loss functions
# Loss functions / training
"NegativeLogLikelihoodLoss": NegativeLogLikelihoodLoss.get_converter(opset),
"Adagrad": Adagrad.get_converter(opset),
}


Expand Down
2 changes: 0 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4711,8 +4711,6 @@ def verify_eyelike(indata):
)

unsupported_onnx_tests = [
"test_adagrad",
"test_adagrad_multiple",
"test_adam",
"test_adam_multiple",
"test_cast_BFLOAT16_to_FLOAT",
Expand Down

0 comments on commit 1b99adc

Please sign in to comment.