Skip to content

Commit

Permalink
Special cased export of array, float and int constants as terms.
Browse files Browse the repository at this point in the history
  • Loading branch information
zrho committed Jan 21, 2025
1 parent be3961d commit 39dafb9
Show file tree
Hide file tree
Showing 17 changed files with 244 additions and 54 deletions.
44 changes: 44 additions & 0 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ use crate::{
extension::{ExtensionId, ExtensionSet, OpDef, SignatureFunc},
hugr::{IdentList, NodeMetadataMap},
ops::{constant::CustomSerialized, DataflowBlock, OpName, OpTrait, OpType, Value},
std_extensions::{
arithmetic::{float_types::ConstF64, int_types::ConstInt},
collections::array::ArrayValue,
},
types::{
type_param::{TypeArgVariable, TypeParam},
type_row::TypeRowBase,
Expand Down Expand Up @@ -1018,6 +1022,46 @@ impl<'a> Context<'a> {
fn export_value(&mut self, value: &'a Value) -> model::TermId {
match value {
Value::Extension { e } => {
// NOTE: We have special cased arrays, integers, and floats for now.
// TODO: Allow arbitrary extension values to be exported as terms.

if let Some(array) = e.value().downcast_ref::<ArrayValue>() {
let len = self.make_term(model::Term::Nat(array.get_contents().len() as u64));
let element_type = self.export_type(array.get_element_type());
let mut contents =
BumpVec::with_capacity_in(array.get_contents().len(), self.bump);

for element in array.get_contents() {
contents.push(model::ListPart::Item(self.export_value(element)));
}

let contents = self.make_term(model::Term::List {
parts: contents.into_bump_slice(),
});

let symbol = self.resolve_symbol(ArrayValue::CTR_NAME);
let args = self.bump.alloc_slice_copy(&[len, element_type, contents]);
return self.make_term(model::Term::ApplyFull { symbol, args });
}

if let Some(v) = e.value().downcast_ref::<ConstInt>() {
let bitwidth = self.make_term(model::Term::Nat(v.log_width() as u64));
let literal = self.make_term(model::Term::Nat(v.value_u()));

let symbol = self.resolve_symbol(ConstInt::CTR_NAME);
let args = self.bump.alloc_slice_copy(&[bitwidth, literal]);
return self.make_term(model::Term::ApplyFull { symbol, args });
}

if let Some(v) = e.value().downcast_ref::<ConstF64>() {
let literal = self.make_term(model::Term::Float {
value: v.value().into(),
});
let symbol = self.resolve_symbol(ConstF64::CTR_NAME);
let args = self.bump.alloc_slice_copy(&[literal]);
return self.make_term(model::Term::ApplyFull { symbol, args });
}

let json = match e.value().downcast_ref::<CustomSerialized>() {
Some(custom) => serde_json::to_string(custom.value()).unwrap(),
None => serde_json::to_string(e.value())
Expand Down
71 changes: 70 additions & 1 deletion hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ use crate::{
ExitBlock, FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, Module, OpType, OpaqueOp,
Output, Tag, TailLoop, Value, CFG, DFG,
},
std_extensions::{
arithmetic::{float_types::ConstF64, int_types::ConstInt},
collections::array::ArrayValue,
},
types::{
type_param::TypeParam, type_row::TypeRowBase, CustomType, FuncTypeBase, MaybeRV,
PolyFuncType, PolyFuncTypeBase, RowVariable, Signature, Type, TypeArg, TypeBase, TypeBound,
Expand Down Expand Up @@ -922,6 +926,7 @@ impl<'a> Context<'a> {
model::Term::Apply { .. } => Err(error_unsupported!("custom type as `TypeParam`")),
model::Term::ApplyFull { .. } => Err(error_unsupported!("custom type as `TypeParam`")),
model::Term::BytesType { .. } => Err(error_unsupported!("`bytes` as `TypeParam`")),
model::Term::FloatType { .. } => Err(error_unsupported!("`float` as `TypeParam`")),
model::Term::Const { .. } => Err(error_unsupported!("`(const ...)` as `TypeParam`")),
model::Term::FuncType { .. } => Err(error_unsupported!("`(fn ...)` as `TypeParam`")),

Expand All @@ -947,6 +952,7 @@ impl<'a> Context<'a> {
| model::Term::ConstFunc { .. }
| model::Term::Bytes { .. }
| model::Term::Meta
| model::Term::Float { .. }
| model::Term::ConstAdt { .. } => Err(model::ModelError::TypeError(term_id).into()),

model::Term::ControlType => {
Expand Down Expand Up @@ -1000,8 +1006,10 @@ impl<'a> Context<'a> {
model::Term::StaticType => Err(error_unsupported!("`static` as `TypeArg`")),
model::Term::ControlType => Err(error_unsupported!("`ctrl` as `TypeArg`")),
model::Term::BytesType => Err(error_unsupported!("`bytes` as `TypeArg`")),
model::Term::FloatType => Err(error_unsupported!("`float` as `TypeArg`")),
model::Term::Bytes { .. } => Err(error_unsupported!("`(bytes ..)` as `TypeArg`")),
model::Term::Const { .. } => Err(error_unsupported!("`const` as `TypeArg`")),
model::Term::Float { .. } => Err(error_unsupported!("float literal as `TypeArg`")),
model::Term::ConstAdt { .. } => Err(error_unsupported!("adt constant as `TypeArg`")),
model::Term::ConstFunc { .. } => {
Err(error_unsupported!("function constant as `TypeArg`"))
Expand Down Expand Up @@ -1136,6 +1144,8 @@ impl<'a> Context<'a> {
| model::Term::NonLinearConstraint { .. }
| model::Term::Bytes { .. }
| model::Term::BytesType
| model::Term::FloatType
| model::Term::Float { .. }
| model::Term::ConstFunc { .. }
| model::Term::Meta
| model::Term::ConstAdt { .. } => Err(model::ModelError::TypeError(term_id).into()),
Expand Down Expand Up @@ -1356,7 +1366,64 @@ impl<'a> Context<'a> {
}
}

Err(error_unsupported!("constant value that is not JSON data"))
// NOTE: We have special cased arrays, integers, and floats for now.
// TODO: Allow arbitrary extension values to be imported from terms.

if symbol_name == ArrayValue::CTR_NAME {
let element_type_term =
args.get(1).ok_or(model::ModelError::TypeError(term_id))?;
let element_type = self.import_type(*element_type_term)?;

let contents = {
let contents = args.get(2).ok_or(model::ModelError::TypeError(term_id))?;
let contents = self.import_closed_list(*contents)?;
contents
.iter()
.map(|item| self.import_value(*item, *element_type_term))
.collect::<Result<Vec<_>, _>>()?
};

return Ok(ArrayValue::new(element_type, contents).into());
}

if symbol_name == ConstInt::CTR_NAME {
let bitwidth = {
let bitwidth = args.first().ok_or(model::ModelError::TypeError(term_id))?;
let model::Term::Nat(bitwidth) = self.get_term(*bitwidth)? else {
return Err(model::ModelError::TypeError(term_id).into());
};
if *bitwidth > 6 {
return Err(model::ModelError::TypeError(term_id).into());
}
*bitwidth as u8
};

let value = {
let value = args.get(1).ok_or(model::ModelError::TypeError(term_id))?;
let model::Term::Nat(value) = self.get_term(*value)? else {
return Err(model::ModelError::TypeError(term_id).into());
};
*value
};

return Ok(ConstInt::new_u(bitwidth, value)
.map_err(|_| model::ModelError::TypeError(term_id))?
.into());
}

if symbol_name == ConstF64::CTR_NAME {
let value = {
let value = args.first().ok_or(model::ModelError::TypeError(term_id))?;
let model::Term::Float { value } = self.get_term(*value)? else {
return Err(model::ModelError::TypeError(term_id).into());
};
value.into_inner()
};

return Ok(ConstF64::new(value).into());
}

Err(error_unsupported!("unknown custom constant value"))
// TODO: This should ultimately include the following cases:
// - function definitions
// - custom constructors for values
Expand All @@ -1381,6 +1448,8 @@ impl<'a> Context<'a> {
| model::Term::Bytes { .. }
| model::Term::BytesType
| model::Term::Meta
| model::Term::Float { .. }
| model::Term::FloatType
| model::Term::NonLinearConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}
Expand Down
3 changes: 3 additions & 0 deletions hugr-core/src/std_extensions/arithmetic/float_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ impl std::ops::Deref for ConstF64 {
}

impl ConstF64 {
/// Name of the constructor for creating constant 64bit floats.
pub(crate) const CTR_NAME: &'static str = "arithmetic.float.const-f64";

/// Create a new [`ConstF64`]
pub fn new(value: f64) -> Self {
// This function can't be `const` because `is_finite()` is not yet stable as a const function.
Expand Down
3 changes: 3 additions & 0 deletions hugr-core/src/std_extensions/arithmetic/int_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ pub struct ConstInt {
}

impl ConstInt {
/// Name of the constructor for creating constant integers.
pub(crate) const CTR_NAME: &'static str = "arithmetic.int.const";

/// Create a new [`ConstInt`] with a given width and unsigned value
pub fn new_u(log_width: u8, value: u64) -> Result<Self, ConstTypeError> {
if !is_valid_log_width(log_width) {
Expand Down
3 changes: 3 additions & 0 deletions hugr-core/src/std_extensions/collections/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ pub struct ArrayValue {
}

impl ArrayValue {
/// Name of the constructor for creating constant arrays.
pub(crate) const CTR_NAME: &'static str = "collections.array.const";

/// Create a new [CustomConst] for an array of values of type `typ`.
/// That all values are of type `typ` is not checked here.
pub fn new(typ: Type, contents: impl IntoIterator<Item = Value>) -> Self {
Expand Down
43 changes: 25 additions & 18 deletions hugr-core/tests/snapshots/model__roundtrip_const.snap
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,20 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons
---
(hugr 0)

(import collections.array.array)

(import collections.array.const)

(import compat.const-json)

(import arithmetic.float.types.float64)

(import arithmetic.int.const)

(import arithmetic.int.types.int)

(import arithmetic.float.const-f64)

(define-func example.bools
[] [(adt [[] []]) (adt [[] []])] (ext)
(dfg
Expand All @@ -19,40 +29,41 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons
(define-func example.make-pair
[]
[(adt
[[(@ arithmetic.float.types.float64) (@ arithmetic.float.types.float64)]])]
[[(@ collections.array.array 5 (@ arithmetic.int.types.int 6))
(@ arithmetic.float.types.float64)]])]
(ext)
(dfg
[] [%0]
(signature
(->
[]
[(adt
[[(@ arithmetic.float.types.float64)
[[(@ collections.array.array 5 (@ arithmetic.int.types.int 6))
(@ arithmetic.float.types.float64)]])]
(ext)))
(const
(tag
0
[(@
compat.const-json
(@ arithmetic.float.types.float64)
"{\"c\":\"ConstF64\",\"v\":{\"value\":2.0}}"
(ext arithmetic.float.types))
(@
compat.const-json
(@ arithmetic.float.types.float64)
"{\"c\":\"ConstF64\",\"v\":{\"value\":3.0}}"
(ext arithmetic.float.types))])
collections.array.const
5
(@ arithmetic.int.types.int 6)
[(@ arithmetic.int.const 6 1)
(@ arithmetic.int.const 6 2)
(@ arithmetic.int.const 6 3)
(@ arithmetic.int.const 6 4)
(@ arithmetic.int.const 6 5)])
(@ arithmetic.float.const-f64 -3.0)])
[] [%0]
(signature
(->
[]
[(adt
[[(@ arithmetic.float.types.float64)
[[(@ collections.array.array 5 (@ arithmetic.int.types.int 6))
(@ arithmetic.float.types.float64)]])]
(ext))))))

(define-func example.f64
(define-func example.f64-json
[] [(@ arithmetic.float.types.float64)] (ext)
(dfg
[] [%0 %1]
Expand All @@ -62,11 +73,7 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons
[(@ arithmetic.float.types.float64) (@ arithmetic.float.types.float64)]
(ext)))
(const
(@
compat.const-json
(@ arithmetic.float.types.float64)
"{\"c\":\"ConstF64\",\"v\":{\"value\":1.0}}"
(ext arithmetic.float.types))
(@ arithmetic.float.const-f64 1.0)
[] [%0]
(signature (-> [] [(@ arithmetic.float.types.float64)] (ext))))
(const
Expand Down
1 change: 1 addition & 0 deletions hugr-model/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ capnp = "0.20.1"
derive_more = { version = "1.0.0", features = ["display"] }
fxhash.workspace = true
indexmap.workspace = true
ordered-float = "4.6.0"
pest = "2.7.12"
pest_derive = "2.7.12"
pretty = "0.12.3"
Expand Down
2 changes: 2 additions & 0 deletions hugr-model/capnp/hugr-v0.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ struct Term {
bytes @24 :Data;
bytesType @25 :Void;
meta @26 :Void;
float @27 :Float64;
floatType @28 :Void;
}

struct Apply {
Expand Down
5 changes: 5 additions & 0 deletions hugr-model/src/v0/binary/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,11 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult
data: bump.alloc_slice_copy(bytes?),
},
Which::BytesType(()) => model::Term::BytesType,

Which::Float(value) => model::Term::Float {
value: value.into(),
},
Which::FloatType(()) => model::Term::FloatType,
})
}

Expand Down
2 changes: 2 additions & 0 deletions hugr-model/src/v0/binary/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) {
model::Term::Meta => {
builder.set_meta(());
}
model::Term::Float { value } => builder.set_float(value.into_inner()),
model::Term::FloatType => builder.set_float_type(()),
}
}

Expand Down
10 changes: 10 additions & 0 deletions hugr-model/src/v0/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
//! [#1546]: https://github.com/CQCL/hugr/issues/1546
//! [#1553]: https://github.com/CQCL/hugr/issues/1553
//! [#1554]: https://github.com/CQCL/hugr/issues/1554
use ordered_float::OrderedFloat;
use smol_str::SmolStr;
use thiserror::Error;

Expand Down Expand Up @@ -718,6 +719,15 @@ pub enum Term<'a> {

/// The type of metadata.
Meta,

/// A literal floating-point number.
Float {
/// The value of the floating-point number.
value: OrderedFloat<f64>,
},

/// The type of floating-point numbers.
FloatType,
}

/// A part of a list term.
Expand Down
Loading

0 comments on commit 39dafb9

Please sign in to comment.