|
12 | 12 | from sklearn.preprocessing import label_binarize |
13 | 13 |
|
14 | 14 | from turftopic.base import ContextualModel, Encoder |
15 | | -from turftopic.dynamic import DynamicTopicModel, bin_timestamps |
| 15 | +from turftopic.dynamic import DynamicTopicModel |
16 | 16 | from turftopic.feature_importance import ( |
17 | 17 | cluster_centroid_distance, |
18 | 18 | ctf_idf, |
@@ -335,20 +335,26 @@ def fit_transform_dynamic( |
335 | 335 | embeddings: Optional[np.ndarray] = None, |
336 | 336 | bins: Union[int, list[datetime]] = 10, |
337 | 337 | ): |
338 | | - time_labels, self.time_bin_edges = bin_timestamps(timestamps, bins) |
339 | | - temporal_components = [] |
340 | | - temporal_importances = [] |
| 338 | + time_labels, self.time_bin_edges = self.bin_timestamps( |
| 339 | + timestamps, bins |
| 340 | + ) |
| 341 | + n_comp, n_vocab = self.components_.shape |
| 342 | + n_bins = len(self.time_bin_edges) - 1 |
| 343 | + if hasattr(self, "components_"): |
| 344 | + doc_topic_matrix = label_binarize( |
| 345 | + self.labels_, classes=self.classes_ |
| 346 | + ) |
| 347 | + else: |
| 348 | + doc_topic_matrix = self.fit_transform( |
| 349 | + raw_documents, embeddings=embeddings |
| 350 | + ) |
| 351 | + self.temporal_components = np.zeros( |
| 352 | + (n_bins, n_comp, n_vocab), dtype=doc_topic_matrix.dtype |
| 353 | + ) |
| 354 | + self.temporal_importance_ = np.zeros((n_bins, n_comp)) |
341 | 355 | if embeddings is None: |
342 | 356 | embeddings = self.encoder_.encode(raw_documents) |
343 | | - for i_timebin in np.arange(len(self.time_bin_edges) - 1): |
344 | | - if hasattr(self, "components_"): |
345 | | - doc_topic_matrix = label_binarize( |
346 | | - self.labels_, classes=self.classes_ |
347 | | - ) |
348 | | - else: |
349 | | - doc_topic_matrix = self.fit_transform( |
350 | | - raw_documents, embeddings=embeddings |
351 | | - ) |
| 357 | + for i_timebin in np.unique(time_labels): |
352 | 358 | topic_importances = doc_topic_matrix[time_labels == i_timebin].sum( |
353 | 359 | axis=0 |
354 | 360 | ) |
@@ -382,8 +388,6 @@ def fit_transform_dynamic( |
382 | 388 | mask_terms = t_doc_term_matrix.sum(axis=0).astype(np.float64) |
383 | 389 | mask_terms[mask_terms == 0] = np.nan |
384 | 390 | components *= mask_terms |
385 | | - temporal_components.append(components) |
386 | | - temporal_importances.append(topic_importances) |
387 | | - self.temporal_components_ = np.stack(temporal_components) |
388 | | - self.temporal_importance_ = np.stack(temporal_importances) |
| 391 | + self.temporal_components_[i_timebin] = components |
| 392 | + self.temporal_importance_[i_timebin].append(topic_importances) |
389 | 393 | return doc_topic_matrix |
0 commit comments