# 图像搜索
# 简介
图像搜索已成为一种流行且强大的应用程序,使用户能够通过匹配特征或视觉内容来查找相似的图像。随着计算机视觉和深度学习的快速发展,这种能力得到了极大的增强。
本指南旨在帮助您利用最新的技术和工具进行图像搜索。在本指南中,您将学习如何:
- 使用公共数据集和模型创建具有向量嵌入的数据集
- 使用 MyScale 进行图像相似性搜索,MyScale 是一个强大的平台,可以简化搜索过程并提供快速准确的结果
如果您更感兴趣于探索 MyScale 的功能,请随意跳过 构建数据集 部分,直接进入 将数据导入 MyScale 部分。
您可以按照 导入数据 部分中提供的说明,在 MyScale 控制台上导入此数据集。导入后,您可以直接转到 查询 MyScale 部分,享受此示例应用程序。
# 先决条件
在开始之前,我们需要安装 clickhouse python client (opens new window) 和 HuggingFace 的 datasets
库以下载示例数据。
pip install datasets clickhouse-connect
要按照 构建数据集 部分中概述的步骤进行操作,我们需要安装 transformers 和其他必要的依赖项。
pip install requests transformers torch tqdm
# 构建数据集
# 下载和处理数据
我们从 unsplash 数据集 (opens new window) 下载数据,并使用 Lite 数据集。
wget https://unsplash-datasets.s3.amazonaws.com/lite/latest/unsplash-research-dataset-lite-latest.zip
# 将下载的文件解压到临时目录
unzip unsplash-research-dataset-lite-latest.zip -d tmp
我们读取下载的数据并将其转换为 Pandas 数据帧。
import numpy as np
import pandas as pd
import glob
documents = ['photos', 'conversions']
datasets = {}
for doc in documents:
files = glob.glob("tmp/" + doc + ".tsv*")
subsets = []
for filename in files:
df = pd.read_csv(filename, sep='\t', header=0)
subsets.append(df)
datasets[doc] = pd.concat(subsets, axis=0, ignore_index=True)
df_photos = datasets['photos']
df_conversions = datasets['conversions']
# 生成图像嵌入
为了从图像中提取嵌入,我们定义了一个 extract_image_features
函数,该函数利用 HuggingFace 的 clip-vit-base-patch32 (opens new window) 模型。生成的嵌入是 512 维向量。
import torch
from transformers import CLIPProcessor, CLIPModel
model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def extract_image_features(image):
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model.get_image_features(**inputs)
outputs = outputs / outputs.norm(dim=-1, keepdim=True)
return outputs.squeeze(0).tolist()
然后,我们从 df_photos
数据帧中选择前 1000 个照片 ID,下载相应的图像,并使用 extract_image_features
函数提取它们的图像嵌入。
from PIL import Image
import requests
from tqdm.auto import tqdm
# 选择前 1000 个照片 ID
photo_ids = df_photos['photo_id'][:1000].tolist()
# 创建一个只包含所选照片 ID 的新数据帧
df_photos = df_photos[df_photos['photo_id'].isin(photo_ids)].reset_index(drop=True)
# 在数据帧中只保留 'photo_id' 和 'photo_image_url' 两列
df_photos = df_photos[['photo_id', 'photo_image_url']]
# 在数据帧中添加一个新列 'photo_embed'
df_photos['photo_embed'] = None
# 下载图像并使用 'extract_image_features' 函数提取它们的嵌入
for i, row in tqdm(df_photos.iterrows(), total=len(df_photos)):
# 通过修改图像 URL 构造一个较小尺寸的图像的下载 URL
url = row['photo_image_url'] + "?q=75&fm=jpg&w=200&fit=max"
try:
res = requests.get(url, stream=True).raw
image = Image.open(res)
except:
# 如果图像下载失败,则移除该照片
photo_ids.remove(row['photo_id'])
continue
# 提取特征嵌入
df_photos.at[i, 'photo_embed'] = extract_image_features(image)
# 创建数据集
现在我们有了两个数据帧:一个用于包含嵌入的照片信息,另一个用于转换信息。
df_photos = df_photos[df_photos['photo_id'].isin(photo_ids)].reset_index().rename(columns={'index': 'id'})
df_conversions = df_conversions[df_conversions['photo_id'].isin(photo_ids)].reset_index(drop=True)
df_conversions = df_conversions[['photo_id', 'keyword']].reset_index().rename(columns={'index': 'id'})
最后,我们将数据帧转换为 Parquet 文件,然后按照 步骤 (opens new window) 将它们上传到 Hugging Face 仓库 myscale/unsplash-examples (opens new window) 中。这样可以方便地访问和共享数据。
import pyarrow as pa
import pyarrow.parquet as pq
import numpy as np
# 从数据和模式创建一个 Table 对象
photos_table = pa.Table.from_pandas(df_photos)
conversion_table = pa.Table.from_pandas(df_conversions)
# 将表写入 Parquet 文件
pq.write_table(photos_table, 'photos.parquet')
pq.write_table(conversion_table, 'conversions.parquet')
# 将数据导入 MyScale
# 加载数据
要将数据导入 MyScale,首先,我们从之前部分创建的 HuggingFace 数据集 myscale/unsplash-examples (opens new window) 中加载数据,并将其转换为 Panda DataFrames。以下代码片段显示了如何加载数据并将其转换为 Panda DataFrames。
注意:photo_embed
是一个 512 维浮点向量,表示使用 CLIP (opens new window) 模型从图像中提取的图像特征。
from datasets import load_dataset
photos = load_dataset("myscale/unsplash-examples", data_files="photos-all.parquet", split="train")
conversions = load_dataset("myscale/unsplash-examples", data_files="conversions-all.parquet", split="train")
# 将数据集转换为 Panda DataFrame
photo_df = photos.to_pandas()
conversion_df = conversions.to_pandas()
# 将 photo_embed 从 np 数组转换为列表
photo_df['photo_embed'] = photo_df['photo_embed'].apply(lambda x: x.tolist())
# 创建表
接下来,我们在 MyScale 中创建表。在开始之前,您需要从 MyScale 控制台中获取集群主机、用户名和密码信息。以下代码片段创建了两个表,一个用于照片信息,另一个用于转换信息。
import clickhouse_connect
# 初始化客户端
client = clickhouse_connect.get_client(host='YOUR_CLUSTER_HOST', port=443, username='YOUR_USERNAME', password='YOUR_CLUSTER_PASSWORD')
# 如果存在,则删除表
client.command("DROP TABLE IF EXISTS default.myscale_photos")
client.command("DROP TABLE IF EXISTS default.myscale_conversions")
# 创建照片表
client.command("""
CREATE TABLE default.myscale_photos
(
id UInt64,
photo_id String,
photo_image_url String,
photo_embed Array(Float32),
CONSTRAINT vector_len CHECK length(photo_embed) = 512
)
ORDER BY id
""")
# 创建转换表
client.command("""
CREATE TABLE default.myscale_conversions
(
id UInt64,
photo_id String,
keyword String
)
ORDER BY id
""")
# 导入数据
在创建表之后,我们将从数据集中加载的数据插入到表中,并创建一个向量索引以加速后续的向量搜索查询。以下代码片段显示了如何将数据插入到表中,并使用余弦距离度量创建一个向量索引。
# 从数据集中上传数据
client.insert("default.myscale_photos", photo_df.to_records(index=False).tolist(),
column_names=photo_df.columns.tolist())
client.insert("default.myscale_conversions", conversion_df.to_records(index=False).tolist(),
column_names=conversion_df.columns.tolist())
# 检查插入数据的数量
print(f"照片数量: {client.command('SELECT count(*) FROM default.myscale_photos')}")
print(f"转换数量: {client.command('SELECT count(*) FROM default.myscale_conversions')}")
# 使用余弦距离创建向量索引
client.command("""
ALTER TABLE default.myscale_photos
ADD VECTOR INDEX photo_embed_index photo_embed
TYPE MSTG
('metric_type=Cosine')
""")
# 检查向量索引的状态,确保向量索引的状态为 'Built'
get_index_status="SELECT status FROM system.vector_indices WHERE name='photo_embed_index'"
print(f"索引构建状态: {client.command(get_index_status)}")
# 查询 MyScale
# 查找前 K 个相似图像
要使用向量搜索查找前 K 个相似图像,请按照以下步骤进行操作:
首先,让我们随机选择一张图像并使用 show_image() 函数显示它。
import requests
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
# 使用 URL 下载图像
def download(url):
response = requests.get(url)
return Image.open(BytesIO(response.content))
# 定义一个方法,使用 URL 显示在线图像
def show_image(url, title=None):
img = download(url)
fig = plt.figure(figsize=(4, 4))
plt.imshow(img)
plt.show()
# 显示每个表中的行数
print(f"照片数量: {client.command('SELECT count(*) FROM default.myscale_photos')}")
print(f"转换数量: {client.command('SELECT count(*) FROM default.myscale_conversions')}")
# 从表中随机选择一张图像作为目标
random_image = client.query("SELECT * FROM default.myscale_photos ORDER BY rand() LIMIT 1")
assert random_image.row_count == 1
target_image_id = random_image.first_item["photo_id"]
target_image_url = random_image.first_item["photo_image_url"]
target_image_embed = random_image.first_item["photo_embed"]
print("当前选择的图像 id={}, url={}".format(target_image_id, target_image_url))
# 显示目标图像
print("正在加载目标图像...")
show_image(target_image_url)
一张示例图像:
然后,使用向量搜索来识别与所选图像最相似的前 K 个候选项,并显示这些候选项:
# 查询数据库以找到与给定图像最相似的前 K 个图像
top_k = 10
results = client.query(f"""
SELECT photo_id, photo_image_url, distance(photo_embed, {target_image_embed}) as dist
FROM default.myscale_photos
WHERE photo_id != '{target_image_id}'
ORDER BY dist
LIMIT {top_k}
""")
# 下载图像并将其添加到列表中
images_url = []
for r in results.named_results():
# 通过修改图像 URL 构造一个较小尺寸的图像的下载 URL
url = r['photo_image_url'] + "?q=75&fm=jpg&w=200&fit=max"
images_url.append(download(url))
# 显示候选图像
print("正在加载候选图像...")
for row in range(int(top_k / 5)):
fig, axs = plt.subplots(1, 5, figsize=(20, 4))
for i, img in enumerate(images_url[row * 5:row * 5 + 5]):
axs[i % 5].imshow(img)
plt.show()
相似的候选图像:
# 分析每个候选图像的转换信息
在识别出前 K 个相似图像之后,您可以使用结构化字段和向量字段的 SQL 查询来对每个候选的转换信息进行分析。
要计算每个候选图像的总转换次数,您可以使用以下 SQL 查询将图像搜索结果与 conversions
表进行连接:
# 显示每个候选图像的总下载次数
results = client.query(f"""
SELECT photo_id, count(*) as count
FROM default.myscale_conversions
JOIN (
SELECT photo_id, distance(photo_embed, {target_image_embed}) as dist
FROM default.myscale_photos
ORDER BY dist ASC
LIMIT {top_k}
) AS target_photos
ON default.myscale_conversions.photo_id = target_photos.photo_id
GROUP BY photo_id
ORDER BY count DESC
""")
print("每个候选的总下载次数")
for r in results.named_results():
print("- {}: {}".format(r['photo_id'], r['count']))
示例输出:
每个候选的总下载次数
- Qgb9urMZ8lw: 1729
- f0OL01IHbCM: 1444
- Bgae-sqbe_g: 313
- XYg2zLjxxa0: 207
- BkW8I1n354I: 184
- 5yFOvJZp7Rg: 63
- sKPPBn6OkJg: 48
- joL0nSbZ-lI: 20
- fzDtQWW8dV4: 8
- DCAERnaj31U: 3
在计算了每个候选图像的总转换次数之后,您可以找到下载次数最多的候选图像,并查看该图像每个下载关键字的详细转换信息。使用以下 SQL 查询:
# 显示最受欢迎的候选图像和前 5 个相关的下载关键字
most_popular_candidate = results.first_item['photo_id']
# 显示最受欢迎的图像
candidate_url = client.command(f"""
SELECT photo_image_url FROM default.myscale_photos WHERE photo_id = '{most_popular_candidate}'
""")
print("正在加载最受欢迎的候选图像...")
show_image(candidate_url)
# 查找前 5 个下载关键字
results = client.query(f"""
SELECT keyword, count(*) as count
FROM default.myscale_conversions
WHERE photo_id='{most_popular_candidate}'
GROUP BY keyword
ORDER BY count DESC
LIMIT 5
""")
print("最受欢迎的候选图像的相关关键字和下载次数")
for r in results.named_results():
print(f"- {r['keyword']}: {r['count']}")
在前 10 个候选项中,最受欢迎的候选图像:
示例输出:
最受欢迎的候选图像的相关关键字和下载次数
- bee: 1615
- bees: 21
- bumblebee: 13
- honey: 13
- honey bee: 12