Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary forward decls from MIR, code-gen appropriately #1298

Merged
merged 7 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,9 @@ let create_function_inline_map adt l =
| None -> accum
| Some fdbody -> (
let create_data propto =
( Option.map ~f:(fun x -> Type.Unsized x) fdrt
( Option.map
~f:(fun x -> Type.Unsized x)
(UnsizedType.returntype_to_type_opt fdrt)
, List.map ~f:(fun (_, name, _) -> name) fdargs
, inline_function_statement propto adt accum fdbody ) in
match Middle.Utils.with_unnormalized_suffix fdname with
Expand Down
6 changes: 3 additions & 3 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,7 @@ let trans_fun_def ud_dists (ts : Ast.typed_statement) =
match ts.stmt with
| Ast.FunDef {returntype; funname; arguments; body} ->
[ Program.
{ fdrt=
(match returntype with Void -> None | ReturnType ut -> Some ut)
{ fdrt= returntype
; fdname= funname.name
; fdsuffix= Fun_kind.(suffix_from_name funname.name |> without_propto)
; fdargs= List.map ~f:trans_arg arguments
Expand Down Expand Up @@ -785,7 +784,8 @@ let gather_data (p : Ast.typed_program) =
| _ -> [] )

let trans_prog filename (p : Ast.typed_program) : Program.Typed.t =
let {Ast.functionblock; datablock; transformeddatablock; modelblock; _} = p in
let {Ast.functionblock; datablock; transformeddatablock; modelblock; _} =
Deprecation_analysis.remove_unneeded_forward_decls p in
let map f list_op =
Option.value_map ~default:[]
~f:(fun {Ast.stmts; _} -> List.concat_map ~f stmts)
Expand Down
11 changes: 1 addition & 10 deletions src/frontend/Canonicalize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -261,16 +261,7 @@ let repair_syntax program settings =
let canonicalize_program program settings : typed_program =
let program =
if settings.deprecations then
let fundefs = userdef_functions program in
let drop_forwarddecl = function
| {stmt= FunDef {body= {stmt= Skip; _}; funname; arguments; _}; _}
when is_redundant_forwarddecl fundefs funname arguments ->
false
| _ -> true in
{ program with
functionblock=
Option.map program.functionblock ~f:(fun x ->
{x with stmts= List.filter ~f:drop_forwarddecl x.stmts} ) }
remove_unneeded_forward_decls program
|> map_program
(replace_deprecated_stmt (collect_userdef_distributions program))
else program in
Expand Down
19 changes: 15 additions & 4 deletions src/frontend/Deprecation_analysis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,17 @@ let rename_deprecated map name =

let userdef_functions program =
match program.functionblock with
| None -> []
| None -> Hash_set.Poly.create ()
| Some {stmts; _} ->
List.filter_map stmts ~f:(function
| {stmt= FunDef {body= {stmt= Skip; _}; _}; _} -> None
| {stmt= FunDef {funname; arguments; _}; _} ->
Some (funname.name, Ast.type_of_arguments arguments)
| _ -> None )
|> Hash_set.Poly.of_list

let is_redundant_forwarddecl fundefs funname arguments =
let equal (id1, a1) (id2, a2) =
String.equal id1 id2 && UnsizedType.equal_argumentlist a1 a2 in
List.mem ~equal fundefs (funname.name, Ast.type_of_arguments arguments)
Hash_set.mem fundefs (funname.name, Ast.type_of_arguments arguments)

let userdef_distributions stmts =
let open String in
Expand Down Expand Up @@ -235,3 +234,15 @@ let collect_userdef_distributions program =
let collect_warnings (program : typed_program) =
let fundefs = userdef_functions program in
fold_program (collect_deprecated_stmt fundefs) [] program

let remove_unneeded_forward_decls program =
let fundefs = userdef_functions program in
let drop_forwarddecl = function
| {stmt= FunDef {body= {stmt= Skip; _}; funname; arguments; _}; _}
when is_redundant_forwarddecl fundefs funname arguments ->
false
| _ -> true in
{ program with
functionblock=
Option.map program.functionblock ~f:(fun x ->
{x with stmts= List.filter ~f:drop_forwarddecl x.stmts} ) }
12 changes: 1 addition & 11 deletions src/frontend/Deprecation_analysis.mli
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,6 @@ val is_deprecated_distribution : string -> bool
val deprecated_distributions : (string * string) String.Map.t
val deprecated_functions : (string * string) String.Map.t
val rename_deprecated : (string * string) String.Map.t -> string -> string

val userdef_functions :
('a, 'b, 'c, 'd) statement_with program
-> (string * Middle.UnsizedType.argumentlist) list

val is_redundant_forwarddecl :
(string * Middle.UnsizedType.argumentlist) list
-> identifier
-> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t * 'a) list
-> bool

val userdef_distributions : untyped_statement block option -> string list
val collect_warnings : typed_program -> Warnings.t list
val remove_unneeded_forward_decls : typed_program -> typed_program
30 changes: 14 additions & 16 deletions src/middle/Program.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ type fun_arg_decl = (UnsizedType.autodifftype * string * UnsizedType.t) list
[@@deriving sexp, hash, map]

type 'a fun_def =
{ fdrt: UnsizedType.t option
{ fdrt: UnsizedType.returntype
; fdname: string
; fdsuffix: unit Fun_kind.suffix
; fdargs: (UnsizedType.autodifftype * string * UnsizedType.t) list
(* If fdbody is None, this is a function declaration without body. *)
; fdbody: 'a option
(* If fdbody is None, this is an external function declaration
(forward decls are removed during AST lowering) *)
; fdloc: (Location_span.t[@sexp.opaque] [@compare.ignore]) }
[@@deriving compare, hash, sexp, map, fold]

Expand Down Expand Up @@ -49,20 +50,17 @@ let pp_fun_arg_decl ppf (autodifftype, name, unsizedtype) =
Fmt.pf ppf "%a%a %s" UnsizedType.pp_autodifftype autodifftype UnsizedType.pp
unsizedtype name

let pp_fun_def pp_s ppf = function
| {fdrt; fdname; fdargs; fdbody; _} -> (
let pp_body_opt ppf = function
| None -> Fmt.pf ppf ";"
| Some body -> pp_s ppf body in
match fdrt with
| Some rt ->
Fmt.pf ppf "@[<v2>%a %s%a {@ %a@]@ }" UnsizedType.pp rt fdname
Fmt.(list pp_fun_arg_decl ~sep:comma |> parens)
fdargs pp_body_opt fdbody
| None ->
Fmt.pf ppf "@[<v2>void %s%a {@ %a@]@ }" fdname
Fmt.(list pp_fun_arg_decl ~sep:comma |> parens)
fdargs pp_body_opt fdbody )
let pp_fun_def pp_s ppf {fdrt; fdname; fdargs; fdbody; _} =
match fdbody with
| None ->
Fmt.pf ppf "@[<v2>extern %a %s%a;@]" UnsizedType.pp_returntype fdrt fdname
Fmt.(list pp_fun_arg_decl ~sep:comma |> parens)
fdargs
| Some s ->
Fmt.pf ppf "@[<v2>%a %s%a {@ %a@]@ }" UnsizedType.pp_returntype fdrt
fdname
Fmt.(list pp_fun_arg_decl ~sep:comma |> parens)
fdargs pp_s s

let pp_io_block ppf = function
| Parameters -> Fmt.string ppf "parameters"
Expand Down
2 changes: 2 additions & 0 deletions src/middle/UnsizedType.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ and autodifftype = DataOnly | AutoDiffable
and returntype = Void | ReturnType of t
[@@deriving compare, hash, sexp, equal]

let returntype_to_type_opt = function Void -> None | ReturnType t -> Some t

let pp_autodifftype ppf = function
| DataOnly -> Fmt.string ppf "data "
| AutoDiffable -> ()
Expand Down
55 changes: 27 additions & 28 deletions src/stan_math_backend/Lower_functions.ml
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,10 @@ promotion.*)
let lower_returntype arg_types rt =
let scalar = lower_promoted_scalar arg_types in
match rt with
| Some ut when UnsizedType.is_int_type ut -> lower_type ut Int
| Some ut -> lower_type ut scalar
| None -> Void
| UnsizedType.ReturnType ut when UnsizedType.is_int_type ut ->
lower_type ut Int
| ReturnType ut -> lower_type ut scalar
| Void -> Void

let lower_eigen_args_to_ref arg_types =
let lower_ref name =
Expand Down Expand Up @@ -269,7 +270,7 @@ let collect_functors_functions (p : Program.Numbered.t) : defn list =
List.equal UnsizedType.equal
(List.map ~f:(fun (_, _, t) -> t) fdargs)
arg_types in
let declare_and_define (d : _ Program.fun_def) =
let register_functors (d : _ Program.fun_def) =
let functors =
Map.find_multi functor_required d.fdname
|> List.stable_dedup
Expand All @@ -281,20 +282,21 @@ let collect_functors_functions (p : Program.Numbered.t) : defn list =
Hashtbl.update structs s.struct_name ~f:(function
| Some x -> {x with body= x.body @ s.body}
| None -> s ) ) ;
let decl, defn = Cpp.split_fun_decl_defn fn in
(FunDef decl, FunDef defn) in
fn in
let fun_decls, fun_defns =
p.functions_block
|> List.filter_map ~f:(fun d ->
(* FIXME: external functions don't need decls
but they do need structs (when used in HOF) *)
if Option.is_none d.fdbody then None else Some (declare_and_define d) )
let fn = register_functors d in
if Option.is_none d.fdbody then None
else
let decl, defn = Cpp.split_fun_decl_defn fn in
Some (FunDef decl, FunDef defn) )
|> List.unzip in
let structs = Hashtbl.data structs |> List.map ~f:(fun s -> Struct s) in
fun_decls @ structs @ fun_defns

let lower_standalone_fun_def namespace_fun
Program.{fdname; fdsuffix; fdargs; fdbody; fdrt; _} =
Program.{fdname; fdsuffix; fdargs; fdrt; _} =
let extra, extra_templates =
match fdsuffix with
| Fun_kind.FnTarget ->
Expand All @@ -313,24 +315,21 @@ let lower_standalone_fun_def namespace_fun
let mark_function_comment = GlobalComment "[[stan::function]]" in
let return_type, return_stmt =
match fdrt with
| None -> (Void, fun e -> Expression e)
| Void -> (Void, fun e -> Expression e)
| _ -> (Auto, fun e -> Return (Some e)) in
let fn_sig = make_fun_defn ~name:fdname ~return_type ~args:all_args in
match fdbody with
| None -> []
| Some _ ->
let internal_fname = namespace_fun ^ "::" ^ fdname in
let template =
match fdsuffix with
| FnLpdf _ | FnTarget -> [TypeLiteral "false"]
| FnRng | FnPlain -> [] in
let call_args =
List.map ~f:(fun (_, name, _) -> name) fdargs @ extra @ ["pstream__"]
|> List.map ~f:Exprs.to_var in
let ret =
return_stmt (Exprs.templated_fun_call internal_fname template call_args)
in
[mark_function_comment; FunDef (fn_sig ~body:[ret] ())]
let internal_fname = namespace_fun ^ "::" ^ fdname in
let template =
match fdsuffix with
| FnLpdf _ | FnTarget -> [TypeLiteral "false"]
| FnRng | FnPlain -> [] in
let call_args =
List.map ~f:(fun (_, name, _) -> name) fdargs @ extra @ ["pstream__"]
|> List.map ~f:Exprs.to_var in
let ret =
return_stmt (Exprs.templated_fun_call internal_fname template call_args)
in
[mark_function_comment; FunDef (fn_sig ~body:[ret] ())]

module Testing = struct
(* Testing code *)
Expand All @@ -347,7 +346,7 @@ module Testing = struct
let with_no_loc stmt =
Stmt.Fixed.{pattern= stmt; meta= Numbering.no_span_num} in
let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in
{ fdrt= None
{ fdrt= Void
; fdname= "sars"
; fdsuffix= FnPlain
; fdargs= [(DataOnly, "x", UMatrix); (AutoDiffable, "y", URowVector)]
Expand Down Expand Up @@ -404,7 +403,7 @@ module Testing = struct
let with_no_loc stmt =
Stmt.Fixed.{pattern= stmt; meta= Numbering.no_span_num} in
let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in
{ fdrt= Some UMatrix
{ fdrt= ReturnType UMatrix
; fdname= "sars"
; fdsuffix= FnPlain
; fdargs=
Expand Down
Loading