Pytorch LSTM против LSTMCell

В чем разница между LSTM и LSTMCell в Pytorch (в настоящее время версия 1.1)? Кажется, что LSTMCell - это особый случай LSTM (то есть только с одним слоем, однонаправленный, без выпадения).

Тогда какова цель наличия обеих реализаций? Если я что-то не упустил, тривиально использовать объект LSTM в качестве LSTMCell (или, наоборот, довольно просто использовать несколько LSTMCell для создания объекта LSTM)

2 ответа

Решение

Да, вы можете подражать друг другу, причина их разделения - эффективность.

LSTMCell это ячейка, которая принимает аргументы:

  • Ввод партии формы × размер ввода;
  • Кортеж скрытых состояний LSTM формы партии x скрытых размеров.

Это простая реализация уравнений.

LSTM это слой, применяющий ячейку LSTM (или несколько ячеек LSTM) в цикле for, но цикл сильно оптимизирован с использованием cuDNN. Его вклад

  • Трехмерный тензор входных данных формы партии × длина ввода × размер ввода;
  • Необязательно, начальное состояние LSTM, то есть кортеж скрытых состояний пакета формы × скрытое затемнение (или кортеж таких кортежей, если LSTM является двунаправленным)

Вам часто может потребоваться использовать ячейку LSTM в другом контексте, чем применять ее к последовательности, то есть создать LSTM, работающий по древовидной структуре. Когда вы пишете декодер в моделях последовательности-последовательности, вы также вызываете ячейку в цикле и останавливаете цикл при декодировании символа конца последовательности.

Приведу несколько конкретных примеров:

      # LSTM example:
>>> rnn = nn.LSTM(10, 20, 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0))
# LSTMCell example:
>>> rnn = nn.LSTMCell(10, 20)
>>> input = torch.randn(3, 10)
>>> hx = torch.randn(3, 20)
>>> cx = torch.randn(3, 20)
>>> output = []
>>> for i in range(6):
        hx, cx = rnn(input[i], (hx, cx))
        output.append(hx)

Ключевое отличие:

  1. LSTM: аргумент 2, стоит num_layers, количество повторяющихся слоев. Есть seq_len * num_layers=5 * 2клетки. Нет петли, но больше ячеек.
  2. LSTMCell: в for петля ( seq_len=5 раз), каждый выход ith экземпляр будет вводом (i+1)thпример. Есть только одна ячейка, действительно рекуррентная.

Если мы установим num_layers=1 в LSTM или добавьте еще один LSTMCell, коды выше будут такими же.

Очевидно, что в LSTM проще применять параллельные вычисления.

Другие вопросы по тегам