Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Repair aten.where with Numpy + Broadcast #2372

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union

import numpy as np
import tensorrt as trt
Expand All @@ -11,7 +11,7 @@
get_trt_tensor,
)
from torch_tensorrt.dynamo.conversion.impl.slice import expand
from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor


Expand All @@ -20,23 +20,13 @@ def where(
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
other: TRTTensor,
condition: TRTTensor,
input: Union[TRTTensor, np.ndarray, torch.Tensor],
other: Union[TRTTensor, np.ndarray, torch.Tensor],
condition: Union[TRTTensor, np.ndarray, torch.Tensor],
) -> TRTTensor:
if not (broadcastable(input, other)):
assert "The two torch tensors should be broadcastable"

# get output shape
# purpose of this is to bring input and other rank same as
# output_shape to input it to the add_expand operation
# condition will have dimension of either input or other
input, other = broadcast(ctx.net, input, other, f"{name}_x", f"{name}_y")
if len(tuple(condition.shape)) != len(tuple(input.shape)):
condition, input = broadcast(
ctx.net, condition, input, f"{name}_condition", f"{name}_x"
)

x_shape = list(input.shape)
y_shape = list(other.shape)
condition_shape = list(condition.shape)
Expand Down Expand Up @@ -71,7 +61,11 @@ def where(
if isinstance(input, torch.Tensor)
else np.expand_dims(input, axis=0)
)
input = input.expand(output_shape)
input = (
input.expand(output_shape)
if isinstance(input, torch.Tensor)
else np.broadcast_to(input, output_shape)
)
x_val = get_trt_tensor(ctx, input, f"{name}_x")
else:
x_val = input
Expand All @@ -89,7 +83,11 @@ def where(
if isinstance(other, torch.Tensor)
else np.expand_dims(other, axis=0)
)
other = other.expand(output_shape)
other = (
other.expand(output_shape)
if isinstance(other, torch.Tensor)
else np.broadcast_to(other, output_shape)
)
y_val = get_trt_tensor(ctx, other, f"{name}_y")
else:
y_val = other
Expand Down
17 changes: 17 additions & 0 deletions tests/py/dynamo/conversion/test_where_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ def forward(self, condition):
(condition,),
)

def test_const_input_with_broadcast(self):
class Where(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.inputY = torch.randn((1,))
self.inputX = torch.randn((1,))

def forward(self, condition):
return torch.ops.aten.where.self(condition, self.inputX, self.inputY)

input1 = torch.randn((5, 6, 7))
condition = input1 < 0
self.run_test(
Where(),
(condition,),
)


if __name__ == "__main__":
run_tests()
Loading