Skip to content

Commit

Permalink
Merge pull request #16 from IDSIA/unbalanced_hier
Browse files Browse the repository at this point in the history
Unbalanced hier
  • Loading branch information
dazzimonti authored Aug 28, 2024
2 parents 49fbfb8 + 2114b50 commit df536b7
Show file tree
Hide file tree
Showing 34 changed files with 1,011 additions and 952 deletions.
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: bayesRecon
Type: Package
Date: 2024-05-29
Date: 2024-08-28
Title: Probabilistic Reconciliation via Conditioning
Version: 0.3.0
Version: 0.3.1
Authors@R: c(person(given = "Dario",
family = "Azzimonti",
role = c("aut","cre"),
Expand Down Expand Up @@ -42,7 +42,7 @@ Imports:
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown=TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
Suggests:
knitr,
rmarkdown,
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# bayesRecon 0.3.1

* IMPORTANT CHANGE IN THE API OF THE `reconc_*` functions: they now require the aggregating matrix A and not the summing matrix S.

* The examples section of the `reconc_TDcond` now contains an example showing how to handle the case of an unbalanced hierarchy.

# bayesRecon 0.3.0

* Added `reconc_MixCond`, the implementation of Mixed conditioning for the reconciliation of mixed-type hierarchical forecasts.
Expand Down
10 changes: 10 additions & 0 deletions R/PMF.R
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,16 @@ PMF.bottom_up = function(l_pmf, toll=.TOLL, Rtoll=.RTOLL, return_all=FALSE,
if (smoothing) l_pmf = lapply(l_pmf, PMF.smoothing,
alpha=al_smooth, laplace=lap_smooth)

# In case we have an upper which is a duplicate of a bottom,
# the bottom up is simply that bottom.
if(length(l_pmf)==1){
if (return_all) {
return(list(l_pmf))
} else {
return(l_pmf[[1]])
}
}

# Doesn't do convolutions sequentially
# Instead, for each iteration (while) it creates a new list of vectors
# by doing convolution between 1 and 2, 3 and 4, ...
Expand Down
41 changes: 27 additions & 14 deletions R/hierarchy.R
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,12 @@ get_reconc_matrices <- function(agg_levels, h) {

# Get A from S
.get_A_from_S <- function(S) {
bottom_idxs = which(rowSums(S) == 1)
# Bottom rows are those with a single 1; if there are replicated bottom rows,
# only one is treated as bottom, the copies will be upper
bottom_idxs = which(rowSums(S) == 1 & !duplicated(S))
if (length(bottom_idxs) < ncol(S)) {
stop("Check S: some bottom rows are missing")
}
upper_idxs = setdiff(1:nrow(S), bottom_idxs)
A = matrix(S[upper_idxs, ], ncol=ncol(S))
out = list(A = A,
Expand Down Expand Up @@ -331,7 +336,7 @@ get_reconc_matrices <- function(agg_levels, h) {
for (i in 1:k) {
for (j in 1:k) {
if (i < j) {
cond1 = A[i,] %*% A[j,] != 0 # Upper i and j have some common descendants
cond1 = c(A[i,] %*% A[j,] != 0) # Upper i and j have some common descendants
cond2 = any(A[j,] > A[i,]) # Upper j is not a descendant of upper i
cond3 = any(A[i,] > A[j,]) # Upper i is not a descendant of upper j
if (cond1 & cond2 & cond3) {
Expand Down Expand Up @@ -371,30 +376,40 @@ get_reconc_matrices <- function(agg_levels, h) {

if (!.check_hierarchical(A)) stop("Matrix A is not hierarchical")

k = nrow(A)
m = ncol(A)
# First, only keep unique rows of A
A_uni = unique(A)

rows = c()
k = nrow(A_uni)
m = ncol(A_uni)

low_rows_A_uni = c()
for (i in 1:k) {
rows = c(rows, i)
low_rows_A_uni = c(low_rows_A_uni, i)
for (j in 1:k) {
if (i != j) {
# If upper j is a descendant of upper i, remove i and exit loop
if (all(A[j,] <= A[i,])) {
rows = rows[-length(rows)]
if (all(A_uni[j,] <= A_uni[i,])) {
low_rows_A_uni = low_rows_A_uni[-length(low_rows_A_uni)]
break
}
}
}
}
# keep all rows except those that have no descendants among the uppers

# Now, change the indices of the lowest rows to match with A (instead of A_un)
# If there are duplicated rows for some lowest rows, only take one copy
low_rows_A = (1:nrow(A))[!duplicated(A)][low_rows_A_uni]

# The sum of the rows corresponding to the lowest level should be a vector of 1
if (any(colSums(A[rows,,drop=FALSE])!=1)) {
stop("The hierarchy is not balanced")
if (any(colSums(A[low_rows_A,,drop=FALSE])!=1)) {
unbal_bott = which(colSums(A[low_rows_A,,drop=FALSE])!=1)
err_mess = "It is impossible to find the lowest upper level. Probably the hierarchy is unbalanced, the following bottom should be duplicated (see example): "
err_mess = paste0(c(err_mess, unbal_bott), collapse = " ")
stop(err_mess)
}

return(rows)
return(low_rows_A)
}


Expand All @@ -409,15 +424,13 @@ get_reconc_matrices <- function(agg_levels, h) {
}

A_ = A[-lowest_rows,,drop=FALSE]
n_bott = ncol(A_)
n_upp_u = nrow(A_)
n_bott_u = length(lowest_rows)
A_u = matrix(nrow=n_upp_u, ncol=n_bott_u)
for (j in 1:n_bott_u) {
l = lowest_rows[[j]]
mask = A[l,]==1
for (i in 1:n_upp_u) {
A_u[i,j] = sum(A_[i, mask]==1) == sum(mask) # check that is a vector of 1
A_u[i,j] = all(A[l,] <= A_[i,]) # check that "lower upper" j is a descendant of "upper upper" i
}
}

Expand Down
69 changes: 41 additions & 28 deletions R/reconc_BUIS.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,11 @@
#'
#' @details
#'
#' The parameter `base_forecast` is a list containing n elements where the i-th element depends on
#' the values of `in_type[[i]]` and `distr[[i]]`.
#' The parameter `base_forecast` is a list containing n = n_upper + n_bottom elements.
#' The first n_upper elements of the list are the upper base forecasts, in the order given by the rows of A.
#' The elements from n_upper+1 until the end of the list are the bottom base forecasts, in the order given by the columns of A.
#'
#' The i-th element depends on the values of `in_type[[i]]` and `distr[[i]]`.
#'
#' If `in_type[[i]]`='samples', then `base_forecast[[i]]` is a vector containing samples from the base forecast distribution.
#'
Expand All @@ -74,8 +77,6 @@
#' * size and prob (or mu) for the negative binomial base forecast if `distr[[i]]`='nbinom', see \link[stats]{NegBinomial}.
#'
#' See the description of the parameters `in_type` and `distr` for more details.
#'
#' The order of the `base_forecast` list is given by the order of the time series in the summing matrix.
#'
#' Warnings are triggered from the Importance Sampling step if:
#'
Expand All @@ -86,16 +87,16 @@
#' Note that warnings are an indication that the base forecasts might have issues.
#' Please check the base forecasts in case of warnings.
#'
#' @param S Summing matrix (n x n_bottom).
#' @param A aggregation matrix (n_upper x n_bottom).
#' @param base_forecasts A list containing the base_forecasts, see details.
#' @param in_type A string or a list of length n. If it is a list the i-th element is a string with two possible values:
#' @param in_type A string or a list of length n_upper + n_bottom. If it is a list the i-th element is a string with two possible values:
#'
#' * 'samples' if the i-th base forecasts are in the form of samples;
#' * 'params' if the i-th base forecasts are in the form of estimated parameters.
#'
#' If it `in_type` is a string it is assumed that all base forecasts are of the same type.
#'
#' @param distr A string or a list of length n describing the type of base forecasts.
#' @param distr A string or a list of length n_upper + n_bottom describing the type of base forecasts.
#' If it is a list the i-th element is a string with two possible values:
#'
#' * 'continuous' or 'discrete' if `in_type[[i]]`='samples';
Expand Down Expand Up @@ -123,6 +124,7 @@
#'
#'# Create a minimal hierarchy with 2 bottom and 1 upper variable
#'rec_mat <- get_reconc_matrices(agg_levels=c(1,2), h=2)
#'A <- rec_mat$A
#'S <- rec_mat$S
#'
#'
Expand All @@ -140,21 +142,21 @@
#'sigmas <- c(sigmaY,sigma1,sigma2)
#'
#'base_forecasts = list()
#'for (i in 1:nrow(S)) {
#'for (i in 1:length(mus)) {
#' base_forecasts[[i]] = list(mean = mus[[i]], sd = sigmas[[i]])
#'}
#'
#'
#'#Sample from the reconciled forecast distribution using the BUIS algorithm
#'buis <- reconc_BUIS(S, base_forecasts, in_type="params",
#'buis <- reconc_BUIS(A, base_forecasts, in_type="params",
#' distr="gaussian", num_samples=100000, seed=42)
#'
#'samples_buis <- buis$reconciled_samples
#'
#'#In the Gaussian case, the reconciled distribution is still Gaussian and can be
#'#computed in closed form
#'Sigma <- diag(sigmas^2) #transform into covariance matrix
#'analytic_rec <- reconc_gaussian(S, base_forecasts.mu = mus,
#'analytic_rec <- reconc_gaussian(A, base_forecasts.mu = mus,
#' base_forecasts.Sigma = Sigma)
#'
#'#Compare the reconciled means obtained analytically and via BUIS
Expand All @@ -171,12 +173,12 @@
#'lambdas <- c(lambdaY,lambda1,lambda2)
#'
#'base_forecasts <- list()
#'for (i in 1:nrow(S)) {
#'for (i in 1:length(lambdas)) {
#' base_forecasts[[i]] = list(lambda = lambdas[i])
#'}
#'
#'#Sample from the reconciled forecast distribution using the BUIS algorithm
#'buis <- reconc_BUIS(S, base_forecasts, in_type="params",
#'buis <- reconc_BUIS(A, base_forecasts, in_type="params",
#' distr="poisson", num_samples=100000, seed=42)
#'samples_buis <- buis$reconciled_samples
#'
Expand All @@ -194,7 +196,7 @@
#' [reconc_gaussian()]
#'
#' @export
reconc_BUIS <- function(S,
reconc_BUIS <- function(A,
base_forecasts,
in_type,
distr,
Expand All @@ -203,21 +205,33 @@ reconc_BUIS <- function(S,
seed = NULL) {

if (!is.null(seed)) set.seed(seed)

n_upper = nrow(A)
n_bottom = ncol(A)
n_tot <- length(base_forecasts)

# Transform distr and in_type into lists
if (!is.list(distr)) {
distr = rep(list(distr), nrow(S))
distr = rep(list(distr), n_tot)
}
if (!is.list(in_type)) {
in_type = rep(list(in_type), nrow(S))
in_type = rep(list(in_type), n_tot)
}

# Ensure that data inputs are valid
.check_input_BUIS(S, base_forecasts, in_type, distr)
.check_input_BUIS(A, base_forecasts, in_type, distr)

# Split bottoms, uppers
split_hierarchy.res = .split_hierarchy(S, base_forecasts)
A = split_hierarchy.res$A
# the first nrow(A) elements of base_forecasts are upper
# the second ncol(A) elements of base_forecasts are lower

split_hierarchy.res = list(
A = A,
upper = base_forecasts[1:nrow(A)],
bottom = base_forecasts[(nrow(A)+1):n_tot],
upper_idxs = 1:nrow(A),
bottom_idxs = (nrow(A)+1):n_tot
)
upper_base_forecasts = split_hierarchy.res$upper
bottom_base_forecasts = split_hierarchy.res$bottom

Expand Down Expand Up @@ -249,8 +263,7 @@ reconc_BUIS <- function(S,
}

# Reconciliation using BUIS
n_upper = nrow(A)
n_bottom = ncol(A)

# 1. Bottom samples
B = list()
in_type_bottom = in_type[split_hierarchy.res$bottom_idxs]
Expand All @@ -276,13 +289,13 @@ reconc_BUIS <- function(S,
in_type_ = in_typeH[[hi]],
distr_ = distr_H[[hi]]
)
check_weights.res = .check_weigths(weights)
check_weights.res = .check_weights(weights)
if (check_weights.res$warning & !suppress_warnings) {
warning_msg = check_weights.res$warning_msg
# add information to the warning message
upper_fromS_i = which(lapply(seq_len(nrow(S)), function(i) sum(abs(S[i,] - c))) == 0)
upper_fromA_i = which(lapply(seq_len(nrow(A)), function(i) sum(abs(A[i,] - c))) == 0)
for (wmsg in warning_msg) {
wmsg = paste(wmsg, paste0("Check the upper forecast at index: ", upper_fromS_i,"."))
wmsg = paste(wmsg, paste0("Check the upper forecast at index: ", upper_fromA_i,"."))
warning(wmsg)
}
}
Expand All @@ -304,18 +317,18 @@ reconc_BUIS <- function(S,
distr_ = distr_G[[gi]]
)
}
check_weights.res = .check_weigths(weights)
check_weights.res = .check_weights(weights)
if (check_weights.res$warning & !suppress_warnings) {
warning_msg = check_weights.res$warning_msg
# add information to the warning message
upper_fromS_i = c()
upper_fromA_i = c()
for (gi in 1:nrow(G)) {
c = G[gi, ]
upper_fromS_i = c(upper_fromS_i,
which(lapply(seq_len(nrow(S)), function(i) sum(abs(S[i,] - c))) == 0))
upper_fromA_i = c(upper_fromA_i,
which(lapply(seq_len(nrow(A)), function(i) sum(abs(A[i,] - c))) == 0))
}
for (wmsg in warning_msg) {
wmsg = paste(wmsg, paste0("Check the upper forecasts at index: ", paste0("{",paste(upper_fromS_i, collapse = ","), "}.")))
wmsg = paste(wmsg, paste0("Check the upper forecasts at index: ", paste0("{",paste(upper_fromA_i, collapse = ","), "}.")))
warning(wmsg)
}
}
Expand Down
Loading

0 comments on commit df536b7

Please sign in to comment.