Как сохранить векторы в словаре в тензорном потоке?

Кажется, что tf.lookup.experimental.DenseHashTable не может содержать векторы, и я не мог найти примеров, как его использовать.

1 ответ

Ниже вы можете найти простую реализацию словаря векторов в Tensorflow. Это также пример использованияtf.lookup.experimental.DenseHashTable а также tf.TensorArray.

Как уже говорилось, векторы нельзя хранить в tf.lookup.experimental.DenseHashTable, и поэтому tf.TensorArray используется для сохранения актуальных векторов.

Конечно, это простой пример, и он не включает удаление записей в словаре - операцию, которая потребует некоторого управления свободными ячейками массива. Кроме того, вы должны прочитать на соответствующих страницах APItf.lookup.experimental.DenseHashTable а также tf.TensorArray как настроить их под свои нужды.

import tensorflow as tf


class DictionaryOfVectors:

  def __init__(self, dtype):
    empty_key = tf.constant('')
    deleted_key = tf.constant('deleted')

    self.ht = tf.lookup.experimental.DenseHashTable(key_dtype=tf.string,
                                                    value_dtype=tf.int32,
                                                    default_value=-1,
                                                    empty_key=empty_key,
                                                    deleted_key=deleted_key)
    self.ta = tf.TensorArray(dtype, size=0, dynamic_size=True, clear_after_read=False)
    self.inserts_counter = 0

  @tf.function
  def insertOrAssign(self, key, vec):
    # Insert the vector to the TensorArray. The write() method returns a new
    # TensorArray object with flow that ensures the write occurs. It should be 
    # used for subsequent operations.
    with tf.init_scope():
      self.ta = self.ta.write(self.inserts_counter, vec)

      # Insert the same counter value to the hash table
      self.ht.insert_or_assign(key, self.inserts_counter)
      self.inserts_counter += 1

  @tf.function
  def lookup(self, key):
    with tf.init_scope():
      index = self.ht.lookup(key)
      return self.ta.read(index)

dictionary_of_vectors = DictionaryOfVectors(dtype=tf.float32)
dictionary_of_vectors.insertOrAssign('first', [1,2,3,4,5])
print(dictionary_of_vectors.lookup('first'))

Пример немного сложнее, поскольку методы вставки и поиска украшены @tf.function. Поскольку методы изменяют переменные, определенные вне них,tf.init_scope()используется. Вы можете спросить, что изменилось вlookup()метод, поскольку он фактически только читает из хеш-таблицы и массива. Причина в том, что в режиме графика индекс, возвращаемый изlookup() call - это Tensor, а в реализации TensorArray есть строка, содержащая if index < 0: что не удается:

OperatorNotAllowedInGraphError: использование tf.Tensor как Python bool не допускается.

Когда мы используем tf.init_scope(), как объясняется в документации по API, "код внутри блока init_scope выполняется с активным исполнением даже при трассировке tf.function". Значит, в этом случае этот индекс не тензор, а скаляр.

Другие вопросы по тегам