Можно ли применить форму sklearn cross_val_score() к neupy NN с таким дополнением, как Weigth Elmination?
Я пытаюсь применить cross_val_score() к следующему алгоритму:
cgnet = algorithms.LevenbergMarquardt(
connection=[
layers.Input(XTrain.shape[1]),
layers.Linear(6),
layers.Linear(1)],
mu_update_factor=2,
mu=0.1,
shuffle_data=True,
verbose=True,
decay_rate=0.1,
addons=[algorithms.WeightElimination])
kfold = KFold(n_splits=5, shuffle=True, random_state=7)
scores=cross_val_score(cgnet, XTrainScaled,yTrainScaled,scoring='neg_mean_absolute_error',cv=kfold,verbose=10)
print scores
print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() * 2))
И это сообщение об ошибке, которое я получаю:
TypeError: Cannot create a consistent method resolution
order (MRO) for bases LevenbergMarquardtWeightElimination, WeightElimination
Без WeightElimination или любого другого дополнения, cross_val_score(), работает нормально... Есть ли другой способ сделать это? Спасибо
1 ответ
Решение
Похоже на функцию cross_val_score
не будет работать в Neupy, но вы можете запустить тот же код немного по-другому.
import numpy as np
from neupy import algorithms, layers
from sklearn.model_selection import *
from sklearn import metrics
XTrainScaled = XTrain = np.random.random((10, 2))
yTrainScaled = np.random.random((10, 1))
kfold = KFold(n_splits=5, shuffle=True, random_state=7)
scores = []
for train, test in kfold.split(XTrainScaled):
x_train, x_test = XTrainScaled[train], XTrainScaled[test]
y_train, y_test = yTrainScaled[train], yTrainScaled[test]
cgnet = algorithms.LevenbergMarquardt(
connection=[
layers.Input(XTrain.shape[1]),
layers.Linear(6),
layers.Linear(1)
],
mu_update_factor=2,
mu=0.1,
shuffle_data=True,
verbose=True,
decay_rate=0.1,
addons=[algorithms.WeightElimination]
)
cgnet.train(x_train, y_train, epochs=5)
y_predicted = cgnet.predict(x_test)
score = metrics.mean_absolute_error(y_test, y_predicted)
scores.append(score)
print(scores)
scores = np.array(scores)
print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() * 2))