Skip to content

Commit

Permalink
new model definitions! with sir example
Browse files Browse the repository at this point in the history
  • Loading branch information
stevencarlislewalker committed Nov 3, 2023
1 parent 02f8c1d commit d9699e7
Show file tree
Hide file tree
Showing 20 changed files with 756 additions and 362 deletions.
16 changes: 14 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,21 @@
S3method(Index,Index)
S3method(Index,Partition)
S3method(Index,data.frame)
S3method(Vector,Index)
S3method(Vector,data.frame)
S3method(Vector,numeric)
S3method(as.data.frame,Index)
S3method(as.data.frame,Link)
S3method(as.matrix,Vector)
S3method(c,String)
S3method(c,StringData)
S3method(head,Link)
S3method(labelling_names,Index)
S3method(labelling_names,Link)
S3method(labels,Index)
S3method(length,Vector)
S3method(mp_index,character)
S3method(mp_index,data.frame)
S3method(mp_labels,Index)
S3method(mp_labels,Link)
S3method(mp_union,Index)
Expand All @@ -22,7 +28,6 @@ S3method(names,Link)
S3method(names,MatsList)
S3method(names,MethList)
S3method(names,Partition)
S3method(print,FormulaData)
S3method(print,Index)
S3method(print,Link)
S3method(print,MathExpression)
Expand Down Expand Up @@ -76,6 +81,7 @@ export(FlowExpander)
export(Flows)
export(Formula)
export(Index)
export(IndexedExpressions)
export(Infection)
export(IntVecs)
export(JSONReader)
Expand Down Expand Up @@ -138,21 +144,27 @@ export(mp_choose_out)
export(mp_decompose)
export(mp_expr_binop)
export(mp_expr_group_sum)
export(mp_expr_list)
export(mp_formula_data)
export(mp_group)
export(mp_index)
export(mp_indexed_exprs)
export(mp_indicator)
export(mp_indices)
export(mp_join)
export(mp_labels)
export(mp_linear)
export(mp_rbind)
export(mp_rename)
export(mp_select)
export(mp_set_numbers)
export(mp_setdiff)
export(mp_square)
export(mp_subset)
export(mp_symmetric)
export(mp_tmb_simulator)
export(mp_triangle)
export(mp_union)
export(mp_vector)
export(mp_zero_vector)
export(nlist)
export(not_all_equal)
Expand Down
11 changes: 11 additions & 0 deletions R/expr_list.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,14 @@ ExprList = function(
if (is.null(nms)) nms = rep("", length(self$formula_list()))
nms
}
self$all_formula_vars = function() {
(self$formula_list()
|> lapply(formula_components)
|> lapply(getElement, "variables")
|> unlist(use.names = FALSE, recursive = FALSE)
|> unique()
)
}

