Как керасы определяют "точность" и "потеря"?

Я не могу найти, как Керас определяет "точность" и "потеря". Я знаю, что могу указать разные метрики (например, mse, кросс-энтропия), но keras выводит стандартную "точность". Как это определяется? Аналогично для потери: я знаю, что могу указать различные типы регуляризации - есть ли потери?

В идеале я хотел бы распечатать уравнение, используемое для его определения; если нет, я соглашусь на ответ здесь.

2 ответа

Решение

Посмотри на metrics.py, там вы можете найти определение всех доступных метрик, включая различные типы точности. Точность не печатается, если вы не добавите ее в список желаемых метрик при компиляции вашей модели.

Регуляризаторы по определению добавляются к убытку. Например, см. add_loss метод Layer учебный класс.

Обновить

Тип accuracy определяется на основе целевой функции, см. training.py, Выбор по умолчанию categorical_accuracy, Другие типы, такие как binary_accuracy а также sparse_categorical_accuracy выбираются, когда целевая функция является двоичной или разреженной.

После ответа Сергея библиотека Keras была немного подчищена, и исходный код в настоящее время довольно читабелен. Метрики определены вtensorflow.keras.metrics(документацию которого можно найти здесь ), а потери определены вtensorflow.keras.losses( документы ). Есть некоторое совпадение с модулем метрик, но это ожидаемо, поскольку конкретную функцию потерь также можно отслеживать как метрику.

Кроме того, если мы проверяем исходный код , если метрикой не является точность, метод вызывается вmetricsмодуль для получения конкретной метрической функции, т.е.tf.keras.metrics.get('binary_accuracy'). С другой стороны,get()метод всегда вызывается для получения конкретной функции потерь.

Также тип точности выбирается в зависимости от типа цели (binary_accuracy,categorical_accuracyи т. д.).

Все показатели/потери можно распечатать, вызвавdir()на модулях.

      metrics_list = [m for m in dir(tf.keras.metrics) if not m.startswith('_')]

losses_list = [m for m in dir(tf.keras.losses) if not m.startswith('_')]
Другие вопросы по тегам