diff --git a/src/tevatron/arguments.py b/src/tevatron/arguments.py index c2554b8c..b6b67231 100644 --- a/src/tevatron/arguments.py +++ b/src/tevatron/arguments.py @@ -99,12 +99,15 @@ def __post_init__(self): self.dataset_split = 'train' self.dataset_language = 'default' if self.train_dir is not None: - files = os.listdir(self.train_dir) - self.train_path = [ - os.path.join(self.train_dir, f) - for f in files - if f.endswith('jsonl') or f.endswith('json') - ] + if os.path.isdir(self.train_dir): + files = os.listdir(self.train_dir) + self.train_path = [ + os.path.join(self.train_dir, f) + for f in files + if f.endswith('jsonl') or f.endswith('json') + ] + else: + self.train_path = [self.train_dir] else: self.train_path = None diff --git a/src/tevatron/data.py b/src/tevatron/data.py index e5a5a455..970111cb 100644 --- a/src/tevatron/data.py +++ b/src/tevatron/data.py @@ -30,7 +30,7 @@ def __init__( self.total_len = len(self.train_data) def create_one_example(self, text_encoding: List[int], is_query=False): - item = self.tok.encode_plus( + item = self.tok.prepare_for_model( text_encoding, truncation='only_first', max_length=self.data_args.q_max_len if is_query else self.data_args.p_max_len, @@ -95,7 +95,7 @@ def __len__(self): def __getitem__(self, item) -> Tuple[str, BatchEncoding]: text_id, text = (self.encode_data[item][f] for f in self.input_keys) - encoded_text = self.tok.encode_plus( + encoded_text = self.tok.prepare_for_model( text, max_length=self.max_len, truncation='only_first', diff --git a/src/tevatron/datasets/preprocessor.py b/src/tevatron/datasets/preprocessor.py index 2daa1acf..4756d848 100644 --- a/src/tevatron/datasets/preprocessor.py +++ b/src/tevatron/datasets/preprocessor.py @@ -54,4 +54,4 @@ def __call__(self, example): add_special_tokens=False, max_length=self.text_max_length, truncation=True) - return {'text_id': docid, 'text': text} + return {'text_id': docid, 'text': text} \ No newline at end of file diff --git a/src/tevatron/driver/encode.py b/src/tevatron/driver/encode.py index 84a484e7..c31d9d47 100644 --- a/src/tevatron/driver/encode.py +++ b/src/tevatron/driver/encode.py @@ -52,8 +52,7 @@ def main(): ) tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - use_fast=False, + cache_dir=model_args.cache_dir ) model = DenseModel.load( diff --git a/src/tevatron/driver/train.py b/src/tevatron/driver/train.py index e95c26a2..ff529403 100644 --- a/src/tevatron/driver/train.py +++ b/src/tevatron/driver/train.py @@ -2,6 +2,7 @@ import os import sys +import torch from transformers import AutoConfig, AutoTokenizer from transformers import ( HfArgumentParser, @@ -66,8 +67,7 @@ def main(): ) tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - use_fast=False, + cache_dir=model_args.cache_dir ) model = DenseModel.build( model_args, @@ -78,7 +78,13 @@ def main(): train_dataset = HFTrainDataset(tokenizer=tokenizer, data_args=data_args, cache_dir=data_args.data_cache_dir or model_args.cache_dir) + if training_args.local_rank > 0: + print("Waiting for main process to perform the mapping") + torch.distributed.barrier() train_dataset = TrainDataset(data_args, train_dataset.process(), tokenizer) + if training_args.local_rank == 0: + print("Loading results from main process") + torch.distributed.barrier() trainer_cls = GCTrainer if training_args.grad_cache else Trainer trainer = trainer_cls( diff --git a/src/tevatron/faiss_retriever/__main__.py b/src/tevatron/faiss_retriever/__main__.py index 3099768f..8b8e9b1a 100644 --- a/src/tevatron/faiss_retriever/__main__.py +++ b/src/tevatron/faiss_retriever/__main__.py @@ -19,7 +19,7 @@ def search_queries(retriever, q_reps, p_lookup, args): if args.batch_size > 0: - all_scores, all_indices = retriever.batch_search(q_reps, args.depth, args.batch_size) + all_scores, all_indices = retriever.batch_search(q_reps, args.depth, args.batch_size, args.quiet) else: all_scores, all_indices = retriever.search(q_reps, args.depth) @@ -56,6 +56,7 @@ def main(): parser.add_argument('--depth', type=int, default=1000) parser.add_argument('--save_ranking_to', required=True) parser.add_argument('--save_text', action='store_true') + parser.add_argument('--quiet', action='store_true') args = parser.parse_args() diff --git a/src/tevatron/faiss_retriever/retriever.py b/src/tevatron/faiss_retriever/retriever.py index 315c7637..bbe375fc 100644 --- a/src/tevatron/faiss_retriever/retriever.py +++ b/src/tevatron/faiss_retriever/retriever.py @@ -2,6 +2,7 @@ import faiss import logging +from tqdm import tqdm logger = logging.getLogger(__name__) @@ -17,11 +18,11 @@ def add(self, p_reps: np.ndarray): def search(self, q_reps: np.ndarray, k: int): return self.index.search(q_reps, k) - def batch_search(self, q_reps: np.ndarray, k: int, batch_size: int): + def batch_search(self, q_reps: np.ndarray, k: int, batch_size: int, quiet: bool=False): num_query = q_reps.shape[0] all_scores = [] all_indices = [] - for start_idx in range(0, num_query, batch_size): + for start_idx in tqdm(range(0, num_query, batch_size), disable=quiet): nn_scores, nn_indices = self.search(q_reps[start_idx: start_idx + batch_size], k) all_scores.append(nn_scores) all_indices.append(nn_indices)