Как получить строки меток в тонко настроенной сети, используя файлы контрольных точек или tf-запись?
Например, я настроил сеть VGG, используя свой собственный набор данных, только с двумя метками foo
а также bar
, Я конвертировал изображения в tf.record, используя пример по этой ссылке:
labels_to_class_names = dict(zip(range(len(class_names)), class_names))
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
Я собираюсь создать API для прогнозирования изображений на основе этой новой модели, мой вопрос: есть ли какой-нибудь формальный способ получить строку метки из файлов контрольных точек или из набора данных (например, predict_image("abc.png")
возвращается foo
строка)? Поскольку я понятия не имею, какой узел в слое logits представляет метку foo
и какой из них представляет bar
Я пытался искать, но без помощи, и я все еще Noobie тензорного потока.
1 ответ
Модель (и, кстати, файлы контрольных точек) не имеют названия каждого класса. Все, что у него есть - это определенное количество выходных нейронов, первый из которых соответствует первому классу, второй - второму классу и так далее.
Если вы хотите узнать, какой из них какой, посмотрите на файл меток, созданный этой строкой (скорее всего, с названием label.txt):
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
В качестве альтернативы вы можете проверить содержимое labels_to_class_names
ДИКТ:
In [1]: class_names=['aaa', 'bbb', 'ccc']
In [2]: labels_to_class_names = dict(zip(range(len(class_names)), class_names))
In [3]: labels_to_class_names
Out[3]: {0: 'aaa', 1: 'bbb', 2: 'ccc'}
-> значение по индексу 0 в выходных данных модели = класс 'aaa' и т. д.