diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 23bfb955a5c4..48089d164a2f 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3631,6 +3631,73 @@ def _impl_v1(cls, inputs, attr, params): return _expr.TupleWrapper(_expr.Tuple(result), len(result)) +class Adam(OnnxOpConverter): + """Operator converter for Adam op.""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + alpha = attr.get("alpha", 0.9) + beta = attr.get("beta", 0.999) + + # Note in the docs epsilon default is 0.0 but in the tests it is set to 1e-2: + # https://git.io/Ju5C4 + epsilon = attr.get("epsilon", 1e-2) + norm_coefficient = attr.get("norm_coefficient", 0.0) + norm_coefficient_post = attr.get("norm_coefficient_post", 0.0) + + R = inputs[0] + T = inputs[1] + + assert ( + len(inputs) - 2 + ) % 4 == 0, f"Expect 4-lets for remaining inputs, found {len(inputs) - 2}" + + # convert attributes to constants, proper types + dtype_inputs = infer_type(inputs[3]).checked_type.dtype + inverse_alpha = relay.const(1 - alpha, dtype=dtype_inputs) + alpha = relay.const(alpha, dtype=dtype_inputs) + inverse_beta = relay.const(1 - beta, dtype=dtype_inputs) + beta = relay.const(beta, dtype=dtype_inputs) + epsilon = relay.const(epsilon, dtype=dtype_inputs) + norm_coefficient = relay.const(norm_coefficient, dtype=dtype_inputs) + norm_coefficient_post = relay.const(norm_coefficient_post, dtype=dtype_inputs) + one = relay.const(1, dtype=dtype_inputs) + T = relay.cast_like(T, inputs[3]) + + # Remaining inputs are: + # [x_1, x_2 ..., x_1_grad, x_2_grad, ... x_1_g_accum, x_2_g_accum..., x_1_g_sq_accum, ...] + num_input_tensors = (len(inputs) - 2) // 4 + output_tensors = [] + output_accumulated_gradients = [] + output_accumulated_squared_gradients = [] + for i in range(num_input_tensors): + x = inputs[i + 2] + g = inputs[i + 2 + num_input_tensors] + v = inputs[i + 2 + 2 * num_input_tensors] + h = inputs[i + 2 + 3 * num_input_tensors] + + g_regularized = norm_coefficient * x + g + v_new = alpha * v + inverse_alpha * g_regularized + h_new = beta * h + inverse_beta * g_regularized * g_regularized + h_sqrt = relay.sqrt(h_new) + epsilon + + true_branch = R * relay.sqrt(one - relay.power(beta, T)) / (one - relay.power(alpha, T)) + R_adjusted = relay.If(T > relay.const(0, dtype=dtype_inputs), true_branch, R) + + x_new = x - R_adjusted * (v_new / h_sqrt) + x_result = (one - norm_coefficient_post) * x_new + + output_tensors.append(x_result) + output_accumulated_gradients.append(v_new) + output_accumulated_squared_gradients.append(h_new) + + # append lists together to get final result + result = ( + output_tensors + output_accumulated_gradients + output_accumulated_squared_gradients + ) + return _expr.TupleWrapper(_expr.Tuple(result), len(result)) + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -3817,6 +3884,7 @@ def _get_convert_map(opset): # Loss functions / training "NegativeLogLikelihoodLoss": NegativeLogLikelihoodLoss.get_converter(opset), "Adagrad": Adagrad.get_converter(opset), + "Adam": Adam.get_converter(opset), } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d283cbf3624c..d9f2e97f8247 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4711,8 +4711,6 @@ def verify_eyelike(indata): ) unsupported_onnx_tests = [ - "test_adam", - "test_adam_multiple", "test_cast_BFLOAT16_to_FLOAT", "test_cast_DOUBLE_to_FLOAT16", "test_cast_FLOAT_to_BFLOAT16",