Пропускать запрещенные комбинации параметров при использовании GridSearchCV

Я хочу жадно искать во всем пространстве параметров моего классификатора опорных векторов, используя GridSearchCV. Однако некоторые комбинации параметров запрещены LinearSVC и выдают исключение. В частности, существуют взаимоисключающие комбинации dual, penalty, а также loss параметры:

Например, этот код:

from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV

iris = datasets.load_iris()
parameters = {'dual':[True, False], 'penalty' : ['l1', 'l2'], \
              'loss': ['hinge', 'squared_hinge']}
svc = svm.LinearSVC()
clf = GridSearchCV(svc, parameters)
clf.fit(iris.data, iris.target)

Возвращает ValueError: Unsupported set of arguments: The combination of penalty='l2' and loss='hinge' are not supported when dual=False, Parameters: penalty='l2', loss='hinge', dual=False

Мой вопрос: возможно ли заставить GridSearchCV пропускать комбинации параметров, которые запрещены моделью? Если нет, то есть ли простой способ построить пространство параметров, которое не будет нарушать правила?

1 ответ

Решение

Я решил эту проблему, пройдя error_score=0.0 в GridSearchCV:

error_score: 'повысить' (по умолчанию) или числовой

Значение, присваиваемое баллу в случае ошибки при подборе оценки. Если установлено значение "поднимать", ошибка возникает. Если задано числовое значение, вызывается FitFailedWarning. Этот параметр не влияет на шаг восстановления, что всегда вызывает ошибку.

Если вы хотите полностью избежать изучения конкретных комбинаций (не дожидаясь ошибок), вы должны построить сетку самостоятельно. GridSearchCV может принимать список диктов, где исследуются сетки, охватываемые каждым словарем в списке.

В этом случае условная логика была не такой уж плохой, но это было бы действительно утомительно для чего-то более сложного:

from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV
from itertools import product

iris = datasets.load_iris()

duals = [True, False]
penaltys = ['l1', 'l2']
losses = ['hinge', 'squared_hinge']
all_params = list(product(duals, penaltys, losses))
filtered_params = [{'dual': [dual], 'penalty' : [penalty], 'loss': [loss]}
                   for dual, penalty, loss in all_params
                   if not (penalty == 'l1' and loss == 'hinge') 
                   and not ((penalty == 'l1' and loss == 'squared_hinge' and dual is True))
                  and not ((penalty == 'l2' and loss == 'hinge' and dual is False))]

svc = svm.LinearSVC()
clf = GridSearchCV(svc, filtered_params)
clf.fit(iris.data, iris.target)
Другие вопросы по тегам