diff --git a/xla/service/algebraic_simplifier.cc b/xla/service/algebraic_simplifier.cc index f54864220e2e6..ff9a2f688cc87 100644 --- a/xla/service/algebraic_simplifier.cc +++ b/xla/service/algebraic_simplifier.cc @@ -530,7 +530,8 @@ bool AlgebraicSimplifierVisitor::IsNonNegative( return hlo->operand(0) == hlo->operand(1); } case HloOpcode::kAbs: - case HloOpcode::kExp: { + case HloOpcode::kExp: + case HloOpcode::kIota: { return true; } case HloOpcode::kBroadcast: { diff --git a/xla/service/algebraic_simplifier_test.cc b/xla/service/algebraic_simplifier_test.cc index 1af8721066b8b..ea67d07a14196 100644 --- a/xla/service/algebraic_simplifier_test.cc +++ b/xla/service/algebraic_simplifier_test.cc @@ -10561,6 +10561,18 @@ TEST_F(AlgebraicSimplifierTest, AbsEliminationSelMaxBcast) { m::Broadcast(m::ConstantScalar()))))); } +TEST_F(AlgebraicSimplifierTest, AbsEliminationIota) { + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(R"( + e { + i = s32[3,2] iota(), iota_dimension=0 + ROOT a = s32[3,2] abs(i) + } + )")); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Iota())); +} + TEST_F(AlgebraicSimplifierTest, SimplifyRedundantBitcastConvert) { const char* kModuleStr = R"( HloModule m