Прекратить отслеживание массивов в Flux (Джулия)

В настоящее время я пытаюсь реализовать пакетное обновление Flux для Джулии.

Во время своих вычислений я получаю пакет скаляров, многократно выполняя

δ = Gt - model(St)[1]
push!(deltas,δ)

где модель - нейронная сеть

global model= Chain(
    Dense(statesize,10, leakyrelu),
    Dense(10,10,leakyrelu),
    Dense(10,1))

В итоге я получаю дельты массива, и я хотел бы выполнить пакетное обновление градиента (размер пакета = 19) во второй нейронной сети, где каждый градиент взвешивается соответствующей дельтой. Я написал функцию обновления

function vupdate2!(S_batch,model,α,deltas)

   function v_loss_total(x)
       return sum(reshape(deltas,(1,19)) .* model(x))
   end

   local ps = Flux.params(model)
   local gs = Flux.Tracker.gradient(() -> v_loss_total(S_batch), ps)
   for p in ps
       Flux.Tracker.update!( p,  α.* gs[p])
   end
end

Проблема в том, что строка, в которой вычисляются градиенты, выдает ошибку: MethodError: no method matching Float32(::Tracker.TrackedReal{Float64})

Я думаю, проблема в том, что мой дельта-массив отслеживается. Глядя на вывод функции v_loss_total для случайного ввода, я получаю:

julia> v_loss_total(S_batch)
-6752.433690476287 (tracked) (tracked)

Интересно, что это число отслеживается дважды (?), Что, как я полагаю, происходит от умножения двух отслеживаемых чисел вместе (т.е. записей дельт и модели (S_batch)). Есть ли способ сначала отследить массив дельты? Буду признателен за любую помощь.

2 ответа

Решение

Ладно, как оказалось, есть функция

Flux.Tracker.data()

что делает именно то, что мне нужно. Он берет отслеживаемое число и возвращает сам Float. См. Также: https://github.com/FluxML/Flux.jl/issues/640

Что сработало для меня в julia 1.2, так это доступ к float как к полю с помощью .data

Вышеупомянутая функция, предложенная GreenLogic, возвращает только другой трекер.

Другие вопросы по тегам