1- from typing import Callable , Optional
1+ import itertools
2+ import warnings
3+ from typing import Callable , Iterable , Optional , Union
24
35import numpy as np
6+ import torch
7+ from sentence_transformers import SentenceTransformer
48from sklearn .base import TransformerMixin
9+ from sklearn .preprocessing import normalize
10+ from tokenizers import Tokenizer
11+ from tqdm import trange
512
6- from turftopic .encoders .contextual import Offsets
7-
13+ Offsets = list [tuple [int , int ]]
814Lengths = list [int ]
915
1016
17+ class LateSentenceTransformer (SentenceTransformer ):
18+ def encode (
19+ self , sentences : Union [str , list [str ], np .ndarray ], * args , ** kwargs
20+ ):
21+ warnings .warn (
22+ "Encoder is contextual but topic model is not using contextual embeddings. Perhaps you wanted to use another topic model."
23+ )
24+ return super ().encode (sentences , * args , ** kwargs )
25+
26+ def _encode_tokens (
27+ self ,
28+ texts ,
29+ batch_size = 32 ,
30+ show_progress_bar = True ,
31+ ) -> tuple [list [np .ndarray ], list [Offsets ]]:
32+ """
33+ Returns
34+ -------
35+ token_embeddings: list[np.ndarray]
36+ Embedding matrix of tokens for each document.
37+ offsets: list[list[tuple[int, int]]]
38+ Start and end character of each token in each document.
39+ """
40+ token_embeddings = []
41+ offsets = []
42+ tokenizer = Tokenizer .from_pretrained (self .model_card_data .base_model )
43+ for start_index in trange (
44+ 0 ,
45+ len (texts ),
46+ batch_size ,
47+ disable = not show_progress_bar ,
48+ desc = "Encoding tokens..." ,
49+ ):
50+ batch = texts [start_index : start_index + batch_size ]
51+ features = self .tokenize (batch )
52+ with torch .no_grad ():
53+ output_features = self .forward (features )
54+ n_tokens = output_features ["attention_mask" ].sum (axis = 1 )
55+ # Find first nonzero elements in each document
56+ # The document could be padded from the left, so we have to watch out for this.
57+ start_token = torch .argmax (
58+ (output_features ["attention_mask" ] > 0 ).to (torch .long ), axis = 1
59+ )
60+ end_token = start_token + n_tokens
61+ for i_doc in range (len (batch )):
62+ _token_embeddings = output_features ["token_embeddings" ][
63+ i_doc , start_token [i_doc ] : end_token [i_doc ], :
64+ ].numpy (force = True )
65+ _offsets = tokenizer .encode (batch [i_doc ]).offsets
66+ token_embeddings .append (_token_embeddings )
67+ offsets .append (_offsets )
68+ return token_embeddings , offsets
69+
70+ def encode_tokens (
71+ self ,
72+ sentences : list [str ],
73+ batch_size : int = 32 ,
74+ show_progress_bar : bool = True ,
75+ ):
76+ """Produces contextual token embeddings over all documents.
77+
78+ Parameters
79+ ----------
80+ sentences: list[str]
81+ Documents to encode contextually.
82+ batch_size: int, default 32
83+ Size of the batch of document to encode at once.
84+ show_progress_bar: bool, default True
85+ Indicates whether a progress bar should be displayed when encoding.
86+
87+ Returns
88+ -------
89+ token_embeddings: list[np.ndarray]
90+ Embedding matrix of tokens for each document.
91+ offsets: list[list[tuple[int, int]]]
92+ Start and end character of each token in each document.
93+ """
94+ # This is needed because the above implementation does not normalize embeddings,
95+ # which normally happens to document embeddings.
96+ token_embeddings , offsets = self ._encode_tokens (
97+ sentences ,
98+ batch_size = batch_size ,
99+ show_progress_bar = show_progress_bar ,
100+ )
101+ token_embeddings = [normalize (emb ) for emb in token_embeddings ]
102+ return token_embeddings , offsets
103+
104+ def encode_windows (
105+ self ,
106+ sentences : list [str ],
107+ batch_size : int = 32 ,
108+ window_size : int = 50 ,
109+ step_size : int = 40 ,
110+ show_progress_bar : bool = True ,
111+ ):
112+ """Produces contextual embeddings for a sliding window of tokens similar to C-Top2Vec.
113+
114+ Parameters
115+ ----------
116+ sentences: list[str]
117+ Documents to encode contextually.
118+ batch_size: int, default 32
119+ Size of the batch of document to encode at once.
120+ window_size: int, default 50
121+ Size of the sliding window.
122+ step_size: int, default 40
123+ Step size of the window.
124+ If step_size < window_size, windows will overlap.
125+ If step_size == window_size, then windows are separate.
126+ If step_size > window_size, there will be gaps between the windows.
127+ In this case, we throw a warning, as this is probably unintended behaviour.
128+ show_progress_bar: bool, default True
129+ Indicates whether a progress bar should be displayed when encoding.
130+
131+ Returns
132+ -------
133+ window_embeddings: list[np.ndarray]
134+ Embedding matrix of windows for each document.
135+ offsets: list[list[tuple[int, int]]]
136+ Start and end character of each token in each document.
137+ """
138+ token_embeddings , token_offsets = self ._encode_tokens (
139+ sentences ,
140+ batch_size = batch_size ,
141+ show_progress_bar = show_progress_bar ,
142+ )
143+ window_embeddings = []
144+ window_offsets = []
145+ for emb , offs in zip (token_embeddings , token_offsets ):
146+ _offsets = []
147+ _embeddings = []
148+ for start_index in trange (0 , len (emb ), step_size ):
149+ end_index = start_index + window_size
150+ window_emb = np .mean (emb [start_index :end_index ], axis = 0 )
151+ _embeddings .append (window_emb )
152+ _offsets .append ((offs [start_index ][0 ], offs [end_index ][1 ]))
153+ window_embeddings .append (normalize (np .stack (_embeddings )))
154+ window_offsets .append (_offsets )
155+ return window_embeddings , window_offsets
156+
157+
11158def flatten_repr (
12159 repr : list [np .ndarray ],
13160) -> tuple [np .ndarray , Lengths ]:
@@ -78,12 +225,37 @@ class LateModel(TransformerMixin):
78225 def __init__ (
79226 self ,
80227 model : TransformerMixin ,
81- batch_size : int = 32 ,
228+ batch_size : Optional [int ] = 32 ,
229+ window_size : Optional [int ] = None ,
230+ step_size : Optional [int ] = None ,
82231 pooling : Optional [Callable ] = None ,
83232 ):
84233 self .model = model
85234 self .batch_size = batch_size
86235 self .pooling = pooling
236+ self .window_size = window_size
237+ self .step_size = step_size
238+
239+ def encode_documents (
240+ self , raw_documents : list [str ]
241+ ) -> tuple [np .ndarray , list [Offsets ]]:
242+ if self .window_size is None :
243+ embeddings , offsets = self .model .encoder .encode_tokens (
244+ raw_documents , batch_size = self .batch_size
245+ )
246+ return embeddings , offsets
247+ # If the window_size is specified, but not step_size, we set the step size to the window size
248+ # Thereby getting non-overlapping windows
249+ step_size = (
250+ self .window_size if self .step_size is None else self .step_size
251+ )
252+ embeddings , offsets = self .model .encoder .encode_windows (
253+ raw_documents ,
254+ batch_size = self .batch_size ,
255+ window_size = self .window_size ,
256+ step_size = step_size ,
257+ )
258+ return embeddings , offsets
87259
88260 def transform (
89261 self ,
@@ -92,9 +264,7 @@ def transform(
92264 offsets : list [Offsets ] = None ,
93265 ):
94266 if (embeddings is None ) or (offsets is None ):
95- embeddings , offsets = self .model .encoder .encode_tokens (
96- raw_documents , batch_size = self .batch_size
97- )
267+ embeddings , offsets = self .encode_documents (raw_documents )
98268 flat_embeddings , lengths = flatten_repr (embeddings )
99269 chunks = get_document_chunks (raw_documents , offsets )
100270 out_array = self .model .transform (chunks , embeddings = flat_embeddings )
@@ -111,9 +281,7 @@ def fit_transform(
111281 offsets : list [Offsets ] = None ,
112282 ):
113283 if (embeddings is None ) or (offsets is None ):
114- embeddings , offsets = self .model .encoder .encode_tokens (
115- raw_documents , batch_size = self .batch_size
116- )
284+ embeddings , offsets = self .encode_documents (raw_documents )
117285 flat_embeddings , lengths = flatten_repr (embeddings )
118286 chunks = get_document_chunks (raw_documents , offsets )
119287 out_array = self .model .fit_transform (
0 commit comments