Есть ли достойный обходной путь для сохранения контрольных точек на локальном диске при использовании TPU в Tensorflow?

Ответ на этот вопрос:

Как сохранить файл контрольной точки Tensorflow из Google Colab Laboratory в режиме TPU?

Официальный способ сохранения контрольной точки при использовании TPU Tensorflow - использование облачной службы Google.

Я работаю, если есть обходной путь для тех, кто не хочет использовать GCS. Возможно, для каждой переменной сделайте.eval(), сохраните переменную. И затем установите переменную сохранения в значение 'init' для каждой переменной.

Однако я предвижу главную проблему - сохранение и загрузка параметров для оптимизаторов.

Для Кераса веса действительно сохраняются из ТПУ в локальные

https://colab.research.google.com/github/tensorflow/tpu/blob/master/tools/colab/shakespeare_with_tpu_and_keras.ipynb

ИНФОРМАЦИЯ: tenorflow: копирование весов TPU в CPU

Итак, я думаю, что есть общий обходной путь, без использования керас.

2 ответа

Посмотрите на этот код от Keras

Если я правильно понял, веса не сохраняются напрямую из TPU, вместо этого веса синхронизируются с процессором и сохраняются в хранилище colab.

Я только что нашел решение ниже, увидев этот поток, поэтому я хотел добавить эту опцию. Из документации по тензорному потоку есть optionполе, которое вы можете использовать в функциях сохранения/загрузки/восстановления в keras, а также tf.train.Checkpointи метод сохранения tf.train.CheckpointManagerкоторые позволяют вам использовать экспериментальную стратегию синхронизации с локальным хостом.

Копирование их примера кода:

      model = get_model()

# Saving the model to a path on localhost.
saved_model_path = '/tmp/tf_save'
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)

Источники документации:

Другие вопросы по тегам