Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: extract handling of RecForms #38

Merged
merged 1 commit into from
Aug 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 5 additions & 39 deletions Qpf/Macro/Data/Ind.lean
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import Qpf.Macro.Data.RecForm
import Qpf.Macro.Data.View
import Qpf.Macro.Common
import Mathlib.Data.QPF.Multivariate.Constructions.Fix
Expand All @@ -7,30 +8,8 @@ open Lean.Parser (Parser)
open Lean Meta Elab.Command Elab.Term Parser.Term
open Lean.Parser.Tactic (inductionAlt)

/--
The recursive form encodes how a function argument is recursive.

Examples ty R α:

α → R α → List (R α) → R α
[nonRec, directRec, composed ]
-/
inductive RecursionForm :=
| nonRec (stx: Term)
| directRec
-- | composed -- Not supported yet
deriving Repr, BEq

partial def getArgTypes (v : Term) : List Term := match v.raw with
| .node _ ``arrow #[arg, _, deeper] =>
⟨arg⟩ :: getArgTypes ⟨deeper⟩
| rest => [⟨rest⟩]

def flattenForArg (n : Name) := Name.str .anonymous $ n.toStringWithSep "_" true

def containsStx (top : Term) (search : Term) : Bool :=
(top.raw.find? (· == search)).isSome

/-- Both `bracketedBinder` and `matchAlts` have optional arguments,
which cause them to not by recognized as parsers in quotation syntax
(that is, ``` `(bracketedBinder| ...) ``` does not work).
Expand All @@ -54,22 +33,6 @@ def addShapeToName : Name → Name
section
variable {m} [Monad m] [MonadQuotation m] [MonadError m] [MonadTrace m] [AddMessageContext m]

/-- Extract takes a constructor and extracts its recursive forms.

This function assumes the pre-processor has run
It also assumes you don't have polymorphic recursive types such as
data Ql α | nil | l : α → Ql Bool → Ql α -/
def extract (topName : Name) (view : CtorView) (rec_type : Term) : m $ Name × List RecursionForm :=
(view.declName.replacePrefix topName .anonymous , ·) <$> (do
let some type := view.type? | pure []
let type_ls := (getArgTypes ⟨type⟩).dropLast

type_ls.mapM fun v =>
if v == rec_type then pure .directRec
else if containsStx v rec_type then
throwErrorAt v.raw "Cannot handle composed recursive types"
else pure $ .nonRec v)

/-- Generate the binders for the different recursors -/
def mkRecursorBinder
(rec_type : Term) (name : Name)
Expand All @@ -87,6 +50,7 @@ def mkRecursorBinder
let ty ← form.foldlM (fun acc => (match · with
| ⟨.nonRec x, name⟩ => `(($name : $x) → $acc)
| ⟨.directRec, name⟩ => `(($name : $rec_type) → $acc)
| ⟨.composed x, _⟩ => throwErrorAt x "Cannot handle recursive forms"
)) out

`(bb | ($(mkIdent $ flattenForArg name) : $ty))
Expand Down Expand Up @@ -174,13 +138,15 @@ def generateRecBody (ctors : Array (Name × List RecursionForm)) (includeMotive
match f with
| .directRec => `(⟨_, $nm⟩)
| .nonRec _ => `(_)
| .composed _ => throwError "Cannot handle composed"

let nonMotiveArgs ← names.mapM fun _ => `(_)
let motiveArgs ← if includeMotive then
names.filterMapM fun ⟨nm, f⟩ =>
match f with
| .directRec => some <$> `($nm)
| .nonRec _ => pure none
| .composed _ => throwError "Cannot handle composed"
else pure #[]


Expand All @@ -194,7 +160,7 @@ def generateRecBody (ctors : Array (Name × List RecursionForm)) (includeMotive
def genRecursors (view : DataView) : CommandElabM Unit := do
let rec_type := view.getExpectedType

let mapped view.ctors.mapM (extract view.declName · rec_type)
let mapped := view.ctors.map (RecursionForm.extractWithName view.declName · rec_type)

let ih_types ← mapped.mapM fun ⟨name, base⟩ =>
mkRecursorBinder (rec_type) (name) base true
Expand Down
62 changes: 62 additions & 0 deletions Qpf/Macro/Data/RecForm.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import Qpf.Macro.Data.Replace

open Lean.Parser (Parser)
open Lean Meta Elab.Command Elab.Term Parser.Term
open Lean.Parser.Tactic (inductionAlt)

/--
The recursive form encodes how a function argument is recursive.
Examples ty R α:
α → R α → List (R α) → R α
[nonRec, directRec, composed ]
-/
inductive RecursionForm :=
| nonRec (stx : Term)
| directRec
| composed (stx : Term) -- Not supported yet
deriving Repr, BEq

namespace RecursionForm

variable {m} [Monad m] [MonadQuotation m]

private def containsStx (top : Term) (search : Term) : Bool :=
(top.raw.find? (· == search)).isSome

partial def getArgTypes (v : Term) : List Term := match v.raw with
| .node _ ``arrow #[arg, _, deeper] =>
⟨arg⟩ :: getArgTypes ⟨deeper⟩
| rest => [⟨rest⟩]

partial def toType (retTy : Term) : List Term → m Term
| [] => pure retTy
| hd :: tl => do `($hd → $(← toType retTy tl))

/-- Extract takes a constructor and extracts its recursive forms.
This function assumes the pre-processor has run
It also assumes you don't have polymorphic recursive types such as
data Ql α | nil | l : α → Ql Bool → Ql α -/
def extract (view : CtorView) (rec_type : Term) : List RecursionForm := do
if let some type := view.type? then
let type_ls := (getArgTypes ⟨type⟩).dropLast

type_ls.map fun v =>
if v == rec_type then .directRec
else if containsStx v rec_type then
.composed v
else .nonRec v
else []

def extractWithName (topName : Name) (view : CtorView) (rec_type : Term) : Name × List RecursionForm :=
(view.declName.replacePrefix topName .anonymous , extract view rec_type)

def replaceRec (old new : Term) : RecursionForm → Term
| .nonRec x => x
| .directRec => new
| .composed x => ⟨Replace.replaceAllStx old new x⟩

def toTerm (recType : Term) : RecursionForm → Term
| .nonRec x | .composed x => x
| .directRec => recType

end RecursionForm
Loading