Функция потери потока не считывает данные должным образом

Вероятно, это довольно простая ошибка с моей стороны, но я не могу ее понять. Я пытаюсь создать RNN, который будет изучать числовые последовательности. Пример набора данных (каждая строка представляет точку данных)

0 0 0 1 3
0 0 0 0 0
0 0 1 3 0 
...

В основном я следую этому примеру: https://www.juliabloggers.com/a-basic-rnn/

Мои данные и данные в примерах читаются как Array{Array{Float64,1},1}. Вот часть моего кода

function eval_model(model, x)
  out = model.(x)[end]
  Flux.reset!(model)
  return out
end

m = Chain(GRU(1, 40), Dense(40, 1, σ))

loss(y) = Flux.crossentropy(eval_model(m, y), y)

ps = Flux.params(m)

opt = Flux.ADAM()

@epochs 100 Flux.train!(loss, ps, data, opt)

Выход:

MethodError: no method matching loss(::Float64, ::Float64, ::Float64, ::Float64, ::Float64)
Closest candidates are:
  loss(::Any, ::Any) at In[4]:2

Функция потерь считывает каждое число в последовательности как отдельный вход для функции потерь (я пробовал другие длины последовательностей, и ошибка такая же, но это "MethodError: нет потери соответствия метода ((длина последовательности) *::Float64)".

В примере, над которым я работаю, это не проблема. Я мог бы построить процедуру обучения с нуля, но предпочел бы передать ее Flux.

1 ответ

Это проблема с формой ваших данных: train! предполагает, что вы разделите данные на предикторы и цели для применения к loss(x,y). Но в твоем случаеtrain! разбивает данные на что-то нежелательное.

data = zip(x,y) позволит train! чтобы разделить ваши данные на набор данных x и y.

Кроме того, ваш набор данных и функция потерь принимают только y, что делает для меня неоднозначным то, что вы пытаетесь предсказать. Если вы хотите предсказать следующий элемент в последовательности, тогда x следует опустить последний элемент в последовательности, а y пропускает первый элемент последовательности.

Рассмотрим этот пример, чтобы предсказать следующий элемент в последовательности:

x = data[:,1:end-1]
y = data[:,2:end]
loss(x,y) = Flux.crossentropy(eval_model(m, x), y)
@epochs 100 Flux.train!(loss, ps, zip(x,y), opt)
Другие вопросы по тегам