Обучающий классификатор текста пространственно-трансформеров - пример минимального обучения
После ряда экспериментов по линиям, указанным в этих примерах:
https://github.com/explosion/spacy-transformers/blob/master/examples/train_textcat.py
Я обнаружил, что не могу наблюдать эффект любого обучения при вызове nlp.update на моделях пространственного преобразователя. Я пробовал с en_trf_bertbaseuncased_lg, как показано ниже, и с моделью en_trf_distilbertbaseuncased_lg безуспешно. Однако я могу получить классификацию текста с помощью просторных примеров TextCategorizer и LSTM, которые работают.
Поэтому я хотел бы спросить, что я мог бы сделать, чтобы изменить приведенный ниже код, чтобы получить результат менее 1.0 для "THE_POSITIVE_LABEL" при вызове doc.cats для этого тестового предложения. В настоящее время он работает без ошибок, но всегда возвращает 1.0 для оценки. Я попытался использовать этот пример после запуска правильного набора тренировок и наблюдения за одинаковыми значениями P,R,F потерь, которые просто прыгали вокруг каждой оценки. Исправленная версия может тогда служить простой проверкой функциональности.
import spacy
from collections import Counter
nlp = spacy.load('en_trf_bertbaseuncased_lg')
textcat = nlp.create_pipe(
"trf_textcat",
config={
"architecture": "softmax_class_vector", # have also tried "softmax_last_hidden" with "words_per_batch" like in one of the examples
'token_vector_width': 768 # added as otherwise it complains about textcat config not having 'token_vector_width'
}
)
textcat.add_label("THE_POSITIVE_LABEL")
nlp.add_pipe(textcat, last=True)
nlp.begin_training() # added as otherwise it says trf_textcat has no model when we call doc.cats
print(nlp("an example of a document that does not match the label").cats)
#{'THE_POSITIVE_LABEL': 1.0} is printed
optimizer = nlp.resume_training()
optimizer.alpha = 0.001
optimizer.trf_weight_decay = 0.005
optimizer.L2 = 0.0
optimizer.trf_lr = 2e-5
losses = Counter()
texts = ['an example of a document that does not match the label',]
annotations = [{'THE_POSITIVE_LABEL': 0.},]
nlp.update(texts, annotations, sgd=optimizer, drop=0.1, losses=losses)
print(nlp("an example of a document that does not match the label").cats)
#{'THE_POSITIVE_LABEL': 1.0} is again printed