Skip to content

Commit

Permalink
Fix an error in the shape logic
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 committed Aug 16, 2022
1 parent 168fe6d commit b5f3d43
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 53 deletions.
105 changes: 55 additions & 50 deletions lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5441,10 +5441,6 @@ module {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.remainder.Scalar"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.to.dtype"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
Expand Down Expand Up @@ -5524,6 +5520,10 @@ module {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.remainder.Scalar"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.floor_divide.Scalar"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
Expand Down Expand Up @@ -6434,17 +6434,17 @@ module {
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.conv_transpose2d.input"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int, %arg7: !torch.list<int>) -> !torch.list<int> {
%true = torch.constant.bool true
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int0 = torch.constant.int 0
%true = torch.constant.bool true
%0 = torch.aten.len.t %arg7 : !torch.list<int> -> !torch.int
%1 = torch.aten.gt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%3 = torch.prim.ListConstruct : () -> !torch.list<int>
%4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int
%5 = torch.aten.append.t %3, %4 : !torch.list<int>, !torch.int -> !torch.list<int>
%6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int
%6 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int
%7 = torch.aten.append.t %3, %6 : !torch.list<int>, !torch.int -> !torch.list<int>
%8 = torch.aten.__range_length %int2, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
torch.prim.Loop %8, %true, init() {
Expand Down Expand Up @@ -6477,62 +6477,67 @@ module {
return %3 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.convolution"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int) -> !torch.list<int> {
%true = torch.constant.bool true
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int0 = torch.constant.int 0
%true = torch.constant.bool true
%0 = torch.aten.len.t %arg5 : !torch.list<int> -> !torch.int
%1 = torch.aten.gt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%3 = torch.prim.ListConstruct : () -> !torch.list<int>
%4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int
%5 = torch.aten.append.t %3, %4 : !torch.list<int>, !torch.int -> !torch.list<int>
%6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int
%7 = torch.aten.append.t %3, %6 : !torch.list<int>, !torch.int -> !torch.list<int>
%8 = torch.aten.__range_length %int2, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
torch.prim.Loop %8, %true, init() {
%4 = torch.prim.If %arg6 -> (!torch.int) {
torch.prim.If.yield %int1 : !torch.int
} else {
torch.prim.If.yield %int0 : !torch.int
}
%5 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int
%6 = torch.aten.append.t %3, %5 : !torch.list<int>, !torch.int -> !torch.list<int>
%7 = torch.aten.__getitem__.t %arg1, %4 : !torch.list<int>, !torch.int -> !torch.int
%8 = torch.aten.append.t %3, %7 : !torch.list<int>, !torch.int -> !torch.list<int>
%9 = torch.aten.__range_length %int2, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
torch.prim.Loop %9, %true, init() {
^bb0(%arg9: !torch.int):
%9 = torch.aten.__derive_index %arg9, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
%10 = torch.prim.If %1 -> (!torch.int) {
%11 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
%12 = torch.aten.__getitem__.t %arg5, %11 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %12 : !torch.int
%10 = torch.aten.__derive_index %arg9, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
%11 = torch.prim.If %1 -> (!torch.int) {
%12 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int
%13 = torch.aten.__getitem__.t %arg5, %12 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %13 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
torch.prim.If %arg6 -> () {
%11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list<int>, !torch.int -> !torch.int
%12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int
%13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.__getitem__.t %arg0, %9 : !torch.list<int>, !torch.int -> !torch.int
%15 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int
%16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.__getitem__.t %arg3, %16 : !torch.list<int>, !torch.int -> !torch.int
%18 = torch.aten.mul.int %15, %17 : !torch.int, !torch.int -> !torch.int
%19 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
%20 = torch.aten.__getitem__.t %arg4, %19 : !torch.list<int>, !torch.int -> !torch.int
%21 = torch.aten.mul.int %int2, %20 : !torch.int, !torch.int -> !torch.int
%22 = torch.aten.sub.int %18, %21 : !torch.int, !torch.int -> !torch.int
%23 = torch.aten.add.int %22, %13 : !torch.int, !torch.int -> !torch.int
%24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int
%25 = torch.aten.append.t %3, %24 : !torch.list<int>, !torch.int -> !torch.list<int>
%12 = torch.aten.__getitem__.t %arg1, %10 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.sub.int %12, %int1 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.mul.int %11, %13 : !torch.int, !torch.int -> !torch.int
%15 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.sub.int %15, %int1 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int
%18 = torch.aten.__getitem__.t %arg3, %17 : !torch.list<int>, !torch.int -> !torch.int
%19 = torch.aten.mul.int %16, %18 : !torch.int, !torch.int -> !torch.int
%20 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int
%21 = torch.aten.__getitem__.t %arg4, %20 : !torch.list<int>, !torch.int -> !torch.int
%22 = torch.aten.mul.int %int2, %21 : !torch.int, !torch.int -> !torch.int
%23 = torch.aten.sub.int %19, %22 : !torch.int, !torch.int -> !torch.int
%24 = torch.aten.add.int %23, %14 : !torch.int, !torch.int -> !torch.int
%25 = torch.aten.add.int %24, %int1 : !torch.int, !torch.int -> !torch.int
%26 = torch.aten.append.t %3, %25 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
} else {
%11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list<int>, !torch.int -> !torch.int
%12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int
%13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int
%15 = torch.aten.__getitem__.t %arg0, %9 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
%17 = torch.aten.__getitem__.t %arg4, %16 : !torch.list<int>, !torch.int -> !torch.int
%18 = torch.aten.mul.int %int2, %17 : !torch.int, !torch.int -> !torch.int
%19 = torch.aten.add.int %15, %18 : !torch.int, !torch.int -> !torch.int
%20 = torch.aten.sub.int %19, %14 : !torch.int, !torch.int -> !torch.int
%21 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int
%22 = torch.aten.__getitem__.t %arg3, %21 : !torch.list<int>, !torch.int -> !torch.int
%23 = torch.aten.floordiv.int %20, %22 : !torch.int, !torch.int -> !torch.int
%24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int
%25 = torch.aten.append.t %3, %24 : !torch.list<int>, !torch.int -> !torch.list<int>
%12 = torch.aten.__getitem__.t %arg1, %10 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.sub.int %12, %int1 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.mul.int %11, %13 : !torch.int, !torch.int -> !torch.int
%15 = torch.aten.add.int %14, %int1 : !torch.int, !torch.int -> !torch.int
%16 = torch.aten.__getitem__.t %arg0, %10 : !torch.list<int>, !torch.int -> !torch.int
%17 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int
%18 = torch.aten.__getitem__.t %arg4, %17 : !torch.list<int>, !torch.int -> !torch.int
%19 = torch.aten.mul.int %int2, %18 : !torch.int, !torch.int -> !torch.int
%20 = torch.aten.add.int %16, %19 : !torch.int, !torch.int -> !torch.int
%21 = torch.aten.sub.int %20, %15 : !torch.int, !torch.int -> !torch.int
%22 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int
%23 = torch.aten.__getitem__.t %arg3, %22 : !torch.list<int>, !torch.int -> !torch.int
%24 = torch.aten.floordiv.int %21, %23 : !torch.int, !torch.int -> !torch.int
%25 = torch.aten.add.int %24, %int1 : !torch.int, !torch.int -> !torch.int
%26 = torch.aten.append.t %3, %25 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield
}
torch.prim.Loop.condition %true, iter()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ def aten〇bernoulli(self: List[int], generator: Any = None) -> List[int]:
def aten〇rand_like(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
return self

def aten〇arange〇start_step(start: float, end: float, step: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
def aten〇arange〇start_step(start: float, end: float, step: float = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return upstream_shape_functions.arange_start_step(start, end, step, dtype, layout, device, pin_memory)

def aten〇arange〇start(start: float, end: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
Expand Down