Как я могу выполнить следующие RNN на основе GRU, написанные в тензорном потоке?

Пока что я написал следующий код:

import pickle
import numpy as np
import pandas as pd
import tensorflow as tf

# load pickled objects (x and y)
x_input, y_actual = pickle.load(open('sample_input.pickle', 'rb'))
x_input = np.reshape(x_input, (50, 1))
y_actual = np.reshape(y_actual, (50, 1))

# parameters
batch_size = 50
hidden_size = 100

# create network graph
input_data = tf.placeholder(tf.float32, [batch_size, 1])
output_data = tf.placeholder(tf.float32, [batch_size, 1])

cell = tf.nn.rnn_cell.GRUCell(hidden_size)

initial_state = cell.zero_state(batch_size, tf.float32)

hidden_state = initial_state

output_of_cell, hidden_state = cell(inputs=input_data, state=hidden_state)

init_op = tf.initialize_all_variables()

softmax_w = tf.get_variable("softmax_w", [hidden_size, 1], )
softmax_b = tf.get_variable("softmax_b", [1])

logits = tf.matmul(output_of_cell, softmax_w) + softmax_b

probabilities = tf.nn.softmax(logits)

sess = tf.Session()
sess.run(init_op)

something = sess.run([probabilities, hidden_state], feed_dict={input_data:x_input, output_data:y_actual})

#cost = tf.nn.sigmoid_cross_entropy_with_logits(logits, output_data)


#sess.close()

Но я получаю ошибку для сoftmax_w/b как неинициализированные переменные.

Я не понимаю, как я должен использовать эти W а также b и проводить работу поезда.

Что-то вроде следующего:

## some cost function
## training operation minimizing cost function using gradient descent optimizer

1 ответ

Решение

tf.initialize_all_variables() получает "текущий" набор переменных из графика. Так как вы создаете softmax_w а также softmax_b после вашего звонка tf.initialize_all_variables()их нет в списке tf.initialize_all_variables() консультируется, и, следовательно, не инициализируется при запуске sess.run(init_op), Следующее должно работать:

softmax_w = tf.get_variable("softmax_w", [hidden_size, 1], )
softmax_b = tf.get_variable("softmax_b", [1])

init_op = tf.initialize_all_variables()
Другие вопросы по тегам