Сопоставление шаблонов через сиамскую архитектуру CNN + гиперсеть

В настоящее время я пытаюсь создать NN для "сопоставления с шаблоном" (нахождение под-изображения в более крупном) в Keras. Идея основана на документе: http://cs231n.stanford.edu/reports/2017/pdfs/817.pdf Используется сиамская архитектура CNN, в которой одна ветвь предназначена для изображения шаблона, а другая - для сцены. Выходные данные ветви шаблона изменяются и передаются весам (3, 3, 512, 512) последнего слоя conv2D ветви сцены (гиперсети). Выходные данные ветви сцены окончательно изменяются в соответствии с размером входного изображения сцены. Этот вывод представляет собой двоичное изображение, которое должно показывать локализацию шаблона в сцене.

Прогнозирование на шаблоне из обучающего набора и прогнозирование на шаблоне, которого нет в обучающем наборе

В качестве CNN используется VGG11 или VGG16. Гиперсеть реализована следующим образом:

def hypernet(x):
    # -- Reshape x to the weight shape of a conv2D
    #    with kernel=(3,3) and 512 input and output channels => (3, 3, 512, 512)
    x_list = []
    for i in range(512):
        y = Lambda(lambda a: a[:, :, :, i])(x)
        y = Flatten()(y)
        y = Dense(4608, activation="relu", trainable=False)(y)
        x_list.append(y)
    res = Concatenate()(x_list)
    # (1, 2359296,) -> (3, 3, 512, 512)
    res = Reshape((3, 3, 512, 512), name="h_reshape_to_weight")(res)
    # -- Remove the batch dimension
    res = Lambda(lambda a: a[0], name="h_remove_batchdim")(res)
    return res


def create_model(self):
    # -- TEMPLATE LEG
    # -- Define the template input image for the "leg" network
    w_template, h_template = self.t_shape
    input_template = Input(batch_shape=(1, h_template, w_template, 3), dtype=np.float32, name="input")
    # -- Template branch architecture: VGG11 or VGG16, etc.
    xt= leg(LegType.TEMPLATE, self.architecture, shape=(h_template, w_template, 3),
                  input_tensor=input_template)

    # -- HYPERNETWORK
    # -- Reshape the "leg" network, so that its output can be used as the weights of the main network
    xt_output = hypernet(xt)

    # -- SCENE LEG
    # -- Define the scene input image for the main network
    w_scene, h_scene = self.s_shape
    input_scene = Input(batch_shape=(1, h_scene, w_scene, 3), dtype=np.float32, name="input")
    # -- Scene branch architecture: VGG11 or VGG16, etc.
    xs= leg(LegType.SCENE, self.architecture, shape=(h_scene, w_scene, 3), input_tensor=input_scene)

    # -- Feed the output of the leg network as the weights of the main network (Last Conv2D of the CNN)
    # https://stackru.com/questions/56812831/keras-hypernetwork-implementation
    x = layers.Lambda(lambda a: K.conv2d(a[0], a[1], padding="same"), name="weight_input")([xs, xt_output])
    x = layers.Activation(activation="relu", name="last_conv_activation")(x)

    # -- Reshape the output to two binary masks with sizes match the size of the scene image
    x = layers.Reshape((h_scene, w_scene, 2), name="output_reshape")(x)
    output = layers.Activation(activation="softmax", name="output_softmax")(x)

    self.model = Model(inputs=[input_template, input_scene], outputs=output)

Оптимизатор: Adam (lr = 0,001), потеря = category_crossentropy, metric = acc, размер пакета 1, потому что изображения довольно большие (768x768 + 256x384). По тесту он обучен на 500 парах изображений, где сцена остается прежней, а изменяются только изображения шаблона. Сеть способна правильно предсказать шаблон, только если обучена на нем. Когда набор для обучения и проверки основан на разных шаблонах, точность и val_accuracy сходятся к 0,9888, что в основном представляет собой двоичное изображение значений интенсивности = 0.

В настоящее время я изучаю следующие изменения: 1) Использование верхних слоев FCN 2) Использование предварительно обученного VGG16 + замораживание 3) Увеличение регуляризации L2 4) Увеличение отсева 5) Увеличение / уменьшение LR 6) Оптимизатор SGD

Я только начал изучать NN, поэтому, пожалуйста, извините, если я пропустил некоторые важные данные. Есть ли у кого-нибудь представление о том, что я могу сделать, чтобы сеть узнала взаимосвязь между шаблоном и сценой, а не только запомнила конкретный шаблон? Спасибо за ответы заранее!

0 ответов

Другие вопросы по тегам