前言
本文讲解Spark如何获取当前分区的partitionId,这是一位群友提出的问题,其实只要通过TaskContext.get.partitionId(我是在官网上看到的),下面给出一些示例。
1、代码
下面的代码主要测试SparkSession,SparkContext创建的rdd和df是否都支持。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37package com.dkl.leanring.partition
import org.apache.spark.sql.SparkSession
import org.apache.spark.TaskContext
/**
* 获取当前分区的partitionId
*/
object GetPartitionIdDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("GetPartitionIdDemo").master("local").getOrCreate()
val sc = spark.sparkContext
val data = Seq(1, 2, 3, 4)
// 测试rdd,三个分区
val rdd = sc.parallelize(data, 3)
rdd.foreach(i => {
println("partitionId:" + TaskContext.get.partitionId)
})
import spark.implicits._
// 测试df,三个分区
val df = rdd.toDF("id")
df.show
df.foreach(row => {
println("partitionId:" + TaskContext.get.partitionId)
})
// 测试df,两个分区
val data1 = Array((1, 2), (3, 4))
val df1 = spark.createDataFrame(data1).repartition(2)
df1.show()
df1.foreach(row => {
println("partitionId:" + TaskContext.get.partitionId)
})
}
}