diff --git a/src/base/Cashflow.ml b/src/base/Cashflow.ml index 4731c8dc0..5630b3989 100644 --- a/src/base/Cashflow.ml +++ b/src/base/Cashflow.ml @@ -441,7 +441,7 @@ struct new_map | _ -> targ_tag_map) | PrimType _ | PolyFun (_, _) | Unit -> targ_tag_map - | Address _ -> (* TODO *) targ_tag_map + | Address _ | ProcType _ -> (* TODO *) targ_tag_map in let tvar_tag_map, _ = List.fold_left arg_typs ~init:(init_targ_to_tag_map, arg_tags) @@ -519,7 +519,7 @@ struct let ctr_arg_filter targ = let open CFType in match targ with - | PrimType _ | MapType _ | FunType _ -> true + | PrimType _ | MapType _ | FunType _ | ProcType _ -> true | ADT _ (* TODO: Detect induction, and ignore only when inductive *) | TypeVar _ (* TypeVars tagged at type level *) | PolyFun _ (* Ignore *) | Unit -> diff --git a/src/base/Disambiguate.ml b/src/base/Disambiguate.ml index 5b5b5918e..80658265a 100644 --- a/src/base/Disambiguate.ml +++ b/src/base/Disambiguate.ml @@ -203,6 +203,9 @@ module ScillaDisambiguation (SR : Rep) (ER : Rep) = struct let%bind dis_arg_t = recurse arg_t in let%bind dis_res_t = recurse res_t in pure @@ PostDisType.FunType (dis_arg_t, dis_res_t) + | ProcType (p, args) -> + let%bind dis_args = mapM args ~f:recurse in + pure @@ PostDisType.ProcType (p, dis_args) | ADT (t_name, targs) -> let%bind dis_t_name = disambiguate_identifier typ_dict t_name (get_rep t_name) diff --git a/src/base/JSON.ml b/src/base/JSON.ml index 6de10a65b..40439583d 100644 --- a/src/base/JSON.ml +++ b/src/base/JSON.ml @@ -90,7 +90,7 @@ let build_prim_lit_exn t v = match t with | PrimType pt -> build_prim_literal_of_type pt v | Address _ -> build_prim_literal_of_type (Bystrx_typ Type.address_length) v - | MapType _ | FunType _ | ADT _ | TypeVar _ | PolyFun _ | Unit -> + | MapType _ | FunType _ | ADT _ | TypeVar _ | PolyFun _ | ProcType _ | Unit -> raise (exn ()) (****************************************************************) @@ -240,7 +240,7 @@ and json_to_lit_exn t v = | PrimType _ | Address _ -> let tv = build_prim_lit_exn t (to_string_exn v) in tv - | FunType _ | TypeVar _ | PolyFun _ | Unit -> + | FunType _ | TypeVar _ | PolyFun _ | ProcType _ | Unit -> let exn () = mk_invalid_json ~kind:"Invalid type in JSON" ~inst:(pp_typ t) in diff --git a/src/base/Recursion.ml b/src/base/Recursion.ml index 9e7365bbe..27a2b3664 100644 --- a/src/base/Recursion.ml +++ b/src/base/Recursion.ml @@ -68,6 +68,7 @@ module ScillaRecursion (SR : Rep) (ER : Rep) = struct | MapType (t1, t2) | FunType (t1, t2) -> let%bind () = walk t1 in walk t2 + | ProcType (_p, args) -> forallM ~f:walk args | ADT (s, targs) -> let%bind () = is_adt_in_scope s in forallM ~f:walk targs @@ -288,6 +289,7 @@ module ScillaRecursion (SR : Rep) (ER : Rep) = struct | MapType (t1, t2) | FunType (t1, t2) -> let%bind _ = walk t1 in walk t2 + | ProcType (_, args) -> forallM args ~f:walk | ADT (s, targs) -> (* Only allow ADTs that are already in scope. This prevents mutually inductive definitions. *) let%bind _ = is_adt_in_scope s in diff --git a/src/base/SanityChecker.ml b/src/base/SanityChecker.ml index 3f6577595..0273a30b6 100644 --- a/src/base/SanityChecker.ml +++ b/src/base/SanityChecker.ml @@ -755,7 +755,7 @@ struct | LibVar _ | LibTyp _ -> None) in (* Returns an array with information about matched [Option] arguments - [Some(args)] if the [fun_name] is a procedure. *) + [Some(args)] if [fun_name] is a procedure. *) let handle_comp (cmod : cmodule) option_args_matches fun_name = let get_comp_args comp = match comp.comp_type with @@ -769,8 +769,8 @@ struct match param_ty with | ADT (id, _targs) when is_option_name id -> Map.set m ~key:(get_id param_id) ~data:i - | ADT _ | PrimType _ | MapType _ | FunType _ | TypeVar _ - | PolyFun _ | Unit | Address _ -> + | ADT _ | PrimType _ | MapType _ | FunType _ | ProcType _ + | TypeVar _ | PolyFun _ | Unit | Address _ -> m) in (* Mark Option arguments that matches inside the body. *) diff --git a/src/base/Syntax.ml b/src/base/Syntax.ml index bb5b35702..88707d075 100644 --- a/src/base/Syntax.ml +++ b/src/base/Syntax.ml @@ -374,7 +374,10 @@ module ScillaSyntax (SR : Rep) (ER : Rep) (Lit : ScillaLiteral) = struct ER.rep SIdentifier.t option * SR.rep SIdentifier.t * ER.rep SIdentifier.t list - (** [CallProc(I, P, [A1, ... An])] is a procedure call: [I = P A1 ... An] *) + (** [CallProc(I, P, [A1, ... An])] is a procedure call, when all the + arguments are specified: [I = P A1 ... An]. Otherwise, it is or + partial application of procedure that creates a new local variable + [I] that has the [ProcType] type: [I = P A1 ... An]. *) | Throw of ER.rep SIdentifier.t option (** [Throw(I)] represents: [throw I] *) | GasStmt of SGasCharge.gas_charge diff --git a/src/base/SyntaxAnnotMapper.ml b/src/base/SyntaxAnnotMapper.ml index 65ebc1a9c..9fbc20c97 100644 --- a/src/base/SyntaxAnnotMapper.ml +++ b/src/base/SyntaxAnnotMapper.ml @@ -64,6 +64,7 @@ struct | PrimType _ | Unit | TypeVar _ -> ty | MapType (ty1, ty2) -> MapType (walk ty1, walk ty2) | FunType (ty1, ty2) -> FunType (walk ty1, walk ty2) + | ProcType (p, args) -> ProcType (p, List.map args ~f:walk) | PolyFun (tv, ty) -> PolyFun (tv, walk ty) | ADT (tid, tys) -> ADT (map_id fl tid, List.map tys ~f:(fun ty -> walk ty)) diff --git a/src/base/Type.ml b/src/base/Type.ml index 1cb68397f..985a158c8 100644 --- a/src/base/Type.ml +++ b/src/base/Type.ml @@ -129,14 +129,21 @@ module type ScillaType = sig [@@deriving sexp] type t = - | PrimType of PrimType.t - | MapType of t * t + | PrimType of PrimType.t (** [PrimType(T)] represents a primary type *) + | MapType of t * t (** [MapType(K, V)] is a mapping type: [K |-> V] *) | FunType of t * t + (** [FunType(Ty1, Ty2)] is a function type: [Ty1 => Ty2] *) | ADT of loc TIdentifier.t * t list - | TypeVar of string + (** [ADT(N, Args)] represents ADT type [N] with type parameters [Args] *) + | TypeVar of string (** [TypeVar(T)] is a type variable *) | PolyFun of string * t - | Unit - | Address of t addr_kind + (** [PolyFun('A, T)] represents a polymorphic function type where + ['A] is a type parameter. For example: [forall 'A. List 'a -> List 'A] *) + | ProcType of string * t list + (** [ProcType(P, Args)] is a type of partial application of procedure + [P] which has formal arguments with types [Args] *) + | Unit (** [Unit] is a unit type *) + | Address of t addr_kind (** [Address(A)] represents address *) [@@deriving sexp, to_yojson] val pp_typ : t -> string @@ -223,6 +230,7 @@ module MkType (I : ScillaIdentifier) = struct | ADT of loc TIdentifier.t * t list | TypeVar of string | PolyFun of string * t + | ProcType of string * t list | Unit | Address of (t addr_kind[@to_yojson fun _ -> `String "Address"]) [@@deriving sexp, to_yojson] @@ -239,6 +247,9 @@ module MkType (I : ScillaIdentifier) = struct in String.concat ~sep:" " elems | FunType (at, vt) -> sprintf "%s -> %s" (with_paren at) (recurser vt) + | ProcType (p, args_tys) -> + sprintf "%s (%s)" p + (List.map args_tys ~f:recurser |> String.concat ~sep:", ") | TypeVar tv -> tv | PolyFun (tv, bt) -> sprintf "forall %s. %s" tv (recurser bt) | Unit -> sprintf "()" @@ -280,6 +291,8 @@ module MkType (I : ScillaIdentifier) = struct | PrimType _ | Unit -> acc | MapType (kt, vt) -> go kt acc |> go vt | FunType (at, rt) -> go at acc |> go rt + | ProcType (_, args) -> + List.fold_left args ~init:[] ~f:(fun acc arg -> go arg acc) | TypeVar n -> add acc n | ADT (_, ts) -> List.fold_left ts ~init:acc ~f:(Fn.flip go) | PolyFun (arg, bt) -> @@ -314,6 +327,9 @@ module MkType (I : ScillaIdentifier) = struct let ats = subst_type_in_type tvar tp at in let rts = subst_type_in_type tvar tp rt in FunType (ats, rts) + | ProcType (p, args_tys) -> + let args_tyss = List.map args_tys ~f:(subst_type_in_type tvar tp) in + ProcType (p, args_tyss) | TypeVar n -> if String.(tvar = n) then tp else tm | ADT (s, ts) -> let ts' = List.map ts ~f:(subst_type_in_type tvar tp) in @@ -337,6 +353,8 @@ module MkType (I : ScillaIdentifier) = struct match t with | MapType (kt, vt) -> MapType (kt, recursor vt taken) | FunType (at, rt) -> FunType (recursor at taken, recursor rt taken) + | ProcType (p, args_tys) -> + ProcType (p, List.map args_tys ~f:(fun ty -> recursor ty taken)) | ADT (n, ts) -> let ts' = List.map ts ~f:(fun w -> recursor w taken) in ADT (n, ts') diff --git a/src/base/TypeChecker.ml b/src/base/TypeChecker.ml index b0adcbe68..6cd081d43 100644 --- a/src/base/TypeChecker.ml +++ b/src/base/TypeChecker.ml @@ -147,6 +147,7 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct match t with | PrimType _ | Unit | TypeVar _ -> 1 | PolyFun (_, t) -> 1 + type_size t + | ProcType (_, args) -> 1 + List.length args | MapType (t1, t2) | FunType (t1, t2) -> 1 + type_size t1 + type_size t2 | ADT (_, ts) -> List.fold_left ts ~init:1 ~f:(fun acc t -> acc + type_size t) @@ -163,7 +164,8 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct | MapType (_, _) | FunType (_, _) | ADT (_, _) - | PolyFun (_, _) -> + | PolyFun (_, _) + | ProcType (_, _) -> 1 | TypeVar n -> if String.(n = tvar) then tp_size else 1 | Address AnyAddr | Address LibAddr | Address CodeAddr -> 1 @@ -201,6 +203,9 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct else let%bind res = recurser t' in pure (PolyFun (arg, res)) + | ProcType (pname, args) -> + let%bind args' = mapM args ~f:recurser in + pure (ProcType (pname, args')) | Address AnyAddr | Address LibAddr | Address CodeAddr -> pure t | Address (ContrAddr fts) -> let%bind fts_res = @@ -551,15 +556,37 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct (* Typing statements *) (**************************************************************) - (* Auxiliary structure for types of fields and BC components *) type stmt_tenv = { pure : TEnv.t; fields : TEnv.t; procedures : (TCName.t * (TCType.t list * TCType.t option)) list; } + (** Auxiliary structure for types of fields and BC components *) + (** Looks up a procedure with name [pname] or a local binding to a partial + application of procedure. *) let lookup_proc env pname = - List.Assoc.find env.procedures ~equal:[%equal: TCName.t] (get_id pname) + let proc = + List.Assoc.find env.procedures ~equal:[%equal: TCName.t] (get_id pname) + in + match proc with + | None -> ( + (* Lookup local bind to a partial application *) + let get_proc_type_args (rr : resolve_result) = + match (rr_typ rr).tp with + | ProcType (_proc_name, arg_tys) -> Some arg_tys + | _ -> None + in + match TEnv.resolveT env.pure (get_id pname) ~lopt:None with + | Ok bind -> + get_proc_type_args bind + |> Option.value_map + ~f:(fun arg_tys -> + (* Partially applied procedures never have a return type *) + Some (arg_tys, None)) + ~default:None + | _ -> None) + | res -> res let type_map_access_helper env maptype keys = let rec helper maptype keys = @@ -985,8 +1012,8 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct @@ add_stmt_to_stmts_env_gas (TypedSyntax.CreateEvnt typed_i, rep) checked_stmts - | CallProc (id_opt, p, args) -> - let%bind arg_typs, ret_ty_opt = + | CallProc (id_opt, p, actual_args) -> + let%bind formal_args, ret_ty_opt = match lookup_proc env p with | Some (arg_typs, ret_ty_opt) -> pure (arg_typs, ret_ty_opt) | None -> @@ -995,41 +1022,76 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct ~inst:(as_error_string p) (SR.get_loc (get_rep p))) in - let%bind typed_args = - let%bind targs, typed_actuals = type_actuals env.pure args in + let%bind targs, typed_actuals = type_actuals env.pure actual_args in + let is_partial_application = + Option.is_some id_opt + && List.length formal_args > List.length targs + in + if is_partial_application then let%bind _ = fromR_TE - @@ proc_type_applies arg_typs targs ~lc:(SR.get_loc rep) + @@ partial_proc_type_applies formal_args targs + ~lc:(SR.get_loc rep) in - pure typed_actuals - in - let%bind typed_id_opt, checked_stmts = - match id_opt with - | None -> - let%bind checked_stmts = type_stmts comp sts get_loc env in - pure @@ (None, checked_stmts) - | Some id -> ( - match ret_ty_opt with - | Some ret_ty -> - let typed_id = add_type_to_ident id (mk_qual_tp ret_ty) in - let%bind checked_stmts = - with_extended_env env get_tenv_pure - [ (id, ret_ty) ] - [] - (type_stmts comp sts get_loc) - in - pure @@ (Some typed_id, checked_stmts) - | None -> - fail - (mk_type_error1 - ~kind:"Procedure does not return a value" - ~inst:(as_error_string p) - (SR.get_loc (get_rep p)))) - in - pure - @@ add_stmt_to_stmts_env_gas - (TypedSyntax.CallProc (typed_id_opt, p, typed_args), rep) - checked_stmts + let id = Option.value_exn id_opt in + let proc_name = + SIdentifier.Name.as_string (SIdentifier.get_id p) + in + let unapplied_formal_args = + List.sub formal_args + ~pos:(List.length targs - 1) + ~len:(List.length formal_args - List.length targs) + in + let partial_applied_type = + ProcType (proc_name, unapplied_formal_args) + in + let typed_id = + add_type_to_ident id (mk_qual_tp partial_applied_type) + in + let%bind checked_stmts = + with_extended_env env get_tenv_pure + [ (id, partial_applied_type) ] + [] + (type_stmts comp sts get_loc) + in + pure + @@ add_stmt_to_stmts_env_gas + (TypedSyntax.CallProc (Some typed_id, p, typed_actuals), rep) + checked_stmts + else + let%bind _ = + fromR_TE + @@ proc_type_applies formal_args targs ~lc:(SR.get_loc rep) + in + let%bind typed_id_opt, checked_stmts = + match id_opt with + | None -> + let%bind checked_stmts = type_stmts comp sts get_loc env in + pure @@ (None, checked_stmts) + | Some id -> ( + match ret_ty_opt with + | Some ret_ty -> + let typed_id = + add_type_to_ident id (mk_qual_tp ret_ty) + in + let%bind checked_stmts = + with_extended_env env get_tenv_pure + [ (id, ret_ty) ] + [] + (type_stmts comp sts get_loc) + in + pure @@ (Some typed_id, checked_stmts) + | None -> + fail + (mk_type_error1 + ~kind:"Procedure does not return a value" + ~inst:(as_error_string p) + (SR.get_loc (get_rep p)))) + in + pure + @@ add_stmt_to_stmts_env_gas + (TypedSyntax.CallProc (typed_id_opt, p, typed_actuals), rep) + checked_stmts | Iterate (l, p) -> ( let%bind lt = fromR_TE diff --git a/src/base/TypeUtil.ml b/src/base/TypeUtil.ml index 7574438a9..1ba0faf49 100644 --- a/src/base/TypeUtil.ml +++ b/src/base/TypeUtil.ml @@ -216,6 +216,7 @@ functor | FunType (at, rt) -> let%bind () = is_wf_typ' at tb in is_wf_typ' rt tb + | ProcType (_, args) -> forallM args ~f:(fun ty -> is_wf_typ' ty tb) | ADT (n, ts) -> let open Datatypes.DataTypeDictionary in let%bind adt = lookup_name ~sloc:(get_rep n) (get_id n) in @@ -348,8 +349,8 @@ module TypeUtilities = struct | _ -> true let rec is_legal_type_helper ~allow_maps ~allow_messages_events - ~allow_closures ~allow_polymorphism ~allow_unit ~check_addresses t - seen_adts seen_tvars = + ~allow_closures ~allow_polymorphism ~allow_partial_procedures ~allow_unit + ~check_addresses t seen_adts seen_tvars = let rec recurser t seen_adts seen_tvars = match t with | FunType (a, r) -> @@ -358,6 +359,7 @@ module TypeUtilities = struct && recurser r seen_adts seen_tvars | PolyFun (tvar, t) -> allow_polymorphism && recurser t seen_adts (tvar :: seen_tvars) + | ProcType _ -> allow_partial_procedures | Unit -> allow_unit | MapType (kt, vt) -> allow_maps @@ -402,47 +404,54 @@ module TypeUtilities = struct and is_legal_field_type_helper t seen_adts seen_tvars = (* Maps are allowed. Address values should be checked for storable field types. *) is_legal_type_helper ~allow_maps:true ~allow_messages_events:false - ~allow_closures:false ~allow_polymorphism:false ~allow_unit:false - ~check_addresses:true t seen_adts seen_tvars + ~allow_closures:false ~allow_polymorphism:false + ~allow_partial_procedures:false ~allow_unit:false ~check_addresses:true t + seen_adts seen_tvars let is_legal_message_field_type t = (* Maps are not allowed. Address values are considered ByStr20 when used as message field value. *) is_legal_type_helper ~allow_maps:false ~allow_messages_events:false - ~allow_closures:false ~allow_polymorphism:false ~allow_unit:false - ~check_addresses:false t [] [] + ~allow_closures:false ~allow_polymorphism:false + ~allow_partial_procedures:false ~allow_unit:false ~check_addresses:false t + [] [] let is_legal_transition_parameter_type t = (* Maps are not allowed. Address values should be checked for storable field types. *) is_legal_type_helper ~allow_maps:false ~allow_messages_events:false - ~allow_closures:false ~allow_polymorphism:false ~allow_unit:false - ~check_addresses:true t [] [] + ~allow_closures:false ~allow_polymorphism:false + ~allow_partial_procedures:false ~allow_unit:false ~check_addresses:true t + [] [] let is_legal_procedure_parameter_type t = - (* Like transition parametes, except messages, events and closures are allowed, + (* Like transition parameters, except messages, events and closures are allowed, since parameters do not need to be serializable. *) is_legal_type_helper ~allow_maps:false ~allow_messages_events:true - ~allow_closures:false ~allow_polymorphism:false ~allow_unit:false - ~check_addresses:true t [] [] + ~allow_closures:false ~allow_polymorphism:false + ~allow_partial_procedures:false ~allow_unit:false ~check_addresses:true t + [] [] let is_legal_contract_parameter_type t = (* Like transitions parameters, except maps are allowed (due to an exploited bug). Address values should be checked for storable field types. *) is_legal_type_helper ~allow_maps:true ~allow_messages_events:false - ~allow_closures:false ~allow_polymorphism:false ~allow_unit:false - ~check_addresses:true t [] [] + ~allow_closures:false ~allow_polymorphism:false + ~allow_partial_procedures:false ~allow_unit:false ~check_addresses:true t + [] [] let is_legal_field_type t = is_legal_field_type_helper t [] [] let is_legal_hash_argument_type t = (* Only closures and type closures are disallowed. Addresses behave like ByStr20. *) is_legal_type_helper ~allow_maps:true ~allow_messages_events:true - ~allow_closures:false ~allow_polymorphism:false ~allow_unit:false - ~check_addresses:false t [] [] + ~allow_closures:false ~allow_polymorphism:false + ~allow_partial_procedures:false ~allow_unit:false ~check_addresses:false t + [] [] let is_legal_map_key_type t = (* Only primitive (non-message and non-event) types are allowed. Addresses behave like ByStr20, and are thus allowed. *) is_legal_type_helper ~allow_maps:false ~allow_messages_events:false - ~allow_closures:false ~allow_polymorphism:false ~allow_unit:false - ~check_addresses:false t [] [] + ~allow_closures:false ~allow_polymorphism:false + ~allow_partial_procedures:false ~allow_unit:false ~check_addresses:false t + [] [] let get_msgevnt_type m lc = let open ContractUtil.MessagePayload in @@ -542,6 +551,13 @@ module TypeUtilities = struct mk_error1 ~kind:"Incorrect number of arguments to procedure" ?inst:None lc) + let partial_proc_type_applies ~lc formals actuals = + if List.length formals < List.length actuals then + fail (mk_error1 ~kind:"Extra arguments in procedure call" ?inst:None lc) + else + let formals' = List.sub ~pos:0 ~len:(List.length actuals) formals in + proc_type_applies ~lc formals' actuals + let rec elab_tfun_with_args_no_gas tf args = match (tf, args) with | (PolyFun _ as pf), a :: args' -> diff --git a/src/base/TypeUtil.mli b/src/base/TypeUtil.mli index 27a711a61..8307fbd8f 100644 --- a/src/base/TypeUtil.mli +++ b/src/base/TypeUtil.mli @@ -200,6 +200,14 @@ module TypeUtilities : sig TUType.t list -> (unit list, scilla_error list) result + (* Checks if types of the specified [actuals] arguments are assignable to the + corresponding types of [formals]. *) + val partial_proc_type_applies : + lc:ErrorUtils.loc -> + TUType.t list -> + TUType.t list -> + (unit list, scilla_error list) result + (* Applying a type function without gas charge (for builtins) *) val elab_tfun_with_args_no_gas : TUType.t -> TUType.t list -> (TUType.t, scilla_error list) result diff --git a/src/eval/Disambiguator.ml b/src/eval/Disambiguator.ml index f14242069..ff383b6b8 100644 --- a/src/eval/Disambiguator.ml +++ b/src/eval/Disambiguator.ml @@ -234,6 +234,9 @@ let disambiguate_type t this_address = let dis_arg_t = recurse arg_t in let dis_res_t = recurse res_t in OutputType.FunType (dis_arg_t, dis_res_t) + | ProcType (p, args_tys) -> + let dis_args_tys = List.map args_tys ~f:recurse in + OutputType.ProcType (p, dis_args_tys) | ADT (t_name, targs) -> let dis_nm = disambiguate_adt_name (InputIdentifier.get_id t_name) this_address diff --git a/src/eval/Eval.ml b/src/eval/Eval.ml index 3fffd86b7..4e4562cba 100644 --- a/src/eval/Eval.ml +++ b/src/eval/Eval.ml @@ -83,6 +83,48 @@ let sanitize_literal l = if is_legal_message_field_type t then pure l else fail0 ~kind:"Cannot serialize literal" ~inst:(pp_literal l) +(** Looks up for a partial application of procedure bounded to a local variable + named [p]. *) +let find_partial_appication (conf : Configuration.t) (p : loc SIdentifier.t) : + ( component * component list, + ErrorUtils.scilla_error list, + 'b -> 'c ) + CPSMonad.t = + let process_local_bind (local_bind : SLiteral.t) = + match literal_type local_bind with + | Ok (ProcType (proc_name, _args)) -> ( + match Configuration.lookup_procedure conf p with + | Some (proc, p_rest) -> pure @@ (proc, p_rest) + | None -> + fail0 + ~kind:(Printf.sprintf "Cannot find procedure %s" proc_name) + ~inst:(as_error_string p)) + | Error e -> (* Literal typing error *) fail e + | _ -> + fail0 + ~kind: + (Printf.sprintf "%s is not a procedure and cannot be called" + (SIdentifier.Name.as_string (SIdentifier.get_id p))) + ~inst:(as_error_string p) + in + match Env.lookup conf.env p with + | Ok local_bind -> process_local_bind local_bind + | Error _ -> + fail0 + ~kind: + (Printf.sprintf "Identifier %s is not found" + (SIdentifier.Name.as_string (SIdentifier.get_id p))) + ~inst:(as_error_string p) + +(** Applies [p] as is implemented in [apply]. [p] may be a procedure name or a + local bind to a partial application of procedure. *) +let apply_procedure ~apply conf p = + match Configuration.lookup_procedure conf p with + | Some (proc, p_rest) -> apply proc p_rest + | None -> + let%bind proc, p_rest = find_partial_appication conf p in + apply proc p_rest + let eval_gas_charge env g = let open MonadUtil in let open Result.Let_syntax in @@ -164,6 +206,7 @@ let replace_address_types l = | Address _ -> bystrx_typ Type.address_length | MapType (kt, vt) -> MapType (replace_in_type kt, replace_in_type vt) | FunType (t1, t2) -> FunType (replace_in_type t1, replace_in_type t2) + | ProcType (p, args) -> ProcType (p, List.map args ~f:replace_in_type) | ADT (tname, targs) -> ADT (tname, List.map targs ~f:replace_in_type) in let replace_in_literal l = @@ -486,32 +529,36 @@ let rec stmt_eval conf stmts = let%bind conf' = Configuration.create_event conf eparams_resolved in stmt_eval conf' sts | CallProc (id_opt, p, actuals) -> - (* Resolve the actuals *) - let%bind args = - mapM actuals ~f:(fun arg -> fromR @@ Env.lookup conf.env arg) + let apply proc p_rest = + (* Resolve the actuals *) + let%bind args = + mapM actuals ~f:(fun arg -> fromR @@ Env.lookup conf.env arg) + in + (* Apply procedure. No gas charged for the application *) + let%bind conf' = try_apply_as_procedure conf proc p_rest args in + (* Bind the return of [p] if it returns. *) + let%bind conf' = + match id_opt with + | Some id -> Configuration.procedure_return conf' id + | None -> pure @@ conf' + in + stmt_eval conf' sts in - let%bind proc, p_rest = Configuration.lookup_procedure conf p in - (* Apply procedure. No gas charged for the application *) - let%bind conf' = try_apply_as_procedure conf proc p_rest args in - (* Bind the return of [p] if it returns. *) - let%bind conf' = - match id_opt with - | Some id -> Configuration.procedure_return conf' id - | None -> pure @@ conf' - in - stmt_eval conf' sts + apply_procedure ~apply conf p | Iterate (l, p) -> - let%bind l_actual = fromR @@ Env.lookup conf.env l in - let%bind l' = fromR @@ Datatypes.scilla_list_to_ocaml l_actual in - let%bind proc, p_rest = Configuration.lookup_procedure conf p in - let%bind conf' = - foldM l' ~init:conf ~f:(fun confacc arg -> - let%bind conf' = - try_apply_as_procedure confacc proc p_rest [ arg ] - in - pure conf') + let apply proc p_rest = + let%bind l_actual = fromR @@ Env.lookup conf.env l in + let%bind l' = fromR @@ Datatypes.scilla_list_to_ocaml l_actual in + let%bind conf' = + foldM l' ~init:conf ~f:(fun confacc arg -> + let%bind conf' = + try_apply_as_procedure confacc proc p_rest [ arg ] + in + pure conf') + in + stmt_eval conf' sts in - stmt_eval conf' sts + apply_procedure ~apply conf p | Throw eopt -> let%bind estr = match eopt with diff --git a/src/eval/EvalUtil.ml b/src/eval/EvalUtil.ml index 808562f67..9bf11a517 100644 --- a/src/eval/EvalUtil.ml +++ b/src/eval/EvalUtil.ml @@ -445,10 +445,9 @@ module Configuration = struct let rec finder procs = match procs with | p :: p_rest when EvalIdentifier.equal p.comp_name proc_name -> - pure (p, p_rest) + Some (p, p_rest) | _ :: p_rest -> finder p_rest - | [] -> - fail0 ~kind:"Procedure not found" ~inst:(as_error_string proc_name) + | [] -> None in finder st.procedures diff --git a/src/eval/Runner.ml b/src/eval/Runner.ml index f7c22ed2f..64c0bcd04 100644 --- a/src/eval/Runner.ml +++ b/src/eval/Runner.ml @@ -188,7 +188,7 @@ let assert_no_address_type_in_type t gas_remaining = let () = recurser kt in recurser vt | ADT (_, ts) -> List.iter ts ~f:recurser - | FunType _ | TypeVar _ | PolyFun _ | Unit -> + | FunType _ | ProcType _ | TypeVar _ | PolyFun _ | Unit -> fatal_error_gas_scale Gas.scale_factor (mk_error0 ~kind:"Illegal type in json file" ~inst:(pp_typ t)) gas_remaining diff --git a/src/formatter/Formatter.ml b/src/formatter/Formatter.ml index 178beb84e..b314ff47c 100644 --- a/src/formatter/Formatter.ml +++ b/src/formatter/Formatter.ml @@ -149,6 +149,8 @@ struct (* TODO: test MapType and PolyFun pretty-printing *) | MapType (kt, vt) -> parens_if (p > 0) @@ map_kwd ^//^ (of_type_with_prec 1 kt) ^/^ (of_type_with_prec 1 vt) + | ProcType _ -> + failwith "ProcType annotation cannot appear in source code" | PolyFun (tv, bt) -> parens_if (p > 0) @@ forall_kwd ^^^ !^tv ^^ dot ^//^ (of_type_with_prec 0 bt) | ADT (tid, tys) -> diff --git a/src/merge/Merge.ml b/src/merge/Merge.ml index a7594aafc..501e22df9 100644 --- a/src/merge/Merge.ml +++ b/src/merge/Merge.ml @@ -236,6 +236,18 @@ module ScillaMerger (SR : Rep) (ER : Rep) = struct rename_local renames_map id (PIdentifier.get_rep id |> Lazy.from_val) |> add_loc + (** Renames [name] using [renames_map]. *) + let rename_name (name : string) renames_map = + let keys = Map.keys renames_map in + List.find keys ~f:(fun n -> + String.equal (SIdentifier.Name.as_string n) name) + |> Option.value_map ~default:name ~f:(fun k -> + let possible_names = Map.find_exn renames_map k in + let new_name = + choose_name possible_names name (lazy ErrorUtils.dummy_loc) + in + SIdentifier.Name.as_string new_name) + (** Renames user-defined ADTs from [renames_map]. *) let rec rename_ty renames_map (ty : SType.t) : SType.t = match ty with @@ -243,6 +255,12 @@ module ScillaMerger (SR : Rep) (ER : Rep) = struct MapType (rename_ty renames_map key_ty, rename_ty renames_map val_ty) | FunType (arg_ty, ret_ty) -> FunType (rename_ty renames_map arg_ty, rename_ty renames_map ret_ty) + | ProcType (pname, args_tys) -> + let pname' = rename_name pname renames_map in + let args_tys' = + List.map args_tys ~f:(fun ty -> rename_ty renames_map ty) + in + SType.ProcType (pname', args_tys') | ADT (id, tys) -> let id' = rename_local_loc renames_map id in let tys' = List.map tys ~f:(fun ty -> rename_ty renames_map ty) in @@ -262,6 +280,12 @@ module ScillaMerger (SR : Rep) (ER : Rep) = struct let arg_ty', renames_map' = qualify_ty renames_map arg_ty in let ret_ty', renames_map' = qualify_ty renames_map' ret_ty in (FunType (arg_ty', ret_ty'), renames_map') + | ProcType (pname, args_tys) -> + let pname' = rename_name pname renames_map in + let args_tys' = + List.map args_tys ~f:(fun ty -> rename_ty renames_map ty) + in + (ProcType (pname', args_tys'), renames_map) | ADT (id, tys) -> ( let name = PIdentifier.get_id id in match Map.find renames_map name with diff --git a/tests/runner/Testcontracts.ml b/tests/runner/Testcontracts.ml index bdd8a73ef..c27ca8062 100644 --- a/tests/runner/Testcontracts.ml +++ b/tests/runner/Testcontracts.ml @@ -499,6 +499,9 @@ let contract_tests env = >::: build_contract_tests env "ark" succ_code 1 1 []; "return-1" >::: build_contract_tests env "return-1" succ_code 1 1 []; + "partial-application-1" + >::: build_contract_tests env "partial-application-1" succ_code + 1 1 []; ]; "these_tests_must_FAIL" >::: [ diff --git a/tests/runner/partial-application-1/blockchain_1.json b/tests/runner/partial-application-1/blockchain_1.json new file mode 100644 index 000000000..d962cce79 --- /dev/null +++ b/tests/runner/partial-application-1/blockchain_1.json @@ -0,0 +1 @@ +[ { "vname": "BLOCKNUMBER", "type": "BNum", "value": "100" } ] diff --git a/tests/runner/partial-application-1/init.json b/tests/runner/partial-application-1/init.json new file mode 100644 index 000000000..1ce7ec6fb --- /dev/null +++ b/tests/runner/partial-application-1/init.json @@ -0,0 +1,17 @@ +[ + { + "vname" : "_scilla_version", + "type" : "Uint32", + "value" : "0" + }, + { + "vname" : "_this_address", + "type" : "ByStr20", + "value" : "0xabfeccdc9012345678901234567890f777567890" + }, + { + "vname" : "_creation_block", + "type" : "BNum", + "value" : "1" + } +] diff --git a/tests/runner/partial-application-1/message_1.json b/tests/runner/partial-application-1/message_1.json new file mode 100644 index 000000000..bb81e832c --- /dev/null +++ b/tests/runner/partial-application-1/message_1.json @@ -0,0 +1,37 @@ +{ + "_tag": "test", + "_amount": "0", + "_sender": "0x1234567890123456789012345678901234567890", + "params": [ + { + "vname": "l", + "type": "List Int32", + "value": { + "constructor": "Cons", + "argtypes": [ + "Int32" + ], + "arguments": [ + "2", + { + "constructor": "Cons", + "argtypes": [ + "Int32" + ], + "arguments": [ + "1", + { + "constructor": "Nil", + "argtypes": [ + "Int32" + ], + "arguments": [] + } + ] + } + ] + } + } + ], + "_origin": "0x1234567890123456789012345678901234567890" +} diff --git a/tests/runner/partial-application-1/output_1.json b/tests/runner/partial-application-1/output_1.json new file mode 100644 index 000000000..86f225852 --- /dev/null +++ b/tests/runner/partial-application-1/output_1.json @@ -0,0 +1,42 @@ +{ + "gas_remaining": "7967", + "errors": [ + { + "error_message": "Exception thrown: (Message [(_exception : (String \"Error\")) ; (code : (Int32 -1))])", + "start_location": { + "file": "tests/contracts/exception-position.scilla", + "line": 9, + "column": 3 + }, + "end_location": { "file": "", "line": 0, "column": 0 } + }, + { + "error_message": "Raised from CheckExn", + "start_location": { + "file": "tests/contracts/exception-position.scilla", + "line": 7, + "column": 11 + }, + "end_location": { "file": "", "line": 0, "column": 0 } + }, + { + "error_message": "Raised from CheckNestedExn", + "start_location": { + "file": "tests/contracts/exception-position.scilla", + "line": 16, + "column": 11 + }, + "end_location": { "file": "", "line": 0, "column": 0 } + }, + { + "error_message": "Raised from Test", + "start_location": { + "file": "tests/contracts/exception-position.scilla", + "line": 20, + "column": 12 + }, + "end_location": { "file": "", "line": 0, "column": 0 } + } + ], + "warnings": [] +} \ No newline at end of file diff --git a/tests/runner/partial-application-1/state_1.json b/tests/runner/partial-application-1/state_1.json new file mode 100644 index 000000000..ad24288b7 --- /dev/null +++ b/tests/runner/partial-application-1/state_1.json @@ -0,0 +1,7 @@ +[ + { + "vname": "_balance", + "type" : "Uint128", + "value": "0" + } +]