Spark_UDAF

 2023-09-11 阅读 17 评论 0

摘要:import org.apache.spark.SparkContext import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row, SparkSession}/*** 自定义函数:* UD
import org.apache.spark.SparkContext
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}/*** 自定义函数:* UDF:User- Defined Funcation;用户定义(普通)函数,只对单行数值产生作用; 一进一出* UDAF:User- Defined Aggregation Funcation;用户定义聚合函数,对多行数据产生作用(sum()、avg()...),多进一出* UDTF:User- Defined Table-Generating Functions;用户定义表生成函数,输入一行 输出多行,一进多出**///  多进一出
object UDAF {def main(args: Array[String]): Unit = {val sparkSession: SparkSession = SparkSession.builder().appName("UDAF").master("local[*]").getOrCreate()val sc: SparkContext = sparkSession.sparkContextimport sparkSession.implicits._val students: Seq[Student] = Seq(Student(1, "zhangsan", "F", 22),Student(2, "lisi", "M", 38),Student(3, "wangwu", "M", 13),Student(4, "zhaoliu", "F", 17),Student(5, "songba", "M", 32),Student(6, "sunjiu", "M", 16),Student(7, "qianshiyi", "F", 17),Student(8, "yinshier", "F", 15),Student(9, "fangshisan", "M", 12),Student(10, "yeshisan", "F", 11),Student(11, "ruishiyi", "F", 26),Student(12, "chenshier", "M", 28))//    seq to df :   1. roDF  2.spark.createDataFrameval frame: DataFrame = sparkSession.createDataFrame(students)frame.printSchema()import org.apache.spark.sql.functions._sparkSession.udf.register("myAvg",new MyAgeAvgFunction)frame.createOrReplaceTempView("students")val resultDF: DataFrame = sparkSession.sql("select gender,myAvg(age) avgage from students group by gender")resultDF.printSchema()resultDF.show(false)}
}//自定义聚合函数 UDAF 继承UserDefinedAggregateFunction
class MyAgeAvgFunction extends UserDefinedAggregateFunction{//聚合函数的输出数据的数据结构override def inputSchema: StructType = {new StructType().add("age",LongType)}//在缓冲区内的数据结构//sum 用来记录 所有年龄值相加的总和 43 + 52 + 61 + 78 = 234 => sum//count 用来记录 相加各个的总和 1 + 1 + 1 + 1 = 4 => countoverride def bufferSchema: StructType = {new StructType().add("sum",LongType).add("count",LongType)}//定义当前函数返回值的类型 sum/count 得到 Double类型override def dataType: DataType = DoubleType//聚合函数幂等override def deterministic: Boolean = true//初始值override def initialize(buffer: MutableAggregationBuffer): Unit = {buffer(0)=0L //记录 传入所有用户年龄相加的总和buffer(1)=0L //记录 传入所有用户年龄的个数}// 传入一条新数据后需要进入处理// 将Row() 对象中的值与buffer(0) 数据相加// buffer(1)数据个数加一override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {buffer(0) = buffer.getLong(0) + input.getLong(0)buffer(1) = buffer.getLong(1) + 1}//合并 各分区内的数据//例如 p1(321,6) p2(128,2) p3(219,3)override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {//计算年龄相加的总和buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)//总人数buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)}//计算最终结果override def evaluate(buffer: Row): Any = {buffer.getLong(0)/buffer.getLong(1).toDouble}
}

版权声明:本站所有资料均为网友推荐收集整理而来,仅供学习和研究交流使用。

原文链接:https://hbdhgg.com/3/49102.html

发表评论:

本站为非赢利网站,部分文章来源或改编自互联网及其他公众平台,主要目的在于分享信息,版权归原作者所有,内容仅供读者参考,如有侵权请联系我们删除!

Copyright © 2022 匯編語言學習筆記 Inc. 保留所有权利。

底部版权信息