目录

BPE 分词器高性能优化:从 10 分钟到 1 秒的实践

本文是 CS336 作业一 的延伸阅读,详细介绍 BPE 分词器的优化实现。

背景

文档中推荐使用的 cppyy 在 Mac 和 Linux 环境中有问题。为了追求高性能,我使用 Pybind11 来绑定 C++ 代码:预分词由 Python 处理,而 BPE 归并过程交给 C++。实际最大的瓶颈还是预分词,可以直接用已有的代码 pretokenization_example.py 做分块并行(8核 100s → 16核 30s)。

核心优化策略

1. 并行化处理

  • 使用 OpenMP 并行处理 build_initial_counts() 函数
  • 每个线程维护本地统计结果(ThreadLocal 结构),避免频繁的锁竞争
  • 最终合并各线程的结果到全局统计中

2. 惰性删除的优先级队列

  • 使用最大堆(priority_queue)快速找到最高频的 token pair
  • 采用"惰性删除"策略:不直接从队列中删除过期的 pair
  • 当从队列顶部取出 pair 时,检查它是否仍然有效(频率是否与当前统计匹配)
  • 复杂度从 O(NlogN) 降低到 O(KlogN),其中 K 是需要跳过的过期条目数

3. 增量更新机制

  • merge_and_update() 函数在合并 token 时,只更新受影响的相邻 pair
  • 维护 pair_positions 索引结构,记录每个 pair 在哪些单词的什么位置出现
  • 避免每次合并后重新扫描所有单词,大幅减少计算量

4. 高效的数据结构

  • 使用整数 ID(0-255)表示字节,避免频繁的字符串操作
  • 自定义哈希函数 PairHash 支持 pair 作为 unordered_map 的键
  • 使用 -1 标记已合并的 token,避免数据移动

5. 内存友好的表示

  • 单词存储为整数向量而不是字符串
  • 词汇表使用 map<int, Bytes>,支持快速 ID 到字节串的查找
  • 特殊 token 在训练结束时添加,不影响核心训练过程

6. 灵活的训练控制

  • 支持指定目标词汇表大小
  • 支持特殊 token(如 <pad>
  • 返回完整的合并记录,便于后续编码使用

详细示例:优化工作流程

用一个具体例子说明这些优化如何协同工作。

输入数据

words = {"low", "lower", "widest", "newest"}
counts = [5, 2, 3, 6]

初始 token 频率统计(频率相同时按字典序比较):

Token Pair频率来源
(“e”,“s”)9newest(6) + widest(3)
(“s”,“t”)9newest(6) + widest(3)
(“w”,“e”)8newest(6) + lower(2)
(“l”,“o”)7low(5) + lower(2)
(“o”,“w”)7low(5) + lower(2)

频率相同时的字典序比较:

  • "es" 对应 (“e”,“s”)
  • "st" 对应 (“s”,“t”)
  • 字典序比较:"es" < "st",所以 ("e","s") 的优先级低于 ("s","t")

在最大堆中,优先级低的会下沉,所以堆顶是 ("s","t")

惰性删除队列工作流程

第一次合并:合并 ("s","t") 为新 token 256

  1. 初始队列状态(最大堆,堆顶优先):

    堆顶: ("s","t"):9  <-- 将被合并
          ("e","s"):9
          ("w","e"):8
          ("l","o"):7
          ("o","w"):7
  2. 合并操作的影响

    • 单词 “newest”:[110,101,119,101,115,116][110,101,119,101,256,-1]
    • 单词 “widest”:[119,105,100,101,115,116][119,105,100,101,256,-1]
  3. 增量更新(而非重新计算全部)

    // 对于 "newest":
    // 删除受影响 pairs: ("e","s"):6, ("s","t"):6
    // 添加新 pairs: ("e",256):6
    
    // 对于 "widest":
    // 删除受影响 pairs: ("e","s"):3, ("s","t"):3
    // 添加新 pairs: ("e",256):3
    
  4. 队列更新(惰性方式)

    • 不立即删除队列中的旧条目
    • 将新 pairs 推入队列:("e",256):9
    • 队列现在包含新旧混合条目
  5. 下一次获取最高频 pair 时

    while (!pair_queue.empty()) {
        best_info = pair_queue.top();
        pair_queue.pop();
    
        auto it = pair_counts.find(best_info.pair);
        if (it != pair_counts.end() && it->second == best_info.count) {
            break;  // 有效,使用它
        }
        // 否则,这是过期条目,继续检查下一个
    }

增量更新的具体数值变化

合并前全局统计

pair_counts = {
    ("e","s"):9,
    ("s","t"):9,
    ("w","e"):8,
    ("l","o"):7,
    ("o","w"):7,
    ("n","e"):6,
    ("e","w"):6,
    ...
}

合并 ("s","t") 后的增量更新

对于 “newest”(频率 6):

  1. 删除左相邻 ("e","s")pair_counts[("e","s")] -= 6 → 从 9 到 3
  2. 删除 ("s","t") 自身:pair_counts[("s","t")] -= 6 → 从 9 到 3
  3. 添加新左相邻 ("e",256)pair_counts[("e",256)] += 6 → 从 0 到 6

对于 “widest”(频率 3):

  1. 删除左相邻 ("e","s")pair_counts[("e","s")] -= 3 → 从 3 到 0
  2. 删除 ("s","t") 自身:pair_counts[("s","t")] -= 3 → 从 3 到 0
  3. 添加新左相邻 ("e",256)pair_counts[("e",256)] += 3 → 从 6 到 9

并行处理优势

假设有 8 个线程,处理 100 万个单词:

优化前(串行)

  • 单线程扫描 100 万单词,统计所有相邻 pair
  • 时间复杂度:O(N×M),其中 M 是平均单词长度

优化后(并行)

#pragma omp parallel for schedule(static)
for (size_t i = 0; i < 1000000; ++i) {
    // 每个线程处理约 125,000 个单词
    // 线程本地统计,无锁竞争
}
// 最后合并线程本地结果

性能提升

  • 理想情况下:8 线程加速 ≈ 6-7 倍
  • 实际考虑线程创建、合并开销:加速 ≈ 5-6 倍

内存效率对比

方法存储 “newest”合并 “st” 后内存占用
字符串数组["n","e","w","e","s","t"]["n","e","w","e","st"]需要移动/复制字符串
整数ID+标记[110,101,119,101,115,116][110,101,119,101,256,-1]只修改两个整数

性能对比

测试机器:autodl Xeon(R) Platinum 8352V 32 核心 CPU 60GB 内存,预分词使用 24 个核心并行工作。

数据版本预分词BPE归并训练总时间
TinyStoriesV2-GPT4-trainPython29.65s10min++不可接受
TinyStoriesV2-GPT4-trainCpp 未优化归并27.337s366.644s394.16s
TinyStoriesV2-GPT4-trainCpp 优化归并26.767s1.081s28.03s
TinyStoriesV2-GPT4-trainRust 优化预分词67.261s--

Python 的 regex 库底层 C 语言优化非常出色,C++ 的 regex 库对 Unicode 支持不完善,Rust 性能反而比 Python 慢一倍。正如文档所说:"…but the regex package in Python is, if anything, even faster."

总结

这些优化使得算法能够高效处理大规模文本数据,特别是在构建初始统计和迭代合并阶段表现出色:

  • 并行化处理加速了初始计数
  • 惰性删除的优先级队列减少了维护开销
  • 增量更新机制避免了不必要的重复计算

最终实现了从 10 分钟以上到约 1 秒的性能提升,提速超过 300 倍