Skip to content

Commit 83fd785

Browse files
Fixed raw encode calls
1 parent 443bf25 commit 83fd785

5 files changed

Lines changed: 12 additions & 12 deletions

File tree

turftopic/models/ctm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def transform(
193193
Document-topic matrix.
194194
"""
195195
if embeddings is None:
196-
embeddings = self.encoder_.encode(raw_documents)
196+
embeddings = self.encode_documents(raw_documents)
197197
if self.combined:
198198
bow = self.vectorizer.fit_transform(raw_documents)
199199
contextual_embeddings = np.concatenate(
@@ -219,7 +219,7 @@ def fit_transform(
219219
with console.status("Fitting model") as status:
220220
if embeddings is None:
221221
status.update("Encoding documents")
222-
embeddings = self.encoder_.encode(raw_documents)
222+
embeddings = self.encode_documents(raw_documents)
223223
console.log("Documents encoded.")
224224
status.update("Extracting terms.")
225225
document_term_matrix = self.vectorizer.fit_transform(raw_documents)

turftopic/models/cvp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def __init__(
6262
self.classes_ = np.array([name for name in self._seeds])
6363
self.concept_matrix_ = []
6464
for _, (positive, negative) in self._seeds.items():
65-
positive_emb = self.encoder_.encode(positive)
66-
negative_emb = self.encoder_.encode(negative)
65+
positive_emb = self.encoder_.encode(list(positive))
66+
negative_emb = self.encoder_.encode(list(negative))
6767
cv = np.mean(positive_emb, axis=0) - np.mean(negative_emb, axis=0)
6868
self.concept_matrix_.append(cv / np.linalg.norm(cv))
6969
self.concept_matrix_ = np.stack(self.concept_matrix_)
@@ -92,7 +92,7 @@ def fit_transform(self, raw_documents=None, y=None, embeddings=None):
9292
"Either embeddings or raw_documents has to be passed, both are None."
9393
)
9494
if embeddings is None:
95-
embeddings = self.encoder_.encode(raw_documents)
95+
embeddings = self.encoder_.encode(list(raw_documents))
9696
return embeddings @ self.concept_matrix_.T
9797

9898
def transform(self, raw_documents=None, embeddings=None):

turftopic/models/decomp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def fit_transform(
140140
with console.status("Fitting model") as status:
141141
if self.embeddings is None:
142142
status.update("Encoding documents")
143-
self.embeddings = self.encoder_.encode(raw_documents)
143+
self.embeddings = self.encode_documents(raw_documents)
144144
console.log("Documents encoded.")
145145
status.update("Decomposing embeddings")
146146
if isinstance(self.decomposition, FastICA) and (y is not None):
@@ -153,7 +153,7 @@ def fit_transform(
153153
vocab = self.vectorizer.fit(raw_documents).get_feature_names_out()
154154
console.log("Term extraction done.")
155155
status.update("Encoding vocabulary")
156-
self.vocab_embeddings = self.encoder_.encode(vocab)
156+
self.vocab_embeddings = self.encode_documents(vocab)
157157
if self.vocab_embeddings.shape[1] != self.embeddings.shape[1]:
158158
raise ValueError(
159159
NOT_MATCHING_ERROR.format(
@@ -636,7 +636,7 @@ def transform(
636636
Document-topic matrix.
637637
"""
638638
if embeddings is None:
639-
embeddings = self.encoder_.encode(raw_documents)
639+
embeddings = self.encode_documents(raw_documents)
640640
return self.decomposition.transform(embeddings)
641641

642642
def print_topics(

turftopic/models/gmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def fit_transform(
206206
with console.status("Fitting model") as status:
207207
if embeddings is None:
208208
status.update("Encoding documents")
209-
embeddings = self.encoder_.encode(raw_documents)
209+
embeddings = self.encode_documents(raw_documents)
210210
console.log("Documents encoded.")
211211
self.embeddings = embeddings
212212
status.update("Extracting terms.")
@@ -325,7 +325,7 @@ def transform(
325325
Document-topic matrix.
326326
"""
327327
if embeddings is None:
328-
embeddings = self.encoder_.encode(raw_documents)
328+
embeddings = self.encode_documents(raw_documents)
329329
if self.dimensionality_reduction is not None:
330330
embeddings = self.dimensionality_reduction.transform(embeddings)
331331
return self.gmm_.predict_proba(embeddings)

turftopic/models/senstopic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def fit_transform(
149149
with console.status("Fitting model") as status:
150150
if self.embeddings is None:
151151
status.update("Encoding documents")
152-
self.embeddings = self.encoder_.encode(raw_documents)
152+
self.embeddings = self.encode_documents(raw_documents)
153153
console.log("Documents encoded.")
154154
if self.n_components == "auto":
155155
status.update("Finding the number of components.")
@@ -177,7 +177,7 @@ def fit_transform(
177177
console.log("Term extraction done.")
178178
if getattr(self, "vocab_embeddings", None) is None:
179179
status.update("Encoding vocabulary")
180-
self.vocab_embeddings = self.encoder_.encode(vocab)
180+
self.vocab_embeddings = self.encode_documents(vocab)
181181
if self.vocab_embeddings.shape[1] != self.embeddings.shape[1]:
182182
raise ValueError(
183183
NOT_MATCHING_ERROR.format(

0 commit comments

Comments
 (0)