Skip to content

Commit

Permalink
[Doc] Add docstring and documents for sparse matrix (#3119)
Browse files Browse the repository at this point in the history
* start adding docstring

* Apply suggestions from code review

Co-authored-by: FantasyVR <6712304+FantasyVR@users.noreply.github.com>

* updated sparse matrix docstrings

* docstrings for builder

* rebase @add docstrings for sparse solver

* update doc

* finish docs

* updated sparse solver docstrings

* add outputs for doc

* update doc

* update docstring

* move doc to new page

* updated

* Apply suggestions from code review

Co-authored-by: FantasyVR <6712304+FantasyVR@users.noreply.github.com>

* add headings

* Apply suggestions from code review

Co-authored-by: FantasyVR <6712304+FantasyVR@users.noreply.github.com>

* Auto Format

* Apply suggestions from code review

Co-authored-by: FantasyVR <6712304+FantasyVR@users.noreply.github.com>

* Update docs/lang/articles/advanced/sparse_matrix.md

* Update docs/lang/articles/advanced/sparse_matrix.md

* Update docs/lang/articles/advanced/sparse_matrix.md

Co-authored-by: FantasyVR <6712304+FantasyVR@users.noreply.github.com>
Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
Co-authored-by: Ye Kuang <k-ye@users.noreply.github.com>
  • Loading branch information
4 people authored Oct 11, 2021
1 parent e253fd3 commit c3478c7
Show file tree
Hide file tree
Showing 3 changed files with 281 additions and 0 deletions.
188 changes: 188 additions & 0 deletions docs/lang/articles/advanced/sparse_matrix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Sparse Matrix
Sparse matrices are frequently used when solving linear systems in science and engineering. Taichi provides programmers with useful APIs for sparse matrices.

To use the sparse matrix in taichi programs, you should follow these three steps:
1. Create a `builder` using `ti.SparseMatrixBuilder()`.
2. Fill the `builder` with your matrices' data.
3. Create sparse matrices from the `builder`.

:::caution WARNING
The sparse matrix is still under implementation. There are some limitations:
- Only the CPU backend is supported.
- The data type of sparse matrix is float32.
- The storage format is column-major
:::
Here's an example:
```python
import taichi as ti
ti.init(arch=ti.x64) # only CPU backend is supported for now

n = 4
# step 1: create sparse matrix builder
K = ti.SparseMatrixBuilder(n, n, max_num_triplets=100)

@ti.kernel
def fill(A: ti.sparse_matrix_builder()):
for i in range(n):
A[i, i] += 1

# step 2: fill the builder with data.
fill(K)

print(">>>> K.print_triplets()")
K.print_triplets()
# outputs:
# >>>> K.print_triplets()
# n=4, m=4, num_triplets=4 (max=100)(0, 0) val=1.0(1, 1) val=1.0(2, 2) val=1.0(3, 3) val=1.0

# step 3: create a sparse matrix from the builder.
A = K.build()
print(">>>> A = K.build()")
print(A)
# outputs:
# >>>> A = K.build()
# [1, 0, 0, 0]
# [0, 1, 0, 0]
# [0, 0, 1, 0]
# [0, 0, 0, 1]
```

The basic operations like `+`, `-`, `*`, `@` and transpose of sparse matrices are supported now.

```python
print(">>>> Summation: C = A + A")
C = A + A
print(C)
# outputs:
# >>>> Summation: C = A + A
# [2, 0, 0, 0]
# [0, 2, 0, 0]
# [0, 0, 2, 0]
# [0, 0, 0, 2]

print(">>>> Subtraction: D = A - A")
D = A - A
print(D)
# outputs:
# >>>> Subtraction: D = A - A
# [0, 0, 0, 0]
# [0, 0, 0, 0]
# [0, 0, 0, 0]
# [0, 0, 0, 0]

print(">>>> Multiplication with a scalar on the right: E = A * 3.0")
E = A * 3.0
print(E)
# outputs:
# >>>> Multiplication with a scalar on the right: E = A * 3.0
# [3, 0, 0, 0]
# [0, 3, 0, 0]
# [0, 0, 3, 0]
# [0, 0, 0, 3]

print(">>>> Multiplication with a scalar on the left: E = 3.0 * A")
E = 3.0 * A
print(E)
# outputs:
# >>>> Multiplication with a scalar on the left: E = 3.0 * A
# [3, 0, 0, 0]
# [0, 3, 0, 0]
# [0, 0, 3, 0]
# [0, 0, 0, 3]

print(">>>> Transpose: F = A.transpose()")
F = A.transpose()
print(F)
# outputs:
# >>>> Transpose: F = A.transpose()
# [1, 0, 0, 0]
# [0, 1, 0, 0]
# [0, 0, 1, 0]
# [0, 0, 0, 1]

print(">>>> Matrix multiplication: G = E @ A")
G = E @ A
print(G)
# outputs:
# >>>> Matrix multiplication: G = E @ A
# [3, 0, 0, 0]
# [0, 3, 0, 0]
# [0, 0, 3, 0]
# [0, 0, 0, 3]

print(">>>> Element-wise multiplication: H = E * A")
H = E * A
print(H)
# outputs:
# >>>> Element-wise multiplication: H = E * A
# [3, 0, 0, 0]
# [0, 3, 0, 0]
# [0, 0, 3, 0]
# [0, 0, 0, 3]

print(f">>>> Element Access: A[0,0] = {A[0,0]}")
# outputs:
# >>>> Element Access: A[0,0] = 1.0
```

## Sparse linear solver
You may want to solve some linear equations using sparse matrices.
Then, the following steps could help:
1. Create a `solver` using `ti.SparseSolver(solver_type, ordering)`. Currently, the sparse solver supports `LLT`, `LDLT` and `LU` factorization types, and orderings including `AMD`, `COLAMD`
2. Analyze and factorize the sparse matrix you want to solve using `solver.analyze_pattern(sparse_matrix)` and `solver.factorize(sparse_matrix)`
3. Call `solver.solve(b)` to get your solutions, where `b` is a numpy array or taichi filed representing the right-hand side of the linear system.
4. Call `solver.info()` to check if the solving process succeeds.

Here's a full example.

```python
import taichi as ti

ti.init(arch=ti.x64)

n = 4

K = ti.SparseMatrixBuilder(n, n, max_num_triplets=100)
b = ti.field(ti.f32, shape=n)

@ti.kernel
def fill(A: ti.sparse_matrix_builder(), b: ti.template(), interval: ti.i32):
for i in range(n):
A[i, i] += 2.0

if i % interval == 0:
b[i] += 1.0

fill(K, b, 3)

A = K.build()
print(">>>> Matrix A:")
print(A)
print(">>>> Vector b:")
print(b)
# outputs:
# >>>> Matrix A:
# [2, 0, 0, 0]
# [0, 2, 0, 0]
# [0, 0, 2, 0]
# [0, 0, 0, 2]
# >>>> Vector b:
# [1. 0. 0. 1.]
solver = ti.SparseSolver(solver_type="LLT")
solver.analyze_pattern(A)
solver.factorize(A)
x = solver.solve(b)
isSuccess = solver.info()
print(">>>> Solve sparse linear systems Ax = b with the solution x:")
print(x)
print(f">>>> Computation was successful?: {isSuccess}")
# outputs:
# >>>> Solve sparse linear systems Ax = b with the solution x:
# [0.5 0. 0. 0.5]
# >>>> Computation was successful?: True
```
## Examples

Please have a look at our two demos for more information:
+ `examples/simulation/stable_fluid.py`: A 2D fluid simulation using a sparse Laplacian matrix to solve Poisson's pressure equation.
+ `examples/simulation/implicit_mass_spring.py`: A 2D cloth simulation demo using sparse matrices to solve the linear systems.
58 changes: 58 additions & 0 deletions python/taichi/linalg/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@


class SparseMatrix:
"""Taichi's Sparse Matrix class
A sparse matrix allows the programmer to solve a large linear system.
Args:
n (int): the first dimension of a sparse matrix.
m (int): the second dimension of a sparse matrix.
sm (SparseMatrix): another sparse matrix that will be built from.
"""
def __init__(self, n=None, m=None, sm=None, dtype=f32):
if sm is None:
self.n = n
Expand All @@ -16,16 +25,33 @@ def __init__(self, n=None, m=None, sm=None, dtype=f32):
self.matrix = sm

def __add__(self, other):
"""Addition operation for sparse matrix.
Returns:
The result sparse matrix of the addition.
"""
assert self.n == other.n and self.m == other.m, f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
sm = self.matrix + other.matrix
return SparseMatrix(sm=sm)

def __sub__(self, other):
"""Subtraction operation for sparse matrix.
Returns:
The result sparse matrix of the subtraction.
"""
assert self.n == other.n and self.m == other.m, f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
sm = self.matrix - other.matrix
return SparseMatrix(sm=sm)

def __mul__(self, other):
"""Sparse matrix's multiplication against real numbers or the hadamard product against another matrix
Args:
other (float or SparseMatrix): the other operand of multiplication.
Returns:
The result of multiplication.
"""
if isinstance(other, float):
sm = self.matrix * other
return SparseMatrix(sm=sm)
Expand All @@ -35,15 +61,34 @@ def __mul__(self, other):
return SparseMatrix(sm=sm)

def __rmul__(self, other):
"""Right scalar multiplication for sparse matrix.
Args:
other (float): the other operand of scalar multiplication.
Returns:
The result of multiplication.
"""
if isinstance(other, float):
sm = other * self.matrix
return SparseMatrix(sm=sm)

def transpose(self):
"""Sparse Matrix transpose.
Returns:
The transposed sparse mastrix.
"""
sm = self.matrix.transpose()
return SparseMatrix(sm=sm)

def __matmul__(self, other):
"""Matrix multiplication.
Args:
other (SparseMatrix, Field, or numpy.array): the other sparse matrix of the multiplication.
Returns:
The result of matrix multiplication.
"""
if isinstance(other, SparseMatrix):
assert self.m == other.n, f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
sm = self.matrix.matmul(other.matrix)
Expand All @@ -66,13 +111,23 @@ def __setitem__(self, indices, value):
self.matrix.set_element(indices[0], indices[1], value)

def __str__(self):
"""Python scope matrix print support."""
return self.matrix.to_string()

def __repr__(self):
return self.matrix.to_string()


class SparseMatrixBuilder:
"""A python wrap around sparse matrix builder.
Use this builder to fill the sparse matrix.
Args:
num_rows (int): the first dimension of a sparse matrix.
num_cols (int): the second dimension of a sparse matrix.
max_num_triplets (int): the maximum number of triplets.
"""
def __init__(self,
num_rows=None,
num_cols=None,
Expand All @@ -85,11 +140,14 @@ def __init__(self,
num_rows, num_cols, max_num_triplets)

def get_addr(self):
"""Get the address of the sparse matrix"""
return self.ptr.get_addr()

def print_triplets(self):
"""Print the triplets stored in the builder"""
self.ptr.print_triplets()

def build(self, dtype=f32, format='CSR'):
"""Create a sparse matrix using the triplets"""
sm = self.ptr.build()
return SparseMatrix(sm=sm)
35 changes: 35 additions & 0 deletions python/taichi/linalg/sparse_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@


class SparseSolver:
"""Sparse linear system solver
Use this class to solve linear systems represented by sparse matrices.
Args:
solver_type (str): The factorization type.
ordering (str): The method for matrices re-ordering.
"""
def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"):
solver_type_list = ["LLT", "LDLT", "LU"]
solver_ordering = ['AMD', 'COLAMD']
Expand All @@ -21,24 +29,46 @@ def type_assert(sparse_matrix):
assert False, f"The parameter type: {type(sparse_matrix)} is not supported in linear solvers for now."

def compute(self, sparse_matrix):
"""This method is equivalent to calling both `analyze_pattern` and then `factorize`.
Args:
sparse_matrix (SparseMatrix): The sparse matrix to be computed.
"""
if isinstance(sparse_matrix, SparseMatrix):
self.solver.compute(sparse_matrix.matrix)
else:
self.type_assert(sparse_matrix)

def analyze_pattern(self, sparse_matrix):
"""Reorder the nonzero elements of the matrix, such that the factorization step creates less fill-in.
Args:
sparse_matrix (SparseMatrix): The sparse matrix to be analyzed.
"""
if isinstance(sparse_matrix, SparseMatrix):
self.solver.analyze_pattern(sparse_matrix.matrix)
else:
self.type_assert(sparse_matrix)

def factorize(self, sparse_matrix):
"""Do the factorization step
Args:
sparse_matrix (SparseMatrix): The sparse matrix to be factorized.
"""
if isinstance(sparse_matrix, SparseMatrix):
self.solver.factorize(sparse_matrix.matrix)
else:
self.type_assert(sparse_matrix)

def solve(self, b):
"""Computes the solution of the linear systems.
Args:
b (numpy.array or Field): The right-hand side of the linear systems.
Returns:
numpy.array: The solution of linear systems.
"""
if isinstance(b, taichi.lang.Field):
return self.solver.solve(b.to_numpy())
elif isinstance(b, np.ndarray):
Expand All @@ -47,4 +77,9 @@ def solve(self, b):
assert False, f"The parameter type: {type(b)} is not supported in linear solvers for now."

def info(self):
"""Check if the linear systems are solved successfully.
Returns:
bool: True if the solving process succeeded, False otherwise.
"""
return self.solver.info()

0 comments on commit c3478c7

Please sign in to comment.