Влияние --oaa 2 и --loss_function=logistic в Vowpal Wabbit
Какие параметры я должен использовать в VW для задачи двоичной классификации? Например, давайте использовать rcv1_small.dat. Я думал, что лучше использовать функцию логистической потери (или шарнир), и нет смысла использовать --oaa 2
, Тем не менее, эмпирические результаты (с прогрессивной валидацией 0/1 потерь, зарегистрированных во всех 4 экспериментах) показывают, что лучшая комбинация --oaa 2
без логистических потерь (т.е. с квадратическими потерями по умолчанию):
cd vowpal_wabbit/test/train-sets
cat rcv1_small.dat | vw --binary
# average loss = 0.0861
cat rcv1_small.dat | vw --binary --loss_function=logistic
# average loss = 0.0909
cat rcv1_small.dat | sed 's/^-1/2/' | vw --oaa 2
# average loss = 0.0857
cat rcv1_small.dat | sed 's/^-1/2/' | vw --oaa 2 --loss_function=logistic
# average loss = 0.0934
Мой основной вопрос: почему --oaa 2
не дает точно такие же результаты, как --binary
(в вышеуказанной настройке) ?
Мои второстепенные вопросы: почему оптимизация логистических потерь не улучшает потери 0/1 (по сравнению с оптимизацией квадратичных потерь по умолчанию)? Это специфика этого конкретного набора данных?
1 ответ
Я испытал нечто подобное при использовании --csoaa
, Подробности можно найти здесь. Я предполагаю, что в случае проблемы мультикласса с N классами (независимо от того, что вы указали 2 в качестве количества классов), vw практически работает с N копиями функций. Тот же пример получает другое значение ft_offset, когда оно предсказано / изучено для каждого возможного класса, и это смещение используется в алгоритме хеширования. Таким образом, все классы получают "независимый" набор функций из одной и той же строки набора данных. Конечно, значения объектов одинаковы, но vw не сохраняет значения - только веса объектов. И веса разные для каждого возможного класса. А так как объем оперативной памяти, используемой для хранения этих весов, фиксируется с -b
(-b 18
по умолчанию) - чем больше у вас классов, тем больше шансов получить хеш-коллизию. Вы можете попытаться увеличить -b
значение и проверьте разницу между --oaa 2
а также --binary
результаты уменьшаются. Но я могу ошибаться, поскольку я не слишком углубился в код vw.
Что касается функции потерь - вы не можете напрямую сравнивать средние значения потерь в квадрате (по умолчанию) и логистических функциях потерь. Вы должны получить необработанные значения прогноза из результата, полученного с квадратом потерь, и получить потери этих прогнозов с точки зрения логистических потерь. Функция будет: log(1 + exp(-label * prediction)
где метка является априори известным ответом. Такие функции (float getLoss(float prediction, float label)
) для всех функций потерь, реализованных в vw, можно найти в файле loss_functions.cc. Или вы можете предварительно масштабировать исходное значение прогноза до [0..1] с помощью 1.f / (1.f + exp(- prediction)
а затем рассчитайте потери в журнале, как описано на сайте kaggle.com:
double val = 1.f / (1.f + exp(- prediction); // y = f(x) -> [0, 1]
if (val < 1e-15) val = 1e-15;
if (val > (1.0 - 1e-15)) val = 1.0 - 1e-15;
float xx = (label < 0)?0:1; // label {-1,1} -> {0,1}
double loss = xx*log(val) + (1.0 - xx) * log(1.0 - val);
loss *= -1;
Вы также можете масштабировать необработанные прогнозы до [0..1] с помощью скрипта /vowpal_wabbit/utl/logistic или --link=logistic
параметр. Оба используют 1/(1+exp(-i))
,