Обучение CNN на ТПУ с использованием Google Colab

Мне нужно точно настроить CNN (vgg16) для большого набора данных изображений. Я использую Google Colab, и мне нужно использовать TPU для ускорения обучения. Как указано в примере документации, которую я использую tf.keras вместо kerasПосле создания модели, замены полностью связанных слоев и указания обучаемых я скомпилировал модель, используя categorical_crossentropy Функция потерь, и создал генератор поездов:

model.compile(loss='categorical_crossentropy', 
          optimizer= tensorflow.train.RMSPropOptimizer(learning_rate=1e-4),
          metrics = ['accuracy'])
train_data_gen = ImageDataGenerator(rescale=1./255,rotation_range = 20, 
                                width_shift_range = 0.2, 
                                height_shift_range = 0.2, 
                                horizontal_flip = True)
train_gen= train_data_gen.flow_from_directory('/content/drive/My Drive/data/train', 
                                          target_size=(224, 224), 
                                          batch_size = 1024, 
                                          class_mode='categorical' )

Затем я преобразовал модель в совместимую модель TPU с тензорным потоком:

tf.logging.set_verbosity(tf.logging.INFO)
tpu_model = tf.contrib.tpu.keras_to_tpu_model(model,
strategy=tf.contrib.tpu.TPUDistributionStrategy(
    tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)))

Когда я пытаюсь тренировать модель, используя следующую команду:

history = tpu_model.fit_generator(train_gen, 
                    steps_per_epoch=train_gen.samples/train_gen.batch_size, 
                    epochs=1000, 
                    verbose=2, 
                    shuffle= False)

Я получаю это сообщение об ошибке:

RuntimeError: Compilation failed: Compilation failure: 
Detected unsupported operations when trying to compile graph 
cluster_4125121914893370080[] on XLA_TPU_JIT: Placeholder 
(No registered 'Placeholder' OpKernel for XLA_TPU_JIT devices compatible          
with node {{node tpu_140154018445800/input_1}} Registered:  
device='TPU'
device='CPU'
device='GPU'
device='XLA_CPU'
){{node tpu_140154018445800/input_1}}

Может кто-нибудь помочь мне решить это?

0 ответов

Другие вопросы по тегам