55
66import numpy as np
77from PIL import Image
8- from sentence_transformers import SentenceTransformer
98
109from turftopic .data import TopicData
11- from turftopic .encoders .multimodal import MultimodalEncoder
1210
1311UrlStr = str
1412
@@ -42,6 +40,61 @@ class MultimodalEmbeddings(TypedDict):
4240 document_embeddings : np .ndarray
4341
4442
43+ def encode_multimodal (
44+ encoder , sentences : list [str ], images : list [ImageRepr ]
45+ ) -> dict [str , np .ndarray ]:
46+ """Produce multimodal embeddings of the documents passed to the model.
47+
48+ Parameters
49+ ----------
50+ encoder
51+ MTEB or SentenceTransformer compatible embedding model.
52+ sentences: list[str]
53+ Textual documents to encode.
54+ images: list[ImageRepr]
55+ Corresponding images for each document.
56+
57+ Returns
58+ -------
59+ MultimodalEmbeddings
60+ Text, image and joint document embeddings.
61+ """
62+ if len (sentences ) != len (images ):
63+ raise ValueError ("Images and documents were not the same length." )
64+ if hasattr (encoder , "get_text_embeddings" ):
65+ text_embeddings = np .array (encoder .get_text_embeddings (sentences ))
66+ else :
67+ text_embeddings = encoder .encode (sentences )
68+ embedding_size = text_embeddings .shape [1 ]
69+ images = list (_load_images (images ))
70+ if hasattr (encoder , "get_image_embeddings" ):
71+ image_embeddings = np .array (encoder .get_image_embeddings (images ))
72+ else :
73+ image_embeddings = []
74+ for image in images :
75+ if image is not None :
76+ image_embeddings .append (encoder .encode (image ))
77+ else :
78+ image_embeddings .append (np .full (embedding_size , np .nan ))
79+ image_embeddings = np .stack (image_embeddings )
80+ if hasattr (encoder , "get_fused_embeddings" ):
81+ document_embeddings = np .array (
82+ encoder .get_fused_embeddings (
83+ texts = sentences ,
84+ images = images ,
85+ )
86+ )
87+ else :
88+ document_embeddings = _naive_join_embeddings (
89+ text_embeddings , image_embeddings
90+ )
91+ return {
92+ "text_embeddings" : text_embeddings ,
93+ "image_embeddings" : image_embeddings ,
94+ "document_embeddings" : document_embeddings ,
95+ }
96+
97+
4598class MultimodalModel :
4699 """Base model for multimodal topic models."""
47100
@@ -65,46 +118,7 @@ def encode_multimodal(
65118 Text, image and joint document embeddings.
66119
67120 """
68- if len (sentences ) != len (images ):
69- raise ValueError ("Images and documents were not the same length." )
70- if hasattr (self .encoder_ , "get_text_embeddings" ):
71- text_embeddings = np .array (
72- self .encoder_ .get_text_embeddings (sentences )
73- )
74- else :
75- text_embeddings = self .encoder_ .encode (sentences )
76- embedding_size = text_embeddings .shape [1 ]
77- images = list (_load_images (images ))
78- if hasattr (self .encoder_ , "get_image_embeddings" ):
79- image_embeddings = np .array (
80- self .encoder_ .get_image_embeddings (images )
81- )
82- else :
83- image_embeddings = []
84- for image in images :
85- if image is not None :
86- image_embeddings .append (self .encoder_ .encode (image ))
87- else :
88- image_embeddings .append (np .full (embedding_size , np .nan ))
89- image_embeddings = np .stack (image_embeddings )
90- print (image_embeddings )
91- if hasattr (self .encoder_ , "get_fused_embeddings" ):
92- document_embeddings = np .array (
93- self .encoder_ .get_fused_embeddings (
94- texts = sentences ,
95- images = images ,
96- )
97- )
98- else :
99- document_embeddings = _naive_join_embeddings (
100- text_embeddings , image_embeddings
101- )
102-
103- return {
104- "text_embeddings" : text_embeddings ,
105- "image_embeddings" : image_embeddings ,
106- "document_embeddings" : document_embeddings ,
107- }
121+ return encode_multimodal (self .encoder_ , sentences , images )
108122
109123 @staticmethod
110124 def validate_embeddings (embeddings : Optional [MultimodalEmbeddings ]):
0 commit comments