Skip to content

Commit 560682e

Browse files
Fixed transform() in KeyNMF
1 parent e282104 commit 560682e

2 files changed

Lines changed: 11 additions & 4 deletions

File tree

turftopic/models/_keynmf.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def batch_extract_keywords(
121121
documents: list[str],
122122
embeddings: Optional[np.ndarray] = None,
123123
seed_embedding: Optional[np.ndarray] = None,
124+
fitting: bool = True,
124125
) -> list[dict[str, float]]:
125126
if not len(documents):
126127
return []
@@ -136,9 +137,11 @@ def batch_extract_keywords(
136137
"Number of documents doesn't match number of embeddings."
137138
)
138139
keywords = []
139-
vectorizer = clone(self.vectorizer)
140-
document_term_matrix = vectorizer.fit_transform(documents)
141-
batch_vocab = vectorizer.get_feature_names_out()
140+
if fitting:
141+
document_term_matrix = self.vectorizer.fit_transform(documents)
142+
else:
143+
document_term_matrix = self.vectorizer.transform(documents)
144+
batch_vocab = self.vectorizer.get_feature_names_out()
142145
new_terms = list(set(batch_vocab) - set(self.key_to_index.keys()))
143146
if len(new_terms):
144147
self._add_terms(new_terms)

turftopic/models/keynmf.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def extract_keywords(
9999
self,
100100
batch_or_document: Union[str, list[str]],
101101
embeddings: Optional[np.ndarray] = None,
102+
fitting: bool = True,
102103
) -> list[dict[str, float]]:
103104
"""Extracts keywords from a document or a batch of documents.
104105
@@ -115,6 +116,7 @@ def extract_keywords(
115116
batch_or_document,
116117
embeddings=embeddings,
117118
seed_embedding=self.seed_embedding,
119+
fitting=fitting,
118120
)
119121

120122
def vectorize(
@@ -260,7 +262,9 @@ def transform(
260262
)
261263
if keywords is None:
262264
keywords = self.extract_keywords(
263-
list(raw_documents), embeddings=embeddings
265+
list(raw_documents),
266+
embeddings=embeddings,
267+
fitting=False,
264268
)
265269
return self.model.transform(keywords)
266270

0 commit comments

Comments
 (0)