Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: Make Either::Right the "success" case #1489

Merged
merged 5 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 56 additions & 14 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,24 +246,24 @@ pub const ERROR_TYPE: Type = Type::new_extension(ERROR_CUSTOM_TYPE);
/// The string name of the error type.
pub const ERROR_TYPE_NAME: TypeName = TypeName::new_inline("error");

/// Return a Sum type with the first variant as the given type and the second an Error.
/// Return a Sum type with the second variant as the given type and the first an Error.
pub fn sum_with_error(ty: impl Into<TypeRowRV>) -> SumType {
either_type(ty, ERROR_TYPE)
either_type(ERROR_TYPE, ty)
}

/// An optional type, i.e. a Sum type with the first variant as the given type and the second as an empty tuple.
/// An optional type, i.e. a Sum type with the second variant as the given type and the first as an empty tuple.
#[inline]
pub fn option_type(ty: impl Into<TypeRowRV>) -> SumType {
either_type(ty, TypeRow::new())
either_type(TypeRow::new(), ty)
}

/// An "either" type, i.e. a Sum type with a "left" and a "right" variant.
///
/// When used as a fallible value, the "left" variant represents a successful computation,
/// and the "right" variant represents a failure.
/// When used as a fallible value, the "right" variant represents a successful computation,
/// and the "left" variant represents a failure.
#[inline]
pub fn either_type(ty_ok: impl Into<TypeRowRV>, ty_err: impl Into<TypeRowRV>) -> SumType {
SumType::new([ty_ok.into(), ty_err.into()])
pub fn either_type(ty_left: impl Into<TypeRowRV>, ty_right: impl Into<TypeRowRV>) -> SumType {
SumType::new([ty_left.into(), ty_right.into()])
}

/// A constant optional value with a given value.
Expand All @@ -279,19 +279,19 @@ pub fn const_some(value: Value) -> Value {
///
/// See [option_type].
pub fn const_some_tuple(values: impl IntoIterator<Item = Value>) -> Value {
const_left_tuple(values, TypeRow::new())
const_right_tuple(TypeRow::new(), values)
}

/// A constant optional value with no value.
///
/// See [option_type].
pub fn const_none(ty: impl Into<TypeRowRV>) -> Value {
const_right_tuple(ty, [])
const_left_tuple([], ty)
}

/// A constant Either value with a left variant.
///
/// In fallible computations, this represents a successful result.
/// In fallible computations, this represents a failure.
///
/// See [either_type].
pub fn const_left(value: Value, ty_right: impl Into<TypeRowRV>) -> Value {
Expand All @@ -300,7 +300,7 @@ pub fn const_left(value: Value, ty_right: impl Into<TypeRowRV>) -> Value {

/// A constant Either value with a row of left values.
///
/// In fallible computations, this represents a successful result.
/// In fallible computations, this represents a failure.
///
/// See [either_type].
pub fn const_left_tuple(
Expand All @@ -319,7 +319,7 @@ pub fn const_left_tuple(

/// A constant Either value with a right variant.
///
/// In fallible computations, this represents a failure.
/// In fallible computations, this represents a successful result.
///
/// See [either_type].
pub fn const_right(ty_left: impl Into<TypeRowRV>, value: Value) -> Value {
Expand All @@ -328,7 +328,7 @@ pub fn const_right(ty_left: impl Into<TypeRowRV>, value: Value) -> Value {

/// A constant Either value with a row of right values.
///
/// In fallible computations, this represents a failure.
/// In fallible computations, this represents a successful result.
///
/// See [either_type].
pub fn const_right_tuple(
Expand All @@ -345,6 +345,40 @@ pub fn const_right_tuple(
Value::sum(1, values, typ).unwrap()
}

/// A constant Either value with a success variant.
///
/// Alias for [const_right].
pub fn const_ok(value: Value, ty_fail: impl Into<TypeRowRV>) -> Value {
const_right(ty_fail, value)
}

/// A constant Either with a row of success values.
///
/// Alias for [const_right_tuple].
pub fn const_ok_tuple(
values: impl IntoIterator<Item = Value>,
ty_fail: impl Into<TypeRowRV>,
) -> Value {
const_right_tuple(ty_fail, values)
}

/// A constant Either value with a failure variant.
///
/// Alias for [const_left].
pub fn const_fail(value: Value, ty_ok: impl Into<TypeRowRV>) -> Value {
const_left(value, ty_ok)
}

/// A constant Either with a row of failure values.
///
/// Alias for [const_left_tuple].
pub fn const_fail_tuple(
values: impl IntoIterator<Item = Value>,
ty_ok: impl Into<TypeRowRV>,
) -> Value {
const_left_tuple(values, ty_ok)
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
/// Structure for holding constant usize values.
pub struct ConstUsize(u64);
Expand Down Expand Up @@ -397,6 +431,14 @@ impl ConstError {
message: message.to_string(),
}
}

/// Returns an "either" value with a failure variant.
///
/// args:
/// ty_ok: The type of the success variant.
pub fn as_either(self, ty_ok: impl Into<TypeRowRV>) -> Value {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

const_fail(self.into(), ty_ok)
}
}

#[typetag::serde]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::ops::Value;
use crate::std_extensions::arithmetic::int_types::INT_TYPES;
use crate::{
extension::{
prelude::{sum_with_error, ConstError},
prelude::{const_ok, ConstError, ERROR_TYPE},
ConstFold, ConstFoldResult, OpDef,
},
ops,
Expand Down Expand Up @@ -40,21 +40,19 @@ fn fold_trunc(
};
let log_width = get_log_width(arg).ok()?;
let int_type = INT_TYPES[log_width as usize].to_owned();
let sum_type = sum_with_error(int_type.clone());
let err_value = || {
let err_val = ConstError {
ConstError {
signal: 0,
message: "Can't truncate non-finite float".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
}
.as_either(int_type.clone())
};
let out_const: ops::Value = if !f.is_finite() {
err_value()
} else {
let cv = convert(f, log_width);
if let Ok(cv) = cv {
Value::sum(0, [cv], sum_type).unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
const_ok(cv, ERROR_TYPE)
} else {
err_value()
}
Expand Down
58 changes: 23 additions & 35 deletions hugr-core/src/std_extensions/arithmetic/int_ops/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
Value,
},
std_extensions::arithmetic::int_types::{get_log_width, ConstInt, INT_TYPES},
types::{SumType, Type, TypeArg},
types::{Type, TypeArg},
IncomingPort,
};

Expand Down Expand Up @@ -132,9 +132,9 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
};
let n0val: u64 = n0.value_u();
let out_const: Value = if n0val >> (1 << logwidth1) != 0 {
mk_out_const(1, Ok(INARROW_ERROR_VALUE.clone()))
mk_out_const(0, Ok(INARROW_ERROR_VALUE.clone()))
} else {
mk_out_const(0, ConstInt::new_u(logwidth1, n0val).map(Into::into))
mk_out_const(1, ConstInt::new_u(logwidth1, n0val).map(Into::into))
};
Some(vec![(0.into(), out_const)])
},
Expand All @@ -160,9 +160,9 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
let n0val: i64 = n0.value_s();
let ub = 1i64 << ((1 << logwidth1) - 1);
let out_const: Value = if n0val >= ub || n0val < -ub {
mk_out_const(1, Ok(INARROW_ERROR_VALUE.clone()))
mk_out_const(0, Ok(INARROW_ERROR_VALUE.clone()))
} else {
mk_out_const(0, ConstInt::new_s(logwidth1, n0val).map(Into::into))
mk_out_const(1, ConstInt::new_s(logwidth1, n0val).map(Into::into))
};
Some(vec![(0.into(), out_const)])
},
Expand Down Expand Up @@ -631,14 +631,12 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
let q_type = INT_TYPES[logwidth0 as usize].to_owned();
let r_type = q_type.clone();
let qr_type: Type = Type::new_tuple(vec![q_type, r_type]);
let sum_type: SumType = sum_with_error(qr_type);
let err_value = || {
let err_val = ConstError {
ConstError {
signal: 0,
message: "Division by zero".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
}
.as_either(qr_type)
};
let nval = n.value_u();
let mval = m.value_u();
Expand Down Expand Up @@ -694,14 +692,12 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
let q_type = INT_TYPES[logwidth0 as usize].to_owned();
let r_type = INT_TYPES[logwidth0 as usize].to_owned();
let qr_type: Type = Type::new_tuple(vec![q_type, r_type]);
let sum_type: SumType = sum_with_error(qr_type);
let err_value = || {
let err_val = ConstError {
ConstError {
signal: 0,
message: "Division by zero".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
}
.as_either(qr_type)
};
let nval = n.value_s();
let mval = m.value_u();
Expand Down Expand Up @@ -754,14 +750,12 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
None
} else {
let int_out_type = INT_TYPES[logwidth0 as usize].to_owned();
let sum_type = sum_with_error(int_out_type.clone());
let err_value = || {
let err_val = ConstError {
ConstError {
signal: 0,
message: "Division by zero".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
}
.as_either(int_out_type.clone())
};
let nval = n.value_u();
let mval = m.value_u();
Expand Down Expand Up @@ -808,14 +802,12 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
None
} else {
let int_out_type = INT_TYPES[logwidth0 as usize].to_owned();
let sum_type = sum_with_error(int_out_type.clone());
let err_value = || {
let err_val = ConstError {
ConstError {
signal: 0,
message: "Division by zero".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
}
.as_either(int_out_type.clone())
};
let nval = n.value_u();
let mval = m.value_u();
Expand Down Expand Up @@ -862,14 +854,12 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
None
} else {
let int_out_type = INT_TYPES[logwidth0 as usize].to_owned();
let sum_type = sum_with_error(int_out_type.clone());
let err_value = || {
let err_val = ConstError {
ConstError {
signal: 0,
message: "Division by zero".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
}
.as_either(int_out_type.clone())
};
let nval = n.value_s();
let mval = m.value_u();
Expand Down Expand Up @@ -918,14 +908,12 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
None
} else {
let int_out_type = INT_TYPES[logwidth0 as usize].to_owned();
let sum_type = sum_with_error(int_out_type.clone());
let err_value = || {
let err_val = ConstError {
ConstError {
signal: 0,
message: "Division by zero".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
}
.as_either(int_out_type.clone())
};
let nval = n.value_s();
let mval = m.value_u();
Expand Down
6 changes: 3 additions & 3 deletions hugr-core/src/std_extensions/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ mod test {
use rstest::rstest;

use crate::extension::prelude::{
const_left_tuple, const_none, const_right_tuple, const_some_tuple,
const_fail_tuple, const_none, const_ok_tuple, const_some_tuple,
};
use crate::ops::OpTrait;
use crate::PortIndex;
Expand Down Expand Up @@ -467,11 +467,11 @@ mod test {
TestVal::None(tr) => const_none(tr.clone()),
TestVal::Ok(l, tr) => {
let elems = l.iter().map(TestVal::to_value);
const_left_tuple(elems, tr.clone())
const_ok_tuple(elems, tr.clone())
}
TestVal::Err(tr, l) => {
let elems = l.iter().map(TestVal::to_value);
const_right_tuple(tr.clone(), elems)
const_fail_tuple(elems, tr.clone())
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions hugr-core/src/std_extensions/collections/list_fold.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Folding definitions for list operations.

use crate::extension::prelude::{
const_left, const_left_tuple, const_none, const_right, const_some, ConstUsize,
const_fail, const_none, const_ok, const_ok_tuple, const_some, ConstUsize,
};
use crate::extension::{ConstFold, ConstFoldResult, OpDef};
use crate::ops::Value;
Expand Down Expand Up @@ -96,9 +96,9 @@ impl ConstFold for SetFold {
let res_elem: Value = match list.0.get_mut(idx) {
Some(old_elem) => {
std::mem::swap(old_elem, &mut elem);
const_left(elem, list.1.clone())
const_ok(elem, list.1.clone())
}
None => const_right(list.1.clone(), elem),
None => const_fail(elem, list.1.clone()),
};
Some(vec![(0.into(), list.into()), (1.into(), res_elem)])
}
Expand All @@ -118,9 +118,9 @@ impl ConstFold for InsertFold {
let elem = elem.clone();
let res_elem: Value = if list.0.len() > idx {
list.0.insert(idx, elem);
const_left_tuple([], list.1.clone())
const_ok_tuple([], list.1.clone())
} else {
const_right(Type::UNIT, elem)
const_fail(elem, Type::UNIT)
};
Some(vec![(0.into(), list.into()), (1.into(), res_elem)])
}
Expand Down
Loading
Loading