Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc] Update PyTorch interface documentation #4311

Merged
merged 12 commits into from
Mar 2, 2022
71 changes: 63 additions & 8 deletions docs/lang/articles/basic/external.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ support NumPy, e.g. `matplotlib`.
```python {8}
@ti.kernel
def my_kernel():
for i in x:
x[i] = i * 2
for i in x:
x[i] = i * 2

x = ti.field(ti.f32, 4)
my_kernel()
Expand All @@ -41,12 +41,38 @@ print(x[2]) # 3
print(x[3]) # 5
```

Likewise, Taichi fields can be **imported from and exported to PyTorch tensors**:
```python
@ti.kernel
def my_kernel():
for i in x:
x[i] = i * 2

x = ti.field(ti.f32, 4)
my_kernel()
x_torch = x.to_torch()
print(x_torch) # torch.tensor([0, 2, 4, 6])

x.from_numpy(torch.tensor([1, 7, 3, 5]))
print(x[0]) # 1
print(x[1]) # 7
print(x[2]) # 3
print(x[3]) # 5
```
When calling `to_torch()`, specify the PyTorch device where the Taichi field is exported using the `device` argument:
```python
x = ti.field(ti.f32, 4)
x.fill(3.0)
x_torch = x.to_torch(device="cuda:0")
print(x_torch.device) # device(type='cuda', index=0)
```

## External array shapes

Shapes of Taichi fields and those of corresponding NumPy arrays are closely
Shapes of Taichi fields and those of corresponding NumPy arrays or PyTorch tensors are closely
connected via the following rules:

- For scalar fields, **the shape of NumPy array is exactly the same as
- For scalar fields, **the shape of NumPy array or PyTorch tensor equals the shape of
the Taichi field**:

```python
Expand All @@ -60,7 +86,7 @@ field.from_numpy(array) # the input array must be of shape (256, 512)
```

- For vector fields, if the vector is `n`-D, then **the shape of NumPy
array should be** `(*field_shape, vector_n)`:
array or Pytorch tensor should be** `(*field_shape, vector_n)`:

```python
field = ti.Vector.field(3, ti.i32, shape=(256, 512))
Expand All @@ -74,7 +100,7 @@ field.from_numpy(array) # the input array must be of shape (256, 512, 3)
```

- For matrix fields, if the matrix is `n`-by-`m` (`n x m`), then **the shape of NumPy
array should be** `(*field_shape, matrix_n, matrix_m)`:
array or Pytorch Tensor should be** `(*field_shape, matrix_n, matrix_m)`:

```python
field = ti.Matrix.field(3, 4, ti.i32, shape=(256, 512))
Expand All @@ -88,7 +114,8 @@ array.shape # (256, 512, 3, 4)
field.from_numpy(array) # the input array must be of shape (256, 512, 3, 4)
```

- For struct fields, the external array will be exported as **a dictionary of arrays** with the keys being struct member names and values being struct member arrays. Nested structs will be exported as nested dictionaries:
- For struct fields, the external array will be exported as **a dictionary of NumPy arrays or PyTorch tensors** with keys
being struct member names and values being struct member arrays. Nested structs will be exported as nested dictionaries:

```python
field = ti.Struct.field({'a': ti.i32, 'b': ti.types.vector(float, 3)} shape=(256, 512))
Expand All @@ -104,7 +131,7 @@ field.from_numpy(array_dict) # the input array must have the same keys as the fi

## Using external arrays as Taichi kernel arguments

Use the type hint `ti.ext_arr()` for passing external arrays as kernel
Use type hint `ti.ext_arr()` or `ti.any_arr()` to pass external arrays as kernel
arguments. For example:

```python {10}
Expand Down Expand Up @@ -135,3 +162,31 @@ for i in range(n):
for j in range(m):
assert a[i, j] == i * j + i + j
```

Note that the elements in an external array must be indexed using a single square bracket.
This contrasts with a Taichi vector or matrix field where field and matrix indices are indexed separately:
```python
@ti.kernel
def copy_vector(x: ti.template(), y: ti.ext_arr()):
for i, j in ti.ndrange(n, m):
for k in ti.static(range(3)):
y[i, j, k] = x[i, j][k] # correct
# y[i][j][k] = x[i, j][k] incorrect
# y[i, j][k] = x[i, j][k] incorrect
```
Also, external arrays in a Taichi kernel are indexed using its **physical memory layout**. For PyTorch users,
this implies that the PyTorch tensor [needs to be made contiguous](https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html)
before being passed into a Taichi kernel:
```python
@ti.kernel
def copy_scalar(x: ti.template(), y: ti.ext_arr()):
for i, j in x:
y[i, j] = x[i, j]

x = ti.field(dtype=int, shape=(3, 3))
y = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
y = y.T # Transposing the tensor returns a view of the tensor which is not contiguous
copy(x, y) # error!
copy(x, y.clone()) # correct
copy(x, y.contiguous()) # correct
```