Embedding functions
Representing multi-modal data as vector embeddings is becoming a standard practice. Embedding functions can themselves be thought of as key part of the data processing pipeline that each request has to be passed through. The assumption here is: after initial setup, these components and the underlying methodology are not expected to change for a particular project.
For this purpose, LanceDB introduces an embedding functions API, that allow you simply set up once, during the configuration stage of your project. After this, the table remembers it, effectively making the embedding functions disappear in the background so you don't have to worry about manually passing callables, and instead, simply focus on the rest of your data engineering pipeline.
Embedding functions on LanceDB cloud
When using embedding functions with LanceDB cloud, the embeddings will be generated on the source device and sent to the cloud. This means that the source device must have the necessary resources to generate the embeddings.
Warning
Using the embedding function registry means that you don't have to explicitly generate the embeddings yourself. However, if your embedding function changes, you'll have to re-configure your table with the new embedding function and regenerate the embeddings. In the future, we plan to support the ability to change the embedding function via table metadata and have LanceDB automatically take care of regenerating the embeddings.
1. Define the embedding function
In the LanceDB python SDK, we define a global embedding function registry with many different embedding models and even more coming soon. Here's let's an implementation of CLIP as example.
from lancedb.embeddings import get_registry
registry = get_registry()
clip = registry.get("open-clip").create()
You can also define your own embedding function by implementing the EmbeddingFunction
abstract base interface. It subclasses Pydantic Model which can be utilized to write complex schemas simply as we'll see next!
In the TypeScript SDK, the choices are more limited. For now, only the OpenAI embedding function is available.
In the Rust SDK, the choices are more limited. For now, only the OpenAI embedding function is available. But unlike the Python and TypeScript SDKs, you need manually register the OpenAI embedding function.
// Make sure to include the `openai` feature
[dependencies]
lancedb = {version = "*", features = ["openai"]}
use std::{iter::once, sync::Arc};
use arrow_array::{Float64Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray};
use arrow_schema::{DataType, Field, Schema};
use futures::StreamExt;
use lancedb::{
arrow::IntoArrow,
connect,
embeddings::{openai::OpenAIEmbeddingFunction, EmbeddingDefinition, EmbeddingFunction},
query::{ExecutableQuery, QueryBase},
Result,
};
#[tokio::main]
async fn main() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY is not set");
let embedding = Arc::new(OpenAIEmbeddingFunction::new_with_model(
api_key,
"text-embedding-3-large",
)?);
let db = connect(tempdir).execute().await?;
db.embedding_registry()
.register("openai", embedding.clone())?;
let table = db
.create_table("vectors", make_data())
.add_embedding(EmbeddingDefinition::new(
"text",
"openai",
Some("embeddings"),
))?
.execute()
.await?;
let query = Arc::new(StringArray::from_iter_values(once("something warm")));
let query_vector = embedding.compute_query_embeddings(query)?;
let mut results = table
.vector_search(query_vector)?
.limit(1)
.execute()
.await?;
let rb = results.next().await.unwrap()?;
let out = rb
.column_by_name("text")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let text = out.iter().next().unwrap().unwrap();
println!("Closest match: {}", text);
Ok(())
}
2. Define the data model or schema
The embedding function defined above abstracts away all the details about the models and dimensions required to define the schema. You can simply set a field as source or vector column. Here's how:
class Pets(LanceModel):
vector: Vector(clip.ndims()) = clip.VectorField()
image_uri: str = clip.SourceField()
VectorField
tells LanceDB to use the clip embedding function to generate query embeddings for the vector
column and SourceField
ensures that when adding data, we automatically use the specified embedding function to encode image_uri
.
For the TypeScript SDK, a schema can be inferred from input data, or an explicit Arrow schema can be provided.
3. Create table and add data
Now that we have chosen/defined our embedding function and the schema, we can create the table and ingest data without needing to explicitly generate the embeddings at all:
import * as lancedb from "@lancedb/lancedb";
import "@lancedb/lancedb/embedding/openai";
import { LanceSchema, getRegistry, register } from "@lancedb/lancedb/embedding";
import { EmbeddingFunction } from "@lancedb/lancedb/embedding";
import { type Float, Float32, Utf8 } from "apache-arrow";
const db = await lancedb.connect(databaseDir);
@register("my_embedding")
class MyEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
// This is a placeholder for a real embedding function
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
// This is a placeholder for a real embedding function
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = new MyEmbeddingFunction();
const data = [{ text: "pepperoni" }, { text: "pineapple" }];
// Option 1: manually specify the embedding function
const table = await db.createTable("vectors", data, {
embeddingFunction: {
function: func,
sourceColumn: "text",
vectorColumn: "vector",
},
mode: "overwrite",
});
// Option 2: provide the embedding function through a schema
const schema = LanceSchema({
text: func.sourceField(new Utf8()),
vector: func.vectorField(),
});
const table2 = await db.createTable("vectors2", data, {
schema,
mode: "overwrite",
});
4. Querying your table
Not only can you forget about the embeddings during ingestion, you also don't need to worry about it when you query the table:
Our OpenCLIP query embedding function supports querying via both text and images:
Or we can search using an image:
p = Path("path/to/images/samoyed_100.jpg")
query_image = Image.open(p)
results = (
table.search(query_image)
.limit(10)
.to_pandas()
)
Both of the above snippet returns a pandas DataFrame with the 10 closest vectors to the query.
Rate limit Handling
EmbeddingFunction
class wraps the calls for source and query embedding generation inside a rate limit handler that retries the requests with exponential backoff after successive failures. By default, the maximum retires is set to 7. You can tune it by setting it to a different number, or disable it by setting it to 0.
An example of how to do this is shown below:
clip = registry.get("open-clip").create() # Defaults to 7 max retries
clip = registry.get("open-clip").create(max_retries=10) # Increase max retries to 10
clip = registry.get("open-clip").create(max_retries=0) # Retries disabled
Note
Embedding functions can also fail due to other errors that have nothing to do with rate limits. This is why the error is also logged.
Some fun with Pydantic
LanceDB is integrated with Pydantic, which was used in the example above to define the schema in Python. It's also used behind the scenes by the embedding function API to ingest useful information as table metadata.
You can also use the integration for adding utility operations in the schema. For example, in our multi-modal example, you can search images using text or another image. Let's define a utility function to plot the image.
class Pets(LanceModel):
vector: Vector(clip.ndims()) = clip.VectorField()
image_uri: str = clip.SourceField()
@property
def image(self):
return Image.open(self.image_uri)
Now that you have the basic idea about LanceDB embedding functions and the embedding function registry, let's dive deeper into defining your own custom functions.