Возможная ошибка в функции Caret Предсказание.gb()?
Мне кажется, что я обнаружил ошибку в выполнении функции предиката () для method=gbm в пакете Caret в R. Мне любопытно узнать, согласны ли другие, или у кого-то есть объяснение поведения этой функции.
1. Генерация данных
library(caret)
x1 <- rnorm(100)
x2 <- rnorm(100, 2)
y <- x1 + x2 + rnorm(100)
df <- data.frame(x1=x1, x2=x2, y=y)
2. Прогнозирование с использованием method="lm"
Следующий код работает должным образом: используя method="lm", два предсказанных значения совпадают. В первом случае p1, "y" включены в новые данные, во втором случае p2 это не так.
tempd <- df[1:99, c("y", "x1", "x2") ]
newdata <- df[100, c("y", "x1", "x2")]
lm.fit <- train(y~x1 + x2, data=tempd, method="lm")
p1 <- predict(lm.fit$finalModel, newdata=newdata)
newdata <- df[100, c("x1", "x2")]
p2 <- predict(lm.fit$finalModel, newdata=newdata)
р1 должен равняться р2 и делает:
p1==p2
3. Прогнозирование с использованием method="gbm"
Этот код не работает должным образом: при использовании method="gbm" с одинаковой настройкой два прогнозируемых значения не совпадают.
tempd <- df[1:99, c("y","x1","x2")]
newdata <- df[100, c("y","x1","x2")]
gbm.fit <- train(y~x1+x2 , data=tempd, method="gbm", verbose=F)
p1 <- predict(gbm.fit$finalModel, newdata=newdata,
n.trees=gbm.fit$finalModel$tuneValue$n.trees,
interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
newdata <- df[100, c("x1","x2")]
p2 <- predict(gbm.fit$finalModel, newdata=newdata,
n.trees=gbm.fit$finalModel$tuneValue$n.trees,
interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
В этом случае p1 не равно p2:
p1==p2
4. Прогнозирование с использованием method="gbm" с другой настройкой
НО, как ни странно, с одним небольшим изменением - без явного именования переменных в операции подмножества - это работает:
tempd <- df[1:99, ]
newdata <- df[100, ]
gbm.fit <- train(y~x1+x2 , data=tempd, method="gbm", verbose=F)
p1 <- predict(gbm.fit$finalModel, newdata=newdata,
n.trees=gbm.fit$finalModel$tuneValue$n.trees,
interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
newdata <- df[100, c("x1","x2")]
p2 <- predict(gbm.fit$finalModel, newdata=newdata,
n.trees=gbm.fit$finalModel$tuneValue$n.trees,
interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
p1==p2
Заранее спасибо за наши мысли.
Джефф
1 ответ
Как отметил @Pascal, вы пропускаете важный шаг. Вместо того, чтобы звонить predict()
на значение finalModel, вы должны звонить predict
на gmb.fit
объект напрямую. Заметка
class(gbm.fit)
# [1] "train" "train.formula"
class(gbm.fit$finalModel)
# [1] "gbm"
Поскольку эти объекты имеют разные классы, они запускают разные базовые функции прогнозирования. Важной частью является то, что predict.train
перекраивает newdata
в правильный формат для gbm
предсказатель. Без изменения этих данных вы получите неверные результаты (предиктор ожидает, что столбцы будут в определенном порядке)
соблюдать
newdata1 <- df[100, c("y","x1","x2")]
newdata2 <- df[100, c("x1","x2")]
newdata3 <- df[100, ]
predict(gbm.fit, newdata1)
# [1] 1.427069
predict(gbm.fit, newdata2)
# [1] 1.427069
predict(gbm.fit, newdata3)
# [1] 1.427069
predict(gbm.fit$finalModel, newdata=newdata1,
n.trees=gbm.fit$finalModel$tuneValue$n.trees,
interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
# [1] 2.166468
predict(gbm.fit$finalModel, newdata=newdata2,
n.trees=gbm.fit$finalModel$tuneValue$n.trees,
interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
# [1] 1.427069
predict(gbm.fit$finalModel, newdata=newdata3,
n.trees=gbm.fit$finalModel$tuneValue$n.trees,
interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
# [1] 1.427069
Так что если вы собираетесь использовать train()
функции, чтобы соответствовать вашей модели, обязательно используйте правильную predict.train
функция, чтобы правильно делать прогнозы из модели.