# BPE 分词器高性能优化：从 10 分钟到 1 秒的实践


> 本文是 [CS336 作业一](/posts/cs336-assign1/) 的延伸阅读，详细介绍 BPE 分词器的优化实现。

## 背景

文档中推荐使用的 cppyy 在 Mac 和 Linux 环境中有问题。为了追求高性能，我使用 Pybind11 来绑定 C++ 代码：预分词由 Python 处理，而 BPE 归并过程交给 C++。实际最大的瓶颈还是预分词，可以直接用已有的代码 `pretokenization_example.py` 做分块并行（8核 100s → 16核 30s）。

## 核心优化策略

### 1. 并行化处理

- 使用 OpenMP 并行处理 `build_initial_counts()` 函数
- 每个线程维护本地统计结果（`ThreadLocal` 结构），避免频繁的锁竞争
- 最终合并各线程的结果到全局统计中

### 2. 惰性删除的优先级队列

- 使用最大堆（priority_queue）快速找到最高频的 token pair
- 采用"惰性删除"策略：不直接从队列中删除过期的 pair
- 当从队列顶部取出 pair 时，检查它是否仍然有效（频率是否与当前统计匹配）
- 复杂度从 O(NlogN) 降低到 O(KlogN)，其中 K 是需要跳过的过期条目数

### 3. 增量更新机制

- `merge_and_update()` 函数在合并 token 时，只更新受影响的相邻 pair
- 维护 `pair_positions` 索引结构，记录每个 pair 在哪些单词的什么位置出现
- 避免每次合并后重新扫描所有单词，大幅减少计算量

### 4. 高效的数据结构

- 使用整数 ID（0-255）表示字节，避免频繁的字符串操作
- 自定义哈希函数 `PairHash` 支持 pair 作为 unordered_map 的键
- 使用 `-1` 标记已合并的 token，避免数据移动

### 5. 内存友好的表示

- 单词存储为整数向量而不是字符串
- 词汇表使用 `map<int, Bytes>`，支持快速 ID 到字节串的查找
- 特殊 token 在训练结束时添加，不影响核心训练过程

### 6. 灵活的训练控制

- 支持指定目标词汇表大小
- 支持特殊 token（如 `<pad>`）
- 返回完整的合并记录，便于后续编码使用

---

## 详细示例：优化工作流程

用一个具体例子说明这些优化如何协同工作。

### 输入数据

```cpp
words = {"low", "lower", "widest", "newest"}
counts = [5, 2, 3, 6]
```

初始 token 频率统计（频率相同时按字典序比较）：

| Token Pair | 频率 | 来源                  |
| ---------- | ---- | --------------------- |
| ("e","s")  | 9    | newest(6) + widest(3) |
| ("s","t")  | 9    | newest(6) + widest(3) |
| ("w","e")  | 8    | newest(6) + lower(2)  |
| ("l","o")  | 7    | low(5) + lower(2)     |
| ("o","w")  | 7    | low(5) + lower(2)     |

**频率相同时的字典序比较：**
- `"es"` 对应 ("e","s")
- `"st"` 对应 ("s","t")
- 字典序比较：`"es" < "st"`，所以 `("e","s")` 的优先级**低于** `("s","t")`

在最大堆中，优先级低的会下沉，所以堆顶是 `("s","t")`。

### 惰性删除队列工作流程

**第一次合并：合并 `("s","t")` 为新 token 256**

1. **初始队列状态**（最大堆，堆顶优先）：
   ```
   堆顶: ("s","t"):9  <-- 将被合并
         ("e","s"):9
         ("w","e"):8
         ("l","o"):7
         ("o","w"):7
   ```

2. **合并操作的影响**：
   - 单词 "newest"：`[110,101,119,101,115,116]` → `[110,101,119,101,256,-1]`
   - 单词 "widest"：`[119,105,100,101,115,116]` → `[119,105,100,101,256,-1]`

3. **增量更新（而非重新计算全部）**：
   ```cpp
   // 对于 "newest"：
   // 删除受影响 pairs: ("e","s"):6, ("s","t"):6
   // 添加新 pairs: ("e",256):6

   // 对于 "widest"：
   // 删除受影响 pairs: ("e","s"):3, ("s","t"):3
   // 添加新 pairs: ("e",256):3
   ```

