Как загрузить обученную модель protobuf TF1 в TF2?

Я создал и обучил модель с использованием stable-baselines, в которой используется Tensorflow 1. Теперь мне нужно использовать эту обученную модель в среде, где у меня есть доступ только к Tensorflow 2 или PyTorch. Я решил, что выберу Tensorflow 2, поскольку в документации сказано, что я смогу загружать модели, созданные с помощью Tensorflow 1.

Я могу без проблем загрузить pb-файл в Tensorflow 1:

global_session = tf.Session()

with global_session.as_default():
    model_loaded = tf.saved_model.load_v2('tensorflow_model')
    model_loaded = model_loaded.signatures['serving_default']

init = tf.global_variables_initializer()
global_session.run(init)

Однако в Tensorflow 2 я получаю следующую ошибку:

can_be_imported = tf.saved_model.contains_saved_model('tensorflow_model')
assert(can_be_imported)
model_loaded = tf.saved_model.load('tensorflow_model/')

ValueError: Node 'loss/gradients/model/batch_normalization_3/FusedBatchNormV3_1_grad/FusedBatchNormGradV3' has an _output_shapes attribute inconsistent with the GraphDef for output #3: Dimension 0 in both shapes must be equal, but are 0 and 64. Shapes are [0] and [64].

Определение модели:

NUM_CHANNELS = 64

BN1 = BatchNormalization()
BN2 = BatchNormalization()
BN3 = BatchNormalization()
BN4 = BatchNormalization()
BN5 = BatchNormalization()
BN6 = BatchNormalization()
CONV1 = Conv2D(NUM_CHANNELS, kernel_size=3, strides=1, padding='same')
CONV2 = Conv2D(NUM_CHANNELS, kernel_size=3, strides=1, padding='same')
CONV3 = Conv2D(NUM_CHANNELS, kernel_size=3, strides=1)
CONV4 = Conv2D(NUM_CHANNELS, kernel_size=3, strides=1)
FC1 = Dense(128)
FC2 = Dense(64)
FC3 = Dense(7)

def modified_cnn(inputs, **kwargs):
    relu = tf.nn.relu
    log_softmax = tf.nn.log_softmax
    
    layer_1_out = relu(BN1(CONV1(inputs)))
    layer_2_out = relu(BN2(CONV2(layer_1_out)))
    layer_3_out = relu(BN3(CONV3(layer_2_out)))
    layer_4_out = relu(BN4(CONV4(layer_3_out)))
    
    flattened = tf.reshape(layer_4_out, [-1, NUM_CHANNELS * 3 * 2]) 
    
    layer_5_out = relu(BN5(FC1(flattened)))
    layer_6_out = relu(BN6(FC2(layer_5_out)))
    
    return log_softmax(FC3(layer_6_out))

class CustomCnnPolicy(CnnPolicy):
    def __init__(self, *args, **kwargs):
        super(CustomCnnPolicy, self).__init__(*args, **kwargs, cnn_extractor=modified_cnn)

model = PPO2(CustomCnnPolicy, env, verbose=1)

Сохранение модели в TF1:

with model.graph.as_default():
    tf.saved_model.simple_save(model.sess, 'tensorflow_model', inputs={"obs": model.act_model.obs_ph},
                                   outputs={"action": model.act_model._policy_proba})

Полностью воспроизводимый код можно найти в следующих двух блокнотах google colab: сохранение и загрузка Tensorflow 1 и загрузка Tensorflow 2

Прямая ссылка на сохраненную модель: модель

1 ответ

Решение

Вы можете использовать уровень совместимости TensorFlow.

Все v1 функциональность доступна под tf.compat.v1 пространство имен.

Мне удалось загрузить вашу модель в TF 2.1 (ничего особенного в этой версии, просто она у меня локально):

import tensorflow as tf

tf.__version__
Out[2]: '2.1.0'

model = tf.compat.v1.saved_model.load_v2('~/tmp/tensorflow_model')

model.signatures
Out[3]: _SignatureMap({'serving_default': <tensorflow.python.eager.wrap_function.WrappedFunction object at 0x7ff9244a6908>})
Другие вопросы по тегам