Писпарк НЛП - CountVectorizer Max DF или TF. Как отфильтровать общие вхождения из набора данных
Я использую CountVectorizer
чтобы подготовить набор данных для ML. Я хочу отфильтровать редкие слова, и я использую параметр CountVectorizer
, minDF или minTF для этого. Я также хотел бы удалить элементы, которые часто появляются в моем наборе данных. Я не вижу параметр maxTF или maxDF, который я могу установить. Есть ли хороший способ сделать это?
df = spark.createDataFrame(
[(0, ["a", "b", "c","b"]), (1, ["a", "b", "b", "c", "a"])],
["label", "raw"])
Так что в этом случае, если я хотел удалить параметры, которые появлялись "4" раза или 40% времени, и те, которые появлялись 2 раза или меньше. Это уберет "b" и "c".
В настоящее время я бегу CountVectorizer(minDf=3......)
для нижней границы требования Как я могу отфильтровать элементы, которые появляются чаще, чем я хочу моделировать.
1 ответ
Я полагаю, что вы запрашиваете параметр CountVectorizer, но похоже, что пока нет параметров для этого. Это не простой или практичный способ сделать это простым, но это работает. Я надеюсь, что это поможет вам:
from pyspark.sql.types import *
from pyspark.sql import functions as F
df = spark.createDataFrame(
[(0, ["a", "b", "c","b"]), (1, ["a", "b", "b", "c", "a"])],
["label", "raw"])
counts_df = df \
.select(F.explode('raw').alias('testCol')) \
.groupby('testCol') \
.agg(F.count('testCol').alias('count')).persist() # this will be used multiple times
total = counts_df \
.agg(F.sum('count').alias('total')) \
.rdd.take(1)[0]['total']
min_times = 3
max_times = total * 0.4
filtered_elements = counts_df \
.filter((min_times>F.col('count')) | (F.col('count')>max_times)) \
.select('testCol') \
.rdd.map(lambda row: row['testCol']) \
.collect()
def removeElements(arr):
return list(set(arr) - set(filtered_elements))
remove_udf = F.udf(removeElements, ArrayType(StringType()))
filtered_df = df \
.withColumn('raw', remove_udf('raw'))
Результаты:
filtered_df.show()
+-----+---+
|label|raw|
+-----+---+
| 0|[a]|
| 1|[a]|
+-----+---+