-
Notifications
You must be signed in to change notification settings - Fork 1
/
faiss_embedding_writer.rs
141 lines (110 loc) · 4.35 KB
/
faiss_embedding_writer.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
use clap::{ArgAction, Parser};
use rustserini::encode::auto::AutoDocumentEncoder;
use rustserini::encode::base::{DocumentEncoder, RepresentationWriter};
use rustserini::encode::vector_writer::{FaissRepresentationWriter, JsonlCollectionIterator};
use std::collections::HashMap;
use std::time::Instant;
/// A Rust example of encoding a corpus and store the embeddings in a FAISS Index
/// Download the msmarco passage dataset using the below command:
/// mkdir corpus/msmarco-passage
/// wget https://huggingface.co/datasets/Tevatron/msmarco-passage-corpus/resolve/main/corpus.jsonl.gz -P corpus/msmarco-passage
/// cargo run --example faiss_embedding_writer -- --corpus corpus/msmarco-passage/corpus.jsonl.gz --embeddings-dir corpus/msmarco-passage --encoder bert-base-uncased --tokenizer bert-base-uncased
///
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Directory that contains corpus files to be encoded, in jsonl format.
#[arg(short, long)]
corpus: String,
/// Fields that contents in jsonl has (in order) separated by comma.
#[arg(short, long, default_value = "text,title")]
fields: String,
/// delimiter for the fields
#[arg(short, long, default_value = "\n")]
delimiter: String,
/// shard-id 0-based
#[arg(short, long, default_value_t = 0)]
shard_id: u8,
/// number of shards
#[arg(long, default_value_t = 1)]
shard_num: u8,
/// directory to store encoded corpus
#[arg(short, long, required = true)]
embeddings_dir: String,
/// Whether to store the embeddings in a faiss index or in a jsonl file
#[arg(long, action=ArgAction::SetFalse)]
to_faiss: bool,
/// Use lowercase in tokenizer
#[arg(long, action=ArgAction::SetTrue)]
lowercase: bool,
/// Strip accents in tokenizer
#[arg(long, action=ArgAction::SetTrue)]
strip_accents: bool,
/// Encoder name or path
#[arg(long)]
encoder: String,
/// Tokenizer name or path
#[arg(long)]
tokenizer: String,
/// Batch size for encoding
#[arg(short, long, default_value_t = 8)]
batch_size: usize,
/// GPU Device ==> cpu or cuda:0
#[arg(long, default_value = "cpu")]
device: String,
/// Whether to use fp16
#[arg(long, action=ArgAction::SetTrue)]
fp16: bool,
/// max length of the input
#[arg(short, long, default_value_t = 512)]
max_length: u16,
/// Embedding dimension
#[arg(long, default_value_t = 2)]
embedding_dim: u32,
}
fn main() {
let start = Instant::now();
let args = Args::parse();
let fields: Vec<String> = args.fields.split(',').map(|s| s.to_string()).collect();
let mut iterator: JsonlCollectionIterator =
JsonlCollectionIterator::new(fields, "docid".to_string(), args.delimiter, args.batch_size);
let _ = iterator.load(args.corpus);
let mut writer: FaissRepresentationWriter =
FaissRepresentationWriter::new(&args.embeddings_dir, args.embedding_dim);
writer.init_index(768, "Flat");
let _ = writer.open_file();
let lowercase = args.lowercase;
let strip_accents = args.strip_accents;
let encoder: AutoDocumentEncoder = AutoDocumentEncoder::new(
&args.encoder,
Some(&args.tokenizer),
lowercase,
strip_accents,
);
let mut counter: usize = 0;
// let pb = ProgressBarIter(iterator.iter());
for batch in iterator.iter() {
let mut batch_info = HashMap::new();
let batch_text: Vec<String> = batch["text"].to_vec();
let batch_title: Vec<String> = batch["title"].to_vec();
let batch_id: Vec<String> = batch["id"].to_vec();
let embeddings = &encoder.encode(&batch_text, &batch_title, "cls");
let mut embeddings: Vec<f32> = match embeddings {
Ok(embeddings) => embeddings.to_vec(),
Err(_) => vec![],
};
batch_info.insert("text", batch_text);
batch_info.insert("title", batch_title);
batch_info.insert("id", batch_id);
let _ = &writer.write(&batch_info, &mut embeddings);
counter += 1;
if counter % 100 == 0 {
// Reduce console output
println!("Batch {} encoded", counter);
}
}
let _ = writer.save_index();
let _ = writer.save_docids();
let duration = start.elapsed();
println!("Time elapsed in expensive_function() is: {:?}", duration);
}