PySpark:: FP-Growth алгоритм (повысить ValueError("Параметры должны быть либо картой параметров, либо списком / кортежом карт параметров")

Я новичок в PySpark. Я использую вычислительную ассоциацию FPgrowth в PySpark. Я следовал за шагами ниже.

Пример данных

from pyspark.sql.session import SparkSession

spark = SparkSession.builder.getOrCreate()

# make some test data
columns = ['customer_id', 'product_id']
vals = [
     (370, 154),
     (41, 40),
     (109, 173),
     (18, 55),
     (105, 126),
     (370, 121),
     (41, 32323),
     (109, 22),
     (18, 55),
     (105, 133),
     (109, 22),
     (18, 55),
     (105, 133)
]

df = spark.createDataFrame(vals, columns)

df.show()
+-----------+----------+
|customer_id|product_id|
+-----------+----------+
|        370|       154|
|         41|        40|
|        109|       173|
|         18|        55|
|        105|       126|
|        370|       121|
|         41|     32323|
|        109|        22|
|         18|        55|
|        105|       133|
|        109|        22|
|         18|        55|
|        105|       133|
+-----------+----------+

### Prepare input data
from pyspark.sql.functions import collect_list, col

transactions = df.groupBy("customer_id")\
      .agg(collect_list("product_id").alias("product_ids"))\
      .rdd\
      .map(lambda x: (x.customer_id, x.product_ids))

transactions.collect()
[(370, [121, 154]),
 (41, [32323, 40]),
 (105, [133, 133, 126]),
 (18, [55, 55, 55]),
 (109, [22, 173, 22])]

## Convert .rdd to spark dataframe 
df2 = spark.createDataFrame(transactions)
df2.show()
+---+---------------+
| _1|             _2|
+---+---------------+
|370|     [121, 154]|
| 41|    [32323, 40]|
|105|[126, 133, 133]|
| 18|   [55, 55, 55]|
|109|  [22, 173, 22]|
+---+---------------+

df3 = df2.selectExpr("_1 as customer_id", "_2 as product_id")
df3.show()
df3.printSchema()
+-----------+---------------+
|customer_id|     product_id|
+-----------+---------------+
|        370|     [154, 121]|
|         41|    [32323, 40]|
|        105|[126, 133, 133]|
|         18|   [55, 55, 55]|
|        109|  [173, 22, 22]|
+-----------+---------------+

root
 |-- customer_id: long (nullable = true)
 |-- product_id: array (nullable = true)
 |    |-- element: long (containsNull = true)

 ## FPGrowth Model Building
 from pyspark.ml.fpm import FPGrowth
 fpGrowth = FPGrowth(itemsCol="product_id", minSupport=0.5, minConfidence=0.6)
 model = fpGrowth.fit(df3)

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-12-aa1f71745240> in <module>()
----> 1 model = fpGrowth.fit(df3)

/usr/lib/spark/python/pyspark/ml/base.py in fit(self, dataset, params)
     62                 return self.copy(params)._fit(dataset)
     63             else:
---> 64                 return self._fit(dataset)
     65         else:
     66             raise ValueError("Params must be either a param map or a list/tuple of param maps, "

/usr/lib/spark/python/pyspark/ml/wrapper.py in _fit(self, dataset)
    263 
    264     def _fit(self, dataset):
--> 265         java_model = self._fit_java(dataset)
    266         return self._create_model(java_model)
    267 

/usr/lib/spark/python/pyspark/ml/wrapper.py in _fit_java(self, dataset)
    260         """
    261         self._transfer_params_to_java()
--> 262         return self._java_obj.fit(dataset._jdf)
    263 
    264     def _fit(self, dataset):

/usr/lib/spark/python/lib/py4j-0.10.4-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1131         answer = self.gateway_client.send_command(command)
   1132         return_value = get_return_value(
-> 1133             answer, self.gateway_client, self.target_id, self.name)
   1134 
   1135         for temp_arg in temp_args:

/usr/lib/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
     61     def deco(*a, **kw):
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:
     65             s = e.java_exception.toString()

/usr/lib/spark/python/lib/py4j-0.10.4-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    317                 raise Py4JJavaError(
    318                     "An error occurred while calling {0}{1}{2}.\n".
--> 319                     format(target_id, ".", name), value)
    320             else:
    321                 raise Py4JError(

Я посмотрел вверх, но я не понял, что пошло не так. Единственное, на что я могу указать, я преобразовал RDD в информационный фрейм.

Кто-нибудь может указать мне, что я делаю неправильно?

2 ответа

Решение

Ну я просто понял FPGrowth из pyspark.ml.fpm принимает фрейм данных pyspark, а не rdd. Таким образом, вышеупомянутый метод преобразовал мой набор данных в rdd.

Мне удалось избежать ситуации, предъявив иск PySpark collect_set список с groupby, чтобы получить фрейм данных и передать.

from pyspark.sql.session import SparkSession

# instantiate Spark
spark = SparkSession.builder.getOrCreate()

# make some test data
columns = ['customer_id', 'product_id']
vals = [
     (370, 154),
     (370, 40),
     (370, 173),
     (41, 55),
     (41, 126),
     (41, 121),
     (41, 321),
     (105, 22),
     (105, 55),
     (105, 133),
     (109, 22),
     (109, 55),
     (109, 133)    
]


# create DataFrame
df = spark.createDataFrame(vals, columns)

df.show()
+-----------+----------+
|customer_id|product_id|
+-----------+----------+
|        370|       154|
|        370|        40|
|        370|       173|
|         41|        55|
|         41|       126|
|         41|       121|
|         41|     32323|
|        105|        22|
|        105|        55|
|        105|       133|
|        109|        22|
|        109|        55|
|        109|       133|
+-----------+----------+

# Create dataframe for FPGrowth model input
from pyspark.sql.functions import collect_list, col
from pyspark.sql import functions as F 
from pyspark.sql.functions import *
transactions = df.groupBy("customer_id")\
      .agg(F.collect_set("product_id"))

transactions.show()
+-----------+-----------------------+
|customer_id|collect_set(product_id)|
+-----------+-----------------------+
|        370|         [154, 173, 40]|
|         41|    [321, 121, 126, 55]|
|        105|          [133, 22, 55]|
|        109|          [133, 22, 55]|
+-----------+-----------------------+

# FPGrowth model 
from pyspark.ml.fpm import FPGrowth
fpGrowth = FPGrowth(itemsCol="collect_set(product_id)", minSupport=0.5, minConfidence=0.6
 model_working = fpGrowth.fit(transactions)

# Display frequent itemsets
model_working.freqItemsets.show()
+-------------+----+
|        items|freq|
+-------------+----+
|         [55]|   3|
|         [22]|   2|
|     [22, 55]|   2|
|        [133]|   2|
|    [133, 22]|   2|
|[133, 22, 55]|   2|
|    [133, 55]|   2|
+-------------+----+

# Display generated association rules.
model_working.associationRules.show()

# transform examines the input items against all the association rules and summarise the
# consequents as prediction
model_working.transform(transactions).show()

+----------+----------+------------------+
|antecedent|consequent|        confidence|
+----------+----------+------------------+
|     [133]|      [22]|               1.0|
|     [133]|      [55]|               1.0|
| [133, 55]|      [22]|               1.0|
| [133, 22]|      [55]|               1.0|
|      [22]|      [55]|               1.0|
|      [22]|     [133]|               1.0|
|      [55]|      [22]|0.6666666666666666|
|      [55]|     [133]|0.6666666666666666|
|  [22, 55]|     [133]|               1.0|
+----------+----------+------------------+

+-----------+-----------------------+----------+
|customer_id|collect_set(product_id)|prediction|
+-----------+-----------------------+----------+
|        370|         [154, 173, 40]|        []|
|         41|    [321, 121, 126, 55]| [22, 133]|
|        105|          [133, 22, 55]|        []|
|        109|          [133, 22, 55]|        []|
+-----------+-----------------------+----------+

Если вы внимательно проверите трассировку, вы увидите источник проблемы:

Caused by: org.apache.spark.SparkException: Items in a transaction must be unique but got ....

замещать collect_list с collect_set и проблема будет исправлена.

Другие вопросы по тегам