File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1414from sklearn .cluster import HDBSCAN
1515from sklearn .exceptions import NotFittedError
1616from sklearn .feature_extraction .text import CountVectorizer
17+ from sklearn .metrics .pairwise import cosine_similarity
1718from sklearn .preprocessing import label_binarize , scale
1819
1920from turftopic .base import ContextualModel , Encoder
@@ -448,7 +449,11 @@ def fit_transform(
448449 self , raw_documents , y = None , embeddings : Optional [np .ndarray ] = None
449450 ):
450451 labels = self .fit_predict (raw_documents , y , embeddings )
451- return label_binarize (labels , classes = self .classes_ )
452+ document_topic_matrix = label_binarize (labels , classes = self .classes_ )
453+ document_topic_matrix = document_topic_matrix * cosine_similarity (
454+ self .embeddings , self ._calculate_topic_vectors ()
455+ )
456+ return document_topic_matrix
452457
453458 def estimate_temporal_components (
454459 self ,
You can’t perform that action at this time.
0 commit comments