Skip to content

Commit

Permalink
[bug] Fix the mapping from virtual axes to physical axes (#3159)
Browse files Browse the repository at this point in the history
* [bug] Fix the mapping from virtual axes to physical axes

* Auto Format

Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
strongoier and taichi-gardener authored Oct 12, 2021
1 parent 36fb11f commit bfa5c28
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
2 changes: 2 additions & 0 deletions taichi/ir/snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ SNode &SNode::create_node(std::vector<Axis> axes,
new_node.physical_index_position[new_node.num_active_indices++] = i;
}
}
std::sort(new_node.physical_index_position,
new_node.physical_index_position + new_node.num_active_indices);
// infer extractors
int acc_shape = 1;
for (int i = taichi_max_num_indices - 1; i >= 0; i--) {
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_indices():

mapping_b = b.snode.physical_index_position()

assert mapping_b == {0: 1, 1: 0}
assert mapping_b == {0: 0, 1: 1}
# Note that b is column-major:
# the virtual first index exposed to the user comes second in memory layout.

Expand Down
10 changes: 5 additions & 5 deletions tests/python/test_struct_for_non_pot.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ def _test_2d():

@ti.kernel
def accumulate():
for i, j in x:
ti.atomic_add(sum[None], i + j * 2)
for i, k in x:
ti.atomic_add(sum[None], i + k * 2)

gt = 0
for i in range(n):
for j in range(m):
gt += i + j * 2
for k in range(n):
for i in range(m):
gt += i + k * 2

accumulate()

Expand Down
8 changes: 4 additions & 4 deletions tests/python/test_tensor_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_unordered():
blk3.place(val)

assert val.dtype == ti.i32
assert val.shape == (n, m, p)
assert val.shape == (m, p, n)
assert val.snode.parent(0) == val.snode
assert val.snode.parent() == blk3
assert val.snode.parent(1) == blk3
Expand All @@ -62,8 +62,8 @@ def test_unordered():
ti.get_runtime().materialize()
assert blk1 in ti.FieldsBuilder.finalized_roots()[0].get_children()

expected_str = f'ti.root => dense {[n]} => dense {[n, m]}' \
f' => dense {[n, m, p]} => place {[n, m, p]}'
expected_str = f'ti.root => dense {[n]} => dense {[m, n]}' \
f' => dense {[m, p, n]} => place {[m, p, n]}'
assert str(val.snode) == expected_str


Expand All @@ -80,7 +80,7 @@ def test_unordered_matrix():
blk3 = blk2.dense(ti.j, p)
blk3.place(val)

assert val.shape == (n, m, p)
assert val.shape == (m, p, n)
assert val.dtype == ti.i32
assert val.snode.parent(0) == val.snode
assert val.snode.parent() == blk3
Expand Down

0 comments on commit bfa5c28

Please sign in to comment.