Как керасы определяют "точность" и "потеря"?
Я не могу найти, как Керас определяет "точность" и "потеря". Я знаю, что могу указать разные метрики (например, mse, кросс-энтропия), но keras выводит стандартную "точность". Как это определяется? Аналогично для потери: я знаю, что могу указать различные типы регуляризации - есть ли потери?
В идеале я хотел бы распечатать уравнение, используемое для его определения; если нет, я соглашусь на ответ здесь.
2 ответа
Посмотри на metrics.py
, там вы можете найти определение всех доступных метрик, включая различные типы точности. Точность не печатается, если вы не добавите ее в список желаемых метрик при компиляции вашей модели.
Регуляризаторы по определению добавляются к убытку. Например, см. add_loss
метод Layer
учебный класс.
Обновить
Тип accuracy
определяется на основе целевой функции, см. training.py
, Выбор по умолчанию categorical_accuracy
, Другие типы, такие как binary_accuracy
а также sparse_categorical_accuracy
выбираются, когда целевая функция является двоичной или разреженной.
После ответа Сергея библиотека Keras была немного подчищена, и исходный код в настоящее время довольно читабелен. Метрики определены вtensorflow.keras.metrics
(документацию которого можно найти здесь ), а потери определены вtensorflow.keras.losses
( документы ). Есть некоторое совпадение с модулем метрик, но это ожидаемо, поскольку конкретную функцию потерь также можно отслеживать как метрику.
Кроме того, если мы проверяем исходный код , если метрикой не является точность, метод вызывается вmetrics
модуль для получения конкретной метрической функции, т.е.tf.keras.metrics.get('binary_accuracy')
. С другой стороны,get()
метод всегда вызывается для получения конкретной функции потерь.
Также тип точности выбирается в зависимости от типа цели (binary_accuracy
,categorical_accuracy
и т. д.).
Все показатели/потери можно распечатать, вызвавdir()
на модулях.
metrics_list = [m for m in dir(tf.keras.metrics) if not m.startswith('_')]
losses_list = [m for m in dir(tf.keras.losses) if not m.startswith('_')]