dask.dataframe.shuffle.set_index
中获取 divisions 的步骤分析
主要流程概述
在 set_index
函数中,当 divisions=None
时,系统需要通过分析数据来动态计算分区边界。这个过程分为以下几个关键步骤:
1. 初始检查和准备
if divisions is None:sizes = df.map_partitions(sizeof) if repartition else []divisions = index2._repartition_quantiles(npartitions, upsample=upsample)mins = index2.map_partitions(M.min)maxes = index2.map_partitions(M.max)divisions, sizes, mins, maxes = base.compute(divisions, sizes, mins, maxes)
步骤说明:
- 计算每个分区的大小(如果启用重新分区)
- 调用
_repartition_quantiles
计算近似分位数 - 并行计算每个分区的最小值和最大值
- 使用
base.compute
触发实际计算
2. 分位数计算过程 (_repartition_quantiles
)
_repartition_quantiles
方法调用 partition_quantiles
函数,该函数执行以下步骤:
2.1 生成采样策略
def sample_percentiles(num_old, num_new, chunk_length, upsample=1.0, random_state=None):# 计算随机百分位比例random_percentage = 1 / (1 + (4 * num_new / num_old) ** 0.5)# 生成等间距和随机百分位
2.2 创建计算图
# 1. 数据类型信息
dtype_dsk = {(name0, 0): (dtype_info, df_keys[0])}# 2. 每个分区的百分位摘要
val_dsk = {(name1, i): (percentiles_summary, key, df.npartitions, npartitions, upsample, state)for i, (state, key) in enumerate(zip(state_data, df_keys))
}# 3. 合并和压缩摘要
merge_dsk = create_merge_tree(merge_and_compress_summaries, sorted(val_dsk), name2)# 4. 最终处理
last_dsk = {(name3, 0): (pd.Series, (process_val_weights, merged_key, npartitions, (name0, 0)), qs, None, df.name)
}
3. 数据后处理
divisions = methods.tolist(divisions)
if type(sizes) is not list:sizes = methods.tolist(sizes)
mins = methods.tolist(mins)
maxes = methods.tolist(maxes)
4. 空数据检测和重新分区
empty_dataframe_detected = pd.isnull(divisions).all()
if repartition or empty_dataframe_detected:total = sum(sizes)npartitions = max(math.ceil(total / partition_size), 1)npartitions = min(npartitions, df.npartitions)# 插值生成新的分界点divisions = np.interp(x=np.linspace(0, n - 1, npartitions + 1),xp=np.linspace(0, n - 1, n),fp=divisions,).tolist()
5. 数据类型特殊处理
if pd.api.types.is_categorical_dtype(index2.dtype):dtype = index2.dtypemins = pd.Categorical(mins, dtype=dtype).codes.tolist()maxes = pd.Categorical(maxes, dtype=dtype).codes.tolist()
6. 排序优化检查
if (mins == sorted(mins) and maxes == sorted(maxes) and all(mx < mn for mx, mn in zip(maxes[:-1], mins[1:]))):divisions = mins + [maxes[-1]]result = set_sorted_index(df, index, drop=drop, divisions=divisions)return result.map_partitions(M.sort_index)
这个检查的作用:
- 如果数据已经按索引排序,可以直接使用最小值和最大值作为分界点
- 避免昂贵的shuffle操作
分位数计算详细过程
核心算法:percentiles_summary
函数
def percentiles_summary(df, num_old, num_new, upsample, state):"""Summarize data using percentiles and derived weights."""# 1. 生成采样百分位qs = sample_percentiles(num_old, num_new, len(df), upsample, state)# 2. 计算百分位值vals = df.quantile(qs)# 3. 转换为权重return percentiles_to_weights(qs, vals, len(df))
权重计算:percentiles_to_weights
函数
def percentiles_to_weights(qs, vals, length):"""Weigh percentile values by length and the difference between percentiles"""if length == 0:return ()diff = np.ediff1d(qs, 0.0, 0.0)weights = 0.5 * length * (diff[1:] + diff[:-1])return vals.tolist(), weights.tolist()
权重计算原理:
- 每个百分位值的权重 = 0.5 × 分区长度 × (前一个百分位差 + 后一个百分位差)
- 这样确保权重反映该值在数据分布中的重要性
合并和压缩:merge_and_compress_summaries
函数
def merge_and_compress_summaries(vals_and_weights):"""Merge and sort percentile summaries that are already sorted."""vals, weights = zip(*vals_and_weights)vals = list(merge_sorted(*vals))weights = list(merge_sorted(*weights))# 压缩重复值compressed_vals = []compressed_weights = []# ... 压缩逻辑return compressed_vals, compressed_weights
最终处理:process_val_weights
函数
这个函数(我们之前分析过的)将合并后的值-权重对转换为最终的分区边界。
完整流程图和示例
让我创建一个详细的流程图来展示整个过程:
开始 set_index(df, index, divisions=None)↓
检查 divisions 是否为 None↓
是 → 开始动态计算分界点↓
1. 并行计算:- 每个分区的大小 (sizes)- 每个分区的分位数 (divisions)- 每个分区的最小值 (mins)- 每个分区的最大值 (maxes)↓
2. 触发计算:base.compute(divisions, sizes, mins, maxes)↓
3. 分位数计算过程:┌─────────────────────────────────────────┐│ 对每个分区执行 percentiles_summary: ││ 1. 生成采样百分位 (等间距 + 随机) ││ 2. 计算百分位值 ││ 3. 转换为权重 │└─────────────────────────────────────────┘↓
4. 合并所有分区的摘要:┌─────────────────────────────────────────┐│ merge_and_compress_summaries: ││ 1. 合并排序的值和权重 ││ 2. 压缩重复值 │└─────────────────────────────────────────┘↓
5. 最终处理:┌─────────────────────────────────────────┐│ process_val_weights: ││ 1. 处理巨型值 ││ 2. 计算最终分界点 ││ 3. 处理数据类型 │└─────────────────────────────────────────┘↓
6. 后处理:- 转换为列表格式- 检查空数据- 处理重新分区- 特殊数据类型处理↓
7. 排序优化检查:- 如果数据已排序,使用 min/max 作为分界点- 否则继续到 shuffle 阶段↓
调用 set_partition 进行实际的数据重排↓
结束
关键优化策略
- 采样策略:结合等间距和随机百分位,平衡计算效率和准确性
- 排序检测:如果数据已排序,避免昂贵的shuffle操作
- 数据类型感知:特别处理分类、时间等特殊数据类型
- 内存优化:通过压缩和合并减少内存使用
- 分布式计算:利用Dask的并行计算能力
性能考虑
- 时间复杂度:O(n log n),主要由排序和分位数计算决定
- 空间复杂度:O(n),存储采样数据和权重
- 网络开销:需要收集所有分区的统计信息
- 计算开销:需要两次数据遍历(统计 + shuffle)
总结
dask.dataframe.shuffle.set_index
中获取 divisions 的过程是一个复杂的分布式算法,主要包含以下步骤:
核心步骤
- 并行统计:计算每个分区的分位数、大小、最小值、最大值
- 分位数计算:使用采样策略生成代表性百分位
- 权重分配:根据数据分布为每个值分配权重
- 合并压缩:合并所有分区的统计信息并压缩重复值
- 分界点计算:使用
process_val_weights
计算最终分界点 - 优化检查:检测数据是否已排序,避免不必要的shuffle
关键特点
- 分布式设计:充分利用Dask的并行计算能力
- 智能采样:结合等间距和随机采样策略
- 类型感知:特别处理不同数据类型
- 性能优化:检测已排序数据,避免重复计算
- 内存高效:通过压缩和合并减少内存使用
这个算法是Dask DataFrame实现高效分布式排序和分区的核心,通过巧妙的采样和合并策略,在保证准确性的同时实现了良好的性能。
自己实现
import numpy as np
import pandas as pd# 1️⃣ 采样百分位
def sample_percentiles(num_old, num_new, chunk_length, upsample=1.0, random_state=None):"""简单版本:等间距百分位"""return np.linspace(0, 1, num_new + 1)# 2️⃣ 计算百分位摘要(值+权重)
def percentiles_summary(series, num_old, num_new):qs = sample_percentiles(num_old, num_new, len(series))vals = series.quantile(qs).to_numpy()diff = np.ediff1d(qs, 0.0, 0.0)weights = 0.5 * len(series) * (diff[1:] + diff[:-1])return vals.tolist(), weights.tolist()# 3️⃣ 合并多个分区的摘要
def merge_and_compress_summaries(summaries):all_vals = []all_weights = []for vals, weights in summaries:all_vals.extend(vals)all_weights.extend(weights)# 按值排序order = np.argsort(all_vals)vals = np.array(all_vals)[order]weights = np.array(all_weights)[order]# 压缩重复值compressed_vals = []compressed_weights = []last_val = Nonefor v, w in zip(vals, weights):if last_val is not None and v == last_val:compressed_weights[-1] += welse:compressed_vals.append(v)compressed_weights.append(w)last_val = vreturn np.array(compressed_vals), np.array(compressed_weights)# 4️⃣ 最终处理:计算分界点
def process_val_weights(vals, weights, npartitions):if len(vals) == 0:return np.array([])if len(vals) == npartitions + 1:return valselif len(vals) < npartitions + 1:q_weights = np.cumsum(weights)q_target = np.linspace(q_weights[0], q_weights[-1], npartitions + 1)return np.interp(q_target, q_weights, vals)else:target_weight = weights.sum() / npartitionsjumbo_mask = weights >= target_weightjumbo_vals = vals[jumbo_mask]trimmed_vals = vals[~jumbo_mask]trimmed_weights = weights[~jumbo_mask]trimmed_npartitions = npartitions - len(jumbo_vals)q_weights = np.cumsum(trimmed_weights)q_target = np.linspace(0, q_weights[-1], trimmed_npartitions + 1)left = np.searchsorted(q_weights, q_target, side="left")right = np.searchsorted(q_weights, q_target, side="right") - 1lower = np.minimum(left, right)trimmed = trimmed_vals[lower]rv = np.concatenate([trimmed, jumbo_vals])rv.sort()return rv# 5️⃣ 模拟 set_index 中 divisions 的获取
def simulate_set_index(df, column, npartitions):num_old = len(df)# 假设原始有分区(这里手动切分成2块模拟)partitions = np.array_split(df[column], 2)summaries = [percentiles_summary(p, num_old, npartitions) for p in partitions]vals, weights = merge_and_compress_summaries(summaries)divisions = process_val_weights(vals, weights, npartitions)return divisions# ========== DEMO 使用 ==========
df = pd.DataFrame({"x": np.random.randint(0, 100, size=50)})divs = simulate_set_index(df, "x", npartitions=4)print("原始数据示例:\n", df.head())
print("\n计算得到的 divisions:", divs)