@@ -120,6 +120,7 @@ def batch_extract_keywords(
120120 self ,
121121 documents : list [str ],
122122 embeddings : Optional [np .ndarray ] = None ,
123+ seed_embedding : Optional [np .ndarray ] = None ,
123124 ) -> list [dict [str , float ]]:
124125 if not len (documents ):
125126 return []
@@ -142,6 +143,16 @@ def batch_extract_keywords(
142143 if len (new_terms ):
143144 self ._add_terms (new_terms )
144145 total = embeddings .shape [0 ]
146+ # Relevance based on similarity to seed embedding
147+ document_relevance = None
148+ if seed_embedding is not None :
149+ if self .metric == "cosine" :
150+ document_relevance = cosine_similarity (
151+ [seed_embedding ], embeddings
152+ )[0 ]
153+ else :
154+ document_relevance = np .dot (embeddings , seed_embedding )
155+ document_relevance [document_relevance < 0 ] = 0
145156 for i in range (total ):
146157 terms = document_term_matrix [i , :].todense ()
147158 embedding = embeddings [i ].reshape (1 , - 1 )
@@ -162,14 +173,13 @@ def batch_extract_keywords(
162173 )
163174 )
164175 if self .metric == "cosine" :
165- sim = cosine_similarity (embedding , word_embeddings ).astype (
166- np .float64
167- )
176+ sim = cosine_similarity (embedding , word_embeddings )
168177 sim = np .ravel (sim )
169178 else :
170- sim = np .dot (word_embeddings , embedding [0 ]).T .astype (
171- np .float64
172- )
179+ sim = np .dot (word_embeddings , embedding [0 ]).T
180+ # If a seed is specified, we multiply by the document's relevance
181+ if document_relevance is not None :
182+ sim = document_relevance [i ] * sim
173183 kth = min (self .top_n , len (sim ) - 1 )
174184 top = np .argpartition (- sim , kth )[:kth ]
175185 top_words = batch_vocab [important_terms ][top ]
0 commit comments