Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

fix: datamodule can't load files with square brackets in names #1501

Merged
merged 15 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion flash/core/data/utilities/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import glob
import re
from functools import partial
from os import PathLike
from typing import Union
from urllib.parse import parse_qs, quote, urlencode, urlparse

import fsspec
import numpy as np
Expand Down Expand Up @@ -139,9 +144,30 @@ def _get_loader(file_path: str, loaders):
)


WINDOWS_FILE_PATH_RE = re.compile("^[a-zA-Z]:(\\\\[^\\\\]|/[^/]).*")


def is_local_path(file_path: str) -> bool:
if WINDOWS_FILE_PATH_RE.fullmatch(file_path):
return True
return urlparse(file_path).scheme in ["", "file"]


def escape_url(url: str) -> str:
parsed = urlparse(url)
return f"{parsed.scheme}://{parsed.netloc}{quote(parsed.path)}?{urlencode(parse_qs(parsed.query), doseq=True)}"


def escape_file_path(file_path: Union[str, PathLike]) -> str:
file_path_str = str(file_path)
return glob.escape(file_path_str) if is_local_path(file_path_str) else escape_url(file_path_str)


def load(file_path: str, loaders):
loader = _get_loader(file_path, loaders)
with fsspec.open(file_path) as file:
# escaping file_path to avoid fsspec treating the path as a glob pattern
# fsspec ignores `expand=False` in read mode
with fsspec.open(escape_file_path(file_path)) as file:
return loader(file)


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ pandas>=1.1.0
jsonargparse[signatures]>=3.17.0, <=4.9.0
click>=7.1.2
protobuf<=3.20.1
fsspec
fsspec[http]>=2021.6.1,<=2022.7.1
lightning-utilities>=0.3.0
11 changes: 10 additions & 1 deletion tests/core/data/utilities/test_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def write_tsv(file_path):
@pytest.mark.parametrize(
"extension,write",
[(extension, write_image) for extension in IMG_EXTENSIONS]
+ [(extension, write_numpy) for extension in NP_EXTENSIONS],
+ [(extension, write_numpy) for extension in NP_EXTENSIONS]
# it shouldn't try to expand glob patterns in filenames
+ [(filename, write_image) for filename in ("image [test].jpeg",)],
)
def test_load_image(tmpdir, extension, write):
file_path = os.path.join(tmpdir, f"test{extension}")
Expand Down Expand Up @@ -149,6 +151,13 @@ def test_load_data_frame(tmpdir, extension, write):
Image.Image,
marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed."),
),
# it shouldn't try to expand glob patterns in URLs
pytest.param(
"https://pl-flash-data.s3.amazonaws.com/images/ant_1 [test].jpg",
load_image,
Image.Image,
marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed."),
),
pytest.param(
"https://pl-flash-data.s3.amazonaws.com/images/ant_1.jpg",
load_spectrogram,
Expand Down