diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index fc89a87c5cd25..af29766d1fad9 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -479,13 +479,11 @@ def inverse(self): """ assert self.n == self.m, 'Only square matrices are invertible' if self.n == 1: - return Matrix([1 / self(0, 0)], disable_local_tensor=True) + return Matrix([1 / self(0, 0)]) if self.n == 2: - inv_det = impl.expr_init(1.0 / self.determinant()) - # Discussion: https://github.com/taichi-dev/taichi/pull/943#issuecomment-626344323 - return inv_det * Matrix([[self(1, 1), -self(0, 1)], - [-self(1, 0), self(0, 0)]], - disable_local_tensor=True).variable() + inv_determinant = impl.expr_init(1.0 / self.determinant()) + return inv_determinant * Matrix([[self( + 1, 1), -self(0, 1)], [-self(1, 0), self(0, 0)]]) if self.n == 3: n = 3 inv_determinant = impl.expr_init(1.0 / self.determinant()) @@ -496,10 +494,10 @@ def E(x, y): for i in range(n): for j in range(n): - entries[j][i] = impl.expr_init( - inv_determinant * (E(i + 1, j + 1) * E(i + 2, j + 2) - - E(i + 2, j + 1) * E(i + 1, j + 2))) - return Matrix(entries, disable_local_tensor=True) + entries[j][i] = inv_determinant * ( + E(i + 1, j + 1) * E(i + 2, j + 2) - + E(i + 2, j + 1) * E(i + 1, j + 2)) + return Matrix(entries) if self.n == 4: n = 4 inv_determinant = impl.expr_init(1.0 / self.determinant()) @@ -510,18 +508,15 @@ def E(x, y): for i in range(n): for j in range(n): - entries[j][i] = impl.expr_init( - inv_determinant * (-1)**(i + j) * - ((E(i + 1, j + 1) * - (E(i + 2, j + 2) * E(i + 3, j + 3) - - E(i + 3, j + 2) * E(i + 2, j + 3)) - - E(i + 2, j + 1) * - (E(i + 1, j + 2) * E(i + 3, j + 3) - - E(i + 3, j + 2) * E(i + 1, j + 3)) + - E(i + 3, j + 1) * - (E(i + 1, j + 2) * E(i + 2, j + 3) - - E(i + 2, j + 2) * E(i + 1, j + 3))))) - return Matrix(entries, disable_local_tensor=True) + entries[j][i] = inv_determinant * (-1)**(i + j) * (( + E(i + 1, j + 1) * + (E(i + 2, j + 2) * E(i + 3, j + 3) - + E(i + 3, j + 2) * E(i + 2, j + 3)) - E(i + 2, j + 1) * + (E(i + 1, j + 2) * E(i + 3, j + 3) - + E(i + 3, j + 2) * E(i + 1, j + 3)) + E(i + 3, j + 1) * + (E(i + 1, j + 2) * E(i + 2, j + 3) - + E(i + 2, j + 2) * E(i + 1, j + 3)))) + return Matrix(entries) raise Exception( "Inversions of matrices with sizes >= 5 are not supported") @@ -567,10 +562,8 @@ def transpose(self): Get the transpose of a matrix. """ - ret = Matrix([[self[i, j] for i in range(self.n)] - for j in range(self.m)], - disable_local_tensor=True) - return ret + return Matrix([[self[i, j] for i in range(self.n)] + for j in range(self.m)]) @taichi_scope def determinant(a): @@ -790,10 +783,8 @@ def zero(dt, n, m=None): """ if m is None: - return Vector([ti.cast(0, dt) for _ in range(n)], - disable_local_tensor=True) - return Matrix([[ti.cast(0, dt) for _ in range(m)] for _ in range(n)], - disable_local_tensor=True) + return Vector([ti.cast(0, dt) for _ in range(n)]) + return Matrix([[ti.cast(0, dt) for _ in range(m)] for _ in range(n)]) @staticmethod @taichi_scope @@ -810,10 +801,8 @@ def one(dt, n, m=None): """ if m is None: - return Vector([ti.cast(1, dt) for _ in range(n)], - disable_local_tensor=True) - return Matrix([[ti.cast(1, dt) for _ in range(m)] for _ in range(n)], - disable_local_tensor=True) + return Vector([ti.cast(1, dt) for _ in range(n)]) + return Matrix([[ti.cast(1, dt) for _ in range(m)] for _ in range(n)]) @staticmethod @taichi_scope @@ -832,8 +821,7 @@ def unit(n, i, dt=None): if dt is None: dt = int assert 0 <= i < n - return Matrix([ti.cast(int(j == i), dt) for j in range(n)], - disable_local_tensor=True) + return Vector([ti.cast(int(j == i), dt) for j in range(n)]) @staticmethod @taichi_scope @@ -849,8 +837,7 @@ def identity(dt, n): """ return Matrix([[ti.cast(int(i == j), dt) for j in range(n)] - for i in range(n)], - disable_local_tensor=True) + for i in range(n)]) @staticmethod def rotation2d(alpha): @@ -1107,18 +1094,15 @@ def dot(self, other): @kern_mod.pyfunc def _cross3d(self, other): - ret = Matrix([ + return Matrix([ self[1] * other[2] - self[2] * other[1], self[2] * other[0] - self[0] * other[2], self[0] * other[1] - self[1] * other[0], - ], - disable_local_tensor=True) - return ret + ]) @kern_mod.pyfunc def _cross2d(self, other): - ret = self[0] * other[1] - self[1] * other[0] - return ret + return self[0] * other[1] - self[1] * other[0] def cross(self, other): """Perform the cross product with the input Vector (1-D Matrix). @@ -1156,10 +1140,8 @@ def outer_product(self, other): impl.static( impl.static_assert(other.m == 1, "rhs for outer_product is not a vector")) - ret = Matrix([[self[i] * other[j] for j in range(other.n)] - for i in range(self.n)], - disable_local_tensor=True) - return ret + return Matrix([[self[i] * other[j] for j in range(other.n)] + for i in range(self.n)]) def Vector(n, dt=None, **kwargs):