Ошибка индексирования списка с помощью numba guvectorize
Я новичок в Numba / Numbapro. Я пытался запустить один из примеров, этот об обобщенных Ufuncs с помощью guvectorize:
(Здесь ссылка на пример): http://docs.continuum.io/numbapro/quickstart.html
import numbapro as numbapro
@numbapro.guvectorize(['void(int32[:], int32[:])'], '(n)->()')
def sum_row(inp, out):
"""
Sum every row
function type: two arrays
(note: scalar is represented as an array of length 1)
signature: n elements to scalar
"""
tmp = 0.
for i in range(inp.shape[0]):
tmp += inp[i]
out[0] = tmp
Я получаю эту ошибку:
IndexError Traceback (most recent call last)
<ipython-input-98-79514a184595> in <module>()
----> 1 @numbapro.guvectorize(['void(int32[:], int32[:])'], '(n)->()')
2 def sum_row(inp, out):
3 """
4 Sum every row
5
/users/adelacalle/anaconda_linux/lib/python2.7/site-packages/numba/npyufunc/decorators.pyc in wrap(func)
117 for fty in ftylist:
118 guvec.add(fty)
--> 119 return guvec.build_ufunc()
120
121 return wrap
/users/adelacalle/anaconda_linux/lib/python2.7/site-packages/numba/npyufunc/ufuncbuilder.pyc in build_ufunc(self)
149
150 for sig, cres in self.nb_func.overloads.items():
--> 151 dtypenums, ptr = self.build(cres)
152 dtypelist.append(dtypenums)
153 ptrlist.append(utils.longint(ptr))
/users/adelacalle/anaconda_linux/lib/python2.7/site-packages/numba/npyufunc/ufuncbuilder.pyc in build(self, cres)
167 signature = cres.signature
168 wrapper = build_gufunc_wrapper(ctx, cres.llvm_func, signature,
--> 169 self.sin, self.sout)
170 ctx.engine.add_module(wrapper.module)
171 ptr = ctx.engine.get_pointer_to_function(wrapper)
/users/adelacalle/anaconda_linux/lib/python2.7/site-packages/numba/npyufunc/wrappers.pyc in build_gufunc_wrapper(context, func, signature, sin, sout)
143 for i, (typ, sym) in enumerate(zip(signature.args, sin + sout)):
144 ary = GUArrayArg(context, builder, arg_args, arg_dims, arg_steps, i,
--> 145 step_offset, typ, sym, sym_dim)
146 step_offset += ary.ndim
147 arrays.append(ary)
/users/adelacalle/anaconda_linux/lib/python2.7/site-packages/numba/npyufunc/wrappers.pyc in __init__(self, context, builder, args, dims, steps, i, step_offset, typ, syms, sym_dim)
207 self.array = arycls(context, builder)
208 self.array.data = builder.bitcast(self.data, self.array.data.type)
--> 209 self.array.shape = cgutils.pack_array(builder, self.shape)
210 self.array.strides = cgutils.pack_array(builder, self.strides)
211 self.array_value = self.array._getpointer()
/users/adelacalle/anaconda_linux/lib/python2.7/site-packages/numba/cgutils.pyc in pack_array(builder, values)
257 def pack_array(builder, values):
258 n = len(values)
--> 259 ty = values[0].type
260 ary = Constant.undef(Type.array(ty, n))
261 for i, v in enumerate(values):
IndexError: list index out of range
Я не нашел больше документации, чем эта ссылка. Я делаю что-то неправильно? Я обнаружил, что это происходит, когда в подписи есть пустая скобка. Я использую это на машине с Linux, и моя версия numbapro - 0.14.1
Заранее спасибо,
Alex
1 ответ
Наконец я пришел к выводу сам. Вы можете выделить фактический массив из n элементов, чтобы вы могли писать как подпись (n)->(n)
вместо (n)->()
, В документации должна быть ошибка, или она устарела.
Однако это неэффективно, так как мы должны выделить весь массив, и это будет пустой тратой памяти (хотя это работает!). Даже больше, используйте guvectorize
для суммирования элементов в массиве это, вероятно, лучший, аккуратный и эффективный способ сделать это, поэтому я использую @jit.