-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathmain.py
More file actions
162 lines (145 loc) · 5.18 KB
/
main.py
File metadata and controls
162 lines (145 loc) · 5.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from dotenv import load_dotenv
from psycopg_pool import ConnectionPool
from pgvector.psycopg import register_vector
import functools
import cocoindex
import os
from numpy.typing import NDArray
import numpy as np
@cocoindex.transform_flow()
def code_to_embedding(
text: cocoindex.DataSlice[str],
) -> cocoindex.DataSlice[NDArray[np.float32]]:
"""
Embed the text using a SentenceTransformer model.
"""
# You can also switch to Voyage embedding model:
# return text.transform(
# cocoindex.functions.EmbedText(
# api_type=cocoindex.LlmApiType.GEMINI,
# model="text-embedding-004",
# )
# )
return text.transform(
cocoindex.functions.SentenceTransformerEmbed(
model="sentence-transformers/all-MiniLM-L6-v2"
)
)
@cocoindex.flow_def(name="CodeEmbedding")
def code_embedding_flow(
flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope
) -> None:
"""
Define an example flow that embeds files into a vector database.
"""
data_scope["files"] = flow_builder.add_source(
cocoindex.sources.LocalFile(
path=os.path.join("..", ".."),
included_patterns=["*.py", "*.rs", "*.toml", "*.md", "*.mdx"],
excluded_patterns=["**/.*", "target", "**/node_modules"],
)
)
code_embeddings = data_scope.add_collector()
with data_scope["files"].row() as file:
file["language"] = file["filename"].transform(
cocoindex.functions.DetectProgrammingLanguage()
)
file["chunks"] = file["content"].transform(
cocoindex.functions.SplitRecursively(),
language=file["language"],
chunk_size=1000,
min_chunk_size=300,
chunk_overlap=300,
)
with file["chunks"].row() as chunk:
chunk["embedding"] = chunk["text"].call(code_to_embedding)
code_embeddings.collect(
filename=file["filename"],
location=chunk["location"],
code=chunk["text"],
embedding=chunk["embedding"],
start=chunk["start"],
end=chunk["end"],
)
code_embeddings.export(
"code_embeddings",
cocoindex.targets.Postgres(),
primary_key_fields=["filename", "location"],
vector_indexes=[
cocoindex.VectorIndexDef(
field_name="embedding",
metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
)
],
)
@functools.cache
def connection_pool() -> ConnectionPool:
"""
Get a connection pool to the database.
"""
return ConnectionPool(os.environ["COCOINDEX_DATABASE_URL"])
TOP_K = 5
# Declaring it as a query handler, so that you can easily run queries in CocoInsight.
@code_embedding_flow.query_handler(
result_fields=cocoindex.QueryHandlerResultFields(
embedding=["embedding"], score="score"
)
)
def search(query: str) -> cocoindex.QueryOutput:
# Get the table name, for the export target in the code_embedding_flow above.
table_name = cocoindex.utils.get_target_default_name(
code_embedding_flow, "code_embeddings"
)
# Evaluate the transform flow defined above with the input query, to get the embedding.
query_vector = code_to_embedding.eval(query)
# Run the query and get the results.
with connection_pool().connection() as conn:
register_vector(conn)
with conn.cursor() as cur:
cur.execute(
f"""
SELECT filename, code, embedding, embedding <=> %s AS distance, start, "end"
FROM {table_name} ORDER BY distance LIMIT %s
""",
(query_vector, TOP_K),
)
return cocoindex.QueryOutput(
query_info=cocoindex.QueryInfo(
embedding=query_vector,
similarity_metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
),
results=[
{
"filename": row[0],
"code": row[1],
"embedding": row[2],
"score": 1.0 - row[3],
"start": row[4],
"end": row[5],
}
for row in cur.fetchall()
],
)
def _main() -> None:
# Make sure the flow is built and up-to-date.
stats = code_embedding_flow.update()
print("Updated index: ", stats)
# Run queries in a loop to demonstrate the query capabilities.
while True:
query = input("Enter search query (or Enter to quit): ")
if query == "":
break
# Run the query function with the database connection pool and the query.
query_output = search(query)
print("\nSearch results:")
for result in query_output.results:
print(
f"[{result['score']:.3f}] {result['filename']} (L{result['start']['line']}-L{result['end']['line']})"
)
print(f" {result['code']}")
print("---")
print()
if __name__ == "__main__":
load_dotenv()
cocoindex.init()
_main()