R-caret: как использовать веса классов вместе с downSample для решения проблемы дисбаланса классов?
У меня очень несбалансированный набор данных. Чтобы справиться с этой проблемой, я пробовал по отдельности разные методы дисбаланса классов: downSample, веса классов, настройка порогов. Среди них настройка порога оказалась наименее эффективной. Используя только downSample или только веса классов, мне не удалось получить достаточно хороших результатов: либо слишком много FalsePositives, либо FalseNegatives. Поэтому я хотел бы объединить две техники. Вот что мне надоело:
# produce some re-producible imbalanced data
set.seed(12345)
y <- as.factor(sample(c("M", "F"),
prob = c(0.1, 0.9),
size = 10000,
replace = TRUE))
x <- rnorm(10000)
DATA <- data.frame(y = as.factor(y), x)
set.seed(12345)
folds <- createFolds(dataSet$y, k = 10,
list = TRUE, returnTrain = TRUE)
# class weights
k <- 0.5
classWeights <- ifelse(DATA$y == "M",
(1/table(DATA$y)[1]) * k,
(1/table(DATA$y)[2]) * (1-k))
поэтому, когда я не ставлю sampling
аргумент в controlTrain
:
# select algorithm
algorithm <- "bayesglm"
# train parameters
set.seed(12345)
traincontrol <- trainControl(method = "loocv", # resampling method
number = 10,
index = folds,
classProbs = TRUE,
summaryFunction = twoClassSummary,
savePredictions = TRUE,
# sampling = "down"
)
fitModel <- train(y ~ .,
data = DATA,
trControl = traincontrol,
method = algorithm,
metric = "ROC",
weights = classWeights,
)
он работает, и ошибки нет. но когда я добавляю аргумент выборки trainControl как
# train parameters
set.seed(12345)
traincontrol <- trainControl(method = "loocv", # resampling method
number = 10,
index = folds,
classProbs = TRUE,
summaryFunction = twoClassSummary,
savePredictions = TRUE,
sampling = "down"
)
fitModel <- train(y ~ .,
data = DATA,
trControl = traincontrol,
method = algorithm,
metric = "ROC",
weights = classWeights,
)
Я получаю эту ошибку, которая понятна:
Error in model.frame.default(formula = .outcome ~ ., data = list(x = c(-0.0640913631047556, :
variable lengths differ (found for '(weights)')
In addition: There were 11 warnings (use warnings() to see them)
Timing stopped at: 0.112 0.001 0.115
Есть ли способ сделать это в caret
? Спасибо заранее.