Передача обучения / переподготовки с помощью TensorFlow Estimators

Мне не удалось выяснить, как использовать трансферное обучение / последний уровень переподготовки с новым API TF Estimator.

Estimator требует model_fn который содержит архитектуру сети, а также обучающие и оценочные операции, как определено в документации. Пример model_fn с использованием архитектуры CNN здесь.

Если я хочу переобучить последний слой, например, начальной архитектуры, я не уверен, нужно ли мне указывать всю модель в этом model_fn, затем загрузите предварительно обученные веса, или есть ли способ использовать сохраненный график, как это делается в "традиционном" подходе (пример здесь).

Это было поднято как проблема, но все еще открыто, и ответы для меня неясны.

1 ответ

Решение

Можно загрузить метаграф во время определения модели и использовать SessionRunHook для загрузки весов из файла ckpt.

def model(features, labels, mode, params):
    # Create the graph here

    return tf.estimator.EstimatorSpec(mode, 
            predictions,
            loss,
            train_op,
            training_hooks=[RestoreHook()])

SessionRunHook может быть:

class RestoreHook(tf.train.SessionRunHook):

    def after_create_session(self, session, coord=None):
        if session.run(tf.train.get_or_create_global_step()) == 0:
            # load weights here

Таким образом, веса загружаются на первом этапе и сохраняются во время обучения в контрольных точках модели.

Другие вопросы по тегам