Skip to content

Commit

Permalink
Merge 3f0447a into 85c72e0
Browse files Browse the repository at this point in the history
  • Loading branch information
njtierney authored Mar 31, 2022
2 parents 85c72e0 + 3f0447a commit 1ab817d
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 6 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ S3method(rowMeans,default)
S3method(rowMeans,greta_array)
S3method(rowSums,default)
S3method(rowSums,greta_array)
S3method(sd,default)
S3method(sd,greta_array)
S3method(sign,greta_array)
S3method(simulate,greta_model)
S3method(sin,greta_array)
Expand Down Expand Up @@ -232,6 +234,7 @@ export(rms_prop)
export(rowMeans)
export(rowSums)
export(rwmh)
export(sd)
export(simplex_variable)
export(slice)
export(slsqp)
Expand Down
34 changes: 30 additions & 4 deletions R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
#' prod(..., na.rm = TRUE)
#' min(..., na.rm = TRUE)
#' max(..., na.rm = TRUE)
#' sd(..., na.rm = TRUE)
#'
#' # cumulative operations
#' cumsum(x)
Expand All @@ -77,10 +78,10 @@
#'
#' # miscellaneous operations
#' aperm(x, perm)
#' apply(x, MARGIN, FUN = c("sum", "max", "mean", "min",
#' apply(x, MARGIN, FUN = c("sum", "max", "mean", "min", "sd,
#' "prod", "cumsum", "cumprod"))
#' sweep(x, MARGIN, STATS, FUN = c('-', '+', '/', '*'))
#' tapply(X, INDEX, FUN = c("sum", "max", "mean", "min", "prod"), ...)
#' tapply(X, INDEX, FUN = c("sum", "max", "mean", "min", "sd, "prod"), ...)
#'
#' }
#'
Expand Down Expand Up @@ -590,6 +591,31 @@ mean.greta_array <- function(x, trim = 0, na.rm = TRUE, ...) { # nolint
)
}

# need to define sd as a generic since it isn't actually a generic
#' @rdname overloaded
#' @export
sd <- function(x, ...) UseMethod("sd", x)

# setting default and setting arguments for it so it passes package check
#' @export
sd.default <- function(x, na.rm = FALSE, ...) {
sd_result <- stats::sd(x = x,
na.rm = na.rm)
formals(sd.default) <- c(formals(sd.default), alist(... =))

sd_result
}

#' @export
sd.greta_array <- function(x, na.rm = TRUE, ...) { # nolint

# calculate SD on greta array
op("sd", x,
dim = c(1, 1),
tf_operation = "tf_sd"
)
}

#' @export
max.greta_array <- function(..., na.rm = TRUE) { # nolint

Expand Down Expand Up @@ -1073,7 +1099,7 @@ apply.default <- function(X, MARGIN, FUN, ...) { # nolint
#' @export
apply.greta_array <- function(X, MARGIN,
FUN = c(
"sum", "max", "mean", "min", "prod",
"sum", "max", "mean", "min", "sd", "prod",
"cumsum", "cumprod"
),
...) {
Expand Down Expand Up @@ -1169,7 +1195,7 @@ tapply.default <- function(X, INDEX, FUN = NULL, ...,
# nolint start
#' @export
tapply.greta_array <- function(X, INDEX,
FUN = c("sum", "max", "mean", "min", "prod"),
FUN = c("sum", "max", "mean", "min", "prod", "sd"),
...) {
# nolint end

Expand Down
19 changes: 19 additions & 0 deletions R/tf_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,25 @@ tf_mean <- function(x, drop = FALSE) {
skip_dim("reduce_mean", x, drop)
}

# need to create a "reduce_sd" function
# which

tf_sd <- function(x, drop = FALSE){

n_dim <- length(dim(x))
reduction_dims <- seq_len(n_dim - 1)

# replace these parts with tf_sum and friends?
x_mean_sq <- tf_mean(x, drop = drop) * tf_mean(x, drop = drop)
total_ss <- tf_sum(x - x_mean_sq, drop = drop)
n_denom <- prod(dim(x)[reduction_dims + 1])
var <- total_ss / fl(n_denom - 1)
sd_result <- tf$math$sqrt(var)

sd_result
}


tf_max <- function(x, drop = FALSE) {
skip_dim("reduce_max", x, drop)
}
Expand Down
5 changes: 3 additions & 2 deletions man/functions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/overloaded.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions tests/testthat/test_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ test_that("simple functions work as expected", {
check_op(mean, x)
check_op(sqrt, exp(x))
check_op(sign, x)
check_op(sd, x)

# rounding of numbers
check_op(ceiling, x)
Expand Down Expand Up @@ -230,6 +231,7 @@ test_that("apply works as expected", {
check_apply(a, margin, "prod")
check_apply(a, margin, "cumsum")
check_apply(a, margin, "cumprod")
check_apply(a, margin, "sd")
}
})

Expand All @@ -243,6 +245,7 @@ test_that("tapply works as expected", {
check_expr(tapply(x, rep(1:5, each = 3), "mean"))
check_expr(tapply(x, rep(1:5, each = 3), "min"))
check_expr(tapply(x, rep(1:5, each = 3), "prod"))
check_expr(tapply(x, rep(1:5, each = 3), "sd"))
})

test_that("cumulative functions error as expected", {
Expand Down

0 comments on commit 1ab817d

Please sign in to comment.