Как изменить уровень временной классификации (CTC) в сети для подключения к сети, чтобы он также давал нам оценку достоверности?
Я пытаюсь распознать слова по обрезанным изображениям самих слов, обучая модель CRNN(CNN+LSTM+CTC). Я запутался, как добавить показатель доверия наряду с узнаваемыми словами. Я использую тензор потока и слежу за реализацией https://github.com/TJCVRS/CRNN_Tensorflow. Может кто-нибудь подсказать мне, как изменить уровень подключения к сети (CTC) в сети, чтобы дать нам показатель достоверности?
2 ответа
Одно обновление от меня:
я наконец достиг результата, передав спрогнозированную метку обратно в функцию потерь ctc и взяв анти-лог негатив отрицательной потери. Я нахожу это значение очень точным, чем использование анти-журнала log_prob.
Есть два решения, о которых я могу думать прямо сейчас:
- оба декодера TensorFlow предоставляют информацию о значении распознанного текста. ctc_greedy_decoder возвращает neg_sum_logits, который содержит оценку для каждого элемента пакета. То же самое верно для ctc_beam_search_decoder, который возвращает log_probabilities, который содержит оценки для каждого луча каждого элемента пакета.
- возьмите распознанный текст из любого из двух декодеров. Поместите в свой код другую функцию потери CTC и введите выходную матрицу RNN и распознанный текст в функцию потери. Тогда результатом будет вероятность (хорошо, вы должны отменить минус и журнал, но это должно быть легко) увидеть данный текст в матрице.
Решение (1) быстрее и проще в реализации, однако решение (2) является более точным. Но разница не должна быть слишком большой, если CRNN хорошо обучен и ширина луча декодера поиска луча достаточно велика.
Посмотрите на код TF-CRNN в следующей строке - оценка уже возвращается как переменная log_prob: https://github.com/MaybeShewill-CV/CRNN_Tensorflow/blob/master/tools/train_shadownet.py
И вот пример автономного кода, который иллюстрирует решение (2): https://gist.github.com/githubharald/8b6f3d489fc014b0faccbae8542060dc