Skip to content

Commit

Permalink
[lang] Fix matrix type inference and remove _MatrixEntriesInitializer (
Browse files Browse the repository at this point in the history
…taichi-dev#6928)

Issue: taichi-dev#5819

### Brief Summary

Before this PR, matrix type inference directly takes the type of the
first element, which is problematic. This PR fixes the inference by
calculating the common type of all elements and removes the redundant
`_MatrixEntriesInitializer`.
  • Loading branch information
strongoier authored and quadpixels committed May 13, 2023
1 parent b132a57 commit a6638b8
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 60 deletions.
101 changes: 41 additions & 60 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import numbers
import warnings
from collections.abc import Iterable
Expand Down Expand Up @@ -102,25 +103,44 @@ def prop_setter(instance, value):
return cls


def _infer_entry_dt(entry):
if isinstance(entry, (int, np.integer)):
return impl.get_runtime().default_ip
if isinstance(entry, float):
return impl.get_runtime().default_fp
if isinstance(entry, expr.Expr):
dt = entry.ptr.get_ret_type()
if dt == ti_python_core.DataType_unknown:
raise TaichiTypeError(
'Element type of the matrix cannot be inferred. Please set dt instead for now.'
)
return dt
raise TaichiTypeError('Element type of the matrix is invalid.')


def _infer_array_dt(arr):
assert len(arr) > 0
return functools.reduce(ti_python_core.promoted_type,
map(_infer_entry_dt, arr))


def make_matrix(arr, dt=None):
if len(arr) == 0:
# the only usage of an empty vector is to serve as field indices
is_matrix = False
shape = [0]
dt = primitive_types.i32
else:
is_matrix = isinstance(arr[0], Iterable)
if isinstance(arr[0], Iterable): # matrix
shape = [len(arr), len(arr[0])]
arr = [elt for row in arr for elt in row]
else: # vector
shape = [len(arr)]
if dt is None:
dt = _make_entries_initializer(is_matrix).infer_dt(arr)
dt = _infer_array_dt(arr)
else:
dt = cook_dtype(dt)
if not is_matrix:
return impl.Expr(
impl.make_matrix_expr([len(arr)], dt,
[expr.Expr(elt).ptr for elt in arr]))
return impl.Expr(
impl.make_matrix_expr(
[len(arr), len(arr[0])], dt,
[expr.Expr(elt).ptr for row in arr for elt in row]))
return expr.Expr(
impl.make_matrix_expr(shape, dt, [expr.Expr(elt).ptr for elt in arr]))


def is_vector(x):
Expand Down Expand Up @@ -241,48 +261,6 @@ def _set_entries(self, value):
self[i, j] = value[i][j]


class _MatrixEntriesInitializer:
def pyscope(self, arr):
raise NotImplementedError('Override')

def _get_entry_to_infer(self, arr):
raise NotImplementedError('Override')

def infer_dt(self, arr):
entry = self._get_entry_to_infer(arr)
if isinstance(entry, (int, np.integer)):
return impl.get_runtime().default_ip
if isinstance(entry, float):
return impl.get_runtime().default_fp
if isinstance(entry, expr.Expr):
dt = entry.ptr.get_ret_type()
if dt == ti_python_core.DataType_unknown:
raise TypeError(
'Element type of the matrix cannot be inferred. Please set dt instead for now.'
)
return dt
raise Exception(
'dt required when using dynamic_index for local tensor')


def _make_entries_initializer(is_matrix: bool) -> _MatrixEntriesInitializer:
class _VecImpl(_MatrixEntriesInitializer):
def pyscope(self, arr):
return [[x] for x in arr]

def _get_entry_to_infer(self, arr):
return arr[0]

class _MatImpl(_MatrixEntriesInitializer):
def pyscope(self, arr):
return [list(row) for row in arr]

def _get_entry_to_infer(self, arr):
return arr[0][0]

return _MatImpl() if is_matrix else _VecImpl()


@_gen_swizzles
class Matrix(TaichiOperations):
"""The matrix class.
Expand Down Expand Up @@ -346,13 +324,16 @@ def __init__(self, arr, dt=None, ndim=None):
is_matrix = isinstance(arr[0],
Iterable) and not is_vector(self)
self.ndim = 2 if is_matrix else 1
initializer = _make_entries_initializer(is_matrix)
if not is_matrix and isinstance(arr[0], Iterable):
flattened = []
for row in arr:
flattened += row
arr = flattened
mat = initializer.pyscope(arr)

if is_matrix:
mat = [list(row) for row in arr]
else:
if isinstance(arr[0], Iterable):
flattened = []
for row in arr:
flattened += row
arr = flattened
mat = [[x] for x in arr]

self.n, self.m = len(mat), 1
if len(mat) > 0:
Expand Down
2 changes: 2 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,8 @@ void export_lang(py::module &m) {

py::class_<Type>(m, "Type").def("to_string", &Type::to_string);

m.def("promoted_type", promoted_type);

// Note that it is important to specify py::return_value_policy::reference for
// the factory methods, otherwise pybind11 will delete the Types owned by
// TypeFactory on Python-scope pointer destruction.
Expand Down
10 changes: 10 additions & 0 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,3 +1205,13 @@ def foo() -> ti.types.vector(4, ti.i32):
return ti.Vector([a[0, 0], a[0, 1], a[1, 0], a[1, 1]])

assert (foo() == [1, 2, 3, 4]).all()


@test_utils.test(debug=True)
def test_matrix_type_inference():
@ti.kernel
def foo():
a = ti.Vector([1, 2.5])[1] # should be f32 instead of i32
assert a == 2.5

foo()

0 comments on commit a6638b8

Please sign in to comment.