Как правильно определить собственный градиент STE во Flux?
Я пытаюсь написать собственный градиент STE с помощью Flux. Активация - это, в основном, функция sign(), а ее градиент является входящим градиентом, если только его абсолютное значение <=1, и отменяется другим способом. Реализация, которую я в настоящее время имею, кажется, не работает правильно
binarize(x) = x>=0 ? true : false
binarize(x::Flux.Tracker.TrackedReal) = Flux.Tracker.track(binarize, x)
@grad function binarize(x)
return binarize.(Flux.Tracker.data(x)), Δ -> (abs(x) <= 1 ? x : 0, )
0), )
end
Таким образом, для случайной матрицы 5x1 я получаю:
>> a= param(randn(5))
>> Tracked 5-element Array{Float64,1}:
-0.3605564089879154
-0.7853512499733902
0.8102988051980005
-0.9715952052917924
-1.276343849200165
>> c= binarize.(a)
>> 5-element BitArray{1}:
false
false
true
false
false
>> Tracker.back!(c, [1,1,1,1,1])
>> a.grad
5-element Array{Float64,1}:
0.0
0.0
0.0
0.0
0.0
Я ожидал бы, что градиент a будет похож на a, за исключением последнего элемента, который будет 0.
Что я делаю неправильно?