Ошибка при переходе с использованием API оценки в тензорном потоке

Я пытаюсь запустить простой SVM классификатор по набору данных радужной оболочки, предоставляя данные, используя input_fn, возвращая tf.data.dataset объект, но я сталкиваюсь со следующей ошибкой.

Traceback (most recent call last):
  File "tf_test.py", line 45, in <module>
    est.fit(steps=1, input_fn=input_fn)
  File "/venv/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/venv/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 524, in fit
    loss = self._train_model(input_fn=input_fn, hooks=hooks)
  File "/venv/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1038, in _train_model
    features, labels = input_fn()
ValueError: too many values to unpack (expected 2)

Я думаю, что вышеупомянутая ошибка является ошибкой в ​​тензорном потоке, поэтому я попытался вернуть итератор вместо использования строки tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() который сталкивается с другой ошибкой следующим образом

Traceback (most recent call last):
  File "tf_test.py", line 48, in <module>
    est.fit(steps=1, input_fn=input_fn)
  File "venv/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/venv/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 524, in fit
    loss = self._train_model(input_fn=input_fn, hooks=hooks)
  File "/venv/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1041, in _train_model
    model_fn_ops = self._get_train_ops(features, labels)
  File "/venv/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1264, in _get_train_ops
    return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN)
  File "/venv/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1227, in _call_model_fn
    model_fn_results = self._model_fn(features, labels, **kwargs)
  File "/venv/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/linear.py", line 251, in sdca_model_fn
    features.update(layers.transform_features(features, feature_columns))
  File "/venv/lib/python3.7/site-packages/tensorflow/contrib/layers/python/layers/feature_column_ops.py", line 653, in transform_features
    check_feature_columns(feature_columns)
  File "/venv/lib/python3.7/site-packages/tensorflow/contrib/layers/python/layers/feature_column_ops.py", line 777, in check_feature_columns
    key = f.key
AttributeError: 'str' object has no attribute 'key'

Я подготовил этот автономный пример, чтобы продемонстрировать проблему.

import copy
import tempfile
import tensorflow as tf
tf.enable_eager_execution()

_, filename = tempfile.mkstemp()

cols = ["example_id", "sepal_len", "sepal_width", "petal_len", "petal_width", "label"]

data = "\n".join([
    #",".join(cols),
    "1,5.1,3.5,1.4,0.2,Iris-setosa",
    "2,4.9,3.0,1.4,0.2,Iris-setosa",
    "2,4.7,3.2,1.3,0.2,Iris-setosa",
    "4,4.6,3.1,1.5,0.2,Iris-setosa",
    "5,5.0,3.6,1.4,0.2,Iris-setosa",
    "6,5.7,2.5,5.0,2.0,Iris-virginica",
    "7,7.0,3.2,4.7,1.4,Iris-versicolor",
    "8,6.5,3.2,5.1,2.0,Iris-virginica",
    "9,6.4,2.7,5.3,1.9,Iris-virginica",
    "10,6.8,3.0,5.5,2.1,Iris-virginica",
    "11,5.4,3.9,1.7,0.4,Iris-setosa"
])

with open(filename, 'w') as f:
    f.write(data)

batch_size = 2

FIELD_DEFAULTS = [[0], [0.0], [0.0], [0.0], [0.0], [0]]

def _parse(line):
    fields = tf.io.decode_csv(line, FIELD_DEFAULTS)
    features = dict(zip(cols, fields))
    label = features.pop('label')
    return features, label

def input_fn():
    dataset = tf.data.TextLineDataset(filename).skip(1)
    dataset = dataset.map(_parse)
    return dataset
    # return tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()


feature_cols = copy.copy(cols)
feature_cols.remove('example_id')
est = tf.contrib.learn.SVM(example_id_column='example_ids', feature_columns=feature_cols)
est.fit(steps=1, input_fn=input_fn)
out = est.predict(input_fn=test_input_fn, yield_single_examples=False)
assert(len(out['classes']) == len(data))

Как я могу сделать эту работу с tensorflow==1.14

0 ответов

Другие вопросы по тегам