Skip to content

Commit 398dbcf

Browse files
committed
feat: enhance vector database configuration handling and improve provider initialization
1 parent 0609cf6 commit 398dbcf

5 files changed

Lines changed: 65 additions & 23 deletions

File tree

domain/ai/api.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -267,19 +267,24 @@ async def get_vector_db_config(request: Request, user: User = Depends(get_curren
267267
async def update_vector_db_config(
268268
request: Request, payload: VectorDBConfigPayload, user: User = Depends(get_current_active_user)
269269
):
270-
entry = get_provider_entry(payload.type)
270+
provider_type = str(payload.type or "").strip()
271+
if not provider_type:
272+
raise HTTPException(status_code=400, detail="向量数据库类型不能为空")
273+
normalized_config = VectorDBConfigManager.normalize_config(payload.config)
274+
275+
entry = get_provider_entry(provider_type)
271276
if not entry:
272277
raise HTTPException(
273-
status_code=400, detail=f"未知的向量数据库类型: {payload.type}")
278+
status_code=400, detail=f"未知的向量数据库类型: {provider_type}")
274279
if not entry.get("enabled", True):
275280
raise HTTPException(status_code=400, detail="该向量数据库类型暂不可用")
276281

277-
provider_cls = get_provider_class(payload.type)
282+
provider_cls = get_provider_class(provider_type)
278283
if not provider_cls:
279284
raise HTTPException(
280-
status_code=400, detail=f"未找到类型 {payload.type} 对应的实现")
285+
status_code=400, detail=f"未找到类型 {provider_type} 对应的实现")
281286

282-
test_provider = provider_cls(payload.config)
287+
test_provider = provider_cls(normalized_config)
283288
try:
284289
await test_provider.initialize()
285290
except Exception as exc:
@@ -293,7 +298,7 @@ async def update_vector_db_config(
293298
except Exception:
294299
pass
295300

296-
await VectorDBConfigManager.save_config(payload.type, payload.config)
301+
await VectorDBConfigManager.save_config(provider_type, normalized_config)
297302
service = VectorDBService()
298303
await service.reload()
299304
config_data = await service.current_provider()

domain/ai/service.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import json
33
from collections.abc import Iterable
4-
from typing import Any, Dict, List, Optional, Tuple
4+
from typing import Any, Dict, List, Optional, Tuple, TypeVar
55

66
import httpx
77
from tortoise.exceptions import DoesNotExist
@@ -28,16 +28,37 @@
2828
"text-embedding-ada-002": 1536,
2929
}
3030

31+
T = TypeVar("T")
32+
3133

3234
class VectorDBConfigManager:
3335
TYPE_KEY = "VECTOR_DB_TYPE"
3436
CONFIG_KEY = "VECTOR_DB_CONFIG"
3537
DEFAULT_TYPE = "milvus_lite"
3638

39+
@classmethod
40+
def normalize_type(cls, provider_type: Any) -> str:
41+
normalized = str(provider_type or cls.DEFAULT_TYPE).strip()
42+
return normalized or cls.DEFAULT_TYPE
43+
44+
@classmethod
45+
def normalize_config(cls, config: Dict[str, Any] | None) -> Dict[str, Any]:
46+
normalized: Dict[str, Any] = {}
47+
for key, value in (config or {}).items():
48+
normalized_key = str(key).strip()
49+
if not normalized_key:
50+
continue
51+
if isinstance(value, str):
52+
value = value.strip()
53+
if not value:
54+
continue
55+
normalized[normalized_key] = value
56+
return normalized
57+
3758
@classmethod
3859
async def load_config(cls) -> Tuple[str, Dict[str, Any]]:
3960
raw_type = await ConfigService.get(cls.TYPE_KEY, cls.DEFAULT_TYPE)
40-
provider_type = str(raw_type or cls.DEFAULT_TYPE)
61+
provider_type = cls.normalize_type(raw_type)
4162

4263
raw_config = await ConfigService.get(cls.CONFIG_KEY)
4364
config_dict: Dict[str, Any] = {}
@@ -48,12 +69,14 @@ async def load_config(cls) -> Tuple[str, Dict[str, Any]]:
4869
config_dict = {}
4970
elif isinstance(raw_config, dict):
5071
config_dict = raw_config
51-
return provider_type, config_dict
72+
return provider_type, cls.normalize_config(config_dict)
5273

5374
@classmethod
5475
async def save_config(cls, provider_type: str, config: Dict[str, Any]) -> None:
55-
await ConfigService.set(cls.TYPE_KEY, provider_type)
56-
await ConfigService.set(cls.CONFIG_KEY, json.dumps(config or {}))
76+
normalized_type = cls.normalize_type(provider_type)
77+
normalized_config = cls.normalize_config(config)
78+
await ConfigService.set(cls.TYPE_KEY, normalized_type)
79+
await ConfigService.set(cls.CONFIG_KEY, json.dumps(normalized_config))
5780

5881
@classmethod
5982
async def get_type(cls) -> str:
@@ -413,6 +436,7 @@ def __init__(self):
413436
self._provider_type: Optional[str] = None
414437
self._provider_config: Dict[str, Any] | None = None
415438
self._lock = asyncio.Lock()
439+
self._operation_lock = asyncio.Lock()
416440

