Аналог tf.depthwise_conv2d с использованием Jax jax.lax.conv

Я переношу код с Tensorflow на Jax и сталкиваюсь со следующими трудностями:

У меня есть два массива: R и S. У нас есть:

R.shape
(10,201,11)

а также

S.shape
(61,11)

Мне нужно свернуть каждый S[:,i] с соответствующим R[j,:,i] для всех j из 0:9, в результате получится shape = [10,201]. Это можно сделать в Tensorflow, выполнив следующие действия:

R1 = tf.expand_dims(R,axis=1)
output=tf.nn.depthwise_conv2d(R1,S,strides=[1,1,1,1],padding='SAME')

Использование функции тензорного потока tf.nn.depthwise_conv2d.

Мне интересно, есть ли способ сделать это с помощью jax.lax.conv. Существует небольшой учебник по функциям сверток Jax" здесь, но это своего рода трудно следовать, и, кажется, предназначены для общего 2d сверток; не 1d свертки применялись равномерно по всем строкам.

0 ответов

Другие вопросы по тегам