У объекта 'int' нет атрибута '__getitem__' mxnet
Я пытался выучить mxnet из учебника, когда выполняю загрузку данных. У меня 'int' нет 'getitem', но я не могу найти местоположение ошибки, пожалуйста, помогите мне, спасибо:
import mxnet as mx
import numpy as np
class SimpleData :
def __init__(self,data,label,pad = 0):
self.data = data
self.label = label
self.pad = pad
class SimpleIter:
def __init__(self,mean,std,data_shape,label_shape,num_of_classes,num_batch = 10):
self._provide_data = zip(['data'],data_shape[0])
self._provide_label = zip(['softmax_label'],label_shape[0])
self.cur_batch = 0
self.num_batch = 10
self.mean = mean
self.std = std
self.data_shape = data_shape[0]
self.label_shape = label_shape[0]
self.num_of_classes = num_of_classes
def __iter__(self):
return self
def __next__(self):
return self.next()
def reset(self):
self.cur_batch = 0
@property
def provide_data(self):
return self._provide_data
@property
def provide_label(self):
return self._provide_label
def next(self):
if(self.cur_batch < self.num_batch):
self.cur_batch += 1
data = [mx.nd.array(np.random.normal(self.mean,self.std, ((self.data_shape)[0][0]/self.num_batch,self.data_shape[0][1])))]
label = [mx.nd.array(np.random.randint(0,10, ((self.data_shape)[0][1]/self.num_batch)))]
print data
print label
return SimpleBatch(data,label)
else:
raise StopIteration
class SyntheticData:
def __init__(self,mean,std,num_records,num_of_features,num_classes):
self.mean = mean
self.std = std
self.data_shape = zip(num_records,num_of_features)
self.label_shape = zip(num_records,)
self.num_classes = num_classes
def get_iter(self):
return SimpleIter(self.mean,self.std,self.data_shape,self.label_shape,self.num_classes)
net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(data = net,name = 'fc1',num_hidden = 64)
net = mx.sym.Activation(data = net,name = 'relu_1',act_type = 'relu')
net = mx.sym.FullyConnected(data = net,name = 'fc2',num_hidden = 10)
net = mx.sym.SoftmaxOutput(data = net,name = 'softmax')
data = SyntheticData(10,128,[100],[100],10)
mod.fit(data.get_iter(),
eval_data=data.get_iter(),
optimizer='sgd',
optimizer_params={'learning_rate':0.1},
eval_metric='acc',
num_epoch = 5)
ошибка:
TypeError Traceback (most recent call last)
<ipython-input-273-a7375f022406> in <module>()
4 optimizer_params={'learning_rate':0.1},
5 eval_metric='acc',
----> 6 num_epoch = 5)
/usr/local/lib/python2.7/dist-packages/mxnet-0.9.4-py2.7.egg/mxnet/module/base_module.pyc in fit(self, train_data, eval_data, eval_metric, epoch_end_callback, batch_end_callback, kvstore, optimizer, optimizer_params, eval_end_callback, eval_batch_end_callback, initializer, arg_params, aux_params, allow_missing, force_rebind, force_init, begin_epoch, num_epoch, validation_metric, monitor)
440
441 self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
--> 442 for_training=True, force_rebind=force_rebind)
443 if monitor is not None:
444 self.install_monitor(monitor)
/usr/local/lib/python2.7/dist-packages/mxnet-0.9.4-py2.7.egg/mxnet/module/module.pyc in bind(self, data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind, shared_module, grad_req)
386 fixed_param_names=self._fixed_param_names,
387 grad_req=grad_req,
--> 388 state_names=self._state_names)
389 self._total_exec_bytes = self._exec_group._total_exec_bytes
390 if shared_module is not None:
/usr/local/lib/python2.7/dist-packages/mxnet-0.9.4-py2.7.egg/mxnet/module/executor_group.pyc in __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_names, for_training, inputs_need_grad, shared_group, logger, fixed_param_names, grad_req, state_names)
203 for name in self.symbol.list_outputs()]
204
--> 205 self.bind_exec(data_shapes, label_shapes, shared_group)
206
207 def decide_slices(self, data_shapes):
/usr/local/lib/python2.7/dist-packages/mxnet-0.9.4-py2.7.egg/mxnet/module/executor_group.pyc in bind_exec(self, data_shapes, label_shapes, shared_group, reshape)
282
283 # calculate workload and bind executors
--> 284 self.data_layouts = self.decide_slices(data_shapes)
285 if label_shapes is not None:
286 # call it to make sure labels has the same batch size as data
/usr/local/lib/python2.7/dist-packages/mxnet-0.9.4- py2.7.egg/mxnet/module/executor_group.pyc in decide_slices(self, data_shapes)
220 continue
221
--> 222 batch_size = shape[axis]
223 if self.batch_size is not None:
224 assert batch_size == self.batch_size, ("all data must have the same batch size: "
TypeError: 'int' object has no attribute '__getitem__'
1 ответ
Решение
Я думаю, что ваша проблема в определении вашего data_shape
,
self.data_shape = data_shape[0]
Как вы определили, self.data_shape
это просто инт. В вашем случае, я думаю, что это должно быть просто
self.data_shape = data_shape
Так что когда shape[axis]
доступ из decide_slices
он может получить количество элементов.