417441
async def _ensure_provider(self) -> BaseVectorProvider:
418442
if self._provider is None:
@@ -449,33 +473,38 @@ async def reload(self) -> BaseVectorProvider:
449473
self._provider_config = normalized_config
450474
return provider
451475

476+
async def _run_provider_call(self, provider: BaseVectorProvider, method_name: str, *args, **kwargs) -> T:
477+
method = getattr(provider, method_name)
478+
async with self._operation_lock:
479+
return await asyncio.to_thread(method, *args, **kwargs)
480+
452481
async def ensure_collection(self, collection_name: str, vector: bool = True, dim: int = DEFAULT_VECTOR_DIMENSION) -> None:
453482
provider = await self._ensure_provider()
454-
provider.ensure_collection(collection_name, vector, dim)
483+
await self._run_provider_call(provider, "ensure_collection", collection_name, vector, dim)
455484

456485
async def upsert_vector(self, collection_name: str, data: Dict[str, Any]) -> None:
457486
provider = await self._ensure_provider()
458-
provider.upsert_vector(collection_name, data)
487+
await self._run_provider_call(provider, "upsert_vector", collection_name, data)
459488

460489
async def delete_vector(self, collection_name: str, path: str) -> None:
461490
provider = await self._ensure_provider()
462-
provider.delete_vector(collection_name, path)
491+
await self._run_provider_call(provider, "delete_vector", collection_name, path)
463492

464493
async def search_vectors(self, collection_name: str, query_embedding, top_k: int = 5):
465494
provider = await self._ensure_provider()
466-
return provider.search_vectors(collection_name, query_embedding, top_k)
495+
return await self._run_provider_call(provider, "search_vectors", collection_name, query_embedding, top_k)
467496

468497
async def search_by_path(self, collection_name: str, query_path: str, top_k: int = 20):
469498
provider = await self._ensure_provider()
470-
return provider.search_by_path(collection_name, query_path, top_k)
499+
return await self._run_provider_call(provider, "search_by_path", collection_name, query_path, top_k)
471500

472501
async def get_all_stats(self) -> Dict[str, Any]:
473502
provider = await self._ensure_provider()
474-
return provider.get_all_stats()
503+
return await self._run_provider_call(provider, "get_all_stats")
475504

476505
async def clear_all_data(self) -> None:
477506
provider = await self._ensure_provider()
478-
provider.clear_all_data()
507+
await self._run_provider_call(provider, "clear_all_data")
479508

480509
async def current_provider(self) -> Dict[str, Any]:
481510
provider_type, provider_config = await VectorDBConfigManager.load_config()

domain/ai/vector_providers/milvus_lite.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from pathlib import Path
23
from typing import Any, Dict, List, Optional
34

@@ -23,12 +24,14 @@ class MilvusLiteProvider(BaseVectorProvider):
2324

2425
def __init__(self, config: Dict[str, Any] | None = None):
2526
super().__init__(config)
26-
self.db_path = Path(self.config.get("db_path") or "data/db/milvus.db")
27+
raw_db_path = self.config.get("db_path")
28+
db_path = str(raw_db_path).strip() if raw_db_path is not None else ""
29+
self.db_path = Path(db_path or "data/db/milvus.db")
2730
self.client: MilvusClient | None = None
2831

2932
async def initialize(self) -> None:
3033
try:
31-
self.client = MilvusClient(str(self.db_path))
34+
self.client = await asyncio.to_thread(MilvusClient, str(self.db_path))
3235
except Exception as exc: # pragma: no cover - depends on local environment
3336
raise RuntimeError(f"Failed to open Milvus Lite at {self.db_path}: {exc}") from exc
3437

domain/ai/vector_providers/milvus_server.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from typing import Any, Dict, List, Optional
23

34
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
@@ -32,11 +33,14 @@ def __init__(self, config: Dict[str, Any] | None = None):
3233
self.client: MilvusClient | None = None
3334

3435
async def initialize(self) -> None:
35-
uri = self.config.get("uri")
36+
uri = str(self.config.get("uri") or "").strip()
3637
if not uri:
3738
raise RuntimeError("Milvus Server URI is required")
39+
token = self.config.get("token")
40+
if isinstance(token, str):
41+
token = token.strip() or None
3842
try:
39-
self.client = MilvusClient(uri=uri, token=self.config.get("token"))
43+
self.client = await asyncio.to_thread(MilvusClient, uri=uri, token=token)
4044
except Exception as exc: # pragma: no cover - depends on remote availability
4145
raise RuntimeError(f"Failed to connect to Milvus Server {uri}: {exc}") from exc
4246

domain/ai/vector_providers/qdrant.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from typing import Any, Dict, List, Optional, Sequence
23
from uuid import NAMESPACE_URL, uuid5
34

@@ -40,7 +41,7 @@ async def initialize(self) -> None:
4041
api_key = (self.config.get("api_key") or None) or None
4142
try:
4243
client = QdrantClient(url=url, api_key=api_key)
43-
client.get_collections()
44+
await asyncio.to_thread(client.get_collections)
4445
self.client = client
4546
except Exception as exc: # pragma: no cover - 依赖外部服务
4647
raise RuntimeError(f"Failed to connect to Qdrant at {url}: {exc}") from exc

0 commit comments

Comments
 (0)