Градиентный спуск не сходится

Вот моя собственная реализация алгоритма градиентного спуска на языке Matlab

 m = height(data_training); % number of samples
cols = {'x1', 'x2', 'x3', 'x4', 'x5', 'x6',...
    'x7', 'x8','x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15'}; 

y = data_training{:, {'y'}}';
X = [ones(m,1) data_training{:,cols}]'; 

theta = zeros(1,width(data_training));

alpha = 1e-2; % learning rate
iter = 400;

dJ = zeros(1,width(data_training));

J_seq = zeros(1, iter);

for n = 1:iter

    err = (theta*X - y);

    for j = 1:width(data_training)
        dJ(j) = 1/m*sum(err*X(j,:)');
    end

    J = 1/(2*m)*sum((theta*X-y).^2);

    theta = theta - alpha.*dJ;

    J_seq(n) = J;

    if mod(n,100) == 0
        plot(1:iter, J_seq);
    end
end

РЕДАКТИРОВАТЬ РАБОЧИЙ АЛГОРИТМ

Я применил этот алгоритм к следующему набору данных. Последний столбец является выходной переменной. Здесь у нас есть 15 различных функций.

По неизвестной мне причине, когда я строю функцию стоимости J после 50 итераций, чтобы проверить, идет ли она к сходимости, я вижу, что она не сходима. Можете ли вы помочь мне понять? Это неправильная реализация или я должен что-то сделать?

36    27    71     8.1    3.34    11.4    81.5    3243     8.8    42.6    11.7     21     15     59    59     921.87
35    23    72    11.1    3.14      11    78.8    4281     3.6    50.7    14.4      8     10     39    57     997.88
44    29    74    10.4    3.21     9.8    81.6    4260     0.8    39.4    12.4      6      6     33    54     962.35
47    45    79     6.5    3.41    11.1    77.5    3125    27.1    50.2    20.6     18      8     24    56     982.29
43    35    77     7.6    3.44     9.6    84.6    6441    24.4    43.7    14.3     43     38    206    55     1071.3
53    45    80     7.7    3.45    10.2    66.8    3325    38.5    43.1    25.5     30     32     72    54     1030.4
43    30    74    10.9    3.23    12.1    83.9    4679     3.5    49.2    11.3     21     32     62    56      934.7
45    30    73     9.3    3.29    10.6      86    2140     5.3    40.4    10.5      6      4      4    56     899.53
36    24    70       9    3.31    10.5    83.2    6582     8.1    42.5    12.6     18     12     37    61     1001.9
36    27    72     9.5    3.36    10.7    79.3    4213     6.7      41    13.2     12      7     20    59     912.35
52    42    79     7.7    3.39     9.6    69.2    2302    22.2    41.3    24.2     18      8     27    56     1017.6
33    26    76     8.6     3.2    10.9    83.4    6122    16.3    44.9    10.7     88     63    278    58     1024.9
40    34    77     9.2    3.21    10.2      77    4101      13    45.7    15.1     26     26    146    57     970.47
35    28    71     8.8    3.29    11.1    86.3    3042    14.7    44.6    11.4     31     21     64    60     985.95
37    31    75       8    3.26    11.9    78.4    4259    13.1    49.6    13.9     23      9     15    58     958.84
35    46    85     7.1    3.22    11.8    79.9    1441    14.8    51.2    16.1      1      1      1    54      860.1
36    30    75     7.5    3.35    11.4    81.9    4029    12.4      44      12      6      4     16    58     936.23
15    30    73     8.2    3.15    12.2    84.2    4824     4.7    53.1    12.7     17      8     28    38     871.77
31    27    74     7.2    3.44    10.8      87    4834    15.8    43.5    13.6     52     35    124    59     959.22
30    24    72     6.5    3.53    10.8    79.5    3694    13.1    33.8    12.4     11      4     11    61     941.18
31    45    85     7.3    3.22    11.4    80.7    1844    11.5    48.1    18.5      1      1      1    53     891.71
31    24    72       9    3.37    10.9    82.8    3226     5.1    45.2    12.3      5      3     10    61     871.34
42    40    77     6.1    3.45    10.4    71.8    2269    22.7    41.4    19.5      8      3      5    53     971.12
43    27    72       9    3.25    11.5    87.1    2909     7.2    51.6     9.5      7      3     10    56     887.47
46    55    84     5.6    3.35    11.4    79.7    2647      21    46.9    17.9      6      5      1    59     952.53
39    29    76     8.7    3.23    11.4    78.6    4412    15.6    46.6    13.2     13      7     33    60     968.66
35    31    81     9.2     3.1      12    78.3    3262    12.6    48.6    13.9      7      4      4    55     919.73
43    32    74    10.1    3.38     9.5    79.2    3214     2.9    43.7      12     11      7     32    54     844.05
11    53    68     9.2    2.99    12.1    90.6    4700     7.8    48.9    12.3    648    319    130    47     861.83
30    35    71     8.3    3.37     9.9    77.4    4474    13.1    42.6    17.7     38     37    193    57     989.26
50    42    82     7.3    3.49    10.4    72.5    3497    36.7    43.3    26.4     15     10     34    59     1006.5
60    67    82      10    2.98    11.5    88.6    4657    13.6    47.3    22.4      3      1      1    60     861.44
30    20    69     8.8    3.26    11.1    85.4    2934     5.8      44     9.4     33     23    125    64     929.15
25    12    73     9.2    3.28    12.1    83.1    2095       2    51.9     9.8     20     11     26    50     857.62
45    40    80     8.3    3.32    10.1    70.3    2682      21    46.1    24.1     17     14     78    56     961.01
46    30    72    10.2    3.16    11.3    83.2    3327     8.8    45.3    12.2      4      3      8    58     923.23

1 ответ

Решение

Не уверен, что я следую вашей логике, но совершенно очевидно, что "е" (ошибка) не должно быть в квадрате.

Давайте посмотрим, что вы должны использовать.

theta вектор-столбец неизвестных, y вектор-столбец измерений и X это матрица модели, где каждая строка является "примером". Так что вам нужно найти theta такой что:

y = X*theta 

Или эквивалентно, используйте метод оптимизации, чтобы найти theta минимизация текущей квадратичной ошибки (вот что делает эту проблему выпуклой оптимизацией):

e[n] = (y - X*theta[n])

e[n]^2 --> minimize 

Градиентный спуск использует градиент функции ошибки (относительно тета) для обновления theta вектор:

theta[n+1] = theta[n] - alpha*2*X'*e[n]

(Обратите внимание, что e[n] и theta[n] - векторы. Это математическая запись, а не матрица)

Итак, вы видите, что e[n] не возводится в квадрат в уравнении обновления.

Другие вопросы по тегам