Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed improvements #206

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
14 changes: 6 additions & 8 deletions R/IRWLS.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,11 @@ solveIRWLS.weights <-function(S,B,nUMI, OLS=FALSE, constrain = TRUE, verbose = F
#solution <- runif(length(solution))*2 / length(solution) # random initialization
names(solution) <- colnames(S)

S_mat <<- matrix(0,nrow = dim(S)[1],ncol = dim(S)[2]*(dim(S)[2] + 1)/2)
counter = 1
for(i in 1:dim(S)[2])
for(j in i:dim(S)[2]) {
S_mat[,counter] <<- S[,i] * S[,j] # depends on n^2
counter <- counter + 1
}
numCols <- ncol(S)
Index <- which(upper.tri(matrix(0, ncol = numCols, nrow = numCols), diag = TRUE), arr.ind = TRUE)
Index <- Index[order(Index[, 1], Index[, 2]), ,drop=F]
S_mat <<- S[, Index[, 1]] * S[, Index[, 2]]


iterations<-0 #now use dampened WLS, iterate weights until convergence
changes<-c()
Expand Down Expand Up @@ -83,7 +81,7 @@ solveWLS<-function(S,B,initialSol, nUMI, bulk_mode = F, constrain = F){
threshold = max(1e-4, nUMI * 1e-7)
prediction[prediction < threshold] <- threshold
gene_list = rownames(S)
derivatives <- get_der_fast(S, B, gene_list, prediction, bulk_mode = bulk_mode)
derivatives <- get_der_fast(S, B, gene_list, prediction[,1], bulk_mode = bulk_mode)
d_vec <- -derivatives$grad
D_mat <- psd(derivatives$hess)
norm_factor <- norm(D_mat,"2")
Expand Down
127 changes: 127 additions & 0 deletions R/platform_effect_normalization.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,130 @@ choose_sigma_c <- function(RCTD) {
RCTD@internal_vars$X_vals <- X_vals
return(RCTD)
}

#' Estimates sigma_c by maximum likelihood (multi-core)
#'
#' @param RCTD an \code{\linkS4class{RCTD}} object after running the \code{\link{fitBulk}} function.
#' @return Returns an \code{\linkS4class{RCTD}} with the estimated \code{sigma_c}.
#' @export
choose_sigma_mc<-function(RCTD)
{
message('Step 2/4: Choose Sigma')
puck <- RCTD@spatialRNA
MIN_UMI <- RCTD@config$UMI_min_sigma
sigma <- 100

Q1 <- readRDS(system.file("extdata", "Qmat/Q_mat_1.rds", package = "spacexrHD"))
Q2 <- readRDS(system.file("extdata", "Qmat/Q_mat_2.rds", package = "spacexrHD"))
Q3 <- readRDS(system.file("extdata", "Qmat/Q_mat_3.rds", package = "spacexrHD"))
Q4 <- readRDS(system.file("extdata", "Qmat/Q_mat_4.rds", package = "spacexrHD"))
Q5 <- readRDS(system.file("extdata", "Qmat/Q_mat_5.rds", package = "spacexrHD"))

Q_mat_all <- c(Q1, Q2, Q3, Q4, Q5)
sigma_vals <- names(Q_mat_all)

X_vals <- readRDS(system.file("extdata", "Qmat/X_vals.rds", package = "spacexrHD"))

#get initial classification
N_fit = min(RCTD@config$N_fit,sum(puck@nUMI > MIN_UMI))
if(N_fit == 0) {
stop(paste('choose_sigma_c determined a N_fit of 0! This is probably due to unusually low UMI counts per bead in your dataset. Try decreasing the parameter UMI_min_sigma. It currently is',MIN_UMI,'but none of the beads had counts larger than that.'))
}

fit_ind = sample(names(puck@nUMI[puck@nUMI > MIN_UMI]), N_fit)
beads = t(puck@counts[RCTD@internal_vars$gene_list_reg,fit_ind])

#message(paste('chooseSigma: using initial Q_mat with sigma = ',sigma/100))
#print(paste0("N_epoch: ",RCTD@config$N_epoch))

nUMI <- puck@nUMI[fit_ind]
cell_type_means <- RCTD@cell_type_info$renorm[[1]]
gene_list <- RCTD@internal_vars$gene_list_reg
constrain <- FALSE
max_cores <- RCTD@config$max_cores

if(max_cores > 1)
{
message(paste0("Multicore enabled using ", max_cores," cores"))
registerDoParallel(cores=max_cores)
}

NN<-nrow(beads)
pb <- txtProgressBar(min = 0, max = RCTD@config$N_epoch, style = 3)

for(iter in 1:RCTD@config$N_epoch)
{
set_likelihood_vars(Q_mat_all[[as.character(sigma)]], X_vals)

if(max_cores>1)
{

results<- foreach(i = 1:NN) %dopar% {

#set_likelihood_vars(Q_mat_all[[as.character(sigma)]], X_vals)
weights <- solveIRWLS.weights(data.matrix(RCTD@cell_type_info$renorm[[1]][RCTD@internal_vars$gene_list_reg,]*nUMI[i]),
beads[i,],
nUMI[i],
OLS = FALSE,
constrain = FALSE,
verbose = FALSE,
n.iter = 50,
MIN_CHANGE = 0.001,
bulk_mode = FALSE)

return(weights)
}


}else{

results<-vector("list",length=nrow(beads))

for(i in 1:nrow(beads))
{
set_likelihood_vars(Q_mat_all[[as.character(sigma)]], X_vals)
weights <- solveIRWLS.weights(data.matrix(RCTD@cell_type_info$renorm[[1]][RCTD@internal_vars$gene_list_reg,]*nUMI[i]),
beads[i,],
nUMI[i],
OLS = FALSE,
constrain = FALSE,
verbose = FALSE,
n.iter = 50,
MIN_CHANGE = 0.001,
bulk_mode = FALSE)
results[[i]]<-weights

}


}

weights<- do.call(rbind,lapply(results,function(X){return(X$weights)}))
weights<-as(weights,"dgCMatrix")
rownames(weights) <- fit_ind
colnames(weights) <- RCTD@cell_type_info$renorm[[2]]
prediction <- sweep(as.matrix(RCTD@cell_type_info$renorm[[1]][RCTD@internal_vars$gene_list_reg,]) %*% t(as.matrix(weights)), 2, puck@nUMI[fit_ind], '*')
#message(paste('Likelihood value:',calc_log_l_vec(as.vector(prediction), as.vector(t(beads)))))
sigma_prev <- sigma
sigma <- chooseSigma(prediction, t(beads), Q_mat_all, X_vals, sigma)

if(sigma == sigma_prev)
{
message(paste0(RCTD@config$N_epoch,"/",RCTD@config$N_epoch))
break
}

setTxtProgressBar(pb, iter)
}

setTxtProgressBar(pb, iter)

close(pb)
RCTD@internal_vars$sigma <- sigma/100
RCTD@internal_vars$Q_mat <- Q_mat_all[[as.character(sigma)]]
RCTD@internal_vars$X_vals <- X_vals

return(RCTD)


}
65 changes: 35 additions & 30 deletions R/postProcessing.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,44 @@

# Collects RCTD results
gather_results <- function(RCTD, results) {
cell_type_names = RCTD@cell_type_info$renorm[[2]]

message('Step 4/4: Gather Results')
pb <- txtProgressBar(max = 3, style = 3)

cell_type_names <- RCTD@cell_type_info$renorm[[2]]
barcodes <- colnames(RCTD@spatialRNA@counts)
N <- length(results)
weights = Matrix(0, nrow = N, ncol = length(cell_type_names))
weights_doublet = Matrix(0, nrow = N, ncol = 2)
rownames(weights) = barcodes; rownames(weights_doublet) = barcodes
colnames(weights) = cell_type_names; colnames(weights_doublet) = c('first_type', 'second_type')
empty_cell_types = factor(character(N),levels = cell_type_names)

empty_cell_types <- factor(character(N),levels = cell_type_names)
spot_levels <- c("reject", "singlet", "doublet_certain", "doublet_uncertain")
results_df <- data.frame(spot_class = factor(character(N),levels=spot_levels),
first_type = empty_cell_types, second_type = empty_cell_types,
first_class = logical(N), second_class = logical(N),
min_score = numeric(N), singlet_score = numeric(N),
conv_all = logical(N), conv_doublet = logical(N))
score_mat <- list()
singlet_scores <- list()
for(i in 1:N) {
if(i %% 1000 == 0)
print(paste("gather_results: finished",i))
weights_doublet[i,] = results[[i]]$doublet_weights
weights[i,] = results[[i]]$all_weights
results_df[i, "spot_class"] = results[[i]]$spot_class
results_df[i, "first_type"] = results[[i]]$first_type
results_df[i, "second_type"] = results[[i]]$second_type
results_df[i, "first_class"] = results[[i]]$first_class
results_df[i, "second_class"] = results[[i]]$second_class
results_df[i, "min_score"] = results[[i]]$min_score
results_df[i, "singlet_score"] = results[[i]]$singlet_score
results_df[i, "conv_all"] = results[[i]]$conv_all
results_df[i, "conv_doublet"] = results[[i]]$conv_doublet
score_mat[[i]] <- results[[i]]$score_mat
singlet_scores[[i]] <- results[[i]]$singlet_scores
}

setTxtProgressBar(pb, 1)

results_df <- data.frame(spot_class = factor(sapply(results,function(X){return(X$spot_class)}),levels=spot_levels),
first_type = sapply(results,function(X){return(X$first_type)}),
scond_type = sapply(results,function(X){return(X$second_type)}),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a typo of 'second_type' here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching that!

first_class = sapply(results,function(X){return(X$first_class)}),
second_class = sapply(results,function(X){return(X$second_class)}),
min_score = sapply(results,function(X){return(X$min_score)}),
singlet_score = sapply(results,function(X){return(X$singlet_score)}),
conv_all = sapply(results,function(X){return(X$conv_all)}),
conv_doublet = sapply(results,function(X){return(X$conv_doublet)}))

setTxtProgressBar(pb, 2)

weights_doublet <- do.call(rbind,lapply(results,function(X){return(X$doublet_weights)}))
weights <- do.call(rbind,lapply(results,function(X){return(X$all_weights)}))

rownames(weights) <- barcodes
rownames(weights_doublet) <- barcodes
colnames(weights) <- cell_type_names
colnames(weights_doublet) <- c('first_type', 'second_type')

score_mat <- lapply(results,function(X){return(X$score_mat)})
singlet_scores <- lapply(results,function(X){return(X$singlet_scores)})

setTxtProgressBar(pb, 3)

rownames(results_df) = barcodes
RCTD@results <- list(results_df = results_df, weights = weights, weights_doublet = weights_doublet,
score_mat = score_mat, singlet_scores = singlet_scores)
Expand Down