From 464d3091d94ce9e79930e953e99383ff99b5f9b1 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Wed, 25 Sep 2024 04:21:36 -0700 Subject: [PATCH] PR #17579: Algebraic simplifier: mark iota non-negative. Imported from GitHub PR https://github.com/openxla/xla/pull/17579 Copybara import of the project: -- 02c09a8dd5bb62ffd3729a23813a0e66f672a5a3 by Ilia Sergachev : Algebraic simplifier: mark iota non-negative. -- 4735edc2bac278ea1e87035f128a2f5d0f2a7a59 by Ilia Sergachev : Fix unrelated clang-format issues to make CI happy Merging this change closes #17579 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/17579 from openxla:iota_non_neg 4735edc2bac278ea1e87035f128a2f5d0f2a7a59 PiperOrigin-RevId: 678640567 --- xla/service/algebraic_simplifier.cc | 3 ++- xla/service/algebraic_simplifier_test.cc | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/xla/service/algebraic_simplifier.cc b/xla/service/algebraic_simplifier.cc index f54864220e2e6..ff9a2f688cc87 100644 --- a/xla/service/algebraic_simplifier.cc +++ b/xla/service/algebraic_simplifier.cc @@ -530,7 +530,8 @@ bool AlgebraicSimplifierVisitor::IsNonNegative( return hlo->operand(0) == hlo->operand(1); } case HloOpcode::kAbs: - case HloOpcode::kExp: { + case HloOpcode::kExp: + case HloOpcode::kIota: { return true; } case HloOpcode::kBroadcast: { diff --git a/xla/service/algebraic_simplifier_test.cc b/xla/service/algebraic_simplifier_test.cc index 1af8721066b8b..ea67d07a14196 100644 --- a/xla/service/algebraic_simplifier_test.cc +++ b/xla/service/algebraic_simplifier_test.cc @@ -10561,6 +10561,18 @@ TEST_F(AlgebraicSimplifierTest, AbsEliminationSelMaxBcast) { m::Broadcast(m::ConstantScalar()))))); } +TEST_F(AlgebraicSimplifierTest, AbsEliminationIota) { + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(R"( + e { + i = s32[3,2] iota(), iota_dimension=0 + ROOT a = s32[3,2] abs(i) + } + )")); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Iota())); +} + TEST_F(AlgebraicSimplifierTest, SimplifyRedundantBitcastConvert) { const char* kModuleStr = R"( HloModule m