-
Notifications
You must be signed in to change notification settings - Fork 2
/
cv.MLKNN.R
36 lines (31 loc) · 1.22 KB
/
cv.MLKNN.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# compute the MLKNN result
# train: matrix that contains features
# cv = 3 cross validation number
# k = 5 number of knn
# smoothing = 1 laplace smoothing
# ignore.nearest=T whether to ignore the nearest instance
MLKNN<-function(train = NULL, train.label = NULL, cv = 3, k = 5, smoothing = 1, ignore.nearest=T){
if(is.null(train) | is.null(train.label)){
stop("train set and test set cannot be NULL")
}
M<-ncol(train)
N<-nrow(train)
folds<-caret::createFolds(1:N, k = cv)
final<-lapply(folds, function(x){
cv.train<-train[-x,]
label.cv.train<-train.label[-x,]
cv.test<-train[x,]
label.cv.test<-train.label[x,]
model<-MLKNN(train = cv.train, train.label = label.cv.train, test = cv.test,
k = k, smoothing = smoothing, ignore.nearest = ignore.nearest)
out<-list()
model.cutoff<-ifelse(model >= 0.5, 1, 0)
out[["hamming_loss"]]<-HammingLoss(label.cv.test, model.cutoff)
out[["one_error"]]<-OneError(label.cv.test, model)
out[["coverage"]]<-Coverage(label.cv.test, model)
out[["average_precision"]]<-AveragePrecision(label.cv.test, model)
out[["ranking_loss"]]<-RankingLoss(label.cv.test, model)
return(out)
})
return(as.matrix(unlist(final)))
}