Skip to content
This repository has been archived by the owner on May 27, 2024. It is now read-only.

Commit

Permalink
unify: implement more
Browse files Browse the repository at this point in the history
  • Loading branch information
HoshinoTented authored and ice1000 committed Mar 8, 2024
1 parent f8f6d6e commit d6a84ec
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 33 deletions.
11 changes: 8 additions & 3 deletions base/src/main/java/org/aya/tyck/ExprTycker.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,14 @@ yield subscoped(() -> {
case Expr.LitInt litInt -> throw new UnsupportedOperationException("TODO");
case Expr.LitString litString -> throw new UnsupportedOperationException("TODO");
case Expr.Ref(var ref) -> checkApplication(ref, ImmutableSeq.empty());
case Expr.Sigma sigma -> throw new UnsupportedOperationException("TODO");
case Expr.Pi(var param, var body) -> {
case Expr.Sigma _ -> {
var ty = ty(expr);
// TODO: type level
yield new Result.Default(ty, new SortTerm(SortKind.Type, 0));
}
case Expr.Pi _ -> {
var ty = ty(expr);
// TODO: type level
yield new Result.Default(ty, new SortTerm(SortKind.Type, 0));
}
case Expr.Sort _ -> {
Expand All @@ -104,7 +109,7 @@ yield subscoped(() -> {
};
}

private @NotNull Result checkApplication(AnyVar f, ImmutableSeq<Expr.NamedArg> args) {
private @NotNull Result checkApplication(@NotNull AnyVar f, @NotNull ImmutableSeq<Expr.NamedArg> args) {
return switch (f) {
case LocalVar lVar -> args.foldLeft(new Result.Default(new FreeTerm(lVar), localCtx().get(lVar)), (acc, arg) -> {
if (arg.name() != null || !arg.explicit()) throw new UnsupportedOperationException("TODO: named arg");
Expand Down
160 changes: 140 additions & 20 deletions base/src/main/java/org/aya/tyck/unify/TermComparator.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,44 @@
// Use of this source code is governed by the MIT license that can be found in the LICENSE.md file.
package org.aya.tyck.unify;

import kala.collection.immutable.ImmutableSeq;
import org.aya.generic.SortKind;
import org.aya.syntax.core.def.TeleDef;
import org.aya.syntax.core.term.*;
import org.aya.syntax.core.term.call.ConCallLike;
import org.aya.syntax.core.term.call.DataCall;
import org.aya.syntax.core.term.call.FnCall;
import org.aya.tyck.tycker.StateBased;
import org.aya.util.Ordering;
import org.aya.util.Pair;
import org.aya.util.error.InternalException;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.function.Supplier;

public abstract class TermComparator implements StateBased {
protected final @NotNull Ordering cmp;

public TermComparator(@NotNull Ordering cmp) {
this.cmp = cmp;
}

public boolean compare(@NotNull Term lhs, @NotNull Term rhs, @Nullable Term type) {
// TODO
if (type == null) return compareUntyped(lhs, rhs) != null;
return doCompareTyped(lhs, rhs, type);
}

/**
* Compare {@param lhs} and {@param rhs} with {@param type} information
* Compare whnf {@param lhs} and whnf {@param rhs} with {@param type} information
*
* @param type the whnf type.
* @return whether they are 'the same' and their types are {@param type}
*/
private boolean doCompareTyped(@NotNull Term lhs, @NotNull Term rhs, @NotNull Term type) {
boolean ret = switch (type) {
return switch (type) {
// TODO: ClassCall
case LamTerm _ -> throw new InternalException("LamTerm is never type");
case ConCallLike _ -> throw new InternalException("ConCall is never type");
case TupTerm _ -> throw new InternalException("TupTerm is never type");
Expand All @@ -37,28 +52,18 @@ case Pair(LamTerm(var lbody), LamTerm(var rbody)) -> state().dCtx().with(pi.para
default -> false;
};
case SigmaTerm(var paramSeq) -> {
// We use view since we need to instantiate the remaining params after tyck some component.
var params = paramSeq.view();
var size = paramSeq.size();
for (var i = 0; i < size; ++ i) {
var l = ProjTerm.make(lhs, i);
var r = ProjTerm.make(rhs, i);
var param = params.getFirst();
if (! compare(l, r, param)) yield false;
params = params.drop(1).mapIndexed((j, term) ->
term.replace(j, l));
}
yield true;
var list = ImmutableSeq.fill(size, i -> ProjTerm.make(lhs, i));
var rist = ImmutableSeq.fill(size, i -> ProjTerm.make(rhs, i));

yield compareMany(list, rist, paramSeq);
}
default -> false;
default -> compareUntyped(lhs, rhs) != null;
};

// TODO
throw new UnsupportedOperationException("TODO");
}

/**
* Compare {@param lhs} and {@param rhs} without type information.
* Compare whnf {@param lhs} and whnf {@param rhs} without type information.
*
* @return the type of {@param lhs} and {@param rhs} if they are 'the same', null otherwise.
*/
Expand All @@ -68,15 +73,27 @@ case SigmaTerm(var paramSeq) -> {
}

private @Nullable Term doCompareUntyped(@NotNull Term lhs, @NotNull Term rhs) {
// TODO: return correct type level
if (lhs instanceof Formation form) return doCompareType(form, rhs) ? SortTerm.Set0 : null;
return switch (lhs) {
case AppTerm(var f, var a) -> {
if (!(rhs instanceof AppTerm(var g, var b))) yield null;
var fTy = compareUntyped(f, g);
if (fTy == null) yield null;
if (!(whnf(fTy) instanceof PiTerm pi)) yield null;
if (!(fTy instanceof PiTerm pi)) yield null;
if (!compare(a, b, pi.param())) yield null;
yield pi.body().instantiate(a);
}
case ProjTerm(var lof, var ldx) -> {
// Since the {lhs} and {rhs} are whnf, at this point, {lof} is unable to evaluate.
// Thus the only thing we can do is check whether {lof} and {rhs.of(}} (if rhs is ProjTerm) is 'the same'.
if (!(rhs instanceof ProjTerm(var rof, var rdx))) yield null;
if (!(compareUntyped(lof, rof) instanceof SigmaTerm(var params))) yield null;
if (ldx != rdx) yield null;
// Make type
var spine = ImmutableSeq.fill(ldx /* ldx is 0-based */, i -> ProjTerm.make(lof, i)); // 0 = lof.0, 1 = lof.1, ...
// however, for {lof.ldx}, the nearest(0) element is {lof.(idx - 1)}, so we need to reverse the spine.
yield params.get(ldx).instantiateAll(spine.view().reversed());
}
case FreeTerm(var lvar) -> {
if (rhs instanceof FreeTerm(var rvar) && lvar == rvar) yield state().ctx().get(lvar);
yield null;
Expand All @@ -85,6 +102,7 @@ case LocalTerm(var ldx) -> {
if (rhs instanceof LocalTerm(var rdx) && ldx == rdx) yield state().dCtx().get(ldx);
yield null;
}
case FnCall _ -> throw new InternalException("FnCall is compared in compareApprox");
default -> throw new UnsupportedOperationException("TODO");
};
}
Expand All @@ -100,4 +118,106 @@ private boolean compareLambda(@NotNull LamTerm lambda, @NotNull Term rhs, @NotNu
return compare(lhsBody, rhsBody, type.body());
});
}

private boolean compareMany(
@NotNull ImmutableSeq<Term> list,
@NotNull ImmutableSeq<Term> rist,
@NotNull ImmutableSeq<Term> types
) {
assert list.sizeEquals(rist);
assert rist.sizeEquals(types);

var typeView = types.view();
for (var i = 0; i < list.size(); ++i) {
var l = list.get(i);
var r = rist.get(i);
var ty = typeView.getFirst();
if (!compare(l, r, ty)) return false;
typeView = typeView.drop(1).mapIndexed((j, x) -> x.replace(j, l));
}

return true;
}

private <R> R compareTypeWith(
@NotNull Term lTy,
@NotNull Term rTy,
@NotNull Supplier<R> onFailed,
@NotNull Supplier<R> continuation
) {
if (!compare(lTy, rTy, null)) return onFailed.get();
return state().dCtx().with(lTy, continuation);
}

/**
* Compare types and run the {@param continuation} with those types in context (reverse order).
* @param onFailed run while failed (size doesn't match or compare failed)
*/
private <R> R compareTypesWith(
@NotNull ImmutableSeq<Term> list,
@NotNull ImmutableSeq<Term> rist,
@NotNull Supplier<R> onFailed,
@NotNull Supplier<R> continuation
) {
if (! list.sizeEquals(rist)) return onFailed.get();
if (list.isEmpty()) return continuation.get();
return compareTypeWith(list.getFirst(), rist.getFirst(), onFailed, () ->
compareTypesWith(list.drop(1), rist.drop(1), onFailed, continuation));
}

private boolean sortLt(@NotNull SortTerm l, @NotNull SortTerm r) {
var lift = l.lift();
var rift = r.lift();
// ISet <= Set0
// Set i <= Set j if i <= j
// Type i <= Type j if i <= j
return switch (l.kind()) {
case Type -> r.kind() == SortKind.Type && lift <= rift;
case ISet -> r.kind() == SortKind.Set || r.kind() == SortKind.ISet;
case Set -> r.kind() == SortKind.Set && lift <= rift;
};
}

private boolean compareSort(@NotNull SortTerm l, @NotNull SortTerm r) {
return switch (cmp) {
case Gt -> {
if (! sortLt(r, l)) {
// TODO: report error
yield false;
} else yield true;
}
case Eq -> {
if (! (l.kind() == r.kind() && l.lift() == r.lift())) {
// TODO: report error
yield false;
} else yield true;
}
case Lt -> {
if (! sortLt(l, r)) {
// TODO: report error
yield false;
} else yield true;
}
};
}

/**
* Compare two type formation
* Note: don't confuse with {@link TermComparator#doCompareTyped(Term, Term, Term)}
*/
private boolean doCompareType(@NotNull Formation preLhs, @NotNull Term preRhs) {
if (preLhs.getClass() != preRhs.getClass()) return false;
return switch (new Pair<>(preLhs, (Formation) preRhs)) {
case Pair(DataCall lhs, DataCall rhs) -> {
if (lhs.ref() != rhs.ref()) yield false;
yield compareMany(lhs.args(), rhs.args(), TeleDef.defTele(lhs.ref())
.map(x -> x.type().elevate(lhs.ulift())));
}
case Pair(PiTerm lhs, PiTerm rhs) -> compareTypeWith(lhs.param(), rhs.param(), () -> false, () ->
compare(lhs.body(), rhs.body(), null));
case Pair(SigmaTerm lhs, SigmaTerm rhs) -> compareTypesWith(lhs.params(), rhs.params(), () -> false, () -> true);
case Pair(SortTerm lhs, SortTerm rhs) -> compareSort(lhs, rhs);
default -> false;
};
}
}
4 changes: 4 additions & 0 deletions syntax/src/main/java/org/aya/syntax/core/term/Term.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ public sealed interface Term extends Serializable, AyaDocile
*/
@NotNull Term descent(@NotNull IndexedFunction<Term, Term> f);

default @NotNull Term elevate(int level) {
return descent((_, t) -> t.elevate(level));
}

record Matching(@NotNull SourcePos sourcePos, @NotNull ImmutableSeq<Arg<Pat>> patterns, @NotNull Term body) {
public @NotNull Matching update(@NotNull ImmutableSeq<Arg<Pat>> patterns, @NotNull Term body) {
return body == body() && patterns.sameElements(patterns(), true) ? this
Expand Down
10 changes: 0 additions & 10 deletions syntax/src/main/java/org/aya/syntax/core/term/call/Callable.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,4 @@ sealed interface Common extends Callable permits Tele {
@Override @NotNull DefVar<? extends TeleDef, ? extends Decl> ref();
int ulift();
}

/** This exists solely for simplifying code in the tycker. */
@FunctionalInterface
interface Factory<D extends TeleDef, S extends Decl> {
@Contract(pure = true, value = "_,_,_->new") @NotNull Callable make(
DefVar<D, S> defVar,
int ulift,
ImmutableSeq<@NotNull Term> args
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,8 @@ public ConCall(

return new ConCall(newHead, newArgs);
}

@Override public Tele elevate(int level) {
return new ConCall(new Head(head.dataRef(), head.ref(), head.ulift() + level, head.dataArgs()), conArgs);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ public record DataCall(
return new DataCall(ref, ulift, args.appended(arg));
}

@Override public Tele elevate(int level) {
return new DataCall(ref, ulift + level, args);
}

// public @NotNull ConCall.Head conHead(@NotNull DefVar<CtorDef, TeleDecl.DataCtor> ctorRef) {
// return new ConCall.Head(ref, ctorRef, ulift, args);
// }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,8 @@ public record FnCall(
public @NotNull Tele applyTo(@NotNull Term arg) {
return new FnCall(ref, ulift, args);
}

@Override public Tele elevate(int level) {
return new FnCall(ref, ulift + level, args);
}
}

0 comments on commit d6a84ec

Please sign in to comment.