diff --git a/src/Expander/Error.hs b/src/Expander/Error.hs index a25ea8b2..2e9d559f 100644 --- a/src/Expander/Error.hs +++ b/src/Expander/Error.hs @@ -48,13 +48,19 @@ data ExpansionErr | ReaderError Text | WrongMacroContext Syntax MacroContext MacroContext | NotValidType Syntax - | TypeMismatch (Maybe SrcLoc) Ty Ty - | OccursCheckFailed MetaPtr Ty + | TypeCheckError TypeCheckError | WrongArgCount Syntax Constructor Int Int | NotAConstructor Syntax | WrongDatatypeArity Syntax Datatype Natural Int deriving (Show) +data TypeCheckError + = TypeMismatch (Maybe SrcLoc) Ty Ty (Maybe (Ty, Ty)) + -- ^ Blame for constraint, expected, got, and specific part that doesn't match + | OccursCheckFailed MetaPtr Ty + deriving (Show) + + data MacroContext = ExpressionCtx | TypeCtx @@ -154,23 +160,7 @@ instance Pretty VarInfo ExpansionErr where ] pp env (NotValidType stx) = hang 2 $ group $ vsep [text "Not a type:", pp env stx] - pp env (TypeMismatch loc expected got) = - group $ vsep [ group $ hang 2 $ vsep [ text "Type mismatch at" - , maybe (text "unknown location") (pp env) loc <> text "." - ] - , group $ vsep [ group $ hang 2 $ vsep [ text "Expected" - , pp env expected - ] - , group $ hang 2 $ vsep [ text "but got" - , pp env got - ] - ] - ] - - pp env (OccursCheckFailed ptr ty) = - hang 2 $ group $ vsep [ text "Occurs check failed:" - , group (vsep [viaShow ptr, "≠", pp env ty]) - ] + pp env (TypeCheckError err) = pp env err pp env (WrongArgCount stx ctor wanted got) = hang 2 $ vsep [ text "Wrong number of arguments for constructor" <+> pp env ctor @@ -187,6 +177,36 @@ instance Pretty VarInfo ExpansionErr where , text "In" <+> align (pp env stx) ] +instance Pretty VarInfo TypeCheckError where + pp env (TypeMismatch loc expected got specifically) = + group $ vsep [ group $ hang 2 $ vsep [ text "Type mismatch at" + , maybe (text "unknown location") (pp env) loc <> text "." + ] + , group $ vsep $ + [ group $ hang 2 $ vsep [ text "Expected" + , pp env expected + ] + , group $ hang 2 $ vsep [ text "but got" + , pp env got + ] + ] ++ + case specifically of + Nothing -> [] + Just (expected', got') -> + [ hang 2 $ group $ vsep [text "Specifically," + , group (vsep [ pp env expected' + , text "doesn't match" <+> pp env got' + ]) + ] + ] + ] + + pp env (OccursCheckFailed ptr ty) = + hang 2 $ group $ vsep [ text "Occurs check failed:" + , group (vsep [viaShow ptr, "≠", pp env ty]) + ] + + instance Pretty VarInfo MacroContext where pp _env ExpressionCtx = text "an expression" pp _env ModuleCtx = text "a module" diff --git a/src/Expander/TC.hs b/src/Expander/TC.hs index 4fd1ad7a..171fe7b9 100644 --- a/src/Expander/TC.hs +++ b/src/Expander/TC.hs @@ -1,7 +1,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE ViewPatterns #-} -module Expander.TC where +module Expander.TC (unify, freshMeta, inst, specialize, varType, makeTypeMetas, generalizeType, normType) where import Control.Lens hiding (indices) import Control.Monad.Except @@ -73,7 +73,7 @@ occursCheck ptr t = do if ptr `elem` free then do t' <- normAll t - throwError $ OccursCheckFailed ptr t' + throwError $ TypeCheckError $ OccursCheckFailed ptr t' else pure () pruneLevel :: Traversable f => BindingLevel -> f MetaPtr -> Expand () @@ -193,9 +193,12 @@ instance UnificationErrorBlame SrcLoc where instance UnificationErrorBlame SplitCorePtr where getBlameLoc ptr = view (expanderOriginLocations . at ptr) <$> getState --- The expected type is first, the received is second unify :: UnificationErrorBlame blame => blame -> Ty -> Ty -> Expand () -unify blame t1 t2 = do +unify loc t1 t2 = unifyWithBlame (loc, t1, t2) 0 t1 t2 + +-- The expected type is first, the received is second +unifyWithBlame :: UnificationErrorBlame blame => (blame, Ty, Ty) -> Natural -> Ty -> Ty -> Expand () +unifyWithBlame blame depth t1 t2 = do t1' <- normType t1 t2' <- normType t2 unify' (unTy t1') (unTy t2') @@ -206,10 +209,10 @@ unify blame t1 t2 = do unify' TType TType = pure () unify' TSyntax TSyntax = pure () unify' TSignal TSignal = pure () - unify' (TFun a c) (TFun b d) = unify blame b a >> unify blame c d - unify' (TMacro a) (TMacro b) = unify blame a b + unify' (TFun a c) (TFun b d) = unifyWithBlame blame (depth + 1) b a >> unifyWithBlame blame (depth + 1) c d + unify' (TMacro a) (TMacro b) = unifyWithBlame blame (depth + 1) a b unify' (TDatatype dt1 args1) (TDatatype dt2 args2) - | dt1 == dt2 = traverse_ (uncurry (unify blame)) (zip args1 args2) + | dt1 == dt2 = traverse_ (uncurry (unifyWithBlame blame (depth + 1))) (zip args1 args2) -- Flex-flex unify' (TMetaVar ptr1) (TMetaVar ptr2) = do @@ -225,5 +228,13 @@ unify blame t1 t2 = do -- Mismatch unify' expected received = do - loc <- getBlameLoc blame - throwError $ TypeMismatch loc (Ty expected) (Ty received) + let (here, outerExpected, outerReceived) = blame + loc <- getBlameLoc here + e' <- normAll $ Ty expected + r' <- normAll $ Ty received + if depth == 0 + then throwError $ TypeCheckError $ TypeMismatch loc e' r' Nothing + else do + outerE' <- normAll outerExpected + outerR' <- normAll outerReceived + throwError $ TypeCheckError $ TypeMismatch loc outerE' outerR' (Just (e', r')) diff --git a/tests/Test.hs b/tests/Test.hs index 22bbbd5d..344712a8 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -450,7 +450,7 @@ moduleTests = testGroup "Module tests" [ shouldWork, shouldn'tWork ] ) , ( "examples/non-examples/type-errors.kl" , \case - TypeMismatch (Just _) _ _ -> True + TypeCheckError (TypeMismatch (Just _) _ _ _) -> True _ -> False ) ]