Skip to content

Commit

Permalink
Fix GQ not using correct max value (#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
gowerc authored Apr 11, 2024
1 parent fa90fba commit a273133
Show file tree
Hide file tree
Showing 14 changed files with 166 additions and 57 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ S3method(as_stan_list,default)
S3method(autoplot,LongitudinalQuantities)
S3method(autoplot,SurvivalQuantities)
S3method(brierScore,SurvivalQuantities)
S3method(coalesceGridTime,GridFixed)
S3method(coalesceGridTime,GridGrouped)
S3method(coalesceGridTime,default)
S3method(compileStanModel,JointModel)
S3method(dim,Quantities)
S3method(enableLink,LongitudinalGSF)
Expand Down
17 changes: 15 additions & 2 deletions R/GridFixed.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ GridFixed <- function(subjects = NULL, times = NULL) {
#' @rdname Quant-Dev
#' @export
as.QuantityGenerator.GridFixed <- function(object, data, ...) {

assert_class(data, "DataJoint")
data_list <- as.list(data)
subjects <- unlist(as.list(object, data = data), use.names = FALSE)
time_grid <- expand_time_grid(object@times, max(data_list[["tumour_time"]]))

validate_time_grid(object@times)
pt_times <- expand.grid(
pt = subjects,
time = time_grid,
time = object@times,
stringsAsFactors = FALSE
)

Expand All @@ -58,3 +59,15 @@ as.QuantityCollapser.GridFixed <- function(object, data, ...) {
as.list.GridFixed <- function(x, data, ...) {
subjects_to_list(x@subjects, data)
}

#' @rdname coalesceGridTime
#' @export
coalesceGridTime.GridFixed <- function(object, times, ...) {
if (is.null(object@times)) {
object <- GridFixed(
subjects = object@subjects,
times = times
)
}
object
}
17 changes: 15 additions & 2 deletions R/GridGrouped.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,12 @@ as.QuantityCollapser.GridGrouped <- function(object, data, ...) {
assert_that(
all(unique(unlist(object@groups)) %in% names(data_list$subject_to_index))
)
time_grid <- expand_time_grid(object@times, max(data_list[["tumour_time"]]))

validate_time_grid(object@times)

group_grid <- expand.grid(
group = names(object@groups),
time = time_grid,
time = object@times,
stringsAsFactors = FALSE
)

Expand Down Expand Up @@ -99,3 +100,15 @@ as.QuantityCollapser.GridGrouped <- function(object, data, ...) {
as.list.GridGrouped <- function(x, ...) {
x@groups
}

#' @rdname coalesceGridTime
#' @export
coalesceGridTime.GridGrouped <- function(object, times, ...) {
if (is.null(object@times)) {
object <- GridGrouped(
groups = object@groups,
times = times
)
}
object
}
8 changes: 8 additions & 0 deletions R/LongitudinalQuantities.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ LongitudinalQuantities <- function(
assert_class(object, "JointModelSamples")
assert_class(grid, "Grid")

time_grid <- seq(
from = 0,
to = max(as.list(object@data)[["tumour_time"]]),
length = 201
)

grid <- coalesceGridTime(grid, time_grid)

gq <- generateQuantities(
object,
generator = as.QuantityGenerator(grid, object@data),
Expand Down
8 changes: 8 additions & 0 deletions R/SurvivalQuantities.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ SurvivalQuantities <- function(
assert_class(object, "JointModelSamples")
assert_class(grid, "Grid")

time_grid <- seq(
from = 0,
to = max(as.list(object@data)[["event_times"]]),
length = 201
)

grid <- coalesceGridTime(grid, time_grid)

generator <- as.QuantityGenerator(grid, object@data)

assert_that(
Expand Down
18 changes: 18 additions & 0 deletions R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,21 @@ as.QuantityGenerator <- function(object, ...) {
as.QuantityCollapser <- function(object, ...) {
UseMethod("as.QuantityCollapser")
}


#' Coalesce Time
#'
#' @param object ([`Grid`]) \cr object to coalesce time for.
#' @param times (`numeric`) \cr the times to coalesce to.
#' @param ... Not used
#'
#' Method used to replace NULL times on grid objects (if appropriate)
#'
#' @keywords internal
coalesceGridTime <- function(object, times, ...) {
UseMethod("coalesceGridTime")
}
#' @export
coalesceGridTime.default <- function(object, times, ...) {
object
}
27 changes: 12 additions & 15 deletions R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -205,23 +205,20 @@ is_windows <- function() {
return(sysname == "Windows")
}

#' `expand_time_grid`
#' `validate_time_grid`
#'
#' This function expands a given time grid by setting a default grid if one hasn't been provided
#' and then verifying it's properties.
#' The grid must be finite, sorted, and contain unique values.
#' Validate that the provided time grid is:
#' - finite
#' - numeric
#' - non-missing
#' - sorted
#' - unique
#'
#' @param time_grid (`numeric`)\cr A vector of times which quantities will be
#' evaluated at.
#'
#' @param time_grid (`numeric` or `NULL`)\cr A vector of times which quantities will be
#' evaluated at. If NULL, a default grid will be created as a length 201 vector spanning
#' from 0 to `max_time`.
#' @param max_time (`numeric``)\cr Specifies the maximum time to be used in creating the default grid.
#' @return Returns the expanded time_grid.
#' @keywords internal
expand_time_grid <- function(time_grid, max_time) {
default_grid <- seq(from = 0, to = max_time, length = 201)
if (is.null(time_grid)) {
time_grid <- default_grid
}
validate_time_grid <- function(time_grid) {
assert_that(
!any(is.na(time_grid)),
is.numeric(time_grid),
Expand All @@ -231,7 +228,7 @@ expand_time_grid <- function(time_grid, max_time) {
all(is.finite(time_grid)),
msg = "`time_grid` needs to be finite, sorted, unique valued numeric vector"
)
time_grid
invisible(time_grid)
}


Expand Down
27 changes: 27 additions & 0 deletions man/coalesceGridTime.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 0 additions & 24 deletions man/expand_time_grid.Rd

This file was deleted.

23 changes: 23 additions & 0 deletions man/validate_time_grid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions tests/testthat/test-Grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,22 @@ test_that("subjects_to_list works as expected", {
regex = "Not all subjects exist within the data object"
)
})


test_that("coalesceGridTime() works as expected", {
grid <- GridFixed("A")
grid2 <- coalesceGridTime(grid, c(1, 2, 3))
expect_equal(grid2@times, c(1, 2, 3))

grid <- GridFixed("A", 5)
grid2 <- coalesceGridTime(grid, c(1, 2, 3))
expect_equal(grid2@times, 5)

grid <- GridGrouped(list("A" = "A"))
grid2 <- coalesceGridTime(grid, c(1, 2, 3))
expect_equal(grid2@times, c(1, 2, 3))

grid <- GridGrouped(list("A" = "A"), 5)
grid2 <- coalesceGridTime(grid, c(1, 2, 3))
expect_equal(grid2@times, 5)
})
1 change: 1 addition & 0 deletions tests/testthat/test-LongitudinalQuantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ test_that("Test that LongitudinalQuantities works as expected", {
preds <- summary(longsamps)
expect_equal(nrow(preds), 2 * 201) # 201 default time points for 2 subjects
expect_equal(names(preds), expected_column_names)
expect_equal(max(preds$time), max(test_data_1$dat_lm$time))
})


Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-SurvivalQuantities.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ test_that("SurvivalQuantities and autoplot.SurvivalQuantities works as expected"
test_data_1$jsamples,
grid = GridGrouped(groups = list("pt_0001" = "pt_0001", "pt_0003" = "pt_0003"))
)
preds <- preds <- summary(survsamps)
preds <- summary(survsamps)
expect_equal(nrow(preds), 2 * 201) # 201 default time points for 2 subjects
expect_equal(names(preds), expected_column_names)

expect_equal(max(preds$time), max(test_data_1$dat_os$time))

# Check that the relationship between the quantitites is preservered e.g.
# that `surv = exp(-cumhaz)`
Expand Down
27 changes: 15 additions & 12 deletions tests/testthat/test-utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,38 +146,41 @@ test_that("samples_median_ci works with a custom credibility level", {



test_that("expand_time_grid() works as expected", {
test_that("validate_time_grid() works as expected", {

## Smoke test of basic usage
expect_equal(
expand_time_grid(NULL, 5),
seq(0, 5, length.out = 201)
validate_time_grid(c(1, 2, 3)),
c(1, 2, 3)
)
expect_equal(
expand_time_grid(c(1, 2, 3), 5),
c(1, 2, 3)
validate_time_grid(c(1)),
c(1)
)
expect_equal(
expand_time_grid(c(1, 2, 3)),
c(1, 2, 3)
validate_time_grid(c(1, 2, 30000.3)),
c(1, 2, 30000.3)
)
expect_equal(
validate_time_grid(c(1L, 2L, 4L)),
c(1L, 2L, 4L)
)


## Error handling
expect_error(
expand_time_grid(c(1, 1, 2)),
validate_time_grid(c(1, 1, 2)),
regexp = "`time_grid`"
)
expect_error(
expand_time_grid(c(2, 1, 3)),
validate_time_grid(c(2, 1, 3)),
regexp = "`time_grid`"
)
expect_error(
expand_time_grid(c(1, 3, NA)),
validate_time_grid(c(1, 3, NA)),
regexp = "`time_grid`"
)
expect_error(
expand_time_grid(c(1, 3, -Inf)),
validate_time_grid(c(1, 3, -Inf)),
regexp = "`time_grid`"
)
})
Expand Down

0 comments on commit a273133

Please sign in to comment.