-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[backport][Fori_loop|While_loop] Enable while_loop/fori_loop, add tes…
- Loading branch information
Showing
7 changed files
with
287 additions
and
247 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,114 +1,72 @@ | ||
# Fori_loop | ||
`fori_loop` is a replacement of pure python for loop, PyTorch/XLA would enable `torch_xla.experimental.fori_loop` to keep loop computation graph as rolled during compilation | ||
like [`jax.lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html), not like currently repeat computations by enumerating all execution steps | ||
of each iteration. `fori_loop` might help memory utilization and might help faster compilation. | ||
# `While_loop` optimize memory utilization and compilation | ||
|
||
User could use `fori_loop` like this: | ||
```python | ||
from torch_xla.experimental.fori_loop import fori_loop | ||
res = fori_loop(upper, lower, /*user defined*/body_fun, init) | ||
``` | ||
|
||
current fori_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-fori_loop) with `fori_loop` on TPU too. | ||
<br> | ||
|
||
For detailed implementation: | ||
- for situation that loop range is dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`while_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#while_loop), | ||
like [`jax.lax.while_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html), PyTorch/XLA would support `while_loop` with the | ||
native PyTorch and the XLA backend: XLA::While. Due to `while_loop` didn't support autograd, so it would be used for inference only. | ||
### `while_loop` | ||
`while_loop` replace pure python `while` loop, PyTorch supported `while_loop` by | ||
[torch._higher_order_ops.while_loop](https://github.com/pytorch/pytorch/blob/62311257adb902d6a4ea98809c88895af1dbbf2b/torch/_higher_order_ops/while_loop.py#L66). | ||
PyTorch/XLA provide experimental XLA backend support for `torch._higher_order_ops.while_loop` via `XLA::While`. | ||
|
||
- for situation that loop range is not dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`scan`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#wipscan), | ||
like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` using XLA::While operator. | ||
This implementation would be very similar like `while_loop`. `scan` support autograd, and it could be used in both training and inference. | ||
|
||
# while_loop | ||
`while_loop` is a replacement of pure python while loop, PyTorch has supported `while_loop` in | ||
[code](https://github.com/pytorch/pytorch/blob/ca6a0e1348ba7dcade1833d983b1b4ca12a5c1e1/torch/_higher_order_ops/while_loop.py#L69). | ||
PyTorch/XLA want to support `while_loop` with the native PyTorch and the XLA backend: XLA::While. | ||
|
||
User could use `while_loop` like this: | ||
#### Usage: | ||
```python | ||
import torch_xla.experimental.fori_loop | ||
from torch._higher_order_ops.while_loop import while_loop | ||
res = while_loop(/*user-defined*/cond_fn, /*user-defined*/body_fn, /*tuple or list*/init) | ||
result = while_loop(cond_fn, body_fn, init) | ||
``` | ||
current while_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-while_loop) with `while_loop` on TPU too. | ||
|
||
- `cond_fn`: User-defined condition function. | ||
- `body_fn`: User-defined loop body function. | ||
- `init`: Initial values (tuple or list). | ||
|
||
# [WIP]scan | ||
like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` for training and inference since it support autograd. | ||
`scan` is WIP. | ||
|
||
|
||
# Simple user guide | ||
User could try these three simple test case to better compare difference between `pure python for loop` and `fori_loop` and `while_loop`, these three test case have similar logic: cumulative plus 1 for ten times: | ||
|
||
### simple example with pure python for loop | ||
```bash | ||
# python | ||
>>> import torch | ||
>>> init = torch.tensor([0], dtype=torch.int32) | ||
>>> one_value = torch.ones(1, dtype=torch.int32) | ||
>>> | ||
>>> for i in range(10): | ||
... init = init + one_value | ||
... | ||
>>> init | ||
tensor([10], dtype=torch.int32) | ||
``` | ||
### simple example with `while_loop`: | ||
#### simple example with `while_loop`: | ||
```bash | ||
# PJRT_DEVICE=TPU python | ||
>>> import torch | ||
>>> import torch_xla | ||
>>> import torch_xla.experimental.fori_loop | ||
>>> from torch_xla.experimental.fori_loop import fori_loop | ||
>>> from torch._higher_order_ops.while_loop import while_loop | ||
>>> import torch_xla.core.xla_model as xm | ||
>>> import torch_xla.core.xla_builder as xb | ||
>>> | ||
>>> device = xm.xla_device() | ||
>>> | ||
>>> def cond_fn(init, limit_value): | ||
... return limit_value[0] >= init[0] | ||
>>> def cond_fn(iteri, x): | ||
... return iteri > 0 | ||
... | ||
>>> def body_fn(init, limit_value): | ||
... one_value = torch.ones(1, dtype=torch.int32, device=device) | ||
... return (torch.add(init, one_value), limit_value.clone()) | ||
>>> def body_fn(iteri, x): | ||
... return iteri - 1, torch.add(x, 1) | ||
... | ||
>>> init = torch.tensor([0], dtype=torch.int32, device=device) | ||
>>> limit_value = torch.tensor([10], dtype=torch.int32, device=device) | ||
>>> res_, limit_value_ = while_loop(cond_fn, body_fn, (init, limit_value)) | ||
>>> res_ | ||
>>> init_val = torch.tensor(3, device=device) | ||
>>> iteri = torch.tensor(10, device=device) | ||
>>> _, res = while_loop(cond_fn, body_fn, (iteri, init_val)) | ||
>>> res | ||
FunctionalTensor(lvl=0, value=\ | ||
tensor([11], device='xla:0', dtype=torch.int32)) | ||
tensor(13, device='xla:0')) | ||
``` | ||
|
||
### simple example with `fori_loop`: | ||
<br> | ||
|
||
## Control group test case | ||
For better compare difference between `pure python while loop` and `while_loop`, there is one test case called pure python `while` loop with similar logic: cumulative plus 1 for ten times: | ||
|
||
### Control group example with pure python `while` loop | ||
```bash | ||
# PJRT_DEVICE=TPU python | ||
>>> import torch | ||
>>> import torch_xla | ||
>>> import torch_xla.experimental.fori_loop | ||
>>> from torch_xla.experimental.fori_loop import fori_loop | ||
>>> from torch._higher_order_ops.while_loop import while_loop | ||
>>> import torch_xla.core.xla_model as xm | ||
>>> import torch_xla.core.xla_builder as xb | ||
>>> | ||
>>> device = xm.xla_device() | ||
>>> | ||
>>> lower = torch.tensor([2], dtype=torch.int32, device=device) | ||
>>> upper = torch.tensor([52], dtype=torch.int32, device=device) | ||
>>> plus_value = torch.tensor([1], dtype=torch.int32, device=device) | ||
>>> init_val = torch.tensor([1], dtype=torch.int32, device=device) | ||
>>> init_val = torch.tensor(1, device=device) | ||
>>> iteri = torch.tensor(50, device=device) | ||
>>> | ||
>>> def body_fun(*argus): | ||
... plus_value, init_val = argus | ||
... return plus_value, torch.add(plus_value, init_val) | ||
>>> while iteri > 0: | ||
... init_val = init_val + 1 | ||
... iteri -= 1 | ||
... | ||
>>> _, _, _, res_ = fori_loop(upper, lower, body_fun, plus_value, init_val) | ||
>>> res_ | ||
tensor([51], device='xla:0', dtype=torch.int32) | ||
>>> init_val | ||
tensor(51, device='xla:0') | ||
``` | ||
For more example and detailed user guide, please read [this test file](https://github.com/pytorch/xla/blob/master/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py). PyTorch/XLA would include `while_loop` support in 2.3 for simple test case, complex test case and support for `fori_loop` and `scan` would be added after 2.3 | ||
PyTorch/XLA would include `while_loop` support in 2.4 with test case, support for `fori_loop` would be added after 2.4. For `while_loop`, currently we only should force define `body_fn` with same `input` and `output(return args)` shape |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
106 changes: 0 additions & 106 deletions
106
test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import os | ||
import unittest | ||
from typing import Callable, Dict, List | ||
|
||
import torch | ||
import torch_xla | ||
# We need to import the underlying implementation function to register with the dispatcher | ||
import torch_xla.experimental.fori_loop | ||
from torch_xla.experimental.fori_loop import fori_loop | ||
from torch._higher_order_ops.while_loop import while_loop | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.core.xla_builder as xb | ||
import torch_xla.utils.utils as xu | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
|
||
|
||
def _fake_while_loop(cond_fn, body_fn, operands): | ||
# operands need to be more than one here | ||
while cond_fn(*operands): | ||
operands = body_fn(*operands) | ||
return operands | ||
|
||
|
||
class WhileLoopTest(unittest.TestCase): | ||
|
||
def test_while_loop_addition(self): | ||
device = xm.xla_device() | ||
|
||
def cond_fn(iteri, x): | ||
return iteri > 0 | ||
|
||
def body_fn(iteri, x): | ||
return iteri - 1, torch.add(x, 1) | ||
|
||
init_val = torch.tensor(3, dtype=torch.int32, device=device) | ||
iteri = torch.tensor(10, device=device) | ||
_, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val)) | ||
_, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val)) | ||
self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) | ||
|
||
def test_while_loop_addition_nested(self): | ||
device = xm.xla_device() | ||
|
||
def cond_fn(iteri, x): | ||
return iteri > 0 | ||
|
||
def body_fn(iteri, x): | ||
return iteri - 1, torch.add(torch.add(x, 1), 1) | ||
|
||
init_val = torch.tensor(2, dtype=torch.int32, device=device) | ||
iteri = torch.tensor(10, device=device) | ||
_, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val)) | ||
_, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val)) | ||
self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) | ||
|
||
def test_while_loop_simple_linear_inside_loop(self): | ||
device = xm.xla_device() | ||
torch.set_grad_enabled(False) | ||
|
||
class SimpleLinear(torch.nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.linear = torch.nn.Linear(2, 2) | ||
|
||
def forward(self, iteri, x): | ||
|
||
def cond_fn(iteri, x): | ||
return iteri > 0 | ||
|
||
def body_fn(iteri, x): | ||
return iteri - 1, self.linear(x) | ||
|
||
return while_loop(cond_fn, body_fn, (iteri, x)) | ||
|
||
def forward_without_while_loop_op(self, iteri, x): | ||
while (iteri > 0): | ||
x = self.linear(x) | ||
iteri -= 1 | ||
return iteri, x | ||
|
||
linear_model = SimpleLinear() | ||
linear_model.to(device) | ||
l_in_0 = torch.randn(2, 2, dtype=torch.float32, device=device) | ||
iteri = torch.tensor(10, dtype=torch.int32, device=device) | ||
_, res_with_loop = linear_model(iteri, l_in_0) | ||
_, res_without_loop = linear_model.forward_without_while_loop_op( | ||
iteri, l_in_0) | ||
|
||
self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) | ||
|
||
# ====== fori_loop ====== | ||
@unittest.skip("Fori_loop is not supported now due to unstable result.") | ||
def test_fori_loop_addition(self): | ||
device = xm.xla_device() | ||
|
||
lower = torch.tensor(0, device=device) | ||
upper = torch.tensor(50, device=device) | ||
init_val = torch.tensor(1, dtype=torch.int32, device=device) | ||
|
||
def body_fun(x): | ||
return torch.add(x, 1) | ||
|
||
_, res_with_loop = fori_loop(lower, upper, body_fun, (init_val)) | ||
|
||
# === expected === | ||
for i in range(upper - lower): | ||
init_val = torch.add(init_val, 1) | ||
res_without_loop = init_val | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.