@@ -266,14 +266,14 @@ async def _openai_chat_completion_stream_async(self, messages: List[Dict[str, An
266266
267267 def _request_embedding (self , text : str , ** kwargs ) -> List [float ]:
268268 try :
269- if self .provider == LLMProvider .OPENAI .value :
270- response = self .client .embeddings .create (model = self .model , input = [text ])
271- embedding = response .data [0 ].embedding
272- else :
269+ if self .provider == LLMProvider .DOUBAO .value :
273270 response = self .client .multimodal_embeddings .create (
274271 model = self .model , input = [{"type" : "text" , "text" : text }]
275272 )
276273 embedding = response .data .embedding
274+ else :
275+ response = self .client .embeddings .create (model = self .model , input = [text ])
276+ embedding = response .data [0 ].embedding
277277
278278 # Record token usage
279279 if hasattr (response , "usage" ) and response .usage :
@@ -313,14 +313,14 @@ def _request_embedding(self, text: str, **kwargs) -> List[float]:
313313
314314 async def _request_embedding_async (self , text : str , ** kwargs ) -> List [float ]:
315315 try :
316- if self .provider == LLMProvider .OPENAI .value :
317- response = await self .async_client .embeddings .create (model = self .model , input = [text ])
318- embedding = response .data [0 ].embedding
319- else :
316+ if self .provider == LLMProvider .DOUBAO .value :
320317 response = self .client .multimodal_embeddings .create (
321318 model = self .model , input = [{"type" : "text" , "text" : text }]
322319 )
323320 embedding = response .data .embedding
321+ else :
322+ response = await self .async_client .embeddings .create (model = self .model , input = [text ])
323+ embedding = response .data [0 ].embedding
324324
325325 # Record token usage
326326 if hasattr (response , "usage" ) and response .usage :
@@ -476,20 +476,20 @@ def _extract_error_summary(error: Any) -> str:
476476
477477 elif self .llm_type == LLMType .EMBEDDING :
478478 # Test with a simple text
479- if self .provider == LLMProvider .OPENAI .value :
480- response = self .client .embeddings .create (model = self .model , input = ["test" ])
481- if response .data and len (response .data ) > 0 and response .data [0 ].embedding :
482- return True , "Embedding model validation successful"
483- else :
484- return False , "Embedding model returned empty response"
485- else :
479+ if self .provider == LLMProvider .DOUBAO .value :
486480 response = self .client .multimodal_embeddings .create (
487481 model = self .model , input = [{"type" : "text" , "text" : "test" }]
488482 )
489483 if response .data and response .data .embedding :
490484 return True , "Embedding model validation successful"
491485 else :
492486 return False , "Embedding model returned empty response"
487+ else :
488+ response = self .client .embeddings .create (model = self .model , input = ["test" ])
489+ if response .data and len (response .data ) > 0 and response .data [0 ].embedding :
490+ return True , "Embedding model validation successful"
491+ else :
492+ return False , "Embedding model returned empty response"
493493 else :
494494 return False , f"Unsupported LLM type: { self .llm_type } "
495495
0 commit comments