Пользовательская функция потери / цели с дополнительным вводом переменных в Keras

Я пытаюсь создать пользовательскую целевую функцию в Keras (бэкэнд тензорного потока) с дополнительным параметром, значение которого будет зависеть от обучаемого пакета.

Например:

def myLoss(self, stateValues):
    def sparse_loss(y_true, y_pred):
        foo = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)
        return tf.reduce_mean(foo * stateValues)
    return sparse_loss


self.model.compile(loss=self.myLoss(stateValue = self.stateValue),
        optimizer=Adam(lr=self.alpha))

Моя функция поезда заключается в следующем

for batch in batches:
    self.stateValue = computeStateValueVectorForCurrentBatch(batch)
    model.fit(xVals, yVals, batch_size=<num>)

Однако stateValue в функции потерь не обновляется. Он просто использует значение, которое stateValue имеет на шаге model.compile.

Я думаю, что это можно решить с помощью placeHolder для stateValue, но я не могу понять, как это сделать. Может кто-нибудь, пожалуйста, помогите?

1 ответ

Решение

Ваша функция потерь не обновляется, потому что keras не компилирует модель после каждой партии и, следовательно, не использует обновленную функцию потерь.

Вы можете определить пользовательский обратный вызов, который будет обновлять значение потери после каждой партии. Что-то вроде этого:

from keras.callbacks import Callback

class UpdateLoss(Callback):
    def on_batch_end(self, batch, logs={}):
        # I am not sure what is the type of the argument you are passing for computing stateValue ??
        stateValue = computeStateValueVectorForCurrentBatch(batch)
        self.model.loss = myLoss(stateValue)
Другие вопросы по тегам