Skip to content

Commit

Permalink
added test for assign_cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
osorensen committed Dec 5, 2023
1 parent da724ba commit c1f94cf
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 7 deletions.
9 changes: 2 additions & 7 deletions R/assign_cluster.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,11 @@ assign_cluster <- function(
if (is.null(burnin)) {
stop("Please specify the burnin.")
}
if (is.null(model_fit$cluster_assignment)) {
stop("No cluster assignments.")
}
stopifnot(burnin < model_fit$nmc)

df <- model_fit$cluster_assignment[model_fit$cluster_assignment$iteration > burnin, , drop = FALSE]
df <- model_fit$cluster_assignment[
model_fit$cluster_assignment$iteration > burnin, , drop = FALSE]

# Compute the probability of each iteration
df <- aggregate(
list(count = df$iteration),
list(assessor = df$assessor, cluster = df$value),
Expand All @@ -67,7 +64,6 @@ assign_cluster <- function(
x
}))

# Compute the MAP estimate per assessor
map <- do.call(rbind, lapply(split(df, f = df$assessor), function(x) {
x <- x[x$probability == max(x$probability), , drop = FALSE]
x <- x[1, , drop = FALSE] # in case of ties
Expand All @@ -76,7 +72,6 @@ assign_cluster <- function(
x
}))

# Join map back onto df
df <- merge(df, map, by = "assessor")

if (!soft) {
Expand Down
48 changes: 48 additions & 0 deletions tests/testthat/test-assign_cluster.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
test_that("assign_cluster fails properly", {
mod <- compute_mallows(
setup_rank_data(potato_visual),
compute_options = set_compute_options(nmc = 10)
)

expect_error(assign_cluster(mod), "Please specify the burnin.")
mod$burnin <- 11
expect_error(assign_cluster(mod), "burnin < model_fit")

})

test_that("assign_cluster works", {
set.seed(123)
mod <- compute_mallows(
setup_rank_data(cluster_data),
model_options = set_model_options(n_clusters = 3),
compute_options = set_compute_options(nmc = 300, burnin = 50)
)

a1 <- assign_cluster(mod, soft = FALSE, expand = FALSE)
expect_equal(dim(a1), c(60, 3))
agg1 <- aggregate(assessor ~ map_cluster, a1, length)
expect_equal(agg1$assessor, c(17, 24, 19))

a2 <- assign_cluster(mod, soft = TRUE, expand = FALSE)
expect_equal(ncol(a2), 4)
agg2 <- aggregate(probability ~ assessor, a2, sum)
expect_equal(mean(agg2$probability), 1)

expect_equal(
dim(assign_cluster(mod, soft = FALSE, expand = TRUE)),
c(60, 3))

a3 <- assign_cluster(mod, soft = TRUE, expand = TRUE)
agg3 <- aggregate(probability ~ assessor, a3, sum)
expect_equal(mean(agg2$probability), 1)

mod <- compute_mallows(
setup_rank_data(cluster_data),
model_options = set_model_options(n_clusters = 3),
compute_options = set_compute_options(nmc = 2, burnin = 1)
)

expect_equal(dim(assign_cluster(mod)), c(60, 4))
expect_equal(dim(assign_cluster(mod, expand = TRUE)), c(180, 4))

})

0 comments on commit c1f94cf

Please sign in to comment.