Creating text dataset for LLM training using Lance in RustΒΆ

In this example, we will demonstrate how to achieve the Python example - LLM dataset creation shown in Creating text dataset for LLM training using Lance in Rust.

Note

The huggingface Python API supports loading data in streaming mode and shuffling is provided as a builtin feature. Rust API lacks these feature thus the data are manually downloaded and shuffled within each batch.

This example will show how to:

  1. Download and process a text dataset in parts from huggingface

  2. Tokenize the text data with a custom RecordBatchReader

  3. Save it as a Lance dataset using Lance API

The implementation details in Rust will follow similar concepts as the Python version, but with Rust-specific APIs and patterns which are significantly more verbose.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

//! This example demonstrates how to:
//!
//! 1. Download and process a text dataset in parts from huggingface
//! 2. Tokenize the text data with a custom RecordBatchReader
//! 3. Save it as a Lance dataset using Lance API
//!
//! Run with `cargo run -v --package lance-examples --example llm_dataset_creation`
//!
//!
// linked to `docs/examples/Rust/llm_dataset_creation.rst`

use arrow::array::{Array, Int64Builder, ListBuilder, UInt32Array};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow::record_batch::RecordBatchReader;
use futures::StreamExt;
use hf_hub::{api::sync::Api, Repo, RepoType};
use lance::dataset::WriteParams;
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use rand::seq::SliceRandom;
use rand::SeedableRng;
use std::error::Error;
use std::fs::File;
use std::io::Write;
use std::sync::Arc;
use tempfile::NamedTempFile;
use tokenizers::Tokenizer;

// Implement a custom stream batch reader
struct WikiTextBatchReader {
    schema: Arc<Schema>,
    parquet_readers: Vec<Option<ParquetRecordBatchReaderBuilder<File>>>,
    current_reader_idx: usize,
    current_reader: Option<Box<dyn RecordBatchReader + Send>>,
    tokenizer: Tokenizer,
    num_samples: u64,
    cur_samples_cnt: u64,
}

impl WikiTextBatchReader {
    fn new(
        parquet_readers: Vec<ParquetRecordBatchReaderBuilder<File>>,
        tokenizer: Tokenizer,
        num_samples: Option<u64>,
    ) -> Result<Self, Box<dyn Error + Send + Sync>> {
        let schema = Arc::new(Schema::new(vec![Field::new(
            "input_ids",
            DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
            false,
        )]));

        Ok(Self {
            schema,
            parquet_readers: parquet_readers.into_iter().map(Some).collect(),
            current_reader_idx: 0,
            current_reader: None,
            tokenizer,
            num_samples: num_samples.unwrap_or(100_000),
            cur_samples_cnt: 0,
        })
    }

    fn process_batch(
        &mut self,
        input_batch: &RecordBatch,
    ) -> Result<RecordBatch, arrow::error::ArrowError> {
        let num_rows = input_batch.num_rows();
        let mut token_builder = ListBuilder::new(Int64Builder::with_capacity(num_rows * 1024)); // Pre-allocate space
        let mut should_break = false;

        let column = input_batch.column_by_name("text").unwrap();
        let string_array = column
            .as_any()
            .downcast_ref::<arrow::array::StringArray>()
            .unwrap();
        for i in 0..num_rows {
            if self.cur_samples_cnt >= self.num_samples {
                should_break = true;
                break;
            }
            if !Array::is_null(string_array, i) {
                let text = string_array.value(i);
                // Split paragraph into lines
                for line in text.split('\n') {
                    if let Ok(encoding) = self.tokenizer.encode(line, true) {
                        let tb_values = token_builder.values();
                        for &id in encoding.get_ids() {
                            tb_values.append_value(id as i64);
                        }
                        token_builder.append(true);
                        self.cur_samples_cnt += 1;
                        if self.cur_samples_cnt % 5000 == 0 {
                            println!("Processed {} rows", self.cur_samples_cnt);
                        }
                        if self.cur_samples_cnt >= self.num_samples {
                            should_break = true;
                            break;
                        }
                    }
                }
            }
        }

        // Create array and shuffle it
        let input_ids_array = token_builder.finish();

        // Create shuffled array by randomly sampling indices
        let mut rng = rand::rngs::StdRng::seed_from_u64(1337);
        let len = input_ids_array.len();
        let mut indices: Vec<u32> = (0..len as u32).collect();
        indices.shuffle(&mut rng);

        // Take values in shuffled order
        let indices_array = UInt32Array::from(indices);
        let shuffled = arrow::compute::take(&input_ids_array, &indices_array, None)?;

        let batch = RecordBatch::try_new(self.schema.clone(), vec![Arc::new(shuffled)]);
        if should_break {
            println!("Stop at {} rows", self.cur_samples_cnt);
            self.parquet_readers.clear();
            self.current_reader = None;
        }

        batch
    }
}

