Перезапись методов с помощью миксин-паттерна не работает как задумано
Я пытаюсь представить мод /mixin для проблемы. В частности, я сосредоточен здесь на SpeechRecognitionProblem
, Я намерен изменить эту проблему, и поэтому я стараюсь сделать следующее:
class SpeechRecognitionProblemMod(speech_recognition.SpeechRecognitionProblem):
def hparams(self, defaults, model_hparams):
SpeechRecognitionProblem.hparams(self, defaults, model_hparams)
vocab_size = self.feature_encoders(model_hparams.data_dir)['targets'].vocab_size
p = defaults
p.vocab_size['targets'] = vocab_size
def feature_encoders(self, data_dir):
# ...
Так что этот не делает много. Это вызывает hparams()
функция из базового класса, а затем меняет некоторые значения.
Теперь уже есть некоторые готовые проблемы, например Libri Speech:
@registry.register_problem()
class Librispeech(speech_recognition.SpeechRecognitionProblem):
# ..
Однако, чтобы применить мои модификации, я делаю это:
@registry.register_problem()
class LibrispeechMod(SpeechRecognitionProblemMod, Librispeech):
# ..
Это должно, если я не ошибаюсь, переписать все (с одинаковыми подписями) в Librispeech
и вместо вызова функций SpeechRecognitionProblemMod
,
Поскольку я смог обучить модель с этим кодом, я предполагаю, что она работает так, как задумано.
Теперь вот моя проблема:
После тренировки хочу сериализовать модель. Это обычно работает. Тем не менее, это не так с моим модом, и я на самом деле знаю, почему:
В определенный момент hparams()
вызывается. Отладка до этого момента покажет мне следующее:
self # {LibrispeechMod}
self.hparams # <bound method SpeechRecognitionProblem.hparams of ..>
self.feature_encoders # <bound method SpeechRecognitionProblemMod.feature_encoders of ..>
self.hparams
должно быть <bound method SpeechRecognitionProblemMod.hparams of ..>
! Казалось бы, по какой-то причине hparams()
из SpeechRecognitionProblem
вызывается напрямую вместо SpeechRecognitionProblemMod
, Но обратите внимание, что это правильный тип для feature_encoders()
!
Дело в том, что я знаю, что это работает во время тренировок. Я вижу, что гиперпараметры (hparams) применяются соответственно просто потому, что имена узлов графа модели меняются в результате моих модификаций.
Есть одна специальность, которую я должен указать. tensor2tensor
позволяет динамически загружать t2t_usr_dir
, которые являются дополнительными модулями Python, которые загружаются import_usr_dir
, Я использую эту функцию и в моем сценарии сериализации:
if usr_dir:
logging.info('Loading user dir %s' % usr_dir)
import_usr_dir(usr_dir)
Это может быть единственным виновником, которого я вижу на данный момент, хотя я не смогу сказать, почему это может вызвать проблему.
Если кто-то видит что-то, чего нет у меня, я был бы рад получить подсказку, что я здесь делаю неправильно.
Так какую ошибку вы получаете?
Ради полноты, это результат неправильного hparams()
вызываемый метод:
NotFoundError (see above for traceback): Restoring from checkpoint failed.
Key transformer/symbol_modality_256_256/softmax/weights_0 not found in checkpoint
symbol_modality_256_256
неправильно. Так должно быть symbol_modality_<vocab-size>_256
где <vocab-size>
это размер словаря, который устанавливается в SpeechRecognitionProblemMod.hparams
,
1 ответ
Итак, это странное поведение произошло из-за того, что я выполнял удаленную отладку и что исходные файлы usr_dir
не были правильно синхронизированы. Все работает, как задумано, но исходные файлы там, где не совпадают.
Дело закрыто.