Skip to content

Commit e282104

Browse files
Added seed phrases to KeyNMF
1 parent 5dc7c7a commit e282104

2 files changed

Lines changed: 28 additions & 7 deletions

File tree

turftopic/models/_keynmf.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def batch_extract_keywords(
120120
self,
121121
documents: list[str],
122122
embeddings: Optional[np.ndarray] = None,
123+
seed_embedding: Optional[np.ndarray] = None,
123124
) -> list[dict[str, float]]:
124125
if not len(documents):
125126
return []
@@ -142,6 +143,16 @@ def batch_extract_keywords(
142143
if len(new_terms):
143144
self._add_terms(new_terms)
144145
total = embeddings.shape[0]
146+
# Relevance based on similarity to seed embedding
147+
document_relevance = None
148+
if seed_embedding is not None:
149+
if self.metric == "cosine":
150+
document_relevance = cosine_similarity(
151+
[seed_embedding], embeddings
152+
)[0]
153+
else:
154+
document_relevance = np.dot(embeddings, seed_embedding)
155+
document_relevance[document_relevance < 0] = 0
145156
for i in range(total):
146157
terms = document_term_matrix[i, :].todense()
147158
embedding = embeddings[i].reshape(1, -1)
@@ -162,14 +173,13 @@ def batch_extract_keywords(
162173
)
163174
)
164175
if self.metric == "cosine":
165-
sim = cosine_similarity(embedding, word_embeddings).astype(
166-
np.float64
167-
)
176+
sim = cosine_similarity(embedding, word_embeddings)
168177
sim = np.ravel(sim)
169178
else:
170-
sim = np.dot(word_embeddings, embedding[0]).T.astype(
171-
np.float64
172-
)
179+
sim = np.dot(word_embeddings, embedding[0]).T
180+
# If a seed is specified, we multiply by the document's relevance
181+
if document_relevance is not None:
182+
sim = document_relevance[i] * sim
173183
kth = min(self.top_n, len(sim) - 1)
174184
top = np.argpartition(-sim, kth)[:kth]
175185
top_words = batch_vocab[important_terms][top]

turftopic/models/keynmf.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ class KeyNMF(ContextualModel, DynamicTopicModel):
4949
Random state to use so that results are exactly reproducible.
5050
metric: "cosine" or "dot", default "cosine"
5151
Similarity metric to use for keyword extraction.
52+
seed_phrase: str, default None
53+
Describes an aspect of the corpus that the model should explore.
54+
It can be a free-text query, such as
55+
"Christian Denominations: Protestantism and Catholicism"
5256
"""
5357

5458
def __init__(
@@ -61,6 +65,7 @@ def __init__(
6165
top_n: int = 25,
6266
random_state: Optional[int] = None,
6367
metric: Literal["cosine", "dot"] = "cosine",
68+
seed_phrase: Optional[str] = None,
6469
):
6570
self.random_state = random_state
6671
self.n_components = n_components
@@ -85,6 +90,10 @@ def __init__(
8590
encoder=self.encoder_,
8691
metric=self.metric,
8792
)
93+
self.seed_phrase = seed_phrase
94+
self.seed_embedding = None
95+
if self.seed_phrase is not None:
96+
self.seed_embedding = self.encoder_.encode([self.seed_phrase])[0]
8897

8998
def extract_keywords(
9099
self,
@@ -103,7 +112,9 @@ def extract_keywords(
103112
if isinstance(batch_or_document, str):
104113
batch_or_document = [batch_or_document]
105114
return self.extractor.batch_extract_keywords(
106-
batch_or_document, embeddings=embeddings
115+
batch_or_document,
116+
embeddings=embeddings,
117+
seed_embedding=self.seed_embedding,
107118
)
108119

109120
def vectorize(

0 commit comments

Comments
 (0)