Используя tf.keras в TF 2.0, как я могу определить пользовательский слой, который зависит от фазы обучения?
Я хочу построить пользовательский слой, используя tf.keras. Для простоты предположим, что он должен возвращать входные данные *2 во время обучения и входные данные *3 во время тестирования. Как правильно это сделать?
Я попробовал этот подход:
class CustomLayer(Layer):
@tf.function
def call(self, inputs, training=None):
if training:
return inputs*2
else:
return inputs*3
Затем я могу использовать этот класс следующим образом:
>>> layer = CustomLayer()
>>> layer(10)
tf.Tensor(30, shape=(), dtype=int32)
>>> layer(10, training=True)
tf.Tensor(20, shape=(), dtype=int32)
Работает отлично! Тем не менее, когда я использую этот класс в модели, и я называю его fit()
метод, кажется, что training
не установлен в True
, Я попытался добавить следующий код в начале call()
метод, но training
всегда равно 0.
if training is None:
training = K.learning_phase()
Что мне не хватает?
редактировать
Я нашел решение (см. Мой ответ), но я все еще ищу более хорошее решение, используя @tf.function
(Я предпочитаю автограф этому smart_cond()
бизнес). К сожалению, похоже K.learning_phase()
не играет с @tf.function
(я думаю, что когда call()
Функция отслеживается, фаза обучения жестко запрограммирована в графе: так как это происходит до вызова fit()
метод, фаза обучения всегда 0). Это может быть ошибка, или, возможно, есть другой способ получить фазу обучения при использовании @tf.function
,
2 ответа
Франсуа Шоле подтвердил, что правильное решение при использовании @tf.function
является:
class CustomLayer(Layer):
@tf.function
def call(self, inputs, training=None):
if training is None:
training = K.learning_phase()
if training:
return inputs * 2
else:
return inputs * 3
В настоящее время есть ошибка (по состоянию на 15 февраля 2019 года), которая делает training
всегда равно 0
, но это будет исправлено в ближайшее время.
Следующий код не использует @tf.function
, поэтому он не выглядит так хорошо (так как он не использует автограф), но работает нормально:
from tensorflow.python.keras.utils.tf_utils import smart_cond
class CustomLayer(Layer):
def call(self, inputs, training=None):
if training is None:
training = K.learning_phase()
return smart_cond(training, lambda: inputs * 2, lambda: inputs * 3)