Skip to content

Commit

Permalink
#4405: Update logit implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw committed Mar 22, 2024
1 parent 26f4140 commit 6738e44
Showing 1 changed file with 17 additions and 41 deletions.
58 changes: 17 additions & 41 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -830,49 +830,25 @@ Tensor addcdiv(

// logit(input, eps)=log(input / 1 - input)
Tensor _logit(const Tensor& input_a, float eps, const MemoryConfig& output_mem_config) {
Tensor t_eps = mk_filled_tensor_like(input_a, eps, output_mem_config);
Tensor t_one = ones_like(input_a, output_mem_config);
Tensor t_inf = mul_unary(t_one, std::numeric_limits<float>::infinity(), output_mem_config);
Tensor partial_output(input_a);
{
Tensor result(input_a);
{
Tensor output(input_a);
{
Tensor mul_input(input_a);
{
Tensor sub_input(input_a);
{
Tensor neg_input = neg(input_a, output_mem_config);
sub_input = add_unary(neg_input, 1.0f, output_mem_config);
}
mul_input = mul(input_a, recip(sub_input, output_mem_config), std::nullopt, output_mem_config);
}

Tensor mul_eps(input_a);
{
Tensor sub_eps(input_a);
{
Tensor neg_eps = neg(t_eps, output_mem_config);
sub_eps = add_unary(neg_eps, 1.0f, output_mem_config);
}
mul_eps = mul(t_eps, recip(sub_eps, output_mem_config), std::nullopt, output_mem_config);
}
{
Tensor ia_lt_eps = lt(input_a, t_eps, std::nullopt, output_mem_config);
output = where(ia_lt_eps, mul_eps, mul_input, output_mem_config);
}
}
result = log(output, output_mem_config);
}
{
Tensor in_eq_one = eq(input_a, t_one, std::nullopt, output_mem_config);
partial_output = where(in_eq_one, t_inf, result, output_mem_config);
}
}
Tensor logit_input = input_a;
Tensor t_eps = full_like(input_a, eps, output_mem_config);
Tensor t1m_eps = full_like(input_a, (1 - eps), output_mem_config);
Tensor xlt_eps = lt(input_a, t_eps, std::nullopt, output_mem_config);
Tensor in_range = logical_and(lte(t_eps, input_a, std::nullopt, output_mem_config),
lte(input_a, t1m_eps, std::nullopt, output_mem_config), std::nullopt, output_mem_config);
logit_input = where(eq_unary(xlt_eps, 1.0, output_mem_config), t_eps,
where(eq_unary(in_range, 1.0, output_mem_config), input_a, t1m_eps, output_mem_config), output_mem_config);
xlt_eps.deallocate();
in_range.deallocate();
Tensor linput_m1 = rsub(logit_input, 1.0, output_mem_config);
Tensor partial_output = log(mul(logit_input, recip(linput_m1, output_mem_config), std::nullopt, output_mem_config), output_mem_config);
linput_m1.deallocate();
float t_inf = std::numeric_limits<float>::infinity();
float t_nan = std::nanf("");
partial_output = where(eq_unary(input_a, 1.0, output_mem_config), t_inf, partial_output);
Tensor final_result(input_a);
{
float t_nan = std::nanf("");
Tensor t_one = ones_like(input_a, output_mem_config);
Tensor eps_gt_one = gt(t_eps, t_one, std::nullopt, output_mem_config);
Tensor eps_eq_one = eq(t_eps, t_one, std::nullopt, output_mem_config);
final_result =
Expand Down

0 comments on commit 6738e44

Please sign in to comment.