Skip to content

Commit 2e35ee8

Browse files
Fixed dynamic functionality in all models
1 parent 7d0deab commit 2e35ee8

5 files changed

Lines changed: 38 additions & 30 deletions

File tree

turftopic/dynamic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def plot_topics_over_time(self, top_k: int = 6):
314314
continue
315315
high = high[np.argsort(-values)]
316316
name_over_time.append(", ".join(vocab[high]))
317-
times = self.time_bin_edges[1:]
317+
times = self.time_bin_edges[:-1]
318318
fig.add_trace(
319319
go.Scatter(
320320
x=times,

turftopic/models/_keynmf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def fit_transform_dynamic(
254254
time_bin_edges: list[datetime],
255255
) -> np.ndarray:
256256
self.time_bin_edges = time_bin_edges
257-
n_bins = len(time_bin_edges) + 1
257+
n_bins = len(time_bin_edges) - 1
258258
document_term_matrix = self.vectorize(keywords, fitting=True)
259259
check_non_negative(document_term_matrix, "NMF (input X)")
260260
document_topic_matrix, H = _initialize_nmf(

turftopic/models/cluster.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sklearn.preprocessing import label_binarize
1313

1414
from turftopic.base import ContextualModel, Encoder
15-
from turftopic.dynamic import DynamicTopicModel, bin_timestamps
15+
from turftopic.dynamic import DynamicTopicModel
1616
from turftopic.feature_importance import (
1717
cluster_centroid_distance,
1818
ctf_idf,
@@ -335,20 +335,26 @@ def fit_transform_dynamic(
335335
embeddings: Optional[np.ndarray] = None,
336336
bins: Union[int, list[datetime]] = 10,
337337
):
338-
time_labels, self.time_bin_edges = bin_timestamps(timestamps, bins)
339-
temporal_components = []
340-
temporal_importances = []
338+
time_labels, self.time_bin_edges = self.bin_timestamps(
339+
timestamps, bins
340+
)
341+
n_comp, n_vocab = self.components_.shape
342+
n_bins = len(self.time_bin_edges) - 1
343+
if hasattr(self, "components_"):
344+
doc_topic_matrix = label_binarize(
345+
self.labels_, classes=self.classes_
346+
)
347+
else:
348+
doc_topic_matrix = self.fit_transform(
349+
raw_documents, embeddings=embeddings
350+
)
351+
self.temporal_components = np.zeros(
352+
(n_bins, n_comp, n_vocab), dtype=doc_topic_matrix.dtype
353+
)
354+
self.temporal_importance_ = np.zeros((n_bins, n_comp))
341355
if embeddings is None:
342356
embeddings = self.encoder_.encode(raw_documents)
343-
for i_timebin in np.arange(len(self.time_bin_edges) - 1):
344-
if hasattr(self, "components_"):
345-
doc_topic_matrix = label_binarize(
346-
self.labels_, classes=self.classes_
347-
)
348-
else:
349-
doc_topic_matrix = self.fit_transform(
350-
raw_documents, embeddings=embeddings
351-
)
357+
for i_timebin in np.unique(time_labels):
352358
topic_importances = doc_topic_matrix[time_labels == i_timebin].sum(
353359
axis=0
354360
)
@@ -382,8 +388,6 @@ def fit_transform_dynamic(
382388
mask_terms = t_doc_term_matrix.sum(axis=0).astype(np.float64)
383389
mask_terms[mask_terms == 0] = np.nan
384390
components *= mask_terms
385-
temporal_components.append(components)
386-
temporal_importances.append(topic_importances)
387-
self.temporal_components_ = np.stack(temporal_components)
388-
self.temporal_importance_ = np.stack(temporal_importances)
391+
self.temporal_components_[i_timebin] = components
392+
self.temporal_importance_[i_timebin].append(topic_importances)
389393
return doc_topic_matrix

turftopic/models/gmm.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sklearn.pipeline import Pipeline, make_pipeline
1111

1212
from turftopic.base import ContextualModel, Encoder
13-
from turftopic.dynamic import DynamicTopicModel, bin_timestamps
13+
from turftopic.dynamic import DynamicTopicModel
1414
from turftopic.feature_importance import soft_ctf_idf
1515
from turftopic.vectorizer import default_vectorizer
1616

@@ -168,7 +168,9 @@ def fit_transform_dynamic(
168168
embeddings: Optional[np.ndarray] = None,
169169
bins: Union[int, list[datetime]] = 10,
170170
):
171-
time_labels, self.time_bin_edges = bin_timestamps(timestamps, bins)
171+
time_labels, self.time_bin_edges = self.bin_timestamps(
172+
timestamps, bins
173+
)
172174
if hasattr(self, "components_"):
173175
doc_topic_matrix = self.transform(
174176
raw_documents, embeddings=embeddings
@@ -178,9 +180,13 @@ def fit_transform_dynamic(
178180
raw_documents, embeddings=embeddings
179181
)
180182
document_term_matrix = self.vectorizer.transform(raw_documents)
181-
temporal_components = []
182-
temporal_importances = []
183-
for i_timebin in np.arange(len(self.time_bin_edges) - 1):
183+
n_comp, n_vocab = self.components_.shape
184+
n_bins = len(self.time_bin_edges) - 1
185+
self.temporal_components = np.zeros(
186+
(n_bins, n_comp, n_vocab), dtype=document_term_matrix.dtype
187+
)
188+
self.temporal_importance_ = np.zeros((n_bins, n_comp))
189+
for i_timebin in np.unique(time_labels):
184190
topic_importances = doc_topic_matrix[time_labels == i_timebin].sum(
185191
axis=0
186192
)
@@ -190,8 +196,6 @@ def fit_transform_dynamic(
190196
doc_topic_matrix[time_labels == i_timebin],
191197
document_term_matrix[time_labels == i_timebin], # type: ignore
192198
)
193-
temporal_components.append(components)
194-
temporal_importances.append(topic_importances)
195-
self.temporal_components_ = np.stack(temporal_components)
196-
self.temporal_importance_ = np.stack(temporal_importances)
199+
self.temporal_components_[i_timebin] = components
200+
self.temporal_importance_[i_timebin].append(topic_importances)
197201
return doc_topic_matrix

turftopic/models/keynmf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from turftopic.base import ContextualModel, Encoder
1111
from turftopic.data import TopicData
12-
from turftopic.dynamic import DynamicTopicModel, bin_timestamps
12+
from turftopic.dynamic import DynamicTopicModel
1313
from turftopic.models._keynmf import KeywordExtractor, KeywordNMF
1414

1515

@@ -300,7 +300,7 @@ def partial_fit_dynamic(
300300
)
301301
else:
302302
self.time_bin_edges = bins
303-
time_labels, self.time_bin_edges = bin_timestamps(
303+
time_labels, self.time_bin_edges = self.bin_timestamps(
304304
timestamps, self.time_bin_edges
305305
)
306306
if keywords is None:

0 commit comments

Comments
 (0)