Аналог tf.depthwise_conv2d с использованием Jax jax.lax.conv
Я переношу код с Tensorflow на Jax и сталкиваюсь со следующими трудностями:
У меня есть два массива: R и S. У нас есть:
R.shape
(10,201,11)
а также
S.shape
(61,11)
Мне нужно свернуть каждый S[:,i] с соответствующим R[j,:,i] для всех j из 0:9, в результате получится shape = [10,201]. Это можно сделать в Tensorflow, выполнив следующие действия:
R1 = tf.expand_dims(R,axis=1)
output=tf.nn.depthwise_conv2d(R1,S,strides=[1,1,1,1],padding='SAME')
Использование функции тензорного потока tf.nn.depthwise_conv2d.
Мне интересно, есть ли способ сделать это с помощью jax.lax.conv. Существует небольшой учебник по функциям сверток Jax" здесь, но это своего рода трудно следовать, и, кажется, предназначены для общего 2d сверток; не 1d свертки применялись равномерно по всем строкам.