Передача обучения / переподготовки с помощью TensorFlow Estimators
Мне не удалось выяснить, как использовать трансферное обучение / последний уровень переподготовки с новым API TF Estimator.
Estimator
требует model_fn
который содержит архитектуру сети, а также обучающие и оценочные операции, как определено в документации. Пример model_fn
с использованием архитектуры CNN здесь.
Если я хочу переобучить последний слой, например, начальной архитектуры, я не уверен, нужно ли мне указывать всю модель в этом model_fn
, затем загрузите предварительно обученные веса, или есть ли способ использовать сохраненный график, как это делается в "традиционном" подходе (пример здесь).
Это было поднято как проблема, но все еще открыто, и ответы для меня неясны.
1 ответ
Можно загрузить метаграф во время определения модели и использовать SessionRunHook для загрузки весов из файла ckpt.
def model(features, labels, mode, params):
# Create the graph here
return tf.estimator.EstimatorSpec(mode,
predictions,
loss,
train_op,
training_hooks=[RestoreHook()])
SessionRunHook может быть:
class RestoreHook(tf.train.SessionRunHook):
def after_create_session(self, session, coord=None):
if session.run(tf.train.get_or_create_global_step()) == 0:
# load weights here
Таким образом, веса загружаются на первом этапе и сохраняются во время обучения в контрольных точках модели.