Skip to content

Commit e0c6732

Browse files
Refactored multimodal encoding
1 parent f1208ab commit e0c6732

1 file changed

Lines changed: 56 additions & 42 deletions

File tree

turftopic/multimodal.py

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55

66
import numpy as np
77
from PIL import Image
8-
from sentence_transformers import SentenceTransformer
98

109
from turftopic.data import TopicData
11-
from turftopic.encoders.multimodal import MultimodalEncoder
1210

1311
UrlStr = str
1412

@@ -42,6 +40,61 @@ class MultimodalEmbeddings(TypedDict):
4240
document_embeddings: np.ndarray
4341

4442

43+
def encode_multimodal(
44+
encoder, sentences: list[str], images: list[ImageRepr]
45+
) -> dict[str, np.ndarray]:
46+
"""Produce multimodal embeddings of the documents passed to the model.
47+
48+
Parameters
49+
----------
50+
encoder
51+
MTEB or SentenceTransformer compatible embedding model.
52+
sentences: list[str]
53+
Textual documents to encode.
54+
images: list[ImageRepr]
55+
Corresponding images for each document.
56+
57+
Returns
58+
-------
59+
MultimodalEmbeddings
60+
Text, image and joint document embeddings.
61+
"""
62+
if len(sentences) != len(images):
63+
raise ValueError("Images and documents were not the same length.")
64+
if hasattr(encoder, "get_text_embeddings"):
65+
text_embeddings = np.array(encoder.get_text_embeddings(sentences))
66+
else:
67+
text_embeddings = encoder.encode(sentences)
68+
embedding_size = text_embeddings.shape[1]
69+
images = list(_load_images(images))
70+
if hasattr(encoder, "get_image_embeddings"):
71+
image_embeddings = np.array(encoder.get_image_embeddings(images))
72+
else:
73+
image_embeddings = []
74+
for image in images:
75+
if image is not None:
76+
image_embeddings.append(encoder.encode(image))
77+
else:
78+
image_embeddings.append(np.full(embedding_size, np.nan))
79+
image_embeddings = np.stack(image_embeddings)
80+
if hasattr(encoder, "get_fused_embeddings"):
81+
document_embeddings = np.array(
82+
encoder.get_fused_embeddings(
83+
texts=sentences,
84+
images=images,
85+
)
86+
)
87+
else:
88+
document_embeddings = _naive_join_embeddings(
89+
text_embeddings, image_embeddings
90+
)
91+
return {
92+
"text_embeddings": text_embeddings,
93+
"image_embeddings": image_embeddings,
94+
"document_embeddings": document_embeddings,
95+
}
96+
97+
4598
class MultimodalModel:
4699
"""Base model for multimodal topic models."""
47100

@@ -65,46 +118,7 @@ def encode_multimodal(
65118
Text, image and joint document embeddings.
66119
67120
"""
68-
if len(sentences) != len(images):
69-
raise ValueError("Images and documents were not the same length.")
70-
if hasattr(self.encoder_, "get_text_embeddings"):
71-
text_embeddings = np.array(
72-
self.encoder_.get_text_embeddings(sentences)
73-
)
74-
else:
75-
text_embeddings = self.encoder_.encode(sentences)
76-
embedding_size = text_embeddings.shape[1]
77-
images = list(_load_images(images))
78-
if hasattr(self.encoder_, "get_image_embeddings"):
79-
image_embeddings = np.array(
80-
self.encoder_.get_image_embeddings(images)
81-
)
82-
else:
83-
image_embeddings = []
84-
for image in images:
85-
if image is not None:
86-
image_embeddings.append(self.encoder_.encode(image))
87-
else:
88-
image_embeddings.append(np.full(embedding_size, np.nan))
89-
image_embeddings = np.stack(image_embeddings)
90-
print(image_embeddings)
91-
if hasattr(self.encoder_, "get_fused_embeddings"):
92-
document_embeddings = np.array(
93-
self.encoder_.get_fused_embeddings(
94-
texts=sentences,
95-
images=images,
96-
)
97-
)
98-
else:
99-
document_embeddings = _naive_join_embeddings(
100-
text_embeddings, image_embeddings
101-
)
102-
103-
return {
104-
"text_embeddings": text_embeddings,
105-
"image_embeddings": image_embeddings,
106-
"document_embeddings": document_embeddings,
107-
}
121+
return encode_multimodal(self.encoder_, sentences, images)
108122

109123
@staticmethod
110124
def validate_embeddings(embeddings: Optional[MultimodalEmbeddings]):

0 commit comments

Comments
 (0)