Можете ли вы получить максимальное количество тензоров с разными размерами, используя pytorch?

Учитывая тензор pytorch длины, скажем, 15, существует ли «хороший» способ получить максимальные значения непересекающихся подмножеств этого тензора? В частности, при наличии списка l=(5,4,6)Я хочу максимум первых 5 элементов t, затем максимум следующих 4 элементов и максимум последних 6.

Если элементы lравны, тензор можно изменить, а максимум для каждой строки можно найти за один шаг. Но я не вижу способа сделать это красиво, когда элементы l разные, не прибегая к циклу. Было бы неплохо найти параллельный способ сделать это.

Для контекста: я работаю над проблемой обучения с подкреплением, в которой мои состояния представляют собой графики. Действие — это просто вершина графа. Я использую ДКН. Когда я выбираю пакет из буфера, для каждого графа в буфере я генерирую значение Q для каждой вершины, и для каждого графа я хочу найти максимальное значение среди этих значений Q. Но когда графы имеют разное количество вершин, я сталкиваюсь с проблемой, описанной выше: я не вижу способа получить максимальное значение Q для каждого графа без необходимости перебора пакета.

Вот пример того, как это сделать с помощью цикла:

      # Case when the subset lengths are unbalanced
q_values_tensor = torch.randint(10,(15,1))
lengths_list = [5,4,6]
max_q_values_unbalanced = torch.zeros((1,len(lengths_list)))
tensor_index=0
for list_index, length in enumerate(lengths_list):
    max_q_values_unbalanced[0, list_index] = torch.max(q_values_tensor[tensor_index:tensor_index+length])
    tensor_index += length

И пример, когда все длины подмножеств равны (я надеюсь, что может быть решение, похожее на это):

      # Case when the subset lengths are balanced (what I'd like to replicate)
q_values_tensor = torch.randint(10,(15,1))
lengths_list = [5,5,5]
reshaped_q_values_tensor = q_values_tensor.view(-1,lengths_list[0])
max_q_values_balanced = reshaped_q_values_tensor.max(1)[0]

0 ответов

Другие вопросы по тегам