Scikit Learn Предсказание одного наблюдения
Вероятно, это действительно глупый вопрос, но почему следующее дает разные результаты?
X == array([ 7.84682988e-01, 3.80109225e-17, 8.06386582e-01,
1.00000000e+00, 5.71428571e-01, 4.44189342e+00])
model.predict_proba(X)[1] # gives array([ 0.35483244, 0.64516756])
model.predict_proba(X[1]) # gives an error
model.predict_proba(list(X[1])) # gives array([[ 0.65059327, 0.34940673]])
Model
это LGBMClassifier
из библиотеки lightgbm.
1 ответ
Давайте разберем его на простые шаги для анализа:
1) model.predict_proba(X)[1]
Это эквивалентно
probas = model.predict_proba(X)
probas[1]
Таким образом, это сначала выводит вероятности всех классов для всех выборок. Допустим, ваш X содержит 5 строк и 4 объекта с двумя разными классами.
Таким образом, проба будет что-то вроде этого:
Prob of class 0, prob of class 1
For sample1 [[0.1, 0.9],
For sample2 [0.8, 0.2],
For sample3 [0.85, 0.15],
For sample4 [0.4, 0.6],
For sample5 [0.01, 0.99]]
probas[1]
просто выведет вероятности для второго столбца вашего probas
вывод, т.е. вероятность 1 класса.
Output [0.9, 0.2, 0.15, 0.6, 0.99]
Две другие строки кода зависят от реализации и версии того, как обрабатывать одномерный массив. Например, Scikit v18 показывает только предупреждение и рассматривает его как одну строку. Но v19 (основная ветка) выдает ошибку.
РЕДАКТИРОВАТЬ: Обновлено для LGBMClassifier
2) model.predict_proba(X[1])
Это эквивалентно:
X_new = X[1]
model.predict_proba(X_new)
Здесь вы выбираете только второй ряд, который приводит к форме [n_features, ]
, Но LGBMClassifier требует, чтобы двумерные данные имели форму [n_samples, n_features]
, Это может быть возможным источником ошибки, как указано выше. Вы можете изменить форму данного массива, чтобы иметь 1 вместо n_samples:
model.predict_proba(X[1].reshape(1, -1))
# Будет работать правильно
3) model.predict_proba(list(X[1]))
Это можно разбить на:
X_new = list(X[1])
model.predict_proba(X_new)
Это также в основном то же, что 2-й, просто X_new
теперь список вместо массива numpy и автоматически обрабатывается как одна строка (так же, как X[1].reshape(1, -1)
во 2-м случае) вместо выкидывания ошибки.
Таким образом, учитывая приведенный выше пример, результат будет только
For sample2 [0.8, 0.2],