Как получить строки меток в тонко настроенной сети, используя файлы контрольных точек или tf-запись?

Например, я настроил сеть VGG, используя свой собственный набор данных, только с двумя метками foo а также bar, Я конвертировал изображения в tf.record, используя пример по этой ссылке:

labels_to_class_names = dict(zip(range(len(class_names)), class_names))
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)

Я собираюсь создать API для прогнозирования изображений на основе этой новой модели, мой вопрос: есть ли какой-нибудь формальный способ получить строку метки из файлов контрольных точек или из набора данных (например, predict_image("abc.png") возвращается foo строка)? Поскольку я понятия не имею, какой узел в слое logits представляет метку fooи какой из них представляет bar

Я пытался искать, но без помощи, и я все еще Noobie тензорного потока.

1 ответ

Модель (и, кстати, файлы контрольных точек) не имеют названия каждого класса. Все, что у него есть - это определенное количество выходных нейронов, первый из которых соответствует первому классу, второй - второму классу и так далее.

Если вы хотите узнать, какой из них какой, посмотрите на файл меток, созданный этой строкой (скорее всего, с названием label.txt):

dataset_utils.write_label_file(labels_to_class_names, dataset_dir)

В качестве альтернативы вы можете проверить содержимое labels_to_class_names ДИКТ:

In [1]: class_names=['aaa', 'bbb', 'ccc']

In [2]: labels_to_class_names = dict(zip(range(len(class_names)), class_names))

In [3]: labels_to_class_names
Out[3]: {0: 'aaa', 1: 'bbb', 2: 'ccc'}

-> значение по индексу 0 в выходных данных модели = класс 'aaa' и т. д.

Другие вопросы по тегам