Skip to content

Commit 3f0d4fe

Browse files
committed
implement review comments
1 parent 65eb519 commit 3f0d4fe

3 files changed

Lines changed: 39 additions & 41 deletions

File tree

crates/bpe/benchmarks/performance.rs

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,6 @@ use rand::rngs::StdRng;
1313
use rand::SeedableRng;
1414
use rand::{rng, Rng};
1515

16-
fn get_rng(seed: u64) -> StdRng {
17-
// Expand the u64 seed to 32 bytes
18-
let mut seed_bytes = [0u8; 32];
19-
seed_bytes[..8].copy_from_slice(&seed.to_le_bytes());
20-
StdRng::from_seed(seed_bytes)
21-
}
22-
2316
fn counting_benchmark(c: &mut Criterion) {
2417
for (name, bpe, _, _) in TOKENIZERS.iter() {
2518
let input = create_test_string(&bpe.bpe, 80_000);
@@ -107,10 +100,7 @@ fn encoding_benchmark(c: &mut Criterion) {
107100
|b, bytes| {
108101
b.iter_batched(
109102
|| select_test_string(&text, *bytes),
110-
|text| {
111-
bpe.bpe
112-
.encode_minimal_dropout(text.as_bytes(), 0.1, get_rng(0))
113-
},
103+
|text| bpe.bpe.encode_minimal_dropout(text.as_bytes(), 0.1, rng()),
114104
criterion::BatchSize::SmallInput,
115105
)
116106
},

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -558,17 +558,20 @@ impl BytePairEncoding {
558558
/// In more detail: the tokenization uses dynamic programming, i.e. it models the tokenization as a graph,
559559
/// where every position between text bytes is a node and two nodes are connected when the text slice between those two nodes matches a token.
560560
// It then tries to find the shortest possible path from the beginning of the text till the end, i.e. it finds the shortest possible encoding.
561-
// For this nodes are processed from right to left. At each node, edges starting at that node and ending on the right are tested and
562-
// the one producing the shortest path is stored together with the length of the shortest path to that node.
561+
// For this nodes are processed from right to left. At each node, edges starting at that node and ending on the right are tested and
562+
// the one producing the shortest path is stored together with the length of the shortest path to that node.
563563
// The length of the shortest path is stored as second value, the edge (or rather token) is stored as first value.
564+
// Then, we walk in reverse direction through the table along the shortest path.
565+
// Note: the reason for constructing the table from back to front is that
566+
// the reconstruction outputs the path from start till end (i.e. we don't have to reverse the path afterwards).
564567
//
565568
// For the dropout (when dropout > 0.0), we uniformly drop edges from the graph, but always keep the one-byte tokens such that the graph stays connected.
566569
// Note: this is very different from how BPE works and cannot produce the same output as the algorithm
567570
// in the [paper's repository](https://github.com/VProv/BPE-Dropout/blob/master/bpe.py#L98), for two main reasons:
568571
// - `encode_minimal` already doesn't follow the original heap-based BPE procedure
569-
// - randomness source in dropout works differently in rust and python
570572
// - BPE-dropout authors discard all multi-byte tokens for each word separately, while this implementation does not split the "sentence" into words first
571573
// and hence may include previously discarded token later down the byte stream. At the sentence level though we don't expect it to make much difference.
574+
// Also, this implementation of BPE constructs merges on the fly from the set of tokens, hence might come up with a different set of merges with the same dictionary.
572575
#[cfg(feature = "rand")]
573576
pub fn encode_minimal_dropout<R: rand::Rng>(
574577
&self,

crates/bpe/tests/src/lib.rs

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#[cfg(test)]
22
mod tests {
3-
use std::time;
4-
53
use itertools::Itertools;
64
use rand::{rng, Rng};
75
use tiktoken_rs::cl100k_base_singleton;
@@ -157,31 +155,38 @@ mod tests {
157155
}
158156

159157
let bpe = &cl100k_base().bpe;
160-
for bytes in [10000, 20000] {
161-
for _ in 0..8 {
162-
let input = create_test_bytes(bpe, bytes);
163-
let encoded = bpe.encode_minimal(&input);
164-
let encoded_d_0_2 = bpe.encode_minimal_dropout(&input, 0.2, get_rng(0));
165-
let encoded_d_0_9 = bpe.encode_minimal_dropout(&input, 0.9, get_rng(1));
166-
let encoded_d_1_0 = bpe.encode_minimal_dropout(&input, 1.0, get_rng(2));
167-
let decoded = bpe.decode_tokens(&encoded);
168-
let decoded_min = bpe.decode_tokens(&encoded_d_min);
169-
let decoded_max = bpe.decode_tokens(&encoded_d_max);
170-
let decoded_max_again = bpe.decode_tokens(&encoded_d_1_0);
171-
println!("Input length: {}, Encoded length: {}, Encoded with dropout length: {}-{}, max {}",
172-
input.len(), encoded.len(), encoded_d_min.len(), encoded_d_max.len(), encoded_d_1_0.len());
173-
assert_eq!(input, decoded);
174-
assert_eq!(input, decoded_min);
175-
assert_eq!(input, decoded_max);
176-
assert_eq!(input, decoded_max_again);
177-
assert_eq!(input.len(), encoded_d_1_0.len());
178-
assert!(encoded_d_min.len() >= encoded.len());
179-
assert!(encoded_d_max.len() > encoded.len());
180-
181-
assert_ne!(encoded, encoded_d_min);
182-
assert_ne!(encoded, encoded_d_max);
183-
assert_ne!(encoded_d_max, encoded_d_1_0);
184-
}
158+
let bytes = 10000;
159+
for _ in 0..8 {
160+
let input = create_test_bytes(bpe, bytes);
161+
let encoded = bpe.encode_minimal(&input);
162+
let encoded_d_0_2 = bpe.encode_minimal_dropout(&input, 0.2, get_rng(0));
163+
let encoded_d_0_9 = bpe.encode_minimal_dropout(&input, 0.9, get_rng(1));
164+
let encoded_d_1_0 = bpe.encode_minimal_dropout(&input, 1.0, get_rng(1));
165+
let encoded_d_0_9_again = bpe.encode_minimal_dropout(&input, 0.9, get_rng(1));
166+
let decoded = bpe.decode_tokens(&encoded);
167+
let decoded_min = bpe.decode_tokens(&encoded_d_0_2);
168+
let decoded_max = bpe.decode_tokens(&encoded_d_0_9);
169+
let decoded_max_again = bpe.decode_tokens(&encoded_d_0_9_again);
170+
println!(
171+
"Input length: {}, Encoded length: {}, Encoded with dropout length: {}-{}, max {}",
172+
input.len(),
173+
encoded.len(),
174+
encoded_d_0_2.len(),
175+
encoded_d_0_9.len(),
176+
encoded_d_0_9_again.len()
177+
);
178+
assert_eq!(encoded_d_0_9, encoded_d_0_9_again);
179+
assert_eq!(input, decoded);
180+
assert_eq!(input, decoded_min);
181+
assert_eq!(input, decoded_max);
182+
assert_eq!(input, decoded_max_again);
183+
assert_eq!(input.len(), encoded_d_1_0.len());
184+
assert!(encoded_d_0_2.len() >= encoded.len());
185+
assert!(encoded_d_0_9.len() > encoded.len());
186+
187+
assert_ne!(encoded, encoded_d_0_2);
188+
assert_ne!(encoded, encoded_d_0_9);
189+
assert_ne!(encoded_d_0_9, encoded_d_1_0);
185190
}
186191
}
187192
}

0 commit comments

Comments
 (0)