Почему изменяемая карта становится неизменяемой автоматически в UserDefinedAggregateFunction(UDAF) в Spark

Я пытаюсь определить UserDefinedAggregateFunction(UDAF) в Spark, который подсчитывает количество вхождений для каждого уникального значения в столбце группы.

Это пример: предположим, у меня есть датафрейм df как это,

+----+----+
|col1|col2|
+----+----+
|   a|  a1|
|   a|  a1|
|   a|  a2|
|   b|  b1|
|   b|  b2|
|   b|  b3|
|   b|  b1|
|   b|  b1|
+----+----+

У меня будет UDAF DistinctValues

val func = new DistinctValues

Затем я применяю его к фрейму данных

val agg_value = df.groupBy("col1").agg(func(col("col2")).as("DV"))

Я ожидаю что-то вроде этого:

+----+--------------------------+
|col1|DV                        |
+----+--------------------------+
|   a|  Map(a1->2, a2->1)       |
|   b|  Map(b1->3, b2->1, b3->1)|
+----+--------------------------+

Итак, я выступил с UDAF, как это,

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.MapType
import org.apache.spark.sql.types.LongType
import Array._

class DistinctValues extends UserDefinedAggregateFunction {
  def inputSchema: org.apache.spark.sql.types.StructType = StructType(StructField("value", StringType) :: Nil)

  def bufferSchema: StructType = StructType(StructField("values", MapType(StringType, LongType))::Nil)

  def dataType: DataType =  MapType(StringType, LongType)
  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = scala.collection.mutable.Map()
  }

  def update(buffer: MutableAggregationBuffer, input: Row) : Unit = {
    val str = input.getAs[String](0)
    var mp = buffer.getAs[scala.collection.mutable.Map[String, Long]](0)
    var c:Long = mp.getOrElse(str, 0)
    c = c + 1
    mp.put(str, c)
    buffer(0) = mp
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) : Unit = {
    var mp1 = buffer1.getAs[scala.collection.mutable.Map[String, Long]](0)
    var mp2 = buffer2.getAs[scala.collection.mutable.Map[String, Long]](0)
    mp2 foreach {
        case (k ,v) => {
            var c:Long = mp1.getOrElse(k, 0)
            c = c + v
            mp1.put(k ,c)
        }
    }
    buffer1(0) = mp1
  }

  def evaluate(buffer: Row): Any = {
      buffer.getAs[scala.collection.mutable.Map[String, LongType]](0)
  }
}

Тогда у меня есть эта функция на моем фрейме данных,

val func = new DistinctValues
val agg_values = df.groupBy("col1").agg(func(col("col2")).as("DV"))

Это дало такую ​​ошибку,

func: DistinctValues = $iwC$$iwC$DistinctValues@17f48a25
org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 32.0 failed 4 times, most recent failure: Lost task 1.3 in stage 32.0 (TID 884, ip-172-31-22-166.ec2.internal): java.lang.ClassCastException: scala.collection.immutable.Map$EmptyMap$ cannot be cast to scala.collection.mutable.Map
at $iwC$$iwC$DistinctValues.update(<console>:39)
at org.apache.spark.sql.execution.aggregate.ScalaUDAF.update(udaf.scala:431)
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$12.apply(AggregationIterator.scala:187)
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$12.apply(AggregationIterator.scala:180)
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.processCurrentSortedGroup(SortBasedAggregationIterator.scala:116)
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:152)
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:29)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:149)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:73)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:41)
at org.apache.spark.scheduler.Task.run(Task.scala:89)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:213)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)
at java.lang.Thread.run(Thread.java:745)

Похоже, в update(buffer: MutableAggregationBuffer, input: Row) метод, переменная buffer это immutable.Map, программа устала бросать его mutable.Map,

Но я использовал mutable.Map инициализировать buffer переменная в initialize(buffer: MutableAggregationBuffer, input:Row) метод. Это та же самая переменная, переданная update метод? А также buffer является mutableAggregationBufferтак оно и должно быть изменчивым, верно?

Почему мой изменяемый. Карта стала неизменной? Кто-нибудь знает, что случилось?

Мне действительно нужна изменчивая карта в этой функции, чтобы выполнить задачу. Я знаю, что есть обходной путь, чтобы создать изменяемую карту из неизменяемой карты, а затем обновить ее. Но я действительно хочу знать, почему изменяемый в программе автоматически преобразуется в неизменяемый, для меня это не имеет смысла.

1 ответ

Решение

Поверьте это MapType в вашем StructType, buffer поэтому держит Map, который был бы неизменным.

Вы можете конвертировать его, но почему бы не оставить его неизменным и сделать это:

mp = mp + (k -> c)

добавить запись в неизменяемый Map?

Рабочий пример ниже:

class DistinctValues extends UserDefinedAggregateFunction {
  def inputSchema: org.apache.spark.sql.types.StructType = StructType(StructField("_2", IntegerType) :: Nil)

  def bufferSchema: StructType = StructType(StructField("values", MapType(StringType, LongType))::Nil)

  def dataType: DataType =  MapType(StringType, LongType)
  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = Map()
  }

  def update(buffer: MutableAggregationBuffer, input: Row) : Unit = {
    val str = input.getAs[String](0)
    var mp = buffer.getAs[Map[String, Long]](0)
    var c:Long = mp.getOrElse(str, 0)
    c = c + 1
    mp = mp  + (str -> c)
    buffer(0) = mp
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) : Unit = {
    var mp1 = buffer1.getAs[Map[String, Long]](0)
    var mp2 = buffer2.getAs[Map[String, Long]](0)
    mp2 foreach {
        case (k ,v) => {
            var c:Long = mp1.getOrElse(k, 0)
            c = c + v
            mp1 = mp1 + (k -> c)
        }
    }
    buffer1(0) = mp1
  }

  def evaluate(buffer: Row): Any = {
      buffer.getAs[Map[String, LongType]](0)
  }
}

Поздно на вечеринку. Я только что обнаружил, что можно использовать

override def bufferSchema: StructType = StructType(List(
    StructField("map", ObjectType(classOf[mutable.Map[String, Long]]))
))

использовать mutable.Map в буфере.

Другие вопросы по тегам