sklearn RandomizedSearchCV извлекает путаницу матрицы для разных сгибов
Я пытаюсь вычислить агрегированную матрицу путаницы, чтобы оценить мою модель:
cv_results = cross_validate(estimator, dataset.data, dataset.target, scoring=scoring,
cv=Config.CROSS_VALIDATION_FOLDS, n_jobs=N_CPUS, return_train_score=False)
Но я не знаю, как извлечь одиночные матрицы из разных сгибов. В бомбардире я могу вычислить это:
scoring = {
'cm': make_scorer(confusion_matrix)
}
, но я не могу вернуть матрицу слияния, потому что она должна возвращать число вместо массива. Если я попробую это, я получу следующую ошибку:
ValueError: scoring must return a number, got [[...]] (<class 'numpy.ndarray'>) instead. (scorer=cm)
Интересно, возможно ли сохранить матрицы путаницы в глобальной переменной, но безуспешно при использовании
global cm_list
cm_list.append(confusion_matrix(y_true,y_pred))
в кастомном бомбардире.
Спасибо заранее за любые советы.
2 ответа
Проблема состояла в том, что я не мог получить доступ к оценщику после того, как RandomizedSearchCV был закончен, потому что я не знал, что RandomizedSearchCV реализует метод прогнозирования. Вот мое личное решение:
r_search = RandomizedSearchCV(estimator=estimator, param_distributions=param_distributions,
n_iter=n_iter, cv=cv, scoring=scorer, n_jobs=n_cpus,
refit=next(iter(scorer)))
r_search.fit(X, y_true)
y_pred = r_search.predict(X)
cm = confusion_matrix(y_true, y_pred)
Чтобы вернуть матрицу путаницы для каждого сгиба, вы можете вызывать confusion_matrix из модулей метрик на каждой итерации (сгиб), что даст вам массив в качестве вывода. В качестве входных данных будут использоваться значения y_true и y_predict, полученные для каждого сгиба.
from sklearn import metrics
print metrics.confusion_matrix(y_true,y_predict)
array([[327582, 264313],
[167523, 686735]])
Кроме того, если вы используете панд, то у панд есть модуль кросс-таблицы
df_conf = pd.crosstab(y_true,y_predict,rownames=['Actual'],colnames=['Predicted'],margins=True)
print df_conf
Predicted 0 1 All
Actual
0 332553 58491 391044
1 97283 292623 389906
All 429836 351114 780950