
在 PySpark 中处理数据倾斜问题是非常重要的,因为数据倾斜会导致某些任务执行时间过长,从而影响整个作业的性能。以下是一些常见的优化方法:
通过重新分区可以将数据均匀分布到各个分区中。可以使用 repartition 或 coalesce 方法来调整分区数量。
df = df.repartition(100, "key_column")在进行全局聚合之前,先进行局部聚合,可以减少数据传输量。
df = df.groupBy("key_column").agg(F.collect_list("value_column"))
df = df.groupBy("key_column").agg(F.flatten(F.collect_list("value_column")).alias("value_column"))如果一个表很小,可以使用广播 join 来避免数据倾斜。
from pyspark.sql.functions import broadcast
small_df = spark.read.csv("small_table.csv")
large_df = spark.read.csv("large_table.csv")
result = large_df.join(broadcast(small_df), "key_column")在 key 上添加随机值(盐值),以分散热点 key 的负载。
import random
def add_salt(key):
return (key, random.randint(1, 10))
df = df.withColumn("salted_key", F.udf(add_salt)("key_column"))
df = df.groupBy("salted_key").agg(F.collect_list("value_column"))
df = df.withColumn("key_column", F.col("salted_key").getItem(0)).drop("salted_key")对数据进行采样,找出热点 key,然后对这些 key 进行特殊处理。
sample_df = df.sample(False, 0.1)
hot_keys = sample_df.groupBy("key_column").count().filter(F.col("count") > 1000).select("key_column").collect()
hot_keys = [row["key_column"] for row in hot_keys]
def handle_hot_keys(key):
if key in hot_keys:
return (key, random.randint(1, 10))
else:
return (key, 0)
df = df.withColumn("salted_key", F.udf(handle_hot_keys)("key_column"))
df = df.groupBy("salted_key").agg(F.collect_list("value_column"))
df = df.withColumn("key_column", F.when(F.col("salted_key").getItem(1) == 0, F.col("salted_key").getItem(0)).otherwise(F.col("key_column"))).drop("salted_key")增加 Shuffle 操作的分区数,可以更好地分散数据。
spark.conf.set("spark.sql.shuffle.partitions", 200)根据业务需求,实现自定义的 Partitioner 来更好地控制数据的分布。
class CustomPartitioner:
def __init__(self, num_partitions):
self.num_partitions = num_partitions
def getPartition(self, key):
# 自定义分区逻辑
return hash(key) % self.num_partitions
rdd = df.rdd.partitionBy(100, CustomPartitioner(100))
df = rdd.toDF()在数据倾斜发生之前,先进行预聚合,减少后续操作的数据量。
df = df.groupBy("key_column", "sub_key_column").agg(F.sum("value_column").alias("sum_value"))
df = df.groupBy("key_column").agg(F.sum("sum_value").alias("total_value"))原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。