KeyError: 'base_score' while fitting XGBClassifier
Используя Gridsearch, я нахожу наиболее оптимальные гиперпараметры после подбора моих данных обучения:
model_xgb = XGBClassifier()
n_estimators = [50, 100, 150, 200]
max_depth = [2, 4, 6, 8]
param_grid = dict(max_depth=max_depth, n_estimators=n_estimators)
kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=7)
grid_search = GridSearchCV(model_xgb, param_grid, scoring="neg_log_loss", n_jobs=-1, cv=kfold, verbose=1)
grid_result = grid_search.fit(train_X, y_train)
Наилучший ответ дает использование {'max_depth': 4, 'n_estimators': 50}
. Вот почему я создаю новую модель с этими гиперпараметрами:
model_xgb_tn = XGBClassifier(n_estimators=50,max_depth=4,objective='multi:softprob')
Когда я пытаюсь подогнать модель под свои данные: model_xgb_tn.fit(train_X,y_train)
, Я получаю KeyError: 'base_score'
. Я просто не мог понять, почему у меня возникла ошибка KeyError, когда я даже не использовал гиперпараметр.
Ниже приведен код ошибки:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
~\Anaconda3\lib\site-packages\IPython\core\formatters.py in __call__(self, obj, include, exclude)
968
969 if method is not None:
--> 970 return method(include=include, exclude=exclude)
971 return None
972 else:
~\Anaconda3\lib\site-packages\sklearn\base.py in _repr_mimebundle_(self, **kwargs)
461 def _repr_mimebundle_(self, **kwargs):
462 """Mime bundle used by jupyter kernels to display estimator"""
--> 463 output = {"text/plain": repr(self)}
464 if get_config()["display"] == 'diagram':
465 output["text/html"] = estimator_html_repr(self)
~\Anaconda3\lib\site-packages\sklearn\base.py in __repr__(self, N_CHAR_MAX)
277 n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW)
278
--> 279 repr_ = pp.pformat(self)
280
281 # Use bruteforce ellipsis when there are a lot of non-blank characters
~\Anaconda3\lib\pprint.py in pformat(self, object)
142 def pformat(self, object):
143 sio = _StringIO()
--> 144 self._format(object, sio, 0, 0, {}, 0)
145 return sio.getvalue()
146
~\Anaconda3\lib\pprint.py in _format(self, object, stream, indent, allowance, context, level)
159 self._readable = False
160 return
--> 161 rep = self._repr(object, context, level)
162 max_width = self._width - indent - allowance
163 if len(rep) > max_width:
~\Anaconda3\lib\pprint.py in _repr(self, object, context, level)
391 def _repr(self, object, context, level):
392 repr, readable, recursive = self.format(object, context.copy(),
--> 393 self._depth, level)
394 if not readable:
395 self._readable = False
~\Anaconda3\lib\site-packages\sklearn\utils\_pprint.py in format(self, object, context, maxlevels, level)
168 def format(self, object, context, maxlevels, level):
169 return _safe_repr(object, context, maxlevels, level,
--> 170 changed_only=self._changed_only)
171
172 def _pprint_estimator(self, object, stream, indent, allowance, context,
~\Anaconda3\lib\site-packages\sklearn\utils\_pprint.py in _safe_repr(object, context, maxlevels, level, changed_only)
412 recursive = False
413 if changed_only:
--> 414 params = _changed_params(object)
415 else:
416 params = object.get_params(deep=False)
~\Anaconda3\lib\site-packages\sklearn\utils\_pprint.py in _changed_params(estimator)
96 init_params = {name: param.default for name, param in init_params.items()}
97 for k, v in params.items():
---> 98 if (repr(v) != repr(init_params[k]) and
99 not (is_scalar_nan(init_params[k]) and is_scalar_nan(v))):
100 filtered_params[k] = v
KeyError: 'base_score'
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
~\Anaconda3\lib\site-packages\IPython\core\formatters.py in __call__(self, obj)
700 type_pprinters=self.type_printers,
701 deferred_pprinters=self.deferred_printers)
--> 702 printer.pretty(obj)
703 printer.flush()
704 return stream.getvalue()
~\Anaconda3\lib\site-packages\IPython\lib\pretty.py in pretty(self, obj)
400 if cls is not object \
401 and callable(cls.__dict__.get('__repr__')):
--> 402 return _repr_pprint(obj, self, cycle)
403
404 return _default_pprint(obj, self, cycle)
~\Anaconda3\lib\site-packages\IPython\lib\pretty.py in _repr_pprint(obj, p, cycle)
695 """A pprint that just redirects to the normal repr function."""
696 # Find newlines and replace them with p.break_()
--> 697 output = repr(obj)
698 for idx,output_line in enumerate(output.splitlines()):
699 if idx:
~\Anaconda3\lib\site-packages\sklearn\base.py in __repr__(self, N_CHAR_MAX)
277 n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW)
278
--> 279 repr_ = pp.pformat(self)
280
281 # Use bruteforce ellipsis when there are a lot of non-blank characters
~\Anaconda3\lib\pprint.py in pformat(self, object)
142 def pformat(self, object):
143 sio = _StringIO()
--> 144 self._format(object, sio, 0, 0, {}, 0)
145 return sio.getvalue()
146
~\Anaconda3\lib\pprint.py in _format(self, object, stream, indent, allowance, context, level)
159 self._readable = False
160 return
--> 161 rep = self._repr(object, context, level)
162 max_width = self._width - indent - allowance
163 if len(rep) > max_width:
~\Anaconda3\lib\pprint.py in _repr(self, object, context, level)
391 def _repr(self, object, context, level):
392 repr, readable, recursive = self.format(object, context.copy(),
--> 393 self._depth, level)
394 if not readable:
395 self._readable = False
~\Anaconda3\lib\site-packages\sklearn\utils\_pprint.py in format(self, object, context, maxlevels, level)
168 def format(self, object, context, maxlevels, level):
169 return _safe_repr(object, context, maxlevels, level,
--> 170 changed_only=self._changed_only)
171
172 def _pprint_estimator(self, object, stream, indent, allowance, context,
~\Anaconda3\lib\site-packages\sklearn\utils\_pprint.py in _safe_repr(object, context, maxlevels, level, changed_only)
412 recursive = False
413 if changed_only:
--> 414 params = _changed_params(object)
415 else:
416 params = object.get_params(deep=False)
~\Anaconda3\lib\site-packages\sklearn\utils\_pprint.py in _changed_params(estimator)
96 init_params = {name: param.default for name, param in init_params.items()}
97 for k, v in params.items():
---> 98 if (repr(v) != repr(init_params[k]) and
99 not (is_scalar_nan(init_params[k]) and is_scalar_nan(v))):
100 filtered_params[k] = v
KeyError: 'base_score'
1 ответ
Вам необходимо указать базовый параметр оценки, для первой итерации повышения градиента вы можете рассматривать его как начальный вес для начала. Для регрессии это среднее значение вашего целевого столбца, а для задач классификации - 1/(количество классов). Вы можете обратиться к документации xgboost для получения дополнительной информации об этом гиперпараметре.