|
1 | 1 | import itertools |
2 | 2 | import json |
3 | 3 | import random |
| 4 | +from datetime import datetime |
4 | 5 | from typing import Dict, Iterable, List, Optional, Union |
5 | 6 |
|
6 | 7 | import numpy as np |
7 | 8 | from rich.console import Console |
8 | 9 | from sentence_transformers import SentenceTransformer |
9 | 10 | from sklearn.decomposition import NMF, MiniBatchNMF |
| 11 | +from sklearn.decomposition._nmf import (_initialize_nmf, |
| 12 | + _update_coordinate_descent) |
10 | 13 | from sklearn.exceptions import NotFittedError |
11 | 14 | from sklearn.feature_extraction import DictVectorizer |
12 | 15 | from sklearn.feature_extraction.text import CountVectorizer |
13 | 16 | from sklearn.metrics.pairwise import cosine_similarity |
| 17 | +from sklearn.utils import check_array |
14 | 18 |
|
15 | 19 | from turftopic.base import ContextualModel, Encoder |
16 | 20 | from turftopic.data import TopicData |
| 21 | +from turftopic.dynamic import DynamicTopicModel, bin_timestamps |
17 | 22 | from turftopic.vectorizer import default_vectorizer |
18 | 23 |
|
19 | 24 |
|
| 25 | +def fit_timeslice( |
| 26 | + X, |
| 27 | + W, |
| 28 | + H, |
| 29 | + tol=1e-4, |
| 30 | + max_iter=200, |
| 31 | + l1_reg_W=0, |
| 32 | + l1_reg_H=0, |
| 33 | + l2_reg_W=0, |
| 34 | + l2_reg_H=0, |
| 35 | + verbose=0, |
| 36 | + shuffle=False, |
| 37 | + random_state=None, |
| 38 | +): |
| 39 | + """Fits topic_term_matrix based on a precomputed document_topic_matrix. |
| 40 | + This is used to get temporal components in dynamic KeyNMF. |
| 41 | + """ |
| 42 | + Ht = check_array(H.T, order="C") |
| 43 | + if random_state is None: |
| 44 | + rng = np.random.mtrand._rand |
| 45 | + else: |
| 46 | + rng = np.random.RandomState(random_state) |
| 47 | + for n_iter in range(1, max_iter + 1): |
| 48 | + violation = 0.0 |
| 49 | + violation += _update_coordinate_descent( |
| 50 | + X.T, Ht, W, l1_reg_H, l2_reg_H, shuffle, rng |
| 51 | + ) |
| 52 | + if n_iter == 1: |
| 53 | + violation_init = violation |
| 54 | + if violation_init == 0: |
| 55 | + break |
| 56 | + if verbose: |
| 57 | + print("violation:", violation / violation_init) |
| 58 | + if violation / violation_init <= tol: |
| 59 | + if verbose: |
| 60 | + print("Converged at iteration", n_iter + 1) |
| 61 | + break |
| 62 | + return W, Ht.T, n_iter |
| 63 | + |
| 64 | + |
20 | 65 | def batched(iterable, n: int) -> Iterable[List[str]]: |
21 | 66 | "Batch data into tuples of length n. The last batch may be shorter." |
22 | 67 | # batched('ABCDEFG', 3) --> ABC DEF G |
@@ -48,7 +93,7 @@ def __iter__(self) -> Iterable[Dict[str, float]]: |
48 | 93 | yield deserialize_keywords(line.strip()) |
49 | 94 |
|
50 | 95 |
|
51 | | -class KeyNMF(ContextualModel): |
| 96 | +class KeyNMF(ContextualModel, DynamicTopicModel): |
52 | 97 | """Extracts keywords from documents based on semantic similarity of |
53 | 98 | term encodings to document encodings. |
54 | 99 | Topics are then extracted with non-negative matrix factorization from |
@@ -305,3 +350,32 @@ def prepare_topic_data( |
305 | 350 | "topic_names": self.topic_names, |
306 | 351 | } |
307 | 352 | return res |
| 353 | + |
| 354 | + def fit_transform_dynamic( |
| 355 | + self, |
| 356 | + raw_documents, |
| 357 | + timestamps: list[datetime], |
| 358 | + embeddings: Optional[np.ndarray] = None, |
| 359 | + bins: Union[int, list[datetime]] = 10, |
| 360 | + ) -> np.ndarray: |
| 361 | + time_labels, self.time_bin_edges = bin_timestamps(timestamps, bins) |
| 362 | + topic_data = self.prepare_topic_data( |
| 363 | + raw_documents, embeddings=embeddings |
| 364 | + ) |
| 365 | + n_bins = len(self.time_bin_edges) + 1 |
| 366 | + n_comp, n_vocab = self.components_.shape |
| 367 | + self.temporal_components_ = np.zeros((n_bins, n_comp, n_vocab)) |
| 368 | + self.temporal_importance_ = np.zeros((n_bins, n_comp)) |
| 369 | + for label in np.unique(time_labels): |
| 370 | + idx = np.nonzero(time_labels == label) |
| 371 | + X = topic_data["document_term_matrix"][idx] |
| 372 | + W = topic_data["document_topic_matrix"][idx] |
| 373 | + _, H = _initialize_nmf( |
| 374 | + X, self.components_.shape[0], random_state=self.random_state |
| 375 | + ) |
| 376 | + _, H, _ = fit_timeslice(X, W, H, random_state=self.random_state) |
| 377 | + self.temporal_components_[label] = H |
| 378 | + topic_importances = np.squeeze(np.asarray(W.sum(axis=0))) |
| 379 | + topic_importances = topic_importances / topic_importances.sum() |
| 380 | + self.temporal_importance_[label] = topic_importances |
| 381 | + return topic_data["document_topic_matrix"] |
0 commit comments