Как сохранить векторы в словаре в тензорном потоке?
Кажется, что tf.lookup.experimental.DenseHashTable
не может содержать векторы, и я не мог найти примеров, как его использовать.
1 ответ
Ниже вы можете найти простую реализацию словаря векторов в Tensorflow. Это также пример использованияtf.lookup.experimental.DenseHashTable
а также tf.TensorArray
.
Как уже говорилось, векторы нельзя хранить в tf.lookup.experimental.DenseHashTable
, и поэтому tf.TensorArray
используется для сохранения актуальных векторов.
Конечно, это простой пример, и он не включает удаление записей в словаре - операцию, которая потребует некоторого управления свободными ячейками массива. Кроме того, вы должны прочитать на соответствующих страницах APItf.lookup.experimental.DenseHashTable
а также tf.TensorArray
как настроить их под свои нужды.
import tensorflow as tf
class DictionaryOfVectors:
def __init__(self, dtype):
empty_key = tf.constant('')
deleted_key = tf.constant('deleted')
self.ht = tf.lookup.experimental.DenseHashTable(key_dtype=tf.string,
value_dtype=tf.int32,
default_value=-1,
empty_key=empty_key,
deleted_key=deleted_key)
self.ta = tf.TensorArray(dtype, size=0, dynamic_size=True, clear_after_read=False)
self.inserts_counter = 0
@tf.function
def insertOrAssign(self, key, vec):
# Insert the vector to the TensorArray. The write() method returns a new
# TensorArray object with flow that ensures the write occurs. It should be
# used for subsequent operations.
with tf.init_scope():
self.ta = self.ta.write(self.inserts_counter, vec)
# Insert the same counter value to the hash table
self.ht.insert_or_assign(key, self.inserts_counter)
self.inserts_counter += 1
@tf.function
def lookup(self, key):
with tf.init_scope():
index = self.ht.lookup(key)
return self.ta.read(index)
dictionary_of_vectors = DictionaryOfVectors(dtype=tf.float32)
dictionary_of_vectors.insertOrAssign('first', [1,2,3,4,5])
print(dictionary_of_vectors.lookup('first'))
Пример немного сложнее, поскольку методы вставки и поиска украшены @tf.function
. Поскольку методы изменяют переменные, определенные вне них,tf.init_scope()
используется. Вы можете спросить, что изменилось вlookup()
метод, поскольку он фактически только читает из хеш-таблицы и массива. Причина в том, что в режиме графика индекс, возвращаемый изlookup()
call - это Tensor, а в реализации TensorArray есть строка, содержащая if index < 0:
что не удается:
OperatorNotAllowedInGraphError: использование
tf.Tensor
как Pythonbool
не допускается.
Когда мы используем tf.init_scope()
, как объясняется в документации по API, "код внутри блока init_scope выполняется с активным исполнением даже при трассировке tf.function
". Значит, в этом случае этот индекс не тензор, а скаляр.