Как использовать опцию Trax SelfAttention с несколькими головками?

Я играю с моделью из библиотеки Самовниманияtrax .

когда я установил n_heads=1, все отлично работает. Но когда я установил n_heads=2, мой код ломается.

Я использую только активации ввода и один слой SelfAttention.

Вот минимальный код:

      import trax
import numpy as np

attention = trax.layers.SelfAttention(n_heads=2)

activations = np.random.randint(0, 10, (1, 100, 1)).astype(np.float32)
input = (activations, )

init = attention.init(input)

output = attention(input)

Но у меня ошибка:

       File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/layers/research/efficient_attention.py, line 1637, in forward_unbatched_h
    return forward_unbatched(*i_h, weights=w_h, state=s_h)

  File [...]/layers/research/efficient_attention.py, line 1175, in forward_unbatched
    q_info = kv_info = np.arange(q.shape[-2], dtype=np.int32)

IndexError: tuple index out of range

Что я делаю не так?

0 ответов

Другие вопросы по тегам