Наборы данных TensorFlow 1.2 с функцией карты
У меня есть следующий кусок кода, который работает в TF1.3 и TF1.4. Когда я пробую это в t1.2, код работает, но просто зависает. Я использую только tf1.2, потому что я хочу выполнить тестирование на облачном мл-движке Google, и на этом этапе движок поддерживает только tf1.2
Вот мой входной файл CSV:
A B Result
2 2 4
2 3 5
Вот мой код:
csv_defaults = OrderedDict([("A", [0]), ("B", [0]), ("Result", [0])]);
file_path = "InputFile.csv";
def csv_decoder(line):
parsed = tf.decode_csv(line, list(csv_defaults.values()), field_delim="\t");
return parsed[0];
def test():
dataset = (TextLineDataset(file_path)
.skip(1)
.map(csv_decoder)
.batch(512));
iterator = dataset.make_one_shot_iterator();
columns = iterator.get_next();
return columns;
input_fn = test();
with tf.Session() as sess:
columns = sess.run(input_fn);
print(columns);
Это вывод в тф 1.4
[2 2]
Когда я запускаю тот же код в TF 1.2, код просто зависает и ничего не возвращает..
Из https://github.com/tensorflow/tensorflow/issues/13751 я знаю, что в tf 1.2 функция parse_csv не может возвращать dict, tuple или namedtuple (я также попробовал их все). Таким образом, я сократил это до возвращения только тензор. В ошибке @mrry рекомендует извлечь значения для объектов и затем создать кортеж вручную. Функция parser(record) возвращает тензор и, похоже, работает. Мой parse_csv также возвращает тензор, но он все еще не работает. Может кто-нибудь, пожалуйста, помогите мне?
Извините, если я упускаю что-то очевидное. Я только использую tf в течение прошлых нескольких недель и искал много ответов.