Creating text dataset for LLM training using Lance in Rust¶
In this example, we will demonstrate how to achieve the Python example - LLM dataset creation shown in the Python examples in Rust.
Note
The huggingface Python API supports loading data in streaming mode and shuffling is provided as a builtin feature. Rust API lacks these feature thus the data are manually downloaded and shuffled within each batch.
This example will show how to:
- Download and process a text dataset in parts from huggingface
- Tokenize the text data with a custom RecordBatchReader
- Save it as a Lance dataset using Lance API
The implementation details in Rust will follow similar concepts as the Python version, but with Rust-specific APIs and patterns which are significantly more verbose.
Complete Example¶
use arrow::array::{Array, Int64Builder, ListBuilder, UInt32Array};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow::record_batch::RecordBatchReader;
use futures::StreamExt;
use hf_hub::{api::sync::Api, Repo, RepoType};
use lance::dataset::WriteParams;
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use rand::seq::SliceRandom;
use rand::SeedableRng;
use std::error::Error;
use std::fs::File;
use std::io::Write;
use std::sync::Arc;
use tempfile::NamedTempFile;
use tokenizers::Tokenizer;
// Implement a custom stream batch reader
struct WikiTextBatchReader {
schema: Arc<Schema>,
parquet_readers: Vec<Option<ParquetRecordBatchReaderBuilder<File>>>,
current_reader_idx: usize,
current_reader: Option<Box<dyn RecordBatchReader + Send>>,
tokenizer: Tokenizer,
num_samples: u64,
cur_samples_cnt: u64,
}
impl WikiTextBatchReader {
fn new(
parquet_readers: Vec<ParquetRecordBatchReaderBuilder<File>>,
tokenizer: Tokenizer,
num_samples: Option<u64>,
) -> Result<Self, Box<dyn Error + Send + Sync>> {
let schema = Arc::new(Schema::new(vec![Field::new(
"input_ids",
DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
false,
)]));
Ok(Self {
schema,
parquet_readers: parquet_readers.into_iter().map(Some).collect(),
current_reader_idx: 0,
current_reader: None,
tokenizer,
num_samples: num_samples.unwrap_or(100_000),
cur_samples_cnt: 0,
})
}
fn process_batch(
&mut self,
input_batch: &RecordBatch,
) -> Result<RecordBatch, arrow::error::ArrowError> {
let num_rows = input_batch.num_rows();
let mut token_builder = ListBuilder::new(Int64Builder::with_capacity(num_rows * 1024)); // Pre-allocate space
let mut should_break = false;
let column = input_batch.column_by_name("text").unwrap();
let string_array = column
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
for i in 0..num_rows {
if self.cur_samples_cnt >= self.num_samples {
should_break = true;
break;
}
if !Array::is_null(string_array, i) {
let text = string_array.value(i);
// Split paragraph into lines
for line in text.split('\n') {
if let Ok(encoding) = self.tokenizer.encode(line, true) {
let tb_values = token_builder.values();
for &id in encoding.get_ids() {
tb_values.append_value(id as i64);
}
token_builder.append(true);
self.cur_samples_cnt += 1;
if self.cur_samples_cnt % 5000 == 0 {
println!("Processed {} rows", self.cur_samples_cnt);
}
if self.cur_samples_cnt >= self.num_samples {
should_break = true;
break;
}
}
}
}
}
// Create array and shuffle it
let input_ids_array = token_builder.finish();
// Create shuffled array by randomly sampling indices
let mut rng = rand::rngs::StdRng::seed_from_u64(1337);
let len = input_ids_array.len();
let mut indices: Vec<u32> = (0..len as u32).collect();
indices.shuffle(&mut rng);
// Take values in shuffled order
let indices_array = UInt32Array::from(indices);
let shuffled = arrow::compute::take(&input_ids_array, &indices_array, None)?;
let batch = RecordBatch::try_new(self.schema.clone(), vec![Arc::new(shuffled)]);
if should_break {
println!("Stop at {} rows", self.cur_samples_cnt);
self.parquet_readers.clear();
self.current_reader = None;
}
batch
}
}
impl RecordBatchReader for WikiTextBatchReader {
fn schema(&self) -> Arc<Schema> {
self.schema.clone()
}
}
impl Iterator for WikiTextBatchReader {
type Item = Result<RecordBatch, arrow::error::ArrowError>;
fn next(&mut self) -> Option<Self::Item> {
loop {
// If we have a current reader, try to get next batch
if let Some(reader) = &mut self.current_reader {
if let Some(batch_result) = reader.next() {
return Some(batch_result.and_then(|batch| self.process_batch(&batch)));
}
}
// If no current reader or current reader is exhausted, try to get next reader
if self.current_reader_idx < self.parquet_readers.len() {
if let Some(builder) = self.parquet_readers[self.current_reader_idx].take() {
match builder.build() {
Ok(reader) => {
self.current_reader = Some(Box::new(reader));
self.current_reader_idx += 1;
continue;
}
Err(e) => {
return Some(Err(arrow::error::ArrowError::ExternalError(Box::new(e))))
}
}
}
}
// No more readers available
return None;
}
}
}
fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(async {
// Load tokenizer
let tokenizer = load_tokenizer("gpt2")?;
// Set up Hugging Face API
// Download from https://huggingface.co/datasets/Salesforce/wikitext/tree/main/wikitext-103-raw-v1
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
"Salesforce/wikitext".into(),
RepoType::Dataset,
"main".into(),
));
// Define the parquet files we want to download
let train_files = vec![
"wikitext-103-raw-v1/train-00000-of-00002.parquet",
"wikitext-103-raw-v1/train-00001-of-00002.parquet",
];
let mut parquet_readers = Vec::new();
for file in &train_files {
println!("Downloading file: {}", file);
let file_path = repo.get(file)?;
let data = std::fs::read(file_path)?;
// Create a temporary file in the system temp directory and write the downloaded data to it
let mut temp_file = NamedTempFile::new()?;
temp_file.write_all(&data)?;
// Create the parquet reader builder with a larger batch size
let builder = ParquetRecordBatchReaderBuilder::try_new(temp_file.into_file())?
.with_batch_size(8192); // Increase batch size for better performance
parquet_readers.push(builder);
}
if parquet_readers.is_empty() {
println!("No parquet files found to process.");
return Ok(());
}
// Create batch reader
let num_samples: u64 = 500_000;
let batch_reader = WikiTextBatchReader::new(parquet_readers, tokenizer, Some(num_samples))?;
// Save as Lance dataset
println!("Writing to Lance dataset...");
let lance_dataset_path = "rust_wikitext_lance_dataset.lance";
let write_params = WriteParams::default();
lance::Dataset::write(batch_reader, lance_dataset_path, Some(write_params)).await?;
// Verify the dataset
let ds = lance::Dataset::open(lance_dataset_path).await?;
let scanner = ds.scan();
let mut stream = scanner.try_into_stream().await?;
let mut total_rows = 0;
while let Some(batch_result) = stream.next().await {
let batch = batch_result?;
total_rows += batch.num_rows();
}
println!(
"Lance dataset created successfully with {} rows",
total_rows
);
println!("Dataset location: {}", lance_dataset_path);
Ok(())
})
}
fn load_tokenizer(model_name: &str) -> Result<Tokenizer, Box<dyn Error + Send + Sync>> {
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
model_name.into(),
RepoType::Model,
"main".into(),
));
let tokenizer_path = repo.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_path)?;
Ok(tokenizer)
}