PySpark заполнить отсутствующее / неправильное значение сгруппированным средним
У меня есть Spark dataframe с одним отсутствующим и одним неправильным значением.
from pyspark.sql import Row
from pyspark.sql.types import StringType, DoubleType, StructType, StructField
# fruit sales data
data = [Row(id='01', fruit='Apple', qty=5.0),
Row(id='02', fruit='Apple', qty=1.0),
Row(id='03', fruit='Apple', qty=None),
Row(id='04', fruit='Pear', qty=6.0),
Row(id='05', fruit='Pear', qty=2.0),
Row(id='06', fruit='Mango', qty=6.0),
Row(id='07', fruit='Mango', qty=-4.0),
Row(id='08', fruit='Mango', qty=2.0)]
# create dataframe
df = spark.createDataFrame(data)
df.show()
+-----+---+----+
|fruit| id| qty|
+-----+---+----+
|Apple| 01| 5.0|
|Apple| 02| 1.0|
|Apple| 03|null|
| Pear| 04| 6.0|
| Pear| 05| 2.0|
|Mango| 06| 6.0|
|Mango| 07|-4.0|
|Mango| 08| 2.0|
+-----+---+----+
Выполнить заливку по всему значению столбца просто. Но как я могу сгруппировать среднее? Для иллюстрации я хотел бы null
в строке 3 следует заменить на mean(qty)
от Apple
- в этом случае (5+1)/2=3. Точно так же -4.0
неверное значение (без отрицательного кол-во) в строке 7, которое я хотел бы заменить на (6+2)/2=4
В чистом Python я бы сделал что-то вроде этого:
def replace_with_grouped_mean(df, value, column, to_groupby):
invalid_mask = (df[column] == value)
# get the mean without the invalid value
means_by_group = (df[~invalid_mask].groupby(to_groupby)[column].mean())
# get an array of the means for all of the data
means_array = means_by_group[df[to_groupby].values].values
# assign the invalid values to means
df.loc[invalid_mask, column] = means_array[invalid_mask]
return df
И в конечном итоге сделать:
x = replace_with_grouped_mean(df=df, value=-4, column='qty', to_groupby='fruit')
Однако я не совсем уверен, как этого добиться в PySpark. Любая помощь / указатели приветствуются!
1 ответ
Примечание: когда мы делаем группу, строки, имеющие Null
игнорируются Если у нас есть 3 строки, один из которых имеет значение Null
, то среднее значение делится на 2, а не на 3, потому что третье значение было Null
, Ключевым моментом здесь является использование функции Window().
from pyspark.sql.functions import avg, col, when
from pyspark.sql.window import Window
w = Window().partitionBy('fruit')
#Replace negative values of 'qty' with Null, as we don't want to consider them while averaging.
df = df.withColumn('qty',when(col('qty')<0,None).otherwise(col('qty')))
df = df.withColumn('qty',when(col('qty').isNull(),avg(col('qty')).over(w)).otherwise(col('qty')))
df.show()
+-----+---+---+
|fruit| id|qty|
+-----+---+---+
| Pear| 04|6.0|
| Pear| 05|2.0|
|Mango| 06|6.0|
|Mango| 07|4.0|
|Mango| 08|2.0|
|Apple| 01|5.0|
|Apple| 02|1.0|
|Apple| 03|3.0|
+-----+---+---+