Сверточная нейронная сеть Google JAX 1D
Я пытаюсь реализовать 1D сверточную нейронную сеть в Google Jax с помощью stax.GeneralConv() (https://jax.readthedocs.io/en/latest/_modules/jax/experimental/stax.html). У меня есть одномерный входной массив с 18 и выходной массив с 6 записями. Я хочу реализовать CNN с шириной ядра 3 следующим образом:
init_random_params, conv_net = stax.serial(
GeneralConv(('NC','IO','NC'),1,(3,),padding='SAME'), # dimension_numbers = ('NC','IO','NC')
LogSoftmax,
Dense(6),
)
с исходными параметрами сети:
rng = jax.random.PRNGKey(0)
_, init_params = init_random_params(rng, (18,))
Но я получаю следующую ошибку:
stax.py", line 75, in <listcomp>
next(filter_shape_iter) for c in rhs_spec]
IndexError: tuple index out of range
stax требует, чтобы размерное число rhs_spec было не менее 2 символов, но я использую одномерный фильтр. Есть ли у кого-нибудь идеи, как решить эту проблему?
1 ответ
Я сам этого не пробовал, но ожидаю, что 1-мерная свертка по-прежнему требует одного направления для свертки, например
Conv2d = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
Conv1d = functools.partial(GeneralConv, ('NHC', 'HIO', 'NHC'))
Другими словами, отбрасывая W
ось для перехода от 2-го к 1-му витку.
Форма ввода, соответствующая NHC
является (batch_size, sequence_length, num_channels)
.
Обратите внимание, что даже если количество каналов может быть равно 1, вам все равно нужно включить эту ось, потому что GeneralConv
выполняет поиск по индексу по строкам num_channels = input_shape['NHC'.index('C')]
.