Skip to content

Commit 9472ac4

Browse files
feat(ml): add qdrant ingestion
refactor: use local qdrant implementation for tests chore: clean up imports chore: add qdrant to py3xxML tox envs chore: add qdrant dependency to ml_test extra chore: run precommit chore: make QdrantWriteTransform private chore: add comment to CHANGES.md
1 parent 211cd66 commit 9472ac4

4 files changed

Lines changed: 510 additions & 3 deletions

File tree

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@
110110
* Updates minimum Go version to 1.26.1 ([#37897](https://github.com/apache/beam/issues/37897)).
111111
* (Python) Added image embedding support in `apache_beam.ml.rag` package ([#37628](https://github.com/apache/beam/issues/37628)).
112112
* (Python) Added support for Python version 3.14 ([#37247](https://github.com/apache/beam/issues/37247)).
113+
* (Python) Added [Qdrant](https://qdrant.tech/) VectorDatabaseWriteConfig implementation ([#38141](https://github.com/apache/beam/issues/38141)).
113114

114115
## Breaking Changes
115116

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from dataclasses import dataclass, field
18+
from typing import Any, Callable, Dict, Optional
19+
20+
from qdrant_client import QdrantClient, models
21+
22+
import apache_beam as beam
23+
from apache_beam.ml.rag.types import EmbeddableItem
24+
from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
25+
26+
DEFAULT_WRITE_BATCH_SIZE = 1000
27+
28+
29+
@dataclass
30+
class QdrantConnectionParameters:
31+
location: Optional[str] = None
32+
url: Optional[str] = None
33+
port: Optional[int] = 6333
34+
grpc_port: int = 6334
35+
prefer_grpc: bool = False
36+
https: Optional[bool] = None
37+
api_key: Optional[str] = None
38+
prefix: Optional[str] = None
39+
timeout: Optional[int] = None
40+
host: Optional[str] = None
41+
path: Optional[str] = None
42+
kwargs: Dict[str, Any] = field(default_factory=dict)
43+
44+
def __post_init__(self):
45+
if not (self.location or self.url or self.host or self.path):
46+
raise ValueError(
47+
"One of location, url, host, or path must be provided for Qdrant")
48+
49+
50+
@dataclass
51+
class QdrantWriteConfig(VectorDatabaseWriteConfig):
52+
connection_params: QdrantConnectionParameters
53+
collection_name: str
54+
timeout: Optional[float] = None
55+
batch_size: int = DEFAULT_WRITE_BATCH_SIZE
56+
kwargs: Dict[str, Any] = field(default_factory=dict)
57+
dense_embedding_key: str = "dense"
58+
sparse_embedding_key: str = "sparse"
59+
60+
def __post_init__(self):
61+
if not self.collection_name:
62+
raise ValueError("Collection name must be provided")
63+
64+
def create_write_transform(self) -> beam.PTransform[EmbeddableItem, Any]:
65+
return _QdrantWriteTransform(self)
66+
67+
def create_converter(self) -> Callable[[EmbeddableItem], models.PointStruct]:
68+
def convert(item: EmbeddableItem) -> models.PointStruct:
69+
if item.dense_embedding is None and item.sparse_embedding is None:
70+
raise ValueError(
71+
"EmbeddableItem must have at least one embedding (dense or sparse)")
72+
vector = {}
73+
if item.dense_embedding is not None:
74+
vector[self.dense_embedding_key] = item.dense_embedding
75+
if item.sparse_embedding is not None:
76+
sparse_indices, sparse_values = item.sparse_embedding
77+
vector[self.sparse_embedding_key] = models.SparseVector(
78+
indices=sparse_indices,
79+
values=sparse_values,
80+
)
81+
id = (
82+
int(item.id)
83+
if isinstance(item.id, str) and item.id.isdigit() else item.id)
84+
return models.PointStruct(
85+
id=id,
86+
vector=vector,
87+
payload=item.metadata if item.metadata else None,
88+
)
89+
90+
return convert
91+
92+
93+
class _QdrantWriteTransform(beam.PTransform):
94+
def __init__(self, config: QdrantWriteConfig):
95+
self.config = config
96+
97+
def expand(self, input_or_inputs: beam.PCollection[EmbeddableItem]):
98+
return (
99+
input_or_inputs
100+
| "Convert to Records" >> beam.Map(self.config.create_converter())
101+
| beam.ParDo(_QdrantWriteFn(self.config)))
102+
103+
104+
class _QdrantWriteFn(beam.DoFn):
105+
def __init__(self, config: QdrantWriteConfig):
106+
self.config = config
107+
self._batch = []
108+
self._client: Optional[QdrantClient] = None
109+
110+
def process(self, element, *args, **kwargs):
111+
self._batch.append(element)
112+
if len(self._batch) >= self.config.batch_size:
113+
self._flush()
114+
115+
def setup(self):
116+
params = self.config.connection_params
117+
self._client = QdrantClient(
118+
location=params.location,
119+
url=params.url,
120+
port=params.port,
121+
grpc_port=params.grpc_port,
122+
prefer_grpc=params.prefer_grpc,
123+
https=params.https,
124+
api_key=params.api_key,
125+
prefix=params.prefix,
126+
timeout=params.timeout,
127+
host=params.host,
128+
path=params.path,
129+
check_compatibility=False,
130+
**params.kwargs,
131+
)
132+
133+
def teardown(self):
134+
if self._client:
135+
self._client.close()
136+
self._client = None
137+
138+
def finish_bundle(self):
139+
self._flush()
140+
141+
def _flush(self):
142+
if len(self._batch) == 0:
143+
return
144+
if not self._client:
145+
raise RuntimeError("Qdrant client is not initialized")
146+
self._client.upsert(
147+
collection_name=self.config.collection_name,
148+
points=self._batch,
149+
timeout=self.config.timeout,
150+
**self.config.kwargs,
151+
)
152+
self._batch = []
153+
154+
def display_data(self):
155+
res = super().display_data()
156+
res["collection"] = self.config.collection_name
157+
res["batch_size"] = self.config.batch_size
158+
return res

0 commit comments

Comments
 (0)