Skip to content

Commit

Permalink
Capitalize class names
Browse files Browse the repository at this point in the history
  • Loading branch information
yuxuan-ji committed Sep 13, 2020
1 parent 26de4a6 commit 8a636c7
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 24 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,22 @@ Here's how to initalize the T5 reranker from [Document Ranking with a Pretrained

```python
from pygaggle.rerank.base import Query, Text
from pygaggle.rerank.transformer import monoT5
from pygaggle.rerank.transformer import MonoT5

model_name = 'castorini/monot5-base-msmarco'
model_name = 'castorini/monoT5-base-msmarco'
tokenizer_name = 't5-base'
reranker = monoT5(model_name, tokenizer_name)
reranker = MonoT5(model_name, tokenizer_name)
```

Alternatively, here's the BERT reranker from [Passage Re-ranking with BERT](https://arxiv.org/pdf/1901.04085.pdf), which isn't as good as the T5 reranker:

```python
from pygaggle.rerank.base import Query, Text
from pygaggle.rerank.transformer import monoBERT
from pygaggle.rerank.transformer import MonoBERT

model_name = 'castorini/monobert-large-msmarco'
model_name = 'castorini/monoBERT-large-msmarco'
tokenizer_name = 'bert-large-uncased'
reranker = monoBERT(model_name, tokenizer_name)
reranker = MonoBERT(model_name, tokenizer_name)
```

Either way, continue with a complere reranking example:
Expand Down
12 changes: 6 additions & 6 deletions pygaggle/rerank/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
greedy_decode)


__all__ = ['monoT5',
__all__ = ['MonoT5',
'UnsupervisedTransformerReranker',
'monoBERT',
'MonoBERT',
'QuestionAnsweringTransformerReranker']


class monoT5(Reranker):
class MonoT5(Reranker):
def __init__(self,
model_name_or_instance: Union[str, T5ForConditionalGeneration] = 'castorini/monot5-base-msmarco',
model_name_or_instance: Union[str, T5ForConditionalGeneration] = 'castorini/monoT5-base-msmarco',
tokenizer_name_or_instance: Union[str, QueryDocumentBatchTokenizer] = 't5-base'):
if isinstance(model_name_or_instance, str):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Expand Down Expand Up @@ -107,9 +107,9 @@ def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
return texts


class monoBERT(Reranker):
class MonoBERT(Reranker):
def __init__(self,
model_name_or_instance: Union[str, PreTrainedModel] = 'castorini/monobert-large-msmarco',
model_name_or_instance: Union[str, PreTrainedModel] = 'castorini/monoBERT-large-msmarco',
tokenizer_name_or_instance: Union[str, PreTrainedTokenizer] = 'bert-large-uncased'):
if isinstance(model_name_or_instance, str):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Expand Down
8 changes: 4 additions & 4 deletions pygaggle/run/evaluate_document_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from pygaggle.rerank.bm25 import Bm25Reranker
from pygaggle.rerank.transformer import (
UnsupervisedTransformerReranker,
monoT5,
monoBERT
MonoT5,
MonoBERT
)
from pygaggle.rerank.random import RandomReranker
from pygaggle.rerank.similarity import CosineSimilarityMatrixProvider
Expand Down Expand Up @@ -85,7 +85,7 @@ def construct_t5(options: DocumentRankingEvaluationOptions) -> Reranker:
from_tf=options.from_tf).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(options.model_type)
tokenizer = T5BatchTokenizer(tokenizer, options.batch_size)
return monoT5(model, tokenizer)
return MonoT5(model, tokenizer)


def construct_transformer(options:
Expand All @@ -106,7 +106,7 @@ def construct_seq_class_transformer(options: DocumentRankingEvaluationOptions
device = torch.device(options.device)
model = model.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(options.tokenizer_name)
return monoBERT(model, tokenizer)
return MonoBERT(model, tokenizer)


def construct_bm25(options: DocumentRankingEvaluationOptions) -> Reranker:
Expand Down
8 changes: 4 additions & 4 deletions pygaggle/run/evaluate_kaggle_highlighter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from pygaggle.rerank.bm25 import Bm25Reranker
from pygaggle.rerank.transformer import (
QuestionAnsweringTransformerReranker,
monoBERT,
monoT5,
MonoBERT,
MonoT5,
UnsupervisedTransformerReranker
)
from pygaggle.rerank.random import RandomReranker
Expand Down Expand Up @@ -82,7 +82,7 @@ def construct_t5(options: KaggleEvaluationOptions) -> Reranker:
tokenizer = AutoTokenizer.from_pretrained(
options.model_name, do_lower_case=options.do_lower_case)
tokenizer = T5BatchTokenizer(tokenizer, options.batch_size)
return monoT5(model, tokenizer)
return MonoT5(model, tokenizer)


def construct_transformer(options: KaggleEvaluationOptions) -> Reranker:
Expand Down Expand Up @@ -124,7 +124,7 @@ def construct_seq_class_transformer(options:
model = model.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(
options.tokenizer_name, do_lower_case=options.do_lower_case)
return monoBERT(model, tokenizer)
return MonoBERT(model, tokenizer)


def construct_qa_transformer(options: KaggleEvaluationOptions) -> Reranker:
Expand Down
8 changes: 4 additions & 4 deletions pygaggle/run/evaluate_passage_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from pygaggle.rerank.bm25 import Bm25Reranker
from pygaggle.rerank.transformer import (
UnsupervisedTransformerReranker,
monoT5,
monoBERT
MonoT5,
MonoBERT
)
from pygaggle.rerank.random import RandomReranker
from pygaggle.rerank.similarity import CosineSimilarityMatrixProvider
Expand Down Expand Up @@ -83,7 +83,7 @@ def construct_t5(options: PassageRankingEvaluationOptions) -> Reranker:
from_tf=options.from_tf).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(options.model_type)
tokenizer = T5BatchTokenizer(tokenizer, options.batch_size)
return monoT5(model, tokenizer)
return MonoT5(model, tokenizer)


def construct_transformer(options:
Expand Down Expand Up @@ -116,7 +116,7 @@ def construct_seq_class_transformer(options: PassageRankingEvaluationOptions
device = torch.device(options.device)
model = model.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(options.tokenizer_name)
return monoBERT(model, tokenizer)
return MonoBERT(model, tokenizer)


def construct_bm25(options: PassageRankingEvaluationOptions) -> Reranker:
Expand Down

0 comments on commit 8a636c7

Please sign in to comment.