-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathembed.py
More file actions
100 lines (79 loc) · 3.87 KB
/
embed.py
File metadata and controls
100 lines (79 loc) · 3.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
from pandas import DataFrame
from PIL import Image
from io import BytesIO
import base64
import math
from transformers import AutoImageProcessor, ResNetForImageClassification
from dotenv import load_dotenv
from qdrant_client import QdrantClient, models
from qdrant_client.models import VectorParams, Distance
def get_image_urls(directory: str, sample_size: int) -> list[str]:
cardboard = list(map(lambda item: f"{directory}/cardboard/{item}", os.listdir(f"{directory}/cardboard")[:sample_size]))
glass = list(map(lambda item: f"{directory}/glass/{item}", os.listdir(f"{directory}/glass")[:sample_size]))
metal = list(map(lambda item: f"{directory}/metal/{item}", os.listdir(f"{directory}/metal")[:sample_size]))
paper = list(map(lambda item: f"{directory}/paper/{item}", os.listdir(f"{directory}/paper")[:sample_size]))
plastic = list(map(lambda item: f"{directory}/plastic/{item}", os.listdir(f"{directory}/plastic")[:sample_size]))
return [*cardboard, *glass, *metal, *paper, *plastic]
def create_data_frame(urls: str) -> DataFrame:
return DataFrame.from_records({"image_url": urls})
def create_images(data_frame: DataFrame) -> list[Image.Image]:
return list(map(lambda url: Image.open(url), data_frame["image_url"]))
def resize(url: str, target_width: int) -> Image.Image:
image = Image.open(url)
aspect_ratio = image.width / image.height
return image.resize([target_width, math.floor(target_width * aspect_ratio)])
def resize_images(urls: list[str]) -> list[Image.Image]:
return list(map(lambda element: resize(element, target_width=256), urls))
def to_base64(image: Image.Image) -> str:
data = BytesIO()
image.save(data, format="JPEG")
return base64.b64encode(data.getvalue()).decode("utf-8")
def get_base64_strings(images: list[Image.Image]) -> list[str]:
return list(map(lambda element: to_base64(element), images))
def main():
# Get full paths of sample images:
total_images = 15 # Big numbers makes transformer cry
categories = 5
sample_urls = get_image_urls("recycling", total_images // categories)
# Store image metadate in dataframe:
payloads = create_data_frame(sample_urls)
payloads["type"] = "recycling"
# Create PIL images from each url:
images = create_images(payloads)
# Resize images:
resized_images = resize_images(sample_urls)
# Create base 64 string representation of images for user prievew:
base64_strings = get_base64_strings(resized_images)
payloads["base64"] = base64_strings
# Create embeddings:
embedding_model = "microsoft/resnet-50"
processor = AutoImageProcessor.from_pretrained(embedding_model, use_fast=True)
model = ResNetForImageClassification.from_pretrained(embedding_model)
# Call processor on images to create dictionary of inputs for model:
inputs = processor(images, return_tensors="pt")
outputs = model(**inputs)
embeddings = outputs.logits
# Store embedding length for later use:
embedding_length = len(embeddings[0])
# Initialize Qdrant client with environment variables:
load_dotenv()
client = QdrantClient(url=os.getenv('QDRANT_DB_URL'), api_key=os.getenv('QDRANT_API_KEY'))
# Create recycling collection:
name = "recycling_images"
if client.collection_exists(collection_name=name):
print("Collection already exists")
return
collection = client.create_collection(
collection_name=name,
vectors_config=VectorParams(size=embedding_length, distance=Distance.COSINE)
)
# Convert dataframe to array of objects:
payload_dictionaries = payloads.to_dict(orient="records")
# Create record containing payload metadata and vector embeddings:
records = [
models.Record(id=i, payload=payload_dictionaries[i], vector=embeddings[i])
for i, _ in enumerate(payload_dictionaries)]
# Upload to cluster:
client.upload_records(collection_name=name, records=records)
main()