From 73696fe1bcb0d4c7ecc37c496b263526e3346796 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 9 Aug 2021 21:52:29 +0000 Subject: [PATCH] [AutoScheduler] Fix FLOPS estimation --- src/auto_scheduler/compute_dag.cc | 12 ++++++++---- .../unittest/test_auto_scheduler_compute_dag.py | 5 +++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index abbcba234848..e82830fa4d06 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -611,10 +611,14 @@ class FlopEstimator : public ExprFunctor { std::max(VisitExpr(op->true_value), VisitExpr(op->false_value)); } -#define VisitBinary(Node) \ - double VisitExpr_(const Node* op) final { \ - double base = op->dtype.code() == cur_type_code_ ? 1.0 : 0.0; \ - return base + VisitExpr(op->a) + VisitExpr(op->b); \ +// Index calculations (e.g., the "i + j" expression in A[i + j]) are not counted in FLOPS. +#define VisitBinary(Node) \ + double VisitExpr_(const Node* op) final { \ + double base = 1.0; \ + if ((op->a->dtype.code() != cur_type_code_) && (op->b->dtype.code() != cur_type_code_)) { \ + base = 0.0; \ + } \ + return base + VisitExpr(op->a) + VisitExpr(op->b); \ } #define VisitUnary(Node) \ diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py b/tests/python/unittest/test_auto_scheduler_compute_dag.py index b303ef56c1d2..e394115619a4 100644 --- a/tests/python/unittest/test_auto_scheduler_compute_dag.py +++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py @@ -62,6 +62,11 @@ def test_estimate_flop(): dag = auto_scheduler.ComputeDAG([A, B, F]) assert abs(dag.flop_ct - (2 * N ** 3 + 1234)) < 0.5 + A = te.placeholder((N, N), dtype="float32", name="A") + F = te.compute((N, N), lambda i, j: te.if_then_else(A[i, j] > 0, A[i, j], 0)) + dag = auto_scheduler.ComputeDAG([A, F]) + assert abs(dag.flop_ct - N ** 2) < 0.5 + def test_stage_order(): """Test if the stage order is preserved when recovering a DAG."""