From a4382a04acd84b3fe4df92a56f7657b0d7273773 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Fri, 6 May 2022 17:28:04 -0400 Subject: [PATCH 01/13] Added e2e LTC Torch MLIR tests --- e2e_testing/lazy_tensor_core/bert.mlir | 1324 +++++++++++++++++ e2e_testing/lazy_tensor_core/main.py | 35 + e2e_testing/lazy_tensor_core/mnist.mlir | 55 + .../ltc_backend/csrc/backend/backend_impl.cpp | 25 +- .../ltc_backend/csrc/backend/backend_impl.h | 2 + .../csrc/example_mlir_backend_pybind.cpp | 11 + examples/ltc_backend_bert.py | 72 +- examples/ltc_backend_mnist.py | 50 +- .../mlir_lowering_context.cpp | 27 +- .../base_lazy_backend/mlir_lowering_context.h | 2 + 10 files changed, 1524 insertions(+), 79 deletions(-) create mode 100644 e2e_testing/lazy_tensor_core/bert.mlir create mode 100644 e2e_testing/lazy_tensor_core/main.py create mode 100644 e2e_testing/lazy_tensor_core/mnist.mlir diff --git a/e2e_testing/lazy_tensor_core/bert.mlir b/e2e_testing/lazy_tensor_core/bert.mlir new file mode 100644 index 00000000000..d057c034b6a --- /dev/null +++ b/e2e_testing/lazy_tensor_core/bert.mlir @@ -0,0 +1,1324 @@ +func.func @graph(%arg0: !torch.float, %arg1: !torch.int, %arg2: !torch.vtensor<[],f64>, %arg3: !torch.float, %arg4: !torch.int, %arg5: !torch.vtensor<[2,512],si64>, %arg6: !torch.vtensor<[32],f32>, %arg7: !torch.vtensor<[32],f32>, %arg8: !torch.int, %arg9: !torch.vtensor<[1,512],si64>, %arg10: !torch.vtensor<[512,32],f32>, %arg11: !torch.int, %arg12: !torch.vtensor<[2,512],si64>, %arg13: !torch.vtensor<[2,32],f32>, %arg14: !torch.vtensor<[28996,32],f32>, %arg15: !torch.int, %arg16: !torch.vtensor<[32,32],f32>, %arg17: !torch.int, %arg18: !torch.int, %arg19: !torch.vtensor<[32,32],f32>, %arg20: !torch.vtensor<[32],f32>, %arg21: !torch.vtensor<[],f64>, %arg22: !torch.int, %arg23: !torch.vtensor<[],f64>, %arg24: !torch.int, %arg25: !torch.float, %arg26: !torch.vtensor<[2,512],si64>, %arg27: !torch.int, %arg28: !torch.int, %arg29: !torch.vtensor<[32],f32>, %arg30: !torch.int, %arg31: !torch.int, %arg32: !torch.vtensor<[32,32],f32>, %arg33: !torch.vtensor<[32],f32>, %arg34: !torch.vtensor<[32,32],f32>, %arg35: !torch.vtensor<[32],f32>, %arg36: !torch.vtensor<[32],f32>, %arg37: !torch.int, %arg38: !torch.int, %arg39: !torch.int, %arg40: !torch.vtensor<[32],f32>, %arg41: !torch.int, %arg42: !torch.vtensor<[32,32],f32>, %arg43: !torch.int, %arg44: !torch.int, %arg45: !torch.vtensor<[32],f32>, %arg46: !torch.vtensor<[32,32],f32>, %arg47: !torch.vtensor<[32],f32>, %arg48: !torch.vtensor<[32],f32>, %arg49: !torch.int, %arg50: !torch.int, %arg51: !torch.int, %arg52: !torch.vtensor<[32],f32>, %arg53: !torch.vtensor<[2,512,32],f32>, %arg54: !torch.int, %arg55: !torch.int, %arg56: !torch.vtensor<[28996,32],f32>, %arg57: !torch.vtensor<[],f64>, %arg58: !torch.vtensor<[28996,32],f32>, %arg59: !torch.float, %arg60: !torch.vtensor<[],f64>, %arg61: !torch.vtensor<[28996,32],f32>, %arg62: !torch.float, %arg63: !torch.int, %arg64: !torch.float, %arg65: !torch.int, %arg66: !torch.vtensor<[512,32],f32>, %arg67: !torch.vtensor<[512,32],f32>, %arg68: !torch.float, %arg69: !torch.vtensor<[512,32],f32>, %arg70: !torch.float, %arg71: !torch.int, %arg72: !torch.float, %arg73: !torch.int, %arg74: !torch.vtensor<[2,32],f32>, %arg75: !torch.vtensor<[2,32],f32>, %arg76: !torch.float, %arg77: !torch.vtensor<[2,32],f32>, %arg78: !torch.float, %arg79: !torch.int, %arg80: !torch.float, %arg81: !torch.int, %arg82: !torch.vtensor<[32],f32>, %arg83: !torch.vtensor<[32],f32>, %arg84: !torch.float, %arg85: !torch.vtensor<[32],f32>, %arg86: !torch.float, %arg87: !torch.int, %arg88: !torch.float, %arg89: !torch.int, %arg90: !torch.vtensor<[32],f32>, %arg91: !torch.vtensor<[32],f32>, %arg92: !torch.float, %arg93: !torch.vtensor<[32],f32>, %arg94: !torch.float, %arg95: !torch.int, %arg96: !torch.float, %arg97: !torch.int, %arg98: !torch.vtensor<[32,32],f32>, %arg99: !torch.vtensor<[32,32],f32>, %arg100: !torch.float, %arg101: !torch.vtensor<[32,32],f32>, %arg102: !torch.float, %arg103: !torch.int, %arg104: !torch.float, %arg105: !torch.int, %arg106: !torch.vtensor<[32],f32>, %arg107: !torch.vtensor<[32],f32>, %arg108: !torch.float, %arg109: !torch.vtensor<[32],f32>, %arg110: !torch.float, %arg111: !torch.int, %arg112: !torch.float, %arg113: !torch.int, %arg114: !torch.vtensor<[32,32],f32>, %arg115: !torch.vtensor<[32,32],f32>, %arg116: !torch.float, %arg117: !torch.vtensor<[32,32],f32>, %arg118: !torch.float, %arg119: !torch.int, %arg120: !torch.float, %arg121: !torch.int, %arg122: !torch.vtensor<[32],f32>, %arg123: !torch.vtensor<[32],f32>, %arg124: !torch.float, %arg125: !torch.vtensor<[32],f32>, %arg126: !torch.float, %arg127: !torch.int, %arg128: !torch.float, %arg129: !torch.int, %arg130: !torch.vtensor<[32,32],f32>, %arg131: !torch.vtensor<[32,32],f32>, %arg132: !torch.float, %arg133: !torch.vtensor<[32,32],f32>, %arg134: !torch.float, %arg135: !torch.int, %arg136: !torch.float, %arg137: !torch.int, %arg138: !torch.vtensor<[32],f32>, %arg139: !torch.vtensor<[32],f32>, %arg140: !torch.float, %arg141: !torch.vtensor<[32],f32>, %arg142: !torch.float, %arg143: !torch.int, %arg144: !torch.float, %arg145: !torch.int, %arg146: !torch.vtensor<[32,32],f32>, %arg147: !torch.vtensor<[32,32],f32>, %arg148: !torch.float, %arg149: !torch.vtensor<[32,32],f32>, %arg150: !torch.float, %arg151: !torch.int, %arg152: !torch.float, %arg153: !torch.int, %arg154: !torch.vtensor<[32],f32>, %arg155: !torch.vtensor<[32],f32>, %arg156: !torch.float, %arg157: !torch.vtensor<[32],f32>, %arg158: !torch.float, %arg159: !torch.int, %arg160: !torch.float, %arg161: !torch.int, %arg162: !torch.vtensor<[32],f32>, %arg163: !torch.vtensor<[32],f32>, %arg164: !torch.float, %arg165: !torch.vtensor<[32],f32>, %arg166: !torch.float, %arg167: !torch.int, %arg168: !torch.float, %arg169: !torch.int, %arg170: !torch.vtensor<[32],f32>, %arg171: !torch.vtensor<[32],f32>, %arg172: !torch.float, %arg173: !torch.vtensor<[32],f32>, %arg174: !torch.float, %arg175: !torch.int, %arg176: !torch.float, %arg177: !torch.int, %arg178: !torch.vtensor<[32,32],f32>, %arg179: !torch.vtensor<[32,32],f32>, %arg180: !torch.float, %arg181: !torch.vtensor<[32,32],f32>, %arg182: !torch.float, %arg183: !torch.int, %arg184: !torch.float, %arg185: !torch.int, %arg186: !torch.vtensor<[32],f32>, %arg187: !torch.vtensor<[32],f32>, %arg188: !torch.float, %arg189: !torch.vtensor<[32],f32>, %arg190: !torch.float, %arg191: !torch.int, %arg192: !torch.float, %arg193: !torch.int, %arg194: !torch.vtensor<[32,32],f32>, %arg195: !torch.vtensor<[32,32],f32>, %arg196: !torch.float, %arg197: !torch.vtensor<[32,32],f32>, %arg198: !torch.float, %arg199: !torch.int, %arg200: !torch.float, %arg201: !torch.int, %arg202: !torch.vtensor<[32],f32>, %arg203: !torch.vtensor<[32],f32>, %arg204: !torch.float, %arg205: !torch.vtensor<[32],f32>, %arg206: !torch.float, %arg207: !torch.int, %arg208: !torch.float, %arg209: !torch.int, %arg210: !torch.vtensor<[32],f32>, %arg211: !torch.vtensor<[32],f32>, %arg212: !torch.float, %arg213: !torch.vtensor<[32],f32>, %arg214: !torch.float, %arg215: !torch.int, %arg216: !torch.float, %arg217: !torch.int, %arg218: !torch.vtensor<[32],f32>, %arg219: !torch.vtensor<[32],f32>, %arg220: !torch.float, %arg221: !torch.vtensor<[32],f32>, %arg222: !torch.float, %arg223: !torch.int, %arg224: !torch.float, %arg225: !torch.int, %arg226: !torch.int, %arg227: !torch.int, %arg228: !torch.vtensor<[32,32],f32>, %arg229: !torch.vtensor<[32],f32>, %arg230: !torch.vtensor<[2,32],f32>, %arg231: !torch.int, %arg232: !torch.int, %arg233: !torch.vtensor<[2],f32>, %arg234: !torch.vtensor<[2],si64>, %arg235: !torch.vtensor<[],f32>, %arg236: !torch.vtensor<[32,32],f32>, %arg237: !torch.vtensor<[32,32],f32>, %arg238: !torch.float, %arg239: !torch.vtensor<[32,32],f32>, %arg240: !torch.float, %arg241: !torch.int, %arg242: !torch.float, %arg243: !torch.int, %arg244: !torch.vtensor<[32],f32>, %arg245: !torch.vtensor<[32],f32>, %arg246: !torch.float, %arg247: !torch.vtensor<[32],f32>, %arg248: !torch.float, %arg249: !torch.int, %arg250: !torch.float, %arg251: !torch.int, %arg252: !torch.vtensor<[2,32],f32>, %arg253: !torch.vtensor<[2,32],f32>, %arg254: !torch.float, %arg255: !torch.vtensor<[2,32],f32>, %arg256: !torch.float, %arg257: !torch.int, %arg258: !torch.float, %arg259: !torch.int, %arg260: !torch.vtensor<[2],f32>, %arg261: !torch.vtensor<[2],f32>, %arg262: !torch.float, %arg263: !torch.vtensor<[2],f32>) -> (!torch.vtensor<[28996,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[],f32>) { + %int0 = torch.constant.int 0 + %int0_0 = torch.constant.int 0 + %int9223372036854775807 = torch.constant.int 9223372036854775807 + %int1 = torch.constant.int 1 + %0 = torch.aten.slice.Tensor %arg9, %int0, %int0_0, %int9223372036854775807, %int1 : !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,512],si64> + %int-1 = torch.constant.int -1 + %false = torch.constant.bool false + %false_1 = torch.constant.bool false + %1 = torch.aten.embedding %arg10, %0, %int-1, %false, %false_1 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[1,512],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,32],f32> + %int-1_2 = torch.constant.int -1 + %false_3 = torch.constant.bool false + %false_4 = torch.constant.bool false + %2 = torch.aten.embedding %arg13, %arg12, %int-1_2, %false_3, %false_4 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,512],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[2,512,32],f32> + %int0_5 = torch.constant.int 0 + %false_6 = torch.constant.bool false + %false_7 = torch.constant.bool false + %3 = torch.aten.embedding %arg14, %arg5, %int0_5, %false_6, %false_7 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[2,512],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[2,512,32],f32> + %4 = torch.aten.add.Tensor %3, %2, %arg11 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> + %5 = torch.aten.add.Tensor %4, %1, %arg8 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[1,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> + %int32 = torch.constant.int 32 + %6 = torch.prim.ListConstruct %int32 : (!torch.int) -> !torch.list + %float1.000000e-05 = torch.constant.float 1.000000e-05 + %result0, %result1, %result2 = torch.aten.native_layer_norm %5, %6, %arg7, %arg6, %float1.000000e-05 : !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> + %7 = torch.prim.TupleConstruct %result0, %result1, %result2 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> -> !torch.tuple, vtensor<[2,512,1],f32>, vtensor<[2,512,1],f32>> + %int1_8 = torch.constant.int 1 + %int0_9 = torch.constant.int 0 + %8 = torch.prim.ListConstruct %int1_8, %int0_9 : (!torch.int, !torch.int) -> !torch.list + %int1_10 = torch.constant.int 1 + %int0_11 = torch.constant.int 0 + %9 = torch.prim.ListConstruct %int1_10, %int0_11 : (!torch.int, !torch.int) -> !torch.list + %10 = torch.aten.permute %arg16, %9 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_12 = torch.constant.int 1 + %int0_13 = torch.constant.int 0 + %11 = torch.prim.ListConstruct %int1_12, %int0_13 : (!torch.int, !torch.int) -> !torch.list + %int1_14 = torch.constant.int 1 + %int0_15 = torch.constant.int 0 + %12 = torch.prim.ListConstruct %int1_14, %int0_15 : (!torch.int, !torch.int) -> !torch.list + %13 = torch.aten.permute %10, %12 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_16 = torch.constant.int 1 + %int0_17 = torch.constant.int 0 + %14 = torch.prim.ListConstruct %int1_16, %int0_17 : (!torch.int, !torch.int) -> !torch.list + %int1_18 = torch.constant.int 1 + %int0_19 = torch.constant.int 0 + %15 = torch.prim.ListConstruct %int1_18, %int0_19 : (!torch.int, !torch.int) -> !torch.list + %16 = torch.aten.permute %arg19, %15 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1024 = torch.constant.int 1024 + %int32_20 = torch.constant.int 32 + %17 = torch.prim.ListConstruct %int1024, %int32_20 : (!torch.int, !torch.int) -> !torch.list + %int1024_21 = torch.constant.int 1024 + %int32_22 = torch.constant.int 32 + %18 = torch.prim.ListConstruct %int1024_21, %int32_22 : (!torch.int, !torch.int) -> !torch.list + %19 = torch.aten.reshape %result0, %18 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %20 = torch.aten.addmm %arg20, %19, %16, %arg18, %arg17 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> + %int2 = torch.constant.int 2 + %int512 = torch.constant.int 512 + %int32_23 = torch.constant.int 32 + %21 = torch.prim.ListConstruct %int2, %int512, %int32_23 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_24 = torch.constant.int 2 + %int512_25 = torch.constant.int 512 + %int32_26 = torch.constant.int 32 + %22 = torch.prim.ListConstruct %int2_24, %int512_25, %int32_26 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %23 = torch.aten.reshape %20, %22 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int2_27 = torch.constant.int 2 + %int512_28 = torch.constant.int 512 + %int16 = torch.constant.int 16 + %24 = torch.prim.ListConstruct %int2_27, %int512_28, %int2_27, %int16 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_29 = torch.constant.int 2 + %int512_30 = torch.constant.int 512 + %int16_31 = torch.constant.int 16 + %25 = torch.prim.ListConstruct %int2_29, %int512_30, %int2_29, %int16_31 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %26 = torch.aten.reshape %23, %25 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> + %int0_32 = torch.constant.int 0 + %int2_33 = torch.constant.int 2 + %int1_34 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %27 = torch.prim.ListConstruct %int0_32, %int2_33, %int1_34, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_35 = torch.constant.int 0 + %int2_36 = torch.constant.int 2 + %int1_37 = torch.constant.int 1 + %int3_38 = torch.constant.int 3 + %28 = torch.prim.ListConstruct %int0_35, %int2_36, %int1_37, %int3_38 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %29 = torch.aten.permute %26, %28 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> + %int-1_39 = torch.constant.int -1 + %int-2 = torch.constant.int -2 + %30 = torch.aten.transpose.int %29, %int-1_39, %int-2 : !torch.vtensor<[2,2,512,16],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,2,16,512],f32> + %int2_40 = torch.constant.int 2 + %int16_41 = torch.constant.int 16 + %int512_42 = torch.constant.int 512 + %31 = torch.prim.ListConstruct %int2_40, %int2_40, %int16_41, %int512_42 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_43 = torch.constant.bool false + %32 = torch.aten.expand %30, %31, %false_43 : !torch.vtensor<[2,2,16,512],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,2,16,512],f32> + %int4 = torch.constant.int 4 + %int16_44 = torch.constant.int 16 + %int512_45 = torch.constant.int 512 + %33 = torch.prim.ListConstruct %int4, %int16_44, %int512_45 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int4_46 = torch.constant.int 4 + %int16_47 = torch.constant.int 16 + %int512_48 = torch.constant.int 512 + %34 = torch.prim.ListConstruct %int4_46, %int16_47, %int512_48 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %35 = torch.aten.reshape %32, %34 : !torch.vtensor<[2,2,16,512],f32>, !torch.list -> !torch.vtensor<[4,16,512],f32> + %int1_49 = torch.constant.int 1 + %int2_50 = torch.constant.int 2 + %36 = torch.aten.transpose.int %35, %int1_49, %int2_50 : !torch.vtensor<[4,16,512],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,512,16],f32> + %int0_51 = torch.constant.int 0 + %int0_52 = torch.constant.int 0 + %int9223372036854775807_53 = torch.constant.int 9223372036854775807 + %int1_54 = torch.constant.int 1 + %37 = torch.aten.slice.Tensor %arg26, %int0_51, %int0_52, %int9223372036854775807_53, %int1_54 : !torch.vtensor<[2,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,512],si64> + %int2_55 = torch.constant.int 2 + %int1_56 = torch.constant.int 1 + %int512_57 = torch.constant.int 512 + %38 = torch.prim.ListConstruct %int2_55, %int1_56, %int512_57 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_58 = torch.constant.int 2 + %int1_59 = torch.constant.int 1 + %int512_60 = torch.constant.int 512 + %39 = torch.prim.ListConstruct %int2_58, %int1_59, %int512_60 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %40 = torch.aten.reshape %37, %39 : !torch.vtensor<[2,512],si64>, !torch.list -> !torch.vtensor<[2,1,512],si64> + %int2_61 = torch.constant.int 2 + %int1_62 = torch.constant.int 1 + %int512_63 = torch.constant.int 512 + %41 = torch.prim.ListConstruct %int2_61, %int1_62, %int1_62, %int512_63 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_64 = torch.constant.int 2 + %int1_65 = torch.constant.int 1 + %int512_66 = torch.constant.int 512 + %42 = torch.prim.ListConstruct %int2_64, %int1_65, %int1_65, %int512_66 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %43 = torch.aten.reshape %40, %42 : !torch.vtensor<[2,1,512],si64>, !torch.list -> !torch.vtensor<[2,1,1,512],si64> + %int3_67 = torch.constant.int 3 + %int0_68 = torch.constant.int 0 + %int9223372036854775807_69 = torch.constant.int 9223372036854775807 + %int1_70 = torch.constant.int 1 + %44 = torch.aten.slice.Tensor %43, %int3_67, %int0_68, %int9223372036854775807_69, %int1_70 : !torch.vtensor<[2,1,1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1,1,512],si64> + %int6 = torch.constant.int 6 + %none = torch.constant.none + %none_71 = torch.constant.none + %none_72 = torch.constant.none + %false_73 = torch.constant.bool false + %none_74 = torch.constant.none + %45 = torch.aten._to_copy %44, %int6, %none, %none_71, %none_72, %false_73, %none_74 : !torch.vtensor<[2,1,1,512],si64>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[2,1,1,512],f32> + %46 = torch.aten.rsub.Scalar %45, %arg25, %arg24 : !torch.vtensor<[2,1,1,512],f32>, !torch.float, !torch.int -> !torch.vtensor<[2,1,1,512],f32> + %47 = torch.aten.mul.Tensor %46, %arg23 : !torch.vtensor<[2,1,1,512],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,1,1,512],f32> + %int1_75 = torch.constant.int 1 + %int0_76 = torch.constant.int 0 + %48 = torch.prim.ListConstruct %int1_75, %int0_76 : (!torch.int, !torch.int) -> !torch.list + %int1_77 = torch.constant.int 1 + %int0_78 = torch.constant.int 0 + %49 = torch.prim.ListConstruct %int1_77, %int0_78 : (!torch.int, !torch.int) -> !torch.list + %50 = torch.aten.permute %arg16, %49 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1024_79 = torch.constant.int 1024 + %int32_80 = torch.constant.int 32 + %51 = torch.prim.ListConstruct %int1024_79, %int32_80 : (!torch.int, !torch.int) -> !torch.list + %int1024_81 = torch.constant.int 1024 + %int32_82 = torch.constant.int 32 + %52 = torch.prim.ListConstruct %int1024_81, %int32_82 : (!torch.int, !torch.int) -> !torch.list + %53 = torch.aten.reshape %result0, %52 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %54 = torch.aten.addmm %arg29, %53, %50, %arg28, %arg27 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> + %int2_83 = torch.constant.int 2 + %int512_84 = torch.constant.int 512 + %int32_85 = torch.constant.int 32 + %55 = torch.prim.ListConstruct %int2_83, %int512_84, %int32_85 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_86 = torch.constant.int 2 + %int512_87 = torch.constant.int 512 + %int32_88 = torch.constant.int 32 + %56 = torch.prim.ListConstruct %int2_86, %int512_87, %int32_88 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %57 = torch.aten.reshape %54, %56 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int2_89 = torch.constant.int 2 + %int512_90 = torch.constant.int 512 + %int16_91 = torch.constant.int 16 + %58 = torch.prim.ListConstruct %int2_89, %int512_90, %int2_89, %int16_91 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_92 = torch.constant.int 2 + %int512_93 = torch.constant.int 512 + %int16_94 = torch.constant.int 16 + %59 = torch.prim.ListConstruct %int2_92, %int512_93, %int2_92, %int16_94 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %60 = torch.aten.reshape %57, %59 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> + %int0_95 = torch.constant.int 0 + %int2_96 = torch.constant.int 2 + %int1_97 = torch.constant.int 1 + %int3_98 = torch.constant.int 3 + %61 = torch.prim.ListConstruct %int0_95, %int2_96, %int1_97, %int3_98 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_99 = torch.constant.int 0 + %int2_100 = torch.constant.int 2 + %int1_101 = torch.constant.int 1 + %int3_102 = torch.constant.int 3 + %62 = torch.prim.ListConstruct %int0_99, %int2_100, %int1_101, %int3_102 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %63 = torch.aten.permute %60, %62 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> + %int2_103 = torch.constant.int 2 + %int512_104 = torch.constant.int 512 + %int16_105 = torch.constant.int 16 + %64 = torch.prim.ListConstruct %int2_103, %int2_103, %int512_104, %int16_105 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_106 = torch.constant.bool false + %65 = torch.aten.expand %63, %64, %false_106 : !torch.vtensor<[2,2,512,16],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,2,512,16],f32> + %int4_107 = torch.constant.int 4 + %int512_108 = torch.constant.int 512 + %int16_109 = torch.constant.int 16 + %66 = torch.prim.ListConstruct %int4_107, %int512_108, %int16_109 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int4_110 = torch.constant.int 4 + %int512_111 = torch.constant.int 512 + %int16_112 = torch.constant.int 16 + %67 = torch.prim.ListConstruct %int4_110, %int512_111, %int16_112 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %68 = torch.aten.reshape %65, %67 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[4,512,16],f32> + %69 = torch.aten.bmm %68, %35 : !torch.vtensor<[4,512,16],f32>, !torch.vtensor<[4,16,512],f32> -> !torch.vtensor<[4,512,512],f32> + %int2_113 = torch.constant.int 2 + %int512_114 = torch.constant.int 512 + %70 = torch.prim.ListConstruct %int2_113, %int2_113, %int512_114, %int512_114 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_115 = torch.constant.int 2 + %int512_116 = torch.constant.int 512 + %71 = torch.prim.ListConstruct %int2_115, %int2_115, %int512_116, %int512_116 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %72 = torch.aten._unsafe_view %69, %71 : !torch.vtensor<[4,512,512],f32>, !torch.list -> !torch.vtensor<[2,2,512,512],f32> + %73 = torch.aten.div.Tensor %72, %arg21 : !torch.vtensor<[2,2,512,512],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,2,512,512],f32> + %74 = torch.aten.add.Tensor %73, %47, %arg22 : !torch.vtensor<[2,2,512,512],f32>, !torch.vtensor<[2,1,1,512],f32>, !torch.int -> !torch.vtensor<[2,2,512,512],f32> + %int-1_117 = torch.constant.int -1 + %false_118 = torch.constant.bool false + %75 = torch.aten._softmax %74, %int-1_117, %false_118 : !torch.vtensor<[2,2,512,512],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,2,512,512],f32> + %int1_119 = torch.constant.int 1 + %int0_120 = torch.constant.int 0 + %76 = torch.prim.ListConstruct %int1_119, %int0_120 : (!torch.int, !torch.int) -> !torch.list + %int1_121 = torch.constant.int 1 + %int0_122 = torch.constant.int 0 + %77 = torch.prim.ListConstruct %int1_121, %int0_122 : (!torch.int, !torch.int) -> !torch.list + %78 = torch.aten.permute %arg32, %77 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1024_123 = torch.constant.int 1024 + %int32_124 = torch.constant.int 32 + %79 = torch.prim.ListConstruct %int1024_123, %int32_124 : (!torch.int, !torch.int) -> !torch.list + %int1024_125 = torch.constant.int 1024 + %int32_126 = torch.constant.int 32 + %80 = torch.prim.ListConstruct %int1024_125, %int32_126 : (!torch.int, !torch.int) -> !torch.list + %81 = torch.aten.reshape %result0, %80 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %82 = torch.aten.addmm %arg33, %81, %78, %arg31, %arg30 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> + %int2_127 = torch.constant.int 2 + %int512_128 = torch.constant.int 512 + %int32_129 = torch.constant.int 32 + %83 = torch.prim.ListConstruct %int2_127, %int512_128, %int32_129 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_130 = torch.constant.int 2 + %int512_131 = torch.constant.int 512 + %int32_132 = torch.constant.int 32 + %84 = torch.prim.ListConstruct %int2_130, %int512_131, %int32_132 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %85 = torch.aten.reshape %82, %84 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int2_133 = torch.constant.int 2 + %int512_134 = torch.constant.int 512 + %int16_135 = torch.constant.int 16 + %86 = torch.prim.ListConstruct %int2_133, %int512_134, %int2_133, %int16_135 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_136 = torch.constant.int 2 + %int512_137 = torch.constant.int 512 + %int16_138 = torch.constant.int 16 + %87 = torch.prim.ListConstruct %int2_136, %int512_137, %int2_136, %int16_138 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %88 = torch.aten.reshape %85, %87 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> + %int0_139 = torch.constant.int 0 + %int2_140 = torch.constant.int 2 + %int1_141 = torch.constant.int 1 + %int3_142 = torch.constant.int 3 + %89 = torch.prim.ListConstruct %int0_139, %int2_140, %int1_141, %int3_142 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_143 = torch.constant.int 0 + %int2_144 = torch.constant.int 2 + %int1_145 = torch.constant.int 1 + %int3_146 = torch.constant.int 3 + %90 = torch.prim.ListConstruct %int0_143, %int2_144, %int1_145, %int3_146 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %91 = torch.aten.permute %88, %90 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> + %int2_147 = torch.constant.int 2 + %int512_148 = torch.constant.int 512 + %int16_149 = torch.constant.int 16 + %92 = torch.prim.ListConstruct %int2_147, %int2_147, %int512_148, %int16_149 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_150 = torch.constant.bool false + %93 = torch.aten.expand %91, %92, %false_150 : !torch.vtensor<[2,2,512,16],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,2,512,16],f32> + %int4_151 = torch.constant.int 4 + %int512_152 = torch.constant.int 512 + %int16_153 = torch.constant.int 16 + %94 = torch.prim.ListConstruct %int4_151, %int512_152, %int16_153 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int4_154 = torch.constant.int 4 + %int512_155 = torch.constant.int 512 + %int16_156 = torch.constant.int 16 + %95 = torch.prim.ListConstruct %int4_154, %int512_155, %int16_156 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %96 = torch.aten.reshape %93, %95 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[4,512,16],f32> + %int1_157 = torch.constant.int 1 + %int2_158 = torch.constant.int 2 + %97 = torch.aten.transpose.int %96, %int1_157, %int2_158 : !torch.vtensor<[4,512,16],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,16,512],f32> + %int1_159 = torch.constant.int 1 + %int0_160 = torch.constant.int 0 + %98 = torch.prim.ListConstruct %int1_159, %int0_160 : (!torch.int, !torch.int) -> !torch.list + %int1_161 = torch.constant.int 1 + %int0_162 = torch.constant.int 0 + %99 = torch.prim.ListConstruct %int1_161, %int0_162 : (!torch.int, !torch.int) -> !torch.list + %100 = torch.aten.permute %arg34, %99 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_163 = torch.constant.int 1 + %int0_164 = torch.constant.int 0 + %101 = torch.prim.ListConstruct %int1_163, %int0_164 : (!torch.int, !torch.int) -> !torch.list + %int1_165 = torch.constant.int 1 + %int0_166 = torch.constant.int 0 + %102 = torch.prim.ListConstruct %int1_165, %int0_166 : (!torch.int, !torch.int) -> !torch.list + %103 = torch.aten.permute %100, %102 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_167 = torch.constant.int 1 + %int0_168 = torch.constant.int 0 + %104 = torch.prim.ListConstruct %int1_167, %int0_168 : (!torch.int, !torch.int) -> !torch.list + %int1_169 = torch.constant.int 1 + %int0_170 = torch.constant.int 0 + %105 = torch.prim.ListConstruct %int1_169, %int0_170 : (!torch.int, !torch.int) -> !torch.list + %106 = torch.aten.permute %arg34, %105 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int2_171 = torch.constant.int 2 + %int512_172 = torch.constant.int 512 + %107 = torch.prim.ListConstruct %int2_171, %int2_171, %int512_172, %int512_172 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_173 = torch.constant.bool false + %108 = torch.aten.expand %75, %107, %false_173 : !torch.vtensor<[2,2,512,512],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,2,512,512],f32> + %int4_174 = torch.constant.int 4 + %int512_175 = torch.constant.int 512 + %109 = torch.prim.ListConstruct %int4_174, %int512_175, %int512_175 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int4_176 = torch.constant.int 4 + %int512_177 = torch.constant.int 512 + %110 = torch.prim.ListConstruct %int4_176, %int512_177, %int512_177 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %111 = torch.aten.reshape %108, %110 : !torch.vtensor<[2,2,512,512],f32>, !torch.list -> !torch.vtensor<[4,512,512],f32> + %112 = torch.aten.bmm %111, %96 : !torch.vtensor<[4,512,512],f32>, !torch.vtensor<[4,512,16],f32> -> !torch.vtensor<[4,512,16],f32> + %int2_178 = torch.constant.int 2 + %int512_179 = torch.constant.int 512 + %int16_180 = torch.constant.int 16 + %113 = torch.prim.ListConstruct %int2_178, %int2_178, %int512_179, %int16_180 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_181 = torch.constant.int 2 + %int512_182 = torch.constant.int 512 + %int16_183 = torch.constant.int 16 + %114 = torch.prim.ListConstruct %int2_181, %int2_181, %int512_182, %int16_183 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %115 = torch.aten._unsafe_view %112, %114 : !torch.vtensor<[4,512,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> + %int0_184 = torch.constant.int 0 + %int2_185 = torch.constant.int 2 + %int1_186 = torch.constant.int 1 + %int3_187 = torch.constant.int 3 + %116 = torch.prim.ListConstruct %int0_184, %int2_185, %int1_186, %int3_187 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_188 = torch.constant.int 0 + %int2_189 = torch.constant.int 2 + %int1_190 = torch.constant.int 1 + %int3_191 = torch.constant.int 3 + %117 = torch.prim.ListConstruct %int0_188, %int2_189, %int1_190, %int3_191 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %118 = torch.aten.permute %115, %117 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> + %int2_192 = torch.constant.int 2 + %int512_193 = torch.constant.int 512 + %int32_194 = torch.constant.int 32 + %119 = torch.prim.ListConstruct %int2_192, %int512_193, %int32_194 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_195 = torch.constant.int 2 + %int512_196 = torch.constant.int 512 + %int32_197 = torch.constant.int 32 + %120 = torch.prim.ListConstruct %int2_195, %int512_196, %int32_197 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %121 = torch.aten.reshape %118, %120 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int1024_198 = torch.constant.int 1024 + %int32_199 = torch.constant.int 32 + %122 = torch.prim.ListConstruct %int1024_198, %int32_199 : (!torch.int, !torch.int) -> !torch.list + %int1024_200 = torch.constant.int 1024 + %int32_201 = torch.constant.int 32 + %123 = torch.prim.ListConstruct %int1024_200, %int32_201 : (!torch.int, !torch.int) -> !torch.list + %124 = torch.aten.reshape %121, %123 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %125 = torch.aten.addmm %arg40, %124, %106, %arg39, %arg38 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> + %int2_202 = torch.constant.int 2 + %int512_203 = torch.constant.int 512 + %int32_204 = torch.constant.int 32 + %126 = torch.prim.ListConstruct %int2_202, %int512_203, %int32_204 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_205 = torch.constant.int 2 + %int512_206 = torch.constant.int 512 + %int32_207 = torch.constant.int 32 + %127 = torch.prim.ListConstruct %int2_205, %int512_206, %int32_207 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %128 = torch.aten.reshape %125, %127 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %129 = torch.aten.add.Tensor %128, %result0, %arg37 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> + %int32_208 = torch.constant.int 32 + %130 = torch.prim.ListConstruct %int32_208 : (!torch.int) -> !torch.list + %float1.000000e-05_209 = torch.constant.float 1.000000e-05 + %result0_210, %result1_211, %result2_212 = torch.aten.native_layer_norm %129, %130, %arg36, %arg35, %float1.000000e-05_209 : !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> + %131 = torch.prim.TupleConstruct %result0_210, %result1_211, %result2_212 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> -> !torch.tuple, vtensor<[2,512,1],f32>, vtensor<[2,512,1],f32>> + %int1_213 = torch.constant.int 1 + %int0_214 = torch.constant.int 0 + %132 = torch.prim.ListConstruct %int1_213, %int0_214 : (!torch.int, !torch.int) -> !torch.list + %int1_215 = torch.constant.int 1 + %int0_216 = torch.constant.int 0 + %133 = torch.prim.ListConstruct %int1_215, %int0_216 : (!torch.int, !torch.int) -> !torch.list + %134 = torch.aten.permute %arg42, %133 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_217 = torch.constant.int 1 + %int0_218 = torch.constant.int 0 + %135 = torch.prim.ListConstruct %int1_217, %int0_218 : (!torch.int, !torch.int) -> !torch.list + %int1_219 = torch.constant.int 1 + %int0_220 = torch.constant.int 0 + %136 = torch.prim.ListConstruct %int1_219, %int0_220 : (!torch.int, !torch.int) -> !torch.list + %137 = torch.aten.permute %134, %136 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_221 = torch.constant.int 1 + %int0_222 = torch.constant.int 0 + %138 = torch.prim.ListConstruct %int1_221, %int0_222 : (!torch.int, !torch.int) -> !torch.list + %int1_223 = torch.constant.int 1 + %int0_224 = torch.constant.int 0 + %139 = torch.prim.ListConstruct %int1_223, %int0_224 : (!torch.int, !torch.int) -> !torch.list + %140 = torch.aten.permute %arg42, %139 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1024_225 = torch.constant.int 1024 + %int32_226 = torch.constant.int 32 + %141 = torch.prim.ListConstruct %int1024_225, %int32_226 : (!torch.int, !torch.int) -> !torch.list + %int1024_227 = torch.constant.int 1024 + %int32_228 = torch.constant.int 32 + %142 = torch.prim.ListConstruct %int1024_227, %int32_228 : (!torch.int, !torch.int) -> !torch.list + %143 = torch.aten.reshape %result0_210, %142 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %144 = torch.aten.addmm %arg45, %143, %140, %arg44, %arg43 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> + %int2_229 = torch.constant.int 2 + %int512_230 = torch.constant.int 512 + %int32_231 = torch.constant.int 32 + %145 = torch.prim.ListConstruct %int2_229, %int512_230, %int32_231 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_232 = torch.constant.int 2 + %int512_233 = torch.constant.int 512 + %int32_234 = torch.constant.int 32 + %146 = torch.prim.ListConstruct %int2_232, %int512_233, %int32_234 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %147 = torch.aten.reshape %144, %146 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int1_235 = torch.constant.int 1 + %int0_236 = torch.constant.int 0 + %148 = torch.prim.ListConstruct %int1_235, %int0_236 : (!torch.int, !torch.int) -> !torch.list + %int1_237 = torch.constant.int 1 + %int0_238 = torch.constant.int 0 + %149 = torch.prim.ListConstruct %int1_237, %int0_238 : (!torch.int, !torch.int) -> !torch.list + %150 = torch.aten.permute %arg46, %149 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_239 = torch.constant.int 1 + %int0_240 = torch.constant.int 0 + %151 = torch.prim.ListConstruct %int1_239, %int0_240 : (!torch.int, !torch.int) -> !torch.list + %int1_241 = torch.constant.int 1 + %int0_242 = torch.constant.int 0 + %152 = torch.prim.ListConstruct %int1_241, %int0_242 : (!torch.int, !torch.int) -> !torch.list + %153 = torch.aten.permute %150, %152 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_243 = torch.constant.int 1 + %int0_244 = torch.constant.int 0 + %154 = torch.prim.ListConstruct %int1_243, %int0_244 : (!torch.int, !torch.int) -> !torch.list + %int1_245 = torch.constant.int 1 + %int0_246 = torch.constant.int 0 + %155 = torch.prim.ListConstruct %int1_245, %int0_246 : (!torch.int, !torch.int) -> !torch.list + %156 = torch.aten.permute %arg46, %155 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %str = torch.constant.str "none" + %157 = torch.aten.gelu %147, %str : !torch.vtensor<[2,512,32],f32>, !torch.str -> !torch.vtensor<[2,512,32],f32> + %int1024_247 = torch.constant.int 1024 + %int32_248 = torch.constant.int 32 + %158 = torch.prim.ListConstruct %int1024_247, %int32_248 : (!torch.int, !torch.int) -> !torch.list + %int1024_249 = torch.constant.int 1024 + %int32_250 = torch.constant.int 32 + %159 = torch.prim.ListConstruct %int1024_249, %int32_250 : (!torch.int, !torch.int) -> !torch.list + %160 = torch.aten.reshape %157, %159 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %161 = torch.aten.addmm %arg52, %160, %156, %arg51, %arg50 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> + %int2_251 = torch.constant.int 2 + %int512_252 = torch.constant.int 512 + %int32_253 = torch.constant.int 32 + %162 = torch.prim.ListConstruct %int2_251, %int512_252, %int32_253 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_254 = torch.constant.int 2 + %int512_255 = torch.constant.int 512 + %int32_256 = torch.constant.int 32 + %163 = torch.prim.ListConstruct %int2_254, %int512_255, %int32_256 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %164 = torch.aten.reshape %161, %163 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %165 = torch.aten.add.Tensor %164, %result0_210, %arg49 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> + %int32_257 = torch.constant.int 32 + %166 = torch.prim.ListConstruct %int32_257 : (!torch.int) -> !torch.list + %float1.000000e-05_258 = torch.constant.float 1.000000e-05 + %result0_259, %result1_260, %result2_261 = torch.aten.native_layer_norm %165, %166, %arg48, %arg47, %float1.000000e-05_258 : !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> + %167 = torch.prim.TupleConstruct %result0_259, %result1_260, %result2_261 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> -> !torch.tuple, vtensor<[2,512,1],f32>, vtensor<[2,512,1],f32>> + %168 = torch.aten.zero.functional %arg53 : !torch.vtensor<[2,512,32],f32> -> !torch.vtensor<[2,512,32],f32> + %int32_262 = torch.constant.int 32 + %169 = torch.prim.ListConstruct %int32_262 : (!torch.int) -> !torch.list + %true = torch.constant.bool true + %170 = torch.prim.ListConstruct %true, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0_263, %result1_264, %result2_265 = torch.aten.native_layer_norm_backward %168, %165, %169, %result1_260, %result2_261, %arg48, %arg47, %170 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> + %171 = torch.prim.TupleConstruct %result0_263, %result1_264, %result2_265 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> -> !torch.tuple, vtensor<[32],f32>, vtensor<[32],f32>> + %int1024_266 = torch.constant.int 1024 + %int32_267 = torch.constant.int 32 + %172 = torch.prim.ListConstruct %int1024_266, %int32_267 : (!torch.int, !torch.int) -> !torch.list + %int1024_268 = torch.constant.int 1024 + %int32_269 = torch.constant.int 32 + %173 = torch.prim.ListConstruct %int1024_268, %int32_269 : (!torch.int, !torch.int) -> !torch.list + %174 = torch.aten.reshape %result0_263, %173 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %175 = torch.aten.mm %174, %153 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> + %int2_270 = torch.constant.int 2 + %int512_271 = torch.constant.int 512 + %int32_272 = torch.constant.int 32 + %176 = torch.prim.ListConstruct %int2_270, %int512_271, %int32_272 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_273 = torch.constant.int 2 + %int512_274 = torch.constant.int 512 + %int32_275 = torch.constant.int 32 + %177 = torch.prim.ListConstruct %int2_273, %int512_274, %int32_275 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %178 = torch.aten.reshape %175, %177 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %str_276 = torch.constant.str "none" + %179 = torch.aten.gelu_backward %178, %147, %str_276 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.str -> !torch.vtensor<[2,512,32],f32> + %int1024_277 = torch.constant.int 1024 + %int32_278 = torch.constant.int 32 + %180 = torch.prim.ListConstruct %int1024_277, %int32_278 : (!torch.int, !torch.int) -> !torch.list + %int1024_279 = torch.constant.int 1024 + %int32_280 = torch.constant.int 32 + %181 = torch.prim.ListConstruct %int1024_279, %int32_280 : (!torch.int, !torch.int) -> !torch.list + %182 = torch.aten.reshape %179, %181 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %183 = torch.aten.mm %182, %137 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> + %int2_281 = torch.constant.int 2 + %int512_282 = torch.constant.int 512 + %int32_283 = torch.constant.int 32 + %184 = torch.prim.ListConstruct %int2_281, %int512_282, %int32_283 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_284 = torch.constant.int 2 + %int512_285 = torch.constant.int 512 + %int32_286 = torch.constant.int 32 + %185 = torch.prim.ListConstruct %int2_284, %int512_285, %int32_286 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %186 = torch.aten.reshape %183, %185 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %187 = torch.aten.add.Tensor %result0_263, %186, %arg41 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> + %int32_287 = torch.constant.int 32 + %188 = torch.prim.ListConstruct %int32_287 : (!torch.int) -> !torch.list + %true_288 = torch.constant.bool true + %189 = torch.prim.ListConstruct %true_288, %true_288, %true_288 : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0_289, %result1_290, %result2_291 = torch.aten.native_layer_norm_backward %187, %129, %188, %result1_211, %result2_212, %arg36, %arg35, %189 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> + %190 = torch.prim.TupleConstruct %result0_289, %result1_290, %result2_291 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> -> !torch.tuple, vtensor<[32],f32>, vtensor<[32],f32>> + %int1024_292 = torch.constant.int 1024 + %int32_293 = torch.constant.int 32 + %191 = torch.prim.ListConstruct %int1024_292, %int32_293 : (!torch.int, !torch.int) -> !torch.list + %int1024_294 = torch.constant.int 1024 + %int32_295 = torch.constant.int 32 + %192 = torch.prim.ListConstruct %int1024_294, %int32_295 : (!torch.int, !torch.int) -> !torch.list + %193 = torch.aten.reshape %result0_289, %192 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %194 = torch.aten.mm %193, %103 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> + %int2_296 = torch.constant.int 2 + %int512_297 = torch.constant.int 512 + %int32_298 = torch.constant.int 32 + %195 = torch.prim.ListConstruct %int2_296, %int512_297, %int32_298 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_299 = torch.constant.int 2 + %int512_300 = torch.constant.int 512 + %int32_301 = torch.constant.int 32 + %196 = torch.prim.ListConstruct %int2_299, %int512_300, %int32_301 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %197 = torch.aten.reshape %194, %196 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int2_302 = torch.constant.int 2 + %int512_303 = torch.constant.int 512 + %int16_304 = torch.constant.int 16 + %198 = torch.prim.ListConstruct %int2_302, %int512_303, %int2_302, %int16_304 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_305 = torch.constant.int 2 + %int512_306 = torch.constant.int 512 + %int16_307 = torch.constant.int 16 + %199 = torch.prim.ListConstruct %int2_305, %int512_306, %int2_305, %int16_307 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %200 = torch.aten.reshape %197, %199 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> + %int0_308 = torch.constant.int 0 + %int2_309 = torch.constant.int 2 + %int1_310 = torch.constant.int 1 + %int3_311 = torch.constant.int 3 + %201 = torch.prim.ListConstruct %int0_308, %int2_309, %int1_310, %int3_311 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_312 = torch.constant.int 0 + %int2_313 = torch.constant.int 2 + %int1_314 = torch.constant.int 1 + %int3_315 = torch.constant.int 3 + %202 = torch.prim.ListConstruct %int0_312, %int2_313, %int1_314, %int3_315 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %203 = torch.aten.permute %200, %202 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> + %int4_316 = torch.constant.int 4 + %int512_317 = torch.constant.int 512 + %int16_318 = torch.constant.int 16 + %204 = torch.prim.ListConstruct %int4_316, %int512_317, %int16_318 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int4_319 = torch.constant.int 4 + %int512_320 = torch.constant.int 512 + %int16_321 = torch.constant.int 16 + %205 = torch.prim.ListConstruct %int4_319, %int512_320, %int16_321 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %206 = torch.aten.reshape %203, %205 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[4,512,16],f32> + %207 = torch.aten.bmm %206, %97 : !torch.vtensor<[4,512,16],f32>, !torch.vtensor<[4,16,512],f32> -> !torch.vtensor<[4,512,512],f32> + %int2_322 = torch.constant.int 2 + %int512_323 = torch.constant.int 512 + %208 = torch.prim.ListConstruct %int2_322, %int2_322, %int512_323, %int512_323 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_324 = torch.constant.int 2 + %int512_325 = torch.constant.int 512 + %209 = torch.prim.ListConstruct %int2_324, %int2_324, %int512_325, %int512_325 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %210 = torch.aten.reshape %207, %209 : !torch.vtensor<[4,512,512],f32>, !torch.list -> !torch.vtensor<[2,2,512,512],f32> + %int-1_326 = torch.constant.int -1 + %int6_327 = torch.constant.int 6 + %211 = torch.aten._softmax_backward_data %210, %75, %int-1_326, %int6_327 : !torch.vtensor<[2,2,512,512],f32>, !torch.vtensor<[2,2,512,512],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,2,512,512],f32> + %212 = torch.aten.div.Tensor %211, %arg21 : !torch.vtensor<[2,2,512,512],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,2,512,512],f32> + %int4_328 = torch.constant.int 4 + %int512_329 = torch.constant.int 512 + %213 = torch.prim.ListConstruct %int4_328, %int512_329, %int512_329 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int4_330 = torch.constant.int 4 + %int512_331 = torch.constant.int 512 + %214 = torch.prim.ListConstruct %int4_330, %int512_331, %int512_331 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %215 = torch.aten.reshape %212, %214 : !torch.vtensor<[2,2,512,512],f32>, !torch.list -> !torch.vtensor<[4,512,512],f32> + %216 = torch.aten.bmm %215, %36 : !torch.vtensor<[4,512,512],f32>, !torch.vtensor<[4,512,16],f32> -> !torch.vtensor<[4,512,16],f32> + %int2_332 = torch.constant.int 2 + %int512_333 = torch.constant.int 512 + %int16_334 = torch.constant.int 16 + %217 = torch.prim.ListConstruct %int2_332, %int2_332, %int512_333, %int16_334 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_335 = torch.constant.int 2 + %int512_336 = torch.constant.int 512 + %int16_337 = torch.constant.int 16 + %218 = torch.prim.ListConstruct %int2_335, %int2_335, %int512_336, %int16_337 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %219 = torch.aten.reshape %216, %218 : !torch.vtensor<[4,512,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> + %int0_338 = torch.constant.int 0 + %int2_339 = torch.constant.int 2 + %int1_340 = torch.constant.int 1 + %int3_341 = torch.constant.int 3 + %220 = torch.prim.ListConstruct %int0_338, %int2_339, %int1_340, %int3_341 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_342 = torch.constant.int 0 + %int2_343 = torch.constant.int 2 + %int1_344 = torch.constant.int 1 + %int3_345 = torch.constant.int 3 + %221 = torch.prim.ListConstruct %int0_342, %int2_343, %int1_344, %int3_345 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %222 = torch.aten.permute %219, %221 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> + %int2_346 = torch.constant.int 2 + %int512_347 = torch.constant.int 512 + %int32_348 = torch.constant.int 32 + %223 = torch.prim.ListConstruct %int2_346, %int512_347, %int32_348 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_349 = torch.constant.int 2 + %int512_350 = torch.constant.int 512 + %int32_351 = torch.constant.int 32 + %224 = torch.prim.ListConstruct %int2_349, %int512_350, %int32_351 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %225 = torch.aten.reshape %222, %224 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int1024_352 = torch.constant.int 1024 + %int32_353 = torch.constant.int 32 + %226 = torch.prim.ListConstruct %int1024_352, %int32_353 : (!torch.int, !torch.int) -> !torch.list + %int1024_354 = torch.constant.int 1024 + %int32_355 = torch.constant.int 32 + %227 = torch.prim.ListConstruct %int1024_354, %int32_355 : (!torch.int, !torch.int) -> !torch.list + %228 = torch.aten.reshape %225, %227 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %229 = torch.aten.mm %228, %13 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> + %int2_356 = torch.constant.int 2 + %int512_357 = torch.constant.int 512 + %int32_358 = torch.constant.int 32 + %230 = torch.prim.ListConstruct %int2_356, %int512_357, %int32_358 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_359 = torch.constant.int 2 + %int512_360 = torch.constant.int 512 + %int32_361 = torch.constant.int 32 + %231 = torch.prim.ListConstruct %int2_359, %int512_360, %int32_361 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %232 = torch.aten.reshape %229, %231 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int1_362 = torch.constant.int 1 + %int0_363 = torch.constant.int 0 + %233 = torch.prim.ListConstruct %int1_362, %int0_363 : (!torch.int, !torch.int) -> !torch.list + %int1_364 = torch.constant.int 1 + %int0_365 = torch.constant.int 0 + %234 = torch.prim.ListConstruct %int1_364, %int0_365 : (!torch.int, !torch.int) -> !torch.list + %235 = torch.aten.permute %arg19, %234 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_366 = torch.constant.int 1 + %int0_367 = torch.constant.int 0 + %236 = torch.prim.ListConstruct %int1_366, %int0_367 : (!torch.int, !torch.int) -> !torch.list + %int1_368 = torch.constant.int 1 + %int0_369 = torch.constant.int 0 + %237 = torch.prim.ListConstruct %int1_368, %int0_369 : (!torch.int, !torch.int) -> !torch.list + %238 = torch.aten.permute %235, %237 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_370 = torch.constant.int 1 + %int2_371 = torch.constant.int 2 + %239 = torch.aten.transpose.int %68, %int1_370, %int2_371 : !torch.vtensor<[4,512,16],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,16,512],f32> + %240 = torch.aten.bmm %239, %215 : !torch.vtensor<[4,16,512],f32>, !torch.vtensor<[4,512,512],f32> -> !torch.vtensor<[4,16,512],f32> + %int2_372 = torch.constant.int 2 + %int16_373 = torch.constant.int 16 + %int512_374 = torch.constant.int 512 + %241 = torch.prim.ListConstruct %int2_372, %int2_372, %int16_373, %int512_374 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_375 = torch.constant.int 2 + %int16_376 = torch.constant.int 16 + %int512_377 = torch.constant.int 512 + %242 = torch.prim.ListConstruct %int2_375, %int2_375, %int16_376, %int512_377 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %243 = torch.aten.reshape %240, %242 : !torch.vtensor<[4,16,512],f32>, !torch.list -> !torch.vtensor<[2,2,16,512],f32> + %int-1_378 = torch.constant.int -1 + %int-2_379 = torch.constant.int -2 + %244 = torch.aten.transpose.int %243, %int-1_378, %int-2_379 : !torch.vtensor<[2,2,16,512],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,2,512,16],f32> + %int0_380 = torch.constant.int 0 + %int2_381 = torch.constant.int 2 + %int1_382 = torch.constant.int 1 + %int3_383 = torch.constant.int 3 + %245 = torch.prim.ListConstruct %int0_380, %int2_381, %int1_382, %int3_383 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_384 = torch.constant.int 0 + %int2_385 = torch.constant.int 2 + %int1_386 = torch.constant.int 1 + %int3_387 = torch.constant.int 3 + %246 = torch.prim.ListConstruct %int0_384, %int2_385, %int1_386, %int3_387 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %247 = torch.aten.permute %244, %246 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> + %int2_388 = torch.constant.int 2 + %int512_389 = torch.constant.int 512 + %int32_390 = torch.constant.int 32 + %248 = torch.prim.ListConstruct %int2_388, %int512_389, %int32_390 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_391 = torch.constant.int 2 + %int512_392 = torch.constant.int 512 + %int32_393 = torch.constant.int 32 + %249 = torch.prim.ListConstruct %int2_391, %int512_392, %int32_393 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %250 = torch.aten.reshape %247, %249 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int1024_394 = torch.constant.int 1024 + %int32_395 = torch.constant.int 32 + %251 = torch.prim.ListConstruct %int1024_394, %int32_395 : (!torch.int, !torch.int) -> !torch.list + %int1024_396 = torch.constant.int 1024 + %int32_397 = torch.constant.int 32 + %252 = torch.prim.ListConstruct %int1024_396, %int32_397 : (!torch.int, !torch.int) -> !torch.list + %253 = torch.aten.reshape %250, %252 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %254 = torch.aten.mm %253, %238 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> + %int2_398 = torch.constant.int 2 + %int512_399 = torch.constant.int 512 + %int32_400 = torch.constant.int 32 + %255 = torch.prim.ListConstruct %int2_398, %int512_399, %int32_400 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_401 = torch.constant.int 2 + %int512_402 = torch.constant.int 512 + %int32_403 = torch.constant.int 32 + %256 = torch.prim.ListConstruct %int2_401, %int512_402, %int32_403 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %257 = torch.aten.reshape %254, %256 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int1_404 = torch.constant.int 1 + %int0_405 = torch.constant.int 0 + %258 = torch.prim.ListConstruct %int1_404, %int0_405 : (!torch.int, !torch.int) -> !torch.list + %int1_406 = torch.constant.int 1 + %int0_407 = torch.constant.int 0 + %259 = torch.prim.ListConstruct %int1_406, %int0_407 : (!torch.int, !torch.int) -> !torch.list + %260 = torch.aten.permute %arg32, %259 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_408 = torch.constant.int 1 + %int0_409 = torch.constant.int 0 + %261 = torch.prim.ListConstruct %int1_408, %int0_409 : (!torch.int, !torch.int) -> !torch.list + %int1_410 = torch.constant.int 1 + %int0_411 = torch.constant.int 0 + %262 = torch.prim.ListConstruct %int1_410, %int0_411 : (!torch.int, !torch.int) -> !torch.list + %263 = torch.aten.permute %260, %262 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_412 = torch.constant.int 1 + %int2_413 = torch.constant.int 2 + %264 = torch.aten.transpose.int %111, %int1_412, %int2_413 : !torch.vtensor<[4,512,512],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,512,512],f32> + %265 = torch.aten.bmm %264, %206 : !torch.vtensor<[4,512,512],f32>, !torch.vtensor<[4,512,16],f32> -> !torch.vtensor<[4,512,16],f32> + %int2_414 = torch.constant.int 2 + %int512_415 = torch.constant.int 512 + %int16_416 = torch.constant.int 16 + %266 = torch.prim.ListConstruct %int2_414, %int2_414, %int512_415, %int16_416 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_417 = torch.constant.int 2 + %int512_418 = torch.constant.int 512 + %int16_419 = torch.constant.int 16 + %267 = torch.prim.ListConstruct %int2_417, %int2_417, %int512_418, %int16_419 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %268 = torch.aten.reshape %265, %267 : !torch.vtensor<[4,512,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> + %int0_420 = torch.constant.int 0 + %int2_421 = torch.constant.int 2 + %int1_422 = torch.constant.int 1 + %int3_423 = torch.constant.int 3 + %269 = torch.prim.ListConstruct %int0_420, %int2_421, %int1_422, %int3_423 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_424 = torch.constant.int 0 + %int2_425 = torch.constant.int 2 + %int1_426 = torch.constant.int 1 + %int3_427 = torch.constant.int 3 + %270 = torch.prim.ListConstruct %int0_424, %int2_425, %int1_426, %int3_427 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %271 = torch.aten.permute %268, %270 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> + %int2_428 = torch.constant.int 2 + %int512_429 = torch.constant.int 512 + %int32_430 = torch.constant.int 32 + %272 = torch.prim.ListConstruct %int2_428, %int512_429, %int32_430 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_431 = torch.constant.int 2 + %int512_432 = torch.constant.int 512 + %int32_433 = torch.constant.int 32 + %273 = torch.prim.ListConstruct %int2_431, %int512_432, %int32_433 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %274 = torch.aten.reshape %271, %273 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int1024_434 = torch.constant.int 1024 + %int32_435 = torch.constant.int 32 + %275 = torch.prim.ListConstruct %int1024_434, %int32_435 : (!torch.int, !torch.int) -> !torch.list + %int1024_436 = torch.constant.int 1024 + %int32_437 = torch.constant.int 32 + %276 = torch.prim.ListConstruct %int1024_436, %int32_437 : (!torch.int, !torch.int) -> !torch.list + %277 = torch.aten.reshape %274, %276 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %278 = torch.aten.mm %277, %263 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> + %int2_438 = torch.constant.int 2 + %int512_439 = torch.constant.int 512 + %int32_440 = torch.constant.int 32 + %279 = torch.prim.ListConstruct %int2_438, %int512_439, %int32_440 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_441 = torch.constant.int 2 + %int512_442 = torch.constant.int 512 + %int32_443 = torch.constant.int 32 + %280 = torch.prim.ListConstruct %int2_441, %int512_442, %int32_443 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %281 = torch.aten.reshape %278, %280 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %282 = torch.aten.add.Tensor %result0_289, %281, %arg55 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> + %283 = torch.aten.add.Tensor %282, %257, %arg54 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> + %284 = torch.aten.add.Tensor %283, %232, %arg15 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> + %int32_444 = torch.constant.int 32 + %285 = torch.prim.ListConstruct %int32_444 : (!torch.int) -> !torch.list + %true_445 = torch.constant.bool true + %286 = torch.prim.ListConstruct %true_445, %true_445, %true_445 : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0_446, %result1_447, %result2_448 = torch.aten.native_layer_norm_backward %284, %5, %285, %result1, %result2, %arg7, %arg6, %286 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> + %287 = torch.prim.TupleConstruct %result0_446, %result1_447, %result2_448 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> -> !torch.tuple, vtensor<[32],f32>, vtensor<[32],f32>> + %int28996 = torch.constant.int 28996 + %int0_449 = torch.constant.int 0 + %false_450 = torch.constant.bool false + %288 = torch.aten.embedding_dense_backward %result0_446, %arg5, %int28996, %int0_449, %false_450 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512],si64>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[28996,32],f32> + %289 = torch.aten.add.Tensor %arg56, %288, %arg4 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.int -> !torch.vtensor<[28996,32],f32> + %290 = torch.aten.mul.Tensor %arg58, %arg57 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[28996,32],f32> + %291 = torch.aten.addcmul %290, %289, %289, %arg3 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.float -> !torch.vtensor<[28996,32],f32> + %292 = torch.aten.sqrt %291 : !torch.vtensor<[28996,32],f32> -> !torch.vtensor<[28996,32],f32> + %293 = torch.aten.add.Tensor %292, %arg2, %arg1 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[28996,32],f32> + %294 = torch.aten.mul.Tensor %arg61, %arg60 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[28996,32],f32> + %295 = torch.aten.add.Tensor %294, %289, %arg59 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.float -> !torch.vtensor<[28996,32],f32> + %296 = torch.aten.addcdiv %arg14, %295, %293, %arg0 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.float -> !torch.vtensor<[28996,32],f32> + %int0_451 = torch.constant.int 0 + %297 = torch.prim.ListConstruct %int0_451 : (!torch.int) -> !torch.list + %true_452 = torch.constant.bool true + %none_453 = torch.constant.none + %298 = torch.aten.sum.dim_IntList %result0_446, %297, %true_452, %none_453 : !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,512,32],f32> + %int512_454 = torch.constant.int 512 + %int-1_455 = torch.constant.int -1 + %false_456 = torch.constant.bool false + %299 = torch.aten.embedding_dense_backward %298, %0, %int512_454, %int-1_455, %false_456 : !torch.vtensor<[1,512,32],f32>, !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[512,32],f32> + %300 = torch.aten.add.Tensor %arg66, %299, %arg65 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.int -> !torch.vtensor<[512,32],f32> + %301 = torch.aten.mul.Tensor %arg67, %arg57 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[512,32],f32> + %302 = torch.aten.addcmul %301, %300, %300, %arg64 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.float -> !torch.vtensor<[512,32],f32> + %303 = torch.aten.sqrt %302 : !torch.vtensor<[512,32],f32> -> !torch.vtensor<[512,32],f32> + %304 = torch.aten.add.Tensor %303, %arg2, %arg63 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[512,32],f32> + %305 = torch.aten.mul.Tensor %arg69, %arg60 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[512,32],f32> + %306 = torch.aten.add.Tensor %305, %300, %arg68 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.float -> !torch.vtensor<[512,32],f32> + %307 = torch.aten.addcdiv %arg10, %306, %304, %arg62 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.float -> !torch.vtensor<[512,32],f32> + %int2_457 = torch.constant.int 2 + %int-1_458 = torch.constant.int -1 + %false_459 = torch.constant.bool false + %308 = torch.aten.embedding_dense_backward %result0_446, %arg12, %int2_457, %int-1_458, %false_459 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512],si64>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[2,32],f32> + %309 = torch.aten.add.Tensor %arg74, %308, %arg73 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.int -> !torch.vtensor<[2,32],f32> + %310 = torch.aten.mul.Tensor %arg75, %arg57 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,32],f32> + %311 = torch.aten.addcmul %310, %309, %309, %arg72 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> + %312 = torch.aten.sqrt %311 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> + %313 = torch.aten.add.Tensor %312, %arg2, %arg71 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[2,32],f32> + %314 = torch.aten.mul.Tensor %arg77, %arg60 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,32],f32> + %315 = torch.aten.add.Tensor %314, %309, %arg76 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> + %316 = torch.aten.addcdiv %arg13, %315, %313, %arg70 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> + %317 = torch.aten.add.Tensor %arg82, %result1_447, %arg81 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %318 = torch.aten.mul.Tensor %arg83, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %319 = torch.aten.addcmul %318, %317, %317, %arg80 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %320 = torch.aten.sqrt %319 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %321 = torch.aten.add.Tensor %320, %arg2, %arg79 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %322 = torch.aten.mul.Tensor %arg85, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %323 = torch.aten.add.Tensor %322, %317, %arg84 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %324 = torch.aten.addcdiv %arg7, %323, %321, %arg78 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %325 = torch.aten.add.Tensor %arg90, %result2_448, %arg89 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %326 = torch.aten.mul.Tensor %arg91, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %327 = torch.aten.addcmul %326, %325, %325, %arg88 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %328 = torch.aten.sqrt %327 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %329 = torch.aten.add.Tensor %328, %arg2, %arg87 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %330 = torch.aten.mul.Tensor %arg93, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %331 = torch.aten.add.Tensor %330, %325, %arg92 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %332 = torch.aten.addcdiv %arg6, %331, %329, %arg86 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %int1024_460 = torch.constant.int 1024 + %int32_461 = torch.constant.int 32 + %333 = torch.prim.ListConstruct %int1024_460, %int32_461 : (!torch.int, !torch.int) -> !torch.list + %int1024_462 = torch.constant.int 1024 + %int32_463 = torch.constant.int 32 + %334 = torch.prim.ListConstruct %int1024_462, %int32_463 : (!torch.int, !torch.int) -> !torch.list + %335 = torch.aten.reshape %result0, %334 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %int1_464 = torch.constant.int 1 + %int0_465 = torch.constant.int 0 + %336 = torch.prim.ListConstruct %int1_464, %int0_465 : (!torch.int, !torch.int) -> !torch.list + %int1_466 = torch.constant.int 1 + %int0_467 = torch.constant.int 0 + %337 = torch.prim.ListConstruct %int1_466, %int0_467 : (!torch.int, !torch.int) -> !torch.list + %338 = torch.aten.permute %335, %337 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> + %339 = torch.aten.mm %338, %228 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> + %int1_468 = torch.constant.int 1 + %int0_469 = torch.constant.int 0 + %340 = torch.prim.ListConstruct %int1_468, %int0_469 : (!torch.int, !torch.int) -> !torch.list + %int1_470 = torch.constant.int 1 + %int0_471 = torch.constant.int 0 + %341 = torch.prim.ListConstruct %int1_470, %int0_471 : (!torch.int, !torch.int) -> !torch.list + %342 = torch.aten.permute %339, %341 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %343 = torch.aten.add.Tensor %arg98, %342, %arg97 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> + %344 = torch.aten.mul.Tensor %arg99, %arg57 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %345 = torch.aten.addcmul %344, %343, %343, %arg96 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %346 = torch.aten.sqrt %345 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %347 = torch.aten.add.Tensor %346, %arg2, %arg95 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> + %348 = torch.aten.mul.Tensor %arg101, %arg60 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %349 = torch.aten.add.Tensor %348, %343, %arg100 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %350 = torch.aten.addcdiv %arg16, %349, %347, %arg94 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %int0_472 = torch.constant.int 0 + %351 = torch.prim.ListConstruct %int0_472 : (!torch.int) -> !torch.list + %true_473 = torch.constant.bool true + %none_474 = torch.constant.none + %352 = torch.aten.sum.dim_IntList %228, %351, %true_473, %none_474 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> + %int32_475 = torch.constant.int 32 + %353 = torch.prim.ListConstruct %int32_475 : (!torch.int) -> !torch.list + %int32_476 = torch.constant.int 32 + %354 = torch.prim.ListConstruct %int32_476 : (!torch.int) -> !torch.list + %355 = torch.aten.reshape %352, %354 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> + %356 = torch.aten.add.Tensor %arg106, %355, %arg105 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %357 = torch.aten.mul.Tensor %arg107, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %358 = torch.aten.addcmul %357, %356, %356, %arg104 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %359 = torch.aten.sqrt %358 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %360 = torch.aten.add.Tensor %359, %arg2, %arg103 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %361 = torch.aten.mul.Tensor %arg109, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %362 = torch.aten.add.Tensor %361, %356, %arg108 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %363 = torch.aten.addcdiv %arg29, %362, %360, %arg102 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %int1024_477 = torch.constant.int 1024 + %int32_478 = torch.constant.int 32 + %364 = torch.prim.ListConstruct %int1024_477, %int32_478 : (!torch.int, !torch.int) -> !torch.list + %int1024_479 = torch.constant.int 1024 + %int32_480 = torch.constant.int 32 + %365 = torch.prim.ListConstruct %int1024_479, %int32_480 : (!torch.int, !torch.int) -> !torch.list + %366 = torch.aten.reshape %result0, %365 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %int1_481 = torch.constant.int 1 + %int0_482 = torch.constant.int 0 + %367 = torch.prim.ListConstruct %int1_481, %int0_482 : (!torch.int, !torch.int) -> !torch.list + %int1_483 = torch.constant.int 1 + %int0_484 = torch.constant.int 0 + %368 = torch.prim.ListConstruct %int1_483, %int0_484 : (!torch.int, !torch.int) -> !torch.list + %369 = torch.aten.permute %366, %368 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> + %370 = torch.aten.mm %369, %253 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> + %int1_485 = torch.constant.int 1 + %int0_486 = torch.constant.int 0 + %371 = torch.prim.ListConstruct %int1_485, %int0_486 : (!torch.int, !torch.int) -> !torch.list + %int1_487 = torch.constant.int 1 + %int0_488 = torch.constant.int 0 + %372 = torch.prim.ListConstruct %int1_487, %int0_488 : (!torch.int, !torch.int) -> !torch.list + %373 = torch.aten.permute %370, %372 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %374 = torch.aten.add.Tensor %arg114, %373, %arg113 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> + %375 = torch.aten.mul.Tensor %arg115, %arg57 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %376 = torch.aten.addcmul %375, %374, %374, %arg112 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %377 = torch.aten.sqrt %376 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %378 = torch.aten.add.Tensor %377, %arg2, %arg111 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> + %379 = torch.aten.mul.Tensor %arg117, %arg60 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %380 = torch.aten.add.Tensor %379, %374, %arg116 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %381 = torch.aten.addcdiv %arg19, %380, %378, %arg110 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %int0_489 = torch.constant.int 0 + %382 = torch.prim.ListConstruct %int0_489 : (!torch.int) -> !torch.list + %true_490 = torch.constant.bool true + %none_491 = torch.constant.none + %383 = torch.aten.sum.dim_IntList %253, %382, %true_490, %none_491 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> + %int32_492 = torch.constant.int 32 + %384 = torch.prim.ListConstruct %int32_492 : (!torch.int) -> !torch.list + %int32_493 = torch.constant.int 32 + %385 = torch.prim.ListConstruct %int32_493 : (!torch.int) -> !torch.list + %386 = torch.aten.reshape %383, %385 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> + %387 = torch.aten.add.Tensor %arg122, %386, %arg121 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %388 = torch.aten.mul.Tensor %arg123, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %389 = torch.aten.addcmul %388, %387, %387, %arg120 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %390 = torch.aten.sqrt %389 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %391 = torch.aten.add.Tensor %390, %arg2, %arg119 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %392 = torch.aten.mul.Tensor %arg125, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %393 = torch.aten.add.Tensor %392, %387, %arg124 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %394 = torch.aten.addcdiv %arg20, %393, %391, %arg118 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %int1024_494 = torch.constant.int 1024 + %int32_495 = torch.constant.int 32 + %395 = torch.prim.ListConstruct %int1024_494, %int32_495 : (!torch.int, !torch.int) -> !torch.list + %int1024_496 = torch.constant.int 1024 + %int32_497 = torch.constant.int 32 + %396 = torch.prim.ListConstruct %int1024_496, %int32_497 : (!torch.int, !torch.int) -> !torch.list + %397 = torch.aten.reshape %result0, %396 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %int1_498 = torch.constant.int 1 + %int0_499 = torch.constant.int 0 + %398 = torch.prim.ListConstruct %int1_498, %int0_499 : (!torch.int, !torch.int) -> !torch.list + %int1_500 = torch.constant.int 1 + %int0_501 = torch.constant.int 0 + %399 = torch.prim.ListConstruct %int1_500, %int0_501 : (!torch.int, !torch.int) -> !torch.list + %400 = torch.aten.permute %397, %399 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> + %401 = torch.aten.mm %400, %277 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> + %int1_502 = torch.constant.int 1 + %int0_503 = torch.constant.int 0 + %402 = torch.prim.ListConstruct %int1_502, %int0_503 : (!torch.int, !torch.int) -> !torch.list + %int1_504 = torch.constant.int 1 + %int0_505 = torch.constant.int 0 + %403 = torch.prim.ListConstruct %int1_504, %int0_505 : (!torch.int, !torch.int) -> !torch.list + %404 = torch.aten.permute %401, %403 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %405 = torch.aten.add.Tensor %arg130, %404, %arg129 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> + %406 = torch.aten.mul.Tensor %arg131, %arg57 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %407 = torch.aten.addcmul %406, %405, %405, %arg128 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %408 = torch.aten.sqrt %407 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %409 = torch.aten.add.Tensor %408, %arg2, %arg127 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> + %410 = torch.aten.mul.Tensor %arg133, %arg60 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %411 = torch.aten.add.Tensor %410, %405, %arg132 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %412 = torch.aten.addcdiv %arg32, %411, %409, %arg126 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %int0_506 = torch.constant.int 0 + %413 = torch.prim.ListConstruct %int0_506 : (!torch.int) -> !torch.list + %true_507 = torch.constant.bool true + %none_508 = torch.constant.none + %414 = torch.aten.sum.dim_IntList %277, %413, %true_507, %none_508 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> + %int32_509 = torch.constant.int 32 + %415 = torch.prim.ListConstruct %int32_509 : (!torch.int) -> !torch.list + %int32_510 = torch.constant.int 32 + %416 = torch.prim.ListConstruct %int32_510 : (!torch.int) -> !torch.list + %417 = torch.aten.reshape %414, %416 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> + %418 = torch.aten.add.Tensor %arg138, %417, %arg137 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %419 = torch.aten.mul.Tensor %arg139, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %420 = torch.aten.addcmul %419, %418, %418, %arg136 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %421 = torch.aten.sqrt %420 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %422 = torch.aten.add.Tensor %421, %arg2, %arg135 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %423 = torch.aten.mul.Tensor %arg141, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %424 = torch.aten.add.Tensor %423, %418, %arg140 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %425 = torch.aten.addcdiv %arg33, %424, %422, %arg134 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %int0_511 = torch.constant.int 0 + %int2_512 = torch.constant.int 2 + %int1_513 = torch.constant.int 1 + %int3_514 = torch.constant.int 3 + %426 = torch.prim.ListConstruct %int0_511, %int2_512, %int1_513, %int3_514 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_515 = torch.constant.int 0 + %int2_516 = torch.constant.int 2 + %int1_517 = torch.constant.int 1 + %int3_518 = torch.constant.int 3 + %427 = torch.prim.ListConstruct %int0_515, %int2_516, %int1_517, %int3_518 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %428 = torch.aten.permute %115, %427 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> + %int2_519 = torch.constant.int 2 + %int512_520 = torch.constant.int 512 + %int32_521 = torch.constant.int 32 + %429 = torch.prim.ListConstruct %int2_519, %int512_520, %int32_521 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_522 = torch.constant.int 2 + %int512_523 = torch.constant.int 512 + %int32_524 = torch.constant.int 32 + %430 = torch.prim.ListConstruct %int2_522, %int512_523, %int32_524 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %431 = torch.aten.reshape %428, %430 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int1024_525 = torch.constant.int 1024 + %int32_526 = torch.constant.int 32 + %432 = torch.prim.ListConstruct %int1024_525, %int32_526 : (!torch.int, !torch.int) -> !torch.list + %int1024_527 = torch.constant.int 1024 + %int32_528 = torch.constant.int 32 + %433 = torch.prim.ListConstruct %int1024_527, %int32_528 : (!torch.int, !torch.int) -> !torch.list + %434 = torch.aten.reshape %431, %433 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %int1_529 = torch.constant.int 1 + %int0_530 = torch.constant.int 0 + %435 = torch.prim.ListConstruct %int1_529, %int0_530 : (!torch.int, !torch.int) -> !torch.list + %int1_531 = torch.constant.int 1 + %int0_532 = torch.constant.int 0 + %436 = torch.prim.ListConstruct %int1_531, %int0_532 : (!torch.int, !torch.int) -> !torch.list + %437 = torch.aten.permute %434, %436 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> + %438 = torch.aten.mm %437, %193 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> + %int1_533 = torch.constant.int 1 + %int0_534 = torch.constant.int 0 + %439 = torch.prim.ListConstruct %int1_533, %int0_534 : (!torch.int, !torch.int) -> !torch.list + %int1_535 = torch.constant.int 1 + %int0_536 = torch.constant.int 0 + %440 = torch.prim.ListConstruct %int1_535, %int0_536 : (!torch.int, !torch.int) -> !torch.list + %441 = torch.aten.permute %438, %440 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %442 = torch.aten.add.Tensor %arg146, %441, %arg145 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> + %443 = torch.aten.mul.Tensor %arg147, %arg57 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %444 = torch.aten.addcmul %443, %442, %442, %arg144 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %445 = torch.aten.sqrt %444 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %446 = torch.aten.add.Tensor %445, %arg2, %arg143 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> + %447 = torch.aten.mul.Tensor %arg149, %arg60 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %448 = torch.aten.add.Tensor %447, %442, %arg148 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %449 = torch.aten.addcdiv %arg34, %448, %446, %arg142 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %int0_537 = torch.constant.int 0 + %450 = torch.prim.ListConstruct %int0_537 : (!torch.int) -> !torch.list + %true_538 = torch.constant.bool true + %none_539 = torch.constant.none + %451 = torch.aten.sum.dim_IntList %193, %450, %true_538, %none_539 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> + %int32_540 = torch.constant.int 32 + %452 = torch.prim.ListConstruct %int32_540 : (!torch.int) -> !torch.list + %int32_541 = torch.constant.int 32 + %453 = torch.prim.ListConstruct %int32_541 : (!torch.int) -> !torch.list + %454 = torch.aten.reshape %451, %453 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> + %455 = torch.aten.add.Tensor %arg154, %454, %arg153 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %456 = torch.aten.mul.Tensor %arg155, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %457 = torch.aten.addcmul %456, %455, %455, %arg152 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %458 = torch.aten.sqrt %457 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %459 = torch.aten.add.Tensor %458, %arg2, %arg151 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %460 = torch.aten.mul.Tensor %arg157, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %461 = torch.aten.add.Tensor %460, %455, %arg156 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %462 = torch.aten.addcdiv %arg40, %461, %459, %arg150 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %463 = torch.aten.add.Tensor %arg162, %result1_290, %arg161 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %464 = torch.aten.mul.Tensor %arg163, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %465 = torch.aten.addcmul %464, %463, %463, %arg160 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %466 = torch.aten.sqrt %465 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %467 = torch.aten.add.Tensor %466, %arg2, %arg159 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %468 = torch.aten.mul.Tensor %arg165, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %469 = torch.aten.add.Tensor %468, %463, %arg164 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %470 = torch.aten.addcdiv %arg36, %469, %467, %arg158 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %471 = torch.aten.add.Tensor %arg170, %result2_291, %arg169 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %472 = torch.aten.mul.Tensor %arg171, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %473 = torch.aten.addcmul %472, %471, %471, %arg168 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %474 = torch.aten.sqrt %473 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %475 = torch.aten.add.Tensor %474, %arg2, %arg167 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %476 = torch.aten.mul.Tensor %arg173, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %477 = torch.aten.add.Tensor %476, %471, %arg172 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %478 = torch.aten.addcdiv %arg35, %477, %475, %arg166 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %int1024_542 = torch.constant.int 1024 + %int32_543 = torch.constant.int 32 + %479 = torch.prim.ListConstruct %int1024_542, %int32_543 : (!torch.int, !torch.int) -> !torch.list + %int1024_544 = torch.constant.int 1024 + %int32_545 = torch.constant.int 32 + %480 = torch.prim.ListConstruct %int1024_544, %int32_545 : (!torch.int, !torch.int) -> !torch.list + %481 = torch.aten.reshape %result0_210, %480 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %int1_546 = torch.constant.int 1 + %int0_547 = torch.constant.int 0 + %482 = torch.prim.ListConstruct %int1_546, %int0_547 : (!torch.int, !torch.int) -> !torch.list + %int1_548 = torch.constant.int 1 + %int0_549 = torch.constant.int 0 + %483 = torch.prim.ListConstruct %int1_548, %int0_549 : (!torch.int, !torch.int) -> !torch.list + %484 = torch.aten.permute %481, %483 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> + %485 = torch.aten.mm %484, %182 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> + %int1_550 = torch.constant.int 1 + %int0_551 = torch.constant.int 0 + %486 = torch.prim.ListConstruct %int1_550, %int0_551 : (!torch.int, !torch.int) -> !torch.list + %int1_552 = torch.constant.int 1 + %int0_553 = torch.constant.int 0 + %487 = torch.prim.ListConstruct %int1_552, %int0_553 : (!torch.int, !torch.int) -> !torch.list + %488 = torch.aten.permute %485, %487 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %489 = torch.aten.add.Tensor %arg178, %488, %arg177 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> + %490 = torch.aten.mul.Tensor %arg179, %arg57 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %491 = torch.aten.addcmul %490, %489, %489, %arg176 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %492 = torch.aten.sqrt %491 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %493 = torch.aten.add.Tensor %492, %arg2, %arg175 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> + %494 = torch.aten.mul.Tensor %arg181, %arg60 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %495 = torch.aten.add.Tensor %494, %489, %arg180 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %496 = torch.aten.addcdiv %arg42, %495, %493, %arg174 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %int0_554 = torch.constant.int 0 + %497 = torch.prim.ListConstruct %int0_554 : (!torch.int) -> !torch.list + %true_555 = torch.constant.bool true + %none_556 = torch.constant.none + %498 = torch.aten.sum.dim_IntList %182, %497, %true_555, %none_556 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> + %int32_557 = torch.constant.int 32 + %499 = torch.prim.ListConstruct %int32_557 : (!torch.int) -> !torch.list + %int32_558 = torch.constant.int 32 + %500 = torch.prim.ListConstruct %int32_558 : (!torch.int) -> !torch.list + %501 = torch.aten.reshape %498, %500 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> + %502 = torch.aten.add.Tensor %arg186, %501, %arg185 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %503 = torch.aten.mul.Tensor %arg187, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %504 = torch.aten.addcmul %503, %502, %502, %arg184 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %505 = torch.aten.sqrt %504 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %506 = torch.aten.add.Tensor %505, %arg2, %arg183 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %507 = torch.aten.mul.Tensor %arg189, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %508 = torch.aten.add.Tensor %507, %502, %arg188 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %509 = torch.aten.addcdiv %arg45, %508, %506, %arg182 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %int1024_559 = torch.constant.int 1024 + %int32_560 = torch.constant.int 32 + %510 = torch.prim.ListConstruct %int1024_559, %int32_560 : (!torch.int, !torch.int) -> !torch.list + %int1024_561 = torch.constant.int 1024 + %int32_562 = torch.constant.int 32 + %511 = torch.prim.ListConstruct %int1024_561, %int32_562 : (!torch.int, !torch.int) -> !torch.list + %512 = torch.aten.reshape %157, %511 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %int1_563 = torch.constant.int 1 + %int0_564 = torch.constant.int 0 + %513 = torch.prim.ListConstruct %int1_563, %int0_564 : (!torch.int, !torch.int) -> !torch.list + %int1_565 = torch.constant.int 1 + %int0_566 = torch.constant.int 0 + %514 = torch.prim.ListConstruct %int1_565, %int0_566 : (!torch.int, !torch.int) -> !torch.list + %515 = torch.aten.permute %512, %514 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> + %516 = torch.aten.mm %515, %174 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> + %int1_567 = torch.constant.int 1 + %int0_568 = torch.constant.int 0 + %517 = torch.prim.ListConstruct %int1_567, %int0_568 : (!torch.int, !torch.int) -> !torch.list + %int1_569 = torch.constant.int 1 + %int0_570 = torch.constant.int 0 + %518 = torch.prim.ListConstruct %int1_569, %int0_570 : (!torch.int, !torch.int) -> !torch.list + %519 = torch.aten.permute %516, %518 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %520 = torch.aten.add.Tensor %arg194, %519, %arg193 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> + %521 = torch.aten.mul.Tensor %arg195, %arg57 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %522 = torch.aten.addcmul %521, %520, %520, %arg192 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %523 = torch.aten.sqrt %522 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %524 = torch.aten.add.Tensor %523, %arg2, %arg191 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> + %525 = torch.aten.mul.Tensor %arg197, %arg60 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %526 = torch.aten.add.Tensor %525, %520, %arg196 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %527 = torch.aten.addcdiv %arg46, %526, %524, %arg190 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %int0_571 = torch.constant.int 0 + %528 = torch.prim.ListConstruct %int0_571 : (!torch.int) -> !torch.list + %true_572 = torch.constant.bool true + %none_573 = torch.constant.none + %529 = torch.aten.sum.dim_IntList %174, %528, %true_572, %none_573 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> + %int32_574 = torch.constant.int 32 + %530 = torch.prim.ListConstruct %int32_574 : (!torch.int) -> !torch.list + %int32_575 = torch.constant.int 32 + %531 = torch.prim.ListConstruct %int32_575 : (!torch.int) -> !torch.list + %532 = torch.aten.reshape %529, %531 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> + %533 = torch.aten.add.Tensor %arg202, %532, %arg201 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %534 = torch.aten.mul.Tensor %arg203, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %535 = torch.aten.addcmul %534, %533, %533, %arg200 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %536 = torch.aten.sqrt %535 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %537 = torch.aten.add.Tensor %536, %arg2, %arg199 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %538 = torch.aten.mul.Tensor %arg205, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %539 = torch.aten.add.Tensor %538, %533, %arg204 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %540 = torch.aten.addcdiv %arg52, %539, %537, %arg198 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %541 = torch.aten.add.Tensor %arg210, %result1_264, %arg209 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %542 = torch.aten.mul.Tensor %arg211, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %543 = torch.aten.addcmul %542, %541, %541, %arg208 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %544 = torch.aten.sqrt %543 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %545 = torch.aten.add.Tensor %544, %arg2, %arg207 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %546 = torch.aten.mul.Tensor %arg213, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %547 = torch.aten.add.Tensor %546, %541, %arg212 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %548 = torch.aten.addcdiv %arg48, %547, %545, %arg206 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %549 = torch.aten.add.Tensor %arg218, %result2_265, %arg217 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %550 = torch.aten.mul.Tensor %arg219, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %551 = torch.aten.addcmul %550, %549, %549, %arg216 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %552 = torch.aten.sqrt %551 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %553 = torch.aten.add.Tensor %552, %arg2, %arg215 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %554 = torch.aten.mul.Tensor %arg221, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %555 = torch.aten.add.Tensor %554, %549, %arg220 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %556 = torch.aten.addcdiv %arg47, %555, %553, %arg214 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %int1_576 = torch.constant.int 1 + %int0_577 = torch.constant.int 0 + %557 = torch.prim.ListConstruct %int1_576, %int0_577 : (!torch.int, !torch.int) -> !torch.list + %int1_578 = torch.constant.int 1 + %int0_579 = torch.constant.int 0 + %558 = torch.prim.ListConstruct %int1_578, %int0_579 : (!torch.int, !torch.int) -> !torch.list + %559 = torch.aten.permute %arg228, %558 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int0_580 = torch.constant.int 0 + %int0_581 = torch.constant.int 0 + %int9223372036854775807_582 = torch.constant.int 9223372036854775807 + %int1_583 = torch.constant.int 1 + %560 = torch.aten.slice.Tensor %result0_259, %int0_580, %int0_581, %int9223372036854775807_582, %int1_583 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,512,32],f32> + %int1_584 = torch.constant.int 1 + %int0_585 = torch.constant.int 0 + %561 = torch.aten.select.int %560, %int1_584, %int0_585 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,32],f32> + %562 = torch.aten.addmm %arg229, %561, %559, %arg227, %arg226 : !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,32],f32> + %563 = torch.aten.tanh %562 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> + %int1_586 = torch.constant.int 1 + %int0_587 = torch.constant.int 0 + %564 = torch.prim.ListConstruct %int1_586, %int0_587 : (!torch.int, !torch.int) -> !torch.list + %int1_588 = torch.constant.int 1 + %int0_589 = torch.constant.int 0 + %565 = torch.prim.ListConstruct %int1_588, %int0_589 : (!torch.int, !torch.int) -> !torch.list + %566 = torch.aten.permute %arg230, %565 : !torch.vtensor<[2,32],f32>, !torch.list -> !torch.vtensor<[32,2],f32> + %int1_590 = torch.constant.int 1 + %int0_591 = torch.constant.int 0 + %567 = torch.prim.ListConstruct %int1_590, %int0_591 : (!torch.int, !torch.int) -> !torch.list + %int1_592 = torch.constant.int 1 + %int0_593 = torch.constant.int 0 + %568 = torch.prim.ListConstruct %int1_592, %int0_593 : (!torch.int, !torch.int) -> !torch.list + %569 = torch.aten.permute %566, %568 : !torch.vtensor<[32,2],f32>, !torch.list -> !torch.vtensor<[2,32],f32> + %int1_594 = torch.constant.int 1 + %int0_595 = torch.constant.int 0 + %570 = torch.prim.ListConstruct %int1_594, %int0_595 : (!torch.int, !torch.int) -> !torch.list + %int1_596 = torch.constant.int 1 + %int0_597 = torch.constant.int 0 + %571 = torch.prim.ListConstruct %int1_596, %int0_597 : (!torch.int, !torch.int) -> !torch.list + %572 = torch.aten.permute %arg230, %571 : !torch.vtensor<[2,32],f32>, !torch.list -> !torch.vtensor<[32,2],f32> + %573 = torch.aten.addmm %arg233, %563, %572, %arg232, %arg231 : !torch.vtensor<[2],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,2],f32> + %int2_598 = torch.constant.int 2 + %574 = torch.prim.ListConstruct %int2_598, %int2_598 : (!torch.int, !torch.int) -> !torch.list + %int2_599 = torch.constant.int 2 + %575 = torch.prim.ListConstruct %int2_599, %int2_599 : (!torch.int, !torch.int) -> !torch.list + %576 = torch.aten.reshape %573, %575 : !torch.vtensor<[2,2],f32>, !torch.list -> !torch.vtensor<[2,2],f32> + %int1_600 = torch.constant.int 1 + %false_601 = torch.constant.bool false + %577 = torch.aten._log_softmax %576, %int1_600, %false_601 : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,2],f32> + %int2_602 = torch.constant.int 2 + %578 = torch.prim.ListConstruct %int2_602 : (!torch.int) -> !torch.list + %int2_603 = torch.constant.int 2 + %579 = torch.prim.ListConstruct %int2_603 : (!torch.int) -> !torch.list + %580 = torch.aten.reshape %arg234, %579 : !torch.vtensor<[2],si64>, !torch.list -> !torch.vtensor<[2],si64> + %none_604 = torch.constant.none + %int1_605 = torch.constant.int 1 + %int-100 = torch.constant.int -100 + %output, %total_weight = torch.aten.nll_loss_forward %577, %580, %none_604, %int1_605, %int-100 : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32> + %581 = torch.prim.TupleConstruct %output, %total_weight : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.tuple, vtensor<[],f32>> + %none_606 = torch.constant.none + %int1_607 = torch.constant.int 1 + %int-100_608 = torch.constant.int -100 + %582 = torch.aten.nll_loss_backward %arg235, %577, %580, %none_606, %int1_607, %int-100_608, %total_weight : !torch.vtensor<[],f32>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[2],si64>, !torch.none, !torch.int, !torch.int, !torch.vtensor<[],f32> -> !torch.vtensor<[2,2],f32> + %int1_609 = torch.constant.int 1 + %int6_610 = torch.constant.int 6 + %583 = torch.aten._log_softmax_backward_data %582, %577, %int1_609, %int6_610 : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,2],f32> + %int2_611 = torch.constant.int 2 + %584 = torch.prim.ListConstruct %int2_611, %int2_611 : (!torch.int, !torch.int) -> !torch.list + %int2_612 = torch.constant.int 2 + %585 = torch.prim.ListConstruct %int2_612, %int2_612 : (!torch.int, !torch.int) -> !torch.list + %586 = torch.aten.reshape %583, %585 : !torch.vtensor<[2,2],f32>, !torch.list -> !torch.vtensor<[2,2],f32> + %587 = torch.aten.mm %586, %569 : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> + %588 = torch.aten.tanh_backward %587, %563 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> + %int1_613 = torch.constant.int 1 + %int0_614 = torch.constant.int 0 + %589 = torch.prim.ListConstruct %int1_613, %int0_614 : (!torch.int, !torch.int) -> !torch.list + %int1_615 = torch.constant.int 1 + %int0_616 = torch.constant.int 0 + %590 = torch.prim.ListConstruct %int1_615, %int0_616 : (!torch.int, !torch.int) -> !torch.list + %591 = torch.aten.permute %561, %590 : !torch.vtensor<[2,32],f32>, !torch.list -> !torch.vtensor<[32,2],f32> + %592 = torch.aten.mm %591, %588 : !torch.vtensor<[32,2],f32>, !torch.vtensor<[2,32],f32> -> !torch.vtensor<[32,32],f32> + %int1_617 = torch.constant.int 1 + %int0_618 = torch.constant.int 0 + %593 = torch.prim.ListConstruct %int1_617, %int0_618 : (!torch.int, !torch.int) -> !torch.list + %int1_619 = torch.constant.int 1 + %int0_620 = torch.constant.int 0 + %594 = torch.prim.ListConstruct %int1_619, %int0_620 : (!torch.int, !torch.int) -> !torch.list + %595 = torch.aten.permute %592, %594 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %596 = torch.aten.add.Tensor %arg236, %595, %arg225 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> + %597 = torch.aten.mul.Tensor %arg237, %arg57 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %598 = torch.aten.addcmul %597, %596, %596, %arg224 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %599 = torch.aten.sqrt %598 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %600 = torch.aten.add.Tensor %599, %arg2, %arg223 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> + %601 = torch.aten.mul.Tensor %arg239, %arg60 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %602 = torch.aten.add.Tensor %601, %596, %arg238 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %603 = torch.aten.addcdiv %arg228, %602, %600, %arg222 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %int0_621 = torch.constant.int 0 + %604 = torch.prim.ListConstruct %int0_621 : (!torch.int) -> !torch.list + %true_622 = torch.constant.bool true + %none_623 = torch.constant.none + %605 = torch.aten.sum.dim_IntList %588, %604, %true_622, %none_623 : !torch.vtensor<[2,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> + %int32_624 = torch.constant.int 32 + %606 = torch.prim.ListConstruct %int32_624 : (!torch.int) -> !torch.list + %int32_625 = torch.constant.int 32 + %607 = torch.prim.ListConstruct %int32_625 : (!torch.int) -> !torch.list + %608 = torch.aten.reshape %605, %607 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> + %609 = torch.aten.add.Tensor %arg244, %608, %arg243 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %610 = torch.aten.mul.Tensor %arg245, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %611 = torch.aten.addcmul %610, %609, %609, %arg242 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %612 = torch.aten.sqrt %611 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %613 = torch.aten.add.Tensor %612, %arg2, %arg241 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %614 = torch.aten.mul.Tensor %arg247, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %615 = torch.aten.add.Tensor %614, %609, %arg246 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %616 = torch.aten.addcdiv %arg229, %615, %613, %arg240 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %int1_626 = torch.constant.int 1 + %int0_627 = torch.constant.int 0 + %617 = torch.prim.ListConstruct %int1_626, %int0_627 : (!torch.int, !torch.int) -> !torch.list + %int1_628 = torch.constant.int 1 + %int0_629 = torch.constant.int 0 + %618 = torch.prim.ListConstruct %int1_628, %int0_629 : (!torch.int, !torch.int) -> !torch.list + %619 = torch.aten.permute %563, %618 : !torch.vtensor<[2,32],f32>, !torch.list -> !torch.vtensor<[32,2],f32> + %620 = torch.aten.mm %619, %586 : !torch.vtensor<[32,2],f32>, !torch.vtensor<[2,2],f32> -> !torch.vtensor<[32,2],f32> + %int1_630 = torch.constant.int 1 + %int0_631 = torch.constant.int 0 + %621 = torch.prim.ListConstruct %int1_630, %int0_631 : (!torch.int, !torch.int) -> !torch.list + %int1_632 = torch.constant.int 1 + %int0_633 = torch.constant.int 0 + %622 = torch.prim.ListConstruct %int1_632, %int0_633 : (!torch.int, !torch.int) -> !torch.list + %623 = torch.aten.permute %620, %622 : !torch.vtensor<[32,2],f32>, !torch.list -> !torch.vtensor<[2,32],f32> + %624 = torch.aten.add.Tensor %arg252, %623, %arg251 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.int -> !torch.vtensor<[2,32],f32> + %625 = torch.aten.mul.Tensor %arg253, %arg57 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,32],f32> + %626 = torch.aten.addcmul %625, %624, %624, %arg250 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> + %627 = torch.aten.sqrt %626 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> + %628 = torch.aten.add.Tensor %627, %arg2, %arg249 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[2,32],f32> + %629 = torch.aten.mul.Tensor %arg255, %arg60 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,32],f32> + %630 = torch.aten.add.Tensor %629, %624, %arg254 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> + %631 = torch.aten.addcdiv %arg230, %630, %628, %arg248 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> + %int0_634 = torch.constant.int 0 + %632 = torch.prim.ListConstruct %int0_634 : (!torch.int) -> !torch.list + %true_635 = torch.constant.bool true + %none_636 = torch.constant.none + %633 = torch.aten.sum.dim_IntList %586, %632, %true_635, %none_636 : !torch.vtensor<[2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,2],f32> + %int2_637 = torch.constant.int 2 + %634 = torch.prim.ListConstruct %int2_637 : (!torch.int) -> !torch.list + %int2_638 = torch.constant.int 2 + %635 = torch.prim.ListConstruct %int2_638 : (!torch.int) -> !torch.list + %636 = torch.aten.reshape %633, %635 : !torch.vtensor<[1,2],f32>, !torch.list -> !torch.vtensor<[2],f32> + %637 = torch.aten.add.Tensor %arg260, %636, %arg259 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2],f32> + %638 = torch.aten.mul.Tensor %arg261, %arg57 : !torch.vtensor<[2],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2],f32> + %639 = torch.aten.addcmul %638, %637, %637, %arg258 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float -> !torch.vtensor<[2],f32> + %640 = torch.aten.sqrt %639 : !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32> + %641 = torch.aten.add.Tensor %640, %arg2, %arg257 : !torch.vtensor<[2],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[2],f32> + %642 = torch.aten.mul.Tensor %arg263, %arg60 : !torch.vtensor<[2],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2],f32> + %643 = torch.aten.add.Tensor %642, %637, %arg262 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float -> !torch.vtensor<[2],f32> + %644 = torch.aten.addcdiv %arg233, %643, %641, %arg256 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float -> !torch.vtensor<[2],f32> + %645 = torch.aten.zero.functional %637 : !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32> + %646 = torch.aten.zero.functional %624 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> + %647 = torch.aten.zero.functional %609 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %648 = torch.aten.zero.functional %596 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %649 = torch.aten.zero.functional %541 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %650 = torch.aten.zero.functional %549 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %651 = torch.aten.zero.functional %533 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %652 = torch.aten.zero.functional %520 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %653 = torch.aten.zero.functional %502 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %654 = torch.aten.zero.functional %489 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %655 = torch.aten.zero.functional %463 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %656 = torch.aten.zero.functional %471 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %657 = torch.aten.zero.functional %455 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %658 = torch.aten.zero.functional %442 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %659 = torch.aten.zero.functional %418 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %660 = torch.aten.zero.functional %405 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %661 = torch.aten.zero.functional %387 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %662 = torch.aten.zero.functional %374 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %663 = torch.aten.zero.functional %356 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %664 = torch.aten.zero.functional %343 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %665 = torch.aten.zero.functional %317 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %666 = torch.aten.zero.functional %325 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %667 = torch.aten.zero.functional %300 : !torch.vtensor<[512,32],f32> -> !torch.vtensor<[512,32],f32> + %668 = torch.aten.zero.functional %309 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> + %669 = torch.aten.zero.functional %289 : !torch.vtensor<[28996,32],f32> -> !torch.vtensor<[28996,32],f32> + return %296, %307, %316, %324, %332, %350, %363, %381, %394, %412, %425, %449, %462, %470, %478, %496, %509, %527, %540, %548, %556, %603, %616, %631, %644, %645, %646, %647, %648, %649, %650, %651, %652, %653, %654, %655, %656, %657, %658, %659, %660, %661, %662, %663, %664, %665, %666, %667, %668, %669, %295, %291, %306, %302, %315, %311, %323, %319, %331, %327, %349, %345, %362, %358, %380, %376, %393, %389, %411, %407, %424, %420, %448, %444, %461, %457, %469, %465, %477, %473, %495, %491, %508, %504, %526, %522, %539, %535, %547, %543, %555, %551, %602, %598, %615, %611, %630, %626, %643, %639, %arg234, %arg5, %arg12, %arg26, %573, %output : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[],f32> +} diff --git a/e2e_testing/lazy_tensor_core/main.py b/e2e_testing/lazy_tensor_core/main.py new file mode 100644 index 00000000000..8561874e9bf --- /dev/null +++ b/e2e_testing/lazy_tensor_core/main.py @@ -0,0 +1,35 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend + +import unittest + +# Example models +import ltc_backend_bert +import ltc_backend_mnist +import os +import pathlib + + +class LTCTests(unittest.TestCase): + def run_test(self, run_model, mlir_path): + run_model() + + # Compare the generated MLIR with a known good output. + with open(os.path.join(pathlib.Path(__file__).parent.resolve(), mlir_path), 'r') as file: + self.assertEqual(ltc_backend.get_latest_computation().to_string(), file.read()) + + def test_bert(self): + self.run_test(ltc_backend_bert.main, 'bert.mlir') + + def test_mnist(self): + self.run_test(ltc_backend_mnist.main, 'mnist.mlir') + + +if __name__ == '__main__': + ltc_backend._initialize() + + unittest.main() diff --git a/e2e_testing/lazy_tensor_core/mnist.mlir b/e2e_testing/lazy_tensor_core/mnist.mlir new file mode 100644 index 00000000000..21df4e8eff6 --- /dev/null +++ b/e2e_testing/lazy_tensor_core/mnist.mlir @@ -0,0 +1,55 @@ +func.func @graph(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1],si64>, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.vtensor<[10,5],f32>, %arg7: !torch.vtensor<[10],f32>, %arg8: !torch.vtensor<[],f32>, %arg9: !torch.float) -> (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[10,5],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[1,10],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[10,5],f32>) { + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.prim.ListConstruct %int1, %int0 : (!torch.int, !torch.int) -> !torch.list + %int1_0 = torch.constant.int 1 + %int0_1 = torch.constant.int 0 + %1 = torch.prim.ListConstruct %int1_0, %int0_1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.aten.permute %arg6, %1 : !torch.vtensor<[10,5],f32>, !torch.list -> !torch.vtensor<[5,10],f32> + %3 = torch.aten.addmm %arg7, %arg0, %2, %arg5, %arg4 : !torch.vtensor<[10],f32>, !torch.vtensor<[1,5],f32>, !torch.vtensor<[5,10],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,10],f32> + %4 = torch.aten.relu %3 : !torch.vtensor<[1,10],f32> -> !torch.vtensor<[1,10],f32> + %int1_2 = torch.constant.int 1 + %false = torch.constant.bool false + %5 = torch.aten._log_softmax %4, %int1_2, %false : !torch.vtensor<[1,10],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,10],f32> + %none = torch.constant.none + %int1_3 = torch.constant.int 1 + %int-100 = torch.constant.int -100 + %output, %total_weight = torch.aten.nll_loss_forward %5, %arg1, %none, %int1_3, %int-100 : !torch.vtensor<[1,10],f32>, !torch.vtensor<[1],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32> + %6 = torch.prim.TupleConstruct %output, %total_weight : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.tuple, vtensor<[],f32>> + %none_4 = torch.constant.none + %int1_5 = torch.constant.int 1 + %int-100_6 = torch.constant.int -100 + %7 = torch.aten.nll_loss_backward %arg8, %5, %arg1, %none_4, %int1_5, %int-100_6, %total_weight : !torch.vtensor<[],f32>, !torch.vtensor<[1,10],f32>, !torch.vtensor<[1],si64>, !torch.none, !torch.int, !torch.int, !torch.vtensor<[],f32> -> !torch.vtensor<[1,10],f32> + %int1_7 = torch.constant.int 1 + %int6 = torch.constant.int 6 + %8 = torch.aten._log_softmax_backward_data %7, %5, %int1_7, %int6 : !torch.vtensor<[1,10],f32>, !torch.vtensor<[1,10],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,10],f32> + %9 = torch.aten.threshold_backward %8, %4, %arg3 : !torch.vtensor<[1,10],f32>, !torch.vtensor<[1,10],f32>, !torch.int -> !torch.vtensor<[1,10],f32> + %int1_8 = torch.constant.int 1 + %int0_9 = torch.constant.int 0 + %10 = torch.prim.ListConstruct %int1_8, %int0_9 : (!torch.int, !torch.int) -> !torch.list + %int1_10 = torch.constant.int 1 + %int0_11 = torch.constant.int 0 + %11 = torch.prim.ListConstruct %int1_10, %int0_11 : (!torch.int, !torch.int) -> !torch.list + %12 = torch.aten.permute %arg0, %11 : !torch.vtensor<[1,5],f32>, !torch.list -> !torch.vtensor<[5,1],f32> + %13 = torch.aten.mm %12, %9 : !torch.vtensor<[5,1],f32>, !torch.vtensor<[1,10],f32> -> !torch.vtensor<[5,10],f32> + %int1_12 = torch.constant.int 1 + %int0_13 = torch.constant.int 0 + %14 = torch.prim.ListConstruct %int1_12, %int0_13 : (!torch.int, !torch.int) -> !torch.list + %int1_14 = torch.constant.int 1 + %int0_15 = torch.constant.int 0 + %15 = torch.prim.ListConstruct %int1_14, %int0_15 : (!torch.int, !torch.int) -> !torch.list + %16 = torch.aten.permute %13, %15 : !torch.vtensor<[5,10],f32>, !torch.list -> !torch.vtensor<[10,5],f32> + %17 = torch.aten.add.Tensor %arg6, %16, %arg2 : !torch.vtensor<[10,5],f32>, !torch.vtensor<[10,5],f32>, !torch.float -> !torch.vtensor<[10,5],f32> + %int0_16 = torch.constant.int 0 + %18 = torch.prim.ListConstruct %int0_16 : (!torch.int) -> !torch.list + %true = torch.constant.bool true + %none_17 = torch.constant.none + %19 = torch.aten.sum.dim_IntList %9, %18, %true, %none_17 : !torch.vtensor<[1,10],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,10],f32> + %int10 = torch.constant.int 10 + %20 = torch.prim.ListConstruct %int10 : (!torch.int) -> !torch.list + %int10_18 = torch.constant.int 10 + %21 = torch.prim.ListConstruct %int10_18 : (!torch.int) -> !torch.list + %22 = torch.aten.reshape %19, %21 : !torch.vtensor<[1,10],f32>, !torch.list -> !torch.vtensor<[10],f32> + %23 = torch.aten.add.Tensor %arg7, %22, %arg9 : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + return %arg0, %arg1, %17, %23, %4, %output, %22, %16 : !torch.vtensor<[1,5],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[10,5],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[1,10],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[10,5],f32> +} diff --git a/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.cpp b/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.cpp index 8f5b507cdc1..1b51346a6f0 100644 --- a/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.cpp +++ b/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.cpp @@ -12,8 +12,8 @@ #include #include -#include #include +#include #include #include #include @@ -60,10 +60,13 @@ class ExampleMlirBackendImpl : public torch::lazy::TorchMlirBackendImpl { // Vendor backend specific lowering can be exec here before returning. for (const auto &instance : instances) { - std::cout << "Instance received at Compile: \n" - << GetComputationBackendText(instance) << std::endl; + // Store computation instance for external access after compilation. + GetLatestComputation() = instance; } + std::cout << "Received " << instances.size() + << " computation instances at Compile!" << std::endl; + return instances; } @@ -133,9 +136,13 @@ class ExampleMlirBackendImpl : public torch::lazy::TorchMlirBackendImpl { * */ std::string GetComputationBackendText(const ComputationPtr computation) const override { - auto mlir_computation = - static_cast(computation.get()); - return mlir_computation->to_string(); + // Store computation instance for external access after compilation. + // We do this in GetComputationBackendText since there may be instances + // where a ComputationPtr does not pass through Compile (e.g. when using + // DumpUtil::ToBackend.) + GetLatestComputation() = computation; + + return computation->to_string(); } private: @@ -154,5 +161,11 @@ void InitExampleMlirBackend() { g_registrar.reset(new BackendRegistrar(GetExampleMlirBackendImpl())); } +ComputationPtr &GetLatestComputation() { + // Store the computation from the most recent compile. + static ComputationPtr computation; + return computation; +} + } // namespace lazy } // namespace torch diff --git a/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.h b/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.h index 377ae4d219f..4c915fa9fdd 100644 --- a/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.h +++ b/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.h @@ -23,5 +23,7 @@ torch::lazy::BackendImplInterface *GetExampleMlirBackendImpl(); void InitExampleMlirBackend(); +ComputationPtr &GetLatestComputation(); + } // namespace lazy } // namespace torch diff --git a/examples/ltc_backend/ltc_backend/csrc/example_mlir_backend_pybind.cpp b/examples/ltc_backend/ltc_backend/csrc/example_mlir_backend_pybind.cpp index 1474b4dc907..ff1aa766642 100644 --- a/examples/ltc_backend/ltc_backend/csrc/example_mlir_backend_pybind.cpp +++ b/examples/ltc_backend/ltc_backend/csrc/example_mlir_backend_pybind.cpp @@ -10,6 +10,8 @@ #include "torch/csrc/jit/python/pybind.h" #include "torch/csrc/lazy/backend/backend_interface.h" +#include + #include #include #include @@ -61,7 +63,16 @@ void Shutdown() { } // anonymous namespace PYBIND11_MODULE(_EXAMPLE_MLIR_BACKEND, m) { + py::class_(m, "TorchMlirComputation") + .def("to_string", &torch::lazy::TorchMlirComputation::to_string) + .def("debug_string", &torch::lazy::TorchMlirComputation::debug_string); + m.doc() = ("pybind11 for example MLIR LTC backend."); + m.def("get_latest_computation", []() { + auto computation = static_cast( + torch::lazy::GetLatestComputation().get()); + return py::cast(computation); + }); m.def("_initialize", []() { NoGilSection gil; Initialize(); diff --git a/examples/ltc_backend_bert.py b/examples/ltc_backend_bert.py index d8434f5ef14..a43caddd62a 100644 --- a/examples/ltc_backend_bert.py +++ b/examples/ltc_backend_bert.py @@ -14,7 +14,12 @@ """ import argparse +import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend +import sys import torch +import torch._C +import torch._lazy +import torch._lazy.ts_backend from datasets import load_dataset from datasets.dataset_dict import DatasetDict from torch.utils.data import DataLoader @@ -42,8 +47,7 @@ def train(model: BertForSequenceClassification, num_epochs: int, num_training_steps: int, train_dataloader: DataLoader, - device: torch.device, - do_mark_step: bool) -> List[torch.Tensor]: + device: torch.device) -> List[torch.Tensor]: optimizer = AdamW(model.parameters(), lr=5e-5) lr_scheduler = get_scheduler('linear', optimizer=optimizer, num_warmup_steps=0, @@ -63,31 +67,20 @@ def train(model: BertForSequenceClassification, lr_scheduler.step() optimizer.zero_grad() - if do_mark_step and 'lazy' in str(model.device): + if 'lazy' in str(model.device): print("Calling Mark Step") torch._lazy.mark_step() return losses -def main(device, lower_only, full_size): - if device in ("TS", "MLIR_EXAMPLE"): - import torch._lazy +def main(device='lazy', full_size=False): + """ + Load model to specified device. Ensure that any backends have been initialized by this point. - if device == "TS": - import torch._lazy.ts_backend - - torch._lazy.ts_backend.init() - - elif device == "MLIR_EXAMPLE": - import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend - - ltc_backend._initialize() - - device = "lazy" - print("Initialized backend") - else: - device = device.lower() + :param device: name of device to load tensors to + :param full_size: if true, use a full pretrained bert-base-cased model instead of a smaller variant + """ tokenized_datasets = tokenize_dataset(load_dataset('imdb')) small_train_dataset = tokenized_datasets['train'].shuffle(seed=42) \ @@ -117,18 +110,13 @@ def main(device, lower_only, full_size): num_epochs = 3 num_training_steps = num_epochs * len(train_dataloader) - losses = train(model, num_epochs, - num_training_steps, train_dataloader, device, not lower_only) - - if lower_only: - print('\nJIT Graph:') - import torch._C - graph_str = torch._C._lazy._get_tensors_backend([losses[0]]) - print(graph_str) - else: - # Execute computation - print('Loss: ', losses) + losses = train(model, num_epochs, num_training_steps, train_dataloader, device) + + # Get debug information from LTC + if 'ltc_backend' in sys.modules: + print(ltc_backend.get_latest_computation().debug_string()) + print('Loss: ', losses) if __name__ == "__main__": torch.manual_seed(0) @@ -142,13 +130,6 @@ def main(device, lower_only, full_size): default="MLIR_EXAMPLE", help="The device type", ) - parser.add_argument( - "-l", - "--lower_only", - action='store_true', - default=False, - help="Only get backend printout -- do not execute computation", - ) parser.add_argument( "-f", "--full_size", @@ -157,4 +138,17 @@ def main(device, lower_only, full_size): help="Use full sized BERT model instead of one with smaller parameterization", ) args = parser.parse_args() - main(args.device, args.lower_only, args.full_size) + + if args.device in ("TS", "MLIR_EXAMPLE"): + if args.device == "TS": + torch._lazy.ts_backend.init() + + elif args.device == "MLIR_EXAMPLE": + ltc_backend._initialize() + + device = "lazy" + print("Initialized backend") + else: + device = args.device.lower() + + main(device, args.full_size) diff --git a/examples/ltc_backend_mnist.py b/examples/ltc_backend_mnist.py index 7448bbc0b34..a40980dde6b 100644 --- a/examples/ltc_backend_mnist.py +++ b/examples/ltc_backend_mnist.py @@ -6,31 +6,20 @@ Example use of the example Torch MLIR LTC backend. """ import argparse - +import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend +import sys +import torch +import torch._lazy +import torch._lazy.ts_backend import torch.nn.functional as F -def main(device): - import torch - - if device in ("TS", "MLIR_EXAMPLE"): - import torch._lazy - - if device == "TS": - import torch._lazy.ts_backend - - torch._lazy.ts_backend.init() - - elif device == "MLIR_EXAMPLE": - import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend - - ltc_backend._initialize() - - device = "lazy" - print("Initialized backend") - else: - device = device.lower() +def main(device='lazy'): + """ + Load model to specified device. Ensure that any backends have been initialized by this point. + :param device: name of device to load tensors to + """ inputs = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.float32, device=device) assert inputs.device.type == device @@ -71,6 +60,10 @@ def forward(self, x): print() print(loss) + # Get debug information from LTC + if 'ltc_backend' in sys.modules: + print(ltc_backend.get_latest_computation().debug_string()) + if __name__ == "__main__": torch.manual_seed(0) @@ -85,4 +78,17 @@ def forward(self, x): help="The device type", ) args = parser.parse_args() - main(args.device) + + if args.device in ("TS", "MLIR_EXAMPLE"): + if args.device == "TS": + torch._lazy.ts_backend.init() + + elif args.device == "MLIR_EXAMPLE": + ltc_backend._initialize() + + device = "lazy" + print("Initialized backend") + else: + device = args.device.lower() + + main(device) diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp index 2ebb963bd13..c024b73e076 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -308,24 +308,14 @@ std::shared_ptr TorchMlirComputation::graph() const { MlirOperation TorchMlirComputation::func_op() const { return func_op_; } -const std::string TorchMlirComputation::to_string() const { - // Since we use the C-MLIR API, we need to use a callback to print. - MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) { - // user_data is a void ptr to some data structure of our choice -- in this - // case, the string stream where we'll be accumulating the strings. - std::stringstream* ss_ptr = static_cast(user_data); - *ss_ptr << std::string(part.data, part.length); - }; - +const std::string TorchMlirComputation::debug_string() const { std::stringstream ss; // JIT Graph ss << "JIT Graph: \n" << graph_->toString() << "\n\n"; // MLIR - ss << "MLIR: \n"; - mlirOperationPrint(func_op_, print_callback, &ss); - ss << "\n"; + ss << "MLIR: \n" << to_string() << "\n"; // Input/Output Mapping ss << "Input/Output Alias Mapping: \n"; @@ -341,5 +331,18 @@ const std::string TorchMlirComputation::to_string() const { return ss.str(); } +const std::string TorchMlirComputation::to_string() const { + // Since we use the C-MLIR API, we need to use a callback to print. + MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) { + // user_data is a void ptr to some data structure of our choice -- in this + // case, the string stream where we'll be accumulating the strings. + std::stringstream* ss_ptr = static_cast(user_data); + *ss_ptr << std::string(part.data, part.length); + }; + std::stringstream ss; + mlirOperationPrint(func_op_, print_callback, &ss); + return ss.str(); +} + } // namespace lazy } // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h index 4a025b5bb9e..5738d23837d 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h @@ -149,6 +149,8 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation { MlirOperation func_op() const; + const std::string debug_string() const; + const std::string to_string() const; private: From ef3f32579d56c577173ba4be88f33458e120cbb1 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Wed, 1 Jun 2022 12:45:46 -0400 Subject: [PATCH 02/13] Fix seed for reproducability --- examples/ltc_backend_bert.py | 6 ++++-- examples/ltc_backend_mnist.py | 11 ++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/ltc_backend_bert.py b/examples/ltc_backend_bert.py index a43caddd62a..70040ebb5fd 100644 --- a/examples/ltc_backend_bert.py +++ b/examples/ltc_backend_bert.py @@ -14,12 +14,10 @@ """ import argparse -import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend import sys import torch import torch._C import torch._lazy -import torch._lazy.ts_backend from datasets import load_dataset from datasets.dataset_dict import DatasetDict from torch.utils.data import DataLoader @@ -81,6 +79,7 @@ def main(device='lazy', full_size=False): :param device: name of device to load tensors to :param full_size: if true, use a full pretrained bert-base-cased model instead of a smaller variant """ + torch.manual_seed(0) tokenized_datasets = tokenize_dataset(load_dataset('imdb')) small_train_dataset = tokenized_datasets['train'].shuffle(seed=42) \ @@ -118,6 +117,7 @@ def main(device='lazy', full_size=False): print('Loss: ', losses) + if __name__ == "__main__": torch.manual_seed(0) @@ -141,9 +141,11 @@ def main(device='lazy', full_size=False): if args.device in ("TS", "MLIR_EXAMPLE"): if args.device == "TS": + import torch._lazy.ts_backend torch._lazy.ts_backend.init() elif args.device == "MLIR_EXAMPLE": + import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend ltc_backend._initialize() device = "lazy" diff --git a/examples/ltc_backend_mnist.py b/examples/ltc_backend_mnist.py index a40980dde6b..2568e10e610 100644 --- a/examples/ltc_backend_mnist.py +++ b/examples/ltc_backend_mnist.py @@ -6,11 +6,9 @@ Example use of the example Torch MLIR LTC backend. """ import argparse -import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend import sys import torch import torch._lazy -import torch._lazy.ts_backend import torch.nn.functional as F @@ -20,6 +18,8 @@ def main(device='lazy'): :param device: name of device to load tensors to """ + torch.manual_seed(0) + inputs = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.float32, device=device) assert inputs.device.type == device @@ -57,13 +57,12 @@ def forward(self, x): print("Calling Mark Step") torch._lazy.mark_step() - print() - print(loss) - # Get debug information from LTC if 'ltc_backend' in sys.modules: print(ltc_backend.get_latest_computation().debug_string()) + print(loss) + if __name__ == "__main__": torch.manual_seed(0) @@ -81,9 +80,11 @@ def forward(self, x): if args.device in ("TS", "MLIR_EXAMPLE"): if args.device == "TS": + import torch._lazy.ts_backend torch._lazy.ts_backend.init() elif args.device == "MLIR_EXAMPLE": + import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend ltc_backend._initialize() device = "lazy" From 4f7f99956a9c8a69ae0672a31de2b67bb7714cbb Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Wed, 1 Jun 2022 13:31:03 -0400 Subject: [PATCH 03/13] Check if computation is None before getting debug string --- examples/ltc_backend_bert.py | 11 +++++++---- examples/ltc_backend_mnist.py | 9 ++++++--- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/ltc_backend_bert.py b/examples/ltc_backend_bert.py index 70040ebb5fd..0b808328efd 100644 --- a/examples/ltc_backend_bert.py +++ b/examples/ltc_backend_bert.py @@ -15,15 +15,18 @@ import argparse import sys +from typing import List + +import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend import torch import torch._C import torch._lazy +import torch._lazy.ts_backend from datasets import load_dataset from datasets.dataset_dict import DatasetDict from torch.utils.data import DataLoader from transformers import BertForSequenceClassification, \ BertConfig, BertTokenizer, AdamW, get_scheduler -from typing import List def tokenize_dataset(dataset: DatasetDict) -> DatasetDict: @@ -113,7 +116,9 @@ def main(device='lazy', full_size=False): # Get debug information from LTC if 'ltc_backend' in sys.modules: - print(ltc_backend.get_latest_computation().debug_string()) + computation = ltc_backend.get_latest_computation() + if computation: + print(computation.debug_string()) print('Loss: ', losses) @@ -141,11 +146,9 @@ def main(device='lazy', full_size=False): if args.device in ("TS", "MLIR_EXAMPLE"): if args.device == "TS": - import torch._lazy.ts_backend torch._lazy.ts_backend.init() elif args.device == "MLIR_EXAMPLE": - import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend ltc_backend._initialize() device = "lazy" diff --git a/examples/ltc_backend_mnist.py b/examples/ltc_backend_mnist.py index 2568e10e610..312ea9fc2dd 100644 --- a/examples/ltc_backend_mnist.py +++ b/examples/ltc_backend_mnist.py @@ -7,8 +7,11 @@ """ import argparse import sys + +import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend import torch import torch._lazy +import torch._lazy.ts_backend import torch.nn.functional as F @@ -59,7 +62,9 @@ def forward(self, x): # Get debug information from LTC if 'ltc_backend' in sys.modules: - print(ltc_backend.get_latest_computation().debug_string()) + computation = ltc_backend.get_latest_computation() + if computation: + print(computation.debug_string()) print(loss) @@ -80,11 +85,9 @@ def forward(self, x): if args.device in ("TS", "MLIR_EXAMPLE"): if args.device == "TS": - import torch._lazy.ts_backend torch._lazy.ts_backend.init() elif args.device == "MLIR_EXAMPLE": - import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend ltc_backend._initialize() device = "lazy" From 9546f105962ebacc2fb9957b16d26fb79837f640 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Mon, 6 Jun 2022 19:31:09 -0400 Subject: [PATCH 04/13] Updated unit tests, and added numeric tests --- e2e_testing/lazy_tensor_core/bert.mlir | 2638 ++++++++++++----------- e2e_testing/lazy_tensor_core/main.py | 45 +- e2e_testing/lazy_tensor_core/mnist.mlir | 26 +- examples/ltc_backend_bert.py | 2 + examples/ltc_backend_mnist.py | 26 +- 5 files changed, 1484 insertions(+), 1253 deletions(-) diff --git a/e2e_testing/lazy_tensor_core/bert.mlir b/e2e_testing/lazy_tensor_core/bert.mlir index d057c034b6a..294d48cd597 100644 --- a/e2e_testing/lazy_tensor_core/bert.mlir +++ b/e2e_testing/lazy_tensor_core/bert.mlir @@ -1,21 +1,21 @@ -func.func @graph(%arg0: !torch.float, %arg1: !torch.int, %arg2: !torch.vtensor<[],f64>, %arg3: !torch.float, %arg4: !torch.int, %arg5: !torch.vtensor<[2,512],si64>, %arg6: !torch.vtensor<[32],f32>, %arg7: !torch.vtensor<[32],f32>, %arg8: !torch.int, %arg9: !torch.vtensor<[1,512],si64>, %arg10: !torch.vtensor<[512,32],f32>, %arg11: !torch.int, %arg12: !torch.vtensor<[2,512],si64>, %arg13: !torch.vtensor<[2,32],f32>, %arg14: !torch.vtensor<[28996,32],f32>, %arg15: !torch.int, %arg16: !torch.vtensor<[32,32],f32>, %arg17: !torch.int, %arg18: !torch.int, %arg19: !torch.vtensor<[32,32],f32>, %arg20: !torch.vtensor<[32],f32>, %arg21: !torch.vtensor<[],f64>, %arg22: !torch.int, %arg23: !torch.vtensor<[],f64>, %arg24: !torch.int, %arg25: !torch.float, %arg26: !torch.vtensor<[2,512],si64>, %arg27: !torch.int, %arg28: !torch.int, %arg29: !torch.vtensor<[32],f32>, %arg30: !torch.int, %arg31: !torch.int, %arg32: !torch.vtensor<[32,32],f32>, %arg33: !torch.vtensor<[32],f32>, %arg34: !torch.vtensor<[32,32],f32>, %arg35: !torch.vtensor<[32],f32>, %arg36: !torch.vtensor<[32],f32>, %arg37: !torch.int, %arg38: !torch.int, %arg39: !torch.int, %arg40: !torch.vtensor<[32],f32>, %arg41: !torch.int, %arg42: !torch.vtensor<[32,32],f32>, %arg43: !torch.int, %arg44: !torch.int, %arg45: !torch.vtensor<[32],f32>, %arg46: !torch.vtensor<[32,32],f32>, %arg47: !torch.vtensor<[32],f32>, %arg48: !torch.vtensor<[32],f32>, %arg49: !torch.int, %arg50: !torch.int, %arg51: !torch.int, %arg52: !torch.vtensor<[32],f32>, %arg53: !torch.vtensor<[2,512,32],f32>, %arg54: !torch.int, %arg55: !torch.int, %arg56: !torch.vtensor<[28996,32],f32>, %arg57: !torch.vtensor<[],f64>, %arg58: !torch.vtensor<[28996,32],f32>, %arg59: !torch.float, %arg60: !torch.vtensor<[],f64>, %arg61: !torch.vtensor<[28996,32],f32>, %arg62: !torch.float, %arg63: !torch.int, %arg64: !torch.float, %arg65: !torch.int, %arg66: !torch.vtensor<[512,32],f32>, %arg67: !torch.vtensor<[512,32],f32>, %arg68: !torch.float, %arg69: !torch.vtensor<[512,32],f32>, %arg70: !torch.float, %arg71: !torch.int, %arg72: !torch.float, %arg73: !torch.int, %arg74: !torch.vtensor<[2,32],f32>, %arg75: !torch.vtensor<[2,32],f32>, %arg76: !torch.float, %arg77: !torch.vtensor<[2,32],f32>, %arg78: !torch.float, %arg79: !torch.int, %arg80: !torch.float, %arg81: !torch.int, %arg82: !torch.vtensor<[32],f32>, %arg83: !torch.vtensor<[32],f32>, %arg84: !torch.float, %arg85: !torch.vtensor<[32],f32>, %arg86: !torch.float, %arg87: !torch.int, %arg88: !torch.float, %arg89: !torch.int, %arg90: !torch.vtensor<[32],f32>, %arg91: !torch.vtensor<[32],f32>, %arg92: !torch.float, %arg93: !torch.vtensor<[32],f32>, %arg94: !torch.float, %arg95: !torch.int, %arg96: !torch.float, %arg97: !torch.int, %arg98: !torch.vtensor<[32,32],f32>, %arg99: !torch.vtensor<[32,32],f32>, %arg100: !torch.float, %arg101: !torch.vtensor<[32,32],f32>, %arg102: !torch.float, %arg103: !torch.int, %arg104: !torch.float, %arg105: !torch.int, %arg106: !torch.vtensor<[32],f32>, %arg107: !torch.vtensor<[32],f32>, %arg108: !torch.float, %arg109: !torch.vtensor<[32],f32>, %arg110: !torch.float, %arg111: !torch.int, %arg112: !torch.float, %arg113: !torch.int, %arg114: !torch.vtensor<[32,32],f32>, %arg115: !torch.vtensor<[32,32],f32>, %arg116: !torch.float, %arg117: !torch.vtensor<[32,32],f32>, %arg118: !torch.float, %arg119: !torch.int, %arg120: !torch.float, %arg121: !torch.int, %arg122: !torch.vtensor<[32],f32>, %arg123: !torch.vtensor<[32],f32>, %arg124: !torch.float, %arg125: !torch.vtensor<[32],f32>, %arg126: !torch.float, %arg127: !torch.int, %arg128: !torch.float, %arg129: !torch.int, %arg130: !torch.vtensor<[32,32],f32>, %arg131: !torch.vtensor<[32,32],f32>, %arg132: !torch.float, %arg133: !torch.vtensor<[32,32],f32>, %arg134: !torch.float, %arg135: !torch.int, %arg136: !torch.float, %arg137: !torch.int, %arg138: !torch.vtensor<[32],f32>, %arg139: !torch.vtensor<[32],f32>, %arg140: !torch.float, %arg141: !torch.vtensor<[32],f32>, %arg142: !torch.float, %arg143: !torch.int, %arg144: !torch.float, %arg145: !torch.int, %arg146: !torch.vtensor<[32,32],f32>, %arg147: !torch.vtensor<[32,32],f32>, %arg148: !torch.float, %arg149: !torch.vtensor<[32,32],f32>, %arg150: !torch.float, %arg151: !torch.int, %arg152: !torch.float, %arg153: !torch.int, %arg154: !torch.vtensor<[32],f32>, %arg155: !torch.vtensor<[32],f32>, %arg156: !torch.float, %arg157: !torch.vtensor<[32],f32>, %arg158: !torch.float, %arg159: !torch.int, %arg160: !torch.float, %arg161: !torch.int, %arg162: !torch.vtensor<[32],f32>, %arg163: !torch.vtensor<[32],f32>, %arg164: !torch.float, %arg165: !torch.vtensor<[32],f32>, %arg166: !torch.float, %arg167: !torch.int, %arg168: !torch.float, %arg169: !torch.int, %arg170: !torch.vtensor<[32],f32>, %arg171: !torch.vtensor<[32],f32>, %arg172: !torch.float, %arg173: !torch.vtensor<[32],f32>, %arg174: !torch.float, %arg175: !torch.int, %arg176: !torch.float, %arg177: !torch.int, %arg178: !torch.vtensor<[32,32],f32>, %arg179: !torch.vtensor<[32,32],f32>, %arg180: !torch.float, %arg181: !torch.vtensor<[32,32],f32>, %arg182: !torch.float, %arg183: !torch.int, %arg184: !torch.float, %arg185: !torch.int, %arg186: !torch.vtensor<[32],f32>, %arg187: !torch.vtensor<[32],f32>, %arg188: !torch.float, %arg189: !torch.vtensor<[32],f32>, %arg190: !torch.float, %arg191: !torch.int, %arg192: !torch.float, %arg193: !torch.int, %arg194: !torch.vtensor<[32,32],f32>, %arg195: !torch.vtensor<[32,32],f32>, %arg196: !torch.float, %arg197: !torch.vtensor<[32,32],f32>, %arg198: !torch.float, %arg199: !torch.int, %arg200: !torch.float, %arg201: !torch.int, %arg202: !torch.vtensor<[32],f32>, %arg203: !torch.vtensor<[32],f32>, %arg204: !torch.float, %arg205: !torch.vtensor<[32],f32>, %arg206: !torch.float, %arg207: !torch.int, %arg208: !torch.float, %arg209: !torch.int, %arg210: !torch.vtensor<[32],f32>, %arg211: !torch.vtensor<[32],f32>, %arg212: !torch.float, %arg213: !torch.vtensor<[32],f32>, %arg214: !torch.float, %arg215: !torch.int, %arg216: !torch.float, %arg217: !torch.int, %arg218: !torch.vtensor<[32],f32>, %arg219: !torch.vtensor<[32],f32>, %arg220: !torch.float, %arg221: !torch.vtensor<[32],f32>, %arg222: !torch.float, %arg223: !torch.int, %arg224: !torch.float, %arg225: !torch.int, %arg226: !torch.int, %arg227: !torch.int, %arg228: !torch.vtensor<[32,32],f32>, %arg229: !torch.vtensor<[32],f32>, %arg230: !torch.vtensor<[2,32],f32>, %arg231: !torch.int, %arg232: !torch.int, %arg233: !torch.vtensor<[2],f32>, %arg234: !torch.vtensor<[2],si64>, %arg235: !torch.vtensor<[],f32>, %arg236: !torch.vtensor<[32,32],f32>, %arg237: !torch.vtensor<[32,32],f32>, %arg238: !torch.float, %arg239: !torch.vtensor<[32,32],f32>, %arg240: !torch.float, %arg241: !torch.int, %arg242: !torch.float, %arg243: !torch.int, %arg244: !torch.vtensor<[32],f32>, %arg245: !torch.vtensor<[32],f32>, %arg246: !torch.float, %arg247: !torch.vtensor<[32],f32>, %arg248: !torch.float, %arg249: !torch.int, %arg250: !torch.float, %arg251: !torch.int, %arg252: !torch.vtensor<[2,32],f32>, %arg253: !torch.vtensor<[2,32],f32>, %arg254: !torch.float, %arg255: !torch.vtensor<[2,32],f32>, %arg256: !torch.float, %arg257: !torch.int, %arg258: !torch.float, %arg259: !torch.int, %arg260: !torch.vtensor<[2],f32>, %arg261: !torch.vtensor<[2],f32>, %arg262: !torch.float, %arg263: !torch.vtensor<[2],f32>) -> (!torch.vtensor<[28996,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[],f32>) { +func.func @graph(%arg0: !torch.float, %arg1: !torch.int, %arg2: !torch.vtensor<[],f64>, %arg3: !torch.float, %arg4: !torch.int, %arg5: !torch.vtensor<[2,512],si64>, %arg6: !torch.vtensor<[32],f32>, %arg7: !torch.vtensor<[32],f32>, %arg8: !torch.int, %arg9: !torch.vtensor<[1,512],si64>, %arg10: !torch.vtensor<[512,32],f32>, %arg11: !torch.int, %arg12: !torch.vtensor<[2,512],si64>, %arg13: !torch.vtensor<[2,32],f32>, %arg14: !torch.vtensor<[28996,32],f32>, %arg15: !torch.int, %arg16: !torch.vtensor<[32,32],f32>, %arg17: !torch.int, %arg18: !torch.int, %arg19: !torch.vtensor<[32,32],f32>, %arg20: !torch.vtensor<[32],f32>, %arg21: !torch.vtensor<[],f64>, %arg22: !torch.int, %arg23: !torch.vtensor<[],f64>, %arg24: !torch.int, %arg25: !torch.vtensor<[2,512],si64>, %arg26: !torch.vtensor<[],f64>, %arg27: !torch.int, %arg28: !torch.int, %arg29: !torch.vtensor<[32],f32>, %arg30: !torch.int, %arg31: !torch.int, %arg32: !torch.vtensor<[32,32],f32>, %arg33: !torch.vtensor<[32],f32>, %arg34: !torch.vtensor<[32,32],f32>, %arg35: !torch.vtensor<[32],f32>, %arg36: !torch.vtensor<[32],f32>, %arg37: !torch.int, %arg38: !torch.int, %arg39: !torch.int, %arg40: !torch.vtensor<[32],f32>, %arg41: !torch.int, %arg42: !torch.vtensor<[32,32],f32>, %arg43: !torch.int, %arg44: !torch.int, %arg45: !torch.vtensor<[32],f32>, %arg46: !torch.vtensor<[32,32],f32>, %arg47: !torch.vtensor<[32],f32>, %arg48: !torch.vtensor<[32],f32>, %arg49: !torch.int, %arg50: !torch.int, %arg51: !torch.int, %arg52: !torch.vtensor<[32],f32>, %arg53: !torch.vtensor<[32,32],f32>, %arg54: !torch.int, %arg55: !torch.int, %arg56: !torch.vtensor<[32],f32>, %arg57: !torch.vtensor<[2,32],f32>, %arg58: !torch.int, %arg59: !torch.int, %arg60: !torch.vtensor<[2],f32>, %arg61: !torch.vtensor<[2],si64>, %arg62: !torch.vtensor<[],f32>, %arg63: !torch.vtensor<[2,512,32],f32>, %arg64: !torch.vtensor<[2,512,32],f32>, %arg65: !torch.int, %arg66: !torch.int, %arg67: !torch.vtensor<[28996,32],f32>, %arg68: !torch.vtensor<[],f64>, %arg69: !torch.vtensor<[28996,32],f32>, %arg70: !torch.float, %arg71: !torch.vtensor<[],f64>, %arg72: !torch.vtensor<[28996,32],f32>, %arg73: !torch.float, %arg74: !torch.int, %arg75: !torch.float, %arg76: !torch.int, %arg77: !torch.vtensor<[512,32],f32>, %arg78: !torch.vtensor<[512,32],f32>, %arg79: !torch.float, %arg80: !torch.vtensor<[512,32],f32>, %arg81: !torch.float, %arg82: !torch.int, %arg83: !torch.float, %arg84: !torch.int, %arg85: !torch.vtensor<[2,32],f32>, %arg86: !torch.vtensor<[2,32],f32>, %arg87: !torch.float, %arg88: !torch.vtensor<[2,32],f32>, %arg89: !torch.float, %arg90: !torch.int, %arg91: !torch.float, %arg92: !torch.int, %arg93: !torch.vtensor<[32],f32>, %arg94: !torch.vtensor<[32],f32>, %arg95: !torch.float, %arg96: !torch.vtensor<[32],f32>, %arg97: !torch.float, %arg98: !torch.int, %arg99: !torch.float, %arg100: !torch.int, %arg101: !torch.vtensor<[32],f32>, %arg102: !torch.vtensor<[32],f32>, %arg103: !torch.float, %arg104: !torch.vtensor<[32],f32>, %arg105: !torch.float, %arg106: !torch.int, %arg107: !torch.float, %arg108: !torch.int, %arg109: !torch.vtensor<[32,32],f32>, %arg110: !torch.vtensor<[32,32],f32>, %arg111: !torch.float, %arg112: !torch.vtensor<[32,32],f32>, %arg113: !torch.float, %arg114: !torch.int, %arg115: !torch.float, %arg116: !torch.int, %arg117: !torch.vtensor<[32],f32>, %arg118: !torch.vtensor<[32],f32>, %arg119: !torch.float, %arg120: !torch.vtensor<[32],f32>, %arg121: !torch.float, %arg122: !torch.int, %arg123: !torch.float, %arg124: !torch.int, %arg125: !torch.vtensor<[32,32],f32>, %arg126: !torch.vtensor<[32,32],f32>, %arg127: !torch.float, %arg128: !torch.vtensor<[32,32],f32>, %arg129: !torch.float, %arg130: !torch.int, %arg131: !torch.float, %arg132: !torch.int, %arg133: !torch.vtensor<[32],f32>, %arg134: !torch.vtensor<[32],f32>, %arg135: !torch.float, %arg136: !torch.vtensor<[32],f32>, %arg137: !torch.float, %arg138: !torch.int, %arg139: !torch.float, %arg140: !torch.int, %arg141: !torch.vtensor<[32,32],f32>, %arg142: !torch.vtensor<[32,32],f32>, %arg143: !torch.float, %arg144: !torch.vtensor<[32,32],f32>, %arg145: !torch.float, %arg146: !torch.int, %arg147: !torch.float, %arg148: !torch.int, %arg149: !torch.vtensor<[32],f32>, %arg150: !torch.vtensor<[32],f32>, %arg151: !torch.float, %arg152: !torch.vtensor<[32],f32>, %arg153: !torch.float, %arg154: !torch.int, %arg155: !torch.float, %arg156: !torch.int, %arg157: !torch.vtensor<[32,32],f32>, %arg158: !torch.vtensor<[32,32],f32>, %arg159: !torch.float, %arg160: !torch.vtensor<[32,32],f32>, %arg161: !torch.float, %arg162: !torch.int, %arg163: !torch.float, %arg164: !torch.int, %arg165: !torch.vtensor<[32],f32>, %arg166: !torch.vtensor<[32],f32>, %arg167: !torch.float, %arg168: !torch.vtensor<[32],f32>, %arg169: !torch.float, %arg170: !torch.int, %arg171: !torch.float, %arg172: !torch.int, %arg173: !torch.vtensor<[32],f32>, %arg174: !torch.vtensor<[32],f32>, %arg175: !torch.float, %arg176: !torch.vtensor<[32],f32>, %arg177: !torch.float, %arg178: !torch.int, %arg179: !torch.float, %arg180: !torch.int, %arg181: !torch.vtensor<[32],f32>, %arg182: !torch.vtensor<[32],f32>, %arg183: !torch.float, %arg184: !torch.vtensor<[32],f32>, %arg185: !torch.float, %arg186: !torch.int, %arg187: !torch.float, %arg188: !torch.int, %arg189: !torch.vtensor<[32,32],f32>, %arg190: !torch.vtensor<[32,32],f32>, %arg191: !torch.float, %arg192: !torch.vtensor<[32,32],f32>, %arg193: !torch.float, %arg194: !torch.int, %arg195: !torch.float, %arg196: !torch.int, %arg197: !torch.vtensor<[32],f32>, %arg198: !torch.vtensor<[32],f32>, %arg199: !torch.float, %arg200: !torch.vtensor<[32],f32>, %arg201: !torch.float, %arg202: !torch.int, %arg203: !torch.float, %arg204: !torch.int, %arg205: !torch.vtensor<[32,32],f32>, %arg206: !torch.vtensor<[32,32],f32>, %arg207: !torch.float, %arg208: !torch.vtensor<[32,32],f32>, %arg209: !torch.float, %arg210: !torch.int, %arg211: !torch.float, %arg212: !torch.int, %arg213: !torch.vtensor<[32],f32>, %arg214: !torch.vtensor<[32],f32>, %arg215: !torch.float, %arg216: !torch.vtensor<[32],f32>, %arg217: !torch.float, %arg218: !torch.int, %arg219: !torch.float, %arg220: !torch.int, %arg221: !torch.vtensor<[32],f32>, %arg222: !torch.vtensor<[32],f32>, %arg223: !torch.float, %arg224: !torch.vtensor<[32],f32>, %arg225: !torch.float, %arg226: !torch.int, %arg227: !torch.float, %arg228: !torch.int, %arg229: !torch.vtensor<[32],f32>, %arg230: !torch.vtensor<[32],f32>, %arg231: !torch.float, %arg232: !torch.vtensor<[32],f32>, %arg233: !torch.float, %arg234: !torch.int, %arg235: !torch.float, %arg236: !torch.int, %arg237: !torch.vtensor<[32,32],f32>, %arg238: !torch.vtensor<[32,32],f32>, %arg239: !torch.float, %arg240: !torch.vtensor<[32,32],f32>, %arg241: !torch.float, %arg242: !torch.int, %arg243: !torch.float, %arg244: !torch.int, %arg245: !torch.vtensor<[32],f32>, %arg246: !torch.vtensor<[32],f32>, %arg247: !torch.float, %arg248: !torch.vtensor<[32],f32>, %arg249: !torch.float, %arg250: !torch.int, %arg251: !torch.float, %arg252: !torch.int, %arg253: !torch.vtensor<[2,32],f32>, %arg254: !torch.vtensor<[2,32],f32>, %arg255: !torch.float, %arg256: !torch.vtensor<[2,32],f32>, %arg257: !torch.float, %arg258: !torch.int, %arg259: !torch.float, %arg260: !torch.int, %arg261: !torch.vtensor<[2],f32>, %arg262: !torch.vtensor<[2],f32>, %arg263: !torch.float, %arg264: !torch.vtensor<[2],f32>) -> (!torch.vtensor<[28996,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[],f32>) { %int0 = torch.constant.int 0 %int0_0 = torch.constant.int 0 - %int9223372036854775807 = torch.constant.int 9223372036854775807 %int1 = torch.constant.int 1 - %0 = torch.aten.slice.Tensor %arg9, %int0, %int0_0, %int9223372036854775807, %int1 : !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,512],si64> + %int1_1 = torch.constant.int 1 + %0 = torch.aten.slice.Tensor %arg9, %int0, %int0_0, %int1, %int1_1 : !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,512],si64> %int-1 = torch.constant.int -1 %false = torch.constant.bool false - %false_1 = torch.constant.bool false - %1 = torch.aten.embedding %arg10, %0, %int-1, %false, %false_1 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[1,512],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,32],f32> - %int-1_2 = torch.constant.int -1 - %false_3 = torch.constant.bool false + %false_2 = torch.constant.bool false + %1 = torch.aten.embedding %arg10, %0, %int-1, %false, %false_2 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[1,512],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,32],f32> + %int-1_3 = torch.constant.int -1 %false_4 = torch.constant.bool false - %2 = torch.aten.embedding %arg13, %arg12, %int-1_2, %false_3, %false_4 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,512],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[2,512,32],f32> - %int0_5 = torch.constant.int 0 - %false_6 = torch.constant.bool false + %false_5 = torch.constant.bool false + %2 = torch.aten.embedding %arg13, %arg12, %int-1_3, %false_4, %false_5 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,512],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[2,512,32],f32> + %int0_6 = torch.constant.int 0 %false_7 = torch.constant.bool false - %3 = torch.aten.embedding %arg14, %arg5, %int0_5, %false_6, %false_7 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[2,512],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[2,512,32],f32> + %false_8 = torch.constant.bool false + %3 = torch.aten.embedding %arg14, %arg5, %int0_6, %false_7, %false_8 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[2,512],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[2,512,32],f32> %4 = torch.aten.add.Tensor %3, %2, %arg11 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> %5 = torch.aten.add.Tensor %4, %1, %arg8 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[1,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> %int32 = torch.constant.int 32 @@ -23,1302 +23,1484 @@ func.func @graph(%arg0: !torch.float, %arg1: !torch.int, %arg2: !torch.vtensor<[ %float1.000000e-05 = torch.constant.float 1.000000e-05 %result0, %result1, %result2 = torch.aten.native_layer_norm %5, %6, %arg7, %arg6, %float1.000000e-05 : !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> %7 = torch.prim.TupleConstruct %result0, %result1, %result2 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> -> !torch.tuple, vtensor<[2,512,1],f32>, vtensor<[2,512,1],f32>> - %int1_8 = torch.constant.int 1 - %int0_9 = torch.constant.int 0 - %8 = torch.prim.ListConstruct %int1_8, %int0_9 : (!torch.int, !torch.int) -> !torch.list - %int1_10 = torch.constant.int 1 - %int0_11 = torch.constant.int 0 - %9 = torch.prim.ListConstruct %int1_10, %int0_11 : (!torch.int, !torch.int) -> !torch.list + %int1_9 = torch.constant.int 1 + %int0_10 = torch.constant.int 0 + %8 = torch.prim.ListConstruct %int1_9, %int0_10 : (!torch.int, !torch.int) -> !torch.list + %int1_11 = torch.constant.int 1 + %int0_12 = torch.constant.int 0 + %9 = torch.prim.ListConstruct %int1_11, %int0_12 : (!torch.int, !torch.int) -> !torch.list %10 = torch.aten.permute %arg16, %9 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_12 = torch.constant.int 1 - %int0_13 = torch.constant.int 0 - %11 = torch.prim.ListConstruct %int1_12, %int0_13 : (!torch.int, !torch.int) -> !torch.list - %int1_14 = torch.constant.int 1 - %int0_15 = torch.constant.int 0 - %12 = torch.prim.ListConstruct %int1_14, %int0_15 : (!torch.int, !torch.int) -> !torch.list + %int1_13 = torch.constant.int 1 + %int0_14 = torch.constant.int 0 + %11 = torch.prim.ListConstruct %int1_13, %int0_14 : (!torch.int, !torch.int) -> !torch.list + %int1_15 = torch.constant.int 1 + %int0_16 = torch.constant.int 0 + %12 = torch.prim.ListConstruct %int1_15, %int0_16 : (!torch.int, !torch.int) -> !torch.list %13 = torch.aten.permute %10, %12 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_16 = torch.constant.int 1 - %int0_17 = torch.constant.int 0 - %14 = torch.prim.ListConstruct %int1_16, %int0_17 : (!torch.int, !torch.int) -> !torch.list - %int1_18 = torch.constant.int 1 - %int0_19 = torch.constant.int 0 - %15 = torch.prim.ListConstruct %int1_18, %int0_19 : (!torch.int, !torch.int) -> !torch.list + %int1_17 = torch.constant.int 1 + %int0_18 = torch.constant.int 0 + %14 = torch.prim.ListConstruct %int1_17, %int0_18 : (!torch.int, !torch.int) -> !torch.list + %int1_19 = torch.constant.int 1 + %int0_20 = torch.constant.int 0 + %15 = torch.prim.ListConstruct %int1_19, %int0_20 : (!torch.int, !torch.int) -> !torch.list %16 = torch.aten.permute %arg19, %15 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> %int1024 = torch.constant.int 1024 - %int32_20 = torch.constant.int 32 - %17 = torch.prim.ListConstruct %int1024, %int32_20 : (!torch.int, !torch.int) -> !torch.list - %int1024_21 = torch.constant.int 1024 - %int32_22 = torch.constant.int 32 - %18 = torch.prim.ListConstruct %int1024_21, %int32_22 : (!torch.int, !torch.int) -> !torch.list + %int32_21 = torch.constant.int 32 + %17 = torch.prim.ListConstruct %int1024, %int32_21 : (!torch.int, !torch.int) -> !torch.list + %int1024_22 = torch.constant.int 1024 + %int32_23 = torch.constant.int 32 + %18 = torch.prim.ListConstruct %int1024_22, %int32_23 : (!torch.int, !torch.int) -> !torch.list %19 = torch.aten.reshape %result0, %18 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> %20 = torch.aten.addmm %arg20, %19, %16, %arg18, %arg17 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> %int2 = torch.constant.int 2 %int512 = torch.constant.int 512 - %int32_23 = torch.constant.int 32 - %21 = torch.prim.ListConstruct %int2, %int512, %int32_23 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_24 = torch.constant.int 2 - %int512_25 = torch.constant.int 512 - %int32_26 = torch.constant.int 32 - %22 = torch.prim.ListConstruct %int2_24, %int512_25, %int32_26 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int32_24 = torch.constant.int 32 + %21 = torch.prim.ListConstruct %int2, %int512, %int32_24 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_25 = torch.constant.int 2 + %int512_26 = torch.constant.int 512 + %int32_27 = torch.constant.int 32 + %22 = torch.prim.ListConstruct %int2_25, %int512_26, %int32_27 : (!torch.int, !torch.int, !torch.int) -> !torch.list %23 = torch.aten.reshape %20, %22 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int2_27 = torch.constant.int 2 - %int512_28 = torch.constant.int 512 + %int2_28 = torch.constant.int 2 + %int512_29 = torch.constant.int 512 %int16 = torch.constant.int 16 - %24 = torch.prim.ListConstruct %int2_27, %int512_28, %int2_27, %int16 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_29 = torch.constant.int 2 - %int512_30 = torch.constant.int 512 - %int16_31 = torch.constant.int 16 - %25 = torch.prim.ListConstruct %int2_29, %int512_30, %int2_29, %int16_31 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %24 = torch.prim.ListConstruct %int2_28, %int512_29, %int2_28, %int16 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_30 = torch.constant.int 2 + %int512_31 = torch.constant.int 512 + %int16_32 = torch.constant.int 16 + %25 = torch.prim.ListConstruct %int2_30, %int512_31, %int2_30, %int16_32 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %26 = torch.aten.reshape %23, %25 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> - %int0_32 = torch.constant.int 0 - %int2_33 = torch.constant.int 2 - %int1_34 = torch.constant.int 1 + %int0_33 = torch.constant.int 0 + %int2_34 = torch.constant.int 2 + %int1_35 = torch.constant.int 1 %int3 = torch.constant.int 3 - %27 = torch.prim.ListConstruct %int0_32, %int2_33, %int1_34, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_35 = torch.constant.int 0 - %int2_36 = torch.constant.int 2 - %int1_37 = torch.constant.int 1 - %int3_38 = torch.constant.int 3 - %28 = torch.prim.ListConstruct %int0_35, %int2_36, %int1_37, %int3_38 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %27 = torch.prim.ListConstruct %int0_33, %int2_34, %int1_35, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_36 = torch.constant.int 0 + %int2_37 = torch.constant.int 2 + %int1_38 = torch.constant.int 1 + %int3_39 = torch.constant.int 3 + %28 = torch.prim.ListConstruct %int0_36, %int2_37, %int1_38, %int3_39 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %29 = torch.aten.permute %26, %28 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> - %int-1_39 = torch.constant.int -1 - %int-2 = torch.constant.int -2 - %30 = torch.aten.transpose.int %29, %int-1_39, %int-2 : !torch.vtensor<[2,2,512,16],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,2,16,512],f32> - %int2_40 = torch.constant.int 2 - %int16_41 = torch.constant.int 16 - %int512_42 = torch.constant.int 512 - %31 = torch.prim.ListConstruct %int2_40, %int2_40, %int16_41, %int512_42 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_43 = torch.constant.bool false - %32 = torch.aten.expand %30, %31, %false_43 : !torch.vtensor<[2,2,16,512],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,2,16,512],f32> + %int0_40 = torch.constant.int 0 + %int1_41 = torch.constant.int 1 + %int3_42 = torch.constant.int 3 + %int2_43 = torch.constant.int 2 + %30 = torch.prim.ListConstruct %int0_40, %int1_41, %int3_42, %int2_43 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_44 = torch.constant.int 0 + %int1_45 = torch.constant.int 1 + %int3_46 = torch.constant.int 3 + %int2_47 = torch.constant.int 2 + %31 = torch.prim.ListConstruct %int0_44, %int1_45, %int3_46, %int2_47 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %32 = torch.aten.permute %29, %31 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,2,16,512],f32> + %int2_48 = torch.constant.int 2 + %int16_49 = torch.constant.int 16 + %int512_50 = torch.constant.int 512 + %33 = torch.prim.ListConstruct %int2_48, %int2_48, %int16_49, %int512_50 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_51 = torch.constant.bool false + %34 = torch.aten.expand %32, %33, %false_51 : !torch.vtensor<[2,2,16,512],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,2,16,512],f32> %int4 = torch.constant.int 4 - %int16_44 = torch.constant.int 16 - %int512_45 = torch.constant.int 512 - %33 = torch.prim.ListConstruct %int4, %int16_44, %int512_45 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int4_46 = torch.constant.int 4 - %int16_47 = torch.constant.int 16 - %int512_48 = torch.constant.int 512 - %34 = torch.prim.ListConstruct %int4_46, %int16_47, %int512_48 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %35 = torch.aten.reshape %32, %34 : !torch.vtensor<[2,2,16,512],f32>, !torch.list -> !torch.vtensor<[4,16,512],f32> - %int1_49 = torch.constant.int 1 - %int2_50 = torch.constant.int 2 - %36 = torch.aten.transpose.int %35, %int1_49, %int2_50 : !torch.vtensor<[4,16,512],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,512,16],f32> - %int0_51 = torch.constant.int 0 - %int0_52 = torch.constant.int 0 - %int9223372036854775807_53 = torch.constant.int 9223372036854775807 - %int1_54 = torch.constant.int 1 - %37 = torch.aten.slice.Tensor %arg26, %int0_51, %int0_52, %int9223372036854775807_53, %int1_54 : !torch.vtensor<[2,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,512],si64> - %int2_55 = torch.constant.int 2 - %int1_56 = torch.constant.int 1 - %int512_57 = torch.constant.int 512 - %38 = torch.prim.ListConstruct %int2_55, %int1_56, %int512_57 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int16_52 = torch.constant.int 16 + %int512_53 = torch.constant.int 512 + %35 = torch.prim.ListConstruct %int4, %int16_52, %int512_53 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int4_54 = torch.constant.int 4 + %int16_55 = torch.constant.int 16 + %int512_56 = torch.constant.int 512 + %36 = torch.prim.ListConstruct %int4_54, %int16_55, %int512_56 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %37 = torch.aten.reshape %34, %36 : !torch.vtensor<[2,2,16,512],f32>, !torch.list -> !torch.vtensor<[4,16,512],f32> + %int0_57 = torch.constant.int 0 %int2_58 = torch.constant.int 2 %int1_59 = torch.constant.int 1 - %int512_60 = torch.constant.int 512 - %39 = torch.prim.ListConstruct %int2_58, %int1_59, %int512_60 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %40 = torch.aten.reshape %37, %39 : !torch.vtensor<[2,512],si64>, !torch.list -> !torch.vtensor<[2,1,512],si64> + %38 = torch.prim.ListConstruct %int0_57, %int2_58, %int1_59 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int0_60 = torch.constant.int 0 %int2_61 = torch.constant.int 2 %int1_62 = torch.constant.int 1 - %int512_63 = torch.constant.int 512 - %41 = torch.prim.ListConstruct %int2_61, %int1_62, %int1_62, %int512_63 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_64 = torch.constant.int 2 - %int1_65 = torch.constant.int 1 - %int512_66 = torch.constant.int 512 - %42 = torch.prim.ListConstruct %int2_64, %int1_65, %int1_65, %int512_66 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %43 = torch.aten.reshape %40, %42 : !torch.vtensor<[2,1,512],si64>, !torch.list -> !torch.vtensor<[2,1,1,512],si64> - %int3_67 = torch.constant.int 3 - %int0_68 = torch.constant.int 0 - %int9223372036854775807_69 = torch.constant.int 9223372036854775807 - %int1_70 = torch.constant.int 1 - %44 = torch.aten.slice.Tensor %43, %int3_67, %int0_68, %int9223372036854775807_69, %int1_70 : !torch.vtensor<[2,1,1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1,1,512],si64> + %39 = torch.prim.ListConstruct %int0_60, %int2_61, %int1_62 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %40 = torch.aten.permute %37, %39 : !torch.vtensor<[4,16,512],f32>, !torch.list -> !torch.vtensor<[4,512,16],f32> + %int0_63 = torch.constant.int 0 + %int0_64 = torch.constant.int 0 + %int2_65 = torch.constant.int 2 + %int1_66 = torch.constant.int 1 + %41 = torch.aten.slice.Tensor %arg25, %int0_63, %int0_64, %int2_65, %int1_66 : !torch.vtensor<[2,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,512],si64> + %int2_67 = torch.constant.int 2 + %int1_68 = torch.constant.int 1 + %int512_69 = torch.constant.int 512 + %42 = torch.prim.ListConstruct %int2_67, %int1_68, %int512_69 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_70 = torch.constant.int 2 + %int1_71 = torch.constant.int 1 + %int512_72 = torch.constant.int 512 + %43 = torch.prim.ListConstruct %int2_70, %int1_71, %int512_72 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %44 = torch.aten.reshape %41, %43 : !torch.vtensor<[2,512],si64>, !torch.list -> !torch.vtensor<[2,1,512],si64> + %int2_73 = torch.constant.int 2 + %int1_74 = torch.constant.int 1 + %int512_75 = torch.constant.int 512 + %45 = torch.prim.ListConstruct %int2_73, %int1_74, %int1_74, %int512_75 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_76 = torch.constant.int 2 + %int1_77 = torch.constant.int 1 + %int512_78 = torch.constant.int 512 + %46 = torch.prim.ListConstruct %int2_76, %int1_77, %int1_77, %int512_78 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %47 = torch.aten.reshape %44, %46 : !torch.vtensor<[2,1,512],si64>, !torch.list -> !torch.vtensor<[2,1,1,512],si64> + %int3_79 = torch.constant.int 3 + %int0_80 = torch.constant.int 0 + %int512_81 = torch.constant.int 512 + %int1_82 = torch.constant.int 1 + %48 = torch.aten.slice.Tensor %47, %int3_79, %int0_80, %int512_81, %int1_82 : !torch.vtensor<[2,1,1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1,1,512],si64> %int6 = torch.constant.int 6 %none = torch.constant.none - %none_71 = torch.constant.none - %none_72 = torch.constant.none - %false_73 = torch.constant.bool false - %none_74 = torch.constant.none - %45 = torch.aten._to_copy %44, %int6, %none, %none_71, %none_72, %false_73, %none_74 : !torch.vtensor<[2,1,1,512],si64>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[2,1,1,512],f32> - %46 = torch.aten.rsub.Scalar %45, %arg25, %arg24 : !torch.vtensor<[2,1,1,512],f32>, !torch.float, !torch.int -> !torch.vtensor<[2,1,1,512],f32> - %47 = torch.aten.mul.Tensor %46, %arg23 : !torch.vtensor<[2,1,1,512],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,1,1,512],f32> - %int1_75 = torch.constant.int 1 - %int0_76 = torch.constant.int 0 - %48 = torch.prim.ListConstruct %int1_75, %int0_76 : (!torch.int, !torch.int) -> !torch.list - %int1_77 = torch.constant.int 1 - %int0_78 = torch.constant.int 0 - %49 = torch.prim.ListConstruct %int1_77, %int0_78 : (!torch.int, !torch.int) -> !torch.list - %50 = torch.aten.permute %arg16, %49 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1024_79 = torch.constant.int 1024 - %int32_80 = torch.constant.int 32 - %51 = torch.prim.ListConstruct %int1024_79, %int32_80 : (!torch.int, !torch.int) -> !torch.list - %int1024_81 = torch.constant.int 1024 - %int32_82 = torch.constant.int 32 - %52 = torch.prim.ListConstruct %int1024_81, %int32_82 : (!torch.int, !torch.int) -> !torch.list - %53 = torch.aten.reshape %result0, %52 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %54 = torch.aten.addmm %arg29, %53, %50, %arg28, %arg27 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> - %int2_83 = torch.constant.int 2 - %int512_84 = torch.constant.int 512 - %int32_85 = torch.constant.int 32 - %55 = torch.prim.ListConstruct %int2_83, %int512_84, %int32_85 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_86 = torch.constant.int 2 - %int512_87 = torch.constant.int 512 - %int32_88 = torch.constant.int 32 - %56 = torch.prim.ListConstruct %int2_86, %int512_87, %int32_88 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %57 = torch.aten.reshape %54, %56 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int2_89 = torch.constant.int 2 - %int512_90 = torch.constant.int 512 + %none_83 = torch.constant.none + %none_84 = torch.constant.none + %false_85 = torch.constant.bool false + %none_86 = torch.constant.none + %49 = torch.aten._to_copy %48, %int6, %none, %none_83, %none_84, %false_85, %none_86 : !torch.vtensor<[2,1,1,512],si64>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[2,1,1,512],f32> + %50 = torch.aten.sub.Tensor %arg26, %49, %arg24 : !torch.vtensor<[],f64>, !torch.vtensor<[2,1,1,512],f32>, !torch.int -> !torch.vtensor<[2,1,1,512],f32> + %51 = torch.aten.mul.Tensor %50, %arg23 : !torch.vtensor<[2,1,1,512],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,1,1,512],f32> + %int4_87 = torch.constant.int 4 + %int16_88 = torch.constant.int 16 + %int512_89 = torch.constant.int 512 + %52 = torch.prim.ListConstruct %int4_87, %int16_88, %int512_89 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int4_90 = torch.constant.int 4 %int16_91 = torch.constant.int 16 - %58 = torch.prim.ListConstruct %int2_89, %int512_90, %int2_89, %int16_91 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_92 = torch.constant.int 2 - %int512_93 = torch.constant.int 512 - %int16_94 = torch.constant.int 16 - %59 = torch.prim.ListConstruct %int2_92, %int512_93, %int2_92, %int16_94 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %60 = torch.aten.reshape %57, %59 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> - %int0_95 = torch.constant.int 0 - %int2_96 = torch.constant.int 2 - %int1_97 = torch.constant.int 1 - %int3_98 = torch.constant.int 3 - %61 = torch.prim.ListConstruct %int0_95, %int2_96, %int1_97, %int3_98 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_99 = torch.constant.int 0 - %int2_100 = torch.constant.int 2 - %int1_101 = torch.constant.int 1 - %int3_102 = torch.constant.int 3 - %62 = torch.prim.ListConstruct %int0_99, %int2_100, %int1_101, %int3_102 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %63 = torch.aten.permute %60, %62 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> - %int2_103 = torch.constant.int 2 - %int512_104 = torch.constant.int 512 - %int16_105 = torch.constant.int 16 - %64 = torch.prim.ListConstruct %int2_103, %int2_103, %int512_104, %int16_105 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_106 = torch.constant.bool false - %65 = torch.aten.expand %63, %64, %false_106 : !torch.vtensor<[2,2,512,16],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,2,512,16],f32> - %int4_107 = torch.constant.int 4 + %int512_92 = torch.constant.int 512 + %53 = torch.prim.ListConstruct %int4_90, %int16_91, %int512_92 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %54 = torch.aten.reshape %34, %53 : !torch.vtensor<[2,2,16,512],f32>, !torch.list -> !torch.vtensor<[4,16,512],f32> + %int1_93 = torch.constant.int 1 + %int0_94 = torch.constant.int 0 + %55 = torch.prim.ListConstruct %int1_93, %int0_94 : (!torch.int, !torch.int) -> !torch.list + %int1_95 = torch.constant.int 1 + %int0_96 = torch.constant.int 0 + %56 = torch.prim.ListConstruct %int1_95, %int0_96 : (!torch.int, !torch.int) -> !torch.list + %57 = torch.aten.permute %arg16, %56 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1024_97 = torch.constant.int 1024 + %int32_98 = torch.constant.int 32 + %58 = torch.prim.ListConstruct %int1024_97, %int32_98 : (!torch.int, !torch.int) -> !torch.list + %int1024_99 = torch.constant.int 1024 + %int32_100 = torch.constant.int 32 + %59 = torch.prim.ListConstruct %int1024_99, %int32_100 : (!torch.int, !torch.int) -> !torch.list + %60 = torch.aten.reshape %result0, %59 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %61 = torch.aten.addmm %arg29, %60, %57, %arg28, %arg27 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> + %int2_101 = torch.constant.int 2 + %int512_102 = torch.constant.int 512 + %int32_103 = torch.constant.int 32 + %62 = torch.prim.ListConstruct %int2_101, %int512_102, %int32_103 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_104 = torch.constant.int 2 + %int512_105 = torch.constant.int 512 + %int32_106 = torch.constant.int 32 + %63 = torch.prim.ListConstruct %int2_104, %int512_105, %int32_106 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %64 = torch.aten.reshape %61, %63 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int2_107 = torch.constant.int 2 %int512_108 = torch.constant.int 512 %int16_109 = torch.constant.int 16 - %66 = torch.prim.ListConstruct %int4_107, %int512_108, %int16_109 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int4_110 = torch.constant.int 4 + %65 = torch.prim.ListConstruct %int2_107, %int512_108, %int2_107, %int16_109 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_110 = torch.constant.int 2 %int512_111 = torch.constant.int 512 %int16_112 = torch.constant.int 16 - %67 = torch.prim.ListConstruct %int4_110, %int512_111, %int16_112 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %68 = torch.aten.reshape %65, %67 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[4,512,16],f32> - %69 = torch.aten.bmm %68, %35 : !torch.vtensor<[4,512,16],f32>, !torch.vtensor<[4,16,512],f32> -> !torch.vtensor<[4,512,512],f32> - %int2_113 = torch.constant.int 2 - %int512_114 = torch.constant.int 512 - %70 = torch.prim.ListConstruct %int2_113, %int2_113, %int512_114, %int512_114 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_115 = torch.constant.int 2 - %int512_116 = torch.constant.int 512 - %71 = torch.prim.ListConstruct %int2_115, %int2_115, %int512_116, %int512_116 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %72 = torch.aten._unsafe_view %69, %71 : !torch.vtensor<[4,512,512],f32>, !torch.list -> !torch.vtensor<[2,2,512,512],f32> - %73 = torch.aten.div.Tensor %72, %arg21 : !torch.vtensor<[2,2,512,512],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,2,512,512],f32> - %74 = torch.aten.add.Tensor %73, %47, %arg22 : !torch.vtensor<[2,2,512,512],f32>, !torch.vtensor<[2,1,1,512],f32>, !torch.int -> !torch.vtensor<[2,2,512,512],f32> - %int-1_117 = torch.constant.int -1 - %false_118 = torch.constant.bool false - %75 = torch.aten._softmax %74, %int-1_117, %false_118 : !torch.vtensor<[2,2,512,512],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,2,512,512],f32> + %66 = torch.prim.ListConstruct %int2_110, %int512_111, %int2_110, %int16_112 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %67 = torch.aten.reshape %64, %66 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> + %int0_113 = torch.constant.int 0 + %int2_114 = torch.constant.int 2 + %int1_115 = torch.constant.int 1 + %int3_116 = torch.constant.int 3 + %68 = torch.prim.ListConstruct %int0_113, %int2_114, %int1_115, %int3_116 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_117 = torch.constant.int 0 + %int2_118 = torch.constant.int 2 %int1_119 = torch.constant.int 1 - %int0_120 = torch.constant.int 0 - %76 = torch.prim.ListConstruct %int1_119, %int0_120 : (!torch.int, !torch.int) -> !torch.list - %int1_121 = torch.constant.int 1 - %int0_122 = torch.constant.int 0 - %77 = torch.prim.ListConstruct %int1_121, %int0_122 : (!torch.int, !torch.int) -> !torch.list - %78 = torch.aten.permute %arg32, %77 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1024_123 = torch.constant.int 1024 - %int32_124 = torch.constant.int 32 - %79 = torch.prim.ListConstruct %int1024_123, %int32_124 : (!torch.int, !torch.int) -> !torch.list - %int1024_125 = torch.constant.int 1024 - %int32_126 = torch.constant.int 32 - %80 = torch.prim.ListConstruct %int1024_125, %int32_126 : (!torch.int, !torch.int) -> !torch.list - %81 = torch.aten.reshape %result0, %80 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %82 = torch.aten.addmm %arg33, %81, %78, %arg31, %arg30 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> - %int2_127 = torch.constant.int 2 - %int512_128 = torch.constant.int 512 - %int32_129 = torch.constant.int 32 - %83 = torch.prim.ListConstruct %int2_127, %int512_128, %int32_129 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_130 = torch.constant.int 2 - %int512_131 = torch.constant.int 512 - %int32_132 = torch.constant.int 32 - %84 = torch.prim.ListConstruct %int2_130, %int512_131, %int32_132 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %85 = torch.aten.reshape %82, %84 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int3_120 = torch.constant.int 3 + %69 = torch.prim.ListConstruct %int0_117, %int2_118, %int1_119, %int3_120 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %70 = torch.aten.permute %67, %69 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> + %int2_121 = torch.constant.int 2 + %int512_122 = torch.constant.int 512 + %int16_123 = torch.constant.int 16 + %71 = torch.prim.ListConstruct %int2_121, %int2_121, %int512_122, %int16_123 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_124 = torch.constant.bool false + %72 = torch.aten.expand %70, %71, %false_124 : !torch.vtensor<[2,2,512,16],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,2,512,16],f32> + %int4_125 = torch.constant.int 4 + %int512_126 = torch.constant.int 512 + %int16_127 = torch.constant.int 16 + %73 = torch.prim.ListConstruct %int4_125, %int512_126, %int16_127 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int4_128 = torch.constant.int 4 + %int512_129 = torch.constant.int 512 + %int16_130 = torch.constant.int 16 + %74 = torch.prim.ListConstruct %int4_128, %int512_129, %int16_130 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %75 = torch.aten.reshape %72, %74 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[4,512,16],f32> + %76 = torch.aten.bmm %75, %54 : !torch.vtensor<[4,512,16],f32>, !torch.vtensor<[4,16,512],f32> -> !torch.vtensor<[4,512,512],f32> + %int2_131 = torch.constant.int 2 + %int512_132 = torch.constant.int 512 + %77 = torch.prim.ListConstruct %int2_131, %int2_131, %int512_132, %int512_132 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %int2_133 = torch.constant.int 2 %int512_134 = torch.constant.int 512 - %int16_135 = torch.constant.int 16 - %86 = torch.prim.ListConstruct %int2_133, %int512_134, %int2_133, %int16_135 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_136 = torch.constant.int 2 - %int512_137 = torch.constant.int 512 - %int16_138 = torch.constant.int 16 - %87 = torch.prim.ListConstruct %int2_136, %int512_137, %int2_136, %int16_138 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %88 = torch.aten.reshape %85, %87 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> - %int0_139 = torch.constant.int 0 - %int2_140 = torch.constant.int 2 - %int1_141 = torch.constant.int 1 - %int3_142 = torch.constant.int 3 - %89 = torch.prim.ListConstruct %int0_139, %int2_140, %int1_141, %int3_142 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_143 = torch.constant.int 0 - %int2_144 = torch.constant.int 2 - %int1_145 = torch.constant.int 1 - %int3_146 = torch.constant.int 3 - %90 = torch.prim.ListConstruct %int0_143, %int2_144, %int1_145, %int3_146 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %91 = torch.aten.permute %88, %90 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> - %int2_147 = torch.constant.int 2 - %int512_148 = torch.constant.int 512 - %int16_149 = torch.constant.int 16 - %92 = torch.prim.ListConstruct %int2_147, %int2_147, %int512_148, %int16_149 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_150 = torch.constant.bool false - %93 = torch.aten.expand %91, %92, %false_150 : !torch.vtensor<[2,2,512,16],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,2,512,16],f32> - %int4_151 = torch.constant.int 4 + %78 = torch.prim.ListConstruct %int2_133, %int2_133, %int512_134, %int512_134 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %79 = torch.aten.reshape %76, %78 : !torch.vtensor<[4,512,512],f32>, !torch.list -> !torch.vtensor<[2,2,512,512],f32> + %80 = torch.aten.div.Tensor %79, %arg21 : !torch.vtensor<[2,2,512,512],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,2,512,512],f32> + %81 = torch.aten.add.Tensor %80, %51, %arg22 : !torch.vtensor<[2,2,512,512],f32>, !torch.vtensor<[2,1,1,512],f32>, !torch.int -> !torch.vtensor<[2,2,512,512],f32> + %int-1_135 = torch.constant.int -1 + %false_136 = torch.constant.bool false + %82 = torch.aten._softmax %81, %int-1_135, %false_136 : !torch.vtensor<[2,2,512,512],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,2,512,512],f32> + %int1_137 = torch.constant.int 1 + %int0_138 = torch.constant.int 0 + %83 = torch.prim.ListConstruct %int1_137, %int0_138 : (!torch.int, !torch.int) -> !torch.list + %int1_139 = torch.constant.int 1 + %int0_140 = torch.constant.int 0 + %84 = torch.prim.ListConstruct %int1_139, %int0_140 : (!torch.int, !torch.int) -> !torch.list + %85 = torch.aten.permute %arg32, %84 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1024_141 = torch.constant.int 1024 + %int32_142 = torch.constant.int 32 + %86 = torch.prim.ListConstruct %int1024_141, %int32_142 : (!torch.int, !torch.int) -> !torch.list + %int1024_143 = torch.constant.int 1024 + %int32_144 = torch.constant.int 32 + %87 = torch.prim.ListConstruct %int1024_143, %int32_144 : (!torch.int, !torch.int) -> !torch.list + %88 = torch.aten.reshape %result0, %87 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %89 = torch.aten.addmm %arg33, %88, %85, %arg31, %arg30 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> + %int2_145 = torch.constant.int 2 + %int512_146 = torch.constant.int 512 + %int32_147 = torch.constant.int 32 + %90 = torch.prim.ListConstruct %int2_145, %int512_146, %int32_147 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_148 = torch.constant.int 2 + %int512_149 = torch.constant.int 512 + %int32_150 = torch.constant.int 32 + %91 = torch.prim.ListConstruct %int2_148, %int512_149, %int32_150 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %92 = torch.aten.reshape %89, %91 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int2_151 = torch.constant.int 2 %int512_152 = torch.constant.int 512 %int16_153 = torch.constant.int 16 - %94 = torch.prim.ListConstruct %int4_151, %int512_152, %int16_153 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int4_154 = torch.constant.int 4 + %93 = torch.prim.ListConstruct %int2_151, %int512_152, %int2_151, %int16_153 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_154 = torch.constant.int 2 %int512_155 = torch.constant.int 512 %int16_156 = torch.constant.int 16 - %95 = torch.prim.ListConstruct %int4_154, %int512_155, %int16_156 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %96 = torch.aten.reshape %93, %95 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[4,512,16],f32> - %int1_157 = torch.constant.int 1 + %94 = torch.prim.ListConstruct %int2_154, %int512_155, %int2_154, %int16_156 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %95 = torch.aten.reshape %92, %94 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> + %int0_157 = torch.constant.int 0 %int2_158 = torch.constant.int 2 - %97 = torch.aten.transpose.int %96, %int1_157, %int2_158 : !torch.vtensor<[4,512,16],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,16,512],f32> %int1_159 = torch.constant.int 1 - %int0_160 = torch.constant.int 0 - %98 = torch.prim.ListConstruct %int1_159, %int0_160 : (!torch.int, !torch.int) -> !torch.list - %int1_161 = torch.constant.int 1 - %int0_162 = torch.constant.int 0 - %99 = torch.prim.ListConstruct %int1_161, %int0_162 : (!torch.int, !torch.int) -> !torch.list - %100 = torch.aten.permute %arg34, %99 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int3_160 = torch.constant.int 3 + %96 = torch.prim.ListConstruct %int0_157, %int2_158, %int1_159, %int3_160 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_161 = torch.constant.int 0 + %int2_162 = torch.constant.int 2 %int1_163 = torch.constant.int 1 - %int0_164 = torch.constant.int 0 - %101 = torch.prim.ListConstruct %int1_163, %int0_164 : (!torch.int, !torch.int) -> !torch.list - %int1_165 = torch.constant.int 1 - %int0_166 = torch.constant.int 0 - %102 = torch.prim.ListConstruct %int1_165, %int0_166 : (!torch.int, !torch.int) -> !torch.list - %103 = torch.aten.permute %100, %102 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_167 = torch.constant.int 1 - %int0_168 = torch.constant.int 0 - %104 = torch.prim.ListConstruct %int1_167, %int0_168 : (!torch.int, !torch.int) -> !torch.list - %int1_169 = torch.constant.int 1 - %int0_170 = torch.constant.int 0 - %105 = torch.prim.ListConstruct %int1_169, %int0_170 : (!torch.int, !torch.int) -> !torch.list - %106 = torch.aten.permute %arg34, %105 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int2_171 = torch.constant.int 2 - %int512_172 = torch.constant.int 512 - %107 = torch.prim.ListConstruct %int2_171, %int2_171, %int512_172, %int512_172 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_173 = torch.constant.bool false - %108 = torch.aten.expand %75, %107, %false_173 : !torch.vtensor<[2,2,512,512],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,2,512,512],f32> - %int4_174 = torch.constant.int 4 - %int512_175 = torch.constant.int 512 - %109 = torch.prim.ListConstruct %int4_174, %int512_175, %int512_175 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int4_176 = torch.constant.int 4 - %int512_177 = torch.constant.int 512 - %110 = torch.prim.ListConstruct %int4_176, %int512_177, %int512_177 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %111 = torch.aten.reshape %108, %110 : !torch.vtensor<[2,2,512,512],f32>, !torch.list -> !torch.vtensor<[4,512,512],f32> - %112 = torch.aten.bmm %111, %96 : !torch.vtensor<[4,512,512],f32>, !torch.vtensor<[4,512,16],f32> -> !torch.vtensor<[4,512,16],f32> - %int2_178 = torch.constant.int 2 - %int512_179 = torch.constant.int 512 - %int16_180 = torch.constant.int 16 - %113 = torch.prim.ListConstruct %int2_178, %int2_178, %int512_179, %int16_180 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_181 = torch.constant.int 2 - %int512_182 = torch.constant.int 512 - %int16_183 = torch.constant.int 16 - %114 = torch.prim.ListConstruct %int2_181, %int2_181, %int512_182, %int16_183 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %115 = torch.aten._unsafe_view %112, %114 : !torch.vtensor<[4,512,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> + %int3_164 = torch.constant.int 3 + %97 = torch.prim.ListConstruct %int0_161, %int2_162, %int1_163, %int3_164 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %98 = torch.aten.permute %95, %97 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> + %int2_165 = torch.constant.int 2 + %int512_166 = torch.constant.int 512 + %int16_167 = torch.constant.int 16 + %99 = torch.prim.ListConstruct %int2_165, %int2_165, %int512_166, %int16_167 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_168 = torch.constant.bool false + %100 = torch.aten.expand %98, %99, %false_168 : !torch.vtensor<[2,2,512,16],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,2,512,16],f32> + %int4_169 = torch.constant.int 4 + %int512_170 = torch.constant.int 512 + %int16_171 = torch.constant.int 16 + %101 = torch.prim.ListConstruct %int4_169, %int512_170, %int16_171 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int4_172 = torch.constant.int 4 + %int512_173 = torch.constant.int 512 + %int16_174 = torch.constant.int 16 + %102 = torch.prim.ListConstruct %int4_172, %int512_173, %int16_174 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %103 = torch.aten.reshape %100, %102 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[4,512,16],f32> + %int0_175 = torch.constant.int 0 + %int2_176 = torch.constant.int 2 + %int1_177 = torch.constant.int 1 + %104 = torch.prim.ListConstruct %int0_175, %int2_176, %int1_177 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int0_178 = torch.constant.int 0 + %int2_179 = torch.constant.int 2 + %int1_180 = torch.constant.int 1 + %105 = torch.prim.ListConstruct %int0_178, %int2_179, %int1_180 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %106 = torch.aten.permute %103, %105 : !torch.vtensor<[4,512,16],f32>, !torch.list -> !torch.vtensor<[4,16,512],f32> + %int1_181 = torch.constant.int 1 + %int0_182 = torch.constant.int 0 + %107 = torch.prim.ListConstruct %int1_181, %int0_182 : (!torch.int, !torch.int) -> !torch.list + %int1_183 = torch.constant.int 1 %int0_184 = torch.constant.int 0 - %int2_185 = torch.constant.int 2 - %int1_186 = torch.constant.int 1 - %int3_187 = torch.constant.int 3 - %116 = torch.prim.ListConstruct %int0_184, %int2_185, %int1_186, %int3_187 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %108 = torch.prim.ListConstruct %int1_183, %int0_184 : (!torch.int, !torch.int) -> !torch.list + %109 = torch.aten.permute %arg34, %108 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_185 = torch.constant.int 1 + %int0_186 = torch.constant.int 0 + %110 = torch.prim.ListConstruct %int1_185, %int0_186 : (!torch.int, !torch.int) -> !torch.list + %int1_187 = torch.constant.int 1 %int0_188 = torch.constant.int 0 - %int2_189 = torch.constant.int 2 - %int1_190 = torch.constant.int 1 - %int3_191 = torch.constant.int 3 - %117 = torch.prim.ListConstruct %int0_188, %int2_189, %int1_190, %int3_191 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %118 = torch.aten.permute %115, %117 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> - %int2_192 = torch.constant.int 2 - %int512_193 = torch.constant.int 512 - %int32_194 = torch.constant.int 32 - %119 = torch.prim.ListConstruct %int2_192, %int512_193, %int32_194 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_195 = torch.constant.int 2 - %int512_196 = torch.constant.int 512 - %int32_197 = torch.constant.int 32 - %120 = torch.prim.ListConstruct %int2_195, %int512_196, %int32_197 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %121 = torch.aten.reshape %118, %120 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int1024_198 = torch.constant.int 1024 - %int32_199 = torch.constant.int 32 - %122 = torch.prim.ListConstruct %int1024_198, %int32_199 : (!torch.int, !torch.int) -> !torch.list - %int1024_200 = torch.constant.int 1024 - %int32_201 = torch.constant.int 32 - %123 = torch.prim.ListConstruct %int1024_200, %int32_201 : (!torch.int, !torch.int) -> !torch.list - %124 = torch.aten.reshape %121, %123 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %125 = torch.aten.addmm %arg40, %124, %106, %arg39, %arg38 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> - %int2_202 = torch.constant.int 2 + %111 = torch.prim.ListConstruct %int1_187, %int0_188 : (!torch.int, !torch.int) -> !torch.list + %112 = torch.aten.permute %109, %111 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_189 = torch.constant.int 1 + %int0_190 = torch.constant.int 0 + %113 = torch.prim.ListConstruct %int1_189, %int0_190 : (!torch.int, !torch.int) -> !torch.list + %int1_191 = torch.constant.int 1 + %int0_192 = torch.constant.int 0 + %114 = torch.prim.ListConstruct %int1_191, %int0_192 : (!torch.int, !torch.int) -> !torch.list + %115 = torch.aten.permute %arg34, %114 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int4_193 = torch.constant.int 4 + %int512_194 = torch.constant.int 512 + %int16_195 = torch.constant.int 16 + %116 = torch.prim.ListConstruct %int4_193, %int512_194, %int16_195 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int4_196 = torch.constant.int 4 + %int512_197 = torch.constant.int 512 + %int16_198 = torch.constant.int 16 + %117 = torch.prim.ListConstruct %int4_196, %int512_197, %int16_198 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %118 = torch.aten.reshape %100, %117 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[4,512,16],f32> + %int2_199 = torch.constant.int 2 + %int512_200 = torch.constant.int 512 + %119 = torch.prim.ListConstruct %int2_199, %int2_199, %int512_200, %int512_200 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_201 = torch.constant.bool false + %120 = torch.aten.expand %82, %119, %false_201 : !torch.vtensor<[2,2,512,512],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,2,512,512],f32> + %int4_202 = torch.constant.int 4 %int512_203 = torch.constant.int 512 - %int32_204 = torch.constant.int 32 - %126 = torch.prim.ListConstruct %int2_202, %int512_203, %int32_204 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_205 = torch.constant.int 2 - %int512_206 = torch.constant.int 512 - %int32_207 = torch.constant.int 32 - %127 = torch.prim.ListConstruct %int2_205, %int512_206, %int32_207 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %128 = torch.aten.reshape %125, %127 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %129 = torch.aten.add.Tensor %128, %result0, %arg37 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> - %int32_208 = torch.constant.int 32 - %130 = torch.prim.ListConstruct %int32_208 : (!torch.int) -> !torch.list - %float1.000000e-05_209 = torch.constant.float 1.000000e-05 - %result0_210, %result1_211, %result2_212 = torch.aten.native_layer_norm %129, %130, %arg36, %arg35, %float1.000000e-05_209 : !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> - %131 = torch.prim.TupleConstruct %result0_210, %result1_211, %result2_212 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> -> !torch.tuple, vtensor<[2,512,1],f32>, vtensor<[2,512,1],f32>> - %int1_213 = torch.constant.int 1 - %int0_214 = torch.constant.int 0 - %132 = torch.prim.ListConstruct %int1_213, %int0_214 : (!torch.int, !torch.int) -> !torch.list - %int1_215 = torch.constant.int 1 + %121 = torch.prim.ListConstruct %int4_202, %int512_203, %int512_203 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int4_204 = torch.constant.int 4 + %int512_205 = torch.constant.int 512 + %122 = torch.prim.ListConstruct %int4_204, %int512_205, %int512_205 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %123 = torch.aten.reshape %120, %122 : !torch.vtensor<[2,2,512,512],f32>, !torch.list -> !torch.vtensor<[4,512,512],f32> + %124 = torch.aten.bmm %123, %118 : !torch.vtensor<[4,512,512],f32>, !torch.vtensor<[4,512,16],f32> -> !torch.vtensor<[4,512,16],f32> + %int2_206 = torch.constant.int 2 + %int512_207 = torch.constant.int 512 + %int16_208 = torch.constant.int 16 + %125 = torch.prim.ListConstruct %int2_206, %int2_206, %int512_207, %int16_208 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_209 = torch.constant.int 2 + %int512_210 = torch.constant.int 512 + %int16_211 = torch.constant.int 16 + %126 = torch.prim.ListConstruct %int2_209, %int2_209, %int512_210, %int16_211 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %127 = torch.aten.reshape %124, %126 : !torch.vtensor<[4,512,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> + %int0_212 = torch.constant.int 0 + %int2_213 = torch.constant.int 2 + %int1_214 = torch.constant.int 1 + %int3_215 = torch.constant.int 3 + %128 = torch.prim.ListConstruct %int0_212, %int2_213, %int1_214, %int3_215 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %int0_216 = torch.constant.int 0 - %133 = torch.prim.ListConstruct %int1_215, %int0_216 : (!torch.int, !torch.int) -> !torch.list - %134 = torch.aten.permute %arg42, %133 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_217 = torch.constant.int 1 - %int0_218 = torch.constant.int 0 - %135 = torch.prim.ListConstruct %int1_217, %int0_218 : (!torch.int, !torch.int) -> !torch.list - %int1_219 = torch.constant.int 1 - %int0_220 = torch.constant.int 0 - %136 = torch.prim.ListConstruct %int1_219, %int0_220 : (!torch.int, !torch.int) -> !torch.list - %137 = torch.aten.permute %134, %136 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_221 = torch.constant.int 1 - %int0_222 = torch.constant.int 0 - %138 = torch.prim.ListConstruct %int1_221, %int0_222 : (!torch.int, !torch.int) -> !torch.list - %int1_223 = torch.constant.int 1 - %int0_224 = torch.constant.int 0 - %139 = torch.prim.ListConstruct %int1_223, %int0_224 : (!torch.int, !torch.int) -> !torch.list - %140 = torch.aten.permute %arg42, %139 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1024_225 = torch.constant.int 1024 - %int32_226 = torch.constant.int 32 - %141 = torch.prim.ListConstruct %int1024_225, %int32_226 : (!torch.int, !torch.int) -> !torch.list - %int1024_227 = torch.constant.int 1024 - %int32_228 = torch.constant.int 32 - %142 = torch.prim.ListConstruct %int1024_227, %int32_228 : (!torch.int, !torch.int) -> !torch.list - %143 = torch.aten.reshape %result0_210, %142 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %144 = torch.aten.addmm %arg45, %143, %140, %arg44, %arg43 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> - %int2_229 = torch.constant.int 2 - %int512_230 = torch.constant.int 512 - %int32_231 = torch.constant.int 32 - %145 = torch.prim.ListConstruct %int2_229, %int512_230, %int32_231 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_232 = torch.constant.int 2 - %int512_233 = torch.constant.int 512 - %int32_234 = torch.constant.int 32 - %146 = torch.prim.ListConstruct %int2_232, %int512_233, %int32_234 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %147 = torch.aten.reshape %144, %146 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int1_235 = torch.constant.int 1 - %int0_236 = torch.constant.int 0 - %148 = torch.prim.ListConstruct %int1_235, %int0_236 : (!torch.int, !torch.int) -> !torch.list - %int1_237 = torch.constant.int 1 - %int0_238 = torch.constant.int 0 - %149 = torch.prim.ListConstruct %int1_237, %int0_238 : (!torch.int, !torch.int) -> !torch.list - %150 = torch.aten.permute %arg46, %149 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_239 = torch.constant.int 1 - %int0_240 = torch.constant.int 0 - %151 = torch.prim.ListConstruct %int1_239, %int0_240 : (!torch.int, !torch.int) -> !torch.list + %int2_217 = torch.constant.int 2 + %int1_218 = torch.constant.int 1 + %int3_219 = torch.constant.int 3 + %129 = torch.prim.ListConstruct %int0_216, %int2_217, %int1_218, %int3_219 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %130 = torch.aten.permute %127, %129 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> + %int2_220 = torch.constant.int 2 + %int512_221 = torch.constant.int 512 + %int32_222 = torch.constant.int 32 + %131 = torch.prim.ListConstruct %int2_220, %int512_221, %int32_222 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_223 = torch.constant.int 2 + %int512_224 = torch.constant.int 512 + %int32_225 = torch.constant.int 32 + %132 = torch.prim.ListConstruct %int2_223, %int512_224, %int32_225 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %133 = torch.aten.reshape %130, %132 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int1024_226 = torch.constant.int 1024 + %int32_227 = torch.constant.int 32 + %134 = torch.prim.ListConstruct %int1024_226, %int32_227 : (!torch.int, !torch.int) -> !torch.list + %int1024_228 = torch.constant.int 1024 + %int32_229 = torch.constant.int 32 + %135 = torch.prim.ListConstruct %int1024_228, %int32_229 : (!torch.int, !torch.int) -> !torch.list + %136 = torch.aten.reshape %133, %135 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %137 = torch.aten.addmm %arg40, %136, %115, %arg39, %arg38 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> + %int2_230 = torch.constant.int 2 + %int512_231 = torch.constant.int 512 + %int32_232 = torch.constant.int 32 + %138 = torch.prim.ListConstruct %int2_230, %int512_231, %int32_232 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_233 = torch.constant.int 2 + %int512_234 = torch.constant.int 512 + %int32_235 = torch.constant.int 32 + %139 = torch.prim.ListConstruct %int2_233, %int512_234, %int32_235 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %140 = torch.aten.reshape %137, %139 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %141 = torch.aten.add.Tensor %140, %result0, %arg37 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> + %int32_236 = torch.constant.int 32 + %142 = torch.prim.ListConstruct %int32_236 : (!torch.int) -> !torch.list + %float1.000000e-05_237 = torch.constant.float 1.000000e-05 + %result0_238, %result1_239, %result2_240 = torch.aten.native_layer_norm %141, %142, %arg36, %arg35, %float1.000000e-05_237 : !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> + %143 = torch.prim.TupleConstruct %result0_238, %result1_239, %result2_240 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> -> !torch.tuple, vtensor<[2,512,1],f32>, vtensor<[2,512,1],f32>> %int1_241 = torch.constant.int 1 %int0_242 = torch.constant.int 0 - %152 = torch.prim.ListConstruct %int1_241, %int0_242 : (!torch.int, !torch.int) -> !torch.list - %153 = torch.aten.permute %150, %152 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %144 = torch.prim.ListConstruct %int1_241, %int0_242 : (!torch.int, !torch.int) -> !torch.list %int1_243 = torch.constant.int 1 %int0_244 = torch.constant.int 0 - %154 = torch.prim.ListConstruct %int1_243, %int0_244 : (!torch.int, !torch.int) -> !torch.list + %145 = torch.prim.ListConstruct %int1_243, %int0_244 : (!torch.int, !torch.int) -> !torch.list + %146 = torch.aten.permute %arg42, %145 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> %int1_245 = torch.constant.int 1 %int0_246 = torch.constant.int 0 - %155 = torch.prim.ListConstruct %int1_245, %int0_246 : (!torch.int, !torch.int) -> !torch.list - %156 = torch.aten.permute %arg46, %155 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %str = torch.constant.str "none" - %157 = torch.aten.gelu %147, %str : !torch.vtensor<[2,512,32],f32>, !torch.str -> !torch.vtensor<[2,512,32],f32> - %int1024_247 = torch.constant.int 1024 - %int32_248 = torch.constant.int 32 - %158 = torch.prim.ListConstruct %int1024_247, %int32_248 : (!torch.int, !torch.int) -> !torch.list - %int1024_249 = torch.constant.int 1024 - %int32_250 = torch.constant.int 32 - %159 = torch.prim.ListConstruct %int1024_249, %int32_250 : (!torch.int, !torch.int) -> !torch.list - %160 = torch.aten.reshape %157, %159 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %161 = torch.aten.addmm %arg52, %160, %156, %arg51, %arg50 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> - %int2_251 = torch.constant.int 2 - %int512_252 = torch.constant.int 512 - %int32_253 = torch.constant.int 32 - %162 = torch.prim.ListConstruct %int2_251, %int512_252, %int32_253 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_254 = torch.constant.int 2 - %int512_255 = torch.constant.int 512 + %147 = torch.prim.ListConstruct %int1_245, %int0_246 : (!torch.int, !torch.int) -> !torch.list + %int1_247 = torch.constant.int 1 + %int0_248 = torch.constant.int 0 + %148 = torch.prim.ListConstruct %int1_247, %int0_248 : (!torch.int, !torch.int) -> !torch.list + %149 = torch.aten.permute %146, %148 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_249 = torch.constant.int 1 + %int0_250 = torch.constant.int 0 + %150 = torch.prim.ListConstruct %int1_249, %int0_250 : (!torch.int, !torch.int) -> !torch.list + %int1_251 = torch.constant.int 1 + %int0_252 = torch.constant.int 0 + %151 = torch.prim.ListConstruct %int1_251, %int0_252 : (!torch.int, !torch.int) -> !torch.list + %152 = torch.aten.permute %arg42, %151 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1024_253 = torch.constant.int 1024 + %int32_254 = torch.constant.int 32 + %153 = torch.prim.ListConstruct %int1024_253, %int32_254 : (!torch.int, !torch.int) -> !torch.list + %int1024_255 = torch.constant.int 1024 %int32_256 = torch.constant.int 32 - %163 = torch.prim.ListConstruct %int2_254, %int512_255, %int32_256 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %164 = torch.aten.reshape %161, %163 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %165 = torch.aten.add.Tensor %164, %result0_210, %arg49 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> - %int32_257 = torch.constant.int 32 - %166 = torch.prim.ListConstruct %int32_257 : (!torch.int) -> !torch.list - %float1.000000e-05_258 = torch.constant.float 1.000000e-05 - %result0_259, %result1_260, %result2_261 = torch.aten.native_layer_norm %165, %166, %arg48, %arg47, %float1.000000e-05_258 : !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> - %167 = torch.prim.TupleConstruct %result0_259, %result1_260, %result2_261 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> -> !torch.tuple, vtensor<[2,512,1],f32>, vtensor<[2,512,1],f32>> - %168 = torch.aten.zero.functional %arg53 : !torch.vtensor<[2,512,32],f32> -> !torch.vtensor<[2,512,32],f32> + %154 = torch.prim.ListConstruct %int1024_255, %int32_256 : (!torch.int, !torch.int) -> !torch.list + %155 = torch.aten.reshape %result0_238, %154 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %156 = torch.aten.addmm %arg45, %155, %152, %arg44, %arg43 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> + %int2_257 = torch.constant.int 2 + %int512_258 = torch.constant.int 512 + %int32_259 = torch.constant.int 32 + %157 = torch.prim.ListConstruct %int2_257, %int512_258, %int32_259 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_260 = torch.constant.int 2 + %int512_261 = torch.constant.int 512 %int32_262 = torch.constant.int 32 - %169 = torch.prim.ListConstruct %int32_262 : (!torch.int) -> !torch.list - %true = torch.constant.bool true - %170 = torch.prim.ListConstruct %true, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list - %result0_263, %result1_264, %result2_265 = torch.aten.native_layer_norm_backward %168, %165, %169, %result1_260, %result2_261, %arg48, %arg47, %170 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> - %171 = torch.prim.TupleConstruct %result0_263, %result1_264, %result2_265 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> -> !torch.tuple, vtensor<[32],f32>, vtensor<[32],f32>> - %int1024_266 = torch.constant.int 1024 - %int32_267 = torch.constant.int 32 - %172 = torch.prim.ListConstruct %int1024_266, %int32_267 : (!torch.int, !torch.int) -> !torch.list - %int1024_268 = torch.constant.int 1024 - %int32_269 = torch.constant.int 32 - %173 = torch.prim.ListConstruct %int1024_268, %int32_269 : (!torch.int, !torch.int) -> !torch.list - %174 = torch.aten.reshape %result0_263, %173 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %175 = torch.aten.mm %174, %153 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> - %int2_270 = torch.constant.int 2 - %int512_271 = torch.constant.int 512 - %int32_272 = torch.constant.int 32 - %176 = torch.prim.ListConstruct %int2_270, %int512_271, %int32_272 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_273 = torch.constant.int 2 - %int512_274 = torch.constant.int 512 - %int32_275 = torch.constant.int 32 - %177 = torch.prim.ListConstruct %int2_273, %int512_274, %int32_275 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %178 = torch.aten.reshape %175, %177 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %str_276 = torch.constant.str "none" - %179 = torch.aten.gelu_backward %178, %147, %str_276 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.str -> !torch.vtensor<[2,512,32],f32> + %158 = torch.prim.ListConstruct %int2_260, %int512_261, %int32_262 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %159 = torch.aten.reshape %156, %158 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int1_263 = torch.constant.int 1 + %int0_264 = torch.constant.int 0 + %160 = torch.prim.ListConstruct %int1_263, %int0_264 : (!torch.int, !torch.int) -> !torch.list + %int1_265 = torch.constant.int 1 + %int0_266 = torch.constant.int 0 + %161 = torch.prim.ListConstruct %int1_265, %int0_266 : (!torch.int, !torch.int) -> !torch.list + %162 = torch.aten.permute %arg46, %161 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_267 = torch.constant.int 1 + %int0_268 = torch.constant.int 0 + %163 = torch.prim.ListConstruct %int1_267, %int0_268 : (!torch.int, !torch.int) -> !torch.list + %int1_269 = torch.constant.int 1 + %int0_270 = torch.constant.int 0 + %164 = torch.prim.ListConstruct %int1_269, %int0_270 : (!torch.int, !torch.int) -> !torch.list + %165 = torch.aten.permute %162, %164 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_271 = torch.constant.int 1 + %int0_272 = torch.constant.int 0 + %166 = torch.prim.ListConstruct %int1_271, %int0_272 : (!torch.int, !torch.int) -> !torch.list + %int1_273 = torch.constant.int 1 + %int0_274 = torch.constant.int 0 + %167 = torch.prim.ListConstruct %int1_273, %int0_274 : (!torch.int, !torch.int) -> !torch.list + %168 = torch.aten.permute %arg46, %167 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %str = torch.constant.str "none" + %169 = torch.aten.gelu %159, %str : !torch.vtensor<[2,512,32],f32>, !torch.str -> !torch.vtensor<[2,512,32],f32> + %int1024_275 = torch.constant.int 1024 + %int32_276 = torch.constant.int 32 + %170 = torch.prim.ListConstruct %int1024_275, %int32_276 : (!torch.int, !torch.int) -> !torch.list %int1024_277 = torch.constant.int 1024 %int32_278 = torch.constant.int 32 - %180 = torch.prim.ListConstruct %int1024_277, %int32_278 : (!torch.int, !torch.int) -> !torch.list - %int1024_279 = torch.constant.int 1024 - %int32_280 = torch.constant.int 32 - %181 = torch.prim.ListConstruct %int1024_279, %int32_280 : (!torch.int, !torch.int) -> !torch.list - %182 = torch.aten.reshape %179, %181 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %183 = torch.aten.mm %182, %137 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> - %int2_281 = torch.constant.int 2 - %int512_282 = torch.constant.int 512 - %int32_283 = torch.constant.int 32 - %184 = torch.prim.ListConstruct %int2_281, %int512_282, %int32_283 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_284 = torch.constant.int 2 - %int512_285 = torch.constant.int 512 - %int32_286 = torch.constant.int 32 - %185 = torch.prim.ListConstruct %int2_284, %int512_285, %int32_286 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %186 = torch.aten.reshape %183, %185 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %187 = torch.aten.add.Tensor %result0_263, %186, %arg41 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> - %int32_287 = torch.constant.int 32 - %188 = torch.prim.ListConstruct %int32_287 : (!torch.int) -> !torch.list - %true_288 = torch.constant.bool true - %189 = torch.prim.ListConstruct %true_288, %true_288, %true_288 : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list - %result0_289, %result1_290, %result2_291 = torch.aten.native_layer_norm_backward %187, %129, %188, %result1_211, %result2_212, %arg36, %arg35, %189 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> - %190 = torch.prim.TupleConstruct %result0_289, %result1_290, %result2_291 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> -> !torch.tuple, vtensor<[32],f32>, vtensor<[32],f32>> - %int1024_292 = torch.constant.int 1024 - %int32_293 = torch.constant.int 32 - %191 = torch.prim.ListConstruct %int1024_292, %int32_293 : (!torch.int, !torch.int) -> !torch.list - %int1024_294 = torch.constant.int 1024 - %int32_295 = torch.constant.int 32 - %192 = torch.prim.ListConstruct %int1024_294, %int32_295 : (!torch.int, !torch.int) -> !torch.list - %193 = torch.aten.reshape %result0_289, %192 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %194 = torch.aten.mm %193, %103 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> - %int2_296 = torch.constant.int 2 - %int512_297 = torch.constant.int 512 - %int32_298 = torch.constant.int 32 - %195 = torch.prim.ListConstruct %int2_296, %int512_297, %int32_298 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_299 = torch.constant.int 2 - %int512_300 = torch.constant.int 512 - %int32_301 = torch.constant.int 32 - %196 = torch.prim.ListConstruct %int2_299, %int512_300, %int32_301 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %197 = torch.aten.reshape %194, %196 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int2_302 = torch.constant.int 2 - %int512_303 = torch.constant.int 512 - %int16_304 = torch.constant.int 16 - %198 = torch.prim.ListConstruct %int2_302, %int512_303, %int2_302, %int16_304 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_305 = torch.constant.int 2 - %int512_306 = torch.constant.int 512 - %int16_307 = torch.constant.int 16 - %199 = torch.prim.ListConstruct %int2_305, %int512_306, %int2_305, %int16_307 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %200 = torch.aten.reshape %197, %199 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> - %int0_308 = torch.constant.int 0 - %int2_309 = torch.constant.int 2 + %171 = torch.prim.ListConstruct %int1024_277, %int32_278 : (!torch.int, !torch.int) -> !torch.list + %172 = torch.aten.reshape %169, %171 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %173 = torch.aten.addmm %arg52, %172, %168, %arg51, %arg50 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> + %int2_279 = torch.constant.int 2 + %int512_280 = torch.constant.int 512 + %int32_281 = torch.constant.int 32 + %174 = torch.prim.ListConstruct %int2_279, %int512_280, %int32_281 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_282 = torch.constant.int 2 + %int512_283 = torch.constant.int 512 + %int32_284 = torch.constant.int 32 + %175 = torch.prim.ListConstruct %int2_282, %int512_283, %int32_284 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %176 = torch.aten.reshape %173, %175 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %177 = torch.aten.add.Tensor %176, %result0_238, %arg49 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> + %int32_285 = torch.constant.int 32 + %178 = torch.prim.ListConstruct %int32_285 : (!torch.int) -> !torch.list + %float1.000000e-05_286 = torch.constant.float 1.000000e-05 + %result0_287, %result1_288, %result2_289 = torch.aten.native_layer_norm %177, %178, %arg48, %arg47, %float1.000000e-05_286 : !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> + %179 = torch.prim.TupleConstruct %result0_287, %result1_288, %result2_289 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> -> !torch.tuple, vtensor<[2,512,1],f32>, vtensor<[2,512,1],f32>> + %int1_290 = torch.constant.int 1 + %int0_291 = torch.constant.int 0 + %180 = torch.prim.ListConstruct %int1_290, %int0_291 : (!torch.int, !torch.int) -> !torch.list + %int1_292 = torch.constant.int 1 + %int0_293 = torch.constant.int 0 + %181 = torch.prim.ListConstruct %int1_292, %int0_293 : (!torch.int, !torch.int) -> !torch.list + %182 = torch.aten.permute %arg53, %181 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_294 = torch.constant.int 1 + %int0_295 = torch.constant.int 0 + %183 = torch.prim.ListConstruct %int1_294, %int0_295 : (!torch.int, !torch.int) -> !torch.list + %int1_296 = torch.constant.int 1 + %int0_297 = torch.constant.int 0 + %184 = torch.prim.ListConstruct %int1_296, %int0_297 : (!torch.int, !torch.int) -> !torch.list + %185 = torch.aten.permute %182, %184 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_298 = torch.constant.int 1 + %int0_299 = torch.constant.int 0 + %186 = torch.prim.ListConstruct %int1_298, %int0_299 : (!torch.int, !torch.int) -> !torch.list + %int1_300 = torch.constant.int 1 + %int0_301 = torch.constant.int 0 + %187 = torch.prim.ListConstruct %int1_300, %int0_301 : (!torch.int, !torch.int) -> !torch.list + %188 = torch.aten.permute %arg53, %187 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int0_302 = torch.constant.int 0 + %int0_303 = torch.constant.int 0 + %int2_304 = torch.constant.int 2 + %int1_305 = torch.constant.int 1 + %189 = torch.aten.slice.Tensor %result0_287, %int0_302, %int0_303, %int2_304, %int1_305 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,512,32],f32> + %int0_306 = torch.constant.int 0 + %int0_307 = torch.constant.int 0 + %int2_308 = torch.constant.int 2 + %int1_309 = torch.constant.int 1 + %190 = torch.aten.slice.Tensor %189, %int0_306, %int0_307, %int2_308, %int1_309 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,512,32],f32> %int1_310 = torch.constant.int 1 - %int3_311 = torch.constant.int 3 - %201 = torch.prim.ListConstruct %int0_308, %int2_309, %int1_310, %int3_311 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_312 = torch.constant.int 0 - %int2_313 = torch.constant.int 2 - %int1_314 = torch.constant.int 1 - %int3_315 = torch.constant.int 3 - %202 = torch.prim.ListConstruct %int0_312, %int2_313, %int1_314, %int3_315 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %203 = torch.aten.permute %200, %202 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> - %int4_316 = torch.constant.int 4 - %int512_317 = torch.constant.int 512 - %int16_318 = torch.constant.int 16 - %204 = torch.prim.ListConstruct %int4_316, %int512_317, %int16_318 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int4_319 = torch.constant.int 4 - %int512_320 = torch.constant.int 512 - %int16_321 = torch.constant.int 16 - %205 = torch.prim.ListConstruct %int4_319, %int512_320, %int16_321 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %206 = torch.aten.reshape %203, %205 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[4,512,16],f32> - %207 = torch.aten.bmm %206, %97 : !torch.vtensor<[4,512,16],f32>, !torch.vtensor<[4,16,512],f32> -> !torch.vtensor<[4,512,512],f32> - %int2_322 = torch.constant.int 2 - %int512_323 = torch.constant.int 512 - %208 = torch.prim.ListConstruct %int2_322, %int2_322, %int512_323, %int512_323 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_324 = torch.constant.int 2 - %int512_325 = torch.constant.int 512 - %209 = torch.prim.ListConstruct %int2_324, %int2_324, %int512_325, %int512_325 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %210 = torch.aten.reshape %207, %209 : !torch.vtensor<[4,512,512],f32>, !torch.list -> !torch.vtensor<[2,2,512,512],f32> - %int-1_326 = torch.constant.int -1 - %int6_327 = torch.constant.int 6 - %211 = torch.aten._softmax_backward_data %210, %75, %int-1_326, %int6_327 : !torch.vtensor<[2,2,512,512],f32>, !torch.vtensor<[2,2,512,512],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,2,512,512],f32> - %212 = torch.aten.div.Tensor %211, %arg21 : !torch.vtensor<[2,2,512,512],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,2,512,512],f32> - %int4_328 = torch.constant.int 4 - %int512_329 = torch.constant.int 512 - %213 = torch.prim.ListConstruct %int4_328, %int512_329, %int512_329 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int4_330 = torch.constant.int 4 - %int512_331 = torch.constant.int 512 - %214 = torch.prim.ListConstruct %int4_330, %int512_331, %int512_331 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %215 = torch.aten.reshape %212, %214 : !torch.vtensor<[2,2,512,512],f32>, !torch.list -> !torch.vtensor<[4,512,512],f32> - %216 = torch.aten.bmm %215, %36 : !torch.vtensor<[4,512,512],f32>, !torch.vtensor<[4,512,16],f32> -> !torch.vtensor<[4,512,16],f32> - %int2_332 = torch.constant.int 2 - %int512_333 = torch.constant.int 512 - %int16_334 = torch.constant.int 16 - %217 = torch.prim.ListConstruct %int2_332, %int2_332, %int512_333, %int16_334 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_311 = torch.constant.int 0 + %int1_312 = torch.constant.int 1 + %int1_313 = torch.constant.int 1 + %191 = torch.aten.slice.Tensor %190, %int1_310, %int0_311, %int1_312, %int1_313 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1,32],f32> + %int2_314 = torch.constant.int 2 + %int0_315 = torch.constant.int 0 + %int32_316 = torch.constant.int 32 + %int1_317 = torch.constant.int 1 + %192 = torch.aten.slice.Tensor %191, %int2_314, %int0_315, %int32_316, %int1_317 : !torch.vtensor<[2,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1,32],f32> + %int2_318 = torch.constant.int 2 + %int32_319 = torch.constant.int 32 + %193 = torch.prim.ListConstruct %int2_318, %int32_319 : (!torch.int, !torch.int) -> !torch.list + %int2_320 = torch.constant.int 2 + %int32_321 = torch.constant.int 32 + %194 = torch.prim.ListConstruct %int2_320, %int32_321 : (!torch.int, !torch.int) -> !torch.list + %195 = torch.aten.reshape %192, %194 : !torch.vtensor<[2,1,32],f32>, !torch.list -> !torch.vtensor<[2,32],f32> + %196 = torch.aten.addmm %arg56, %195, %188, %arg55, %arg54 : !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,32],f32> + %197 = torch.aten.tanh %196 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> + %int1_322 = torch.constant.int 1 + %int0_323 = torch.constant.int 0 + %198 = torch.prim.ListConstruct %int1_322, %int0_323 : (!torch.int, !torch.int) -> !torch.list + %int1_324 = torch.constant.int 1 + %int0_325 = torch.constant.int 0 + %199 = torch.prim.ListConstruct %int1_324, %int0_325 : (!torch.int, !torch.int) -> !torch.list + %200 = torch.aten.permute %arg57, %199 : !torch.vtensor<[2,32],f32>, !torch.list -> !torch.vtensor<[32,2],f32> + %int1_326 = torch.constant.int 1 + %int0_327 = torch.constant.int 0 + %201 = torch.prim.ListConstruct %int1_326, %int0_327 : (!torch.int, !torch.int) -> !torch.list + %int1_328 = torch.constant.int 1 + %int0_329 = torch.constant.int 0 + %202 = torch.prim.ListConstruct %int1_328, %int0_329 : (!torch.int, !torch.int) -> !torch.list + %203 = torch.aten.permute %200, %202 : !torch.vtensor<[32,2],f32>, !torch.list -> !torch.vtensor<[2,32],f32> + %int1_330 = torch.constant.int 1 + %int0_331 = torch.constant.int 0 + %204 = torch.prim.ListConstruct %int1_330, %int0_331 : (!torch.int, !torch.int) -> !torch.list + %int1_332 = torch.constant.int 1 + %int0_333 = torch.constant.int 0 + %205 = torch.prim.ListConstruct %int1_332, %int0_333 : (!torch.int, !torch.int) -> !torch.list + %206 = torch.aten.permute %arg57, %205 : !torch.vtensor<[2,32],f32>, !torch.list -> !torch.vtensor<[32,2],f32> + %207 = torch.aten.addmm %arg60, %197, %206, %arg59, %arg58 : !torch.vtensor<[2],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,2],f32> + %int2_334 = torch.constant.int 2 + %208 = torch.prim.ListConstruct %int2_334, %int2_334 : (!torch.int, !torch.int) -> !torch.list %int2_335 = torch.constant.int 2 - %int512_336 = torch.constant.int 512 - %int16_337 = torch.constant.int 16 - %218 = torch.prim.ListConstruct %int2_335, %int2_335, %int512_336, %int16_337 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %219 = torch.aten.reshape %216, %218 : !torch.vtensor<[4,512,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> - %int0_338 = torch.constant.int 0 + %209 = torch.prim.ListConstruct %int2_335, %int2_335 : (!torch.int, !torch.int) -> !torch.list + %210 = torch.aten.reshape %207, %209 : !torch.vtensor<[2,2],f32>, !torch.list -> !torch.vtensor<[2,2],f32> + %int1_336 = torch.constant.int 1 + %false_337 = torch.constant.bool false + %211 = torch.aten._log_softmax %210, %int1_336, %false_337 : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,2],f32> + %int2_338 = torch.constant.int 2 + %212 = torch.prim.ListConstruct %int2_338 : (!torch.int) -> !torch.list %int2_339 = torch.constant.int 2 - %int1_340 = torch.constant.int 1 - %int3_341 = torch.constant.int 3 - %220 = torch.prim.ListConstruct %int0_338, %int2_339, %int1_340, %int3_341 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_342 = torch.constant.int 0 - %int2_343 = torch.constant.int 2 - %int1_344 = torch.constant.int 1 - %int3_345 = torch.constant.int 3 - %221 = torch.prim.ListConstruct %int0_342, %int2_343, %int1_344, %int3_345 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %222 = torch.aten.permute %219, %221 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> - %int2_346 = torch.constant.int 2 - %int512_347 = torch.constant.int 512 - %int32_348 = torch.constant.int 32 - %223 = torch.prim.ListConstruct %int2_346, %int512_347, %int32_348 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %213 = torch.prim.ListConstruct %int2_339 : (!torch.int) -> !torch.list + %214 = torch.aten.reshape %arg61, %213 : !torch.vtensor<[2],si64>, !torch.list -> !torch.vtensor<[2],si64> + %none_340 = torch.constant.none + %int1_341 = torch.constant.int 1 + %int-100 = torch.constant.int -100 + %output, %total_weight = torch.aten.nll_loss_forward %211, %214, %none_340, %int1_341, %int-100 : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32> + %215 = torch.prim.TupleConstruct %output, %total_weight : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.tuple, vtensor<[],f32>> + %none_342 = torch.constant.none + %int1_343 = torch.constant.int 1 + %int-100_344 = torch.constant.int -100 + %216 = torch.aten.nll_loss_backward %arg62, %211, %214, %none_342, %int1_343, %int-100_344, %total_weight : !torch.vtensor<[],f32>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[2],si64>, !torch.none, !torch.int, !torch.int, !torch.vtensor<[],f32> -> !torch.vtensor<[2,2],f32> + %int1_345 = torch.constant.int 1 + %int6_346 = torch.constant.int 6 + %217 = torch.aten._log_softmax_backward_data %216, %211, %int1_345, %int6_346 : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,2],f32> + %int2_347 = torch.constant.int 2 + %218 = torch.prim.ListConstruct %int2_347, %int2_347 : (!torch.int, !torch.int) -> !torch.list + %int2_348 = torch.constant.int 2 + %219 = torch.prim.ListConstruct %int2_348, %int2_348 : (!torch.int, !torch.int) -> !torch.list + %220 = torch.aten.reshape %217, %219 : !torch.vtensor<[2,2],f32>, !torch.list -> !torch.vtensor<[2,2],f32> + %221 = torch.aten.mm %220, %203 : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> + %222 = torch.aten.tanh_backward %221, %197 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> + %223 = torch.aten.mm %222, %185 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[2,32],f32> %int2_349 = torch.constant.int 2 - %int512_350 = torch.constant.int 512 + %int1_350 = torch.constant.int 1 %int32_351 = torch.constant.int 32 - %224 = torch.prim.ListConstruct %int2_349, %int512_350, %int32_351 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %225 = torch.aten.reshape %222, %224 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int1024_352 = torch.constant.int 1024 - %int32_353 = torch.constant.int 32 - %226 = torch.prim.ListConstruct %int1024_352, %int32_353 : (!torch.int, !torch.int) -> !torch.list - %int1024_354 = torch.constant.int 1024 - %int32_355 = torch.constant.int 32 - %227 = torch.prim.ListConstruct %int1024_354, %int32_355 : (!torch.int, !torch.int) -> !torch.list - %228 = torch.aten.reshape %225, %227 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %229 = torch.aten.mm %228, %13 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> - %int2_356 = torch.constant.int 2 - %int512_357 = torch.constant.int 512 - %int32_358 = torch.constant.int 32 - %230 = torch.prim.ListConstruct %int2_356, %int512_357, %int32_358 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_359 = torch.constant.int 2 - %int512_360 = torch.constant.int 512 - %int32_361 = torch.constant.int 32 - %231 = torch.prim.ListConstruct %int2_359, %int512_360, %int32_361 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %232 = torch.aten.reshape %229, %231 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %224 = torch.prim.ListConstruct %int2_349, %int1_350, %int32_351 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_352 = torch.constant.int 2 + %int1_353 = torch.constant.int 1 + %int32_354 = torch.constant.int 32 + %225 = torch.prim.ListConstruct %int2_352, %int1_353, %int32_354 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %226 = torch.aten.reshape %223, %225 : !torch.vtensor<[2,32],f32>, !torch.list -> !torch.vtensor<[2,1,32],f32> + %227 = torch.aten.zero.functional %arg63 : !torch.vtensor<[2,512,32],f32> -> !torch.vtensor<[2,512,32],f32> + %none_355 = torch.constant.none + %228 = torch.aten.clone %227, %none_355 : !torch.vtensor<[2,512,32],f32>, !torch.none -> !torch.vtensor<[2,512,32],f32> + %int0_356 = torch.constant.int 0 + %int0_357 = torch.constant.int 0 + %int2_358 = torch.constant.int 2 + %int1_359 = torch.constant.int 1 + %229 = torch.aten.slice.Tensor %228, %int0_356, %int0_357, %int2_358, %int1_359 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,512,32],f32> + %int1_360 = torch.constant.int 1 + %int0_361 = torch.constant.int 0 %int1_362 = torch.constant.int 1 - %int0_363 = torch.constant.int 0 - %233 = torch.prim.ListConstruct %int1_362, %int0_363 : (!torch.int, !torch.int) -> !torch.list - %int1_364 = torch.constant.int 1 + %int1_363 = torch.constant.int 1 + %230 = torch.aten.slice.Tensor %229, %int1_360, %int0_361, %int1_362, %int1_363 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1,32],f32> + %int2_364 = torch.constant.int 2 %int0_365 = torch.constant.int 0 - %234 = torch.prim.ListConstruct %int1_364, %int0_365 : (!torch.int, !torch.int) -> !torch.list - %235 = torch.aten.permute %arg19, %234 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_366 = torch.constant.int 1 - %int0_367 = torch.constant.int 0 - %236 = torch.prim.ListConstruct %int1_366, %int0_367 : (!torch.int, !torch.int) -> !torch.list - %int1_368 = torch.constant.int 1 - %int0_369 = torch.constant.int 0 - %237 = torch.prim.ListConstruct %int1_368, %int0_369 : (!torch.int, !torch.int) -> !torch.list - %238 = torch.aten.permute %235, %237 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_370 = torch.constant.int 1 - %int2_371 = torch.constant.int 2 - %239 = torch.aten.transpose.int %68, %int1_370, %int2_371 : !torch.vtensor<[4,512,16],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,16,512],f32> - %240 = torch.aten.bmm %239, %215 : !torch.vtensor<[4,16,512],f32>, !torch.vtensor<[4,512,512],f32> -> !torch.vtensor<[4,16,512],f32> + %int32_366 = torch.constant.int 32 + %int1_367 = torch.constant.int 1 + %231 = torch.aten.slice.Tensor %230, %int2_364, %int0_365, %int32_366, %int1_367 : !torch.vtensor<[2,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1,32],f32> + %false_368 = torch.constant.bool false + %232 = torch.aten.copy_ %231, %226, %false_368 : !torch.vtensor<[2,1,32],f32>, !torch.vtensor<[2,1,32],f32>, !torch.bool -> !torch.vtensor<[2,1,32],f32> + %233 = torch.aten.zero.functional %arg64 : !torch.vtensor<[2,512,32],f32> -> !torch.vtensor<[2,512,32],f32> + %none_369 = torch.constant.none + %234 = torch.aten.clone %233, %none_369 : !torch.vtensor<[2,512,32],f32>, !torch.none -> !torch.vtensor<[2,512,32],f32> + %int0_370 = torch.constant.int 0 + %int0_371 = torch.constant.int 0 %int2_372 = torch.constant.int 2 - %int16_373 = torch.constant.int 16 - %int512_374 = torch.constant.int 512 - %241 = torch.prim.ListConstruct %int2_372, %int2_372, %int16_373, %int512_374 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_375 = torch.constant.int 2 - %int16_376 = torch.constant.int 16 - %int512_377 = torch.constant.int 512 - %242 = torch.prim.ListConstruct %int2_375, %int2_375, %int16_376, %int512_377 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %243 = torch.aten.reshape %240, %242 : !torch.vtensor<[4,16,512],f32>, !torch.list -> !torch.vtensor<[2,2,16,512],f32> - %int-1_378 = torch.constant.int -1 - %int-2_379 = torch.constant.int -2 - %244 = torch.aten.transpose.int %243, %int-1_378, %int-2_379 : !torch.vtensor<[2,2,16,512],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,2,512,16],f32> - %int0_380 = torch.constant.int 0 - %int2_381 = torch.constant.int 2 - %int1_382 = torch.constant.int 1 - %int3_383 = torch.constant.int 3 - %245 = torch.prim.ListConstruct %int0_380, %int2_381, %int1_382, %int3_383 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_384 = torch.constant.int 0 - %int2_385 = torch.constant.int 2 - %int1_386 = torch.constant.int 1 - %int3_387 = torch.constant.int 3 - %246 = torch.prim.ListConstruct %int0_384, %int2_385, %int1_386, %int3_387 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %247 = torch.aten.permute %244, %246 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> - %int2_388 = torch.constant.int 2 - %int512_389 = torch.constant.int 512 - %int32_390 = torch.constant.int 32 - %248 = torch.prim.ListConstruct %int2_388, %int512_389, %int32_390 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_391 = torch.constant.int 2 - %int512_392 = torch.constant.int 512 + %int1_373 = torch.constant.int 1 + %235 = torch.aten.slice.Tensor %234, %int0_370, %int0_371, %int2_372, %int1_373 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,512,32],f32> + %false_374 = torch.constant.bool false + %236 = torch.aten.copy_ %235, %228, %false_374 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.bool -> !torch.vtensor<[2,512,32],f32> + %int32_375 = torch.constant.int 32 + %237 = torch.prim.ListConstruct %int32_375 : (!torch.int) -> !torch.list + %true = torch.constant.bool true + %238 = torch.prim.ListConstruct %true, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0_376, %result1_377, %result2_378 = torch.aten.native_layer_norm_backward %234, %177, %237, %result1_288, %result2_289, %arg48, %arg47, %238 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> + %239 = torch.prim.TupleConstruct %result0_376, %result1_377, %result2_378 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> -> !torch.tuple, vtensor<[32],f32>, vtensor<[32],f32>> + %int1024_379 = torch.constant.int 1024 + %int32_380 = torch.constant.int 32 + %240 = torch.prim.ListConstruct %int1024_379, %int32_380 : (!torch.int, !torch.int) -> !torch.list + %int1024_381 = torch.constant.int 1024 + %int32_382 = torch.constant.int 32 + %241 = torch.prim.ListConstruct %int1024_381, %int32_382 : (!torch.int, !torch.int) -> !torch.list + %242 = torch.aten.reshape %result0_376, %241 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %243 = torch.aten.mm %242, %165 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> + %int2_383 = torch.constant.int 2 + %int512_384 = torch.constant.int 512 + %int32_385 = torch.constant.int 32 + %244 = torch.prim.ListConstruct %int2_383, %int512_384, %int32_385 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_386 = torch.constant.int 2 + %int512_387 = torch.constant.int 512 + %int32_388 = torch.constant.int 32 + %245 = torch.prim.ListConstruct %int2_386, %int512_387, %int32_388 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %246 = torch.aten.reshape %243, %245 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %str_389 = torch.constant.str "none" + %247 = torch.aten.gelu_backward %246, %159, %str_389 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.str -> !torch.vtensor<[2,512,32],f32> + %int1024_390 = torch.constant.int 1024 + %int32_391 = torch.constant.int 32 + %248 = torch.prim.ListConstruct %int1024_390, %int32_391 : (!torch.int, !torch.int) -> !torch.list + %int1024_392 = torch.constant.int 1024 %int32_393 = torch.constant.int 32 - %249 = torch.prim.ListConstruct %int2_391, %int512_392, %int32_393 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %250 = torch.aten.reshape %247, %249 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int1024_394 = torch.constant.int 1024 - %int32_395 = torch.constant.int 32 - %251 = torch.prim.ListConstruct %int1024_394, %int32_395 : (!torch.int, !torch.int) -> !torch.list - %int1024_396 = torch.constant.int 1024 - %int32_397 = torch.constant.int 32 - %252 = torch.prim.ListConstruct %int1024_396, %int32_397 : (!torch.int, !torch.int) -> !torch.list - %253 = torch.aten.reshape %250, %252 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %254 = torch.aten.mm %253, %238 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> - %int2_398 = torch.constant.int 2 - %int512_399 = torch.constant.int 512 + %249 = torch.prim.ListConstruct %int1024_392, %int32_393 : (!torch.int, !torch.int) -> !torch.list + %250 = torch.aten.reshape %247, %249 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %251 = torch.aten.mm %250, %149 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> + %int2_394 = torch.constant.int 2 + %int512_395 = torch.constant.int 512 + %int32_396 = torch.constant.int 32 + %252 = torch.prim.ListConstruct %int2_394, %int512_395, %int32_396 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_397 = torch.constant.int 2 + %int512_398 = torch.constant.int 512 + %int32_399 = torch.constant.int 32 + %253 = torch.prim.ListConstruct %int2_397, %int512_398, %int32_399 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %254 = torch.aten.reshape %251, %253 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %255 = torch.aten.add.Tensor %result0_376, %254, %arg41 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> %int32_400 = torch.constant.int 32 - %255 = torch.prim.ListConstruct %int2_398, %int512_399, %int32_400 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_401 = torch.constant.int 2 - %int512_402 = torch.constant.int 512 - %int32_403 = torch.constant.int 32 - %256 = torch.prim.ListConstruct %int2_401, %int512_402, %int32_403 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %257 = torch.aten.reshape %254, %256 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int1_404 = torch.constant.int 1 - %int0_405 = torch.constant.int 0 - %258 = torch.prim.ListConstruct %int1_404, %int0_405 : (!torch.int, !torch.int) -> !torch.list - %int1_406 = torch.constant.int 1 - %int0_407 = torch.constant.int 0 - %259 = torch.prim.ListConstruct %int1_406, %int0_407 : (!torch.int, !torch.int) -> !torch.list - %260 = torch.aten.permute %arg32, %259 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_408 = torch.constant.int 1 - %int0_409 = torch.constant.int 0 - %261 = torch.prim.ListConstruct %int1_408, %int0_409 : (!torch.int, !torch.int) -> !torch.list - %int1_410 = torch.constant.int 1 - %int0_411 = torch.constant.int 0 - %262 = torch.prim.ListConstruct %int1_410, %int0_411 : (!torch.int, !torch.int) -> !torch.list - %263 = torch.aten.permute %260, %262 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_412 = torch.constant.int 1 - %int2_413 = torch.constant.int 2 - %264 = torch.aten.transpose.int %111, %int1_412, %int2_413 : !torch.vtensor<[4,512,512],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,512,512],f32> - %265 = torch.aten.bmm %264, %206 : !torch.vtensor<[4,512,512],f32>, !torch.vtensor<[4,512,16],f32> -> !torch.vtensor<[4,512,16],f32> - %int2_414 = torch.constant.int 2 - %int512_415 = torch.constant.int 512 - %int16_416 = torch.constant.int 16 - %266 = torch.prim.ListConstruct %int2_414, %int2_414, %int512_415, %int16_416 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_417 = torch.constant.int 2 - %int512_418 = torch.constant.int 512 - %int16_419 = torch.constant.int 16 - %267 = torch.prim.ListConstruct %int2_417, %int2_417, %int512_418, %int16_419 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %268 = torch.aten.reshape %265, %267 : !torch.vtensor<[4,512,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> - %int0_420 = torch.constant.int 0 - %int2_421 = torch.constant.int 2 - %int1_422 = torch.constant.int 1 - %int3_423 = torch.constant.int 3 - %269 = torch.prim.ListConstruct %int0_420, %int2_421, %int1_422, %int3_423 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_424 = torch.constant.int 0 - %int2_425 = torch.constant.int 2 - %int1_426 = torch.constant.int 1 - %int3_427 = torch.constant.int 3 - %270 = torch.prim.ListConstruct %int0_424, %int2_425, %int1_426, %int3_427 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %271 = torch.aten.permute %268, %270 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> - %int2_428 = torch.constant.int 2 - %int512_429 = torch.constant.int 512 - %int32_430 = torch.constant.int 32 - %272 = torch.prim.ListConstruct %int2_428, %int512_429, %int32_430 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_431 = torch.constant.int 2 - %int512_432 = torch.constant.int 512 - %int32_433 = torch.constant.int 32 - %273 = torch.prim.ListConstruct %int2_431, %int512_432, %int32_433 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %274 = torch.aten.reshape %271, %273 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int1024_434 = torch.constant.int 1024 - %int32_435 = torch.constant.int 32 - %275 = torch.prim.ListConstruct %int1024_434, %int32_435 : (!torch.int, !torch.int) -> !torch.list - %int1024_436 = torch.constant.int 1024 - %int32_437 = torch.constant.int 32 - %276 = torch.prim.ListConstruct %int1024_436, %int32_437 : (!torch.int, !torch.int) -> !torch.list - %277 = torch.aten.reshape %274, %276 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %278 = torch.aten.mm %277, %263 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> - %int2_438 = torch.constant.int 2 - %int512_439 = torch.constant.int 512 - %int32_440 = torch.constant.int 32 - %279 = torch.prim.ListConstruct %int2_438, %int512_439, %int32_440 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_441 = torch.constant.int 2 + %256 = torch.prim.ListConstruct %int32_400 : (!torch.int) -> !torch.list + %true_401 = torch.constant.bool true + %257 = torch.prim.ListConstruct %true_401, %true_401, %true_401 : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0_402, %result1_403, %result2_404 = torch.aten.native_layer_norm_backward %255, %141, %256, %result1_239, %result2_240, %arg36, %arg35, %257 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> + %258 = torch.prim.TupleConstruct %result0_402, %result1_403, %result2_404 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> -> !torch.tuple, vtensor<[32],f32>, vtensor<[32],f32>> + %int1024_405 = torch.constant.int 1024 + %int32_406 = torch.constant.int 32 + %259 = torch.prim.ListConstruct %int1024_405, %int32_406 : (!torch.int, !torch.int) -> !torch.list + %int1024_407 = torch.constant.int 1024 + %int32_408 = torch.constant.int 32 + %260 = torch.prim.ListConstruct %int1024_407, %int32_408 : (!torch.int, !torch.int) -> !torch.list + %261 = torch.aten.reshape %result0_402, %260 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %262 = torch.aten.mm %261, %112 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> + %int2_409 = torch.constant.int 2 + %int512_410 = torch.constant.int 512 + %int32_411 = torch.constant.int 32 + %263 = torch.prim.ListConstruct %int2_409, %int512_410, %int32_411 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_412 = torch.constant.int 2 + %int512_413 = torch.constant.int 512 + %int32_414 = torch.constant.int 32 + %264 = torch.prim.ListConstruct %int2_412, %int512_413, %int32_414 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %265 = torch.aten.reshape %262, %264 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int2_415 = torch.constant.int 2 + %int512_416 = torch.constant.int 512 + %int16_417 = torch.constant.int 16 + %266 = torch.prim.ListConstruct %int2_415, %int512_416, %int2_415, %int16_417 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_418 = torch.constant.int 2 + %int512_419 = torch.constant.int 512 + %int16_420 = torch.constant.int 16 + %267 = torch.prim.ListConstruct %int2_418, %int512_419, %int2_418, %int16_420 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %268 = torch.aten.reshape %265, %267 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> + %int0_421 = torch.constant.int 0 + %int2_422 = torch.constant.int 2 + %int1_423 = torch.constant.int 1 + %int3_424 = torch.constant.int 3 + %269 = torch.prim.ListConstruct %int0_421, %int2_422, %int1_423, %int3_424 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_425 = torch.constant.int 0 + %int2_426 = torch.constant.int 2 + %int1_427 = torch.constant.int 1 + %int3_428 = torch.constant.int 3 + %270 = torch.prim.ListConstruct %int0_425, %int2_426, %int1_427, %int3_428 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %271 = torch.aten.permute %268, %270 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> + %int4_429 = torch.constant.int 4 + %int512_430 = torch.constant.int 512 + %int16_431 = torch.constant.int 16 + %272 = torch.prim.ListConstruct %int4_429, %int512_430, %int16_431 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int4_432 = torch.constant.int 4 + %int512_433 = torch.constant.int 512 + %int16_434 = torch.constant.int 16 + %273 = torch.prim.ListConstruct %int4_432, %int512_433, %int16_434 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %274 = torch.aten.reshape %271, %273 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[4,512,16],f32> + %275 = torch.aten.bmm %274, %106 : !torch.vtensor<[4,512,16],f32>, !torch.vtensor<[4,16,512],f32> -> !torch.vtensor<[4,512,512],f32> + %int2_435 = torch.constant.int 2 + %int512_436 = torch.constant.int 512 + %276 = torch.prim.ListConstruct %int2_435, %int2_435, %int512_436, %int512_436 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_437 = torch.constant.int 2 + %int512_438 = torch.constant.int 512 + %277 = torch.prim.ListConstruct %int2_437, %int2_437, %int512_438, %int512_438 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %278 = torch.aten.reshape %275, %277 : !torch.vtensor<[4,512,512],f32>, !torch.list -> !torch.vtensor<[2,2,512,512],f32> + %int-1_439 = torch.constant.int -1 + %int6_440 = torch.constant.int 6 + %279 = torch.aten._softmax_backward_data %278, %82, %int-1_439, %int6_440 : !torch.vtensor<[2,2,512,512],f32>, !torch.vtensor<[2,2,512,512],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,2,512,512],f32> + %280 = torch.aten.div.Tensor %279, %arg21 : !torch.vtensor<[2,2,512,512],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,2,512,512],f32> + %int4_441 = torch.constant.int 4 %int512_442 = torch.constant.int 512 - %int32_443 = torch.constant.int 32 - %280 = torch.prim.ListConstruct %int2_441, %int512_442, %int32_443 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %281 = torch.aten.reshape %278, %280 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %282 = torch.aten.add.Tensor %result0_289, %281, %arg55 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> - %283 = torch.aten.add.Tensor %282, %257, %arg54 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> - %284 = torch.aten.add.Tensor %283, %232, %arg15 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> - %int32_444 = torch.constant.int 32 - %285 = torch.prim.ListConstruct %int32_444 : (!torch.int) -> !torch.list - %true_445 = torch.constant.bool true - %286 = torch.prim.ListConstruct %true_445, %true_445, %true_445 : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list - %result0_446, %result1_447, %result2_448 = torch.aten.native_layer_norm_backward %284, %5, %285, %result1, %result2, %arg7, %arg6, %286 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> - %287 = torch.prim.TupleConstruct %result0_446, %result1_447, %result2_448 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> -> !torch.tuple, vtensor<[32],f32>, vtensor<[32],f32>> - %int28996 = torch.constant.int 28996 - %int0_449 = torch.constant.int 0 - %false_450 = torch.constant.bool false - %288 = torch.aten.embedding_dense_backward %result0_446, %arg5, %int28996, %int0_449, %false_450 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512],si64>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[28996,32],f32> - %289 = torch.aten.add.Tensor %arg56, %288, %arg4 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.int -> !torch.vtensor<[28996,32],f32> - %290 = torch.aten.mul.Tensor %arg58, %arg57 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[28996,32],f32> - %291 = torch.aten.addcmul %290, %289, %289, %arg3 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.float -> !torch.vtensor<[28996,32],f32> - %292 = torch.aten.sqrt %291 : !torch.vtensor<[28996,32],f32> -> !torch.vtensor<[28996,32],f32> - %293 = torch.aten.add.Tensor %292, %arg2, %arg1 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[28996,32],f32> - %294 = torch.aten.mul.Tensor %arg61, %arg60 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[28996,32],f32> - %295 = torch.aten.add.Tensor %294, %289, %arg59 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.float -> !torch.vtensor<[28996,32],f32> - %296 = torch.aten.addcdiv %arg14, %295, %293, %arg0 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.float -> !torch.vtensor<[28996,32],f32> + %281 = torch.prim.ListConstruct %int4_441, %int512_442, %int512_442 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int4_443 = torch.constant.int 4 + %int512_444 = torch.constant.int 512 + %282 = torch.prim.ListConstruct %int4_443, %int512_444, %int512_444 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %283 = torch.aten.reshape %280, %282 : !torch.vtensor<[2,2,512,512],f32>, !torch.list -> !torch.vtensor<[4,512,512],f32> + %284 = torch.aten.bmm %283, %40 : !torch.vtensor<[4,512,512],f32>, !torch.vtensor<[4,512,16],f32> -> !torch.vtensor<[4,512,16],f32> + %int2_445 = torch.constant.int 2 + %int512_446 = torch.constant.int 512 + %int16_447 = torch.constant.int 16 + %285 = torch.prim.ListConstruct %int2_445, %int2_445, %int512_446, %int16_447 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_448 = torch.constant.int 2 + %int512_449 = torch.constant.int 512 + %int16_450 = torch.constant.int 16 + %286 = torch.prim.ListConstruct %int2_448, %int2_448, %int512_449, %int16_450 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %287 = torch.aten.reshape %284, %286 : !torch.vtensor<[4,512,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> %int0_451 = torch.constant.int 0 - %297 = torch.prim.ListConstruct %int0_451 : (!torch.int) -> !torch.list - %true_452 = torch.constant.bool true - %none_453 = torch.constant.none - %298 = torch.aten.sum.dim_IntList %result0_446, %297, %true_452, %none_453 : !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,512,32],f32> - %int512_454 = torch.constant.int 512 - %int-1_455 = torch.constant.int -1 - %false_456 = torch.constant.bool false - %299 = torch.aten.embedding_dense_backward %298, %0, %int512_454, %int-1_455, %false_456 : !torch.vtensor<[1,512,32],f32>, !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[512,32],f32> - %300 = torch.aten.add.Tensor %arg66, %299, %arg65 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.int -> !torch.vtensor<[512,32],f32> - %301 = torch.aten.mul.Tensor %arg67, %arg57 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[512,32],f32> - %302 = torch.aten.addcmul %301, %300, %300, %arg64 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.float -> !torch.vtensor<[512,32],f32> - %303 = torch.aten.sqrt %302 : !torch.vtensor<[512,32],f32> -> !torch.vtensor<[512,32],f32> - %304 = torch.aten.add.Tensor %303, %arg2, %arg63 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[512,32],f32> - %305 = torch.aten.mul.Tensor %arg69, %arg60 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[512,32],f32> - %306 = torch.aten.add.Tensor %305, %300, %arg68 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.float -> !torch.vtensor<[512,32],f32> - %307 = torch.aten.addcdiv %arg10, %306, %304, %arg62 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.float -> !torch.vtensor<[512,32],f32> - %int2_457 = torch.constant.int 2 - %int-1_458 = torch.constant.int -1 - %false_459 = torch.constant.bool false - %308 = torch.aten.embedding_dense_backward %result0_446, %arg12, %int2_457, %int-1_458, %false_459 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512],si64>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[2,32],f32> - %309 = torch.aten.add.Tensor %arg74, %308, %arg73 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.int -> !torch.vtensor<[2,32],f32> - %310 = torch.aten.mul.Tensor %arg75, %arg57 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,32],f32> - %311 = torch.aten.addcmul %310, %309, %309, %arg72 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> - %312 = torch.aten.sqrt %311 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> - %313 = torch.aten.add.Tensor %312, %arg2, %arg71 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[2,32],f32> - %314 = torch.aten.mul.Tensor %arg77, %arg60 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,32],f32> - %315 = torch.aten.add.Tensor %314, %309, %arg76 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> - %316 = torch.aten.addcdiv %arg13, %315, %313, %arg70 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> - %317 = torch.aten.add.Tensor %arg82, %result1_447, %arg81 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %318 = torch.aten.mul.Tensor %arg83, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %319 = torch.aten.addcmul %318, %317, %317, %arg80 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %320 = torch.aten.sqrt %319 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %321 = torch.aten.add.Tensor %320, %arg2, %arg79 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %322 = torch.aten.mul.Tensor %arg85, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %323 = torch.aten.add.Tensor %322, %317, %arg84 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %324 = torch.aten.addcdiv %arg7, %323, %321, %arg78 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %325 = torch.aten.add.Tensor %arg90, %result2_448, %arg89 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %326 = torch.aten.mul.Tensor %arg91, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %327 = torch.aten.addcmul %326, %325, %325, %arg88 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %328 = torch.aten.sqrt %327 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %329 = torch.aten.add.Tensor %328, %arg2, %arg87 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %330 = torch.aten.mul.Tensor %arg93, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %331 = torch.aten.add.Tensor %330, %325, %arg92 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %332 = torch.aten.addcdiv %arg6, %331, %329, %arg86 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %int1024_460 = torch.constant.int 1024 + %int2_452 = torch.constant.int 2 + %int1_453 = torch.constant.int 1 + %int3_454 = torch.constant.int 3 + %288 = torch.prim.ListConstruct %int0_451, %int2_452, %int1_453, %int3_454 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_455 = torch.constant.int 0 + %int2_456 = torch.constant.int 2 + %int1_457 = torch.constant.int 1 + %int3_458 = torch.constant.int 3 + %289 = torch.prim.ListConstruct %int0_455, %int2_456, %int1_457, %int3_458 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %290 = torch.aten.permute %287, %289 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> + %int2_459 = torch.constant.int 2 + %int512_460 = torch.constant.int 512 %int32_461 = torch.constant.int 32 - %333 = torch.prim.ListConstruct %int1024_460, %int32_461 : (!torch.int, !torch.int) -> !torch.list - %int1024_462 = torch.constant.int 1024 - %int32_463 = torch.constant.int 32 - %334 = torch.prim.ListConstruct %int1024_462, %int32_463 : (!torch.int, !torch.int) -> !torch.list - %335 = torch.aten.reshape %result0, %334 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %int1_464 = torch.constant.int 1 - %int0_465 = torch.constant.int 0 - %336 = torch.prim.ListConstruct %int1_464, %int0_465 : (!torch.int, !torch.int) -> !torch.list - %int1_466 = torch.constant.int 1 - %int0_467 = torch.constant.int 0 - %337 = torch.prim.ListConstruct %int1_466, %int0_467 : (!torch.int, !torch.int) -> !torch.list - %338 = torch.aten.permute %335, %337 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> - %339 = torch.aten.mm %338, %228 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> - %int1_468 = torch.constant.int 1 - %int0_469 = torch.constant.int 0 - %340 = torch.prim.ListConstruct %int1_468, %int0_469 : (!torch.int, !torch.int) -> !torch.list - %int1_470 = torch.constant.int 1 - %int0_471 = torch.constant.int 0 - %341 = torch.prim.ListConstruct %int1_470, %int0_471 : (!torch.int, !torch.int) -> !torch.list - %342 = torch.aten.permute %339, %341 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %343 = torch.aten.add.Tensor %arg98, %342, %arg97 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> - %344 = torch.aten.mul.Tensor %arg99, %arg57 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %345 = torch.aten.addcmul %344, %343, %343, %arg96 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %346 = torch.aten.sqrt %345 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %347 = torch.aten.add.Tensor %346, %arg2, %arg95 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> - %348 = torch.aten.mul.Tensor %arg101, %arg60 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %349 = torch.aten.add.Tensor %348, %343, %arg100 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %350 = torch.aten.addcdiv %arg16, %349, %347, %arg94 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %int0_472 = torch.constant.int 0 - %351 = torch.prim.ListConstruct %int0_472 : (!torch.int) -> !torch.list - %true_473 = torch.constant.bool true - %none_474 = torch.constant.none - %352 = torch.aten.sum.dim_IntList %228, %351, %true_473, %none_474 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> - %int32_475 = torch.constant.int 32 - %353 = torch.prim.ListConstruct %int32_475 : (!torch.int) -> !torch.list - %int32_476 = torch.constant.int 32 - %354 = torch.prim.ListConstruct %int32_476 : (!torch.int) -> !torch.list - %355 = torch.aten.reshape %352, %354 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> - %356 = torch.aten.add.Tensor %arg106, %355, %arg105 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %357 = torch.aten.mul.Tensor %arg107, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %358 = torch.aten.addcmul %357, %356, %356, %arg104 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %359 = torch.aten.sqrt %358 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %360 = torch.aten.add.Tensor %359, %arg2, %arg103 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %361 = torch.aten.mul.Tensor %arg109, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %362 = torch.aten.add.Tensor %361, %356, %arg108 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %363 = torch.aten.addcdiv %arg29, %362, %360, %arg102 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %int1024_477 = torch.constant.int 1024 - %int32_478 = torch.constant.int 32 - %364 = torch.prim.ListConstruct %int1024_477, %int32_478 : (!torch.int, !torch.int) -> !torch.list - %int1024_479 = torch.constant.int 1024 - %int32_480 = torch.constant.int 32 - %365 = torch.prim.ListConstruct %int1024_479, %int32_480 : (!torch.int, !torch.int) -> !torch.list - %366 = torch.aten.reshape %result0, %365 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %291 = torch.prim.ListConstruct %int2_459, %int512_460, %int32_461 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_462 = torch.constant.int 2 + %int512_463 = torch.constant.int 512 + %int32_464 = torch.constant.int 32 + %292 = torch.prim.ListConstruct %int2_462, %int512_463, %int32_464 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %293 = torch.aten.reshape %290, %292 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int1024_465 = torch.constant.int 1024 + %int32_466 = torch.constant.int 32 + %294 = torch.prim.ListConstruct %int1024_465, %int32_466 : (!torch.int, !torch.int) -> !torch.list + %int1024_467 = torch.constant.int 1024 + %int32_468 = torch.constant.int 32 + %295 = torch.prim.ListConstruct %int1024_467, %int32_468 : (!torch.int, !torch.int) -> !torch.list + %296 = torch.aten.reshape %293, %295 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %297 = torch.aten.mm %296, %13 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> + %int2_469 = torch.constant.int 2 + %int512_470 = torch.constant.int 512 + %int32_471 = torch.constant.int 32 + %298 = torch.prim.ListConstruct %int2_469, %int512_470, %int32_471 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_472 = torch.constant.int 2 + %int512_473 = torch.constant.int 512 + %int32_474 = torch.constant.int 32 + %299 = torch.prim.ListConstruct %int2_472, %int512_473, %int32_474 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %300 = torch.aten.reshape %297, %299 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int1_475 = torch.constant.int 1 + %int0_476 = torch.constant.int 0 + %301 = torch.prim.ListConstruct %int1_475, %int0_476 : (!torch.int, !torch.int) -> !torch.list + %int1_477 = torch.constant.int 1 + %int0_478 = torch.constant.int 0 + %302 = torch.prim.ListConstruct %int1_477, %int0_478 : (!torch.int, !torch.int) -> !torch.list + %303 = torch.aten.permute %arg19, %302 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_479 = torch.constant.int 1 + %int0_480 = torch.constant.int 0 + %304 = torch.prim.ListConstruct %int1_479, %int0_480 : (!torch.int, !torch.int) -> !torch.list %int1_481 = torch.constant.int 1 %int0_482 = torch.constant.int 0 - %367 = torch.prim.ListConstruct %int1_481, %int0_482 : (!torch.int, !torch.int) -> !torch.list - %int1_483 = torch.constant.int 1 - %int0_484 = torch.constant.int 0 - %368 = torch.prim.ListConstruct %int1_483, %int0_484 : (!torch.int, !torch.int) -> !torch.list - %369 = torch.aten.permute %366, %368 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> - %370 = torch.aten.mm %369, %253 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> - %int1_485 = torch.constant.int 1 - %int0_486 = torch.constant.int 0 - %371 = torch.prim.ListConstruct %int1_485, %int0_486 : (!torch.int, !torch.int) -> !torch.list - %int1_487 = torch.constant.int 1 - %int0_488 = torch.constant.int 0 - %372 = torch.prim.ListConstruct %int1_487, %int0_488 : (!torch.int, !torch.int) -> !torch.list - %373 = torch.aten.permute %370, %372 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %374 = torch.aten.add.Tensor %arg114, %373, %arg113 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> - %375 = torch.aten.mul.Tensor %arg115, %arg57 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %376 = torch.aten.addcmul %375, %374, %374, %arg112 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %377 = torch.aten.sqrt %376 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %378 = torch.aten.add.Tensor %377, %arg2, %arg111 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> - %379 = torch.aten.mul.Tensor %arg117, %arg60 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %380 = torch.aten.add.Tensor %379, %374, %arg116 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %381 = torch.aten.addcdiv %arg19, %380, %378, %arg110 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %305 = torch.prim.ListConstruct %int1_481, %int0_482 : (!torch.int, !torch.int) -> !torch.list + %306 = torch.aten.permute %303, %305 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int4_483 = torch.constant.int 4 + %int512_484 = torch.constant.int 512 + %int16_485 = torch.constant.int 16 + %307 = torch.prim.ListConstruct %int4_483, %int512_484, %int16_485 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int4_486 = torch.constant.int 4 + %int512_487 = torch.constant.int 512 + %int16_488 = torch.constant.int 16 + %308 = torch.prim.ListConstruct %int4_486, %int512_487, %int16_488 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %309 = torch.aten.reshape %72, %308 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[4,512,16],f32> %int0_489 = torch.constant.int 0 - %382 = torch.prim.ListConstruct %int0_489 : (!torch.int) -> !torch.list - %true_490 = torch.constant.bool true - %none_491 = torch.constant.none - %383 = torch.aten.sum.dim_IntList %253, %382, %true_490, %none_491 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> - %int32_492 = torch.constant.int 32 - %384 = torch.prim.ListConstruct %int32_492 : (!torch.int) -> !torch.list - %int32_493 = torch.constant.int 32 - %385 = torch.prim.ListConstruct %int32_493 : (!torch.int) -> !torch.list - %386 = torch.aten.reshape %383, %385 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> - %387 = torch.aten.add.Tensor %arg122, %386, %arg121 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %388 = torch.aten.mul.Tensor %arg123, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %389 = torch.aten.addcmul %388, %387, %387, %arg120 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %390 = torch.aten.sqrt %389 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %391 = torch.aten.add.Tensor %390, %arg2, %arg119 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %392 = torch.aten.mul.Tensor %arg125, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %393 = torch.aten.add.Tensor %392, %387, %arg124 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %394 = torch.aten.addcdiv %arg20, %393, %391, %arg118 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %int1024_494 = torch.constant.int 1024 - %int32_495 = torch.constant.int 32 - %395 = torch.prim.ListConstruct %int1024_494, %int32_495 : (!torch.int, !torch.int) -> !torch.list - %int1024_496 = torch.constant.int 1024 - %int32_497 = torch.constant.int 32 - %396 = torch.prim.ListConstruct %int1024_496, %int32_497 : (!torch.int, !torch.int) -> !torch.list - %397 = torch.aten.reshape %result0, %396 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %int1_498 = torch.constant.int 1 - %int0_499 = torch.constant.int 0 - %398 = torch.prim.ListConstruct %int1_498, %int0_499 : (!torch.int, !torch.int) -> !torch.list - %int1_500 = torch.constant.int 1 + %int2_490 = torch.constant.int 2 + %int1_491 = torch.constant.int 1 + %310 = torch.prim.ListConstruct %int0_489, %int2_490, %int1_491 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int0_492 = torch.constant.int 0 + %int2_493 = torch.constant.int 2 + %int1_494 = torch.constant.int 1 + %311 = torch.prim.ListConstruct %int0_492, %int2_493, %int1_494 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %312 = torch.aten.permute %309, %311 : !torch.vtensor<[4,512,16],f32>, !torch.list -> !torch.vtensor<[4,16,512],f32> + %313 = torch.aten.bmm %312, %283 : !torch.vtensor<[4,16,512],f32>, !torch.vtensor<[4,512,512],f32> -> !torch.vtensor<[4,16,512],f32> + %int2_495 = torch.constant.int 2 + %int16_496 = torch.constant.int 16 + %int512_497 = torch.constant.int 512 + %314 = torch.prim.ListConstruct %int2_495, %int2_495, %int16_496, %int512_497 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_498 = torch.constant.int 2 + %int16_499 = torch.constant.int 16 + %int512_500 = torch.constant.int 512 + %315 = torch.prim.ListConstruct %int2_498, %int2_498, %int16_499, %int512_500 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %316 = torch.aten.reshape %313, %315 : !torch.vtensor<[4,16,512],f32>, !torch.list -> !torch.vtensor<[2,2,16,512],f32> %int0_501 = torch.constant.int 0 - %399 = torch.prim.ListConstruct %int1_500, %int0_501 : (!torch.int, !torch.int) -> !torch.list - %400 = torch.aten.permute %397, %399 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> - %401 = torch.aten.mm %400, %277 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> %int1_502 = torch.constant.int 1 - %int0_503 = torch.constant.int 0 - %402 = torch.prim.ListConstruct %int1_502, %int0_503 : (!torch.int, !torch.int) -> !torch.list - %int1_504 = torch.constant.int 1 + %int3_503 = torch.constant.int 3 + %int2_504 = torch.constant.int 2 + %317 = torch.prim.ListConstruct %int0_501, %int1_502, %int3_503, %int2_504 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %int0_505 = torch.constant.int 0 - %403 = torch.prim.ListConstruct %int1_504, %int0_505 : (!torch.int, !torch.int) -> !torch.list - %404 = torch.aten.permute %401, %403 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %405 = torch.aten.add.Tensor %arg130, %404, %arg129 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> - %406 = torch.aten.mul.Tensor %arg131, %arg57 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %407 = torch.aten.addcmul %406, %405, %405, %arg128 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %408 = torch.aten.sqrt %407 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %409 = torch.aten.add.Tensor %408, %arg2, %arg127 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> - %410 = torch.aten.mul.Tensor %arg133, %arg60 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %411 = torch.aten.add.Tensor %410, %405, %arg132 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %412 = torch.aten.addcdiv %arg32, %411, %409, %arg126 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %int0_506 = torch.constant.int 0 - %413 = torch.prim.ListConstruct %int0_506 : (!torch.int) -> !torch.list - %true_507 = torch.constant.bool true - %none_508 = torch.constant.none - %414 = torch.aten.sum.dim_IntList %277, %413, %true_507, %none_508 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> - %int32_509 = torch.constant.int 32 - %415 = torch.prim.ListConstruct %int32_509 : (!torch.int) -> !torch.list - %int32_510 = torch.constant.int 32 - %416 = torch.prim.ListConstruct %int32_510 : (!torch.int) -> !torch.list - %417 = torch.aten.reshape %414, %416 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> - %418 = torch.aten.add.Tensor %arg138, %417, %arg137 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %419 = torch.aten.mul.Tensor %arg139, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %420 = torch.aten.addcmul %419, %418, %418, %arg136 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %421 = torch.aten.sqrt %420 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %422 = torch.aten.add.Tensor %421, %arg2, %arg135 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %423 = torch.aten.mul.Tensor %arg141, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %424 = torch.aten.add.Tensor %423, %418, %arg140 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %425 = torch.aten.addcdiv %arg33, %424, %422, %arg134 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %int0_511 = torch.constant.int 0 - %int2_512 = torch.constant.int 2 - %int1_513 = torch.constant.int 1 - %int3_514 = torch.constant.int 3 - %426 = torch.prim.ListConstruct %int0_511, %int2_512, %int1_513, %int3_514 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_515 = torch.constant.int 0 - %int2_516 = torch.constant.int 2 - %int1_517 = torch.constant.int 1 - %int3_518 = torch.constant.int 3 - %427 = torch.prim.ListConstruct %int0_515, %int2_516, %int1_517, %int3_518 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %428 = torch.aten.permute %115, %427 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> - %int2_519 = torch.constant.int 2 - %int512_520 = torch.constant.int 512 - %int32_521 = torch.constant.int 32 - %429 = torch.prim.ListConstruct %int2_519, %int512_520, %int32_521 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_522 = torch.constant.int 2 - %int512_523 = torch.constant.int 512 + %int1_506 = torch.constant.int 1 + %int3_507 = torch.constant.int 3 + %int2_508 = torch.constant.int 2 + %318 = torch.prim.ListConstruct %int0_505, %int1_506, %int3_507, %int2_508 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %319 = torch.aten.permute %316, %318 : !torch.vtensor<[2,2,16,512],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> + %int0_509 = torch.constant.int 0 + %int2_510 = torch.constant.int 2 + %int1_511 = torch.constant.int 1 + %int3_512 = torch.constant.int 3 + %320 = torch.prim.ListConstruct %int0_509, %int2_510, %int1_511, %int3_512 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_513 = torch.constant.int 0 + %int2_514 = torch.constant.int 2 + %int1_515 = torch.constant.int 1 + %int3_516 = torch.constant.int 3 + %321 = torch.prim.ListConstruct %int0_513, %int2_514, %int1_515, %int3_516 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %322 = torch.aten.permute %319, %321 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> + %int2_517 = torch.constant.int 2 + %int512_518 = torch.constant.int 512 + %int32_519 = torch.constant.int 32 + %323 = torch.prim.ListConstruct %int2_517, %int512_518, %int32_519 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_520 = torch.constant.int 2 + %int512_521 = torch.constant.int 512 + %int32_522 = torch.constant.int 32 + %324 = torch.prim.ListConstruct %int2_520, %int512_521, %int32_522 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %325 = torch.aten.reshape %322, %324 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int1024_523 = torch.constant.int 1024 %int32_524 = torch.constant.int 32 - %430 = torch.prim.ListConstruct %int2_522, %int512_523, %int32_524 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %431 = torch.aten.reshape %428, %430 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %326 = torch.prim.ListConstruct %int1024_523, %int32_524 : (!torch.int, !torch.int) -> !torch.list %int1024_525 = torch.constant.int 1024 %int32_526 = torch.constant.int 32 - %432 = torch.prim.ListConstruct %int1024_525, %int32_526 : (!torch.int, !torch.int) -> !torch.list - %int1024_527 = torch.constant.int 1024 - %int32_528 = torch.constant.int 32 - %433 = torch.prim.ListConstruct %int1024_527, %int32_528 : (!torch.int, !torch.int) -> !torch.list - %434 = torch.aten.reshape %431, %433 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %int1_529 = torch.constant.int 1 - %int0_530 = torch.constant.int 0 - %435 = torch.prim.ListConstruct %int1_529, %int0_530 : (!torch.int, !torch.int) -> !torch.list - %int1_531 = torch.constant.int 1 - %int0_532 = torch.constant.int 0 - %436 = torch.prim.ListConstruct %int1_531, %int0_532 : (!torch.int, !torch.int) -> !torch.list - %437 = torch.aten.permute %434, %436 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> - %438 = torch.aten.mm %437, %193 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> + %327 = torch.prim.ListConstruct %int1024_525, %int32_526 : (!torch.int, !torch.int) -> !torch.list + %328 = torch.aten.reshape %325, %327 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %329 = torch.aten.mm %328, %306 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> + %int2_527 = torch.constant.int 2 + %int512_528 = torch.constant.int 512 + %int32_529 = torch.constant.int 32 + %330 = torch.prim.ListConstruct %int2_527, %int512_528, %int32_529 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_530 = torch.constant.int 2 + %int512_531 = torch.constant.int 512 + %int32_532 = torch.constant.int 32 + %331 = torch.prim.ListConstruct %int2_530, %int512_531, %int32_532 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %332 = torch.aten.reshape %329, %331 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> %int1_533 = torch.constant.int 1 %int0_534 = torch.constant.int 0 - %439 = torch.prim.ListConstruct %int1_533, %int0_534 : (!torch.int, !torch.int) -> !torch.list + %333 = torch.prim.ListConstruct %int1_533, %int0_534 : (!torch.int, !torch.int) -> !torch.list %int1_535 = torch.constant.int 1 %int0_536 = torch.constant.int 0 - %440 = torch.prim.ListConstruct %int1_535, %int0_536 : (!torch.int, !torch.int) -> !torch.list - %441 = torch.aten.permute %438, %440 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %442 = torch.aten.add.Tensor %arg146, %441, %arg145 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> - %443 = torch.aten.mul.Tensor %arg147, %arg57 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %444 = torch.aten.addcmul %443, %442, %442, %arg144 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %445 = torch.aten.sqrt %444 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %446 = torch.aten.add.Tensor %445, %arg2, %arg143 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> - %447 = torch.aten.mul.Tensor %arg149, %arg60 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %448 = torch.aten.add.Tensor %447, %442, %arg148 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %449 = torch.aten.addcdiv %arg34, %448, %446, %arg142 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %int0_537 = torch.constant.int 0 - %450 = torch.prim.ListConstruct %int0_537 : (!torch.int) -> !torch.list - %true_538 = torch.constant.bool true - %none_539 = torch.constant.none - %451 = torch.aten.sum.dim_IntList %193, %450, %true_538, %none_539 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> - %int32_540 = torch.constant.int 32 - %452 = torch.prim.ListConstruct %int32_540 : (!torch.int) -> !torch.list - %int32_541 = torch.constant.int 32 - %453 = torch.prim.ListConstruct %int32_541 : (!torch.int) -> !torch.list - %454 = torch.aten.reshape %451, %453 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> - %455 = torch.aten.add.Tensor %arg154, %454, %arg153 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %456 = torch.aten.mul.Tensor %arg155, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %457 = torch.aten.addcmul %456, %455, %455, %arg152 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %458 = torch.aten.sqrt %457 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %459 = torch.aten.add.Tensor %458, %arg2, %arg151 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %460 = torch.aten.mul.Tensor %arg157, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %461 = torch.aten.add.Tensor %460, %455, %arg156 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %462 = torch.aten.addcdiv %arg40, %461, %459, %arg150 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %463 = torch.aten.add.Tensor %arg162, %result1_290, %arg161 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %464 = torch.aten.mul.Tensor %arg163, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %465 = torch.aten.addcmul %464, %463, %463, %arg160 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %466 = torch.aten.sqrt %465 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %467 = torch.aten.add.Tensor %466, %arg2, %arg159 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %468 = torch.aten.mul.Tensor %arg165, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %469 = torch.aten.add.Tensor %468, %463, %arg164 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %470 = torch.aten.addcdiv %arg36, %469, %467, %arg158 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %471 = torch.aten.add.Tensor %arg170, %result2_291, %arg169 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %472 = torch.aten.mul.Tensor %arg171, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %473 = torch.aten.addcmul %472, %471, %471, %arg168 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %474 = torch.aten.sqrt %473 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %475 = torch.aten.add.Tensor %474, %arg2, %arg167 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %476 = torch.aten.mul.Tensor %arg173, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %477 = torch.aten.add.Tensor %476, %471, %arg172 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %478 = torch.aten.addcdiv %arg35, %477, %475, %arg166 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %int1024_542 = torch.constant.int 1024 - %int32_543 = torch.constant.int 32 - %479 = torch.prim.ListConstruct %int1024_542, %int32_543 : (!torch.int, !torch.int) -> !torch.list - %int1024_544 = torch.constant.int 1024 - %int32_545 = torch.constant.int 32 - %480 = torch.prim.ListConstruct %int1024_544, %int32_545 : (!torch.int, !torch.int) -> !torch.list - %481 = torch.aten.reshape %result0_210, %480 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %int1_546 = torch.constant.int 1 - %int0_547 = torch.constant.int 0 - %482 = torch.prim.ListConstruct %int1_546, %int0_547 : (!torch.int, !torch.int) -> !torch.list - %int1_548 = torch.constant.int 1 - %int0_549 = torch.constant.int 0 - %483 = torch.prim.ListConstruct %int1_548, %int0_549 : (!torch.int, !torch.int) -> !torch.list - %484 = torch.aten.permute %481, %483 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> - %485 = torch.aten.mm %484, %182 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> + %334 = torch.prim.ListConstruct %int1_535, %int0_536 : (!torch.int, !torch.int) -> !torch.list + %335 = torch.aten.permute %arg32, %334 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int1_537 = torch.constant.int 1 + %int0_538 = torch.constant.int 0 + %336 = torch.prim.ListConstruct %int1_537, %int0_538 : (!torch.int, !torch.int) -> !torch.list + %int1_539 = torch.constant.int 1 + %int0_540 = torch.constant.int 0 + %337 = torch.prim.ListConstruct %int1_539, %int0_540 : (!torch.int, !torch.int) -> !torch.list + %338 = torch.aten.permute %335, %337 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %int4_541 = torch.constant.int 4 + %int512_542 = torch.constant.int 512 + %339 = torch.prim.ListConstruct %int4_541, %int512_542, %int512_542 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int4_543 = torch.constant.int 4 + %int512_544 = torch.constant.int 512 + %340 = torch.prim.ListConstruct %int4_543, %int512_544, %int512_544 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %341 = torch.aten.reshape %120, %340 : !torch.vtensor<[2,2,512,512],f32>, !torch.list -> !torch.vtensor<[4,512,512],f32> + %int0_545 = torch.constant.int 0 + %int2_546 = torch.constant.int 2 + %int1_547 = torch.constant.int 1 + %342 = torch.prim.ListConstruct %int0_545, %int2_546, %int1_547 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int0_548 = torch.constant.int 0 + %int2_549 = torch.constant.int 2 %int1_550 = torch.constant.int 1 - %int0_551 = torch.constant.int 0 - %486 = torch.prim.ListConstruct %int1_550, %int0_551 : (!torch.int, !torch.int) -> !torch.list - %int1_552 = torch.constant.int 1 - %int0_553 = torch.constant.int 0 - %487 = torch.prim.ListConstruct %int1_552, %int0_553 : (!torch.int, !torch.int) -> !torch.list - %488 = torch.aten.permute %485, %487 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %489 = torch.aten.add.Tensor %arg178, %488, %arg177 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> - %490 = torch.aten.mul.Tensor %arg179, %arg57 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %491 = torch.aten.addcmul %490, %489, %489, %arg176 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %492 = torch.aten.sqrt %491 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %493 = torch.aten.add.Tensor %492, %arg2, %arg175 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> - %494 = torch.aten.mul.Tensor %arg181, %arg60 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %495 = torch.aten.add.Tensor %494, %489, %arg180 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %496 = torch.aten.addcdiv %arg42, %495, %493, %arg174 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %int0_554 = torch.constant.int 0 - %497 = torch.prim.ListConstruct %int0_554 : (!torch.int) -> !torch.list - %true_555 = torch.constant.bool true - %none_556 = torch.constant.none - %498 = torch.aten.sum.dim_IntList %182, %497, %true_555, %none_556 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> - %int32_557 = torch.constant.int 32 - %499 = torch.prim.ListConstruct %int32_557 : (!torch.int) -> !torch.list - %int32_558 = torch.constant.int 32 - %500 = torch.prim.ListConstruct %int32_558 : (!torch.int) -> !torch.list - %501 = torch.aten.reshape %498, %500 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> - %502 = torch.aten.add.Tensor %arg186, %501, %arg185 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %503 = torch.aten.mul.Tensor %arg187, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %504 = torch.aten.addcmul %503, %502, %502, %arg184 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %505 = torch.aten.sqrt %504 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %506 = torch.aten.add.Tensor %505, %arg2, %arg183 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %507 = torch.aten.mul.Tensor %arg189, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %508 = torch.aten.add.Tensor %507, %502, %arg188 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %509 = torch.aten.addcdiv %arg45, %508, %506, %arg182 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %int1024_559 = torch.constant.int 1024 - %int32_560 = torch.constant.int 32 - %510 = torch.prim.ListConstruct %int1024_559, %int32_560 : (!torch.int, !torch.int) -> !torch.list - %int1024_561 = torch.constant.int 1024 - %int32_562 = torch.constant.int 32 - %511 = torch.prim.ListConstruct %int1024_561, %int32_562 : (!torch.int, !torch.int) -> !torch.list - %512 = torch.aten.reshape %157, %511 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %343 = torch.prim.ListConstruct %int0_548, %int2_549, %int1_550 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %344 = torch.aten.permute %341, %343 : !torch.vtensor<[4,512,512],f32>, !torch.list -> !torch.vtensor<[4,512,512],f32> + %345 = torch.aten.bmm %344, %274 : !torch.vtensor<[4,512,512],f32>, !torch.vtensor<[4,512,16],f32> -> !torch.vtensor<[4,512,16],f32> + %int2_551 = torch.constant.int 2 + %int512_552 = torch.constant.int 512 + %int16_553 = torch.constant.int 16 + %346 = torch.prim.ListConstruct %int2_551, %int2_551, %int512_552, %int16_553 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_554 = torch.constant.int 2 + %int512_555 = torch.constant.int 512 + %int16_556 = torch.constant.int 16 + %347 = torch.prim.ListConstruct %int2_554, %int2_554, %int512_555, %int16_556 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %348 = torch.aten.reshape %345, %347 : !torch.vtensor<[4,512,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> + %int0_557 = torch.constant.int 0 + %int2_558 = torch.constant.int 2 + %int1_559 = torch.constant.int 1 + %int3_560 = torch.constant.int 3 + %349 = torch.prim.ListConstruct %int0_557, %int2_558, %int1_559, %int3_560 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_561 = torch.constant.int 0 + %int2_562 = torch.constant.int 2 %int1_563 = torch.constant.int 1 - %int0_564 = torch.constant.int 0 - %513 = torch.prim.ListConstruct %int1_563, %int0_564 : (!torch.int, !torch.int) -> !torch.list - %int1_565 = torch.constant.int 1 - %int0_566 = torch.constant.int 0 - %514 = torch.prim.ListConstruct %int1_565, %int0_566 : (!torch.int, !torch.int) -> !torch.list - %515 = torch.aten.permute %512, %514 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> - %516 = torch.aten.mm %515, %174 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> - %int1_567 = torch.constant.int 1 - %int0_568 = torch.constant.int 0 - %517 = torch.prim.ListConstruct %int1_567, %int0_568 : (!torch.int, !torch.int) -> !torch.list - %int1_569 = torch.constant.int 1 - %int0_570 = torch.constant.int 0 - %518 = torch.prim.ListConstruct %int1_569, %int0_570 : (!torch.int, !torch.int) -> !torch.list - %519 = torch.aten.permute %516, %518 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %520 = torch.aten.add.Tensor %arg194, %519, %arg193 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> - %521 = torch.aten.mul.Tensor %arg195, %arg57 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %522 = torch.aten.addcmul %521, %520, %520, %arg192 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %523 = torch.aten.sqrt %522 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %524 = torch.aten.add.Tensor %523, %arg2, %arg191 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> - %525 = torch.aten.mul.Tensor %arg197, %arg60 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %526 = torch.aten.add.Tensor %525, %520, %arg196 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %527 = torch.aten.addcdiv %arg46, %526, %524, %arg190 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %int0_571 = torch.constant.int 0 - %528 = torch.prim.ListConstruct %int0_571 : (!torch.int) -> !torch.list - %true_572 = torch.constant.bool true - %none_573 = torch.constant.none - %529 = torch.aten.sum.dim_IntList %174, %528, %true_572, %none_573 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> + %int3_564 = torch.constant.int 3 + %350 = torch.prim.ListConstruct %int0_561, %int2_562, %int1_563, %int3_564 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %351 = torch.aten.permute %348, %350 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> + %int2_565 = torch.constant.int 2 + %int512_566 = torch.constant.int 512 + %int32_567 = torch.constant.int 32 + %352 = torch.prim.ListConstruct %int2_565, %int512_566, %int32_567 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_568 = torch.constant.int 2 + %int512_569 = torch.constant.int 512 + %int32_570 = torch.constant.int 32 + %353 = torch.prim.ListConstruct %int2_568, %int512_569, %int32_570 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %354 = torch.aten.reshape %351, %353 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int1024_571 = torch.constant.int 1024 + %int32_572 = torch.constant.int 32 + %355 = torch.prim.ListConstruct %int1024_571, %int32_572 : (!torch.int, !torch.int) -> !torch.list + %int1024_573 = torch.constant.int 1024 %int32_574 = torch.constant.int 32 - %530 = torch.prim.ListConstruct %int32_574 : (!torch.int) -> !torch.list - %int32_575 = torch.constant.int 32 - %531 = torch.prim.ListConstruct %int32_575 : (!torch.int) -> !torch.list - %532 = torch.aten.reshape %529, %531 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> - %533 = torch.aten.add.Tensor %arg202, %532, %arg201 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %534 = torch.aten.mul.Tensor %arg203, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %535 = torch.aten.addcmul %534, %533, %533, %arg200 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %536 = torch.aten.sqrt %535 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %537 = torch.aten.add.Tensor %536, %arg2, %arg199 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %538 = torch.aten.mul.Tensor %arg205, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %539 = torch.aten.add.Tensor %538, %533, %arg204 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %540 = torch.aten.addcdiv %arg52, %539, %537, %arg198 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %541 = torch.aten.add.Tensor %arg210, %result1_264, %arg209 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %542 = torch.aten.mul.Tensor %arg211, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %543 = torch.aten.addcmul %542, %541, %541, %arg208 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %544 = torch.aten.sqrt %543 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %545 = torch.aten.add.Tensor %544, %arg2, %arg207 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %546 = torch.aten.mul.Tensor %arg213, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %547 = torch.aten.add.Tensor %546, %541, %arg212 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %548 = torch.aten.addcdiv %arg48, %547, %545, %arg206 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %549 = torch.aten.add.Tensor %arg218, %result2_265, %arg217 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %550 = torch.aten.mul.Tensor %arg219, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %551 = torch.aten.addcmul %550, %549, %549, %arg216 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %552 = torch.aten.sqrt %551 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %553 = torch.aten.add.Tensor %552, %arg2, %arg215 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %554 = torch.aten.mul.Tensor %arg221, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %555 = torch.aten.add.Tensor %554, %549, %arg220 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %556 = torch.aten.addcdiv %arg47, %555, %553, %arg214 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %int1_576 = torch.constant.int 1 - %int0_577 = torch.constant.int 0 - %557 = torch.prim.ListConstruct %int1_576, %int0_577 : (!torch.int, !torch.int) -> !torch.list - %int1_578 = torch.constant.int 1 - %int0_579 = torch.constant.int 0 - %558 = torch.prim.ListConstruct %int1_578, %int0_579 : (!torch.int, !torch.int) -> !torch.list - %559 = torch.aten.permute %arg228, %558 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int0_580 = torch.constant.int 0 - %int0_581 = torch.constant.int 0 - %int9223372036854775807_582 = torch.constant.int 9223372036854775807 - %int1_583 = torch.constant.int 1 - %560 = torch.aten.slice.Tensor %result0_259, %int0_580, %int0_581, %int9223372036854775807_582, %int1_583 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,512,32],f32> - %int1_584 = torch.constant.int 1 - %int0_585 = torch.constant.int 0 - %561 = torch.aten.select.int %560, %int1_584, %int0_585 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,32],f32> - %562 = torch.aten.addmm %arg229, %561, %559, %arg227, %arg226 : !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,32],f32> - %563 = torch.aten.tanh %562 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> - %int1_586 = torch.constant.int 1 - %int0_587 = torch.constant.int 0 - %564 = torch.prim.ListConstruct %int1_586, %int0_587 : (!torch.int, !torch.int) -> !torch.list - %int1_588 = torch.constant.int 1 - %int0_589 = torch.constant.int 0 - %565 = torch.prim.ListConstruct %int1_588, %int0_589 : (!torch.int, !torch.int) -> !torch.list - %566 = torch.aten.permute %arg230, %565 : !torch.vtensor<[2,32],f32>, !torch.list -> !torch.vtensor<[32,2],f32> - %int1_590 = torch.constant.int 1 - %int0_591 = torch.constant.int 0 - %567 = torch.prim.ListConstruct %int1_590, %int0_591 : (!torch.int, !torch.int) -> !torch.list - %int1_592 = torch.constant.int 1 - %int0_593 = torch.constant.int 0 - %568 = torch.prim.ListConstruct %int1_592, %int0_593 : (!torch.int, !torch.int) -> !torch.list - %569 = torch.aten.permute %566, %568 : !torch.vtensor<[32,2],f32>, !torch.list -> !torch.vtensor<[2,32],f32> - %int1_594 = torch.constant.int 1 - %int0_595 = torch.constant.int 0 - %570 = torch.prim.ListConstruct %int1_594, %int0_595 : (!torch.int, !torch.int) -> !torch.list - %int1_596 = torch.constant.int 1 - %int0_597 = torch.constant.int 0 - %571 = torch.prim.ListConstruct %int1_596, %int0_597 : (!torch.int, !torch.int) -> !torch.list - %572 = torch.aten.permute %arg230, %571 : !torch.vtensor<[2,32],f32>, !torch.list -> !torch.vtensor<[32,2],f32> - %573 = torch.aten.addmm %arg233, %563, %572, %arg232, %arg231 : !torch.vtensor<[2],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,2],f32> - %int2_598 = torch.constant.int 2 - %574 = torch.prim.ListConstruct %int2_598, %int2_598 : (!torch.int, !torch.int) -> !torch.list - %int2_599 = torch.constant.int 2 - %575 = torch.prim.ListConstruct %int2_599, %int2_599 : (!torch.int, !torch.int) -> !torch.list - %576 = torch.aten.reshape %573, %575 : !torch.vtensor<[2,2],f32>, !torch.list -> !torch.vtensor<[2,2],f32> - %int1_600 = torch.constant.int 1 - %false_601 = torch.constant.bool false - %577 = torch.aten._log_softmax %576, %int1_600, %false_601 : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,2],f32> - %int2_602 = torch.constant.int 2 - %578 = torch.prim.ListConstruct %int2_602 : (!torch.int) -> !torch.list - %int2_603 = torch.constant.int 2 - %579 = torch.prim.ListConstruct %int2_603 : (!torch.int) -> !torch.list - %580 = torch.aten.reshape %arg234, %579 : !torch.vtensor<[2],si64>, !torch.list -> !torch.vtensor<[2],si64> - %none_604 = torch.constant.none + %356 = torch.prim.ListConstruct %int1024_573, %int32_574 : (!torch.int, !torch.int) -> !torch.list + %357 = torch.aten.reshape %354, %356 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %358 = torch.aten.mm %357, %338 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> + %int2_575 = torch.constant.int 2 + %int512_576 = torch.constant.int 512 + %int32_577 = torch.constant.int 32 + %359 = torch.prim.ListConstruct %int2_575, %int512_576, %int32_577 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_578 = torch.constant.int 2 + %int512_579 = torch.constant.int 512 + %int32_580 = torch.constant.int 32 + %360 = torch.prim.ListConstruct %int2_578, %int512_579, %int32_580 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %361 = torch.aten.reshape %358, %360 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %362 = torch.aten.add.Tensor %result0_402, %361, %arg66 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> + %363 = torch.aten.add.Tensor %362, %332, %arg65 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> + %364 = torch.aten.add.Tensor %363, %300, %arg15 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> + %int32_581 = torch.constant.int 32 + %365 = torch.prim.ListConstruct %int32_581 : (!torch.int) -> !torch.list + %true_582 = torch.constant.bool true + %366 = torch.prim.ListConstruct %true_582, %true_582, %true_582 : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0_583, %result1_584, %result2_585 = torch.aten.native_layer_norm_backward %364, %5, %365, %result1, %result2, %arg7, %arg6, %366 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> + %367 = torch.prim.TupleConstruct %result0_583, %result1_584, %result2_585 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> -> !torch.tuple, vtensor<[32],f32>, vtensor<[32],f32>> + %int28996 = torch.constant.int 28996 + %int0_586 = torch.constant.int 0 + %false_587 = torch.constant.bool false + %368 = torch.aten.embedding_dense_backward %result0_583, %arg5, %int28996, %int0_586, %false_587 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512],si64>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[28996,32],f32> + %369 = torch.aten.add.Tensor %arg67, %368, %arg4 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.int -> !torch.vtensor<[28996,32],f32> + %370 = torch.aten.mul.Tensor %arg69, %arg68 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[28996,32],f32> + %371 = torch.aten.addcmul %370, %369, %369, %arg3 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.float -> !torch.vtensor<[28996,32],f32> + %372 = torch.aten.sqrt %371 : !torch.vtensor<[28996,32],f32> -> !torch.vtensor<[28996,32],f32> + %373 = torch.aten.add.Tensor %372, %arg2, %arg1 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[28996,32],f32> + %374 = torch.aten.mul.Tensor %arg72, %arg71 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[28996,32],f32> + %375 = torch.aten.add.Tensor %374, %369, %arg70 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.float -> !torch.vtensor<[28996,32],f32> + %376 = torch.aten.addcdiv %arg14, %375, %373, %arg0 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.float -> !torch.vtensor<[28996,32],f32> + %int0_588 = torch.constant.int 0 + %377 = torch.prim.ListConstruct %int0_588 : (!torch.int) -> !torch.list + %true_589 = torch.constant.bool true + %none_590 = torch.constant.none + %378 = torch.aten.sum.dim_IntList %result0_583, %377, %true_589, %none_590 : !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,512,32],f32> + %int512_591 = torch.constant.int 512 + %int-1_592 = torch.constant.int -1 + %false_593 = torch.constant.bool false + %379 = torch.aten.embedding_dense_backward %378, %0, %int512_591, %int-1_592, %false_593 : !torch.vtensor<[1,512,32],f32>, !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[512,32],f32> + %380 = torch.aten.add.Tensor %arg77, %379, %arg76 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.int -> !torch.vtensor<[512,32],f32> + %381 = torch.aten.mul.Tensor %arg78, %arg68 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[512,32],f32> + %382 = torch.aten.addcmul %381, %380, %380, %arg75 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.float -> !torch.vtensor<[512,32],f32> + %383 = torch.aten.sqrt %382 : !torch.vtensor<[512,32],f32> -> !torch.vtensor<[512,32],f32> + %384 = torch.aten.add.Tensor %383, %arg2, %arg74 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[512,32],f32> + %385 = torch.aten.mul.Tensor %arg80, %arg71 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[512,32],f32> + %386 = torch.aten.add.Tensor %385, %380, %arg79 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.float -> !torch.vtensor<[512,32],f32> + %387 = torch.aten.addcdiv %arg10, %386, %384, %arg73 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.float -> !torch.vtensor<[512,32],f32> + %int2_594 = torch.constant.int 2 + %int-1_595 = torch.constant.int -1 + %false_596 = torch.constant.bool false + %388 = torch.aten.embedding_dense_backward %result0_583, %arg12, %int2_594, %int-1_595, %false_596 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512],si64>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[2,32],f32> + %389 = torch.aten.add.Tensor %arg85, %388, %arg84 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.int -> !torch.vtensor<[2,32],f32> + %390 = torch.aten.mul.Tensor %arg86, %arg68 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,32],f32> + %391 = torch.aten.addcmul %390, %389, %389, %arg83 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> + %392 = torch.aten.sqrt %391 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> + %393 = torch.aten.add.Tensor %392, %arg2, %arg82 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[2,32],f32> + %394 = torch.aten.mul.Tensor %arg88, %arg71 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,32],f32> + %395 = torch.aten.add.Tensor %394, %389, %arg87 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> + %396 = torch.aten.addcdiv %arg13, %395, %393, %arg81 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> + %397 = torch.aten.add.Tensor %arg93, %result1_584, %arg92 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %398 = torch.aten.mul.Tensor %arg94, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %399 = torch.aten.addcmul %398, %397, %397, %arg91 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %400 = torch.aten.sqrt %399 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %401 = torch.aten.add.Tensor %400, %arg2, %arg90 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %402 = torch.aten.mul.Tensor %arg96, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %403 = torch.aten.add.Tensor %402, %397, %arg95 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %404 = torch.aten.addcdiv %arg7, %403, %401, %arg89 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %405 = torch.aten.add.Tensor %arg101, %result2_585, %arg100 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %406 = torch.aten.mul.Tensor %arg102, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %407 = torch.aten.addcmul %406, %405, %405, %arg99 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %408 = torch.aten.sqrt %407 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %409 = torch.aten.add.Tensor %408, %arg2, %arg98 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %410 = torch.aten.mul.Tensor %arg104, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %411 = torch.aten.add.Tensor %410, %405, %arg103 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %412 = torch.aten.addcdiv %arg6, %411, %409, %arg97 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %int1024_597 = torch.constant.int 1024 + %int32_598 = torch.constant.int 32 + %413 = torch.prim.ListConstruct %int1024_597, %int32_598 : (!torch.int, !torch.int) -> !torch.list + %int1024_599 = torch.constant.int 1024 + %int32_600 = torch.constant.int 32 + %414 = torch.prim.ListConstruct %int1024_599, %int32_600 : (!torch.int, !torch.int) -> !torch.list + %415 = torch.aten.reshape %result0, %414 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %int1_601 = torch.constant.int 1 + %int0_602 = torch.constant.int 0 + %416 = torch.prim.ListConstruct %int1_601, %int0_602 : (!torch.int, !torch.int) -> !torch.list + %int1_603 = torch.constant.int 1 + %int0_604 = torch.constant.int 0 + %417 = torch.prim.ListConstruct %int1_603, %int0_604 : (!torch.int, !torch.int) -> !torch.list + %418 = torch.aten.permute %415, %417 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> + %419 = torch.aten.mm %418, %296 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> %int1_605 = torch.constant.int 1 - %int-100 = torch.constant.int -100 - %output, %total_weight = torch.aten.nll_loss_forward %577, %580, %none_604, %int1_605, %int-100 : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32> - %581 = torch.prim.TupleConstruct %output, %total_weight : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.tuple, vtensor<[],f32>> - %none_606 = torch.constant.none + %int0_606 = torch.constant.int 0 + %420 = torch.prim.ListConstruct %int1_605, %int0_606 : (!torch.int, !torch.int) -> !torch.list %int1_607 = torch.constant.int 1 - %int-100_608 = torch.constant.int -100 - %582 = torch.aten.nll_loss_backward %arg235, %577, %580, %none_606, %int1_607, %int-100_608, %total_weight : !torch.vtensor<[],f32>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[2],si64>, !torch.none, !torch.int, !torch.int, !torch.vtensor<[],f32> -> !torch.vtensor<[2,2],f32> - %int1_609 = torch.constant.int 1 - %int6_610 = torch.constant.int 6 - %583 = torch.aten._log_softmax_backward_data %582, %577, %int1_609, %int6_610 : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,2],f32> - %int2_611 = torch.constant.int 2 - %584 = torch.prim.ListConstruct %int2_611, %int2_611 : (!torch.int, !torch.int) -> !torch.list - %int2_612 = torch.constant.int 2 - %585 = torch.prim.ListConstruct %int2_612, %int2_612 : (!torch.int, !torch.int) -> !torch.list - %586 = torch.aten.reshape %583, %585 : !torch.vtensor<[2,2],f32>, !torch.list -> !torch.vtensor<[2,2],f32> - %587 = torch.aten.mm %586, %569 : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> - %588 = torch.aten.tanh_backward %587, %563 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> - %int1_613 = torch.constant.int 1 - %int0_614 = torch.constant.int 0 - %589 = torch.prim.ListConstruct %int1_613, %int0_614 : (!torch.int, !torch.int) -> !torch.list - %int1_615 = torch.constant.int 1 - %int0_616 = torch.constant.int 0 - %590 = torch.prim.ListConstruct %int1_615, %int0_616 : (!torch.int, !torch.int) -> !torch.list - %591 = torch.aten.permute %561, %590 : !torch.vtensor<[2,32],f32>, !torch.list -> !torch.vtensor<[32,2],f32> - %592 = torch.aten.mm %591, %588 : !torch.vtensor<[32,2],f32>, !torch.vtensor<[2,32],f32> -> !torch.vtensor<[32,32],f32> - %int1_617 = torch.constant.int 1 - %int0_618 = torch.constant.int 0 - %593 = torch.prim.ListConstruct %int1_617, %int0_618 : (!torch.int, !torch.int) -> !torch.list - %int1_619 = torch.constant.int 1 - %int0_620 = torch.constant.int 0 - %594 = torch.prim.ListConstruct %int1_619, %int0_620 : (!torch.int, !torch.int) -> !torch.list - %595 = torch.aten.permute %592, %594 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %596 = torch.aten.add.Tensor %arg236, %595, %arg225 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> - %597 = torch.aten.mul.Tensor %arg237, %arg57 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %598 = torch.aten.addcmul %597, %596, %596, %arg224 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %599 = torch.aten.sqrt %598 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %600 = torch.aten.add.Tensor %599, %arg2, %arg223 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> - %601 = torch.aten.mul.Tensor %arg239, %arg60 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %602 = torch.aten.add.Tensor %601, %596, %arg238 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %603 = torch.aten.addcdiv %arg228, %602, %600, %arg222 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %int0_608 = torch.constant.int 0 + %421 = torch.prim.ListConstruct %int1_607, %int0_608 : (!torch.int, !torch.int) -> !torch.list + %422 = torch.aten.permute %419, %421 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %423 = torch.aten.add.Tensor %arg109, %422, %arg108 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> + %424 = torch.aten.mul.Tensor %arg110, %arg68 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %425 = torch.aten.addcmul %424, %423, %423, %arg107 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %426 = torch.aten.sqrt %425 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %427 = torch.aten.add.Tensor %426, %arg2, %arg106 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> + %428 = torch.aten.mul.Tensor %arg112, %arg71 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %429 = torch.aten.add.Tensor %428, %423, %arg111 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %430 = torch.aten.addcdiv %arg16, %429, %427, %arg105 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %int0_609 = torch.constant.int 0 + %431 = torch.prim.ListConstruct %int0_609 : (!torch.int) -> !torch.list + %true_610 = torch.constant.bool true + %none_611 = torch.constant.none + %432 = torch.aten.sum.dim_IntList %296, %431, %true_610, %none_611 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> + %int32_612 = torch.constant.int 32 + %433 = torch.prim.ListConstruct %int32_612 : (!torch.int) -> !torch.list + %int32_613 = torch.constant.int 32 + %434 = torch.prim.ListConstruct %int32_613 : (!torch.int) -> !torch.list + %435 = torch.aten.reshape %432, %434 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> + %436 = torch.aten.add.Tensor %arg117, %435, %arg116 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %437 = torch.aten.mul.Tensor %arg118, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %438 = torch.aten.addcmul %437, %436, %436, %arg115 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %439 = torch.aten.sqrt %438 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %440 = torch.aten.add.Tensor %439, %arg2, %arg114 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %441 = torch.aten.mul.Tensor %arg120, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %442 = torch.aten.add.Tensor %441, %436, %arg119 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %443 = torch.aten.addcdiv %arg29, %442, %440, %arg113 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %int1024_614 = torch.constant.int 1024 + %int32_615 = torch.constant.int 32 + %444 = torch.prim.ListConstruct %int1024_614, %int32_615 : (!torch.int, !torch.int) -> !torch.list + %int1024_616 = torch.constant.int 1024 + %int32_617 = torch.constant.int 32 + %445 = torch.prim.ListConstruct %int1024_616, %int32_617 : (!torch.int, !torch.int) -> !torch.list + %446 = torch.aten.reshape %result0, %445 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %int1_618 = torch.constant.int 1 + %int0_619 = torch.constant.int 0 + %447 = torch.prim.ListConstruct %int1_618, %int0_619 : (!torch.int, !torch.int) -> !torch.list + %int1_620 = torch.constant.int 1 %int0_621 = torch.constant.int 0 - %604 = torch.prim.ListConstruct %int0_621 : (!torch.int) -> !torch.list - %true_622 = torch.constant.bool true - %none_623 = torch.constant.none - %605 = torch.aten.sum.dim_IntList %588, %604, %true_622, %none_623 : !torch.vtensor<[2,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> - %int32_624 = torch.constant.int 32 - %606 = torch.prim.ListConstruct %int32_624 : (!torch.int) -> !torch.list - %int32_625 = torch.constant.int 32 - %607 = torch.prim.ListConstruct %int32_625 : (!torch.int) -> !torch.list - %608 = torch.aten.reshape %605, %607 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> - %609 = torch.aten.add.Tensor %arg244, %608, %arg243 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %610 = torch.aten.mul.Tensor %arg245, %arg57 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %611 = torch.aten.addcmul %610, %609, %609, %arg242 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %612 = torch.aten.sqrt %611 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %613 = torch.aten.add.Tensor %612, %arg2, %arg241 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %614 = torch.aten.mul.Tensor %arg247, %arg60 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %615 = torch.aten.add.Tensor %614, %609, %arg246 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %616 = torch.aten.addcdiv %arg229, %615, %613, %arg240 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %int1_626 = torch.constant.int 1 - %int0_627 = torch.constant.int 0 - %617 = torch.prim.ListConstruct %int1_626, %int0_627 : (!torch.int, !torch.int) -> !torch.list - %int1_628 = torch.constant.int 1 - %int0_629 = torch.constant.int 0 - %618 = torch.prim.ListConstruct %int1_628, %int0_629 : (!torch.int, !torch.int) -> !torch.list - %619 = torch.aten.permute %563, %618 : !torch.vtensor<[2,32],f32>, !torch.list -> !torch.vtensor<[32,2],f32> - %620 = torch.aten.mm %619, %586 : !torch.vtensor<[32,2],f32>, !torch.vtensor<[2,2],f32> -> !torch.vtensor<[32,2],f32> - %int1_630 = torch.constant.int 1 - %int0_631 = torch.constant.int 0 - %621 = torch.prim.ListConstruct %int1_630, %int0_631 : (!torch.int, !torch.int) -> !torch.list - %int1_632 = torch.constant.int 1 - %int0_633 = torch.constant.int 0 - %622 = torch.prim.ListConstruct %int1_632, %int0_633 : (!torch.int, !torch.int) -> !torch.list - %623 = torch.aten.permute %620, %622 : !torch.vtensor<[32,2],f32>, !torch.list -> !torch.vtensor<[2,32],f32> - %624 = torch.aten.add.Tensor %arg252, %623, %arg251 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.int -> !torch.vtensor<[2,32],f32> - %625 = torch.aten.mul.Tensor %arg253, %arg57 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,32],f32> - %626 = torch.aten.addcmul %625, %624, %624, %arg250 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> - %627 = torch.aten.sqrt %626 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> - %628 = torch.aten.add.Tensor %627, %arg2, %arg249 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[2,32],f32> - %629 = torch.aten.mul.Tensor %arg255, %arg60 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,32],f32> - %630 = torch.aten.add.Tensor %629, %624, %arg254 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> - %631 = torch.aten.addcdiv %arg230, %630, %628, %arg248 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> - %int0_634 = torch.constant.int 0 - %632 = torch.prim.ListConstruct %int0_634 : (!torch.int) -> !torch.list - %true_635 = torch.constant.bool true - %none_636 = torch.constant.none - %633 = torch.aten.sum.dim_IntList %586, %632, %true_635, %none_636 : !torch.vtensor<[2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,2],f32> - %int2_637 = torch.constant.int 2 - %634 = torch.prim.ListConstruct %int2_637 : (!torch.int) -> !torch.list - %int2_638 = torch.constant.int 2 - %635 = torch.prim.ListConstruct %int2_638 : (!torch.int) -> !torch.list - %636 = torch.aten.reshape %633, %635 : !torch.vtensor<[1,2],f32>, !torch.list -> !torch.vtensor<[2],f32> - %637 = torch.aten.add.Tensor %arg260, %636, %arg259 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2],f32> - %638 = torch.aten.mul.Tensor %arg261, %arg57 : !torch.vtensor<[2],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2],f32> - %639 = torch.aten.addcmul %638, %637, %637, %arg258 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float -> !torch.vtensor<[2],f32> - %640 = torch.aten.sqrt %639 : !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32> - %641 = torch.aten.add.Tensor %640, %arg2, %arg257 : !torch.vtensor<[2],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[2],f32> - %642 = torch.aten.mul.Tensor %arg263, %arg60 : !torch.vtensor<[2],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2],f32> - %643 = torch.aten.add.Tensor %642, %637, %arg262 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float -> !torch.vtensor<[2],f32> - %644 = torch.aten.addcdiv %arg233, %643, %641, %arg256 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float -> !torch.vtensor<[2],f32> - %645 = torch.aten.zero.functional %637 : !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32> - %646 = torch.aten.zero.functional %624 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> - %647 = torch.aten.zero.functional %609 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %648 = torch.aten.zero.functional %596 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %649 = torch.aten.zero.functional %541 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %650 = torch.aten.zero.functional %549 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %651 = torch.aten.zero.functional %533 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %652 = torch.aten.zero.functional %520 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %653 = torch.aten.zero.functional %502 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %654 = torch.aten.zero.functional %489 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %655 = torch.aten.zero.functional %463 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %656 = torch.aten.zero.functional %471 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %657 = torch.aten.zero.functional %455 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %658 = torch.aten.zero.functional %442 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %659 = torch.aten.zero.functional %418 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %660 = torch.aten.zero.functional %405 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %661 = torch.aten.zero.functional %387 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %662 = torch.aten.zero.functional %374 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %663 = torch.aten.zero.functional %356 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %664 = torch.aten.zero.functional %343 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %665 = torch.aten.zero.functional %317 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %666 = torch.aten.zero.functional %325 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %667 = torch.aten.zero.functional %300 : !torch.vtensor<[512,32],f32> -> !torch.vtensor<[512,32],f32> - %668 = torch.aten.zero.functional %309 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> - %669 = torch.aten.zero.functional %289 : !torch.vtensor<[28996,32],f32> -> !torch.vtensor<[28996,32],f32> - return %296, %307, %316, %324, %332, %350, %363, %381, %394, %412, %425, %449, %462, %470, %478, %496, %509, %527, %540, %548, %556, %603, %616, %631, %644, %645, %646, %647, %648, %649, %650, %651, %652, %653, %654, %655, %656, %657, %658, %659, %660, %661, %662, %663, %664, %665, %666, %667, %668, %669, %295, %291, %306, %302, %315, %311, %323, %319, %331, %327, %349, %345, %362, %358, %380, %376, %393, %389, %411, %407, %424, %420, %448, %444, %461, %457, %469, %465, %477, %473, %495, %491, %508, %504, %526, %522, %539, %535, %547, %543, %555, %551, %602, %598, %615, %611, %630, %626, %643, %639, %arg234, %arg5, %arg12, %arg26, %573, %output : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[],f32> + %448 = torch.prim.ListConstruct %int1_620, %int0_621 : (!torch.int, !torch.int) -> !torch.list + %449 = torch.aten.permute %446, %448 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> + %450 = torch.aten.mm %449, %328 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> + %int1_622 = torch.constant.int 1 + %int0_623 = torch.constant.int 0 + %451 = torch.prim.ListConstruct %int1_622, %int0_623 : (!torch.int, !torch.int) -> !torch.list + %int1_624 = torch.constant.int 1 + %int0_625 = torch.constant.int 0 + %452 = torch.prim.ListConstruct %int1_624, %int0_625 : (!torch.int, !torch.int) -> !torch.list + %453 = torch.aten.permute %450, %452 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %454 = torch.aten.add.Tensor %arg125, %453, %arg124 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> + %455 = torch.aten.mul.Tensor %arg126, %arg68 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %456 = torch.aten.addcmul %455, %454, %454, %arg123 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %457 = torch.aten.sqrt %456 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %458 = torch.aten.add.Tensor %457, %arg2, %arg122 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> + %459 = torch.aten.mul.Tensor %arg128, %arg71 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %460 = torch.aten.add.Tensor %459, %454, %arg127 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %461 = torch.aten.addcdiv %arg19, %460, %458, %arg121 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %int0_626 = torch.constant.int 0 + %462 = torch.prim.ListConstruct %int0_626 : (!torch.int) -> !torch.list + %true_627 = torch.constant.bool true + %none_628 = torch.constant.none + %463 = torch.aten.sum.dim_IntList %328, %462, %true_627, %none_628 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> + %int32_629 = torch.constant.int 32 + %464 = torch.prim.ListConstruct %int32_629 : (!torch.int) -> !torch.list + %int32_630 = torch.constant.int 32 + %465 = torch.prim.ListConstruct %int32_630 : (!torch.int) -> !torch.list + %466 = torch.aten.reshape %463, %465 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> + %467 = torch.aten.add.Tensor %arg133, %466, %arg132 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %468 = torch.aten.mul.Tensor %arg134, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %469 = torch.aten.addcmul %468, %467, %467, %arg131 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %470 = torch.aten.sqrt %469 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %471 = torch.aten.add.Tensor %470, %arg2, %arg130 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %472 = torch.aten.mul.Tensor %arg136, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %473 = torch.aten.add.Tensor %472, %467, %arg135 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %474 = torch.aten.addcdiv %arg20, %473, %471, %arg129 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %int1024_631 = torch.constant.int 1024 + %int32_632 = torch.constant.int 32 + %475 = torch.prim.ListConstruct %int1024_631, %int32_632 : (!torch.int, !torch.int) -> !torch.list + %int1024_633 = torch.constant.int 1024 + %int32_634 = torch.constant.int 32 + %476 = torch.prim.ListConstruct %int1024_633, %int32_634 : (!torch.int, !torch.int) -> !torch.list + %477 = torch.aten.reshape %result0, %476 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %int1_635 = torch.constant.int 1 + %int0_636 = torch.constant.int 0 + %478 = torch.prim.ListConstruct %int1_635, %int0_636 : (!torch.int, !torch.int) -> !torch.list + %int1_637 = torch.constant.int 1 + %int0_638 = torch.constant.int 0 + %479 = torch.prim.ListConstruct %int1_637, %int0_638 : (!torch.int, !torch.int) -> !torch.list + %480 = torch.aten.permute %477, %479 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> + %481 = torch.aten.mm %480, %357 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> + %int1_639 = torch.constant.int 1 + %int0_640 = torch.constant.int 0 + %482 = torch.prim.ListConstruct %int1_639, %int0_640 : (!torch.int, !torch.int) -> !torch.list + %int1_641 = torch.constant.int 1 + %int0_642 = torch.constant.int 0 + %483 = torch.prim.ListConstruct %int1_641, %int0_642 : (!torch.int, !torch.int) -> !torch.list + %484 = torch.aten.permute %481, %483 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %485 = torch.aten.add.Tensor %arg141, %484, %arg140 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> + %486 = torch.aten.mul.Tensor %arg142, %arg68 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %487 = torch.aten.addcmul %486, %485, %485, %arg139 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %488 = torch.aten.sqrt %487 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %489 = torch.aten.add.Tensor %488, %arg2, %arg138 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> + %490 = torch.aten.mul.Tensor %arg144, %arg71 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %491 = torch.aten.add.Tensor %490, %485, %arg143 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %492 = torch.aten.addcdiv %arg32, %491, %489, %arg137 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %int0_643 = torch.constant.int 0 + %493 = torch.prim.ListConstruct %int0_643 : (!torch.int) -> !torch.list + %true_644 = torch.constant.bool true + %none_645 = torch.constant.none + %494 = torch.aten.sum.dim_IntList %357, %493, %true_644, %none_645 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> + %int32_646 = torch.constant.int 32 + %495 = torch.prim.ListConstruct %int32_646 : (!torch.int) -> !torch.list + %int32_647 = torch.constant.int 32 + %496 = torch.prim.ListConstruct %int32_647 : (!torch.int) -> !torch.list + %497 = torch.aten.reshape %494, %496 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> + %498 = torch.aten.add.Tensor %arg149, %497, %arg148 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %499 = torch.aten.mul.Tensor %arg150, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %500 = torch.aten.addcmul %499, %498, %498, %arg147 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %501 = torch.aten.sqrt %500 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %502 = torch.aten.add.Tensor %501, %arg2, %arg146 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %503 = torch.aten.mul.Tensor %arg152, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %504 = torch.aten.add.Tensor %503, %498, %arg151 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %505 = torch.aten.addcdiv %arg33, %504, %502, %arg145 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %int2_648 = torch.constant.int 2 + %int512_649 = torch.constant.int 512 + %int16_650 = torch.constant.int 16 + %506 = torch.prim.ListConstruct %int2_648, %int2_648, %int512_649, %int16_650 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int2_651 = torch.constant.int 2 + %int512_652 = torch.constant.int 512 + %int16_653 = torch.constant.int 16 + %507 = torch.prim.ListConstruct %int2_651, %int2_651, %int512_652, %int16_653 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %508 = torch.aten.reshape %124, %507 : !torch.vtensor<[4,512,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> + %int0_654 = torch.constant.int 0 + %int2_655 = torch.constant.int 2 + %int1_656 = torch.constant.int 1 + %int3_657 = torch.constant.int 3 + %509 = torch.prim.ListConstruct %int0_654, %int2_655, %int1_656, %int3_657 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0_658 = torch.constant.int 0 + %int2_659 = torch.constant.int 2 + %int1_660 = torch.constant.int 1 + %int3_661 = torch.constant.int 3 + %510 = torch.prim.ListConstruct %int0_658, %int2_659, %int1_660, %int3_661 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %511 = torch.aten.permute %508, %510 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> + %int2_662 = torch.constant.int 2 + %int512_663 = torch.constant.int 512 + %int32_664 = torch.constant.int 32 + %512 = torch.prim.ListConstruct %int2_662, %int512_663, %int32_664 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int2_665 = torch.constant.int 2 + %int512_666 = torch.constant.int 512 + %int32_667 = torch.constant.int 32 + %513 = torch.prim.ListConstruct %int2_665, %int512_666, %int32_667 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %514 = torch.aten.reshape %511, %513 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> + %int1024_668 = torch.constant.int 1024 + %int32_669 = torch.constant.int 32 + %515 = torch.prim.ListConstruct %int1024_668, %int32_669 : (!torch.int, !torch.int) -> !torch.list + %int1024_670 = torch.constant.int 1024 + %int32_671 = torch.constant.int 32 + %516 = torch.prim.ListConstruct %int1024_670, %int32_671 : (!torch.int, !torch.int) -> !torch.list + %517 = torch.aten.reshape %514, %516 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %int1_672 = torch.constant.int 1 + %int0_673 = torch.constant.int 0 + %518 = torch.prim.ListConstruct %int1_672, %int0_673 : (!torch.int, !torch.int) -> !torch.list + %int1_674 = torch.constant.int 1 + %int0_675 = torch.constant.int 0 + %519 = torch.prim.ListConstruct %int1_674, %int0_675 : (!torch.int, !torch.int) -> !torch.list + %520 = torch.aten.permute %517, %519 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> + %521 = torch.aten.mm %520, %261 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> + %int1_676 = torch.constant.int 1 + %int0_677 = torch.constant.int 0 + %522 = torch.prim.ListConstruct %int1_676, %int0_677 : (!torch.int, !torch.int) -> !torch.list + %int1_678 = torch.constant.int 1 + %int0_679 = torch.constant.int 0 + %523 = torch.prim.ListConstruct %int1_678, %int0_679 : (!torch.int, !torch.int) -> !torch.list + %524 = torch.aten.permute %521, %523 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %525 = torch.aten.add.Tensor %arg157, %524, %arg156 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> + %526 = torch.aten.mul.Tensor %arg158, %arg68 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %527 = torch.aten.addcmul %526, %525, %525, %arg155 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %528 = torch.aten.sqrt %527 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %529 = torch.aten.add.Tensor %528, %arg2, %arg154 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> + %530 = torch.aten.mul.Tensor %arg160, %arg71 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %531 = torch.aten.add.Tensor %530, %525, %arg159 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %532 = torch.aten.addcdiv %arg34, %531, %529, %arg153 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %int0_680 = torch.constant.int 0 + %533 = torch.prim.ListConstruct %int0_680 : (!torch.int) -> !torch.list + %true_681 = torch.constant.bool true + %none_682 = torch.constant.none + %534 = torch.aten.sum.dim_IntList %261, %533, %true_681, %none_682 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> + %int32_683 = torch.constant.int 32 + %535 = torch.prim.ListConstruct %int32_683 : (!torch.int) -> !torch.list + %int32_684 = torch.constant.int 32 + %536 = torch.prim.ListConstruct %int32_684 : (!torch.int) -> !torch.list + %537 = torch.aten.reshape %534, %536 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> + %538 = torch.aten.add.Tensor %arg165, %537, %arg164 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %539 = torch.aten.mul.Tensor %arg166, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %540 = torch.aten.addcmul %539, %538, %538, %arg163 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %541 = torch.aten.sqrt %540 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %542 = torch.aten.add.Tensor %541, %arg2, %arg162 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %543 = torch.aten.mul.Tensor %arg168, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %544 = torch.aten.add.Tensor %543, %538, %arg167 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %545 = torch.aten.addcdiv %arg40, %544, %542, %arg161 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %546 = torch.aten.add.Tensor %arg173, %result1_403, %arg172 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %547 = torch.aten.mul.Tensor %arg174, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %548 = torch.aten.addcmul %547, %546, %546, %arg171 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %549 = torch.aten.sqrt %548 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %550 = torch.aten.add.Tensor %549, %arg2, %arg170 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %551 = torch.aten.mul.Tensor %arg176, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %552 = torch.aten.add.Tensor %551, %546, %arg175 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %553 = torch.aten.addcdiv %arg36, %552, %550, %arg169 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %554 = torch.aten.add.Tensor %arg181, %result2_404, %arg180 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %555 = torch.aten.mul.Tensor %arg182, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %556 = torch.aten.addcmul %555, %554, %554, %arg179 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %557 = torch.aten.sqrt %556 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %558 = torch.aten.add.Tensor %557, %arg2, %arg178 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %559 = torch.aten.mul.Tensor %arg184, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %560 = torch.aten.add.Tensor %559, %554, %arg183 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %561 = torch.aten.addcdiv %arg35, %560, %558, %arg177 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %int1024_685 = torch.constant.int 1024 + %int32_686 = torch.constant.int 32 + %562 = torch.prim.ListConstruct %int1024_685, %int32_686 : (!torch.int, !torch.int) -> !torch.list + %int1024_687 = torch.constant.int 1024 + %int32_688 = torch.constant.int 32 + %563 = torch.prim.ListConstruct %int1024_687, %int32_688 : (!torch.int, !torch.int) -> !torch.list + %564 = torch.aten.reshape %result0_238, %563 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %int1_689 = torch.constant.int 1 + %int0_690 = torch.constant.int 0 + %565 = torch.prim.ListConstruct %int1_689, %int0_690 : (!torch.int, !torch.int) -> !torch.list + %int1_691 = torch.constant.int 1 + %int0_692 = torch.constant.int 0 + %566 = torch.prim.ListConstruct %int1_691, %int0_692 : (!torch.int, !torch.int) -> !torch.list + %567 = torch.aten.permute %564, %566 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> + %568 = torch.aten.mm %567, %250 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> + %int1_693 = torch.constant.int 1 + %int0_694 = torch.constant.int 0 + %569 = torch.prim.ListConstruct %int1_693, %int0_694 : (!torch.int, !torch.int) -> !torch.list + %int1_695 = torch.constant.int 1 + %int0_696 = torch.constant.int 0 + %570 = torch.prim.ListConstruct %int1_695, %int0_696 : (!torch.int, !torch.int) -> !torch.list + %571 = torch.aten.permute %568, %570 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %572 = torch.aten.add.Tensor %arg189, %571, %arg188 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> + %573 = torch.aten.mul.Tensor %arg190, %arg68 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %574 = torch.aten.addcmul %573, %572, %572, %arg187 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %575 = torch.aten.sqrt %574 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %576 = torch.aten.add.Tensor %575, %arg2, %arg186 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> + %577 = torch.aten.mul.Tensor %arg192, %arg71 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %578 = torch.aten.add.Tensor %577, %572, %arg191 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %579 = torch.aten.addcdiv %arg42, %578, %576, %arg185 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %int0_697 = torch.constant.int 0 + %580 = torch.prim.ListConstruct %int0_697 : (!torch.int) -> !torch.list + %true_698 = torch.constant.bool true + %none_699 = torch.constant.none + %581 = torch.aten.sum.dim_IntList %250, %580, %true_698, %none_699 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> + %int32_700 = torch.constant.int 32 + %582 = torch.prim.ListConstruct %int32_700 : (!torch.int) -> !torch.list + %int32_701 = torch.constant.int 32 + %583 = torch.prim.ListConstruct %int32_701 : (!torch.int) -> !torch.list + %584 = torch.aten.reshape %581, %583 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> + %585 = torch.aten.add.Tensor %arg197, %584, %arg196 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %586 = torch.aten.mul.Tensor %arg198, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %587 = torch.aten.addcmul %586, %585, %585, %arg195 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %588 = torch.aten.sqrt %587 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %589 = torch.aten.add.Tensor %588, %arg2, %arg194 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %590 = torch.aten.mul.Tensor %arg200, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %591 = torch.aten.add.Tensor %590, %585, %arg199 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %592 = torch.aten.addcdiv %arg45, %591, %589, %arg193 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %int1024_702 = torch.constant.int 1024 + %int32_703 = torch.constant.int 32 + %593 = torch.prim.ListConstruct %int1024_702, %int32_703 : (!torch.int, !torch.int) -> !torch.list + %int1024_704 = torch.constant.int 1024 + %int32_705 = torch.constant.int 32 + %594 = torch.prim.ListConstruct %int1024_704, %int32_705 : (!torch.int, !torch.int) -> !torch.list + %595 = torch.aten.reshape %169, %594 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> + %int1_706 = torch.constant.int 1 + %int0_707 = torch.constant.int 0 + %596 = torch.prim.ListConstruct %int1_706, %int0_707 : (!torch.int, !torch.int) -> !torch.list + %int1_708 = torch.constant.int 1 + %int0_709 = torch.constant.int 0 + %597 = torch.prim.ListConstruct %int1_708, %int0_709 : (!torch.int, !torch.int) -> !torch.list + %598 = torch.aten.permute %595, %597 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> + %599 = torch.aten.mm %598, %242 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> + %int1_710 = torch.constant.int 1 + %int0_711 = torch.constant.int 0 + %600 = torch.prim.ListConstruct %int1_710, %int0_711 : (!torch.int, !torch.int) -> !torch.list + %int1_712 = torch.constant.int 1 + %int0_713 = torch.constant.int 0 + %601 = torch.prim.ListConstruct %int1_712, %int0_713 : (!torch.int, !torch.int) -> !torch.list + %602 = torch.aten.permute %599, %601 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %603 = torch.aten.add.Tensor %arg205, %602, %arg204 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> + %604 = torch.aten.mul.Tensor %arg206, %arg68 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %605 = torch.aten.addcmul %604, %603, %603, %arg203 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %606 = torch.aten.sqrt %605 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %607 = torch.aten.add.Tensor %606, %arg2, %arg202 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> + %608 = torch.aten.mul.Tensor %arg208, %arg71 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %609 = torch.aten.add.Tensor %608, %603, %arg207 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %610 = torch.aten.addcdiv %arg46, %609, %607, %arg201 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %int0_714 = torch.constant.int 0 + %611 = torch.prim.ListConstruct %int0_714 : (!torch.int) -> !torch.list + %true_715 = torch.constant.bool true + %none_716 = torch.constant.none + %612 = torch.aten.sum.dim_IntList %242, %611, %true_715, %none_716 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> + %int32_717 = torch.constant.int 32 + %613 = torch.prim.ListConstruct %int32_717 : (!torch.int) -> !torch.list + %int32_718 = torch.constant.int 32 + %614 = torch.prim.ListConstruct %int32_718 : (!torch.int) -> !torch.list + %615 = torch.aten.reshape %612, %614 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> + %616 = torch.aten.add.Tensor %arg213, %615, %arg212 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %617 = torch.aten.mul.Tensor %arg214, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %618 = torch.aten.addcmul %617, %616, %616, %arg211 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %619 = torch.aten.sqrt %618 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %620 = torch.aten.add.Tensor %619, %arg2, %arg210 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %621 = torch.aten.mul.Tensor %arg216, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %622 = torch.aten.add.Tensor %621, %616, %arg215 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %623 = torch.aten.addcdiv %arg52, %622, %620, %arg209 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %624 = torch.aten.add.Tensor %arg221, %result1_377, %arg220 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %625 = torch.aten.mul.Tensor %arg222, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %626 = torch.aten.addcmul %625, %624, %624, %arg219 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %627 = torch.aten.sqrt %626 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %628 = torch.aten.add.Tensor %627, %arg2, %arg218 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %629 = torch.aten.mul.Tensor %arg224, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %630 = torch.aten.add.Tensor %629, %624, %arg223 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %631 = torch.aten.addcdiv %arg48, %630, %628, %arg217 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %632 = torch.aten.add.Tensor %arg229, %result2_378, %arg228 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %633 = torch.aten.mul.Tensor %arg230, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %634 = torch.aten.addcmul %633, %632, %632, %arg227 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %635 = torch.aten.sqrt %634 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %636 = torch.aten.add.Tensor %635, %arg2, %arg226 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %637 = torch.aten.mul.Tensor %arg232, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %638 = torch.aten.add.Tensor %637, %632, %arg231 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %639 = torch.aten.addcdiv %arg47, %638, %636, %arg225 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %int0_719 = torch.constant.int 0 + %int0_720 = torch.constant.int 0 + %int2_721 = torch.constant.int 2 + %int1_722 = torch.constant.int 1 + %640 = torch.aten.slice.Tensor %result0_287, %int0_719, %int0_720, %int2_721, %int1_722 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,512,32],f32> + %int0_723 = torch.constant.int 0 + %int0_724 = torch.constant.int 0 + %int2_725 = torch.constant.int 2 + %int1_726 = torch.constant.int 1 + %641 = torch.aten.slice.Tensor %640, %int0_723, %int0_724, %int2_725, %int1_726 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,512,32],f32> + %int1_727 = torch.constant.int 1 + %int0_728 = torch.constant.int 0 + %int1_729 = torch.constant.int 1 + %int1_730 = torch.constant.int 1 + %642 = torch.aten.slice.Tensor %641, %int1_727, %int0_728, %int1_729, %int1_730 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1,32],f32> + %int2_731 = torch.constant.int 2 + %int0_732 = torch.constant.int 0 + %int32_733 = torch.constant.int 32 + %int1_734 = torch.constant.int 1 + %643 = torch.aten.slice.Tensor %642, %int2_731, %int0_732, %int32_733, %int1_734 : !torch.vtensor<[2,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1,32],f32> + %int2_735 = torch.constant.int 2 + %int32_736 = torch.constant.int 32 + %644 = torch.prim.ListConstruct %int2_735, %int32_736 : (!torch.int, !torch.int) -> !torch.list + %int2_737 = torch.constant.int 2 + %int32_738 = torch.constant.int 32 + %645 = torch.prim.ListConstruct %int2_737, %int32_738 : (!torch.int, !torch.int) -> !torch.list + %646 = torch.aten.reshape %643, %645 : !torch.vtensor<[2,1,32],f32>, !torch.list -> !torch.vtensor<[2,32],f32> + %int1_739 = torch.constant.int 1 + %int0_740 = torch.constant.int 0 + %647 = torch.prim.ListConstruct %int1_739, %int0_740 : (!torch.int, !torch.int) -> !torch.list + %int1_741 = torch.constant.int 1 + %int0_742 = torch.constant.int 0 + %648 = torch.prim.ListConstruct %int1_741, %int0_742 : (!torch.int, !torch.int) -> !torch.list + %649 = torch.aten.permute %646, %648 : !torch.vtensor<[2,32],f32>, !torch.list -> !torch.vtensor<[32,2],f32> + %650 = torch.aten.mm %649, %222 : !torch.vtensor<[32,2],f32>, !torch.vtensor<[2,32],f32> -> !torch.vtensor<[32,32],f32> + %int1_743 = torch.constant.int 1 + %int0_744 = torch.constant.int 0 + %651 = torch.prim.ListConstruct %int1_743, %int0_744 : (!torch.int, !torch.int) -> !torch.list + %int1_745 = torch.constant.int 1 + %int0_746 = torch.constant.int 0 + %652 = torch.prim.ListConstruct %int1_745, %int0_746 : (!torch.int, !torch.int) -> !torch.list + %653 = torch.aten.permute %650, %652 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> + %654 = torch.aten.add.Tensor %arg237, %653, %arg236 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> + %655 = torch.aten.mul.Tensor %arg238, %arg68 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %656 = torch.aten.addcmul %655, %654, %654, %arg235 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %657 = torch.aten.sqrt %656 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %658 = torch.aten.add.Tensor %657, %arg2, %arg234 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> + %659 = torch.aten.mul.Tensor %arg240, %arg71 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> + %660 = torch.aten.add.Tensor %659, %654, %arg239 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %661 = torch.aten.addcdiv %arg53, %660, %658, %arg233 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> + %int0_747 = torch.constant.int 0 + %662 = torch.prim.ListConstruct %int0_747 : (!torch.int) -> !torch.list + %true_748 = torch.constant.bool true + %none_749 = torch.constant.none + %663 = torch.aten.sum.dim_IntList %222, %662, %true_748, %none_749 : !torch.vtensor<[2,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> + %int32_750 = torch.constant.int 32 + %664 = torch.prim.ListConstruct %int32_750 : (!torch.int) -> !torch.list + %int32_751 = torch.constant.int 32 + %665 = torch.prim.ListConstruct %int32_751 : (!torch.int) -> !torch.list + %666 = torch.aten.reshape %663, %665 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> + %667 = torch.aten.add.Tensor %arg245, %666, %arg244 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> + %668 = torch.aten.mul.Tensor %arg246, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %669 = torch.aten.addcmul %668, %667, %667, %arg243 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %670 = torch.aten.sqrt %669 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %671 = torch.aten.add.Tensor %670, %arg2, %arg242 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> + %672 = torch.aten.mul.Tensor %arg248, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> + %673 = torch.aten.add.Tensor %672, %667, %arg247 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %674 = torch.aten.addcdiv %arg56, %673, %671, %arg241 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> + %int1_752 = torch.constant.int 1 + %int0_753 = torch.constant.int 0 + %675 = torch.prim.ListConstruct %int1_752, %int0_753 : (!torch.int, !torch.int) -> !torch.list + %int1_754 = torch.constant.int 1 + %int0_755 = torch.constant.int 0 + %676 = torch.prim.ListConstruct %int1_754, %int0_755 : (!torch.int, !torch.int) -> !torch.list + %677 = torch.aten.permute %197, %676 : !torch.vtensor<[2,32],f32>, !torch.list -> !torch.vtensor<[32,2],f32> + %678 = torch.aten.mm %677, %220 : !torch.vtensor<[32,2],f32>, !torch.vtensor<[2,2],f32> -> !torch.vtensor<[32,2],f32> + %int1_756 = torch.constant.int 1 + %int0_757 = torch.constant.int 0 + %679 = torch.prim.ListConstruct %int1_756, %int0_757 : (!torch.int, !torch.int) -> !torch.list + %int1_758 = torch.constant.int 1 + %int0_759 = torch.constant.int 0 + %680 = torch.prim.ListConstruct %int1_758, %int0_759 : (!torch.int, !torch.int) -> !torch.list + %681 = torch.aten.permute %678, %680 : !torch.vtensor<[32,2],f32>, !torch.list -> !torch.vtensor<[2,32],f32> + %682 = torch.aten.add.Tensor %arg253, %681, %arg252 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.int -> !torch.vtensor<[2,32],f32> + %683 = torch.aten.mul.Tensor %arg254, %arg68 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,32],f32> + %684 = torch.aten.addcmul %683, %682, %682, %arg251 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> + %685 = torch.aten.sqrt %684 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> + %686 = torch.aten.add.Tensor %685, %arg2, %arg250 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[2,32],f32> + %687 = torch.aten.mul.Tensor %arg256, %arg71 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,32],f32> + %688 = torch.aten.add.Tensor %687, %682, %arg255 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> + %689 = torch.aten.addcdiv %arg57, %688, %686, %arg249 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> + %int0_760 = torch.constant.int 0 + %690 = torch.prim.ListConstruct %int0_760 : (!torch.int) -> !torch.list + %true_761 = torch.constant.bool true + %none_762 = torch.constant.none + %691 = torch.aten.sum.dim_IntList %220, %690, %true_761, %none_762 : !torch.vtensor<[2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,2],f32> + %int2_763 = torch.constant.int 2 + %692 = torch.prim.ListConstruct %int2_763 : (!torch.int) -> !torch.list + %int2_764 = torch.constant.int 2 + %693 = torch.prim.ListConstruct %int2_764 : (!torch.int) -> !torch.list + %694 = torch.aten.reshape %691, %693 : !torch.vtensor<[1,2],f32>, !torch.list -> !torch.vtensor<[2],f32> + %695 = torch.aten.add.Tensor %arg261, %694, %arg260 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2],f32> + %696 = torch.aten.mul.Tensor %arg262, %arg68 : !torch.vtensor<[2],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2],f32> + %697 = torch.aten.addcmul %696, %695, %695, %arg259 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float -> !torch.vtensor<[2],f32> + %698 = torch.aten.sqrt %697 : !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32> + %699 = torch.aten.add.Tensor %698, %arg2, %arg258 : !torch.vtensor<[2],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[2],f32> + %700 = torch.aten.mul.Tensor %arg264, %arg71 : !torch.vtensor<[2],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2],f32> + %701 = torch.aten.add.Tensor %700, %695, %arg263 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float -> !torch.vtensor<[2],f32> + %702 = torch.aten.addcdiv %arg60, %701, %699, %arg257 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float -> !torch.vtensor<[2],f32> + %703 = torch.aten.zero.functional %695 : !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32> + %704 = torch.aten.zero.functional %682 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> + %705 = torch.aten.zero.functional %667 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %706 = torch.aten.zero.functional %654 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %707 = torch.aten.zero.functional %624 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %708 = torch.aten.zero.functional %632 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %709 = torch.aten.zero.functional %616 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %710 = torch.aten.zero.functional %603 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %711 = torch.aten.zero.functional %585 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %712 = torch.aten.zero.functional %572 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %713 = torch.aten.zero.functional %546 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %714 = torch.aten.zero.functional %554 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %715 = torch.aten.zero.functional %538 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %716 = torch.aten.zero.functional %525 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %717 = torch.aten.zero.functional %498 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %718 = torch.aten.zero.functional %485 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %719 = torch.aten.zero.functional %467 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %720 = torch.aten.zero.functional %454 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %721 = torch.aten.zero.functional %436 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %722 = torch.aten.zero.functional %423 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> + %723 = torch.aten.zero.functional %397 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %724 = torch.aten.zero.functional %405 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> + %725 = torch.aten.zero.functional %380 : !torch.vtensor<[512,32],f32> -> !torch.vtensor<[512,32],f32> + %726 = torch.aten.zero.functional %389 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> + %727 = torch.aten.zero.functional %369 : !torch.vtensor<[28996,32],f32> -> !torch.vtensor<[28996,32],f32> + return %376, %387, %396, %404, %412, %430, %443, %461, %474, %492, %505, %532, %545, %553, %561, %579, %592, %610, %623, %631, %639, %661, %674, %689, %702, %703, %704, %705, %706, %707, %708, %709, %710, %711, %712, %713, %714, %715, %716, %717, %718, %719, %720, %721, %722, %723, %724, %725, %726, %727, %375, %371, %386, %382, %395, %391, %403, %399, %411, %407, %429, %425, %442, %438, %460, %456, %473, %469, %491, %487, %504, %500, %531, %527, %544, %540, %552, %548, %560, %556, %578, %574, %591, %587, %609, %605, %622, %618, %630, %626, %638, %634, %660, %656, %673, %669, %688, %684, %701, %697, %arg61, %arg5, %arg12, %arg25, %207, %output : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[],f32> } diff --git a/e2e_testing/lazy_tensor_core/main.py b/e2e_testing/lazy_tensor_core/main.py index 8561874e9bf..9f51480d222 100644 --- a/e2e_testing/lazy_tensor_core/main.py +++ b/e2e_testing/lazy_tensor_core/main.py @@ -3,18 +3,53 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend - +import os +import pathlib import unittest +import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend +from numpy.testing import assert_almost_equal + # Example models import ltc_backend_bert import ltc_backend_mnist -import os -import pathlib -class LTCTests(unittest.TestCase): +class LTCNumericTests(unittest.TestCase): + """ + This test suite validates numerics by comparing the output of the models when + executed using the MLIR LTC backend and ensuring they match the results on CPU. + """ + + def assert_tensors_list_almost_equal(self, tensors_a, tensors_b): + self.assertEqual(len(tensors_a), len(tensors_b)) + + for idx in range(len(tensors_a)): + a = tensors_a[idx].cpu().detach().numpy() + b = tensors_b[idx].cpu().detach().numpy() + + assert_almost_equal(a, b) + + def run_test(self, run_model): + model_torch_mlir, loss_torch_mlir = run_model('lazy') + model_cpu, loss_cpu = run_model('cpu') + + # Ensure that model states and losses are almost equal between LTC and CPU. + self.assert_tensors_list_almost_equal(loss_torch_mlir, loss_cpu) + self.assert_tensors_list_almost_equal(list(model_torch_mlir.parameters()), list(model_cpu.parameters())) + + def test_bert(self): + self.run_test(ltc_backend_bert.main) + + def test_mnist(self): + self.run_test(ltc_backend_mnist.main) + + +class LTCMlirTests(unittest.TestCase): + """ + This test suite validates that the emitted MLIR matches a known good output. + """ + def run_test(self, run_model, mlir_path): run_model() diff --git a/e2e_testing/lazy_tensor_core/mnist.mlir b/e2e_testing/lazy_tensor_core/mnist.mlir index 21df4e8eff6..26e9874c914 100644 --- a/e2e_testing/lazy_tensor_core/mnist.mlir +++ b/e2e_testing/lazy_tensor_core/mnist.mlir @@ -1,4 +1,4 @@ -func.func @graph(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1],si64>, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.vtensor<[10,5],f32>, %arg7: !torch.vtensor<[10],f32>, %arg8: !torch.vtensor<[],f32>, %arg9: !torch.float) -> (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[10,5],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[1,10],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[10,5],f32>) { +func.func @graph(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.vtensor<[10,5],f32>, %arg7: !torch.vtensor<[10],f32>, %arg8: !torch.vtensor<[1],si64>, %arg9: !torch.vtensor<[],f32>, %arg10: !torch.vtensor<[10,5],f32>, %arg11: !torch.float, %arg12: !torch.int, %arg13: !torch.vtensor<[10],f32>) -> (!torch.vtensor<[1,5],f32>, !torch.vtensor<[10,5],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[10,5],f32>, !torch.vtensor<[1,10],f32>, !torch.vtensor<[],f32>) { %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 %0 = torch.prim.ListConstruct %int1, %int0 : (!torch.int, !torch.int) -> !torch.list @@ -14,12 +14,12 @@ func.func @graph(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1],si6 %none = torch.constant.none %int1_3 = torch.constant.int 1 %int-100 = torch.constant.int -100 - %output, %total_weight = torch.aten.nll_loss_forward %5, %arg1, %none, %int1_3, %int-100 : !torch.vtensor<[1,10],f32>, !torch.vtensor<[1],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32> + %output, %total_weight = torch.aten.nll_loss_forward %5, %arg8, %none, %int1_3, %int-100 : !torch.vtensor<[1,10],f32>, !torch.vtensor<[1],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32> %6 = torch.prim.TupleConstruct %output, %total_weight : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.tuple, vtensor<[],f32>> %none_4 = torch.constant.none %int1_5 = torch.constant.int 1 %int-100_6 = torch.constant.int -100 - %7 = torch.aten.nll_loss_backward %arg8, %5, %arg1, %none_4, %int1_5, %int-100_6, %total_weight : !torch.vtensor<[],f32>, !torch.vtensor<[1,10],f32>, !torch.vtensor<[1],si64>, !torch.none, !torch.int, !torch.int, !torch.vtensor<[],f32> -> !torch.vtensor<[1,10],f32> + %7 = torch.aten.nll_loss_backward %arg9, %5, %arg8, %none_4, %int1_5, %int-100_6, %total_weight : !torch.vtensor<[],f32>, !torch.vtensor<[1,10],f32>, !torch.vtensor<[1],si64>, !torch.none, !torch.int, !torch.int, !torch.vtensor<[],f32> -> !torch.vtensor<[1,10],f32> %int1_7 = torch.constant.int 1 %int6 = torch.constant.int 6 %8 = torch.aten._log_softmax_backward_data %7, %5, %int1_7, %int6 : !torch.vtensor<[1,10],f32>, !torch.vtensor<[1,10],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,10],f32> @@ -39,17 +39,21 @@ func.func @graph(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1],si6 %int0_15 = torch.constant.int 0 %15 = torch.prim.ListConstruct %int1_14, %int0_15 : (!torch.int, !torch.int) -> !torch.list %16 = torch.aten.permute %13, %15 : !torch.vtensor<[5,10],f32>, !torch.list -> !torch.vtensor<[10,5],f32> - %17 = torch.aten.add.Tensor %arg6, %16, %arg2 : !torch.vtensor<[10,5],f32>, !torch.vtensor<[10,5],f32>, !torch.float -> !torch.vtensor<[10,5],f32> + %17 = torch.aten.zero.functional %arg10 : !torch.vtensor<[10,5],f32> -> !torch.vtensor<[10,5],f32> + %18 = torch.aten.add.Tensor %17, %16, %arg2 : !torch.vtensor<[10,5],f32>, !torch.vtensor<[10,5],f32>, !torch.int -> !torch.vtensor<[10,5],f32> + %19 = torch.aten.add.Tensor %arg6, %18, %arg1 : !torch.vtensor<[10,5],f32>, !torch.vtensor<[10,5],f32>, !torch.float -> !torch.vtensor<[10,5],f32> %int0_16 = torch.constant.int 0 - %18 = torch.prim.ListConstruct %int0_16 : (!torch.int) -> !torch.list + %20 = torch.prim.ListConstruct %int0_16 : (!torch.int) -> !torch.list %true = torch.constant.bool true %none_17 = torch.constant.none - %19 = torch.aten.sum.dim_IntList %9, %18, %true, %none_17 : !torch.vtensor<[1,10],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,10],f32> + %21 = torch.aten.sum.dim_IntList %9, %20, %true, %none_17 : !torch.vtensor<[1,10],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,10],f32> %int10 = torch.constant.int 10 - %20 = torch.prim.ListConstruct %int10 : (!torch.int) -> !torch.list + %22 = torch.prim.ListConstruct %int10 : (!torch.int) -> !torch.list %int10_18 = torch.constant.int 10 - %21 = torch.prim.ListConstruct %int10_18 : (!torch.int) -> !torch.list - %22 = torch.aten.reshape %19, %21 : !torch.vtensor<[1,10],f32>, !torch.list -> !torch.vtensor<[10],f32> - %23 = torch.aten.add.Tensor %arg7, %22, %arg9 : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> - return %arg0, %arg1, %17, %23, %4, %output, %22, %16 : !torch.vtensor<[1,5],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[10,5],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[1,10],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[10,5],f32> + %23 = torch.prim.ListConstruct %int10_18 : (!torch.int) -> !torch.list + %24 = torch.aten.reshape %21, %23 : !torch.vtensor<[1,10],f32>, !torch.list -> !torch.vtensor<[10],f32> + %25 = torch.aten.zero.functional %arg13 : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + %26 = torch.aten.add.Tensor %25, %24, %arg12 : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.int -> !torch.vtensor<[10],f32> + %27 = torch.aten.add.Tensor %arg7, %26, %arg11 : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + return %arg0, %19, %27, %26, %18, %4, %output : !torch.vtensor<[1,5],f32>, !torch.vtensor<[10,5],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[10,5],f32>, !torch.vtensor<[1,10],f32>, !torch.vtensor<[],f32> } diff --git a/examples/ltc_backend_bert.py b/examples/ltc_backend_bert.py index 0b808328efd..31cfe156e69 100644 --- a/examples/ltc_backend_bert.py +++ b/examples/ltc_backend_bert.py @@ -122,6 +122,8 @@ def main(device='lazy', full_size=False): print('Loss: ', losses) + return model, losses + if __name__ == "__main__": torch.manual_seed(0) diff --git a/examples/ltc_backend_mnist.py b/examples/ltc_backend_mnist.py index 312ea9fc2dd..7bb2db16deb 100644 --- a/examples/ltc_backend_mnist.py +++ b/examples/ltc_backend_mnist.py @@ -49,16 +49,22 @@ def forward(self, x): criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - optimizer.zero_grad() - outputs = model(inputs) - loss = criterion(outputs, targets) - loss.backward() - optimizer.step() + num_epochs = 3 + losses = [] + for _ in range(num_epochs): + optimizer.zero_grad() - if device == "lazy": - print("Calling Mark Step") - torch._lazy.mark_step() + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + losses.append(loss) + + optimizer.step() + + if device == "lazy": + print("Calling Mark Step") + torch._lazy.mark_step() # Get debug information from LTC if 'ltc_backend' in sys.modules: @@ -66,7 +72,9 @@ def forward(self, x): if computation: print(computation.debug_string()) - print(loss) + print(losses) + + return model, losses if __name__ == "__main__": From e76464b3bd29040338676e4607b99378a52118b8 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Mon, 6 Jun 2022 19:53:44 -0400 Subject: [PATCH 05/13] Print name of the model layer that fails numeric validation --- e2e_testing/lazy_tensor_core/main.py | 38 ++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/e2e_testing/lazy_tensor_core/main.py b/e2e_testing/lazy_tensor_core/main.py index 9f51480d222..0438cc31a22 100644 --- a/e2e_testing/lazy_tensor_core/main.py +++ b/e2e_testing/lazy_tensor_core/main.py @@ -21,22 +21,40 @@ class LTCNumericTests(unittest.TestCase): executed using the MLIR LTC backend and ensuring they match the results on CPU. """ - def assert_tensors_list_almost_equal(self, tensors_a, tensors_b): - self.assertEqual(len(tensors_a), len(tensors_b)) + def assert_tensors_almost_equal(self, tensor_a, tensor_b, message): + a, b = tensor_a.cpu().detach().numpy(), tensor_b.cpu().detach().numpy() - for idx in range(len(tensors_a)): - a = tensors_a[idx].cpu().detach().numpy() - b = tensors_b[idx].cpu().detach().numpy() - - assert_almost_equal(a, b) + assert_almost_equal(a, b, 7, message) def run_test(self, run_model): model_torch_mlir, loss_torch_mlir = run_model('lazy') model_cpu, loss_cpu = run_model('cpu') - # Ensure that model states and losses are almost equal between LTC and CPU. - self.assert_tensors_list_almost_equal(loss_torch_mlir, loss_cpu) - self.assert_tensors_list_almost_equal(list(model_torch_mlir.parameters()), list(model_cpu.parameters())) + # Check losses match. + self.assertEqual(len(loss_torch_mlir), len(loss_cpu)) + for idx in range(len(loss_torch_mlir)): + self.assert_tensors_almost_equal(loss_torch_mlir[idx], loss_cpu[idx], + f'Losses at index {idx} do not match!') + + # Check that number of parameters match. + torch_mlir_params, cpu_params = [list(model.named_parameters()) for model in (model_torch_mlir, model_cpu)] + self.assertEqual(len(torch_mlir_params), len(cpu_params)) + + # Check that names of parameters. + torch_mlir_keys = [] + for name, param in torch_mlir_params: + torch_mlir_keys.append(name) + + cpu_keys = [] + for name, param in cpu_params: + cpu_keys.append(name) + + self.assertEqual(torch_mlir_keys, cpu_keys) + + # Check contents of parameters match. + for idx in range(len(torch_mlir_params)): + self.assert_tensors_almost_equal(torch_mlir_params[idx][1], cpu_params[idx][1], + f'Parameters {torch_mlir_keys[idx]} do not match!') def test_bert(self): self.run_test(ltc_backend_bert.main) From d00f8c7f4c7b27f8eae487ea44d89e6e97e05806 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Tue, 7 Jun 2022 10:37:57 -0400 Subject: [PATCH 06/13] Run LTC e2e test with CI/CD --- .github/workflows/buildAndTest.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index e99c8c2477b..5496808c7e8 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -56,6 +56,11 @@ jobs: cd $GITHUB_WORKSPACE export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" python -m e2e_testing.torchscript.main --config=tosa -v + - name: Lazy Tensor Core end-to-end tests + run: | + cd $GITHUB_WORKSPACE + export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" + python -m e2e_testing.lazy_tensor_core.main build-out-of-tree: name: Build out-of-tree (Release Asserts) From f75c4730f53bab7c18b2a811f01fe8158dec655b Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Tue, 7 Jun 2022 10:41:39 -0400 Subject: [PATCH 07/13] Set seed in main function, instead of beginning of execution --- examples/ltc_backend_bert.py | 2 -- examples/ltc_backend_mnist.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/examples/ltc_backend_bert.py b/examples/ltc_backend_bert.py index 31cfe156e69..d309ba87136 100644 --- a/examples/ltc_backend_bert.py +++ b/examples/ltc_backend_bert.py @@ -126,8 +126,6 @@ def main(device='lazy', full_size=False): if __name__ == "__main__": - torch.manual_seed(0) - parser = argparse.ArgumentParser() parser.add_argument( "-d", diff --git a/examples/ltc_backend_mnist.py b/examples/ltc_backend_mnist.py index 7bb2db16deb..65f8a4e5f37 100644 --- a/examples/ltc_backend_mnist.py +++ b/examples/ltc_backend_mnist.py @@ -78,8 +78,6 @@ def forward(self, x): if __name__ == "__main__": - torch.manual_seed(0) - parser = argparse.ArgumentParser() parser.add_argument( "-d", From 28c50e0ef5e5b13a0b6ee88e7fe1c41d88a84675 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Tue, 7 Jun 2022 16:15:35 -0400 Subject: [PATCH 08/13] Add comment to specify number of digits of precision --- e2e_testing/lazy_tensor_core/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/e2e_testing/lazy_tensor_core/main.py b/e2e_testing/lazy_tensor_core/main.py index 0438cc31a22..32ce19857c6 100644 --- a/e2e_testing/lazy_tensor_core/main.py +++ b/e2e_testing/lazy_tensor_core/main.py @@ -24,6 +24,7 @@ class LTCNumericTests(unittest.TestCase): def assert_tensors_almost_equal(self, tensor_a, tensor_b, message): a, b = tensor_a.cpu().detach().numpy(), tensor_b.cpu().detach().numpy() + # Ensure tensors match up to 7 decimals of precision. assert_almost_equal(a, b, 7, message) def run_test(self, run_model): From 0b737f7fe27b0e5fab10e16a3f53953c1d3880dd Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Tue, 7 Jun 2022 23:36:15 -0400 Subject: [PATCH 09/13] Fixed typo --- e2e_testing/lazy_tensor_core/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/e2e_testing/lazy_tensor_core/main.py b/e2e_testing/lazy_tensor_core/main.py index 32ce19857c6..02f9d3dfef0 100644 --- a/e2e_testing/lazy_tensor_core/main.py +++ b/e2e_testing/lazy_tensor_core/main.py @@ -41,7 +41,7 @@ def run_test(self, run_model): torch_mlir_params, cpu_params = [list(model.named_parameters()) for model in (model_torch_mlir, model_cpu)] self.assertEqual(len(torch_mlir_params), len(cpu_params)) - # Check that names of parameters. + # Check that names of parameters match. torch_mlir_keys = [] for name, param in torch_mlir_params: torch_mlir_keys.append(name) From 9e23d975edf5640f57599c12e8d6b80525666805 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Thu, 9 Jun 2022 08:30:51 -0400 Subject: [PATCH 10/13] Remove tests for LTC example models --- .github/workflows/buildAndTest.yml | 5 - e2e_testing/lazy_tensor_core/bert.mlir | 1506 ----------------------- e2e_testing/lazy_tensor_core/main.py | 89 -- e2e_testing/lazy_tensor_core/mnist.mlir | 59 - 4 files changed, 1659 deletions(-) delete mode 100644 e2e_testing/lazy_tensor_core/bert.mlir delete mode 100644 e2e_testing/lazy_tensor_core/main.py delete mode 100644 e2e_testing/lazy_tensor_core/mnist.mlir diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index 5496808c7e8..e99c8c2477b 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -56,11 +56,6 @@ jobs: cd $GITHUB_WORKSPACE export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" python -m e2e_testing.torchscript.main --config=tosa -v - - name: Lazy Tensor Core end-to-end tests - run: | - cd $GITHUB_WORKSPACE - export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" - python -m e2e_testing.lazy_tensor_core.main build-out-of-tree: name: Build out-of-tree (Release Asserts) diff --git a/e2e_testing/lazy_tensor_core/bert.mlir b/e2e_testing/lazy_tensor_core/bert.mlir deleted file mode 100644 index 294d48cd597..00000000000 --- a/e2e_testing/lazy_tensor_core/bert.mlir +++ /dev/null @@ -1,1506 +0,0 @@ -func.func @graph(%arg0: !torch.float, %arg1: !torch.int, %arg2: !torch.vtensor<[],f64>, %arg3: !torch.float, %arg4: !torch.int, %arg5: !torch.vtensor<[2,512],si64>, %arg6: !torch.vtensor<[32],f32>, %arg7: !torch.vtensor<[32],f32>, %arg8: !torch.int, %arg9: !torch.vtensor<[1,512],si64>, %arg10: !torch.vtensor<[512,32],f32>, %arg11: !torch.int, %arg12: !torch.vtensor<[2,512],si64>, %arg13: !torch.vtensor<[2,32],f32>, %arg14: !torch.vtensor<[28996,32],f32>, %arg15: !torch.int, %arg16: !torch.vtensor<[32,32],f32>, %arg17: !torch.int, %arg18: !torch.int, %arg19: !torch.vtensor<[32,32],f32>, %arg20: !torch.vtensor<[32],f32>, %arg21: !torch.vtensor<[],f64>, %arg22: !torch.int, %arg23: !torch.vtensor<[],f64>, %arg24: !torch.int, %arg25: !torch.vtensor<[2,512],si64>, %arg26: !torch.vtensor<[],f64>, %arg27: !torch.int, %arg28: !torch.int, %arg29: !torch.vtensor<[32],f32>, %arg30: !torch.int, %arg31: !torch.int, %arg32: !torch.vtensor<[32,32],f32>, %arg33: !torch.vtensor<[32],f32>, %arg34: !torch.vtensor<[32,32],f32>, %arg35: !torch.vtensor<[32],f32>, %arg36: !torch.vtensor<[32],f32>, %arg37: !torch.int, %arg38: !torch.int, %arg39: !torch.int, %arg40: !torch.vtensor<[32],f32>, %arg41: !torch.int, %arg42: !torch.vtensor<[32,32],f32>, %arg43: !torch.int, %arg44: !torch.int, %arg45: !torch.vtensor<[32],f32>, %arg46: !torch.vtensor<[32,32],f32>, %arg47: !torch.vtensor<[32],f32>, %arg48: !torch.vtensor<[32],f32>, %arg49: !torch.int, %arg50: !torch.int, %arg51: !torch.int, %arg52: !torch.vtensor<[32],f32>, %arg53: !torch.vtensor<[32,32],f32>, %arg54: !torch.int, %arg55: !torch.int, %arg56: !torch.vtensor<[32],f32>, %arg57: !torch.vtensor<[2,32],f32>, %arg58: !torch.int, %arg59: !torch.int, %arg60: !torch.vtensor<[2],f32>, %arg61: !torch.vtensor<[2],si64>, %arg62: !torch.vtensor<[],f32>, %arg63: !torch.vtensor<[2,512,32],f32>, %arg64: !torch.vtensor<[2,512,32],f32>, %arg65: !torch.int, %arg66: !torch.int, %arg67: !torch.vtensor<[28996,32],f32>, %arg68: !torch.vtensor<[],f64>, %arg69: !torch.vtensor<[28996,32],f32>, %arg70: !torch.float, %arg71: !torch.vtensor<[],f64>, %arg72: !torch.vtensor<[28996,32],f32>, %arg73: !torch.float, %arg74: !torch.int, %arg75: !torch.float, %arg76: !torch.int, %arg77: !torch.vtensor<[512,32],f32>, %arg78: !torch.vtensor<[512,32],f32>, %arg79: !torch.float, %arg80: !torch.vtensor<[512,32],f32>, %arg81: !torch.float, %arg82: !torch.int, %arg83: !torch.float, %arg84: !torch.int, %arg85: !torch.vtensor<[2,32],f32>, %arg86: !torch.vtensor<[2,32],f32>, %arg87: !torch.float, %arg88: !torch.vtensor<[2,32],f32>, %arg89: !torch.float, %arg90: !torch.int, %arg91: !torch.float, %arg92: !torch.int, %arg93: !torch.vtensor<[32],f32>, %arg94: !torch.vtensor<[32],f32>, %arg95: !torch.float, %arg96: !torch.vtensor<[32],f32>, %arg97: !torch.float, %arg98: !torch.int, %arg99: !torch.float, %arg100: !torch.int, %arg101: !torch.vtensor<[32],f32>, %arg102: !torch.vtensor<[32],f32>, %arg103: !torch.float, %arg104: !torch.vtensor<[32],f32>, %arg105: !torch.float, %arg106: !torch.int, %arg107: !torch.float, %arg108: !torch.int, %arg109: !torch.vtensor<[32,32],f32>, %arg110: !torch.vtensor<[32,32],f32>, %arg111: !torch.float, %arg112: !torch.vtensor<[32,32],f32>, %arg113: !torch.float, %arg114: !torch.int, %arg115: !torch.float, %arg116: !torch.int, %arg117: !torch.vtensor<[32],f32>, %arg118: !torch.vtensor<[32],f32>, %arg119: !torch.float, %arg120: !torch.vtensor<[32],f32>, %arg121: !torch.float, %arg122: !torch.int, %arg123: !torch.float, %arg124: !torch.int, %arg125: !torch.vtensor<[32,32],f32>, %arg126: !torch.vtensor<[32,32],f32>, %arg127: !torch.float, %arg128: !torch.vtensor<[32,32],f32>, %arg129: !torch.float, %arg130: !torch.int, %arg131: !torch.float, %arg132: !torch.int, %arg133: !torch.vtensor<[32],f32>, %arg134: !torch.vtensor<[32],f32>, %arg135: !torch.float, %arg136: !torch.vtensor<[32],f32>, %arg137: !torch.float, %arg138: !torch.int, %arg139: !torch.float, %arg140: !torch.int, %arg141: !torch.vtensor<[32,32],f32>, %arg142: !torch.vtensor<[32,32],f32>, %arg143: !torch.float, %arg144: !torch.vtensor<[32,32],f32>, %arg145: !torch.float, %arg146: !torch.int, %arg147: !torch.float, %arg148: !torch.int, %arg149: !torch.vtensor<[32],f32>, %arg150: !torch.vtensor<[32],f32>, %arg151: !torch.float, %arg152: !torch.vtensor<[32],f32>, %arg153: !torch.float, %arg154: !torch.int, %arg155: !torch.float, %arg156: !torch.int, %arg157: !torch.vtensor<[32,32],f32>, %arg158: !torch.vtensor<[32,32],f32>, %arg159: !torch.float, %arg160: !torch.vtensor<[32,32],f32>, %arg161: !torch.float, %arg162: !torch.int, %arg163: !torch.float, %arg164: !torch.int, %arg165: !torch.vtensor<[32],f32>, %arg166: !torch.vtensor<[32],f32>, %arg167: !torch.float, %arg168: !torch.vtensor<[32],f32>, %arg169: !torch.float, %arg170: !torch.int, %arg171: !torch.float, %arg172: !torch.int, %arg173: !torch.vtensor<[32],f32>, %arg174: !torch.vtensor<[32],f32>, %arg175: !torch.float, %arg176: !torch.vtensor<[32],f32>, %arg177: !torch.float, %arg178: !torch.int, %arg179: !torch.float, %arg180: !torch.int, %arg181: !torch.vtensor<[32],f32>, %arg182: !torch.vtensor<[32],f32>, %arg183: !torch.float, %arg184: !torch.vtensor<[32],f32>, %arg185: !torch.float, %arg186: !torch.int, %arg187: !torch.float, %arg188: !torch.int, %arg189: !torch.vtensor<[32,32],f32>, %arg190: !torch.vtensor<[32,32],f32>, %arg191: !torch.float, %arg192: !torch.vtensor<[32,32],f32>, %arg193: !torch.float, %arg194: !torch.int, %arg195: !torch.float, %arg196: !torch.int, %arg197: !torch.vtensor<[32],f32>, %arg198: !torch.vtensor<[32],f32>, %arg199: !torch.float, %arg200: !torch.vtensor<[32],f32>, %arg201: !torch.float, %arg202: !torch.int, %arg203: !torch.float, %arg204: !torch.int, %arg205: !torch.vtensor<[32,32],f32>, %arg206: !torch.vtensor<[32,32],f32>, %arg207: !torch.float, %arg208: !torch.vtensor<[32,32],f32>, %arg209: !torch.float, %arg210: !torch.int, %arg211: !torch.float, %arg212: !torch.int, %arg213: !torch.vtensor<[32],f32>, %arg214: !torch.vtensor<[32],f32>, %arg215: !torch.float, %arg216: !torch.vtensor<[32],f32>, %arg217: !torch.float, %arg218: !torch.int, %arg219: !torch.float, %arg220: !torch.int, %arg221: !torch.vtensor<[32],f32>, %arg222: !torch.vtensor<[32],f32>, %arg223: !torch.float, %arg224: !torch.vtensor<[32],f32>, %arg225: !torch.float, %arg226: !torch.int, %arg227: !torch.float, %arg228: !torch.int, %arg229: !torch.vtensor<[32],f32>, %arg230: !torch.vtensor<[32],f32>, %arg231: !torch.float, %arg232: !torch.vtensor<[32],f32>, %arg233: !torch.float, %arg234: !torch.int, %arg235: !torch.float, %arg236: !torch.int, %arg237: !torch.vtensor<[32,32],f32>, %arg238: !torch.vtensor<[32,32],f32>, %arg239: !torch.float, %arg240: !torch.vtensor<[32,32],f32>, %arg241: !torch.float, %arg242: !torch.int, %arg243: !torch.float, %arg244: !torch.int, %arg245: !torch.vtensor<[32],f32>, %arg246: !torch.vtensor<[32],f32>, %arg247: !torch.float, %arg248: !torch.vtensor<[32],f32>, %arg249: !torch.float, %arg250: !torch.int, %arg251: !torch.float, %arg252: !torch.int, %arg253: !torch.vtensor<[2,32],f32>, %arg254: !torch.vtensor<[2,32],f32>, %arg255: !torch.float, %arg256: !torch.vtensor<[2,32],f32>, %arg257: !torch.float, %arg258: !torch.int, %arg259: !torch.float, %arg260: !torch.int, %arg261: !torch.vtensor<[2],f32>, %arg262: !torch.vtensor<[2],f32>, %arg263: !torch.float, %arg264: !torch.vtensor<[2],f32>) -> (!torch.vtensor<[28996,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[],f32>) { - %int0 = torch.constant.int 0 - %int0_0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %int1_1 = torch.constant.int 1 - %0 = torch.aten.slice.Tensor %arg9, %int0, %int0_0, %int1, %int1_1 : !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,512],si64> - %int-1 = torch.constant.int -1 - %false = torch.constant.bool false - %false_2 = torch.constant.bool false - %1 = torch.aten.embedding %arg10, %0, %int-1, %false, %false_2 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[1,512],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,32],f32> - %int-1_3 = torch.constant.int -1 - %false_4 = torch.constant.bool false - %false_5 = torch.constant.bool false - %2 = torch.aten.embedding %arg13, %arg12, %int-1_3, %false_4, %false_5 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,512],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[2,512,32],f32> - %int0_6 = torch.constant.int 0 - %false_7 = torch.constant.bool false - %false_8 = torch.constant.bool false - %3 = torch.aten.embedding %arg14, %arg5, %int0_6, %false_7, %false_8 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[2,512],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[2,512,32],f32> - %4 = torch.aten.add.Tensor %3, %2, %arg11 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> - %5 = torch.aten.add.Tensor %4, %1, %arg8 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[1,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> - %int32 = torch.constant.int 32 - %6 = torch.prim.ListConstruct %int32 : (!torch.int) -> !torch.list - %float1.000000e-05 = torch.constant.float 1.000000e-05 - %result0, %result1, %result2 = torch.aten.native_layer_norm %5, %6, %arg7, %arg6, %float1.000000e-05 : !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> - %7 = torch.prim.TupleConstruct %result0, %result1, %result2 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> -> !torch.tuple, vtensor<[2,512,1],f32>, vtensor<[2,512,1],f32>> - %int1_9 = torch.constant.int 1 - %int0_10 = torch.constant.int 0 - %8 = torch.prim.ListConstruct %int1_9, %int0_10 : (!torch.int, !torch.int) -> !torch.list - %int1_11 = torch.constant.int 1 - %int0_12 = torch.constant.int 0 - %9 = torch.prim.ListConstruct %int1_11, %int0_12 : (!torch.int, !torch.int) -> !torch.list - %10 = torch.aten.permute %arg16, %9 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_13 = torch.constant.int 1 - %int0_14 = torch.constant.int 0 - %11 = torch.prim.ListConstruct %int1_13, %int0_14 : (!torch.int, !torch.int) -> !torch.list - %int1_15 = torch.constant.int 1 - %int0_16 = torch.constant.int 0 - %12 = torch.prim.ListConstruct %int1_15, %int0_16 : (!torch.int, !torch.int) -> !torch.list - %13 = torch.aten.permute %10, %12 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_17 = torch.constant.int 1 - %int0_18 = torch.constant.int 0 - %14 = torch.prim.ListConstruct %int1_17, %int0_18 : (!torch.int, !torch.int) -> !torch.list - %int1_19 = torch.constant.int 1 - %int0_20 = torch.constant.int 0 - %15 = torch.prim.ListConstruct %int1_19, %int0_20 : (!torch.int, !torch.int) -> !torch.list - %16 = torch.aten.permute %arg19, %15 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1024 = torch.constant.int 1024 - %int32_21 = torch.constant.int 32 - %17 = torch.prim.ListConstruct %int1024, %int32_21 : (!torch.int, !torch.int) -> !torch.list - %int1024_22 = torch.constant.int 1024 - %int32_23 = torch.constant.int 32 - %18 = torch.prim.ListConstruct %int1024_22, %int32_23 : (!torch.int, !torch.int) -> !torch.list - %19 = torch.aten.reshape %result0, %18 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %20 = torch.aten.addmm %arg20, %19, %16, %arg18, %arg17 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> - %int2 = torch.constant.int 2 - %int512 = torch.constant.int 512 - %int32_24 = torch.constant.int 32 - %21 = torch.prim.ListConstruct %int2, %int512, %int32_24 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_25 = torch.constant.int 2 - %int512_26 = torch.constant.int 512 - %int32_27 = torch.constant.int 32 - %22 = torch.prim.ListConstruct %int2_25, %int512_26, %int32_27 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %23 = torch.aten.reshape %20, %22 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int2_28 = torch.constant.int 2 - %int512_29 = torch.constant.int 512 - %int16 = torch.constant.int 16 - %24 = torch.prim.ListConstruct %int2_28, %int512_29, %int2_28, %int16 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_30 = torch.constant.int 2 - %int512_31 = torch.constant.int 512 - %int16_32 = torch.constant.int 16 - %25 = torch.prim.ListConstruct %int2_30, %int512_31, %int2_30, %int16_32 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %26 = torch.aten.reshape %23, %25 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> - %int0_33 = torch.constant.int 0 - %int2_34 = torch.constant.int 2 - %int1_35 = torch.constant.int 1 - %int3 = torch.constant.int 3 - %27 = torch.prim.ListConstruct %int0_33, %int2_34, %int1_35, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_36 = torch.constant.int 0 - %int2_37 = torch.constant.int 2 - %int1_38 = torch.constant.int 1 - %int3_39 = torch.constant.int 3 - %28 = torch.prim.ListConstruct %int0_36, %int2_37, %int1_38, %int3_39 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %29 = torch.aten.permute %26, %28 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> - %int0_40 = torch.constant.int 0 - %int1_41 = torch.constant.int 1 - %int3_42 = torch.constant.int 3 - %int2_43 = torch.constant.int 2 - %30 = torch.prim.ListConstruct %int0_40, %int1_41, %int3_42, %int2_43 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_44 = torch.constant.int 0 - %int1_45 = torch.constant.int 1 - %int3_46 = torch.constant.int 3 - %int2_47 = torch.constant.int 2 - %31 = torch.prim.ListConstruct %int0_44, %int1_45, %int3_46, %int2_47 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %32 = torch.aten.permute %29, %31 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,2,16,512],f32> - %int2_48 = torch.constant.int 2 - %int16_49 = torch.constant.int 16 - %int512_50 = torch.constant.int 512 - %33 = torch.prim.ListConstruct %int2_48, %int2_48, %int16_49, %int512_50 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_51 = torch.constant.bool false - %34 = torch.aten.expand %32, %33, %false_51 : !torch.vtensor<[2,2,16,512],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,2,16,512],f32> - %int4 = torch.constant.int 4 - %int16_52 = torch.constant.int 16 - %int512_53 = torch.constant.int 512 - %35 = torch.prim.ListConstruct %int4, %int16_52, %int512_53 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int4_54 = torch.constant.int 4 - %int16_55 = torch.constant.int 16 - %int512_56 = torch.constant.int 512 - %36 = torch.prim.ListConstruct %int4_54, %int16_55, %int512_56 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %37 = torch.aten.reshape %34, %36 : !torch.vtensor<[2,2,16,512],f32>, !torch.list -> !torch.vtensor<[4,16,512],f32> - %int0_57 = torch.constant.int 0 - %int2_58 = torch.constant.int 2 - %int1_59 = torch.constant.int 1 - %38 = torch.prim.ListConstruct %int0_57, %int2_58, %int1_59 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int0_60 = torch.constant.int 0 - %int2_61 = torch.constant.int 2 - %int1_62 = torch.constant.int 1 - %39 = torch.prim.ListConstruct %int0_60, %int2_61, %int1_62 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %40 = torch.aten.permute %37, %39 : !torch.vtensor<[4,16,512],f32>, !torch.list -> !torch.vtensor<[4,512,16],f32> - %int0_63 = torch.constant.int 0 - %int0_64 = torch.constant.int 0 - %int2_65 = torch.constant.int 2 - %int1_66 = torch.constant.int 1 - %41 = torch.aten.slice.Tensor %arg25, %int0_63, %int0_64, %int2_65, %int1_66 : !torch.vtensor<[2,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,512],si64> - %int2_67 = torch.constant.int 2 - %int1_68 = torch.constant.int 1 - %int512_69 = torch.constant.int 512 - %42 = torch.prim.ListConstruct %int2_67, %int1_68, %int512_69 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_70 = torch.constant.int 2 - %int1_71 = torch.constant.int 1 - %int512_72 = torch.constant.int 512 - %43 = torch.prim.ListConstruct %int2_70, %int1_71, %int512_72 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %44 = torch.aten.reshape %41, %43 : !torch.vtensor<[2,512],si64>, !torch.list -> !torch.vtensor<[2,1,512],si64> - %int2_73 = torch.constant.int 2 - %int1_74 = torch.constant.int 1 - %int512_75 = torch.constant.int 512 - %45 = torch.prim.ListConstruct %int2_73, %int1_74, %int1_74, %int512_75 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_76 = torch.constant.int 2 - %int1_77 = torch.constant.int 1 - %int512_78 = torch.constant.int 512 - %46 = torch.prim.ListConstruct %int2_76, %int1_77, %int1_77, %int512_78 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %47 = torch.aten.reshape %44, %46 : !torch.vtensor<[2,1,512],si64>, !torch.list -> !torch.vtensor<[2,1,1,512],si64> - %int3_79 = torch.constant.int 3 - %int0_80 = torch.constant.int 0 - %int512_81 = torch.constant.int 512 - %int1_82 = torch.constant.int 1 - %48 = torch.aten.slice.Tensor %47, %int3_79, %int0_80, %int512_81, %int1_82 : !torch.vtensor<[2,1,1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1,1,512],si64> - %int6 = torch.constant.int 6 - %none = torch.constant.none - %none_83 = torch.constant.none - %none_84 = torch.constant.none - %false_85 = torch.constant.bool false - %none_86 = torch.constant.none - %49 = torch.aten._to_copy %48, %int6, %none, %none_83, %none_84, %false_85, %none_86 : !torch.vtensor<[2,1,1,512],si64>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[2,1,1,512],f32> - %50 = torch.aten.sub.Tensor %arg26, %49, %arg24 : !torch.vtensor<[],f64>, !torch.vtensor<[2,1,1,512],f32>, !torch.int -> !torch.vtensor<[2,1,1,512],f32> - %51 = torch.aten.mul.Tensor %50, %arg23 : !torch.vtensor<[2,1,1,512],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,1,1,512],f32> - %int4_87 = torch.constant.int 4 - %int16_88 = torch.constant.int 16 - %int512_89 = torch.constant.int 512 - %52 = torch.prim.ListConstruct %int4_87, %int16_88, %int512_89 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int4_90 = torch.constant.int 4 - %int16_91 = torch.constant.int 16 - %int512_92 = torch.constant.int 512 - %53 = torch.prim.ListConstruct %int4_90, %int16_91, %int512_92 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %54 = torch.aten.reshape %34, %53 : !torch.vtensor<[2,2,16,512],f32>, !torch.list -> !torch.vtensor<[4,16,512],f32> - %int1_93 = torch.constant.int 1 - %int0_94 = torch.constant.int 0 - %55 = torch.prim.ListConstruct %int1_93, %int0_94 : (!torch.int, !torch.int) -> !torch.list - %int1_95 = torch.constant.int 1 - %int0_96 = torch.constant.int 0 - %56 = torch.prim.ListConstruct %int1_95, %int0_96 : (!torch.int, !torch.int) -> !torch.list - %57 = torch.aten.permute %arg16, %56 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1024_97 = torch.constant.int 1024 - %int32_98 = torch.constant.int 32 - %58 = torch.prim.ListConstruct %int1024_97, %int32_98 : (!torch.int, !torch.int) -> !torch.list - %int1024_99 = torch.constant.int 1024 - %int32_100 = torch.constant.int 32 - %59 = torch.prim.ListConstruct %int1024_99, %int32_100 : (!torch.int, !torch.int) -> !torch.list - %60 = torch.aten.reshape %result0, %59 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %61 = torch.aten.addmm %arg29, %60, %57, %arg28, %arg27 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> - %int2_101 = torch.constant.int 2 - %int512_102 = torch.constant.int 512 - %int32_103 = torch.constant.int 32 - %62 = torch.prim.ListConstruct %int2_101, %int512_102, %int32_103 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_104 = torch.constant.int 2 - %int512_105 = torch.constant.int 512 - %int32_106 = torch.constant.int 32 - %63 = torch.prim.ListConstruct %int2_104, %int512_105, %int32_106 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %64 = torch.aten.reshape %61, %63 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int2_107 = torch.constant.int 2 - %int512_108 = torch.constant.int 512 - %int16_109 = torch.constant.int 16 - %65 = torch.prim.ListConstruct %int2_107, %int512_108, %int2_107, %int16_109 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_110 = torch.constant.int 2 - %int512_111 = torch.constant.int 512 - %int16_112 = torch.constant.int 16 - %66 = torch.prim.ListConstruct %int2_110, %int512_111, %int2_110, %int16_112 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %67 = torch.aten.reshape %64, %66 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> - %int0_113 = torch.constant.int 0 - %int2_114 = torch.constant.int 2 - %int1_115 = torch.constant.int 1 - %int3_116 = torch.constant.int 3 - %68 = torch.prim.ListConstruct %int0_113, %int2_114, %int1_115, %int3_116 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_117 = torch.constant.int 0 - %int2_118 = torch.constant.int 2 - %int1_119 = torch.constant.int 1 - %int3_120 = torch.constant.int 3 - %69 = torch.prim.ListConstruct %int0_117, %int2_118, %int1_119, %int3_120 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %70 = torch.aten.permute %67, %69 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> - %int2_121 = torch.constant.int 2 - %int512_122 = torch.constant.int 512 - %int16_123 = torch.constant.int 16 - %71 = torch.prim.ListConstruct %int2_121, %int2_121, %int512_122, %int16_123 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_124 = torch.constant.bool false - %72 = torch.aten.expand %70, %71, %false_124 : !torch.vtensor<[2,2,512,16],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,2,512,16],f32> - %int4_125 = torch.constant.int 4 - %int512_126 = torch.constant.int 512 - %int16_127 = torch.constant.int 16 - %73 = torch.prim.ListConstruct %int4_125, %int512_126, %int16_127 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int4_128 = torch.constant.int 4 - %int512_129 = torch.constant.int 512 - %int16_130 = torch.constant.int 16 - %74 = torch.prim.ListConstruct %int4_128, %int512_129, %int16_130 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %75 = torch.aten.reshape %72, %74 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[4,512,16],f32> - %76 = torch.aten.bmm %75, %54 : !torch.vtensor<[4,512,16],f32>, !torch.vtensor<[4,16,512],f32> -> !torch.vtensor<[4,512,512],f32> - %int2_131 = torch.constant.int 2 - %int512_132 = torch.constant.int 512 - %77 = torch.prim.ListConstruct %int2_131, %int2_131, %int512_132, %int512_132 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_133 = torch.constant.int 2 - %int512_134 = torch.constant.int 512 - %78 = torch.prim.ListConstruct %int2_133, %int2_133, %int512_134, %int512_134 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %79 = torch.aten.reshape %76, %78 : !torch.vtensor<[4,512,512],f32>, !torch.list -> !torch.vtensor<[2,2,512,512],f32> - %80 = torch.aten.div.Tensor %79, %arg21 : !torch.vtensor<[2,2,512,512],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,2,512,512],f32> - %81 = torch.aten.add.Tensor %80, %51, %arg22 : !torch.vtensor<[2,2,512,512],f32>, !torch.vtensor<[2,1,1,512],f32>, !torch.int -> !torch.vtensor<[2,2,512,512],f32> - %int-1_135 = torch.constant.int -1 - %false_136 = torch.constant.bool false - %82 = torch.aten._softmax %81, %int-1_135, %false_136 : !torch.vtensor<[2,2,512,512],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,2,512,512],f32> - %int1_137 = torch.constant.int 1 - %int0_138 = torch.constant.int 0 - %83 = torch.prim.ListConstruct %int1_137, %int0_138 : (!torch.int, !torch.int) -> !torch.list - %int1_139 = torch.constant.int 1 - %int0_140 = torch.constant.int 0 - %84 = torch.prim.ListConstruct %int1_139, %int0_140 : (!torch.int, !torch.int) -> !torch.list - %85 = torch.aten.permute %arg32, %84 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1024_141 = torch.constant.int 1024 - %int32_142 = torch.constant.int 32 - %86 = torch.prim.ListConstruct %int1024_141, %int32_142 : (!torch.int, !torch.int) -> !torch.list - %int1024_143 = torch.constant.int 1024 - %int32_144 = torch.constant.int 32 - %87 = torch.prim.ListConstruct %int1024_143, %int32_144 : (!torch.int, !torch.int) -> !torch.list - %88 = torch.aten.reshape %result0, %87 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %89 = torch.aten.addmm %arg33, %88, %85, %arg31, %arg30 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> - %int2_145 = torch.constant.int 2 - %int512_146 = torch.constant.int 512 - %int32_147 = torch.constant.int 32 - %90 = torch.prim.ListConstruct %int2_145, %int512_146, %int32_147 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_148 = torch.constant.int 2 - %int512_149 = torch.constant.int 512 - %int32_150 = torch.constant.int 32 - %91 = torch.prim.ListConstruct %int2_148, %int512_149, %int32_150 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %92 = torch.aten.reshape %89, %91 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int2_151 = torch.constant.int 2 - %int512_152 = torch.constant.int 512 - %int16_153 = torch.constant.int 16 - %93 = torch.prim.ListConstruct %int2_151, %int512_152, %int2_151, %int16_153 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_154 = torch.constant.int 2 - %int512_155 = torch.constant.int 512 - %int16_156 = torch.constant.int 16 - %94 = torch.prim.ListConstruct %int2_154, %int512_155, %int2_154, %int16_156 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %95 = torch.aten.reshape %92, %94 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> - %int0_157 = torch.constant.int 0 - %int2_158 = torch.constant.int 2 - %int1_159 = torch.constant.int 1 - %int3_160 = torch.constant.int 3 - %96 = torch.prim.ListConstruct %int0_157, %int2_158, %int1_159, %int3_160 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_161 = torch.constant.int 0 - %int2_162 = torch.constant.int 2 - %int1_163 = torch.constant.int 1 - %int3_164 = torch.constant.int 3 - %97 = torch.prim.ListConstruct %int0_161, %int2_162, %int1_163, %int3_164 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %98 = torch.aten.permute %95, %97 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> - %int2_165 = torch.constant.int 2 - %int512_166 = torch.constant.int 512 - %int16_167 = torch.constant.int 16 - %99 = torch.prim.ListConstruct %int2_165, %int2_165, %int512_166, %int16_167 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_168 = torch.constant.bool false - %100 = torch.aten.expand %98, %99, %false_168 : !torch.vtensor<[2,2,512,16],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,2,512,16],f32> - %int4_169 = torch.constant.int 4 - %int512_170 = torch.constant.int 512 - %int16_171 = torch.constant.int 16 - %101 = torch.prim.ListConstruct %int4_169, %int512_170, %int16_171 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int4_172 = torch.constant.int 4 - %int512_173 = torch.constant.int 512 - %int16_174 = torch.constant.int 16 - %102 = torch.prim.ListConstruct %int4_172, %int512_173, %int16_174 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %103 = torch.aten.reshape %100, %102 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[4,512,16],f32> - %int0_175 = torch.constant.int 0 - %int2_176 = torch.constant.int 2 - %int1_177 = torch.constant.int 1 - %104 = torch.prim.ListConstruct %int0_175, %int2_176, %int1_177 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int0_178 = torch.constant.int 0 - %int2_179 = torch.constant.int 2 - %int1_180 = torch.constant.int 1 - %105 = torch.prim.ListConstruct %int0_178, %int2_179, %int1_180 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %106 = torch.aten.permute %103, %105 : !torch.vtensor<[4,512,16],f32>, !torch.list -> !torch.vtensor<[4,16,512],f32> - %int1_181 = torch.constant.int 1 - %int0_182 = torch.constant.int 0 - %107 = torch.prim.ListConstruct %int1_181, %int0_182 : (!torch.int, !torch.int) -> !torch.list - %int1_183 = torch.constant.int 1 - %int0_184 = torch.constant.int 0 - %108 = torch.prim.ListConstruct %int1_183, %int0_184 : (!torch.int, !torch.int) -> !torch.list - %109 = torch.aten.permute %arg34, %108 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_185 = torch.constant.int 1 - %int0_186 = torch.constant.int 0 - %110 = torch.prim.ListConstruct %int1_185, %int0_186 : (!torch.int, !torch.int) -> !torch.list - %int1_187 = torch.constant.int 1 - %int0_188 = torch.constant.int 0 - %111 = torch.prim.ListConstruct %int1_187, %int0_188 : (!torch.int, !torch.int) -> !torch.list - %112 = torch.aten.permute %109, %111 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_189 = torch.constant.int 1 - %int0_190 = torch.constant.int 0 - %113 = torch.prim.ListConstruct %int1_189, %int0_190 : (!torch.int, !torch.int) -> !torch.list - %int1_191 = torch.constant.int 1 - %int0_192 = torch.constant.int 0 - %114 = torch.prim.ListConstruct %int1_191, %int0_192 : (!torch.int, !torch.int) -> !torch.list - %115 = torch.aten.permute %arg34, %114 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int4_193 = torch.constant.int 4 - %int512_194 = torch.constant.int 512 - %int16_195 = torch.constant.int 16 - %116 = torch.prim.ListConstruct %int4_193, %int512_194, %int16_195 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int4_196 = torch.constant.int 4 - %int512_197 = torch.constant.int 512 - %int16_198 = torch.constant.int 16 - %117 = torch.prim.ListConstruct %int4_196, %int512_197, %int16_198 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %118 = torch.aten.reshape %100, %117 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[4,512,16],f32> - %int2_199 = torch.constant.int 2 - %int512_200 = torch.constant.int 512 - %119 = torch.prim.ListConstruct %int2_199, %int2_199, %int512_200, %int512_200 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_201 = torch.constant.bool false - %120 = torch.aten.expand %82, %119, %false_201 : !torch.vtensor<[2,2,512,512],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,2,512,512],f32> - %int4_202 = torch.constant.int 4 - %int512_203 = torch.constant.int 512 - %121 = torch.prim.ListConstruct %int4_202, %int512_203, %int512_203 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int4_204 = torch.constant.int 4 - %int512_205 = torch.constant.int 512 - %122 = torch.prim.ListConstruct %int4_204, %int512_205, %int512_205 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %123 = torch.aten.reshape %120, %122 : !torch.vtensor<[2,2,512,512],f32>, !torch.list -> !torch.vtensor<[4,512,512],f32> - %124 = torch.aten.bmm %123, %118 : !torch.vtensor<[4,512,512],f32>, !torch.vtensor<[4,512,16],f32> -> !torch.vtensor<[4,512,16],f32> - %int2_206 = torch.constant.int 2 - %int512_207 = torch.constant.int 512 - %int16_208 = torch.constant.int 16 - %125 = torch.prim.ListConstruct %int2_206, %int2_206, %int512_207, %int16_208 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_209 = torch.constant.int 2 - %int512_210 = torch.constant.int 512 - %int16_211 = torch.constant.int 16 - %126 = torch.prim.ListConstruct %int2_209, %int2_209, %int512_210, %int16_211 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %127 = torch.aten.reshape %124, %126 : !torch.vtensor<[4,512,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> - %int0_212 = torch.constant.int 0 - %int2_213 = torch.constant.int 2 - %int1_214 = torch.constant.int 1 - %int3_215 = torch.constant.int 3 - %128 = torch.prim.ListConstruct %int0_212, %int2_213, %int1_214, %int3_215 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_216 = torch.constant.int 0 - %int2_217 = torch.constant.int 2 - %int1_218 = torch.constant.int 1 - %int3_219 = torch.constant.int 3 - %129 = torch.prim.ListConstruct %int0_216, %int2_217, %int1_218, %int3_219 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %130 = torch.aten.permute %127, %129 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> - %int2_220 = torch.constant.int 2 - %int512_221 = torch.constant.int 512 - %int32_222 = torch.constant.int 32 - %131 = torch.prim.ListConstruct %int2_220, %int512_221, %int32_222 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_223 = torch.constant.int 2 - %int512_224 = torch.constant.int 512 - %int32_225 = torch.constant.int 32 - %132 = torch.prim.ListConstruct %int2_223, %int512_224, %int32_225 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %133 = torch.aten.reshape %130, %132 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int1024_226 = torch.constant.int 1024 - %int32_227 = torch.constant.int 32 - %134 = torch.prim.ListConstruct %int1024_226, %int32_227 : (!torch.int, !torch.int) -> !torch.list - %int1024_228 = torch.constant.int 1024 - %int32_229 = torch.constant.int 32 - %135 = torch.prim.ListConstruct %int1024_228, %int32_229 : (!torch.int, !torch.int) -> !torch.list - %136 = torch.aten.reshape %133, %135 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %137 = torch.aten.addmm %arg40, %136, %115, %arg39, %arg38 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> - %int2_230 = torch.constant.int 2 - %int512_231 = torch.constant.int 512 - %int32_232 = torch.constant.int 32 - %138 = torch.prim.ListConstruct %int2_230, %int512_231, %int32_232 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_233 = torch.constant.int 2 - %int512_234 = torch.constant.int 512 - %int32_235 = torch.constant.int 32 - %139 = torch.prim.ListConstruct %int2_233, %int512_234, %int32_235 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %140 = torch.aten.reshape %137, %139 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %141 = torch.aten.add.Tensor %140, %result0, %arg37 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> - %int32_236 = torch.constant.int 32 - %142 = torch.prim.ListConstruct %int32_236 : (!torch.int) -> !torch.list - %float1.000000e-05_237 = torch.constant.float 1.000000e-05 - %result0_238, %result1_239, %result2_240 = torch.aten.native_layer_norm %141, %142, %arg36, %arg35, %float1.000000e-05_237 : !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> - %143 = torch.prim.TupleConstruct %result0_238, %result1_239, %result2_240 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> -> !torch.tuple, vtensor<[2,512,1],f32>, vtensor<[2,512,1],f32>> - %int1_241 = torch.constant.int 1 - %int0_242 = torch.constant.int 0 - %144 = torch.prim.ListConstruct %int1_241, %int0_242 : (!torch.int, !torch.int) -> !torch.list - %int1_243 = torch.constant.int 1 - %int0_244 = torch.constant.int 0 - %145 = torch.prim.ListConstruct %int1_243, %int0_244 : (!torch.int, !torch.int) -> !torch.list - %146 = torch.aten.permute %arg42, %145 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_245 = torch.constant.int 1 - %int0_246 = torch.constant.int 0 - %147 = torch.prim.ListConstruct %int1_245, %int0_246 : (!torch.int, !torch.int) -> !torch.list - %int1_247 = torch.constant.int 1 - %int0_248 = torch.constant.int 0 - %148 = torch.prim.ListConstruct %int1_247, %int0_248 : (!torch.int, !torch.int) -> !torch.list - %149 = torch.aten.permute %146, %148 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_249 = torch.constant.int 1 - %int0_250 = torch.constant.int 0 - %150 = torch.prim.ListConstruct %int1_249, %int0_250 : (!torch.int, !torch.int) -> !torch.list - %int1_251 = torch.constant.int 1 - %int0_252 = torch.constant.int 0 - %151 = torch.prim.ListConstruct %int1_251, %int0_252 : (!torch.int, !torch.int) -> !torch.list - %152 = torch.aten.permute %arg42, %151 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1024_253 = torch.constant.int 1024 - %int32_254 = torch.constant.int 32 - %153 = torch.prim.ListConstruct %int1024_253, %int32_254 : (!torch.int, !torch.int) -> !torch.list - %int1024_255 = torch.constant.int 1024 - %int32_256 = torch.constant.int 32 - %154 = torch.prim.ListConstruct %int1024_255, %int32_256 : (!torch.int, !torch.int) -> !torch.list - %155 = torch.aten.reshape %result0_238, %154 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %156 = torch.aten.addmm %arg45, %155, %152, %arg44, %arg43 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> - %int2_257 = torch.constant.int 2 - %int512_258 = torch.constant.int 512 - %int32_259 = torch.constant.int 32 - %157 = torch.prim.ListConstruct %int2_257, %int512_258, %int32_259 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_260 = torch.constant.int 2 - %int512_261 = torch.constant.int 512 - %int32_262 = torch.constant.int 32 - %158 = torch.prim.ListConstruct %int2_260, %int512_261, %int32_262 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %159 = torch.aten.reshape %156, %158 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int1_263 = torch.constant.int 1 - %int0_264 = torch.constant.int 0 - %160 = torch.prim.ListConstruct %int1_263, %int0_264 : (!torch.int, !torch.int) -> !torch.list - %int1_265 = torch.constant.int 1 - %int0_266 = torch.constant.int 0 - %161 = torch.prim.ListConstruct %int1_265, %int0_266 : (!torch.int, !torch.int) -> !torch.list - %162 = torch.aten.permute %arg46, %161 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_267 = torch.constant.int 1 - %int0_268 = torch.constant.int 0 - %163 = torch.prim.ListConstruct %int1_267, %int0_268 : (!torch.int, !torch.int) -> !torch.list - %int1_269 = torch.constant.int 1 - %int0_270 = torch.constant.int 0 - %164 = torch.prim.ListConstruct %int1_269, %int0_270 : (!torch.int, !torch.int) -> !torch.list - %165 = torch.aten.permute %162, %164 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_271 = torch.constant.int 1 - %int0_272 = torch.constant.int 0 - %166 = torch.prim.ListConstruct %int1_271, %int0_272 : (!torch.int, !torch.int) -> !torch.list - %int1_273 = torch.constant.int 1 - %int0_274 = torch.constant.int 0 - %167 = torch.prim.ListConstruct %int1_273, %int0_274 : (!torch.int, !torch.int) -> !torch.list - %168 = torch.aten.permute %arg46, %167 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %str = torch.constant.str "none" - %169 = torch.aten.gelu %159, %str : !torch.vtensor<[2,512,32],f32>, !torch.str -> !torch.vtensor<[2,512,32],f32> - %int1024_275 = torch.constant.int 1024 - %int32_276 = torch.constant.int 32 - %170 = torch.prim.ListConstruct %int1024_275, %int32_276 : (!torch.int, !torch.int) -> !torch.list - %int1024_277 = torch.constant.int 1024 - %int32_278 = torch.constant.int 32 - %171 = torch.prim.ListConstruct %int1024_277, %int32_278 : (!torch.int, !torch.int) -> !torch.list - %172 = torch.aten.reshape %169, %171 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %173 = torch.aten.addmm %arg52, %172, %168, %arg51, %arg50 : !torch.vtensor<[32],f32>, !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[1024,32],f32> - %int2_279 = torch.constant.int 2 - %int512_280 = torch.constant.int 512 - %int32_281 = torch.constant.int 32 - %174 = torch.prim.ListConstruct %int2_279, %int512_280, %int32_281 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_282 = torch.constant.int 2 - %int512_283 = torch.constant.int 512 - %int32_284 = torch.constant.int 32 - %175 = torch.prim.ListConstruct %int2_282, %int512_283, %int32_284 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %176 = torch.aten.reshape %173, %175 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %177 = torch.aten.add.Tensor %176, %result0_238, %arg49 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> - %int32_285 = torch.constant.int 32 - %178 = torch.prim.ListConstruct %int32_285 : (!torch.int) -> !torch.list - %float1.000000e-05_286 = torch.constant.float 1.000000e-05 - %result0_287, %result1_288, %result2_289 = torch.aten.native_layer_norm %177, %178, %arg48, %arg47, %float1.000000e-05_286 : !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> - %179 = torch.prim.TupleConstruct %result0_287, %result1_288, %result2_289 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32> -> !torch.tuple, vtensor<[2,512,1],f32>, vtensor<[2,512,1],f32>> - %int1_290 = torch.constant.int 1 - %int0_291 = torch.constant.int 0 - %180 = torch.prim.ListConstruct %int1_290, %int0_291 : (!torch.int, !torch.int) -> !torch.list - %int1_292 = torch.constant.int 1 - %int0_293 = torch.constant.int 0 - %181 = torch.prim.ListConstruct %int1_292, %int0_293 : (!torch.int, !torch.int) -> !torch.list - %182 = torch.aten.permute %arg53, %181 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_294 = torch.constant.int 1 - %int0_295 = torch.constant.int 0 - %183 = torch.prim.ListConstruct %int1_294, %int0_295 : (!torch.int, !torch.int) -> !torch.list - %int1_296 = torch.constant.int 1 - %int0_297 = torch.constant.int 0 - %184 = torch.prim.ListConstruct %int1_296, %int0_297 : (!torch.int, !torch.int) -> !torch.list - %185 = torch.aten.permute %182, %184 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_298 = torch.constant.int 1 - %int0_299 = torch.constant.int 0 - %186 = torch.prim.ListConstruct %int1_298, %int0_299 : (!torch.int, !torch.int) -> !torch.list - %int1_300 = torch.constant.int 1 - %int0_301 = torch.constant.int 0 - %187 = torch.prim.ListConstruct %int1_300, %int0_301 : (!torch.int, !torch.int) -> !torch.list - %188 = torch.aten.permute %arg53, %187 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int0_302 = torch.constant.int 0 - %int0_303 = torch.constant.int 0 - %int2_304 = torch.constant.int 2 - %int1_305 = torch.constant.int 1 - %189 = torch.aten.slice.Tensor %result0_287, %int0_302, %int0_303, %int2_304, %int1_305 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,512,32],f32> - %int0_306 = torch.constant.int 0 - %int0_307 = torch.constant.int 0 - %int2_308 = torch.constant.int 2 - %int1_309 = torch.constant.int 1 - %190 = torch.aten.slice.Tensor %189, %int0_306, %int0_307, %int2_308, %int1_309 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,512,32],f32> - %int1_310 = torch.constant.int 1 - %int0_311 = torch.constant.int 0 - %int1_312 = torch.constant.int 1 - %int1_313 = torch.constant.int 1 - %191 = torch.aten.slice.Tensor %190, %int1_310, %int0_311, %int1_312, %int1_313 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1,32],f32> - %int2_314 = torch.constant.int 2 - %int0_315 = torch.constant.int 0 - %int32_316 = torch.constant.int 32 - %int1_317 = torch.constant.int 1 - %192 = torch.aten.slice.Tensor %191, %int2_314, %int0_315, %int32_316, %int1_317 : !torch.vtensor<[2,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1,32],f32> - %int2_318 = torch.constant.int 2 - %int32_319 = torch.constant.int 32 - %193 = torch.prim.ListConstruct %int2_318, %int32_319 : (!torch.int, !torch.int) -> !torch.list - %int2_320 = torch.constant.int 2 - %int32_321 = torch.constant.int 32 - %194 = torch.prim.ListConstruct %int2_320, %int32_321 : (!torch.int, !torch.int) -> !torch.list - %195 = torch.aten.reshape %192, %194 : !torch.vtensor<[2,1,32],f32>, !torch.list -> !torch.vtensor<[2,32],f32> - %196 = torch.aten.addmm %arg56, %195, %188, %arg55, %arg54 : !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,32],f32> - %197 = torch.aten.tanh %196 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> - %int1_322 = torch.constant.int 1 - %int0_323 = torch.constant.int 0 - %198 = torch.prim.ListConstruct %int1_322, %int0_323 : (!torch.int, !torch.int) -> !torch.list - %int1_324 = torch.constant.int 1 - %int0_325 = torch.constant.int 0 - %199 = torch.prim.ListConstruct %int1_324, %int0_325 : (!torch.int, !torch.int) -> !torch.list - %200 = torch.aten.permute %arg57, %199 : !torch.vtensor<[2,32],f32>, !torch.list -> !torch.vtensor<[32,2],f32> - %int1_326 = torch.constant.int 1 - %int0_327 = torch.constant.int 0 - %201 = torch.prim.ListConstruct %int1_326, %int0_327 : (!torch.int, !torch.int) -> !torch.list - %int1_328 = torch.constant.int 1 - %int0_329 = torch.constant.int 0 - %202 = torch.prim.ListConstruct %int1_328, %int0_329 : (!torch.int, !torch.int) -> !torch.list - %203 = torch.aten.permute %200, %202 : !torch.vtensor<[32,2],f32>, !torch.list -> !torch.vtensor<[2,32],f32> - %int1_330 = torch.constant.int 1 - %int0_331 = torch.constant.int 0 - %204 = torch.prim.ListConstruct %int1_330, %int0_331 : (!torch.int, !torch.int) -> !torch.list - %int1_332 = torch.constant.int 1 - %int0_333 = torch.constant.int 0 - %205 = torch.prim.ListConstruct %int1_332, %int0_333 : (!torch.int, !torch.int) -> !torch.list - %206 = torch.aten.permute %arg57, %205 : !torch.vtensor<[2,32],f32>, !torch.list -> !torch.vtensor<[32,2],f32> - %207 = torch.aten.addmm %arg60, %197, %206, %arg59, %arg58 : !torch.vtensor<[2],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,2],f32> - %int2_334 = torch.constant.int 2 - %208 = torch.prim.ListConstruct %int2_334, %int2_334 : (!torch.int, !torch.int) -> !torch.list - %int2_335 = torch.constant.int 2 - %209 = torch.prim.ListConstruct %int2_335, %int2_335 : (!torch.int, !torch.int) -> !torch.list - %210 = torch.aten.reshape %207, %209 : !torch.vtensor<[2,2],f32>, !torch.list -> !torch.vtensor<[2,2],f32> - %int1_336 = torch.constant.int 1 - %false_337 = torch.constant.bool false - %211 = torch.aten._log_softmax %210, %int1_336, %false_337 : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,2],f32> - %int2_338 = torch.constant.int 2 - %212 = torch.prim.ListConstruct %int2_338 : (!torch.int) -> !torch.list - %int2_339 = torch.constant.int 2 - %213 = torch.prim.ListConstruct %int2_339 : (!torch.int) -> !torch.list - %214 = torch.aten.reshape %arg61, %213 : !torch.vtensor<[2],si64>, !torch.list -> !torch.vtensor<[2],si64> - %none_340 = torch.constant.none - %int1_341 = torch.constant.int 1 - %int-100 = torch.constant.int -100 - %output, %total_weight = torch.aten.nll_loss_forward %211, %214, %none_340, %int1_341, %int-100 : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32> - %215 = torch.prim.TupleConstruct %output, %total_weight : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.tuple, vtensor<[],f32>> - %none_342 = torch.constant.none - %int1_343 = torch.constant.int 1 - %int-100_344 = torch.constant.int -100 - %216 = torch.aten.nll_loss_backward %arg62, %211, %214, %none_342, %int1_343, %int-100_344, %total_weight : !torch.vtensor<[],f32>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[2],si64>, !torch.none, !torch.int, !torch.int, !torch.vtensor<[],f32> -> !torch.vtensor<[2,2],f32> - %int1_345 = torch.constant.int 1 - %int6_346 = torch.constant.int 6 - %217 = torch.aten._log_softmax_backward_data %216, %211, %int1_345, %int6_346 : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,2],f32> - %int2_347 = torch.constant.int 2 - %218 = torch.prim.ListConstruct %int2_347, %int2_347 : (!torch.int, !torch.int) -> !torch.list - %int2_348 = torch.constant.int 2 - %219 = torch.prim.ListConstruct %int2_348, %int2_348 : (!torch.int, !torch.int) -> !torch.list - %220 = torch.aten.reshape %217, %219 : !torch.vtensor<[2,2],f32>, !torch.list -> !torch.vtensor<[2,2],f32> - %221 = torch.aten.mm %220, %203 : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> - %222 = torch.aten.tanh_backward %221, %197 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> - %223 = torch.aten.mm %222, %185 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[2,32],f32> - %int2_349 = torch.constant.int 2 - %int1_350 = torch.constant.int 1 - %int32_351 = torch.constant.int 32 - %224 = torch.prim.ListConstruct %int2_349, %int1_350, %int32_351 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_352 = torch.constant.int 2 - %int1_353 = torch.constant.int 1 - %int32_354 = torch.constant.int 32 - %225 = torch.prim.ListConstruct %int2_352, %int1_353, %int32_354 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %226 = torch.aten.reshape %223, %225 : !torch.vtensor<[2,32],f32>, !torch.list -> !torch.vtensor<[2,1,32],f32> - %227 = torch.aten.zero.functional %arg63 : !torch.vtensor<[2,512,32],f32> -> !torch.vtensor<[2,512,32],f32> - %none_355 = torch.constant.none - %228 = torch.aten.clone %227, %none_355 : !torch.vtensor<[2,512,32],f32>, !torch.none -> !torch.vtensor<[2,512,32],f32> - %int0_356 = torch.constant.int 0 - %int0_357 = torch.constant.int 0 - %int2_358 = torch.constant.int 2 - %int1_359 = torch.constant.int 1 - %229 = torch.aten.slice.Tensor %228, %int0_356, %int0_357, %int2_358, %int1_359 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,512,32],f32> - %int1_360 = torch.constant.int 1 - %int0_361 = torch.constant.int 0 - %int1_362 = torch.constant.int 1 - %int1_363 = torch.constant.int 1 - %230 = torch.aten.slice.Tensor %229, %int1_360, %int0_361, %int1_362, %int1_363 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1,32],f32> - %int2_364 = torch.constant.int 2 - %int0_365 = torch.constant.int 0 - %int32_366 = torch.constant.int 32 - %int1_367 = torch.constant.int 1 - %231 = torch.aten.slice.Tensor %230, %int2_364, %int0_365, %int32_366, %int1_367 : !torch.vtensor<[2,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1,32],f32> - %false_368 = torch.constant.bool false - %232 = torch.aten.copy_ %231, %226, %false_368 : !torch.vtensor<[2,1,32],f32>, !torch.vtensor<[2,1,32],f32>, !torch.bool -> !torch.vtensor<[2,1,32],f32> - %233 = torch.aten.zero.functional %arg64 : !torch.vtensor<[2,512,32],f32> -> !torch.vtensor<[2,512,32],f32> - %none_369 = torch.constant.none - %234 = torch.aten.clone %233, %none_369 : !torch.vtensor<[2,512,32],f32>, !torch.none -> !torch.vtensor<[2,512,32],f32> - %int0_370 = torch.constant.int 0 - %int0_371 = torch.constant.int 0 - %int2_372 = torch.constant.int 2 - %int1_373 = torch.constant.int 1 - %235 = torch.aten.slice.Tensor %234, %int0_370, %int0_371, %int2_372, %int1_373 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,512,32],f32> - %false_374 = torch.constant.bool false - %236 = torch.aten.copy_ %235, %228, %false_374 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.bool -> !torch.vtensor<[2,512,32],f32> - %int32_375 = torch.constant.int 32 - %237 = torch.prim.ListConstruct %int32_375 : (!torch.int) -> !torch.list - %true = torch.constant.bool true - %238 = torch.prim.ListConstruct %true, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list - %result0_376, %result1_377, %result2_378 = torch.aten.native_layer_norm_backward %234, %177, %237, %result1_288, %result2_289, %arg48, %arg47, %238 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> - %239 = torch.prim.TupleConstruct %result0_376, %result1_377, %result2_378 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> -> !torch.tuple, vtensor<[32],f32>, vtensor<[32],f32>> - %int1024_379 = torch.constant.int 1024 - %int32_380 = torch.constant.int 32 - %240 = torch.prim.ListConstruct %int1024_379, %int32_380 : (!torch.int, !torch.int) -> !torch.list - %int1024_381 = torch.constant.int 1024 - %int32_382 = torch.constant.int 32 - %241 = torch.prim.ListConstruct %int1024_381, %int32_382 : (!torch.int, !torch.int) -> !torch.list - %242 = torch.aten.reshape %result0_376, %241 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %243 = torch.aten.mm %242, %165 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> - %int2_383 = torch.constant.int 2 - %int512_384 = torch.constant.int 512 - %int32_385 = torch.constant.int 32 - %244 = torch.prim.ListConstruct %int2_383, %int512_384, %int32_385 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_386 = torch.constant.int 2 - %int512_387 = torch.constant.int 512 - %int32_388 = torch.constant.int 32 - %245 = torch.prim.ListConstruct %int2_386, %int512_387, %int32_388 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %246 = torch.aten.reshape %243, %245 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %str_389 = torch.constant.str "none" - %247 = torch.aten.gelu_backward %246, %159, %str_389 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.str -> !torch.vtensor<[2,512,32],f32> - %int1024_390 = torch.constant.int 1024 - %int32_391 = torch.constant.int 32 - %248 = torch.prim.ListConstruct %int1024_390, %int32_391 : (!torch.int, !torch.int) -> !torch.list - %int1024_392 = torch.constant.int 1024 - %int32_393 = torch.constant.int 32 - %249 = torch.prim.ListConstruct %int1024_392, %int32_393 : (!torch.int, !torch.int) -> !torch.list - %250 = torch.aten.reshape %247, %249 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %251 = torch.aten.mm %250, %149 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> - %int2_394 = torch.constant.int 2 - %int512_395 = torch.constant.int 512 - %int32_396 = torch.constant.int 32 - %252 = torch.prim.ListConstruct %int2_394, %int512_395, %int32_396 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_397 = torch.constant.int 2 - %int512_398 = torch.constant.int 512 - %int32_399 = torch.constant.int 32 - %253 = torch.prim.ListConstruct %int2_397, %int512_398, %int32_399 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %254 = torch.aten.reshape %251, %253 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %255 = torch.aten.add.Tensor %result0_376, %254, %arg41 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> - %int32_400 = torch.constant.int 32 - %256 = torch.prim.ListConstruct %int32_400 : (!torch.int) -> !torch.list - %true_401 = torch.constant.bool true - %257 = torch.prim.ListConstruct %true_401, %true_401, %true_401 : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list - %result0_402, %result1_403, %result2_404 = torch.aten.native_layer_norm_backward %255, %141, %256, %result1_239, %result2_240, %arg36, %arg35, %257 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> - %258 = torch.prim.TupleConstruct %result0_402, %result1_403, %result2_404 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> -> !torch.tuple, vtensor<[32],f32>, vtensor<[32],f32>> - %int1024_405 = torch.constant.int 1024 - %int32_406 = torch.constant.int 32 - %259 = torch.prim.ListConstruct %int1024_405, %int32_406 : (!torch.int, !torch.int) -> !torch.list - %int1024_407 = torch.constant.int 1024 - %int32_408 = torch.constant.int 32 - %260 = torch.prim.ListConstruct %int1024_407, %int32_408 : (!torch.int, !torch.int) -> !torch.list - %261 = torch.aten.reshape %result0_402, %260 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %262 = torch.aten.mm %261, %112 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> - %int2_409 = torch.constant.int 2 - %int512_410 = torch.constant.int 512 - %int32_411 = torch.constant.int 32 - %263 = torch.prim.ListConstruct %int2_409, %int512_410, %int32_411 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_412 = torch.constant.int 2 - %int512_413 = torch.constant.int 512 - %int32_414 = torch.constant.int 32 - %264 = torch.prim.ListConstruct %int2_412, %int512_413, %int32_414 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %265 = torch.aten.reshape %262, %264 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int2_415 = torch.constant.int 2 - %int512_416 = torch.constant.int 512 - %int16_417 = torch.constant.int 16 - %266 = torch.prim.ListConstruct %int2_415, %int512_416, %int2_415, %int16_417 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_418 = torch.constant.int 2 - %int512_419 = torch.constant.int 512 - %int16_420 = torch.constant.int 16 - %267 = torch.prim.ListConstruct %int2_418, %int512_419, %int2_418, %int16_420 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %268 = torch.aten.reshape %265, %267 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> - %int0_421 = torch.constant.int 0 - %int2_422 = torch.constant.int 2 - %int1_423 = torch.constant.int 1 - %int3_424 = torch.constant.int 3 - %269 = torch.prim.ListConstruct %int0_421, %int2_422, %int1_423, %int3_424 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_425 = torch.constant.int 0 - %int2_426 = torch.constant.int 2 - %int1_427 = torch.constant.int 1 - %int3_428 = torch.constant.int 3 - %270 = torch.prim.ListConstruct %int0_425, %int2_426, %int1_427, %int3_428 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %271 = torch.aten.permute %268, %270 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> - %int4_429 = torch.constant.int 4 - %int512_430 = torch.constant.int 512 - %int16_431 = torch.constant.int 16 - %272 = torch.prim.ListConstruct %int4_429, %int512_430, %int16_431 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int4_432 = torch.constant.int 4 - %int512_433 = torch.constant.int 512 - %int16_434 = torch.constant.int 16 - %273 = torch.prim.ListConstruct %int4_432, %int512_433, %int16_434 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %274 = torch.aten.reshape %271, %273 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[4,512,16],f32> - %275 = torch.aten.bmm %274, %106 : !torch.vtensor<[4,512,16],f32>, !torch.vtensor<[4,16,512],f32> -> !torch.vtensor<[4,512,512],f32> - %int2_435 = torch.constant.int 2 - %int512_436 = torch.constant.int 512 - %276 = torch.prim.ListConstruct %int2_435, %int2_435, %int512_436, %int512_436 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_437 = torch.constant.int 2 - %int512_438 = torch.constant.int 512 - %277 = torch.prim.ListConstruct %int2_437, %int2_437, %int512_438, %int512_438 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %278 = torch.aten.reshape %275, %277 : !torch.vtensor<[4,512,512],f32>, !torch.list -> !torch.vtensor<[2,2,512,512],f32> - %int-1_439 = torch.constant.int -1 - %int6_440 = torch.constant.int 6 - %279 = torch.aten._softmax_backward_data %278, %82, %int-1_439, %int6_440 : !torch.vtensor<[2,2,512,512],f32>, !torch.vtensor<[2,2,512,512],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,2,512,512],f32> - %280 = torch.aten.div.Tensor %279, %arg21 : !torch.vtensor<[2,2,512,512],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,2,512,512],f32> - %int4_441 = torch.constant.int 4 - %int512_442 = torch.constant.int 512 - %281 = torch.prim.ListConstruct %int4_441, %int512_442, %int512_442 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int4_443 = torch.constant.int 4 - %int512_444 = torch.constant.int 512 - %282 = torch.prim.ListConstruct %int4_443, %int512_444, %int512_444 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %283 = torch.aten.reshape %280, %282 : !torch.vtensor<[2,2,512,512],f32>, !torch.list -> !torch.vtensor<[4,512,512],f32> - %284 = torch.aten.bmm %283, %40 : !torch.vtensor<[4,512,512],f32>, !torch.vtensor<[4,512,16],f32> -> !torch.vtensor<[4,512,16],f32> - %int2_445 = torch.constant.int 2 - %int512_446 = torch.constant.int 512 - %int16_447 = torch.constant.int 16 - %285 = torch.prim.ListConstruct %int2_445, %int2_445, %int512_446, %int16_447 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_448 = torch.constant.int 2 - %int512_449 = torch.constant.int 512 - %int16_450 = torch.constant.int 16 - %286 = torch.prim.ListConstruct %int2_448, %int2_448, %int512_449, %int16_450 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %287 = torch.aten.reshape %284, %286 : !torch.vtensor<[4,512,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> - %int0_451 = torch.constant.int 0 - %int2_452 = torch.constant.int 2 - %int1_453 = torch.constant.int 1 - %int3_454 = torch.constant.int 3 - %288 = torch.prim.ListConstruct %int0_451, %int2_452, %int1_453, %int3_454 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_455 = torch.constant.int 0 - %int2_456 = torch.constant.int 2 - %int1_457 = torch.constant.int 1 - %int3_458 = torch.constant.int 3 - %289 = torch.prim.ListConstruct %int0_455, %int2_456, %int1_457, %int3_458 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %290 = torch.aten.permute %287, %289 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> - %int2_459 = torch.constant.int 2 - %int512_460 = torch.constant.int 512 - %int32_461 = torch.constant.int 32 - %291 = torch.prim.ListConstruct %int2_459, %int512_460, %int32_461 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_462 = torch.constant.int 2 - %int512_463 = torch.constant.int 512 - %int32_464 = torch.constant.int 32 - %292 = torch.prim.ListConstruct %int2_462, %int512_463, %int32_464 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %293 = torch.aten.reshape %290, %292 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int1024_465 = torch.constant.int 1024 - %int32_466 = torch.constant.int 32 - %294 = torch.prim.ListConstruct %int1024_465, %int32_466 : (!torch.int, !torch.int) -> !torch.list - %int1024_467 = torch.constant.int 1024 - %int32_468 = torch.constant.int 32 - %295 = torch.prim.ListConstruct %int1024_467, %int32_468 : (!torch.int, !torch.int) -> !torch.list - %296 = torch.aten.reshape %293, %295 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %297 = torch.aten.mm %296, %13 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> - %int2_469 = torch.constant.int 2 - %int512_470 = torch.constant.int 512 - %int32_471 = torch.constant.int 32 - %298 = torch.prim.ListConstruct %int2_469, %int512_470, %int32_471 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_472 = torch.constant.int 2 - %int512_473 = torch.constant.int 512 - %int32_474 = torch.constant.int 32 - %299 = torch.prim.ListConstruct %int2_472, %int512_473, %int32_474 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %300 = torch.aten.reshape %297, %299 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int1_475 = torch.constant.int 1 - %int0_476 = torch.constant.int 0 - %301 = torch.prim.ListConstruct %int1_475, %int0_476 : (!torch.int, !torch.int) -> !torch.list - %int1_477 = torch.constant.int 1 - %int0_478 = torch.constant.int 0 - %302 = torch.prim.ListConstruct %int1_477, %int0_478 : (!torch.int, !torch.int) -> !torch.list - %303 = torch.aten.permute %arg19, %302 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_479 = torch.constant.int 1 - %int0_480 = torch.constant.int 0 - %304 = torch.prim.ListConstruct %int1_479, %int0_480 : (!torch.int, !torch.int) -> !torch.list - %int1_481 = torch.constant.int 1 - %int0_482 = torch.constant.int 0 - %305 = torch.prim.ListConstruct %int1_481, %int0_482 : (!torch.int, !torch.int) -> !torch.list - %306 = torch.aten.permute %303, %305 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int4_483 = torch.constant.int 4 - %int512_484 = torch.constant.int 512 - %int16_485 = torch.constant.int 16 - %307 = torch.prim.ListConstruct %int4_483, %int512_484, %int16_485 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int4_486 = torch.constant.int 4 - %int512_487 = torch.constant.int 512 - %int16_488 = torch.constant.int 16 - %308 = torch.prim.ListConstruct %int4_486, %int512_487, %int16_488 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %309 = torch.aten.reshape %72, %308 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[4,512,16],f32> - %int0_489 = torch.constant.int 0 - %int2_490 = torch.constant.int 2 - %int1_491 = torch.constant.int 1 - %310 = torch.prim.ListConstruct %int0_489, %int2_490, %int1_491 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int0_492 = torch.constant.int 0 - %int2_493 = torch.constant.int 2 - %int1_494 = torch.constant.int 1 - %311 = torch.prim.ListConstruct %int0_492, %int2_493, %int1_494 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %312 = torch.aten.permute %309, %311 : !torch.vtensor<[4,512,16],f32>, !torch.list -> !torch.vtensor<[4,16,512],f32> - %313 = torch.aten.bmm %312, %283 : !torch.vtensor<[4,16,512],f32>, !torch.vtensor<[4,512,512],f32> -> !torch.vtensor<[4,16,512],f32> - %int2_495 = torch.constant.int 2 - %int16_496 = torch.constant.int 16 - %int512_497 = torch.constant.int 512 - %314 = torch.prim.ListConstruct %int2_495, %int2_495, %int16_496, %int512_497 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_498 = torch.constant.int 2 - %int16_499 = torch.constant.int 16 - %int512_500 = torch.constant.int 512 - %315 = torch.prim.ListConstruct %int2_498, %int2_498, %int16_499, %int512_500 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %316 = torch.aten.reshape %313, %315 : !torch.vtensor<[4,16,512],f32>, !torch.list -> !torch.vtensor<[2,2,16,512],f32> - %int0_501 = torch.constant.int 0 - %int1_502 = torch.constant.int 1 - %int3_503 = torch.constant.int 3 - %int2_504 = torch.constant.int 2 - %317 = torch.prim.ListConstruct %int0_501, %int1_502, %int3_503, %int2_504 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_505 = torch.constant.int 0 - %int1_506 = torch.constant.int 1 - %int3_507 = torch.constant.int 3 - %int2_508 = torch.constant.int 2 - %318 = torch.prim.ListConstruct %int0_505, %int1_506, %int3_507, %int2_508 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %319 = torch.aten.permute %316, %318 : !torch.vtensor<[2,2,16,512],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> - %int0_509 = torch.constant.int 0 - %int2_510 = torch.constant.int 2 - %int1_511 = torch.constant.int 1 - %int3_512 = torch.constant.int 3 - %320 = torch.prim.ListConstruct %int0_509, %int2_510, %int1_511, %int3_512 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_513 = torch.constant.int 0 - %int2_514 = torch.constant.int 2 - %int1_515 = torch.constant.int 1 - %int3_516 = torch.constant.int 3 - %321 = torch.prim.ListConstruct %int0_513, %int2_514, %int1_515, %int3_516 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %322 = torch.aten.permute %319, %321 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> - %int2_517 = torch.constant.int 2 - %int512_518 = torch.constant.int 512 - %int32_519 = torch.constant.int 32 - %323 = torch.prim.ListConstruct %int2_517, %int512_518, %int32_519 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_520 = torch.constant.int 2 - %int512_521 = torch.constant.int 512 - %int32_522 = torch.constant.int 32 - %324 = torch.prim.ListConstruct %int2_520, %int512_521, %int32_522 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %325 = torch.aten.reshape %322, %324 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int1024_523 = torch.constant.int 1024 - %int32_524 = torch.constant.int 32 - %326 = torch.prim.ListConstruct %int1024_523, %int32_524 : (!torch.int, !torch.int) -> !torch.list - %int1024_525 = torch.constant.int 1024 - %int32_526 = torch.constant.int 32 - %327 = torch.prim.ListConstruct %int1024_525, %int32_526 : (!torch.int, !torch.int) -> !torch.list - %328 = torch.aten.reshape %325, %327 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %329 = torch.aten.mm %328, %306 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> - %int2_527 = torch.constant.int 2 - %int512_528 = torch.constant.int 512 - %int32_529 = torch.constant.int 32 - %330 = torch.prim.ListConstruct %int2_527, %int512_528, %int32_529 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_530 = torch.constant.int 2 - %int512_531 = torch.constant.int 512 - %int32_532 = torch.constant.int 32 - %331 = torch.prim.ListConstruct %int2_530, %int512_531, %int32_532 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %332 = torch.aten.reshape %329, %331 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int1_533 = torch.constant.int 1 - %int0_534 = torch.constant.int 0 - %333 = torch.prim.ListConstruct %int1_533, %int0_534 : (!torch.int, !torch.int) -> !torch.list - %int1_535 = torch.constant.int 1 - %int0_536 = torch.constant.int 0 - %334 = torch.prim.ListConstruct %int1_535, %int0_536 : (!torch.int, !torch.int) -> !torch.list - %335 = torch.aten.permute %arg32, %334 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int1_537 = torch.constant.int 1 - %int0_538 = torch.constant.int 0 - %336 = torch.prim.ListConstruct %int1_537, %int0_538 : (!torch.int, !torch.int) -> !torch.list - %int1_539 = torch.constant.int 1 - %int0_540 = torch.constant.int 0 - %337 = torch.prim.ListConstruct %int1_539, %int0_540 : (!torch.int, !torch.int) -> !torch.list - %338 = torch.aten.permute %335, %337 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %int4_541 = torch.constant.int 4 - %int512_542 = torch.constant.int 512 - %339 = torch.prim.ListConstruct %int4_541, %int512_542, %int512_542 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int4_543 = torch.constant.int 4 - %int512_544 = torch.constant.int 512 - %340 = torch.prim.ListConstruct %int4_543, %int512_544, %int512_544 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %341 = torch.aten.reshape %120, %340 : !torch.vtensor<[2,2,512,512],f32>, !torch.list -> !torch.vtensor<[4,512,512],f32> - %int0_545 = torch.constant.int 0 - %int2_546 = torch.constant.int 2 - %int1_547 = torch.constant.int 1 - %342 = torch.prim.ListConstruct %int0_545, %int2_546, %int1_547 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int0_548 = torch.constant.int 0 - %int2_549 = torch.constant.int 2 - %int1_550 = torch.constant.int 1 - %343 = torch.prim.ListConstruct %int0_548, %int2_549, %int1_550 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %344 = torch.aten.permute %341, %343 : !torch.vtensor<[4,512,512],f32>, !torch.list -> !torch.vtensor<[4,512,512],f32> - %345 = torch.aten.bmm %344, %274 : !torch.vtensor<[4,512,512],f32>, !torch.vtensor<[4,512,16],f32> -> !torch.vtensor<[4,512,16],f32> - %int2_551 = torch.constant.int 2 - %int512_552 = torch.constant.int 512 - %int16_553 = torch.constant.int 16 - %346 = torch.prim.ListConstruct %int2_551, %int2_551, %int512_552, %int16_553 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_554 = torch.constant.int 2 - %int512_555 = torch.constant.int 512 - %int16_556 = torch.constant.int 16 - %347 = torch.prim.ListConstruct %int2_554, %int2_554, %int512_555, %int16_556 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %348 = torch.aten.reshape %345, %347 : !torch.vtensor<[4,512,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> - %int0_557 = torch.constant.int 0 - %int2_558 = torch.constant.int 2 - %int1_559 = torch.constant.int 1 - %int3_560 = torch.constant.int 3 - %349 = torch.prim.ListConstruct %int0_557, %int2_558, %int1_559, %int3_560 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_561 = torch.constant.int 0 - %int2_562 = torch.constant.int 2 - %int1_563 = torch.constant.int 1 - %int3_564 = torch.constant.int 3 - %350 = torch.prim.ListConstruct %int0_561, %int2_562, %int1_563, %int3_564 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %351 = torch.aten.permute %348, %350 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> - %int2_565 = torch.constant.int 2 - %int512_566 = torch.constant.int 512 - %int32_567 = torch.constant.int 32 - %352 = torch.prim.ListConstruct %int2_565, %int512_566, %int32_567 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_568 = torch.constant.int 2 - %int512_569 = torch.constant.int 512 - %int32_570 = torch.constant.int 32 - %353 = torch.prim.ListConstruct %int2_568, %int512_569, %int32_570 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %354 = torch.aten.reshape %351, %353 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int1024_571 = torch.constant.int 1024 - %int32_572 = torch.constant.int 32 - %355 = torch.prim.ListConstruct %int1024_571, %int32_572 : (!torch.int, !torch.int) -> !torch.list - %int1024_573 = torch.constant.int 1024 - %int32_574 = torch.constant.int 32 - %356 = torch.prim.ListConstruct %int1024_573, %int32_574 : (!torch.int, !torch.int) -> !torch.list - %357 = torch.aten.reshape %354, %356 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %358 = torch.aten.mm %357, %338 : !torch.vtensor<[1024,32],f32>, !torch.vtensor<[32,32],f32> -> !torch.vtensor<[1024,32],f32> - %int2_575 = torch.constant.int 2 - %int512_576 = torch.constant.int 512 - %int32_577 = torch.constant.int 32 - %359 = torch.prim.ListConstruct %int2_575, %int512_576, %int32_577 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_578 = torch.constant.int 2 - %int512_579 = torch.constant.int 512 - %int32_580 = torch.constant.int 32 - %360 = torch.prim.ListConstruct %int2_578, %int512_579, %int32_580 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %361 = torch.aten.reshape %358, %360 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %362 = torch.aten.add.Tensor %result0_402, %361, %arg66 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> - %363 = torch.aten.add.Tensor %362, %332, %arg65 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> - %364 = torch.aten.add.Tensor %363, %300, %arg15 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.int -> !torch.vtensor<[2,512,32],f32> - %int32_581 = torch.constant.int 32 - %365 = torch.prim.ListConstruct %int32_581 : (!torch.int) -> !torch.list - %true_582 = torch.constant.bool true - %366 = torch.prim.ListConstruct %true_582, %true_582, %true_582 : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list - %result0_583, %result1_584, %result2_585 = torch.aten.native_layer_norm_backward %364, %5, %365, %result1, %result2, %arg7, %arg6, %366 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[2,512,1],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> - %367 = torch.prim.TupleConstruct %result0_583, %result1_584, %result2_585 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32> -> !torch.tuple, vtensor<[32],f32>, vtensor<[32],f32>> - %int28996 = torch.constant.int 28996 - %int0_586 = torch.constant.int 0 - %false_587 = torch.constant.bool false - %368 = torch.aten.embedding_dense_backward %result0_583, %arg5, %int28996, %int0_586, %false_587 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512],si64>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[28996,32],f32> - %369 = torch.aten.add.Tensor %arg67, %368, %arg4 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.int -> !torch.vtensor<[28996,32],f32> - %370 = torch.aten.mul.Tensor %arg69, %arg68 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[28996,32],f32> - %371 = torch.aten.addcmul %370, %369, %369, %arg3 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.float -> !torch.vtensor<[28996,32],f32> - %372 = torch.aten.sqrt %371 : !torch.vtensor<[28996,32],f32> -> !torch.vtensor<[28996,32],f32> - %373 = torch.aten.add.Tensor %372, %arg2, %arg1 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[28996,32],f32> - %374 = torch.aten.mul.Tensor %arg72, %arg71 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[28996,32],f32> - %375 = torch.aten.add.Tensor %374, %369, %arg70 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.float -> !torch.vtensor<[28996,32],f32> - %376 = torch.aten.addcdiv %arg14, %375, %373, %arg0 : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.float -> !torch.vtensor<[28996,32],f32> - %int0_588 = torch.constant.int 0 - %377 = torch.prim.ListConstruct %int0_588 : (!torch.int) -> !torch.list - %true_589 = torch.constant.bool true - %none_590 = torch.constant.none - %378 = torch.aten.sum.dim_IntList %result0_583, %377, %true_589, %none_590 : !torch.vtensor<[2,512,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,512,32],f32> - %int512_591 = torch.constant.int 512 - %int-1_592 = torch.constant.int -1 - %false_593 = torch.constant.bool false - %379 = torch.aten.embedding_dense_backward %378, %0, %int512_591, %int-1_592, %false_593 : !torch.vtensor<[1,512,32],f32>, !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[512,32],f32> - %380 = torch.aten.add.Tensor %arg77, %379, %arg76 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.int -> !torch.vtensor<[512,32],f32> - %381 = torch.aten.mul.Tensor %arg78, %arg68 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[512,32],f32> - %382 = torch.aten.addcmul %381, %380, %380, %arg75 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.float -> !torch.vtensor<[512,32],f32> - %383 = torch.aten.sqrt %382 : !torch.vtensor<[512,32],f32> -> !torch.vtensor<[512,32],f32> - %384 = torch.aten.add.Tensor %383, %arg2, %arg74 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[512,32],f32> - %385 = torch.aten.mul.Tensor %arg80, %arg71 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[512,32],f32> - %386 = torch.aten.add.Tensor %385, %380, %arg79 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.float -> !torch.vtensor<[512,32],f32> - %387 = torch.aten.addcdiv %arg10, %386, %384, %arg73 : !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.float -> !torch.vtensor<[512,32],f32> - %int2_594 = torch.constant.int 2 - %int-1_595 = torch.constant.int -1 - %false_596 = torch.constant.bool false - %388 = torch.aten.embedding_dense_backward %result0_583, %arg12, %int2_594, %int-1_595, %false_596 : !torch.vtensor<[2,512,32],f32>, !torch.vtensor<[2,512],si64>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[2,32],f32> - %389 = torch.aten.add.Tensor %arg85, %388, %arg84 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.int -> !torch.vtensor<[2,32],f32> - %390 = torch.aten.mul.Tensor %arg86, %arg68 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,32],f32> - %391 = torch.aten.addcmul %390, %389, %389, %arg83 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> - %392 = torch.aten.sqrt %391 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> - %393 = torch.aten.add.Tensor %392, %arg2, %arg82 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[2,32],f32> - %394 = torch.aten.mul.Tensor %arg88, %arg71 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,32],f32> - %395 = torch.aten.add.Tensor %394, %389, %arg87 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> - %396 = torch.aten.addcdiv %arg13, %395, %393, %arg81 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> - %397 = torch.aten.add.Tensor %arg93, %result1_584, %arg92 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %398 = torch.aten.mul.Tensor %arg94, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %399 = torch.aten.addcmul %398, %397, %397, %arg91 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %400 = torch.aten.sqrt %399 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %401 = torch.aten.add.Tensor %400, %arg2, %arg90 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %402 = torch.aten.mul.Tensor %arg96, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %403 = torch.aten.add.Tensor %402, %397, %arg95 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %404 = torch.aten.addcdiv %arg7, %403, %401, %arg89 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %405 = torch.aten.add.Tensor %arg101, %result2_585, %arg100 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %406 = torch.aten.mul.Tensor %arg102, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %407 = torch.aten.addcmul %406, %405, %405, %arg99 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %408 = torch.aten.sqrt %407 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %409 = torch.aten.add.Tensor %408, %arg2, %arg98 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %410 = torch.aten.mul.Tensor %arg104, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %411 = torch.aten.add.Tensor %410, %405, %arg103 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %412 = torch.aten.addcdiv %arg6, %411, %409, %arg97 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %int1024_597 = torch.constant.int 1024 - %int32_598 = torch.constant.int 32 - %413 = torch.prim.ListConstruct %int1024_597, %int32_598 : (!torch.int, !torch.int) -> !torch.list - %int1024_599 = torch.constant.int 1024 - %int32_600 = torch.constant.int 32 - %414 = torch.prim.ListConstruct %int1024_599, %int32_600 : (!torch.int, !torch.int) -> !torch.list - %415 = torch.aten.reshape %result0, %414 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %int1_601 = torch.constant.int 1 - %int0_602 = torch.constant.int 0 - %416 = torch.prim.ListConstruct %int1_601, %int0_602 : (!torch.int, !torch.int) -> !torch.list - %int1_603 = torch.constant.int 1 - %int0_604 = torch.constant.int 0 - %417 = torch.prim.ListConstruct %int1_603, %int0_604 : (!torch.int, !torch.int) -> !torch.list - %418 = torch.aten.permute %415, %417 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> - %419 = torch.aten.mm %418, %296 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> - %int1_605 = torch.constant.int 1 - %int0_606 = torch.constant.int 0 - %420 = torch.prim.ListConstruct %int1_605, %int0_606 : (!torch.int, !torch.int) -> !torch.list - %int1_607 = torch.constant.int 1 - %int0_608 = torch.constant.int 0 - %421 = torch.prim.ListConstruct %int1_607, %int0_608 : (!torch.int, !torch.int) -> !torch.list - %422 = torch.aten.permute %419, %421 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %423 = torch.aten.add.Tensor %arg109, %422, %arg108 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> - %424 = torch.aten.mul.Tensor %arg110, %arg68 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %425 = torch.aten.addcmul %424, %423, %423, %arg107 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %426 = torch.aten.sqrt %425 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %427 = torch.aten.add.Tensor %426, %arg2, %arg106 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> - %428 = torch.aten.mul.Tensor %arg112, %arg71 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %429 = torch.aten.add.Tensor %428, %423, %arg111 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %430 = torch.aten.addcdiv %arg16, %429, %427, %arg105 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %int0_609 = torch.constant.int 0 - %431 = torch.prim.ListConstruct %int0_609 : (!torch.int) -> !torch.list - %true_610 = torch.constant.bool true - %none_611 = torch.constant.none - %432 = torch.aten.sum.dim_IntList %296, %431, %true_610, %none_611 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> - %int32_612 = torch.constant.int 32 - %433 = torch.prim.ListConstruct %int32_612 : (!torch.int) -> !torch.list - %int32_613 = torch.constant.int 32 - %434 = torch.prim.ListConstruct %int32_613 : (!torch.int) -> !torch.list - %435 = torch.aten.reshape %432, %434 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> - %436 = torch.aten.add.Tensor %arg117, %435, %arg116 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %437 = torch.aten.mul.Tensor %arg118, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %438 = torch.aten.addcmul %437, %436, %436, %arg115 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %439 = torch.aten.sqrt %438 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %440 = torch.aten.add.Tensor %439, %arg2, %arg114 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %441 = torch.aten.mul.Tensor %arg120, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %442 = torch.aten.add.Tensor %441, %436, %arg119 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %443 = torch.aten.addcdiv %arg29, %442, %440, %arg113 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %int1024_614 = torch.constant.int 1024 - %int32_615 = torch.constant.int 32 - %444 = torch.prim.ListConstruct %int1024_614, %int32_615 : (!torch.int, !torch.int) -> !torch.list - %int1024_616 = torch.constant.int 1024 - %int32_617 = torch.constant.int 32 - %445 = torch.prim.ListConstruct %int1024_616, %int32_617 : (!torch.int, !torch.int) -> !torch.list - %446 = torch.aten.reshape %result0, %445 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %int1_618 = torch.constant.int 1 - %int0_619 = torch.constant.int 0 - %447 = torch.prim.ListConstruct %int1_618, %int0_619 : (!torch.int, !torch.int) -> !torch.list - %int1_620 = torch.constant.int 1 - %int0_621 = torch.constant.int 0 - %448 = torch.prim.ListConstruct %int1_620, %int0_621 : (!torch.int, !torch.int) -> !torch.list - %449 = torch.aten.permute %446, %448 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> - %450 = torch.aten.mm %449, %328 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> - %int1_622 = torch.constant.int 1 - %int0_623 = torch.constant.int 0 - %451 = torch.prim.ListConstruct %int1_622, %int0_623 : (!torch.int, !torch.int) -> !torch.list - %int1_624 = torch.constant.int 1 - %int0_625 = torch.constant.int 0 - %452 = torch.prim.ListConstruct %int1_624, %int0_625 : (!torch.int, !torch.int) -> !torch.list - %453 = torch.aten.permute %450, %452 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %454 = torch.aten.add.Tensor %arg125, %453, %arg124 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> - %455 = torch.aten.mul.Tensor %arg126, %arg68 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %456 = torch.aten.addcmul %455, %454, %454, %arg123 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %457 = torch.aten.sqrt %456 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %458 = torch.aten.add.Tensor %457, %arg2, %arg122 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> - %459 = torch.aten.mul.Tensor %arg128, %arg71 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %460 = torch.aten.add.Tensor %459, %454, %arg127 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %461 = torch.aten.addcdiv %arg19, %460, %458, %arg121 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %int0_626 = torch.constant.int 0 - %462 = torch.prim.ListConstruct %int0_626 : (!torch.int) -> !torch.list - %true_627 = torch.constant.bool true - %none_628 = torch.constant.none - %463 = torch.aten.sum.dim_IntList %328, %462, %true_627, %none_628 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> - %int32_629 = torch.constant.int 32 - %464 = torch.prim.ListConstruct %int32_629 : (!torch.int) -> !torch.list - %int32_630 = torch.constant.int 32 - %465 = torch.prim.ListConstruct %int32_630 : (!torch.int) -> !torch.list - %466 = torch.aten.reshape %463, %465 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> - %467 = torch.aten.add.Tensor %arg133, %466, %arg132 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %468 = torch.aten.mul.Tensor %arg134, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %469 = torch.aten.addcmul %468, %467, %467, %arg131 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %470 = torch.aten.sqrt %469 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %471 = torch.aten.add.Tensor %470, %arg2, %arg130 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %472 = torch.aten.mul.Tensor %arg136, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %473 = torch.aten.add.Tensor %472, %467, %arg135 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %474 = torch.aten.addcdiv %arg20, %473, %471, %arg129 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %int1024_631 = torch.constant.int 1024 - %int32_632 = torch.constant.int 32 - %475 = torch.prim.ListConstruct %int1024_631, %int32_632 : (!torch.int, !torch.int) -> !torch.list - %int1024_633 = torch.constant.int 1024 - %int32_634 = torch.constant.int 32 - %476 = torch.prim.ListConstruct %int1024_633, %int32_634 : (!torch.int, !torch.int) -> !torch.list - %477 = torch.aten.reshape %result0, %476 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %int1_635 = torch.constant.int 1 - %int0_636 = torch.constant.int 0 - %478 = torch.prim.ListConstruct %int1_635, %int0_636 : (!torch.int, !torch.int) -> !torch.list - %int1_637 = torch.constant.int 1 - %int0_638 = torch.constant.int 0 - %479 = torch.prim.ListConstruct %int1_637, %int0_638 : (!torch.int, !torch.int) -> !torch.list - %480 = torch.aten.permute %477, %479 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> - %481 = torch.aten.mm %480, %357 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> - %int1_639 = torch.constant.int 1 - %int0_640 = torch.constant.int 0 - %482 = torch.prim.ListConstruct %int1_639, %int0_640 : (!torch.int, !torch.int) -> !torch.list - %int1_641 = torch.constant.int 1 - %int0_642 = torch.constant.int 0 - %483 = torch.prim.ListConstruct %int1_641, %int0_642 : (!torch.int, !torch.int) -> !torch.list - %484 = torch.aten.permute %481, %483 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %485 = torch.aten.add.Tensor %arg141, %484, %arg140 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> - %486 = torch.aten.mul.Tensor %arg142, %arg68 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %487 = torch.aten.addcmul %486, %485, %485, %arg139 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %488 = torch.aten.sqrt %487 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %489 = torch.aten.add.Tensor %488, %arg2, %arg138 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> - %490 = torch.aten.mul.Tensor %arg144, %arg71 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %491 = torch.aten.add.Tensor %490, %485, %arg143 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %492 = torch.aten.addcdiv %arg32, %491, %489, %arg137 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %int0_643 = torch.constant.int 0 - %493 = torch.prim.ListConstruct %int0_643 : (!torch.int) -> !torch.list - %true_644 = torch.constant.bool true - %none_645 = torch.constant.none - %494 = torch.aten.sum.dim_IntList %357, %493, %true_644, %none_645 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> - %int32_646 = torch.constant.int 32 - %495 = torch.prim.ListConstruct %int32_646 : (!torch.int) -> !torch.list - %int32_647 = torch.constant.int 32 - %496 = torch.prim.ListConstruct %int32_647 : (!torch.int) -> !torch.list - %497 = torch.aten.reshape %494, %496 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> - %498 = torch.aten.add.Tensor %arg149, %497, %arg148 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %499 = torch.aten.mul.Tensor %arg150, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %500 = torch.aten.addcmul %499, %498, %498, %arg147 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %501 = torch.aten.sqrt %500 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %502 = torch.aten.add.Tensor %501, %arg2, %arg146 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %503 = torch.aten.mul.Tensor %arg152, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %504 = torch.aten.add.Tensor %503, %498, %arg151 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %505 = torch.aten.addcdiv %arg33, %504, %502, %arg145 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %int2_648 = torch.constant.int 2 - %int512_649 = torch.constant.int 512 - %int16_650 = torch.constant.int 16 - %506 = torch.prim.ListConstruct %int2_648, %int2_648, %int512_649, %int16_650 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int2_651 = torch.constant.int 2 - %int512_652 = torch.constant.int 512 - %int16_653 = torch.constant.int 16 - %507 = torch.prim.ListConstruct %int2_651, %int2_651, %int512_652, %int16_653 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %508 = torch.aten.reshape %124, %507 : !torch.vtensor<[4,512,16],f32>, !torch.list -> !torch.vtensor<[2,2,512,16],f32> - %int0_654 = torch.constant.int 0 - %int2_655 = torch.constant.int 2 - %int1_656 = torch.constant.int 1 - %int3_657 = torch.constant.int 3 - %509 = torch.prim.ListConstruct %int0_654, %int2_655, %int1_656, %int3_657 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %int0_658 = torch.constant.int 0 - %int2_659 = torch.constant.int 2 - %int1_660 = torch.constant.int 1 - %int3_661 = torch.constant.int 3 - %510 = torch.prim.ListConstruct %int0_658, %int2_659, %int1_660, %int3_661 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %511 = torch.aten.permute %508, %510 : !torch.vtensor<[2,2,512,16],f32>, !torch.list -> !torch.vtensor<[2,512,2,16],f32> - %int2_662 = torch.constant.int 2 - %int512_663 = torch.constant.int 512 - %int32_664 = torch.constant.int 32 - %512 = torch.prim.ListConstruct %int2_662, %int512_663, %int32_664 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %int2_665 = torch.constant.int 2 - %int512_666 = torch.constant.int 512 - %int32_667 = torch.constant.int 32 - %513 = torch.prim.ListConstruct %int2_665, %int512_666, %int32_667 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %514 = torch.aten.reshape %511, %513 : !torch.vtensor<[2,512,2,16],f32>, !torch.list -> !torch.vtensor<[2,512,32],f32> - %int1024_668 = torch.constant.int 1024 - %int32_669 = torch.constant.int 32 - %515 = torch.prim.ListConstruct %int1024_668, %int32_669 : (!torch.int, !torch.int) -> !torch.list - %int1024_670 = torch.constant.int 1024 - %int32_671 = torch.constant.int 32 - %516 = torch.prim.ListConstruct %int1024_670, %int32_671 : (!torch.int, !torch.int) -> !torch.list - %517 = torch.aten.reshape %514, %516 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %int1_672 = torch.constant.int 1 - %int0_673 = torch.constant.int 0 - %518 = torch.prim.ListConstruct %int1_672, %int0_673 : (!torch.int, !torch.int) -> !torch.list - %int1_674 = torch.constant.int 1 - %int0_675 = torch.constant.int 0 - %519 = torch.prim.ListConstruct %int1_674, %int0_675 : (!torch.int, !torch.int) -> !torch.list - %520 = torch.aten.permute %517, %519 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> - %521 = torch.aten.mm %520, %261 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> - %int1_676 = torch.constant.int 1 - %int0_677 = torch.constant.int 0 - %522 = torch.prim.ListConstruct %int1_676, %int0_677 : (!torch.int, !torch.int) -> !torch.list - %int1_678 = torch.constant.int 1 - %int0_679 = torch.constant.int 0 - %523 = torch.prim.ListConstruct %int1_678, %int0_679 : (!torch.int, !torch.int) -> !torch.list - %524 = torch.aten.permute %521, %523 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %525 = torch.aten.add.Tensor %arg157, %524, %arg156 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> - %526 = torch.aten.mul.Tensor %arg158, %arg68 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %527 = torch.aten.addcmul %526, %525, %525, %arg155 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %528 = torch.aten.sqrt %527 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %529 = torch.aten.add.Tensor %528, %arg2, %arg154 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> - %530 = torch.aten.mul.Tensor %arg160, %arg71 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %531 = torch.aten.add.Tensor %530, %525, %arg159 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %532 = torch.aten.addcdiv %arg34, %531, %529, %arg153 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %int0_680 = torch.constant.int 0 - %533 = torch.prim.ListConstruct %int0_680 : (!torch.int) -> !torch.list - %true_681 = torch.constant.bool true - %none_682 = torch.constant.none - %534 = torch.aten.sum.dim_IntList %261, %533, %true_681, %none_682 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> - %int32_683 = torch.constant.int 32 - %535 = torch.prim.ListConstruct %int32_683 : (!torch.int) -> !torch.list - %int32_684 = torch.constant.int 32 - %536 = torch.prim.ListConstruct %int32_684 : (!torch.int) -> !torch.list - %537 = torch.aten.reshape %534, %536 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> - %538 = torch.aten.add.Tensor %arg165, %537, %arg164 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %539 = torch.aten.mul.Tensor %arg166, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %540 = torch.aten.addcmul %539, %538, %538, %arg163 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %541 = torch.aten.sqrt %540 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %542 = torch.aten.add.Tensor %541, %arg2, %arg162 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %543 = torch.aten.mul.Tensor %arg168, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %544 = torch.aten.add.Tensor %543, %538, %arg167 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %545 = torch.aten.addcdiv %arg40, %544, %542, %arg161 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %546 = torch.aten.add.Tensor %arg173, %result1_403, %arg172 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %547 = torch.aten.mul.Tensor %arg174, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %548 = torch.aten.addcmul %547, %546, %546, %arg171 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %549 = torch.aten.sqrt %548 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %550 = torch.aten.add.Tensor %549, %arg2, %arg170 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %551 = torch.aten.mul.Tensor %arg176, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %552 = torch.aten.add.Tensor %551, %546, %arg175 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %553 = torch.aten.addcdiv %arg36, %552, %550, %arg169 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %554 = torch.aten.add.Tensor %arg181, %result2_404, %arg180 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %555 = torch.aten.mul.Tensor %arg182, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %556 = torch.aten.addcmul %555, %554, %554, %arg179 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %557 = torch.aten.sqrt %556 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %558 = torch.aten.add.Tensor %557, %arg2, %arg178 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %559 = torch.aten.mul.Tensor %arg184, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %560 = torch.aten.add.Tensor %559, %554, %arg183 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %561 = torch.aten.addcdiv %arg35, %560, %558, %arg177 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %int1024_685 = torch.constant.int 1024 - %int32_686 = torch.constant.int 32 - %562 = torch.prim.ListConstruct %int1024_685, %int32_686 : (!torch.int, !torch.int) -> !torch.list - %int1024_687 = torch.constant.int 1024 - %int32_688 = torch.constant.int 32 - %563 = torch.prim.ListConstruct %int1024_687, %int32_688 : (!torch.int, !torch.int) -> !torch.list - %564 = torch.aten.reshape %result0_238, %563 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %int1_689 = torch.constant.int 1 - %int0_690 = torch.constant.int 0 - %565 = torch.prim.ListConstruct %int1_689, %int0_690 : (!torch.int, !torch.int) -> !torch.list - %int1_691 = torch.constant.int 1 - %int0_692 = torch.constant.int 0 - %566 = torch.prim.ListConstruct %int1_691, %int0_692 : (!torch.int, !torch.int) -> !torch.list - %567 = torch.aten.permute %564, %566 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> - %568 = torch.aten.mm %567, %250 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> - %int1_693 = torch.constant.int 1 - %int0_694 = torch.constant.int 0 - %569 = torch.prim.ListConstruct %int1_693, %int0_694 : (!torch.int, !torch.int) -> !torch.list - %int1_695 = torch.constant.int 1 - %int0_696 = torch.constant.int 0 - %570 = torch.prim.ListConstruct %int1_695, %int0_696 : (!torch.int, !torch.int) -> !torch.list - %571 = torch.aten.permute %568, %570 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %572 = torch.aten.add.Tensor %arg189, %571, %arg188 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> - %573 = torch.aten.mul.Tensor %arg190, %arg68 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %574 = torch.aten.addcmul %573, %572, %572, %arg187 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %575 = torch.aten.sqrt %574 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %576 = torch.aten.add.Tensor %575, %arg2, %arg186 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> - %577 = torch.aten.mul.Tensor %arg192, %arg71 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %578 = torch.aten.add.Tensor %577, %572, %arg191 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %579 = torch.aten.addcdiv %arg42, %578, %576, %arg185 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %int0_697 = torch.constant.int 0 - %580 = torch.prim.ListConstruct %int0_697 : (!torch.int) -> !torch.list - %true_698 = torch.constant.bool true - %none_699 = torch.constant.none - %581 = torch.aten.sum.dim_IntList %250, %580, %true_698, %none_699 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> - %int32_700 = torch.constant.int 32 - %582 = torch.prim.ListConstruct %int32_700 : (!torch.int) -> !torch.list - %int32_701 = torch.constant.int 32 - %583 = torch.prim.ListConstruct %int32_701 : (!torch.int) -> !torch.list - %584 = torch.aten.reshape %581, %583 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> - %585 = torch.aten.add.Tensor %arg197, %584, %arg196 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %586 = torch.aten.mul.Tensor %arg198, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %587 = torch.aten.addcmul %586, %585, %585, %arg195 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %588 = torch.aten.sqrt %587 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %589 = torch.aten.add.Tensor %588, %arg2, %arg194 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %590 = torch.aten.mul.Tensor %arg200, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %591 = torch.aten.add.Tensor %590, %585, %arg199 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %592 = torch.aten.addcdiv %arg45, %591, %589, %arg193 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %int1024_702 = torch.constant.int 1024 - %int32_703 = torch.constant.int 32 - %593 = torch.prim.ListConstruct %int1024_702, %int32_703 : (!torch.int, !torch.int) -> !torch.list - %int1024_704 = torch.constant.int 1024 - %int32_705 = torch.constant.int 32 - %594 = torch.prim.ListConstruct %int1024_704, %int32_705 : (!torch.int, !torch.int) -> !torch.list - %595 = torch.aten.reshape %169, %594 : !torch.vtensor<[2,512,32],f32>, !torch.list -> !torch.vtensor<[1024,32],f32> - %int1_706 = torch.constant.int 1 - %int0_707 = torch.constant.int 0 - %596 = torch.prim.ListConstruct %int1_706, %int0_707 : (!torch.int, !torch.int) -> !torch.list - %int1_708 = torch.constant.int 1 - %int0_709 = torch.constant.int 0 - %597 = torch.prim.ListConstruct %int1_708, %int0_709 : (!torch.int, !torch.int) -> !torch.list - %598 = torch.aten.permute %595, %597 : !torch.vtensor<[1024,32],f32>, !torch.list -> !torch.vtensor<[32,1024],f32> - %599 = torch.aten.mm %598, %242 : !torch.vtensor<[32,1024],f32>, !torch.vtensor<[1024,32],f32> -> !torch.vtensor<[32,32],f32> - %int1_710 = torch.constant.int 1 - %int0_711 = torch.constant.int 0 - %600 = torch.prim.ListConstruct %int1_710, %int0_711 : (!torch.int, !torch.int) -> !torch.list - %int1_712 = torch.constant.int 1 - %int0_713 = torch.constant.int 0 - %601 = torch.prim.ListConstruct %int1_712, %int0_713 : (!torch.int, !torch.int) -> !torch.list - %602 = torch.aten.permute %599, %601 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %603 = torch.aten.add.Tensor %arg205, %602, %arg204 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> - %604 = torch.aten.mul.Tensor %arg206, %arg68 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %605 = torch.aten.addcmul %604, %603, %603, %arg203 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %606 = torch.aten.sqrt %605 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %607 = torch.aten.add.Tensor %606, %arg2, %arg202 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> - %608 = torch.aten.mul.Tensor %arg208, %arg71 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %609 = torch.aten.add.Tensor %608, %603, %arg207 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %610 = torch.aten.addcdiv %arg46, %609, %607, %arg201 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %int0_714 = torch.constant.int 0 - %611 = torch.prim.ListConstruct %int0_714 : (!torch.int) -> !torch.list - %true_715 = torch.constant.bool true - %none_716 = torch.constant.none - %612 = torch.aten.sum.dim_IntList %242, %611, %true_715, %none_716 : !torch.vtensor<[1024,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> - %int32_717 = torch.constant.int 32 - %613 = torch.prim.ListConstruct %int32_717 : (!torch.int) -> !torch.list - %int32_718 = torch.constant.int 32 - %614 = torch.prim.ListConstruct %int32_718 : (!torch.int) -> !torch.list - %615 = torch.aten.reshape %612, %614 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> - %616 = torch.aten.add.Tensor %arg213, %615, %arg212 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %617 = torch.aten.mul.Tensor %arg214, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %618 = torch.aten.addcmul %617, %616, %616, %arg211 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %619 = torch.aten.sqrt %618 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %620 = torch.aten.add.Tensor %619, %arg2, %arg210 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %621 = torch.aten.mul.Tensor %arg216, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %622 = torch.aten.add.Tensor %621, %616, %arg215 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %623 = torch.aten.addcdiv %arg52, %622, %620, %arg209 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %624 = torch.aten.add.Tensor %arg221, %result1_377, %arg220 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %625 = torch.aten.mul.Tensor %arg222, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %626 = torch.aten.addcmul %625, %624, %624, %arg219 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %627 = torch.aten.sqrt %626 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %628 = torch.aten.add.Tensor %627, %arg2, %arg218 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %629 = torch.aten.mul.Tensor %arg224, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %630 = torch.aten.add.Tensor %629, %624, %arg223 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %631 = torch.aten.addcdiv %arg48, %630, %628, %arg217 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %632 = torch.aten.add.Tensor %arg229, %result2_378, %arg228 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %633 = torch.aten.mul.Tensor %arg230, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %634 = torch.aten.addcmul %633, %632, %632, %arg227 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %635 = torch.aten.sqrt %634 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %636 = torch.aten.add.Tensor %635, %arg2, %arg226 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %637 = torch.aten.mul.Tensor %arg232, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %638 = torch.aten.add.Tensor %637, %632, %arg231 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %639 = torch.aten.addcdiv %arg47, %638, %636, %arg225 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %int0_719 = torch.constant.int 0 - %int0_720 = torch.constant.int 0 - %int2_721 = torch.constant.int 2 - %int1_722 = torch.constant.int 1 - %640 = torch.aten.slice.Tensor %result0_287, %int0_719, %int0_720, %int2_721, %int1_722 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,512,32],f32> - %int0_723 = torch.constant.int 0 - %int0_724 = torch.constant.int 0 - %int2_725 = torch.constant.int 2 - %int1_726 = torch.constant.int 1 - %641 = torch.aten.slice.Tensor %640, %int0_723, %int0_724, %int2_725, %int1_726 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,512,32],f32> - %int1_727 = torch.constant.int 1 - %int0_728 = torch.constant.int 0 - %int1_729 = torch.constant.int 1 - %int1_730 = torch.constant.int 1 - %642 = torch.aten.slice.Tensor %641, %int1_727, %int0_728, %int1_729, %int1_730 : !torch.vtensor<[2,512,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1,32],f32> - %int2_731 = torch.constant.int 2 - %int0_732 = torch.constant.int 0 - %int32_733 = torch.constant.int 32 - %int1_734 = torch.constant.int 1 - %643 = torch.aten.slice.Tensor %642, %int2_731, %int0_732, %int32_733, %int1_734 : !torch.vtensor<[2,1,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1,32],f32> - %int2_735 = torch.constant.int 2 - %int32_736 = torch.constant.int 32 - %644 = torch.prim.ListConstruct %int2_735, %int32_736 : (!torch.int, !torch.int) -> !torch.list - %int2_737 = torch.constant.int 2 - %int32_738 = torch.constant.int 32 - %645 = torch.prim.ListConstruct %int2_737, %int32_738 : (!torch.int, !torch.int) -> !torch.list - %646 = torch.aten.reshape %643, %645 : !torch.vtensor<[2,1,32],f32>, !torch.list -> !torch.vtensor<[2,32],f32> - %int1_739 = torch.constant.int 1 - %int0_740 = torch.constant.int 0 - %647 = torch.prim.ListConstruct %int1_739, %int0_740 : (!torch.int, !torch.int) -> !torch.list - %int1_741 = torch.constant.int 1 - %int0_742 = torch.constant.int 0 - %648 = torch.prim.ListConstruct %int1_741, %int0_742 : (!torch.int, !torch.int) -> !torch.list - %649 = torch.aten.permute %646, %648 : !torch.vtensor<[2,32],f32>, !torch.list -> !torch.vtensor<[32,2],f32> - %650 = torch.aten.mm %649, %222 : !torch.vtensor<[32,2],f32>, !torch.vtensor<[2,32],f32> -> !torch.vtensor<[32,32],f32> - %int1_743 = torch.constant.int 1 - %int0_744 = torch.constant.int 0 - %651 = torch.prim.ListConstruct %int1_743, %int0_744 : (!torch.int, !torch.int) -> !torch.list - %int1_745 = torch.constant.int 1 - %int0_746 = torch.constant.int 0 - %652 = torch.prim.ListConstruct %int1_745, %int0_746 : (!torch.int, !torch.int) -> !torch.list - %653 = torch.aten.permute %650, %652 : !torch.vtensor<[32,32],f32>, !torch.list -> !torch.vtensor<[32,32],f32> - %654 = torch.aten.add.Tensor %arg237, %653, %arg236 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.int -> !torch.vtensor<[32,32],f32> - %655 = torch.aten.mul.Tensor %arg238, %arg68 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %656 = torch.aten.addcmul %655, %654, %654, %arg235 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %657 = torch.aten.sqrt %656 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %658 = torch.aten.add.Tensor %657, %arg2, %arg234 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32,32],f32> - %659 = torch.aten.mul.Tensor %arg240, %arg71 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32,32],f32> - %660 = torch.aten.add.Tensor %659, %654, %arg239 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %661 = torch.aten.addcdiv %arg53, %660, %658, %arg233 : !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.float -> !torch.vtensor<[32,32],f32> - %int0_747 = torch.constant.int 0 - %662 = torch.prim.ListConstruct %int0_747 : (!torch.int) -> !torch.list - %true_748 = torch.constant.bool true - %none_749 = torch.constant.none - %663 = torch.aten.sum.dim_IntList %222, %662, %true_748, %none_749 : !torch.vtensor<[2,32],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,32],f32> - %int32_750 = torch.constant.int 32 - %664 = torch.prim.ListConstruct %int32_750 : (!torch.int) -> !torch.list - %int32_751 = torch.constant.int 32 - %665 = torch.prim.ListConstruct %int32_751 : (!torch.int) -> !torch.list - %666 = torch.aten.reshape %663, %665 : !torch.vtensor<[1,32],f32>, !torch.list -> !torch.vtensor<[32],f32> - %667 = torch.aten.add.Tensor %arg245, %666, %arg244 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.int -> !torch.vtensor<[32],f32> - %668 = torch.aten.mul.Tensor %arg246, %arg68 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %669 = torch.aten.addcmul %668, %667, %667, %arg243 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %670 = torch.aten.sqrt %669 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %671 = torch.aten.add.Tensor %670, %arg2, %arg242 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[32],f32> - %672 = torch.aten.mul.Tensor %arg248, %arg71 : !torch.vtensor<[32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[32],f32> - %673 = torch.aten.add.Tensor %672, %667, %arg247 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %674 = torch.aten.addcdiv %arg56, %673, %671, %arg241 : !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.float -> !torch.vtensor<[32],f32> - %int1_752 = torch.constant.int 1 - %int0_753 = torch.constant.int 0 - %675 = torch.prim.ListConstruct %int1_752, %int0_753 : (!torch.int, !torch.int) -> !torch.list - %int1_754 = torch.constant.int 1 - %int0_755 = torch.constant.int 0 - %676 = torch.prim.ListConstruct %int1_754, %int0_755 : (!torch.int, !torch.int) -> !torch.list - %677 = torch.aten.permute %197, %676 : !torch.vtensor<[2,32],f32>, !torch.list -> !torch.vtensor<[32,2],f32> - %678 = torch.aten.mm %677, %220 : !torch.vtensor<[32,2],f32>, !torch.vtensor<[2,2],f32> -> !torch.vtensor<[32,2],f32> - %int1_756 = torch.constant.int 1 - %int0_757 = torch.constant.int 0 - %679 = torch.prim.ListConstruct %int1_756, %int0_757 : (!torch.int, !torch.int) -> !torch.list - %int1_758 = torch.constant.int 1 - %int0_759 = torch.constant.int 0 - %680 = torch.prim.ListConstruct %int1_758, %int0_759 : (!torch.int, !torch.int) -> !torch.list - %681 = torch.aten.permute %678, %680 : !torch.vtensor<[32,2],f32>, !torch.list -> !torch.vtensor<[2,32],f32> - %682 = torch.aten.add.Tensor %arg253, %681, %arg252 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.int -> !torch.vtensor<[2,32],f32> - %683 = torch.aten.mul.Tensor %arg254, %arg68 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,32],f32> - %684 = torch.aten.addcmul %683, %682, %682, %arg251 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> - %685 = torch.aten.sqrt %684 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> - %686 = torch.aten.add.Tensor %685, %arg2, %arg250 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[2,32],f32> - %687 = torch.aten.mul.Tensor %arg256, %arg71 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2,32],f32> - %688 = torch.aten.add.Tensor %687, %682, %arg255 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> - %689 = torch.aten.addcdiv %arg57, %688, %686, %arg249 : !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.float -> !torch.vtensor<[2,32],f32> - %int0_760 = torch.constant.int 0 - %690 = torch.prim.ListConstruct %int0_760 : (!torch.int) -> !torch.list - %true_761 = torch.constant.bool true - %none_762 = torch.constant.none - %691 = torch.aten.sum.dim_IntList %220, %690, %true_761, %none_762 : !torch.vtensor<[2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,2],f32> - %int2_763 = torch.constant.int 2 - %692 = torch.prim.ListConstruct %int2_763 : (!torch.int) -> !torch.list - %int2_764 = torch.constant.int 2 - %693 = torch.prim.ListConstruct %int2_764 : (!torch.int) -> !torch.list - %694 = torch.aten.reshape %691, %693 : !torch.vtensor<[1,2],f32>, !torch.list -> !torch.vtensor<[2],f32> - %695 = torch.aten.add.Tensor %arg261, %694, %arg260 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2],f32> - %696 = torch.aten.mul.Tensor %arg262, %arg68 : !torch.vtensor<[2],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2],f32> - %697 = torch.aten.addcmul %696, %695, %695, %arg259 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float -> !torch.vtensor<[2],f32> - %698 = torch.aten.sqrt %697 : !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32> - %699 = torch.aten.add.Tensor %698, %arg2, %arg258 : !torch.vtensor<[2],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[2],f32> - %700 = torch.aten.mul.Tensor %arg264, %arg71 : !torch.vtensor<[2],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[2],f32> - %701 = torch.aten.add.Tensor %700, %695, %arg263 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float -> !torch.vtensor<[2],f32> - %702 = torch.aten.addcdiv %arg60, %701, %699, %arg257 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float -> !torch.vtensor<[2],f32> - %703 = torch.aten.zero.functional %695 : !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32> - %704 = torch.aten.zero.functional %682 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> - %705 = torch.aten.zero.functional %667 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %706 = torch.aten.zero.functional %654 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %707 = torch.aten.zero.functional %624 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %708 = torch.aten.zero.functional %632 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %709 = torch.aten.zero.functional %616 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %710 = torch.aten.zero.functional %603 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %711 = torch.aten.zero.functional %585 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %712 = torch.aten.zero.functional %572 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %713 = torch.aten.zero.functional %546 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %714 = torch.aten.zero.functional %554 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %715 = torch.aten.zero.functional %538 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %716 = torch.aten.zero.functional %525 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %717 = torch.aten.zero.functional %498 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %718 = torch.aten.zero.functional %485 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %719 = torch.aten.zero.functional %467 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %720 = torch.aten.zero.functional %454 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %721 = torch.aten.zero.functional %436 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %722 = torch.aten.zero.functional %423 : !torch.vtensor<[32,32],f32> -> !torch.vtensor<[32,32],f32> - %723 = torch.aten.zero.functional %397 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %724 = torch.aten.zero.functional %405 : !torch.vtensor<[32],f32> -> !torch.vtensor<[32],f32> - %725 = torch.aten.zero.functional %380 : !torch.vtensor<[512,32],f32> -> !torch.vtensor<[512,32],f32> - %726 = torch.aten.zero.functional %389 : !torch.vtensor<[2,32],f32> -> !torch.vtensor<[2,32],f32> - %727 = torch.aten.zero.functional %369 : !torch.vtensor<[28996,32],f32> -> !torch.vtensor<[28996,32],f32> - return %376, %387, %396, %404, %412, %430, %443, %461, %474, %492, %505, %532, %545, %553, %561, %579, %592, %610, %623, %631, %639, %661, %674, %689, %702, %703, %704, %705, %706, %707, %708, %709, %710, %711, %712, %713, %714, %715, %716, %717, %718, %719, %720, %721, %722, %723, %724, %725, %726, %727, %375, %371, %386, %382, %395, %391, %403, %399, %411, %407, %429, %425, %442, %438, %460, %456, %473, %469, %491, %487, %504, %500, %531, %527, %544, %540, %552, %548, %560, %556, %578, %574, %591, %587, %609, %605, %622, %618, %630, %626, %638, %634, %660, %656, %673, %669, %688, %684, %701, %697, %arg61, %arg5, %arg12, %arg25, %207, %output : !torch.vtensor<[28996,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[28996,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[512,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32,32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2,32],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,512],si64>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[],f32> -} diff --git a/e2e_testing/lazy_tensor_core/main.py b/e2e_testing/lazy_tensor_core/main.py deleted file mode 100644 index 02f9d3dfef0..00000000000 --- a/e2e_testing/lazy_tensor_core/main.py +++ /dev/null @@ -1,89 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -import os -import pathlib -import unittest - -import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend -from numpy.testing import assert_almost_equal - -# Example models -import ltc_backend_bert -import ltc_backend_mnist - - -class LTCNumericTests(unittest.TestCase): - """ - This test suite validates numerics by comparing the output of the models when - executed using the MLIR LTC backend and ensuring they match the results on CPU. - """ - - def assert_tensors_almost_equal(self, tensor_a, tensor_b, message): - a, b = tensor_a.cpu().detach().numpy(), tensor_b.cpu().detach().numpy() - - # Ensure tensors match up to 7 decimals of precision. - assert_almost_equal(a, b, 7, message) - - def run_test(self, run_model): - model_torch_mlir, loss_torch_mlir = run_model('lazy') - model_cpu, loss_cpu = run_model('cpu') - - # Check losses match. - self.assertEqual(len(loss_torch_mlir), len(loss_cpu)) - for idx in range(len(loss_torch_mlir)): - self.assert_tensors_almost_equal(loss_torch_mlir[idx], loss_cpu[idx], - f'Losses at index {idx} do not match!') - - # Check that number of parameters match. - torch_mlir_params, cpu_params = [list(model.named_parameters()) for model in (model_torch_mlir, model_cpu)] - self.assertEqual(len(torch_mlir_params), len(cpu_params)) - - # Check that names of parameters match. - torch_mlir_keys = [] - for name, param in torch_mlir_params: - torch_mlir_keys.append(name) - - cpu_keys = [] - for name, param in cpu_params: - cpu_keys.append(name) - - self.assertEqual(torch_mlir_keys, cpu_keys) - - # Check contents of parameters match. - for idx in range(len(torch_mlir_params)): - self.assert_tensors_almost_equal(torch_mlir_params[idx][1], cpu_params[idx][1], - f'Parameters {torch_mlir_keys[idx]} do not match!') - - def test_bert(self): - self.run_test(ltc_backend_bert.main) - - def test_mnist(self): - self.run_test(ltc_backend_mnist.main) - - -class LTCMlirTests(unittest.TestCase): - """ - This test suite validates that the emitted MLIR matches a known good output. - """ - - def run_test(self, run_model, mlir_path): - run_model() - - # Compare the generated MLIR with a known good output. - with open(os.path.join(pathlib.Path(__file__).parent.resolve(), mlir_path), 'r') as file: - self.assertEqual(ltc_backend.get_latest_computation().to_string(), file.read()) - - def test_bert(self): - self.run_test(ltc_backend_bert.main, 'bert.mlir') - - def test_mnist(self): - self.run_test(ltc_backend_mnist.main, 'mnist.mlir') - - -if __name__ == '__main__': - ltc_backend._initialize() - - unittest.main() diff --git a/e2e_testing/lazy_tensor_core/mnist.mlir b/e2e_testing/lazy_tensor_core/mnist.mlir deleted file mode 100644 index 26e9874c914..00000000000 --- a/e2e_testing/lazy_tensor_core/mnist.mlir +++ /dev/null @@ -1,59 +0,0 @@ -func.func @graph(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.vtensor<[10,5],f32>, %arg7: !torch.vtensor<[10],f32>, %arg8: !torch.vtensor<[1],si64>, %arg9: !torch.vtensor<[],f32>, %arg10: !torch.vtensor<[10,5],f32>, %arg11: !torch.float, %arg12: !torch.int, %arg13: !torch.vtensor<[10],f32>) -> (!torch.vtensor<[1,5],f32>, !torch.vtensor<[10,5],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[10,5],f32>, !torch.vtensor<[1,10],f32>, !torch.vtensor<[],f32>) { - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %0 = torch.prim.ListConstruct %int1, %int0 : (!torch.int, !torch.int) -> !torch.list - %int1_0 = torch.constant.int 1 - %int0_1 = torch.constant.int 0 - %1 = torch.prim.ListConstruct %int1_0, %int0_1 : (!torch.int, !torch.int) -> !torch.list - %2 = torch.aten.permute %arg6, %1 : !torch.vtensor<[10,5],f32>, !torch.list -> !torch.vtensor<[5,10],f32> - %3 = torch.aten.addmm %arg7, %arg0, %2, %arg5, %arg4 : !torch.vtensor<[10],f32>, !torch.vtensor<[1,5],f32>, !torch.vtensor<[5,10],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,10],f32> - %4 = torch.aten.relu %3 : !torch.vtensor<[1,10],f32> -> !torch.vtensor<[1,10],f32> - %int1_2 = torch.constant.int 1 - %false = torch.constant.bool false - %5 = torch.aten._log_softmax %4, %int1_2, %false : !torch.vtensor<[1,10],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,10],f32> - %none = torch.constant.none - %int1_3 = torch.constant.int 1 - %int-100 = torch.constant.int -100 - %output, %total_weight = torch.aten.nll_loss_forward %5, %arg8, %none, %int1_3, %int-100 : !torch.vtensor<[1,10],f32>, !torch.vtensor<[1],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32> - %6 = torch.prim.TupleConstruct %output, %total_weight : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.tuple, vtensor<[],f32>> - %none_4 = torch.constant.none - %int1_5 = torch.constant.int 1 - %int-100_6 = torch.constant.int -100 - %7 = torch.aten.nll_loss_backward %arg9, %5, %arg8, %none_4, %int1_5, %int-100_6, %total_weight : !torch.vtensor<[],f32>, !torch.vtensor<[1,10],f32>, !torch.vtensor<[1],si64>, !torch.none, !torch.int, !torch.int, !torch.vtensor<[],f32> -> !torch.vtensor<[1,10],f32> - %int1_7 = torch.constant.int 1 - %int6 = torch.constant.int 6 - %8 = torch.aten._log_softmax_backward_data %7, %5, %int1_7, %int6 : !torch.vtensor<[1,10],f32>, !torch.vtensor<[1,10],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,10],f32> - %9 = torch.aten.threshold_backward %8, %4, %arg3 : !torch.vtensor<[1,10],f32>, !torch.vtensor<[1,10],f32>, !torch.int -> !torch.vtensor<[1,10],f32> - %int1_8 = torch.constant.int 1 - %int0_9 = torch.constant.int 0 - %10 = torch.prim.ListConstruct %int1_8, %int0_9 : (!torch.int, !torch.int) -> !torch.list - %int1_10 = torch.constant.int 1 - %int0_11 = torch.constant.int 0 - %11 = torch.prim.ListConstruct %int1_10, %int0_11 : (!torch.int, !torch.int) -> !torch.list - %12 = torch.aten.permute %arg0, %11 : !torch.vtensor<[1,5],f32>, !torch.list -> !torch.vtensor<[5,1],f32> - %13 = torch.aten.mm %12, %9 : !torch.vtensor<[5,1],f32>, !torch.vtensor<[1,10],f32> -> !torch.vtensor<[5,10],f32> - %int1_12 = torch.constant.int 1 - %int0_13 = torch.constant.int 0 - %14 = torch.prim.ListConstruct %int1_12, %int0_13 : (!torch.int, !torch.int) -> !torch.list - %int1_14 = torch.constant.int 1 - %int0_15 = torch.constant.int 0 - %15 = torch.prim.ListConstruct %int1_14, %int0_15 : (!torch.int, !torch.int) -> !torch.list - %16 = torch.aten.permute %13, %15 : !torch.vtensor<[5,10],f32>, !torch.list -> !torch.vtensor<[10,5],f32> - %17 = torch.aten.zero.functional %arg10 : !torch.vtensor<[10,5],f32> -> !torch.vtensor<[10,5],f32> - %18 = torch.aten.add.Tensor %17, %16, %arg2 : !torch.vtensor<[10,5],f32>, !torch.vtensor<[10,5],f32>, !torch.int -> !torch.vtensor<[10,5],f32> - %19 = torch.aten.add.Tensor %arg6, %18, %arg1 : !torch.vtensor<[10,5],f32>, !torch.vtensor<[10,5],f32>, !torch.float -> !torch.vtensor<[10,5],f32> - %int0_16 = torch.constant.int 0 - %20 = torch.prim.ListConstruct %int0_16 : (!torch.int) -> !torch.list - %true = torch.constant.bool true - %none_17 = torch.constant.none - %21 = torch.aten.sum.dim_IntList %9, %20, %true, %none_17 : !torch.vtensor<[1,10],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,10],f32> - %int10 = torch.constant.int 10 - %22 = torch.prim.ListConstruct %int10 : (!torch.int) -> !torch.list - %int10_18 = torch.constant.int 10 - %23 = torch.prim.ListConstruct %int10_18 : (!torch.int) -> !torch.list - %24 = torch.aten.reshape %21, %23 : !torch.vtensor<[1,10],f32>, !torch.list -> !torch.vtensor<[10],f32> - %25 = torch.aten.zero.functional %arg13 : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> - %26 = torch.aten.add.Tensor %25, %24, %arg12 : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.int -> !torch.vtensor<[10],f32> - %27 = torch.aten.add.Tensor %arg7, %26, %arg11 : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> - return %arg0, %19, %27, %26, %18, %4, %output : !torch.vtensor<[1,5],f32>, !torch.vtensor<[10,5],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.vtensor<[10,5],f32>, !torch.vtensor<[1,10],f32>, !torch.vtensor<[],f32> -} From 923a9fe8b13f46993b9cd44ac8c7bd4ee46ce775 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Thu, 9 Jun 2022 11:57:37 -0400 Subject: [PATCH 11/13] Added LTC option to torchscript e2e --- .github/workflows/buildAndTest.yml | 5 ++++ e2e_testing/torchscript/main.py | 8 +++++-- .../torchscript/configs/__init__.py | 1 + .../torchscript/configs/lazy_tensor_core.py | 23 +++++++++++++++++++ 4 files changed, 35 insertions(+), 2 deletions(-) create mode 100644 python/torch_mlir_e2e_test/torchscript/configs/lazy_tensor_core.py diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index e99c8c2477b..a24ee750a1c 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -56,6 +56,11 @@ jobs: cd $GITHUB_WORKSPACE export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" python -m e2e_testing.torchscript.main --config=tosa -v + - name: Lazy Tensor Core - TorchScript end-to-end tests + run: | + cd $GITHUB_WORKSPACE + export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" + python -m e2e_testing.torchscript.main --config=lazy_tensor_core -v build-out-of-tree: name: Build out-of-tree (Release Asserts) diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index 08880f4e20c..bc52d86c27c 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -15,7 +15,7 @@ # Available test configs. from torch_mlir_e2e_test.torchscript.configs import ( - LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig + LazyTensorCoreTestConfig, LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig ) from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend @@ -28,7 +28,7 @@ register_all_tests() def _get_argparse(): - config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'eager_mode'] + config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'eager_mode', 'lazy_tensor_core'] parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') parser.add_argument('-c', '--config', choices=config_choices, @@ -40,6 +40,7 @@ def _get_argparse(): "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). "eager_mode": run through torch-mlir's eager mode frontend, using RefBackend for execution. +"lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph. ''') parser.add_argument('-f', '--filter', default='.*', help=''' Regular expression specifying which tests to include in this run. @@ -86,6 +87,9 @@ def main(): elif args.config == 'eager_mode': config = EagerModeTestConfig() xfail_set = EAGER_MODE_XFAIL_SET + elif args.config == 'lazy_tensor_core': + config = LazyTensorCoreTestConfig() + xfail_set = {} # Find the selected tests, and emit a diagnostic if none are found. tests = [ diff --git a/python/torch_mlir_e2e_test/torchscript/configs/__init__.py b/python/torch_mlir_e2e_test/torchscript/configs/__init__.py index 14c2f48c36c..63d9a733940 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/__init__.py +++ b/python/torch_mlir_e2e_test/torchscript/configs/__init__.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. +from .lazy_tensor_core import LazyTensorCoreTestConfig from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig from .native_torch import NativeTorchTestConfig from .torchscript import TorchScriptTestConfig diff --git a/python/torch_mlir_e2e_test/torchscript/configs/lazy_tensor_core.py b/python/torch_mlir_e2e_test/torchscript/configs/lazy_tensor_core.py new file mode 100644 index 00000000000..4d375f6ed06 --- /dev/null +++ b/python/torch_mlir_e2e_test/torchscript/configs/lazy_tensor_core.py @@ -0,0 +1,23 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace + + +class LazyTensorCoreTestConfig(TestConfig): + """TestConfig that runs torch.nn.Module thru the Lazy Tensor Core frontend for Torch MLIR""" + + def __init__(self): + super().__init__() + + def compile(self, program: torch.nn.Module) -> torch.nn.Module: + return program + + def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: + result: Trace = [] + + return result From 890979c9529cd6fd3feeafe6d4260fab441d9076 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Thu, 9 Jun 2022 15:00:36 -0400 Subject: [PATCH 12/13] Implement compile and run for LTC e2e test --- .../torchscript/configs/lazy_tensor_core.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/python/torch_mlir_e2e_test/torchscript/configs/lazy_tensor_core.py b/python/torch_mlir_e2e_test/torchscript/configs/lazy_tensor_core.py index 4d375f6ed06..9c5b90cda84 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/lazy_tensor_core.py +++ b/python/torch_mlir_e2e_test/torchscript/configs/lazy_tensor_core.py @@ -3,9 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. +import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend import torch - -from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace +from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem class LazyTensorCoreTestConfig(TestConfig): @@ -13,11 +13,22 @@ class LazyTensorCoreTestConfig(TestConfig): def __init__(self): super().__init__() + ltc_backend._initialize() def compile(self, program: torch.nn.Module) -> torch.nn.Module: - return program + return program.to('lazy') def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: result: Trace = [] + for item in trace: + # We need to move all the inputs to the lazy device before running in LTC. + lazy_inputs = [arg.to('lazy') for arg in item.inputs] + output = getattr(artifact, item.symbol)(*lazy_inputs) + + result.append( + TraceItem(symbol=item.symbol, + inputs=item.inputs, + output=output.to('cpu'))) + return result From 08101ebef649e48ffbb5152248fbec6e30d053fc Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Thu, 9 Jun 2022 15:12:06 -0400 Subject: [PATCH 13/13] xfail all tests that use ops that aren't currently supported --- e2e_testing/torchscript/main.py | 4 +- e2e_testing/torchscript/xfail_sets.py | 309 ++++++++++++++++++++++++++ 2 files changed, 311 insertions(+), 2 deletions(-) diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index bc52d86c27c..10e86a89663 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -21,7 +21,7 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend -from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET +from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET # Import tests to register them in the global registry. from torch_mlir_e2e_test.test_suite import register_all_tests @@ -89,7 +89,7 @@ def main(): xfail_set = EAGER_MODE_XFAIL_SET elif args.config == 'lazy_tensor_core': config = LazyTensorCoreTestConfig() - xfail_set = {} + xfail_set = LTC_XFAIL_SET # Find the selected tests, and emit a diagnostic if none are found. tests = [ diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 111fd161372..cc84872d9b7 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -154,3 +154,312 @@ "TestMultipleTensorReturn_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", } + +LTC_XFAIL_SET = { + "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", + "AddIntModule_basic", + "AllBoolFalseModule_basic", + "AllBoolTrueModule_basic", + "AnyBoolFalseModule_basic", + "AnyBoolTrueModule_basic", + "ArangeDtypeFloatModule_basic", + "ArangeDtypeIntModule_basic", + "ArangeFalsePinMemoryModule_basic", + "ArangeFloatModule_basic", + "ArangeIntModule_basic", + "ArangeNegativeStartFloatModule_basic", + "ArangeNegativeStartIntModule_basic", + "ArangeStartFloatModule_basic", + "ArangeStartIntModule_basic", + "ArangeStartNegativeStepFloatModule_basic", + "ArangeStartNegativeStepIntModule_basic", + "ArangeStartStepFloatModule_basic", + "ArangeStartStepIntModule_basic", + "ArangeZeroElementOutputModule_basic", + "AvgPool2dCeilModeTrueModule_basic", + "AvgPool2dDivisorOverrideModule_basic", + "AvgPool2dFloatModule_basic", + "AvgPool2dIntModule_basic", + "AvgPool2dStaticModule_basic", + "BernoulliFloatModule_basic", + "BernoulliModule_basic", + "BernoulliOnesModule_basic", + "BernoulliTensorModule_basic", + "BernoulliZerosModule_basic", + "BincountMinlengthModule_basic", + "BincountModule_basic", + "BincountStaticSizeModule_basic", + "BoolFloatConstantModule_basic", + "BoolFloatFalseModule_basic", + "BoolFloatTrueModule_basic", + "BoolIntConstantModule_basic", + "BoolIntFalseModule_basic", + "BoolIntTrueModule_basic", + "CeilFloatModule_basic", + "DivFloatModule_basic", + "DropoutTrainModule_basic", + "ElementwiseAtenLogicalOrOpBrodcastModule_basic", + "ElementwiseAtenLogicalOrOpDiffArgs1Module_basic", + "ElementwiseAtenLogicalOrOpDiffArgs2Module_basic", + "ElementwiseAtenLogicalOrOpDiffArgs3Module_basic", + "ElementwiseAtenLogicalOrOpModule_basic", + "ElementwiseAtenLogicalOrOpNegativeModule_basic", + "ElementwiseAtenLogicalOrOpRandomFloatModule_basic", + "ElementwiseAtenLogicalOrOpRandomModule_basic", + "ElementwiseClampMaxModule_basic", + "ElementwiseClampMinModule_basic", + "ElementwiseClampModule_basic", + "ElementwiseWhereScalarModule_basic", + "ElementwiseWhereScalarOtherModule_basic", + "ElementwiseWhereScalarSelfModule_basic", + "ElementwiseWhereSelfModule_basic", + "EmptyLikeMemoryFormatModule_basic", + "EmptyLikeModule_defaultDtype", + "EmptyLikeModule_falsePinMemory", + "EmptyLikeModule_float", + "EmptyLikeModule_int", + "EmptyModule_contiguous", + "EmptyModule_defaultDtype", + "EmptyModule_falsePinMemory", + "EmptyModule_float", + "EmptyModule_int", + "EqIntModule_basic", + "Fill_TensorFloat64WithFloat32_basic", + "Fill_TensorFloat64WithFloat64_basic", + "Fill_TensorFloat64WithInt64_basic", + "FullLikeModuleDefaultDtype_basic", + "FullLikeModuleFalsePinMemory_basic", + "FullLikeModuleFloat2D_basic", + "FullLikeModuleFloat3DStatic_basic", + "FullLikeModuleFloat3D_basic", + "FullLikeModuleInt2DStatic_basic", + "FullLikeModuleInt2D_basic", + "FullLikeModuleInt3D_basic", + "FullModuleDefaultDtype_basic", + "FullModuleFalsePinMemory_basic", + "FullModuleFloat2D_basic", + "FullModuleFloat3D_basic", + "FullModuleInt2D_basic", + "FullModuleInt3D_basic", + "GeFloatIntModule_basic", + "GeFloatModule_basic", + "GtFloatIntModule_basic", + "GtIntModule_basic", + "HBC_basic", + "HardTanhIntModule_basic", + "HardTanhModule_basic", + "HardswishModule_basic", + "HardswishRandomModule_basic", + "IndexPut1DFloatAccumulateModule_basic", + "IndexPut1DFloatNonAccumulateModule_basic", + "IndexPut1DIntAccumulateModule_basic", + "IndexPut1DIntNonAccumulateModule_basic", + "IndexPut2DFloatAccumulateModule_basic", + "IndexPut2DFloatNonAccumulateModule_basic", + "IndexPut2DIntAccumulateModule_basic", + "IndexPut2DIntNonAccumulateModule_basic", + "IndexPut3DFloatAccumulateModule_basic", + "IndexPut3DFloatNonAccumulateModule_basic", + "IndexPut3DIntAccumulateModule_basic", + "IndexPut3DIntNonAccumulateModule_basic", + "IndexPutHackedTwin1DFloatAccumulateModule_basic", + "IndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin1DIntAccumulateModule_basic", + "IndexPutHackedTwin1DIntNonAccumulateModule_basic", + "IndexPutHackedTwin2DFloatAccumulateModule_basic", + "IndexPutHackedTwin2DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin2DIntAccumulateModule_basic", + "IndexPutHackedTwin2DIntNonAccumulateModule_basic", + "IndexPutHackedTwin3DFloatAccumulateModule_basic", + "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin3DIntAccumulateModule_basic", + "IndexPutHackedTwin3DIntNonAccumulateModule_basic", + "IndexPutImpl1DFloatAccumulateModule_basic", + "IndexPutImpl1DFloatNonAccumulateModule_basic", + "IndexPutImpl1DIntAccumulateModule_basic", + "IndexPutImpl1DIntNonAccumulateModule_basic", + "IndexPutImpl2DFloatAccumulateModule_basic", + "IndexPutImpl2DFloatNonAccumulateModule_basic", + "IndexPutImpl3DFloatAccumulateModule_basic", + "IndexPutImpl3DFloatNonAccumulateModule_basic", + "IndexSelectDynamicIndexSizeModule_basic", + "IndexSelectDynamicInputSizeModule_basic", + "IndexSelectDynamicModulebasic", + "IndexSelectSingleIdxModule_basic", + "IndexSelectTwoIdxModule_basic", + "IndexSelectWholeDimensionModule_basic", + "IndexSelectWholeTensorModule_basic", + "IndexTensorModule_basic", + "MaskedFillScalarDefaultModule_basic", + "MaskedFillScalarFloatValueModule_basic", + "MaskedFillScalarIntValueModule_basic", + "Matmul_dot", + "Matmul_matvec", + "Matmul_vecmat", + "MaxPool2dCeilModeTrueModule_basic", + "MaxPool2dModule_basic", + "MaxPool2dStaticModule_basic", + "MaxPool2dWith3dInputModule_basic", + "MaxPool2dWithIndicesAllNegativeValuesModule_basic", + "MaxPool2dWithIndicesAllOnesModule_basic", + "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", + "MaxPool2dWithIndicesBackwardDynamic4DModule_basic", + "MaxPool2dWithIndicesBackwardStatic3DModule_basic", + "MaxPool2dWithIndicesBackwardStatic4DModule_basic", + "MaxPool2dWithIndicesCeilModeTrueModule_basic", + "MaxPool2dWithIndicesFullSizeKernelModule_basic", + "MaxPool2dWithIndicesModule_basic", + "MaxPool2dWithIndicesNonDefaultDilationModule_basic", + "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", + "MaxPool2dWithIndicesNonDefaultParamsModule_basic", + "MaxPool2dWithIndicesNonDefaultStrideModule_basic", + "MaxPool2dWithIndicesStaticModule_basic", + "MaxPool2dWithIndicesWith3dInputModule_basic", + "MeanDimAllReduceKeepdimModule_basic", + "MeanDimAllReduceModule_basic", + "MeanDimDtypeModule_basic", + "MeanDimKeepdimModule_basic", + "MeanDimModule_basic", + "MeanDimNegativeModule_basic", + "MeanDtypeModule_basic", + "MeanDynamicSizesModule_basic", + "MeanModule_basic", + "MobilenetV3Module_basic", + "MulIntModule_basic", + "NativeBatchNorm1DModule_basic", + "NativeBatchNorm2DModule_basic", + "NativeBatchNorm3DModule_basic", + "NativeBatchNormNoneWeightModule_basic", + "NativeLayerNormDynamicModule_basic", + "NativeLayerNormModule_basic", + "NeFloatIntModule_basic", + "NeIntModule_basic", + "NewEmptyModuleDefaultDtype_basic", + "NewEmptyModuleFalsePinMemory_basic", + "NewEmptyModuleFloat2D_basic", + "NewEmptyModuleFloat3D_basic", + "NewEmptyModuleInt2D_basic", + "NewEmptyModuleInt3D_basic", + "NewEmptyModuleLayoutIntDtype_basic", + "NewEmptyModuleNonDefaultFloatDtype_basic", + "NewEmptyModuleNonDefaultIntDtype_basic", + "NewOnesModuleDefaultDtype_basic", + "NewOnesModuleFalsePinMemory_basic", + "NewOnesModuleFloat2D_basic", + "NewOnesModuleFloat3D_basic", + "NewOnesModuleInt2D_basic", + "NewOnesModuleInt3D_basic", + "NewZerosModuleDefaultDtype_basic", + "NewZerosModuleFalsePinMemory_basic", + "NewZerosModuleFloat2D_basic", + "NewZerosModuleFloat3D_basic", + "NewZerosModuleInt2D_basic", + "NewZerosModuleInt3D_basic", + "NllLossModuleBackward1DMeanWeight_basic", + "NllLossModuleBackward1DMean_basic", + "NllLossModuleBackward1DSumWeight_basic", + "NllLossModuleBackward1DSum_basic", + "NllLossModuleBackward1DWeight_basic", + "NllLossModuleBackward1D_basic", + "NllLossModuleBackwardMeanWeight_basic", + "NllLossModuleBackwardMean_basic", + "NllLossModuleBackwardSumWeight_basic", + "NllLossModuleBackwardSum_basic", + "NllLossModuleBackwardWeight_basic", + "NllLossModuleBackward_basic", + "NllLossModuleBackward_ignore_index", + "NllLossModule_1D_basic", + "NllLossModule_basic", + "NllLossModule_ignore_index_out_of_bounds_basic", + "NllLossModule_mean_basic", + "NllLossModule_sum_basic", + "NumelModule_basic", + "NumelZeroRankModule_basic", + "OnesLikeModule_defaultDtype", + "OnesLikeModule_falsePinMemory", + "OnesLikeModule_float", + "OnesLikeModule_int", + "OnesModuleDefaultDtype_basic", + "OnesModuleFalsePinMemory_basic", + "OnesModuleFloat_basic", + "OnesModuleInt_basic", + "QuantizedMLP_basic", + "RandLikeDtypeModule_basic", + "RandLikeModule_basic", + "ReduceMaxKeepDimReturnBoth_basic", + "ReduceMaxNegativeDim_basic", + "ReshapeAliasCollapseModule_basic", + "ReshapeAliasExpandModule_basic", + "ReturnThreeTensorFloat32_basic", + "ReturnTwoTensorF32I64_basic", + "ScalarImplicitFloatModule_basic", + "ScalarImplicitIntModule_basic", + "SelectIntModule_basic", + "SliceEndSleStartModule_basic", + "SliceNegIdxModule_basic", + "SliceOutOfLowerBoundEndIndexModule_basic", + "SliceOutOfLowerBoundStartIndexModule_basic", + "SliceOutOfUpperBoundIndexModule_basic", + "SliceSingleIdxModule_basic", + "SliceSizeTwoStepModule_basic", + "SliceStartEqEndModule_basic", + "SliceWholeTensorModule_basic", + "SqrtIntConstantModule_basic", + "SqrtIntModule_basic", + "StdBiasedModule_basic", + "StdUnbiasedModule_basic", + "SubFloatModule_basic", + "SubIntModule_basic", + "TModuleRank0_basic", + "TModuleRank1_basic", + "TableBatchEmbeddingModule_basic", + "TensorToBoolZeroRank_basic", + "TensorToBool_basic", + "TensorToFloatZeroRank_basic", + "TensorToFloat_basic", + "TensorToIntZeroRank_basic", + "TensorToInt_basic", + "TensorsConcatModule_basic", + "TestMultipleTensorAndPrimitiveTypesReturn_basic", + "TestMultipleTensorReturn_basic", + "Threshold1dFloatModule_basic", + "Threshold1dIntI32Module_basic", + "Threshold1dIntModule_basic", + "Threshold2dFloatModule_basic", + "Threshold2dIntModule_basic", + "Threshold3dFloatModule_basic", + "Threshold3dIntModule_basic", + "ThresholdBackward1dFloatModule_basic", + "ThresholdBackward1dIntModule_basic", + "ThresholdBackward1dMixedModule_basic", + "ThresholdBackward2dFloatModule_basic", + "ThresholdBackward2dIntModule_basic", + "ThresholdBackward2dMixedModule_basic", + "ThresholdBackward3dFloatModule_basic", + "ThresholdBackward3dIntModule_basic", + "ThresholdBackward3dMixedModule_basic", + "TorchPrimLoopForLikeModule_basic", + "TorchPrimLoopWhileLikeModule_basic", + "UniformModule_basic", + "UniformStaticModule_basic", + "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", + "VarBiasedModule_basic", + "VarUnbiasedModule_basic", + "ViewCollapseDynamicWithAtenSizeIntModule_basic", + "ZeroFloat32Module_basic", + "ZeroInt32Module_basic", + "ZeroInt64Module_basic", + "ZerosLikeModule_defaultDtype", + "ZerosLikeModule_falsePinMemory", + "ZerosLikeModule_float", + "ZerosLikeModule_int", + "ZerosModuleDefaultDtype_basic", + "ZerosModuleFalsePinMemory_basic", + "ZerosModuleFloat2D_basic", + "ZerosModuleFloat3D_basic", + "ZerosModuleInt2D_basic", + "ZerosModuleInt3D_basic", +}