Pytorch LSTM против LSTMCell
В чем разница между LSTM и LSTMCell в Pytorch (в настоящее время версия 1.1)? Кажется, что LSTMCell - это особый случай LSTM (то есть только с одним слоем, однонаправленный, без выпадения).
Тогда какова цель наличия обеих реализаций? Если я что-то не упустил, тривиально использовать объект LSTM в качестве LSTMCell (или, наоборот, довольно просто использовать несколько LSTMCell для создания объекта LSTM)
2 ответа
Да, вы можете подражать друг другу, причина их разделения - эффективность.
LSTMCell
это ячейка, которая принимает аргументы:
- Ввод партии формы × размер ввода;
- Кортеж скрытых состояний LSTM формы партии x скрытых размеров.
Это простая реализация уравнений.
LSTM
это слой, применяющий ячейку LSTM (или несколько ячеек LSTM) в цикле for, но цикл сильно оптимизирован с использованием cuDNN. Его вклад
- Трехмерный тензор входных данных формы партии × длина ввода × размер ввода;
- Необязательно, начальное состояние LSTM, то есть кортеж скрытых состояний пакета формы × скрытое затемнение (или кортеж таких кортежей, если LSTM является двунаправленным)
Вам часто может потребоваться использовать ячейку LSTM в другом контексте, чем применять ее к последовательности, то есть создать LSTM, работающий по древовидной структуре. Когда вы пишете декодер в моделях последовательности-последовательности, вы также вызываете ячейку в цикле и останавливаете цикл при декодировании символа конца последовательности.
Приведу несколько конкретных примеров:
# LSTM example:
>>> rnn = nn.LSTM(10, 20, 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0))
# LSTMCell example:
>>> rnn = nn.LSTMCell(10, 20)
>>> input = torch.randn(3, 10)
>>> hx = torch.randn(3, 20)
>>> cx = torch.randn(3, 20)
>>> output = []
>>> for i in range(6):
hx, cx = rnn(input[i], (hx, cx))
output.append(hx)
Ключевое отличие:
- LSTM: аргумент
2
, стоитnum_layers
, количество повторяющихся слоев. Естьseq_len * num_layers=5 * 2
клетки. Нет петли, но больше ячеек. - LSTMCell: в
for
петля (seq_len=5
раз), каждый выходith
экземпляр будет вводом(i+1)th
пример. Есть только одна ячейка, действительно рекуррентная.
Если мы установим
num_layers=1
в LSTM или добавьте еще один LSTMCell, коды выше будут такими же.
Очевидно, что в LSTM проще применять параллельные вычисления.