Skip to content

Commit f758087

Browse files
fix: correct embedding provider detection for non-Doubao providers (#347)
1 parent 526e2cd commit f758087

1 file changed

Lines changed: 15 additions & 15 deletions

File tree

opencontext/llm/llm_client.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -266,14 +266,14 @@ async def _openai_chat_completion_stream_async(self, messages: List[Dict[str, An
266266

267267
def _request_embedding(self, text: str, **kwargs) -> List[float]:
268268
try:
269-
if self.provider == LLMProvider.OPENAI.value:
270-
response = self.client.embeddings.create(model=self.model, input=[text])
271-
embedding = response.data[0].embedding
272-
else:
269+
if self.provider == LLMProvider.DOUBAO.value:
273270
response = self.client.multimodal_embeddings.create(
274271
model=self.model, input=[{"type": "text", "text": text}]
275272
)
276273
embedding = response.data.embedding
274+
else:
275+
response = self.client.embeddings.create(model=self.model, input=[text])
276+
embedding = response.data[0].embedding
277277

278278
# Record token usage
279279
if hasattr(response, "usage") and response.usage:
@@ -313,14 +313,14 @@ def _request_embedding(self, text: str, **kwargs) -> List[float]:
313313

314314
async def _request_embedding_async(self, text: str, **kwargs) -> List[float]:
315315
try:
316-
if self.provider == LLMProvider.OPENAI.value:
317-
response = await self.async_client.embeddings.create(model=self.model, input=[text])
318-
embedding = response.data[0].embedding
319-
else:
316+
if self.provider == LLMProvider.DOUBAO.value:
320317
response = self.client.multimodal_embeddings.create(
321318
model=self.model, input=[{"type": "text", "text": text}]
322319
)
323320
embedding = response.data.embedding
321+
else:
322+
response = await self.async_client.embeddings.create(model=self.model, input=[text])
323+
embedding = response.data[0].embedding
324324

325325
# Record token usage
326326
if hasattr(response, "usage") and response.usage:
@@ -476,20 +476,20 @@ def _extract_error_summary(error: Any) -> str:
476476

477477
elif self.llm_type == LLMType.EMBEDDING:
478478
# Test with a simple text
479-
if self.provider == LLMProvider.OPENAI.value:
480-
response = self.client.embeddings.create(model=self.model, input=["test"])
481-
if response.data and len(response.data) > 0 and response.data[0].embedding:
482-
return True, "Embedding model validation successful"
483-
else:
484-
return False, "Embedding model returned empty response"
485-
else:
479+
if self.provider == LLMProvider.DOUBAO.value:
486480
response = self.client.multimodal_embeddings.create(
487481
model=self.model, input=[{"type": "text", "text": "test"}]
488482
)
489483
if response.data and response.data.embedding:
490484
return True, "Embedding model validation successful"
491485
else:
492486
return False, "Embedding model returned empty response"
487+
else:
488+
response = self.client.embeddings.create(model=self.model, input=["test"])
489+
if response.data and len(response.data) > 0 and response.data[0].embedding:
490+
return True, "Embedding model validation successful"
491+
else:
492+
return False, "Embedding model returned empty response"
493493
else:
494494
return False, f"Unsupported LLM type: {self.llm_type}"
495495

0 commit comments

Comments
 (0)