Использование квантиля в Flux (Julia) в функции потерь
Я пытаюсь использовать квантиль в функции потерь для обучения! (для некоторой надежности, например, наименее обрезанных квадратов), но он изменяет массив, и Zygote выдает ошибкуMutating arrays is not supported
, приходящий из sort!
. Ниже приведен простой пример (содержание, конечно, не имеет смысла):
using Flux, StatsBase
xdata = randn(2, 100)
ydata = randn(100)
model = Chain(Dense(2,10), Dense(10, 1))
function trimmedLoss(x,y; trimFrac=0.f05)
yhat = model(x)
absRes = abs.(yhat .- y) |> vec
trimVal = quantile(absRes, 1.f0-trimFrac)
s = sum(ifelse.(absRes .> trimVal, 0.f0 , absRes ))/(length(absRes)*(1.f0-trimFrac))
#s = sum(absRes)/length(absRes) # using this and commenting out the two above works (no surprise)
end
println(trimmedLoss(xdata, ydata)) #works ok
Flux.train!(trimmedLoss, params(model), zip([xdata], [ydata]), ADAM())
println(trimmedLoss(xdata, ydata)) #changed loss?
Это все в Flux 0.10 с Julia 1.2
Заранее благодарим за любые подсказки или обходные пути!
1 ответ
В идеале мы должны определить собственное сопряжение дляquantile
так что это работает из коробки. (Не стесняйтесь открывать вопрос, чтобы напомнить нам об этом.)
А пока есть быстрое решение. На самом деле сортировка вызывает здесь проблемы, поэтому, если выquantile(xs, p, sorted=true)
это сработает. Очевидно, это требуетxs
для сортировки для получения правильных результатов, поэтому вам может потребоваться использовать quantile(sort(xs), ...)
.
В зависимости от вашей версии Zygote вам также может понадобиться сопряжение для sort
. Это довольно просто:
julia> using Zygote: @adjoint
julia> @adjoint function sort(x)
p = sortperm(x)
x[p], x̄ -> (x̄[invperm(p)],)
end
julia> gradient(x -> quantile(sort(x), 0.5, sorted=true), [1, 2, 3, 3])
([0.0, 0.5, 0.5, 0.0],)
Мы сделаем это встроенным в следующем выпуске Zygote, но сейчас, если вы добавите это в свой скрипт, он заставит ваш код работать.