Исключение Tensorboard с summary.image формы [-1, 125, 128, 1] MFCC

Следуя этому руководству, я конвертирую тензор [batch_size, 16000, 1] в MFCC, используя метод, описанный в ссылке:

def gen_spectrogram(wav, sr=16000):
    # A 1024-point STFT with frames of 64 ms and 75% overlap.
    stfts = tf.contrib.signal.stft(wav, frame_length=1024, frame_step=256, fft_length=1024)
    spectrograms = tf.abs(stfts)

    # Warp the linear scale spectrograms into the mel-scale.
    num_spectrogram_bins = stfts.shape[-1].value
    lower_edge_hertz, upper_edge_hertz, num_mel_bins = 80.0, 7600.0, 80
    linear_to_mel_weight_matrix = tf.contrib.signal.linear_to_mel_weight_matrix(
        num_mel_bins, num_spectrogram_bins,
        sample_rate, lower_edge_hertz, upper_edge_hertz)
    mel_spectrograms = tf.tensordot(spectrograms, linear_to_mel_weight_matrix, 1)
    mel_spectrograms.set_shape(
       spectrograms.shape[:-1].concatenate(
          linear_to_mel_weight_matrix.shape[-1:]
       )
    )

    # Compute a stabilized log to get log-magnitude mel-scale spectrograms.
    log_mel_spectrograms = tf.log(mel_spectrograms + 1e-6)

    # Compute MFCCs from log_mel_spectrograms and take the first 13.
    return tf.contrib.signal.mfccs_from_log_mel_spectrograms(log_mel_spectrograms)[..., :13]

Затем я изменяю вывод этого в [batch_size, 125, 128, 1], Если я отправлю это tf.layers.conv2dКажется, все работает нормально. Однако, если я попытаюсь tf.summary.imageЯ получаю следующую ошибку:

print(spec)
// => Tensor("spectrogram/Reshape:0", shape=(?, 125, 128, 1), dtype=float32)

tf.summary.image('spec', spec)

Caused by op u'spectrogram/stft/rfft', defined at:
  File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 162, in _run_module_as_main
    "__main__", fname, loader, pkg_name)
  File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 72, in _run_code
    exec code in run_globals
  File "/Users/rsilveira/rnd/ml-engine/trainer/flatv1.py", line 103, in <module>
    runner.run(model_fn)
  File "trainer/runner.py", line 88, in run
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
  File "/Library/Python/2.7/site-packages/tensorflow/python/estimator/training.py", line 432, in train_and_evaluate
    executor.run_local()
  File "/Library/Python/2.7/site-packages/tensorflow/python/estimator/training.py", line 611, in run_local
    hooks=train_hooks)
  File "/Library/Python/2.7/site-packages/tensorflow/python/estimator/estimator.py", line 302, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/Library/Python/2.7/site-packages/tensorflow/python/estimator/estimator.py", line 711, in _train_model
    features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
  File "/Library/Python/2.7/site-packages/tensorflow/python/estimator/estimator.py", line 694, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "/Users/rsilveira/rnd/ml-engine/trainer/flatv1.py", line 53, in model_fn
    spec = gen_spectrogram(x)
  File "/Users/rsilveira/rnd/ml-engine/trainer/flatv1.py", line 22, in gen_spectrogram
    step,
  File "/Library/Python/2.7/site-packages/tensorflow/contrib/signal/python/ops/spectral_ops.py", line 91, in stft
    return spectral_ops.rfft(framed_signals, [fft_length])
  File "/Library/Python/2.7/site-packages/tensorflow/python/ops/spectral_ops.py", line 136, in _rfft
    return fft_fn(input_tensor, fft_length, name)
  File "/Library/Python/2.7/site-packages/tensorflow/python/ops/gen_spectral_ops.py", line 619, in rfft
    "RFFT", input=input, fft_length=fft_length, name=name)
  File "/Library/Python/2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/Library/Python/2.7/site-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
    op_def=op_def)
  File "/Library/Python/2.7/site-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Input dimension 4 must have length of at least 512 but got: 320

Не уверен, с чего начать устранение неполадок. Что мне здесь не хватает?

0 ответов

Другие вопросы по тегам