У объекта 'int' нет атрибута '__getitem__' mxnet

Я пытался выучить mxnet из учебника, когда выполняю загрузку данных. У меня 'int' нет 'getitem', но я не могу найти местоположение ошибки, пожалуйста, помогите мне, спасибо:

import mxnet as mx
import numpy as np

class SimpleData :
    def __init__(self,data,label,pad = 0):
        self.data = data
        self.label = label
        self.pad = pad

class SimpleIter:
        def __init__(self,mean,std,data_shape,label_shape,num_of_classes,num_batch = 10):
        self._provide_data = zip(['data'],data_shape[0])
        self._provide_label = zip(['softmax_label'],label_shape[0])
        self.cur_batch = 0
        self.num_batch = 10
        self.mean = mean 
        self.std = std
        self.data_shape = data_shape[0]
        self.label_shape = label_shape[0]
        self.num_of_classes  = num_of_classes

    def __iter__(self):
        return self

    def __next__(self):
        return self.next()

    def reset(self):
        self.cur_batch = 0

    @property 
    def provide_data(self):
        return self._provide_data

    @property
    def provide_label(self):
        return self._provide_label

    def next(self):
        if(self.cur_batch < self.num_batch):
            self.cur_batch += 1
            data = [mx.nd.array(np.random.normal(self.mean,self.std,    ((self.data_shape)[0][0]/self.num_batch,self.data_shape[0][1])))]
            label = [mx.nd.array(np.random.randint(0,10,    ((self.data_shape)[0][1]/self.num_batch)))]
            print data
            print label
            return SimpleBatch(data,label)
        else:
            raise StopIteration

class SyntheticData:
    def     __init__(self,mean,std,num_records,num_of_features,num_classes):
        self.mean = mean
        self.std = std
        self.data_shape = zip(num_records,num_of_features)
        self.label_shape = zip(num_records,)
        self.num_classes = num_classes

        def get_iter(self):
            return     SimpleIter(self.mean,self.std,self.data_shape,self.label_shape,self.num_classes)
net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(data = net,name = 'fc1',num_hidden = 64)
net = mx.sym.Activation(data = net,name = 'relu_1',act_type = 'relu')
net = mx.sym.FullyConnected(data = net,name = 'fc2',num_hidden = 10)
net = mx.sym.SoftmaxOutput(data = net,name = 'softmax')
data = SyntheticData(10,128,[100],[100],10) 
mod.fit(data.get_iter(), 
    eval_data=data.get_iter(),
    optimizer='sgd',
    optimizer_params={'learning_rate':0.1},
    eval_metric='acc',
    num_epoch = 5)

ошибка:

    TypeError                                 Traceback (most recent call last)
 <ipython-input-273-a7375f022406> in <module>()
      4         optimizer_params={'learning_rate':0.1},
      5         eval_metric='acc',
----> 6         num_epoch = 5)

/usr/local/lib/python2.7/dist-packages/mxnet-0.9.4-py2.7.egg/mxnet/module/base_module.pyc in fit(self, train_data, eval_data, eval_metric, epoch_end_callback, batch_end_callback, kvstore, optimizer, optimizer_params, eval_end_callback, eval_batch_end_callback, initializer, arg_params, aux_params, allow_missing, force_rebind, force_init, begin_epoch, num_epoch, validation_metric, monitor)
    440 
    441         self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
--> 442                   for_training=True, force_rebind=force_rebind)
    443         if monitor is not None:
    444             self.install_monitor(monitor)

/usr/local/lib/python2.7/dist-packages/mxnet-0.9.4-py2.7.egg/mxnet/module/module.pyc in bind(self, data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind, shared_module, grad_req)
    386                                                      fixed_param_names=self._fixed_param_names,
    387                                                      grad_req=grad_req,
--> 388                                                      state_names=self._state_names)
    389         self._total_exec_bytes = self._exec_group._total_exec_bytes
    390         if shared_module is not None:

/usr/local/lib/python2.7/dist-packages/mxnet-0.9.4-py2.7.egg/mxnet/module/executor_group.pyc in __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_names, for_training, inputs_need_grad, shared_group, logger, fixed_param_names, grad_req, state_names)
    203                                for name in self.symbol.list_outputs()]
    204 
--> 205         self.bind_exec(data_shapes, label_shapes, shared_group)
    206 
    207     def decide_slices(self, data_shapes):

/usr/local/lib/python2.7/dist-packages/mxnet-0.9.4-py2.7.egg/mxnet/module/executor_group.pyc in bind_exec(self, data_shapes, label_shapes, shared_group, reshape)
    282 
    283         # calculate workload and bind executors
--> 284         self.data_layouts = self.decide_slices(data_shapes)
    285         if label_shapes is not None:
    286             # call it to make sure labels has the same batch size     as data

/usr/local/lib/python2.7/dist-packages/mxnet-0.9.4-    py2.7.egg/mxnet/module/executor_group.pyc in decide_slices(self,     data_shapes)
       220                 continue
       221 
-->     222             batch_size = shape[axis]
       223             if self.batch_size is not None:
       224                 assert batch_size == self.batch_size, ("all data      must have the same batch size: "

TypeError: 'int' object has no attribute '__getitem__'

1 ответ

Решение

Я думаю, что ваша проблема в определении вашего data_shape,

self.data_shape = data_shape[0]

Как вы определили, self.data_shape это просто инт. В вашем случае, я думаю, что это должно быть просто

self.data_shape = data_shape

Так что когда shape[axis] доступ из decide_slices он может получить количество элементов.

Другие вопросы по тегам