Модель SVM одного класса для классификации текста (scikit-learn)
Я пытаюсь классифицировать набор текстов, который будет использоваться для предсказания похожих текстов в тестовом наборе текстов. Я использую модель one_class_svm. "author_corpus" содержит список текстов, написанных одним автором, а "test_corpus" содержит список текстов, написанных как другими авторами, так и первоначальным автором. Я пытаюсь использовать one_class_svm для идентификации автора в тестовых текстах.
def analyse_corpus(author_corpus, test_corpus):
vectorizer = TfidfVectorizer()
author_vectors = vectorizer.fit_transform(author_corpus)
test_vectors = vectorizer.fit_transform(test_corpus)
model = OneClassSVM(gamma='auto')
model.fit(author_vectors)
test = model.predict(test_vectors)
Я получаю ошибку значения:
X.shape[1] = 2484 should be equal to 1478, the number of features at training time
Как я могу реализовать эту модель, чтобы точно предсказать авторство текстов в наборе тестов с учетом одного автора в наборе поездов? Любая помощь приветствуется.
Для справки, вот ссылка на руководство по модели one_class_svm: https://scikit-learn.org/stable/modules/generated/sklearn.svm.OneClassSVM.html
1 ответ
Вам следует fit
(обучить) модель на train
данных и делать прогнозы, используя обученную модель на test
данные.
fit
: подогнать (тренировать) модельfit_transform
: соответствует модели, а затем делает прогнозыtransform
: Делает предсказания
Вы делаете ошибку
test_vectors = vectorizer.fit_transform(test_corpus)
Пример использования
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
train = fetch_20newsgroups(subset='train', categories=['alt.atheism'], shuffle=True, random_state=42).data
test = fetch_20newsgroups(subset='train', categories=['alt.atheism', 'soc.religion.christian'], shuffle=True, random_state=42).data
vectorizer = TfidfVectorizer()
train_vectors = vectorizer.fit_transform(train)
test_vectors = vectorizer.transform(test)
model = OneClassSVM(gamma='auto')
model.fit(train_vectors)
test_predictions = model.predict(test_vectors)