Spark UDAF - используя дженерики в качестве типа ввода?

Я хочу написать Spark UDAF, где тип столбца может быть любым, для которого определен Scala Numeric. Я искал через Интернет, но нашел только примеры с конкретными типами, такими как DoubleType, LongType, Разве это не возможно? Но как тогда использовать этот UDAF с другими числовыми значениями?

1 ответ

Решение

Для простоты предположим, что вы хотите определить пользовательский sum, Вам придется предоставить TypeTag для типа ввода и использования отражения Scala для определения схем:

import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import scala.reflect.runtime.universe._
import org.apache.spark.sql.catalyst.ScalaReflection.schemaFor

case class MySum [T : TypeTag](implicit n: Numeric[T]) 
    extends UserDefinedAggregateFunction {

  val dt = schemaFor[T].dataType
  def inputSchema = new StructType().add("x", dt)
  def bufferSchema = new StructType().add("x", dt)

  def dataType = dt
  def deterministic = true

  def initialize(buffer: MutableAggregationBuffer) = buffer.update(0,  n.zero)
  def update(buffer: MutableAggregationBuffer, input: Row) = {
    if (!input.isNullAt(0))
      buffer.update(0, n.plus(buffer.getAs[T](0), input.getAs[T](0)))
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    buffer1.update(0, n.plus(buffer1.getAs[T](0),  buffer2.getAs[T](0)))    
  }

  def evaluate(buffer: Row) = buffer.getAs[T](0)
}

С помощью функции, определенной выше, мы можем создать экземпляр, обрабатывающий определенные типы:

val sumOfLong = MySum[Long]
spark.range(10).select(sumOfLong($"id")).show
+---------+
|mysum(id)|
+---------+
|       45|
+---------+

Примечание:

Чтобы получить ту же гибкость, что и встроенные агрегатные функции, вам нужно определить свои собственные AggregateFunction, лайк ImperativeAggregate или же DeclarativeAggregate, Это возможно, но это внутренний API.

Другие вопросы по тегам