Реализовать трансферное обучение по niftynet

Я хочу реализовать трансферное обучение с использованием архитектуры Dense V-Net. Когда я искал, как это сделать, я обнаружил, что эта функция в настоящее время работает ( Как я могу реализовать трансферное обучение в NiftyNet?).

Хотя из этого ответа совершенно ясно, что не существует прямого способа его реализации, я пытался:

1) Создайте Dense V-Net

2) Восстановите вес из файла.ckpt

3) Внедрить трансферное обучение самостоятельно

Для выполнения шага 1 я подумал, что смогу использовать модуль niftynet.network.dense_vnet. Поэтому я попробовал следующее:

checkpoint = '/path_to_ckpt/model.ckpt-3000.index'
x = tf.placeholder(dtype=tf.float32, shape=[None,1,144,144,144])
architecture_parameters = dict(
    use_bdo=False,
    use_prior=False,
    use_dense_connections=True,
    use_coords=False)

hyperparameters = dict(
    prior_size=12,
    n_dense_channels=(4, 8, 16),
    n_seg_channels=(12, 24, 24),
    n_input_channels=(24, 24, 24),
    dilation_rates=([1] * 5, [1] * 10, [1] * 10),
    final_kernel=3,
    augmentation_scale=0)
model_instance = DenseVNet(num_classes=9,hyperparameters=hyperparameters,
                             architecture_parameters=architecture_parameters)

model_net = DenseVNet.layer_op(model_instance, x)

Однако я получаю следующую ошибку:

TypeError: Failed to convert object of type <type 'list'> to Tensor. Contents: [None, 1, 72, 72, 24]. Consider casting elements to a supported type.

Итак, вопрос в следующем:

Есть ли способ реализовать это?

2 ответа

Трансферное обучение было добавлено в NiftyNet.

Вы можете выбрать, какие переменные вы хотите восстановить через vars_to_restore параметр конфигурации и какие переменные заморозить через vars_to_freeze параметр конфигурации

Смотрите здесь для получения дополнительной информации.

Простое обучение переносу может быть достигнуто с восстановлением весов из существующей модели так, как вы устанавливаете параметр starting_iter в [TRAINING] раздел вашего конфигурационного файла на номер предварительно обученной модели. В вашем примере starting_iter=3000,

Это восстановит веса из вашей модели, и с этой инициализацией начнутся новые итерации.

Здесь архитектура вашей модели должна быть точно такой же, иначе вы получите ошибку.

Для более сложного обучения переносу или, возможно, также для тонкой настройки, когда вы можете восстановить только часть весов, здесь есть отличная реализация. Возможно, он скоро будет объединен с официальным репозиторием niftynet, но вы уже можете его использовать.

Другие вопросы по тегам