Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fori_loop|While_loop] Placeholder lower torch.while_loop with python dispatch for simple addition test case #6532

Merged
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
0a2cc00
Create fori_loop.py
ManfeiBai Feb 13, 2024
c4b4384
Update fori_loop.py
ManfeiBai Feb 13, 2024
8ac0bb5
Update init_python_bindings.cpp
ManfeiBai Feb 13, 2024
0a0f8c6
Update init_python_bindings.cpp
ManfeiBai Feb 13, 2024
c62ad7d
Update init_python_bindings.cpp
ManfeiBai Feb 13, 2024
f0dd53e
Update fori_loop.py
ManfeiBai Feb 13, 2024
dcf65fc
Update fori_loop.py
ManfeiBai Feb 13, 2024
a1ed583
Create test_fori_loop.py
ManfeiBai Feb 13, 2024
cea9062
Update test_fori_loop.py
ManfeiBai Feb 13, 2024
7e1ffcb
Update test_fori_loop.py
ManfeiBai Feb 13, 2024
6694c09
Update test_fori_loop.py
ManfeiBai Feb 13, 2024
caeeb50
Update test_fori_loop.py
ManfeiBai Feb 13, 2024
6a4ad4f
use code from xb
ManfeiBai Feb 13, 2024
76a1cfa
only xb test
ManfeiBai Feb 13, 2024
694e797
check original version has used python dispatch or not
ManfeiBai Feb 13, 2024
fc4c8bb
check original version has used python dispatch or not again
ManfeiBai Feb 13, 2024
cf6aae9
add test script for xla
ManfeiBai Feb 13, 2024
04a8eff
add test script for xla again
ManfeiBai Feb 13, 2024
9c4ba74
test with while_loop_dense
ManfeiBai Feb 13, 2024
e377b77
change dispatchkey
ManfeiBai Feb 13, 2024
27d66f8
re dispatch
ManfeiBai Feb 13, 2024
82612f1
re dispatch again
ManfeiBai Feb 13, 2024
0de6bae
check again
ManfeiBai Feb 13, 2024
0977c54
check type
ManfeiBai Feb 13, 2024
e76484f
checkpoint to show body/cond hlo
ManfeiBai Feb 14, 2024
f5dea5d
correct xb
ManfeiBai Feb 14, 2024
a189666
add example code script
ManfeiBai Feb 14, 2024
12e6886
add example code script again
ManfeiBai Feb 14, 2024
202f1ef
only test tpu
ManfeiBai Feb 14, 2024
d6fd934
only test tpu
ManfeiBai Feb 14, 2024
8b6df8e
add result value check
ManfeiBai Feb 14, 2024
2db33ab
add result value check
ManfeiBai Feb 14, 2024
6d3f3a4
clean code
ManfeiBai Feb 14, 2024
700ed2a
try torchxla code
ManfeiBai Feb 14, 2024
0f05d76
try torchxla code again
ManfeiBai Feb 14, 2024
d7e54c3
try torchxla code again again
ManfeiBai Feb 14, 2024
1b11308
try torchxla code again again again
ManfeiBai Feb 14, 2024
fea71a6
add test on CPU/GPU tests
ManfeiBai Feb 14, 2024
9a42faf
add test on TPU test trigger's
ManfeiBai Feb 14, 2024
0d21e47
add test in TPU CI workflow
ManfeiBai Feb 14, 2024
45c4433
Merge branch 'master' into while_loop-lowering-with-simplecalculation…
ManfeiBai Feb 14, 2024
5050eb4
placeholder for xlacomputation
ManfeiBai Feb 17, 2024
92b69f4
add test
ManfeiBai Feb 17, 2024
c4ed61a
modif test code
ManfeiBai Feb 17, 2024
1c8395e
modify fori_loop
ManfeiBai Feb 17, 2024
c9e7b5f
format
ManfeiBai Feb 19, 2024
02f738f
format
ManfeiBai Feb 19, 2024
d689acc
format
ManfeiBai Feb 19, 2024
122d6e5
format
ManfeiBai Feb 19, 2024
e95ddc3
test log
ManfeiBai Feb 21, 2024
589b1b0
test
ManfeiBai Feb 21, 2024
02caa42
trigger test again
ManfeiBai Feb 21, 2024
ab5ae68
comment to explain
ManfeiBai Feb 22, 2024
594280c
comment for explain
ManfeiBai Feb 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tpu_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ jobs:
python3 -u test/test_autocast.py
python3 -u test/dynamo/test_dynamo.py
python3 -u test/spmd/test_spmd_debugging.py
python3 -u test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ function run_xla_op_tests1 {
function run_xla_op_tests2 {
run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py"
run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import unittest
from typing import Callable, Dict, List

import torch
import torch_xla
import torch_xla.experimental.fori_loop
ManfeiBai marked this conversation as resolved.
Show resolved Hide resolved
from torch._higher_order_ops.while_loop import while_loop
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_builder as xb


def _fake_while_loop(cond_fn, body_fn, operands):
while cond_fn(*operands):
operands = body_fn(*operands)
return operands


class WhileLoopTest(unittest.TestCase):

def test_while_loop_tpu(self):

def cond_fn(x):
return x.sum() <= 10

def body_fn(x):
return (x + 1,)

device = xm.xla_device()
x = torch.ones(1, dtype=torch.int, device=device)
res = while_loop(cond_fn, body_fn, (x,))
print("while_loop result: ", res)
expected = torch.tensor(11, dtype=torch.int, device=device)
print("expected result: ", expected)
self.assertEqual(expected, res[0])


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/xla_test_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ spec:
python3 /src/pytorch/xla/test/spmd/test_spmd_debugging.py
python3 /src/pytorch/xla/test/pjrt/test_dtypes.py
python3 /src/pytorch/xla/test/pjrt/test_dynamic_plugin_tpu.py
python3 /src/pytorch/xla/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
volumeMounts:
- mountPath: /dev/shm
name: dshm
Expand Down
42 changes: 42 additions & 0 deletions torch_xla/experimental/fori_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
import torch
import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.core.xla_model as xm
import torch_xla.utils.utils as xu
import torch_xla.core.xla_op_registry as xor

from torch._C import DispatchKey
from torch._ops import HigherOrderOperator
import torch._higher_order_ops.while_loop
from torch._higher_order_ops.while_loop import while_loop_op


@while_loop_op.py_impl(DispatchKey.XLA)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great to see that this is working!

def while_loop(cond_fn, body_fn, operands):
# cond_fn&body_fn: callable
# operands: (Tuple of possibly nested dict/list/tuple of tensors)
return _xla_while_loop(cond_fn, body_fn, operands)


def _xla_while_loop(cond_fn, body_fn, operands):

def op_fn(internal_x):
# TODO(manfei): replace cond_fn_placeholder and body_fn_placeholder after confirm xlacomputation could be in xla::while
def cond_fn_placeholder(counter, internal_x):
return counter < xb.Op.scalar(internal_x.builder(), 10, dtype=xb.Type.S32)

def body_fn_placeholder(counter, internal_x):
next_counter = counter + xb.Op.scalar(
counter.builder(), 1, dtype=xb.Type.S32)
internal_x = internal_x + xb.Op.scalar(
internal_x.builder(), 1, dtype=xb.Type.S32)
return xb.Op.tuple((next_counter, internal_x))

zero = xb.Op.scalar(internal_x.builder(), 0, dtype=xb.Type.S32)
w = xb.Op.mkwhile((zero, internal_x), cond_fn_placeholder,
body_fn_placeholder)
return w.get_tuple_element(1)

op = xor.register('test_while', op_fn)
return xu.as_list(op(operands[0]))
Loading