Skip to content

Commit

Permalink
fix: refinement subtyping bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Nov 3, 2024
1 parent 806c5a9 commit 87fb4cf
Show file tree
Hide file tree
Showing 11 changed files with 191 additions and 51 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,7 @@ path = "src/lib.rs"

# [profile.release]
# panic = 'abort'

[profile.opt-with-dbg]
inherits = "release"
debug = true
5 changes: 4 additions & 1 deletion crates/erg_common/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,10 @@ impl SubMessage {
gutter_color,
);
cxt.push_str(&" ".repeat(lineno.to_string().len()));
cxt.push_str_with_color(mark.repeat(cmp::max(1, codes[i].len())), err_color);
cxt.push_str_with_color(
mark.repeat(cmp::max(1, codes.get(i).map_or(1, |code| code.len()))),
err_color,
);
cxt.push_str("\n");
}
cxt.push_str("\n");
Expand Down
68 changes: 38 additions & 30 deletions crates/erg_compiler/context/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,40 +413,43 @@ impl Context {
.return_t
.clone()
.replace_params(rs.param_names(), ls.param_names());
let return_t_judge = self.supertype_of(&ls.return_t, &rhs_ret); // covariant
let non_defaults_judge = if let Some(r_var) = rs.var_params.as_deref() {
ls.non_default_params
.iter()
.zip(repeat(r_var))
.all(|(l, r)| self.subtype_of(l.typ(), r.typ()))
} else {
let rs_params = if !ls.is_method() && rs.is_method() {
rs.non_default_params
let return_t_judge = || self.supertype_of(&ls.return_t, &rhs_ret); // covariant
let non_defaults_judge = || {
if let Some(r_var) = rs.var_params.as_deref() {
ls.non_default_params
.iter()
.skip(1)
.chain(&rs.default_params)
.zip(repeat(r_var))
.all(|(l, r)| self.subtype_of(l.typ(), r.typ()))
} else {
#[allow(clippy::iter_skip_zero)]
rs.non_default_params
let rs_params = if !ls.is_method() && rs.is_method() {
rs.non_default_params
.iter()
.skip(1)
.chain(&rs.default_params)
} else {
#[allow(clippy::iter_skip_zero)]
rs.non_default_params
.iter()
.skip(0)
.chain(&rs.default_params)
};
ls.non_default_params
.iter()
.skip(0)
.chain(&rs.default_params)
};
ls.non_default_params
.iter()
.zip(rs_params)
.all(|(l, r)| self.subtype_of(l.typ(), r.typ()))
.zip(rs_params)
.all(|(l, r)| self.subtype_of(l.typ(), r.typ()))
}
};
let var_params_judge = || {
ls.var_params
.as_ref()
.zip(rs.var_params.as_ref())
.map(|(l, r)| self.subtype_of(l.typ(), r.typ()))
.unwrap_or(true)
};
let var_params_judge = ls
.var_params
.as_ref()
.zip(rs.var_params.as_ref())
.map(|(l, r)| self.subtype_of(l.typ(), r.typ()))
.unwrap_or(true);
len_judge
&& return_t_judge
&& non_defaults_judge
&& var_params_judge
&& return_t_judge()
&& non_defaults_judge()
&& var_params_judge()
&& default_check() // contravariant
}
// {Int} <: Obj -> Int
Expand All @@ -461,7 +464,8 @@ impl Context {
return false;
};
if let Some((_, __call__)) = ctx.get_class_attr("__call__") {
self.supertype_of(lhs, &__call__.t)
let call_t = __call__.t.clone().undoable_root();
self.supertype_of(lhs, &call_t)
} else {
false
}
Expand Down Expand Up @@ -771,10 +775,14 @@ impl Context {
// Bool :> {2} == false
// [2, 3]: {A: List(Nat) | A.prod() == 6}
// List({1, 2}, _) :> {[3, 4]} == false
// T :> {None} == T :> NoneType
(l, Refinement(r)) => {
// Type / {S: Set(Str) | S == {"a", "b"}}
// TODO: GeneralEq
if let Pred::Equal { rhs, .. } = r.pred.as_ref() {
if rhs.is_none() {
return self.supertype_of(lhs, &NoneType);
}
if self.subtype_of(l, &Type) && self.convert_tp_into_type(rhs.clone()).is_ok() {
return true;
}
Expand Down
43 changes: 40 additions & 3 deletions crates/erg_compiler/context/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2939,6 +2939,34 @@ impl Context {
}
Type::Quantified(quant) => self.convert_singular_type_into_value(*quant),
Type::Subr(subr) => self.convert_singular_type_into_value(*subr.return_t),
Type::Proj { lhs, rhs } => {
let old = lhs.clone().proj(rhs.clone());
let evaled = self.eval_proj(*lhs, rhs, 0, &()).map_err(|_| old.clone())?;
if old != evaled {
self.convert_singular_type_into_value(evaled)
} else {
Err(old)
}
}
Type::ProjCall {
lhs,
attr_name,
args,
} => {
let old = Type::ProjCall {
lhs: lhs.clone(),
attr_name: attr_name.clone(),
args: args.clone(),
};
let evaled = self
.eval_proj_call_t(*lhs, attr_name, args, 0, &())
.map_err(|_| old.clone())?;
if old != evaled {
self.convert_singular_type_into_value(evaled)
} else {
Err(old)
}
}
Type::Failure => Ok(ValueObj::Failure),
_ => Err(typ),
}
Expand Down Expand Up @@ -2993,7 +3021,15 @@ impl Context {
let end = fields["end"].clone();
Ok(closed_range(start.class(), start, end))
}
other => Err(other),
// TODO:
ValueObj::DataClass { .. }
| ValueObj::Int(_)
| ValueObj::Nat(_)
| ValueObj::Bool(_)
| ValueObj::Float(_)
| ValueObj::Code(_)
| ValueObj::Str(_)
| ValueObj::None => Err(val),
}
}

Expand Down Expand Up @@ -3083,7 +3119,8 @@ impl Context {
})
}
ValueObj::Failure => Ok(TyParam::Failure),
_ => Err(TyParam::Value(value)),
ValueObj::Subr(_) => Err(TyParam::Value(value)),
mono_value_pattern!(-Failure) => Err(TyParam::Value(value)),
}
}

