Как взорвать столбец ArrayType в искровом фрейме данных, который содержит нули и пустые массивы.
У меня есть датафрейм, состоящий из следующих данных
val df = List(
(1,"wwe",List(1,2,3)),
(2,"dsad",List.empty),
(3,"dfd",null)).toDF("id","name","value")
df.show
+---+----+---------+
| id|name| value|
+---+----+---------+
| 1| wwe|[1, 2, 3]|
| 2|dsad| []|
| 3| dfd| null|
+---+----+---------+
Чтобы взорвать значения столбцов массива, я использовал следующую логику
def explodeWithNull(f:StructField): Column ={
explode(
when(
col(f.name).isNotNull, col(f.name)
).otherwise(
f.dataType.asInstanceOf[ArrayType].elementType match{
case StringType => array(lit(""))
case DoubleType => array(lit(0.0))
case IntegerType => array(lit(0))
case _ => array(lit(""))
}
)
)
}
def explodeAllArraysColumns(dataframe: DataFrame): DataFrame = {
val schema: StructType = dataframe.schema
val arrayFileds: Seq[StructField] = schema.filter(f => f.dataType.typeName == "array")
arrayFileds.foldLeft(dataframe) {
(df: DataFrame, f: StructField) => df.withColumn(f.name,explodeWithNull(f))
}
}
explodeAllArraysColumns(df).show
+---+----+-----+
| id|name|value|
+---+----+-----+
| 1| wwe| 1|
| 1| wwe| 2|
| 1| wwe| 3|
| 3| dfd| 0|
+---+----+-----+
взорвавшись таким образом, я пропускаю строку, которая является пустым массивом в df. В идеале я не хочу пропустить эту строку, я хочу либо значение NULL, либо значение по умолчанию для этого столбца в разобранном фрейме данных. Как этого добиться?
2 ответа
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame
from pyspark.sql import Row
from pyspark.sql.types import ArrayType
from pyspark.sql.functions import *
from functools import reduce
def explode_outer(df, columns_to_explode):
array_fields = dict([(field.name, field.dataType)
for field in df.schema.fields
if type(field.dataType) == ArrayType])
return reduce(lambda df_with_explode, column:
df_with_explode.withColumn(column, explode(
when(size(df_with_explode[column]) != 0, df_with_explode[column])
.otherwise(array(lit(None).cast(array_fields[column].elementType))))),
columns_to_explode, df)
from pyspark.sql.functions import *
def flatten_df(nested_df):
flat_cols = [c[0] for c in nested_df.dtypes if c[1][:6] != 'struct']
nested_cols = [c[0] for c in nested_df.dtypes if c[1][:6] == 'struct']
flat_df = nested_df.select(flat_cols +
[col(nc + '.' + c).alias(nc + '_' + c)
for nc in nested_cols
for c in nested_df.select(nc + '.*').columns])
print("flatten_df_count :", flat_df.count())
return flat_df
def explode_df(nested_df):
flat_cols = [c[0] for c in nested_df.dtypes if c[1][:6] != 'struct' and c[1][:5] != 'array']
array_cols = [c[0] for c in nested_df.dtypes if c[1][:5] == 'array']
for array_col in array_cols:
schema = new_df.select(array_col).dtypes[0][1]
nested_df = nested_df.withColumn(array_col, when(col(array_col).isNotNull(), col(array_col)).otherwise(array(lit(None)).cast(schema)))
nested_df = nested_df.withColumn("tmp", arrays_zip(*array_cols)).withColumn("tmp", explode("tmp")).select([col("tmp."+c).alias(c) for c in array_cols] + flat_cols)
print("explode_dfs_count :", nested_df.count())
return nested_df
new_df = flatten_df(myDf)
while True:
array_cols = [c[0] for c in new_df.dtypes if c[1][:5] == 'array']
if len(array_cols):
new_df = flatten_df(explode_df(new_df))
else:
break
new_df.printSchema()
Использовал
arrays_zip
а также
explode
решить эту проблему