BPE 分词器高性能优化:从 10 分钟到 1 秒的实践
本文是 CS336 作业一 的延伸阅读,详细介绍 BPE 分词器的优化实现。
背景
文档中推荐使用的 cppyy 在 Mac 和 Linux 环境中有问题。为了追求高性能,我使用 Pybind11 来绑定 C++ 代码:预分词由 Python 处理,而 BPE 归并过程交给 C++。实际最大的瓶颈还是预分词,可以直接用已有的代码 pretokenization_example.py 做分块并行(8核 100s → 16核 30s)。
核心优化策略
1. 并行化处理
- 使用 OpenMP 并行处理
build_initial_counts()函数 - 每个线程维护本地统计结果(
ThreadLocal结构),避免频繁的锁竞争 - 最终合并各线程的结果到全局统计中
2. 惰性删除的优先级队列
- 使用最大堆(priority_queue)快速找到最高频的 token pair
- 采用"惰性删除"策略:不直接从队列中删除过期的 pair
- 当从队列顶部取出 pair 时,检查它是否仍然有效(频率是否与当前统计匹配)
- 复杂度从 O(NlogN) 降低到 O(KlogN),其中 K 是需要跳过的过期条目数
3. 增量更新机制
merge_and_update()函数在合并 token 时,只更新受影响的相邻 pair- 维护
pair_positions索引结构,记录每个 pair 在哪些单词的什么位置出现 - 避免每次合并后重新扫描所有单词,大幅减少计算量
4. 高效的数据结构
- 使用整数 ID(0-255)表示字节,避免频繁的字符串操作
- 自定义哈希函数
PairHash支持 pair 作为 unordered_map 的键 - 使用
-1标记已合并的 token,避免数据移动
5. 内存友好的表示
- 单词存储为整数向量而不是字符串
- 词汇表使用
map<int, Bytes>,支持快速 ID 到字节串的查找 - 特殊 token 在训练结束时添加,不影响核心训练过程
6. 灵活的训练控制
- 支持指定目标词汇表大小
- 支持特殊 token(如
<pad>) - 返回完整的合并记录,便于后续编码使用
详细示例:优化工作流程
用一个具体例子说明这些优化如何协同工作。
输入数据
words = {"low", "lower", "widest", "newest"}
counts = [5, 2, 3, 6]初始 token 频率统计(频率相同时按字典序比较):
| Token Pair | 频率 | 来源 |
|---|---|---|
| (“e”,“s”) | 9 | newest(6) + widest(3) |
| (“s”,“t”) | 9 | newest(6) + widest(3) |
| (“w”,“e”) | 8 | newest(6) + lower(2) |
| (“l”,“o”) | 7 | low(5) + lower(2) |
| (“o”,“w”) | 7 | low(5) + lower(2) |
频率相同时的字典序比较:
"es"对应 (“e”,“s”)"st"对应 (“s”,“t”)- 字典序比较:
"es" < "st",所以("e","s")的优先级低于("s","t")
在最大堆中,优先级低的会下沉,所以堆顶是 ("s","t")。
惰性删除队列工作流程
第一次合并:合并 ("s","t") 为新 token 256
初始队列状态(最大堆,堆顶优先):
堆顶: ("s","t"):9 <-- 将被合并 ("e","s"):9 ("w","e"):8 ("l","o"):7 ("o","w"):7合并操作的影响:
- 单词 “newest”:
[110,101,119,101,115,116]→[110,101,119,101,256,-1] - 单词 “widest”:
[119,105,100,101,115,116]→[119,105,100,101,256,-1]
- 单词 “newest”:
增量更新(而非重新计算全部):
// 对于 "newest": // 删除受影响 pairs: ("e","s"):6, ("s","t"):6 // 添加新 pairs: ("e",256):6 // 对于 "widest": // 删除受影响 pairs: ("e","s"):3, ("s","t"):3 // 添加新 pairs: ("e",256):3队列更新(惰性方式):
- 不立即删除队列中的旧条目
- 将新 pairs 推入队列:
("e",256):9 - 队列现在包含新旧混合条目
下一次获取最高频 pair 时:
while (!pair_queue.empty()) { best_info = pair_queue.top(); pair_queue.pop(); auto it = pair_counts.find(best_info.pair); if (it != pair_counts.end() && it->second == best_info.count) { break; // 有效,使用它 } // 否则,这是过期条目,继续检查下一个 }
增量更新的具体数值变化
合并前全局统计:
pair_counts = {
("e","s"):9,
("s","t"):9,
("w","e"):8,
("l","o"):7,
("o","w"):7,
("n","e"):6,
("e","w"):6,
...
}合并 ("s","t") 后的增量更新:
对于 “newest”(频率 6):
- 删除左相邻
("e","s"):pair_counts[("e","s")] -= 6→ 从 9 到 3 - 删除
("s","t")自身:pair_counts[("s","t")] -= 6→ 从 9 到 3 - 添加新左相邻
("e",256):pair_counts[("e",256)] += 6→ 从 0 到 6
对于 “widest”(频率 3):
- 删除左相邻
("e","s"):pair_counts[("e","s")] -= 3→ 从 3 到 0 - 删除
("s","t")自身:pair_counts[("s","t")] -= 3→ 从 3 到 0 - 添加新左相邻
("e",256):pair_counts[("e",256)] += 3→ 从 6 到 9
并行处理优势
假设有 8 个线程,处理 100 万个单词:
优化前(串行):
- 单线程扫描 100 万单词,统计所有相邻 pair
- 时间复杂度:O(N×M),其中 M 是平均单词长度
优化后(并行):
#pragma omp parallel for schedule(static)
for (size_t i = 0; i < 1000000; ++i) {
// 每个线程处理约 125,000 个单词
// 线程本地统计,无锁竞争
}
// 最后合并线程本地结果
性能提升:
- 理想情况下:8 线程加速 ≈ 6-7 倍
- 实际考虑线程创建、合并开销:加速 ≈ 5-6 倍
内存效率对比
| 方法 | 存储 “newest” | 合并 “st” 后 | 内存占用 |
|---|---|---|---|
| 字符串数组 | ["n","e","w","e","s","t"] | ["n","e","w","e","st"] | 需要移动/复制字符串 |
| 整数ID+标记 | [110,101,119,101,115,116] | [110,101,119,101,256,-1] | 只修改两个整数 |
性能对比
测试机器:autodl Xeon(R) Platinum 8352V 32 核心 CPU 60GB 内存,预分词使用 24 个核心并行工作。
| 数据 | 版本 | 预分词 | BPE归并训练 | 总时间 |
|---|---|---|---|---|
| TinyStoriesV2-GPT4-train | Python | 29.65s | 10min++ | 不可接受 |
| TinyStoriesV2-GPT4-train | Cpp 未优化归并 | 27.337s | 366.644s | 394.16s |
| TinyStoriesV2-GPT4-train | Cpp 优化归并 | 26.767s | 1.081s | 28.03s |
| TinyStoriesV2-GPT4-train | Rust 优化预分词 | 67.261s | - | - |
Python 的 regex 库底层 C 语言优化非常出色,C++ 的 regex 库对 Unicode 支持不完善,Rust 性能反而比 Python 慢一倍。正如文档所说:"…but the regex package in Python is, if anything, even faster."
总结
这些优化使得算法能够高效处理大规模文本数据,特别是在构建初始统计和迭代合并阶段表现出色:
- 并行化处理加速了初始计数
- 惰性删除的优先级队列减少了维护开销
- 增量更新机制避免了不必要的重复计算
最终实现了从 10 分钟以上到约 1 秒的性能提升,提速超过 300 倍。