Spark MultilayerPerceptronClassifier Класс Вероятности
Я опытный программист Python, пытающийся перевести некоторый код Python в Spark для задачи классификации. Я впервые работаю в Spark/Scala.
В Python нейронные сети Keras/tenorflow и sci-kit Learn отлично справляются с многоклассовой классификацией, и я могу легко вернуть 3 самых вероятных класса вместе с вероятностями, которые являются ключевыми для этого проекта.
В целом мне удалось переместить код в Spark (Scala), и я смог сгенерировать правильные прогнозы, но мне не удалось найти способ вернуть вероятности для самых предсказуемых классов из MultilayerPerceptronClassifier в MLlib.
Наиболее близкое решение, которое я нашел, было в этом посте: Как получить вероятности классификации из MultilayerPerceptronClassifier? Однако я не могу заставить работать решение в посте либо потому, что в нем отсутствует ключевой фрагмент кода, либо я слишком новичок в Scala (возможно, в последнем), чтобы внести необходимые корректировки.
Кто-нибудь решил эту проблему?
Это текущие версии в моей среде. Версия Spark: 2.1.1 Версия Scala: 2.11.8
Спасибо за вашу помощь,
РКБ
1 ответ
Если вы внимательно посмотрите на результаты MultilayerPerceptronClassificationModel.transform
(model
а также test
как определено в примере конвейера в официальной документации)
val result = model.transform(test)
result.printSchema
root
|-- label: double (nullable = true)
|-- features: vector (nullable = true)
|-- rawPrediction: vector (nullable = true)
|-- probability: vector (nullable = true)
|-- prediction: double (nullable = false)
вы увидите, что они содержат probability
колонка.
Хранится как o.a.s.ml.linalg.Vector
колонка:
result.select($"probability").show(3, false)
+---------------------------------------------------+
|probability |
+---------------------------------------------------+
|[2.630203838780848E-29,1.7323171642231641E-19,1.0] |
|[1.0,1.448487547623119E-121,4.530084532282489E-44] |
|[1.0,5.157808976162274E-122,2.5702890543589884E-44]|
+---------------------------------------------------+
only showing top 3 rows
и могут быть доступны с использованием стандартных методов.
Эта функция доступна начиная с версии Spark 2.3 ( SPARK-12664 Expose вероятно, rawPrediction в MultilayerPerceptronClassificationModel).