Обновление пользовательских выходных слоев сети LSTM

У меня есть задача по генерации текста: научиться предсказывать следующее слово с помощью сети LSTM с несколькими выходными слоями. После завершения генерации предложения я рассчитываю вознаграждение за все предложение и пытаюсь обновить выходные слои, участвовавшие в генерации (содействующие слои получают рассчитанное значение вознаграждения, остальные получают 0). Моя проблема в том, что даже если я обновляю только выбранные выходные слои, кажется, что вместо этого обновляются веса других слоев.

У меня есть минимизированный пример с фиктивными данными, чтобы представить проблему:

      import random

import numpy as np
import tensorflow as tf

from keras.layers import Input, LSTM, Dense, Embedding
from keras.utils import pad_sequences
from tensorflow.keras.models import Model


def policy_gradient_loss(y_true, y_pred):
    return tf.reduce_mean(tf.math.log(y_pred) * float(y_true))

# Define the model with 3 output layers (named 'a', 'b' and 'c').
input_layer = Input(shape=(4,))
embedding_layer = Embedding(input_dim=10, output_dim=4)(input_layer)
lstm_layer = LSTM(4)(embedding_layer)
output_layers = [Dense(3, activation='softmax', name=name)(lstm_layer) for name in ['a', 'b', 'c']]
model = Model(inputs=input_layer, outputs=output_layers)
model.compile(loss=[policy_gradient_loss] * 3, optimizer='adam', run_eagerly=True)

# Dummy input data.
input_data = np.array([[2, 3, 4, 5]])

# Create target data to reward only the 'b' output layer.
target_data = [np.array([0]) for _ in range(len(model._output_layers))]
target_data[1] = np.array([10]) 

# Save initial weights.
initial_weights = model.get_weights()

model.train_on_batch(input_data, y=target_data)

# Save weights after the learning.
updated_weights = model.get_weights()

# Compare the before-after weights.
for layer_idx, (layer_name, initial_w, updated_w) in enumerate(zip([layer.name for layer in model.layers], initial_weights, updated_weights)):
    if not tf.math.reduce_all(tf.equal(initial_w, updated_w)):
        print(f'The weights in layer {layer_idx} ({layer_name}). has changed.')

Результат:

      The weights in layer 0 (input_1). has changed.
The weights in layer 1 (embedding). has changed.
The weights in layer 2 (lstm). has changed.
The weights in layer 3 (a). has changed.

Я ожидаю обновления слоя 4. (выходной слой «b») вместо слоя «a» (или, по крайней мере, рядом с «a»).

Что мне не хватает? Мое ожидание или моя реализация неверны? (Или оба...?)

0 ответов

Другие вопросы по тегам