Эффективно создать сжатую диагональную матрицу в mxnet

В моей задаче у меня есть вектор, содержащий n элементы. Учитывая размер окна k Я хочу, чтобы эффективно создать размер матрицы n x 2k+1 который содержит полосу диагонали. Например:

a = [a_1, a_2, a_3, a_4]
k = 1
b = [[0, a_1, a_2],
    [a_1, a_2, a_3],
    [a_2, a_3, a_4],
    [a_3, a_4, a_5],
    [a_4, a_5, 0]]

Наивный способ реализовать это будет использовать для циклов

out_data = mx.ndarray.zeros((n, 2k+1))
for i in range(0, n):
    for j in range(0, 2k+1):
        index = i - k + j
        if not (index < 0 or index >= seq_len):
            out_data[i][j] = in_data[index]

Это очень медленно.

Создать полную матрицу было бы легко, просто используя tile а также reshapeОднако маскирующая часть не ясна.

Обновление Я нашел более быструю, но все еще очень медленную реализацию:

window = 2*self.windowSize + 1
in_data_reshaped = in_data.reshape((batch_size, seq_len))
out_data = mx.ndarray.zeros((seq_len * window))
for i in range(0, seq_len):
    copy_from_start = max(i - self.windowSize, 0)
    copy_from_end = min(seq_len -1, i+1+self.windowSize)
    copy_length = copy_from_end - copy_from_start
    copy_to_start = i*window + (2*self.windowSize + 1 - copy_length)
    copy_to_end = copy_to_start + copy_length
    out_data[copy_to_start:copy_to_end] = in_data_reshaped[copy_from_start:copy_from_end]
out_data = out_data.reshape((seq_len, window))

1 ответ

Решение

Если в вашей работе, k а также n постоянны, и вы можете делать то, что вы хотите, используя комбинацию mxnet.nd.gather_nd() а также mx.nd.scatter_nd, Хотя генерация тензора индексов столь же неэффективна, потому что вам нужно сделать это только один раз, это не будет проблемой. Вы хотели бы использовать gather_nd эффективно "дублировать" ваши данные из исходного массива, а затем использовать scatter_nd рассеять их до окончательной формы матрицы. Кроме того, вы можете просто объединить 0 элемент вашего входного массива (например, [a_1, a_2, a_3] превратится в [0, a_1, a_2, a_3]), а затем использовать только mxnet.nd.gather_nd() дублировать элементы в вашей окончательной матрице.

Другие вопросы по тегам