Comment sélectionner la première ligne de chaque groupe?

J’ai un DataFrame généré comme suit:

df.groupBy($"Hour", $"Category") .agg(sum($"value") as "TotalValue") .sort($"Hour".asc, $"TotalValue".desc)) 

Les résultats ressemblent à:

 +----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 0| cat26| 30.9| | 0| cat13| 22.1| | 0| cat95| 19.6| | 0| cat105| 1.3| | 1| cat67| 28.5| | 1| cat4| 26.8| | 1| cat13| 12.6| | 1| cat23| 5.3| | 2| cat56| 39.6| | 2| cat40| 29.7| | 2| cat187| 27.9| | 2| cat68| 9.8| | 3| cat8| 35.6| | ...| ....| ....| +----+--------+----------+ 

Comme vous pouvez le constater, le DataFrame est classé par Hour dans un ordre croissant, puis par TotalValue dans un ordre décroissant.

Je voudrais sélectionner la première rangée de chaque groupe, c.-à-d.

  • du groupe d’Heure == 0 select (0, cat26,30.9)
  • du groupe d’Heure == 1 select (1, cat67,28.5)
  • du groupe d’Heure == 2 select (2, cat56,39.6)
  • etc

Donc, le résultat souhaité serait:

 +----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 0| cat26| 30.9| | 1| cat67| 28.5| | 2| cat56| 39.6| | 3| cat8| 35.6| | ...| ...| ...| +----+--------+----------+ 

Il pourrait être utile de pouvoir sélectionner également les N premières lignes de chaque groupe.

