3535ClusteringTopicModel(n_reduce_to=10)
3636"""
3737
38- feature_message = """
39- feature_importance must be one of 'soft-c-tf-idf', 'c-tf-idf', 'centroids '
38+ feature_message = """
39+ feature_importance must be one of 'soft-c-tf-idf', 'c-tf-idf', 'centroid '
4040"""
4141
4242
@@ -68,14 +68,19 @@ def smallest_hierarchical_join(
6868
6969
7070def calculate_topic_vectors (
71- cluster_labels : np .ndarray , embeddings : np .ndarray
71+ cluster_labels : np .ndarray , embeddings : np .ndarray ,
72+ time_index : Optional [np .ndarray ] = None ,
7273) -> np .ndarray :
7374 """Calculates topic centroids."""
7475 centroids = []
7576 unique_labels = np .unique (cluster_labels )
7677 unique_labels = np .sort (unique_labels )
7778 for label in unique_labels :
78- centroid = np .mean (embeddings [cluster_labels == label ], axis = 0 )
79+ label_index = cluster_labels == label
80+ if time_index is not None :
81+ label_index = label_index * time_index
82+ label_embeddings = embeddings [label_index ]
83+ centroid = np .mean (label_embeddings , axis = 0 )
7984 centroids .append (centroid )
8085 centroids = np .stack (centroids )
8186 return centroids
@@ -169,6 +174,7 @@ def __init__(
169174 self .feature_importance = feature_importance
170175 self .n_reduce_to = n_reduce_to
171176 self .reduction_method = reduction_method
177+ self .components_ = None
172178
173179 def _merge_agglomerative (self , n_reduce_to : int ) -> np .ndarray :
174180 n_topics = self .components_ .shape [0 ]
@@ -309,11 +315,10 @@ def fit_transform_dynamic(
309315 if embeddings is None :
310316 embeddings = self .encoder_ .encode (raw_documents )
311317 for i_timebin in np .arange (len (self .time_bin_edges ) - 1 ):
312- if self .labels_ is not None :
318+ if self .components_ is not None :
313319 doc_topic_matrix = label_binarize (self .labels_ , classes = self .classes_ )
314320 else :
315321 doc_topic_matrix = self .fit_transform (raw_documents , embeddings = embeddings )
316-
317322 topic_importances = doc_topic_matrix [time_labels == i_timebin ].sum (axis = 0 )
318323 topic_importances = topic_importances / topic_importances .sum ()
319324 t_doc_term_matrix = self .doc_term_matrix [time_labels == i_timebin ]
@@ -326,17 +331,24 @@ def fit_transform_dynamic(
326331 )
327332 elif self .feature_importance == 'c-tf-idf' :
328333 components = ctf_idf (t_doc_topic_matrix , t_doc_term_matrix )
329- elif self .feature_importance == 'centroids' :
330- t_labels = self .labels_ [time_labels == i_timebin ]
331- t_embeddings = embeddings [time_labels == i_timebin ]
332- t_topic_vectors = calculate_topic_vectors (t_labels , t_embeddings )
334+ elif self .feature_importance == 'centroid' :
335+ time_index = time_labels == i_timebin
336+ t_topic_vectors = calculate_topic_vectors (
337+ self .labels_ , embeddings , time_index ,
338+ )
339+ topic_mask = np .isnan (t_topic_vectors ).all (
340+ axis = 1 , keepdims = True
341+ )
342+ t_topic_vectors [:] = 0
333343 components = cluster_centroid_distance (
334344 t_topic_vectors ,
335345 self .vocab_embeddings ,
336346 metric = "cosine" ,
337347 )
338- mask = t_doc_term_matrix .sum (axis = 0 )
339- components = components * mask
348+ components *= topic_mask
349+ mask_terms = t_doc_term_matrix .sum (axis = 0 ).astype (np .float64 )
350+ mask_terms [mask_terms == 0 ] = np .nan
351+ components *= mask_terms
340352 temporal_components .append (components )
341353 temporal_importances .append (topic_importances )
342354 self .temporal_components_ = np .stack (temporal_components )
0 commit comments