11import asyncio
22import json
33from collections .abc import Iterable
4- from typing import Any , Dict , List , Optional , Tuple
4+ from typing import Any , Dict , List , Optional , Tuple , TypeVar
55
66import httpx
77from tortoise .exceptions import DoesNotExist
2828 "text-embedding-ada-002" : 1536 ,
2929}
3030
31+ T = TypeVar ("T" )
32+
3133
3234class VectorDBConfigManager :
3335 TYPE_KEY = "VECTOR_DB_TYPE"
3436 CONFIG_KEY = "VECTOR_DB_CONFIG"
3537 DEFAULT_TYPE = "milvus_lite"
3638
39+ @classmethod
40+ def normalize_type (cls , provider_type : Any ) -> str :
41+ normalized = str (provider_type or cls .DEFAULT_TYPE ).strip ()
42+ return normalized or cls .DEFAULT_TYPE
43+
44+ @classmethod
45+ def normalize_config (cls , config : Dict [str , Any ] | None ) -> Dict [str , Any ]:
46+ normalized : Dict [str , Any ] = {}
47+ for key , value in (config or {}).items ():
48+ normalized_key = str (key ).strip ()
49+ if not normalized_key :
50+ continue
51+ if isinstance (value , str ):
52+ value = value .strip ()
53+ if not value :
54+ continue
55+ normalized [normalized_key ] = value
56+ return normalized
57+
3758 @classmethod
3859 async def load_config (cls ) -> Tuple [str , Dict [str , Any ]]:
3960 raw_type = await ConfigService .get (cls .TYPE_KEY , cls .DEFAULT_TYPE )
40- provider_type = str ( raw_type or cls .DEFAULT_TYPE )
61+ provider_type = cls .normalize_type ( raw_type )
4162
4263 raw_config = await ConfigService .get (cls .CONFIG_KEY )
4364 config_dict : Dict [str , Any ] = {}
@@ -48,12 +69,14 @@ async def load_config(cls) -> Tuple[str, Dict[str, Any]]:
4869 config_dict = {}
4970 elif isinstance (raw_config , dict ):
5071 config_dict = raw_config
51- return provider_type , config_dict
72+ return provider_type , cls . normalize_config ( config_dict )
5273
5374 @classmethod
5475 async def save_config (cls , provider_type : str , config : Dict [str , Any ]) -> None :
55- await ConfigService .set (cls .TYPE_KEY , provider_type )
56- await ConfigService .set (cls .CONFIG_KEY , json .dumps (config or {}))
76+ normalized_type = cls .normalize_type (provider_type )
77+ normalized_config = cls .normalize_config (config )
78+ await ConfigService .set (cls .TYPE_KEY , normalized_type )
79+ await ConfigService .set (cls .CONFIG_KEY , json .dumps (normalized_config ))
5780
5881 @classmethod
5982 async def get_type (cls ) -> str :
@@ -413,6 +436,7 @@ def __init__(self):
413436 self ._provider_type : Optional [str ] = None
414437 self ._provider_config : Dict [str , Any ] | None = None
415438 self ._lock = asyncio .Lock ()
439+ self ._operation_lock = asyncio .Lock ()
416440
417441 async def _ensure_provider (self ) -> BaseVectorProvider :
418442 if self ._provider is None :
@@ -449,33 +473,38 @@ async def reload(self) -> BaseVectorProvider:
449473 self ._provider_config = normalized_config
450474 return provider
451475
476+ async def _run_provider_call (self , provider : BaseVectorProvider , method_name : str , * args , ** kwargs ) -> T :
477+ method = getattr (provider , method_name )
478+ async with self ._operation_lock :
479+ return await asyncio .to_thread (method , * args , ** kwargs )
480+
452481 async def ensure_collection (self , collection_name : str , vector : bool = True , dim : int = DEFAULT_VECTOR_DIMENSION ) -> None :
453482 provider = await self ._ensure_provider ()
454- provider . ensure_collection ( collection_name , vector , dim )
483+ await self . _run_provider_call ( provider , "ensure_collection" , collection_name , vector , dim )
455484
456485 async def upsert_vector (self , collection_name : str , data : Dict [str , Any ]) -> None :
457486 provider = await self ._ensure_provider ()
458- provider . upsert_vector ( collection_name , data )
487+ await self . _run_provider_call ( provider , "upsert_vector" , collection_name , data )
459488
460489 async def delete_vector (self , collection_name : str , path : str ) -> None :
461490 provider = await self ._ensure_provider ()
462- provider . delete_vector ( collection_name , path )
491+ await self . _run_provider_call ( provider , "delete_vector" , collection_name , path )
463492
464493 async def search_vectors (self , collection_name : str , query_embedding , top_k : int = 5 ):
465494 provider = await self ._ensure_provider ()
466- return provider . search_vectors ( collection_name , query_embedding , top_k )
495+ return await self . _run_provider_call ( provider , "search_vectors" , collection_name , query_embedding , top_k )
467496
468497 async def search_by_path (self , collection_name : str , query_path : str , top_k : int = 20 ):
469498 provider = await self ._ensure_provider ()
470- return provider . search_by_path ( collection_name , query_path , top_k )
499+ return await self . _run_provider_call ( provider , "search_by_path" , collection_name , query_path , top_k )
471500
472501 async def get_all_stats (self ) -> Dict [str , Any ]:
473502 provider = await self ._ensure_provider ()
474- return provider . get_all_stats ( )
503+ return await self . _run_provider_call ( provider , "get_all_stats" )
475504
476505 async def clear_all_data (self ) -> None :
477506 provider = await self ._ensure_provider ()
478- provider . clear_all_data ( )
507+ await self . _run_provider_call ( provider , "clear_all_data" )
479508
480509 async def current_provider (self ) -> Dict [str , Any ]:
481510 provider_type , provider_config = await VectorDBConfigManager .load_config ()
0 commit comments