Как можно передать два трехмерных тензора только с одним общим измерением (размером пакета) в dynamic_lstm?

Я хотел бы передать 2 тензора с разными размерами в tf.nn.dynamic_rnn, У меня возникли трудности, потому что размеры не совпадают. Я открыт для предложений о лучшем способе сделать это. Эти тензоры являются партиями из tf.data.Dataset

У меня есть 2 тензора формы:

тензор 1: (?,?, 1024)

тензор 2: (?,?, 128)

Первое измерение - это размер пакета, второе измерение - это количество временных шагов, а третье измерение - это число объектов, вводимых на каждом временном шаге.

В настоящее время у меня есть проблема, что число временных шагов для каждого измерения не совпадают. Не только это, но и они несовместимы по размеру между образцами (для некоторых образцов тензор 1 имеет 71 шаг, иногда он может иметь 74 или 77).

Наилучшее решение для динамического дополнения количества временных шагов в более коротком тензоре для каждого образца? Если так, как бы я это сделал?

Ниже приведен фрагмент кода, чтобы показать, что я хотел бы сделать:

#Get the next batch from my tf.data.Dataset
video_id, label, rgb, audio = my_iter.get_next()

print (rgb.shape)    #(?, ?, 1024)
print (audio.shape)    #(?, ?, 128)

lstm_layer = tf.contrib.rnn.BasicLSTMCell(lstm_size)

#This instruction throws an InvalidArgumentError, I have shown the output below this code
concatenated_features = tf.concat([rgb, audio], 2)
print (concatenated_features.shape)    #(?, ?, 1152)

outputs,_= tf.nn.dynamic_rnn(lstm_layer, concatenated_features, dtype="float32")

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(num_epochs):
        sess.run(my_iter.initializer)
        for j in range(num_steps):
            my_outputs = sess.run(outputs)

Ошибка при звонке tf.concat в сеансе:

InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [52,77,1024] vs. shape[1] = [52,101,128]

1 ответ

Решение

Вот решение, которое мне удалось найти, которое не может быть идеальным, но решает проблему, если у кого-то нет лучшего решения. Кроме того, я новичок в TensorFlow и открыт для редактирования, если мои рассуждения неверны.

Проблема усложняется тем, что тензоры хранятся в tf.data.Dataset и любой предложенный Dataset.map Функция (используется для выполнения поэлементных операций) работает с символическими тензорами (которые в этом случае не имеют точной формы). По этой причине я не смог создать Dataset.map использовать функцию tf.pad но я открыт для решений, которые делают.

Это решение использует Dataset.map функция и tf.py_func обернуть функцию python как операцию TensorFlow. Эта функция находит разницу между двумя тензорами (теперь np.arrays внутри функции), а затем использует np.pad чтобы заполнить измерение временных шагов 0 после данных.

def pad_timesteps(video, labs, rgb, audio):
""" Function to pad the timesteps of visual or audio features so that they are equal    
"""
    rgb_timesteps = rgb.shape[1] #Get the number of timesteps for rgb
    audio_timesteps = audio.shape[1] #Get the number of timesteps for audio

    if rgb_timesteps < audio_timesteps:
        difference = audio_timesteps - rgb_timesteps
        #How much you want to pad dimension 1, 2 and 3
        #Each padding tuple is the amount to pad before and after the data in that dimension
        np_padding = ((0, 0), (0,difference), (0,0))
        #This tuple contains the values that are to be used to pad the data
        padding_values = ((0,0), (0,0), (0,0))
        rgb = np.pad(rgb, np_padding, mode='constant', constant_values=padding_values)

    elif rgb_timesteps > audio_timesteps:
        difference = rgb_timesteps - audio_timesteps
        np_padding = ((0,0), (0,difference), (0,0))
        padding_values = ((0,0), (0,0), (0,0))
        audio = np.pad(audio, np_padding, mode='constant', constant_values=padding_values)

    return video, labs, rgb, audio

dataset = dataset.map(lambda video, label, rgb, audio: tuple(tf.py_func(pad_timesteps, [video, label, rgb, audio], [tf.string, tf.int64, tf.float32, tf.float32])))
Другие вопросы по тегам