Обеспечение передачи данных в Keras `fit_generator()` запускается заново с каждой эпохой

Я создал генератор данных для передачи в Керас fit_generator(), но количество данных не всегда точно кратно размеру пакета и steps_per_epochпоэтому, когда эпоха заканчивается, я хочу убедиться, что генератор данных сбрасывается и запускается в начале набора данных (HDF5) еще раз. Мой классификатор ML - LSTM, и это данные временной последовательности, поэтому порядок важен.

Я знаю, что есть on_epoch_end() Функция обратного вызова доступна, когда вы используете генератор данных, встроенный в класс Python, такой как описанный в статье Afshine Amidi. Подробный пример того, как использовать генераторы данных в Keras, но я никогда не смог заставить работать классовую версию генератора. В любом случае, у меня огромный набор данных объемом 20 ГБ, и у меня много проблем с отмиранием моих ядер jupyter и с ошибками памяти, даже в PyCharm, и я предпочитаю по возможности использовать более простую версию генератора данных.

Вот генератор данных, который я передаю в Керас fit_generator(), Обратите внимание на строку кода, которая печатает индексы, используемые для извлечения данных из объекта набора данных HDF5 при каждой итерации генератора циклических данных. Если вы посмотрите на то, что эта строка кода выводит на экран, вы заметите, что индексы не синхронизированы с эпохами:

def data_generator(dataID,
                   batch_size, 
                   dim = (20, 100)):      

    print("\nIn data_generator.\n")      

    DataDir = "data/"
    BatchDir = DataDir + dataID + "/"
    h5path = BatchDir + dataID + ".h5"

    f = h5py.File(h5path, "r")
    data = f["sigccm"]
    labels = f["Attack"]

    # number of records to process
    nrecs = len(data)

    outputshape = (batch_size, *dim)

    while True:            

        for i in range(0, nrecs, batch_size):

            'Generate one batch of data'

            if i + batch_size > nrecs:

                # upperbound = nrecs 

                # If we can get a complete batch
                # out of the remaining data, go
                # ahead and wrap up this epoch and
                # start the next one.

                break

            else: 
                upperbound = i + batch_size
                print("\ndata[%d : %d]\n" % (i, upperbound)) # <<<<< KEY LINE of CODE
                X = np.array(data[i : upperbound]) 
                y = np.array(labels[i : upperbound])             

            if outputshape != X.shape:
                msg = "Wrong shape: "
                idx = "index = {:d}, ".format(index)
                shp = "X.shape = {:s}".format(str(X.shape))
                msg = msg + idx + shp
                print(msg)

            else:

                # Label (Attack) field has an extra
                # nested array dimension. Get rid of it.

                u = np.array(y)
                y = np.resize(u, (batch_size, 1))
                yield X, y  

    f.close()

Вот что я получаю при запуске Keras:

In data_generator.
data[0 : 200]

Epoch 1/10
data[200 : 400]
data[400 : 600]

 1/24 [>.............................] - ETA: 1:33 - loss: 0.2952 - acc: 0.1350
data[600 : 800]

 2/24 [=>............................] - ETA: 1:10 - loss: 0.2234 - acc: 0.4975
data[800 : 1000]

 3/24 [==>...........................] - ETA: 1:01 - loss: 0.1923 - acc: 0.6200
data[1000 : 1200]

 4/24 [====>.........................] - ETA: 55s - loss: 0.1753 - acc: 0.6813 
data[1200 : 1400]

 5/24 [=====>........................] - ETA: 49s - loss: 0.1652 - acc: 0.7170
data[1400 : 1600]

 6/24 [======>.......................] - ETA: 45s - loss: 0.1587 - acc: 0.7400
data[1600 : 1800]

 7/24 [=======>......................] - ETA: 42s - loss: 0.1536 - acc: 0.7571
data[1800 : 2000]

 8/24 [=========>....................] - ETA: 39s - loss: 0.1496 - acc: 0.7700
data[2000 : 2200]

 9/24 [==========>...................] - ETA: 36s - loss: 0.1469 - acc: 0.7794
data[2200 : 2400]

10/24 [===========>..................] - ETA: 34s - loss: 0.1447 - acc: 0.7870
data[2400 : 2600]

11/24 [============>.................] - ETA: 31s - loss: 0.1426 - acc: 0.7936
data[2600 : 2800]

12/24 [==============>...............] - ETA: 29s - loss: 0.1411 - acc: 0.7988
data[2800 : 3000]

13/24 [===============>..............] - ETA: 26s - loss: 0.1395 - acc: 0.8035
data[3000 : 3200]

14/24 [================>.............] - ETA: 24s - loss: 0.1385 - acc: 0.8071
data[3200 : 3400]

15/24 [=================>............] - ETA: 21s - loss: 0.1375 - acc: 0.8103
data[3400 : 3600]

16/24 [===================>..........] - ETA: 19s - loss: 0.1365 - acc: 0.8134
data[3600 : 3800]