impl RecordBatchReader for WikiTextBatchReader {
    fn schema(&self) -> Arc<Schema> {
        self.schema.clone()
    }
}

impl Iterator for WikiTextBatchReader {
    type Item = Result<RecordBatch, arrow::error::ArrowError>;
    fn next(&mut self) -> Option<Self::Item> {
        loop {
            // If we have a current reader, try to get next batch
            if let Some(reader) = &mut self.current_reader {
                if let Some(batch_result) = reader.next() {
                    return Some(batch_result.and_then(|batch| self.process_batch(&batch)));
                }
            }

            // If no current reader or current reader is exhausted, try to get next reader
            if self.current_reader_idx < self.parquet_readers.len() {
                if let Some(builder) = self.parquet_readers[self.current_reader_idx].take() {
                    match builder.build() {
                        Ok(reader) => {
                            self.current_reader = Some(Box::new(reader));
                            self.current_reader_idx += 1;
                            continue;
                        }
                        Err(e) => {
                            return Some(Err(arrow::error::ArrowError::ExternalError(Box::new(e))))
                        }
                    }
                }
            }

            // No more readers available
            return None;
        }
    }
}

fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
    let rt = tokio::runtime::Runtime::new()?;
    rt.block_on(async {
        // Load tokenizer
        let tokenizer = load_tokenizer("gpt2")?;

        // Set up Hugging Face API
        // Download from https://huggingface.co/datasets/Salesforce/wikitext/tree/main/wikitext-103-raw-v1
        let api = Api::new()?;
        let repo = api.repo(Repo::with_revision(
            "Salesforce/wikitext".into(),
            RepoType::Dataset,
            "main".into(),
        ));

        // Define the parquet files we want to download
        let train_files = vec![
            "wikitext-103-raw-v1/train-00000-of-00002.parquet",
            "wikitext-103-raw-v1/train-00001-of-00002.parquet",
        ];

        let mut parquet_readers = Vec::new();
        for file in &train_files {
            println!("Downloading file: {}", file);
            let file_path = repo.get(file)?;
            let data = std::fs::read(file_path)?;

            // Create a temporary file in the system temp directory and write the downloaded data to it
            let mut temp_file = NamedTempFile::new()?;
            temp_file.write_all(&data)?;

            // Create the parquet reader builder with a larger batch size
            let builder = ParquetRecordBatchReaderBuilder::try_new(temp_file.into_file())?
                .with_batch_size(8192); // Increase batch size for better performance
            parquet_readers.push(builder);
        }

        if parquet_readers.is_empty() {
            println!("No parquet files found to process.");
            return Ok(());
        }

        // Create batch reader
        let num_samples: u64 = 500_000;
        let batch_reader = WikiTextBatchReader::new(parquet_readers, tokenizer, Some(num_samples))?;

        // Save as Lance dataset
        println!("Writing to Lance dataset...");
        let lance_dataset_path = "rust_wikitext_lance_dataset.lance";

        let write_params = WriteParams::default();
        lance::Dataset::write(batch_reader, lance_dataset_path, Some(write_params)).await?;

        // Verify the dataset
        let ds = lance::Dataset::open(lance_dataset_path).await?;
        let scanner = ds.scan();
        let mut stream = scanner.try_into_stream().await?;

        let mut total_rows = 0;
        while let Some(batch_result) = stream.next().await {
            let batch = batch_result?;
            total_rows += batch.num_rows();
        }

        println!(
            "Lance dataset created successfully with {} rows",
            total_rows
        );
        println!("Dataset location: {}", lance_dataset_path);

        Ok(())
    })
}

fn load_tokenizer(model_name: &str) -> Result<Tokenizer, Box<dyn Error + Send + Sync>> {
    let api = Api::new()?;
    let repo = api.repo(Repo::with_revision(
        model_name.into(),
        RepoType::Model,
        "main".into(),
    ));

    let tokenizer_path = repo.get("tokenizer.json")?;
    let tokenizer = Tokenizer::from_file(tokenizer_path)?;

    Ok(tokenizer)
}