Как использовать обученные контрольные точки модели BERT для прогнозирования?

Я обучил BERT с помощью SQUAD 2.0 и получил model.ckpt.data, model.ckpt.meta. model.ckpt.index (F1 балл: 81) в выходном каталоге вместе с Foretions.json и т. д. с помощью BERT-master/run_squad.py

python run_squad.py \
  --vocab_file=$BERT_LARGE_DIR/vocab.txt \
  --bert_config_file=$BERT_LARGE_DIR/bert_config.json \
  --init_checkpoint=$BERT_LARGE_DIR/bert_model.ckpt \
  --do_train=True \
  --train_file=$SQUAD_DIR/train-v2.0.json \
  --do_predict=True \
  --predict_file=$SQUAD_DIR/dev-v2.0.json \
  --train_batch_size=24 \
  --learning_rate=3e-5 \
  --num_train_epochs=2.0 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=gs://some_bucket/squad_large/ \
  --use_tpu=True \
  --tpu_name=$TPU_NAME \
  --version_2_with_negative=True

Я попытался скопировать модель.ckpt.meta, model.ckpt.index, model.ckpt.data в каталог $BERT_LARGE_DIR и изменил флаги run_squad.py следующим образом, чтобы только предсказать ответ, а не обучать с использованием набора данных:

python run_squad.py \
  --vocab_file=$BERT_LARGE_DIR/vocab.txt \
  --bert_config_file=$BERT_LARGE_DIR/bert_config.json \
  --init_checkpoint=$BERT_LARGE_DIR/model.ckpt \
  --do_train=False \
  --train_file=$SQUAD_DIR/train-v2.0.json \
  --do_predict=True \
  --predict_file=$SQUAD_DIR/dev-v2.0.json \
  --train_batch_size=24 \
  --learning_rate=3e-5 \
  --num_train_epochs=2.0 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=gs://some_bucket/squad_large/ \
  --use_tpu=True \
  --tpu_name=$TPU_NAME \
  --version_2_with_negative=True

Выдает ведро каталог /model.ckpt не существует ошибки.

Как использовать контрольные точки, созданные после обучения, и использовать их для прогнозирования?

2 ответа

Решение

Обычно обученные контрольные точки создаются в каталоге, указанном --output_dir параметр во время тренировки. (В вашем случае это gs://some_bucket/squad_large/). Каждый контрольно-пропускной пункт будет иметь номер. Вы должны определить наибольшее число; пример: model.ckpt-12345, Теперь установите --init_checkpoint параметр в вашей оценке / прогнозе, используя выходной каталог и последнюю сохраненную контрольную точку (модель с наибольшим числом). (В вашем случае это должно быть что-то вроде --init_checkpoint=gs://some_bucket/squad_large/model.ckpt-<highest number>)

Во втором коде ФЛАГ init_checkpoint Я думаю, что это должно быть:

--init_checkpoint=$BERT_LARGE_DIR/bert_model.ckpt

как в приведенном выше, а не --init_checkpoint=$BERT_LARGE_DIR/model.ckpt,

Если проблема не устранена, используете ли вы multi_cased_L-12_H-768_A-12 предварительно обученные модели?

Другие вопросы по тегам