Spark MultilayerPerceptronClassifier Класс Вероятности

Я опытный программист Python, пытающийся перевести некоторый код Python в Spark для задачи классификации. Я впервые работаю в Spark/Scala.

В Python нейронные сети Keras/tenorflow и sci-kit Learn отлично справляются с многоклассовой классификацией, и я могу легко вернуть 3 самых вероятных класса вместе с вероятностями, которые являются ключевыми для этого проекта.

В целом мне удалось переместить код в Spark (Scala), и я смог сгенерировать правильные прогнозы, но мне не удалось найти способ вернуть вероятности для самых предсказуемых классов из MultilayerPerceptronClassifier в MLlib.

Наиболее близкое решение, которое я нашел, было в этом посте: Как получить вероятности классификации из MultilayerPerceptronClassifier? Однако я не могу заставить работать решение в посте либо потому, что в нем отсутствует ключевой фрагмент кода, либо я слишком новичок в Scala (возможно, в последнем), чтобы внести необходимые корректировки.

Кто-нибудь решил эту проблему?

Это текущие версии в моей среде. Версия Spark: 2.1.1 Версия Scala: 2.11.8

Спасибо за вашу помощь,

РКБ

1 ответ

Если вы внимательно посмотрите на результаты MultilayerPerceptronClassificationModel.transform (model а также test как определено в примере конвейера в официальной документации)

val result = model.transform(test)

result.printSchema
root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)

вы увидите, что они содержат probability колонка.

Хранится как o.a.s.ml.linalg.Vector колонка:

result.select($"probability").show(3, false)
+---------------------------------------------------+
|probability                                        |
+---------------------------------------------------+
|[2.630203838780848E-29,1.7323171642231641E-19,1.0] |
|[1.0,1.448487547623119E-121,4.530084532282489E-44] |
|[1.0,5.157808976162274E-122,2.5702890543589884E-44]|
+---------------------------------------------------+
only showing top 3 rows

и могут быть доступны с использованием стандартных методов.

Эта функция доступна начиная с версии Spark 2.3 ( SPARK-12664 Expose вероятно, rawPrediction в MultilayerPerceptronClassificationModel).

Другие вопросы по тегам