Spark UDAF - используя дженерики в качестве типа ввода?
Я хочу написать Spark UDAF, где тип столбца может быть любым, для которого определен Scala Numeric. Я искал через Интернет, но нашел только примеры с конкретными типами, такими как DoubleType
, LongType
, Разве это не возможно? Но как тогда использовать этот UDAF с другими числовыми значениями?
1 ответ
Для простоты предположим, что вы хотите определить пользовательский sum
, Вам придется предоставить TypeTag
для типа ввода и использования отражения Scala для определения схем:
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import scala.reflect.runtime.universe._
import org.apache.spark.sql.catalyst.ScalaReflection.schemaFor
case class MySum [T : TypeTag](implicit n: Numeric[T])
extends UserDefinedAggregateFunction {
val dt = schemaFor[T].dataType
def inputSchema = new StructType().add("x", dt)
def bufferSchema = new StructType().add("x", dt)
def dataType = dt
def deterministic = true
def initialize(buffer: MutableAggregationBuffer) = buffer.update(0, n.zero)
def update(buffer: MutableAggregationBuffer, input: Row) = {
if (!input.isNullAt(0))
buffer.update(0, n.plus(buffer.getAs[T](0), input.getAs[T](0)))
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
buffer1.update(0, n.plus(buffer1.getAs[T](0), buffer2.getAs[T](0)))
}
def evaluate(buffer: Row) = buffer.getAs[T](0)
}
С помощью функции, определенной выше, мы можем создать экземпляр, обрабатывающий определенные типы:
val sumOfLong = MySum[Long]
spark.range(10).select(sumOfLong($"id")).show
+---------+
|mysum(id)|
+---------+
| 45|
+---------+
Примечание:
Чтобы получить ту же гибкость, что и встроенные агрегатные функции, вам нужно определить свои собственные AggregateFunction
, лайк ImperativeAggregate
или же DeclarativeAggregate
, Это возможно, но это внутренний API.