Indexing a dataset with HNSW(Hierarchical Navigable Small World)ΒΆ

HNSW is a graph based algorithm for approximate neighbor search in high-dimensional spaces. In this example, we will demonstrate how to build an HNSW vector index against a Lance dataset.

This example will show how to:

  1. Generate synthetic test data of specified dimensions

  2. Build a hierarchical graph structure for efficient vector search using Lance API

  3. Perform vector search with different parameters and compute the ground truth using L2 distance search

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

//! HNSW is a graph based algorithm for approximate neighbor search in high-dimensional spaces.
//! In this example, we will demonstrate how to build HNSW vector indexing against a Lance dataset.
//! run with `cargo run -v --package lance-examples --example hnsw``
// linked to `docs/examples/Rust/hnsw.rst`
#![allow(clippy::print_stdout)]
use std::collections::HashSet;
use std::sync::Arc;

use arrow::array::{types::Float32Type, Array, FixedSizeListArray};
use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow::record_batch::RecordBatchIterator;
use arrow_select::concat::concat;
use futures::stream::StreamExt;
use lance::Dataset;
use lance_index::vector::v3::subindex::IvfSubIndex;
use lance_index::vector::{
    flat::storage::FlatFloatStorage,
    hnsw::{builder::HnswBuildParams, HNSW},
};
use lance_linalg::distance::DistanceType;

fn ground_truth(fsl: &FixedSizeListArray, query: &[f32], k: usize) -> HashSet<u32> {
    let mut dists = vec![];
    for i in 0..fsl.len() {
        let dist = lance_linalg::distance::l2_distance(
            query,
            fsl.value(i).as_primitive::<Float32Type>().values(),
        );
        dists.push((dist, i as u32));
    }
    dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
    dists.truncate(k);
    dists.into_iter().map(|(_, i)| i).collect()
}

pub async fn create_test_vector_dataset(output: &str, num_rows: usize, dim: i32) {
    let schema = Arc::new(Schema::new(vec![Field::new(
        "vector",
        DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
        false,
    )]));

    let mut batches = Vec::new();

    // Create a few batches
    for _ in 0..2 {
        let v_builder = Float32Builder::new();
        let mut list_builder = FixedSizeListBuilder::new(v_builder, dim);

        for _ in 0..num_rows {
            for _ in 0..dim {
                list_builder.values().append_value(rand::random::<f32>());
            }
            list_builder.append(true);
        }
        let array = Arc::new(list_builder.finish());
        let batch = RecordBatch::try_new(schema.clone(), vec![array]).unwrap();
        batches.push(batch);
    }
    let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
    println!("Writing dataset to {}", output);
    Dataset::write(batch_reader, output, None).await.unwrap();
}

#[tokio::main]
async fn main() {
    let uri: Option<String> = None; // None means generate test data
    let column = "vector";
    let ef = 100;
    let max_edges = 30;
    let max_level = 7;

    // 1. Generate a synthetic test data of specified dimensions
    let dataset = if uri.is_none() {
        println!("No uri is provided, generating test dataset...");
        let output = "test_vectors.lance";
        create_test_vector_dataset(output, 1000, 64).await;
        Dataset::open(output).await.expect("Failed to open dataset")
    } else {
        Dataset::open(uri.as_ref().unwrap())
            .await
            .expect("Failed to open dataset")
    };

    println!("Dataset schema: {:#?}", dataset.schema());
    let batches = dataset
        .scan()
        .project(&[column])
        .unwrap()
        .try_into_stream()
        .await
        .unwrap()
        .then(|batch| async move { batch.unwrap().column_by_name(column).unwrap().clone() })
        .collect::<Vec<_>>()
        .await;
    let arrs = batches.iter().map(|b| b.as_ref()).collect::<Vec<_>>();
    let fsl = concat(&arrs).unwrap().as_fixed_size_list().clone();
    println!("Loaded {:?} batches", fsl.len());

    let vector_store = Arc::new(FlatFloatStorage::new(fsl.clone(), DistanceType::L2));

    let q = fsl.value(0);
    let k = 10;
    let gt = ground_truth(&fsl, q.as_primitive::<Float32Type>().values(), k);

    for ef_construction in [15, 30, 50] {
        let now = std::time::Instant::now();
        // 2. Build a hierarchical graph structure for efficient vector search using Lance API
        let hnsw = HNSW::index_vectors(
            vector_store.as_ref(),
            HnswBuildParams::default()
                .max_level(max_level)
                .num_edges(max_edges)
                .ef_construction(ef_construction),
        )
        .unwrap();
        let construct_time = now.elapsed().as_secs_f32();
        let now = std::time::Instant::now();
        // 3. Perform vector search with different parameters and compute the ground truth using L2 distance search
        let results: HashSet<u32> = hnsw
            .search_basic(q.clone(), k, ef, None, vector_store.as_ref())
            .unwrap()
            .iter()
            .map(|node| node.id)
            .collect();
        let search_time = now.elapsed().as_micros();
        println!(
            "level={}, ef_construct={}, ef={} recall={}: construct={:.3}s search={:.3} us",
            max_level,
            ef_construction,
            ef,
            results.intersection(&gt).count() as f32 / k as f32,
            construct_time,
            search_time
        );
    }
}