-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a function that computes a context tree for a collection of time …
…series (issue #30).
- Loading branch information
1 parent
6f88992
commit 4723305
Showing
4 changed files
with
256 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
## Context tree for multiple series | ||
|
||
## insert a new dts into an existing context tree | ||
insert_dts <- function(tree, x, vals, max_depth) { | ||
recurse_insert_dts <- function(tree, x, nb_vals, d, from, f_by) { | ||
if (d < max_depth) { | ||
fmatch <- forward_match_all_ctx_counts(x, nb_vals, d, from) | ||
nb_children <- 0L | ||
if (is.null(tree) || is.null(tree[["children"]])) { | ||
## we are in a leaf of the current context tree | ||
children <- vector(mode = "list", nb_vals) | ||
} else { | ||
children <- tree$children | ||
} | ||
d_max <- FALSE | ||
for (v in 1:nb_vals) { | ||
if (sum(fmatch$counts[v, ]) > 0) { | ||
children[[v]] <- recurse_insert_dts( | ||
tree$children[[v]], x, nb_vals, d + 1L, | ||
fmatch$positions[[v]], fmatch$counts[v, ] | ||
) | ||
nb_children <- nb_children + 1 | ||
} else { | ||
## nothing to do we keep the current children[[v]] | ||
if (!is.null(children[[v]])) { | ||
nb_children <- nb_children + 1 | ||
} else { | ||
## make sure to avoid null content | ||
children[[v]] <- list() | ||
} | ||
} | ||
if (isTRUE(children[[v]]$max_depth)) { | ||
d_max <- TRUE | ||
children[[v]]$max_depth <- NULL | ||
} | ||
} | ||
result <- list( | ||
children = children, | ||
f_by = f_by | ||
) | ||
if (d_max) { | ||
result$max_depth <- TRUE | ||
} | ||
} else { | ||
result <- list(f_by = f_by, max_depth = TRUE) | ||
} | ||
if (!is.null(tree[["f_by"]])) { | ||
result$f_by <- f_by + tree[["f_by"]] | ||
} | ||
result | ||
} | ||
recurse_insert_dts(tree, x, length(vals), 0L, NULL, table(x)) | ||
} | ||
|
||
## min_size based pruning | ||
prune_multi_ctx_tree <- function(tree, min_size) { | ||
if (!is.null(tree[["children"]])) { | ||
nb_pruned <- 0L | ||
subtrees <- vector(mode = "list", length(tree$children)) | ||
for (v in seq_along(tree$children)) { | ||
subtrees[[v]] <- prune_multi_ctx_tree(tree$children[[v]], min_size) | ||
if (length(subtrees[[v]]) == 0) { | ||
nb_pruned <- nb_pruned + 1L | ||
} | ||
} | ||
if (nb_pruned < length(tree$children)) { | ||
tree$children <- subtrees | ||
} else { | ||
tree$children <- NULL | ||
} | ||
} | ||
if (!is.null(tree[["f_by"]])) { | ||
if (sum(tree[["f_by"]]) < min_size) { | ||
list() | ||
} else { | ||
tree | ||
} | ||
} else { | ||
tree | ||
} | ||
} | ||
|
||
#' Build a context tree for a collection of discrete time series | ||
#' | ||
#' This function builds a context tree for a collection of time series. | ||
#' | ||
#' The tree represents all the sequences of symbols/states of length smaller | ||
#' than `max_depth` that appear at least `min_size` times in collection of the | ||
#' time series and stores the frequencies of the states that follow each | ||
#' context. | ||
#' | ||
#' Owing to the iterative nature of construction, this function may use a large | ||
#' quantity of memory as pruning infrequent contexts is only done after | ||
#' computing all of them. It is therefore recommend to avoid large depths. | ||
#' | ||
#' @param xs a list of discrete times series | ||
#' @param min_size integer >= 1 (default: 2). Minimum number of observations for | ||
#' a context to be included in the tree (counted over the full collection of | ||
#' time series, see details) | ||
#' @param max_depth integer >= 1 (default: 100). Maximum length of a context to | ||
#' be included in the tree. | ||
#' @param keep_position logical (default: FALSE). Should the context tree keep | ||
#' the position of the contexts. | ||
#' | ||
#' @return a context tree (of class that inherits from `multi_ctx_tree`). | ||
#' @export | ||
#' | ||
#' @examples | ||
#' dts <- c(0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0) | ||
#' dts2 <- c(0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0) | ||
#' mdts <- list(dts, dts2) | ||
#' mctx <- multi_ctx_tree(mdts, max_depth = 4) | ||
multi_ctx_tree <- function(xs, min_size = 2L, max_depth = 100L, keep_position = FALSE) { | ||
## keep_position = TRUE is not supported currently | ||
assertthat::assert_that(!keep_position) | ||
assertthat::assert_that(is.list(xs)) | ||
assertthat::assert_that(length(xs) >= 1) | ||
nx_1 <- to_dts(xs[[1]]) | ||
ix_1 <- nx_1$ix | ||
vals <- nx_1$vals | ||
if (length(vals) > max(10, 0.05 * length(xs[[1]]))) { | ||
warning(paste0("x[[1]] as numerous unique values (", length(vals), ")")) | ||
} | ||
## we cannot use the original min_size for individual time series | ||
pre_result <- grow_ctx_tree(ix_1, vals, | ||
min_size = 1L, max_depth = max_depth, keep_match = keep_position, | ||
compute_stats = FALSE | ||
) | ||
for (k in 2:length(xs)) { | ||
nx <- to_dts(xs[[k]], vals = vals) | ||
pre_result <- insert_dts(pre_result, nx$ix, vals, max_depth = max_depth) | ||
} | ||
## let use post process the tree to remove rare contexts | ||
if (min_size > 1L) { | ||
pre_result <- prune_multi_ctx_tree(pre_result, min_size) | ||
} | ||
new_ctx_tree(vals, pre_result, compute_stats = TRUE, class = "multi_ctx_tree") | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
test_that("multi_ctx_free finds correct contexts in basic cases", { | ||
dts <- c(0, 1, 1, 1, 0, 0, 1, 0, 1, 0) | ||
## use twice the same dts, so that contexts are identical | ||
mdts <- list(dts, dts) | ||
ctx <- ctx_tree(dts, min_size = 1, max_depth = 4) | ||
mctx <- multi_ctx_tree(mdts, min_size = 1, max_depth = 4) | ||
expect_true(compare_ctx(contexts(ctx), contexts(mctx))) | ||
ctx <- ctx_tree(dts, min_size = 2, max_depth = 4) | ||
mctx <- multi_ctx_tree(mdts, min_size = 4, max_depth = 4) | ||
expect_true(compare_ctx(contexts(ctx), contexts(mctx))) | ||
}) | ||
|
||
test_that("multi_ctx_free obeys its basic contract", { | ||
withr::local_seed(5) | ||
nb_dts <- 10L | ||
dts_bsize <- 20L | ||
mdts <- vector(mode = "list", length = nb_dts) | ||
for (k in seq_along(mdts)) { | ||
mdts[[k]] <- sample(c(1L, 2L), dts_bsize + sample(1:5, 1), replace = TRUE) | ||
} | ||
for (d in 2:6) { | ||
mctx <- multi_ctx_tree(mdts, min_size = 2, max_depth = d) | ||
expect_equal(depth(mctx), d) | ||
} | ||
}) | ||
|
||
test_that("multi_ctx_free finds correct contexts in more complex cases", { | ||
withr::local_seed(0) | ||
nb_dts <- 10L | ||
dts_bsize <- 20L | ||
mdts <- vector(mode = "list", length = nb_dts) | ||
for (k in seq_along(mdts)) { | ||
mdts[[k]] <- sample(c(1L, 2L), dts_bsize + sample(1:5, 1), replace = TRUE) | ||
} | ||
mctx <- multi_ctx_tree(mdts, min_size = 2, max_depth = 4) | ||
## check that each context is indeed present with the correct f_by | ||
mctx_ctx <- contexts(mctx, frequency = "detailed") | ||
for (k in seq_along(mctx_ctx$context)) { | ||
expect_equal( | ||
as.integer(mctx_ctx[k, 3:4]), | ||
multi_count_f_by(mdts, mctx_ctx$context[[k]], states(mctx)) | ||
) | ||
} | ||
}) | ||
|
||
test_that("multi_ctx_free finds all contexts", { | ||
withr::local_seed(42) | ||
nb_dts <- 10L | ||
dts_bsize <- 10L | ||
mdts <- vector(mode = "list", length = nb_dts) | ||
for (k in seq_along(mdts)) { | ||
mdts[[k]] <- sample(c(1L, 2L), dts_bsize + sample(1:5, 1), replace = TRUE) | ||
} | ||
mctx <- multi_ctx_tree(mdts, min_size = 2, max_depth = 10) | ||
m_ctxs <- contexts(mctx, sequence = TRUE)$context | ||
## any context found in individual sequences must appear at least as the | ||
## suffix of a context in the tree | ||
for (k in seq_along(mdts)) { | ||
base_ctx_tree <- ctx_tree(mdts[[k]], min_size = 1, max_depth = 10) | ||
base_ctxs <- contexts(base_ctx_tree, sequence = TRUE)$context | ||
all_true <- TRUE | ||
for (l in seq_along(base_ctxs)) { | ||
the_f_by <- multi_count_f_by(mdts, base_ctxs[[l]], states(mctx)) | ||
if (sum(the_f_by) >= 2) { | ||
pos_in <- Position(\(x) ends_with(x, base_ctxs[[l]]), m_ctxs, nomatch = 0) | ||
if (pos_in == 0) { | ||
all_true <- FALSE | ||
break | ||
} | ||
} | ||
} | ||
expect_true(all_true) | ||
} | ||
}) |