Skip to content

Commit

Permalink
feat: Option / Result helpers (#1481)
Browse files Browse the repository at this point in the history
Adds helpers for defining fallible types, so we ensure we always use the
same definition. This change does _not_ modify the serialisation format,
it's only helpers on top of `Sum`.

On the rust side:

- Adds prelude definitions for `option` and `result` types (special
cases of `Sum`s),
including constant values for `some`/`none`, `ok`/`err`.

On the python side:

- Adds an `Option` and `None_` subtype of `tys.Sum`, and
`Some`/`None`/`Ok`/`Err` value definitions (also subtypes of `val.Sum`).
These implement nicer `__repr__`, but the classes get lost on a
serialisation roundtrip. I'm not sure if we want to auto-detect the
special cases during deserialisation?
  
The names used are rusticisms. Feel free to bikeshed them.

Closes #1473

---------

Co-authored-by: Seyon Sivarajah <seyon.sivarajah@quantinuum.com>
  • Loading branch information
aborgna-q and ss2165 authored Aug 29, 2024
1 parent 3ca56f4 commit 9698420
Show file tree
Hide file tree
Showing 3 changed files with 320 additions and 26 deletions.
169 changes: 143 additions & 26 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,27 @@
//! Prelude extension - available in all contexts, defining common types,
//! operations and constants.
use itertools::Itertools;
use lazy_static::lazy_static;

use crate::extension::simple_op::MakeOpDef;
use crate::ops::constant::{CustomCheckFailure, ValueName};
use crate::ops::{ExtensionOp, OpName};
use crate::types::{FuncValueType, SumType, TypeName, TypeRV};
use crate::{
extension::{ExtensionId, TypeDefBound},
ops::constant::CustomConst,
type_row,
types::{
type_param::{TypeArg, TypeParam},
CustomType, PolyFuncTypeRV, Signature, Type, TypeBound,
},
Extension,
use crate::extension::const_fold::fold_out_row;
use crate::extension::simple_op::{
try_from_name, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
};
use crate::extension::{
ConstFold, ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDefBound,
};
use crate::ops::constant::{CustomCheckFailure, CustomConst, ValueName};
use crate::ops::{ExtensionOp, NamedOp, OpName, Value};
use crate::types::type_param::{TypeArg, TypeParam};
use crate::types::{
CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeBound,
TypeName, TypeRV, TypeRow, TypeRowRV,
};
use crate::utils::sorted_consts;
use crate::{type_row, Extension};

use strum_macros::{EnumIter, EnumString, IntoStaticStr};

use crate::{
extension::{
const_fold::fold_out_row,
simple_op::{try_from_name, MakeExtensionOp, MakeRegisteredOp, OpLoadError},
ConstFold, ExtensionSet, OpDef, SignatureError, SignatureFunc,
},
ops::{NamedOp, Value},
types::{PolyFuncType, TypeRow},
utils::sorted_consts,
};

use super::{ExtensionRegistry, SignatureFromArgs};
struct ArrayOpCustom;

Expand Down Expand Up @@ -255,8 +247,102 @@ pub const ERROR_TYPE: Type = Type::new_extension(ERROR_CUSTOM_TYPE);
pub const ERROR_TYPE_NAME: TypeName = TypeName::new_inline("error");

/// Return a Sum type with the first variant as the given type and the second an Error.
pub fn sum_with_error(ty: Type) -> SumType {
SumType::new([ty, ERROR_TYPE])
pub fn sum_with_error(ty: impl Into<TypeRowRV>) -> SumType {
either_type(ty, ERROR_TYPE)
}

/// An optional type, i.e. a Sum type with the first variant as the given type and the second as an empty tuple.
#[inline]
pub fn option_type(ty: impl Into<TypeRowRV>) -> SumType {
either_type(ty, TypeRow::new())
}

/// An "either" type, i.e. a Sum type with a "left" and a "right" variant.
///
/// When used as a fallible value, the "left" variant represents a successful computation,
/// and the "right" variant represents a failure.
#[inline]
pub fn either_type(ty_ok: impl Into<TypeRowRV>, ty_err: impl Into<TypeRowRV>) -> SumType {
SumType::new([ty_ok.into(), ty_err.into()])
}

/// A constant optional value with a given value.
///
/// See [option_type].
pub fn const_some(value: Value) -> Value {
const_some_tuple([value])
}

/// A constant optional value with a row of values.
///
/// For single values, use [const_some].
///
/// See [option_type].
pub fn const_some_tuple(values: impl IntoIterator<Item = Value>) -> Value {
const_left_tuple(values, TypeRow::new())
}

/// A constant optional value with no value.
///
/// See [option_type].
pub fn const_none(ty: impl Into<TypeRowRV>) -> Value {
const_right_tuple(ty, [])
}

/// A constant Either value with a left variant.
///
/// In fallible computations, this represents a successful result.
///
/// See [either_type].
pub fn const_left(value: Value, ty_right: impl Into<TypeRowRV>) -> Value {
const_left_tuple([value], ty_right)
}

/// A constant Either value with a row of left values.
///
/// In fallible computations, this represents a successful result.
///
/// See [either_type].
pub fn const_left_tuple(
values: impl IntoIterator<Item = Value>,
ty_right: impl Into<TypeRowRV>,
) -> Value {
let values = values.into_iter().collect_vec();
let types: TypeRowRV = values
.iter()
.map(|v| TypeRV::from(v.get_type()))
.collect_vec()
.into();
let typ = either_type(types, ty_right);
Value::sum(0, values, typ).unwrap()
}

/// A constant Either value with a right variant.
///
/// In fallible computations, this represents a failure.
///
/// See [either_type].
pub fn const_right(ty_left: impl Into<TypeRowRV>, value: Value) -> Value {
const_right_tuple(ty_left, [value])
}

/// A constant Either value with a row of right values.
///
/// In fallible computations, this represents a failure.
///
/// See [either_type].
pub fn const_right_tuple(
ty_left: impl Into<TypeRowRV>,
values: impl IntoIterator<Item = Value>,
) -> Value {
let values = values.into_iter().collect_vec();
let types: TypeRowRV = values
.iter()
.map(|v| TypeRV::from(v.get_type()))
.collect_vec()
.into();
let typ = either_type(ty_left, types);
Value::sum(1, values, typ).unwrap()
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
Expand Down Expand Up @@ -806,6 +892,8 @@ impl MakeRegisteredOp for Lift {

#[cfg(test)]
mod test {
use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::{
builder::{endo_sig, inout_sig, DFGBuilder, Dataflow, DataflowHugr},
utils::test_quantum_extension::cx_gate,
Expand Down Expand Up @@ -897,6 +985,35 @@ mod test {
b.finish_prelude_hugr_with_outputs(out.outputs()).unwrap();
}

#[test]
fn test_option() {
let typ: Type = option_type(BOOL_T).into();
let const_val1 = const_some(Value::true_val());
let const_val2 = const_none(BOOL_T);

let mut b = DFGBuilder::new(inout_sig(type_row![], vec![typ.clone(), typ])).unwrap();

let some = b.add_load_value(const_val1);
let none = b.add_load_value(const_val2);

b.finish_prelude_hugr_with_outputs([some, none]).unwrap();
}

#[test]
fn test_result() {
let typ: Type = either_type(BOOL_T, FLOAT64_TYPE).into();
let const_bool = const_left(Value::true_val(), FLOAT64_TYPE);
let const_float = const_right(BOOL_T, ConstF64::new(0.5).into());

let mut b = DFGBuilder::new(inout_sig(type_row![], vec![typ.clone(), typ])).unwrap();

let bool = b.add_load_value(const_bool);
let float = b.add_load_value(const_float);

b.finish_hugr_with_outputs([bool, float], &FLOAT_OPS_REGISTRY)
.unwrap();
}

#[test]
/// test the prelude error type and panic op.
fn test_error_type() {
Expand Down
48 changes: 48 additions & 0 deletions hugr-py/src/hugr/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from hugr.utils import ser_it

if TYPE_CHECKING:
from collections.abc import Iterable

from hugr import ext


Expand Down Expand Up @@ -303,6 +305,52 @@ def __repr__(self) -> str:
return f"Tuple{tuple(self.variant_rows[0])}"


@dataclass(eq=False)
class Option(Sum):
"""Optional tuple of elements.
Instances of this type correspond to :class:`Sum` with two variants.
The first variant is the tuple of elements, the second is empty.
"""

def __init__(self, *tys: Type):
self.variant_rows = [list(tys), []]

def __repr__(self) -> str:
return f"Option({', '.join(map(repr, self.variant_rows[0]))})"


@dataclass(eq=False)
class Either(Sum):
"""Two-variant tuple of elements.
Instances of this type correspond to :class:`Sum` with a Left and a Right variant.
In fallible contexts, the Left variant is used to represent success, and the
Right variant is used to represent failure.
Example:
>>> either = Either([Bool, Bool], [Bool])
>>> either
Either(left=[Bool, Bool], right=[Bool])
>>> str(either)
'Either((Bool, Bool), Bool)'
"""

def __init__(self, left: Iterable[Type], right: Iterable[Type]):
self.variant_rows = [list(left), list(right)]

def __repr__(self) -> str: # pragma: no cover
left, right = self.variant_rows
return f"Either(left={left}, right={right})"

def __str__(self) -> str:
left, right = self.variant_rows
left_str = left[0] if len(left) == 1 else tuple(left)
right_str = right[0] if len(right) == 1 else tuple(right)
return f"Either({left_str}, {right_str})"


@dataclass(frozen=True)
class Variable(Type):
"""A type variable with a given bound, identified by index."""
Expand Down
Loading

0 comments on commit 9698420

Please sign in to comment.