Как я могу вызвать классификаторы scikit-learn из Java?

У меня есть классификатор, который я обучил, используя Scikit-Learn Python. Как я могу использовать классификатор из программы Java? Могу ли я использовать Jython? Есть ли способ сохранить классификатор в Python и загрузить его в Java? Есть ли другой способ использовать его?

5 ответов

Решение

Вы не можете использовать jython, так как scikit-learn в значительной степени опирается на numpy и scipy, которые имеют много скомпилированных расширений C и Fortran, поэтому не могут работать в jython.

Самым простым способом использования scikit-learn в среде Java будет:

  • представить классификатор как службу HTTP / Json, например, с помощью микрорамки, такой как колба, бутылка или карниз, и вызвать его из java, используя клиентскую библиотеку HTTP.

  • написать приложение-оболочку командной строки на python, которое считывает данные на stdin и выводит прогнозы на stdout, используя некоторый формат, такой как CSV или JSON (или некоторое двоичное представление более низкого уровня), и вызывает программу python из java, например, с использованием Apache Commons Exec.

  • сделать так, чтобы программа python выводила необработанные числовые параметры, полученные в подходящее время (как правило, в виде массива значений с плавающей запятой), и переопределяет функцию предсказания в java (это обычно легко для линейных предсказательных моделей, где предсказание часто представляет собой просто пороговое произведение точек),

Последний подход будет намного более трудоемким, если вам потребуется повторно реализовать извлечение функций также в Java.

Наконец, вы можете использовать библиотеку Java, такую ​​как Weka или Mahout, которая реализует необходимые вам алгоритмы вместо того, чтобы пытаться использовать scikit-learn from Java.

Для этого есть проект JPMML.

Во-первых, вы можете сериализовать модель scikit-learn для PMML (которая является внутренним XML), используя библиотеку sklearn2pmml непосредственно из python, или сначала выгрузить ее в python и конвертировать, используя jpmml-sklearn в java или из командной строки, предоставленной этой библиотекой. Далее, вы можете загрузить файл pmml, десериализовать и выполнить загруженную модель, используя jpmml-valuator в вашем коде Java.

Этот способ работает не со всеми моделями scikit-learn, но со многими из них.

Вы можете использовать портер, я тестировал sklearn-porter ( https://github.com/nok/sklearn-porter), и он хорошо работает для Java.

Мой код следующий:

import pandas as pd
from sklearn import tree
from sklearn_porter import Porter

train_dataset = pd.read_csv('./result2.csv').as_matrix()

X_train = train_dataset[:90, :8]
Y_train = train_dataset[:90, 8:]

X_test = train_dataset[90:, :8]
Y_test = train_dataset[90:, 8:]

print X_train.shape
print Y_train.shape


clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, Y_train)

porter = Porter(clf, language='java')
output = porter.export(embed_data=True)
print(output)

В моем случае я использую DecisionTreeClassifier, и вывод

печать (выход)

следующий код в виде текста в консоли:

class DecisionTreeClassifier {

  private static int findMax(int[] nums) {
    int index = 0;
    for (int i = 0; i < nums.length; i++) {
        index = nums[i] > nums[index] ? i : index;
    }
    return index;
  }


  public static int predict(double[] features) {
    int[] classes = new int[2];

    if (features[5] <= 51.5) {
        if (features[6] <= 21.0) {

            // HUGE amount of ifs..........

        }
    }

    return findMax(classes);
  }

  public static void main(String[] args) {
    if (args.length == 8) {

        // Features:
        double[] features = new double[args.length];
        for (int i = 0, l = args.length; i < l; i++) {
            features[i] = Double.parseDouble(args[i]);
        }

        // Prediction:
        int prediction = DecisionTreeClassifier.predict(features);
        System.out.println(prediction);

    }
  }
}

Вот некоторый код для решения JPMML:

ЧАСТЬ ПИТОНА

# helper function to determine the string columns which have to be one-hot-encoded in order to apply an estimator.
def determine_categorical_columns(df):
    categorical_columns = []
    x = 0
    for col in df.dtypes:
        if col == 'object':
            val = df[df.columns[x]].iloc[0]
            if not isinstance(val,Decimal):
                categorical_columns.append(df.columns[x])
        x += 1
    return categorical_columns

categorical_columns = determine_categorical_columns(df)
other_columns = list(set(df.columns).difference(categorical_columns))


#construction of transformators for our example
labelBinarizers = [(d, LabelBinarizer()) for d in categorical_columns]
nones = [(d, None) for d in other_columns]
transformators = labelBinarizers+nones

mapper = DataFrameMapper(transformators,df_out=True)
gbc = GradientBoostingClassifier()

#construction of the pipeline
lm = PMMLPipeline([
    ("mapper", mapper),
    ("estimator", gbc)
])

--JAVA PART -

//Initialisation.
String pmmlFile = "ScikitLearnNew.pmml";
PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(new FileInputStream(pmmlFile));
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
MiningModelEvaluator evaluator = (MiningModelEvaluator) modelEvaluatorFactory.newModelEvaluator(pmml);

//Determine which features are required as input
HashMap<String, Field>() inputFieldMap = new HashMap<String, Field>();
for (int i = 0; i < evaluator.getInputFields().size();i++) {
  InputField curInputField = evaluator.getInputFields().get(i);
  String fieldName = curInputField.getName().getValue();
  inputFieldMap.put(fieldName.toLowerCase(),curInputField.getField());
}


//prediction

HashMap<String,String> argsMap = new HashMap<String,String>();
//... fill argsMap with input

Map<FieldName, ?> res;
// here we keep only features that are required by the model
Map<FieldName,String> args = new HashMap<FieldName, String>();
Iterator<String> iter = argsMap.keySet().iterator();
while (iter.hasNext()) {
  String key = iter.next();
  Field f = inputFieldMap.get(key);
  if (f != null) {
    FieldName name =f.getName();
    String value = argsMap.get(key);
    args.put(name, value);
  }
}
//the model is applied to input, a probability distribution is obtained
res = evaluator.evaluate(args);
SegmentResult segmentResult = (SegmentResult) res;
Object targetValue = segmentResult.getTargetValue();
ProbabilityDistribution probabilityDistribution = (ProbabilityDistribution) targetValue;

В качестве альтернативы вы можете просто сгенерировать код Python из обученной модели. Вот инструмент, который может помочь вам с этим https://github.com/BayesWitnesses/m2cgen

Я оказался в похожей ситуации. Я рекомендую вырезать классификатор микросервис. У вас может быть микросервис классификатора, который работает на python, а затем предоставляет вызовы этому сервису через некоторый RESTFul API, обеспечивающий формат обмена данными JSON/XML. Я думаю, что это более чистый подход.

Другие вопросы по тегам