From e1eae17c4db6504ff0ae71d39708e439421ef66b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Pit-Claudel?= Date: Wed, 22 Jan 2020 11:22:00 -0500 Subject: [PATCH] coq, ocaml: Add function calls in typed AST --- README.rst | 6 +-- coq/Desugaring.v | 18 ++----- coq/Environments.v | 97 ++++++++++++++++++++++++++++++++++ coq/Frontend.v | 8 +-- coq/Interop.v | 23 +++++--- coq/Lowering.v | 17 +++--- coq/Member.v | 13 +++++ coq/Parsing.v | 22 ++++---- coq/Std.v | 34 ++++++------ coq/Syntax.v | 4 +- coq/SyntaxFunctions.v | 29 +++------- coq/SyntaxMacros.v | 90 +++++++++++++++++-------------- coq/TypeInference.v | 23 ++++---- coq/TypedSemantics.v | 45 +++++++++++++--- coq/TypedSyntax.v | 12 +++-- coq/TypedSyntaxFunctions.v | 67 +++++++++++++++-------- coq/Types.v | 14 +++-- examples/collatz.v | 2 +- examples/conflicts_modular.v | 6 +-- examples/function_call.v | 4 +- examples/method_call.v | 4 +- examples/rv/RVCore.v | 29 +++++----- examples/rv/Scoreboard.v | 10 ++-- ocaml/backends/cpp.ml | 37 +++++++++---- ocaml/cuttlebone/cuttlebone.ml | 25 +++++---- ocaml/cuttlec.ml | 5 +- ocaml/frontends/lv.ml | 6 ++- 27 files changed, 424 insertions(+), 226 deletions(-) diff --git a/README.rst b/README.rst index c4af4cda..0005cccc 100644 --- a/README.rst +++ b/README.rst @@ -606,19 +606,19 @@ The following (excerpted from ``_) defines a ``Que end. Definition dequeue0: UInternalFunction reg_t empty_ext_fn_t := - {{ fun _ : bits_t 32 => + {{ fun dequeue0 () : bits_t 32 => guard(!read0(empty)); write0(empty, Ob~1); read0(data) }}. Definition enqueue0: UInternalFunction reg_t empty_ext_fn_t := - {{ fun (val: bits_t 32) : unit_t => + {{ fun enqueue0 (val: bits_t 32) : unit_t => guard(read0(empty)); write0(empty, Ob~0); write0(data, val) }}. Definition dequeue1: UInternalFunction reg_t empty_ext_fn_t := - {{ fun _ : bits_t 32 => + {{ fun dequeue1 () : bits_t 32 => guard(!read1(empty)); write1(empty, Ob~1); read1(data) }}. diff --git a/coq/Desugaring.v b/coq/Desugaring.v index 9bb48b04..a8238c60 100644 --- a/coq/Desugaring.v +++ b/coq/Desugaring.v @@ -13,13 +13,6 @@ Section Desugaring. Import PrimUntyped. - Definition map_int_fn_body {fn_name_t var_t action action': Type} - (f: action -> action') (fn: InternalFunction fn_name_t var_t action) := - {| int_name := fn.(int_name); - int_argspec := fn.(int_argspec); - int_retSig := fn.(int_retSig); - int_body := f fn.(int_body) |}. - Fixpoint desugar_action' {reg_t' ext_fn_t'} (pos: pos_t) (fR: reg_t' -> reg_t) (fSigma: ext_fn_t' -> ext_fn_t) (a: uaction reg_t' ext_fn_t') {struct a} @@ -39,7 +32,8 @@ Section Desugaring. | UUnop fn arg => UUnop fn (d arg) | UBinop fn arg1 arg2 => UBinop fn (d arg1) (d arg2) | UExternalCall fn arg => UExternalCall (fSigma fn) (d arg) - | UInternalCall fn args => UInternalCall (map_int_fn_body d fn) (List.map d args) + | UInternalCall fn args => + UInternalCall (map_intf_body d fn) (List.map d args) | UAPos p e => UAPos p (d e) | USugar s => desugar pos fR fSigma s end @@ -71,9 +65,8 @@ Section Desugaring. | UWhen cond body => UIf (d cond) (d body) (UFail (bits_t 0)) (* FIXME infer the type of the second branch? *) | UStructInit sig fields => - let empty := SyntaxMacros.uinit (struct_t sig) in - let usubst f := UBinop (UStruct2 (USubstField f)) in - List.fold_left (fun acc '(f, a) => (usubst f) acc (d a)) fields empty + let fields := List.map (fun '(f, a) => (f, d a)) fields in + SyntaxMacros.ustruct_init sig fields | UArrayInit tau elements => let sig := {| array_type := tau; array_len := List.length elements |} in let usubst pos := UBinop (UArray2 (USubstElement pos)) in @@ -84,8 +77,7 @@ Section Desugaring. SyntaxMacros.uswitch (d var) (d default) branches | UCallModule fR' fSigma' fn args => let df body := desugar_action' pos (fun r => fR (fR' r)) (fun fn => fSigma (fSigma' fn)) body in - let args := List.map d args in - UInternalCall (map_int_fn_body df fn) args + UInternalCall (map_intf_body df fn) (List.map d args) end. Definition desugar_action (pos: pos_t) (a: uaction reg_t ext_fn_t) diff --git a/coq/Environments.v b/coq/Environments.v index 4ae078bc..63e8ebc1 100644 --- a/coq/Environments.v +++ b/coq/Environments.v @@ -159,6 +159,20 @@ Section Contexts. + rewrite Heq in *. destruct eqn. reflexivity. + rewrite IHm; intuition congruence. Qed. + + Fixpoint capp {sig sig'} (ctx: context sig) (ctx': context sig'): context (sig ++ sig') := + match sig return context sig -> context (sig ++ sig') with + | [] => fun _ => ctx' + | k :: sig => fun ctx => CtxCons k (chd ctx) (capp (ctl ctx) ctx') + end ctx. + + Fixpoint csplit {sig sig'} (ctx: context (sig ++ sig')): (context sig * context sig') := + match sig return context (sig ++ sig') -> (context sig * context sig') with + | [] => fun ctx => (CtxEmpty, ctx) + | k :: sig => fun ctx => + let (l, r) := csplit (ctl ctx) in + (CtxCons k (chd ctx) l, r) + end ctx. End Contexts. Arguments context {K} V sig : assert. @@ -207,6 +221,89 @@ Section Maps. Qed. End Maps. + +Section ValueMaps. + Context {K: Type}. + Context {V: K -> Type} {V': K -> Type}. + Context (fV: forall k, V k -> V' k). + + Fixpoint cmapv {sig} (ctx: context V sig) {struct ctx} : context V' sig := + match ctx in context _ sig return context V' sig with + | CtxEmpty => CtxEmpty + | CtxCons k v ctx => CtxCons k (fV k v) (cmapv ctx) + end. + + Lemma cmapv_creplace : + forall {sig} (ctx: context V sig) {k} (m: member k sig) v, + cmapv (creplace m v ctx) = + creplace m (fV k v) (cmapv ctx). + Proof. + induction ctx; cbn; intros. + - destruct (mdestruct m). + - destruct (mdestruct m) as [(-> & ->) | (? & ->)]; cbn in *. + + reflexivity. + + rewrite IHctx; reflexivity. + Qed. + + Lemma cmapv_cassoc : + forall {sig} (ctx: context V sig) {k} (m: member k sig), + cassoc m (cmapv ctx) = + fV k (cassoc m ctx). + Proof. + induction ctx; cbn; intros. + - destruct (mdestruct m). + - destruct (mdestruct m) as [(-> & ->) | (? & ->)]; cbn in *. + + reflexivity. + + rewrite IHctx; reflexivity. + Qed. + + Lemma cmapv_ctl : + forall {k sig} (ctx: context V (k :: sig)), + cmapv (ctl ctx) = ctl (cmapv ctx). + Proof. + intros; rewrite (ceqn ctx); reflexivity. + Qed. +End ValueMaps. + +Section Folds. + Context {K: Type}. + Context {V: K -> Type}. + + Section foldl. + Context {T: Type}. + Context (f: forall (k: K) (v: V k) (acc: T), T). + + Fixpoint cfoldl {sig} (ctx: context V sig) (init: T) := + match ctx with + | CtxEmpty => init + | CtxCons k v ctx => cfoldl ctx (f k v init) + end. + + Fixpoint cfoldl' {sig} (ctx: context V sig) (init: T) := + match sig return context V sig -> T with + | [] => fun _ => init + | k :: sig => fun ctx => cfoldl (ctl ctx) (f k (chd ctx) init) + end ctx. + End foldl. + + Section foldr. + Context {T: list K -> Type}. + Context (f: forall (sg: list K) (k: K) (v: V k), T sg -> T (k :: sg)). + + Fixpoint cfoldr {sig} (ctx: context V sig) (init: T []) := + match ctx with + | CtxEmpty => init + | CtxCons k v ctx => f _ k v (cfoldr ctx init) + end. + + Fixpoint cfoldr' {sig} (ctx: context V sig) (init: T []) := + match sig return context V sig -> T sig with + | [] => fun _ => init + | k :: sig => fun ctx => f sig k (chd ctx) (cfoldr' (ctl ctx) init) + end ctx. + End foldr. +End Folds. + Notation esig K := (forall k: K, Type). Record Env {K: Type} := diff --git a/coq/Frontend.v b/coq/Frontend.v index af5b4ae9..cf9df0b8 100644 --- a/coq/Frontend.v +++ b/coq/Frontend.v @@ -54,13 +54,13 @@ Definition var_t := string. Definition fn_name_t := string. Notation uaction := (uaction pos_t var_t fn_name_t). -Notation action := (action pos_t var_t). -Notation rule := (rule pos_t var_t). +Notation action := (action pos_t var_t fn_name_t). +Notation rule := (rule pos_t var_t fn_name_t). Notation scheduler := (scheduler pos_t _). -Notation UInternalFunction reg_t ext_fn_t := (InternalFunction fn_name_t var_t (uaction reg_t ext_fn_t)). -Notation InternalFunction R Sigma sig tau := (InternalFunction fn_name_t var_t (action R Sigma sig tau)). +Notation UInternalFunction reg_t ext_fn_t := (InternalFunction var_t fn_name_t (uaction reg_t ext_fn_t)). +Notation InternalFunction R Sigma sig tau := (InternalFunction var_t fn_name_t (action R Sigma sig tau)). Notation register_update_circuitry R Sigma := (register_update_circuitry _ R Sigma ContextEnv). diff --git a/coq/Interop.v b/coq/Interop.v index 8fb0490d..a51238de 100644 --- a/coq/Interop.v +++ b/coq/Interop.v @@ -25,6 +25,10 @@ Section Packages. Typically [string]. *) Context {var_t: Type}. + (** [fn_name_t]: The type of function names. + Typically [string]. **) + Context {fn_name_t: Type}. + (** [rule_name_t]: The type of rule names. Typically an inductive [rule1 | rule2 | …]. **) Context {rule_name_t: Type}. @@ -41,6 +45,8 @@ Section Packages. { (** [koika_var_names]: These names are used to generate readable code. *) koika_var_names: Show var_t; + (** [koika_fn_names]: These names are used to generate readable code. *) + koika_fn_names: Show fn_name_t; (** [koika_reg_names]: These names are used to generate readable code. *) koika_reg_names: Show reg_t; @@ -57,7 +63,7 @@ Section Packages. (** [koika_rules]: The rules of the program. **) koika_rules: forall _: rule_name_t, - TypedSyntax.rule pos_t var_t koika_reg_types koika_ext_fn_types; + TypedSyntax.rule pos_t var_t fn_name_t koika_reg_types koika_ext_fn_types; (** [koika_rule_external]: Whether a rule will be replaced by a native implementation. **) koika_rule_external: rule_name_t -> bool; @@ -186,7 +192,7 @@ Section TypeConv. End TypeConv. Section Helpers. - Context {pos_t var_t rule_name_t reg_t ext_fn_t: Type}. + Context {pos_t var_t fn_name_t rule_name_t reg_t ext_fn_t: Type}. Context {R: reg_t -> type}. Context {Sigma: ext_fn_t -> ExternalSignature}. @@ -206,7 +212,7 @@ Section Helpers. Context (opt: forall {sz}, circuit sz -> circuit sz). Definition compile_scheduler - (rules: rule_name_t -> rule pos_t var_t R Sigma) + (rules: rule_name_t -> rule pos_t var_t fn_name_t R Sigma) (external: rule_name_t -> bool) (s: scheduler pos_t rule_name_t) : register_update_circuitry rule_name_t CR CSigma _ := @@ -225,10 +231,10 @@ Section Helpers. End Helpers. Section Compilation. - Context {pos_t var_t rule_name_t reg_t ext_fn_t: Type}. + Context {pos_t var_t fn_name_t rule_name_t reg_t ext_fn_t: Type}. Definition compile_koika_package - (s: @koika_package_t pos_t var_t rule_name_t reg_t ext_fn_t) + (s: @koika_package_t pos_t var_t fn_name_t rule_name_t reg_t ext_fn_t) (opt: let circuit sz := circuit (lower_R s.(koika_reg_types)) (lower_Sigma s.(koika_ext_fn_types)) sz in forall {sz}, circuit sz -> circuit sz) @@ -242,10 +248,11 @@ End Compilation. Record interop_package_t := { pos_t := unit; var_t := string; + fn_name_t := string; ip_reg_t : Type; ip_rule_name_t : Type; ip_ext_fn_t : Type; - ip_koika : @koika_package_t pos_t var_t ip_rule_name_t ip_reg_t ip_ext_fn_t; + ip_koika : @koika_package_t pos_t var_t fn_name_t ip_rule_name_t ip_reg_t ip_ext_fn_t; ip_verilog : @verilog_package_t ip_ext_fn_t; ip_sim : @sim_package_t ip_ext_fn_t }. @@ -253,8 +260,8 @@ Require Import Koika.ExtractionSetup. Module Backends. Section Backends. - Context {pos_t var_t rule_name_t reg_t ext_fn_t: Type}. - Notation koika_package_t := (@koika_package_t pos_t var_t rule_name_t reg_t ext_fn_t). + Context {pos_t var_t fn_name_t rule_name_t reg_t ext_fn_t: Type}. + Notation koika_package_t := (@koika_package_t pos_t var_t fn_name_t rule_name_t reg_t ext_fn_t). Notation verilog_package_t := (@verilog_package_t ext_fn_t). Notation sim_package_t := (@sim_package_t ext_fn_t). diff --git a/coq/Lowering.v b/coq/Lowering.v index 1b0d690b..288bb2fb 100644 --- a/coq/Lowering.v +++ b/coq/Lowering.v @@ -1,20 +1,17 @@ (*! Language | Compilation from typed ASTs to lowered ASTs !*) Require Export Koika.Common Koika.Environments. -Require Import Koika.Syntax Koika.TypedSyntaxFunctions. -Require Koika.TypedSyntax Koika.LoweredSyntax. +Require Import Koika.Syntax Koika.TypedSyntaxFunctions Koika.SyntaxMacros. +Require Koika.SyntaxMacros Koika.TypedSyntax Koika.LoweredSyntax. Import PrimTyped CircuitSignatures. Section Lowering. - Context {pos_t var_t rule_name_t reg_t ext_fn_t: Type}. + Context {pos_t var_t fn_name_t rule_name_t reg_t ext_fn_t: Type}. Context {R: reg_t -> type}. Context {Sigma: ext_fn_t -> ExternalSignature}. Context {REnv: Env reg_t}. - Definition lsig_of_tsig (sig: tsig var_t) : lsig := - List.map (fun k_tau => type_sz (snd k_tau)) sig. - Definition lower_R idx := type_sz (R idx). Notation lR := lower_R. @@ -31,7 +28,7 @@ Section Lowering. : forall f, CSig_denote (lSigma f) := fun f => fun bs => bits_of_value (sigma f (value_of_bits bs)). - Notation typed_action := (TypedSyntax.action pos_t var_t R Sigma). + Notation typed_action := (TypedSyntax.action pos_t var_t fn_name_t R Sigma). Notation low_action := (LoweredSyntax.action pos_t var_t lR lSigma). Section Action. @@ -97,13 +94,13 @@ Section Lowering. match a with | TypedSyntax.Fail tau => LoweredSyntax.Fail (type_sz tau) - | @TypedSyntax.Var _ _ _ _ _ _ _ k _ m => + | @TypedSyntax.Var _ _ _ _ _ _ _ _ k _ m => LoweredSyntax.Var k (lower_member m) | TypedSyntax.Const cst => LoweredSyntax.Const (bits_of_value cst) | TypedSyntax.Seq r1 r2 => LoweredSyntax.Seq (l r1) (l r2) - | @TypedSyntax.Assign _ _ _ _ _ _ _ k _ m ex => + | @TypedSyntax.Assign _ _ _ _ _ _ _ _ k _ m ex => LoweredSyntax.Assign k (lower_member m) (l ex) | TypedSyntax.Bind var ex body => LoweredSyntax.Bind var (l ex) (l body) @@ -119,6 +116,8 @@ Section Lowering. lower_binop fn (l a1) (l a2) | TypedSyntax.ExternalCall fn a => LoweredSyntax.ExternalCall fn (l a) + | TypedSyntax.InternalCall fn args body => + SyntaxMacros.InternalCall (cmapv (fun _ a => l a) args) (l body) | TypedSyntax.APos p a => LoweredSyntax.APos p (l a) end. diff --git a/coq/Member.v b/coq/Member.v index fd7b4762..727e9d74 100644 --- a/coq/Member.v +++ b/coq/Member.v @@ -203,3 +203,16 @@ Proof. + exact (MemberHd k' (sig ++ infix ++ sig')). + exact (MemberTl k k' (sig ++ infix ++ sig') (mshift' _ infix sig sig' k m')). Defined. + +Fixpoint mshift_pair {K sig} (k: K) (p: {k': K & member k' sig}) + : {k': K & member k' (k :: sig)} := + let '(existT _ k' m) := p in + existT _ k' (MemberTl k' k _ m). + +Fixpoint all_members {K} (sig: list K): list { k: K & member k sig } := + match sig with + | [] => [] + | k :: sig => let ms := all_members sig in + let ms := List.map (mshift_pair k) ms in + (existT _ k (MemberHd k sig)) :: ms + end. diff --git a/coq/Parsing.v b/coq/Parsing.v index 0f5022e4..e1db4cd2 100644 --- a/coq/Parsing.v +++ b/coq/Parsing.v @@ -128,18 +128,18 @@ Notation "a '[' b ']'" := (UBinop (UBits2 USel) a b) (in custom koika at level 7 Notation "a '[' b ':+' c ']'" := (UBinop (UBits2 (UIndexedSlice c)) a b) (in custom koika at level 75, c constr at level 0, format "'[' a [ b ':+' c ] ']'"). Notation "'`' a '`'" := ( a) (in custom koika at level 99, a constr at level 99). -Notation "'fun' args ':' ret '=>' body" := - {| int_name := ""; +Notation "'fun' nm args ':' ret '=>' body" := + {| int_name := nm%string; int_argspec := args; int_retSig := ret; int_body := body |} - (in custom koika at level 99, args custom koika_types, ret constr at level 0, body custom koika at level 99, format "'[v' 'fun' args ':' ret '=>' '/' body ']'"). -Notation "'fun' '_' ':' ret '=>' body" := - {| int_name := ""; + (in custom koika at level 99, nm custom koika_var at level 0, args custom koika_types, ret constr at level 0, body custom koika at level 99, format "'[v' 'fun' nm args ':' ret '=>' '/' body ']'"). +Notation "'fun' nm '()' ':' ret '=>' body" := + {| int_name := nm%string; int_argspec := nil; int_retSig := ret; int_body := body |} - (in custom koika at level 99, ret constr at level 0, body custom koika at level 99, format "'[v' 'fun' '_' ':' ret '=>' '/' body ']'"). + (in custom koika at level 99, nm custom koika_var at level 0, ret constr at level 0, body custom koika at level 99, format "'[v' 'fun' nm '()' ':' ret '=>' '/' body ']'"). (* Deprecated *) Notation "'call' instance method args" := @@ -152,7 +152,7 @@ Notation "'extcall' method '(' arg ')'" := (UExternalCall method arg) (in custom koika at level 98, method constr at level 0, arg custom koika). Notation "'call0' instance method " := - (UCallModule instance id method nil) + (USugar (UCallModule instance id method nil)) (in custom koika at level 98, instance constr at level 0, method constr). Notation "'funcall0' method " := (UInternalCall method nil) @@ -242,10 +242,10 @@ Module Type Tests. (* Notation "'{&' a '&}'" := (a) (a custom koika_types at level 200). *) (* Definition test_21 := {& "yoyo" : bits_t 2 &}. *) (* Definition test_22 := {& "yoyo" : bits_t 2 , "yoyo" : bits_t 2 &}. *) - Definition test_23 : InternalFunction string string (uaction reg_t) := {{ fun (arg1 : (bits_t 3)) (arg2 : bits_t 2) : bits_t 4 => magic }}. - Definition test_24 : nat -> InternalFunction string string (uaction reg_t) := (fun sz => {{ fun (arg1 : bits_t sz) (arg1 : bits_t sz) : bits_t sz => magic}}). - Definition test_25 : nat -> InternalFunction string string (uaction reg_t) := (fun sz => {{fun (arg1 : bits_t sz ) : bits_t sz => let oo := magic >> magic in magic}}). - Definition test_26 : nat -> InternalFunction string string (uaction reg_t) := (fun sz => {{ fun _ : bits_t sz => magic }}). + Definition test_23 : InternalFunction string string (uaction reg_t) := {{ fun test (arg1 : (bits_t 3)) (arg2 : bits_t 2) : bits_t 4 => magic }}. + Definition test_24 : nat -> InternalFunction string string (uaction reg_t) := (fun sz => {{ fun test (arg1 : bits_t sz) (arg1 : bits_t sz) : bits_t sz => magic}}). + Definition test_25 : nat -> InternalFunction string string (uaction reg_t) := (fun sz => {{fun test (arg1 : bits_t sz ) : bits_t sz => let oo := magic >> magic in magic}}). + Definition test_26 : nat -> InternalFunction string string (uaction reg_t) := (fun sz => {{ fun test () : bits_t sz => magic }}). Definition test_27 : uaction reg_t := {{ (if (!read0(data0)) then diff --git a/coq/Std.v b/coq/Std.v index 0c95910f..80fcfdea 100644 --- a/coq/Std.v +++ b/coq/Std.v @@ -29,7 +29,7 @@ Module Fifo1 (f: Fifo). Definition enq : UInternalFunction reg_t empty_ext_fn_t := - {{ fun (data : T) : bits_t 0 => + {{ fun enq (data : T) : bits_t 0 => if (!read1(valid0)) then write1(data0, data); write1(valid0, #Ob~1) @@ -38,7 +38,7 @@ Module Fifo1 (f: Fifo). Definition deq : UInternalFunction reg_t empty_ext_fn_t := - {{ fun _ : T => + {{ fun deq () : T => if (read0(valid0)) then write0(valid0, Ob~0) else @@ -74,7 +74,7 @@ Module Fifo1Bypass (f: Fifo). Definition enq : UInternalFunction reg_t empty_ext_fn_t := - {{ fun (data : T) : bits_t 0 => + {{ fun enq (data : T) : bits_t 0 => if (!read0(valid0)) then write0(data0, data); write0(valid0, #Ob~1) @@ -83,7 +83,7 @@ Module Fifo1Bypass (f: Fifo). Definition deq : UInternalFunction reg_t empty_ext_fn_t := - {{ fun _ : T => + {{ fun deq () : T => if (read1(valid0)) then write1(valid0, Ob~0) else @@ -102,7 +102,7 @@ Definition Maybe tau := Notation maybe tau := (struct_t (Maybe tau)). Definition valid {reg_t fn} (tau:type) : UInternalFunction reg_t fn := - {{ fun (x: tau) : maybe tau => + {{ fun valid (x: tau) : maybe tau => struct (Maybe tau) {| valid := (#(Bits.of_nat 1 1)) ; data := x @@ -110,7 +110,7 @@ Definition valid {reg_t fn} (tau:type) : UInternalFunction reg_t fn := }}. Definition invalid {reg_t fn} (tau:type) : UInternalFunction reg_t fn := - {{ fun _ : maybe tau => + {{ fun invalid () : maybe tau => struct (Maybe tau) {| valid := (#(Bits.of_nat 1 0)) |} }}. @@ -142,22 +142,22 @@ Module RfPow2 (s: RfPow2_sig). end. Definition read_0 : UInternalFunction reg_t empty_ext_fn_t := - {{ fun (idx : bits_t s.idx_sz) : s.T => + {{ fun read_0 (idx : bits_t s.idx_sz) : s.T => `UCompleteSwitch s.read_style s.idx_sz "idx" (fun idx => {{ read0(rData idx) }})` }}. Definition write_0 : UInternalFunction reg_t empty_ext_fn_t := - {{ fun (idx : bits_t s.idx_sz) (val: s.T) : unit_t => + {{ fun write_0 (idx : bits_t s.idx_sz) (val: s.T) : unit_t => `UCompleteSwitch s.write_style s.idx_sz "idx" (fun idx => {{ write0(rData idx, val) }})` }}. Definition read_1 : UInternalFunction reg_t empty_ext_fn_t := - {{ fun (idx : bits_t s.idx_sz) : s.T => + {{ fun read_1 (idx : bits_t s.idx_sz) : s.T => `UCompleteSwitch s.read_style s.idx_sz "idx" (fun idx => {{ read1(rData idx) }})` }}. Definition write_1 : UInternalFunction reg_t empty_ext_fn_t := - {{ fun (idx : bits_t s.idx_sz) (val: s.T) : unit_t => + {{ fun write_1 (idx : bits_t s.idx_sz) (val: s.T) : unit_t => `UCompleteSwitch s.write_style s.idx_sz "idx" (fun idx => {{ write1(rData idx, val) }})` }}. End RfPow2. @@ -190,7 +190,7 @@ Module Rf (s: Rf_sig). end. Definition read : UInternalFunction reg_t empty_ext_fn_t := - {{ fun (idx : bits_t log_sz) : s.T => + {{ fun read (idx : bits_t log_sz) : s.T => `USugar (USwitch {{idx}} @@ -206,7 +206,7 @@ Module Rf (s: Rf_sig). (List.seq 0 sz))) ` }}. Definition write : UInternalFunction reg_t empty_ext_fn_t := - {{ fun (idx : bits_t log_sz) (val: s.T) : unit_t => + {{ fun write (idx : bits_t log_sz) (val: s.T) : unit_t => `USugar (USwitch {{idx}} @@ -223,7 +223,7 @@ Module Rf (s: Rf_sig). End Rf. Definition signExtend {reg_t} (n:nat) (m:nat) : UInternalFunction reg_t empty_ext_fn_t := - {{ fun (arg : bits_t n) : bits_t (m+n) => sext(arg, m + n) }}. + {{ fun signExtend (arg : bits_t n) : bits_t (m+n) => sext(arg, m + n) }}. Module RfEhr (s: Rf_sig). @@ -248,7 +248,7 @@ Module RfEhr (s: Rf_sig). end. Definition read_0 : UInternalFunction reg_t empty_ext_fn_t := - {{ fun (idx : bits_t log_sz) : s.T => + {{ fun read_0 (idx : bits_t log_sz) : s.T => `USugar (USwitch {{idx}} @@ -264,7 +264,7 @@ Module RfEhr (s: Rf_sig). (List.seq 0 sz))) ` }}. Definition read_1 : UInternalFunction reg_t empty_ext_fn_t := - {{ fun (idx : bits_t log_sz) : s.T => + {{ fun read_1 (idx : bits_t log_sz) : s.T => `USugar (USwitch {{idx}} @@ -280,7 +280,7 @@ Module RfEhr (s: Rf_sig). (List.seq 0 sz))) ` }}. Definition write_0 : UInternalFunction reg_t empty_ext_fn_t := - {{ fun (idx : bits_t log_sz) (val: s.T) : unit_t => + {{ fun write_0 (idx : bits_t log_sz) (val: s.T) : unit_t => `USugar (USwitch {{idx}} @@ -296,7 +296,7 @@ Module RfEhr (s: Rf_sig). (List.seq 0 sz))) ` }}. Definition write_1 : UInternalFunction reg_t empty_ext_fn_t := - {{ fun (idx : bits_t log_sz) (val: s.T) : unit_t => + {{ fun write_1 (idx : bits_t log_sz) (val: s.T) : unit_t => `USugar (USwitch {{idx}} diff --git a/coq/Syntax.v b/coq/Syntax.v index 29d4b62c..7b979f22 100644 --- a/coq/Syntax.v +++ b/coq/Syntax.v @@ -18,7 +18,7 @@ Section Syntax. | UUnop (ufn1: PrimUntyped.ufn1) (arg1: uaction) | UBinop (ufn2: PrimUntyped.ufn2) (arg1: uaction) (arg2: uaction) | UExternalCall (ufn: ext_fn_t) (arg: uaction) - | UInternalCall (ufn: InternalFunction fn_name_t var_t uaction) (args: list uaction) + | UInternalCall (ufn: InternalFunction var_t fn_name_t uaction) (args: list uaction) | UAPos (p: pos_t) (e: uaction) | USugar (s: usugar) with usugar {reg_t ext_fn_t} := @@ -38,7 +38,7 @@ Section Syntax. | UCallModule {module_reg_t module_ext_fn_t: Type} (fR: module_reg_t -> reg_t) (fSigma: module_ext_fn_t -> ext_fn_t) - (fn: InternalFunction fn_name_t var_t (@uaction module_reg_t module_ext_fn_t)) + (fn: InternalFunction var_t fn_name_t (@uaction module_reg_t module_ext_fn_t)) (args: list uaction). Inductive scheduler := diff --git a/coq/SyntaxFunctions.v b/coq/SyntaxFunctions.v index 49777432..f1efbbd9 100644 --- a/coq/SyntaxFunctions.v +++ b/coq/SyntaxFunctions.v @@ -39,13 +39,9 @@ Section SyntaxFunctions. | UBinop ufn2 arg1 arg2 => UBinop ufn2 (r 0 arg1) (r 1 arg2) | UExternalCall ufn arg => UExternalCall ufn (r 0 arg) - | UInternalCall ufn args => - let ufn := - {| int_name := ufn.(int_name); - int_argspec := ufn.(int_argspec); - int_retSig := ufn.(int_retSig); - int_body := r 0 ufn.(int_body) |} in - let args := snd (foldi (fun n a args => (r n a :: args)) 1 [] args) in + | UInternalCall ufn arg => + let ufn := map_intf_body (r 0) ufn in + let args := snd (foldi (fun n a args => (r n a :: args)) 1 [] arg) in UInternalCall ufn args | UAPos _ e => (r 0 e) | USugar s => USugar (reposition_sugar p s) @@ -82,11 +78,7 @@ Section SyntaxFunctions. foldi (fun n a elements => (r n a) :: elements) 0 [] elements in UArrayInit tau elements | UCallModule fR fSigma ufn args => - let ufn := - {| int_name := ufn.(int_name); - int_argspec := ufn.(int_argspec); - int_retSig := ufn.(int_retSig); - int_body := r 0 ufn.(int_body) |} in + let ufn := map_intf_body (r 0) ufn in let args := snd (foldi (fun n a args => (r n a :: args)) 1 [] args) in UCallModule fR fSigma ufn args end. @@ -166,10 +158,7 @@ Section SyntaxFunctions. let ufn := if fbody then (* Only unfold the body if the error is in it *) - {| int_name := ufn.(int_name); - int_argspec := ufn.(int_argspec); - int_retSig := ufn.(int_retSig); - int_body := body |} + map_intf_body (fun _ => body) ufn else ufn in let '(n, (fargs, args)) := foldi (fun n arg '(fargs, args) => @@ -241,12 +230,8 @@ Section SyntaxFunctions. | UCallModule fR fSigma ufn args => let '(fbody, body) := pe 0 ufn.(int_body) in let ufn := - if fbody then - (* Only unfold the body if the error is in it *) - {| int_name := ufn.(int_name); - int_argspec := ufn.(int_argspec); - int_retSig := ufn.(int_retSig); - int_body := body |} + if fbody then (* Only unfold the body if the error is in it *) + map_intf_body (fun _ => body) ufn else ufn in let '(n, (fargs, args)) := foldi (fun n arg '(fargs, args) => diff --git a/coq/SyntaxMacros.v b/coq/SyntaxMacros.v index d4fe7fad..e15d42d5 100644 --- a/coq/SyntaxMacros.v +++ b/coq/SyntaxMacros.v @@ -1,5 +1,5 @@ (*! Frontend | Macros used in untyped programs !*) -Require Import Koika.Common Koika.Syntax Koika.TypedSyntax Koika.Primitives. +Require Import Koika.Common Koika.Types Koika.Syntax Koika.TypedSyntax Koika.TypedSyntax Koika.Primitives. Import PrimUntyped. Section SyntaxMacros. @@ -33,6 +33,11 @@ Section SyntaxMacros. let zeroes := UConst (tau := bits_t _) (Bits.zeroes (type_sz tau)) in UUnop (UConv (UUnpack tau)) zeroes. + Definition ustruct_init (sig: struct_sig) (fields: list (string * uaction)) : uaction := + let empty := SyntaxMacros.uinit (struct_t sig) in + let usubst f := UBinop (UStruct2 (USubstField f)) in + List.fold_left (fun acc '(f, a) => (usubst f) acc a) fields empty. + Fixpoint uswitch (var: uaction) (default: uaction) (branches: list (uaction * uaction)) : uaction := match branches with @@ -146,10 +151,10 @@ Module Display. | Str (s: string) | Value (tau: type). - Notation intfun := (InternalFunction fn_name_t var_t uaction). + Notation intfun := (InternalFunction var_t fn_name_t uaction). - Definition empty_printer : InternalFunction fn_name_t var_t uaction := - {| int_name := ""; + Definition empty_printer : InternalFunction var_t fn_name_t uaction := + {| int_name := "print"; int_argspec := []; int_retSig := unit_t; int_body := USugar USkip |}. @@ -157,8 +162,8 @@ Module Display. Definition display_utf8 s : uaction := UUnop (UDisplay (UDisplayUtf8)) (USugar (UConstString s)). - Definition nl_printer : InternalFunction fn_name_t var_t uaction := - {| int_name := ""; + Definition nl_printer : InternalFunction var_t fn_name_t uaction := + {| int_name := "print_nl"; int_argspec := []; int_retSig := unit_t; int_body := display_utf8 "\n" |}. @@ -198,23 +203,25 @@ Module Display. End Display. End Display. -Section TypedSyntaxMacros. - Context {pos_t var_t reg_t ext_fn_t: Type}. - Context {R: reg_t -> type} - {Sigma: ext_fn_t -> ExternalSignature}. +Require Import Koika.LoweredSyntax. + +Section LoweredSyntaxMacros. + Context {pos_t var_t fn_name_t reg_t ext_fn_t: Type}. + Context {CR: reg_t -> nat} + {CSigma: ext_fn_t -> CExternalSignature}. - Notation action := (action pos_t var_t R Sigma). + Notation action := (action pos_t var_t CR CSigma). - Fixpoint infix_action (infix: tsig var_t) {sig sig': tsig var_t} {tau} (a: action (sig ++ sig') tau) + Fixpoint infix_action (infix: lsig) {sig sig': lsig} {tau} (a: action (sig ++ sig') tau) : action (sig ++ infix ++ sig') tau. Proof. remember (sig ++ sig'); destruct a; subst. - - exact (Fail tau). - - exact (Var (mshift' infix m)). + - exact (Fail sz). + - exact (Var k (mshift' infix m)). - exact (Const cst). - - exact (Assign (mshift' infix m) (infix_action infix _ _ _ a)). + - exact (Assign k (mshift' infix m) (infix_action infix _ _ _ a)). - exact (Seq (infix_action infix _ _ _ a1) (infix_action infix _ _ _ a2)). - - exact (Bind var (infix_action infix _ _ _ a1) (infix_action infix (_ :: sig) sig' _ a2)). + - exact (Bind k (infix_action infix _ _ _ a1) (infix_action infix (_ :: sig) sig' _ a2)). - exact (If (infix_action infix _ _ _ a1) (infix_action infix _ _ _ a2) (infix_action infix _ _ _ a3)). - exact (Read port idx). - exact (Write port idx (infix_action infix _ _ _ a)). @@ -224,44 +231,47 @@ Section TypedSyntaxMacros. - exact (infix_action infix _ _ _ a). Defined. - Definition prefix_action (prefix: tsig var_t) {sig: tsig var_t} {tau} (a: action sig tau) - : action (prefix ++ sig) tau := + Definition prefix_action (prefix: lsig) {sig: lsig} {sz} (a: action sig sz) + : action (prefix ++ sig) sz := infix_action prefix (sig := []) a. Fixpoint suffix_action_eqn {A} (l: list A) {struct l}: l ++ [] = l. Proof. destruct l; cbn; [ | f_equal ]; eauto. Defined. - Definition suffix_action (suffix: tsig var_t) {sig: tsig var_t} {tau} (a: action sig tau) - : action (sig ++ suffix) tau. + Definition suffix_action (suffix: lsig) {sig: lsig} {sz} (a: action sig sz) + : action (sig ++ suffix) sz. Proof. rewrite <- (suffix_action_eqn suffix); apply infix_action; rewrite (suffix_action_eqn sig); exact a. Defined. + Definition lsig_of_tsig (sig: tsig var_t) : lsig := + List.map (fun k_tau => type_sz (snd k_tau)) sig. + Fixpoint InternalCall' - {tau: type} - (sig: tsig var_t) + {sz: nat} + (sig: lsig) (fn_sig: tsig var_t) - (fn_body: action (fn_sig ++ sig) tau) - (args: context (fun '(_, tau) => action sig tau) fn_sig) - : action sig tau := - match fn_sig return action (fn_sig ++ sig) tau -> - context (fun '(_, tau) => action sig tau) fn_sig -> - action sig tau with + (args: context (fun k_tau => action sig (type_sz (snd k_tau))) fn_sig) + (fn_body: action (lsig_of_tsig fn_sig ++ sig) sz) + : action sig sz := + match fn_sig return context (fun k_tau => action sig (type_sz (snd k_tau))) fn_sig -> + action ((lsig_of_tsig fn_sig) ++ sig) sz -> + action sig sz with | [] => - fun fn_body _ => + fun _ fn_body => fn_body | (k, tau) :: fn_sig => - fun fn_body args => + fun args fn_body => InternalCall' sig fn_sig - (Bind k (prefix_action fn_sig (chd args)) fn_body) (ctl args) - end fn_body args. + (Bind k (prefix_action (lsig_of_tsig fn_sig) (chd args)) fn_body) + end args fn_body. Fixpoint InternalCall - {tau: type} - (sig: tsig var_t) - (fn_sig: tsig var_t) - (fn_body: action fn_sig tau) - (args: context (fun '(_, tau) => action sig tau) fn_sig) - : action sig tau := - InternalCall' sig fn_sig (suffix_action sig fn_body) args. -End TypedSyntaxMacros. + {sz: nat} + {sig: lsig} + {fn_sig: tsig var_t} + (args: context (fun k_tau => action sig (type_sz (snd k_tau))) fn_sig) + (fn_body: action (lsig_of_tsig fn_sig) sz) + : action sig sz := + InternalCall' sig fn_sig args (suffix_action sig fn_body). +End LoweredSyntaxMacros. diff --git a/coq/TypeInference.v b/coq/TypeInference.v index 7803f153..ba4b835e 100644 --- a/coq/TypeInference.v +++ b/coq/TypeInference.v @@ -10,7 +10,7 @@ Section ErrorReporting. {Sigma: ext_fn_t -> ExternalSignature}. Definition lift_basic_error_message - (pos: pos_t) {sig tau} (e: action pos_t var_t R Sigma sig tau) + (pos: pos_t) {sig tau} (e: action pos_t var_t fn_name_t R Sigma sig tau) (err: basic_error_message) : error pos_t var_t fn_name_t := {| epos := pos; emsg := BasicError err; @@ -18,15 +18,15 @@ Section ErrorReporting. Definition lift_fn1_tc_result {A sig tau} - pos (e: action pos_t var_t R Sigma sig tau) + pos (e: action pos_t var_t fn_name_t R Sigma sig tau) (r: result A fn_tc_error) : result A (error pos_t var_t fn_name_t) := result_map_failure (fun '(side, err) => lift_basic_error_message pos e err) r. Definition lift_fn2_tc_result {A sig1 tau1 sig2 tau2} - pos1 (e1: action pos_t var_t R Sigma sig1 tau1) - pos2 (e2: action pos_t var_t R Sigma sig2 tau2) + pos1 (e1: action pos_t var_t fn_name_t R Sigma sig1 tau1) + pos2 (e2: action pos_t var_t fn_name_t R Sigma sig2 tau2) (r: result A fn_tc_error) : result A (error pos_t var_t fn_name_t) := result_map_failure (fun '(side, err) => @@ -38,6 +38,7 @@ End ErrorReporting. Section TypeInference. Context {pos_t rule_name_t fn_name_t var_t reg_t ext_fn_t: Type}. + Context {var_t_eq_dec: EqDec var_t}. Context (R: reg_t -> type). @@ -46,8 +47,8 @@ Section TypeInference. Notation usugar := (usugar pos_t var_t fn_name_t). Notation uaction := (uaction pos_t var_t fn_name_t). - Notation action := (action pos_t var_t R Sigma). - Notation rule := (rule pos_t var_t R Sigma). + Notation action := (action pos_t var_t fn_name_t R Sigma). + Notation rule := (rule pos_t var_t fn_name_t R Sigma). Notation scheduler := (scheduler pos_t rule_name_t). Section Action. @@ -87,7 +88,7 @@ Section TypeInference. Fixpoint assert_argtypes' {T} {sig} (src: T) nexpected (fn_name: fn_name_t) pos (args_desc: tsig var_t) (args: list (pos_t * {tau : type & action sig tau})) - : result (context (K := (var_t * type)) (fun '(_, tau) => action sig tau) args_desc) := + : result (context (K := (var_t * type)) (fun k_tau => action sig (snd k_tau)) args_desc) := match args_desc, args with | [], [] => Success CtxEmpty | [], _ => Failure (mkerror pos (TooManyArguments fn_name nexpected (List.length args)) src) @@ -103,7 +104,7 @@ Section TypeInference. (fn_name: fn_name_t) pos (args_desc: tsig var_t) (args: list (pos_t * {tau : type & action sig tau})) - : result (context (K := (var_t * type)) (fun '(_, tau) => action sig tau) args_desc) := + : result (context (K := (var_t * type)) (fun k_tau => action sig (snd k_tau)) args_desc) := assert_argtypes' src (List.length args_desc) fn_name pos args_desc args. Fixpoint type_action @@ -148,10 +149,12 @@ Section TypeInference. let/res tc_args := result_list_map (type_action pos sig) args in let arg_positions := List.map (actpos pos) args in let tc_args_w_pos := List.combine arg_positions tc_args in - let/res arg_ctx := assert_argtypes e fn.(int_name) pos fn.(int_argspec) tc_args_w_pos in + let/res args_ctx := assert_argtypes e fn.(int_name) pos fn.(int_argspec) tc_args_w_pos in + let/res fn_body' := type_action (actpos pos fn.(int_body)) fn.(int_argspec) fn.(int_body) in let/res fn_body' := cast_action (actpos pos fn.(int_body)) fn.(int_retSig) (``fn_body') in - Success (EX (InternalCall sig fn.(int_argspec) fn_body' arg_ctx)) + + Success (EX (TypedSyntax.InternalCall fn.(int_name) args_ctx fn_body')) | UUnop fn arg1 => let pos1 := actpos pos arg1 in let/res arg1' := type_action pos sig arg1 in diff --git a/coq/TypedSemantics.v b/coq/TypedSemantics.v index 9a47347e..d1cdbd32 100644 --- a/coq/TypedSemantics.v +++ b/coq/TypedSemantics.v @@ -2,7 +2,7 @@ Require Export Koika.Common Koika.Environments Koika.Vect Koika.Logs Koika.Syntax Koika.TypedSyntax. Section Interp. - Context {pos_t var_t rule_name_t reg_t ext_fn_t: Type}. + Context {pos_t var_t fn_name_t rule_name_t reg_t ext_fn_t: Type}. Context {reg_t_eq_dec: EqDec reg_t}. Context {R: reg_t -> type}. @@ -14,14 +14,42 @@ Section Interp. Notation Log := (Log R REnv). - Notation rule := (rule pos_t var_t R Sigma). - Notation action := (action pos_t var_t R Sigma). + Notation rule := (rule pos_t var_t fn_name_t R Sigma). + Notation action := (action pos_t var_t fn_name_t R Sigma). Notation scheduler := (scheduler pos_t rule_name_t). Definition tcontext (sig: tsig var_t) := context (fun k_tau => type_denote (snd k_tau)) sig. + Definition acontext (sig argspec: tsig var_t) := + context (fun k_tau => action sig (snd k_tau)) argspec. + Section Action. + Section Args. + Context (interp_action: + forall {sig: tsig var_t} {tau} + (Gamma: tcontext sig) + (sched_log: Log) (action_log: Log) + (a: action sig tau), + option (Log * type_denote tau * (tcontext sig))). + + Fixpoint interp_args + {sig: tsig var_t} + (Gamma: tcontext sig) + (sched_log: Log) + (action_log: Log) + {argspec: tsig var_t} + (args: acontext sig argspec) + : option (Log * tcontext argspec * (tcontext sig)) := + match args with + | CtxEmpty => Some (action_log, CtxEmpty, Gamma) + | @CtxCons _ _ argspec k_tau arg args => + let/opt3 action_log, ctx, Gamma := interp_args Gamma sched_log action_log args in + let/opt3 action_log, v, Gamma := interp_action _ _ Gamma sched_log action_log arg in + Some (action_log, CtxCons k_tau v ctx, Gamma) + end. + End Args. + Fixpoint interp_action {sig: tsig var_t} {tau} @@ -29,8 +57,9 @@ Section Interp. (sched_log: Log) (action_log: Log) (a: action sig tau) + {struct a} : option (Log * tau * (tcontext sig)) := - match a in TypedSyntax.action _ _ _ _ ts tau return (tcontext ts -> option (Log * tau * (tcontext ts))) with + match a in TypedSyntax.action _ _ _ _ _ ts tau return (tcontext ts -> option (Log * tau * (tcontext ts))) with | Fail tau => fun _ => None | Var m => fun Gamma => @@ -40,10 +69,10 @@ Section Interp. | Seq r1 r2 => fun Gamma => let/opt3 action_log, _, Gamma := interp_action Gamma sched_log action_log r1 in interp_action Gamma sched_log action_log r2 - | @Assign _ _ _ _ _ _ _ k tau m ex => fun Gamma => + | @Assign _ _ _ _ _ _ _ _ k tau m ex => fun Gamma => let/opt3 action_log, v, Gamma := interp_action Gamma sched_log action_log ex in Some (action_log, Ob, creplace m v Gamma) - | @Bind _ _ _ _ _ _ sig tau tau' var ex body => fun (Gamma : tcontext sig) => + | @Bind _ _ _ _ _ _ _ sig tau tau' var ex body => fun (Gamma : tcontext sig) => let/opt3 action_log1, v, Gamma := interp_action Gamma sched_log action_log ex in let/opt3 action_log2, v, Gamma := interp_action (CtxCons (var, tau) v Gamma) sched_log action_log1 body in Some (action_log2, v, ctl Gamma) @@ -80,6 +109,10 @@ Section Interp. | ExternalCall fn arg1 => fun Gamma => let/opt3 action_log, arg1, Gamma := interp_action Gamma sched_log action_log arg1 in Some (action_log, sigma fn arg1, Gamma) + | InternalCall name args body => fun Gamma => + let/opt3 action_log, results, Gamma := interp_args (@interp_action) Gamma sched_log action_log args in + let/opt3 action_log, v, _ := interp_action results sched_log action_log body in + Some (action_log, v, Gamma) | APos _ a => fun Gamma => interp_action Gamma sched_log action_log a end Gamma. diff --git a/coq/TypedSyntax.v b/coq/TypedSyntax.v index 6d1654f9..4042c104 100644 --- a/coq/TypedSyntax.v +++ b/coq/TypedSyntax.v @@ -4,7 +4,7 @@ Require Export Koika.Common Koika.Environments Koika.Types Koika.Primitives. Import PrimTyped PrimSignatures. Section Syntax. - Context {pos_t var_t rule_name_t reg_t ext_fn_t: Type}. + Context {pos_t var_t rule_name_t fn_name_t reg_t ext_fn_t: Type}. Context {R: reg_t -> type}. Context {Sigma: ext_fn_t -> ExternalSignature}. @@ -45,11 +45,17 @@ Section Syntax. (fn: ext_fn_t) (arg: action sig (Sigma fn).(arg1Sig)) : action sig (Sigma fn).(retSig) + | InternalCall {sig tau} + (fn : fn_name_t) + {argspec : tsig var_t} + (args: context (fun k_tau => action sig (snd k_tau)) argspec) + (body : action argspec tau) + : action sig tau | APos {sig tau} (pos: pos_t) (a: action sig tau) : action sig tau. Definition rule := action nil unit_t. End Syntax. -Arguments rule pos_t var_t {reg_t ext_fn_t} R Sigma : assert. -Arguments action pos_t var_t {reg_t ext_fn_t} R Sigma sig tau : assert. +Arguments action pos_t var_t fn_name_t {reg_t ext_fn_t} R Sigma sig tau : assert. +Arguments rule pos_t var_t fn_name_t {reg_t ext_fn_t} R Sigma : assert. diff --git a/coq/TypedSyntaxFunctions.v b/coq/TypedSyntaxFunctions.v index 9f9ad490..1dd3e8b2 100644 --- a/coq/TypedSyntaxFunctions.v +++ b/coq/TypedSyntaxFunctions.v @@ -2,13 +2,13 @@ Require Import Koika.Member Koika.TypedSyntax Koika.Primitives Koika.TypedSemantics. Section TypedSyntaxFunctions. - Context {pos_t var_t rule_name_t reg_t ext_fn_t: Type}. + Context {pos_t var_t fn_name_t rule_name_t reg_t ext_fn_t: Type}. Context {R: reg_t -> type} {Sigma: ext_fn_t -> ExternalSignature}. Context {REnv : Env reg_t}. - Notation rule := (rule pos_t var_t R Sigma). - Notation action := (action pos_t var_t R Sigma). + Notation rule := (rule pos_t var_t fn_name_t R Sigma). + Notation action := (action pos_t var_t fn_name_t R Sigma). Notation scheduler := (scheduler pos_t rule_name_t). Fixpoint scheduler_rules (s: scheduler) := @@ -20,7 +20,7 @@ Section TypedSyntaxFunctions. end. Fixpoint unannot {sig tau} (a: action sig tau) := - match a in TypedSyntax.action _ _ _ _ sig tau return action sig tau with + match a in TypedSyntax.action _ _ _ _ _ sig tau return action sig tau with | APos _ a => unannot a | a => a end. @@ -32,7 +32,7 @@ Section TypedSyntaxFunctions. Section Footprint. Notation footprint_t := (list (reg_t * event_t)). - Fixpoint action_footprint' {sig tau} (acc: footprint_t) (a: action sig tau) := + Fixpoint action_footprint' {sig tau} (acc: footprint_t) (a: action sig tau) {struct a} := match a with | Fail _ | Var _ | Const _ => acc | Assign m ex => action_footprint' acc ex @@ -44,6 +44,9 @@ Section TypedSyntaxFunctions. | Unop fn arg1 => action_footprint' acc arg1 | Binop fn arg1 arg2 => action_footprint' (action_footprint' acc arg1) arg2 | ExternalCall fn arg => action_footprint' acc arg + | InternalCall fn args body => + let acc := cfoldl (fun _ arg acc => action_footprint' acc arg) args acc in + action_footprint' acc body | APos _ a => action_footprint' acc a end. @@ -152,7 +155,7 @@ Section TypedSyntaxFunctions. | HistoryAnnot (rh: reg_history_map). Notation annotated_action sig tau := - (TypedSyntax.action register_annotation var_t R Sigma sig tau). + (TypedSyntax.action register_annotation var_t fn_name_t R Sigma sig tau). Definition join_tribools t1 t2 := match t1, t2 with @@ -252,6 +255,16 @@ Section TypedSyntaxFunctions. | ExternalCall fn arg => let '(env, arg) := annotate_action_register_histories env arg in (env, ExternalCall fn arg) + | InternalCall fn args body => + let '(env, args) := + cfoldr (fun sg k arg cont => + fun env => + let '(env, arg) := annotate_action_register_histories env arg in + let '(env, args) := cont env in + (env, CtxCons k arg args)) args + (fun env => (env, CtxEmpty)) env in + let '(env, body) := annotate_action_register_histories env body in + (env, InternalCall fn args body) | APos pos a => let '(env, a) := annotate_action_register_histories env a in (env, APos (PosAnnot pos) a) @@ -340,7 +353,7 @@ Section TypedSyntaxFunctions. Context (s: scheduler). Definition annotated_rule := - TypedSyntax.action register_annotation var_t R Sigma [] unit_t. + TypedSyntax.action register_annotation var_t fn_name_t R Sigma [] unit_t. Definition compute_register_histories : RLEnv.(env_t) (fun _ => (reg_history_map * annotated_rule)%type) * @@ -377,6 +390,9 @@ Section TypedSyntaxFunctions. | Unop fn arg1 => rule_max_log_size arg1 | Binop fn arg1 arg2 => rule_max_log_size arg1 + rule_max_log_size arg2 | ExternalCall fn arg => rule_max_log_size arg + | InternalCall fn args body => + cfoldl (fun k arg acc => acc + rule_max_log_size arg) args 0 + + rule_max_log_size body | APos pos a => rule_max_log_size a end. End StaticAnalysis. @@ -399,6 +415,9 @@ Section TypedSyntaxFunctions. | Unop fn a => existsb_subterm f a | Binop fn a1 a2 => existsb_subterm f a1 || existsb_subterm f a2 | ExternalCall fn arg => existsb_subterm f arg + | InternalCall fn args body => + cfoldl (fun k arg acc => acc || existsb_subterm f arg) args false + || existsb_subterm f body | APos _ a => existsb_subterm f a end. @@ -417,7 +436,7 @@ Section TypedSyntaxFunctions. Fixpoint action_mentions_var {EQ: EqDec var_t} {sig tau} (k: var_t) (a: action sig tau) := existsb_subterm (fun a => match a with - | AnyAction (@Var _ _ _ _ _ _ _ k' _ m) => beq_dec k k' + | AnyAction (@Var _ _ _ _ _ _ _ _ k' _ m) => beq_dec k k' | _ => false end) a. @@ -435,6 +454,9 @@ Section TypedSyntaxFunctions. | Unop fn arg1 => is_pure arg1 | Binop fn arg1 arg2 => is_pure arg1 && is_pure arg2 | ExternalCall fn arg => false + | InternalCall fn args body => + cfoldl (fun k arg acc => acc && is_pure arg) args true + && is_pure body | APos pos a => is_pure a end. @@ -452,24 +474,26 @@ Section TypedSyntaxFunctions. | Unop fn arg1 => false | Binop fn arg1 arg2 => false | ExternalCall fn arg => false + | InternalCall fn args body => returns_zero body | APos pos a => returns_zero a end. Definition action_type {sig tau} (a: action sig tau) : option type := match a with - | @Fail _ _ _ _ _ _ _ tau => Some tau - | @Var _ _ _ _ _ _ _ _ tau _ => Some tau - | @Const _ _ _ _ _ _ _ tau cst => Some tau - | @Assign _ _ _ _ _ _ _ _ _ _ _ => Some unit_t - | @Seq _ _ _ _ _ _ _ tau _ _ => Some tau - | @Bind _ _ _ _ _ _ _ _ tau' _ _ _ => Some tau' - | @If _ _ _ _ _ _ _ tau _ _ _ => Some tau - | @Read _ _ _ _ _ _ _ _ _ => None - | @Write _ _ _ _ _ _ _ _ _ _ => Some unit_t - | @Unop _ _ _ _ _ _ _ fn _ => Some (PrimSignatures.Sigma1 fn).(retSig) - | @Binop _ _ _ _ _ _ _ fn _ _ => Some (PrimSignatures.Sigma2 fn).(retSig) - | @ExternalCall _ _ _ _ _ _ _ _ _ => None - | @APos _ _ _ _ _ _ _ tau _ _ => Some tau + | @Fail _ _ _ _ _ _ _ _ tau => Some tau + | @Var _ _ _ _ _ _ _ _ _ tau _ => Some tau + | @Const _ _ _ _ _ _ _ _ tau cst => Some tau + | @Assign _ _ _ _ _ _ _ _ _ _ _ _ => Some unit_t + | @Seq _ _ _ _ _ _ _ _ tau _ _ => Some tau + | @Bind _ _ _ _ _ _ _ _ _ tau' _ _ _ => Some tau' + | @If _ _ _ _ _ _ _ _ tau _ _ _ => Some tau + | @Read _ _ _ _ _ _ _ _ _ _ => None + | @Write _ _ _ _ _ _ _ _ _ _ _ => Some unit_t + | @Unop _ _ _ _ _ _ _ _ fn _ => Some (PrimSignatures.Sigma1 fn).(retSig) + | @Binop _ _ _ _ _ _ _ _ fn _ _ => Some (PrimSignatures.Sigma2 fn).(retSig) + | @ExternalCall _ _ _ _ _ _ _ _ _ _ => None + | @InternalCall _ _ _ _ _ _ _ _ tau _ _ _ _ => Some tau + | @APos _ _ _ _ _ _ _ _ tau _ _ => Some tau end. Definition is_tt {sig tau} (a: action sig tau) := @@ -494,6 +518,7 @@ Section TypedSyntaxFunctions. let/opt r2 := interp_arithmetic arg2 in Some (PrimSpecs.sigma2 fn r1 r2) | ExternalCall fn arg => None + | InternalCall fn args body => None | APos pos a => interp_arithmetic a end. End TypedSyntaxFunctions. diff --git a/coq/Types.v b/coq/Types.v index 2f9191f7..bcc078f5 100644 --- a/coq/Types.v +++ b/coq/Types.v @@ -102,7 +102,7 @@ Inductive Port := (** * Denotations *) Definition struct_denote' (type_denote: type -> Type) (fields: list (string * type)) := - List.fold_right (fun '(_, tau) acc => type_denote tau * acc)%type unit fields. + List.fold_right (fun k_tau acc => type_denote (snd k_tau) * acc)%type unit fields. Fixpoint type_denote tau : Type := match tau with @@ -289,18 +289,26 @@ Definition CExternalSignature := CSig 1. Definition tsig var_t := list (var_t * type). Definition lsig := list nat. -Record InternalFunction {fn_name_t var_t action: Type} := +Record InternalFunction {var_t fn_name_t action: Type} := { int_name : fn_name_t; int_argspec : tsig var_t; int_retSig : type; int_body : action }. Arguments InternalFunction : clear implicits. -Arguments Build_InternalFunction {fn_name_t var_t action} +Arguments Build_InternalFunction {var_t fn_name_t action} int_name int_argspec int_retSig int_body : assert. +Definition map_intf_body {var_t fn_name_t action action': Type} + (f: action -> action') (fn: InternalFunction var_t fn_name_t action) := + {| int_name := fn.(int_name); + int_argspec := fn.(int_argspec); + int_retSig := fn.(int_retSig); + int_body := f fn.(int_body) |}. + Record arg_sig {var_t} := { arg_name: var_t; arg_type: type }. +Arguments arg_sig : clear implicits. Definition prod_of_argsig {var_t} (a: @arg_sig var_t) := (a.(arg_name), a.(arg_type)). diff --git a/examples/collatz.v b/examples/collatz.v index b3a994b1..2c5bbee7 100644 --- a/examples/collatz.v +++ b/examples/collatz.v @@ -19,7 +19,7 @@ Module Collatz. end. Definition times_three : UInternalFunction reg_t empty_ext_fn_t := - {{ fun (bs: bits_t 16) : bits_t 16 => + {{ fun times_three (bs: bits_t 16) : bits_t 16 => (bs << Ob~1) + bs }}. Definition _divide : uaction reg_t empty_ext_fn_t := diff --git a/examples/conflicts_modular.v b/examples/conflicts_modular.v index e85f2cde..bd039c5f 100644 --- a/examples/conflicts_modular.v +++ b/examples/conflicts_modular.v @@ -10,15 +10,15 @@ Module Import Queue32. end. Definition dequeue0: UInternalFunction reg_t empty_ext_fn_t := - {{ fun _ : bits_t 32 => + {{ fun dequeue0 () : bits_t 32 => guard(!read0(empty)); write0(empty, Ob~1); read0(data) }}. Definition enqueue0: UInternalFunction reg_t empty_ext_fn_t := - {{ fun (val: bits_t 32) : unit_t => + {{ fun enqueue0 (val: bits_t 32) : unit_t => guard(read0(empty)); write0(empty, Ob~0); write0(data, val) }}. Definition dequeue1: UInternalFunction reg_t empty_ext_fn_t := - {{ fun _ : bits_t 32 => + {{ fun dequeue1 () : bits_t 32 => guard(!read1(empty)); write1(empty, Ob~1); read1(data) }}. End Queue32. diff --git a/examples/function_call.v b/examples/function_call.v index 7533263f..18e75708 100644 --- a/examples/function_call.v +++ b/examples/function_call.v @@ -34,7 +34,7 @@ Definition Sigma (fn: ext_fn_t) : ExternalSignature := end. Definition nth_instr_intfun : UInternalFunction reg_t ext_fn_t := - {{ fun (addr: bits_t 3) : bits_t 32 => + {{ fun nth_instr_intfun (addr: bits_t 3) : bits_t 32 => `UCompleteSwitch NestedSwitch 3 "addr" (List_nth instructions)` }}. @@ -47,7 +47,7 @@ Definition _fetch_external : uaction reg_t ext_fn_t := write1(next_instr, extcall nth_instr_external(addr)) }}. Definition plus4 : UInternalFunction reg_t ext_fn_t := - {{ fun (v: bits_t 5) : bits_t 5 => + {{ fun plus4 (v: bits_t 5) : bits_t 5 => v + |5`d4| }}. Definition _incr_pc : uaction reg_t ext_fn_t := diff --git a/examples/method_call.v b/examples/method_call.v index c8bf95ae..29c163db 100644 --- a/examples/method_call.v +++ b/examples/method_call.v @@ -12,7 +12,7 @@ Module Delay. (* Declaration of a Koika function in a module, called a method *) Definition swap tau: UInternalFunction reg_t ext_fn_t := {{ - fun (arg1 : tau) : tau => + fun swap (arg1 : tau) : tau => write0(buffer, arg1); read0(buffer) }}. @@ -43,7 +43,7 @@ Definition r reg : R reg := (* Declaration of a family of Koika function indexed by a coq integer *) Definition nor (sz: nat) : UInternalFunction reg_t ext_fn_t := {{ - fun (arg1 : bits_t sz) (arg2 : bits_t sz) : bits_t sz => + fun nor (arg1 : bits_t sz) (arg2 : bits_t sz) : bits_t sz => !(arg1 || arg2) }}. diff --git a/examples/rv/RVCore.v b/examples/rv/RVCore.v index 2ea37c6c..f4279f2c 100644 --- a/examples/rv/RVCore.v +++ b/examples/rv/RVCore.v @@ -50,7 +50,7 @@ Section RV32IHelpers. Definition getFields : UInternalFunction reg_t empty_ext_fn_t := {{ - fun (inst : bits_t 32) : struct_t inst_field => + fun getFields (inst : bits_t 32) : struct_t inst_field => let res := struct inst_field {| opcode := inst[|5`d0| :+ 7]; @@ -84,7 +84,7 @@ Section RV32IHelpers. Definition isLegalInstruction : UInternalFunction reg_t empty_ext_fn_t := {{ - fun (inst : bits_t 32) : bits_t 1 => + fun isLegalInstruction (inst : bits_t 32) : bits_t 1 => let fields := getFields (inst) in match get(fields, opcode) with | #opcode_LOAD => @@ -166,7 +166,7 @@ Section RV32IHelpers. Definition getImmediateType : UInternalFunction reg_t empty_ext_fn_t := {{ - fun (inst : bits_t 32) : maybe (enum_t imm_type) => + fun getImmediateType (inst : bits_t 32) : maybe (enum_t imm_type) => match (inst[|5`d2|:+5]) with | #opcode_LOAD[|3`d2|:+5] => {valid (enum_t imm_type)}(enum imm_type {| ImmI |}) | #opcode_OP_IMM[|3`d2|:+5] => {valid (enum_t imm_type)}(enum imm_type {| ImmI |}) @@ -182,7 +182,7 @@ Section RV32IHelpers. Definition usesRS1 : UInternalFunction reg_t empty_ext_fn_t := {{ - fun (inst : bits_t 32) : bits_t 1 => + fun usesRS1 (inst : bits_t 32) : bits_t 1 => match (inst[Ob~0~0~0~1~0 :+ 5]) with | Ob~1~1~0~0~0 => Ob~1 (* // bge, bne, bltu, blt, bgeu, beq *) | Ob~0~0~0~0~0 => Ob~1 (* // lh, ld, lw, lwu, lbu, lhu, lb *) @@ -197,7 +197,7 @@ Section RV32IHelpers. Definition usesRS2 : UInternalFunction reg_t empty_ext_fn_t := {{ - fun (inst : bits_t 32) : bits_t 1 => + fun usesRS2 (inst : bits_t 32) : bits_t 1 => match (inst[Ob~0~0~0~1~0 :+ 5]) with | Ob~1~1~0~0~0 => Ob~1 (* // bge, bne, bltu, blt, bgeu, beq *) | Ob~0~1~0~0~0 => Ob~1 (* // sh, sb, sw, sd *) @@ -209,7 +209,7 @@ Section RV32IHelpers. Definition usesRD : UInternalFunction reg_t empty_ext_fn_t := {{ - fun (inst : bits_t 32) : bits_t 1 => + fun usesRD (inst : bits_t 32) : bits_t 1 => match (inst[Ob~0~0~0~1~0 :+ 5]) with | Ob~0~1~1~0~1 => Ob~1 (* // lui*) | Ob~1~1~0~1~1 => Ob~1 (* // jal*) @@ -223,7 +223,7 @@ Section RV32IHelpers. }}. Definition decode_fun : UInternalFunction reg_t empty_ext_fn_t := - {{ fun (arg_inst : bits_t 32) : struct_t decoded_sig => + {{ fun decode_fun (arg_inst : bits_t 32) : struct_t decoded_sig => struct decoded_sig {| valid_rs1 := usesRS1 (arg_inst); valid_rs2 := usesRS2 (arg_inst); @@ -236,7 +236,7 @@ Section RV32IHelpers. Definition getImmediate : UInternalFunction reg_t empty_ext_fn_t := {{ - fun (dInst: struct_t decoded_sig) : bits_t 32 => + fun getImmediate (dInst: struct_t decoded_sig) : bits_t 32 => let imm_type_v := get(dInst, immediateType) in if (get(imm_type_v, valid) == Ob~1) then let fields := getFields (get(dInst,inst)) in @@ -253,7 +253,7 @@ Section RV32IHelpers. }}. Definition alu32 : UInternalFunction reg_t empty_ext_fn_t := - {{ fun (funct3 : bits_t 3) + {{ fun alu32 (funct3 : bits_t 3) (inst_30 : bits_t 1) (a : bits_t 32) (b : bits_t 32) @@ -274,9 +274,8 @@ Section RV32IHelpers. Definition execALU32 : UInternalFunction reg_t empty_ext_fn_t := - {{ - fun (inst : bits_t 32) + fun execALU32 (inst : bits_t 32) (rs1_val : bits_t 32) (rs2_val : bits_t 32) (imm_val : bits_t 32) @@ -311,7 +310,7 @@ Section RV32IHelpers. Definition execControl32 : UInternalFunction reg_t empty_ext_fn_t := {{ - fun (inst : bits_t 32) + fun execControl32 (inst : bits_t 32) (rs1_val : bits_t 32) (rs2_val : bits_t 32) (imm_val : bits_t 32) @@ -550,17 +549,17 @@ Module RV32ICore. tc_action R empty_Sigma decode. (* Useful for debugging *) - Arguments Var {pos_t var_t reg_t ext_fn_t R Sigma sig} k {tau m} : assert. + Arguments Var {pos_t var_t fn_name_t reg_t ext_fn_t R Sigma sig} k {tau m} : assert. Definition isMemoryInst : UInternalFunction reg_t empty_ext_fn_t := {{ - fun (dInst: struct_t decoded_sig) : bits_t 1 => + fun isMemoryInst (dInst: struct_t decoded_sig) : bits_t 1 => (get(dInst,inst)[|5`d6|] == Ob~0) && (get(dInst,inst)[|5`d3|:+2] == Ob~0~0) }}. Definition isControlInst : UInternalFunction reg_t empty_ext_fn_t := {{ - fun (dInst: struct_t decoded_sig) : bits_t 1 => + fun isControlInst (dInst: struct_t decoded_sig) : bits_t 1 => get(dInst,inst)[|5`d4| :+ 3] == Ob~1~1~0 }}. diff --git a/examples/rv/Scoreboard.v b/examples/rv/Scoreboard.v index bb5140ca..613300c6 100644 --- a/examples/rv/Scoreboard.v +++ b/examples/rv/Scoreboard.v @@ -44,14 +44,14 @@ Module Scoreboard (s:Scoreboard_sig). (* Internal functions *) Definition sat_incr : UInternalFunction reg_t empty_ext_fn_t := {{ - fun (a: bits_t logScore) : bits_t logScore => + fun sat_incr (a: bits_t logScore) : bits_t logScore => (* if ( a == #(Bits.of_nat logScore s.maxScore)) then fail(logScore) *) (* else *) a + #(Bits.of_nat logScore 1) }}. Definition sat_decr : UInternalFunction reg_t empty_ext_fn_t := {{ - fun (a: bits_t logScore) : bits_t logScore => + fun sat_decr (a: bits_t logScore) : bits_t logScore => (* if (a == |logScore`d0|) then fail(logScore) *) (* else *) (a - |logScore`d1|) }}. @@ -59,7 +59,7 @@ Module Scoreboard (s:Scoreboard_sig). (* Interface: *) Definition insert : UInternalFunction reg_t empty_ext_fn_t := {{ - fun (idx: bits_t sz) : bits_t 0 => + fun insert (idx: bits_t sz) : bits_t 0 => let old_score := Scores.(Rf.read_1)(idx) in let new_score := sat_incr(old_score) in Scores.(Rf.write_1)(idx, new_score) @@ -67,7 +67,7 @@ Module Scoreboard (s:Scoreboard_sig). Definition remove : UInternalFunction reg_t empty_ext_fn_t := {{ - fun (idx: bits_t sz) : bits_t 0 => + fun remove (idx: bits_t sz) : bits_t 0 => let old_score := Scores.(Rf.read_0)(idx) in let new_score := sat_decr(old_score) in Scores.(Rf.write_0)(idx, new_score) @@ -75,7 +75,7 @@ Module Scoreboard (s:Scoreboard_sig). Definition search : UInternalFunction reg_t empty_ext_fn_t := {{ - fun (idx: bits_t sz) : bits_t logScore => + fun search (idx: bits_t sz) : bits_t logScore => Scores.(Rf.read_1)(idx) }}. End Scoreboard. diff --git a/ocaml/backends/cpp.ml b/ocaml/backends/cpp.ml index 6c71f224..4726f4eb 100644 --- a/ocaml/backends/cpp.ml +++ b/ocaml/backends/cpp.ml @@ -16,23 +16,25 @@ let add_line_pragmas = false let use_dynamic_logs = false let use_offsets_in_dynamic_log = false -type ('pos_t, 'var_t, 'rule_name_t, 'reg_t, 'ext_fn_t) cpp_rule_t = { +type ('pos_t, 'var_t, 'fn_name_t, 'rule_name_t, 'reg_t, 'ext_fn_t) cpp_rule_t = { rl_external: bool; rl_name: 'rule_name_t; rl_footprint: 'reg_t array; - rl_body: (('pos_t, 'reg_t) Extr.register_annotation, 'var_t, 'reg_t, 'ext_fn_t) Extr.rule; + rl_body: (('pos_t, 'reg_t) Extr.register_annotation, + 'var_t, 'fn_name_t, 'reg_t, 'ext_fn_t) Extr.rule; } -type ('pos_t, 'var_t, 'rule_name_t, 'reg_t, 'ext_fn_t) cpp_input_t = { +type ('pos_t, 'var_t, 'fn_name_t, 'rule_name_t, 'reg_t, 'ext_fn_t) cpp_input_t = { cpp_classname: string; cpp_module_name: string; cpp_pos_of_pos: 'pos_t -> Pos.t; cpp_var_names: 'var_t -> string; + cpp_fn_names: 'fn_name_t -> string; cpp_rule_names: ?prefix:string -> 'rule_name_t -> string; cpp_scheduler: ('pos_t, 'rule_name_t) Extr.scheduler; - cpp_rules: ('pos_t, 'var_t, 'rule_name_t, 'reg_t, 'ext_fn_t) cpp_rule_t list; + cpp_rules: ('pos_t, 'var_t, 'fn_name_t, 'rule_name_t, 'reg_t, 'ext_fn_t) cpp_rule_t list; cpp_registers: 'reg_t array; cpp_register_sigs: 'reg_t -> reg_signature; @@ -142,6 +144,7 @@ module Mangling = struct { u with (* The prefixes are needed to prevent collisions with ‘prims::’ *) cpp_classname = mangle_name ~prefix:"module" u.cpp_classname; cpp_var_names = (fun v -> v |> u.cpp_var_names |> mangle_name); + cpp_fn_names = (fun v -> v |> u.cpp_fn_names |> mangle_name); cpp_rule_names = (fun ?prefix rl -> rl |> u.cpp_rule_names |> mangle_name ~prefix:(default "rule" prefix)); cpp_register_sigs = (fun r -> r |> u.cpp_register_sigs |> mangle_register_sig); @@ -330,8 +333,8 @@ let assignment_result_to_string (d: assignment_result) = | PureExpr s -> sprintf "PureExpr %s" s | ImpureExpr s -> sprintf "ImpureExpr %s" s -let compile (type pos_t var_t rule_name_t reg_t ext_fn_t) - (hpp: (pos_t, var_t, rule_name_t, reg_t, ext_fn_t) cpp_input_t) = +let compile (type pos_t var_t fn_name_t rule_name_t reg_t ext_fn_t) + (hpp: (pos_t, var_t, fn_name_t, rule_name_t, reg_t, ext_fn_t) cpp_input_t) = let buffer = ref (Buffer.create 0) in let hpp = Mangling.mangle_unit hpp in @@ -812,7 +815,7 @@ let compile (type pos_t var_t rule_name_t reg_t ext_fn_t) | Range (filename, { rbeg = { line; _ }; _ }) -> p "#line %d \"%s\"" line (sp_escaped_string filename) in - let p_rule (rule: (pos_t, var_t, rule_name_t, reg_t, ext_fn_t) cpp_rule_t) = + let p_rule (rule: (pos_t, var_t, fn_name_t, rule_name_t, reg_t, ext_fn_t) cpp_rule_t) = gensym_reset (); let footprint_sz = @@ -986,7 +989,8 @@ let compile (type pos_t var_t rule_name_t reg_t ext_fn_t) hpp.cpp_rule_names ~prefix:"" rule.rl_name in let rec p_action (pos: Pos.t) (target: assignment_target) - (rl: ((pos_t, reg_t) Extr.register_annotation, var_t, reg_t, _) Extr.action) = + (rl: ((pos_t, reg_t) Extr.register_annotation, + var_t, fn_name_t, reg_t, _) Extr.action) = p_pos pos; match rl with | Extr.APos (_, _, Extr.HistoryAnnot reg_histories, @@ -1071,6 +1075,13 @@ let compile (type pos_t var_t rule_name_t reg_t ext_fn_t) let a = p_action pos (gensym_target ffi.ffi_argtype "x") a in Hashtbl.replace program_info.pi_ext_funcalls ffi (); ImpureExpr (cpp_ext_funcall ffi.ffi_name (must_value a)) + | Extr.InternalCall (_, _, fn, argspec, args, body) -> + p_declare_target target; + p_scoped (sprintf "/* Call to %s */" @@ hpp.cpp_fn_names fn) (fun () -> + Extr.cfoldl (fun (arg_name, arg_type) arg () -> + p_bound_var_assign pos arg_type arg_name arg) + argspec args (); + p_assign_expr target (p_action pos target body)) | Extr.APos (_, _, Extr.PosAnnot pos, a) -> p_action (hpp.cpp_pos_of_pos pos) target a | Extr.Fail (_, _) -> while true do () done; failwith "Missing annotation on fail" @@ -1328,12 +1339,15 @@ let input_of_compile_unit (cu: 'f Cuttlebone.Compilation.compile_unit) = Cuttlebone.Util.compute_register_histories Cuttlebone.Compilation._R cu.c_registers (List.map fst cu.c_rules) rulemap cu.c_scheduler in - let swap_body (name, (kind, _)) = + let swap_body (name, (kind, _)) + : (_ * (_ * (_, Common.var_t, Common.fn_name_t, Common.reg_signature, 'c) + Cuttlebone.Extr.annotated_rule)) = (name, (kind, annotated_rules name)) in { cpp_classname = cu.c_modname; cpp_module_name = cu.c_modname; cpp_pos_of_pos = cu.c_pos_of_pos; cpp_var_names = (fun x -> x); + cpp_fn_names = (fun x -> x); cpp_rule_names = (fun ?prefix:_ rl -> rl); cpp_scheduler = cu.c_scheduler; cpp_rules = List.map (cpp_rule_of_action cu.c_registers << swap_body) cu.c_rules; @@ -1350,9 +1364,9 @@ let cpp_rule_of_koika_package_rule (kp: _ Extr.koika_package_t) (rn, (kind rn, annotated_rules rn)) let input_of_sim_package - (kp: ('pos_t, 'var_t, 'rule_name_t, 'reg_t, 'ext_fn_t) Extr.koika_package_t) + (kp: ('pos_t, 'var_t, 'fn_name_t, 'rule_name_t, 'reg_t, 'ext_fn_t) Extr.koika_package_t) (sp: ('ext_fn_t) Extr.sim_package_t) - : ('pos_t, 'var_t, 'rule_name_t, 'reg_t, 'ext_fn_t) cpp_input_t = + : ('pos_t, 'var_t, 'fn_name_t, 'rule_name_t, 'reg_t, 'ext_fn_t) cpp_input_t = let rule_names = Extr.scheduler_rules kp.koika_scheduler |> Cuttlebone.Util.dedup in let annotated_rules, register_kinds = @@ -1366,6 +1380,7 @@ let input_of_sim_package cpp_module_name = classname; cpp_pos_of_pos = (fun _ -> Pos.Unknown); cpp_var_names = (fun x -> Cuttlebone.Util.string_of_coq_string (kp.koika_var_names.show0 x)); + cpp_fn_names = (fun x -> Cuttlebone.Util.string_of_coq_string (kp.koika_fn_names.show0 x)); cpp_rule_names = (fun ?prefix:_ rn -> Cuttlebone.Util.string_of_coq_string (kp.koika_rule_names.show0 rn)); cpp_scheduler = kp.koika_scheduler; diff --git a/ocaml/cuttlebone/cuttlebone.ml b/ocaml/cuttlebone/cuttlebone.ml index 8e5551d5..00b8cc5b 100644 --- a/ocaml/cuttlebone/cuttlebone.ml +++ b/ocaml/cuttlebone/cuttlebone.ml @@ -126,6 +126,9 @@ module Util = struct let any_eq_dec = { Extr.eq_dec = fun (s1: 'a) (s2: 'a) -> s1 = s2 } + let string_show = + { Extr.show0 = coq_string_of_string } + type 'var_t extr_error_message = | ExplicitErrorInAst | SugaredConstructorInAst @@ -263,13 +266,13 @@ module Util = struct List.to_seq l |> Seq.map (fun x -> (x, ())) |> Hashtbl.of_seq |> Hashtbl.to_seq_keys |> List.of_seq - let compute_register_histories (type reg_t rule_name_t) + let compute_register_histories (type reg_t fn_name_t rule_name_t) (_R: reg_t -> extr_type) (registers: reg_t list) (rule_names: rule_name_t list) - (rules: rule_name_t -> ('pos_t, 'var_t, reg_t, 'fn_t) Extr.rule) + (rules: rule_name_t -> ('pos_t, 'var_t, fn_name_t, reg_t, 'ext_fn_t) Extr.rule) (scheduler: ('pos_t, rule_name_t) Extr.scheduler) - : (rule_name_t -> ('pos_t, 'var_t, reg_t, 'fn_t) Extr.annotated_rule) + : (rule_name_t -> ('pos_t, 'var_t, fn_name_t, reg_t, 'fn_t) Extr.annotated_rule) * (reg_t -> Extr.register_kind) = (* Taking in a list of rules allows us to ensure that we annotate all rules, not just those mentioned in the scheduler. *) @@ -289,11 +292,11 @@ module Compilation = struct | P1 -> Extr.P1 type 'f extr_uaction = - ('f, fn_name_t, var_t, reg_signature, ffi_signature) Extr.uaction + ('f, var_t, fn_name_t, reg_signature, ffi_signature) Extr.uaction type 'f extr_scheduler = ('f, rule_name_t) Extr.scheduler - type 'f extr_action = ('f, var_t, reg_signature, ffi_signature) Extr.action + type 'f extr_action = ('f, var_t, fn_name_t, reg_signature, ffi_signature) Extr.action type 'f extr_rule = [ `ExternalRule | `InternalRule ] * 'f extr_action let _R = fun rs -> Util.extr_type_of_typ (reg_type rs) @@ -316,8 +319,9 @@ module Compilation = struct | Extr.Failure (err: _ Extr.error) -> Error (err.epos, Util.translate_extr_error_message err.emsg) let typecheck_rule pos (ast: 'f extr_uaction) : ('f extr_action, ('f * _)) result = - Extr.type_rule Util.string_eq_dec _R _Sigma pos (Extr.desugar_action pos ast) - |> result_of_type_result + let desugared = Extr.desugar_action pos ast in + let typed = Extr.type_rule Util.string_eq_dec _R _Sigma pos desugared in + result_of_type_result typed let rec extr_circuit_equivb sz (c1: _ extr_circuit) (c2: _ extr_circuit) = let eqb = extr_circuit_equivb sz in @@ -357,12 +361,11 @@ module Compilation = struct let compile (cu: 'f compile_unit) : (reg_signature -> compiled_circuit) = let finiteType = Util.finiteType_of_list cu.c_registers in - let show_string = { Extr.show0 = fun (rl: string) -> Util.coq_string_of_string rl } in let rules r = List.assoc r cu.c_rules |> snd in let externalp r = (List.assoc r cu.c_rules |> fst) = `ExternalRule in let rEnv = Extr.contextEnv finiteType in let env = Extr.compile_scheduler _R _Sigma finiteType - show_string show_string + Util.string_show Util.string_show (opt _R _Sigma) rules externalp cu.c_scheduler in (fun r -> Extr.getenv rEnv env r) end @@ -595,8 +598,8 @@ module Graphs = struct di_circuits = Compilation.compile cu; di_strip_annotations = strip_annotations } - let graph_of_verilog_package (type pos_t var_t rule_name_t reg_t ext_fn_t) - (kp: (pos_t, var_t, rule_name_t, reg_t, ext_fn_t) Extr.koika_package_t) + let graph_of_verilog_package (type pos_t var_t fn_name_t rule_name_t reg_t ext_fn_t) + (kp: (pos_t, var_t, fn_name_t, rule_name_t, reg_t, ext_fn_t) Extr.koika_package_t) (vp: ext_fn_t Extr.verilog_package_t) : circuit_graph = let di_regs = diff --git a/ocaml/cuttlec.ml b/ocaml/cuttlec.ml index d0fa5d48..5bc5c33f 100644 --- a/ocaml/cuttlec.ml +++ b/ocaml/cuttlec.ml @@ -58,10 +58,11 @@ type config = { cnf_dst_dpath: string; } -type ('pos_t, 'var_t, 'rule_name_t, 'reg_t, 'ext_fn_t) package = { +type ('pos_t, 'var_t, 'fn_name_t, 'rule_name_t, 'reg_t, 'ext_fn_t) package = { pkg_modname: string; pkg_lv: Lv.resolved_unit lazy_t; - pkg_cpp: ('pos_t, 'var_t, 'rule_name_t, 'reg_t, 'ext_fn_t) Backends.Cpp.cpp_input_t lazy_t; + pkg_cpp: ('pos_t, 'var_t, 'fn_name_t, 'rule_name_t, 'reg_t, 'ext_fn_t) + Backends.Cpp.cpp_input_t lazy_t; pkg_graph: Cuttlebone.Graphs.circuit_graph lazy_t; } diff --git a/ocaml/frontends/lv.ml b/ocaml/frontends/lv.ml index 55f463c3..0c0c787b 100644 --- a/ocaml/frontends/lv.ml +++ b/ocaml/frontends/lv.ml @@ -113,9 +113,11 @@ module ResolvedAST = struct | Write (port, reg, v) -> Extr.UWrite (translate_port port, reg.lcnt, translate_action v) | Unop { fn; arg } -> UUnop (fn.lcnt, translate_action arg) | Binop { fn; a1; a2 } -> UBinop (fn.lcnt, translate_action a1, translate_action a2) - | InternalCall { fn; args } -> - Extr.UInternalCall (Util.extr_intfun_of_intfun translate_action fn, List.map translate_action args) | ExternalCall { fn; arg } -> UExternalCall (fn.lcnt, translate_action arg) + | InternalCall { fn; args } -> + UInternalCall + (Util.extr_intfun_of_intfun translate_action fn, + List.map translate_action args) | Sugar u -> Extr.USugar (match u with