Как интерпретировать результаты Spark OneHotEncoder

Я прочитал запись ОНЕ из документов Spark,

Горячее кодирование отображает столбец индексов меток в столбец двоичных векторов, не более одного единственного значения. Это кодирование позволяет алгоритмам, которые ожидают непрерывных функций, таких как логистическая регрессия, использовать категориальные функции.

но, к сожалению, они не дают полного объяснения результата ОНЕ. Итак, запустил данный код:

from pyspark.ml.feature import OneHotEncoder, StringIndexer

df = sqlContext.createDataFrame([
(0, "a"),
(1, "b"),
(2, "c"),
(3, "a"),
(4, "a"),
(5, "c")
], ["id", "category"])

stringIndexer = StringIndexer(inputCol="category",      outputCol="categoryIndex")
model = stringIndexer.fit(df)
indexed = model.transform(df)

encoder = OneHotEncoder(inputCol="categoryIndex", outputCol="categoryVec")
encoded = encoder.transform(indexed)
encoded.show()

И получил результаты:

   +---+--------+-------------+-------------+
   | id|category|categoryIndex|  categoryVec|
   +---+--------+-------------+-------------+
   |  0|       a|          0.0|(2,[0],[1.0])|
   |  1|       b|          2.0|    (2,[],[])|
   |  2|       c|          1.0|(2,[1],[1.0])|
   |  3|       a|          0.0|(2,[0],[1.0])|
   |  4|       a|          0.0|(2,[0],[1.0])|
   |  5|       c|          1.0|(2,[1],[1.0])|
   +---+--------+-------------+-------------+

Как я могу интерпретировать результаты OHE(последний столбец)?

1 ответ

Решение

Горячее кодирование преобразует значения в categoryIndex в двоичный вектор, где максимум одно значение может быть 1. Поскольку имеется три значения, вектор имеет длину 2, и отображение выглядит следующим образом:

0  -> 10
1  -> 01
2  -> 00

(Почему отображение такое? Посмотрите на этот вопрос о том, что одноразовый кодер отбрасывает последнюю категорию.)

Значения в столбце categoryVecименно они, но представлены в разреженном формате. В этом формате нули вектора не печатаются. Первое значение (2) показывает длину вектора, второе значение - массив, в котором перечислены ноль или более индексов, в которых найдены ненулевые записи. Третье значение - это другой массив, который сообщает, какие числа находятся по этим индексам. Итак, (2,[0],[1.0]) означает вектор длины 2 с 1,0 в позиции 0 и 0 в другом месте.

Смотрите: https://spark.apache.org/docs/latest/mllib-data-types.html

Другие вопросы по тегам