diff --git a/test/stablehlo/test_saved_model.py b/test/stablehlo/test_saved_model.py index caa0b47a7ebe..07e6b5b247ae 100644 --- a/test/stablehlo/test_saved_model.py +++ b/test/stablehlo/test_saved_model.py @@ -38,7 +38,6 @@ def forward(self, a, b): b, ), constraints=constraints) shlo = exported_program_to_stablehlo(exported) - print(shlo.get_stablehlo_text()) with tempfile.TemporaryDirectory() as tempdir: save_stablehlo_graph_as_tf( shlo, diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index 7227c9b6a586..dd1a43d793b9 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -298,36 +298,19 @@ def _exported_program_to_stablehlo_bundle(exported_model, options = StableHLOExportOptions() exported_model = exported_model.run_decompositions() exported_model = exported_model.run_decompositions(_extra_decompositions) - input_args = _extract_input_args(exported_model, options) - device = xm.xla_device() - input_args = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device=device), - input_args) + args, kwargs = exported_model.example_inputs - # NOTE call convention: (parameters, buffers, user_inputs) - param_and_buffer_keys = exported_model.graph_signature.parameters + exported_model.graph_signature.buffers - state_dict = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device=device), - exported_model.state_dict) + assert len(kwargs) == 0, "Export to stablehlo doesnt support kwargs yet." - if (constants := getattr(exported_model, 'constants')) is not None: - state_dict.update( - pytree.tree_map_only(torch.Tensor, lambda x: x.to(device=device), - constants)) + device = xm.xla_device() - param_buffer_values = (state_dict[key] for key in param_and_buffer_keys) + _flat_input_args = exported_model._graph_module_flat_inputs(args, {}) + _flat_input_args = pytree.tree_map_only(torch.Tensor, + lambda x: x.to(device=device), + _flat_input_args) - if hasattr(exported_model.graph_signature, "lifted_tensor_constants"): - ordered_tensor_constants = tuple( - exported_model.tensor_constants[name] - for name in exported_model.graph_signature.lifted_tensor_constants) - else: - ordered_tensor_constants = () - - ordered_tensor_constants = pytree.tree_map_only(torch.Tensor, - lambda x: x.to(device=device), - ordered_tensor_constants) num_mutations = len(exported_model.graph_signature.buffers_to_mutate) - xm.mark_step() xm.wait_device_ops() metrics.clear_counters() @@ -336,11 +319,7 @@ def _exported_program_to_stablehlo_bundle(exported_model, # Run the fx graph tracing using lazy tensor xla_interpreter = XLAExportInterpreter(exported_model.graph_module, device) with torch.no_grad(): - res = xla_interpreter.run( - *param_buffer_values, - *ordered_tensor_constants, - *input_args, - enable_io_processing=False) + res = xla_interpreter.run(*_flat_input_args, enable_io_processing=False) res = res[num_mutations:] # If there are any fallback ops, this means that in torch/XLA side, @@ -350,17 +329,31 @@ def _exported_program_to_stablehlo_bundle(exported_model, "\n".join(fallback_ops)) raise RuntimeError(message) + InputKind = torch.export.graph_signature.InputKind + tensor_id_to_state_name = {} + state_dict = {} + input_ids = {} + for i, (tensor, input_spec) in enumerate( + zip(_flat_input_args, exported_model.graph_signature.input_specs)): + # Assumption: + # All states comes first in the list of args, and user provided inputs comes later. + # Also there is no kwargs. + if not isinstance(tensor, torch.Tensor): + continue + + tensor_id = torch_xla._XLAC._xla_get_tensor_id(tensor) + if input_spec.kind == InputKind.USER_INPUT: + input_position = i - len(state_dict) + input_ids[tensor_id] = input_position + else: + state_dict[input_spec.target] = tensor + tensor_id_to_state_name[tensor_id] = input_spec.target + ( graph_input_tensor_ids, graph_input_xla_values, ) = torch_xla._XLAC._get_tensors_xla_device_data_node(res) - tensor_id_to_state_name = { - torch_xla._XLAC._xla_get_tensor_id(value): name - for name, value in state_dict.items() - if isinstance(value, torch.Tensor) - } - stablehlo_content = xm.get_stablehlo_bytecode(res) if options.include_human_readable_text: stablehlo_text = xm.get_stablehlo(res) @@ -372,14 +365,9 @@ def _exported_program_to_stablehlo_bundle(exported_model, input_locations = [] input_signatures = [] additional_constants = [] - input_ids = { - torch_xla._XLAC._xla_get_tensor_id(tensor): pos - for pos, tensor in enumerate(input_args) - if isinstance(tensor, torch.Tensor) - } # there might be inputs that is part of input but not consumed by HLO graph - unused_input_positions = set(range(len(input_args))) + unused_input_positions = set(range(len(args))) for hlo_input_pos, (tensor_id, tensor_value) in enumerate( zip(graph_input_tensor_ids, graph_input_xla_values)): @@ -405,7 +393,7 @@ def _exported_program_to_stablehlo_bundle(exported_model, unused_inputs = [] for i in unused_input_positions: pos = InputLocation.input_arg(position=i) - arg = input_args[i] + arg = args[i] if isinstance(arg, torch.Tensor): signature = VariableSignature( shape=list(arg.shape), dtype=str(arg.dtype).replace('torch.', '')) @@ -441,7 +429,7 @@ def _exported_program_to_stablehlo_bundle(exported_model, stablehlo_funcs=[StableHLOFunc(meta, stablehlo_content, stablehlo_text)], state_dict=pytree.tree_map_only(torch.Tensor, lambda x: x.detach().cpu().numpy(), - exported_model.state_dict), + state_dict), additional_constants=additional_constants, ) diff --git a/torch_xla/tf_saved_model_integration.py b/torch_xla/tf_saved_model_integration.py index 946d5460f5f1..2a6fe1c6fe8a 100644 --- a/torch_xla/tf_saved_model_integration.py +++ b/torch_xla/tf_saved_model_integration.py @@ -31,13 +31,15 @@ def inner(*args): Touts = [sig.dtype for sig in func.meta.output_signature] Souts = [_get_shape_with_dynamic(sig) for sig in func.meta.output_signature] call_args = stablehlo._extract_call_parameters(args, func.meta, bundle) + import pdb; pdb.set_trace() return tfxla.call_module( tuple(call_args), - version=6, + version=5, Tout=Touts, # dtype information Sout=Souts, # Shape information function_list=[], module=func.bytecode, + #platforms=('CPU', 'CUDA', 'TPU', 'ROCM') ) return inner