Skip to content

Commit

Permalink
feat: add RMul, RDiv
Browse files Browse the repository at this point in the history
* `And` has the default type index
* impl `Dimension` traits
  • Loading branch information
mtshiba committed Sep 20, 2024
1 parent 4651a38 commit ff53af0
Show file tree
Hide file tree
Showing 17 changed files with 323 additions and 118 deletions.
2 changes: 1 addition & 1 deletion crates/els/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ fn comp_item_kind(t: &Type, muty: Mutability) -> CompletionItemKind {
CompletionItemKind::VARIABLE
}
}
Type::And(tys) => {
Type::And(tys, _) => {
for k in tys.iter().map(|t| comp_item_kind(t, muty)) {
if k != CompletionItemKind::VARIABLE {
return k;
Expand Down
31 changes: 25 additions & 6 deletions crates/erg_compiler/context/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ impl Context {
(lhs, Or(ors)) => ors.iter().all(|or| self.supertype_of(lhs, or)),
// Hash and Eq :> HashEq and ... == true
// Add(T) and Eq :> Add(Int) and Eq == true
(And(l), And(r)) => {
(And(l, _), And(r, _)) => {
if r.iter().any(|r| l.iter().all(|l| self.supertype_of(l, r))) {
return true;
}
Expand All @@ -843,9 +843,9 @@ impl Context {
false
}
// (Num and Show) :> Show == false
(And(ands), rhs) => ands.iter().all(|and| self.supertype_of(and, rhs)),
(And(ands, _), rhs) => ands.iter().all(|and| self.supertype_of(and, rhs)),
// Show :> (Num and Show) == true
(lhs, And(ands)) => ands.iter().any(|and| self.supertype_of(lhs, and)),
(lhs, And(ands, _)) => ands.iter().any(|and| self.supertype_of(lhs, and)),
// Not(Eq) :> Float == !(Eq :> Float) == true
(Not(_), Obj) => false,
(Not(l), rhs) => !self.supertype_of(l, rhs),
Expand Down Expand Up @@ -1097,6 +1097,23 @@ impl Context {
Variance::Covariant => self.supertype_of(sup, sub),
Variance::Invariant => self.same_type_of(sup, sub),
},
(TyParam::Type(sup), TyParam::Value(ValueObj::Type(sub))) => match variance {
Variance::Contravariant => self.subtype_of(sup, sub.typ()),
Variance::Covariant => self.supertype_of(sup, sub.typ()),
Variance::Invariant => self.same_type_of(sup, sub.typ()),
},
(TyParam::Value(ValueObj::Type(sup)), TyParam::Type(sub)) => match variance {
Variance::Contravariant => self.subtype_of(sup.typ(), sub),
Variance::Covariant => self.supertype_of(sup.typ(), sub),
Variance::Invariant => self.same_type_of(sup.typ(), sub),
},
(TyParam::Value(ValueObj::Type(sup)), TyParam::Value(ValueObj::Type(sub))) => {
match variance {
Variance::Contravariant => self.subtype_of(sup.typ(), sub.typ()),
Variance::Covariant => self.supertype_of(sup.typ(), sub.typ()),
Variance::Invariant => self.same_type_of(sup.typ(), sub.typ()),
}
}
(
TyParam::App { name, args },
TyParam::App {
Expand Down Expand Up @@ -1485,7 +1502,7 @@ impl Context {
},
(other, or @ Or(_)) | (or @ Or(_), other) => self.union_add(or, other),
// (A and B) or C ==> (A or C) and (B or C)
(And(ands), other) | (other, And(ands)) => {
(And(ands, _), other) | (other, And(ands, _)) => {
let mut t = Type::Obj;
for branch in ands.iter() {
let union = self.union(branch, other);
Expand Down Expand Up @@ -1705,7 +1722,9 @@ impl Context {
(_, Not(r)) => self.diff(lhs, r),
(Not(l), _) => self.diff(rhs, l),
// A and B and A == A and B
(other, and @ And(_)) | (and @ And(_), other) => self.intersection_add(and, other),
(other, and @ And(_, _)) | (and @ And(_, _), other) => {
self.intersection_add(and, other)
}
// (A or B) and C == (A and C) or (B and C)
(Or(ors), other) | (other, Or(ors)) => {
if ors.iter().any(|t| t.has_unbound_var()) {
Expand Down Expand Up @@ -2003,7 +2022,7 @@ impl Context {
Or(ors) => ors
.iter()
.fold(Obj, |l, r| self.intersection(&l, &self.complement(r))),
And(ands) => ands
And(ands, _) => ands
.iter()
.fold(Never, |l, r| self.union(&l, &self.complement(r))),
other => not(other.clone()),
Expand Down
2 changes: 1 addition & 1 deletion crates/erg_compiler/context/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2285,7 +2285,7 @@ impl Context {
Err((t, errs))
}
}
Type::And(ands) => {
Type::And(ands, _) => {
let mut new_ands = set! {};
for and in ands.into_iter() {
match self.eval_t_params(and, level, t_loc) {
Expand Down
6 changes: 3 additions & 3 deletions crates/erg_compiler/context/generalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,13 @@ impl Generalizer {
}
proj_call(lhs, attr_name, args)
}
And(ands) => {
And(ands, idx) => {
// not `self.intersection` because types are generalized
let ands = ands
.into_iter()
.map(|t| self.generalize_t(t, uninit))
.collect();
Type::checked_and(ands)
Type::checked_and(ands, idx)
}
Or(ors) => {
// not `self.union` because types are generalized
Expand Down Expand Up @@ -1049,7 +1049,7 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
let pred = self.deref_pred(*refine.pred)?;
Ok(refinement(refine.var, t, pred))
}
And(ands) => {
And(ands, _) => {
let mut new_ands = vec![];
for t in ands.into_iter() {
new_ands.push(self.deref_tyvar(t)?);
Expand Down
2 changes: 1 addition & 1 deletion crates/erg_compiler/context/hint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl Context {
return Some(hint);
}
}
(Type::And(tys), found) if tys.len() == 2 => {
(Type::And(tys, _), found) if tys.len() == 2 => {
let mut iter = tys.iter();
let l = iter.next().unwrap();
let r = iter.next().unwrap();
Expand Down
55 changes: 55 additions & 0 deletions crates/erg_compiler/context/initialize/classes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4119,6 +4119,39 @@ impl Context {
Immutable,
Visibility::BUILTIN_PUBLIC,
);
let mut dimension_to_float = Self::builtin_methods(Some(mono(TO_FLOAT)), 1);
dimension_to_float.register_builtin_erg_impl(
FUNDAMENTAL_FLOAT,
fn0_met(dimension_t.clone(), Float).quantify(),
Immutable,
Visibility::BUILTIN_PUBLIC,
);
dimension.register_trait_methods(dimension_t.clone(), dimension_to_float);
let mut dimension_to_int = Self::builtin_methods(Some(mono(TO_INT)), 1);
dimension_to_int.register_builtin_erg_impl(
FUNDAMENTAL_INT,
fn0_met(dimension_t.clone(), Int).quantify(),
Immutable,
Visibility::BUILTIN_PUBLIC,
);
dimension.register_trait_methods(dimension_t.clone(), dimension_to_int);
let mut dimension_eq = Self::builtin_methods(Some(mono(EQ)), 2);
let t = fn1_met(dimension_t.clone(), dimension_t.clone(), Bool).quantify();
dimension_eq.register_builtin_erg_impl(OP_EQ, t, Immutable, Visibility::BUILTIN_PUBLIC);
dimension.register_trait_methods(dimension_t.clone(), dimension_eq);
let mut dimension_partial_ord = Self::builtin_methods(Some(mono(PARTIAL_ORD)), 2);
dimension_partial_ord.register_builtin_erg_impl(
OP_PARTIAL_CMP,
fn1_met(
dimension_t.clone(),
dimension_t.clone(),
mono(ORDERING) | NoneType,
)
.quantify(),
Const,
Visibility::BUILTIN_PUBLIC,
);
dimension.register_trait_methods(dimension_t.clone(), dimension_partial_ord);
let mut dimension_add =
Self::builtin_methods(Some(poly(ADD, vec![ty_tp(dimension_t.clone())])), 2);
let t = fn1_met(
Expand Down Expand Up @@ -4181,6 +4214,17 @@ impl Context {
ValueObj::builtin_class(dimension_added_t),
);
dimension.register_trait_methods(dimension_t.clone(), dimension_mul);
let mut dimension_rmul =
Self::builtin_methods(Some(poly(RMUL, vec![ty_tp(Ty.clone())])), 2);
let t = fn1_met(dimension_t.clone(), Ty.clone(), dimension_t.clone()).quantify();
dimension_rmul.register_builtin_erg_impl(OP_RMUL, t, Immutable, Visibility::BUILTIN_PUBLIC);
dimension_rmul.register_builtin_const(
OUTPUT,
Visibility::BUILTIN_PUBLIC,
None,
ValueObj::builtin_class(dimension_t.clone()),
);
dimension.register_trait_methods(dimension_t.clone(), dimension_rmul);
let mut dimension_div =
Self::builtin_methods(Some(poly(DIV, vec![ty_tp(dimension2_t.clone())])), 2);
let dimension_subtracted_t = poly(
Expand Down Expand Up @@ -4210,6 +4254,17 @@ impl Context {
ValueObj::builtin_class(dimension_subtracted_t),
);
dimension.register_trait_methods(dimension_t.clone(), dimension_div);
let mut dimension_rdiv =
Self::builtin_methods(Some(poly(RDIV, vec![ty_tp(Ty.clone())])), 2);
let t = fn1_met(dimension_t.clone(), Ty.clone(), dimension_t.clone()).quantify();
dimension_rdiv.register_builtin_erg_impl(OP_RDIV, t, Immutable, Visibility::BUILTIN_PUBLIC);
dimension_rdiv.register_builtin_const(
OUTPUT,
Visibility::BUILTIN_PUBLIC,
None,
ValueObj::builtin_class(dimension_t.clone()),
);
dimension.register_trait_methods(dimension_t.clone(), dimension_rdiv);
let mut base_exception = Self::builtin_mono_class(BASE_EXCEPTION, 2);
base_exception.register_superclass(Obj, &obj);
base_exception.register_builtin_erg_impl(
Expand Down
12 changes: 10 additions & 2 deletions crates/erg_compiler/context/initialize/funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,11 @@ impl Context {
Some(FUNC_SUB),
);
let L = mono_q(TY_L, subtypeof(poly(MUL, params.clone())));
let op_t = bin_op(L.clone(), R.clone(), proj(L, OUTPUT)).quantify();
let L2 = type_q(TY_L);
let R2 = mono_q(TY_R, subtypeof(poly(RMUL, vec![ty_tp(L2.clone())])));
let op_t = (bin_op(L.clone(), R.clone(), proj(L, OUTPUT)).quantify()
& bin_op(L2.clone(), R2.clone(), proj(R2, OUTPUT)).quantify())
.with_default_intersec_index(0);
self.register_builtin_py_impl(
OP_MUL,
op_t,
Expand All @@ -1014,7 +1018,11 @@ impl Context {
Some(FUNC_MUL),
);
let L = mono_q(TY_L, subtypeof(poly(DIV, params.clone())));
let op_t = bin_op(L.clone(), R.clone(), proj(L, OUTPUT)).quantify();
let L2 = type_q(TY_L);
let R2 = mono_q(TY_R, subtypeof(poly(RDIV, vec![ty_tp(L2.clone())])));
let op_t = (bin_op(L.clone(), R.clone(), proj(L, OUTPUT)).quantify()
& bin_op(L2.clone(), R2.clone(), proj(R2, OUTPUT)).quantify())
.with_default_intersec_index(0);
self.register_builtin_py_impl(
OP_DIV,
op_t,
Expand Down
4 changes: 4 additions & 0 deletions crates/erg_compiler/context/initialize/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ const ATTR_TRACEBACK: &str = "traceback";
const ADD: &str = "Add";
const SUB: &str = "Sub";
const MUL: &str = "Mul";
const RMUL: &str = "RMul";
const RDIV: &str = "RDiv";
const DIV: &str = "Div";
const FLOOR_DIV: &str = "FloorDiv";
const POS: &str = "Pos";
Expand Down Expand Up @@ -564,6 +566,8 @@ const OP_SUB: &str = "__sub__";
const OP_MUL: &str = "__mul__";
const OP_DIV: &str = "__div__";
const OP_FLOOR_DIV: &str = "__floordiv__";
const OP_RMUL: &str = "__rmul__";
const OP_RDIV: &str = "__rdiv__";
const OP_ABS: &str = "__abs__";
const OP_PARTIAL_CMP: &str = "__partial_cmp__";
const OP_AND: &str = "__and__";
Expand Down
27 changes: 26 additions & 1 deletion crates/erg_compiler/context/initialize/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,23 @@ impl Context {
let op_t = fn0_met(_Slf.clone(), proj(_Slf, OUTPUT)).quantify();
neg.register_builtin_erg_decl(OP_NEG, op_t, Visibility::BUILTIN_PUBLIC);
neg.register_builtin_erg_decl(OUTPUT, Type, Visibility::BUILTIN_PUBLIC);
/* RMul */
let L = mono_q(TY_L, instanceof(Type));
let rparams = vec![PS::t(TY_L, false, WithDefault)];
let ty_rparams = vec![ty_tp(L.clone())];
let mut rmul = Self::builtin_poly_trait(RMUL, rparams.clone(), 2);
rmul.register_superclass(poly(OUTPUT, vec![ty_tp(L.clone())]), &output);
let RSlf = mono_q(SELF, subtypeof(poly(RMUL, ty_rparams.clone())));
let op_t = fn1_met(RSlf.clone(), L.clone(), proj(RSlf, OUTPUT)).quantify();
rmul.register_builtin_erg_decl(OP_RMUL, op_t, Visibility::BUILTIN_PUBLIC);
rmul.register_builtin_erg_decl(OUTPUT, Type, Visibility::BUILTIN_PUBLIC);
/* RDiv */
let mut rdiv = Self::builtin_poly_trait(RDIV, rparams.clone(), 2);
rdiv.register_superclass(poly(OUTPUT, vec![ty_tp(L.clone())]), &output);
let RSlf = mono_q(SELF, subtypeof(poly(RDIV, ty_rparams.clone())));
let op_t = fn1_met(RSlf.clone(), L.clone(), proj(RSlf, OUTPUT)).quantify();
rdiv.register_builtin_erg_decl(OP_RDIV, op_t, Visibility::BUILTIN_PUBLIC);
rdiv.register_builtin_erg_decl(OUTPUT, Type, Visibility::BUILTIN_PUBLIC);
/* Num */
let num = Self::builtin_mono_trait(NUM, 2);
// num.register_superclass(poly(ADD, vec![]), &add);
Expand Down Expand Up @@ -832,7 +849,15 @@ impl Context {
None,
);
self.register_builtin_type(mono(POS), pos, vis.clone(), Const, Some(POS));
self.register_builtin_type(mono(NEG), neg, vis, Const, Some(NEG));
self.register_builtin_type(mono(NEG), neg, vis.clone(), Const, Some(NEG));
self.register_builtin_type(
poly(RMUL, ty_rparams.clone()),
rmul,
vis.clone(),
Const,
Some(RMUL),
);
self.register_builtin_type(poly(RDIV, ty_rparams), rdiv, vis.clone(), Const, Some(RDIV));
self.register_const_param_defaults(
ADD,
vec![ConstTemplate::Obj(ValueObj::builtin_type(Slf.clone()))],
Expand Down
17 changes: 10 additions & 7 deletions crates/erg_compiler/context/inquire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,9 @@ impl Context {
}
}
}
if let Some(default) = instance.default_intersection_type() {
return Ok(default.clone());
}
let Type::Subr(subr_t) = input_t else {
unreachable!()
};
Expand Down Expand Up @@ -1952,7 +1955,7 @@ impl Context {
res
}
}
Type::And(_) => {
Type::And(_, _) => {
let instance = self.resolve_overload(
obj,
instance.clone(),
Expand Down Expand Up @@ -2852,11 +2855,11 @@ impl Context {
pub(crate) fn type_params_variance(&self) -> Vec<Variance> {
let match_tp_name = |tp: &TyParam, name: &VarName| -> bool {
if let Ok(free) = <&FreeTyParam>::try_from(tp) {
if let Some(prev) = free.get_previous() {
if let Some(prev) = free.get_undoable_root() {
return prev.unbound_name().as_ref() == Some(name.inspect());
}
} else if let Ok(free) = <&FreeTyVar>::try_from(tp) {
if let Some(prev) = free.get_previous() {
if let Some(prev) = free.get_undoable_root() {
return prev.unbound_name().as_ref() == Some(name.inspect());
}
}
Expand Down Expand Up @@ -3014,7 +3017,7 @@ impl Context {
self.get_nominal_super_type_ctxs(&Type)
}
}
Type::And(tys) => {
Type::And(tys, _) => {
let mut acc = vec![];
for ctxs in tys
.iter()
Expand Down Expand Up @@ -3366,7 +3369,7 @@ impl Context {
match trait_ {
// And(Add, Sub) == intersection({Int <: Add(Int), Bool <: Add(Bool) ...}, {Int <: Sub(Int), ...})
// == {Int <: Add(Int) and Sub(Int), ...}
Type::And(tys) => {
Type::And(tys, _) => {
let impls = tys
.iter()
.flat_map(|ty| self.get_trait_impls(ty))
Expand Down Expand Up @@ -3956,7 +3959,7 @@ impl Context {

pub fn is_class(&self, typ: &Type) -> bool {
match typ {
Type::And(_) => false,
Type::And(_, _) => false,
Type::Never => true,
Type::FreeVar(fv) if fv.is_linked() => self.is_class(&fv.crack()),
Type::FreeVar(_) => false,
Expand All @@ -3983,7 +3986,7 @@ impl Context {
Type::Never => false,
Type::FreeVar(fv) if fv.is_linked() => self.is_class(&fv.crack()),
Type::FreeVar(_) => false,
Type::And(tys) => tys.iter().any(|t| self.is_trait(t)),
Type::And(tys, _) => tys.iter().any(|t| self.is_trait(t)),
Type::Or(tys) => tys.iter().all(|t| self.is_trait(t)),
Type::Proj { lhs, rhs } => self
.get_proj_candidates(lhs, rhs)
Expand Down
8 changes: 4 additions & 4 deletions crates/erg_compiler/context/instantiate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ impl Context {
let t = self.instantiate_t_inner(*t, tmp_tv_cache, loc)?;
Ok(t.structuralize())
}
And(tys) => {
And(tys, _) => {
let mut new_tys = vec![];
for ty in tys.iter().cloned() {
new_tys.push(self.instantiate_t_inner(ty, tmp_tv_cache, loc)?);
Expand Down Expand Up @@ -1004,7 +1004,7 @@ impl Context {
let t = fv.crack().clone();
self.instantiate(t, callee)
}
And(tys) => {
And(tys, _idx) => {
let tys = tys
.into_iter()
.map(|t| self.instantiate(t, callee))
Expand Down Expand Up @@ -1036,7 +1036,7 @@ impl Context {
)?;
}
}
Type::And(tys) => {
Type::And(tys, _) => {
for ty in tys {
if let Some(self_t) = ty.self_t() {
self.sub_unify(
Expand Down Expand Up @@ -1068,7 +1068,7 @@ impl Context {
let t = fv.crack().clone();
self.instantiate_dummy(t)
}
And(tys) => {
And(tys, _) => {
let tys = tys
.into_iter()
.map(|t| self.instantiate_dummy(t))
Expand Down
Loading

0 comments on commit ff53af0

Please sign in to comment.