Tensorflow - Как инициализировать все глобальные переменные из NumPy

Я хочу объединить 2 сети в одну сеть, сохраняя вес исходной сети.

Я сохранил веса в их форме с помощью:

for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    weights[i.name] = i.eval()

Я не могу найти способ загрузки весов в переменные новой сети. Есть ли способ загрузить веса для всех переменных?

Я пробовал следующее, но получаю ошибку en:

for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    i.initializer = weights[i.name]

Ошибка:

AttributeError: can't set attribute

1 ответ

Решение

Вы можете написать обе функции

def save_to_dict(sess, collection=tf.GraphKeys.TRAINABLE_VARIABLES):
    return {v.name: sess.run(v) for v in tf.global_variables()}


def load_from_dict(sess, data):
    for v in tf.global_variables():
        if v.name in data.keys():
            sess.run(v.assign(data[v.name]))

Хитрость заключается в том, чтобы просто перебрать все переменные и просто проверить, существуют ли они в словаре, как

import tensorflow as tf
import numpy as np


def save_to_dict(sess, collection=tf.GraphKeys.TRAINABLE_VARIABLES):
    return {v.name: sess.run(v) for v in tf.global_variables()}


def load_from_dict(sess, data):
    for v in tf.global_variables():
        if v.name in data.keys():
            sess.run(v.assign(data[v.name]))


def network(x):
    x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc0')
    x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc1')
    x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc2')
    x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc3')
    x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc4')
    return x


element = np.random.randn(8, 10)
weights = None

# first session
with tf.Session() as sess:

    x = tf.placeholder(dtype=tf.float32, shape=[None, 10])
    y = network(x)
    sess.run(tf.global_variables_initializer())

    # first evaluation
    expected = sess.run(y, {x: element})

    # dump as dict
    weights = save_to_dict(sess)

# destroy session and graph
tf.reset_default_graph()

# second session
with tf.Session() as sess:

    x = tf.placeholder(dtype=tf.float32, shape=[None, 10])
    y = network(x)
    sess.run(tf.global_variables_initializer())

    # use randomly initialized parameters
    actual = sess.run(y, {x: element})
    assert np.sum(np.abs(actual - expected)) > 0  # should NOT match

    # load previous parameters
    load_from_dict(sess, weights)

    actual = sess.run(y, {x: element})
    assert np.sum(np.abs(actual - expected)) == 0  # should match

Таким образом, вы можете просто удалить некоторые параметры из словаря, изменить вес перед загрузкой и даже изменить имя параметра.

Другие вопросы по тегам