Как использовать пакеты AutoGrad?

Я пытаюсь сделать простую вещь: использовать автоград для получения градиентов и сделать градиентный спуск:

import tangent

def model(x):
    return a*x + b

def loss(x,y):
    return (y-model(x))**2.0

После получения потерь для пары ввода-вывода я хочу получить градиенты относительно потерь:

    l = loss(1,2)
    # grad_a = gradient of loss wrt a?
    a = a - grad_a
    b = b - grad_b

Но учебные руководства по библиотеке не показывают, как получить градиент по отношению к a или b, то есть к параметрам, таким образом, ни автоград, ни касательные.

1 ответ

Вы можете указать это со вторым аргументом функции grad:

def f(x,y):
    return x*x + x*y

f_x = grad(f,0) # derivative with respect to first argument
f_y = grad(f,1) # derivative with respect to second argument

print("f(2,3)   = ", f(2.0,3.0))
print("f_x(2,3) = ", f_x(2.0,3.0)) 
print("f_y(2,3) = ", f_y(2.0,3.0))

В вашем случае "a" и "b" должны быть входными данными для функции потерь, которая передает их в модель для вычисления производных.

Был похожий вопрос, на который я только что ответил: частичная производная с использованием Autograd

Здесь это может помочь:

import autograd.numpy as np
from autograd import grad
def tanh(x):
  y=np.exp(-x)
  return (1.0-y)/(1.0+y)

grad_tanh = grad(tanh)

print(grad_tanh(1.0))

e=0.00001
g=(tanh(1+e)-tanh(1))/e
print(g)

Выход:

0.39322386648296376
0.39322295790622513

Вот что вы можете создать:

import autograd.numpy as np
from autograd import grad  # grad(f) returns f'

def f(x): # tanh
  y = np.exp(-x)
  return  (1.0 - y) / ( 1.0 + y)

D_f   = grad(f) # Obtain gradient function
D2_f = grad(D_f)# 2nd derivative
D3_f = grad(D2_f)# 3rd derivative
D4_f = grad(D3_f)# etc.
D5_f = grad(D4_f)
D6_f = grad(D5_f)

import  matplotlib.pyplot  as plt
plt.subplots(figsize = (9,6), dpi=153 )
x = np.linspace(-7, 7, 100)
plt.plot(x, list(map(f, x)),
         x, list(map(D_f , x)),
         x, list(map(D2_f , x)),
         x, list(map(D3_f , x)),
         x, list(map(D4_f , x)),
         x, list(map(D5_f , x)),
         x, list(map(D6_f , x)))
plt.show()

Выход:

Другие вопросы по тегам