Skip to content

Commit

Permalink
Modify export to use refactored _graph_module_flat_inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Feb 13, 2024
1 parent 1bcc014 commit 3ae78a4
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 45 deletions.
1 change: 0 additions & 1 deletion test/stablehlo/test_saved_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def forward(self, a, b):
b,
), constraints=constraints)
shlo = exported_program_to_stablehlo(exported)
print(shlo.get_stablehlo_text())
with tempfile.TemporaryDirectory() as tempdir:
save_stablehlo_graph_as_tf(
shlo,
Expand Down
74 changes: 31 additions & 43 deletions torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,36 +298,19 @@ def _exported_program_to_stablehlo_bundle(exported_model,
options = StableHLOExportOptions()
exported_model = exported_model.run_decompositions()
exported_model = exported_model.run_decompositions(_extra_decompositions)
input_args = _extract_input_args(exported_model, options)

device = xm.xla_device()
input_args = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device=device),
input_args)
args, kwargs = exported_model.example_inputs

# NOTE call convention: (parameters, buffers, user_inputs)
param_and_buffer_keys = exported_model.graph_signature.parameters + exported_model.graph_signature.buffers
state_dict = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device=device),
exported_model.state_dict)
assert len(kwargs) == 0, "Export to stablehlo doesnt support kwargs yet."

if (constants := getattr(exported_model, 'constants')) is not None:
state_dict.update(
pytree.tree_map_only(torch.Tensor, lambda x: x.to(device=device),
constants))
device = xm.xla_device()

param_buffer_values = (state_dict[key] for key in param_and_buffer_keys)
_flat_input_args = exported_model._graph_module_flat_inputs(args, {})
_flat_input_args = pytree.tree_map_only(torch.Tensor,
lambda x: x.to(device=device),
_flat_input_args)

if hasattr(exported_model.graph_signature, "lifted_tensor_constants"):
ordered_tensor_constants = tuple(
exported_model.tensor_constants[name]
for name in exported_model.graph_signature.lifted_tensor_constants)
else:
ordered_tensor_constants = ()

ordered_tensor_constants = pytree.tree_map_only(torch.Tensor,
lambda x: x.to(device=device),
ordered_tensor_constants)
num_mutations = len(exported_model.graph_signature.buffers_to_mutate)

xm.mark_step()
xm.wait_device_ops()
metrics.clear_counters()
Expand All @@ -336,11 +319,7 @@ def _exported_program_to_stablehlo_bundle(exported_model,
# Run the fx graph tracing using lazy tensor
xla_interpreter = XLAExportInterpreter(exported_model.graph_module, device)
with torch.no_grad():
res = xla_interpreter.run(
*param_buffer_values,
*ordered_tensor_constants,
*input_args,
enable_io_processing=False)
res = xla_interpreter.run(*_flat_input_args, enable_io_processing=False)
res = res[num_mutations:]

# If there are any fallback ops, this means that in torch/XLA side,
Expand All @@ -350,17 +329,31 @@ def _exported_program_to_stablehlo_bundle(exported_model,
"\n".join(fallback_ops))
raise RuntimeError(message)

InputKind = torch.export.graph_signature.InputKind
tensor_id_to_state_name = {}
state_dict = {}
input_ids = {}
for i, (tensor, input_spec) in enumerate(
zip(_flat_input_args, exported_model.graph_signature.input_specs)):
# Assumption:
# All states comes first in the list of args, and user provided inputs comes later.
# Also there is no kwargs.
if not isinstance(tensor, torch.Tensor):
continue

tensor_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
if input_spec.kind == InputKind.USER_INPUT:
input_position = i - len(state_dict)
input_ids[tensor_id] = input_position
else:
state_dict[input_spec.target] = tensor
tensor_id_to_state_name[tensor_id] = input_spec.target

(
graph_input_tensor_ids,
graph_input_xla_values,
) = torch_xla._XLAC._get_tensors_xla_device_data_node(res)

tensor_id_to_state_name = {
torch_xla._XLAC._xla_get_tensor_id(value): name
for name, value in state_dict.items()
if isinstance(value, torch.Tensor)
}

stablehlo_content = xm.get_stablehlo_bytecode(res)
if options.include_human_readable_text:
stablehlo_text = xm.get_stablehlo(res)
Expand All @@ -372,14 +365,9 @@ def _exported_program_to_stablehlo_bundle(exported_model,
input_locations = []
input_signatures = []
additional_constants = []
input_ids = {
torch_xla._XLAC._xla_get_tensor_id(tensor): pos
for pos, tensor in enumerate(input_args)
if isinstance(tensor, torch.Tensor)
}

# there might be inputs that is part of input but not consumed by HLO graph
unused_input_positions = set(range(len(input_args)))
unused_input_positions = set(range(len(args)))

for hlo_input_pos, (tensor_id, tensor_value) in enumerate(
zip(graph_input_tensor_ids, graph_input_xla_values)):
Expand All @@ -405,7 +393,7 @@ def _exported_program_to_stablehlo_bundle(exported_model,
unused_inputs = []
for i in unused_input_positions:
pos = InputLocation.input_arg(position=i)
arg = input_args[i]
arg = args[i]
if isinstance(arg, torch.Tensor):
signature = VariableSignature(
shape=list(arg.shape), dtype=str(arg.dtype).replace('torch.', ''))
Expand Down Expand Up @@ -441,7 +429,7 @@ def _exported_program_to_stablehlo_bundle(exported_model,
stablehlo_funcs=[StableHLOFunc(meta, stablehlo_content, stablehlo_text)],
state_dict=pytree.tree_map_only(torch.Tensor,
lambda x: x.detach().cpu().numpy(),
exported_model.state_dict),
state_dict),
additional_constants=additional_constants,
)

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/tf_saved_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def inner(*args):
call_args = stablehlo._extract_call_parameters(args, func.meta, bundle)
return tfxla.call_module(
tuple(call_args),
version=6,
version=5,
Tout=Touts, # dtype information
Sout=Souts, # Shape information
function_list=[],
Expand Down

0 comments on commit 3ae78a4

Please sign in to comment.