Skip to content

Commit

Permalink
[Misc] Refactored flattend_values() to avoid potential conflicts in f…
Browse files Browse the repository at this point in the history
…lattened statements (taichi-dev#6749)

Issue: taichi-dev#5819

Overriding the flattened statement `stmt` of an `Expression` can cause
conflicts, for example:
```
@ti.kernel
def test():
    x = ti.Vector([1, 2, 3, 4])
    tmp = x + x[0] # implicit broadcast
```

In `x + x[0]`, the `x` on the lhs serves as rvalue whereas the `x` in
the `x[0]` serves as a lvalue, so the result of `flatten_rvalue()` and
`flatten_lvalue()` will override each other.

To avoid such conflicts, this PR refactored the `flatten_values()`
functions:
1. Flattened statement `stmt` of an `Expression` will only get modified
by `Expression::flatten()`, any other overriding will be forbidden.
2. `flatten_rvalue()` and `flatten_lvalue()` now returns the flattened
statement as the result. External users such as `irpass::lower_ast()`
will turn to use the returned statement.

Co-authored-by: Yi Xu <xy_xuyi@foxmail.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent 340324f commit 3d9e36f
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 125 deletions.
2 changes: 1 addition & 1 deletion taichi/analysis/gen_offline_cache_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
emit(expr.const_value);
emit(expr.atomic);
auto *e = expr.expr.get();
emit(e->stmt);
emit(e->get_flattened_stmt());
emit(e->attributes);
emit(e->ret_type);
expr.expr->accept(this);
Expand Down
8 changes: 7 additions & 1 deletion taichi/ir/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ class ExpressionVisitor;

// always a tree - used as rvalues
class Expression {
public:
protected:
Stmt *stmt;

public:
std::string tb;
std::map<std::string, std::string> attributes;
DataType ret_type;
Expand Down Expand Up @@ -53,6 +55,10 @@ class Expression {

virtual ~Expression() {
}

Stmt *get_flattened_stmt() const {
return stmt;
}
};

class ExprGroup {
Expand Down
Loading

0 comments on commit 3d9e36f

Please sign in to comment.