Skip to content

Commit

Permalink
[Fori_loop|While_loop] Placeholder lower torch.while_loop with python…
Browse files Browse the repository at this point in the history
… dispatch for simple addition test case (pytorch#6532)
  • Loading branch information
ManfeiBai authored and amithrm committed Mar 1, 2024
1 parent 611bd44 commit 82e2f41
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tpu_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ jobs:
python3 -u test/test_autocast.py
python3 -u test/dynamo/test_dynamo.py
python3 -u test/spmd/test_spmd_debugging.py
python3 -u test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ function run_xla_op_tests1 {
function run_xla_op_tests2 {
run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py"
run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
import unittest
from typing import Callable, Dict, List

import torch
import torch_xla
# We need to import the underlying implementation function to register with the dispatcher
import torch_xla.experimental.fori_loop
from torch._higher_order_ops.while_loop import while_loop
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_builder as xb


def _fake_while_loop(cond_fn, body_fn, operands):
while cond_fn(*operands):
operands = body_fn(*operands)
return operands


class WhileLoopTest(unittest.TestCase):

def test_while_loop_tpu(self):

def cond_fn(x):
return x.sum() <= 10

def body_fn(x):
return (x + 1,)

device = xm.xla_device()
x = torch.ones(1, dtype=torch.int, device=device)
res = while_loop(cond_fn, body_fn, (x,))
expected = _fake_while_loop(cond_fn, body_fn, x)
self.assertEqual(expected, res)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/xla_test_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ spec:
python3 /src/pytorch/xla/test/spmd/test_spmd_debugging.py
python3 /src/pytorch/xla/test/pjrt/test_dtypes.py
python3 /src/pytorch/xla/test/pjrt/test_dynamic_plugin_tpu.py
python3 /src/pytorch/xla/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
volumeMounts:
- mountPath: /dev/shm
name: dshm
Expand Down
42 changes: 42 additions & 0 deletions torch_xla/experimental/fori_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
import torch
import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.core.xla_model as xm
import torch_xla.utils.utils as xu
import torch_xla.core.xla_op_registry as xor

from torch._C import DispatchKey
from torch._ops import HigherOrderOperator
import torch._higher_order_ops.while_loop
from torch._higher_order_ops.while_loop import while_loop_op


@while_loop_op.py_impl(DispatchKey.XLA)
def while_loop(cond_fn, body_fn, operands):
# cond_fn&body_fn: callable
# operands: (Tuple of possibly nested dict/list/tuple of tensors)
return _xla_while_loop(cond_fn, body_fn, operands)


def _xla_while_loop(cond_fn, body_fn, operands):

def op_fn(internal_x):
# TODO(manfei) replace cond_fn_placeholder and body_fn_placeholder once xla::while lowering in the backend is available
def cond_fn_placeholder(counter, internal_x):
return counter < xb.Op.scalar(internal_x.builder(), 10, dtype=xb.Type.S32)

def body_fn_placeholder(counter, internal_x):
next_counter = counter + xb.Op.scalar(
counter.builder(), 1, dtype=xb.Type.S32)
internal_x = internal_x + xb.Op.scalar(
internal_x.builder(), 1, dtype=xb.Type.S32)
return xb.Op.tuple((next_counter, internal_x))

zero = xb.Op.scalar(internal_x.builder(), 0, dtype=xb.Type.S32)
w = xb.Op.mkwhile((zero, internal_x), cond_fn_placeholder,
body_fn_placeholder)
return w.get_tuple_element(1)

op = xor.register('test_while', op_fn)
return xu.as_list(op(operands[0]))

0 comments on commit 82e2f41

Please sign in to comment.