Как использовать обученные контрольные точки модели BERT для прогнозирования?
Я обучил BERT с помощью SQUAD 2.0 и получил model.ckpt.data, model.ckpt.meta. model.ckpt.index (F1 балл: 81) в выходном каталоге вместе с Foretions.json и т. д. с помощью BERT-master/run_squad.py
python run_squad.py \
--vocab_file=$BERT_LARGE_DIR/vocab.txt \
--bert_config_file=$BERT_LARGE_DIR/bert_config.json \
--init_checkpoint=$BERT_LARGE_DIR/bert_model.ckpt \
--do_train=True \
--train_file=$SQUAD_DIR/train-v2.0.json \
--do_predict=True \
--predict_file=$SQUAD_DIR/dev-v2.0.json \
--train_batch_size=24 \
--learning_rate=3e-5 \
--num_train_epochs=2.0 \
--max_seq_length=384 \
--doc_stride=128 \
--output_dir=gs://some_bucket/squad_large/ \
--use_tpu=True \
--tpu_name=$TPU_NAME \
--version_2_with_negative=True
Я попытался скопировать модель.ckpt.meta, model.ckpt.index, model.ckpt.data в каталог $BERT_LARGE_DIR и изменил флаги run_squad.py следующим образом, чтобы только предсказать ответ, а не обучать с использованием набора данных:
python run_squad.py \
--vocab_file=$BERT_LARGE_DIR/vocab.txt \
--bert_config_file=$BERT_LARGE_DIR/bert_config.json \
--init_checkpoint=$BERT_LARGE_DIR/model.ckpt \
--do_train=False \
--train_file=$SQUAD_DIR/train-v2.0.json \
--do_predict=True \
--predict_file=$SQUAD_DIR/dev-v2.0.json \
--train_batch_size=24 \
--learning_rate=3e-5 \
--num_train_epochs=2.0 \
--max_seq_length=384 \
--doc_stride=128 \
--output_dir=gs://some_bucket/squad_large/ \
--use_tpu=True \
--tpu_name=$TPU_NAME \
--version_2_with_negative=True
Выдает ведро каталог /model.ckpt не существует ошибки.
Как использовать контрольные точки, созданные после обучения, и использовать их для прогнозирования?
2 ответа
Обычно обученные контрольные точки создаются в каталоге, указанном --output_dir
параметр во время тренировки. (В вашем случае это gs://some_bucket/squad_large/). Каждый контрольно-пропускной пункт будет иметь номер. Вы должны определить наибольшее число; пример: model.ckpt-12345
, Теперь установите --init_checkpoint
параметр в вашей оценке / прогнозе, используя выходной каталог и последнюю сохраненную контрольную точку (модель с наибольшим числом). (В вашем случае это должно быть что-то вроде --init_checkpoint=gs://some_bucket/squad_large/model.ckpt-<highest number>
)
Во втором коде ФЛАГ init_checkpoint
Я думаю, что это должно быть:
--init_checkpoint=$BERT_LARGE_DIR/bert_model.ckpt
как в приведенном выше, а не --init_checkpoint=$BERT_LARGE_DIR/model.ckpt
,
Если проблема не устранена, используете ли вы multi_cased_L-12_H-768_A-12
предварительно обученные модели?