batch_size не совпадает в torchtext BucketIterator
Я поставил batch_size
равно 64
, но когда я распечатаю train_batch и val_batch, размер не равен 64.
Данные поезда и данные в формате VAL представлены в следующем формате:
Во-первых, я определяю TEXT
а также LABEL
поле.
tokenize = lambda x: x.split()
TEXT = data.Field(sequential=True, tokenize=tokenize)
LABEL = data.Field(sequential=False)
А потом я продолжаю пытаться следовать учебникам, и написал вещи ниже:
train_data, valid_data = data.TabularDataset.splits(
path='.',
train='train_intent.csv', validation='val.csv',
format='csv',
fields= {'sentences': ('text', TEXT),
'labels': ('label',LABEL)}
)
test_data = data.TabularDataset(
path='test.csv',
format='csv',
fields={'sentences': ('text', TEXT)}
)
TEXT.build_vocab(train_data)
LABEL.build_vocab(train_data)
BATCH_SIZE = 64
train_iter, val_iter = data.BucketIterator.splits(
(train_data, valid_data),
batch_sizes=(BATCH_SIZE, BATCH_SIZE),
sort_key=lambda x: len(x.text),
sort_within_batch=False,
repeat=False,
device=device
)
Но когда я хочу знать, хорошо это или нет, я просто нахожу ниже странные вещи:
train_batch = next(iter(train_iter))
print(train_batch.text.shape)
print(train_batch.label.shape)
[output]
torch.Size([15, 64])
torch.Size([64])
И ошибка вывода процесса поезда ValueError: Expected input batch_size (15) to match target batch_size (64).
:
def train(model, iterator, optimizer, criterion):
epoch_loss = 0
model.train()
for batch in iterator:
optimizer.zero_grad()
predictions = model(batch.text)
loss = criterion(predictions, batch.label)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(iterator)
Любой, кто может дать мне подсказку, будет высоко оценен. Спасибо!
2 ответа
Размер возвращаемой партии не всегда равен batch_size
, Например: у вас есть 100 данных поезда, batch_size равно 64. Возвращенный batch_size должен быть [64, 36]
,
Я тоже столкнулся с этой проблемой. Я думаю, проблема в том, что размер batch_size не находится в позиции shape[0]. В вашем вопросе:
train_batch = next(iter(train_iter))
print(train_batch.text.shape)
print(train_batch.label.shape)
[output]
torch.Size([15, 64])
torch.Size([64])
15 - это max_sequence_length в пакете, которое можно исправить с помощью fix_length в определении поля, а 64 - это batch_size. Я думаю, вы можете изменить текст, чтобы решить эту проблему, но я также ищу лучший ответ.