Моделирование полной последовательности с помощью LSTM в Flux-Julia

Я пытаюсь обучить LSTM моделировать полную последовательность y на основе последовательности x (а не только последнего элемента или классификатора). С помощью следующего кода обучение не работает, хотя функция потерь работает. Похоже, точечный формализм с train не работает!? Есть идеи, как я мог это сделать? В Керасе это так просто.... Заранее спасибо, Маркус

    using Flux

# Create synthetic data first
### Function to generate x consisting of three variables and a sequence length of 200
    function generateX()
        x1 = Array{Float32, 1}(randn(200))
        x2 = Array{Float32, 1}(randn(200))
        x3 = Array{Float32, 1}(sin.((0:199) / 12*2*pi))
        xdata=[x1 x2 x3]'
        return(xdata)
    end

### Generate 50 of these sequences of x
    xdata = [generateX() for i in 1:50]

### Function to generate sequence of y from x sequence
    function yfromx(x)
        y=Array{Float32, 1}(0.2*cumsum(x[1,:].*x[2,:].*exp.(x[1,:])) .+x[3,:])
        return(y')
    end
    ydata =  map(yfromx, xdata);

### Now rearrange such that there is a sequence of 200 X inputs, i.e. an array of x vectors (and 50 of those sequences)
    xdata=Flux.batch(xdata) 
    xdata2 = [xdata[:,s,c] for s in 1:200, c in 1:50]
    xdata= [xdata2[:,c] for c in 1:50]

### Same for y
    ydata=Flux.batch(ydata)
    ydata2 = [ydata[:,s,c] for s in 1:200, c in 1:50]
    ydata= [ydata2[:,c] for c in 1:50]

### Define model and loss function. "model." returns sequence of y from sequence of x
    import Base.Iterators: flatten
    model=Chain(LSTM(3, 26), Dense(26,1))

    loss(x,y) = Flux.mse(collect(flatten(model.(x))),collect(flatten(y)))

    model.(xdata[1]) # works fine
    loss(xdata[2],ydata[2]) # also works fine

    Flux.train!(loss, params(model), zip(xdata, ydata), ADAM(0.005)) ## Does not work, see error below. How to work around?

Сообщение об ошибке

Mutating arrays is not supported

Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::getfield(Zygote, Symbol("##992#993")))(::Nothing) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Zygote/fw4Oc/src/lib/array.jl:44
 [3] (::getfield(Zygote, Symbol("##2633#back#994")){getfield(Zygote, Symbol("##992#993"))})(::Nothing) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [4] copyto! at ./abstractarray.jl:725 [inlined]
 [5] (::typeof(∂(copyto!)))(::Array{Float32,1}) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Zygote/fw4Oc/src/compiler/interface2.jl:0
 [6] _collect at ./array.jl:550 [inlined]
 [7] (::typeof(∂(_collect)))(::Array{Float32,1}) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Zygote/fw4Oc/src/compiler/interface2.jl:0
 [8] collect at ./array.jl:544 [inlined]
 [9] (::typeof(∂(collect)))(::Array{Float32,1}) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Zygote/fw4Oc/src/compiler/interface2.jl:0
 [10] loss at ./In[20]:4 [inlined]
 [11] (::typeof(∂(loss)))(::Float32) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Zygote/fw4Oc/src/compiler/interface2.jl:0
 [12] #153 at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Zygote/fw4Oc/src/lib/lib.jl:142 [inlined]
 [13] #283#back at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [14] #15 at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Flux/oX9Pi/src/optimise/train.jl:69 [inlined]
 [15] (::typeof(∂(λ)))(::Float32) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Zygote/fw4Oc/src/compiler/interface2.jl:0
 [16] (::getfield(Zygote, Symbol("##38#39")){Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float32) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Zygote/fw4Oc/src/compiler/interface.jl:101
 [17] gradient(::Function, ::Zygote.Params) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Zygote/fw4Oc/src/compiler/interface.jl:47
 [18] macro expansion at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Flux/oX9Pi/src/optimise/train.jl:68 [inlined]
 [19] macro expansion at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Juno/oLB1d/src/progress.jl:134 [inlined]
 [20] #train!#12(::getfield(Flux.Optimise, Symbol("##16#22")), ::typeof(Flux.Optimise.train!), ::Function, ::Zygote.Params, ::Base.Iterators.Zip{Tuple{Array{Array{Array{Float32,1},1},1},Array{LinearAlgebra.Adjoint{Float32,Array{Float32,1}},1}}}, ::ADAM) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Flux/oX9Pi/src/optimise/train.jl:66
 [21] train!(::Function, ::Zygote.Params, ::Base.Iterators.Zip{Tuple{Array{Array{Array{Float32,1},1},1},Array{LinearAlgebra.Adjoint{Float32,Array{Float32,1}},1}}}, ::ADAM) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Flux/oX9Pi/src/optimise/train.jl:64
 [22] top-level scope at In[24]:1

loss(xdata[2],ydata[2])

1 ответ

Что ж, следуя пути Фредерика, следующая потеря, похоже, сработает, но, честно говоря, мне это не совсем нравится, поэтому я все еще задаюсь вопросом, есть ли более элегантные / идиоматические / эффективные (?) Решения...

function loss(x,y) 
    yhat=model.(x)
    s=0
    for i in 1:length(yhat)
        s+=(yhat[i][1] - y[i][1])^2
    end

    s/=length(yhat)
    s    

end

Пожалуйста, посмотрите Buffer для мутирующего массива.

Другие вопросы по тегам