Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement copy on write optimization for arrays in brillig #3118

Closed
wants to merge 12 commits into from
2 changes: 2 additions & 0 deletions acvm-repo/brillig/src/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ impl From<usize> for RegisterIndex {
pub struct HeapArray {
pub pointer: RegisterIndex,
pub size: usize,
pub reference_count: RegisterIndex,
jfecher marked this conversation as resolved.
Show resolved Hide resolved
}

/// A register-sized vector passed starting from a Brillig register memory location and with a register-held size
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Copy)]
pub struct HeapVector {
pub pointer: RegisterIndex,
pub size: RegisterIndex,
pub reference_count: RegisterIndex,
}

/// Lays out various ways an external foreign call's input and output data may be interpreted inside Brillig.
Expand Down
49 changes: 34 additions & 15 deletions acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,30 +221,32 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> {
"Function result size does not match brillig bytecode (expected 1 result)"
),
},
RegisterOrMemory::HeapArray(HeapArray { pointer: pointer_index, size }) => {
RegisterOrMemory::HeapArray(HeapArray { pointer, size, reference_count }) => {
match output {
ForeignCallParam::Array(values) => {
if values.len() != *size {
invalid_foreign_call_result = true;
break;
}
// Convert the destination pointer to a usize
let destination = self.registers.get(*pointer_index).to_usize();
let destination = self.registers.get(*pointer).to_usize();
// Write to our destination memory
self.memory.write_slice(destination, values);
self.registers.set(*reference_count, Value::from(1usize));
}
_ => {
unreachable!("Function result size does not match brillig bytecode size")
}
}
}
RegisterOrMemory::HeapVector(HeapVector { pointer: pointer_index, size: size_index }) => {
RegisterOrMemory::HeapVector(HeapVector { pointer, size, reference_count }) => {
match output {
ForeignCallParam::Array(values) => {
// Set our size in the size register
self.registers.set(*size_index, Value::from(values.len()));
self.registers.set(*size, Value::from(values.len()));
self.registers.set(*reference_count, Value::from(1_usize));
// Convert the destination pointer to a usize
let destination = self.registers.get(*pointer_index).to_usize();
let destination = self.registers.get(*pointer).to_usize();
// Write to our destination memory
self.memory.write_slice(destination, values);
}
Expand Down Expand Up @@ -337,16 +339,13 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> {
fn get_register_value_or_memory_values(&self, input: RegisterOrMemory) -> ForeignCallParam {
match input {
RegisterOrMemory::RegisterIndex(value_index) => self.registers.get(value_index).into(),
RegisterOrMemory::HeapArray(HeapArray { pointer: pointer_index, size }) => {
let start = self.registers.get(pointer_index);
RegisterOrMemory::HeapArray(HeapArray { pointer, size, reference_count: _ }) => {
let start = self.registers.get(pointer);
self.memory.read_slice(start.to_usize(), size).to_vec().into()
}
RegisterOrMemory::HeapVector(HeapVector {
pointer: pointer_index,
size: size_index,
}) => {
let start = self.registers.get(pointer_index);
let size = self.registers.get(size_index);
RegisterOrMemory::HeapVector(HeapVector { pointer, size, reference_count: _ }) => {
let start = self.registers.get(pointer);
let size = self.registers.get(size);
self.memory.read_slice(start.to_usize(), size.to_usize()).to_vec().into()
}
}
Expand Down Expand Up @@ -943,11 +942,14 @@ mod tests {
// Ensure the foreign call counter has been incremented
assert_eq!(vm.foreign_call_counter, 1);
}

#[test]
fn foreign_call_opcode_memory_result() {
let r_input = RegisterIndex::from(0);
let r_output = RegisterIndex::from(1);

let reference_count = RegisterIndex::from(2);

// Define a simple 2x2 matrix in memory
let initial_matrix =
vec![Value::from(1u128), Value::from(2u128), Value::from(3u128), Value::from(4u128)];
Expand All @@ -967,10 +969,12 @@ mod tests {
destinations: vec![RegisterOrMemory::HeapArray(HeapArray {
pointer: r_output,
size: initial_matrix.len(),
reference_count,
})],
inputs: vec![RegisterOrMemory::HeapArray(HeapArray {
pointer: r_input,
size: initial_matrix.len(),
reference_count,
})],
},
];
Expand Down Expand Up @@ -1008,9 +1012,11 @@ mod tests {
fn foreign_call_opcode_vector_input_and_output() {
let r_input_pointer = RegisterIndex::from(0);
let r_input_size = RegisterIndex::from(1);
let r_input_rc = RegisterIndex::from(2);
// We need to pass a location of appropriate size
let r_output_pointer = RegisterIndex::from(2);
let r_output_size = RegisterIndex::from(3);
let r_output_pointer = RegisterIndex::from(3);
let r_output_size = RegisterIndex::from(4);
let r_output_rc = RegisterIndex::from(5);

// Our first string to use the identity function with
let input_string =
Expand Down Expand Up @@ -1040,10 +1046,12 @@ mod tests {
destinations: vec![RegisterOrMemory::HeapVector(HeapVector {
pointer: r_output_pointer,
size: r_output_size,
reference_count: r_output_rc,
})],
inputs: vec![RegisterOrMemory::HeapVector(HeapVector {
pointer: r_input_pointer,
size: r_input_size,
reference_count: r_input_rc,
})],
},
];
Expand Down Expand Up @@ -1083,6 +1091,9 @@ mod tests {
let r_input = RegisterIndex::from(0);
let r_output = RegisterIndex::from(1);

let r_input_rc = RegisterIndex::from(2);
let r_output_rc = RegisterIndex::from(3);

// Define a simple 2x2 matrix in memory
let initial_matrix =
vec![Value::from(1u128), Value::from(2u128), Value::from(3u128), Value::from(4u128)];
Expand All @@ -1102,10 +1113,12 @@ mod tests {
destinations: vec![RegisterOrMemory::HeapArray(HeapArray {
pointer: r_output,
size: initial_matrix.len(),
reference_count: r_output_rc,
})],
inputs: vec![RegisterOrMemory::HeapArray(HeapArray {
pointer: r_input,
size: initial_matrix.len(),
reference_count: r_input_rc,
})],
},
];
Expand Down Expand Up @@ -1148,6 +1161,9 @@ mod tests {
let r_input_b = RegisterIndex::from(1);
let r_output = RegisterIndex::from(2);

let r_input_rc = RegisterIndex::from(3);
let r_output_rc = RegisterIndex::from(4);

// Define a simple 2x2 matrix in memory
let matrix_a =
vec![Value::from(1u128), Value::from(2u128), Value::from(3u128), Value::from(4u128)];
Expand Down Expand Up @@ -1180,15 +1196,18 @@ mod tests {
destinations: vec![RegisterOrMemory::HeapArray(HeapArray {
pointer: r_output,
size: matrix_a.len(),
reference_count: r_output_rc,
})],
inputs: vec![
RegisterOrMemory::HeapArray(HeapArray {
pointer: r_input_a,
size: matrix_a.len(),
reference_count: r_input_rc,
}),
RegisterOrMemory::HeapArray(HeapArray {
pointer: r_input_b,
size: matrix_b.len(),
reference_count: r_input_rc,
}),
],
},
Expand Down
Loading