Skip to content

Commit a9a015e

Browse files
committed
NA centroids and terms when no documents are available for a topic within a time bin
1 parent 389f3e9 commit a9a015e

1 file changed

Lines changed: 24 additions & 12 deletions

File tree

turftopic/models/cluster.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
ClusteringTopicModel(n_reduce_to=10)
3636
"""
3737

38-
feature_message = """
39-
feature_importance must be one of 'soft-c-tf-idf', 'c-tf-idf', 'centroids'
38+
feature_message = """
39+
feature_importance must be one of 'soft-c-tf-idf', 'c-tf-idf', 'centroid'
4040
"""
4141

4242

@@ -68,14 +68,19 @@ def smallest_hierarchical_join(
6868

6969

7070
def calculate_topic_vectors(
71-
cluster_labels: np.ndarray, embeddings: np.ndarray
71+
cluster_labels: np.ndarray, embeddings: np.ndarray,
72+
time_index: Optional[np.ndarray] = None,
7273
) -> np.ndarray:
7374
"""Calculates topic centroids."""
7475
centroids = []
7576
unique_labels = np.unique(cluster_labels)
7677
unique_labels = np.sort(unique_labels)
7778
for label in unique_labels:
78-
centroid = np.mean(embeddings[cluster_labels == label], axis=0)
79+
label_index = cluster_labels == label
80+
if time_index is not None:
81+
label_index = label_index * time_index
82+
label_embeddings = embeddings[label_index]
83+
centroid = np.mean(label_embeddings, axis=0)
7984
centroids.append(centroid)
8085
centroids = np.stack(centroids)
8186
return centroids
@@ -169,6 +174,7 @@ def __init__(
169174
self.feature_importance = feature_importance
170175
self.n_reduce_to = n_reduce_to
171176
self.reduction_method = reduction_method
177+
self.components_ = None
172178

173179
def _merge_agglomerative(self, n_reduce_to: int) -> np.ndarray:
174180
n_topics = self.components_.shape[0]
@@ -309,11 +315,10 @@ def fit_transform_dynamic(
309315
if embeddings is None:
310316
embeddings = self.encoder_.encode(raw_documents)
311317
for i_timebin in np.arange(len(self.time_bin_edges) - 1):
312-
if self.labels_ is not None:
318+
if self.components_ is not None:
313319
doc_topic_matrix = label_binarize(self.labels_, classes=self.classes_)
314320
else:
315321
doc_topic_matrix = self.fit_transform(raw_documents, embeddings=embeddings)
316-
317322
topic_importances = doc_topic_matrix[time_labels == i_timebin].sum(axis=0)
318323
topic_importances = topic_importances / topic_importances.sum()
319324
t_doc_term_matrix = self.doc_term_matrix[time_labels == i_timebin]
@@ -326,17 +331,24 @@ def fit_transform_dynamic(
326331
)
327332
elif self.feature_importance == 'c-tf-idf':
328333
components = ctf_idf(t_doc_topic_matrix, t_doc_term_matrix)
329-
elif self.feature_importance == 'centroids':
330-
t_labels = self.labels_[time_labels == i_timebin]
331-
t_embeddings = embeddings[time_labels == i_timebin]
332-
t_topic_vectors = calculate_topic_vectors(t_labels, t_embeddings)
334+
elif self.feature_importance == 'centroid':
335+
time_index = time_labels == i_timebin
336+
t_topic_vectors = calculate_topic_vectors(
337+
self.labels_, embeddings, time_index,
338+
)
339+
topic_mask = np.isnan(t_topic_vectors).all(
340+
axis=1, keepdims=True
341+
)
342+
t_topic_vectors[:] = 0
333343
components = cluster_centroid_distance(
334344
t_topic_vectors,
335345
self.vocab_embeddings,
336346
metric="cosine",
337347
)
338-
mask = t_doc_term_matrix.sum(axis=0)
339-
components = components * mask
348+
components *= topic_mask
349+
mask_terms = t_doc_term_matrix.sum(axis=0).astype(np.float64)
350+
mask_terms[mask_terms == 0] = np.nan
351+
components *= mask_terms
340352
temporal_components.append(components)
341353
temporal_importances.append(topic_importances)
342354
self.temporal_components_ = np.stack(temporal_components)

0 commit comments

Comments
 (0)