Обучение Detectron2 на базе данных COCO
Я пытаюсь обучить модель с помощью набора данных Detectron2 и COCO для обнаружения транспортных средств и людей, и у меня возникают проблемы с загрузкой модели.
Я использовал здесь сообщения на SO и https://github.com/immersive-limit/coco-manager (файл filter.py) код, чтобы отфильтровать набор данных COCO, чтобы включить только аннотации и изображения из классов "человек", "автомобиль"., "велосипед", "грузовик" и "велосипед". Теперь моя структура каталогов:
main
- annotations:
- instances_train2017_filtered.json
- instances_val2017_filtered.json
- images:
- train2017_filtered (lots of images inside)
- val2017_filtered (lots of images inside)
По сути, единственное, что я здесь сделал, - это удалил документы и изображения, не соответствующие этим классам, и изменил их идентификаторы (теперь они от 1 до 5).
Затем я использовал код из учебника Detectron2:
import random
import cv2
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data.datasets import register_coco_instances
from detectron2.engine import DefaultTrainer, DefaultPredictor
from detectron2.config import get_cfg
import os
from detectron2.model_zoo import model_zoo
from detectron2.utils.visualizer import Visualizer
register_coco_instances("train",
{},
"/home/jakub/Projects/coco/annotations/instances_train2017_filtered.json",
"/home/jakub/Projects/coco/images/train2017_filtered/")
register_coco_instances("val",
{},
"/home/jakub/Projects/coco/annotations/instances_val2017_filtered.json",
"/home/jakub/Projects/coco/images/val2017_filtered/")
metadata = MetadataCatalog.get("train")
dataset_dicts = DatasetCatalog.get("train")
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("train",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 300
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 5
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.DATASETS.TEST = ("val", )
predictor = DefaultPredictor(cfg)
img = cv2.imread("demo/input.jpg")
outputs = predictor(img)
for d in random.sample(dataset_dicts, 1):
im = cv2.imread(d["file_name"])
outputs = predictor(im)
v = Visualizer(im[:, :, ::-1],
metadata=metadata,
scale=0.8)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
cv2.imwrite('demo/output_retrained.jpg', out.get_image()[:, :, ::-1])
Во время обучения я получаю следующие ошибки:
Unable to load 'roi_heads.box_predictor.cls_score.weight' to the model due to incompatible shapes: (81, 1024) in the checkpoint but (6, 1024) in the model!
Unable to load 'roi_heads.box_predictor.cls_score.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (6,) in the model!
Unable to load 'roi_heads.box_predictor.bbox_pred.weight' to the model due to incompatible shapes: (320, 1024) in the checkpoint but (20, 1024) in the model!
Unable to load 'roi_heads.box_predictor.bbox_pred.bias' to the model due to incompatible shapes: (320,) in the checkpoint but (20,) in the model!
Unable to load 'roi_heads.mask_head.predictor.weight' to the model due to incompatible shapes: (80, 256, 1, 1) in the checkpoint but (5, 256, 1, 1) in the model!
Unable to load 'roi_heads.mask_head.predictor.bias' to the model due to incompatible shapes: (80,) in the checkpoint but (5,) in the model!
Модель не может предсказать ничего полезного после обучения, несмотря на уменьшение total_loss во время обучения. Я понимаю, что должен получать предупреждения из-за несоответствия размера (я уменьшил количество классов), что нормально из того, что я видел в Интернете, но я не получаю "Пропущено" после каждой строки ошибки. Я думаю, что эта модель на самом деле здесь ничего не загружает, и мне интересно, почему и как я могу это исправить.
РЕДАКТИРОВАТЬ
Для сравнения, аналогичное поведение в почти идентичной ситуации было зарегистрировано как Проблема, но в конце каждой строки ошибки было "Пропущено", что делало их фактически предупреждениями, а не ошибками:https://github.com/facebookresearch/detectron2/issues/196
1 ответ
Это «предупреждение» в основном говорит о том, что вы пытаетесь инициализировать веса из модели, которая была обучена на другом количестве классов. Ожидается, что вы прочитали.
Я подозреваю, что вы не получаете никаких результатов от своего обучения, потому что в вашем MetadataCatalog не задано свойство thing_classes. Ты только звонишь
MetadataCatalog.get("train")
Звонок
MetadataCatalog.get("train").set(thing_classes=["person", "car", "bike", "truck", "bicycle"])
Должно решить проблему, но если это не так, я уверен, что ваш json поврежден.