Как определить матричный продукт в ускорении-haskell

Я пытаюсь определить типобезопасную матричную библиотеку вычислений поверх ускорения, частично для образовательных целей, частично чтобы посмотреть, является ли это практическим подходом.

Но я полностью застрял, когда дело доходит до правильного определения произведения матриц - то есть способом, которым GHC принимает / компилирует мой код.

У меня было несколько попыток, которые были вариациями этого:

Linear.hs

{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ScopedTypeVariables #-}

import qualified Data.Array.Accelerate as A

import GHC.TypeLits
import Data.Array.Accelerate ( (:.)(..), Array
                             , Exp, Shape, FullShape, Slice
                             , DIM0, DIM1, DIM2, Z(Z)
                             , IsFloating, IsNum, Elt, Acc
                             , Any(Any), All(All))
import           Data.Proxy

newtype Matrix (rows :: Nat) (cols :: Nat) a = AccMatrix {unMatrix :: Acc (Array DIM2 a)}
(#*#) :: forall k m n a. (KnownNat k, KnownNat m, KnownNat n, IsNum a, Elt a) =>
    Matrix k m a -> Matrix m n a -> Matrix k n a
 v #*# w = let v' = unMatrix v
               w' = unMatrix w
           in AccMatrix $ A.generate (A.index2 k' n') undefined
          where k' = fromInteger $ natVal (Proxy :: Proxy k)
                n' = fromInteger $ natVal (Proxy :: Proxy n)
                aux :: Acc (Array (FullShape (Z :. Int) :. Int) e) -> Acc (Array (FullShape (Z :. All) :. Int) e) -> Exp ((Z :. Int) :. Int) -> Exp e
                aux v w sh = let (Z:.i:.j) = A.unlift sh
                                 v' = A.slice v (A.lift $ Z:.i:.All)
                                 w' = A.slice w (A.lift $ Z:.All:.j)
                              in A.the $ A.sum $ A.zipWith (*) v' w'

Ошибка stack build дает мне это

.../src/Linear.hs:196:55:
    Couldn't match type ‘A.Plain ((Z :. head0) :. head1)’
                   with ‘(Z :. Int) :. Int’
    The type variables ‘head0’, ‘head1’ are ambiguous
    Expected type: Exp (A.Plain ((Z :. head0) :. head1))
      Actual type: Exp ((Z :. Int) :. Int)
    Relevant bindings include
      i :: head0 (bound at src/Linear.hs:196:38)
      j :: head1 (bound at src/Linear.hs:196:41)
    In the first argument of ‘A.unlift’, namely ‘sh’
    In the expression: A.unlift sh

.../src/Linear.hs:197:47:
    Couldn't match type ‘FullShape (A.Plain (Z :. head0))’
                   with ‘Z :. Int’
    The type variable ‘head0’ is ambiguous
    Expected type: Acc
                     (Array (FullShape (A.Plain (Z :. head0) :. All)) e)
      Actual type: Acc (Array (FullShape (Z :. Int) :. Int) e)
    Relevant bindings include
      v' :: Acc (Array (A.SliceShape (A.Plain (Z :. head0)) :. Int) e)
        (bound at src/Linear.hs:197:34)
      i :: head0 (bound at src/Linear.hs:196:38)
    In the first argument of ‘A.slice’, namely ‘v’
    In the expression: A.slice v (A.lift $ Z :. i :. All)

.../src/Linear.hs:198:39:
    Couldn't match type ‘A.SliceShape (A.Plain ((Z :. All) :. head1))’
                   with ‘A.SliceShape (A.Plain (Z :. head0)) :. Int’
    The type variables ‘head0’, ‘head1’ are ambiguous
    Expected type: Acc
                     (Array (A.SliceShape (A.Plain (Z :. head0)) :. Int) e)
      Actual type: Acc
                     (Array (A.SliceShape (A.Plain ((Z :. All) :. head1))) e)
    Relevant bindings include
      w' :: Acc (Array (A.SliceShape (A.Plain (Z :. head0)) :. Int) e)
        (bound at src/Linear.hs:198:34)
      v' :: Acc (Array (A.SliceShape (A.Plain (Z :. head0)) :. Int) e)
        (bound at src/Linear.hs:197:34)
      i :: head0 (bound at src/Linear.hs:196:38)
      j :: head1 (bound at src/Linear.hs:196:41)
    In the expression: A.slice w (A.lift $ Z :. All :. j)
    In an equation for ‘w'’: w' = A.slice w (A.lift $ Z :. All :. j)

.../src/Linear.hs:198:47:
    Couldn't match type ‘FullShape (A.Plain ((Z :. All) :. head1))’
                   with ‘(Z :. Int) :. Int’
    The type variable ‘head1’ is ambiguous
    Expected type: Acc
                     (Array (FullShape (A.Plain ((Z :. All) :. head1))) e)
      Actual type: Acc (Array (FullShape (Z :. All) :. Int) e)
    Relevant bindings include
      j :: head1 (bound at src/Linear.hs:196:41)
    In the first argument of ‘A.slice’, namely ‘w’
    In the expression: A.slice w (A.lift $ Z :. All :. j)

Я ознакомился с документацией Accelerate, и я также читаю ускоренную арифметику, которая имеет аналогичную цель, но не использует TypeLits чтобы утверждать размеры массива / вектора.

Я также попытался сделать ванильную версию (т.е. без моего собственного типа матрицы), на случай, если мои типы были неправильными, что, как я считаю, страдает от того же неправильного представления об использовании slice, Я включил это просто для полноты картины, я могу добавить сообщения об ошибках, но я решил опустить их, так как считаю, что они не связаны с проблемой выше.

(#*#) :: forall a. (IsNum a, Elt a) =>
    Acc (Array DIM2 a) -> Acc (Array DIM2 a) -> Maybe (Acc (Array DIM2 a))   
v #*# w = let Z:.k :.m = A.unlift $ A.arrayShape $ I.run v
              Z:.m':.n = A.unlift $ A.arrayShape $ I.run w
           in if m /= m'
                 then Nothing
                 else Just $ AccMatrix $ A.generate (A.index2 k n) (aux v w)
          where aux :: Acc (Array DIM2 a) -> Acc (Array DIM2 a) -> Exp DIM2 -> Exp a
                aux v w sh = let (Z:.i:.j) = A.unlift sh
                                 v' = A.slice v (A.lift $ Z:.i:.All)
                                 w' = A.slice w (A.lift $ Z:.All:.j)
                              in A.the $ A.sum $ A.zipWith (*) v' w'

1 ответ

Решение

Ваш код на самом деле правильный. К сожалению, проверка типов не достаточно умна, чтобы понять это, поэтому вы должны помочь:

let (Z:.i:.j) = A.unlift sh

становится

let (Z:.i:.j) = A.unlift sh :: (Z :. Exp Int) :. Exp Int

Критическая вещь здесь заключается в том, что A.unlift :: A.Unlift c e => c (A.Plain e) -> e но A.Plain является связанным типом семейства (и, следовательно, неинъективным), поэтому тип e не может быть определено без сигнатуры типа, и e требуется выбрать экземпляр для использования Unlift c e, Отсюда и ошибки "неоднозначного типа" - это действительно e это неоднозначно.


У вас также есть несвязанная ошибка. aux должен иметь тип

aux :: (IsNum e, Elt e) => ...

или же

aux :: (e ~ a) => ... 

в последнем случае a это одна в типе подписи (#*#) так что у него уже есть ограничения IsNum, Elt

Другие вопросы по тегам