From ac8f1d8c805d2bbab293c9351f8f2b0eaecbfe68 Mon Sep 17 00:00:00 2001 From: damirpolat Date: Fri, 22 Dec 2023 23:36:39 -0700 Subject: [PATCH] fix: issue #40 --- R/LearnerClustAffinityPropagation.R | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/R/LearnerClustAffinityPropagation.R b/R/LearnerClustAffinityPropagation.R index e8d9c863..3f486af5 100644 --- a/R/LearnerClustAffinityPropagation.R +++ b/R/LearnerClustAffinityPropagation.R @@ -61,9 +61,10 @@ LearnerClustAP = R6Class("LearnerClustAP", private = list( .train = function(task) { pv = self$param_set$get_values(tags = "train") - m = invoke(apcluster::apcluster, x = task$data(), .args = pv) + d = task$data() + m = invoke(apcluster::apcluster, x = d, .args = pv) # add data points corresponding to examplars - attributes(m)$exemplar_data = task$data()[m@exemplars, ] + attributes(m)$exemplar_data = d[m@exemplars, ] if (self$save_assignments) { self$assignments = apcluster::labels(m, type = "enum") @@ -75,8 +76,9 @@ LearnerClustAP = R6Class("LearnerClustAP", sim_func = self$param_set$values$s exemplar_data = attributes(self$model)$exemplar_data - sim_mat = sim_func(rbind(exemplar_data, task$data()), - sel = (1:nrow(task$data())) + + d = task$data() + sim_mat = sim_func(rbind(exemplar_data, d), + sel = (1:nrow(d)) + nrow(exemplar_data))[1:nrow(exemplar_data), ] partition = unname(apply(sim_mat, 2, which.max)) PredictionClust$new(task = task, partition = partition)