From 6738e4409a6defc51d89e9b7abba66c2ab49b905 Mon Sep 17 00:00:00 2001 From: umadevimcw Date: Tue, 19 Mar 2024 10:10:43 +0000 Subject: [PATCH] #4405: Update logit implementation --- .../op_library/composite/composite_ops.cpp | 58 ++++++------------- 1 file changed, 17 insertions(+), 41 deletions(-) diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp index 9a21726d0b7..ab2a260e511 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp @@ -830,49 +830,25 @@ Tensor addcdiv( // logit(input, eps)=log(input / 1 - input) Tensor _logit(const Tensor& input_a, float eps, const MemoryConfig& output_mem_config) { - Tensor t_eps = mk_filled_tensor_like(input_a, eps, output_mem_config); - Tensor t_one = ones_like(input_a, output_mem_config); - Tensor t_inf = mul_unary(t_one, std::numeric_limits::infinity(), output_mem_config); - Tensor partial_output(input_a); - { - Tensor result(input_a); - { - Tensor output(input_a); - { - Tensor mul_input(input_a); - { - Tensor sub_input(input_a); - { - Tensor neg_input = neg(input_a, output_mem_config); - sub_input = add_unary(neg_input, 1.0f, output_mem_config); - } - mul_input = mul(input_a, recip(sub_input, output_mem_config), std::nullopt, output_mem_config); - } - - Tensor mul_eps(input_a); - { - Tensor sub_eps(input_a); - { - Tensor neg_eps = neg(t_eps, output_mem_config); - sub_eps = add_unary(neg_eps, 1.0f, output_mem_config); - } - mul_eps = mul(t_eps, recip(sub_eps, output_mem_config), std::nullopt, output_mem_config); - } - { - Tensor ia_lt_eps = lt(input_a, t_eps, std::nullopt, output_mem_config); - output = where(ia_lt_eps, mul_eps, mul_input, output_mem_config); - } - } - result = log(output, output_mem_config); - } - { - Tensor in_eq_one = eq(input_a, t_one, std::nullopt, output_mem_config); - partial_output = where(in_eq_one, t_inf, result, output_mem_config); - } - } + Tensor logit_input = input_a; + Tensor t_eps = full_like(input_a, eps, output_mem_config); + Tensor t1m_eps = full_like(input_a, (1 - eps), output_mem_config); + Tensor xlt_eps = lt(input_a, t_eps, std::nullopt, output_mem_config); + Tensor in_range = logical_and(lte(t_eps, input_a, std::nullopt, output_mem_config), + lte(input_a, t1m_eps, std::nullopt, output_mem_config), std::nullopt, output_mem_config); + logit_input = where(eq_unary(xlt_eps, 1.0, output_mem_config), t_eps, + where(eq_unary(in_range, 1.0, output_mem_config), input_a, t1m_eps, output_mem_config), output_mem_config); + xlt_eps.deallocate(); + in_range.deallocate(); + Tensor linput_m1 = rsub(logit_input, 1.0, output_mem_config); + Tensor partial_output = log(mul(logit_input, recip(linput_m1, output_mem_config), std::nullopt, output_mem_config), output_mem_config); + linput_m1.deallocate(); + float t_inf = std::numeric_limits::infinity(); + float t_nan = std::nanf(""); + partial_output = where(eq_unary(input_a, 1.0, output_mem_config), t_inf, partial_output); Tensor final_result(input_a); { - float t_nan = std::nanf(""); + Tensor t_one = ones_like(input_a, output_mem_config); Tensor eps_gt_one = gt(t_eps, t_one, std::nullopt, output_mem_config); Tensor eps_eq_one = eq(t_eps, t_one, std::nullopt, output_mem_config); final_result =