Skip to content

Commit

Permalink
feat: Added indirect const instruction (#8065)
Browse files Browse the repository at this point in the history
Adds indirect const since the AVM supports it, and uses it to reduce a
bunch bytecode sizes when initializing constant arrays.
  • Loading branch information
sirasistant authored and signorecello committed Aug 26, 2024
1 parent 8d28cb9 commit 131f403
Show file tree
Hide file tree
Showing 14 changed files with 296 additions and 57 deletions.
45 changes: 34 additions & 11 deletions avm-transpiler/src/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ pub fn brillig_to_avm(
});
}
BrilligOpcode::Const { destination, value, bit_size } => {
handle_const(&mut avm_instrs, destination, value, bit_size);
handle_const(&mut avm_instrs, destination, value, bit_size, false);
}
BrilligOpcode::IndirectConst { destination_pointer, value, bit_size } => {
handle_const(&mut avm_instrs, destination_pointer, value, bit_size, true);
}
BrilligOpcode::Mov { destination, source } => {
avm_instrs.push(generate_mov_instruction(
Expand Down Expand Up @@ -371,7 +374,7 @@ fn handle_cast(
);
avm_instrs.extend([
// We cast to Field to be able to use toradix.
generate_cast_instruction(source_offset, dest_offset, AvmTypeTag::FIELD),
generate_cast_instruction(source_offset, false, dest_offset, false, AvmTypeTag::FIELD),
// Toradix with radix 2 and 1 limb is the same as modulo 2.
// We need to insert an instruction explicitly because we want to fine-tune 'indirect'.
AvmInstruction {
Expand All @@ -386,11 +389,11 @@ fn handle_cast(
],
},
// Then we cast back to u8 (which is what we use for u1).
generate_cast_instruction(dest_offset, dest_offset, AvmTypeTag::UINT8),
generate_cast_instruction(dest_offset, false, dest_offset, false, AvmTypeTag::UINT8),
]);
} else {
let tag = tag_from_bit_size(bit_size);
avm_instrs.push(generate_cast_instruction(source_offset, dest_offset, tag));
avm_instrs.push(generate_cast_instruction(source_offset, false, dest_offset, false, tag));
}
}

Expand Down Expand Up @@ -667,30 +670,36 @@ fn handle_const(
destination: &MemoryAddress,
value: &FieldElement,
bit_size: &BitSize,
indirect: bool,
) {
let tag = tag_from_bit_size(*bit_size);
let dest = destination.to_usize() as u32;

if !matches!(tag, AvmTypeTag::FIELD) {
avm_instrs.push(generate_set_instruction(tag, dest, value.to_u128()));
avm_instrs.push(generate_set_instruction(tag, dest, value.to_u128(), indirect));
} else {
// We can't fit a field in an instruction. This should've been handled in Brillig.
let field = value;
if field.num_bits() > 128 {
panic!("SET: Field value doesn't fit in 128 bits, that's not supported!");
}
avm_instrs.extend([
generate_set_instruction(AvmTypeTag::UINT128, dest, field.to_u128()),
generate_cast_instruction(dest, dest, AvmTypeTag::FIELD),
generate_set_instruction(AvmTypeTag::UINT128, dest, field.to_u128(), indirect),
generate_cast_instruction(dest, indirect, dest, indirect, AvmTypeTag::FIELD),
]);
}
}

/// Generates an AVM SET instruction.
fn generate_set_instruction(tag: AvmTypeTag, dest: u32, value: u128) -> AvmInstruction {
fn generate_set_instruction(
tag: AvmTypeTag,
dest: u32,
value: u128,
indirect: bool,
) -> AvmInstruction {
AvmInstruction {
opcode: AvmOpcode::SET,
indirect: Some(ALL_DIRECT),
indirect: if indirect { Some(ZEROTH_OPERAND_INDIRECT) } else { Some(ALL_DIRECT) },
tag: Some(tag),
operands: vec![
// const
Expand All @@ -709,10 +718,23 @@ fn generate_set_instruction(tag: AvmTypeTag, dest: u32, value: u128) -> AvmInstr
}

/// Generates an AVM CAST instruction.
fn generate_cast_instruction(source: u32, destination: u32, dst_tag: AvmTypeTag) -> AvmInstruction {
fn generate_cast_instruction(
source: u32,
source_indirect: bool,
destination: u32,
destination_indirect: bool,
dst_tag: AvmTypeTag,
) -> AvmInstruction {
let mut indirect_flags = ALL_DIRECT;
if source_indirect {
indirect_flags |= ZEROTH_OPERAND_INDIRECT;
}
if destination_indirect {
indirect_flags |= FIRST_OPERAND_INDIRECT;
}
AvmInstruction {
opcode: AvmOpcode::CAST,
indirect: Some(ALL_DIRECT),
indirect: Some(indirect_flags),
tag: Some(dst_tag),
operands: vec![AvmOperand::U32 { value: source }, AvmOperand::U32 { value: destination }],
}
Expand Down Expand Up @@ -1107,6 +1129,7 @@ pub fn map_brillig_pcs_to_avm_pcs(brillig_bytecode: &[BrilligOpcode<FieldElement
for i in 0..brillig_bytecode.len() - 1 {
let num_avm_instrs_for_this_brillig_instr = match &brillig_bytecode[i] {
BrilligOpcode::Const { bit_size: BitSize::Field, .. } => 2,
BrilligOpcode::IndirectConst { bit_size: BitSize::Field, .. } => 2,
BrilligOpcode::Cast { bit_size: BitSize::Integer(IntegerBitSize::U1), .. } => 3,
_ => 1,
};
Expand Down
68 changes: 68 additions & 0 deletions barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,16 @@ struct BrilligOpcode {
static Const bincodeDeserialize(std::vector<uint8_t>);
};

struct IndirectConst {
Program::MemoryAddress destination_pointer;
Program::BitSize bit_size;
std::string value;

friend bool operator==(const IndirectConst&, const IndirectConst&);
std::vector<uint8_t> bincodeSerialize() const;
static IndirectConst bincodeDeserialize(std::vector<uint8_t>);
};

struct Return {
friend bool operator==(const Return&, const Return&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -748,6 +758,7 @@ struct BrilligOpcode {
CalldataCopy,
Call,
Const,
IndirectConst,
Return,
ForeignCall,
Mov,
Expand Down Expand Up @@ -6465,6 +6476,63 @@ Program::BrilligOpcode::Const serde::Deserializable<Program::BrilligOpcode::Cons

namespace Program {

inline bool operator==(const BrilligOpcode::IndirectConst& lhs, const BrilligOpcode::IndirectConst& rhs)
{
if (!(lhs.destination_pointer == rhs.destination_pointer)) {
return false;
}
if (!(lhs.bit_size == rhs.bit_size)) {
return false;
}
if (!(lhs.value == rhs.value)) {
return false;
}
return true;
}

inline std::vector<uint8_t> BrilligOpcode::IndirectConst::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<BrilligOpcode::IndirectConst>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BrilligOpcode::IndirectConst BrilligOpcode::IndirectConst::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BrilligOpcode::IndirectConst>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::BrilligOpcode::IndirectConst>::serialize(
const Program::BrilligOpcode::IndirectConst& obj, Serializer& serializer)
{
serde::Serializable<decltype(obj.destination_pointer)>::serialize(obj.destination_pointer, serializer);
serde::Serializable<decltype(obj.bit_size)>::serialize(obj.bit_size, serializer);
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
}

template <>
template <typename Deserializer>
Program::BrilligOpcode::IndirectConst serde::Deserializable<Program::BrilligOpcode::IndirectConst>::deserialize(
Deserializer& deserializer)
{
Program::BrilligOpcode::IndirectConst obj;
obj.destination_pointer = serde::Deserializable<decltype(obj.destination_pointer)>::deserialize(deserializer);
obj.bit_size = serde::Deserializable<decltype(obj.bit_size)>::deserialize(deserializer);
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const BrilligOpcode::Return& lhs, const BrilligOpcode::Return& rhs)
{
return true;
Expand Down
56 changes: 55 additions & 1 deletion noir/noir-repo/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,16 @@ namespace Program {
static Const bincodeDeserialize(std::vector<uint8_t>);
};

struct IndirectConst {
Program::MemoryAddress destination_pointer;
Program::BitSize bit_size;
std::string value;

friend bool operator==(const IndirectConst&, const IndirectConst&);
std::vector<uint8_t> bincodeSerialize() const;
static IndirectConst bincodeDeserialize(std::vector<uint8_t>);
};

struct Return {
friend bool operator==(const Return&, const Return&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -717,7 +727,7 @@ namespace Program {
static Stop bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<BinaryFieldOp, BinaryIntOp, Cast, JumpIfNot, JumpIf, Jump, CalldataCopy, Call, Const, Return, ForeignCall, Mov, ConditionalMov, Load, Store, BlackBox, Trap, Stop> value;
std::variant<BinaryFieldOp, BinaryIntOp, Cast, JumpIfNot, JumpIf, Jump, CalldataCopy, Call, Const, IndirectConst, Return, ForeignCall, Mov, ConditionalMov, Load, Store, BlackBox, Trap, Stop> value;

friend bool operator==(const BrilligOpcode&, const BrilligOpcode&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -5390,6 +5400,50 @@ Program::BrilligOpcode::Const serde::Deserializable<Program::BrilligOpcode::Cons
return obj;
}

namespace Program {

inline bool operator==(const BrilligOpcode::IndirectConst &lhs, const BrilligOpcode::IndirectConst &rhs) {
if (!(lhs.destination_pointer == rhs.destination_pointer)) { return false; }
if (!(lhs.bit_size == rhs.bit_size)) { return false; }
if (!(lhs.value == rhs.value)) { return false; }
return true;
}

inline std::vector<uint8_t> BrilligOpcode::IndirectConst::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BrilligOpcode::IndirectConst>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BrilligOpcode::IndirectConst BrilligOpcode::IndirectConst::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BrilligOpcode::IndirectConst>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::BrilligOpcode::IndirectConst>::serialize(const Program::BrilligOpcode::IndirectConst &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.destination_pointer)>::serialize(obj.destination_pointer, serializer);
serde::Serializable<decltype(obj.bit_size)>::serialize(obj.bit_size, serializer);
serde::Serializable<decltype(obj.value)>::serialize(obj.value, serializer);
}

template <>
template <typename Deserializer>
Program::BrilligOpcode::IndirectConst serde::Deserializable<Program::BrilligOpcode::IndirectConst>::deserialize(Deserializer &deserializer) {
Program::BrilligOpcode::IndirectConst obj;
obj.destination_pointer = serde::Deserializable<decltype(obj.destination_pointer)>::deserialize(deserializer);
obj.bit_size = serde::Deserializable<decltype(obj.bit_size)>::deserialize(deserializer);
obj.value = serde::Deserializable<decltype(obj.value)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const BrilligOpcode::Return &lhs, const BrilligOpcode::Return &rhs) {
Expand Down
28 changes: 14 additions & 14 deletions noir/noir-repo/acvm-repo/acir/tests/test_program_serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,11 @@ fn simple_brillig_foreign_call() {
let bytes = Program::serialize_program(&program);

let expected_serialization: Vec<u8> = vec![
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 80, 49, 10, 192, 32, 12, 52, 45, 45, 165, 155, 63,
209, 31, 248, 25, 7, 23, 7, 17, 223, 175, 96, 2, 65, 162, 139, 30, 132, 203, 221, 65, 72,
2, 170, 227, 107, 5, 216, 63, 200, 164, 57, 200, 115, 200, 102, 15, 22, 206, 205, 50, 124,
223, 107, 108, 128, 155, 106, 113, 217, 141, 252, 10, 25, 225, 103, 121, 136, 197, 167,
188, 250, 213, 76, 75, 158, 22, 178, 10, 176, 188, 242, 119, 164, 1, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 80, 49, 10, 192, 32, 12, 52, 45, 45, 133, 110, 190,
68, 127, 224, 103, 28, 92, 28, 68, 124, 191, 130, 9, 4, 137, 46, 122, 16, 46, 119, 7, 33,
9, 168, 142, 175, 21, 96, 255, 32, 147, 230, 32, 207, 33, 155, 61, 88, 56, 55, 203, 240,
125, 175, 177, 1, 110, 170, 197, 101, 55, 242, 43, 100, 132, 159, 229, 33, 22, 159, 242,
234, 87, 51, 45, 121, 90, 200, 42, 48, 209, 35, 111, 164, 1, 0, 0,
];

assert_eq!(bytes, expected_serialization)
Expand Down Expand Up @@ -307,15 +307,15 @@ fn complex_brillig_foreign_call() {

let bytes = Program::serialize_program(&program);
let expected_serialization: Vec<u8> = vec![
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 75, 10, 132, 48, 12, 77, 90, 199, 145, 217,
205, 13, 6, 102, 14, 208, 241, 4, 222, 69, 220, 41, 186, 244, 248, 90, 140, 24, 159, 5, 23,
86, 208, 7, 37, 253, 228, 243, 146, 144, 50, 77, 200, 198, 197, 178, 127, 136, 52, 34, 253,
189, 165, 53, 102, 221, 66, 164, 59, 134, 63, 199, 243, 229, 206, 226, 104, 110, 192, 209,
158, 192, 145, 84, 255, 47, 216, 239, 152, 125, 137, 90, 63, 27, 152, 159, 132, 166, 249,
74, 229, 252, 20, 153, 97, 161, 189, 145, 161, 237, 224, 173, 128, 19, 235, 189, 126, 192,
17, 97, 4, 177, 75, 162, 101, 154, 187, 84, 113, 97, 136, 255, 82, 89, 150, 109, 211, 213,
85, 111, 65, 21, 233, 126, 213, 254, 7, 239, 12, 118, 104, 171, 161, 63, 176, 144, 46, 7,
244, 246, 124, 191, 105, 41, 241, 92, 246, 1, 235, 222, 207, 212, 69, 5, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 75, 10, 132, 48, 12, 77, 90, 199, 17, 102, 55,
39, 24, 152, 57, 64, 199, 19, 120, 23, 113, 167, 232, 210, 227, 107, 49, 98, 124, 22, 92,
88, 65, 31, 148, 244, 147, 207, 75, 66, 202, 52, 33, 27, 23, 203, 254, 33, 210, 136, 244,
247, 150, 214, 152, 117, 11, 145, 238, 24, 254, 28, 207, 151, 59, 139, 163, 185, 1, 71,
123, 2, 71, 82, 253, 191, 96, 191, 99, 246, 37, 106, 253, 108, 96, 126, 18, 154, 230, 43,
149, 243, 83, 100, 134, 133, 246, 70, 134, 182, 131, 183, 2, 78, 172, 247, 250, 1, 71, 132,
17, 196, 46, 137, 150, 105, 238, 82, 197, 133, 33, 254, 75, 101, 89, 182, 77, 87, 87, 189,
5, 85, 164, 251, 85, 251, 31, 188, 51, 216, 161, 173, 134, 254, 192, 66, 186, 28, 208, 219,
243, 253, 166, 165, 196, 115, 217, 7, 253, 216, 100, 109, 69, 5, 0, 0,
];

assert_eq!(bytes, expected_serialization)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ import { WitnessMap } from '@noir-lang/acvm_js';

// See `complex_brillig_foreign_call` integration test in `acir/tests/test_program_serialization.rs`.
export const bytecode = Uint8Array.from([
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 75, 10, 132, 48, 12, 77, 90, 199, 145, 217, 205, 13, 6, 102, 14, 208, 241,
4, 222, 69, 220, 41, 186, 244, 248, 90, 140, 24, 159, 5, 23, 86, 208, 7, 37, 253, 228, 243, 146, 144, 50, 77, 200,
198, 197, 178, 127, 136, 52, 34, 253, 189, 165, 53, 102, 221, 66, 164, 59, 134, 63, 199, 243, 229, 206, 226, 104, 110,
192, 209, 158, 192, 145, 84, 255, 47, 216, 239, 152, 125, 137, 90, 63, 27, 152, 159, 132, 166, 249, 74, 229, 252, 20,
153, 97, 161, 189, 145, 161, 237, 224, 173, 128, 19, 235, 189, 126, 192, 17, 97, 4, 177, 75, 162, 101, 154, 187, 84,
113, 97, 136, 255, 82, 89, 150, 109, 211, 213, 85, 111, 65, 21, 233, 126, 213, 254, 7, 239, 12, 118, 104, 171, 161,
63, 176, 144, 46, 7, 244, 246, 124, 191, 105, 41, 241, 92, 246, 1, 235, 222, 207, 212, 69, 5, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 75, 10, 132, 48, 12, 77, 90, 199, 17, 102, 55, 39, 24, 152, 57, 64, 199,
19, 120, 23, 113, 167, 232, 210, 227, 107, 49, 98, 124, 22, 92, 88, 65, 31, 148, 244, 147, 207, 75, 66, 202, 52, 33,
27, 23, 203, 254, 33, 210, 136, 244, 247, 150, 214, 152, 117, 11, 145, 238, 24, 254, 28, 207, 151, 59, 139, 163, 185,
1, 71, 123, 2, 71, 82, 253, 191, 96, 191, 99, 246, 37, 106, 253, 108, 96, 126, 18, 154, 230, 43, 149, 243, 83, 100,
134, 133, 246, 70, 134, 182, 131, 183, 2, 78, 172, 247, 250, 1, 71, 132, 17, 196, 46, 137, 150, 105, 238, 82, 197,
133, 33, 254, 75, 101, 89, 182, 77, 87, 87, 189, 5, 85, 164, 251, 85, 251, 31, 188, 51, 216, 161, 173, 134, 254, 192,
66, 186, 28, 208, 219, 243, 253, 166, 165, 196, 115, 217, 7, 253, 216, 100, 109, 69, 5, 0, 0,
]);
export const initialWitnessMap: WitnessMap = new Map([
[1, '0x0000000000000000000000000000000000000000000000000000000000000001'],
Expand Down
8 changes: 4 additions & 4 deletions noir/noir-repo/acvm-repo/acvm_js/test/shared/foreign_call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ import { WitnessMap } from '@noir-lang/acvm_js';

// See `simple_brillig_foreign_call` integration test in `acir/tests/test_program_serialization.rs`.
export const bytecode = Uint8Array.from([
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 80, 49, 10, 192, 32, 12, 52, 45, 45, 165, 155, 63, 209, 31, 248, 25, 7, 23, 7,
17, 223, 175, 96, 2, 65, 162, 139, 30, 132, 203, 221, 65, 72, 2, 170, 227, 107, 5, 216, 63, 200, 164, 57, 200, 115,
200, 102, 15, 22, 206, 205, 50, 124, 223, 107, 108, 128, 155, 106, 113, 217, 141, 252, 10, 25, 225, 103, 121, 136,
197, 167, 188, 250, 213, 76, 75, 158, 22, 178, 10, 176, 188, 242, 119, 164, 1, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 80, 49, 10, 192, 32, 12, 52, 45, 45, 133, 110, 190, 68, 127, 224, 103, 28, 92,
28, 68, 124, 191, 130, 9, 4, 137, 46, 122, 16, 46, 119, 7, 33, 9, 168, 142, 175, 21, 96, 255, 32, 147, 230, 32, 207,
33, 155, 61, 88, 56, 55, 203, 240, 125, 175, 177, 1, 110, 170, 197, 101, 55, 242, 43, 100, 132, 159, 229, 33, 22, 159,
242, 234, 87, 51, 45, 121, 90, 200, 42, 48, 209, 35, 111, 164, 1, 0, 0,
]);
export const initialWitnessMap: WitnessMap = new Map([
[1, '0x0000000000000000000000000000000000000000000000000000000000000005'],
Expand Down
5 changes: 5 additions & 0 deletions noir/noir-repo/acvm-repo/brillig/src/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ pub enum BrilligOpcode<F> {
bit_size: BitSize,
value: F,
},
IndirectConst {
destination_pointer: MemoryAddress,
bit_size: BitSize,
value: F,
},
Return,
/// Used to get data from an outside source.
/// Also referred to as an Oracle. However, we don't use that name as
Expand Down
Loading

0 comments on commit 131f403

Please sign in to comment.