Как сделать BatchNormalisation обучаемым, когда все остальные промежуточные уровни (в случае ResNet,DenseNet) зависают?
from keras.applications.densenet import DenseNet201
conv_base = DenseNet201(weights= 'imagenet', include_top=False, input_shape= (200,200,3))
model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(3, activation='softmax'))
conv_base.trainable = True
set_trainable = False
for layer in conv_base.layers:
if layer.name == 'conv4_block40_0_bn':
set_trainable = True
if set_trainable:
layer.trainable = True
else:
layer.trainable = False
Я использую его для замораживания верхних слоев DenseNet201. Теперь, как добавить или удалить BatchNormalisation/Dropout из предварительно обученной (DenseNet201) модели. Как сделать обучаемым или нет из любого места, где он присутствует?Как должен выглядеть код?
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras import optimizers
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(200,200),
batch_size=20,
class_mode='categorical',
shuffle = True
)
validation_generator = test_datagen.flow_from_directory(
test_dir,
target_size=(200,200),
batch_size=20,
class_mode='categorical',
shuffle = True)
checkpoint = ModelCheckpoint('model-{epoch:03d}-{acc:03f}-{val_acc:03f}.h5', verbose=1, monitor='val_loss',save_best_only=True, mode='auto')
model.compile(loss='categorical_crossentropy',optimizer=optimizers.RMSprop(lr=2e-6),metrics=['acc'])
history = model.fit_generator(
train_generator,
steps_per_epoch=1180//20,
epochs=100,
validation_data=validation_generator,
validation_steps=290//20,
callbacks=[checkpoint]
)
Я использую ImageDataGenerator для получения пакетных данных. Я новичок в глубоком обучении.