Skip to content

Commit

Permalink
coq: Unify treatment of circuit signatures and typed signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
cpitclaudel committed Jan 14, 2020
1 parent 9969b5e commit b01dde5
Show file tree
Hide file tree
Showing 19 changed files with 97 additions and 117 deletions.
10 changes: 5 additions & 5 deletions coq/CircuitCorrectness.v
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Section PrimCompilerCorrectness.

Context {REnv: Env reg_t}.
Context (cr: REnv.(env_t) (fun idx => bits (CR_of_R R idx))).
Context (csigma: forall f, CSigma_of_Sigma Sigma f).
Context (csigma: forall f, CSig_denote (CSigma_of_Sigma Sigma f)).

Notation interp_circuit := (interp_circuit (rule_name_t := rule_name_t) cr csigma).

Expand Down Expand Up @@ -76,8 +76,8 @@ Section CompilerCorrectness.

Notation rwdata := (rwdata (rule_name_t := rule_name_t) R Sigma).

Context (sigma: forall f, Sigma f).
Context (csigma: forall f, CSigma f).
Context (sigma: forall f, Sig_denote (Sigma f)).
Context (csigma: forall f, CSig_denote (CSigma f)).
Context {csigma_correct: csigma_spec sigma csigma}.
Context (lco: (@local_circuit_optimizer
rule_name_t reg_t ext_fn_t CR CSigma
Expand Down Expand Up @@ -1622,7 +1622,7 @@ Section CircuitInit.
Context {REnv: Env reg_t}.

Context (r: REnv.(env_t) R).
Context (sigma: forall f, Sigma f).
Context (sigma: forall f, Sig_denote (Sigma f)).

Lemma circuit_env_equiv_CReadRegister :
forall (csigma: forall f, CSig_denote (CSigma_of_Sigma Sigma f)),
Expand All @@ -1647,7 +1647,7 @@ Section Thm.
Context {Show_rule_name_t : Show rule_name_t}.

Context (r: ContextEnv.(env_t) R).
Context (sigma: forall f, Sigma f).
Context (sigma: forall f, Sig_denote (Sigma f)).
Context (lco: (@local_circuit_optimizer
rule_name_t reg_t ext_fn_t
(CR_of_R R) (CSigma_of_Sigma Sigma)
Expand Down
2 changes: 1 addition & 1 deletion coq/CircuitProperties.v
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ Section Circuits.
Context {Show_rule_name_t : Show rule_name_t}.

Definition csigma_spec (sigma: forall f, Sig_denote (Sigma f)) csigma :=
forall fn (a: (Sigma fn).(arg1Type)),
forall fn (a: (Sigma fn).(arg1Sig)),
csigma fn (bits_of_value a) = bits_of_value (sigma fn a).

Lemma csigma_spec_csigma_of_sigma :
Expand Down
34 changes: 17 additions & 17 deletions coq/Circuits.v
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ Section Circuit.
| CMux {sz} (select: circuit 1) (c1 c2: circuit sz): circuit sz
| CConst {sz} (cst: bits sz): circuit sz
| CReadRegister (reg: reg_t): circuit (CR reg)
| CUnop (fn: fbits1) (a1: circuit (CSigma1 fn).(arg1Size))
: circuit (CSigma1 fn).(retSize)
| CBinop (fn: fbits2) (a1: circuit (CSigma2 fn).(arg1Size)) (a2: circuit (CSigma2 fn).(arg2Size))
: circuit (CSigma2 fn).(retSize)
| CUnop (fn: fbits1) (a1: circuit (CSigma1 fn).(arg1Sig))
: circuit (CSigma1 fn).(retSig)
| CBinop (fn: fbits2) (a1: circuit (CSigma2 fn).(arg1Sig)) (a2: circuit (CSigma2 fn).(arg2Sig))
: circuit (CSigma2 fn).(retSig)
| CExternal (idx: ext_fn_t)
(a: circuit (CSigma idx).(arg1Size))
: circuit (CSigma idx).(retSize)
(a: circuit (CSigma idx).(arg1Sig))
: circuit (CSigma idx).(retSig)
| CBundleRef {sz} (name: rule_name_t) (regs: list reg_t)
(bundle: context (fun r => rwdata (CR r)) regs)
(field: rwcircuit_field) (c: circuit sz): circuit sz
Expand Down Expand Up @@ -620,10 +620,10 @@ Section CircuitCompilation.
Defined.

Section Action.
Definition compile_unop (fn: fn1) (a: circuit (type_sz (PrimSignatures.Sigma1 fn).(arg1Type))):
circuit (type_sz (PrimSignatures.Sigma1 fn).(retType)) :=
let cArg1 fn := circuit (type_sz (PrimSignatures.Sigma1 fn).(arg1Type)) in
let cRet fn := circuit (type_sz (PrimSignatures.Sigma1 fn).(retType)) in
Definition compile_unop (fn: fn1) (a: circuit (type_sz (PrimSignatures.Sigma1 fn).(arg1Sig))):
circuit (type_sz (PrimSignatures.Sigma1 fn).(retSig)) :=
let cArg1 fn := circuit (type_sz (PrimSignatures.Sigma1 fn).(arg1Sig)) in
let cRet fn := circuit (type_sz (PrimSignatures.Sigma1 fn).(retSig)) in
match fn return cArg1 fn -> cRet fn with
| Display fn => fun _ => CConst Ob
| Conv tau fn => fun a =>
Expand Down Expand Up @@ -663,12 +663,12 @@ Section CircuitCompilation.
end a.

Definition compile_binop (fn: fn2)
(a1: circuit (type_sz (PrimSignatures.Sigma2 fn).(arg1Type)))
(a2: circuit (type_sz (PrimSignatures.Sigma2 fn).(arg2Type))):
circuit (type_sz (PrimSignatures.Sigma2 fn).(retType)) :=
let cArg1 fn := circuit (type_sz (PrimSignatures.Sigma2 fn).(arg1Type)) in
let cArg2 fn := circuit (type_sz (PrimSignatures.Sigma2 fn).(arg2Type)) in
let cRet fn := circuit (type_sz (PrimSignatures.Sigma2 fn).(retType)) in
(a1: circuit (type_sz (PrimSignatures.Sigma2 fn).(arg1Sig)))
(a2: circuit (type_sz (PrimSignatures.Sigma2 fn).(arg2Sig))):
circuit (type_sz (PrimSignatures.Sigma2 fn).(retSig)) :=
let cArg1 fn := circuit (type_sz (PrimSignatures.Sigma2 fn).(arg1Sig)) in
let cArg2 fn := circuit (type_sz (PrimSignatures.Sigma2 fn).(arg2Sig)) in
let cRet fn := circuit (type_sz (PrimSignatures.Sigma2 fn).(retSig)) in
match fn return cArg1 fn -> cArg2 fn -> cRet fn with
| Eq tau negate => fun a1 a2 => CBinop (EqBits (type_sz tau) negate) a1 a2
| Bits2 fn => fun a1 a2 => CBinop fn a1 a2
Expand Down Expand Up @@ -967,7 +967,7 @@ Section Helpers.

Context {REnv: Env reg_t}.
Context (r: REnv.(env_t) R).
Context (sigma: forall f, Sigma f).
Context (sigma: forall f, Sig_denote (Sigma f)).

Definition interp_circuits
(circuits: register_update_circuitry rule_name_t R Sigma REnv) :=
Expand Down
2 changes: 1 addition & 1 deletion coq/Desugaring.v
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Section Desugaring.
(f: action -> action') (fn: InternalFunction fn_name_t var_t action) :=
{| int_name := fn.(int_name);
int_argspec := fn.(int_argspec);
int_retType := fn.(int_retType);
int_retSig := fn.(int_retSig);
int_body := f fn.(int_body) |}.

Fixpoint desugar_action' {reg_t' ext_fn_t'} (pos: pos_t)
Expand Down
2 changes: 1 addition & 1 deletion coq/ExtractionSetup.v
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Require Koika.Circuits

(* The following commands work around problems due to incorrect extraction: *)
Extraction Inline Koika.Circuits.retVal.
Extraction Inline Types.argTypes Types.argSizes.
Extraction Inline Types.argSigs.

Extract Constant Vect.index => int.
Extract Inductive Vect.index' => int [ "0" "Pervasives.succ" ]
Expand Down
2 changes: 1 addition & 1 deletion coq/OneRuleAtATime.v
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Section Proof.
Context {Sigma: ext_fn_t -> ExternalSignature}.
Context {REnv: Env reg_t}.
Context (r: REnv.(env_t) R).
Context (sigma: forall f, Sigma f).
Context (sigma: forall f, Sig_denote (Sigma f)).

Notation Log := (Log R REnv).
Notation action := (action pos_t var_t R Sigma).
Expand Down
4 changes: 2 additions & 2 deletions coq/Parsing.v
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ Notation "'`' a '`'" := ( a) (in custom koika at level 99, a constr at level 99)
Notation "'fun' args ':' ret '=>' body" :=
{| int_name := "";
int_argspec := args;
int_retType := ret;
int_retSig := ret;
int_body := body |}
(in custom koika at level 99, args custom koika_types, ret constr at level 0, body custom koika at level 99, format "'[v' 'fun' args ':' ret '=>' '/' body ']'").
Notation "'fun' '_' ':' ret '=>' body" :=
{| int_name := "";
int_argspec := nil;
int_retType := ret;
int_retSig := ret;
int_body := body |}
(in custom koika at level 99, ret constr at level 0, body custom koika at level 99, format "'[v' 'fun' '_' ':' ret '=>' '/' body ']'").

Expand Down
2 changes: 1 addition & 1 deletion coq/Primitives.v
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ End PrimTypeInference.

Module CircuitSignatures.
Import PrimTyped.
Import CSigNotations.
Import SigNotations.

Definition CSigma1 (fn: fbits1) : CSig 1 :=
match fn with
Expand Down
10 changes: 5 additions & 5 deletions coq/SyntaxMacros.v
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ Module Display.
Definition empty_printer : InternalFunction fn_name_t var_t uaction :=
{| int_name := "";
int_argspec := [];
int_retType := unit_t;
int_retSig := unit_t;
int_body := USugar USkip |}.

Definition display_utf8 s : uaction :=
Expand All @@ -160,27 +160,27 @@ Module Display.
Definition nl_printer : InternalFunction fn_name_t var_t uaction :=
{| int_name := "";
int_argspec := [];
int_retType := unit_t;
int_retSig := unit_t;
int_body := display_utf8 "\n" |}.

Fixpoint extend_printer f (offset: nat) (printer: intfun) : intfun :=
let opts :=
{| display_newline := false; display_strings := false; display_style := dFull |} in
let display_value arg :=
UUnop (UDisplay (UDisplayValue opts)) (UVar arg) in
let '(Build_InternalFunction int_name int_argspec int_retType int_body) :=
let '(Build_InternalFunction int_name int_argspec int_retSig int_body) :=
printer in
match f with
| Str s =>
{| int_name := int_name;
int_argspec := int_argspec;
int_retType := int_retType;
int_retSig := int_retSig;
int_body := (USeq (display_utf8 s) int_body) |}
| Value tau =>
let arg := String.append "arg" (show offset) in
{| int_name := int_name;
int_argspec := (arg, tau) :: int_argspec;
int_retType := unit_t;
int_retSig := unit_t;
int_body := (USeq (display_value arg) int_body) |}
end.

Expand Down
10 changes: 5 additions & 5 deletions coq/TypeInference.v
Original file line number Diff line number Diff line change
Expand Up @@ -151,27 +151,27 @@ Section TypeInference.
let tc_args_w_pos := List.combine arg_positions tc_args in
let/res arg_ctx := assert_argtypes e fn.(int_name) pos fn.(int_argspec) tc_args_w_pos in
let/res fn_body' := type_action (actpos pos fn.(int_body)) fn.(int_argspec) fn.(int_body) in
let/res fn_body' := cast_action (actpos pos fn.(int_body)) fn.(int_retType) (``fn_body') in
let/res fn_body' := cast_action (actpos pos fn.(int_body)) fn.(int_retSig) (``fn_body') in
Success (EX (InternalCall sig fn.(int_argspec) fn_body' arg_ctx))
| UUnop fn arg1 =>
let pos1 := actpos pos arg1 in
let/res arg1' := type_action pos sig arg1 in
let/res fn := lift_fn1_tc_result pos1 ``arg1' (PrimTypeInference.tc1 fn `arg1') in
let/res arg1' := cast_action pos1 (PrimSignatures.Sigma1 fn).(arg1Type) (``arg1') in
let/res arg1' := cast_action pos1 (PrimSignatures.Sigma1 fn).(arg1Sig) (``arg1') in
Success (EX (Unop fn arg1'))
| UBinop fn arg1 arg2 =>
let pos1 := actpos pos arg1 in
let pos2 := actpos pos arg2 in
let/res arg1' := type_action pos sig arg1 in
let/res arg2' := type_action pos sig arg2 in
let/res fn := lift_fn2_tc_result pos1 ``arg1' pos2 ``arg2' (PrimTypeInference.tc2 fn `arg1' `arg2') in
let/res arg1' := cast_action pos1 (PrimSignatures.Sigma2 fn).(arg1Type) (``arg1') in
let/res arg2' := cast_action pos2 (PrimSignatures.Sigma2 fn).(arg2Type) (``arg2') in
let/res arg1' := cast_action pos1 (PrimSignatures.Sigma2 fn).(arg1Sig) (``arg1') in
let/res arg2' := cast_action pos2 (PrimSignatures.Sigma2 fn).(arg2Sig) (``arg2') in
Success (EX (Binop fn arg1' arg2'))
| UExternalCall fn arg1 =>
let pos1 := actpos pos arg1 in
let/res arg1' := type_action pos1 sig arg1 in
let/res arg1' := cast_action pos1 (Sigma fn).(arg1Type) (``arg1') in
let/res arg1' := cast_action pos1 (Sigma fn).(arg1Sig) (``arg1') in
Success (EX (ExternalCall fn arg1'))
| UAPos pos e =>
let/res e := type_action pos sig e in
Expand Down
14 changes: 7 additions & 7 deletions coq/TypedSyntax.v
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,17 @@ Section TypedSyntax.
(value: action sig (R idx)) : action sig unit_t
| Unop {sig}
(fn: fn1)
(arg1: action sig (Sigma1 fn).(arg1Type))
: action sig (Sigma1 fn).(retType)
(arg1: action sig (Sigma1 fn).(arg1Sig))
: action sig (Sigma1 fn).(retSig)
| Binop {sig}
(fn: fn2)
(arg1: action sig (Sigma2 fn).(arg1Type))
(arg2: action sig (Sigma2 fn).(arg2Type))
: action sig (Sigma2 fn).(retType)
(arg1: action sig (Sigma2 fn).(arg1Sig))
(arg2: action sig (Sigma2 fn).(arg2Sig))
: action sig (Sigma2 fn).(retSig)
| ExternalCall {sig}
(fn: ext_fn_t)
(arg: action sig (Sigma fn).(arg1Type))
: action sig (Sigma fn).(retType)
(arg: action sig (Sigma fn).(arg1Sig))
: action sig (Sigma fn).(retSig)
| APos {sig tau} (pos: pos_t) (a: action sig tau)
: action sig tau.

Expand Down
4 changes: 2 additions & 2 deletions coq/TypedSyntaxTools.v
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,8 @@ Section TypedSyntaxTools.
| @If _ _ _ _ _ _ _ tau _ _ _ => Some tau
| @Read _ _ _ _ _ _ _ _ _ => None
| @Write _ _ _ _ _ _ _ _ _ _ => Some unit_t
| @Unop _ _ _ _ _ _ _ fn _ => Some (PrimSignatures.Sigma1 fn).(retType)
| @Binop _ _ _ _ _ _ _ fn _ _ => Some (PrimSignatures.Sigma2 fn).(retType)
| @Unop _ _ _ _ _ _ _ fn _ => Some (PrimSignatures.Sigma1 fn).(retSig)
| @Binop _ _ _ _ _ _ _ fn _ _ => Some (PrimSignatures.Sigma2 fn).(retSig)
| @ExternalCall _ _ _ _ _ _ _ _ _ => None
| @APos _ _ _ _ _ _ _ tau _ _ => Some tau
end.
Expand Down
78 changes: 28 additions & 50 deletions coq/Types.v
Original file line number Diff line number Diff line change
Expand Up @@ -237,68 +237,46 @@ Coercion type_denote : type >-> Sortclass.

(** * Anonymous function signatures **)

(* Example ufn := {{{ "A" | "x" :: unit_t ~> bits_t 5 | tt }}}. *)

Record CSig {n: nat} := { argSizes : vect nat n; retSize : nat }.
Arguments CSig : clear implicits.

Fixpoint CSig_denote' {n} (args: vect nat n) (ret: nat) :=
match n return vect nat n -> Type with
| 0 => fun _ => bits ret
| S n => fun arg => bits (vect_hd arg) -> CSig_denote' (vect_tl arg) ret
Record _Sig {argKind: Type} {nArgs: nat} :=
{ argSigs : vect argKind nArgs; retSig : argKind }.
Arguments _Sig : clear implicits.

Fixpoint _Sig_denote {nArgs argKind} (type_of_argKind: argKind -> Type)
(args: vect argKind nArgs) (ret: argKind) :=
match nArgs return vect argKind nArgs -> Type with
| 0 => fun _ => type_of_argKind ret
| S n => fun arg => type_of_argKind (vect_hd arg) ->
_Sig_denote type_of_argKind (vect_tl arg) ret
end args.

Definition CSig_denote {n} (sg: CSig n) :=
CSig_denote' sg.(argSizes) sg.(retSize).

Coercion CSig_denote: CSig >-> Sortclass.
Notation Sig n := (_Sig type n).
Notation CSig n := (_Sig nat n).

Notation arg1Size fsig := (vect_hd fsig.(argSizes)).
Notation arg2Size fsig := (vect_hd (vect_tl fsig.(argSizes))).

Module CSigNotations.
Notation "{$ a1 ~> ret $}" :=
{| argSizes := vect_cons a1 vect_nil;
retSize := ret |}.

Notation "{$ a1 ~> a2 ~> ret $}" :=
{| argSizes := vect_cons a1 (vect_cons a2 vect_nil);
retSize := ret |}.
End CSigNotations.
Definition CSig_denote {n} (sg: CSig n) :=
_Sig_denote (@Bits.bits) sg.(argSigs) sg.(retSig).

Record Sig {n: nat} := { argTypes : vect type n; retType : type }.
Arguments Sig : clear implicits.
Definition Sig_denote {n} (sg: Sig n) :=
_Sig_denote type_denote sg.(argSigs) sg.(retSig).

Definition CSig_of_Sig {n} (sig: Sig n) : CSig n :=
{| argSizes := vect_map type_sz sig.(argTypes);
retSize := type_sz sig.(retType) |}.
{| argSigs := vect_map type_sz sig.(argSigs);
retSig := type_sz sig.(retSig) |}.

Definition Sig_of_CSig {n} (sig: CSig n) : Sig n :=
{| argTypes := vect_map bits_t sig.(argSizes);
retType := bits_t sig.(retSize) |}.

Fixpoint Sig_denote' {n} (args: vect type n) (ret: type) :=
match n return vect type n -> Type with
| 0 => fun _ => ret
| S n => fun arg => vect_hd arg -> Sig_denote' (vect_tl arg) ret
end args.

Definition Sig_denote {n} (sg: Sig n) :=
Sig_denote' sg.(argTypes) sg.(retType).

Coercion Sig_denote: Sig >-> Sortclass.
{| argSigs := vect_map bits_t sig.(argSigs);
retSig := bits_t sig.(retSig) |}.

Notation arg1Type fsig := (vect_hd fsig.(argTypes)).
Notation arg2Type fsig := (vect_hd (vect_tl fsig.(argTypes))).
Notation arg1Sig fsig := (vect_hd fsig.(argSigs)).
Notation arg2Sig fsig := (vect_hd (vect_tl fsig.(argSigs))).

Module SigNotations.
Notation "{$ a1 ~> ret $}" :=
{| argTypes := vect_cons a1 vect_nil;
retType := ret |}.
{| argSigs := vect_cons a1 vect_nil;
retSig := ret |}.

Notation "{$ a1 ~> a2 ~> ret $}" :=
{| argTypes := vect_cons a1 (vect_cons a2 vect_nil);
retType := ret |}.
{| argSigs := vect_cons a1 (vect_cons a2 vect_nil);
retSig := ret |}.
End SigNotations.

(** * External functions **)
Expand All @@ -313,11 +291,11 @@ Definition tsig var_t := list (var_t * type).
Record InternalFunction {fn_name_t var_t action: Type} :=
{ int_name : fn_name_t;
int_argspec : tsig var_t;
int_retType : type;
int_retSig : type;
int_body : action }.
Arguments InternalFunction : clear implicits.
Arguments Build_InternalFunction {fn_name_t var_t action}
int_name int_argspec int_retType int_body : assert.
int_name int_argspec int_retSig int_body : assert.

Record arg_sig {var_t} :=
{ arg_name: var_t;
Expand Down
Loading

0 comments on commit b01dde5

Please sign in to comment.