Skip to content

Commit

Permalink
Minor changes
Browse files Browse the repository at this point in the history
- Adds cpu prediction support
- Fixes exception w/ not existing Raganato dataset directory
- Fixes protobuf lib latest version incompatibility protocolbuffers/protobuf#10051
  • Loading branch information
andrea-gasparini committed Oct 23, 2022
1 parent 2576c15 commit d937b5a
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 2 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
data/*
predictions/*

!.placeholder

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
5 changes: 4 additions & 1 deletion esc/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--prediction-types", type=str, required=True, nargs="+")
parser.add_argument("--evaluate", action="store_true", default=False)
# default + not required
parser.add_argument("--cpu", action="store_true", default=False)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--tokens-per-batch", type=int, default=4000)
parser.add_argument("--output-errors", action="store_true", default=False)
Expand All @@ -328,7 +329,9 @@ def main() -> None:
wsd_model = ESCModule.load_from_checkpoint(args.ckpt)
wsd_model.freeze()

if args.device >= 0:
if args.cpu:
wsd_model.to("cpu")
elif args.device >= 0:
wsd_model.to(torch.device(args.device))

tokenizer = get_tokenizer(
Expand Down
3 changes: 3 additions & 0 deletions esc/utils/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def chunks(lst, n):

def list_elems_in_dir(dir_path: str, only_files: bool = False, only_dirs: bool = False) -> List[str]:

if not isdir(dir_path):
return list()

elems_in_dir = [e for e in listdir(dir_path)]

if only_files:
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ pytorch-lightning==0.9.0
wandb
nltk==3.4.5
nlp
black==21.5b2
black==21.5b2
protobuf<=3.20.1

0 comments on commit d937b5a

Please sign in to comment.