Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter tensor arguments from traced model. #5689

Merged
merged 7 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions test/dynamo/test_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,24 @@ class TorchXLAReuseGraphTest(torch._dynamo.test_case.TestCase):
test_training_linear = make_training_test(LinearModule)
test_training_maxpool = make_training_test(MaxPoolModule)

def test_non_tensor_args_for_partition(self):

class Emb(torch.nn.Embedding):

def __init__(self):
super().__init__(num_embeddings=10, embedding_dim=10, padding_idx=0)

device = xm.xla_device()
module = Emb()
module.to(device)

@torch.compile(backend="openxla")
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator Author

@ysiraichi ysiraichi Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, openxla doesn't have this problem. This only happens with openxla_eval. Since we are trying to get rid of openxla_eval, are we still interested in merging this PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is aot-autograd handles the argument. What's the non-tensor input passed to openxla_eval?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, exactly. One of the inputs passed was a custom class that inherited from nn.Embedding.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check with Brian whether it is possible for us to get non-tensor input after aot-autograd

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually @bdhirsh any thought?

def foo(x):
return module(x)

x = torch.randint(0, 10, (10,), device=device)
foo(x)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
27 changes: 17 additions & 10 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,19 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
for xla_arg in xla_args
]

args_tensor_ids = [
torch_xla._XLAC._xla_get_tensor_id(xla_arg) for xla_arg in xla_args
]
xla_tensor_args = [(i, xla_arg)
ysiraichi marked this conversation as resolved.
Show resolved Hide resolved
for i, xla_arg in enumerate(xla_args)
if isinstance(xla_arg, torch.Tensor)]

args_tensor_ids = [(index, torch_xla._XLAC._xla_get_tensor_id(xla_arg))
ysiraichi marked this conversation as resolved.
Show resolved Hide resolved
for index, xla_arg in xla_tensor_args]

if dynamo_debug:
print(f"Graph module:\n{xla_model.code}")
print(f"args_tensor_ids {args_tensor_ids}")

tensor_id_to_arg_idx = {
tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)
tensor_id: index for index, tensor_id in args_tensor_ids
}

if xr.is_spmd():
Expand All @@ -258,15 +261,16 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):

# If a arg is being in place updated by model, we need to include arg as part of the graph result.
xla_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
xla_args)
[tensor for _, tensor in xla_tensor_args])
xla_args_need_update = []
arg_index_to_need_update_index = {}
for i, need_update in enumerate(xla_args_need_update_bool):
# Don't add inplace updated argument to the list if it's already
# being returned
if need_update and id(xla_args[i]) not in xla_out_ids:
arg_index_to_need_update_index[i] = len(xla_args_need_update)
xla_args_need_update.append(xla_args[i])
index, tensor = xla_tensor_args[i]
if need_update and id(tensor) not in xla_out_ids:
arg_index_to_need_update_index[index] = len(xla_args_need_update)
xla_args_need_update.append(tensor)

args_and_out = tuple(xla_args_need_update) + tuple(xla_out)

Expand Down Expand Up @@ -325,7 +329,8 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
def extract_internal(xla_model: torch.fx.GraphModule):
if dynamo_debug:
for xla_arg in xla_model.xla_args:
print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg))
if isinstance(xla_arg, torch.Tensor):
print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg))
xm.mark_step()
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
Expand All @@ -347,7 +352,9 @@ def optimized_mod(*args):

# mark_step needs to be blocking since we want to access args's XLADatas
# and they can't be placeholder.
if any(torch_xla._XLAC._check_tensor_need_materialization(args)):
if any(
torch_xla._XLAC._check_tensor_need_materialization(
[a for a in args if isinstance(a, torch.Tensor)])):
xm.mark_step(wait=True)

# If input sharding has changed from the previous program, dynamo current can
Expand Down
Loading