Проблема tf.keras: пользовательский слой «gradient_reversal layer» вызывал ошибку

Я пытаюсь реализовать сеть состязательного обучения домена (DANN), используя тензорный поток 2.3.0 со встроенными слоями keras. Очень не повезло, у меня было много ошибок, которые, возможно, были вызваны версией tf или keras, что для меня является чрезвычайно сложной задачей.

Во-первых, я пытаюсь добавить настраиваемый слой обращения градиента в начало классификатора «домена»:

      # Domain Adaptation section - classify the domain
def domain_model(encoder):
    flip_layer = GradientReversal(hp_lambda=1000) # custom layer that reverse the gradient
    dann_in = flip_layer(encoder.output) # if I remove this layer, everthing works well
    domain_classifier = Flatten(name="do4")(dann_in)
    domain_classifier = BatchNormalization(name="do5")(domain_classifier)
    domain_classifier = Activation("relu", name="do6")(domain_classifier)
    domain_classifier = Dropout(0.5)(domain_classifier)
    domain_classifier = Dense(64, activation='softmax', name="do7")(domain_classifier)
    domain_classifier = Activation("relu", name="do8")(domain_classifier)
    dann_out = Dense(2, activation='softmax', name="domain")(domain_classifier)
    domain_classification_model = Model(inputs=encoder.input, outputs=dann_out)
    return domain_classification_model

Слой обращения градиента записывается так:

      @tf.custom_gradient # if I comment this, the function will not be called
def reverse_gradient(X, hp_lambda):
    """Flips the sign of the incoming gradient during training."""
    try:
        reverse_gradient.num_calls += 1
    except AttributeError:
        reverse_gradient.num_calls = 1

    grad_name = "GradientReversal%d" % reverse_gradient.num_calls

    @ops.RegisterGradient(grad_name)
    def _flip_gradients(grad):
        return [tf.negative(grad) * hp_lambda]

    with tf.Graph().as_default() as g:
        with g.gradient_override_map({'Identity': grad_name}):
            y = tf.identity(X)
    return y


class GradientReversal(Layer):
"""Layer that flips the sign of gradient during training."""

def __init__(self, hp_lambda, **kwargs):
    super(GradientReversal, self).__init__(**kwargs)
    self.supports_masking = True
    self.hp_lambda = hp_lambda

@staticmethod
def get_output_shape_for(input_shape):
    return input_shape

def build(self, input_shape):
    self._trainable_weights = [] # I'm not sure if it's necessary.
    #self.trainable_weights = [] 

def call(self, x, mask=None):
    return reverse_gradient(x, self.hp_lambda)

def get_config(self):
    config = {}
    base_config = super(GradientReversal, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

Моя сеть состоит из автоэнкодера и классификатора домена, подключенного к encoder.output:

      # domain adversarial learning: ae+domain
self.domain_classification_model = domain_model(self.encoder)
self.domain_classification_model.compile(optimizer="Adam",
                                             loss=['binary_crossentropy'], metrics=['accuracy'])
self.comb_model = Model(inputs=self.autoencoder.input,
                            outputs=[self.autoencoder.output, 
                                     self.domain_classification_model.output])
self.comb_model.compile(optimizer="Adam",
                            loss=['mse', 'binary_crossentropy'], loss_weights=[1, 200], metrics=['accuracy'], )

При построении этой сети возникает ошибка, я почти уверен, что это вызвано моим слоем обращения градиента:

      2021-02-26 22:34:45.115828: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations:  AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Traceback (most recent call last):
  File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site- 
packages\tensorflow\python\keras\engine\base_layer_v1.py", line 776, in __call__
outputs = call_fn(cast_inputs, *args, **kwargs)
File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\autograph\impl\api.py", line 258, in wrapper
raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: in user code:

D:\Skill-worker-research\Python code and example data\SupplementarySoftware_DeepHL_python\DeepHL_python\Attention_model.py:268 call  *
    return reverse_gradient(x, self.hp_lambda)
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\ops\custom_gradient.py:264 __call__  **
    return self._d(self._f, a, k)
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\ops\custom_gradient.py:220 decorated
    return _graph_mode_decorator(wrapped, args, kwargs)
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\ops\custom_gradient.py:325 _graph_mode_decorator
    result, grad_fn = f(*args)
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\ops.py:503 __iter__
    self._disallow_iteration()
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\ops.py:499 _disallow_iteration
    self._disallow_in_graph_mode("iterating over `tf.Tensor`")
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\ops.py:479 _disallow_in_graph_mode
    " this function with @tf.function.".format(task))

OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "D:/Skill-worker-research/Python code and example data/SupplementarySoftware_DeepHL_python/DeepHL_python/train.py", line 111, in train
dtc.initialize()  # compile self.model
File "D:\Skill-worker-research\Python code and example data\SupplementarySoftware_DeepHL_python\DeepHL_python\Attention_model.py", line 453, in initialize
self.domain_classification_model = domain_model(self.encoder)
File "D:\Skill-worker-research\Python code and example data\SupplementarySoftware_DeepHL_python\DeepHL_python\Attention_model.py", line 213, in domain_model
dann_in = flip_layer(encoder.output)
File "D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\base_layer_v1.py", line 783, in __call__
str(e) + '\n"""')
TypeError: You are attempting to use Python control flow in a layer that was not declared to be dynamic. Pass `dynamic=True` to the class constructor.
Encountered error:
"""
in user code:

D:\Skill-worker-research\Python code and example data\SupplementarySoftware_DeepHL_python\DeepHL_python\Attention_model.py:268 call  *
    return reverse_gradient(x, self.hp_lambda)
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\ops\custom_gradient.py:264 __call__  **
    return self._d(self._f, a, k)
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\ops\custom_gradient.py:220 decorated
    return _graph_mode_decorator(wrapped, args, kwargs)
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\ops\custom_gradient.py:325 _graph_mode_decorator
    result, grad_fn = f(*args)
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\ops.py:503 __iter__
    self._disallow_iteration()
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\ops.py:499 _disallow_iteration
    self._disallow_in_graph_mode("iterating over `tf.Tensor`")
D:\Users\xiqxi\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\ops.py:479 _disallow_in_graph_mode
    " this function with @tf.function.".format(task))

OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.

"""

Может быть, потому что некоторые функции, которые я написал, могут работать только в tensorflow_v1, а не в tensorflow_V2? Но я искал во всей сети, только несколько человек пишут слой с переворачивающимся градиентом, используя чистый tensorflow_V2.

Если кто знаком с tensorflow или Keras? Не могли бы вы дать мне какие-нибудь предложения? Буду очень признателен, если вы сможете обсудить со мной.

PS: Мои импортные пакеты следующие:

      import tensorflow as tf
print(tf.__version__)
from tensorflow.keras.layers import Input, LSTM, Dense, Dropout, Activation, TimeDistributed, Flatten, Masking, Embedding, Conv1D, RepeatVector, Permute, Lambda, AveragePooling1D, BatchNormalization, MaxPooling1D, UpSampling1D

from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K

Я сослался на этот код , который пишет в keras.

0 ответов