Реализация пакета Stata "nl" (нелинейные наименьшие квадраты) в Python

Кто-нибудь знает, есть ли реализация Stata nl (нелинейный метод наименьших квадратов) в Python? Я пытался использовать lmfit так же как optimize.leastsq от scipy, но оба, похоже, не работают.

Уравнение для регрессии

Y = x1 + b1 + 0.3*log(x2-b2)*b3 - 0.7*x3*b3 + b5*x2

где Y является зависимой переменной, x's независимые переменные, а b's являются коэффициентами для оценки.

С использованием lmfit пакет, я попробовал следующее:


from lmfit import minimize, Parameters, Parameter, report_fit
import pandas as pd
import numpy as np

inputfile = "testdata.csv"
df = pd.read_csv(inputfile)

x1= df['x1']
x2 = df['x2']
x3= df['x3']
y= df['y']


def fcn2min(params, x1, x2, x3, y):

    b1 = params['b1'].value
    b2 = params['b2'].value
    b3 = params['b3'].value
    b5 = params['b5'].value
    model = x1 + b1 + (0.3)*np.log(x2-b2)*b3 - (0.7)*x3*b3 + b5*x2
    return model - y

params = Parameters()

params.add('b1', value= 10)

params.add('b2', value= 1990)

params.add('b3', value= 5)

params.add('b5', value= 12)


result = minimize(fcn2min, params, args=(x1, x2, x3, y))

print report_fit(result) 

В результате все параметры оцениваются как NaN. Кто-нибудь может объяснить, что я сделал не так? Или есть хорошая реализация функции Stata nl в Python? Большое спасибо!

Вот данные в файле CSV:


x1, x2, x3, у
1981,15.2824955,14.56475067,2.936807632
1982,15.2635746,15.52343941,2.908272743
1983,15.30461597,16.30871582,2.940227509
1984,15.37490845,16.76519966,3.001846313
1985,15.41295338,17.04235458,3.030970573
1986,15.44680405,17.25271797,3.055702209
1987,15.48135281,17.44781876,3.081344604
1988,15.52259159,17.62217331,3.113491058
1989,15.5565939,17.71343422,3.138068199
1990,15.57392025,17.81187439,3.144176483
1991,15.57197666,17.89474106,3.128887177
1992,15.60479259,17.98217583,3.14837265
1993,15.63134575,18.06685829,3.161927223
1994,15.67116165,18.16578865,3.18959713
1995,15.69621944,18.27449799,3.202876091
1996,15.7329874,18.38712311,3.228042603
1997,15.77698135,18.50685883,3.260077477
1998,15.81788635,18.63579178,3.289312363
1999,15.86141682,18.76427078,3.321393967
2000,15.89737129,18.89691544,3.34650898
2001,15.90485096,18.99729347,3.344522476
2002,15.92070866,19.06253433,3.351119995


1 ответ

Просто чтобы исправить ситуацию, причина неудачи в том, что вы не проверяете случай, x2-b2 может быть отрицательным, так что np.log(x2-b2) является NaN, Конечно, если целевая функция возвращает NaNПриступ остановится и не сможет найти хорошего решения. Вы можете попробовать добавить верхнюю границу b2, Как и другие, я подозреваю, что если вы угадаете b1 быть 10 и b2 быть 1990, что у вас есть какая-то простая ошибка в вашей целевой функции, которая вызывает NaN происходить. Часто бывает полезно один раз вызвать целевую функцию и, возможно, даже построить начальное условие.

Или вы можете обвинить инструмент.

Другие вопросы по тегам