-
Notifications
You must be signed in to change notification settings - Fork 7
/
model-class-imbalance.Rmd
563 lines (421 loc) · 22.8 KB
/
model-class-imbalance.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
---
layout: page
title: xwMOOC 모형
subtitle: 클래스 불균형(Class imbalance)
output:
html_document:
toc: yes
toc_float: true
highlight: tango
code_folding: show
number_section: true
self_contained: true
editor_options:
chunk_output_type: console
---
``` {r, include=FALSE}
source("tools/chunk-options.R")
knitr::opts_chunk$set(echo = TRUE, warning=FALSE, message=FALSE)
library(igraph)
library(tidyverse)
library(readxl)
library(ggpubr)
library(extrafont)
loadfonts()
```
# 기계학습 클래스 불균형 [^ml-class-imbalance-svds] [^freesearch-class-imbalance] [^analytic-vidhya-class-imbalance] {#ml-class-imbalance-problem}
[^freesearch-class-imbalance]: [예측 모형에서의 클래스 불균형(class imbalance) 문제](http://freesearch.pe.kr/archives/4506)
[^analytic-vidhya-class-imbalance]: [Practical Guide to deal with Imbalanced Classification Problems in R](https://www.analyticsvidhya.com/blog/2016/03/practical-guide-deal-imbalanced-classification-problems/)
[^ml-class-imbalance-svds]: [Learning from Imbalanced Classes, AUGUST 25TH, 2016](https://svds.com/learning-imbalanced-classes/)
기계학습에서 관심있는 예측변수의 클래스가 매우 적은 경우가 흔하다.
- 매년 약 2% 정도 신용카드가 도용되고 있다. 일반적으로 사기탐지(fraud detection) 분야는 2% 보다도 훨씬 적다.
- 의학에서 질병검사도 건강한 일반인을 대상으로 하기 때문에 희귀하다. 예를 들어 미국에서 후천성면역결핍증(HIV) 발병율은 0.4% 정도에 불과하다.
- 하드 디스크 고장율도 매년 약 1% 정도된다.
- 온라인 광고의 전환율은 대략 $10^{-3} ~ 10^{-6}$ 정도 된다.
- 자동화된 생산공장의 불량율도 대략 0.1% 정도다.
이와 같이 관심을 갖고 예측하고자 하는 것이 매우 드문 경우가 빈번하기 때문에 지난 수십년동안 수많은 기계학습 분야에서
수많은 석박사를 배출했고, 따라서 예측에 빈번히 발생하는 **클래스 불균형(class imbalance)** 문제를 처리할 수 있는 방법이 많이 소개되었다.
대응방법은 다음과 같다.
- 아무 것도 하지 않는다.
- 훈련 데이터를 보정한다.
- 과대표집(Over-Sampling)
- 과소표집(Under-Sampling)
- 소수 표본 데이터를 조합해서 생성시킴
- 소수 표본 데이터를 버리고 비정상행위 탐지(anomaly detection framework)로 문제를 바꿔 접근한다.
- 알고리즘 수준에서 미세 조정을 취한다.
- 클래스 가중치(오분류 비용)를 조정
- 컷오프 기준을 조정
- 소수 표본 데이터에 좀더 예민하게 반응하도록 알고리즘을 조정
## 환경설정 {#class-imbalance-import-setup}
[Practical Guide to deal with Imbalanced Classification Problems in R](https://www.analyticsvidhya.com/blog/2016/03/practical-guide-deal-imbalanced-classification-problems/)에서 소개된 방법을 따라
환경을 설정히고 `hacide` 데이터를 준비한다. `ROSE` 팩키지에 포함되어 있는 가공된 데이터로 클래스 불균형 문제를 시작하는데 적절한 데이터로 사료된다.
``` {r class-imbalance-setup}
# 0. 환경설정 --------------
library(ROSE)
library(tidyverse)
library(rpart)
library(caret)
library(ggpubr)
library(extrafont)
loadfonts()
library(plotROC)
# 1. 데이터 가져오기 --------
data(hacide)
hacide.train <- hacide.train %>%
mutate(cls = factor(cls, labels= c("no", "yes")))
hacide.test <- hacide.test %>%
mutate(cls = factor(cls, labels= c("no", "yes")))
```
## 데이터 살펴보기 {#class-imbalance-EDA}
`hacide` 데이터를 시각화를 통해 이해한다.
데이터는 `cls`가 0과 1일 경우 다른 방식으로 생성되었는데 자세한 내용은
[hacide - Half Circle Filled Data](https://www.rdocumentation.org/packages/ROSE/versions/0.0-3/topics/hacide) 웹사이트를 참고한다.
``` {r class-imbalance-EDA}
# 2. 탐색적 데이터 분석 --------
## 2.1. 데이터 시각화
hacide.train %>%
ggplot(aes(x=x1, y=x2, color=cls)) +
geom_point() +
theme_pubr(base_family = "NanumGothic") +
theme(legend.position = "top") +
labs(color = "종속변수(cls)") +
scale_color_manual(values = c("lightblue", "red"))
## 2.2. 데이터 장표
hacide.train %>% count(cls) %>%
mutate(비율 = scales::percent(n/ sum(n)))
```
# 클래스 불균형 극복전략 {#class-imbalance-countermeasure}
클래스 불균형 문제에 대한 극복방법에 대해서 크게 4가지 방법이 제시되고 있다. 물론 `ROSE` 방법론을 옹호하는 입장에서 그렇다.
클래스 불균형 문제를 인식하고 이에 대해 체계적으로 접근할 수 있는 가장 손쉬운 시작점으로 이해하면 좋다.
- 과대표집(Over-Sampling)
- 과소표집(Under-Sampling)
- 양쪽 표집(Both-Sampling)
- 로즈 표집(ROSE Sampling)
클래스 불균형 극복전략을 예측모형과 연관하여 부츠트랩 표본을 생성하고 나서 각 부츠트랩 표본에서 과소표집(Down-sampling)을 통해 클래스 불균형을 해소하고
단순한 예측모형을 적합시킨 후에 다수결 원칙에 의거하여 최종 예측모형을 완성하는 과정을 거친다.
<img src="fig/class-imbalance-framework.png" alt="클래스 불균형 대응 예측모형" width="77%" />
``` {r class-imbalance-countermeasure}
# 3. 클래스 불균형(class imbalance) 극복전략 -----
balanced_over_sampling_df <- ovun.sample(cls ~ ., data = hacide.train, method = "over", N = 1960)$data
balanced_under_sampling_df <- ovun.sample(cls ~ ., data = hacide.train, method = "under", N = 40, seed = 1)$data
balanced_both_sampling_df <- ovun.sample(cls ~ ., data = hacide.train, method = "both", p=0.5, N=1000, seed = 1)$data
rose_df <- ROSE(cls ~ ., data = hacide.train)$data
```
# 클래스 불균형 극복전략 성능비교 {#class-imbalance-countermeasure-performance}
과대표집(Over-Sampling), 과소표집(Under-Sampling), 양쪽 표집(Both-Sampling), 로즈 표집(ROSE Sampling) 그리고 클래스 불균형 극복전략이 없는
경우 포함하여 총 5가지 전략에 대해서 성능을 비교해 보자.
## AUC 성능비교 {#class-imbalance-countermeasure-performance-auc}
AUC 곡선 비교하면 `hacide` 데이터에는 로즈 방법론이 가장 좋은 성능을 나타내고 있다.
``` {r class-imbalance-countermeasure-auc}
# 4. 재귀분할(rpart) 나무모형 적합 --------
raw_rpart <- rpart(cls ~ ., data = hacide.train)
over_rpart <- rpart(cls ~ ., data = balanced_over_sampling_df)
under_rpart <- rpart(cls ~ ., data = balanced_under_sampling_df)
both_rpart <- rpart(cls ~ ., data = balanced_both_sampling_df)
rose_rpart <- rpart(cls ~ ., data = rose_df)
# 5. 클래스 불균형 재귀분할(rpart) 나무모형 평가 --------
## 5.1. 검증데이터 적용 예측
pred_raw_rpart <- predict(raw_rpart , newdata = hacide.test)
pred_over_rpart <- predict(over_rpart , newdata = hacide.test)
pred_under_rpart <- predict(under_rpart, newdata = hacide.test)
pred_both_rpart <- predict(both_rpart , newdata = hacide.test)
pred_rose_rpart <- predict(rose_rpart , newdata = hacide.test)
## 5.2. AUC
roc.curve(hacide.test$cls, pred_raw_rpart[,2], plot=FALSE)
roc.curve(hacide.test$cls, pred_over_rpart[,2], plot=FALSE)
roc.curve(hacide.test$cls, pred_under_rpart[,2], plot=FALSE)
roc.curve(hacide.test$cls, pred_both_rpart[,2], plot=FALSE)
roc.curve(hacide.test$cls, pred_rose_rpart[,2], plot=FALSE)
```
## AUC 성능비교 `plotROC` 시각화 {#class-imbalance-countermeasure-performance-auc-plot}
`ggplot`을 통해 5가지 예측모형의 성능을 살펴본다.
``` {r class-imbalance-countermeasure-auc-plot}
## 5.3. ggplot ROC 데이터
raw_roc_df <- tibble(cls = hacide.test[,1], pred=pred_raw_rpart[,2], sampling="원데이터")
over_roc_df <- tibble(cls = hacide.test[,1], pred=pred_over_rpart[,2], sampling="과대 표집")
under_roc_df <- tibble(cls = hacide.test[,1], pred=pred_under_rpart[,2], sampling="과소 표집")
both_roc_df <- tibble(cls = hacide.test[,1], pred=pred_both_rpart[,2], sampling="양쪽 표집")
rose_roc_df <- tibble(cls = hacide.test[,1], pred=pred_rose_rpart[,2], sampling="ROSE 표집")
hacide_roc_df <- bind_rows(raw_roc_df, over_roc_df) %>%
bind_rows(under_roc_df) %>%
bind_rows(both_roc_df) %>%
bind_rows(rose_roc_df)
## 5.4. ggplot ROC 시각화
ggplot(hacide_roc_df, aes(d = cls, m = pred, color=sampling)) +
geom_roc(labels =FALSE) +
style_roc() +
theme_pubr(base_family="NanumGothic") +
theme(legend.position = "top") +
labs(color="클래스 불균형 해소방법: ")
```
# `caret` 구현 [^ml-class-imbalance-shiring] {#ml-class-imbalance-problem-caret}
[^ml-class-imbalance-shiring]: [Shiring, Dealing with unbalanced data in machine learning](https://shiring.github.io/machine_learning/2017/04/02/unbalanced)
## `caret` 클래스 불균형 대응 {#ml-class-imbalance-caret-counter-measure}
`caret` 팩키지 클래스 불균형 대응 구현된 기능을 활용하여 예측모형 성능을 비교평가한다.
예측모형은 동일하게 `randomForest`를 사용하고 앞서 언급된 클래스 불균형 대응 알고리즘을 반영한다.
``` {r class-imbalance-countermeasure-caret-fit, warning=FALSE, message=FALSE}
# 2. 모형 적합 --------
## 2.1. 원데이터 ------
hacide_ctrl <- trainControl(method = "repeatedcv",
number = 5,
repeats = 5,
verboseIter = FALSE)
model_rf <- train(cls ~ .,
data = hacide.train,
method = "rf",
preProcess = c("scale", "center"),
trControl = hacide_ctrl)
## 2.2. Under-sampling ------
hacide_under_ctrl <- trainControl(method = "repeatedcv",
number = 5,
repeats = 5,
verboseIter = FALSE,
sampling = "down")
model_under_rf <- train(cls ~ .,
data = hacide.train,
method = "rf",
preProcess = c("scale", "center"),
trControl = hacide_under_ctrl)
## 2.3. Up-sampling ------
hacide_up_ctrl <- trainControl(method = "repeatedcv",
number = 5,
repeats = 5,
verboseIter = FALSE,
sampling = "up")
model_up_rf <- train(cls ~ .,
data = hacide.train,
method = "rf",
preProcess = c("scale", "center"),
trControl = hacide_up_ctrl)
## 2.4. ROSE ------
hacide_rose_ctrl <- trainControl(method = "repeatedcv",
number = 5,
repeats = 5,
verboseIter = FALSE,
sampling = "rose")
model_rose_rf <- train(cls ~ .,
data = hacide.train,
method = "rf",
preProcess = c("scale", "center"),
trControl = hacide_rose_ctrl)
## 2.5. SMOTE ------
hacide_smote_ctrl <- trainControl(method = "repeatedcv",
number = 5,
repeats = 5,
verboseIter = FALSE,
sampling = "smote")
model_smote_rf <- train(cls ~ .,
data = hacide.train,
method = "rf",
preProcess = c("scale", "center"),
trControl = hacide_smote_ctrl)
```
## `caret` 클래스 불균형 성능평가 {#ml-class-imbalance-caret-counter-measure-performance}
`AUC` 값을 성능평가 기준으로 삼아 예측모형 성능을 비교한다.
``` {r class-imbalance-countermeasure-caret-performance}
# 3. 성능 비교 ------
## 3.1. 일반적인 성능 평가 지표
hacide_models <- list(original = model_rf,
under = model_under_rf,
over = model_up_rf,
smote = model_smote_rf,
rose = model_rose_rf)
hacide_resampling <- resamples(hacide_models)
bwplot(hacide_resampling)
## 3.2. AUC 지표
pred_rf <- predict(model_rf, newdata = hacide.test)
pred_under_rf <- predict(model_under_rf, newdata = hacide.test)
pred_up_rf <- predict(model_up_rf, newdata = hacide.test)
pred_smote_rf <- predict(model_smote_rf, newdata = hacide.test)
pred_rose_rf <- predict(model_rose_rf, newdata = hacide.test)
roc.curve(hacide.test$cls, pred_rf, plot=FALSE)
roc.curve(hacide.test$cls, pred_under_rf, plot=FALSE)
roc.curve(hacide.test$cls, pred_up_rf, plot=FALSE)
roc.curve(hacide.test$cls, pred_smote_rf, plot=FALSE)
roc.curve(hacide.test$cls, pred_rose_rf, plot=FALSE)
```
## `caret` 클래스 불균형 성능평가 시각화 {#ml-class-imbalance-caret-counter-measure-viz}
`ggplot`으로 ROC 곡선을 도식화하여 성능을 비교한다.
``` {r class-imbalance-countermeasure-caret-performance-viz}
## 3.3. ggplot ROC 데이터
pred_rf <- predict(model_rf, newdata = hacide.test, type="prob")
pred_under_rf <- predict(model_under_rf, newdata = hacide.test, type="prob")
pred_up_rf <- predict(model_up_rf, newdata = hacide.test, type="prob")
pred_smote_rf <- predict(model_smote_rf, newdata = hacide.test, type="prob")
pred_rose_rf <- predict(model_rose_rf, newdata = hacide.test, type="prob")
raw_roc_df <- tibble(cls = hacide.test$cls, pred= pred_rf[,2], sampling="원데이터")
under_roc_df <- tibble(cls = hacide.test$cls, pred=pred_under_rf[,2], sampling="과소 표집")
up_roc_df <- tibble(cls = hacide.test$cls, pred=pred_up_rf[,2], sampling="과대 표집")
smote_roc_df <- tibble(cls = hacide.test$cls, pred=pred_smote_rf[,2], sampling="SMOTE 표집")
rose_roc_df <- tibble(cls = hacide.test$cls, pred=pred_rose_rf[,2], sampling="ROSE 표집")
hacide_caret_roc_df <- bind_rows(raw_roc_df, under_roc_df) %>%
bind_rows(up_roc_df) %>%
bind_rows(smote_roc_df) %>%
bind_rows(rose_roc_df)
## 4.4. ggplot ROC 시각화
ggplot(hacide_caret_roc_df, aes(d = cls, m = pred, color=sampling)) +
geom_roc(labels =FALSE) +
style_roc() +
theme_pubr(base_family="NanumGothic") +
theme(legend.position = "top") +
labs(color="클래스 불균형 해소방법: ")
```
# 향상도(Lift) [^caret-lift] {#ml-class-imbalance-problem-lift}
[^caret-lift]: [caret 17 Measuring Performance, lift](https://topepo.github.io/caret/measuring-performance.html#lift)
향상도(Lift)를 통해 희소한, 관심있는 클래스를 탐지하는데 예측모형에서 나온 표본 중 얼마를 탐지해야 유의미한지 확인한다.
``` {r ml-class-imbalance-lift}
# 5. lift ------------
hacide.test$pred <- predict(model_rose_rf, newdata = hacide.test, type="prob")[, "yes"]
hacide_lift <- caret::lift(cls ~ pred, data = hacide.test, cuts = 100, class="yes")
ggplot(hacide_lift, values=80) +
geom_line(color="blue") +
theme_pubr(base_family = "NanumGothic") +
labs(title = "hacide 데이터에 대한 향상도(lift)",
subtitle = "희귀한 2% 사례 80%를 탐지하기 하는데 약 5% 정도 노력만 필요",
x="탐색해야 될 표본비율(% Samples Tested)", y="탐색된 표본비율(% Samples Found)") +
scale_x_continuous(breaks = seq(0,100,10))
```
# 캐글 사례 [^kaggle-class-imbalance] {#kaggle-casestudy}
[^kaggle-class-imbalance]: [Andrew B. Collier(2018-04-21), "Classification: Get the Balance Right"](https://datawookie.netlify.com/blog/2018/04/classification-get-the-balance-right/)
[Medical Appointment No Shows - Why do 30% of patients miss their scheduled appointments?](https://www.kaggle.com/joniarroba/noshowappointments) 병원약속을 했으나
나타나지 않는 노쇼(No Show) 데이터를 바탕으로 클래스 불균형이 극심한 경우 이를 예측할 수 있는 모형을 개발해보자.
## 데이터 다운로드 및 정제작업 {#kaggle-noshow-data}
먼저 데이터를 다운로드 받아... 변수명에 대한 전처리 작업을 수행한다. `janitor` 팩키지 `clean_names()` 함수를 사용해도 유사한 효과를 기대할 수 있다.
문자형 변수를 요인형 변수로 변환시키고, 날짜 변수에서 유용한 몇가지 피쳐를 추출하고 예측모형에 불필요한 변수는 제거하여
예측모형 데이터프레임을 생성시킨다.
```{r kaggle-noshow-data}
library(lubridate)
## 데이터 가져오기
ns_dat <- read_csv(file = "data/KaggleV2-May-2016.csv")
## 변수명 변환
ns_dat <- ns_dat %>%
setNames(names(.) %>% str_to_lower() %>% str_replace("[.-]", "_")) %>%
dplyr::rename(
hypertension = hipertension,
handicap = handcap
)
## 데이터 정제
ns_df <- ns_dat %>%
mutate_at(vars(gender, neighbourhood:no_show), factor) %>% # 자료형 변환: 문자 --> 요인
select(-patientid, -appointmentid, -neighbourhood) %>% # ID 변수명 제거 및 많은 수준을 갖는 변수 제거
mutate(scheduleddow = wday(scheduledday) %>% factor(),
hour = hour(scheduledday) + (minute(scheduledday) + second(scheduledday) / 60) / 60,
appointmentdow = wday(appointmentday) %>% factor(),
advance = difftime(scheduledday, appointmentday, units = "hours") %>% as.numeric()) %>% # 날짜 데이터에서 피처 추출
select(-scheduledday, -appointmentday) %>%
mutate(no_show = relevel(no_show, "Yes")) %>%
select(no_show, everything())
ns_df %>%
sample_n(100) %>%
DT::datatable()
```
## 예측 모형 개발 {#kaggle-caret-no-show}
훈련/시험 데이터로 데이터를 분할하고, 훈련데이터를 다시 훈련/검증 데이터로 분할시켜 각 예측모형 아키텍처에서 최적의 모형이 개발되도록 한다.
그리고 나서 윈도우 환경에서 `doSNOW` 팩키지를 통해 병렬처리를 위한 클러스터를 생성시키고 나서,
GLM, RF, GBM 모형을 적합시켜 최적의 모형을 추출해 낸다.
```{r kaggle-caret-no-show}
library(caret)
ns_m_df <- ns_df %>%
sample_frac(0.1)
ns_index <- createDataPartition(ns_m_df$no_show, times =1, p=0.3, list=FALSE)
train_df <- ns_m_df[ns_index, ]
test_df <- ns_m_df[-ns_index, ]
## 2.2. 모형 개발/검증 데이터셋 준비 ------
cv_folds <- createMultiFolds(train_df$no_show, k = 10, times = 3)
cv_cntrl <- trainControl(method = "repeatedcv", number = 10,
repeats = 1,
index = cv_folds,
classProbs = TRUE,
summaryFunction = twoClassSummary)
library(doSNOW)
# 실행시간
start.time <- Sys.time()
cl <- makeCluster(8, type = "SOCK")
registerDoSNOW(cl)
ns_glm <- train(no_show ~ ., data = train_df,
method = "glm",
family = "binomial",
metric = "Sens",
trControl = cv_cntrl,
tuneLength = 3)
# ns_rf <- train(no_show ~ ., data = train_df,
# method = "rf",
# metric = "Sens",
# importance = TRUE,
# trControl = cv_cntrl,
# tuneLength = 7)
ns_gbm <- train(no_show ~ ., data = train_df,
method = "xgbTree",
metric = "Sens",
trControl = cv_cntrl,
tuneLength = 1)
stopCluster(cl)
total.time <- Sys.time() - start.time
total.time
```
클래스 불균형으로 예측모형의 성능의 민감도(Sensitivity)가 낮게 나와 이를 클래스를 균형있게 잡아 보정해보자.
```{r kaggle-noshow-class-imbalance}
ns_df %>%
count(no_show) %>%
mutate(pcnt = scales::percent(n / sum(n)))
confusionMatrix(predict(ns_gbm, test_df), test_df$no_show)
```
## 클래스 불균형 보정 [^class-imbalance-caret-book] {#kaggle-caret-no-show-class-balance}
[^class-imbalance-caret-book]: [Max Kuhn 2018-05-26, "The caret Package: Subsampling For Class Imbalances"](https://topepo.github.io/caret/subsampling-for-class-imbalances.html)
DMwR 팩키지 `smote`, ROSE 팩키지 `rose`를 `trainControl()`에 적용시켜,
기존 upsamping, downsampling과 함께 클래스 불균형에 따른 예측모형 성능저하를 보완할 수 있다. [^SMOTE-error]
[^SMOTE-error]: [stackoverflow - "SMOTE length of 'dimnames' [2] not equal to array extent"](https://stackoverflow.com/questions/38616260/smote-length-of-dimnames-2-not-equal-to-array-extent)
```{r kaggle-caret-no-show-smote}
table(train_df$no_show)
smote_train <- DMwR::SMOTE(no_show ~ ., data = as.data.frame(train_df), perc.over = 100, perc.under = 200)
table(smote_train$no_show)
# 실행시간
start.time <- Sys.time()
cl <- makeCluster(8, type = "SOCK")
registerDoSNOW(cl)
ns_balance_glm <- train(no_show ~ ., data = train_df,
method = "glm",
family = "binomial",
metric = "Sens",
trControl = cv_cntrl,
tuneLength = 3)
# ns_balance_rf <- train(no_show ~ ., data = train_df,
# method = "rf",
# metric = "Sens",
# importance = TRUE,
# trControl = cv_cntrl,
# tuneLength = 7)
ns_balance_gbm <- train(no_show ~ ., data = train_df,
method = "xgbTree",
metric = "Sens",
trControl = cv_cntrl,
tuneLength = 3)
stopCluster(cl)
total.time <- Sys.time() - start.time
total.time
```
`rose`, `smote` 클래스 불균형 보완전략을 적용시켜 예측모형의 성능을 파악할 수 있다.
관심있는 것이 전반적인 정확도 보다는 노쇼에 대한 예측이기 때문에 민감도(Sensitivity)를 높일 수 있는 측도에 방점을 두고
예측모형을 개발해 나간다.
```{r smote-implementation}
confusionMatrix(predict(ns_balance_gbm, test_df), test_df$no_show)
```
클래스 보완한 것만으로 전반적인 예측(정확도)은 조금 하락(했으나, 노쇼에 대한 예측은 일정부분 개선된 것을 확인할 수 있다.
```{r smote-class-imbalance-comparison}
orig_conf <- confusionMatrix(predict(ns_gbm, test_df), test_df$no_show)
balance_conf <- confusionMatrix(predict(ns_balance_gbm, test_df), test_df$no_show)
## 민감도 측도 변경
orig_conf_df <- orig_conf$byClass %>% as.data.frame %>%
rownames_to_column(var="측도") %>%
rename(원데이터 = ".") %>%
add_row(측도 = "Accuracy", 원데이터 = orig_conf$overall["Accuracy"])
balance_conf_df <- balance_conf$byClass %>% as.data.frame %>%
rownames_to_column(var="측도") %>%
rename(SMOTE = ".") %>%
add_row(측도 = "Accuracy", SMOTE = balance_conf$overall["Accuracy"])
comp_df <- inner_join(orig_conf_df, balance_conf_df)
comp_df %>%
filter(측도 %in% c("Sensitivity", "Specificity", "Accuracy")) %>%
DT::datatable() %>%
DT::formatRound(c(2:3), digits=3)
```