Способы улучшения обучения универсальным дифференциальным уравнениям с помощью sciml_train

Около месяца назад я задал вопрос о стратегиях лучшей сходимости при обучении нейронно-дифференциального уравнения. С тех пор я заставил этот пример работать, используя полученный мне совет, но когда я применил тот же совет к более сложной модели, я снова застрял. Весь мой код находится на Julia, в основном с использованием библиотеки DiffEqFlux. Стремясь сделать этот пост как можно более кратким, я не буду делиться всем своим кодом для всего, что я пробовал, но если кто-то хочет получить к нему доступ для устранения неполадок, я могу предоставить его.

Что я пытаюсь сделать

Данные, которые я пытаюсь изучить, взяты из модели SIRx:

      function SIRx!(du, u, p, t)
    β, μ, γ, a, b = Float32.([280, 1/50, 365/22, 100, 0.05])
    S, I, x = u
    du[1] = μ*(1-x) - β*S*I - μ*S
    du[2] = β*S*I - (μ+γ)*I
    du[3] = a*I - b*x
    nothing
end;

Исходное условие, которое я использовал, было u0 = Float32.([0.062047128, 1.3126149f-7, 0.9486445]);. Я генерировал данные от t = 0 до 25, отбирал каждые 0,02 (при обучении я использую только каждые 20 точек или около того для скорости, и использование большего количества не улучшает результаты). Данные выглядят так: Данные обучения

UDE, который я тренирую, это

      function SIRx_ude!(du, u, p, t)
    μ, γ = Float32.([1/50, 365/22])
    S,I,x = u
    du[1] = μ*(1-x) - μ*S + ann_dS(u, @view p[1:lenS])[1]
    du[2] = -(μ+γ)*I + ann_dI(u, @view p[lenS+1:lenS+lenI])[1]
    du[3] = ann_dx(u, @view p[lenI+1:end])[1]
    nothing
end;

Каждая из нейронных сетей ( ann_dS, ann_dI, ann_dx) определены с помощью FastChain(FastDense(3, 20, tanh), FastDense(20, 1)). Я пробовал использовать одну нейронную сеть с 3 входами и 3 выходами, но это было медленнее и не лучше. Я также сначала попробовал нормализовать входные данные в сеть, но это не имеет существенного значения, кроме замедления работы.

Что я пробовал

  • Одиночная съемка Сеть просто помещает линию в середину данных. Это происходит даже тогда, когда я больше взвешиваю более ранние точки данных в функции потерь. Одноразовое обучение
  • Многократная съемка . Лучший результат, который у меня был, был при многократной съемке. Как видно здесь, это не просто прямая линия, но она не совсем соответствует данным результата множественной съемки . Я пробовал условия непрерывности от 0,1 до 100 и размер групп от 3 до 30, и это не имеет существенного значения.
  • Различные другие стратегии Я также пробовал итеративно увеличивать соответствие, двухэтапное обучение с коллокацией и мини-пакетирование, как описано здесь: https://diffeqflux.sciml.ai/dev/examples/local_minima, https: // diffeqflux.sciml.ai / dev / examples / collocation/, https://diffeqflux.sciml.ai/dev/examples/minibatch/. Итеративное наращивание аппроксимации хорошо работает в первой паре итераций, но по мере увеличения длины она снова возвращается к подгонке прямой линии. Двухэтапное обучение совмещению очень хорошо работает для этапа 1, но на самом деле оно не улучшает производительность на втором этапе (я пробовал как одиночную, так и множественную съемку для второго этапа). Наконец, мини-дозирование работало примерно так же, как и одиночная съемка (то есть не очень хорошо), но гораздо быстрее.

Мой вопрос

Таким образом, я понятия не имею, что попробовать. Существует так много стратегий, каждая из которых имеет так много параметров, которые можно настроить. Мне нужен способ более точной диагностики проблемы, чтобы я мог лучше решить, что делать дальше. Если у кого-то есть опыт решения такого рода проблем, я буду признателен за любой совет или руководство, которое я смогу получить.

1 ответ

Это не лучший вопрос SO, потому что он более исследовательский. Вы снизили допуски ODE? Это улучшит ваш расчет градиента, что может помочь. Какую функцию активации вы используете? Я бы использовал что-то вроде softplus вместо tanhтак что у вас не будет насыщающего поведения. Вы масштабировали собственные значения и учли проблемы, исследованные в статье ODE для жестких нейронных сетей ? Большие нейронные сети? Разные темпы обучения? АДАМ? И т.п.

Это гораздо лучше подходит для форума для обсуждения, такого как JuliaLang Discourse . Мы можем продолжить там, так как прогулка по нему не будет плодотворной без некоторых шагов вперед и назад.

Другие вопросы по тегам