Пакетный размер в форме ввода цепей CNN
У меня есть тренировочный набор из 9957 изображений. Тренировочный набор имеет форму (9957, 3, 60, 80). Требуется ли размер партии при установке тренировочного набора на модель? При необходимости можно ли считать исходную форму правильной для подгонки к слою conv2D или мне нужно добавить размер пакета в input_shape?
X_train.shape
(9957, 60,80,3) из chainer.datasets import split_dataset_random из chainer.dataset import DatasetMixin
import numpy as np
class MyDataset(DatasetMixin):
def __init__(self, X, labels):
super(MyDataset, self).__init__()
self.X_ = X
self.labels_ = labels
self.size_ = X.shape[0]
def __len__(self):
return self.size_
def get_example(self, i):
return np.transpose(self.X_[i, ...], (2, 0, 1)), self.labels_[i]
batch_size = 3
label_train = y_trainHot1
dataset = MyDataset(X_train1, label_train)
dataset_train, valid = split_dataset_random(dataset, 8000, seed=0)
train_iter = iterators.SerialIterator(dataset_train, batch_size)
valid_iter = iterators.SerialIterator(valid, batch_size, repeat=False,
shuffle=False)
1 ответ
Приведенный ниже код говорит вам, что вам не нужно заботиться о размере партии самостоятельно. Вы просто используете DatsetMixin
а также SerialIterator
как указано в руководстве по цепочке.
from chainer.dataset import DatasetMixin
from chainer.iterators import SerialIterator
import numpy as np
NUM_IMAGES = 9957
NUM_CHANNELS = 3 # RGB
IMAGE_WIDTH = 60
IMAGE_HEIGHT = 80
NUM_CLASSES = 10
BATCH_SIZE = 32
TRAIN_SIZE = min(8000, int(NUM_IMAGES * 0.9))
images = np.random.rand(NUM_IMAGES, NUM_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT)
labels = np.random.randint(0, NUM_CLASSES, (NUM_IMAGES,))
class MyDataset(DatasetMixin):
def __init__(self, images_, labels_):
# note: input arg.'s tailing underscore is just to avoid shadowing
super(MyDataset, self).__init__()
self.images_ = images_
self.labels_ = labels_
self.size_ = len(labels_)
def __len__(self):
return self.size_
def get_example(self, i):
return self.images_[i, ...], self.labels_[i]
dataset_train = MyDataset(images[:TRAIN_SIZE, ...], labels[:TRAIN_SIZE])
dataset_valid = MyDataset(images[TRAIN_SIZE:, ...], labels[TRAIN_SIZE:])
train_iter = SerialIterator(dataset_train, BATCH_SIZE)
valid_iter = SerialIterator(dataset_valid, BATCH_SIZE, repeat=False, shuffle=False)
###############################################################################
"""This block is just for the confirmation.
.. note: NOT recommended to call :func:`concat_examples` in your code.
Use :class:`chainer.updaters.StandardUpdater` instead.
"""
from chainer.dataset import concat_examples
batch_image, batch_label = concat_examples(next(train_iter))
print("batch_image.shape\n{}".format(batch_image.shape))
print("batch_label.shape\n{}".format(batch_label.shape))
Выход
batch_image.shape
(32, 3, 60, 80)
batch_label.shape
(32,)
Следует отметить, что chainer.dataset.concat_example
это немного сложная часть. Обычно пользователи не обращают внимания на эту функцию, если вы используете StandardUpdater
которая скрывает нативную функцию chainer.dataset.concat_example
,
Так как цепейник спроектирован по схеме Trainer
, (Standard)Updater
, немного Optimizer
, (Serial)Iterator
а также Dataset(Mixin)
, если вы не будете следовать этой схеме, вы должны погрузиться в море chainer
исходный код.