Skip to content

Commit

Permalink
[Fori_loop|While_loop] Fori loop wrapped from while_loop (#6850)
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Apr 11, 2024
1 parent cfc70e6 commit f0354ec
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,26 @@
import torch_xla
# We need to import the underlying implementation function to register with the dispatcher
import torch_xla.experimental.fori_loop
from torch_xla.experimental.fori_loop import fori_loop
from torch._higher_order_ops.while_loop import while_loop
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_builder as xb


def _fake_while_loop(cond_fn, body_fn, operands):
while cond_fn(operands[0], operands[1]):
operands = body_fn(operands[0], operands[1])
# operands need to be more than one here
while cond_fn(*operands):
operands = body_fn(*operands)
return operands


def _fake_fori_loop(lower, upper, body_fun, *init_val):
(plus_value, init_val) = init_val
for i in range((upper - lower)[0]):
plus_value, init_val = body_fun(plus_value, init_val)
return init_val


class WhileLoopTest(unittest.TestCase):

def test_while_loop_tpu_subtraction(self):
Expand Down Expand Up @@ -73,7 +82,25 @@ def body_fn(init, limit_value):
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
self.assertEqual(expected, res)

def test_fori_loop_tpu_addition(self):

xm.mark_step()
device = xm.xla_device()

lower = torch.tensor([2], dtype=torch.int32, device=device)
upper = torch.tensor([52], dtype=torch.int32, device=device)
plus_value = torch.tensor([1], dtype=torch.int32, device=device)
init_val = torch.tensor([1], dtype=torch.int32, device=device)

def body_fun(*argus):
plus_value, init_val = argus
return plus_value, torch.add(plus_value, init_val)

_, _, _, actual = fori_loop(upper, lower, body_fun, plus_value, init_val)
expected = _fake_fori_loop(lower, upper, body_fun, plus_value, init_val)
self.assertEqual(expected, actual)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
sys.exit(0 if test.result.wasSuccessful() else 1)
38 changes: 38 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,43 @@ class PyLoweringContext {
lowering_ctx.AddResult(root);
}
computation = ConsumeValue(lowering_ctx.BuildXla());
}

// Builds a HLO graph given a set of output tensors, and add unused parameters
// needed in xlacomputation.
void BuildForiLoop(std::vector<at::Tensor> tensors,
std::vector<at::Tensor> input_arguments = {}) {
if (GetNameString() == "condctx") {
xla::XlaBuilder* local_builder = lowering_ctx.builder();
// hard-code parameter_idx to 2 to skip existing upper/lower arguments
int64_t parameter_idx = 2;
for (at::Tensor input_argument : input_arguments) {
xla::Shape shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1});
xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape,
"UnusedArgumentsPlaceholder");
parameter_idx += 1;
}
}

// Get the backing XLA tensors from the output torch tensor handles
std::vector<XLATensorPtr> xtensors =
GetXlaTensors(tensors, /*want_all=*/true);

// Get the lazy IR value from the output XLA tensors
std::vector<torch::lazy::Value> ir_values;
for (auto& xtensor : xtensors) {
torch::lazy::Value value = xtensor->GetIrValue();
ir_values.push_back(value);
}

// Lower the graph using the output IR values
for (auto& ir_value : ir_values) {
xla::XlaOp root = lowering_ctx.GetOutputOp(
torch::lazy::Output(ir_value.node.get(), ir_value.index));
lowering_ctx.AddResult(root);
}
computation = ConsumeValue(lowering_ctx.BuildXla());

// wrap inputs of cond/body_computation
if ((GetNameString() == "condctx") || (GetNameString() == "bodyctx")) {
Expand Down Expand Up @@ -1044,6 +1081,7 @@ void BuildLoweringContextSubmodule(py::module* m) {

lowering_context_class.def(py::init<>())
.def("build", &PyLoweringContext::Build)
.def("buildforiloop", &PyLoweringContext::BuildForiLoop)
.def("hlo", &PyLoweringContext::GetHlo)
.def("hlo_text", &PyLoweringContext::GetHloText)
.def("hlo_json", &PyLoweringContext::GetHloJsonText)
Expand Down
56 changes: 44 additions & 12 deletions torch_xla/experimental/fori_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,73 @@
from torch._higher_order_ops.while_loop import while_loop_op


def fori_loop(lower, upper, user_body_func, *init_val):

device = xm.xla_device()

def cond_fn(upper, lower, *init_val):
return lower[0] < upper[0]

def body_fn(upper, lower, *init_val):
one_value_i = torch.ones(1, dtype=torch.int32, device=device)
res_list = list(user_body_func(*init_val))
res_list.insert(0, lower)
res_list.insert(0, torch.sub(upper, one_value_i))
return res_list

res = while_loop(cond_fn, body_fn, (lower, upper, *init_val))
return res


@while_loop_op.py_impl(DispatchKey.XLA)
def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None):
def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None):
# TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '')
# cond_fn&body_fn: callable
# carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors)
if additional_inputs is None:
additional_inputs = tuple()
return _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs)
return _xla_while_loop(
cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs)


def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs):
def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs):
# untuple carried_inputs from while_loop
carried_inputs = carried_inputs[0]
# fake carried_inputs to split formal code
fake_carried_inputs = []
for carried_input in carried_inputs:
device = carried_input.device
fake_carried_inputs.append(
torch.randint(10, carried_input.size(),
dtype=carried_input.dtype).to(device))
fake_carried_inputs = tuple(fake_carried_inputs)

# create inputs placeholder
# trans fake_carried_inputs from list(tensor) to list(xla::op)
kwargs = {}
shapes = xb.tensor_shape(carried_inputs)
if type(fake_carried_inputs) is tuple:
shapes = xb.tensor_shape(fake_carried_inputs)
else:
shapes = xb.tensor_shape((fake_carried_inputs))
builder = xb.create_builder('test_while')
params = []
for shape in shapes:
p = xb.mkparam(builder, len(params), shape)
params.append(p)

# generate cond_fn xlacomputation
cond_result = cond_fn(carried_inputs[0], carried_inputs[1])
cond_result = cond_fn(*fake_carried_inputs)
cond_ctx = torch_xla._XLAC.lowering.LoweringContext()
cond_ctx.set_name_string("condctx")
cond_ctx.build([cond_result])
cond_ctx.buildforiloop([cond_result], list(fake_carried_inputs[2:]))
cond_hlo = cond_ctx.hlo()
cond_computation = xb.computation_from_module_proto("condcomputation",
cond_hlo)

# generate body_fn xlacomputation
body_result = body_fn(carried_inputs[0], carried_inputs[1])
body_result = body_fn(*fake_carried_inputs)
body_ctx = torch_xla._XLAC.lowering.LoweringContext()
body_ctx.set_name_string("bodyctx")
body_ctx.build(list(body_result))
body_ctx.buildforiloop(list(body_result), [])
body_hlo = body_ctx.hlo()
body_computation = xb.computation_from_module_proto("bodycomputation",
body_hlo)
Expand All @@ -61,7 +94,6 @@ def _xla_while_loop(cond_fn, body_fn, carried_inputs, additional_inputs):

# gain final result with generated while xlacomputation
result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while',
tuple(carried_inputs),
computation)
(carried_inputs), computation)

return result
return result

0 comments on commit f0354ec

Please sign in to comment.