Потеря индивидуального шарнира против потери шарнира sklearn

Я написал специальную функцию потерь шарнира на основе формулы потерь шарнира и протестировал ее на наборе данных, а также на sklearn.metrics шарнире_loss. Но результаты настолько разные.

Может ли кто-нибудь взглянуть и сообщить мне, что я здесь делаю не так?

Итак, прежде всего это данные:

      from sklearn.datasets import load_breast_cancer

data = load_breast_cancer()

from sklearn.svm import SVC
from sklearn.metrics import hinge_loss

features = data.data
labels = label.reshape(-1,1)

feat_train, feat_test,labels_train,labels_test = train_test_split(features,labels,test_size=0.2)

sample_weight = np.random.rand(len(labels_test))
clf = SVC(kernel='linear', C=1.0, random_state=42)
clf.fit(feat_train, labels_train)
y_pred = clf.predict(feat_test)
loss = hinge_loss(labels_test, y_pred, sample_weight=sample_weight) 

print("Hinge loss:", loss)

Результат:

Потеря шарнира: 0,09835019117405303

Теперь это моя пользовательская функция:

      def hinge_loss_full(feature_matrix, labels, theta, theta_0):
    lf = 0,1 - labels*(np.sum(feature_matrix * theta, axis = 1) + theta_0
    max_lf =  max(lf)
    result = np.mean(max_lf)
  
    return result

И если я подставлю те же значения для параметров, которые использовались для sklearn, то получится следующее:

      hinge_loss_full(labels_test, y_pred, theta=sample_weight,theta_0=1)

2.3944007220115626

Почему такая разница?

1 ответ

  1. Форма массива numpy должна быть одинаковой.

  2. В каждой библиотеке есть свой способ расчета функции потерь. Следовательно, формулы потерь шарнира CS231 и потерь шарнира Scikit-Learn различны (например, сигмовидная функция Кераса).

  3. Функция с тем же значением, что и функция Scikit-Learn, выглядит следующим образом.

      def hinge_loss(y_true, y_pred,  sample_weight):
    margin = y_true * y_pred
    loss = np.maximum(0, 1 - margin)
    return np.average(loss, weights=sample_weight)
    
hinge_loss(np.squeeze(labels_test), np.squeeze(y_pred), sample_weight)
Другие вопросы по тегам