Пользовательская функция потери / цели с дополнительным вводом переменных в Keras
Я пытаюсь создать пользовательскую целевую функцию в Keras (бэкэнд тензорного потока) с дополнительным параметром, значение которого будет зависеть от обучаемого пакета.
Например:
def myLoss(self, stateValues):
def sparse_loss(y_true, y_pred):
foo = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)
return tf.reduce_mean(foo * stateValues)
return sparse_loss
self.model.compile(loss=self.myLoss(stateValue = self.stateValue),
optimizer=Adam(lr=self.alpha))
Моя функция поезда заключается в следующем
for batch in batches:
self.stateValue = computeStateValueVectorForCurrentBatch(batch)
model.fit(xVals, yVals, batch_size=<num>)
Однако stateValue в функции потерь не обновляется. Он просто использует значение, которое stateValue имеет на шаге model.compile.
Я думаю, что это можно решить с помощью placeHolder для stateValue, но я не могу понять, как это сделать. Может кто-нибудь, пожалуйста, помогите?
1 ответ
Ваша функция потерь не обновляется, потому что keras не компилирует модель после каждой партии и, следовательно, не использует обновленную функцию потерь.
Вы можете определить пользовательский обратный вызов, который будет обновлять значение потери после каждой партии. Что-то вроде этого:
from keras.callbacks import Callback
class UpdateLoss(Callback):
def on_batch_end(self, batch, logs={}):
# I am not sure what is the type of the argument you are passing for computing stateValue ??
stateValue = computeStateValueVectorForCurrentBatch(batch)
self.model.loss = myLoss(stateValue)