Skip to content

Commit

Permalink
Dont wrap ints in torch.tensor; (pytorch#6539)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored and amithrm committed Mar 1, 2024
1 parent ef97c23 commit 82f40f8
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,8 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any:
args = tuple(
map(
lambda arg_spec: torch.tensor(arg_spec[0])
if isinstance(arg_spec[0], (float, int)) and type(arg_spec[
1].type) == torch.TensorType else arg_spec[0],
args_and_specs))
if isinstance(arg_spec[0], float) and type(arg_spec[1].type) ==
torch.TensorType else arg_spec[0], args_and_specs))
return super().call_function(target, args, new_kwargs)

def run_node(self, n) -> Any:
Expand Down

0 comments on commit 82f40f8

Please sign in to comment.