Toute aide est grandement appréciée.

    Fonctions de fenêtre :

    Quelque chose comme ça devrait faire l’affaire:

     import org.apache.spark.sql.functions.{row_number, max, broadcast} import org.apache.spark.sql.expressions.Window val df = sc.parallelize(Seq( (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3), (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3), (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8), (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue") val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc) val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn") dfTop.show // +----+--------+----------+ // |Hour|Category|TotalValue| // +----+--------+----------+ // | 0| cat26| 30.9| // | 1| cat67| 28.5| // | 2| cat56| 39.6| // | 3| cat8| 35.6| // +----+--------+----------+ 

    Cette méthode sera inefficace en cas de biais important des données.

    Agrégation SQL simple suivie de la join :

    Vous pouvez également vous joindre à un bloc de données agrégé:

     val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value")) val dfTopByJoin = df.join(broadcast(dfMax), ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value")) .drop("max_hour") .drop("max_value") dfTopByJoin.show // +----+--------+----------+ // |Hour|Category|TotalValue| // +----+--------+----------+ // | 0| cat26| 30.9| // | 1| cat67| 28.5| // | 2| cat56| 39.6| // | 3| cat8| 35.6| // +----+--------+----------+ 

    Il conservera les valeurs en double (s’il y a plus d’une catégorie par heure avec la même valeur totale). Vous pouvez les supprimer comme suit:

     dfTopByJoin .groupBy($"hour") .agg( first("category").alias("category"), first("TotalValue").alias("TotalValue")) 

    Utilisation de la commande sur les structs :

    Neat, bien que pas très bien testé, truc qui ne nécessite pas de jointures ou de fonctions de fenêtre:

     val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs")) .groupBy($"hour") .agg(max("vs").alias("vs")) .select($"Hour", $"vs.Category", $"vs.TotalValue") dfTop.show // +----+--------+----------+ // |Hour|Category|TotalValue| // +----+--------+----------+ // | 0| cat26| 30.9| // | 1| cat67| 28.5| // | 2| cat56| 39.6| // | 3| cat8| 35.6| // +----+--------+----------+ 

    Avec l’API DataSet (Spark 1.6+, 2.0+):

    Spark 1.6 :

     case class Record(Hour: Integer, Category: Ssortingng, TotalValue: Double) df.as[Record] .groupBy($"hour") .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y) .show // +---+--------------+ // | _1| _2| // +---+--------------+ // |[0]|[0,cat26,30.9]| // |[1]|[1,cat67,28.5]| // |[2]|[2,cat56,39.6]| // |[3]| [3,cat8,35.6]| // +---+--------------+ 

    Spark 2.0 ou version ultérieure :

     df.as[Record] .groupByKey(_.Hour) .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y) 

    Les deux dernières méthodes peuvent tirer parti de la combinaison des faces de la carte et ne requièrent pas une lecture aléatoire complète. Par conséquent, la plupart du temps, elles doivent présenter de meilleures performances que les fonctions de fenêtre et les jointures. Celles-ci peuvent également être utilisées avec le streaming structuré en mode sortie complet.

    Ne pas utiliser :

     df.orderBy(...).groupBy(...).agg(first(...), ...) 

    Cela peut sembler fonctionner (surtout en mode local ) mais il n’est pas fiable ( SPARK-16207 ). Crédits à Tzach Zohar pour relier le problème JIRA pertinent .

    La même note s’applique à

     df.orderBy(...).dropDuplicates(...) 

    qui utilise en interne un plan d’exécution équivalent.

    Pour Spark 2.0.2 avec regroupement par plusieurs colonnes:

     import org.apache.spark.sql.functions.row_number import org.apache.spark.sql.expressions.Window val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc) val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn") 

    C’est exactement la même chose que la réponse de zero323 mais en mode requête SQL

    En supposant que le dataframe est créé et enregistré comme

     df.createOrReplaceTempView("table") //+----+--------+----------+ //|Hour|Category|TotalValue| //+----+--------+----------+ //|0 |cat26 |30.9 | //|0 |cat13 |22.1 | //|0 |cat95 |19.6 | //|0 |cat105 |1.3 | //|1 |cat67 |28.5 | //|1 |cat4 |26.8 | //|1 |cat13 |12.6 | //|1 |cat23 |5.3 | //|2 |cat56 |39.6 | //|2 |cat40 |29.7 | //|2 |cat187 |27.9 | //|2 |cat68 |9.8 | //|3 |cat8 |35.6 | //+----+--------+----------+ 

    Fonction de fenêtre:

     sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn FROM table) tmp where rn = 1").show(false) //+----+--------+----------+ //|Hour|Category|TotalValue| //+----+--------+----------+ //|1 |cat67 |28.5 | //|3 |cat8 |35.6 | //|2 |cat56 |39.6 | //|0 |cat26 |30.9 | //+----+--------+----------+ 

    Agrégation SQL simple suivie de la jointure:

     sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " + "(select Hour, Category, TotalValue from table tmp1 " + "join " + "(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " + "on " + "tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " + "group by tmp3.Hour") .show(false) //+----+--------+----------+ //|Hour|Category|TotalValue| //+----+--------+----------+ //|1 |cat67 |28.5 | //|3 |cat8 |35.6 | //|2 |cat56 |39.6 | //|0 |cat26 |30.9 | //+----+--------+----------+ 

    Utilisation de la commande sur les structures:

     sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false) //+----+--------+----------+ //|Hour|Category|TotalValue| //+----+--------+----------+ //|1 |cat67 |28.5 | //|3 |cat8 |35.6 | //|2 |cat56 |39.6 | //|0 |cat26 |30.9 | //+----+--------+----------+ 

    DataSets manière et ne font pas s sont les mêmes que dans la réponse originale

    La solution ci-dessous ne comprend qu’un seul groupeBy et extrait les lignes de votre dataframe qui contiennent la valeur maxValue en une seule fois. Pas besoin d’autres jointures ou Windows.

     import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.DataFrame //df is the dataframe with Day, Category, TotalValue implicit val dfEnc = RowEncoder(df.schema) val res: DataFrame = df.groupByKey{(r) => r.getInt(0)}.mapGroups[Row]{(day: Int, rows: Iterator[Row]) => i.maxBy{(r) => r.getDouble(2)}} 

    Si le dataframe doit être groupé par plusieurs colonnes, cela peut aider

     val keys = List("Hour", "Category"); val selectFirstValueOfNoneGroupedColumns = df.columns .filterNot(keys.toSet) .map(_ -> "first").toMap val grouped = df.groupBy(keys.head, keys.tail: _*) .agg(selectFirstValueOfNoneGroupedColumns) 

    J’espère que cela aide quelqu’un avec un problème similaire

    Ici vous pouvez faire comme ça –

      val data = df.groupBy("Hour").agg(first("Hour").as("_1"),first("Category").as("Category"),first("TotalValue").as("TotalValue")).drop("Hour") data.withColumnRenamed("_1","Hour").show 

    Nous pouvons utiliser la fonction de fenêtre rank () où vous choisiriez le rang = 1). Le rang ajoute simplement un nombre pour chaque ligne d’un groupe (dans ce cas, ce serait l’heure)

    voici un exemple. (extrait de https://github.com/jaceklaskowski/mastering-apache-spark-book/blob/master/spark-sql-functions.adoc#rank )

     val dataset = spark.range(9).withColumn("bucket", 'id % 3) import org.apache.spark.sql.expressions.Window val byBucket = Window.partitionBy('bucket).orderBy('id) scala> dataset.withColumn("rank", rank over byBucket).show +---+------+----+ | id|bucket|rank| +---+------+----+ | 0| 0| 1| | 3| 0| 2| | 6| 0| 3| | 1| 1| 1| | 4| 1| 2| | 7| 1| 3| | 2| 2| 1| | 5| 2| 2| | 8| 2| 3| +---+------+----+