Skip to content

Commit

Permalink
feat: add parameter to call_data attribute (#5599)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Kernel circuits can use the databus to optimise public inputs.
In practice, we have kernel circuits calling the app kernel and
recursively the private kernel, leading to two different inputs that
need to be matched differently (e.g the private kernel input need to
match the previous private kernel output)

## Summary\*
The call_data attribute now requires a parameter: e.g `call_data(0)`
which will group all inputs with the same value into the same call_data
array. If you use several values; `call_data(0)`, `call_data(1)`, this
will generate several arrays of type call_data.


## Additional Context



## Documentation\*

Check one:
- [ ] No documentation needed.
- [ ] Documentation included in this PR.
- [X] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [X] I have tested the changes locally.
- [X] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
guipublic authored Jul 26, 2024
1 parent 0bb8372 commit e8bb341
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 73 deletions.
2 changes: 1 addition & 1 deletion compiler/noirc_driver/src/abi_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ fn to_abi_visibility(value: Visibility) -> AbiVisibility {
match value {
Visibility::Public => AbiVisibility::Public,
Visibility::Private => AbiVisibility::Private,
Visibility::DataBus => AbiVisibility::DataBus,
Visibility::CallData(_) | Visibility::ReturnData => AbiVisibility::DataBus,
}
}

Expand Down
58 changes: 35 additions & 23 deletions compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,12 @@ impl<'a> Context<'a> {
let (return_vars, return_warnings) =
self.convert_ssa_return(entry_block.unwrap_terminator(), dfg)?;

let call_data_arrays: Vec<ValueId> =
self.data_bus.call_data.iter().map(|cd| cd.array_id).collect();
for call_data_array in call_data_arrays {
self.ensure_array_is_initialized(call_data_array, dfg)?;
}

// TODO: This is a naive method of assigning the return values to their witnesses as
// we're likely to get a number of constraints which are asserting one witness to be equal to another.
//
Expand Down Expand Up @@ -1263,20 +1269,23 @@ impl<'a> Context<'a> {
let res_typ = dfg.type_of_value(results[0]);

// Get operations to call-data parameters are replaced by a get to the call-data-bus array
if let Some(call_data) = self.data_bus.call_data {
if self.data_bus.call_data_map.contains_key(&array) {
// TODO: the block_id of call-data must be notified to the backend
// TODO: should we do the same for return-data?
let type_size = res_typ.flattened_size();
let type_size =
self.acir_context.add_constant(FieldElement::from(type_size as i128));
let offset = self.acir_context.mul_var(var_index, type_size)?;
let bus_index = self
.acir_context
.add_constant(FieldElement::from(self.data_bus.call_data_map[&array] as i128));
let new_index = self.acir_context.add_var(offset, bus_index)?;
return self.array_get(instruction, call_data, new_index, dfg, index_side_effect);
}
if let Some(call_data) =
self.data_bus.call_data.iter().find(|cd| cd.index_map.contains_key(&array))
{
let type_size = res_typ.flattened_size();
let type_size = self.acir_context.add_constant(FieldElement::from(type_size as i128));
let offset = self.acir_context.mul_var(var_index, type_size)?;
let bus_index = self
.acir_context
.add_constant(FieldElement::from(call_data.index_map[&array] as i128));
let new_index = self.acir_context.add_var(offset, bus_index)?;
return self.array_get(
instruction,
call_data.array_id,
new_index,
dfg,
index_side_effect,
);
}

// Compiler sanity check
Expand Down Expand Up @@ -1707,17 +1716,20 @@ impl<'a> Context<'a> {
len: usize,
value: Option<AcirValue>,
) -> Result<(), InternalError> {
let databus = if self.data_bus.call_data.is_some()
&& self.block_id(&self.data_bus.call_data.unwrap()) == array
{
BlockType::CallData
} else if self.data_bus.return_data.is_some()
let mut databus = BlockType::Memory;
if self.data_bus.return_data.is_some()
&& self.block_id(&self.data_bus.return_data.unwrap()) == array
{
BlockType::ReturnData
} else {
BlockType::Memory
};
databus = BlockType::ReturnData;
}
for array_id in self.data_bus.call_data_array() {
if self.block_id(&array_id) == array {
assert!(databus == BlockType::Memory);
databus = BlockType::CallData;
break;
}
}

self.acir_context.initialize_array(array, len, value, databus)?;
self.initialized_arrays.insert(array);
Ok(())
Expand Down
100 changes: 71 additions & 29 deletions compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::rc::Rc;

use crate::ssa::ir::{types::Type, value::ValueId};
Expand All @@ -8,6 +9,12 @@ use noirc_frontend::hir_def::function::FunctionSignature;

use super::FunctionBuilder;

#[derive(Clone)]
pub(crate) enum DatabusVisibility {
None,
CallData(u32),
ReturnData,
}
/// Used to create a data bus, which is an array of private inputs
/// replacing public inputs
pub(crate) struct DataBusBuilder {
Expand All @@ -27,15 +34,16 @@ impl DataBusBuilder {
}
}

/// Generates a boolean vector telling which (ssa) parameter from the given function signature
/// Generates a vector telling which (ssa) parameters from the given function signature
/// are tagged with databus visibility
pub(crate) fn is_databus(main_signature: &FunctionSignature) -> Vec<bool> {
pub(crate) fn is_databus(main_signature: &FunctionSignature) -> Vec<DatabusVisibility> {
let mut params_is_databus = Vec::new();

for param in &main_signature.0 {
let is_databus = match param.2 {
ast::Visibility::Public | ast::Visibility::Private => false,
ast::Visibility::DataBus => true,
ast::Visibility::Public | ast::Visibility::Private => DatabusVisibility::None,
ast::Visibility::CallData(id) => DatabusVisibility::CallData(id),
ast::Visibility::ReturnData => DatabusVisibility::ReturnData,
};
let len = param.1.field_count() as usize;
params_is_databus.extend(vec![is_databus; len]);
Expand All @@ -44,34 +52,51 @@ impl DataBusBuilder {
}
}

#[derive(Clone, Debug)]
pub(crate) struct CallData {
pub(crate) array_id: ValueId,
pub(crate) index_map: HashMap<ValueId, usize>,
}

#[derive(Clone, Default, Debug)]
pub(crate) struct DataBus {
pub(crate) call_data: Option<ValueId>,
pub(crate) call_data_map: HashMap<ValueId, usize>,
pub(crate) call_data: Vec<CallData>,
pub(crate) return_data: Option<ValueId>,
}

impl DataBus {
/// Updates the databus values with the provided function
pub(crate) fn map_values(&self, mut f: impl FnMut(ValueId) -> ValueId) -> DataBus {
let mut call_data_map = HashMap::default();
for (k, v) in self.call_data_map.iter() {
call_data_map.insert(f(*k), *v);
}
DataBus {
call_data: self.call_data.map(&mut f),
call_data_map,
return_data: self.return_data.map(&mut f),
}
let call_data = self
.call_data
.iter()
.map(|cd| {
let mut call_data_map = HashMap::default();
for (k, v) in cd.index_map.iter() {
call_data_map.insert(f(*k), *v);
}
CallData { array_id: f(cd.array_id), index_map: call_data_map }
})
.collect();
DataBus { call_data, return_data: self.return_data.map(&mut f) }
}

pub(crate) fn call_data_array(&self) -> Vec<ValueId> {
self.call_data.iter().map(|cd| cd.array_id).collect()
}
/// Construct a databus from call_data and return_data data bus builders
pub(crate) fn get_data_bus(call_data: DataBusBuilder, return_data: DataBusBuilder) -> DataBus {
DataBus {
call_data: call_data.databus,
call_data_map: call_data.map,
return_data: return_data.databus,
pub(crate) fn get_data_bus(
call_data: Vec<DataBusBuilder>,
return_data: DataBusBuilder,
) -> DataBus {
let mut call_data_args = Vec::new();
for call_data_item in call_data {
if let Some(array_id) = call_data_item.databus {
call_data_args.push(CallData { array_id, index_map: call_data_item.map });
}
}

DataBus { call_data: call_data_args, return_data: return_data.databus }
}
}

Expand Down Expand Up @@ -129,19 +154,36 @@ impl FunctionBuilder {
}

/// Generate the data bus for call-data, based on the parameters of the entry block
/// and a boolean vector telling which ones are call-data
pub(crate) fn call_data_bus(&mut self, is_params_databus: Vec<bool>) -> DataBusBuilder {
/// and a vector telling which ones are call-data
pub(crate) fn call_data_bus(
&mut self,
is_params_databus: Vec<DatabusVisibility>,
) -> Vec<DataBusBuilder> {
//filter parameters of the first block that have call-data visibility
let first_block = self.current_function.entry_block();
let params = self.current_function.dfg[first_block].parameters();
let mut databus_param = Vec::new();
for (param, is_databus) in params.iter().zip(is_params_databus) {
if is_databus {
databus_param.push(param.to_owned());
let mut databus_param: BTreeMap<u32, Vec<ValueId>> = BTreeMap::new();
for (param, databus_attribute) in params.iter().zip(is_params_databus) {
match databus_attribute {
DatabusVisibility::None | DatabusVisibility::ReturnData => continue,
DatabusVisibility::CallData(call_data_id) => {
if let std::collections::btree_map::Entry::Vacant(e) =
databus_param.entry(call_data_id)
{
e.insert(vec![param.to_owned()]);
} else {
databus_param.get_mut(&call_data_id).unwrap().push(param.to_owned());
}
}
}
}
// create the call-data-bus from the filtered list
let call_data = DataBusBuilder::new();
self.initialize_data_bus(&databus_param, call_data)
// create the call-data-bus from the filtered lists
let mut result = Vec::new();
for id in databus_param.keys() {
let builder = DataBusBuilder::new();
let call_databus = self.initialize_data_bus(&databus_param[id], builder);
result.push(call_databus);
}
result
}
}
4 changes: 2 additions & 2 deletions compiler/noirc_evaluator/src/ssa/opt/die.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ impl Ssa {
/// of its instructions are needed elsewhere.
fn dead_instruction_elimination(function: &mut Function) {
let mut context = Context::default();
if let Some(call_data) = function.dfg.data_bus.call_data {
context.mark_used_instruction_results(&function.dfg, call_data);
for call_data in &function.dfg.data_bus.call_data {
context.mark_used_instruction_results(&function.dfg, call_data.array_id);
}

let blocks = PostOrder::with_function(function);
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub(crate) fn generate_ssa(
// see which parameter has call_data/return_data attribute
let is_databus = DataBusBuilder::is_databus(&program.main_function_signature);

let is_return_data = matches!(program.return_visibility, Visibility::DataBus);
let is_return_data = matches!(program.return_visibility, Visibility::ReturnData);

let return_location = program.return_location;
let context = SharedContext::new(program);
Expand Down
7 changes: 5 additions & 2 deletions compiler/noirc_frontend/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,15 +390,18 @@ pub enum Visibility {
Private,
/// DataBus is public input handled as private input. We use the fact that return values are properly computed by the program to avoid having them as public inputs
/// it is useful for recursion and is handled by the proving system.
DataBus,
/// The u32 value is used to group inputs having the same value.
CallData(u32),
ReturnData,
}

impl std::fmt::Display for Visibility {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Public => write!(f, "pub"),
Self::Private => write!(f, "priv"),
Self::DataBus => write!(f, "databus"),
Self::CallData(id) => write!(f, "calldata{id}"),
Self::ReturnData => write!(f, "returndata"),
}
}
}
2 changes: 2 additions & 0 deletions compiler/noirc_frontend/src/parser/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ pub enum ParserErrorReason {
Lexer(LexerErrorKind),
#[error("The only supported numeric generic types are `u1`, `u8`, `u16`, and `u32`")]
ForbiddenNumericGenericType,
#[error("Invalid call data identifier, must be a number. E.g `call_data(0)`")]
InvalidCallDataIdentifier,
}

/// Represents a parsing error, or a parsing error in the making.
Expand Down
30 changes: 20 additions & 10 deletions compiler/noirc_frontend/src/parser/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ use crate::ast::{
use crate::lexer::{lexer::from_spanned_token_result, Lexer};
use crate::parser::{force, ignore_then_commit, statement_recovery};
use crate::token::{Keyword, Token, TokenKind};
use acvm::AcirField;

use chumsky::prelude::*;
use iter_extended::vecmap;
Expand Down Expand Up @@ -645,19 +646,28 @@ where
})
}

fn call_data() -> impl NoirParser<Visibility> {
keyword(Keyword::CallData).then(parenthesized(literal())).validate(|token, span, emit| {
match token {
(_, ExpressionKind::Literal(Literal::Integer(x, _))) => {
let id = x.to_u128() as u32;
Visibility::CallData(id)
}
_ => {
emit(ParserError::with_reason(ParserErrorReason::InvalidCallDataIdentifier, span));
Visibility::CallData(0)
}
}
})
}

fn optional_visibility() -> impl NoirParser<Visibility> {
keyword(Keyword::Pub)
.or(keyword(Keyword::CallData))
.or(keyword(Keyword::ReturnData))
.map(|_| Visibility::Public)
.or(call_data())
.or(keyword(Keyword::ReturnData).map(|_| Visibility::ReturnData))
.or_not()
.map(|opt| match opt {
Some(Token::Keyword(Keyword::Pub)) => Visibility::Public,
Some(Token::Keyword(Keyword::CallData)) | Some(Token::Keyword(Keyword::ReturnData)) => {
Visibility::DataBus
}
None => Visibility::Private,
_ => unreachable!("unexpected token found"),
})
.map(|opt| opt.unwrap_or(Visibility::Private))
}

pub fn expression() -> impl ExprParser {
Expand Down
2 changes: 1 addition & 1 deletion test_programs/execution_success/databus/src/main.nr
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
fn main(mut x: u32, y: call_data u32, z: call_data [u32; 4]) -> return_data u32 {
fn main(mut x: u32, y: call_data(0) u32, z: call_data(0) [u32; 4]) -> return_data u32 {
let a = z[x];
a + foo(y)
}
Expand Down
7 changes: 4 additions & 3 deletions tooling/nargo_fmt/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,10 @@ impl HasItem for Param {
fn format(self, visitor: &FmtVisitor, shape: Shape) -> String {
let pattern = visitor.slice(self.pattern.span());
let visibility = match self.visibility {
Visibility::Public => "pub",
Visibility::Private => "",
Visibility::DataBus => "call_data",
Visibility::Public => "pub".to_string(),
Visibility::Private => "".to_string(),
Visibility::CallData(x) => format!("call_data({x})"),
Visibility::ReturnData => "return_data".to_string(),
};

if self.pattern.is_synthesized() || self.typ.is_synthesized() {
Expand Down
5 changes: 4 additions & 1 deletion tooling/nargo_fmt/src/visitor/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,11 @@ impl super::FmtVisitor<'_> {

let visibility = match func.def.return_visibility {
Visibility::Public => "pub",
Visibility::DataBus => "return_data",
Visibility::ReturnData => "return_data",
Visibility::Private => "",
Visibility::CallData(_) => {
unreachable!("call_data cannot be used for return value")
}
};
result.push_str(&append_space_if_nonempty(visibility.into()));

Expand Down

0 comments on commit e8bb341

Please sign in to comment.