Минимальный пример Numeric.AD не скомпилируется
Я пытаюсь скомпилировать следующий минимальный пример из Numeric.AD:
import Numeric.AD
timeAndGrad f l = grad f l
main = putStrLn "hi"
и я сталкиваюсь с этой ошибкой:
test.hs:3:24:
Couldn't match expected type ‘f (Numeric.AD.Internal.Reverse.Reverse
s a)
-> Numeric.AD.Internal.Reverse.Reverse s a’
with actual type ‘t’
because type variable ‘s’ would escape its scope
This (rigid, skolem) type variable is bound by
a type expected by the context:
Data.Reflection.Reifies s Numeric.AD.Internal.Reverse.Tape =>
f (Numeric.AD.Internal.Reverse.Reverse s a)
-> Numeric.AD.Internal.Reverse.Reverse s a
at test.hs:3:19-26
Relevant bindings include
l :: f a (bound at test.hs:3:15)
f :: t (bound at test.hs:3:13)
timeAndGrad :: t -> f a -> f a (bound at test.hs:3:1)
In the first argument of ‘grad’, namely ‘f’
In the expression: grad f l
Любой ключ к пониманию того, почему это происходит? Глядя на предыдущие примеры, я понимаю, что это "выравнивание" grad
тип:
grad :: (Traversable f, Num a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> f a
но на самом деле мне нужно сделать что-то подобное в моем коде. На самом деле, это самый минимальный пример, который не будет компилироваться. Более сложная вещь, которую я хочу сделать, это что-то вроде этого:
example :: SomeType
example f x args = (do stuff with the gradient and gradient "function")
where gradient = grad f x
gradientFn = grad f
(other where clauses involving gradient and gradient "function")
Вот немного более сложная версия с сигнатурами типов, которая компилируется.
{-# LANGUAGE RankNTypes #-}
import Numeric.AD
import Numeric.AD.Internal.Reverse
-- compiles but I can't figure out how to use it in code
grad2 :: (Show a, Num a, Floating a) => (forall s.[Reverse s a] -> Reverse s a) -> [a] -> [a]
grad2 f l = grad f l
-- compiles with the right type, but the resulting gradient is all 0s...
grad2' :: (Show a, Num a, Floating a) => ([a] -> a) -> [a] -> [a]
grad2' f l = grad f' l
where f' = Lift . f . extractAll
-- i've tried using the Reverse constructor with Reverse 0 _, Reverse 1 _, and Reverse 2 _, but those don't yield the correct gradient. Not sure how the modes work
extractAll :: [Reverse t a] -> [a]
extractAll xs = map extract xs
where extract (Lift x) = x -- non-exhaustive pattern match
dist :: (Show a, Num a, Floating a) => [a] -> a
dist [x, y] = sqrt(x^2 + y^2)
-- incorrect output: [0.0, 0.0]
main = putStrLn $ show $ grad2' dist [1,2]
Тем не менее, я не могу понять, как использовать первую версию, grad2
в коде, потому что я не знаю, как бороться с Reverse s a
, Вторая версия, grad2'
, имеет правильный тип, потому что я использую внутренний конструктор Lift
создать Reverse s a
, но я не должен понимать, как внутренности (в частности, параметр s
) работает, потому что выходной градиент равен 0. Использование другого конструктора Reverse
(здесь не показано) также создает неправильный градиент.
Кроме того, есть ли примеры библиотек / кода, где люди использовали ad
код? Я думаю, что мой вариант использования очень распространен.
1 ответ
С where f' = Lift . f . extractAll
по сути, вы создаете заднюю дверь в базовый тип автоматической дифференциации, который отбрасывает все производные и сохраняет только постоянные значения. Если вы затем используете это для grad
Неудивительно, что вы получаете нулевой результат!
Разумный способ - просто использовать grad
как это:
dist :: Floating a => [a] -> a
dist [x, y] = sqrt $ x^2 + y^2
-- preferrable is of course `dist = sqrt . sum . map (^2)`
main = print $ grad dist [1,2]
-- output: [0.4472135954999579,0.8944271909999159]
Вам не нужно знать что-то более сложное, чтобы использовать автоматическое дифференцирование. Пока ты только дифференцируешься Num
или же Floating
-полиморфные функции, все будет работать как есть. Если вам нужно дифференцировать функцию, переданную в качестве аргумента, вам нужно сделать этот аргумент полиморфным ранга 2 (альтернативой может быть переключение на версию ранга 1 ad
функции, но я осмелюсь сказать, что это менее элегантно и на самом деле не приносит вам много).
{-# LANGUAGE Rank2Types, UnicodeSyntax #-}
mainWith :: (∀n . Floating n => [n] -> n) -> IO ()
mainWith f = print $ grad f [1,2]
main = mainWith dist