NameAggs.scala 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. package com.winhc.bigdata.spark.udf
  2. import org.apache.commons.lang3.StringUtils
  3. import org.apache.spark.sql.Row
  4. import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
  5. import org.apache.spark.sql.types._
  6. /**
  7. * @Description:原告,被告聚合
  8. * @author π
  9. * @date 2020/10/26 15:15
  10. */
  11. class NameAggs(max: Int) extends UserDefinedAggregateFunction {
  12. val flags = Seq("0", "1", "2", "4", "8")
  13. val split = "\u0001"
  14. override def inputSchema: StructType = StructType(Array[StructField](
  15. StructField("yg_name", DataTypes.StringType)
  16. , StructField("bg_name", DataTypes.StringType)
  17. , StructField("flag", DataTypes.StringType)
  18. , StructField("bus_date", DataTypes.StringType)
  19. ))
  20. override def bufferSchema: StructType = StructType(
  21. Array[StructField](
  22. StructField("t1", DataTypes.createMapType(DataTypes.StringType, DataTypes.StringType))
  23. ,StructField("t2", DataTypes.createMapType(DataTypes.StringType, DataTypes.StringType))
  24. )
  25. )
  26. override def dataType: DataType = DataTypes.createMapType(DataTypes.StringType, DataTypes.StringType)
  27. override def deterministic: Boolean = true
  28. override def initialize(buffer: MutableAggregationBuffer): Unit = {
  29. buffer.update(0, Map[String, String]())
  30. buffer.update(1, Map[String, String]())
  31. }
  32. override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
  33. if (buffer.size >= max) {
  34. return
  35. }
  36. val yg_name = input.getString(0)
  37. val bg_name = input.getString(1)
  38. val flag = input.getString(2)
  39. val bus_date = input.getString(3)
  40. if (StringUtils.isBlank(yg_name) && StringUtils.isBlank(bg_name)) {
  41. return
  42. }
  43. if (!flags.contains(flag)) {
  44. return
  45. }
  46. val map0 = buffer.getMap[String, String](0).toMap
  47. val map1 = buffer.getMap[String, String](1).toMap
  48. var map_new0 = scala.collection.mutable.Map[String, String](map0.toSeq: _*)
  49. var map_new1 = scala.collection.mutable.Map[String, String](map1.toSeq: _*)
  50. if (StringUtils.isNotBlank(yg_name) && StringUtils.isNotBlank(bg_name)) {
  51. map_new0 ++= Map(bus_date -> s"$yg_name$split$bg_name")
  52. } else {
  53. map_new1 ++= Map(bus_date -> s"$yg_name$split$bg_name")
  54. }
  55. buffer.update(0, map_new0)
  56. buffer.update(1, map_new1)
  57. }
  58. override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
  59. buffer1(0) = buffer1.getAs[Map[String, String]](0) ++ buffer2.getAs[Map[String, String]](0)
  60. buffer1(1) = buffer1.getAs[Map[String, String]](1) ++ buffer2.getAs[Map[String, String]](1)
  61. }
  62. override def evaluate(buffer: Row): Any = {
  63. var yg_name = ""
  64. var bg_name = ""
  65. val m0: Map[String, String] = buffer.getAs[Map[String, String]](0)
  66. val m1: Map[String, String] = buffer.getAs[Map[String, String]](1)
  67. println("m0" + m0 + "m1" + m1)
  68. if (m0.isEmpty && m1.isEmpty) {
  69. return Map("yg_name" -> yg_name, "bg_name" -> bg_name)
  70. }else if(!m0.isEmpty){
  71. val key = m0.keySet.toSeq.sorted.head
  72. val Array(a, b) = m0(key).split(s"$split",-1)
  73. yg_name = a
  74. bg_name = b
  75. }else{
  76. val key = m1.keySet.toSeq.sorted.head
  77. val Array(x, y) = m1(key).split(s"$split",-1)
  78. yg_name = x
  79. bg_name = y
  80. }
  81. Map("yg_name" -> yg_name, "bg_name" -> bg_name)
  82. }
  83. }