77from sentence_transformers import SentenceTransformer
88from sklearn .base import TransformerMixin
99from sklearn .preprocessing import normalize
10- from tokenizers import Tokenizer
1110from tqdm import trange
1211
12+ from turftopic .base import ContextualModel
13+
1314Offsets = list [tuple [int , int ]]
1415Lengths = list [int ]
1516
@@ -39,13 +40,11 @@ def _encode_tokens(
3940 """
4041 token_embeddings = []
4142 offsets = []
42- tokenizer = Tokenizer .from_pretrained (self .model_card_data .base_model )
4343 for start_index in trange (
4444 0 ,
4545 len (texts ),
4646 batch_size ,
47- disable = not show_progress_bar ,
48- desc = "Encoding tokens..." ,
47+ desc = "Encoding batches..." ,
4948 ):
5049 batch = texts [start_index : start_index + batch_size ]
5150 features = self .tokenize (batch )
@@ -59,10 +58,18 @@ def _encode_tokens(
5958 )
6059 end_token = start_token + n_tokens
6160 for i_doc in range (len (batch )):
62- _token_embeddings = output_features ["token_embeddings" ][
63- i_doc , start_token [i_doc ] : end_token [i_doc ], :
64- ].numpy (force = True )
65- _offsets = tokenizer .encode (batch [i_doc ]).offsets
61+ _token_embeddings = (
62+ output_features ["token_embeddings" ][
63+ i_doc , start_token [i_doc ] : end_token [i_doc ], :
64+ ]
65+ .float ()
66+ .numpy (force = True )
67+ )
68+ _n = _token_embeddings .shape [0 ]
69+ # We extract the character offsets and prune it at the maximum context length
70+ _offsets = self .tokenizer (
71+ batch [i_doc ], return_offsets_mapping = True , verbose = False
72+ )["offset_mapping" ][:_n ]
6673 token_embeddings .append (_token_embeddings )
6774 offsets .append (_offsets )
6875 return token_embeddings , offsets
@@ -145,11 +152,12 @@ def encode_windows(
145152 for emb , offs in zip (token_embeddings , token_offsets ):
146153 _offsets = []
147154 _embeddings = []
148- for start_index in trange (0 , len (emb ), step_size ):
155+ for start_index in range (0 , len (emb ), step_size ):
149156 end_index = start_index + window_size
150157 window_emb = np .mean (emb [start_index :end_index ], axis = 0 )
158+ off = offs [start_index :end_index ]
151159 _embeddings .append (window_emb )
152- _offsets .append ((offs [ start_index ][0 ], offs [ end_index ][1 ]))
160+ _offsets .append ((off [ 0 ][0 ], off [ - 1 ][1 ]))
153161 window_embeddings .append (normalize (np .stack (_embeddings )))
154162 window_offsets .append (_offsets )
155163 return window_embeddings , window_offsets
@@ -197,7 +205,7 @@ def unflatten_repr(
197205 repr = []
198206 start_index = 0
199207 for length in lengths :
200- repr .append (flat_repr [start_index : length ])
208+ repr .append (flat_repr [start_index : start_index + length ])
201209 start_index += length
202210 return repr
203211
@@ -217,11 +225,11 @@ def get_document_chunks(
217225 chunks = []
218226 for doc , _offs in zip (raw_documents , offsets ):
219227 for start_char , end_char in _offs :
220- chunks .append (raw_documents [start_char , end_char ])
228+ chunks .append (doc [start_char : end_char ])
221229 return chunks
222230
223231
224- class LateModel ( TransformerMixin ):
232+ class LateWrapper ( ContextualModel , TransformerMixin ):
225233 def __init__ (
226234 self ,
227235 model : TransformerMixin ,
@@ -236,7 +244,7 @@ def __init__(
236244 self .window_size = window_size
237245 self .step_size = step_size
238246
239- def encode_documents (
247+ def encode_late (
240248 self , raw_documents : list [str ]
241249 ) -> tuple [np .ndarray , list [Offsets ]]:
242250 if self .window_size is None :
@@ -264,7 +272,7 @@ def transform(
264272 offsets : list [Offsets ] = None ,
265273 ):
266274 if (embeddings is None ) or (offsets is None ):
267- embeddings , offsets = self .encode_documents (raw_documents )
275+ embeddings , offsets = self .encode_late (raw_documents )
268276 flat_embeddings , lengths = flatten_repr (embeddings )
269277 chunks = get_document_chunks (raw_documents , offsets )
270278 out_array = self .model .transform (chunks , embeddings = flat_embeddings )
@@ -281,7 +289,7 @@ def fit_transform(
281289 offsets : list [Offsets ] = None ,
282290 ):
283291 if (embeddings is None ) or (offsets is None ):
284- embeddings , offsets = self .encode_documents (raw_documents )
292+ embeddings , offsets = self .encode_late (raw_documents )
285293 flat_embeddings , lengths = flatten_repr (embeddings )
286294 chunks = get_document_chunks (raw_documents , offsets )
287295 out_array = self .model .fit_transform (
@@ -291,3 +299,23 @@ def fit_transform(
291299 return unflatten_repr (out_array , lengths )
292300 else :
293301 return pool_flat (out_array , lengths )
302+
303+ @property
304+ def components_ (self ):
305+ return self .model .components_
306+
307+ @property
308+ def hierarchy (self ):
309+ return self .model .hierarchy
310+
311+ @property
312+ def topic_names (self ):
313+ return self .model .topic_names
314+
315+ @property
316+ def classes_ (self ):
317+ return self .model .classes_
318+
319+ @property
320+ def vectorizer (self ):
321+ return self .model .vectorizer
0 commit comments