Как использовать countDistinct в Scala с Spark?

Я попытался использовать функцию countDistinct, которая должна быть доступна в Spark 1.5 согласно блогу DataBrick. Однако я получил следующее исключение:

Exception in thread "main" org.apache.spark.sql.AnalysisException: undefined function countDistinct;

Я обнаружил, что в списке рассылки разработчиков Spark они предлагают использовать count и разные функции, чтобы получить тот же результат, который должен получить countDistinct:

count(distinct <columnName>)
// Instead
countDistinct(<columnName>)

Поскольку я строю выражения агрегации динамически из списка имен функций агрегации, я бы предпочел, чтобы не было особых случаев, требующих другой обработки.

Итак, возможно ли объединить его:

  • регистрация нового UDAF, который будет псевдонимом для count(отличное имя столбца)
  • регистрация вручную уже реализована в функции Spark CountDistinct, которая, вероятно, является одной из следующих функций импорта:

    import org.apache.spark.sql.catalyst.expressions. {CountDistinctFunction, CountDistinct}

  • или сделать это по-другому?

РЕДАКТИРОВАТЬ: Пример (с некоторыми местными ссылками и ненужным кодом):

import org.apache.spark.SparkContext
import org.apache.spark.sql.{Column, SQLContext, DataFrame}
import org.apache.spark.sql.functions._

import scala.collection.mutable.ListBuffer


class Flattener(sc: SparkContext) {
  val sqlContext = new SQLContext(sc)

  def flatTable(data: DataFrame, groupField: String): DataFrame = {
    val flatteningExpressions = data.columns.zip(TypeRecognizer.getTypes(data)).
      flatMap(x => getFlatteningExpressions(x._1, x._2)).toList

    data.groupBy(groupField).agg (
      expr(s"count($groupField) as groupSize"),
      flatteningExpressions:_*
    )
  }

  private def getFlatteningExpressions(fieldName: String, fieldType: DType): List[Column] = {
    val aggFuncs = getAggregationFunctons(fieldType)

    aggFuncs.map(f => expr(s"$f($fieldName) as ${fieldName}_$f"))
  }

  private def getAggregationFunctons(fieldType: DType): List[String] = {
    val aggFuncs = new ListBuffer[String]()

    if(fieldType == DType.NUMERIC) {
      aggFuncs += ("avg", "min", "max")
    }

    if(fieldType == DType.CATEGORY) {
      aggFuncs += "countDistinct"
    }

    aggFuncs.toList
  }

}

2 ответа

Решение

countDistinct может использоваться в двух разных формах:

df.groupBy("A").agg(expr("count(distinct B)")

или же

df.groupBy("A").agg(countDistinct("B"))

Однако ни один из этих методов не работает, если вы хотите использовать их в одном столбце с пользовательским UDAF (реализовано как UserDefinedAggregateFunction в Spark 1.5):

// Assume that we have already implemented and registered StdDev UDAF 
df.groupBy("A").agg(countDistinct("B"), expr("StdDev(B)"))

// Will cause
Exception in thread "main" org.apache.spark.sql.AnalysisException: StdDev is implemented based on the new Aggregate Function interface and it cannot be used with functions implemented based on the old Aggregate Function interface.;

Из-за этих ограничений кажется, что наиболее разумным является использование countDistinct в качестве UDAF, что должно позволить одинаково обрабатывать все функции, а также использовать countDistinct вместе с другими UDAF.

Пример реализации может выглядеть так:

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

class CountDistinct extends UserDefinedAggregateFunction{
  override def inputSchema: StructType = StructType(StructField("value", StringType) :: Nil)

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = (buffer.getSeq[String](0).toSet + input.getString(0)).toSeq
  }

  override def bufferSchema: StructType = StructType(
      StructField("items", ArrayType(StringType, true)) :: Nil
  )

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = (buffer1.getSeq[String](0).toSet ++ buffer2.getSeq[String](0).toSet).toSeq
  }

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = Seq[String]()
  }

  override def deterministic: Boolean = true

  override def evaluate(buffer: Row): Any = {
    buffer.getSeq[String](0).length
  }

  override def dataType: DataType = IntegerType
}

Не уверен, что я действительно понял вашу проблему, но это пример агрегированной функции countDistinct:

val values = Array((1, 2), (1, 3), (2, 2), (1, 2))
val myDf = sc.parallelize(values).toDF("id", "foo")
import org.apache.spark.sql.functions.countDistinct
myDf.groupBy('id).agg(countDistinct('foo) as 'distinctFoo) show
/**
+---+-------------------+
| id|COUNT(DISTINCT foo)|
+---+-------------------+
|  1|                  2|
|  2|                  1|
+---+-------------------+
*/
Другие вопросы по тегам