Skip to content

Commit

Permalink
simplify indexing of Challenges through Index trait
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-ferdinand committed Aug 3, 2023
1 parent 25964ae commit bbbc52c
Show file tree
Hide file tree
Showing 13 changed files with 145 additions and 162 deletions.
2 changes: 1 addition & 1 deletion constraint-evaluation-generator/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ fn get_binding_name<II: InputIndicator>(circuit: &ConstraintCircuit<II>) -> Toke
CircuitExpression::Input(idx) => quote!(#idx),
CircuitExpression::Challenge(challenge) => {
let challenge_ident = format_ident!("{challenge}");
quote!(challenges.get_challenge(#challenge_ident))
quote!(challenges[#challenge_ident])
}
CircuitExpression::BinaryOperation(_, _, _) => {
let node_ident = format_ident!("node_{}", circuit.id);
Expand Down
10 changes: 5 additions & 5 deletions triton-vm/src/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1315,15 +1315,15 @@ pub(crate) mod triton_stark_tests {
let ine = EvalArg::compute_terminal(
&claim.input,
EvalArg::default_initial(),
all_challenges.get_challenge(StandardInputIndeterminate),
all_challenges[StandardInputIndeterminate],
);
assert_eq!(ptie, ine, "The input evaluation arguments do not match.");

let ptoe = processor_table_last_row[OutputTableEvalArg.ext_table_index()];
let oute = EvalArg::compute_terminal(
&claim.output,
EvalArg::default_initial(),
all_challenges.get_challenge(StandardOutputIndeterminate),
all_challenges[StandardOutputIndeterminate],
);
assert_eq!(ptoe, oute, "The output evaluation arguments do not match.");
}
Expand Down Expand Up @@ -1353,20 +1353,20 @@ pub(crate) mod triton_stark_tests {
let processor_table = master_ext_table.table(ProcessorTable);
let processor_table_last_row = processor_table.slice(s![-1, ..]);
assert_eq!(
challenges.get_challenge(StandardInputTerminal),
challenges[StandardInputTerminal],
processor_table_last_row[InputTableEvalArg.ext_table_index()],
"The input terminal must match for TASM snippet #{code_idx}."
);
assert_eq!(
challenges.get_challenge(StandardOutputTerminal),
challenges[StandardOutputTerminal],
processor_table_last_row[OutputTableEvalArg.ext_table_index()],
"The output terminal must match for TASM snippet #{code_idx}."
);

let lookup_table = master_ext_table.table(LookupTable);
let lookup_table_last_row = lookup_table.slice(s![-1, ..]);
assert_eq!(
challenges.get_challenge(LookupTablePublicTerminal),
challenges[LookupTablePublicTerminal],
lookup_table_last_row[PublicEvaluationArgument.ext_table_index()],
"The lookup's terminal must match for TASM snippet #{code_idx}."
);
Expand Down
12 changes: 6 additions & 6 deletions triton-vm/src/table/cascade_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ impl CascadeTable {

let two_pow_8 = BFieldElement::new(1 << 8);

let hash_indeterminate = challenges.get_challenge(HashCascadeLookupIndeterminate);
let hash_input_weight = challenges.get_challenge(HashCascadeLookInWeight);
let hash_output_weight = challenges.get_challenge(HashCascadeLookOutWeight);
let hash_indeterminate = challenges[HashCascadeLookupIndeterminate];
let hash_input_weight = challenges[HashCascadeLookInWeight];
let hash_output_weight = challenges[HashCascadeLookOutWeight];

let lookup_indeterminate = challenges.get_challenge(CascadeLookupIndeterminate);
let lookup_input_weight = challenges.get_challenge(LookupTableInputWeight);
let lookup_output_weight = challenges.get_challenge(LookupTableOutputWeight);
let lookup_indeterminate = challenges[CascadeLookupIndeterminate];
let lookup_input_weight = challenges[LookupTableInputWeight];
let lookup_output_weight = challenges[LookupTableOutputWeight];

for row_idx in 0..base_table.nrows() {
let base_row = base_table.row(row_idx);
Expand Down
28 changes: 25 additions & 3 deletions triton-vm/src/table/challenges.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
use std::fmt::Debug;
use std::hash::Hash;
use std::ops::Index;
use std::ops::Range;
use std::ops::RangeInclusive;

use strum::EnumCount;
use strum_macros::Display;
Expand Down Expand Up @@ -292,10 +295,29 @@ impl Challenges {
let stand_in_challenges = random_elements(Self::num_challenges_to_sample());
Self::new(stand_in_challenges, claim)
}
}

impl Index<ChallengeId> for Challenges {
type Output = XFieldElement;

fn index(&self, id: ChallengeId) -> &Self::Output {
&self.challenges[id.index()]
}
}

impl Index<Range<ChallengeId>> for Challenges {
type Output = [XFieldElement];

fn index(&self, indices: Range<ChallengeId>) -> &Self::Output {
&self.challenges[indices.start.index()..indices.end.index()]
}
}

impl Index<RangeInclusive<ChallengeId>> for Challenges {
type Output = [XFieldElement];

#[inline(always)]
pub fn get_challenge(&self, id: ChallengeId) -> XFieldElement {
self.challenges[id.index()]
fn index(&self, indices: RangeInclusive<ChallengeId>) -> &Self::Output {
&self.challenges[indices.start().index()..=indices.end().index()]
}
}

Expand Down
2 changes: 1 addition & 1 deletion triton-vm/src/table/constraint_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ impl<II: InputIndicator> ConstraintCircuit<II> {
XConstant(xfe) => xfe,
BConstant(bfe) => bfe.lift(),
Input(input) => input.evaluate(base_table, ext_table),
Challenge(challenge_id) => challenges.get_challenge(challenge_id),
Challenge(challenge_id) => challenges[challenge_id],
BinaryOperation(binop, lhs, rhs) => {
let lhs_value = lhs
.as_ref()
Expand Down
33 changes: 10 additions & 23 deletions triton-vm/src/table/hash_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1580,13 +1580,12 @@ impl HashTable {
assert_eq!(EXT_WIDTH, ext_table.ncols());
assert_eq!(base_table.nrows(), ext_table.nrows());

let ci_weight = challenges.get_challenge(HashCIWeight);
let hash_digest_eval_indeterminate = challenges.get_challenge(HashDigestIndeterminate);
let hash_input_eval_indeterminate = challenges.get_challenge(HashInputIndeterminate);
let sponge_eval_indeterminate = challenges.get_challenge(SpongeIndeterminate);
let cascade_indeterminate = challenges.get_challenge(HashCascadeLookupIndeterminate);
let send_chunk_indeterminate =
challenges.get_challenge(ProgramAttestationSendChunkIndeterminate);
let ci_weight = challenges[HashCIWeight];
let hash_digest_eval_indeterminate = challenges[HashDigestIndeterminate];
let hash_input_eval_indeterminate = challenges[HashInputIndeterminate];
let sponge_eval_indeterminate = challenges[SpongeIndeterminate];
let cascade_indeterminate = challenges[HashCascadeLookupIndeterminate];
let send_chunk_indeterminate = challenges[ProgramAttestationSendChunkIndeterminate];

let mut hash_input_running_evaluation = EvalArg::default_initial();
let mut hash_digest_running_evaluation = EvalArg::default_initial();
Expand Down Expand Up @@ -1675,19 +1674,7 @@ impl HashTable {
]
};

let state_weights = [
challenges.get_challenge(HashStateWeight0),
challenges.get_challenge(HashStateWeight1),
challenges.get_challenge(HashStateWeight2),
challenges.get_challenge(HashStateWeight3),
challenges.get_challenge(HashStateWeight4),
challenges.get_challenge(HashStateWeight5),
challenges.get_challenge(HashStateWeight6),
challenges.get_challenge(HashStateWeight7),
challenges.get_challenge(HashStateWeight8),
challenges.get_challenge(HashStateWeight9),
];

let state_weights = &challenges[HashStateWeight0..HashStateWeight10];
let compressed_row = |row: ArrayView1<BFieldElement>| -> XFieldElement {
rate_registers(row)
.iter()
Expand All @@ -1696,8 +1683,8 @@ impl HashTable {
.sum()
};

let cascade_look_in_weight = challenges.get_challenge(HashCascadeLookInWeight);
let cascade_look_out_weight = challenges.get_challenge(HashCascadeLookOutWeight);
let cascade_look_in_weight = challenges[HashCascadeLookInWeight];
let cascade_look_out_weight = challenges[HashCascadeLookOutWeight];

let log_derivative_summand =
|row: ArrayView1<BFieldElement>,
Expand Down Expand Up @@ -1725,7 +1712,7 @@ impl HashTable {
let compressed_chunk_of_instructions = EvalArg::compute_terminal(
&rate_registers(row),
EvalArg::default_initial(),
challenges.get_challenge(ProgramAttestationPrepareChunkIndeterminate),
challenges[ProgramAttestationPrepareChunkIndeterminate],
);
receive_chunk_running_evaluation = receive_chunk_running_evaluation
* send_chunk_indeterminate
Expand Down
14 changes: 7 additions & 7 deletions triton-vm/src/table/jump_stack_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,14 +318,14 @@ impl JumpStackTable {
assert_eq!(EXT_WIDTH, ext_table.ncols());
assert_eq!(base_table.nrows(), ext_table.nrows());

let clk_weight = challenges.get_challenge(JumpStackClkWeight);
let ci_weight = challenges.get_challenge(JumpStackCiWeight);
let jsp_weight = challenges.get_challenge(JumpStackJspWeight);
let jso_weight = challenges.get_challenge(JumpStackJsoWeight);
let jsd_weight = challenges.get_challenge(JumpStackJsdWeight);
let perm_arg_indeterminate = challenges.get_challenge(JumpStackIndeterminate);
let clk_weight = challenges[JumpStackClkWeight];
let ci_weight = challenges[JumpStackCiWeight];
let jsp_weight = challenges[JumpStackJspWeight];
let jso_weight = challenges[JumpStackJsoWeight];
let jsd_weight = challenges[JumpStackJsdWeight];
let perm_arg_indeterminate = challenges[JumpStackIndeterminate];
let clock_jump_difference_lookup_indeterminate =
challenges.get_challenge(ClockJumpDifferenceLookupIndeterminate);
challenges[ClockJumpDifferenceLookupIndeterminate];

let mut running_product = PermArg::default_initial();
let mut clock_jump_diff_lookup_log_derivative = LookupArg::default_initial();
Expand Down
8 changes: 4 additions & 4 deletions triton-vm/src/table/lookup_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ impl LookupTable {
assert_eq!(EXT_WIDTH, ext_table.ncols());
assert_eq!(base_table.nrows(), ext_table.nrows());

let look_in_weight = challenges.get_challenge(LookupTableInputWeight);
let look_out_weight = challenges.get_challenge(LookupTableOutputWeight);
let cascade_indeterminate = challenges.get_challenge(CascadeLookupIndeterminate);
let public_indeterminate = challenges.get_challenge(LookupTablePublicIndeterminate);
let look_in_weight = challenges[LookupTableInputWeight];
let look_out_weight = challenges[LookupTableOutputWeight];
let cascade_indeterminate = challenges[CascadeLookupIndeterminate];
let public_indeterminate = challenges[LookupTablePublicIndeterminate];

let mut cascade_table_running_sum_log_derivative = LookupArg::default_initial();
let mut public_running_evaluation = EvalArg::default_initial();
Expand Down
12 changes: 6 additions & 6 deletions triton-vm/src/table/op_stack_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,13 @@ impl OpStackTable {
assert_eq!(EXT_WIDTH, ext_table.ncols());
assert_eq!(base_table.nrows(), ext_table.nrows());

let clk_weight = challenges.get_challenge(OpStackClkWeight);
let ib1_weight = challenges.get_challenge(OpStackIb1Weight);
let osp_weight = challenges.get_challenge(OpStackOspWeight);
let osv_weight = challenges.get_challenge(OpStackOsvWeight);
let perm_arg_indeterminate = challenges.get_challenge(OpStackIndeterminate);
let clk_weight = challenges[OpStackClkWeight];
let ib1_weight = challenges[OpStackIb1Weight];
let osp_weight = challenges[OpStackOspWeight];
let osv_weight = challenges[OpStackOsvWeight];
let perm_arg_indeterminate = challenges[OpStackIndeterminate];
let clock_jump_difference_lookup_indeterminate =
challenges.get_challenge(ClockJumpDifferenceLookupIndeterminate);
challenges[ClockJumpDifferenceLookupIndeterminate];

let mut running_product = PermArg::default_initial();
let mut clock_jump_diff_lookup_log_derivative = LookupArg::default_initial();
Expand Down
Loading

0 comments on commit bbbc52c

Please sign in to comment.