self$data_arg = function() {
r = c(
Expand Down Expand Up @@ -230,3 +238,6 @@ ExprList = function(

return_object(self, "ExprList")
}

#' @export
mp_expr_list = ExprList
120 changes: 114 additions & 6 deletions R/formula_data.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,120 @@
FormulaData = function(frame, reference_index_list, labelling_names_list) {
FormulaData = function(...) {
self = Base()
self$frame = frame
self$reference_index_list = reference_index_list
self$labelling_names_list = labelling_names_list
self$link_list = list(...)

labelling_names_list = (self$link_list
|> lapply(getElement, "labelling_names_list")
|> unname()
|> unique()
)
stopifnot(length(labelling_names_list) == 1L)
self$labelling_names_list = labelling_names_list[[1L]]

reference_index_list = (self$link_list
|> lapply(getElement, "reference_index_list")
|> unname()
|> unique()
)
stopifnot(length(reference_index_list) == 1L)
self$reference_index_list = reference_index_list[[1L]]

table_names = (self$link_list
|> method_apply("table_names")
|> unname()
|> unique()
)
stopifnot(length(table_names) == 1L)
self$table_names = table_names[[1L]]

self$labels_frame = function() {
(self$link_list
|> method_apply("labels_frame")
|> bind_rows()
)
}

self$positions_frame = function(zero_based = FALSE) {
positions_list = list()
for (i in seq_along(self$link_list)) {
positions_list[[i]] = list()
for (d in self$table_names) {
positions_list[[i]][[d]] = self$link_list[[i]]$positions_for[[d]](zero_based)
}
positions_list[[i]] = as.data.frame(positions_list[[i]])
}
bind_rows(positions_list)
}

return_object(self, "FormulaData")
}

#' #' @export
#' print.FormulaData = function(x, ...) {
#' print(x$frame, row.names = FALSE)
#' }


#' @export
print.FormulaData = function(x, ...) {
print(x$frame, row.names = FALSE)
mp_formula_data = function(...) FormulaData(...)

#' Indexed Expressions
#'
#' @param ... Formula objects that reference the columns in the
#' \code{index_data}, the vectors in \code{vector_list} and the matrices
#' in \code{unstructured_matrix_list}.
#' @param index_data An object produced using \code{\link{mp_formula_data}}.
#' @param vector_list Named list of objected produced using
#' \code{\link{mp_vector}}.
#' @param unstructured_matrix_list Named list of objects that can be coerced
#' to a matrix.
#'
#' @export
IndexedExpressions = function(...
, index_data
, vector_list = list()
, unstructured_matrix_list = list()
) {
self = Base()
self$formulas = list(...)
self$index_data = index_data
self$vector_list = vector_list
self$unstructured_matrix_list = unstructured_matrix_list
self$int_vecs = function(zero_based = FALSE) {
self$index_data$positions_frame(zero_based) |> as.list()
}
self$mats_list = function() {
all_vars = (self$formulas
|> lapply(macpan2:::formula_components)
|> lapply(getElement, "variables")
|> unlist(use.names = FALSE, recursive = FALSE)
|> unique()
)
derived = setdiff(all_vars, c(
names(self$vector_list),
self$index_data$table_names,
names(self$unstructured_matrix_list)
))
derived = (empty_matrix
|> list()
|> rep(length(derived))
|> setNames(derived)
)
vectors = method_apply(self$vector_list, "numbers")
unstruc = self$unstructured_matrix_list
c(vectors, unstruc, derived)
}
self$simulate = function(time_steps = 1L) {
simple_sims(
self$formulas,
time_steps,
self$int_vecs(zero_based = TRUE),
self$mats_list()
)
}
return_object(self, "IndexedExpressions")
}


#' @export
mp_indexed_exprs = IndexedExpressions

25 changes: 24 additions & 1 deletion R/index.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ Index.Partition = function(partition
## Standard Methods
self$labels = function() self$partition$select(self$labelling_names)$labels()
self$partial_labels = function(...) self$partition$partial_labels(...)
self$reference_labels = function() {
self$reference_index()$partial_labels(self$labelling_names)
}
self$reference_positions = function(zero_based = FALSE) {
i = match(self$reference_labels(), self$labels())
if (zero_based) i = i - 1L
i
}
self$positions = function(zero_based = FALSE) {
i = match(self$labels(), self$reference_labels())
if (zero_based) i = i - 1L
i
}

return_object(self, "Index")
}
Expand Down Expand Up @@ -154,8 +167,18 @@ labels.Index = function(x, ...) x$labels()
#' )
#'
#' @export
mp_index = function(..., labelling_names) {
mp_index = function(..., labelling_names) UseMethod("mp_index")

#' @export
mp_index.character = function(..., labelling_names) {
f = data.frame(...)
if (missing(labelling_names)) labelling_names = names(f)
Index(f, to_names(labelling_names))
}

#' @export
mp_index.data.frame = function(..., labelling_names) {
f = list(...)[[1L]]
if (missing(labelling_names)) labelling_names = names(f)
Index(f, to_names(labelling_names))
}
45 changes: 45 additions & 0 deletions R/index_to_tmb.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#' @export
mp_tmb_simulator = function(expr_list = ExprList()
, index_data = list()
, indexed_vecs = list()
, unstruc_mats = list()
, time_steps = 0L
, mats_to_save = names(indexed_vecs)
, mats_to_return = names(indexed_vecs)
, ...
) {
int_vecs = (index_data
|> method_apply("positions_frame", zero_based = TRUE)
|> lapply(as.list)
|> unname()
|> unlist(recursive = FALSE)
)
indexed_mats = lapply(indexed_vecs, as.matrix)

all_vars = expr_list$all_formula_vars()

derived_nms = setdiff(all_vars, c(
names(int_vecs), names(indexed_mats), names(unstruc_mats)
))

derived_mats = (empty_matrix
|> list()
|> rep(length(derived_nms))
|> setNames(derived_nms)
)

mats = c(indexed_mats, unstruc_mats, derived_mats)
mats_list_options = list(
.mats_to_save = mats_to_save,
.mats_to_return = mats_to_return
)
engine_methods = EngineMethods(int_vecs = do.call(IntVecs, int_vecs))
tmb_model = TMBModel(
init_mats = do.call(MatsList, c(mats, mats_list_options))
, expr_list = expr_list
, engine_methods = engine_methods
, time_steps = Time(time_steps)
, ...
)
tmb_model$simulator()
}
42 changes: 34 additions & 8 deletions R/link.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,21 @@ Link = function(frame, column_map, reference_index_list, labelling_names_list) {
l |> as.data.frame()
}
self$frame_for = list()
self$labels_for = list()
self$partition_for = list()
self$index_for = list()
self$labels_for = list()
self$reference_labels_for = list()
self$positions_for = list()
self$reference_positions_for = list()
for (d in names(self$column_map)) {
getter = FrameGetter(self, d)
self$frame_for[[d]] = getter$get_frame
self$labels_for[[d]] = getter$get_labels
self$partition_for[[d]] = getter$get_partition
self$index_for[[d]] = getter$get_index
self$reference_labels_for[[d]] = getter$get_reference_labels
self$positions_for[[d]] = getter$get_positions
self$reference_positions_for[[d]] = getter$get_reference_positions
}
self$column_by_dim = list()
for (d in names(self$column_map)) {
Expand All @@ -69,6 +75,10 @@ Link = function(frame, column_map, reference_index_list, labelling_names_list) {
, labelling_names_list = self$labelling_names_list
)
}
self$expr = function(condition) {
substitute(condition)
eval(condition, envir = c(self$column_by_dim, self$frame))
}
self$partition = self$partition_for[[1L]]() ## hack! should probably have a method and then change the partition field in Index to a method as well
return_object(self, "Link")
}
Expand Down Expand Up @@ -102,17 +112,33 @@ FrameGetter = function(link, dimension_name) {
)
}
self$get_partition = function() self$get_frame() |> Partition()
self$get_index = function() Index(
self$get_partition(),
self$link$labelling_names_list[[self$dimension_name]]
)
self$get_index = function() {
Index(
self$get_partition(),
self$link$labelling_names_list[[self$dimension_name]],
self$link$reference_index_list[[self$dimension_name]]
)
}
self$get_labels = function() {
i = self$link$labelling_names_list[[self$dimension_name]]
f = self$get_frame()[, i, drop = FALSE]
l = as.list(f)
paste_args = c(l, sep = ".")
do.call(paste, paste_args)
}
self$get_reference_labels = function() {
self$get_index()$reference_labels()
}
self$get_positions = function(zero_based = FALSE) {
i = match(self$get_labels(), self$get_reference_labels())
if (zero_based) i = i - 1L
i
}
self$get_reference_positions = function(zero_based = FALSE) {
i = match(self$get_reference_labels(), self$get_labels())
if (zero_based) i = i - 1L
i
}
return_object(self, "FrameGetter")
}

Expand Down Expand Up @@ -280,9 +306,9 @@ explicit_provenance = function(x, col_nm) {
m = x$column_map
implicit = is_provenance_implicit(x, col_nm)
if (length(implicit) == 0L) {
macpan2:::msg_colon(
macpan2:::msg("Column", col_nm, "not found in any of the original tables"),
macpan2:::msg_indent(names(m))
msg_colon(
msg("Column", col_nm, "not found in any of the original tables"),
msg_indent(names(m))
) |> stop()
}
if (!any(implicit)) return(x)
Expand Down
Loading

0 comments on commit d9699e7

Please sign in to comment.