-
Notifications
You must be signed in to change notification settings - Fork 305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Several Updates in SMT verification module (part 1) #10437
Changes from 3 commits
6cb7fb2
3910c9e
855a907
c6ab952
d59ead6
ab5fd16
82fe2bf
c474fb7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,32 +103,32 @@ size_t StandardCircuit::prepare_gates(size_t cursor) | |
// TODO(alex): Test the effect of this relaxation after the tests are merged. | ||
if (univariate_flag) { | ||
if ((q_m == 1) && (q_1 == 0) && (q_2 == 0) && (q_3 == -1) && (q_c == 0)) { | ||
(Bool(symbolic_vars[w_l]) == | ||
(Bool(this->symbolic_vars[w_l]) == | ||
Bool(STerm(0, this->solver, this->type)) | // STerm(0, this->solver, this->type)) | | ||
Bool(symbolic_vars[w_l]) == | ||
Bool(this->symbolic_vars[w_l]) == | ||
Bool(STerm(1, this->solver, this->type))) // STerm(1, this->solver, this->type))) | ||
.assert_term(); | ||
} else { | ||
this->handle_univariate_constraint(q_m, q_1, q_2, q_3, q_c, w_l); | ||
} | ||
} else { | ||
STerm eq = symbolic_vars[0]; | ||
STerm eq = this->symbolic_vars[this->variable_names_inverse["zero"]]; | ||
|
||
// mul selector | ||
if (q_m != 0) { | ||
eq += symbolic_vars[w_l] * symbolic_vars[w_r] * q_m; | ||
eq += this->symbolic_vars[w_l] * this->symbolic_vars[w_r] * q_m; | ||
} | ||
// left selector | ||
if (q_1 != 0) { | ||
eq += symbolic_vars[w_l] * q_1; | ||
eq += this->symbolic_vars[w_l] * q_1; | ||
} | ||
// right selector | ||
if (q_2 != 0) { | ||
eq += symbolic_vars[w_r] * q_2; | ||
eq += this->symbolic_vars[w_r] * q_2; | ||
} | ||
// out selector | ||
if (q_3 != 0) { | ||
eq += symbolic_vars[w_o] * q_3; | ||
eq += this->symbolic_vars[w_o] * q_3; | ||
} | ||
// constant selector | ||
if (q_c != 0) { | ||
|
@@ -157,7 +157,7 @@ void StandardCircuit::handle_univariate_constraint( | |
bb::fr b = q_1 + q_2 + q_3; | ||
|
||
if (q_m == 0) { | ||
symbolic_vars[w] == -q_c / b; | ||
this->symbolic_vars[w] == -q_c / b; | ||
return; | ||
} | ||
|
||
|
@@ -169,10 +169,10 @@ void StandardCircuit::handle_univariate_constraint( | |
bb::fr x2 = (-b - d.second) / (bb::fr(2) * q_m); | ||
|
||
if (d.second == 0) { | ||
symbolic_vars[w] == STerm(x1, this->solver, type); | ||
this->symbolic_vars[w] == STerm(x1, this->solver, type); | ||
} else { | ||
((Bool(symbolic_vars[w]) == Bool(STerm(x1, this->solver, this->type))) | | ||
(Bool(symbolic_vars[w]) == Bool(STerm(x2, this->solver, this->type)))) | ||
((Bool(this->symbolic_vars[w]) == Bool(STerm(x1, this->solver, this->type))) | | ||
(Bool(this->symbolic_vars[w]) == Bool(STerm(x2, this->solver, this->type)))) | ||
.assert_term(); | ||
} | ||
} | ||
|
@@ -285,8 +285,6 @@ size_t StandardCircuit::handle_logic_constraint(size_t cursor) | |
} | ||
} | ||
|
||
// TODO(alex): Figure out if I need to create range constraint here too or it'll be | ||
// created anyway in any circuit | ||
if (res != static_cast<size_t>(-1)) { | ||
CircuitProps xor_props = get_standard_logic_circuit(res, true); | ||
CircuitProps and_props = get_standard_logic_circuit(res, false); | ||
|
@@ -307,6 +305,45 @@ size_t StandardCircuit::handle_logic_constraint(size_t cursor) | |
STerm right = this->symbolic_vars[right_idx]; | ||
STerm out = this->symbolic_vars[out_idx]; | ||
|
||
// Simulate the logic constraint circuit using the bitwise operations | ||
size_t num_bits = res; | ||
size_t processed_gates = 0; | ||
for (size_t i = num_bits - 1; i < num_bits; i -= 2) { | ||
// 8 here is the number of gates we have to skip to get proper indices | ||
processed_gates += 8; | ||
uint32_t left_quad_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
uint32_t left_lo_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][1]]; | ||
uint32_t left_hi_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]]; | ||
processed_gates += 1; | ||
uint32_t right_quad_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
uint32_t right_lo_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][1]]; | ||
uint32_t right_hi_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]]; | ||
processed_gates += 1; | ||
uint32_t out_quad_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
uint32_t out_lo_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][1]]; | ||
uint32_t out_hi_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]]; | ||
processed_gates += 1; | ||
uint32_t old_left_acc_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
processed_gates += 1; | ||
uint32_t old_right_acc_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
processed_gates += 1; | ||
uint32_t old_out_acc_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
processed_gates += 1; | ||
|
||
this->symbolic_vars[old_left_acc_idx] == (left >> static_cast<uint32_t>(i - 1)); | ||
this->symbolic_vars[left_quad_idx] == (this->symbolic_vars[old_left_acc_idx] & 3); | ||
this->symbolic_vars[left_lo_idx] == (this->symbolic_vars[left_quad_idx] & 1); | ||
this->symbolic_vars[left_hi_idx] == (this->symbolic_vars[left_quad_idx] >> 1); | ||
this->symbolic_vars[old_right_acc_idx] == (right >> static_cast<uint32_t>(i - 1)); | ||
this->symbolic_vars[right_quad_idx] == (this->symbolic_vars[old_right_acc_idx] & 3); | ||
this->symbolic_vars[right_lo_idx] == (this->symbolic_vars[right_quad_idx] & 1); | ||
this->symbolic_vars[right_hi_idx] == (this->symbolic_vars[right_quad_idx] >> 1); | ||
this->symbolic_vars[old_out_acc_idx] == (out >> static_cast<uint32_t>(i - 1)); | ||
this->symbolic_vars[out_quad_idx] == (this->symbolic_vars[old_out_acc_idx] & 3); | ||
this->symbolic_vars[out_lo_idx] == (this->symbolic_vars[out_quad_idx] & 1); | ||
this->symbolic_vars[out_hi_idx] == (this->symbolic_vars[out_quad_idx] >> 1); | ||
} | ||
|
||
if (logic_flag) { | ||
(left ^ right) == out; | ||
} else { | ||
|
@@ -422,19 +459,41 @@ size_t StandardCircuit::handle_range_constraint(size_t cursor) | |
// we need this because even right shifts do not create | ||
// any additional gates and therefore are undetectible | ||
|
||
// TODO(alex): I think I should simulate the whole subcircuit at that point | ||
// Otherwise optimized out variables are not correct in the final witness | ||
// And I can't fix them by hand each time | ||
size_t num_accs = range_props.gate_idxs.size() - 1; | ||
for (size_t j = 1; j < num_accs + 1 && (this->type == TermType::BVTerm); j++) { | ||
size_t acc_gate = range_props.gate_idxs[j]; | ||
uint32_t acc_gate_idx = range_props.idxs[j]; | ||
// Simulate the range constraint circuit using the bitwise operations | ||
size_t num_bits = res; | ||
size_t num_quads = num_bits >> 1; | ||
num_quads += num_bits & 1; | ||
uint32_t processed_gates = 0; | ||
|
||
for (size_t i = num_quads - 1; i < num_quads; i--) { | ||
uint32_t lo_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]]; | ||
processed_gates += 1; | ||
uint32_t quad_idx = 0; | ||
uint32_t old_accumulator_idx = 0; | ||
uint32_t hi_idx = 0; | ||
|
||
if (i == num_quads - 1 && ((num_bits & 1) == 1)) { | ||
quad_idx = lo_idx; | ||
} else { | ||
hi_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]]; | ||
processed_gates += 1; | ||
quad_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
processed_gates += 1; | ||
} | ||
|
||
uint32_t acc_idx = this->real_variable_index[this->wires_idxs[cursor + acc_gate][acc_gate_idx]]; | ||
if (i == num_quads - 1) { | ||
old_accumulator_idx = quad_idx; | ||
} else { | ||
old_accumulator_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
processed_gates += 1; | ||
} | ||
|
||
this->symbolic_vars[acc_idx] == (left >> static_cast<uint32_t>(2 * j)); | ||
// I think the following is worse. The name of the variable is lost after that | ||
// this->symbolic_vars[acc_idx] = (left >> static_cast<uint32_t>(2 * j)); | ||
this->symbolic_vars[old_accumulator_idx] == (left >> static_cast<uint32_t>(2 * i)); | ||
this->symbolic_vars[quad_idx] == (this->symbolic_vars[old_accumulator_idx] & 3); | ||
this->symbolic_vars[lo_idx] == (this->symbolic_vars[quad_idx] & 1); | ||
if (i != (num_quads - 1) || ((num_bits)&1) != 1) { | ||
this->symbolic_vars[hi_idx] == (this->symbolic_vars[quad_idx] >> 1); | ||
} | ||
} | ||
|
||
left <= (bb::fr(2).pow(res) - 1); | ||
|
@@ -545,8 +604,35 @@ size_t StandardCircuit::handle_shr_constraint(size_t cursor) | |
STerm left = this->symbolic_vars[left_idx]; | ||
STerm out = this->symbolic_vars[out_idx]; | ||
|
||
STerm shled = left >> nr.second; | ||
out == shled; | ||
// Simulate the shr circuit using bitwise ops | ||
uint32_t shift = nr.second; | ||
if ((shift & 1) == 1) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where are these formulas from? An explanation wouldn't hurt. It is very hard to understand what's happening here without context |
||
size_t processed_gates = 0; | ||
uint32_t c_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]]; | ||
uint32_t delta_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
this->symbolic_vars[delta_idx] == (this->symbolic_vars[c_idx] & 3); | ||
STerm delta = this->symbolic_vars[delta_idx]; | ||
processed_gates += 1; | ||
uint32_t r0_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
|
||
// this->symbolic_vars[r0_idx] == (-2 * delta * delta + 9 * delta - 7); | ||
this->post_process.insert({ r0_idx, { delta_idx, delta_idx, -2, 9, 0, -7 } }); | ||
|
||
processed_gates += 1; | ||
uint32_t r1_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
this->symbolic_vars[r1_idx] == (delta >> 1) * 6; | ||
processed_gates += 1; | ||
uint32_t r2_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
this->symbolic_vars[r2_idx] == (left >> shift) * 6; | ||
processed_gates += 1; | ||
uint32_t temp_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
|
||
// this->symbolic_vars[temp_idx] == -6 * out; | ||
this->post_process.insert({ temp_idx, { out_idx, out_idx, 0, -6, 0, 0 } }); | ||
} | ||
|
||
STerm shred = left >> nr.second; | ||
out == shred; | ||
|
||
// You have to mark these arguments so they won't be optimized out | ||
optimized[left_idx] = false; | ||
|
@@ -652,7 +738,35 @@ size_t StandardCircuit::handle_shl_constraint(size_t cursor) | |
STerm left = this->symbolic_vars[left_idx]; | ||
STerm out = this->symbolic_vars[out_idx]; | ||
|
||
STerm shled = (left << nr.second) & (bb::fr(2).pow(nr.first) - 1); | ||
// Simulate the shr circuit using bitwise ops | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shift left |
||
uint32_t num_bits = nr.first; | ||
uint32_t shift = nr.second; | ||
if ((shift & 1) == 1) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you provide an explanation of what you are doing here? |
||
size_t processed_gates = 0; | ||
uint32_t c_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]]; | ||
uint32_t delta_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
this->symbolic_vars[delta_idx] == (this->symbolic_vars[c_idx] & 3); | ||
STerm delta = this->symbolic_vars[delta_idx]; | ||
processed_gates += 1; | ||
uint32_t r0_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
|
||
// this->symbolic_vars[r0_idx] == (-2 * delta * delta + 9 * delta - 7); | ||
this->post_process.insert({ r0_idx, { delta_idx, delta_idx, -2, 9, 0, -7 } }); | ||
|
||
processed_gates += 1; | ||
uint32_t r1_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
this->symbolic_vars[r1_idx] == (delta >> 1) * 6; | ||
processed_gates += 1; | ||
uint32_t r2_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
this->symbolic_vars[r2_idx] == (left >> (num_bits - shift)) * 6; | ||
processed_gates += 1; | ||
uint32_t temp_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
|
||
// this->symbolic_vraiables[temp_idx] == -6 * r2 | ||
this->post_process.insert({ temp_idx, { r2_idx, r2_idx, 0, -1, 0, 0 } }); | ||
} | ||
|
||
STerm shled = (left << shift) & (bb::fr(2).pow(num_bits) - 1); | ||
out == shled; | ||
|
||
// You have to mark these arguments so they won't be optimized out | ||
|
@@ -760,7 +874,35 @@ size_t StandardCircuit::handle_ror_constraint(size_t cursor) | |
STerm left = this->symbolic_vars[left_idx]; | ||
STerm out = this->symbolic_vars[out_idx]; | ||
|
||
STerm rored = ((left >> nr.second) | (left << (nr.first - nr.second))) & (bb::fr(2).pow(nr.first) - 1); | ||
// Simulate the ror circuit using bitwise ops | ||
uint32_t num_bits = nr.first; | ||
uint32_t rotation = nr.second; | ||
if ((rotation & 1) == 1) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please explain how this works |
||
size_t processed_gates = 0; | ||
uint32_t c_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]]; | ||
uint32_t delta_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
this->symbolic_vars[delta_idx] == (this->symbolic_vars[c_idx] & 3); | ||
STerm delta = this->symbolic_vars[delta_idx]; | ||
processed_gates += 1; | ||
uint32_t r0_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
|
||
// this->symbolic_vars[r0_idx] == (-2 * delta * delta + 9 * delta - 7); | ||
this->post_process.insert({ r0_idx, { delta_idx, delta_idx, -2, 9, 0, -7 } }); | ||
|
||
processed_gates += 1; | ||
uint32_t r1_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
this->symbolic_vars[r1_idx] == (delta >> 1) * 6; | ||
processed_gates += 1; | ||
uint32_t r2_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
this->symbolic_vars[r2_idx] == (left >> rotation) * 6; | ||
processed_gates += 1; | ||
uint32_t temp_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; | ||
|
||
// this->symbolic_vraiables[temp_idx] == -6 * r2 | ||
this->post_process.insert({ temp_idx, { r2_idx, r2_idx, 0, -1, 0, 0 } }); | ||
} | ||
|
||
STerm rored = ((left >> rotation) | (left << (num_bits - rotation))) & (bb::fr(2).pow(num_bits) - 1); | ||
out == rored; | ||
|
||
// You have to mark these arguments so they won't be optimized out | ||
|
@@ -909,4 +1051,4 @@ std::pair<StandardCircuit, StandardCircuit> StandardCircuit::unique_witness(Circ | |
} | ||
return { c1, c2 }; | ||
} | ||
}; // namespace smt_circuit | ||
}; // namespace smt_circuit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add a few more comments describing what you are doing?