4. **队列更新（惰性方式）**：
   - 不立即删除队列中的旧条目
   - 将新 pairs 推入队列：`("e",256):9`
   - 队列现在包含新旧混合条目

5. **下一次获取最高频 pair 时**：
   ```cpp
   while (!pair_queue.empty()) {
       best_info = pair_queue.top();
       pair_queue.pop();

       auto it = pair_counts.find(best_info.pair);
       if (it != pair_counts.end() && it->second == best_info.count) {
           break;  // 有效，使用它
       }
       // 否则，这是过期条目，继续检查下一个
   }
   ```

### 增量更新的具体数值变化

**合并前全局统计**：

```
pair_counts = {
    ("e","s"):9,
    ("s","t"):9,
    ("w","e"):8,
    ("l","o"):7,
    ("o","w"):7,
    ("n","e"):6,
    ("e","w"):6,
    ...
}
```

**合并 `("s","t")` 后的增量更新**：

对于 "newest"（频率 6）：
1. 删除左相邻 `("e","s")`：`pair_counts[("e","s")] -= 6` → 从 9 到 3
2. 删除 `("s","t")` 自身：`pair_counts[("s","t")] -= 6` → 从 9 到 3
3. 添加新左相邻 `("e",256)`：`pair_counts[("e",256)] += 6` → 从 0 到 6

对于 "widest"（频率 3）：
1. 删除左相邻 `("e","s")`：`pair_counts[("e","s")] -= 3` → 从 3 到 0
2. 删除 `("s","t")` 自身：`pair_counts[("s","t")] -= 3` → 从 3 到 0
3. 添加新左相邻 `("e",256)`：`pair_counts[("e",256)] += 3` → 从 6 到 9

### 并行处理优势

假设有 8 个线程，处理 100 万个单词：

**优化前（串行）**：
- 单线程扫描 100 万单词，统计所有相邻 pair
- 时间复杂度：O(N×M)，其中 M 是平均单词长度

**优化后（并行）**：
```cpp
#pragma omp parallel for schedule(static)
for (size_t i = 0; i < 1000000; ++i) {
    // 每个线程处理约 125,000 个单词
    // 线程本地统计，无锁竞争
}
// 最后合并线程本地结果
```

**性能提升**：
- 理想情况下：8 线程加速 ≈ 6-7 倍
- 实际考虑线程创建、合并开销：加速 ≈ 5-6 倍

### 内存效率对比

| 方法        | 存储 "newest"               | 合并 "st" 后               | 内存占用            |
| ----------- | --------------------------- | -------------------------- | ------------------- |
| 字符串数组  | `["n","e","w","e","s","t"]` | `["n","e","w","e","st"]`   | 需要移动/复制字符串 |
| 整数ID+标记 | `[110,101,119,101,115,116]` | `[110,101,119,101,256,-1]` | 只修改两个整数      |

---

## 性能对比

测试机器：autodl Xeon(R) Platinum 8352V 32 核心 CPU 60GB 内存，预分词使用 24 个核心并行工作。

|           数据           | 版本            | 预分词  | BPE归并训练 | 总时间   |
| :----------------------: | --------------- | ------- | ----------- | -------- |
| TinyStoriesV2-GPT4-train | Python          | 29.65s  | 10min++     | 不可接受 |
| TinyStoriesV2-GPT4-train | Cpp 未优化归并  | 27.337s | 366.644s    | 394.16s  |
| TinyStoriesV2-GPT4-train | Cpp 优化归并    | 26.767s | 1.081s      | 28.03s   |
| TinyStoriesV2-GPT4-train | Rust 优化预分词 | 67.261s | -           | -        |

> Python 的 regex 库底层 C 语言优化非常出色，C++ 的 regex 库对 Unicode 支持不完善，Rust 性能反而比 Python 慢一倍。正如文档所说：*"...but the regex package in Python is, if anything, even faster."*

## 总结

这些优化使得算法能够高效处理大规模文本数据，特别是在构建初始统计和迭代合并阶段表现出色：

- **并行化处理**加速了初始计数
- **惰性删除的优先级队列**减少了维护开销
- **增量更新机制**避免了不必要的重复计算

最终实现了从 10 分钟以上到约 1 秒的性能提升，提速超过 **300 倍**。

