Как сохранить модель Keras, обученную на ТПУ?

Я использую среду Colab для экспериментов с моделью lstm. Но не может спасти обученную модель.

sess = tf.keras.backend.get_session()

training_model = lstm_model(seq_len=100, batch_size=128, stateful=False)

tpu_model = tf.contrib.tpu.keras_to_tpu_model(  training_model,
    strategy=tf.contrib.tpu.TPUDistributionStrategy(
    tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)))


tpu_model.fit_generator( training_generator(seq_len=100, batch_size=1024),    steps_per_epoch=100,
epochs=4
)

export_path = '/content/output/'
tf.saved_model.simple_save(
    sess,
    export_path,
    inputs={'input_image': tpu_model.input},
    outputs={t.name: t for t in tpu_model.outputs})

Вот исключение:

FailedPreconditionError                   Traceback (most recent call last)
<ipython-input-13-020e67d3772b> in <module>()
 29         export_path,
 30         inputs={'input_image': tpu_model.input},
---> 31         outputs={t.name: t for t in tpu_model.outputs})


...skipped....

FailedPreconditionError: Error while reading resource variable. This could mean that the variable was uninitialized. Not found: Resource worker/TFOptimizer/iterations/N10tensorflow3VarE does not exist.
 [[{{node ReadVariables_8976001795006639924/_2}} = _ReadVariablesOp[N=40, dtypes=[DT_INT64, DT_INT64, DT_INT64, DT_INT64, DT_INT64, ..., DT_FLOAT, DT_INT64, DT_INT64, DT_INT64, DT_INT64], _device="/job:worker/replica:0/task:0/device:CPU:0"](VarHandles_14315951673884632260/_0, VarHandles_14315951673884632260/_0:1, VarHandles_14315951673884632260/_0:2, VarHandles_14315951673884632260/_0:3, VarHandles_14315951673884632260/_0:4, VarHandles_14315951673884632260/_0:5, VarHandles_14315951673884632260/_0:6, VarHandles_14315951673884632260/_0:7, VarHandles_14315951673884632260/_0:8, VarHandles_14315951673884632260/_0:9, VarHandles_14315951673884632260/_0:10, VarHandles_14315951673884632260/_0:11, VarHandles_14315951673884632260/_0:12, VarHandles_14315951673884632260/_0:13, VarHandles_14315951673884632260/_0:14, VarHandles_14315951673884632260/_0:15, VarHandles_14315951673884632260/_0:16, VarHandles_14315951673884632260/_0:17, VarHandles_14315951673884632260/_0:18, VarHandles_14315951673884632260/_0:19, VarHandles_14315951673884632260/_0:20, VarHandles_14315951673884632260/_0:21, VarHandles_14315951673884632260/_0:22, VarHandles_14315951673884632260/_0:23, VarHandles_14315951673884632260/_0:24, VarHandles_14315951673884632260/_0:25, VarHandles_14315951673884632260/_0:26, VarHandles_14315951673884632260/_0:27, VarHandles_14315951673884632260/_0:28, VarHandles_14315951673884632260/_0:29, VarHandles_14315951673884632260/_0:30, VarHandles_14315951673884632260/_0:31, VarHandles_14315951...
 [[{{node ReadVariables_16894311020792346126/_3_G1412}} = _Send[T=DT_FLOAT, client_terminated=false, recv_device="/job:worker/replica:0/task:0/device:CPU:0", send_device="/job:worker/replica:0/task:0/device:TPU:0", send_device_incarnation=8311516724619575166, tensor_name="edge_133_ReadVariables_16894311020792346126/_3", _device="/job:worker/replica:0/task:0/device:TPU:0"](ReadVariables_16894311020792346126/_3:8)]]

пожалуйста, порекомендуйте

2 ответа

Это работает, если вы замените свой tf.saved_model.simple_save() позвонить, например, с

tpu_model.save_weights(os.path.join(export_path, 'weights.h5'), overwrite=True)

как в https://colab.research.google.com/github/tensorflow/tpu/blob/master/tools/colab/shakespeare_with_tpu_and_keras.ipynb?

(этот пример и другие ссылки приведены снизу https://colab.research.google.com/notebooks/tpu.ipynb)

Вы можете просто попробовать это для формата SavedModel ,

      # make model
with strategy.scope():
  model = make_model()

# save model locally from tpu using Tensorflow's "SavedModel" format
save_locally = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save('./model', options=save_locally)

# load model in tpu using Tensorflow's "SavedModel" format
with strategy.scope():
    load_locally = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
    model = tf.keras.models.load_model('./model', options=load_locally)
Другие вопросы по тегам