Перезапись методов с помощью миксин-паттерна не работает как задумано

Я пытаюсь представить мод /mixin для проблемы. В частности, я сосредоточен здесь на SpeechRecognitionProblem, Я намерен изменить эту проблему, и поэтому я стараюсь сделать следующее:

class SpeechRecognitionProblemMod(speech_recognition.SpeechRecognitionProblem):

    def hparams(self, defaults, model_hparams):
        SpeechRecognitionProblem.hparams(self, defaults, model_hparams)
        vocab_size = self.feature_encoders(model_hparams.data_dir)['targets'].vocab_size
        p = defaults
        p.vocab_size['targets'] = vocab_size

    def feature_encoders(self, data_dir): 
        # ...

Так что этот не делает много. Это вызывает hparams() функция из базового класса, а затем меняет некоторые значения.

Теперь уже есть некоторые готовые проблемы, например Libri Speech:

@registry.register_problem()
class Librispeech(speech_recognition.SpeechRecognitionProblem):
    # ..

Однако, чтобы применить мои модификации, я делаю это:

@registry.register_problem()
class LibrispeechMod(SpeechRecognitionProblemMod, Librispeech):
    # ..

Это должно, если я не ошибаюсь, переписать все (с одинаковыми подписями) в Librispeech и вместо вызова функций SpeechRecognitionProblemMod,

Поскольку я смог обучить модель с этим кодом, я предполагаю, что она работает так, как задумано.

Теперь вот моя проблема:

После тренировки хочу сериализовать модель. Это обычно работает. Тем не менее, это не так с моим модом, и я на самом деле знаю, почему:

В определенный момент hparams() вызывается. Отладка до этого момента покажет мне следующее:

self                  # {LibrispeechMod}
self.hparams          # <bound method SpeechRecognitionProblem.hparams of ..>
self.feature_encoders # <bound method SpeechRecognitionProblemMod.feature_encoders of ..>

self.hparams должно быть <bound method SpeechRecognitionProblemMod.hparams of ..>! Казалось бы, по какой-то причине hparams() из SpeechRecognitionProblem вызывается напрямую вместо SpeechRecognitionProblemMod, Но обратите внимание, что это правильный тип для feature_encoders()!

Дело в том, что я знаю, что это работает во время тренировок. Я вижу, что гиперпараметры (hparams) применяются соответственно просто потому, что имена узлов графа модели меняются в результате моих модификаций.

Есть одна специальность, которую я должен указать. tensor2tensor позволяет динамически загружать t2t_usr_dir, которые являются дополнительными модулями Python, которые загружаются import_usr_dir, Я использую эту функцию и в моем сценарии сериализации:

if usr_dir:
    logging.info('Loading user dir %s' % usr_dir)
    import_usr_dir(usr_dir)

Это может быть единственным виновником, которого я вижу на данный момент, хотя я не смогу сказать, почему это может вызвать проблему.

Если кто-то видит что-то, чего нет у меня, я был бы рад получить подсказку, что я здесь делаю неправильно.


Так какую ошибку вы получаете?

Ради полноты, это результат неправильного hparams() вызываемый метод:

NotFoundError (see above for traceback): Restoring from checkpoint failed.
Key transformer/symbol_modality_256_256/softmax/weights_0 not found in checkpoint

symbol_modality_256_256 неправильно. Так должно быть symbol_modality_<vocab-size>_256 где <vocab-size> это размер словаря, который устанавливается в SpeechRecognitionProblemMod.hparams,

1 ответ

Решение

Итак, это странное поведение произошло из-за того, что я выполнял удаленную отладку и что исходные файлы usr_dir не были правильно синхронизированы. Все работает, как задумано, но исходные файлы там, где не совпадают.

Дело закрыто.

Другие вопросы по тегам