Аппроксимация значения SARSA для полюса тележки
У меня есть вопрос по этому вопросу.
Во входной ячейке 142 я вижу это модифицированное обновление
w += alpha * (reward - discount * q_hat_next) * q_hat_grad
где q_hat_next
является Q(S', a')
а также q_hat_grad
является производной от Q(S, a)
(предполагать S, a, R, S' a'
последовательность).
Мой вопрос: не должно ли обновление быть таким?
w += alpha * (reward + discount * q_hat_next - q_hat) * q_hat_grad
Какая интуиция стоит за модифицированным обновлением?
1 ответ
Я думаю, что вы правы. Я также ожидал бы, что обновление содержит термин ошибки TD, который должен быть reward + discount * q_hat_next - q_hat
,
Для справки, это реализация:
if done: # (terminal state reached)
w += alpha*(reward - q_hat) * q_hat_grad
break
else:
next_action = policy(env, w, next_state, epsilon)
q_hat_next = approx(w, next_state, next_action)
w += alpha*(reward - discount*q_hat_next)*q_hat_grad
state = next_state
И это псевдокод из " Усиленного обучения: введение" (Саттон и Барто) (стр. 171):
Поскольку реализация является TD(0), n
равен 1. Тогда обновление в псевдокоде можно упростить:
w <- w + a[G - v(S_t,w)] * dv(S_t,w)
становится (путем замены G == reward + discount*v(S_t+1,w))
)
w <- w + a[reward + discount*v(S_t+1,w) - v(S_t,w)] * dv(S_t,w)
Или с именами переменных в исходном примере кода:
w += alpha * (reward + discount * q_hat_next - q_hat) * q_hat_grad
Я получил ту же формулу обновления, что и у вас. Похоже, ошибка в обновлении состояния нетерминала.
Только терминальный корпус (если done
верно) должно быть правильно, потому что тогда q_hat_next
всегда равен 0 по определению, так как эпизод окончен, и больше никакого вознаграждения получить нельзя.