Skip to content

Commit

Permalink
[lang] Fix matrix type inference and remove _MatrixEntriesInitializer
Browse files Browse the repository at this point in the history
  • Loading branch information
strongoier committed Dec 20, 2022
1 parent 6f4ce42 commit 6b16222
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 60 deletions.
101 changes: 41 additions & 60 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import numbers
import warnings
from collections.abc import Iterable
Expand Down Expand Up @@ -102,25 +103,44 @@ def prop_setter(instance, value):
return cls


def _infer_entry_dt(entry):
if isinstance(entry, (int, np.integer)):
return impl.get_runtime().default_ip
if isinstance(entry, float):
return impl.get_runtime().default_fp
if isinstance(entry, expr.Expr):
dt = entry.ptr.get_ret_type()
if dt == ti_python_core.DataType_unknown:
raise TaichiTypeError(
'Element type of the matrix cannot be inferred. Please set dt instead for now.'
)
return dt
raise TaichiTypeError('Element type of the matrix is invalid.')


def _infer_array_dt(arr):
assert len(arr) > 0
return functools.reduce(ti_python_core.promoted_type,
map(_infer_entry_dt, arr))


def make_matrix(arr, dt=None):
if len(arr) == 0:
# the only usage of an empty vector is to serve as field indices
is_matrix = False
shape = [0]
dt = primitive_types.i32
else:
is_matrix = isinstance(arr[0], Iterable)
if isinstance(arr[0], Iterable): # matrix
shape = [len(arr), len(arr[0])]
arr = [elt for row in arr for elt in row]
else: # vector
shape = [len(arr)]
if dt is None:
dt = _make_entries_initializer(is_matrix).infer_dt(arr)
dt = _infer_array_dt(arr)
else:
dt = cook_dtype(dt)
if not is_matrix:
return impl.Expr(
impl.make_matrix_expr([len(arr)], dt,
[expr.Expr(elt).ptr for elt in arr]))
return impl.Expr(
impl.make_matrix_expr(
[len(arr), len(arr[0])], dt,
[expr.Expr(elt).ptr for row in arr for elt in row]))
return expr.Expr(
impl.make_matrix_expr(shape, dt, [expr.Expr(elt).ptr for elt in arr]))


def is_vector(x):
Expand Down Expand Up @@ -241,48 +261,6 @@ def _set_entries(self, value):
self[i, j] = value[i][j]


class _MatrixEntriesInitializer:
def pyscope(self, arr):
raise NotImplementedError('Override')

def _get_entry_to_infer(self, arr):
raise NotImplementedError('Override')

def infer_dt(self, arr):
entry = self._get_entry_to_infer(arr)
if isinstance(entry, (int, np.integer)):
return impl.get_runtime().default_ip
if isinstance(entry, float):
return impl.get_runtime().default_fp
if isinstance(entry, expr.Expr):
dt = entry.ptr.get_ret_type()
if dt == ti_python_core.DataType_unknown:
raise TypeError(
'Element type of the matrix cannot be inferred. Please set dt instead for now.'
)
return dt
raise Exception(
'dt required when using dynamic_index for local tensor')


def _make_entries_initializer(is_matrix: bool) -> _MatrixEntriesInitializer:
class _VecImpl(_MatrixEntriesInitializer):
def pyscope(self, arr):
return [[x] for x in arr]

def _get_entry_to_infer(self, arr):
return arr[0]

class _MatImpl(_MatrixEntriesInitializer):
def pyscope(self, arr):
return [list(row) for row in arr]

def _get_entry_to_infer(self, arr):
return arr[0][0]

return _MatImpl() if is_matrix else _VecImpl()


@_gen_swizzles
class Matrix(TaichiOperations):
"""The matrix class.
Expand Down Expand Up @@ -346,13 +324,16 @@ def __init__(self, arr, dt=None, ndim=None):
is_matrix = isinstance(arr[0],
Iterable) and not is_vector(self)
self.ndim = 2 if is_matrix else 1
initializer = _make_entries_initializer(is_matrix)
if not is_matrix and isinstance(arr[0], Iterable):
flattened = []
for row in arr:
flattened += row
arr = flattened
mat = initializer.pyscope(arr)

if is_matrix:
mat = [list(row) for row in arr]
else:
if isinstance(arr[0], Iterable):
flattened = []
for row in arr:
flattened += row
arr = flattened
mat = [[x] for x in arr]

self.n, self.m = len(mat), 1
if len(mat) > 0:
Expand Down
2 changes: 2 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,8 @@ void export_lang(py::module &m) {

py::class_<Type>(m, "Type").def("to_string", &Type::to_string);

m.def("promoted_type", promoted_type);

// Note that it is important to specify py::return_value_policy::reference for
// the factory methods, otherwise pybind11 will delete the Types owned by
// TypeFactory on Python-scope pointer destruction.
Expand Down
10 changes: 10 additions & 0 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,3 +1192,13 @@ def foo() -> ti.types.vector(4, ti.i32):
return ti.Vector([a[0, 0], a[0, 1], a[1, 0], a[1, 1]])

assert (foo() == [1, 2, 3, 4]).all()


@test_utils.test(debug=True)
def test_matrix_type_inference():
@ti.kernel
def foo():
a = ti.Vector([1, 2.5])[1] # should be f32 instead of i32
assert a == 2.5

foo()

0 comments on commit 6b16222

Please sign in to comment.