Skip to content

Commit 12e6e86

Browse files
Merge pull request #129 from x-tabdeveloping/multimodal-sbert
Multimodal sbert
2 parents 77eeab1 + e192079 commit 12e6e86

8 files changed

Lines changed: 28 additions & 48 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ profile = "black"
99

1010
[project]
1111
name = "turftopic"
12-
version = "0.25.0"
12+
version = "0.25.1"
1313
description = "Topic modeling with contextual representations from sentence transformers."
1414
authors = [
1515
{ name = "Márton Kardos <power.up1163@gmail.com>", email = "martonkardos@cas.au.dk" }

turftopic/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def encode_documents(self, raw_documents: Iterable[str]) -> np.ndarray:
4141
"""
4242
if not hasattr(self.encoder_, "encode"):
4343
return self.encoder.get_text_embeddings(list(raw_documents))
44-
return self.encoder_.encode(raw_documents)
44+
return self.encoder_.encode(list(raw_documents))
4545

4646
@abstractmethod
4747
def fit_transform(

turftopic/late.py

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -53,40 +53,20 @@ def _encode_tokens(
5353
Start and end character of each token in each document.
5454
"""
5555
self.has_used_token_level = True
56-
token_embeddings = []
57-
offsets = []
58-
for start_index in trange(
59-
0,
60-
len(texts),
61-
batch_size,
62-
desc="Encoding batches...",
63-
):
64-
batch = texts[start_index : start_index + batch_size]
65-
features = self.tokenize(batch)
66-
with torch.no_grad():
67-
output_features = self.forward(features)
68-
n_tokens = output_features["attention_mask"].sum(axis=1)
69-
# Find first nonzero elements in each document
70-
# The document could be padded from the left, so we have to watch out for this.
71-
start_token = torch.argmax(
72-
(output_features["attention_mask"] > 0).to(torch.long), axis=1
73-
)
74-
end_token = start_token + n_tokens
75-
for i_doc in range(len(batch)):
76-
_token_embeddings = (
77-
output_features["token_embeddings"][
78-
i_doc, start_token[i_doc] : end_token[i_doc], :
79-
]
80-
.float()
81-
.numpy(force=True)
82-
)
83-
_n = _token_embeddings.shape[0]
84-
# We extract the character offsets and prune it at the maximum context length
85-
_offsets = self.tokenizer(
86-
batch[i_doc], return_offsets_mapping=True, verbose=False
87-
)["offset_mapping"][:_n]
88-
token_embeddings.append(_token_embeddings)
89-
offsets.append(_offsets)
56+
token_embeddings = self.encode(
57+
texts, output_value="token_embeddings", batch_size=batch_size
58+
)
59+
offsets = self.tokenizer(
60+
texts, return_offsets_mapping=True, verbose=False
61+
)["offset_mapping"]
62+
offsets = [
63+
offs[: len(embs)] for offs, embs in zip(offsets, token_embeddings)
64+
]
65+
token_embeddings = [
66+
embs.numpy(force=True)
67+
for embs in token_embeddings
68+
if torch.is_tensor(embs)
69+
]
9070
return token_embeddings, offsets
9171

9272
def encode_tokens(

turftopic/models/ctm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def transform(
193193
Document-topic matrix.
194194
"""
195195
if embeddings is None:
196-
embeddings = self.encoder_.encode(raw_documents)
196+
embeddings = self.encode_documents(raw_documents)
197197
if self.combined:
198198
bow = self.vectorizer.fit_transform(raw_documents)
199199
contextual_embeddings = np.concatenate(
@@ -219,7 +219,7 @@ def fit_transform(
219219
with console.status("Fitting model") as status:
220220
if embeddings is None:
221221
status.update("Encoding documents")
222-
embeddings = self.encoder_.encode(raw_documents)
222+
embeddings = self.encode_documents(raw_documents)
223223
console.log("Documents encoded.")
224224
status.update("Extracting terms.")
225225
document_term_matrix = self.vectorizer.fit_transform(raw_documents)

turftopic/models/cvp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def __init__(
6262
self.classes_ = np.array([name for name in self._seeds])
6363
self.concept_matrix_ = []
6464
for _, (positive, negative) in self._seeds.items():
65-
positive_emb = self.encoder_.encode(positive)
66-
negative_emb = self.encoder_.encode(negative)
65+
positive_emb = self.encoder_.encode(list(positive))
66+
negative_emb = self.encoder_.encode(list(negative))
6767
cv = np.mean(positive_emb, axis=0) - np.mean(negative_emb, axis=0)
6868
self.concept_matrix_.append(cv / np.linalg.norm(cv))
6969
self.concept_matrix_ = np.stack(self.concept_matrix_)
@@ -92,7 +92,7 @@ def fit_transform(self, raw_documents=None, y=None, embeddings=None):
9292
"Either embeddings or raw_documents has to be passed, both are None."
9393
)
9494
if embeddings is None:
95-
embeddings = self.encoder_.encode(raw_documents)
95+
embeddings = self.encoder_.encode(list(raw_documents))
9696
return embeddings @ self.concept_matrix_.T
9797

9898
def transform(self, raw_documents=None, embeddings=None):

turftopic/models/decomp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def fit_transform(
140140
with console.status("Fitting model") as status:
141141
if self.embeddings is None:
142142
status.update("Encoding documents")
143-
self.embeddings = self.encoder_.encode(raw_documents)
143+
self.embeddings = self.encode_documents(raw_documents)
144144
console.log("Documents encoded.")
145145
status.update("Decomposing embeddings")
146146
if isinstance(self.decomposition, FastICA) and (y is not None):
@@ -153,7 +153,7 @@ def fit_transform(
153153
vocab = self.vectorizer.fit(raw_documents).get_feature_names_out()
154154
console.log("Term extraction done.")
155155
status.update("Encoding vocabulary")
156-
self.vocab_embeddings = self.encoder_.encode(vocab)
156+
self.vocab_embeddings = self.encode_documents(vocab)
157157
if self.vocab_embeddings.shape[1] != self.embeddings.shape[1]:
158158
raise ValueError(
159159
NOT_MATCHING_ERROR.format(
@@ -636,7 +636,7 @@ def transform(
636636
Document-topic matrix.
637637
"""
638638
if embeddings is None:
639-
embeddings = self.encoder_.encode(raw_documents)
639+
embeddings = self.encode_documents(raw_documents)
640640
return self.decomposition.transform(embeddings)
641641

642642
def print_topics(

turftopic/models/gmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def fit_transform(
206206
with console.status("Fitting model") as status:
207207
if embeddings is None:
208208
status.update("Encoding documents")
209-
embeddings = self.encoder_.encode(raw_documents)
209+
embeddings = self.encode_documents(raw_documents)
210210
console.log("Documents encoded.")
211211
self.embeddings = embeddings
212212
status.update("Extracting terms.")
@@ -325,7 +325,7 @@ def transform(
325325
Document-topic matrix.
326326
"""
327327
if embeddings is None:
328-
embeddings = self.encoder_.encode(raw_documents)
328+
embeddings = self.encode_documents(raw_documents)
329329
if self.dimensionality_reduction is not None:
330330
embeddings = self.dimensionality_reduction.transform(embeddings)
331331
return self.gmm_.predict_proba(embeddings)

turftopic/models/senstopic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def fit_transform(
149149
with console.status("Fitting model") as status:
150150
if self.embeddings is None:
151151
status.update("Encoding documents")
152-
self.embeddings = self.encoder_.encode(raw_documents)
152+
self.embeddings = self.encode_documents(raw_documents)
153153
console.log("Documents encoded.")
154154
if self.n_components == "auto":
155155
status.update("Finding the number of components.")
@@ -177,7 +177,7 @@ def fit_transform(
177177
console.log("Term extraction done.")
178178
if getattr(self, "vocab_embeddings", None) is None:
179179
status.update("Encoding vocabulary")
180-
self.vocab_embeddings = self.encoder_.encode(vocab)
180+
self.vocab_embeddings = self.encode_documents(vocab)
181181
if self.vocab_embeddings.shape[1] != self.embeddings.shape[1]:
182182
raise ValueError(
183183
NOT_MATCHING_ERROR.format(

0 commit comments

Comments
 (0)