Shap статистика

Я использовал shap определить важность признаков для множественной регрессии с коррелированными признаками.

import numpy as np
import pandas as pd  
from sklearn.linear_model import LinearRegression
from sklearn.datasets import load_boston
import shap


boston = load_boston()
regr = pd.DataFrame(boston.data)
regr.columns = boston.feature_names
regr['MEDV'] = boston.target

X = regr.drop('MEDV', axis = 1)
Y = regr['MEDV']

fit = LinearRegression().fit(X, Y)

explainer = shap.LinearExplainer(fit, X, feature_dependence = 'independent')
# I used 'independent' because the result is consistent with the ordinary 
# shapely values where `correlated' is not

shap_values = explainer.shap_values(X)

shap.summary_plot(shap_values, X, plot_type = 'bar')

shap предлагает график, чтобы получить значения Shap. Есть ли статистика? Я заинтересован в точных значениях Shap. Я прочитал Github-репозиторий и документацию, но ничего не нашел по этой теме.

1 ответ

Решение

Когда мы смотрим на shap_values мы видим, что он содержит некоторые положительные и отрицательные числа, а его размеры равны размерам boston набор данных. Линейная регрессия - это алгоритм ML, который вычисляет оптимальный y = wx + b, где y это MEDV, x вектор функции и w это вектор весов. По моему мнению, shap_values магазины wx - матрица со значением каждого feauture, умноженным на вектор весов, рассчитанный с помощью линейной регрессии.

Поэтому, чтобы вычислить требуемую статистику, я сначала извлек абсолютные значения, а затем усреднил их. Порядок важен! Далее я использовал исходные имена столбцов и сортировал их от наибольшего эффекта к наименьшему. С этим, я надеюсь, я ответил на ваш вопрос!:)

from matplotlib import pyplot as plt


#rataining only the size of effect
shap_values_abs = np.absolute(shap_values)

#dividing to get good numbers
means_norm = shap_values_abs.mean(axis = 0)/1e-15

#sorting values and names
idx = np.argsort(means_norm)
means = np.array(means_norm)[idx]
names = np.array(boston.feature_names)[idx]

#plotting
plt.figure(figsize=(10,10))
plt.barh(names, means)

Другие вопросы по тегам