Обучите RNN на большой строке, используя Estimator API
Мы бы хотели, чтобы наш оценщик предсказывал слово или следующую работу, как это делает клавиатура вашего смартфона. Мы хотели бы обучить его на некотором текстовом файле.
Итак, мы пошли дальше и посмотрели на API-интерфейс tenorflow и нашли
estimator = RNNEstimator(
head=tf.contrib.estimator.regression_head(),
sequence_feature_columns=[token_emb],
rnn_cell_fn=rnn_cell_fn)
который представляется удобным способом создания оценщика для RNN. Теперь мы столкнулись с проблемами со столбцами функций. Мы настраиваем их так
token_sequence = sequence_categorical_column_with_hash_bucket(
key="text", hash_bucket_size=num_of_categories, dtype=tf.string)
token_emb = embedding_column(categorical_column=token_sequence,
dimension=8)
где 'text'
определяется в нашей входной функции
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"text": features},
y=labels,
batch_size=batch_size,
num_epochs=None,
shuffle=True)
где features
это просто длинный список 40-символьных последовательностей, взятых из нашего оригинального текста.
Проблемы
- Можно ли в любом случае использовать элементарные столбцы на строковых входах? Документация на самом деле мало что выдаёт.
- Что делать с ярлыками? На данный момент мы получаем ошибку, так как они никогда не приводятся к целым числам
Даже при вводе произвольных целых чисел в качестве меток мы все равно получаем ошибку при вызове
estimator.train(input_fn=train_input_fn, steps=100)
который говорит"Заданный тип: {}". Format(type(features))) ValueError: features должен быть словарем
Tensor
s. Данный тип:
поэтому мы определенно делаем что-то не так здесь. Любая помощь приветствуется:)
1 ответ
Существует краткий пример передачи функций строковых слов в качестве SparseTensors в классификацию по метке за шаг Estimator
(т.е. моделирование языка) в модульных тестах StateSavingRnnEstimator. Это выглядит примерно как то, что вы пытаетесь сделать, с оговоркой, что данный Оценщик устарел; может иметь смысл взять идеи и определить свои собственные model_fn
,