Керас выбирает случайную ветвь

Я хочу случайным образом выбирать между двумя ветвями сети во время обучения и выполнения. Я хочу быть в состоянии сделать многие из этих случайных выборов.

Моя мотивация состоит в том, чтобы исследовать сходства между заменяемыми частями сети после того, как вся система сходится. В конце концов, я бы хотел, чтобы моя система выглядела так: введите описание изображения здесь

Я пробовал этот подход, но, похоже, есть проблема с состоянием K.switch заявление:

# can ignore details of this function
def make_first_half(input):
    m = Conv2D(16,
               kernel_size=(3, 3),
               activation='relu',
               input_shape=input_shape)(input)
    m = Conv2D(32, (3, 3), activation='relu')(m)
    m = MaxPooling2D(pool_size=(2, 2))(m)
    m = Flatten()(m)
    return m

# can ignore details of this function
def make_second_half(input):
    m = Dense(32, activation='relu')(input)
    m = Dense(num_classes, activation='softmax')(m)
    return m


input = Input(shape=input_shape)

a1 = make_first_half(input)
a2 = make_first_half(input)

a = K.switch(
    lambda x: 1 == random.randint(0, 1),
    a1, a2)

prediction = make_second_half(a)

Я думаю, что я понимаю, что не так: состояние K.switch должен быть тензором вместо функции, возвращающей логическое значение. Тем не менее, я не уверен, как я могу сделать тензор, который является случайным для каждого раза, когда он достигает этой точки. То есть я не хочу случайным образом выбирать путь в самом начале, а затем каждый раз идти по этому пути, игнорируя другой.

Я также предпочел бы проводить только одну, а не несколько тренировок (т.е. только одну Model заявление, один compile заявление, и один fit заявление).

Я использую бэкэнд TensorFlow.

Вполне возможно, что я думаю об этом неправильно. Пожалуйста, объясните, как я могу добиться этого случайного ветвления.

0 ответов

Другие вопросы по тегам