Возможная ошибка в функции Caret Предсказание.gb()?

Мне кажется, что я обнаружил ошибку в выполнении функции предиката () для method=gbm в пакете Caret в R. Мне любопытно узнать, согласны ли другие, или у кого-то есть объяснение поведения этой функции.

1. Генерация данных

library(caret)

x1 <- rnorm(100)

x2 <- rnorm(100, 2)

y <- x1 + x2 + rnorm(100)

df <- data.frame(x1=x1, x2=x2,  y=y)

2. Прогнозирование с использованием method="lm"

Следующий код работает должным образом: используя method="lm", два предсказанных значения совпадают. В первом случае p1, "y" включены в новые данные, во втором случае p2 это не так.

tempd <- df[1:99, c("y", "x1", "x2") ]

newdata <- df[100, c("y", "x1", "x2")]

lm.fit <- train(y~x1 + x2, data=tempd, method="lm")

p1 <- predict(lm.fit$finalModel, newdata=newdata)

newdata <- df[100, c("x1", "x2")]

p2 <- predict(lm.fit$finalModel, newdata=newdata)

р1 должен равняться р2 и делает:

p1==p2

3. Прогнозирование с использованием method="gbm"

Этот код не работает должным образом: при использовании method="gbm" с одинаковой настройкой два прогнозируемых значения не совпадают.

tempd <- df[1:99, c("y","x1","x2")]

newdata <- df[100, c("y","x1","x2")]

gbm.fit <- train(y~x1+x2 , data=tempd, method="gbm", verbose=F)

p1 <- predict(gbm.fit$finalModel, newdata=newdata,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                       
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)

newdata <- df[100, c("x1","x2")]

p2 <- predict(gbm.fit$finalModel, newdata=newdata,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                  
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)

В этом случае p1 не равно p2:

p1==p2

4. Прогнозирование с использованием method="gbm" с другой настройкой

НО, как ни странно, с одним небольшим изменением - без явного именования переменных в операции подмножества - это работает:

tempd <- df[1:99, ]

newdata <- df[100, ]

gbm.fit <- train(y~x1+x2 , data=tempd, method="gbm", verbose=F)

p1 <- predict(gbm.fit$finalModel, newdata=newdata,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                                         
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)

newdata <- df[100, c("x1","x2")]

p2 <- predict(gbm.fit$finalModel, newdata=newdata,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                  
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)

p1==p2

Заранее спасибо за наши мысли.

Джефф

1 ответ

Как отметил @Pascal, вы пропускаете важный шаг. Вместо того, чтобы звонить predict() на значение finalModel, вы должны звонить predict на gmb.fit объект напрямую. Заметка

class(gbm.fit)
# [1] "train"         "train.formula"
class(gbm.fit$finalModel)
# [1] "gbm"

Поскольку эти объекты имеют разные классы, они запускают разные базовые функции прогнозирования. Важной частью является то, что predict.train перекраивает newdata в правильный формат для gbm предсказатель. Без изменения этих данных вы получите неверные результаты (предиктор ожидает, что столбцы будут в определенном порядке)

соблюдать

newdata1 <- df[100, c("y","x1","x2")]
newdata2 <- df[100, c("x1","x2")]
newdata3 <- df[100, ]

predict(gbm.fit, newdata1)
# [1] 1.427069
predict(gbm.fit, newdata2)
# [1] 1.427069
predict(gbm.fit, newdata3)
# [1] 1.427069

predict(gbm.fit$finalModel, newdata=newdata1,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                  
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
# [1] 2.166468
predict(gbm.fit$finalModel, newdata=newdata2,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                  
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
# [1] 1.427069
predict(gbm.fit$finalModel, newdata=newdata3,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                  
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
# [1] 1.427069

Так что если вы собираетесь использовать train() функции, чтобы соответствовать вашей модели, обязательно используйте правильную predict.train функция, чтобы правильно делать прогнозы из модели.

Другие вопросы по тегам