Использование квантиля в Flux (Julia) в функции потерь

Я пытаюсь использовать квантиль в функции потерь для обучения! (для некоторой надежности, например, наименее обрезанных квадратов), но он изменяет массив, и Zygote выдает ошибкуMutating arrays is not supported, приходящий из sort!. Ниже приведен простой пример (содержание, конечно, не имеет смысла):

using Flux, StatsBase
xdata = randn(2, 100)   
ydata = randn(100)

model = Chain(Dense(2,10), Dense(10, 1))


function trimmedLoss(x,y; trimFrac=0.f05)
        yhat = model(x)
        absRes = abs.(yhat .- y) |> vec
        trimVal = quantile(absRes, 1.f0-trimFrac) 
        s = sum(ifelse.(absRes .> trimVal,  0.f0 , absRes ))/(length(absRes)*(1.f0-trimFrac))
        #s = sum(absRes)/length(absRes)   # using this and commenting out the two above works (no surprise)    
end

println(trimmedLoss(xdata, ydata)) #works ok

Flux.train!(trimmedLoss, params(model), zip([xdata], [ydata]), ADAM())

println(trimmedLoss(xdata, ydata)) #changed loss?

Это все в Flux 0.10 с Julia 1.2

Заранее благодарим за любые подсказки или обходные пути!

1 ответ

Решение

В идеале мы должны определить собственное сопряжение дляquantileтак что это работает из коробки. (Не стесняйтесь открывать вопрос, чтобы напомнить нам об этом.)

А пока есть быстрое решение. На самом деле сортировка вызывает здесь проблемы, поэтому, если выquantile(xs, p, sorted=true)это сработает. Очевидно, это требуетxs для сортировки для получения правильных результатов, поэтому вам может потребоваться использовать quantile(sort(xs), ...).

В зависимости от вашей версии Zygote вам также может понадобиться сопряжение для sort. Это довольно просто:

julia> using Zygote: @adjoint

julia> @adjoint function sort(x)
         p = sortperm(x)
         x[p], x̄ -> (x̄[invperm(p)],)
       end

julia> gradient(x -> quantile(sort(x), 0.5, sorted=true), [1, 2, 3, 3])
([0.0, 0.5, 0.5, 0.0],)

Мы сделаем это встроенным в следующем выпуске Zygote, но сейчас, если вы добавите это в свой скрипт, он заставит ваш код работать.

Другие вопросы по тегам