sklearn RandomizedSearchCV извлекает путаницу матрицы для разных сгибов

Я пытаюсь вычислить агрегированную матрицу путаницы, чтобы оценить мою модель:

cv_results = cross_validate(estimator, dataset.data, dataset.target, scoring=scoring,
                cv=Config.CROSS_VALIDATION_FOLDS, n_jobs=N_CPUS, return_train_score=False)

Но я не знаю, как извлечь одиночные матрицы из разных сгибов. В бомбардире я могу вычислить это:

scoring = {
'cm': make_scorer(confusion_matrix)
}

, но я не могу вернуть матрицу слияния, потому что она должна возвращать число вместо массива. Если я попробую это, я получу следующую ошибку:

ValueError: scoring must return a number, got [[...]] (<class 'numpy.ndarray'>) instead. (scorer=cm)

Интересно, возможно ли сохранить матрицы путаницы в глобальной переменной, но безуспешно при использовании

global cm_list
cm_list.append(confusion_matrix(y_true,y_pred))

в кастомном бомбардире.

Спасибо заранее за любые советы.

2 ответа

Решение

Проблема состояла в том, что я не мог получить доступ к оценщику после того, как RandomizedSearchCV был закончен, потому что я не знал, что RandomizedSearchCV реализует метод прогнозирования. Вот мое личное решение:

r_search = RandomizedSearchCV(estimator=estimator, param_distributions=param_distributions,
                          n_iter=n_iter, cv=cv, scoring=scorer, n_jobs=n_cpus,
                          refit=next(iter(scorer)))
r_search.fit(X, y_true)
y_pred = r_search.predict(X)
cm = confusion_matrix(y_true, y_pred)

Чтобы вернуть матрицу путаницы для каждого сгиба, вы можете вызывать confusion_matrix из модулей метрик на каждой итерации (сгиб), что даст вам массив в качестве вывода. В качестве входных данных будут использоваться значения y_true и y_predict, полученные для каждого сгиба.

from sklearn import metrics
print metrics.confusion_matrix(y_true,y_predict)
array([[327582, 264313],
       [167523, 686735]])

Кроме того, если вы используете панд, то у панд есть модуль кросс-таблицы

df_conf = pd.crosstab(y_true,y_predict,rownames=['Actual'],colnames=['Predicted'],margins=True)
print df_conf

Predicted       0       1     All
Actual                           
  0          332553   58491  391044
  1           97283  292623  389906
  All        429836  351114  780950 
Другие вопросы по тегам