Итерация и накопление по массиву с помощью пользовательской функции
В течение более 7 лет были некоторые связанные вопросы, но я снова поднимаю эту проблему, так как не вижу ни одного "итерационного" способа итерации.
Задача состоит в следующем: если у меня есть массив nrum arr и есть пользовательская функция fn, как я могу итеративно применить fn к arr? 'fn' не может быть создан инструментами ufunc.
Ниже приводится код toy_code, который я придумываю:
import numpy as np
r_list = np.arange(1,6,dtype=np.float32)
# r_list = [1. 2. 3. 4. 5.]
r_list_extended = np.append([0.],r_list)
R_list_extended = np.zeros_like(r_list_extended)
print(r_list)
gamma = 0.99
pv_mc = lambda a, x: x+ a*gamma
# no cumsum, accumulate available
for i in range(len(r_list_extended)):
if i ==0: continue
else: R_list_extended[i] = pv_mc(R_list_extended[i-1],r_list_extended[i])
R_list = R_list_extended[1:]
print(R_list)
# R_list == [ 1. 2.99 5.9601 9.900499 14.80149401]
r_list является массивом r для каждого времени. R_list - совокупная сумма дисконтированного r. Предположим, что r_list и R_list предварительно возвращены. Цикл выше делает R[t]: r[t] + гамма * R[t-1]
Я не думаю, что это лучший способ использовать numpy.... Если кто-то может использовать tenorflow, то tf.scan() выполняет работу, как показано ниже:
import numpy as np
import tensorflow as tf
r_list = np.arange(1,6,dtype=np.float32)
# r_list = [1. 2. 3. 4. 5.]
gamma = 0.99
pv_mc = lambda a, x: x+ a*gamma
R_list_graph = tf.scan(pv_mc, r_list, initializer=np.array(0,dtype=np.float32))
with tf.Session() as sess:
R_list = sess.run(R_list_graph, feed_dict={})
print(R_list)
# R_list = [ 1. 2.99 5.9601 9.900499 14.801495]
Заранее спасибо за помощь!
1 ответ
Вы могли бы использовать np.frompyfunc
, чья документация несколько неясна.
import numpy as np
r_list = np.arange(1,6,dtype=np.float32)
# r_list = [1. 2. 3. 4. 5.]
r_list_extended = np.append([0.],r_list)
R_list_extended = np.zeros_like(r_list_extended)
print(r_list)
gamma = 0.99
pv_mc = lambda a, x: x+ a*gamma
ufunc = np.frompyfunc(pv_mc, 2, 1)
R_list = ufunc.accumulate(r_list, dtype=np.object).astype(float)
print(R_list)