Ошибка при переходе с использованием API оценки в тензорном потоке
Я пытаюсь запустить простой SVM
классификатор по набору данных радужной оболочки, предоставляя данные, используя input_fn, возвращая tf.data.dataset
объект, но я сталкиваюсь со следующей ошибкой.
Traceback (most recent call last):
File "tf_test.py", line 45, in <module>
est.fit(steps=1, input_fn=input_fn)
File "/venv/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/venv/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 524, in fit
loss = self._train_model(input_fn=input_fn, hooks=hooks)
File "/venv/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1038, in _train_model
features, labels = input_fn()
ValueError: too many values to unpack (expected 2)
Я думаю, что вышеупомянутая ошибка является ошибкой в тензорном потоке, поэтому я попытался вернуть итератор вместо использования строки tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
который сталкивается с другой ошибкой следующим образом
Traceback (most recent call last):
File "tf_test.py", line 48, in <module>
est.fit(steps=1, input_fn=input_fn)
File "venv/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/venv/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 524, in fit
loss = self._train_model(input_fn=input_fn, hooks=hooks)
File "/venv/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1041, in _train_model
model_fn_ops = self._get_train_ops(features, labels)
File "/venv/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1264, in _get_train_ops
return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN)
File "/venv/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1227, in _call_model_fn
model_fn_results = self._model_fn(features, labels, **kwargs)
File "/venv/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/linear.py", line 251, in sdca_model_fn
features.update(layers.transform_features(features, feature_columns))
File "/venv/lib/python3.7/site-packages/tensorflow/contrib/layers/python/layers/feature_column_ops.py", line 653, in transform_features
check_feature_columns(feature_columns)
File "/venv/lib/python3.7/site-packages/tensorflow/contrib/layers/python/layers/feature_column_ops.py", line 777, in check_feature_columns
key = f.key
AttributeError: 'str' object has no attribute 'key'
Я подготовил этот автономный пример, чтобы продемонстрировать проблему.
import copy
import tempfile
import tensorflow as tf
tf.enable_eager_execution()
_, filename = tempfile.mkstemp()
cols = ["example_id", "sepal_len", "sepal_width", "petal_len", "petal_width", "label"]
data = "\n".join([
#",".join(cols),
"1,5.1,3.5,1.4,0.2,Iris-setosa",
"2,4.9,3.0,1.4,0.2,Iris-setosa",
"2,4.7,3.2,1.3,0.2,Iris-setosa",
"4,4.6,3.1,1.5,0.2,Iris-setosa",
"5,5.0,3.6,1.4,0.2,Iris-setosa",
"6,5.7,2.5,5.0,2.0,Iris-virginica",
"7,7.0,3.2,4.7,1.4,Iris-versicolor",
"8,6.5,3.2,5.1,2.0,Iris-virginica",
"9,6.4,2.7,5.3,1.9,Iris-virginica",
"10,6.8,3.0,5.5,2.1,Iris-virginica",
"11,5.4,3.9,1.7,0.4,Iris-setosa"
])
with open(filename, 'w') as f:
f.write(data)
batch_size = 2
FIELD_DEFAULTS = [[0], [0.0], [0.0], [0.0], [0.0], [0]]
def _parse(line):
fields = tf.io.decode_csv(line, FIELD_DEFAULTS)
features = dict(zip(cols, fields))
label = features.pop('label')
return features, label
def input_fn():
dataset = tf.data.TextLineDataset(filename).skip(1)
dataset = dataset.map(_parse)
return dataset
# return tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
feature_cols = copy.copy(cols)
feature_cols.remove('example_id')
est = tf.contrib.learn.SVM(example_id_column='example_ids', feature_columns=feature_cols)
est.fit(steps=1, input_fn=input_fn)
out = est.predict(input_fn=test_input_fn, yield_single_examples=False)
assert(len(out['classes']) == len(data))
Как я могу сделать эту работу с tensorflow==1.14