diff --git a/src/analysis_and_optimization/Optimize.ml b/src/analysis_and_optimization/Optimize.ml index 6e9c9184e0..d37f12077f 100644 --- a/src/analysis_and_optimization/Optimize.ml +++ b/src/analysis_and_optimization/Optimize.ml @@ -453,7 +453,9 @@ let create_function_inline_map adt l = | None -> accum | Some fdbody -> ( let create_data propto = - ( Option.map ~f:(fun x -> Type.Unsized x) fdrt + ( Option.map + ~f:(fun x -> Type.Unsized x) + (UnsizedType.returntype_to_type_opt fdrt) , List.map ~f:(fun (_, name, _) -> name) fdargs , inline_function_statement propto adt accum fdbody ) in match Middle.Utils.with_unnormalized_suffix fdname with diff --git a/src/frontend/Ast_to_Mir.ml b/src/frontend/Ast_to_Mir.ml index 100490cff8..7b4df865f1 100644 --- a/src/frontend/Ast_to_Mir.ml +++ b/src/frontend/Ast_to_Mir.ml @@ -587,8 +587,7 @@ let trans_fun_def ud_dists (ts : Ast.typed_statement) = match ts.stmt with | Ast.FunDef {returntype; funname; arguments; body} -> [ Program. - { fdrt= - (match returntype with Void -> None | ReturnType ut -> Some ut) + { fdrt= returntype ; fdname= funname.name ; fdsuffix= Fun_kind.(suffix_from_name funname.name |> without_propto) ; fdargs= List.map ~f:trans_arg arguments @@ -785,7 +784,8 @@ let gather_data (p : Ast.typed_program) = | _ -> [] ) let trans_prog filename (p : Ast.typed_program) : Program.Typed.t = - let {Ast.functionblock; datablock; transformeddatablock; modelblock; _} = p in + let {Ast.functionblock; datablock; transformeddatablock; modelblock; _} = + Deprecation_analysis.remove_unneeded_forward_decls p in let map f list_op = Option.value_map ~default:[] ~f:(fun {Ast.stmts; _} -> List.concat_map ~f stmts) diff --git a/src/frontend/Canonicalize.ml b/src/frontend/Canonicalize.ml index a2cdf0ea81..dd7f6b5f10 100644 --- a/src/frontend/Canonicalize.ml +++ b/src/frontend/Canonicalize.ml @@ -261,16 +261,7 @@ let repair_syntax program settings = let canonicalize_program program settings : typed_program = let program = if settings.deprecations then - let fundefs = userdef_functions program in - let drop_forwarddecl = function - | {stmt= FunDef {body= {stmt= Skip; _}; funname; arguments; _}; _} - when is_redundant_forwarddecl fundefs funname arguments -> - false - | _ -> true in - { program with - functionblock= - Option.map program.functionblock ~f:(fun x -> - {x with stmts= List.filter ~f:drop_forwarddecl x.stmts} ) } + remove_unneeded_forward_decls program |> map_program (replace_deprecated_stmt (collect_userdef_distributions program)) else program in diff --git a/src/frontend/Deprecation_analysis.ml b/src/frontend/Deprecation_analysis.ml index 0f58e7ca77..aaef2cf710 100644 --- a/src/frontend/Deprecation_analysis.ml +++ b/src/frontend/Deprecation_analysis.ml @@ -47,18 +47,17 @@ let rename_deprecated map name = let userdef_functions program = match program.functionblock with - | None -> [] + | None -> Hash_set.Poly.create () | Some {stmts; _} -> List.filter_map stmts ~f:(function | {stmt= FunDef {body= {stmt= Skip; _}; _}; _} -> None | {stmt= FunDef {funname; arguments; _}; _} -> Some (funname.name, Ast.type_of_arguments arguments) | _ -> None ) + |> Hash_set.Poly.of_list let is_redundant_forwarddecl fundefs funname arguments = - let equal (id1, a1) (id2, a2) = - String.equal id1 id2 && UnsizedType.equal_argumentlist a1 a2 in - List.mem ~equal fundefs (funname.name, Ast.type_of_arguments arguments) + Hash_set.mem fundefs (funname.name, Ast.type_of_arguments arguments) let userdef_distributions stmts = let open String in @@ -235,3 +234,15 @@ let collect_userdef_distributions program = let collect_warnings (program : typed_program) = let fundefs = userdef_functions program in fold_program (collect_deprecated_stmt fundefs) [] program + +let remove_unneeded_forward_decls program = + let fundefs = userdef_functions program in + let drop_forwarddecl = function + | {stmt= FunDef {body= {stmt= Skip; _}; funname; arguments; _}; _} + when is_redundant_forwarddecl fundefs funname arguments -> + false + | _ -> true in + { program with + functionblock= + Option.map program.functionblock ~f:(fun x -> + {x with stmts= List.filter ~f:drop_forwarddecl x.stmts} ) } diff --git a/src/frontend/Deprecation_analysis.mli b/src/frontend/Deprecation_analysis.mli index 28c8091afd..eeabc62b7c 100644 --- a/src/frontend/Deprecation_analysis.mli +++ b/src/frontend/Deprecation_analysis.mli @@ -18,16 +18,6 @@ val is_deprecated_distribution : string -> bool val deprecated_distributions : (string * string) String.Map.t val deprecated_functions : (string * string) String.Map.t val rename_deprecated : (string * string) String.Map.t -> string -> string - -val userdef_functions : - ('a, 'b, 'c, 'd) statement_with program - -> (string * Middle.UnsizedType.argumentlist) list - -val is_redundant_forwarddecl : - (string * Middle.UnsizedType.argumentlist) list - -> identifier - -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t * 'a) list - -> bool - val userdef_distributions : untyped_statement block option -> string list val collect_warnings : typed_program -> Warnings.t list +val remove_unneeded_forward_decls : typed_program -> typed_program diff --git a/src/middle/Program.ml b/src/middle/Program.ml index 7e1212cfff..4c2c79d803 100644 --- a/src/middle/Program.ml +++ b/src/middle/Program.ml @@ -6,12 +6,13 @@ type fun_arg_decl = (UnsizedType.autodifftype * string * UnsizedType.t) list [@@deriving sexp, hash, map] type 'a fun_def = - { fdrt: UnsizedType.t option + { fdrt: UnsizedType.returntype ; fdname: string ; fdsuffix: unit Fun_kind.suffix ; fdargs: (UnsizedType.autodifftype * string * UnsizedType.t) list - (* If fdbody is None, this is a function declaration without body. *) ; fdbody: 'a option + (* If fdbody is None, this is an external function declaration + (forward decls are removed during AST lowering) *) ; fdloc: (Location_span.t[@sexp.opaque] [@compare.ignore]) } [@@deriving compare, hash, sexp, map, fold] @@ -49,20 +50,17 @@ let pp_fun_arg_decl ppf (autodifftype, name, unsizedtype) = Fmt.pf ppf "%a%a %s" UnsizedType.pp_autodifftype autodifftype UnsizedType.pp unsizedtype name -let pp_fun_def pp_s ppf = function - | {fdrt; fdname; fdargs; fdbody; _} -> ( - let pp_body_opt ppf = function - | None -> Fmt.pf ppf ";" - | Some body -> pp_s ppf body in - match fdrt with - | Some rt -> - Fmt.pf ppf "@[%a %s%a {@ %a@]@ }" UnsizedType.pp rt fdname - Fmt.(list pp_fun_arg_decl ~sep:comma |> parens) - fdargs pp_body_opt fdbody - | None -> - Fmt.pf ppf "@[void %s%a {@ %a@]@ }" fdname - Fmt.(list pp_fun_arg_decl ~sep:comma |> parens) - fdargs pp_body_opt fdbody ) +let pp_fun_def pp_s ppf {fdrt; fdname; fdargs; fdbody; _} = + match fdbody with + | None -> + Fmt.pf ppf "@[extern %a %s%a;@]" UnsizedType.pp_returntype fdrt fdname + Fmt.(list pp_fun_arg_decl ~sep:comma |> parens) + fdargs + | Some s -> + Fmt.pf ppf "@[%a %s%a {@ %a@]@ }" UnsizedType.pp_returntype fdrt + fdname + Fmt.(list pp_fun_arg_decl ~sep:comma |> parens) + fdargs pp_s s let pp_io_block ppf = function | Parameters -> Fmt.string ppf "parameters" diff --git a/src/middle/UnsizedType.ml b/src/middle/UnsizedType.ml index 7437c9c15f..6df9f46cab 100644 --- a/src/middle/UnsizedType.ml +++ b/src/middle/UnsizedType.ml @@ -24,6 +24,8 @@ and autodifftype = DataOnly | AutoDiffable and returntype = Void | ReturnType of t [@@deriving compare, hash, sexp, equal] +let returntype_to_type_opt = function Void -> None | ReturnType t -> Some t + let pp_autodifftype ppf = function | DataOnly -> Fmt.string ppf "data " | AutoDiffable -> () diff --git a/src/stan_math_backend/Lower_functions.ml b/src/stan_math_backend/Lower_functions.ml index 7abf7f4b0c..81d7405dd0 100644 --- a/src/stan_math_backend/Lower_functions.ml +++ b/src/stan_math_backend/Lower_functions.ml @@ -88,9 +88,10 @@ promotion.*) let lower_returntype arg_types rt = let scalar = lower_promoted_scalar arg_types in match rt with - | Some ut when UnsizedType.is_int_type ut -> lower_type ut Int - | Some ut -> lower_type ut scalar - | None -> Void + | UnsizedType.ReturnType ut when UnsizedType.is_int_type ut -> + lower_type ut Int + | ReturnType ut -> lower_type ut scalar + | Void -> Void let lower_eigen_args_to_ref arg_types = let lower_ref name = @@ -269,7 +270,7 @@ let collect_functors_functions (p : Program.Numbered.t) : defn list = List.equal UnsizedType.equal (List.map ~f:(fun (_, _, t) -> t) fdargs) arg_types in - let declare_and_define (d : _ Program.fun_def) = + let register_functors (d : _ Program.fun_def) = let functors = Map.find_multi functor_required d.fdname |> List.stable_dedup @@ -281,20 +282,21 @@ let collect_functors_functions (p : Program.Numbered.t) : defn list = Hashtbl.update structs s.struct_name ~f:(function | Some x -> {x with body= x.body @ s.body} | None -> s ) ) ; - let decl, defn = Cpp.split_fun_decl_defn fn in - (FunDef decl, FunDef defn) in + fn in let fun_decls, fun_defns = p.functions_block |> List.filter_map ~f:(fun d -> - (* FIXME: external functions don't need decls - but they do need structs (when used in HOF) *) - if Option.is_none d.fdbody then None else Some (declare_and_define d) ) + let fn = register_functors d in + if Option.is_none d.fdbody then None + else + let decl, defn = Cpp.split_fun_decl_defn fn in + Some (FunDef decl, FunDef defn) ) |> List.unzip in let structs = Hashtbl.data structs |> List.map ~f:(fun s -> Struct s) in fun_decls @ structs @ fun_defns let lower_standalone_fun_def namespace_fun - Program.{fdname; fdsuffix; fdargs; fdbody; fdrt; _} = + Program.{fdname; fdsuffix; fdargs; fdrt; _} = let extra, extra_templates = match fdsuffix with | Fun_kind.FnTarget -> @@ -313,24 +315,21 @@ let lower_standalone_fun_def namespace_fun let mark_function_comment = GlobalComment "[[stan::function]]" in let return_type, return_stmt = match fdrt with - | None -> (Void, fun e -> Expression e) + | Void -> (Void, fun e -> Expression e) | _ -> (Auto, fun e -> Return (Some e)) in let fn_sig = make_fun_defn ~name:fdname ~return_type ~args:all_args in - match fdbody with - | None -> [] - | Some _ -> - let internal_fname = namespace_fun ^ "::" ^ fdname in - let template = - match fdsuffix with - | FnLpdf _ | FnTarget -> [TypeLiteral "false"] - | FnRng | FnPlain -> [] in - let call_args = - List.map ~f:(fun (_, name, _) -> name) fdargs @ extra @ ["pstream__"] - |> List.map ~f:Exprs.to_var in - let ret = - return_stmt (Exprs.templated_fun_call internal_fname template call_args) - in - [mark_function_comment; FunDef (fn_sig ~body:[ret] ())] + let internal_fname = namespace_fun ^ "::" ^ fdname in + let template = + match fdsuffix with + | FnLpdf _ | FnTarget -> [TypeLiteral "false"] + | FnRng | FnPlain -> [] in + let call_args = + List.map ~f:(fun (_, name, _) -> name) fdargs @ extra @ ["pstream__"] + |> List.map ~f:Exprs.to_var in + let ret = + return_stmt (Exprs.templated_fun_call internal_fname template call_args) + in + [mark_function_comment; FunDef (fn_sig ~body:[ret] ())] module Testing = struct (* Testing code *) @@ -347,7 +346,7 @@ module Testing = struct let with_no_loc stmt = Stmt.Fixed.{pattern= stmt; meta= Numbering.no_span_num} in let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in - { fdrt= None + { fdrt= Void ; fdname= "sars" ; fdsuffix= FnPlain ; fdargs= [(DataOnly, "x", UMatrix); (AutoDiffable, "y", URowVector)] @@ -404,7 +403,7 @@ module Testing = struct let with_no_loc stmt = Stmt.Fixed.{pattern= stmt; meta= Numbering.no_span_num} in let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in - { fdrt= Some UMatrix + { fdrt= ReturnType UMatrix ; fdname= "sars" ; fdsuffix= FnPlain ; fdargs= diff --git a/src/stan_math_backend/Transform_Mir.ml b/src/stan_math_backend/Transform_Mir.ml index 71466cdc02..f453319581 100644 --- a/src/stan_math_backend/Transform_Mir.ml +++ b/src/stan_math_backend/Transform_Mir.ml @@ -39,6 +39,90 @@ let eigen_block_expr_fns = ["head"; "tail"; "segment"; "col"; "row"; "block"; "sub_row"; "sub_col"] |> String.Set.of_list +(** Eval indexed eigen types in UDF calls to prevent + infinite template expansion if the call is recursive + + Infinite expansion can happen only when the call graph is cyclic. + The strategy here is to build the call graph one edge at the time + and check if adding that edge creates a cycle. If it does, insert + an [eval()] to stop template expansion. + + All relevant function calls are recorded transitively in [callgraph], + meaning if [A] calls [B] and [B] calls [C] then [callgraph[A] = {B,C}]. + In the worst case every function calls every other, [callgraph] has + size O(n^2) and this algorithm is O(n^3) so it's important to track + only the function calls that really can propagate eigen templates. +*) +let break_eigen_cycles functions_block = + let callgraph = String.Table.create () in + let eval_eigen_cycles fun_args calls (f : _ Program.fun_def) = + let open Expr.Fixed in + let rec is_potentially_recursive = function + | {pattern= Var name; _} -> Set.mem fun_args name + | {pattern= Indexed (e, _); _} -> is_potentially_recursive e + | {pattern= FunApp (StanLib (fname, _, _), e :: _); _} -> + Set.mem eigen_block_expr_fns fname && is_potentially_recursive e + | _ -> false in + let rec map_args name args = + let args = List.map ~f:rewrite_expr args in + let can_recurse, eval_args = + List.fold_map ~init:false + ~f:(fun is_rec e -> + if is_potentially_recursive e then + ( true + , {e with pattern= FunApp (StanLib ("eval", FnPlain, AoS), [e])} + ) + else (is_rec, e) ) + args in + if not can_recurse then args + else if name = f.fdname then eval_args + else + match Hashtbl.find callgraph name with + | Some nested when Hash_set.mem nested f.fdname -> eval_args + | Some nested -> + (* [calls] records all functions reachable from the current function *) + Hash_set.add calls name ; + Hash_set.iter nested ~f:(Hash_set.add calls) ; + args + | None -> Hash_set.add calls name ; args + and rewrite_expr : Expr.Typed.t -> Expr.Typed.t = function + | {pattern= FunApp ((UserDefined (name, _) as kind), args); _} as e -> + {e with pattern= FunApp (kind, map_args name args)} + | { pattern= + FunApp + ( (StanLib (_, _, _) as kind) + , ({pattern= Var name; meta= {type_= UFun _; _}} as f) :: args ) + ; _ } as e -> + (* higher-order function -- just pretend it's a direct call *) + {e with pattern= FunApp (kind, f :: map_args name args)} + | e -> {e with pattern= Pattern.map rewrite_expr e.pattern} in + let rec rewrite_stmt s = + let open Stmt.Fixed in + match s with + | {pattern= Pattern.NRFunApp ((UserDefined (name, _) as kind), args); _} + as s -> + {s with pattern= NRFunApp (kind, map_args name args)} + | s -> {s with pattern= Pattern.map rewrite_expr rewrite_stmt s.pattern} + in + Program.map_fun_def rewrite_stmt f in + let break_cycles (Program.{fdname; fdargs; _} as fd) = + let fun_args = + List.filter_map fdargs ~f:(fun (_, n, t) -> + if UnsizedType.is_eigen_type t then Some n else None ) + |> String.Set.of_list in + if Set.is_empty fun_args then fd + else + let calls = String.Hash_set.create () in + let fndef = eval_eigen_cycles fun_args calls fd in + if not (Hash_set.is_empty calls) then ( + (* update [callgraph] with the call paths going through the current function *) + Hashtbl.map_inplace callgraph ~f:(fun x -> + if Hash_set.mem x fdname then Hash_set.union calls x else x ) ; + Hashtbl.update callgraph fdname + ~f:(Option.value_map ~f:(Hash_set.union calls) ~default:calls) ) ; + fndef in + List.map ~f:break_cycles functions_block + let opencl_trigger_restrictions = String.Map.of_alist_exn [ ( "bernoulli_lpmf" @@ -464,88 +548,7 @@ let trans_prog (p : Program.Typed.t) = ; functions_block= List.map ~f:rename_func p.functions_block } |> map translate_funapps_and_kwrds map_stmt |> map Fn.id change_kwrds_stmts) in - (* Eval indexed eigen types in UDF calls to prevent - infinite template expansion if the call is recursive - - Infinite expansion can happen only when the call graph is cyclic. - The strategy here is to build the call graph one edge at the time - and check if adding that edge creates a cycle. If it does, insert - an `eval()` to stop template expansion. - - All relevant function calls are recorded transitively in `callgraph`, - meaning if `A` calls `B` and `B` calls `C` then `callgraph[A] = {B,C}`. - In the worst case every function calls every other, `callgraph` has - size O(n^2) and this algorithm is O(n^3) so it's important to track - only the function calls that really can propagate eigen templates. - *) - let callgraph = String.Table.create () in - let eval_eigen_cycles fun_args calls (f : _ Program.fun_def) = - let open Expr.Fixed in - let rec is_potentially_recursive = function - | {pattern= Var name; _} -> Set.mem fun_args name - | {pattern= Indexed (e, _); _} -> is_potentially_recursive e - | {pattern= FunApp (StanLib (fname, _, _), e :: _); _} -> - Set.mem eigen_block_expr_fns fname && is_potentially_recursive e - | _ -> false in - let rec map_args name args = - let args = List.map ~f:rewrite_expr args in - let can_recurse, eval_args = - List.fold_map ~init:false - ~f:(fun is_rec e -> - if is_potentially_recursive e then - ( true - , {e with pattern= FunApp (StanLib ("eval", FnPlain, AoS), [e])} - ) - else (is_rec, e) ) - args in - if not can_recurse then args - else if name = f.fdname then eval_args - else - match Hashtbl.find callgraph name with - | Some nested when Hash_set.mem nested f.fdname -> eval_args - | Some nested -> - (* `calls` records all functions reachable from the current function *) - Hash_set.add calls name ; - Hash_set.iter nested ~f:(Hash_set.add calls) ; - args - | None -> Hash_set.add calls name ; args - and rewrite_expr : Expr.Typed.t -> Expr.Typed.t = function - | {pattern= FunApp ((UserDefined (name, _) as kind), args); _} as e -> - {e with pattern= FunApp (kind, map_args name args)} - | { pattern= - FunApp - ( (StanLib (_, _, _) as kind) - , ({pattern= Var name; meta= {type_= UFun _; _}} as f) :: args ) - ; _ } as e -> - (* higher-order function -- just pretend it's a direct call *) - {e with pattern= FunApp (kind, f :: map_args name args)} - | e -> {e with pattern= Pattern.map rewrite_expr e.pattern} in - let rec rewrite_stmt s = - let open Stmt.Fixed in - match s with - | {pattern= Pattern.NRFunApp ((UserDefined (name, _) as kind), args); _} - as s -> - {s with pattern= NRFunApp (kind, map_args name args)} - | s -> {s with pattern= Pattern.map rewrite_expr rewrite_stmt s.pattern} - in - Program.map_fun_def rewrite_stmt f in - let break_cycles (Program.{fdname; fdargs; _} as fd) = - let fun_args = - List.filter_map fdargs ~f:(fun (_, n, t) -> - if UnsizedType.is_eigen_type t then Some n else None ) - |> String.Set.of_list in - if Set.is_empty fun_args then fd - else - let calls = String.Hash_set.create () in - let fndef = eval_eigen_cycles fun_args calls fd in - if not (Hash_set.is_empty calls) then ( - (* update `callgraph` with the call paths going through the current function *) - Hashtbl.map_inplace callgraph ~f:(fun x -> - if Hash_set.mem x fdname then Hash_set.union calls x else x ) ; - Hashtbl.update callgraph fdname - ~f:(Option.value_map ~f:(Hash_set.union calls) ~default:calls) ) ; - fndef in - let p = {p with functions_block= List.map ~f:break_cycles p.functions_block} in + let p = {p with functions_block= break_eigen_cycles p.functions_block} in let init_pos = [ Stmt.Fixed.Pattern.Decl { decl_adtype= DataOnly diff --git a/test/integration/cli-args/allow-undefined/cpp.expected b/test/integration/cli-args/allow-undefined/cpp.expected new file mode 100644 index 0000000000..d55fd77b01 --- /dev/null +++ b/test/integration/cli-args/allow-undefined/cpp.expected @@ -0,0 +1,307 @@ + $ ../../../../../install/default/bin/stanc --allow-undefined --print-cpp external.stan +// Code generated by %%NAME%% %%VERSION%% +#include +namespace external_model_namespace { +using stan::model::model_base_crtp; +using namespace stan::math; +stan::math::profile_map profiles__; +static constexpr std::array locations_array__ = + {" (found before start of program)", + " (in 'external.stan', line 10, column 4 to column 66)", + " (in 'external.stan', line 9, column 63 to line 11, column 3)"}; +double +internal_fun(const std::vector>& a, + const std::vector>& d, std::ostream* pstream__); +struct external_map_rectable_functor__ { + template , + stan::is_vt_not_complex, + stan::is_col_vector, + stan::is_vt_not_complex>* = nullptr> + Eigen::Matrix, + stan::base_type_t>,-1,1> + operator()(const T0__& phi, const T1__& theta, const std::vector& + x_r, const std::vector& x_i, std::ostream* pstream__) const { + return external_map_rectable(phi, theta, x_r, x_i, pstream__); + } +}; +double +internal_fun(const std::vector>& a, + const std::vector>& d, std::ostream* pstream__) { + using local_scalar_t__ = double; + int current_statement__ = 0; + static constexpr bool propto__ = true; + // suppress unused var warning + (void) propto__; + local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN()); + // suppress unused var warning + (void) DUMMY_VAR__; + try { + current_statement__ = 1; + return stan::math::sum( + stan::math::map_rect<1, external_map_rectable_functor__>( + Eigen::Matrix(0), + std::vector>{Eigen::Matrix(0)}, + a, d, pstream__)); + } catch (const std::exception& e) { + stan::lang::rethrow_located(e, locations_array__[current_statement__]); + } +} +class external_model final : public model_base_crtp { + private: + + public: + ~external_model() {} + external_model(stan::io::var_context& context__, unsigned int + random_seed__ = 0, std::ostream* pstream__ = nullptr) + : model_base_crtp(0) { + int current_statement__ = 0; + using local_scalar_t__ = double; + boost::ecuyer1988 base_rng__ = + stan::services::util::create_rng(random_seed__, 0); + // suppress unused var warning + (void) base_rng__; + static constexpr const char* function__ = + "external_model_namespace::external_model"; + // suppress unused var warning + (void) function__; + local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN()); + // suppress unused var warning + (void) DUMMY_VAR__; + try { + int pos__ = std::numeric_limits::min(); + pos__ = 1; + } catch (const std::exception& e) { + stan::lang::rethrow_located(e, locations_array__[current_statement__]); + } + num_params_r__ = 0U; + } + inline std::string model_name() const final { + return "external_model"; + } + inline std::vector model_compile_info() const noexcept { + return std::vector{"stanc_version = %%NAME%%3 %%VERSION%%", + "stancflags = --allow-undefined --print-cpp"}; + } + template * = nullptr, + stan::require_vector_like_vt* = nullptr> + inline stan::scalar_type_t + log_prob_impl(VecR& params_r__, VecI& params_i__, std::ostream* + pstream__ = nullptr) const { + using T__ = stan::scalar_type_t; + using local_scalar_t__ = T__; + T__ lp__(0.0); + stan::math::accumulator lp_accum__; + stan::io::deserializer in__(params_r__, params_i__); + int current_statement__ = 0; + local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN()); + // suppress unused var warning + (void) DUMMY_VAR__; + static constexpr const char* function__ = + "external_model_namespace::log_prob"; + // suppress unused var warning + (void) function__; + try { + + } catch (const std::exception& e) { + stan::lang::rethrow_located(e, locations_array__[current_statement__]); + } + lp_accum__.add(lp__); + return lp_accum__.sum(); + } + template * = nullptr, stan::require_vector_like_vt* = nullptr, stan::require_vector_vt* = nullptr> + inline void + write_array_impl(RNG& base_rng__, VecR& params_r__, VecI& params_i__, + VecVar& vars__, const bool + emit_transformed_parameters__ = true, const bool + emit_generated_quantities__ = true, std::ostream* + pstream__ = nullptr) const { + using local_scalar_t__ = double; + stan::io::deserializer in__(params_r__, params_i__); + stan::io::serializer out__(vars__); + static constexpr bool propto__ = true; + // suppress unused var warning + (void) propto__; + double lp__ = 0.0; + // suppress unused var warning + (void) lp__; + int current_statement__ = 0; + stan::math::accumulator lp_accum__; + local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN()); + // suppress unused var warning + (void) DUMMY_VAR__; + constexpr bool jacobian__ = false; + static constexpr const char* function__ = + "external_model_namespace::write_array"; + // suppress unused var warning + (void) function__; + try { + if (stan::math::logical_negation( + (stan::math::primitive_value(emit_transformed_parameters__) || + stan::math::primitive_value(emit_generated_quantities__)))) { + return ; + } + if (stan::math::logical_negation(emit_generated_quantities__)) { + return ; + } + } catch (const std::exception& e) { + stan::lang::rethrow_located(e, locations_array__[current_statement__]); + } + } + template * = nullptr, + stan::require_vector_like_vt* = nullptr> + inline void + transform_inits_impl(VecVar& params_r__, VecI& params_i__, VecVar& vars__, + std::ostream* pstream__ = nullptr) const { + using local_scalar_t__ = double; + stan::io::deserializer in__(params_r__, params_i__); + stan::io::serializer out__(vars__); + int current_statement__ = 0; + local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN()); + // suppress unused var warning + (void) DUMMY_VAR__; + try { + int pos__ = std::numeric_limits::min(); + pos__ = 1; + } catch (const std::exception& e) { + stan::lang::rethrow_located(e, locations_array__[current_statement__]); + } + } + inline void + get_param_names(std::vector& names__, const bool + emit_transformed_parameters__ = true, const bool + emit_generated_quantities__ = true) const { + names__ = std::vector{}; + if (emit_transformed_parameters__) {} + if (emit_generated_quantities__) {} + } + inline void + get_dims(std::vector>& dimss__, const bool + emit_transformed_parameters__ = true, const bool + emit_generated_quantities__ = true) const { + dimss__ = std::vector>{}; + if (emit_transformed_parameters__) {} + if (emit_generated_quantities__) {} + } + inline void + constrained_param_names(std::vector& param_names__, bool + emit_transformed_parameters__ = true, bool + emit_generated_quantities__ = true) const final { + if (emit_transformed_parameters__) {} + if (emit_generated_quantities__) {} + } + inline void + unconstrained_param_names(std::vector& param_names__, bool + emit_transformed_parameters__ = true, bool + emit_generated_quantities__ = true) const final { + if (emit_transformed_parameters__) {} + if (emit_generated_quantities__) {} + } + inline std::string get_constrained_sizedtypes() const { + return std::string("[]"); + } + inline std::string get_unconstrained_sizedtypes() const { + return std::string("[]"); + } + // Begin method overload boilerplate + template inline void + write_array(RNG& base_rng, Eigen::Matrix& params_r, + Eigen::Matrix& vars, const bool + emit_transformed_parameters = true, const bool + emit_generated_quantities = true, std::ostream* + pstream = nullptr) const { + const size_t num_params__ = 0; + const size_t num_transformed = emit_transformed_parameters * (0); + const size_t num_gen_quantities = emit_generated_quantities * (0); + const size_t num_to_write = num_params__ + num_transformed + + num_gen_quantities; + std::vector params_i; + vars = Eigen::Matrix::Constant(num_to_write, + std::numeric_limits::quiet_NaN()); + write_array_impl(base_rng, params_r, params_i, vars, + emit_transformed_parameters, emit_generated_quantities, pstream); + } + template inline void + write_array(RNG& base_rng, std::vector& params_r, std::vector& + params_i, std::vector& vars, bool + emit_transformed_parameters = true, bool + emit_generated_quantities = true, std::ostream* + pstream = nullptr) const { + const size_t num_params__ = 0; + const size_t num_transformed = emit_transformed_parameters * (0); + const size_t num_gen_quantities = emit_generated_quantities * (0); + const size_t num_to_write = num_params__ + num_transformed + + num_gen_quantities; + vars = std::vector(num_to_write, + std::numeric_limits::quiet_NaN()); + write_array_impl(base_rng, params_r, params_i, vars, + emit_transformed_parameters, emit_generated_quantities, pstream); + } + template inline T_ + log_prob(Eigen::Matrix& params_r, std::ostream* pstream = nullptr) const { + Eigen::Matrix params_i; + return log_prob_impl(params_r, params_i, pstream); + } + template inline T_ + log_prob(std::vector& params_r, std::vector& params_i, + std::ostream* pstream = nullptr) const { + return log_prob_impl(params_r, params_i, pstream); + } + inline void + transform_inits(const stan::io::var_context& context, + Eigen::Matrix& params_r, std::ostream* + pstream = nullptr) const final { + std::vector params_r_vec(params_r.size()); + std::vector params_i; + transform_inits(context, params_i, params_r_vec, pstream); + params_r = Eigen::Map>(params_r_vec.data(), + params_r_vec.size()); + } + inline void + transform_inits(const stan::io::var_context& context, std::vector& + params_i, std::vector& vars, std::ostream* + pstream__ = nullptr) const { + constexpr std::array names__{}; + const std::array constrain_param_sizes__{}; + const auto num_constrained_params__ = + std::accumulate(constrain_param_sizes__.begin(), + constrain_param_sizes__.end(), 0); + std::vector params_r_flat__(num_constrained_params__); + Eigen::Index size_iter__ = 0; + Eigen::Index flat_iter__ = 0; + for (auto&& param_name__: names__) { + const auto param_vec__ = context.vals_r(param_name__); + for (Eigen::Index i = 0; i < constrain_param_sizes__[size_iter__]; ++i) { + params_r_flat__[flat_iter__] = param_vec__[i]; + ++flat_iter__; + } + ++size_iter__; + } + vars.resize(num_params_r__); + transform_inits_impl(params_r_flat__, params_i, vars, pstream__); + } +}; +} +using stan_model = external_model_namespace::external_model; +#ifndef USING_R +// Boilerplate +stan::model::model_base& +new_model(stan::io::var_context& data_context, unsigned int seed, + std::ostream* msg_stream) { + stan_model* m = new stan_model(data_context, seed, msg_stream); + return *m; +} +stan::math::profile_map& get_stan_profile_data() { + return external_model_namespace::profiles__; +} +#endif +STAN_REGISTER_MAP_RECT(1, external_model_namespace::external_map_rectable_functor__) +Warning in 'external.stan', line 7, column 7: Functions do not need to be + declared before definition; all user defined function names are always in + scope regardless of defintion order. diff --git a/test/integration/cli-args/allow-undefined/dune b/test/integration/cli-args/allow-undefined/dune new file mode 100644 index 0000000000..3ca558c406 --- /dev/null +++ b/test/integration/cli-args/allow-undefined/dune @@ -0,0 +1,37 @@ +(rule + (targets standalone-cpp.output) + (deps + (package stanc) + (:stanfiles + (glob_files *.stan*))) + (action + (with-stdout-to + %{targets} + (run + %{bin:run_bin_on_args} + "%{bin:stanc} --standalone-functions --allow-undefined --print-cpp" + %{stanfiles})))) + +(rule + (alias runtest) + (action + (diff standalone-cpp.expected standalone-cpp.output))) + +(rule + (targets cpp.output) + (deps + (package stanc) + (:stanfiles + (glob_files *.stan*))) + (action + (with-stdout-to + %{targets} + (run + %{bin:run_bin_on_args} + "%{bin:stanc} --allow-undefined --print-cpp" + %{stanfiles})))) + +(rule + (alias runtest) + (action + (diff cpp.expected cpp.output))) diff --git a/test/integration/cli-args/allow-undefined/external.stan b/test/integration/cli-args/allow-undefined/external.stan new file mode 100644 index 0000000000..905325068f --- /dev/null +++ b/test/integration/cli-args/allow-undefined/external.stan @@ -0,0 +1,12 @@ +functions { + real external_fun(real a); + + vector external_map_rectable(vector phi, vector theta, + data array[] real x_r, data array[] int x_i); + + real internal_fun(data array[,] real a, data array[,] int d); + // forward decl, shouldn't duplicate in output + real internal_fun(data array[,] real a, data array[,] int d) { + return sum(map_rect(external_map_rectable, []', {[]'}, a, d)); + } +} diff --git a/test/integration/cli-args/allow-undefined/standalone-cpp.expected b/test/integration/cli-args/allow-undefined/standalone-cpp.expected new file mode 100644 index 0000000000..8da7ec2014 --- /dev/null +++ b/test/integration/cli-args/allow-undefined/standalone-cpp.expected @@ -0,0 +1,73 @@ + $ ../../../../../install/default/bin/stanc --standalone-functions --allow-undefined --print-cpp external.stan +// Code generated by %%NAME%% %%VERSION%% +#include +namespace external_model_namespace { +using stan::model::model_base_crtp; +using namespace stan::math; +stan::math::profile_map profiles__; +static constexpr std::array locations_array__ = + {" (found before start of program)", + " (in 'external.stan', line 10, column 4 to column 66)", + " (in 'external.stan', line 9, column 63 to line 11, column 3)"}; +double +internal_fun(const std::vector>& a, + const std::vector>& d, std::ostream* pstream__); +struct external_map_rectable_functor__ { + template , + stan::is_vt_not_complex, + stan::is_col_vector, + stan::is_vt_not_complex>* = nullptr> + Eigen::Matrix, + stan::base_type_t>,-1,1> + operator()(const T0__& phi, const T1__& theta, const std::vector& + x_r, const std::vector& x_i, std::ostream* pstream__) const { + return external_map_rectable(phi, theta, x_r, x_i, pstream__); + } +}; +double +internal_fun(const std::vector>& a, + const std::vector>& d, std::ostream* pstream__) { + using local_scalar_t__ = double; + int current_statement__ = 0; + static constexpr bool propto__ = true; + // suppress unused var warning + (void) propto__; + local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN()); + // suppress unused var warning + (void) DUMMY_VAR__; + try { + current_statement__ = 1; + return stan::math::sum( + stan::math::map_rect<1, external_map_rectable_functor__>( + Eigen::Matrix(0), + std::vector>{Eigen::Matrix(0)}, + a, d, pstream__)); + } catch (const std::exception& e) { + stan::lang::rethrow_located(e, locations_array__[current_statement__]); + } +} +} +// [[stan::function]] +auto external_fun(const double& a, std::ostream* pstream__ = nullptr) { + return external_model_namespace::external_fun(a, pstream__); +} +// [[stan::function]] +auto +external_map_rectable(const Eigen::Matrix& phi, + const Eigen::Matrix& theta, + const std::vector& x_r, const std::vector& + x_i, std::ostream* pstream__ = nullptr) { + return external_model_namespace::external_map_rectable(phi, theta, x_r, + x_i, pstream__); +} +// [[stan::function]] +auto +internal_fun(const std::vector>& a, + const std::vector>& d, std::ostream* + pstream__ = nullptr) { + return external_model_namespace::internal_fun(a, d, pstream__); +} +Warning in 'external.stan', line 7, column 7: Functions do not need to be + declared before definition; all user defined function names are always in + scope regardless of defintion order. diff --git a/test/integration/good/code-gen/mir.expected b/test/integration/good/code-gen/mir.expected index 5fbfd70c1e..fc1e6e3fcf 100644 --- a/test/integration/good/code-gen/mir.expected +++ b/test/integration/good/code-gen/mir.expected @@ -1,6 +1,6 @@ $ ../../../../../install/default/bin/stanc --debug-mir mother.stan ((functions_block - (((fdrt (UInt)) (fdname foo) (fdsuffix FnPlain) + (((fdrt (ReturnType UInt)) (fdname foo) (fdsuffix FnPlain) (fdargs ((AutoDiffable n UInt))) (fdbody (((pattern @@ -44,7 +44,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray UReal))) (fdname sho) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray UReal))) (fdname sho) (fdsuffix FnPlain) (fdargs ((AutoDiffable t UReal) (AutoDiffable y (UArray UReal)) (AutoDiffable theta (UArray UReal)) (DataOnly x (UArray UReal)) @@ -141,7 +141,8 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_bar0) (fdsuffix FnPlain) (fdargs ()) + ((fdrt (ReturnType UReal)) (fdname foo_bar0) (fdsuffix FnPlain) + (fdargs ()) (fdbody (((pattern (Block @@ -152,7 +153,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_bar1) (fdsuffix FnPlain) + ((fdrt (ReturnType UReal)) (fdname foo_bar1) (fdsuffix FnPlain) (fdargs ((AutoDiffable x UReal))) (fdbody (((pattern @@ -164,7 +165,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_bar2) (fdsuffix FnPlain) + ((fdrt (ReturnType UReal)) (fdname foo_bar2) (fdsuffix FnPlain) (fdargs ((AutoDiffable x UReal) (AutoDiffable y UReal))) (fdbody (((pattern @@ -176,7 +177,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_lpmf) (fdsuffix (FnLpdf ())) + ((fdrt (ReturnType UReal)) (fdname foo_lpmf) (fdsuffix (FnLpdf ())) (fdargs ((AutoDiffable y UInt) (AutoDiffable lambda UReal))) (fdbody (((pattern @@ -188,7 +189,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_lcdf) (fdsuffix FnPlain) + ((fdrt (ReturnType UReal)) (fdname foo_lcdf) (fdsuffix FnPlain) (fdargs ((AutoDiffable y UInt) (AutoDiffable lambda UReal))) (fdbody (((pattern @@ -200,7 +201,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_lccdf) (fdsuffix FnPlain) + ((fdrt (ReturnType UReal)) (fdname foo_lccdf) (fdsuffix FnPlain) (fdargs ((AutoDiffable y UInt) (AutoDiffable lambda UReal))) (fdbody (((pattern @@ -212,7 +213,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_rng) (fdsuffix FnRng) + ((fdrt (ReturnType UReal)) (fdname foo_rng) (fdsuffix FnRng) (fdargs ((AutoDiffable mu UReal) (AutoDiffable sigma UReal))) (fdbody (((pattern @@ -231,7 +232,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ()) (fdname unit_normal_lp) (fdsuffix FnTarget) + ((fdrt Void) (fdname unit_normal_lp) (fdsuffix FnTarget) (fdargs ((AutoDiffable u UReal))) (fdbody (((pattern @@ -275,7 +276,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UInt)) (fdname foo_1) (fdsuffix FnPlain) + ((fdrt (ReturnType UInt)) (fdname foo_1) (fdsuffix FnPlain) (fdargs ((AutoDiffable a UInt))) (fdbody (((pattern @@ -1136,7 +1137,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UInt)) (fdname foo_2) (fdsuffix FnPlain) + ((fdrt (ReturnType UInt)) (fdname foo_2) (fdsuffix FnPlain) (fdargs ((AutoDiffable a UInt))) (fdbody (((pattern @@ -1207,7 +1208,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray UReal))) (fdname foo_3) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray UReal))) (fdname foo_3) (fdsuffix FnPlain) (fdargs ((AutoDiffable t UReal) (AutoDiffable n UInt))) (fdbody (((pattern @@ -1227,7 +1228,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_lp) (fdsuffix FnTarget) + ((fdrt (ReturnType UReal)) (fdname foo_lp) (fdsuffix FnTarget) (fdargs ((AutoDiffable x UReal))) (fdbody (((pattern @@ -1246,7 +1247,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ()) (fdname foo_4) (fdsuffix FnPlain) + ((fdrt Void) (fdname foo_4) (fdsuffix FnPlain) (fdargs ((AutoDiffable x (UArray UReal)))) (fdbody (((pattern @@ -1274,7 +1275,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname relative_diff) (fdsuffix FnPlain) + ((fdrt (ReturnType UReal)) (fdname relative_diff) (fdsuffix FnPlain) (fdargs ((AutoDiffable x UReal) (AutoDiffable y UReal) (AutoDiffable max_ UReal) (AutoDiffable min_ UReal))) @@ -1429,7 +1430,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UVector)) (fdname foo_5) (fdsuffix FnPlain) + ((fdrt (ReturnType UVector)) (fdname foo_5) (fdsuffix FnPlain) (fdargs ((AutoDiffable shared_params UVector) (AutoDiffable job_params UVector) (DataOnly data_r (UArray UReal)) (DataOnly data_i (UArray UInt)))) @@ -1472,7 +1473,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_five_args) (fdsuffix FnPlain) + ((fdrt (ReturnType UReal)) (fdname foo_five_args) (fdsuffix FnPlain) (fdargs ((AutoDiffable x1 UReal) (AutoDiffable x2 UReal) (AutoDiffable x3 UReal) (AutoDiffable x4 UReal) (AutoDiffable x5 UReal))) @@ -1486,7 +1487,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_five_args_lp) (fdsuffix FnTarget) + ((fdrt (ReturnType UReal)) (fdname foo_five_args_lp) (fdsuffix FnTarget) (fdargs ((AutoDiffable x1 UReal) (AutoDiffable x2 UReal) (AutoDiffable x3 UReal) (AutoDiffable x4 UReal) (AutoDiffable x5 UReal) @@ -1501,7 +1502,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UMatrix)) (fdname covsqrt2corsqrt) (fdsuffix FnPlain) + ((fdrt (ReturnType UMatrix)) (fdname covsqrt2corsqrt) (fdsuffix FnPlain) (fdargs ((AutoDiffable mat UMatrix) (AutoDiffable invert UInt))) (fdbody (((pattern @@ -1600,7 +1601,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ()) (fdname f0) (fdsuffix FnPlain) + ((fdrt Void) (fdname f0) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1620,7 +1621,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UInt)) (fdname f1) (fdsuffix FnPlain) + ((fdrt (ReturnType UInt)) (fdname f1) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1640,7 +1641,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray UInt))) (fdname f2) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray UInt))) (fdname f2) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1661,7 +1662,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray (UArray UInt)))) (fdname f3) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray (UArray UInt)))) (fdname f3) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1683,7 +1684,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname f4) (fdsuffix FnPlain) + ((fdrt (ReturnType UReal)) (fdname f4) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1703,7 +1704,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray UReal))) (fdname f5) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray UReal))) (fdname f5) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1725,7 +1726,8 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray (UArray UReal)))) (fdname f6) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray (UArray UReal)))) (fdname f6) + (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1747,7 +1749,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UVector)) (fdname f7) (fdsuffix FnPlain) + ((fdrt (ReturnType UVector)) (fdname f7) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1767,7 +1769,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray UVector))) (fdname f8) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray UVector))) (fdname f8) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1789,7 +1791,8 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray (UArray UVector)))) (fdname f9) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray (UArray UVector)))) (fdname f9) + (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1811,7 +1814,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UMatrix)) (fdname f10) (fdsuffix FnPlain) + ((fdrt (ReturnType UMatrix)) (fdname f10) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1831,7 +1834,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray UMatrix))) (fdname f11) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray UMatrix))) (fdname f11) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1853,7 +1856,8 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray (UArray UMatrix)))) (fdname f12) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray (UArray UMatrix)))) (fdname f12) + (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1875,7 +1879,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ()) (fdname foo_6) (fdsuffix FnPlain) (fdargs ()) + ((fdrt Void) (fdname foo_6) (fdsuffix FnPlain) (fdargs ()) (fdbody (((pattern (Block @@ -1936,7 +1940,8 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UMatrix)) (fdname matfoo) (fdsuffix FnPlain) (fdargs ()) + ((fdrt (ReturnType UMatrix)) (fdname matfoo) (fdsuffix FnPlain) + (fdargs ()) (fdbody (((pattern (Block @@ -2200,7 +2205,8 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UVector)) (fdname vecfoo) (fdsuffix FnPlain) (fdargs ()) + ((fdrt (ReturnType UVector)) (fdname vecfoo) (fdsuffix FnPlain) + (fdargs ()) (fdbody (((pattern (Block @@ -2296,7 +2302,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UVector)) (fdname vecmufoo) (fdsuffix FnPlain) + ((fdrt (ReturnType UVector)) (fdname vecmufoo) (fdsuffix FnPlain) (fdargs ((AutoDiffable mu UReal))) (fdbody (((pattern @@ -2328,7 +2334,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UVector)) (fdname vecmubar) (fdsuffix FnPlain) + ((fdrt (ReturnType UVector)) (fdname vecmubar) (fdsuffix FnPlain) (fdargs ((AutoDiffable mu UReal))) (fdbody (((pattern @@ -2480,7 +2486,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UVector)) (fdname algebra_system) (fdsuffix FnPlain) + ((fdrt (ReturnType UVector)) (fdname algebra_system) (fdsuffix FnPlain) (fdargs ((AutoDiffable x UVector) (AutoDiffable y UVector) (AutoDiffable dat (UArray UReal)) (AutoDiffable dat_int (UArray UInt)))) @@ -2567,7 +2573,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UVector)) (fdname binomialf) (fdsuffix FnPlain) + ((fdrt (ReturnType UVector)) (fdname binomialf) (fdsuffix FnPlain) (fdargs ((AutoDiffable phi UVector) (AutoDiffable theta UVector) (DataOnly x_r (UArray UReal)) (DataOnly x_i (UArray UInt)))) diff --git a/test/integration/good/code-gen/standalone_functions/cpp.expected b/test/integration/good/code-gen/standalone_functions/cpp.expected index 81dee2deeb..61aaef2f35 100644 --- a/test/integration/good/code-gen/standalone_functions/cpp.expected +++ b/test/integration/good/code-gen/standalone_functions/cpp.expected @@ -1,4 +1,4 @@ - $ ../../../../../../install/default/bin/stanc --standalone-functions --print-cpp basic.stan + $ ../../../../../../install/default/bin/stanc --standalone-functions --allow-undefined --print-cpp basic.stan // Code generated by %%NAME%% %%VERSION%% #include namespace basic_model_namespace { @@ -281,7 +281,7 @@ auto test_lpdf(const double& a, const double& b, std::ostream* pstream__ = nullptr) { return basic_model_namespace::test_lpdf(a, b, pstream__); } - $ ../../../../../../install/default/bin/stanc --standalone-functions --print-cpp basic.stanfunctions + $ ../../../../../../install/default/bin/stanc --standalone-functions --allow-undefined --print-cpp basic.stanfunctions // Code generated by %%NAME%% %%VERSION%% #include namespace basic_model_namespace { @@ -564,7 +564,7 @@ auto test_lpdf(const double& a, const double& b, std::ostream* pstream__ = nullptr) { return basic_model_namespace::test_lpdf(a, b, pstream__); } - $ ../../../../../../install/default/bin/stanc --standalone-functions --print-cpp integrate.stan + $ ../../../../../../install/default/bin/stanc --standalone-functions --allow-undefined --print-cpp integrate.stan // Code generated by %%NAME%% %%VERSION%% #include namespace integrate_model_namespace { diff --git a/test/integration/good/code-gen/standalone_functions/dune b/test/integration/good/code-gen/standalone_functions/dune index 14987f158b..96b54d6819 100644 --- a/test/integration/good/code-gen/standalone_functions/dune +++ b/test/integration/good/code-gen/standalone_functions/dune @@ -9,7 +9,7 @@ %{targets} (run %{bin:run_bin_on_args} - "%{bin:stanc} --standalone-functions --print-cpp" + "%{bin:stanc} --standalone-functions --allow-undefined --print-cpp" %{stanfiles})))) (rule diff --git a/test/integration/good/code-gen/transformed_mir.expected b/test/integration/good/code-gen/transformed_mir.expected index fdc4b98f90..332fff6a22 100644 --- a/test/integration/good/code-gen/transformed_mir.expected +++ b/test/integration/good/code-gen/transformed_mir.expected @@ -1,6 +1,6 @@ $ ../../../../../install/default/bin/stanc --debug-transformed-mir mother.stan ((functions_block - (((fdrt (UInt)) (fdname foo) (fdsuffix FnPlain) + (((fdrt (ReturnType UInt)) (fdname foo) (fdsuffix FnPlain) (fdargs ((AutoDiffable n UInt))) (fdbody (((pattern @@ -47,7 +47,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray UReal))) (fdname sho) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray UReal))) (fdname sho) (fdsuffix FnPlain) (fdargs ((AutoDiffable t UReal) (AutoDiffable y (UArray UReal)) (AutoDiffable theta (UArray UReal)) (DataOnly x (UArray UReal)) @@ -144,7 +144,8 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_bar0) (fdsuffix FnPlain) (fdargs ()) + ((fdrt (ReturnType UReal)) (fdname foo_bar0) (fdsuffix FnPlain) + (fdargs ()) (fdbody (((pattern (Block @@ -155,7 +156,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_bar1) (fdsuffix FnPlain) + ((fdrt (ReturnType UReal)) (fdname foo_bar1) (fdsuffix FnPlain) (fdargs ((AutoDiffable x UReal))) (fdbody (((pattern @@ -167,7 +168,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_bar2) (fdsuffix FnPlain) + ((fdrt (ReturnType UReal)) (fdname foo_bar2) (fdsuffix FnPlain) (fdargs ((AutoDiffable x UReal) (AutoDiffable y UReal))) (fdbody (((pattern @@ -179,7 +180,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_lpmf) (fdsuffix (FnLpdf ())) + ((fdrt (ReturnType UReal)) (fdname foo_lpmf) (fdsuffix (FnLpdf ())) (fdargs ((AutoDiffable y UInt) (AutoDiffable lambda UReal))) (fdbody (((pattern @@ -191,7 +192,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_lcdf) (fdsuffix FnPlain) + ((fdrt (ReturnType UReal)) (fdname foo_lcdf) (fdsuffix FnPlain) (fdargs ((AutoDiffable y UInt) (AutoDiffable lambda UReal))) (fdbody (((pattern @@ -203,7 +204,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_lccdf) (fdsuffix FnPlain) + ((fdrt (ReturnType UReal)) (fdname foo_lccdf) (fdsuffix FnPlain) (fdargs ((AutoDiffable y UInt) (AutoDiffable lambda UReal))) (fdbody (((pattern @@ -215,7 +216,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_rng) (fdsuffix FnRng) + ((fdrt (ReturnType UReal)) (fdname foo_rng) (fdsuffix FnRng) (fdargs ((AutoDiffable mu UReal) (AutoDiffable sigma UReal))) (fdbody (((pattern @@ -234,7 +235,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ()) (fdname unit_normal_lp) (fdsuffix FnTarget) + ((fdrt Void) (fdname unit_normal_lp) (fdsuffix FnTarget) (fdargs ((AutoDiffable u UReal))) (fdbody (((pattern @@ -278,7 +279,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UInt)) (fdname foo_1) (fdsuffix FnPlain) + ((fdrt (ReturnType UInt)) (fdname foo_1) (fdsuffix FnPlain) (fdargs ((AutoDiffable a UInt))) (fdbody (((pattern @@ -1154,7 +1155,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UInt)) (fdname foo_2) (fdsuffix FnPlain) + ((fdrt (ReturnType UInt)) (fdname foo_2) (fdsuffix FnPlain) (fdargs ((AutoDiffable a UInt))) (fdbody (((pattern @@ -1225,7 +1226,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray UReal))) (fdname foo_3) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray UReal))) (fdname foo_3) (fdsuffix FnPlain) (fdargs ((AutoDiffable t UReal) (AutoDiffable n UInt))) (fdbody (((pattern @@ -1245,7 +1246,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_lp) (fdsuffix FnTarget) + ((fdrt (ReturnType UReal)) (fdname foo_lp) (fdsuffix FnTarget) (fdargs ((AutoDiffable x UReal))) (fdbody (((pattern @@ -1264,7 +1265,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ()) (fdname foo_4) (fdsuffix FnPlain) + ((fdrt Void) (fdname foo_4) (fdsuffix FnPlain) (fdargs ((AutoDiffable x (UArray UReal)))) (fdbody (((pattern @@ -1292,7 +1293,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname relative_diff) (fdsuffix FnPlain) + ((fdrt (ReturnType UReal)) (fdname relative_diff) (fdsuffix FnPlain) (fdargs ((AutoDiffable x UReal) (AutoDiffable y UReal) (AutoDiffable max_ UReal) (AutoDiffable min_ UReal))) @@ -1461,7 +1462,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UVector)) (fdname foo_5) (fdsuffix FnPlain) + ((fdrt (ReturnType UVector)) (fdname foo_5) (fdsuffix FnPlain) (fdargs ((AutoDiffable shared_params UVector) (AutoDiffable job_params UVector) (DataOnly data_r (UArray UReal)) (DataOnly data_i (UArray UInt)))) @@ -1504,7 +1505,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_five_args) (fdsuffix FnPlain) + ((fdrt (ReturnType UReal)) (fdname foo_five_args) (fdsuffix FnPlain) (fdargs ((AutoDiffable x1 UReal) (AutoDiffable x2 UReal) (AutoDiffable x3 UReal) (AutoDiffable x4 UReal) (AutoDiffable x5 UReal))) @@ -1518,7 +1519,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname foo_five_args_lp) (fdsuffix FnTarget) + ((fdrt (ReturnType UReal)) (fdname foo_five_args_lp) (fdsuffix FnTarget) (fdargs ((AutoDiffable x1 UReal) (AutoDiffable x2 UReal) (AutoDiffable x3 UReal) (AutoDiffable x4 UReal) (AutoDiffable x5 UReal) @@ -1533,7 +1534,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UMatrix)) (fdname covsqrt2corsqrt) (fdsuffix FnPlain) + ((fdrt (ReturnType UMatrix)) (fdname covsqrt2corsqrt) (fdsuffix FnPlain) (fdargs ((AutoDiffable mat UMatrix) (AutoDiffable invert UInt))) (fdbody (((pattern @@ -1632,7 +1633,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ()) (fdname f0) (fdsuffix FnPlain) + ((fdrt Void) (fdname f0) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1652,7 +1653,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UInt)) (fdname f1) (fdsuffix FnPlain) + ((fdrt (ReturnType UInt)) (fdname f1) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1672,7 +1673,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray UInt))) (fdname f2) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray UInt))) (fdname f2) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1693,7 +1694,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray (UArray UInt)))) (fdname f3) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray (UArray UInt)))) (fdname f3) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1715,7 +1716,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname f4) (fdsuffix FnPlain) + ((fdrt (ReturnType UReal)) (fdname f4) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1735,7 +1736,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray UReal))) (fdname f5) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray UReal))) (fdname f5) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1757,7 +1758,8 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray (UArray UReal)))) (fdname f6) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray (UArray UReal)))) (fdname f6) + (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1779,7 +1781,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UVector)) (fdname f7) (fdsuffix FnPlain) + ((fdrt (ReturnType UVector)) (fdname f7) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1799,7 +1801,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray UVector))) (fdname f8) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray UVector))) (fdname f8) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1821,7 +1823,8 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray (UArray UVector)))) (fdname f9) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray (UArray UVector)))) (fdname f9) + (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1843,7 +1846,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UMatrix)) (fdname f10) (fdsuffix FnPlain) + ((fdrt (ReturnType UMatrix)) (fdname f10) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1863,7 +1866,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray UMatrix))) (fdname f11) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray UMatrix))) (fdname f11) (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1885,7 +1888,8 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ((UArray (UArray UMatrix)))) (fdname f12) (fdsuffix FnPlain) + ((fdrt (ReturnType (UArray (UArray UMatrix)))) (fdname f12) + (fdsuffix FnPlain) (fdargs ((AutoDiffable a1 UInt) (AutoDiffable a2 (UArray UInt)) (AutoDiffable a3 (UArray (UArray UInt))) (AutoDiffable a4 UReal) @@ -1907,7 +1911,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt ()) (fdname foo_6) (fdsuffix FnPlain) (fdargs ()) + ((fdrt Void) (fdname foo_6) (fdsuffix FnPlain) (fdargs ()) (fdbody (((pattern (Block @@ -1968,7 +1972,8 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UMatrix)) (fdname matfoo) (fdsuffix FnPlain) (fdargs ()) + ((fdrt (ReturnType UMatrix)) (fdname matfoo) (fdsuffix FnPlain) + (fdargs ()) (fdbody (((pattern (Block @@ -2232,7 +2237,8 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UVector)) (fdname vecfoo) (fdsuffix FnPlain) (fdargs ()) + ((fdrt (ReturnType UVector)) (fdname vecfoo) (fdsuffix FnPlain) + (fdargs ()) (fdbody (((pattern (Block @@ -2328,7 +2334,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UVector)) (fdname vecmufoo) (fdsuffix FnPlain) + ((fdrt (ReturnType UVector)) (fdname vecmufoo) (fdsuffix FnPlain) (fdargs ((AutoDiffable mu UReal))) (fdbody (((pattern @@ -2360,7 +2366,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UVector)) (fdname vecmubar) (fdsuffix FnPlain) + ((fdrt (ReturnType UVector)) (fdname vecmubar) (fdsuffix FnPlain) (fdargs ((AutoDiffable mu UReal))) (fdbody (((pattern @@ -2512,7 +2518,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UVector)) (fdname algebra_system) (fdsuffix FnPlain) + ((fdrt (ReturnType UVector)) (fdname algebra_system) (fdsuffix FnPlain) (fdargs ((AutoDiffable x UVector) (AutoDiffable y UVector) (AutoDiffable dat (UArray UReal)) (AutoDiffable dat_int (UArray UInt)))) @@ -2599,7 +2605,7 @@ (meta ))))) (meta )))) (fdloc )) - ((fdrt (UVector)) (fdname binomialf) (fdsuffix FnPlain) + ((fdrt (ReturnType UVector)) (fdname binomialf) (fdsuffix FnPlain) (fdargs ((AutoDiffable phi UVector) (AutoDiffable theta UVector) (DataOnly x_r (UArray UReal)) (DataOnly x_i (UArray UInt)))) diff --git a/test/integration/good/compiler-optimizations/mem_patterns/transformed_mir.expected b/test/integration/good/compiler-optimizations/mem_patterns/transformed_mir.expected index 409eb19c57..2c9654251f 100644 --- a/test/integration/good/compiler-optimizations/mem_patterns/transformed_mir.expected +++ b/test/integration/good/compiler-optimizations/mem_patterns/transformed_mir.expected @@ -6194,7 +6194,7 @@ vector[N] tp_soa_single_assign_from_soa: SoA vector[N] tp_aos_loop_vec_v_uni_idx: AoS vector[N] tp_aos_loop_vec_v_multi_uni_idx: AoS ((functions_block - (((fdrt (UInt)) (fdname mask_fun) (fdsuffix FnPlain) + (((fdrt (ReturnType UInt)) (fdname mask_fun) (fdsuffix FnPlain) (fdargs ((AutoDiffable i UInt))) (fdbody (((pattern @@ -6206,7 +6206,7 @@ vector[N] tp_aos_loop_vec_v_multi_uni_idx: AoS (meta ))))) (meta )))) (fdloc )) - ((fdrt (UVector)) (fdname udf_fun) (fdsuffix FnPlain) + ((fdrt (ReturnType UVector)) (fdname udf_fun) (fdsuffix FnPlain) (fdargs ((AutoDiffable A UVector))) (fdbody (((pattern @@ -11368,7 +11368,7 @@ matrix[5, 10] aos_y: AoS matrix[5, 10] tp_matrix_aos_from_mix: AoS matrix[5, 10] tp_matrix_from_udf_reduced_soa: AoS ((functions_block - (((fdrt (UMatrix)) (fdname nono_func) (fdsuffix FnPlain) + (((fdrt (ReturnType UMatrix)) (fdname nono_func) (fdsuffix FnPlain) (fdargs ((AutoDiffable x UMatrix))) (fdbody (((pattern @@ -11380,7 +11380,7 @@ matrix[5, 10] tp_matrix_from_udf_reduced_soa: AoS (meta ))))) (meta )))) (fdloc )) - ((fdrt (UMatrix)) (fdname okay_reduction) (fdsuffix FnPlain) + ((fdrt (ReturnType UMatrix)) (fdname okay_reduction) (fdsuffix FnPlain) (fdargs ((AutoDiffable sum_x UReal) (AutoDiffable y UMatrix))) (fdbody (((pattern @@ -12166,7 +12166,8 @@ matrix[10, 10] inner_empty_user_func_aos: AoS matrix[10, 10] int_aos_mul_aos: AoS matrix[10, 10] mul_two_aos: AoS ((functions_block - (((fdrt (UMatrix)) (fdname empty_user_func) (fdsuffix FnPlain) (fdargs ()) + (((fdrt (ReturnType UMatrix)) (fdname empty_user_func) (fdsuffix FnPlain) + (fdargs ()) (fdbody (((pattern (Block @@ -12188,7 +12189,7 @@ matrix[10, 10] mul_two_aos: AoS (meta ))))) (meta )))) (fdloc )) - ((fdrt (UMatrix)) (fdname mat_ret_user_func) (fdsuffix FnPlain) + ((fdrt (ReturnType UMatrix)) (fdname mat_ret_user_func) (fdsuffix FnPlain) (fdargs ((AutoDiffable A UMatrix))) (fdbody (((pattern @@ -13485,7 +13486,7 @@ matrix[5, 10] first_pass_soa_x: AoS matrix[5, 10] aos_x: AoS matrix[5, 10] tp_matrix_aos: AoS ((functions_block - (((fdrt (UMatrix)) (fdname nono_func) (fdsuffix FnPlain) + (((fdrt (ReturnType UMatrix)) (fdname nono_func) (fdsuffix FnPlain) (fdargs ((AutoDiffable x UMatrix))) (fdbody (((pattern @@ -13497,7 +13498,7 @@ matrix[5, 10] tp_matrix_aos: AoS (meta ))))) (meta )))) (fdloc )) - ((fdrt (UMatrix)) (fdname okay_reduction) (fdsuffix FnPlain) + ((fdrt (ReturnType UMatrix)) (fdname okay_reduction) (fdsuffix FnPlain) (fdargs ((AutoDiffable sum_x UReal) (AutoDiffable y UMatrix))) (fdbody (((pattern diff --git a/test/unit/Optimize.ml b/test/unit/Optimize.ml index db9bbe1a9d..669a6c7c58 100644 --- a/test/unit/Optimize.ml +++ b/test/unit/Optimize.ml @@ -233,7 +233,7 @@ let%expect_test "list collapsing" = [%expect {| ((functions_block - (((fdrt ()) (fdname f) (fdsuffix FnPlain) + (((fdrt Void) (fdname f) (fdsuffix FnPlain) (fdargs ((AutoDiffable x UInt) (AutoDiffable y UMatrix))) (fdbody (((pattern @@ -250,7 +250,7 @@ let%expect_test "list collapsing" = (meta ))))) (meta )))) (fdloc )) - ((fdrt (UReal)) (fdname g) (fdsuffix FnPlain) + ((fdrt (ReturnType UReal)) (fdname g) (fdsuffix FnPlain) (fdargs ((AutoDiffable z UInt))) (fdbody (((pattern @@ -382,18 +382,20 @@ let%expect_test "list collapsing" = (transform_inits ()) (output_vars ()) (prog_name "") (prog_path "")) |}] -let%expect_test "do not inline recursive functions" = +let%expect_test "recursive functions" = let mir = reset_and_mir_of_string {| functions { - real g(int z); - real g(int z) { - return z^2; + int fib(int n) { + if (n == 0 || n == 1) { + return n; + } + return fib(n - 1) + fib(n - 2); } } model { - reject(g(53)); + reject(fib(5)); } |} in @@ -402,12 +404,12 @@ let%expect_test "do not inline recursive functions" = [%expect {| functions { - real g(int z) { - ; - } - real g(int z) { + int fib(int n) { { - return (z ^ 2); + if((n == 0) || (n == 1)) { + return n; + } + return (fib((n - 1)) + fib((n - 2))); } } } @@ -416,7 +418,23 @@ let%expect_test "do not inline recursive functions" = log_prob { { - FnReject__(g(53)); + int inline_fib_return_sym1__; + data int inline_fib_early_ret_check_sym2__; + inline_fib_early_ret_check_sym2__ = 0; + for(inline_fib_iterator_sym3__ in 1:1) { + if((5 == 0)) ; else { + + } + if((5 == 0) || (5 == 1)) { + inline_fib_early_ret_check_sym2__ = 1; + inline_fib_return_sym1__ = 5; + break; + } + inline_fib_early_ret_check_sym2__ = 1; + inline_fib_return_sym1__ = (fib((5 - 1)) + fib((5 - 2))); + break; + } + FnReject__(inline_fib_return_sym1__); } } @@ -428,6 +446,45 @@ let%expect_test "do not inline recursive functions" = if(PNot__(emit_generated_quantities__)) return; } |}] +let%expect_test "do not try to inline extern functions" = + let before = !Frontend.Typechecker.check_that_all_functions_have_definition in + Frontend.Typechecker.check_that_all_functions_have_definition := false ; + let mir = + reset_and_mir_of_string + {| + functions { + int fib(int n); + } + model { + reject(fib(5)); + } + |} + in + Frontend.Typechecker.check_that_all_functions_have_definition := before ; + let mir = function_inlining mir in + Fmt.str "@[%a@]" Program.Typed.pp mir |> print_endline ; + [%expect + {| + functions { + extern int fib(int n); + } + + + + log_prob { + { + FnReject__(fib(5)); + } + } + + generate_quantities { + if(emit_transformed_parameters__) ; else { + + } + if(PNot__(emit_transformed_parameters__ || emit_generated_quantities__)) return; + if(PNot__(emit_generated_quantities__)) return; + } |}] + let%expect_test "inline function in for loop" = let mir = reset_and_mir_of_string