Skip to content

Commit

Permalink
Simplify lit test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
praveen-g-ctt committed Feb 5, 2025
1 parent 59ab6ba commit f8e7132
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 65 deletions.
54 changes: 26 additions & 28 deletions test/Conversion/TorchToLinalg/constraints.mlir
Original file line number Diff line number Diff line change
@@ -1,32 +1,30 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: func.func @torch.aten.sym_constrain_range(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>) -> !torch.int {
// CHECK: %[[VAL_1:.*]] = torch.constant.int 7
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch.constant.none
// CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_4]]
// CHECK: %[[VAL_6:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_7:.*]] = arith.constant 9223372036854775807 : i64
// CHECK: %[[VAL_8:.*]] = arith.cmpi sle, %[[VAL_6]], %[[VAL_5]] : i64
// CHECK: %[[VAL_9:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_7]] : i64
// CHECK: %[[VAL_10:.*]] = arith.andi %[[VAL_8]], %[[VAL_9]] : i1
// CHECK: cf.assert %[[VAL_10]], "Size constraint failed. Expected range: [0, 9223372036854775807]"
// CHECK: %[[VAL_11:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_12:.*]] = arith.constant 7 : i64
// CHECK: %[[VAL_13:.*]] = arith.cmpi sle, %[[VAL_11]], %[[VAL_5]] : i64
// CHECK: %[[VAL_14:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_12]] : i64
// CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_13]], %[[VAL_14]] : i1
// CHECK: cf.assert %[[VAL_15]], "Size constraint failed. Expected range: [0, 7]"
// CHECK: return %[[VAL_4]] : !torch.int

func.func @torch.aten.sym_constrain_range(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
%int7 = torch.constant.int 7
%int0 = torch.constant.int 0
%none = torch.constant.none
%0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int
torch.aten.sym_constrain_range %0, %int0, %none : !torch.int, !torch.int, !torch.none
torch.aten.sym_constrain_range %0, %int0, %int7 : !torch.int, !torch.int, !torch.int
return %0 : !torch.int
// CHECK-SAME: %[[VAL_0:.*]]: !torch.int) -> !torch.int {
// CHECK: %[[VAL_1:.*]] = torch_c.to_i64 %[[VAL_0]]
// CHECK: %[[VAL_2:.*]] = torch.constant.int 7
// CHECK: %[[VAL_3:.*]] = torch.constant.int 0
// CHECK: %[[VAL_4:.*]] = torch.constant.none
// CHECK: %[[VAL_5:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_6:.*]] = arith.constant 9223372036854775807 : i64
// CHECK: %[[VAL_7:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_1]] : i64
// CHECK: %[[VAL_8:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_6]] : i64
// CHECK: %[[VAL_9:.*]] = arith.andi %[[VAL_7]], %[[VAL_8]] : i1
// CHECK: cf.assert %[[VAL_9]], "Size constraint failed. Expected range: [0, 9223372036854775807]"
// CHECK: %[[VAL_10:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_11:.*]] = arith.constant 7 : i64
// CHECK: %[[VAL_12:.*]] = arith.cmpi sle, %[[VAL_10]], %[[VAL_1]] : i64
// CHECK: %[[VAL_13:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_11]] : i64
// CHECK: %[[VAL_14:.*]] = arith.andi %[[VAL_12]], %[[VAL_13]] : i1
// CHECK: cf.assert %[[VAL_14]], "Size constraint failed. Expected range: [0, 7]"
// CHECK: return %[[VAL_0]] : !torch.int
// CHECK: }
func.func @torch.aten.sym_constrain_range(%arg0: !torch.int) -> !torch.int {
%int7 = torch.constant.int 7
%int0 = torch.constant.int 0
%none = torch.constant.none
torch.aten.sym_constrain_range %arg0, %int0, %none : !torch.int, !torch.int, !torch.none
torch.aten.sym_constrain_range %arg0, %int0, %int7 : !torch.int, !torch.int, !torch.int
return %arg0 : !torch.int
}
71 changes: 34 additions & 37 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -232,52 +232,49 @@ func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>)
// -----

// CHECK-LABEL: func.func @torch.aten.sym_constrain_range_for_size(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>) -> !torch.int {
// CHECK-SAME: %[[VAL_0:.*]]: !torch.int) -> !torch.int {
// CHECK: %[[VAL_1:.*]] = torch.constant.int 7
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch.constant.none
// CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: torch.aten.sym_constrain_range %[[VAL_4]], %[[VAL_2]], %[[VAL_3]] : !torch.int, !torch.int, !torch.none
// CHECK: torch.aten.sym_constrain_range %[[VAL_4]], %[[VAL_2]], %[[VAL_1]] : !torch.int, !torch.int, !torch.int
// CHECK: return %[[VAL_4]] : !torch.int
func.func @torch.aten.sym_constrain_range_for_size(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
%0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int
%none = torch.constant.none
%none_0 = torch.constant.none
torch.aten.sym_constrain_range_for_size %0, %none, %none_0 : !torch.int, !torch.none, !torch.none
%int0_6 = torch.constant.int 0
%int7_7 = torch.constant.int 7
torch.aten.sym_constrain_range_for_size %0, %int0_6, %int7_7 : !torch.int, !torch.int, !torch.int
return %0 : !torch.int
// CHECK: torch.aten.sym_constrain_range %[[VAL_0]], %[[VAL_2]], %[[VAL_3]] : !torch.int, !torch.int, !torch.none
// CHECK: torch.aten.sym_constrain_range %[[VAL_0]], %[[VAL_2]], %[[VAL_1]] : !torch.int, !torch.int, !torch.int
// CHECK: return %[[VAL_0]] : !torch.int
// CHECK: }
func.func @torch.aten.sym_constrain_range_for_size(%arg0: !torch.int) -> !torch.int {
%int7 = torch.constant.int 7
%int0 = torch.constant.int 0
%none = torch.constant.none
torch.aten.sym_constrain_range_for_size %arg0, %none, %none : !torch.int, !torch.none, !torch.none
torch.aten.sym_constrain_range_for_size %arg0, %int0, %int7 : !torch.int, !torch.int, !torch.int
return %arg0 : !torch.int
}

// -----

// CHECK-LABEL: func.func @torch.aten._assert_scalar(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>) -> !torch.int {
// CHECK-SAME: %[[VAL_0:.*]]: !torch.int) -> !torch.int {
// CHECK: %[[VAL_1:.*]] = torch.constant.int 2
// CHECK: %[[VAL_2:.*]] = torch.constant.int 3
// CHECK: %[[VAL_3:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[VAL_4:.*]] = torch.aten.ge.int %[[VAL_3]], %[[VAL_2]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[VAL_5:.*]] = torch.aten.Int.bool %[[VAL_4]] : !torch.bool -> !torch.int
// CHECK: %[[VAL_6:.*]] = torch.aten.Bool.int %[[VAL_5]] : !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_6]], "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'"
// CHECK: %[[VAL_7:.*]] = torch.aten.gt.int %[[VAL_3]], %[[VAL_1]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[VAL_8:.*]] = torch.aten.Int.bool %[[VAL_7]] : !torch.bool -> !torch.int
// CHECK: %[[VAL_9:.*]] = torch.aten.Bool.int %[[VAL_8]] : !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_9]], "Runtime assertion failed for expression 2 < u0 on node 'gt_1'"
// CHECK: return %[[VAL_3]] : !torch.int
func.func @torch.aten._assert_scalar(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
%0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int
%int3 = torch.constant.int 3
%1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool
%2 = torch.aten.Int.bool %1 : !torch.bool -> !torch.int
%str = torch.constant.str "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'"
torch.aten._assert_scalar %2, %str : !torch.int, !torch.str
// CHECK: %[[VAL_3:.*]] = torch.aten.ge.int %[[VAL_0]], %[[VAL_2]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[VAL_4:.*]] = torch.aten.Int.bool %[[VAL_3]] : !torch.bool -> !torch.int
// CHECK: %[[VAL_5:.*]] = torch.aten.Bool.int %[[VAL_4]] : !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_5]], "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'"
// CHECK: %[[VAL_6:.*]] = torch.aten.gt.int %[[VAL_0]], %[[VAL_1]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[VAL_7:.*]] = torch.aten.Int.bool %[[VAL_6]] : !torch.bool -> !torch.int
// CHECK: %[[VAL_8:.*]] = torch.aten.Bool.int %[[VAL_7]] : !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_8]], "Runtime assertion failed for expression 2 < u0 on node 'gt_1'"
// CHECK: return %[[VAL_0]] : !torch.int
// CHECK: }
func.func @torch.aten._assert_scalar(%arg0: !torch.int) -> !torch.int {
%str = torch.constant.str "Runtime assertion failed for expression 2 < u0 on node 'gt_1'"
%int2 = torch.constant.int 2
%3 = torch.aten.gt.int %0, %int2 : !torch.int, !torch.int -> !torch.bool
%4 = torch.aten.Int.bool %3 : !torch.bool -> !torch.int
%str_0 = torch.constant.str "Runtime assertion failed for expression 2 < u0 on node 'gt_1'"
torch.aten._assert_scalar %4, %str_0 : !torch.int, !torch.str
return %0 : !torch.int
%str_0 = torch.constant.str "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'"
%int3 = torch.constant.int 3
%0 = torch.aten.ge.int %arg0, %int3 : !torch.int, !torch.int -> !torch.bool
%1 = torch.aten.Int.bool %0 : !torch.bool -> !torch.int
torch.aten._assert_scalar %1, %str_0 : !torch.int, !torch.str
%2 = torch.aten.gt.int %arg0, %int2 : !torch.int, !torch.int -> !torch.bool
%3 = torch.aten.Int.bool %2 : !torch.bool -> !torch.int
torch.aten._assert_scalar %3, %str : !torch.int, !torch.str
return %arg0 : !torch.int
}

0 comments on commit f8e7132

Please sign in to comment.