17/24 [====================>.........] - ETA: 17s - loss: 0.1358 - acc: 0.8159
data[3800 : 4000]

18/24 [=====================>........] - ETA: 14s - loss: 0.1351 - acc: 0.8181
data[4000 : 4200]

19/24 [======================>.......] - ETA: 12s - loss: 0.1344 - acc: 0.8203
data[4200 : 4400]

20/24 [========================>.....] - ETA: 9s - loss: 0.1339 - acc: 0.8220 
data[4400 : 4600]

21/24 [=========================>....] - ETA: 7s - loss: 0.1334 - acc: 0.8236
data[4600 : 4800]

22/24 [==========================>...] - ETA: 4s - loss: 0.1327 - acc: 0.8255
data[4800 : 5000]

23/24 [===========================>..] - ETA: 2s - loss: 0.1325 - acc: 0.8265
data[0 : 200]


In data_generator.
data[0 : 200]
data[200 : 400]
data[200 : 400]
data[400 : 600]
data[400 : 600]
data[600 : 800]
data[600 : 800]
data[800 : 1000]
data[800 : 1000]
data[1000 : 1200]

24/24 [==============================] - 75s 3s/step - loss: 0.1318 - acc: 0.8281 - val_loss: 0.1190 - val_acc: 0.8625
Epoch 2/10
data[1200 : 1400]

 1/24 [>.............................] - ETA: 43s - loss: 0.1242 - acc: 0.8550
data[1400 : 1600]

 2/24 [=>............................] - ETA: 46s - loss: 0.1209 - acc: 0.8600
data[1600 : 1800]

 3/24 [==>...........................] - ETA: 46s - loss: 0.1208 - acc: 0.8600
data[1800 : 2000]

 4/24 [====>.........................] - ETA: 45s - loss: 0.1199 - acc: 0.8613
data[2000 : 2200]

 5/24 [=====>........................] - ETA: 43s - loss: 0.1194 - acc: 0.8620
data[2200 : 2400]

 6/24 [======>.......................] - ETA: 41s - loss: 0.1196 - acc: 0.8617
data[2400 : 2600]

 7/24 [=======>......................] - ETA: 40s - loss: 0.1202 - acc: 0.8607
data[2600 : 2800]

 8/24 [=========>....................] - ETA: 37s - loss: 0.1202 - acc: 0.8606
data[2800 : 3000]

 9/24 [==========>...................] - ETA: 35s - loss: 0.1203 - acc: 0.8606
data[3000 : 3200]

10/24 [===========>..................] - ETA: 33s - loss: 0.1206 - acc: 0.8600
data[3200 : 3400]

11/24 [============>.................] - ETA: 30s - loss: 0.1209 - acc: 0.8595
data[3400 : 3600]

12/24 [==============>...............] - ETA: 28s - loss: 0.1209 - acc: 0.8596
data[3600 : 3800]

13/24 [===============>..............] - ETA: 26s - loss: 0.1211 - acc: 0.8592
data[3800 : 4000]

14/24 [================>.............] - ETA: 23s - loss: 0.1211 - acc: 0.8593
data[4000 : 4200]

15/24 [=================>............] - ETA: 21s - loss: 0.1213 - acc: 0.8590
data[4200 : 4400]

16/24 [===================>..........] - ETA: 19s - loss: 0.1215 - acc: 0.8588
data[4400 : 4600]

17/24 [====================>.........] - ETA: 16s - loss: 0.1214 - acc: 0.8588
data[4600 : 4800]

18/24 [=====================>........] - ETA: 14s - loss: 0.1215 - acc: 0.8586
data[4800 : 5000]

19/24 [======================>.......] - ETA: 11s - loss: 0.1217 - acc: 0.8584
data[0 : 200]

20/24 [========================>.....] - ETA: 9s - loss: 0.1216 - acc: 0.8585 
data[200 : 400]

21/24 [=========================>....] - ETA: 7s - loss: 0.1217 - acc: 0.8583
data[400 : 600]

22/24 [==========================>...] - ETA: 4s - loss: 0.1218 - acc: 0.8582
data[600 : 800]

23/24 [===========================>..] - ETA: 2s - loss: 0.1216 - acc: 0.8585
data[800 : 1000]

data[0 : 200]

Spew заканчивается здесь, потому что ядро ​​умирает в jupyter, а в PyCharm я получаю MemoryError в следующей строке кода в генераторе данных:

X = np.array(data[i: upperbound])

Итак, у меня есть два вопроса:

  1. Как я могу гарантировать, что генератор данных сбрасывается с каждой эпохой и запускается заново в начале набора данных?
  2. Любые предположения о том, как избежать ошибок в памяти и отмирающих ядер jupyter, также приветствуются, но я предполагаю, что где-то достигну жесткого предела и просто должен использовать меньший набор данных.

0 ответов

Другие вопросы по тегам