Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
PathLike (#4479)
Browse files Browse the repository at this point in the history
Co-authored-by: Evan Pete Walsh <epwalsh10@gmail.com>
  • Loading branch information
dirkgr and epwalsh authored Jul 15, 2020
1 parent 2f87832 commit d693cf1
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 21 deletions.
11 changes: 6 additions & 5 deletions allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import argparse
import logging
import os
from typing import Any, Dict, List, Optional
from os import PathLike
from typing import Any, Dict, List, Optional, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -106,8 +107,8 @@ def train_model_from_args(args: argparse.Namespace):


def train_model_from_file(
parameter_filename: str,
serialization_dir: str,
parameter_filename: Union[str, PathLike],
serialization_dir: Union[str, PathLike],
overrides: str = "",
recover: bool = False,
force: bool = False,
Expand Down Expand Up @@ -161,7 +162,7 @@ def train_model_from_file(

def train_model(
params: Params,
serialization_dir: str,
serialization_dir: Union[str, PathLike],
recover: bool = False,
force: bool = False,
node_rank: int = 0,
Expand Down Expand Up @@ -287,7 +288,7 @@ def train_model(
def _train_worker(
process_rank: int,
params: Params,
serialization_dir: str,
serialization_dir: Union[str, PathLike],
include_package: List[str] = None,
dry_run: bool = False,
node_rank: int = 0,
Expand Down
5 changes: 3 additions & 2 deletions allennlp/common/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import tempfile
import json
from os import PathLike
from urllib.parse import urlparse
from pathlib import Path
from typing import Optional, Tuple, Union, IO, Callable, Set, List, Iterator, Iterable
Expand Down Expand Up @@ -89,7 +90,7 @@ def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[


def cached_path(
url_or_filename: Union[str, Path],
url_or_filename: Union[str, PathLike],
cache_dir: Union[str, Path] = None,
extract_archive: bool = False,
force_extract: bool = False,
Expand Down Expand Up @@ -119,7 +120,7 @@ def cached_path(
if cache_dir is None:
cache_dir = CACHE_DIRECTORY

if isinstance(url_or_filename, Path):
if isinstance(url_or_filename, PathLike):
url_or_filename = str(url_or_filename)

# If we're using the /a/b/foo.zip!c/d/file.txt syntax, handle it here.
Expand Down
7 changes: 6 additions & 1 deletion allennlp/common/logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging
from logging import Filter
import os
from os import PathLike
from typing import Union

import sys


Expand Down Expand Up @@ -54,7 +57,9 @@ def filter(self, record):
return record.levelno < logging.ERROR


def prepare_global_logging(serialization_dir: str, rank: int = 0, world_size: int = 1,) -> None:
def prepare_global_logging(
serialization_dir: Union[str, PathLike], rank: int = 0, world_size: int = 1,
) -> None:
root_logger = logging.getLogger()

# create handlers
Expand Down
5 changes: 3 additions & 2 deletions allennlp/common/params.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, List
from os import PathLike
from typing import Any, Dict, List, Union
from collections.abc import MutableMapping
from collections import OrderedDict
import copy
Expand Down Expand Up @@ -456,7 +457,7 @@ def _check_is_dict(self, new_history, value):

@classmethod
def from_file(
cls, params_file: str, params_overrides: str = "", ext_vars: dict = None
cls, params_file: Union[str, PathLike], params_overrides: str = "", ext_vars: dict = None
) -> "Params":
"""
Load a `Params` object from a configuration file.
Expand Down
3 changes: 2 additions & 1 deletion allennlp/common/testing/model_test_case.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import json
from os import PathLike
from typing import Any, Dict, Iterable, Set, Union

import torch
Expand Down Expand Up @@ -48,7 +49,7 @@ def set_up_model(self, param_file, dataset_file):

def ensure_model_can_train_save_and_load(
self,
param_file: str,
param_file: Union[PathLike, str],
tolerance: float = 1e-4,
cuda_device: int = -1,
gradients_to_ignore: Set[str] = None,
Expand Down
8 changes: 5 additions & 3 deletions allennlp/models/archival.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Helper functions for archiving models and restoring archived models.
"""

from typing import NamedTuple
from os import PathLike
from typing import NamedTuple, Union
import atexit
import logging
import os
Expand Down Expand Up @@ -89,7 +89,9 @@ def extract_module(self, path: str, freeze: bool = True) -> Module:


def archive_model(
serialization_dir: str, weights: str = _DEFAULT_WEIGHTS, archive_path: str = None
serialization_dir: Union[str, PathLike],
weights: str = _DEFAULT_WEIGHTS,
archive_path: Union[str, PathLike] = None,
) -> None:
"""
Archive the model weights, its training configuration, and its vocabulary to `model.tar.gz`.
Expand Down
11 changes: 6 additions & 5 deletions allennlp/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import logging
import os
from typing import Dict, List, Set, Type, Optional
from os import PathLike
from typing import Dict, List, Set, Type, Optional, Union

try:
from apex import amp
Expand Down Expand Up @@ -268,8 +269,8 @@ def _maybe_warn_for_unseparable_batches(self, output_key: str):
def _load(
cls,
config: Params,
serialization_dir: str,
weights_file: Optional[str] = None,
serialization_dir: Union[str, PathLike],
weights_file: Optional[Union[str, PathLike]] = None,
cuda_device: int = -1,
opt_level: Optional[str] = None,
) -> "Model":
Expand Down Expand Up @@ -349,8 +350,8 @@ def _load(
def load(
cls,
config: Params,
serialization_dir: str,
weights_file: Optional[str] = None,
serialization_dir: Union[str, PathLike],
weights_file: Optional[Union[str, PathLike]] = None,
cuda_device: int = -1,
opt_level: Optional[str] = None,
) -> "Model":
Expand Down
5 changes: 3 additions & 2 deletions allennlp/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import os
import shutil
from os import PathLike
from typing import Any, Dict, Iterable, Optional, Union, Tuple, Set, List
from collections import Counter

Expand Down Expand Up @@ -168,7 +169,7 @@ def datasets_from_params(


def create_serialization_dir(
params: Params, serialization_dir: str, recover: bool, force: bool
params: Params, serialization_dir: Union[str, PathLike], recover: bool, force: bool
) -> None:
"""
This function creates the serialization directory if it doesn't exist. If it already exists
Expand Down Expand Up @@ -413,7 +414,7 @@ def description_from_metrics(metrics: Dict[str, float]) -> str:


def make_vocab_from_params(
params: Params, serialization_dir: str, print_statistics: bool = False
params: Params, serialization_dir: Union[str, PathLike], print_statistics: bool = False
) -> Vocabulary:
vocab_params = params.pop("vocabulary", {})
os.makedirs(serialization_dir, exist_ok=True)
Expand Down

0 comments on commit d693cf1

Please sign in to comment.