@@ -53,7 +53,9 @@ def smallest_hierarchical_join(
5353 classes = list (classes_ )
5454 while len (classes ) > n_to :
5555 smallest = np .argmin (topic_sizes )
56- dist = cosine_distances (np .atleast_2d (topic_vectors [smallest ]), topic_vectors )
56+ dist = cosine_distances (
57+ np .atleast_2d (topic_vectors [smallest ]), topic_vectors
58+ )
5759 closest = np .argsort (dist [0 ])[1 ]
5860 merge_inst .append ((classes [smallest ], classes [closest ]))
5961 classes .pop (smallest )
@@ -68,7 +70,8 @@ def smallest_hierarchical_join(
6870
6971
7072def calculate_topic_vectors (
71- cluster_labels : np .ndarray , embeddings : np .ndarray ,
73+ cluster_labels : np .ndarray ,
74+ embeddings : np .ndarray ,
7275 time_index : Optional [np .ndarray ] = None ,
7376) -> np .ndarray :
7477 """Calculates topic centroids."""
@@ -138,20 +141,22 @@ class ClusteringTopicModel(ContextualModel, ClusterMixin, DynamicTopicModel):
138141
139142 def __init__ (
140143 self ,
141- encoder : Union [Encoder , str ] = "sentence-transformers/all-MiniLM-L6-v2" ,
144+ encoder : Union [
145+ Encoder , str
146+ ] = "sentence-transformers/all-MiniLM-L6-v2" ,
142147 vectorizer : Optional [CountVectorizer ] = None ,
143148 dimensionality_reduction : Optional [TransformerMixin ] = None ,
144149 clustering : Optional [ClusterMixin ] = None ,
145150 feature_importance : Literal [
146151 "c-tf-idf" , "soft-c-tf-idf" , "centroid"
147152 ] = "soft-c-tf-idf" ,
148153 n_reduce_to : Optional [int ] = None ,
149- reduction_method : Literal ["agglomerative" , "smallest" ] = "agglomerative" ,
154+ reduction_method : Literal [
155+ "agglomerative" , "smallest"
156+ ] = "agglomerative" ,
150157 ):
151158 self .encoder = encoder
152- if feature_importance not in ["c-tf-idf" ,
153- "soft-c-tf-idf" ,
154- "centroid" ]:
159+ if feature_importance not in ["c-tf-idf" , "soft-c-tf-idf" , "centroid" ]:
155160 raise ValueError (feature_message )
156161 if isinstance (encoder , int ):
157162 raise TypeError (integer_message )
@@ -168,7 +173,9 @@ def __init__(
168173 else :
169174 self .clustering = clustering
170175 if dimensionality_reduction is None :
171- self .dimensionality_reduction = TSNE (n_components = 2 , metric = "cosine" )
176+ self .dimensionality_reduction = TSNE (
177+ n_components = 2 , metric = "cosine"
178+ )
172179 else :
173180 self .dimensionality_reduction = dimensionality_reduction
174181 self .feature_importance = feature_importance
@@ -225,7 +232,9 @@ def _estimate_parameters(
225232 self .vocab_embeddings = self .encoder_ .encode (
226233 self .vectorizer .get_feature_names_out ()
227234 ) # type: ignore
228- document_topic_matrix = label_binarize (self .labels_ , classes = self .classes_ )
235+ document_topic_matrix = label_binarize (
236+ self .labels_ , classes = self .classes_
237+ )
229238 if self .feature_importance == "soft-c-tf-idf" :
230239 self .components_ = soft_ctf_idf (document_topic_matrix , doc_term_matrix ) # type: ignore
231240 elif self .feature_importance == "centroid" :
@@ -266,7 +275,9 @@ def fit_predict(
266275 self .doc_term_matrix = self .vectorizer .fit_transform (raw_documents )
267276 console .log ("Term extraction done." )
268277 status .update ("Reducing Dimensionality" )
269- reduced_embeddings = self .dimensionality_reduction .fit_transform (embeddings )
278+ reduced_embeddings = self .dimensionality_reduction .fit_transform (
279+ embeddings
280+ )
270281 console .log ("Dimensionality reduction done." )
271282 status .update ("Clustering documents" )
272283 self .labels_ = self .clustering .fit_predict (reduced_embeddings )
@@ -279,7 +290,9 @@ def fit_predict(
279290 console .log ("Parameter estimation done." )
280291 if self .n_reduce_to is not None :
281292 n_topics = self .classes_ .shape [0 ]
282- status .update (f"Reducing topics from { n_topics } to { self .n_reduce_to } " )
293+ status .update (
294+ f"Reducing topics from { n_topics } to { self .n_reduce_to } "
295+ )
283296 if self .reduction_method == "agglomerative" :
284297 self .labels_ = self ._merge_agglomerative (self .n_reduce_to )
285298 else :
@@ -316,25 +329,32 @@ def fit_transform_dynamic(
316329 embeddings = self .encoder_ .encode (raw_documents )
317330 for i_timebin in np .arange (len (self .time_bin_edges ) - 1 ):
318331 if self .components_ is not None :
319- doc_topic_matrix = label_binarize (self .labels_ , classes = self .classes_ )
332+ doc_topic_matrix = label_binarize (
333+ self .labels_ , classes = self .classes_
334+ )
320335 else :
321- doc_topic_matrix = self .fit_transform (raw_documents , embeddings = embeddings )
322- topic_importances = doc_topic_matrix [time_labels == i_timebin ].sum (axis = 0 )
336+ doc_topic_matrix = self .fit_transform (
337+ raw_documents , embeddings = embeddings
338+ )
339+ topic_importances = doc_topic_matrix [time_labels == i_timebin ].sum (
340+ axis = 0
341+ )
323342 topic_importances = topic_importances / topic_importances .sum ()
324343 t_doc_term_matrix = self .doc_term_matrix [time_labels == i_timebin ]
325344 t_doc_topic_matrix = doc_topic_matrix [time_labels == i_timebin ]
326345 if "c-tf-idf" in self .feature_importance :
327- if self .feature_importance == ' soft-c-tf-idf' :
346+ if self .feature_importance == " soft-c-tf-idf" :
328347 components = soft_ctf_idf (
329- t_doc_topic_matrix ,
330- t_doc_term_matrix
348+ t_doc_topic_matrix , t_doc_term_matrix
331349 )
332- elif self .feature_importance == ' c-tf-idf' :
350+ elif self .feature_importance == " c-tf-idf" :
333351 components = ctf_idf (t_doc_topic_matrix , t_doc_term_matrix )
334- elif self .feature_importance == ' centroid' :
352+ elif self .feature_importance == " centroid" :
335353 time_index = time_labels == i_timebin
336354 t_topic_vectors = calculate_topic_vectors (
337- self .labels_ , embeddings , time_index ,
355+ self .labels_ ,
356+ embeddings ,
357+ time_index ,
338358 )
339359 topic_mask = np .isnan (t_topic_vectors ).all (
340360 axis = 1 , keepdims = True
0 commit comments