Expand All @@ -3094,7 +3131,7 @@ impl Context {
self.convert_type_to_dict_type(t)
}
Type::Refinement(refine) => self.convert_type_to_dict_type(*refine.t),
Type::Poly { name, params } if &name[..] == "Dict" || &name[..] == "Dict!" => {
Type::Poly { params, .. } if ty.is_dict() || ty.is_dict_mut() => {
let dict = Dict::try_from(params[0].clone())?;
let mut new_dict = dict! {};
for (k, v) in dict.into_iter() {
Expand Down
61 changes: 56 additions & 5 deletions crates/erg_compiler/context/initialize/const_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,10 @@ pub(crate) fn __dict_getitem__(mut args: ValueArgs, ctx: &Context) -> EvalValueR
}
}

/// `{Str: Int, Int: Float}.keys() == Str or Int`
/// ```erg
/// {"a": 1, "b": 2}.keys() == ["a", "b"]
/// {Str: Int, Int: Float}.keys() == Str or Int
/// ```
pub(crate) fn dict_keys(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<TyParam> {
let slf = args
.remove_left_or_key("Self")
Expand All @@ -390,7 +393,10 @@ pub(crate) fn dict_keys(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<T
}
}

/// `{Str: Int, Int: Float}.values() == Int or Float`
/// ```erg
/// {"a": 1, "b": 2}.values() == [1, 2]
/// {Str: Int, Int: Float}.values() == Int or Float
/// ```
pub(crate) fn dict_values(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<TyParam> {
let slf = args
.remove_left_or_key("Self")
Expand All @@ -417,7 +423,10 @@ pub(crate) fn dict_values(mut args: ValueArgs, ctx: &Context) -> EvalValueResult
}
}

/// `{Str: Int, Int: Float}.items() == (Str, Int) or (Int, Float)`
/// ```erg
/// {"a": 1, "b": 2}.items() == [("a", 1), ("b", 2)]
/// {Str: Int, Int: Float}.items() == (Str, Int) or (Int, Float)
/// ```
pub(crate) fn dict_items(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<TyParam> {
let slf = args
.remove_left_or_key("Self")
Expand All @@ -434,8 +443,8 @@ pub(crate) fn dict_items(mut args: ValueArgs, ctx: &Context) -> EvalValueResult<
})
.collect::<Result<Dict<_, _>, ValueObj>>();
if let Ok(slf) = dict_type {
let union = slf.iter().fold(Type::Never, |union, (k, v)| {
ctx.union(&union, &tuple_t(vec![k.clone(), v.clone()]))
let union = slf.into_iter().fold(Type::Never, |union, (k, v)| {
ctx.union(&union, &tuple_t(vec![k, v]))
});
// let items = poly(DICT_ITEMS, vec![ty_tp(union)]);
Ok(ValueObj::builtin_type(union).into())
Expand Down Expand Up @@ -1653,3 +1662,45 @@ pub(crate) fn classof_func(mut args: ValueArgs, _ctx: &Context) -> EvalValueResu
.ok_or_else(|| not_passed("obj"))?;
Ok(TyParam::t(val.class()))
}

#[cfg(test)]
mod tests {
use erg_common::config::ErgConfig;

use crate::module::SharedCompilerResource;
use crate::ty::constructors::*;

use super::*;

#[test]
fn test_dict_items() {
let cfg = ErgConfig::default();
let shared = SharedCompilerResource::default();
Context::init_builtins(cfg, shared.clone());
let ctx = &shared.raw_ref_builtins_ctx().unwrap().context;

let a = singleton(Type::Str, TyParam::value("a"));
let b = singleton(Type::Str, TyParam::value("b"));
let c = singleton(Type::Str, TyParam::value("c"));
// FIXME: order dependency
let k_union = ctx.union(&ctx.union(&c, &a), &b);
let g_dic = singleton(Type::Type, TyParam::t(mono("GenericDict")));
let g_lis = singleton(Type::Type, TyParam::t(mono("GenericList")));
let sub = func0(singleton(Type::NoneType, TyParam::value(ValueObj::None)));
let v_union = ctx.union(&ctx.union(&g_lis, &sub), &g_dic);
let dic = dict! {
a.clone() => sub.clone(),
b.clone() => g_dic.clone(),
c.clone() => g_lis.clone(),
};
match ctx.eval_proj_call_t(dic.clone().into(), "items".into(), vec![], 1, &()) {
Ok(t) => {
let items = tuple_t(vec![k_union, v_union]);
assert_eq!(t, items, "{t} != {items}");
}
Err(e) => {
panic!("ERR: {e}");
}
}
}
}
2 changes: 2 additions & 0 deletions crates/erg_compiler/context/inquire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3261,6 +3261,8 @@ impl Context {
/// get_nominal_type_ctx({ .x = Int }) == Some(<RecordMetaType>)
/// get_nominal_type_ctx(?T(<: Int)) == Some(<Int>)
/// get_nominal_type_ctx(?T(:> Int)) == None # you need to coerce it to Int
/// get_nominal_type_ctx(T or U) == Some(<Or>)
/// get_nominal_type_ctx(T and U) == None
/// ```
pub(crate) fn get_nominal_type_ctx<'a>(&'a self, typ: &Type) -> Option<&'a TypeContext> {
match typ {
Expand Down
2 changes: 1 addition & 1 deletion crates/erg_compiler/context/unify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1395,7 +1395,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
}
}
if let Some((sub, mut supe)) = super_fv.get_subsup() {
if supe.is_structural() || !supe.is_recursive() {
if !supe.is_recursive() {
self.sub_unify(maybe_sub, &supe)?;
}
let mut new_sub = self.ctx.union(maybe_sub, &sub);
Expand Down
8 changes: 1 addition & 7 deletions crates/erg_compiler/ty/free.rs
Original file line number Diff line number Diff line change
Expand Up @@ -962,13 +962,7 @@ impl HasLevel for Free<Type> {
fn level(&self) -> Option<Level> {
match &*self.borrow() {
FreeKind::Unbound { lev, .. } | FreeKind::NamedUnbound { lev, .. } => Some(*lev),
FreeKind::Linked(t) | FreeKind::UndoableLinked { t, .. } => {
if t.is_recursive() {
None
} else {
t.level()
}
}
FreeKind::Linked(t) | FreeKind::UndoableLinked { t, .. } => t.level(),
}
}
}
Expand Down
29 changes: 25 additions & 4 deletions crates/erg_compiler/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2590,7 +2590,7 @@ impl Type {
}

pub fn has_proj(&self) -> bool {
self.is_proj() || self.has_type_satisfies(|t| t.is_proj())
self.is_proj() || self.has_type_satisfies(|t| t.has_proj())
}

pub fn is_proj_call(&self) -> bool {
Expand All @@ -2602,7 +2602,7 @@ impl Type {
}

pub fn has_proj_call(&self) -> bool {
self.is_proj_call() || self.has_type_satisfies(|t| t.is_proj_call())
self.is_proj_call() || self.has_type_satisfies(|t| t.has_proj_call())
}

pub fn is_intersection_type(&self) -> bool {
Expand Down Expand Up @@ -3190,7 +3190,7 @@ impl Type {
}

pub fn has_refinement(&self) -> bool {
self.is_refinement() || self.has_type_satisfies(|t| t.is_refinement())
self.is_refinement() || self.has_type_satisfies(|t| t.has_refinement())
}

pub fn is_recursive(&self) -> bool {
Expand Down Expand Up @@ -5132,6 +5132,13 @@ impl Type {
}
}

pub(crate) fn undoable_root(self) -> Self {
if let Some(free) = self.as_free().and_then(|fv| fv.get_undoable_root()) {
return Self::FreeVar(FreeTyVar::new(free.clone()));
}
self.map(&mut |t| t.undoable_root(), &SharedFrees::new())
}

/// interior-mut
///
/// `inc/dec_undo_count` due to the number of `substitute_typarams/undo_typarams` must be matched
Expand Down Expand Up @@ -6089,7 +6096,8 @@ impl TypePair {

#[cfg(test)]
mod tests {
use constructors::type_q;
use constructors::*;
use erg_common::dict;

use super::*;

Expand All @@ -6107,4 +6115,17 @@ mod tests {
Type::Or(set! {t, Type::Never})
);
}

#[test]
fn test_recursive() {
let t = type_q("T");
let u = type_q("U");
let subr = fn1_met(Type::Failure, u.clone(), t.clone());
let sup = Type::Record(dict! {Field::private("x".into()) => subr.clone()}).structuralize();
let constr = Constraint::new_subtype_of(sup.clone());
t.update_constraint(constr, None, true);
assert!(t.is_recursive());
assert!(subr.is_recursive());
assert!(sup.is_recursive());
}
}
7 changes: 7 additions & 0 deletions crates/erg_compiler/ty/typaram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2096,6 +2096,13 @@ impl TyParam {
pub fn as_free(&self) -> Option<&FreeTyParam> {
<&FreeTyParam>::try_from(self).ok()
}

pub fn is_none(&self) -> bool {
match self {
Self::Value(val) => val.is_none(),
_ => false,
}
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
Expand Down
Loading

0 comments on commit 87fb4cf

Please sign in to comment.