Skip to content

Commit fc2ddb1

Browse files
authored
fix(llm): Fix when use custom embedding providers. (#339)
* fix(llm): Fix when use custom embedding providers. * fix validate. * fix validate.
1 parent fff9b3d commit fc2ddb1

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

opencontext/llm/llm_client.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,14 +266,14 @@ async def _openai_chat_completion_stream_async(self, messages: List[Dict[str, An
266266

267267
def _request_embedding(self, text: str, **kwargs) -> List[float]:
268268
try:
269-
if self.provider == LLMProvider.DOUBAO.value:
269+
if self.provider != LLMProvider.DOUBAO.value:
270+
response = self.client.embeddings.create(model=self.model, input=[text])
271+
embedding = response.data[0].embedding
272+
else:
270273
response = self.client.multimodal_embeddings.create(
271274
model=self.model, input=[{"type": "text", "text": text}]
272275
)
273276
embedding = response.data.embedding
274-
else:
275-
response = self.client.embeddings.create(model=self.model, input=[text])
276-
embedding = response.data[0].embedding
277277

278278
# Record token usage
279279
if hasattr(response, "usage") and response.usage:
@@ -314,6 +314,7 @@ def _request_embedding(self, text: str, **kwargs) -> List[float]:
314314
async def _request_embedding_async(self, text: str, **kwargs) -> List[float]:
315315
try:
316316
if self.provider == LLMProvider.DOUBAO.value:
317+
# Only ark has multimodal_embeddings
317318
response = self.client.multimodal_embeddings.create(
318319
model=self.model, input=[{"type": "text", "text": text}]
319320
)

0 commit comments

Comments
 (0)