Skip to content

Commit

Permalink
coq, ocaml: Add function calls in typed AST
Browse files Browse the repository at this point in the history
  • Loading branch information
cpitclaudel committed Jan 22, 2020
1 parent 28bd4b4 commit e1eae17
Show file tree
Hide file tree
Showing 27 changed files with 424 additions and 226 deletions.
6 changes: 3 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -606,19 +606,19 @@ The following (excerpted from `<examples/conflicts_modular.v>`_) defines a ``Que
end.
Definition dequeue0: UInternalFunction reg_t empty_ext_fn_t :=
{{ fun _ : bits_t 32 =>
{{ fun dequeue0 () : bits_t 32 =>
guard(!read0(empty));
write0(empty, Ob~1);
read0(data) }}.
Definition enqueue0: UInternalFunction reg_t empty_ext_fn_t :=
{{ fun (val: bits_t 32) : unit_t =>
{{ fun enqueue0 (val: bits_t 32) : unit_t =>
guard(read0(empty));
write0(empty, Ob~0);
write0(data, val) }}.
Definition dequeue1: UInternalFunction reg_t empty_ext_fn_t :=
{{ fun _ : bits_t 32 =>
{{ fun dequeue1 () : bits_t 32 =>
guard(!read1(empty));
write1(empty, Ob~1);
read1(data) }}.
Expand Down
18 changes: 5 additions & 13 deletions coq/Desugaring.v
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,6 @@ Section Desugaring.

Import PrimUntyped.

Definition map_int_fn_body {fn_name_t var_t action action': Type}
(f: action -> action') (fn: InternalFunction fn_name_t var_t action) :=
{| int_name := fn.(int_name);
int_argspec := fn.(int_argspec);
int_retSig := fn.(int_retSig);
int_body := f fn.(int_body) |}.

Fixpoint desugar_action' {reg_t' ext_fn_t'} (pos: pos_t)
(fR: reg_t' -> reg_t) (fSigma: ext_fn_t' -> ext_fn_t)
(a: uaction reg_t' ext_fn_t') {struct a}
Expand All @@ -39,7 +32,8 @@ Section Desugaring.
| UUnop fn arg => UUnop fn (d arg)
| UBinop fn arg1 arg2 => UBinop fn (d arg1) (d arg2)
| UExternalCall fn arg => UExternalCall (fSigma fn) (d arg)
| UInternalCall fn args => UInternalCall (map_int_fn_body d fn) (List.map d args)
| UInternalCall fn args =>
UInternalCall (map_intf_body d fn) (List.map d args)
| UAPos p e => UAPos p (d e)
| USugar s => desugar pos fR fSigma s
end
Expand Down Expand Up @@ -71,9 +65,8 @@ Section Desugaring.
| UWhen cond body =>
UIf (d cond) (d body) (UFail (bits_t 0)) (* FIXME infer the type of the second branch? *)
| UStructInit sig fields =>
let empty := SyntaxMacros.uinit (struct_t sig) in
let usubst f := UBinop (UStruct2 (USubstField f)) in
List.fold_left (fun acc '(f, a) => (usubst f) acc (d a)) fields empty
let fields := List.map (fun '(f, a) => (f, d a)) fields in
SyntaxMacros.ustruct_init sig fields
| UArrayInit tau elements =>
let sig := {| array_type := tau; array_len := List.length elements |} in
let usubst pos := UBinop (UArray2 (USubstElement pos)) in
Expand All @@ -84,8 +77,7 @@ Section Desugaring.
SyntaxMacros.uswitch (d var) (d default) branches
| UCallModule fR' fSigma' fn args =>
let df body := desugar_action' pos (fun r => fR (fR' r)) (fun fn => fSigma (fSigma' fn)) body in
let args := List.map d args in
UInternalCall (map_int_fn_body df fn) args
UInternalCall (map_intf_body df fn) (List.map d args)
end.

Definition desugar_action (pos: pos_t) (a: uaction reg_t ext_fn_t)
Expand Down
97 changes: 97 additions & 0 deletions coq/Environments.v
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,20 @@ Section Contexts.
+ rewrite Heq in *. destruct eqn. reflexivity.
+ rewrite IHm; intuition congruence.
Qed.

Fixpoint capp {sig sig'} (ctx: context sig) (ctx': context sig'): context (sig ++ sig') :=
match sig return context sig -> context (sig ++ sig') with
| [] => fun _ => ctx'
| k :: sig => fun ctx => CtxCons k (chd ctx) (capp (ctl ctx) ctx')
end ctx.

Fixpoint csplit {sig sig'} (ctx: context (sig ++ sig')): (context sig * context sig') :=
match sig return context (sig ++ sig') -> (context sig * context sig') with
| [] => fun ctx => (CtxEmpty, ctx)
| k :: sig => fun ctx =>
let (l, r) := csplit (ctl ctx) in
(CtxCons k (chd ctx) l, r)
end ctx.
End Contexts.

Arguments context {K} V sig : assert.
Expand Down Expand Up @@ -207,6 +221,89 @@ Section Maps.
Qed.
End Maps.


Section ValueMaps.
Context {K: Type}.
Context {V: K -> Type} {V': K -> Type}.
Context (fV: forall k, V k -> V' k).

Fixpoint cmapv {sig} (ctx: context V sig) {struct ctx} : context V' sig :=
match ctx in context _ sig return context V' sig with
| CtxEmpty => CtxEmpty
| CtxCons k v ctx => CtxCons k (fV k v) (cmapv ctx)
end.

Lemma cmapv_creplace :
forall {sig} (ctx: context V sig) {k} (m: member k sig) v,
cmapv (creplace m v ctx) =
creplace m (fV k v) (cmapv ctx).
Proof.
induction ctx; cbn; intros.
- destruct (mdestruct m).
- destruct (mdestruct m) as [(-> & ->) | (? & ->)]; cbn in *.
+ reflexivity.
+ rewrite IHctx; reflexivity.
Qed.

Lemma cmapv_cassoc :
forall {sig} (ctx: context V sig) {k} (m: member k sig),
cassoc m (cmapv ctx) =
fV k (cassoc m ctx).
Proof.
induction ctx; cbn; intros.
- destruct (mdestruct m).
- destruct (mdestruct m) as [(-> & ->) | (? & ->)]; cbn in *.
+ reflexivity.
+ rewrite IHctx; reflexivity.
Qed.

Lemma cmapv_ctl :
forall {k sig} (ctx: context V (k :: sig)),
cmapv (ctl ctx) = ctl (cmapv ctx).
Proof.
intros; rewrite (ceqn ctx); reflexivity.
Qed.
End ValueMaps.

Section Folds.
Context {K: Type}.
Context {V: K -> Type}.

Section foldl.
Context {T: Type}.
Context (f: forall (k: K) (v: V k) (acc: T), T).

Fixpoint cfoldl {sig} (ctx: context V sig) (init: T) :=
match ctx with
| CtxEmpty => init
| CtxCons k v ctx => cfoldl ctx (f k v init)
end.

Fixpoint cfoldl' {sig} (ctx: context V sig) (init: T) :=
match sig return context V sig -> T with
| [] => fun _ => init
| k :: sig => fun ctx => cfoldl (ctl ctx) (f k (chd ctx) init)
end ctx.
End foldl.

Section foldr.
Context {T: list K -> Type}.
Context (f: forall (sg: list K) (k: K) (v: V k), T sg -> T (k :: sg)).

Fixpoint cfoldr {sig} (ctx: context V sig) (init: T []) :=
match ctx with
| CtxEmpty => init
| CtxCons k v ctx => f _ k v (cfoldr ctx init)
end.

Fixpoint cfoldr' {sig} (ctx: context V sig) (init: T []) :=
match sig return context V sig -> T sig with
| [] => fun _ => init
| k :: sig => fun ctx => f sig k (chd ctx) (cfoldr' (ctl ctx) init)
end ctx.
End foldr.
End Folds.

Notation esig K := (forall k: K, Type).

Record Env {K: Type} :=
Expand Down
8 changes: 4 additions & 4 deletions coq/Frontend.v
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ Definition var_t := string.
Definition fn_name_t := string.

Notation uaction := (uaction pos_t var_t fn_name_t).
Notation action := (action pos_t var_t).
Notation rule := (rule pos_t var_t).
Notation action := (action pos_t var_t fn_name_t).
Notation rule := (rule pos_t var_t fn_name_t).

Notation scheduler := (scheduler pos_t _).

Notation UInternalFunction reg_t ext_fn_t := (InternalFunction fn_name_t var_t (uaction reg_t ext_fn_t)).
Notation InternalFunction R Sigma sig tau := (InternalFunction fn_name_t var_t (action R Sigma sig tau)).
Notation UInternalFunction reg_t ext_fn_t := (InternalFunction var_t fn_name_t (uaction reg_t ext_fn_t)).
Notation InternalFunction R Sigma sig tau := (InternalFunction var_t fn_name_t (action R Sigma sig tau)).

Notation register_update_circuitry R Sigma := (register_update_circuitry _ R Sigma ContextEnv).

Expand Down
23 changes: 15 additions & 8 deletions coq/Interop.v
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ Section Packages.
Typically [string]. *)
Context {var_t: Type}.

(** [fn_name_t]: The type of function names.
Typically [string]. **)
Context {fn_name_t: Type}.

(** [rule_name_t]: The type of rule names.
Typically an inductive [rule1 | rule2 | …]. **)
Context {rule_name_t: Type}.
Expand All @@ -41,6 +45,8 @@ Section Packages.
{
(** [koika_var_names]: These names are used to generate readable code. *)
koika_var_names: Show var_t;
(** [koika_fn_names]: These names are used to generate readable code. *)
koika_fn_names: Show fn_name_t;

(** [koika_reg_names]: These names are used to generate readable code. *)
koika_reg_names: Show reg_t;
Expand All @@ -57,7 +63,7 @@ Section Packages.

(** [koika_rules]: The rules of the program. **)
koika_rules: forall _: rule_name_t,
TypedSyntax.rule pos_t var_t koika_reg_types koika_ext_fn_types;
TypedSyntax.rule pos_t var_t fn_name_t koika_reg_types koika_ext_fn_types;
(** [koika_rule_external]: Whether a rule will be replaced by a native
implementation. **)
koika_rule_external: rule_name_t -> bool;
Expand Down Expand Up @@ -186,7 +192,7 @@ Section TypeConv.
End TypeConv.

Section Helpers.
Context {pos_t var_t rule_name_t reg_t ext_fn_t: Type}.
Context {pos_t var_t fn_name_t rule_name_t reg_t ext_fn_t: Type}.

Context {R: reg_t -> type}.
Context {Sigma: ext_fn_t -> ExternalSignature}.
Expand All @@ -206,7 +212,7 @@ Section Helpers.
Context (opt: forall {sz}, circuit sz -> circuit sz).

Definition compile_scheduler
(rules: rule_name_t -> rule pos_t var_t R Sigma)
(rules: rule_name_t -> rule pos_t var_t fn_name_t R Sigma)
(external: rule_name_t -> bool)
(s: scheduler pos_t rule_name_t)
: register_update_circuitry rule_name_t CR CSigma _ :=
Expand All @@ -225,10 +231,10 @@ Section Helpers.
End Helpers.

Section Compilation.
Context {pos_t var_t rule_name_t reg_t ext_fn_t: Type}.
Context {pos_t var_t fn_name_t rule_name_t reg_t ext_fn_t: Type}.

Definition compile_koika_package
(s: @koika_package_t pos_t var_t rule_name_t reg_t ext_fn_t)
(s: @koika_package_t pos_t var_t fn_name_t rule_name_t reg_t ext_fn_t)
(opt: let circuit sz := circuit (lower_R s.(koika_reg_types))
(lower_Sigma s.(koika_ext_fn_types)) sz in
forall {sz}, circuit sz -> circuit sz)
Expand All @@ -242,19 +248,20 @@ End Compilation.
Record interop_package_t :=
{ pos_t := unit;
var_t := string;
fn_name_t := string;
ip_reg_t : Type;
ip_rule_name_t : Type;
ip_ext_fn_t : Type;
ip_koika : @koika_package_t pos_t var_t ip_rule_name_t ip_reg_t ip_ext_fn_t;
ip_koika : @koika_package_t pos_t var_t fn_name_t ip_rule_name_t ip_reg_t ip_ext_fn_t;
ip_verilog : @verilog_package_t ip_ext_fn_t;
ip_sim : @sim_package_t ip_ext_fn_t }.

Require Import Koika.ExtractionSetup.

Module Backends.
Section Backends.
Context {pos_t var_t rule_name_t reg_t ext_fn_t: Type}.
Notation koika_package_t := (@koika_package_t pos_t var_t rule_name_t reg_t ext_fn_t).
Context {pos_t var_t fn_name_t rule_name_t reg_t ext_fn_t: Type}.
Notation koika_package_t := (@koika_package_t pos_t var_t fn_name_t rule_name_t reg_t ext_fn_t).
Notation verilog_package_t := (@verilog_package_t ext_fn_t).
Notation sim_package_t := (@sim_package_t ext_fn_t).

Expand Down
17 changes: 8 additions & 9 deletions coq/Lowering.v
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
(*! Language | Compilation from typed ASTs to lowered ASTs !*)
Require Export Koika.Common Koika.Environments.
Require Import Koika.Syntax Koika.TypedSyntaxFunctions.
Require Koika.TypedSyntax Koika.LoweredSyntax.
Require Import Koika.Syntax Koika.TypedSyntaxFunctions Koika.SyntaxMacros.
Require Koika.SyntaxMacros Koika.TypedSyntax Koika.LoweredSyntax.

Import PrimTyped CircuitSignatures.

Section Lowering.
Context {pos_t var_t rule_name_t reg_t ext_fn_t: Type}.
Context {pos_t var_t fn_name_t rule_name_t reg_t ext_fn_t: Type}.

Context {R: reg_t -> type}.
Context {Sigma: ext_fn_t -> ExternalSignature}.
Context {REnv: Env reg_t}.

Definition lsig_of_tsig (sig: tsig var_t) : lsig :=
List.map (fun k_tau => type_sz (snd k_tau)) sig.

Definition lower_R idx :=
type_sz (R idx).
Notation lR := lower_R.
Expand All @@ -31,7 +28,7 @@ Section Lowering.
: forall f, CSig_denote (lSigma f) :=
fun f => fun bs => bits_of_value (sigma f (value_of_bits bs)).

Notation typed_action := (TypedSyntax.action pos_t var_t R Sigma).
Notation typed_action := (TypedSyntax.action pos_t var_t fn_name_t R Sigma).
Notation low_action := (LoweredSyntax.action pos_t var_t lR lSigma).

Section Action.
Expand Down Expand Up @@ -97,13 +94,13 @@ Section Lowering.
match a with
| TypedSyntax.Fail tau =>
LoweredSyntax.Fail (type_sz tau)
| @TypedSyntax.Var _ _ _ _ _ _ _ k _ m =>
| @TypedSyntax.Var _ _ _ _ _ _ _ _ k _ m =>
LoweredSyntax.Var k (lower_member m)
| TypedSyntax.Const cst =>
LoweredSyntax.Const (bits_of_value cst)
| TypedSyntax.Seq r1 r2 =>
LoweredSyntax.Seq (l r1) (l r2)
| @TypedSyntax.Assign _ _ _ _ _ _ _ k _ m ex =>
| @TypedSyntax.Assign _ _ _ _ _ _ _ _ k _ m ex =>
LoweredSyntax.Assign k (lower_member m) (l ex)
| TypedSyntax.Bind var ex body =>
LoweredSyntax.Bind var (l ex) (l body)
Expand All @@ -119,6 +116,8 @@ Section Lowering.
lower_binop fn (l a1) (l a2)
| TypedSyntax.ExternalCall fn a =>
LoweredSyntax.ExternalCall fn (l a)
| TypedSyntax.InternalCall fn args body =>
SyntaxMacros.InternalCall (cmapv (fun _ a => l a) args) (l body)
| TypedSyntax.APos p a =>
LoweredSyntax.APos p (l a)
end.
Expand Down
13 changes: 13 additions & 0 deletions coq/Member.v
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,16 @@ Proof.
+ exact (MemberHd k' (sig ++ infix ++ sig')).
+ exact (MemberTl k k' (sig ++ infix ++ sig') (mshift' _ infix sig sig' k m')).
Defined.

Fixpoint mshift_pair {K sig} (k: K) (p: {k': K & member k' sig})
: {k': K & member k' (k :: sig)} :=
let '(existT _ k' m) := p in
existT _ k' (MemberTl k' k _ m).

Fixpoint all_members {K} (sig: list K): list { k: K & member k sig } :=
match sig with
| [] => []
| k :: sig => let ms := all_members sig in
let ms := List.map (mshift_pair k) ms in
(existT _ k (MemberHd k sig)) :: ms
end.
Loading

0 comments on commit e1eae17

Please sign in to comment.