Функция потери потока не считывает данные должным образом
Вероятно, это довольно простая ошибка с моей стороны, но я не могу ее понять. Я пытаюсь создать RNN, который будет изучать числовые последовательности. Пример набора данных (каждая строка представляет точку данных)
0 0 0 1 3
0 0 0 0 0
0 0 1 3 0
...
В основном я следую этому примеру: https://www.juliabloggers.com/a-basic-rnn/
Мои данные и данные в примерах читаются как Array{Array{Float64,1},1}. Вот часть моего кода
function eval_model(model, x)
out = model.(x)[end]
Flux.reset!(model)
return out
end
m = Chain(GRU(1, 40), Dense(40, 1, σ))
loss(y) = Flux.crossentropy(eval_model(m, y), y)
ps = Flux.params(m)
opt = Flux.ADAM()
@epochs 100 Flux.train!(loss, ps, data, opt)
Выход:
MethodError: no method matching loss(::Float64, ::Float64, ::Float64, ::Float64, ::Float64)
Closest candidates are:
loss(::Any, ::Any) at In[4]:2
Функция потерь считывает каждое число в последовательности как отдельный вход для функции потерь (я пробовал другие длины последовательностей, и ошибка такая же, но это "MethodError: нет потери соответствия метода ((длина последовательности) *::Float64)".
В примере, над которым я работаю, это не проблема. Я мог бы построить процедуру обучения с нуля, но предпочел бы передать ее Flux.
1 ответ
Это проблема с формой ваших данных: train!
предполагает, что вы разделите данные на предикторы и цели для применения к loss(x,y)
. Но в твоем случаеtrain!
разбивает данные на что-то нежелательное.
data = zip(x,y)
позволит train!
чтобы разделить ваши данные на набор данных x и y.
Кроме того, ваш набор данных и функция потерь принимают только y, что делает для меня неоднозначным то, что вы пытаетесь предсказать. Если вы хотите предсказать следующий элемент в последовательности, тогда x следует опустить последний элемент в последовательности, а y пропускает первый элемент последовательности.
Рассмотрим этот пример, чтобы предсказать следующий элемент в последовательности:
x = data[:,1:end-1]
y = data[:,2:end]
loss(x,y) = Flux.crossentropy(eval_model(m, x), y)
@epochs 100 Flux.train!(loss, ps, zip(x,y), opt)