Сверточная нейронная сеть Google JAX 1D

Я пытаюсь реализовать 1D сверточную нейронную сеть в Google Jax с помощью stax.GeneralConv() (https://jax.readthedocs.io/en/latest/_modules/jax/experimental/stax.html). У меня есть одномерный входной массив с 18 и выходной массив с 6 записями. Я хочу реализовать CNN с шириной ядра 3 следующим образом:

init_random_params, conv_net = stax.serial(
    GeneralConv(('NC','IO','NC'),1,(3,),padding='SAME'), # dimension_numbers = ('NC','IO','NC')
    LogSoftmax,
    Dense(6),
)

с исходными параметрами сети:

rng = jax.random.PRNGKey(0)
_, init_params = init_random_params(rng, (18,))

Но я получаю следующую ошибку:

stax.py", line 75, in <listcomp>
    next(filter_shape_iter) for c in rhs_spec]

IndexError: tuple index out of range

stax требует, чтобы размерное число rhs_spec было не менее 2 символов, но я использую одномерный фильтр. Есть ли у кого-нибудь идеи, как решить эту проблему?

1 ответ

Я сам этого не пробовал, но ожидаю, что 1-мерная свертка по-прежнему требует одного направления для свертки, например

Conv2d = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
Conv1d = functools.partial(GeneralConv, ('NHC', 'HIO', 'NHC'))

Другими словами, отбрасывая W ось для перехода от 2-го к 1-му витку.

Форма ввода, соответствующая NHC является (batch_size, sequence_length, num_channels).

Обратите внимание, что даже если количество каналов может быть равно 1, вам все равно нужно включить эту ось, потому что GeneralConv выполняет поиск по индексу по строкам num_channels = input_shape['NHC'.index('C')].

Другие вопросы по тегам