Skip to content

Commit

Permalink
Merge pull request #1298 from stan-dev/fix/remove-unnecessary-forward…
Browse files Browse the repository at this point in the history
…-decls-mir

Remove unnecessary forward decls from MIR, code-gen appropriately
  • Loading branch information
WardBrian authored Mar 23, 2023
2 parents a490a09 + 3732675 commit 655d9e8
Show file tree
Hide file tree
Showing 19 changed files with 755 additions and 260 deletions.
4 changes: 3 additions & 1 deletion src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,9 @@ let create_function_inline_map adt l =
| None -> accum
| Some fdbody -> (
let create_data propto =
( Option.map ~f:(fun x -> Type.Unsized x) fdrt
( Option.map
~f:(fun x -> Type.Unsized x)
(UnsizedType.returntype_to_type_opt fdrt)
, List.map ~f:(fun (_, name, _) -> name) fdargs
, inline_function_statement propto adt accum fdbody ) in
match Middle.Utils.with_unnormalized_suffix fdname with
Expand Down
6 changes: 3 additions & 3 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,7 @@ let trans_fun_def ud_dists (ts : Ast.typed_statement) =
match ts.stmt with
| Ast.FunDef {returntype; funname; arguments; body} ->
[ Program.
{ fdrt=
(match returntype with Void -> None | ReturnType ut -> Some ut)
{ fdrt= returntype
; fdname= funname.name
; fdsuffix= Fun_kind.(suffix_from_name funname.name |> without_propto)
; fdargs= List.map ~f:trans_arg arguments
Expand Down Expand Up @@ -785,7 +784,8 @@ let gather_data (p : Ast.typed_program) =
| _ -> [] )

let trans_prog filename (p : Ast.typed_program) : Program.Typed.t =
let {Ast.functionblock; datablock; transformeddatablock; modelblock; _} = p in
let {Ast.functionblock; datablock; transformeddatablock; modelblock; _} =
Deprecation_analysis.remove_unneeded_forward_decls p in
let map f list_op =
Option.value_map ~default:[]
~f:(fun {Ast.stmts; _} -> List.concat_map ~f stmts)
Expand Down
11 changes: 1 addition & 10 deletions src/frontend/Canonicalize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -261,16 +261,7 @@ let repair_syntax program settings =
let canonicalize_program program settings : typed_program =
let program =
if settings.deprecations then
let fundefs = userdef_functions program in
let drop_forwarddecl = function
| {stmt= FunDef {body= {stmt= Skip; _}; funname; arguments; _}; _}
when is_redundant_forwarddecl fundefs funname arguments ->
false
| _ -> true in
{ program with
functionblock=
Option.map program.functionblock ~f:(fun x ->
{x with stmts= List.filter ~f:drop_forwarddecl x.stmts} ) }
remove_unneeded_forward_decls program
|> map_program
(replace_deprecated_stmt (collect_userdef_distributions program))
else program in
Expand Down
19 changes: 15 additions & 4 deletions src/frontend/Deprecation_analysis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,17 @@ let rename_deprecated map name =

let userdef_functions program =
match program.functionblock with
| None -> []
| None -> Hash_set.Poly.create ()
| Some {stmts; _} ->
List.filter_map stmts ~f:(function
| {stmt= FunDef {body= {stmt= Skip; _}; _}; _} -> None
| {stmt= FunDef {funname; arguments; _}; _} ->
Some (funname.name, Ast.type_of_arguments arguments)
| _ -> None )
|> Hash_set.Poly.of_list

let is_redundant_forwarddecl fundefs funname arguments =
let equal (id1, a1) (id2, a2) =
String.equal id1 id2 && UnsizedType.equal_argumentlist a1 a2 in
List.mem ~equal fundefs (funname.name, Ast.type_of_arguments arguments)
Hash_set.mem fundefs (funname.name, Ast.type_of_arguments arguments)

let userdef_distributions stmts =
let open String in
Expand Down Expand Up @@ -235,3 +234,15 @@ let collect_userdef_distributions program =
let collect_warnings (program : typed_program) =
let fundefs = userdef_functions program in
fold_program (collect_deprecated_stmt fundefs) [] program

let remove_unneeded_forward_decls program =
let fundefs = userdef_functions program in
let drop_forwarddecl = function
| {stmt= FunDef {body= {stmt= Skip; _}; funname; arguments; _}; _}
when is_redundant_forwarddecl fundefs funname arguments ->
false
| _ -> true in
{ program with
functionblock=
Option.map program.functionblock ~f:(fun x ->
{x with stmts= List.filter ~f:drop_forwarddecl x.stmts} ) }
12 changes: 1 addition & 11 deletions src/frontend/Deprecation_analysis.mli
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,6 @@ val is_deprecated_distribution : string -> bool
val deprecated_distributions : (string * string) String.Map.t
val deprecated_functions : (string * string) String.Map.t
val rename_deprecated : (string * string) String.Map.t -> string -> string

val userdef_functions :
('a, 'b, 'c, 'd) statement_with program
-> (string * Middle.UnsizedType.argumentlist) list

val is_redundant_forwarddecl :
(string * Middle.UnsizedType.argumentlist) list
-> identifier
-> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t * 'a) list
-> bool

val userdef_distributions : untyped_statement block option -> string list
val collect_warnings : typed_program -> Warnings.t list
val remove_unneeded_forward_decls : typed_program -> typed_program
30 changes: 14 additions & 16 deletions src/middle/Program.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ type fun_arg_decl = (UnsizedType.autodifftype * string * UnsizedType.t) list
[@@deriving sexp, hash, map]

type 'a fun_def =
{ fdrt: UnsizedType.t option
{ fdrt: UnsizedType.returntype
; fdname: string
; fdsuffix: unit Fun_kind.suffix
; fdargs: (UnsizedType.autodifftype * string * UnsizedType.t) list
(* If fdbody is None, this is a function declaration without body. *)
; fdbody: 'a option
(* If fdbody is None, this is an external function declaration
(forward decls are removed during AST lowering) *)
; fdloc: (Location_span.t[@sexp.opaque] [@compare.ignore]) }
[@@deriving compare, hash, sexp, map, fold]

Expand Down Expand Up @@ -49,20 +50,17 @@ let pp_fun_arg_decl ppf (autodifftype, name, unsizedtype) =
Fmt.pf ppf "%a%a %s" UnsizedType.pp_autodifftype autodifftype UnsizedType.pp
unsizedtype name

let pp_fun_def pp_s ppf = function
| {fdrt; fdname; fdargs; fdbody; _} -> (
let pp_body_opt ppf = function
| None -> Fmt.pf ppf ";"
| Some body -> pp_s ppf body in
match fdrt with
| Some rt ->
Fmt.pf ppf "@[<v2>%a %s%a {@ %a@]@ }" UnsizedType.pp rt fdname
Fmt.(list pp_fun_arg_decl ~sep:comma |> parens)
fdargs pp_body_opt fdbody
| None ->
Fmt.pf ppf "@[<v2>void %s%a {@ %a@]@ }" fdname
Fmt.(list pp_fun_arg_decl ~sep:comma |> parens)
fdargs pp_body_opt fdbody )
let pp_fun_def pp_s ppf {fdrt; fdname; fdargs; fdbody; _} =
match fdbody with
| None ->
Fmt.pf ppf "@[<v2>extern %a %s%a;@]" UnsizedType.pp_returntype fdrt fdname
Fmt.(list pp_fun_arg_decl ~sep:comma |> parens)
fdargs
| Some s ->
Fmt.pf ppf "@[<v2>%a %s%a {@ %a@]@ }" UnsizedType.pp_returntype fdrt
fdname
Fmt.(list pp_fun_arg_decl ~sep:comma |> parens)
fdargs pp_s s

let pp_io_block ppf = function
| Parameters -> Fmt.string ppf "parameters"
Expand Down
2 changes: 2 additions & 0 deletions src/middle/UnsizedType.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ and autodifftype = DataOnly | AutoDiffable
and returntype = Void | ReturnType of t
[@@deriving compare, hash, sexp, equal]

let returntype_to_type_opt = function Void -> None | ReturnType t -> Some t

let pp_autodifftype ppf = function
| DataOnly -> Fmt.string ppf "data "
| AutoDiffable -> ()
Expand Down
55 changes: 27 additions & 28 deletions src/stan_math_backend/Lower_functions.ml
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,10 @@ promotion.*)
let lower_returntype arg_types rt =
let scalar = lower_promoted_scalar arg_types in
match rt with
| Some ut when UnsizedType.is_int_type ut -> lower_type ut Int
| Some ut -> lower_type ut scalar
| None -> Void
| UnsizedType.ReturnType ut when UnsizedType.is_int_type ut ->
lower_type ut Int
| ReturnType ut -> lower_type ut scalar
| Void -> Void

let lower_eigen_args_to_ref arg_types =
let lower_ref name =
Expand Down Expand Up @@ -269,7 +270,7 @@ let collect_functors_functions (p : Program.Numbered.t) : defn list =
List.equal UnsizedType.equal
(List.map ~f:(fun (_, _, t) -> t) fdargs)
arg_types in
let declare_and_define (d : _ Program.fun_def) =
let register_functors (d : _ Program.fun_def) =
let functors =
Map.find_multi functor_required d.fdname
|> List.stable_dedup
Expand All @@ -281,20 +282,21 @@ let collect_functors_functions (p : Program.Numbered.t) : defn list =
Hashtbl.update structs s.struct_name ~f:(function
| Some x -> {x with body= x.body @ s.body}
| None -> s ) ) ;
let decl, defn = Cpp.split_fun_decl_defn fn in
(FunDef decl, FunDef defn) in
fn in
let fun_decls, fun_defns =
p.functions_block
|> List.filter_map ~f:(fun d ->
(* FIXME: external functions don't need decls
but they do need structs (when used in HOF) *)
if Option.is_none d.fdbody then None else Some (declare_and_define d) )
let fn = register_functors d in
if Option.is_none d.fdbody then None
else
let decl, defn = Cpp.split_fun_decl_defn fn in
Some (FunDef decl, FunDef defn) )
|> List.unzip in
let structs = Hashtbl.data structs |> List.map ~f:(fun s -> Struct s) in
fun_decls @ structs @ fun_defns

let lower_standalone_fun_def namespace_fun
Program.{fdname; fdsuffix; fdargs; fdbody; fdrt; _} =
Program.{fdname; fdsuffix; fdargs; fdrt; _} =
let extra, extra_templates =
match fdsuffix with
| Fun_kind.FnTarget ->
Expand All @@ -313,24 +315,21 @@ let lower_standalone_fun_def namespace_fun
let mark_function_comment = GlobalComment "[[stan::function]]" in
let return_type, return_stmt =
match fdrt with
| None -> (Void, fun e -> Expression e)
| Void -> (Void, fun e -> Expression e)
| _ -> (Auto, fun e -> Return (Some e)) in
let fn_sig = make_fun_defn ~name:fdname ~return_type ~args:all_args in
match fdbody with
| None -> []
| Some _ ->
let internal_fname = namespace_fun ^ "::" ^ fdname in
let template =
match fdsuffix with
| FnLpdf _ | FnTarget -> [TypeLiteral "false"]
| FnRng | FnPlain -> [] in
let call_args =
List.map ~f:(fun (_, name, _) -> name) fdargs @ extra @ ["pstream__"]
|> List.map ~f:Exprs.to_var in
let ret =
return_stmt (Exprs.templated_fun_call internal_fname template call_args)
in
[mark_function_comment; FunDef (fn_sig ~body:[ret] ())]
let internal_fname = namespace_fun ^ "::" ^ fdname in
let template =
match fdsuffix with
| FnLpdf _ | FnTarget -> [TypeLiteral "false"]
| FnRng | FnPlain -> [] in
let call_args =
List.map ~f:(fun (_, name, _) -> name) fdargs @ extra @ ["pstream__"]
|> List.map ~f:Exprs.to_var in
let ret =
return_stmt (Exprs.templated_fun_call internal_fname template call_args)
in
[mark_function_comment; FunDef (fn_sig ~body:[ret] ())]

module Testing = struct
(* Testing code *)
Expand All @@ -347,7 +346,7 @@ module Testing = struct
let with_no_loc stmt =
Stmt.Fixed.{pattern= stmt; meta= Numbering.no_span_num} in
let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in
{ fdrt= None
{ fdrt= Void
; fdname= "sars"
; fdsuffix= FnPlain
; fdargs= [(DataOnly, "x", UMatrix); (AutoDiffable, "y", URowVector)]
Expand Down Expand Up @@ -404,7 +403,7 @@ module Testing = struct
let with_no_loc stmt =
Stmt.Fixed.{pattern= stmt; meta= Numbering.no_span_num} in
let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in
{ fdrt= Some UMatrix
{ fdrt= ReturnType UMatrix
; fdname= "sars"
; fdsuffix= FnPlain
; fdargs=
Expand Down
Loading

0 comments on commit 655d9e8

Please sign in to comment.