Skip to content

Commit

Permalink
chore(compiler): Bump MLIR fork to version including extended canonic…
Browse files Browse the repository at this point in the history
…alization of tensor.insert_slice

Also, invoke the canonicalizer to include the new canonicalization
pattern after batching in order to eliminate unnecessary copies due to
redundant insertions completely overwriting empty tensors.
  • Loading branch information
andidr committed Jun 18, 2024
1 parent 2700f60 commit 570abf7
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ mlir::LogicalResult batchTFHE(mlir::MLIRContext &context,
pm, mlir::concretelang::createCollapseParallelLoops(), enablePass);
addPotentiallyNestedPass(
pm, mlir::concretelang::createBatchingPass(maxBatchSize), enablePass);
addPotentiallyNestedPass(pm, mlir::createCanonicalizerPass(), enablePass);

return pm.run(module.getOperation());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ func.func @batch_offset_extract_keyswitch(%arg0: tensor<99x2x3x4x99x99x!TFHE.glw
%c97 = arith.constant 97 : index

%0 = bufferization.alloc_tensor() : tensor<2x3x4x!TFHE.glwe<sk<1,1,750>>>
// CHECK: %[[VDROP1DIMS:.*]] = tensor.collapse_shape [[ARG:.*]] {{\[\[0, 1\], \[2\], \[3, 4, 5\]\]}} : tensor<1x2x3x4x1x1x!TFHE.glwe<sk{{\[}}[[SK_IN]]{{\]}}<1,2048>>> into tensor<2x3x4x!TFHE.glwe<sk{{\[}}[[SK_IN]]{{\]}}<1,2048>>>
// CHECK: %[[V0:.*]] = tensor.collapse_shape %[[VDROP1DIMS]] {{\[\[0, 1, 2\]\]}} : tensor<2x3x4x!TFHE.glwe<sk{{\[}}[[SK_IN]]{{\]}}<1,2048>>> into tensor<24x!TFHE.glwe<sk{{\[}}[[SK_IN]]{{\]}}<1,2048>>>
// CHECK: %[[V0:.*]] = tensor.collapse_shape %[[SLICE:.*]] {{\[\[0, 1, 2, 3, 4, 5\]\]}} : tensor<1x2x3x4x1x1x!TFHE.glwe<sk{{\[}}[[SK_IN]]{{\]}}<1,2048>>> into tensor<24x!TFHE.glwe<sk{{\[}}[[SK_IN]]{{\]}}<1,2048>>>
// CHECK: %[[V1:.*]] = "TFHE.batched_keyswitch_glwe"(%[[V0]]) {key = #TFHE<ksk{{\[}}[[KSK:.*]]{{\]}}<sk{{\[}}[[SK_IN]]{{\]}}<1,2048>, sk{{\[}}[[SK_OUT]]{{\]}}<1,750>, 3, 4>>} : (tensor<24x!TFHE.glwe<sk{{\[}}[[SK_IN]]{{\]}}<1,2048>>>) -> tensor<24x!TFHE.glwe<sk{{\[}}[[SK_OUT]]{{\]}}<1,750>>>
// CHECK: %[[V2:.*]] = tensor.expand_shape %[[V1]] {{\[\[0, 1, 2\]\]}} : tensor<24x!TFHE.glwe<sk{{\[}}[[SK_OUT]]{{\]}}<1,750>>> into tensor<2x3x4x!TFHE.glwe<sk{{\[}}[[SK_OUT]]{{\]}}<1,750>>>
// CHECK: return %[[V2]]
Expand Down Expand Up @@ -161,8 +160,7 @@ func.func @batch_offset_shifted_bounds_nonunitstep_extract_keyswitch(%arg0: tens
%0 = bufferization.alloc_tensor() : tensor<2x2x2x!TFHE.glwe<sk<1,1,750>>>

// CHECK: %[[V1:.*]] = tensor.extract_slice %arg0{{\[0, 3, 7, 9, 97, 1\] \[1, 2, 2, 2, 1, 1\] \[1, 2, 1, 7, 1, 1\]}} : tensor<99x20x30x40x99x99x!TFHE.glwe<sk{{\[}}[[SK_IN]]{{\]}}<1,2048>>> to tensor<1x2x2x2x1x1x!TFHE.glwe<sk{{\[}}[[SK_IN]]{{\]}}<1,2048>>>
// CHECK-NEXT: %[[V2:.*]] = tensor.collapse_shape %[[V1]] {{\[\[0, 1\], \[2\], \[3, 4, 5\]\]}} : tensor<1x2x2x2x1x1x!TFHE.glwe<sk{{\[}}[[SK_IN]]{{\]}}<1,2048>>> into tensor<2x2x2x!TFHE.glwe<sk{{\[}}[[SK_IN]]{{\]}}<1,2048>>>
// CHECK-NEXT: %[[V3:.*]] = tensor.collapse_shape %[[V2]] {{\[\[0, 1, 2\]\]}} : tensor<2x2x2x!TFHE.glwe<sk{{\[}}[[SK_IN]]{{\]}}<1,2048>>> into tensor<8x!TFHE.glwe<sk{{\[}}[[SK_IN]]{{\]}}<1,2048>>>
// CHECK-NEXT: %[[V3:.*]] = tensor.collapse_shape %[[V1]] {{\[\[0, 1, 2, 3, 4, 5\]\]}} : tensor<1x2x2x2x1x1x!TFHE.glwe<sk{{\[}}[[SK_IN]]{{\]}}<1,2048>>> into tensor<8x!TFHE.glwe<sk{{\[}}[[SK_IN]]{{\]}}<1,2048>>>
// CHECK-NEXT: %[[V4:.*]] = "TFHE.batched_keyswitch_glwe"(%[[V3]]) {key = #TFHE<ksk{{\[}}[[KSK:.*]]{{\]}}<sk{{\[}}[[SK_IN]]{{\]}}<1,2048>, sk{{\[}}[[SK_OUT]]{{\]}}<1,750>, 3, 4>>} : (tensor<8x!TFHE.glwe<sk{{\[}}[[SK_IN]]{{\]}}<1,2048>>>) -> tensor<8x!TFHE.glwe<sk{{\[}}[[SK_OUT]]{{\]}}<1,750>>>
// CHECK-NEXT: %[[V5:.*]] = tensor.expand_shape %[[V4]] {{\[\[0, 1, 2\]\]}} : tensor<8x!TFHE.glwe<sk{{\[}}[[SK_OUT]]{{\]}}<1,750>>> into tensor<2x2x2x!TFHE.glwe<sk{{\[}}[[SK_OUT]]{{\]}}<1,750>>>
// CHECK-NEXT: return %[[V5]] : tensor<2x2x2x!TFHE.glwe<sk{{\[}}[[SK_OUT]]{{\]}}<1,750>>>
Expand Down
2 changes: 1 addition & 1 deletion third_party/llvm-project

0 comments on commit 570abf7

Please sign in to comment.