Skip to content

Commit

Permalink
Add support for dts in all functions where that makes sense (issue #76).
Browse files Browse the repository at this point in the history
  • Loading branch information
fabrice-rossi committed Mar 27, 2024
1 parent d680ffd commit fa1c3a8
Show file tree
Hide file tree
Showing 45 changed files with 941 additions and 148 deletions.
7 changes: 7 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ S3method(contexts,vlmc)
S3method(contexts,vlmc_cpp)
S3method(counts,ctx_node)
S3method(counts,ctx_node_cpp)
S3method(covlmc,default)
S3method(covlmc,dts)
S3method(ctx_tree,default)
S3method(ctx_tree,dts)
S3method(cutoff,covlmc)
S3method(cutoff,ctx_node)
S3method(cutoff,vlmc)
Expand Down Expand Up @@ -126,6 +130,7 @@ S3method(prune,vlmc_cpp)
S3method(restore_model,ctx_tree_cpp)
S3method(restore_model,vlmc_cpp)
S3method(rev,ctx_node)
S3method(rev,dts)
S3method(simulate,covlmc)
S3method(simulate,vlmc)
S3method(simulate,vlmc_cpp)
Expand All @@ -142,6 +147,8 @@ S3method(trim,covlmc)
S3method(trim,ctx_tree)
S3method(trim,vlmc)
S3method(trim,vlmc_cpp)
S3method(vlmc,default)
S3method(vlmc,dts)
export(as_covlmc)
export(as_sequence)
export(as_vlmc)
Expand Down
10 changes: 10 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@
by `draw()` to display a model

## New features
### Discrete time series class (`dts`)
In order to ease the introduction of multiple time series support, a single
discrete time series can now be represented by the `dts` class, via the `dts()`
function (see issue #76). All functions that use discrete time series now accept
objects of this class in addition to simple vectors of a supported type
(`integer`, `factor`, `character` and `logical`). This applies to model
estimation functions such as `vlmc()` or `covlmc()`, to model selection functions
(e.g. `tune_vlmc`) but also to functions that use new data such
`loglikelihood()` and `predict.vlmc()`.

### Model representation (with `draw()`)
A major change of `draw()` is the support of multiple output formats. This is
done via a `format` parameter. It supports currently:
Expand Down
58 changes: 48 additions & 10 deletions R/covlmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,8 @@ covlmc_control <- function(pseudo_obs = 1) {
#' This function fits a Variable Length Markov Chain with covariates (coVLMC)
#' to a discrete time series coupled with a time series of covariates.
#'
#' @param x a discrete time series; can be numeric, character, factor or logical.
#' @param x an object that can be interpreted as a discrete time series, such
#' as an integer vector or a `dts` object (see [dts()])
#' @param covariate a data frame of covariates.
#' @param alpha number in (0,1) (default: 0.05) cut off value in the pruning
#' phase (in quantile scale).
Expand Down Expand Up @@ -645,7 +646,9 @@ covlmc_control <- function(pseudo_obs = 1) {
#' @export
#' @examples
#' pc <- powerconsumption[powerconsumption$week == 5, ]
#' rdts <- cut(pc$active_power, breaks = c(0, quantile(pc$active_power, probs = c(1 / 3, 2 / 3, 1))))
#' rdts <- cut(pc$active_power, breaks = c(0, quantile(pc$active_power,
#' probs = c(1 / 3, 2 / 3, 1)
#' )))
#' rdts_cov <- data.frame(day_night = (pc$hour >= 7 & pc$hour <= 17))
#' m_cov <- covlmc(rdts, rdts_cov, min_size = 15)
#' draw(m_cov)
Expand All @@ -655,17 +658,52 @@ covlmc_control <- function(pseudo_obs = 1) {
#' )
#' draw(m_cov_nnet)
#' @seealso [cutoff.covlmc()] and [prune.covlmc()] for post-pruning.
covlmc <- function(x, covariate, alpha = 0.05, min_size = 5L, max_depth = 100L, keep_data = TRUE, control = covlmc_control(...), ...) {
covlmc <- function(x, covariate, alpha = 0.05, min_size = 5L, max_depth = 100L,
keep_data = TRUE, control = covlmc_control(...), ...) {
UseMethod("covlmc")
}

#' @export
#' @export
#' @param x a numeric, character, factor or logical vector
#' @inherit covlmc
covlmc.default <- function(x, covariate, alpha = 0.05, min_size = 5L, max_depth = 100L,
keep_data = TRUE, control = covlmc_control(...), ...) {
x_dts <- dts(x)
covlmc_internal(
x_dts$ix, x_dts$fx, x_dts$vals, covariate, alpha, min_size, max_depth,
keep_data, control
)
}

#' @export
#' @param x a discrete time series represented by a `dts` object as created by
#' [dts()]
#' @inherit covlmc
#' @examples
#' pc <- powerconsumption[powerconsumption$week == 5, ]
#' power_dts <- dts(cut(pc$active_power, breaks = c(0, quantile(pc$active_power,
#' probs = c(1 / 3, 2 / 3, 1)
#' ))))
#' power_cov <- data.frame(day_night = (pc$hour >= 7 & pc$hour <= 17))
#' m_cov <- covlmc(power_dts, power_cov, min_size = 15)
#' draw(m_cov)
covlmc.dts <- function(x, covariate, alpha = 0.05, min_size = 5L, max_depth = 100L,
keep_data = TRUE, control = covlmc_control(...), ...) {
covlmc_internal(
x$ix, x$fx, x$vals, covariate, alpha, min_size, max_depth,
keep_data, control
)
}

covlmc_internal <- function(ix, fx, vals, covariate, alpha, min_size, max_depth,
keep_data, control) {
assertthat::assert_that(is.data.frame(covariate))
assertthat::assert_that(nrow(covariate) == length(x))
assertthat::assert_that(nrow(covariate) == length(ix))
if (is.null(alpha) || !is.numeric(alpha) || alpha <= 0 || alpha > 1) {
stop("the alpha parameter must be in (0, 1]")
}
# data conversion
nx <- to_dts(x)
ix <- nx$ix
vals <- nx$vals
if (length(vals) > max(10, 0.05 * length(x))) {
if (length(vals) > max(10, 0.05 * length(ix))) {
warning(paste0("x as numerous unique values (", length(vals), ")"))
}
## covariate preparation
Expand All @@ -681,7 +719,7 @@ covlmc <- function(x, covariate, alpha = 0.05, min_size = 5L, max_depth = 100L,
covsize = desc$cov_size, keep_match = TRUE, all_children = TRUE
)
if (length(vals) > 2) {
x <- nx$fx
x <- fx
} else {
x <- ix
}
Expand Down
20 changes: 9 additions & 11 deletions R/covlmc_likelihood.R
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ logLik.covlmc <- function(object, initial = c("truncated", "specific", "extended
#' attributes(ll_new)
#'
#' @export
loglikelihood.covlmc <- function(vlmc, newdata, initial = c("truncated", "specific", "extended"),
loglikelihood.covlmc <- function(vlmc, newdata,
initial = c("truncated", "specific", "extended"),
ignore, newcov, ...) {
initial <- match.arg(initial)
if (missing(ignore)) {
Expand Down Expand Up @@ -243,9 +244,7 @@ loglikelihood.covlmc <- function(vlmc, newdata, initial = c("truncated", "specif
if (isTRUE(vlmc$trimmed == "full")) {
stop("loglikelihood calculation for new data is not supported by fully trimmed covlmc")
}
assertthat::assert_that((typeof(newdata) == typeof(vlmc$vals)) && methods::is(newdata, class(vlmc$vals)),
msg = "newdata is not compatible with the model state space"
)
newdata <- convert_with_check(newdata, vlmc$vals, "newdata")
assertthat::assert_that(!missing(newcov),
msg = "Need new covariate values (newcov) with new data (newdata)"
)
Expand All @@ -256,23 +255,22 @@ loglikelihood.covlmc <- function(vlmc, newdata, initial = c("truncated", "specif
assertthat::assert_that(nrow(newcov) == length(newdata))
data_size <- length(newdata)
newcov <- validate_covariate(vlmc, newcov)
nx <- to_dts(newdata, vlmc$vals)
ncovlmc <- match_ctx(vlmc, nx$ix, keep_match = TRUE)
ncovlmc <- match_ctx(vlmc, newdata$ix, keep_match = TRUE)
if (length(vlmc$vals) > 2) {
newdata <- nx$fx
nx <- newdata$fx
} else {
newdata <- nx$ix
nx <- newdata$ix
}
ignore_counts <- ignore
if (initial == "specific" && ignore < depth(vlmc)) {
ignore <- depth(vlmc)
}
res <- rec_loglikelihood_covlmc_newdata(ncovlmc, 0, length(vlmc$vals), newdata, newcov)
res <- rec_loglikelihood_covlmc_newdata(ncovlmc, 0, length(vlmc$vals), nx, newcov)
if (ignore > 0) {
icovlmc <- match_ctx(vlmc, nx$ix[1:min(ignore, length(newdata))], keep_match = TRUE)
icovlmc <- match_ctx(vlmc, newdata$ix[1:min(ignore, length(newdata))], keep_match = TRUE)
delta_res <- rec_loglikelihood_covlmc_newdata(
icovlmc, 0, length(vlmc$vals),
newdata[1:min(ignore, length(newdata))],
nx[1:min(ignore, length(newdata))],
newcov[1:min(ignore, length(newdata)), , drop = FALSE]
)
res <- res - delta_res
Expand Down
7 changes: 1 addition & 6 deletions R/covlmc_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,7 @@ predict.covlmc <- function(object, newdata, newcov, type = c("raw", "probs"),
if (missing(newdata) || is.null(newdata)) {
stop("newdata must be provided.")
}
assertthat::assert_that(
(typeof(newdata) == typeof(object$vals)) &&
methods::is(newdata, class(object$vals)),
msg = "newdata is not compatible with the model state space"
)
nx <- to_dts(newdata, object$vals)
nx <- convert_with_check(newdata, object$vals, "newdata")
x <- nx$ix + 1
if (missing(newcov) || is.null(newcov)) {
stop("newcov must be provided.")
Expand Down
8 changes: 3 additions & 5 deletions R/covlmc_simulate.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ match_context_co <- function(tree, ctx) {
#' @param nsim length of the simulated time series (defaults to 1).
#' @param seed an optional random seed (see the dedicated section).
#' @param covariate values of the covariates.
#' @param init an optional initial sequence for the time series.
#' @param init an optional initial sequence for the time series given by an object
#' that can be interpreted as a discrete time series.
#' @param ... additional arguments.
#'
#' @section Extended contexts:
Expand Down Expand Up @@ -90,11 +91,8 @@ simulate.covlmc <- function(object, nsim = 1, seed = NULL, covariate, init = NUL
}
int_vals <- seq_along(object$vals)
if (!is.null(init)) {
assertthat::assert_that((typeof(init) == typeof(object$vals)) && methods::is(init, class(object$vals)),
msg = "init is not compatible with the model state space"
)
assertthat::assert_that(length(init) <= nsim, msg = "too many initial values")
init_dts <- to_dts(init, object$vals)
init_dts <- convert_with_check(init, object$vals, "init")
ctx <- rev(init_dts$ix)[1:(min(max_depth, length(init)))] + 1
istart <- 1 + length(init)
} else {
Expand Down
20 changes: 15 additions & 5 deletions R/covlmc_tune.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
#' collection, including the initial complex tree, as the one that minimizes the
#' chosen information criterion.
#'
#' @param x a discrete time series; can be numeric, character, factor and
#' logical.
#' @param x an object that can be interpreted as a discrete time series, such
#' as an integer vector or a `dts` object (see [dts()]).
#' @param covariate a data frame of covariates.
#' @param criterion criterion used to select the best model. Either `"BIC"`
#' (default) or `"AIC"` (see details).
Expand Down Expand Up @@ -102,6 +102,10 @@ tune_covlmc <- function(x, covariate, criterion = c("BIC", "AIC"),
save <- match.arg(save)
criterion <- match.arg(criterion)
trimming <- match.arg(trimming)
## make sure that x is a dts
if (!is_dts(x)) {
x <- dts(x)
}
if (criterion == "BIC") {
f_criterion <- stats::BIC
} else {
Expand All @@ -119,15 +123,21 @@ tune_covlmc <- function(x, covariate, criterion = c("BIC", "AIC"),
cat("Fitting a covlmc with max_depth=", max_depth, "and alpha=", alpha, "\n")
}
saved_models <- list()
base_model <- covlmc(x, covariate, alpha = alpha, min_size = min_size, max_depth = max_depth)
base_model <- covlmc(x, covariate,
alpha = alpha, min_size = min_size,
max_depth = max_depth
)
while (base_model$max_depth) {
n_max_depth <- min(2 * max_depth, length(x) - 1)
if (n_max_depth > max_depth) {
if (verbose > 0) {
cat("Max depth reached, increasing it to", n_max_depth, "\n")
}
max_depth <- n_max_depth
base_model <- covlmc(x, covariate, alpha = alpha, min_size = min_size, max_depth = max_depth)
base_model <- covlmc(x, covariate,
alpha = alpha, min_size = min_size,
max_depth = max_depth
)
} else {
warning("cannot find a suitable value for max_depth")
break
Expand All @@ -146,7 +156,7 @@ tune_covlmc <- function(x, covariate, criterion = c("BIC", "AIC"),
repeat {
if (initial == "truncated") {
ll <- loglikelihood(model,
newdata = x, initial = "truncated",
newdata = dts_data(x), initial = "truncated",
ignore = max_order, newcov = covariate
)
} else {
Expand Down
5 changes: 1 addition & 4 deletions R/ctx_node.R
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,10 @@ find_sequence.ctx_tree <- function(ct, ctx, reverse = FALSE, ...) {
}
new_ctx_node(ctx, ct, ct, reverse)
} else {
assertthat::assert_that((typeof(ctx) == typeof(ct$vals)) && methods::is(ctx, class(ct$vals)),
msg = "ctx is not compatible with the model state space"
)
if (!reverse) {
ctx <- rev(ctx)
}
nx <- to_dts(ctx, ct$vals)
nx <- convert_with_check(ctx, ct$vals, "ctx")
current <- ct
for (k in seq_along(ctx)) {
if (is.null(current$children)) {
Expand Down
5 changes: 1 addition & 4 deletions R/ctx_node_covlmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,10 @@ find_sequence.covlmc <- function(ct, ctx, reverse = FALSE, ...) {
}
new_ctx_node(ctx, ct, ct, reverse, merged = FALSE, class = "ctx_node_covlmc")
} else {
assertthat::assert_that((typeof(ctx) == typeof(ct$vals)) && methods::is(ctx, class(ct$vals)),
msg = "ctx is not compatible with the model state space"
)
if (!reverse) {
ctx <- rev(ctx)
}
nx <- to_dts(ctx, ct$vals)
nx <- convert_with_check(ctx, ct$vals, "ctx")
current <- ct
## first part
for (k in seq_along(ctx[-length(ctx)])) {
Expand Down
7 changes: 2 additions & 5 deletions R/ctx_node_cpp.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ restore_ctx_node_cpp <- function(node) {
restore_model(node$tree)
if (extptr_is_null(node$node_env$node)) {
if (length(node$sequence) > 0) {
nx <- to_dts(node$sequence, node$tree$vals)
nx <- dts(node$sequence, node$tree$vals)
node$node_env$node <- node$tree$root$raw_find_sequence(nx$ix)
} else {
node$node_env$node <- node$tree$root$raw_find_sequence(integer())
Expand Down Expand Up @@ -64,13 +64,10 @@ find_sequence.ctx_tree_cpp <- function(ct, ctx, reverse = FALSE, ...) {
root <- ct$root$raw_find_sequence(integer())
new_ctx_node_cpp(ctx, ct, root, reverse)
} else {
assertthat::assert_that((typeof(ctx) == typeof(ct$vals)) && methods::is(ctx, class(ct$vals)),
msg = "ctx is not compatible with the model state space"
)
if (!reverse) {
ctx <- rev(ctx)
}
nx <- to_dts(ctx, ct$vals)
nx <- convert_with_check(ctx, ct$vals, "ctx")
node <- ct$root$raw_find_sequence(nx$ix)
if (extptr_is_null(node)) {
NULL
Expand Down
Loading

0 comments on commit fa1c3a8

Please sign in to comment.