From b033d2432d35d222b858d781a8ee01c44cb89d42 Mon Sep 17 00:00:00 2001 From: lemastero Date: Sun, 5 May 2024 11:20:12 +0200 Subject: [PATCH] handle multiple named function arguments - named arguments, single clause --- examples/adts.agda | 30 +++++-- examples/adts.scala | 6 +- src/Agda/Compiler/Scala/AgdaToScalaExpr.hs | 95 +++++++++++++--------- src/Agda/Compiler/Scala/PrintScalaExpr.hs | 4 +- test/PrintScalaExprTest.hs | 12 +-- 5 files changed, 92 insertions(+), 55 deletions(-) diff --git a/examples/adts.agda b/examples/adts.agda index b516473..754c4f5 100644 --- a/examples/adts.agda +++ b/examples/adts.agda @@ -1,7 +1,6 @@ module examples.adts where --- simple product type no arguments - sealed trait + case objects - +-- simple sum type no arguments - sealed trait + case objects data Rgb : Set where Red : Rgb Green : Rgb @@ -13,11 +12,12 @@ data Bool : Set where False : Bool {-# COMPILE AGDA2SCALA Bool #-} --- trivial function with single argument +-- simple sum type with arguments - sealed trait + case class -idRgb : Rgb -> Rgb -idRgb x = x -{-# COMPILE AGDA2SCALA idRgb #-} +data Color : Set where + Light : Rgb -> Color + Dark : Rgb -> Color +-- TODO {-# COMPILE AGDA2SCALA Color #-} -- simple sum type - case class @@ -27,3 +27,21 @@ record RgbPair : Set where fst : Rgb snd : Bool {-# COMPILE AGDA2SCALA RgbPair #-} + +-- trivial function with single argument + +idRgb : Rgb -> Rgb +idRgb theArg = theArg +{-# COMPILE AGDA2SCALA idRgb #-} + +-- const function with one named argument + +rgbConstTrue1 : (rgb : Rgb) → Bool +rgbConstTrue1 rgb = True -- TODO produce function body +-- TODO {-# COMPILE AGDA2SCALA rgbConstTrue1 #-} + +-- function with multiple named arguments + +and0 : (rgbPairArg : RgbPair) -> (rgbArg : Rgb) -> RgbPair +and0 rgbPairArg rgbArg = rgbPairArg +{-# COMPILE AGDA2SCALA and0 #-} diff --git a/examples/adts.scala b/examples/adts.scala index 1fc04f2..ab69aed 100644 --- a/examples/adts.scala +++ b/examples/adts.scala @@ -9,7 +9,9 @@ sealed trait Bool case object True extends Bool case object False extends Bool -def idRgb(x: Rgb): Rgb = x - final case class RgbPair(snd: Bool, fst: Rgb) + +def idRgb(theArg: Rgb): Rgb = theArg + +def and0(rgbArg: Rgb, rgbPairArg: RgbPair): RgbPair = rgbPairArg } diff --git a/src/Agda/Compiler/Scala/AgdaToScalaExpr.hs b/src/Agda/Compiler/Scala/AgdaToScalaExpr.hs index 1b70741..df42bbe 100644 --- a/src/Agda/Compiler/Scala/AgdaToScalaExpr.hs +++ b/src/Agda/Compiler/Scala/AgdaToScalaExpr.hs @@ -1,6 +1,4 @@ -module Agda.Compiler.Scala.AgdaToScalaExpr ( - compileDefn - ) where +module Agda.Compiler.Scala.AgdaToScalaExpr ( compileDefn ) where import Agda.Compiler.Backend ( funCompiled, funClauses, Defn(..), RecordData(..)) import Agda.Syntax.Abstract.Name ( QName ) @@ -22,8 +20,8 @@ compileDefn :: QName -> Defn -> ScalaExpr compileDefn defName theDef = case theDef of Datatype{dataCons = dataCons} -> compileDataType defName dataCons - Function{funCompiled = funDef, funClauses = fc} -> - compileFunction defName funDef fc + Function{funCompiled = funCompiled, funClauses = funClauses} -> + compileFunction defName funCompiled funClauses RecordDefn(RecordData{_recFields = recFields, _recTel = recTel}) -> compileRecord defName recFields recTel other -> @@ -42,32 +40,52 @@ compileFunction :: QName -> Maybe CompiledClauses -> [Clause] -> ScalaExpr -compileFunction defName funDef fc = +compileFunction defName funCompiled funClauses = SeFun - (fromQName defName) - [SeVar (compileFunctionArgument fc) (compileFunctionArgType fc)] -- TODO many function arguments - (compileFunctionResultType fc) - (compileFunctionBody funDef) - -compileFunctionArgument :: [Clause] -> ScalaName -compileFunctionArgument [] = "" -compileFunctionArgument [fc] = fromDeBruijnPattern (namedThing (unArg (head (namedClausePats fc)))) -compileFunctionArgument xs = error "unsupported compileFunctionArgument" ++ (show xs) -- show xs - -compileFunctionArgType :: [Clause] -> ScalaType -compileFunctionArgType [ Clause{clauseTel = ct} ] = fromTelescope ct -compileFunctionArgType xs = error "unsupported compileFunctionArgType" ++ (show xs) - -fromTelescope :: Telescope -> ScalaName -- TODO PP probably parent should be different, use fold on telescope above -fromTelescope tel = case tel of - ExtendTel a _ -> fromDom a - other -> error ("unhandled fromType" ++ show other) + (fromQName defName) -- ++ "\n FULL FUNCTION DEFINITION \n[\n" ++ (show theDef) ++ "\n]\n") + (funArgs funClauses) + (compileFunctionResultType funClauses) + -- you can get body of the function using: + -- - FunctionData _funCompiled + -- - FunctionData _funClauses Clause clauseBody + -- see: + -- https://hackage.haskell.org/package/Agda-2.6.4.3/docs/Agda-TypeChecking-Monad-Base.html#t:FunctionData + -- https://hackage.haskell.org/package/Agda/docs/Agda-Syntax-Internal.html#t:Clause + -- at this point both contain the same info (at least in simple cases) + (compileFunctionBody funCompiled) + +funArgs :: [Clause] -> [SeVar] +funArgs [] = [] +funArgs (c : cs) = funArgsFromClause c + +funArgsFromClause :: Clause -> [SeVar] +funArgsFromClause c@Clause{clauseTel = clauseTel} = case parsedArgs of + [(SeVar "" varType)] -> [SeVar (hackyFunArgNameFromClause c) varType] + args -> args + where + parsedArgs = foldl varsFromTelescope [] clauseTel + +-- this is extremely hacky way to get function argument name +-- for identity function +-- I apparently do not understand enough how this works +-- or perhaps this is bug in Agda compiler :) +hackyFunArgNameFromClause :: Clause -> ScalaName +hackyFunArgNameFromClause fc = hackyFunArgNameFromDeBruijnPattern (namedThing (unArg + (head -- TODO perhaps iterate here + (namedClausePats fc)))) + +hackyFunArgNameFromDeBruijnPattern :: DeBruijnPattern -> ScalaName +hackyFunArgNameFromDeBruijnPattern d = case d of + VarP a b -> (dbPatVarName b) + a@(ConP x y z) -> "\n hackyFunArgNameFromDeBruijnPattern \n[\n" ++ show a ++ "\n]\n" + other -> error ("hackyFunArgNameFromDeBruijnPattern " ++ show other) nameFromDom :: Dom Type -> ScalaName nameFromDom dt = case (domName dt) of - Nothing -> error ("nameFromDom" ++ show dt) + Nothing -> "" Just a -> namedNameToStr a +-- https://hackage.haskell.org/package/Agda-2.6.4.3/docs/Agda-Syntax-Common.html#t:NamedName namedNameToStr :: NamedName -> ScalaName namedNameToStr n = rangedThing (woThing n) @@ -76,39 +94,38 @@ fromDom x = fromType (unDom x) compileFunctionResultType :: [Clause] -> ScalaType compileFunctionResultType [Clause{clauseType = ct}] = fromMaybeType ct -compileFunctionResultType other = error ("unhandled compileFunctionResultType" ++ show other) +compileFunctionResultType (Clause{clauseType = ct} : xs) = fromMaybeType ct +compileFunctionResultType other = error "Fatal error - function has not clause." fromMaybeType :: Maybe (Arg Type) -> ScalaName fromMaybeType (Just argType) = fromArgType argType -fromMaybeType other = error ("unhandled fromMaybeType" ++ show other) +fromMaybeType other = error ("\nunhandled fromMaybeType \n[" ++ show other ++ "]\n") fromArgType :: Arg Type -> ScalaName fromArgType arg = fromType (unArg arg) fromType :: Type -> ScalaName fromType t = case t of - a@(El _ ue) -> fromTerm ue - other -> error ("unhandled fromType" ++ show other) + El _ ue -> fromTerm ue + other -> error ("unhandled fromType [" ++ show other ++ "]") +-- https://hackage.haskell.org/package/Agda-2.6.4.3/docs/Agda-Syntax-Internal.html#t:Term fromTerm :: Term -> ScalaName fromTerm t = case t of - Def qname el -> fromQName qname - other -> error ("unhandled fromTerm" ++ show other) - -fromDeBruijnPattern :: DeBruijnPattern -> ScalaName -fromDeBruijnPattern d = case d of - VarP a b -> (dbPatVarName b) - a@(ConP x y z) -> show a - other -> error ("unhandled fromDeBruijnPattern" ++ show other) + Def qName elims -> fromQName qName + Var n elims -> "\nunhandled fromTerm Var \n[" ++ show t ++ "]\n" + other -> error ("\nunhandled fromTerm [" ++ show other ++ "]\n") compileFunctionBody :: Maybe CompiledClauses -> FunBody compileFunctionBody (Just funDef) = fromCompiledClauses funDef -compileFunctionBody funDef = error ("unhandled compileFunctionBody " ++ show funDef) +compileFunctionBody funDef = error "Fatal error - function body is not compiled." +-- https://hackage.haskell.org/package/Agda/docs/Agda-TypeChecking-CompiledClause.html#t:CompiledClauses fromCompiledClauses :: CompiledClauses -> FunBody fromCompiledClauses cc = case cc of + (Case argInt caseCompiledClauseTerm) -> "WIP" --"\nCase fromCompiledClauses\n[\n" ++ (show cc) ++ "\n]\n" (Done (x:xs) term) -> fromArgName x - other -> error ("unhandled fromCompiledClauses " ++ show other) + other -> "\nunhandled fromCompiledClauses \n\n[" ++ show other ++ "]\n" fromArgName :: Arg ArgName -> FunBody fromArgName = unArg diff --git a/src/Agda/Compiler/Scala/PrintScalaExpr.hs b/src/Agda/Compiler/Scala/PrintScalaExpr.hs index 8c5b60c..97e25cc 100644 --- a/src/Agda/Compiler/Scala/PrintScalaExpr.hs +++ b/src/Agda/Compiler/Scala/PrintScalaExpr.hs @@ -25,11 +25,11 @@ printScalaExpr def = case def of <> defsSeparator (SeFun fName args resType funBody) -> "def" <> exprSeparator <> fName - <> "(" <> combineLines (map printVar args) <> ")" + <> "(" <> combineThem (map printVar args) <> ")" <> ":" <> exprSeparator <> resType <> exprSeparator <> "=" <> exprSeparator <> funBody <> defsSeparator - (SeProd name args) -> printCaseClass name args + (SeProd name args) -> printCaseClass name args <> defsSeparator (Unhandled "" payload) -> "" (Unhandled name payload) -> "TODO " ++ (show name) ++ " " ++ (show payload) other -> "unsupported printScalaExpr " ++ (show other) diff --git a/test/PrintScalaExprTest.hs b/test/PrintScalaExprTest.hs index 5ff38c5..5791b15 100644 --- a/test/PrintScalaExprTest.hs +++ b/test/PrintScalaExprTest.hs @@ -23,11 +23,11 @@ testPrintSealedTrait = TestCase "sealed trait Color" (printSealedTrait "Color")) ---testPrintPackage :: Test ---testPrintPackage = TestCase --- (assertEqual "printPackage" --- "package adts" --- (printPackage "adts")) +testPrintPackage :: Test +testPrintPackage = TestCase + (assertEqual "printPackage" + "object adts" + (printPackage "adts")) testCombineLines :: Test testCombineLines = TestCase @@ -57,7 +57,7 @@ printScalaTests :: Test printScalaTests = TestList [ TestLabel "printCaseObject" testPrintCaseObject , TestLabel "printSealedTrait" testPrintSealedTrait --- , TestLabel "printPackage" testPrintPackage + , TestLabel "printPackage" testPrintPackage , TestLabel "combineLines" testCombineLines , TestLabel "printCaseClass" testPrintCaseClass , TestLabel "printScalaExpr" testPrintScalaExpr