Невозможно прочитать из файла Tensorflow tfrecord
Я могу создать файл tfrecords с помощью приведенного ниже кода.
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def convert_to_tfrecord(images,labels,file_name):
# images is a numpy array of shape (num_images,channel,rows,column)
# labels is a numpy array of shape (num_images,)
num_labels = np.shape(labels)
(num_images,depth,rows,cols) = np.shape(images)
writer = tf.python_io.TFRecordWriter(file_name)
for index in range(num_images):
image_raw = images[index]
image_raw = image_raw.astype(np.float32)
image_raw = image_raw.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
writer.close()
Но при чтении данных из файла tfrecord с помощью функции ниже
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
img_features = tf.parse_single_example(
serialized_example,
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(img_features['image_raw'], tf.float32)
label = tf.cast(img_features['label'], tf.int32)
height = tf.cast(img_features['height'], tf.int32)
width = tf.cast(img_features['width'], tf.int32)
depth = tf.cast(img_features['depth'], tf.int32)
image_shape = tf.stack([depth,height, width])
image = tf.reshape(image, image_shape)
return image,label
def inputs(batch_size, num_epochs):
filename = ['set1.tfrecords']
# dir_path is a global variable
file_path = dir_path + 'set1.tfrecords'
filename_queue = tf.train.string_input_producer([file_path], num_epochs=1)
image,label = read_and_decode(filename_queue)
images, sparse_labels = tf.train.shuffle_batch(
[image, label], batch_size=batch_size, num_threads=2,
capacity=1000 + 3 * batch_size, min_after_dequeue=1000)
return images, sparse_labels
Я постоянно получаю следующую ошибку
images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, num_threads=1, min_after_dequeue=10)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/input.py", line 1225, in shuffle_batch
name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/input.py", line 781, in _shuffle_batch
dtypes=types, shapes=shapes, shared_name=shared_name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/data_flow_ops.py", line 641, in __init__
shapes = _as_shape_list(shapes, dtypes)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/data_flow_ops.py", line 77, in _as_shape_list
raise ValueError("All shapes must be fully defined: %s" % shapes)
ValueError: All shapes must be fully defined: [TensorShape([Dimension(None)]), TensorShape([])]
В чем причина вышеуказанной ошибки и как ее преодолеть? Я могу прочитать файл tfrecords, перебирая файл с помощью tf.python_io.tf_record_iterator(path=filename)
,
1 ответ
Ошибка возникает потому, что tf.train.shuffle_batch
Необходимо знать форму ваших тензоров, чтобы иметь возможность их пакетировать (элементы в пакете должны иметь одинаковую форму). Однако в принципе необработанные данные могут иметь разные размеры, поэтому tf.decode_raw
не устанавливает никакой формы для вашего тензора.
В комментариях вы упоминаете, что все ваши изображения имеют форму (192,81,2)
так что вам нужно только установить эту форму в тензор изображения, прежде чем вернуться из read_and_decode
:
def read_and_decode(filename_queue):
# rest of your code here
image_shape = [height, width, depth]
image = tf.reshape(image, image_shape)
image.set_shape(image_shape) #<<<<<<<<<<<<<<<
return image,label