LIME Интерпретация классификации изображений для нескольких входов DNN
Я довольно новичок в Deep Learning, но мне удалось построить многоотраслевую архитектуру классификации изображений, которая дала вполне удовлетворительные результаты.
Не так важно: я работаю над оттоком клиентов KKBox ( https://kaggle.com/c/kkbox-churn-prediction-challenge/data), где я преобразовал поведение клиентов, транзакции и статические данные в тепловые карты и пытаюсь классифицировать основанные на них данные на что.
Сама классификация работает просто отлично. Моя проблема возникает, когда я пытаюсь применить LIME, чтобы увидеть, откуда приходят результаты. При использовании следующего кода: https://marcotcr.github.io/lime/tutorials/Tutorial%20-%20images.html за исключением того, что я использую список входных данных [members[0], Transactions[0],user_logs[0]], я получаю следующую ошибку: AttributeError: у объекта 'list' нет атрибута 'shape'
Что приходит на ум, так это то, что LIME, вероятно, не предназначен для архитектур с несколькими входами, таких как моя. С другой стороны, Microsoft Azure также имеет многоотраслевую архитектуру ( http://www.freepatentsonline.com/20180253637.pdf?fbclid=IwAR1j30etyDGPCmG-QGfb8qaGRysvnS_f5wLnKz-KdwEbp2Gk0_-OBsSepVc и предполагаемый результат (и предполагаемый результат) (и предполагаемый результат их использования) https://www.slideshare.net/FengZhu18/predicting-azure-churn-with-deep-learning-and-explaining-predictions-with-lime).
Я попытался объединить изображения в один вход, но такой подход дает гораздо худшие результаты, чем мульти-вход. LIME работает для этого подхода, хотя (хотя и не так понятно, как при обычном распознавании изображений).
Архитектура DNN:
# Members
members_input = Input(shape=(61,4,3), name='members_input')
x1 = Dropout(0.2)(members_input)
x1 = Conv2D(32, kernel_size = (61,4), padding='valid', activation='relu', strides=1)(x1)
x1 = GlobalMaxPooling2D()(x1)
# Transactions
transactions_input = Input(shape=(61,39,3), name='transactions_input')
x2 = Dropout(0.2)(transactions_input)
x2 = Conv2D(32, kernel_size = (61,1,), padding='valid', activation='relu', strides=1)(x2)
x2 = Conv2D(32, kernel_size = (1,39,), padding='valid', activation='relu', strides=1)(x2)
x2 = GlobalMaxPooling2D()(x2)
# User logs
userlogs_input = Input(shape=(61,7,3), name='userlogs_input')
x3 = Dropout(0.2)(userlogs_input)
x3 = Conv2D(32, kernel_size = (61,1,), padding='valid', activation='relu', strides=1)(x3)
x3 = Conv2D(32, kernel_size = (1,7,), padding='valid', activation='relu', strides=1)(x3)
x3 = GlobalMaxPooling2D()(x3)
# User_logs + Transactions + Members
merged = keras.layers.concatenate([x1,x2,x3]) # Merged layer
out = Dense(2)(merged)
out_2 = Activation('softmax')(out)
model = Model(inputs=[members_input, transactions_input, userlogs_input], outputs=out_2)
model.compile(optimizer="adam", loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
Попытка использования ЛАЙМ:
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance([members_test[0],transactions_test[0],user_logs_test[0]], model.predict, top_labels=2, hide_color=0, num_samples=1000)
Краткое изложение модели:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
transactions_input (InputLayer) (None, 61, 39, 3) 0
__________________________________________________________________________________________________
userlogs_input (InputLayer) (None, 61, 7, 3) 0
__________________________________________________________________________________________________
members_input (InputLayer) (None, 61, 4, 3) 0
__________________________________________________________________________________________________
dropout_2 (Dropout) (None, 61, 39, 3) 0 transactions_input[0][0]
__________________________________________________________________________________________________
dropout_3 (Dropout) (None, 61, 7, 3) 0 userlogs_input[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout) (None, 61, 4, 3) 0 members_input[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 1, 39, 32) 5888 dropout_2[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 1, 7, 32) 5888 dropout_3[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 1, 1, 32) 23456 dropout_1[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 1, 1, 32) 39968 conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 1, 1, 32) 7200 conv2d_4[0][0]
__________________________________________________________________________________________________
global_max_pooling2d_1 (GlobalM (None, 32) 0 conv2d_1[0][0]
__________________________________________________________________________________________________
global_max_pooling2d_2 (GlobalM (None, 32) 0 conv2d_3[0][0]
__________________________________________________________________________________________________
global_max_pooling2d_3 (GlobalM (None, 32) 0 conv2d_5[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 96) 0 global_max_pooling2d_1[0][0]
global_max_pooling2d_2[0][0]
global_max_pooling2d_3[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 2) 194 concatenate_1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 2) 0 dense_1[0][0]
==================================================================================================
Отсюда мой вопрос: есть ли у кого-нибудь опыт работы с мульти-входной архитектурой DNN и LIME? Есть ли обходной путь, которого я не вижу? Есть ли другая интерпретируемая модель, которую я мог бы использовать?
Спасибо.