Skip to content

Commit

Permalink
Clean up MIR function return type, printing
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Mar 23, 2023
1 parent 5ba584f commit 3732675
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 117 deletions.
4 changes: 3 additions & 1 deletion src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,9 @@ let create_function_inline_map adt l =
| None -> accum
| Some fdbody -> (
let create_data propto =
( Option.map ~f:(fun x -> Type.Unsized x) fdrt
( Option.map
~f:(fun x -> Type.Unsized x)
(UnsizedType.returntype_to_type_opt fdrt)
, List.map ~f:(fun (_, name, _) -> name) fdargs
, inline_function_statement propto adt accum fdbody ) in
match Middle.Utils.with_unnormalized_suffix fdname with
Expand Down
3 changes: 1 addition & 2 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,7 @@ let trans_fun_def ud_dists (ts : Ast.typed_statement) =
match ts.stmt with
| Ast.FunDef {returntype; funname; arguments; body} ->
[ Program.
{ fdrt=
(match returntype with Void -> None | ReturnType ut -> Some ut)
{ fdrt= returntype
; fdname= funname.name
; fdsuffix= Fun_kind.(suffix_from_name funname.name |> without_propto)
; fdargs= List.map ~f:trans_arg arguments
Expand Down
27 changes: 12 additions & 15 deletions src/middle/Program.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ type fun_arg_decl = (UnsizedType.autodifftype * string * UnsizedType.t) list
[@@deriving sexp, hash, map]

type 'a fun_def =
{ fdrt: UnsizedType.t option
{ fdrt: UnsizedType.returntype
; fdname: string
; fdsuffix: unit Fun_kind.suffix
; fdargs: (UnsizedType.autodifftype * string * UnsizedType.t) list
Expand Down Expand Up @@ -50,20 +50,17 @@ let pp_fun_arg_decl ppf (autodifftype, name, unsizedtype) =
Fmt.pf ppf "%a%a %s" UnsizedType.pp_autodifftype autodifftype UnsizedType.pp
unsizedtype name

let pp_fun_def pp_s ppf = function
| {fdrt; fdname; fdargs; fdbody; _} -> (
let pp_body_opt ppf = function
| None -> Fmt.pf ppf ";"
| Some body -> pp_s ppf body in
match fdrt with
| Some rt ->
Fmt.pf ppf "@[<v2>%a %s%a {@ %a@]@ }" UnsizedType.pp rt fdname
Fmt.(list pp_fun_arg_decl ~sep:comma |> parens)
fdargs pp_body_opt fdbody
| None ->
Fmt.pf ppf "@[<v2>void %s%a {@ %a@]@ }" fdname
Fmt.(list pp_fun_arg_decl ~sep:comma |> parens)
fdargs pp_body_opt fdbody )
let pp_fun_def pp_s ppf {fdrt; fdname; fdargs; fdbody; _} =
match fdbody with
| None ->
Fmt.pf ppf "@[<v2>extern %a %s%a;@]" UnsizedType.pp_returntype fdrt fdname
Fmt.(list pp_fun_arg_decl ~sep:comma |> parens)
fdargs
| Some s ->
Fmt.pf ppf "@[<v2>%a %s%a {@ %a@]@ }" UnsizedType.pp_returntype fdrt
fdname
Fmt.(list pp_fun_arg_decl ~sep:comma |> parens)
fdargs pp_s s

let pp_io_block ppf = function
| Parameters -> Fmt.string ppf "parameters"
Expand Down
2 changes: 2 additions & 0 deletions src/middle/UnsizedType.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ and autodifftype = DataOnly | AutoDiffable
and returntype = Void | ReturnType of t
[@@deriving compare, hash, sexp, equal]

let returntype_to_type_opt = function Void -> None | ReturnType t -> Some t

let pp_autodifftype ppf = function
| DataOnly -> Fmt.string ppf "data "
| AutoDiffable -> ()
Expand Down
13 changes: 7 additions & 6 deletions src/stan_math_backend/Lower_functions.ml
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,10 @@ promotion.*)
let lower_returntype arg_types rt =
let scalar = lower_promoted_scalar arg_types in
match rt with
| Some ut when UnsizedType.is_int_type ut -> lower_type ut Int
| Some ut -> lower_type ut scalar
| None -> Void
| UnsizedType.ReturnType ut when UnsizedType.is_int_type ut ->
lower_type ut Int
| ReturnType ut -> lower_type ut scalar
| Void -> Void

let lower_eigen_args_to_ref arg_types =
let lower_ref name =
Expand Down Expand Up @@ -314,7 +315,7 @@ let lower_standalone_fun_def namespace_fun
let mark_function_comment = GlobalComment "[[stan::function]]" in
let return_type, return_stmt =
match fdrt with
| None -> (Void, fun e -> Expression e)
| Void -> (Void, fun e -> Expression e)
| _ -> (Auto, fun e -> Return (Some e)) in
let fn_sig = make_fun_defn ~name:fdname ~return_type ~args:all_args in
let internal_fname = namespace_fun ^ "::" ^ fdname in
Expand Down Expand Up @@ -345,7 +346,7 @@ module Testing = struct
let with_no_loc stmt =
Stmt.Fixed.{pattern= stmt; meta= Numbering.no_span_num} in
let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in
{ fdrt= None
{ fdrt= Void
; fdname= "sars"
; fdsuffix= FnPlain
; fdargs= [(DataOnly, "x", UMatrix); (AutoDiffable, "y", URowVector)]
Expand Down Expand Up @@ -402,7 +403,7 @@ module Testing = struct
let with_no_loc stmt =
Stmt.Fixed.{pattern= stmt; meta= Numbering.no_span_num} in
let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in
{ fdrt= Some UMatrix
{ fdrt= ReturnType UMatrix
; fdname= "sars"
; fdsuffix= FnPlain
; fdargs=
Expand Down
Loading

0 comments on commit 3732675

Please sign in to comment.