Минимальный пример Numeric.AD не скомпилируется

Я пытаюсь скомпилировать следующий минимальный пример из Numeric.AD:

import Numeric.AD 
timeAndGrad f l = grad f l
main = putStrLn "hi"

и я сталкиваюсь с этой ошибкой:

test.hs:3:24:
    Couldn't match expected type ‘f (Numeric.AD.Internal.Reverse.Reverse
                                       s a)
                                  -> Numeric.AD.Internal.Reverse.Reverse s a’
                with actual type ‘t’
      because type variable ‘s’ would escape its scope
    This (rigid, skolem) type variable is bound by
      a type expected by the context:
        Data.Reflection.Reifies s Numeric.AD.Internal.Reverse.Tape =>
        f (Numeric.AD.Internal.Reverse.Reverse s a)
        -> Numeric.AD.Internal.Reverse.Reverse s a
      at test.hs:3:19-26
    Relevant bindings include
      l :: f a (bound at test.hs:3:15)
      f :: t (bound at test.hs:3:13)
      timeAndGrad :: t -> f a -> f a (bound at test.hs:3:1)
    In the first argument of ‘grad’, namely ‘f’
    In the expression: grad f l

Любой ключ к пониманию того, почему это происходит? Глядя на предыдущие примеры, я понимаю, что это "выравнивание" gradтип:

grad :: (Traversable f, Num a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> f a

но на самом деле мне нужно сделать что-то подобное в моем коде. На самом деле, это самый минимальный пример, который не будет компилироваться. Более сложная вещь, которую я хочу сделать, это что-то вроде этого:

example :: SomeType
example f x args = (do stuff with the gradient and gradient "function")
    where gradient = grad f x
          gradientFn = grad f
          (other where clauses involving gradient and gradient "function")

Вот немного более сложная версия с сигнатурами типов, которая компилируется.

{-# LANGUAGE RankNTypes #-}

import Numeric.AD 
import Numeric.AD.Internal.Reverse

-- compiles but I can't figure out how to use it in code
grad2 :: (Show a, Num a, Floating a) => (forall s.[Reverse s a] -> Reverse s a) -> [a] -> [a]
grad2 f l = grad f l

-- compiles with the right type, but the resulting gradient is all 0s...
grad2' :: (Show a, Num a, Floating a) => ([a] -> a) -> [a] -> [a]
grad2' f l = grad f' l
       where f' = Lift . f . extractAll
       -- i've tried using the Reverse constructor with Reverse 0 _, Reverse 1 _, and Reverse 2 _, but those don't yield the correct gradient. Not sure how the modes work

extractAll :: [Reverse t a] -> [a]
extractAll xs = map extract xs
           where extract (Lift x) = x -- non-exhaustive pattern match

dist :: (Show a, Num a, Floating a) => [a] -> a
dist [x, y] = sqrt(x^2 + y^2)

-- incorrect output: [0.0, 0.0]
main = putStrLn $ show $ grad2' dist [1,2]

Тем не менее, я не могу понять, как использовать первую версию, grad2в коде, потому что я не знаю, как бороться с Reverse s a, Вторая версия, grad2', имеет правильный тип, потому что я использую внутренний конструктор Lift создать Reverse s a, но я не должен понимать, как внутренности (в частности, параметр s) работает, потому что выходной градиент равен 0. Использование другого конструктора Reverse (здесь не показано) также создает неправильный градиент.

Кроме того, есть ли примеры библиотек / кода, где люди использовали ad код? Я думаю, что мой вариант использования очень распространен.

1 ответ

Решение

С where f' = Lift . f . extractAll по сути, вы создаете заднюю дверь в базовый тип автоматической дифференциации, который отбрасывает все производные и сохраняет только постоянные значения. Если вы затем используете это для gradНеудивительно, что вы получаете нулевой результат!

Разумный способ - просто использовать grad как это:

dist :: Floating a => [a] -> a
dist [x, y] = sqrt $ x^2 + y^2
-- preferrable is of course `dist = sqrt . sum . map (^2)`

main = print $ grad dist [1,2]
-- output: [0.4472135954999579,0.8944271909999159]

Вам не нужно знать что-то более сложное, чтобы использовать автоматическое дифференцирование. Пока ты только дифференцируешься Num или же Floating-полиморфные функции, все будет работать как есть. Если вам нужно дифференцировать функцию, переданную в качестве аргумента, вам нужно сделать этот аргумент полиморфным ранга 2 (альтернативой может быть переключение на версию ранга 1 ad функции, но я осмелюсь сказать, что это менее элегантно и на самом деле не приносит вам много).

{-# LANGUAGE Rank2Types, UnicodeSyntax #-}

mainWith :: (∀n . Floating n => [n] -> n) -> IO ()
mainWith f = print $ grad f [1,2]

main = mainWith dist
Другие вопросы по тегам