Линейная регрессия с использованием нейронной сети

Я работаю над проблемой регрессии со следующими примерами обучающих данных.

Как показано, у меня есть вход только 4 параметра с изменением только одного из них, который является Z, так что остальные не имеют реального значения, в то время как выход из 124 параметров, обозначенных от O1 до O124. Обратите внимание, что O1 изменяется с постоянной скоростью 20 [1000, то 1020, затем 1040 ...], в то время как O2 изменяется с другой скоростью, которая равна 30, но все еще остается постоянной и одинаковой для всех 124 выходов, все изменения линейно изменяются постоянным образом.

Я полагал, что это тривиальная проблема, и очень простая модель нейронной сети достигнет 100% точности при тестировании данных, но результаты оказались противоположными.

  • Я достиг 100% точности теста с использованием линейного регрессора и 99,9999% точности теста с использованием регрессора KNN
  • Я достиг 41% точности тестовых данных в 10-уровневой нейронной сети с использованием активации Relu, в то время как все остальные функции активации не сработали, а также неглубокий Relu
  • Используя простую нейронную сеть с линейной функцией активации и без скрытых слоев, я достиг 92% по данным испытаний

Мой вопрос: как я могу заставить нейронную сеть получать 100% на тестовых данных, таких как линейный регрессор? Предполагается, что использование мелкой сети с линейной активацией эквивалентно линейному регрессору, но результаты разные, я что-то упустил?

1 ответ

Если вы используете линейную активацию, глубокая модель в принципе такая же, как линейная регрессия / NN с 1 слоем. Например, для глубокого NN с линейной активацией прогноз задается как y = W_3(W_2(W_1 x))), который можно переписать как y = (W_3 (W_2 W_1))x, что совпадает с y = (W_4 x), который является линейной регрессией.

Учитывая это, проверьте, сходится ли ваш NN без скрытого слоя к тем же параметрам, что и ваша линейная регрессия. Если это не так, значит, ваша реализация ошибочна. Если это так, то ваш большой NN, вероятно, сходится к какому-то решению проблемы, где точность теста просто хуже. Попробуйте разные случайные семена.

Другие вопросы по тегам