Используя tf.keras в TF 2.0, как я могу определить пользовательский слой, который зависит от фазы обучения?

Я хочу построить пользовательский слой, используя tf.keras. Для простоты предположим, что он должен возвращать входные данные *2 во время обучения и входные данные *3 во время тестирования. Как правильно это сделать?

Я попробовал этот подход:

class CustomLayer(Layer):
    @tf.function
    def call(self, inputs, training=None):
        if training:
            return inputs*2
        else:
            return inputs*3

Затем я могу использовать этот класс следующим образом:

>>> layer = CustomLayer()
>>> layer(10)
tf.Tensor(30, shape=(), dtype=int32)
>>> layer(10, training=True)
tf.Tensor(20, shape=(), dtype=int32)

Работает отлично! Тем не менее, когда я использую этот класс в модели, и я называю его fit() метод, кажется, что training не установлен в True, Я попытался добавить следующий код в начале call() метод, но training всегда равно 0.

if training is None:
    training = K.learning_phase()

Что мне не хватает?

редактировать

Я нашел решение (см. Мой ответ), но я все еще ищу более хорошее решение, используя @tf.function (Я предпочитаю автограф этому smart_cond() бизнес). К сожалению, похоже K.learning_phase() не играет с @tf.function (я думаю, что когда call() Функция отслеживается, фаза обучения жестко запрограммирована в графе: так как это происходит до вызова fit() метод, фаза обучения всегда 0). Это может быть ошибка, или, возможно, есть другой способ получить фазу обучения при использовании @tf.function,

2 ответа

Франсуа Шоле подтвердил, что правильное решение при использовании @tf.function является:

class CustomLayer(Layer):
    @tf.function
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()
        if training:
            return inputs * 2
        else:
            return inputs * 3

В настоящее время есть ошибка (по состоянию на 15 февраля 2019 года), которая делает training всегда равно 0, но это будет исправлено в ближайшее время.

Следующий код не использует @tf.function, поэтому он не выглядит так хорошо (так как он не использует автограф), но работает нормально:

from tensorflow.python.keras.utils.tf_utils import smart_cond

class CustomLayer(Layer):
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()
        return smart_cond(training, lambda: inputs * 2, lambda: inputs * 3)
Другие вопросы по тегам