Skip to content

Commit

Permalink
feat(compiler): adds support for dynamic luts in fhelinalg
Browse files Browse the repository at this point in the history
  • Loading branch information
aPere3 committed Sep 14, 2023
1 parent d16ce81 commit 46f439e
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,13 @@ mlir::LogicalResult ApplyLookupTableEintOp::verify() {
// Check the shape of lut argument
auto tEltwidth = tEltTy.getWidth();
mlir::SmallVector<int64_t, 1> expectedShape{1 << tEltwidth};
if (!lutTy.hasStaticShape(expectedShape) || !lutEltTy.isInteger(64)) {
this->emitOpError()
<< "should have as operand #2 a tensor<2^pxi64>, where p is the width "
"of the encrypted integer of the operand #1,"
<< "expect tensor <" << expectedShape[0] << "xi64>";
if (!lutTy.hasStaticShape(expectedShape) || !lutEltTy.isSignlessInteger() ||
lutEltTy.getIntOrFloatBitWidth() > 64) {
this->emitOpError() << "should have as operand #2 a "
"tensor<2^pxi{8,16,32,64}>, where p is the width "
"of the encrypted integer of the operand #1,"
<< "expect tensor <" << expectedShape[0]
<< "xi{8,16,32,64}>";
return mlir::failure();
}
if (!resultTy.hasStaticShape(tTy.getShape())) {
Expand All @@ -308,12 +310,14 @@ mlir::LogicalResult ApplyMultiLookupTableEintOp::verify() {
// Check the shape of luts argument
auto lut_size = lutTy.getShape()[lutTy.getShape().size() - 1];
auto expected_lut_size = 1 << tEltTy.getWidth();
if (lut_size != expected_lut_size || !lutEltTy.isInteger(64)) {
this->emitOpError() << "should have as operand #2 a "
"tensor<DMx...xD1X2^pxi64>, where p is the width "
"of the encrypted integer of the operand #1,"
<< "expect tensor <DMx...xD1X" << expected_lut_size
<< "xi64>";
if (lut_size != expected_lut_size || !lutEltTy.isSignlessInteger() ||
lutEltTy.getIntOrFloatBitWidth() > 64) {
this->emitOpError()
<< "should have as operand #2 a "
"tensor<DMx...xD1X2^pxi{8,16,32,64}>, where p is the width "
"of the encrypted integer of the operand #1,"
<< "expect tensor <DMx...xD1X" << expected_lut_size
<< "xi{8,16,32,64}>";
return mlir::failure();
}
if (!resultTy.hasStaticShape(tTy.getShape())) {
Expand Down Expand Up @@ -380,9 +384,14 @@ mlir::LogicalResult verifyLutsSize(ApplyMappedLookupTableEintOp &op,

mlir::LogicalResult ApplyMappedLookupTableEintOp::verify() {
auto t = this->getT();
auto tTy = this->getT().getType().cast<mlir::RankedTensorType>();
auto tEltTy =
tTy.getElementType().cast<mlir::concretelang::FHE::FheIntegerInterface>();
auto luts = this->getLuts();
auto map = this->getMap();
auto result = this->getResult();
auto lutTy = this->getLuts().getType().cast<mlir::RankedTensorType>();
auto lutEltTy = lutTy.getElementType().cast<mlir::IntegerType>();

auto t_shape = getTensorType(t).getShape();
if (!getTensorType(result).hasStaticShape(t_shape)) {
Expand All @@ -397,6 +406,17 @@ mlir::LogicalResult ApplyMappedLookupTableEintOp::verify() {
return mlir::failure();
}

auto expected_lut_size = 1 << tEltTy.getWidth();
if (!lutEltTy.isSignlessInteger() || lutEltTy.getIntOrFloatBitWidth() > 64) {
this->emitOpError()
<< "should have as operand #2 a "
"tensor<DMx...xD1X2^pxi{8,16,32,64}>, where p is the width "
"of the encrypted integer of the operand #1,"
<< "expect tensor <DMx...xD1X" << expected_lut_size
<< "xi{8,16,32,64}>";
return mlir::failure();
}

return mlir::success(verifyMapHasRightShape(*this, t, map).succeeded() &&
verifyLutsSize(*this, t, luts).succeeded());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,16 @@ func.func @main(%a0: tensor<2x3x4x!FHE.eint<2>>, %a1: tensor<2x3x4x!FHE.eint<3>>
// FHELinalg.apply_lookup_table
/////////////////////////////////////////////////

func.func @apply_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<4xi32>) -> tensor<2x3x4x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.apply_lookup_table' op should have as operand #2 a tensor<2^pxi64>, where p is the width of the encrypted integer of the operand #1,expect tensor <4xi64>}}
%1 = "FHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!FHE.eint<2>>, tensor<4xi32>) -> (tensor<2x3x4x!FHE.eint<2>>)
func.func @apply_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<4xi65>) -> tensor<2x3x4x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.apply_lookup_table' op should have as operand #2 a tensor<2^pxi{8,16,32,64}>, where p is the width of the encrypted integer of the operand #1,expect tensor <4xi{8,16,32,64}>}}
%1 = "FHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!FHE.eint<2>>, tensor<4xi65>) -> (tensor<2x3x4x!FHE.eint<2>>)
return %1: tensor<2x3x4x!FHE.eint<2>>
}

// -----

func.func @apply_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<12xi64>) -> tensor<2x3x4x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.apply_lookup_table' op should have as operand #2 a tensor<2^pxi64>, where p is the width of the encrypted integer of the operand #1,expect tensor <4xi64>}}
// expected-error @+1 {{'FHELinalg.apply_lookup_table' op should have as operand #2 a tensor<2^pxi{8,16,32,64}>, where p is the width of the encrypted integer of the operand #1,expect tensor <4xi{8,16,32,64}>}}
%1 = "FHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!FHE.eint<2>>, tensor<12xi64>) -> (tensor<2x3x4x!FHE.eint<2>>)
return %1: tensor<2x3x4x!FHE.eint<2>>
}
Expand All @@ -193,13 +193,21 @@ func.func @apply_lookup_table(%arg0: tensor<3x4x!FHE.eint<2>>, %arg1: tensor<4xi
/////////////////////////////////////////////////

func.func @apply_multi_lookup_table(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<2x6xi64>) -> tensor<2x3x4x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.apply_multi_lookup_table' op should have as operand #2 a tensor<DMx...xD1X2^pxi64>, where p is the width of the encrypted integer of the operand #1,expect tensor <DMx...xD1X4xi64>}}
// expected-error @+1 {{'FHELinalg.apply_multi_lookup_table' op should have as operand #2 a tensor<DMx...xD1X2^pxi{8,16,32,64}>, where p is the width of the encrypted integer of the operand #1,expect tensor <DMx...xD1X4xi{8,16,32,64}>}}
%1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!FHE.eint<2>>, tensor<2x6xi64>) -> (tensor<2x3x4x!FHE.eint<2>>)
return %1: tensor<2x3x4x!FHE.eint<2>>
}

// -----

func.func @apply_multi_lookup_table_bad_prec(%arg0: tensor<2x3x4x!FHE.eint<2>>, %arg1: tensor<2x4xi65>) -> tensor<2x3x4x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.apply_multi_lookup_table' op should have as operand #2 a tensor<DMx...xD1X2^pxi{8,16,32,64}>, where p is the width of the encrypted integer of the operand #1,expect tensor <DMx...xD1X4xi{8,16,32,64}>}}
%1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!FHE.eint<2>>, tensor<2x4xi65>) -> (tensor<2x3x4x!FHE.eint<2>>)
return %1: tensor<2x3x4x!FHE.eint<2>>
}

// -----


/////////////////////////////////////////////////
// FHELinalg.apply_mapped_lookup_table
Expand Down Expand Up @@ -240,6 +248,18 @@ func.func @apply_mapped_lookup_table_bad_map_elmt_type(

// -----

func.func @apply_mapped_lookup_table_bad_lut_prec(
%input: tensor<2x3x4x!FHE.eint<7>>,
%luts: tensor<128xi65>,
%map: tensor<2x3x4xindex>
) -> tensor<2x3x4x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.apply_mapped_lookup_table' op should have as operand #2 a tensor<DMx...xD1X2^pxi{8,16,32,64}>, where p is the width of the encrypted integer of the operand #1,expect tensor <DMx...xD1X128xi{8,16,32,64}>}}
%0 = "FHELinalg.apply_mapped_lookup_table"(%input, %luts, %map): (tensor<2x3x4x!FHE.eint<7>>, tensor<128xi65>, tensor<2x3x4xindex>) -> (tensor<2x3x4x!FHE.eint<7>>)
return %0: tensor<2x3x4x!FHE.eint<7>>
}

// -----

/////////////////////////////////////////////////
// FHELinalg.conv2d
/////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@

PRECISION_FORCE_CRT = 9

def get_lut_integer_type(p):
if p <= 8:
return "i8"
if p <= 16:
return "i16"
if p <= 32:
return "i32"
if p <= 64:
return "i64"
else:
raise Exception("Unexpected precision")

def generate(args):
print("# /!\ DO NOT EDIT MANUALLY THIS FILE MANUALLY")
print("# /!\ THIS FILE HAS BEEN GENERATED")
Expand All @@ -16,15 +28,15 @@ def generate(args):
for n_lut in args.n_lut:
max_value = (2 ** p) - 1
random_lut = np.random.randint(max_value+1, size=2**p)
itype = get_lut_integer_type(p)
# identity_apply_lookup_table
print(f"description: apply_lookup_table_{p}bits_{n_ct}ct_{n_lut}layer")
print("program: |")
print(
f" func.func @main(%0: tensor<{n_ct}x!FHE.eint<{p}>>) -> tensor<{n_ct}x!FHE.eint<{p}>> {{")
print(f" %tlu = arith.constant dense<[{','.join(map(str, random_lut))}]> : tensor<{2**p}xi64>")
f" func.func @main(%0: tensor<{n_ct}x!FHE.eint<{p}>>, %tlu: tensor<{2**p}x{itype}>) -> tensor<{n_ct}x!FHE.eint<{p}>> {{")
for i in range(0, n_lut):
print(f" %{i+1} = \"FHELinalg.apply_lookup_table\"(%{i}, %tlu):")
print(f" (tensor<{n_ct}x!FHE.eint<{p}>>, tensor<{2**p}xi64>) -> (tensor<{n_ct}x!FHE.eint<{p}>>)")
print(f" (tensor<{n_ct}x!FHE.eint<{p}>>, tensor<{2**p}x{itype}>) -> (tensor<{n_ct}x!FHE.eint<{p}>>)")
print(f" return %{n_lut}: tensor<{n_ct}x!FHE.eint<{p}>>")
print(" }")
if p >= PRECISION_FORCE_CRT:
Expand All @@ -35,6 +47,8 @@ def generate(args):
print(" - inputs:")
print(f" - tensor: [{','.join(map(str, random_input))}]")
print(f" shape: [{n_ct}]")
print(f" - tensor: [{','.join(map(str, random_lut))}]")
print(f" shape: [{2**p}]")
outputs = random_input
for i in range(0, n_lut):
outputs = [random_lut[v] for v in outputs]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1034,8 +1034,8 @@ program: |
// [3,0,1] lut [1,3,5,7] = [7,1,3]
// [2,3,0] [5,7,1]
func.func @main(%t: tensor<3x3x!FHE.eint<2>>) -> tensor<3x3x!FHE.eint<3>> {
%lut = arith.constant dense<[1,3,5,7]> : tensor<4xi64>
%res = "FHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!FHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!FHE.eint<3>>
%lut = arith.constant dense<[1,3,5,7]> : tensor<4xi8>
%res = "FHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!FHE.eint<2>>, tensor<4xi8>) -> tensor<3x3x!FHE.eint<3>>
return %res : tensor<3x3x!FHE.eint<3>>
}
tests:
Expand All @@ -1050,8 +1050,8 @@ tests:
description: apply_lookup_table_batched
program: |
func.func @main(%t: tensor<3x3x!FHE.eint<2>>) -> tensor<3x3x!FHE.eint<3>> {
%lut = arith.constant dense<[1,3,5,7]> : tensor<4xi64>
%res = "FHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!FHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!FHE.eint<3>>
%lut = arith.constant dense<[1,3,5,7]> : tensor<4xi8>
%res = "FHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!FHE.eint<2>>, tensor<4xi8>) -> tensor<3x3x!FHE.eint<3>>
return %res : tensor<3x3x!FHE.eint<3>>
}
tests:
Expand All @@ -1066,8 +1066,8 @@ tests:
description: apply_multi_lookup_table
program: |
// Returns the lookup of 3x3 matrix of encrypted indices of width 2 on a 3x3 matrix of tables of size 4=2² of clear integers.
func.func @main(%arg0: tensor<3x3x!FHE.eint<2>>, %arg1: tensor<3x3x4xi64>) -> tensor<3x3x!FHE.eint<2>> {
%1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<3x3x!FHE.eint<2>>, tensor<3x3x4xi64>) -> tensor<3x3x!FHE.eint<2>>
func.func @main(%arg0: tensor<3x3x!FHE.eint<2>>, %arg1: tensor<3x3x4xi8>) -> tensor<3x3x!FHE.eint<2>> {
%1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<3x3x!FHE.eint<2>>, tensor<3x3x4xi8>) -> tensor<3x3x!FHE.eint<2>>
return %1: tensor<3x3x!FHE.eint<2>>
}
tests:
Expand All @@ -1084,8 +1084,8 @@ tests:
description: apply_multi_lookup_table_with_boradcast
program: |
// Returns the lookup of 3x3 matrix of encrypted indices of width 2 on a vector of 3 tables of size 4=2² of clear integers.
func.func @main(%arg0: tensor<3x3x!FHE.eint<2>>, %arg1: tensor<3x4xi64>) -> tensor<3x3x!FHE.eint<2>> {
%1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<3x3x!FHE.eint<2>>, tensor<3x4xi64>) -> tensor<3x3x!FHE.eint<2>>
func.func @main(%arg0: tensor<3x3x!FHE.eint<2>>, %arg1: tensor<3x4xi8>) -> tensor<3x3x!FHE.eint<2>> {
%1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<3x3x!FHE.eint<2>>, tensor<3x4xi8>) -> tensor<3x3x!FHE.eint<2>>
return %1: tensor<3x3x!FHE.eint<2>>
}
tests:
Expand All @@ -1103,9 +1103,9 @@ tests:
description: apply_mapped_lookup_table_sequential
program: |
// Returns the lookup of 3x3 matrix of encrypted indices of width 2 of a 3x3 matrix of tables of size 4=2² of clear integers.
func.func @main(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor<9x4xi64>, %map: tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>> {
func.func @main(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor<9x4xi8>, %map: tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>> {
%1 = "FHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) :
(tensor<3x3x!FHE.eint<2>>, tensor<9x4xi64>, tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>>
(tensor<3x3x!FHE.eint<2>>, tensor<9x4xi8>, tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>>
return %1: tensor<3x3x!FHE.eint<2>>
}
tests:
Expand All @@ -1124,9 +1124,9 @@ tests:
description: apply_mapped_lookup_table_same_lut
program: |
// Returns the lookup of 3x3 matrix of encrypted indices of width 2 of a 3x3 matrix of tables of size 4=2² of clear integers.
func.func @main(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor<9x4xi64>, %map: tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>> {
func.func @main(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor<9x4xi8>, %map: tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>> {
%1 = "FHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) :
(tensor<3x3x!FHE.eint<2>>, tensor<9x4xi64>, tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>>
(tensor<3x3x!FHE.eint<2>>, tensor<9x4xi8>, tensor<3x3xindex>) -> tensor<3x3x!FHE.eint<2>>
return %1: tensor<3x3x!FHE.eint<2>>
}
tests:
Expand Down

0 comments on commit 46f439e

Please sign in to comment.