Skip to content

Commit 4bcceb8

Browse files
fit_transform() in clustering models now predicts importance based on proximity to cluster centroid
1 parent 5c829a1 commit 4bcceb8

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

turftopic/models/cluster.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sklearn.cluster import HDBSCAN
1515
from sklearn.exceptions import NotFittedError
1616
from sklearn.feature_extraction.text import CountVectorizer
17+
from sklearn.metrics.pairwise import cosine_similarity
1718
from sklearn.preprocessing import label_binarize, scale
1819

1920
from turftopic.base import ContextualModel, Encoder
@@ -448,7 +449,11 @@ def fit_transform(
448449
self, raw_documents, y=None, embeddings: Optional[np.ndarray] = None
449450
):
450451
labels = self.fit_predict(raw_documents, y, embeddings)
451-
return label_binarize(labels, classes=self.classes_)
452+
document_topic_matrix = label_binarize(labels, classes=self.classes_)
453+
document_topic_matrix = document_topic_matrix * cosine_similarity(
454+
self.embeddings, self._calculate_topic_vectors()
455+
)
456+
return document_topic_matrix
452457

453458
def estimate_temporal_components(
454459
self,

0 commit comments

Comments
 (0)