@@ -18,43 +18,63 @@ def batched(iterable, n: int) -> Iterable[List[str]]:
1818
1919def encode_chunks (
2020 encoder ,
21- sentences ,
21+ texts ,
2222 batch_size = 64 ,
2323 window_size = 50 ,
2424 step_size = 40 ,
25- return_chunks = False ,
26- show_progress_bar = False ,
2725):
28- chunks = []
26+ """
27+ Returns
28+ -------
29+ chunk_embeddings: list[np.ndarray]
30+ Embedding matrix of chunks for each document.
31+ chunk_positions: list[list[tuple[int, int]]]
32+ List of start and end character index of chunks for each document.
33+ """
34+ chunk_positions = []
2935 chunk_embeddings = []
3036 for start_index in trange (
3137 0 ,
32- len (sentences ),
38+ len (texts ),
3339 batch_size ,
3440 desc = "Encoding batches..." ,
35- disable = not show_progress_bar ,
3641 ):
37- batch = sentences [start_index : start_index + batch_size ]
42+ batch = texts [start_index : start_index + batch_size ]
3843 features = encoder .tokenize (batch )
3944 with torch .no_grad ():
4045 output_features = encoder .forward (features )
4146 n_tokens = output_features ["attention_mask" ].sum (axis = 1 )
47+ # Find first nonzero elements in each document
48+ # The document could be padded from the left, so we have to watch out for this.
49+ start_token = torch .argmax (
50+ (output_features ["attention_mask" ] > 0 ).to (torch .long ), axis = 1
51+ )
52+ end_token = start_token + n_tokens
4253 for i_doc in range (len (batch )):
43- for chunk_start in range (0 , n_tokens [i_doc ], step_size ):
44- chunk_end = min (chunk_start + window_size , n_tokens [i_doc ])
54+ _chunk_embeddings = []
55+ _chunk_positions = []
56+ for chunk_start in range (
57+ start_token [i_doc ], end_token [i_doc ], step_size
58+ ):
59+ chunk_end = min (chunk_start + window_size , end_token [i_doc ])
4560 _emb = output_features ["token_embeddings" ][
4661 i_doc , chunk_start :chunk_end , :
4762 ].mean (axis = 0 )
48- chunk_embeddings .append (_emb )
49- if return_chunks :
50- chunks .append (
51- encoder .tokenizer .decode (
52- features ["input_ids" ][i_doc , chunk_start :chunk_end ]
53- )
54- .replace ("[CLS]" , "" )
55- .replace ("[SEP]" , "" )
63+ _chunk_embeddings .append (_emb )
64+ chunk_text = (
65+ encoder .tokenizer .decode (
66+ features ["input_ids" ][i_doc , chunk_start :chunk_end ],
67+ skip_special_tokens = True ,
5668 )
57- if not return_chunks :
58- chunks = None
59- chunk_embeddings = np .stack (chunk_embeddings )
60- return chunk_embeddings , chunks
69+ .replace ("[CLS]" , "" )
70+ .replace ("[SEP]" , "" )
71+ .strip ()
72+ )
73+ doc_text = texts [start_index + i_doc ]
74+ start_char = doc_text .find (chunk_text )
75+ end_char = start_char + len (chunk_text )
76+ _chunk_positions .append ((start_char , end_char ))
77+ _chunk_embeddings = np .stack (_chunk_embeddings )
78+ chunk_embeddings .append (_chunk_embeddings )
79+ chunk_positions .append (_chunk_positions )
80+ return chunk_embeddings , chunk_positions
0 commit comments