Реализация пакета Stata "nl" (нелинейные наименьшие квадраты) в Python
Кто-нибудь знает, есть ли реализация Stata nl
(нелинейный метод наименьших квадратов) в Python
? Я пытался использовать lmfit
так же как optimize.leastsq
от scipy
, но оба, похоже, не работают.
Уравнение для регрессии
Y = x1 + b1 + 0.3*log(x2-b2)*b3 - 0.7*x3*b3 + b5*x2
где Y
является зависимой переменной, x's
независимые переменные, а b's
являются коэффициентами для оценки.
С использованием lmfit
пакет, я попробовал следующее:
from lmfit import minimize, Parameters, Parameter, report_fit
import pandas as pd
import numpy as np
inputfile = "testdata.csv"
df = pd.read_csv(inputfile)
x1= df['x1']
x2 = df['x2']
x3= df['x3']
y= df['y']
def fcn2min(params, x1, x2, x3, y):
b1 = params['b1'].value
b2 = params['b2'].value
b3 = params['b3'].value
b5 = params['b5'].value
model = x1 + b1 + (0.3)*np.log(x2-b2)*b3 - (0.7)*x3*b3 + b5*x2
return model - y
params = Parameters()
params.add('b1', value= 10)
params.add('b2', value= 1990)
params.add('b3', value= 5)
params.add('b5', value= 12)
result = minimize(fcn2min, params, args=(x1, x2, x3, y))
print report_fit(result)
В результате все параметры оцениваются как NaN. Кто-нибудь может объяснить, что я сделал не так? Или есть хорошая реализация функции Stata nl в Python? Большое спасибо!
Вот данные в файле CSV:
x1, x2, x3, у
1981,15.2824955,14.56475067,2.936807632
1982,15.2635746,15.52343941,2.908272743
1983,15.30461597,16.30871582,2.940227509
1984,15.37490845,16.76519966,3.001846313
1985,15.41295338,17.04235458,3.030970573
1986,15.44680405,17.25271797,3.055702209
1987,15.48135281,17.44781876,3.081344604
1988,15.52259159,17.62217331,3.113491058
1989,15.5565939,17.71343422,3.138068199
1990,15.57392025,17.81187439,3.144176483
1991,15.57197666,17.89474106,3.128887177
1992,15.60479259,17.98217583,3.14837265
1993,15.63134575,18.06685829,3.161927223
1994,15.67116165,18.16578865,3.18959713
1995,15.69621944,18.27449799,3.202876091
1996,15.7329874,18.38712311,3.228042603
1997,15.77698135,18.50685883,3.260077477
1998,15.81788635,18.63579178,3.289312363
1999,15.86141682,18.76427078,3.321393967
2000,15.89737129,18.89691544,3.34650898
2001,15.90485096,18.99729347,3.344522476
2002,15.92070866,19.06253433,3.351119995
1 ответ
Просто чтобы исправить ситуацию, причина неудачи в том, что вы не проверяете случай, x2-b2
может быть отрицательным, так что np.log(x2-b2)
является NaN
, Конечно, если целевая функция возвращает NaN
Приступ остановится и не сможет найти хорошего решения. Вы можете попробовать добавить верхнюю границу b2
, Как и другие, я подозреваю, что если вы угадаете b1
быть 10 и b2
быть 1990, что у вас есть какая-то простая ошибка в вашей целевой функции, которая вызывает NaN
происходить. Часто бывает полезно один раз вызвать целевую функцию и, возможно, даже построить начальное условие.
Или вы можете обвинить инструмент.