Skip to content

Commit 11e3c6e

Browse files
Merge pull request #68 from x-tabdeveloping/embedding_dim
Informative error message
2 parents 0bc8a0e + 636f667 commit 11e3c6e

4 files changed

Lines changed: 43 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ line-length=79
66

77
[tool.poetry]
88
name = "turftopic"
9-
version = "0.7.0"
9+
version = "0.7.1"
1010
description = "Topic modeling with contextual representations from sentence transformers."
1111
authors = ["Márton Kardos <power.up1163@gmail.com>"]
1212
license = "MIT"

turftopic/models/_keynmf.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616

1717
from turftopic.base import Encoder
1818

19+
NOT_MATCHING_ERROR = (
20+
"Document embedding dimensionality ({n_dims}) doesn't match term embedding dimensionality ({n_word_dims}). "
21+
+ "Perhaps you are using precomputed embeddings but forgot to pass an encoder to your model. "
22+
+ "Try to initialize the model with the encoder you used for computing the embeddings."
23+
)
24+
1925

2026
def batched(iterable, n: int) -> Iterable[list[str]]:
2127
"Batch data into tuples of length n. The last batch may be shorter."
@@ -143,6 +149,13 @@ def batch_extract_keywords(
143149
self.term_embeddings[self.key_to_index[term]]
144150
for term in batch_vocab[important_terms]
145151
]
152+
if self.term_embeddings.shape[1] != embeddings.shape[1]:
153+
raise ValueError(
154+
NOT_MATCHING_ERROR.format(
155+
n_dims=embeddings.shape[1],
156+
n_word_dims=self.term_embeddings.shape[1],
157+
)
158+
)
146159
sim = cosine_similarity(embedding, word_embeddings).astype(
147160
np.float64
148161
)

turftopic/models/cluster.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@
3939
feature_importance must be one of 'soft-c-tf-idf', 'c-tf-idf', 'centroid'
4040
"""
4141

42+
NOT_MATCHING_ERROR = (
43+
"Document embedding dimensionality ({n_dims}) doesn't match term embedding dimensionality ({n_word_dims}). "
44+
+ "Perhaps you are using precomputed embeddings but forgot to pass an encoder to your model. "
45+
+ "Try to initialize the model with the encoder you used for computing the embeddings."
46+
)
47+
4248

4349
def smallest_hierarchical_join(
4450
topic_vectors: np.ndarray,
@@ -370,6 +376,16 @@ def estimate_components(
370376
self.vocab_embeddings = self.encoder_.encode(
371377
self.vectorizer.get_feature_names_out()
372378
) # type: ignore
379+
if (
380+
self.vocab_embeddings.shape[1]
381+
!= self.topic_vectors_.shape[1]
382+
):
383+
raise ValueError(
384+
NOT_MATCHING_ERROR.format(
385+
n_dims=self.topic_vectors_.shape[1],
386+
n_word_dims=self.vocab_embeddings.shape[1],
387+
)
388+
)
373389
self.components_ = cluster_centroid_distance(
374390
self.topic_vectors_,
375391
self.vocab_embeddings,

turftopic/models/decomp.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111
from turftopic.base import ContextualModel, Encoder
1212
from turftopic.vectorizer import default_vectorizer
1313

14+
NOT_MATCHING_ERROR = (
15+
"Document embedding dimensionality ({n_dims}) doesn't match term embedding dimensionality ({n_word_dims}). "
16+
+ "Perhaps you are using precomputed embeddings but forgot to pass an encoder to your model. "
17+
+ "Try to initialize the model with the encoder you used for computing the embeddings."
18+
)
19+
1420

1521
class SemanticSignalSeparation(ContextualModel):
1622
"""Separates the embedding matrix into 'semantic signals' with
@@ -115,6 +121,13 @@ def fit_transform(
115121
console.log("Term extraction done.")
116122
status.update("Encoding vocabulary")
117123
self.vocab_embeddings = self.encoder_.encode(vocab)
124+
if self.vocab_embeddings.shape[1] != self.embeddings.shape[1]:
125+
raise ValueError(
126+
NOT_MATCHING_ERROR.format(
127+
n_dims=self.embeddings.shape[1],
128+
n_word_dims=self.vocab_embeddings.shape[1],
129+
)
130+
)
118131
console.log("Vocabulary encoded.")
119132
status.update("Estimating term importances")
120133
vocab_topic = self.decomposition.transform(self.vocab_embeddings)

0 commit comments

Comments
 (0)