Классификация с использованием приблизительных ближайших соседей в Scikit-Learn

У меня есть помеченный набор данных, имеющий набор функций 46D и около 5000 выборок, которые я хочу классифицировать, используя Приблизительные ближайшие соседи.

Поскольку я знаком с Scikit-Learn, я хочу использовать его для достижения этой цели.

Документация Scikit перечисляет LSHForest как один из вероятных методов для ANN, но мне неясно, как применить это для целей классификации.

1 ответ

Решение

Очень хороший вопрос К сожалению, в настоящее время scikit-learn не поддерживает пользовательскую модель соседей, однако вы можете самостоятельно реализовать простую оболочку, такую ​​как

from sklearn.neighbors import LSHForest
import numpy as np
from scipy.stats import mode

class LSH_KNN:

    def __init__(self, **kwargs):
        self.n_neighbors = kwargs['n_neighbors']
        self.lsh = LSHForest(**kwargs)

    def fit(self, X, y):
        self.y = y
        self.lsh.fit(X)

    def predict(self, X):
        _, indices = self.lsh.kneighbors(X, n_neighbors = self.n_neighbors)
        votes, _ = mode(self.y[indices], axis=1)
        return votes.flatten()
Другие вопросы по тегам