Multimodal search using CLIP
In [2]:
Copied!
!pip install --quiet -U lancedb
!pip install --quiet gradio transformers torch torchvision
!pip install --quiet -U lancedb
!pip install --quiet gradio transformers torch torchvision
[notice] A new release of pip available: 22.3.1 -> 23.1.2 [notice] To update, run: pip install --upgrade pip [notice] A new release of pip available: 22.3.1 -> 23.1.2 [notice] To update, run: pip install --upgrade pip
In [1]:
Copied!
import io
import PIL
import duckdb
import lancedb
import io
import PIL
import duckdb
import lancedb
First run setup: Download data and pre-process¶
In [ ]:
Copied!
### Get dataset
!wget https://eto-public.s3.us-west-2.amazonaws.com/datasets/diffusiondb_lance.tar.gz
!tar -xvf diffusiondb_lance.tar.gz
!mv diffusiondb_test rawdata.lance
### Get dataset
!wget https://eto-public.s3.us-west-2.amazonaws.com/datasets/diffusiondb_lance.tar.gz
!tar -xvf diffusiondb_lance.tar.gz
!mv diffusiondb_test rawdata.lance
In [30]:
Copied!
# remove null prompts
import lance
import pyarrow.compute as pc
# download s3://eto-public/datasets/diffusiondb/small_10k.lance to this uri
data = lance.dataset("~/datasets/rawdata.lance").to_table()
# First data processing and full-text-search index
db = lancedb.connect("~/datasets/demo")
tbl = db.create_table("diffusiondb", data.filter(~pc.field("prompt").is_null()))
tbl = tbl.create_fts_index(["prompt"])
# remove null prompts
import lance
import pyarrow.compute as pc
# download s3://eto-public/datasets/diffusiondb/small_10k.lance to this uri
data = lance.dataset("~/datasets/rawdata.lance").to_table()
# First data processing and full-text-search index
db = lancedb.connect("~/datasets/demo")
tbl = db.create_table("diffusiondb", data.filter(~pc.field("prompt").is_null()))
tbl = tbl.create_fts_index(["prompt"])
Out[30]:
<lance.dataset.LanceDataset at 0x3045db590>
Create / Open LanceDB Table¶
In [2]:
Copied!
db = lancedb.connect("~/datasets/demo")
tbl = db.open_table("diffusiondb")
db = lancedb.connect("~/datasets/demo")
tbl = db.open_table("diffusiondb")
Create CLIP embedding function for the text¶
In [3]:
Copied!
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast
MODEL_ID = "openai/clip-vit-base-patch32"
tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
model = CLIPModel.from_pretrained(MODEL_ID)
processor = CLIPProcessor.from_pretrained(MODEL_ID)
def embed_func(query):
inputs = tokenizer([query], padding=True, return_tensors="pt")
text_features = model.get_text_features(**inputs)
return text_features.detach().numpy()[0]
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast
MODEL_ID = "openai/clip-vit-base-patch32"
tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
model = CLIPModel.from_pretrained(MODEL_ID)
processor = CLIPProcessor.from_pretrained(MODEL_ID)
def embed_func(query):
inputs = tokenizer([query], padding=True, return_tensors="pt")
text_features = model.get_text_features(**inputs)
return text_features.detach().numpy()[0]
Search functions for Gradio¶
In [4]:
Copied!
def find_image_vectors(query):
emb = embed_func(query)
code = (
"import lancedb\n"
"db = lancedb.connect('~/datasets/demo')\n"
"tbl = db.open_table('diffusiondb')\n\n"
f"embedding = embed_func('{query}')\n"
"tbl.search(embedding).limit(9).to_pandas()"
)
return (_extract(tbl.search(emb).limit(9).to_pandas()), code)
def find_image_keywords(query):
code = (
"import lancedb\n"
"db = lancedb.connect('~/datasets/demo')\n"
"tbl = db.open_table('diffusiondb')\n\n"
f"tbl.search('{query}').limit(9).to_pandas()"
)
return (_extract(tbl.search(query).limit(9).to_pandas()), code)
def find_image_sql(query):
code = (
"import lancedb\n"
"import duckdb\n"
"db = lancedb.connect('~/datasets/demo')\n"
"tbl = db.open_table('diffusiondb')\n\n"
"diffusiondb = tbl.to_lance()\n"
f"duckdb.sql('{query}').to_df()"
)
diffusiondb = tbl.to_lance()
return (_extract(duckdb.sql(query).to_df()), code)
def _extract(df):
image_col = "image"
return [(PIL.Image.open(io.BytesIO(row[image_col])), row["prompt"]) for _, row in df.iterrows()]
def find_image_vectors(query):
emb = embed_func(query)
code = (
"import lancedb\n"
"db = lancedb.connect('~/datasets/demo')\n"
"tbl = db.open_table('diffusiondb')\n\n"
f"embedding = embed_func('{query}')\n"
"tbl.search(embedding).limit(9).to_pandas()"
)
return (_extract(tbl.search(emb).limit(9).to_pandas()), code)
def find_image_keywords(query):
code = (
"import lancedb\n"
"db = lancedb.connect('~/datasets/demo')\n"
"tbl = db.open_table('diffusiondb')\n\n"
f"tbl.search('{query}').limit(9).to_pandas()"
)
return (_extract(tbl.search(query).limit(9).to_pandas()), code)
def find_image_sql(query):
code = (
"import lancedb\n"
"import duckdb\n"
"db = lancedb.connect('~/datasets/demo')\n"
"tbl = db.open_table('diffusiondb')\n\n"
"diffusiondb = tbl.to_lance()\n"
f"duckdb.sql('{query}').to_df()"
)
diffusiondb = tbl.to_lance()
return (_extract(duckdb.sql(query).to_df()), code)
def _extract(df):
image_col = "image"
return [(PIL.Image.open(io.BytesIO(row[image_col])), row["prompt"]) for _, row in df.iterrows()]
Setup Gradio interface¶
In [28]:
Copied!
import gradio as gr
with gr.Blocks() as demo:
with gr.Row():
with gr.Tab("Embeddings"):
vector_query = gr.Textbox(value="portraits of a person", show_label=False)
b1 = gr.Button("Submit")
with gr.Tab("Keywords"):
keyword_query = gr.Textbox(value="ninja turtle", show_label=False)
b2 = gr.Button("Submit")
with gr.Tab("SQL"):
sql_query = gr.Textbox(value="SELECT * from diffusiondb WHERE image_nsfw >= 2 LIMIT 9", show_label=False)
b3 = gr.Button("Submit")
with gr.Row():
code = gr.Code(label="Code", language="python")
with gr.Row():
gallery = gr.Gallery(
label="Found images", show_label=False, elem_id="gallery"
).style(columns=[3], rows=[3], object_fit="contain", height="auto")
b1.click(find_image_vectors, inputs=vector_query, outputs=[gallery, code])
b2.click(find_image_keywords, inputs=keyword_query, outputs=[gallery, code])
b3.click(find_image_sql, inputs=sql_query, outputs=[gallery, code])
demo.launch()
import gradio as gr
with gr.Blocks() as demo:
with gr.Row():
with gr.Tab("Embeddings"):
vector_query = gr.Textbox(value="portraits of a person", show_label=False)
b1 = gr.Button("Submit")
with gr.Tab("Keywords"):
keyword_query = gr.Textbox(value="ninja turtle", show_label=False)
b2 = gr.Button("Submit")
with gr.Tab("SQL"):
sql_query = gr.Textbox(value="SELECT * from diffusiondb WHERE image_nsfw >= 2 LIMIT 9", show_label=False)
b3 = gr.Button("Submit")
with gr.Row():
code = gr.Code(label="Code", language="python")
with gr.Row():
gallery = gr.Gallery(
label="Found images", show_label=False, elem_id="gallery"
).style(columns=[3], rows=[3], object_fit="contain", height="auto")
b1.click(find_image_vectors, inputs=vector_query, outputs=[gallery, code])
b2.click(find_image_keywords, inputs=keyword_query, outputs=[gallery, code])
b3.click(find_image_sql, inputs=sql_query, outputs=[gallery, code])
demo.launch()
Running on local URL: http://127.0.0.1:7881 To create a public link, set `share=True` in `launch()`.
Out[28]:
In [ ]:
Copied!