Как разделить или умножить все нестроковые столбцы фрейма данных PySpark с константой с плавающей точкой?

Мой входной фрейм выглядит следующим образом

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("Basics").getOrCreate()

df=spark.createDataFrame(data=[('Alice',4.300,None),('Bob',float('nan'),897)],schema=['name','High','Low'])

+-----+----+----+
| name|High| Low|
+-----+----+----+
|Alice| 4.3|null|
|  Bob| NaN| 897|
+-----+----+----+

Ожидаемый результат при делении на 10,0

+-----+----+----+
| name|High| Low|
+-----+----+----+
|Alice| 0.43|null|
|  Bob| NaN| 89.7|
+-----+----+----+

1 ответ

Решение

Я не знаю ни о какой библиотечной функции, которая могла бы сделать это, но этот фрагмент, похоже, отлично справляется:

CONSTANT = 10.0

for field in df.schema.fields:
    if str(field.dataType) in ['DoubleType', 'FloatType', 'LongType', 'IntegerType', 'DecimalType']:
        name = str(field.name)
        df = df.withColumn(name, col(name)/CONSTANT)


df.show()

выходы:

+-----+----+----+
| name|High| Low|
+-----+----+----+
|Alice|0.43|null|
|  Bob| NaN|89.7|
+-----+----+----+

Приведенный ниже код должен эффективно решить вашу проблему.

from pyspark.sql.functions import col

allowed_types = ['DoubleType', 'FloatType', 'LongType', 'IntegerType', 'DecimalType']

df = df.select(*[(col(field.name)/10).name(field.name) if str(field.dataType) in allowed_types else col(field.name) for field in df.schema.fields]

Итеративное использование withColumn может быть не очень хорошей идеей, если количество столбцов велико.
Это связано с тем, что фреймы данных PySpark неизменяемы, поэтому, по сути, мы будем создавать новый DataFrame для каждого столбца, приведенного с использованием withColumn, что будет очень медленным процессом.

Вот здесь и пригодится приведенный выше код.

Другие вопросы по тегам