-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
new feature: Add webdataset in audio
- Loading branch information
1 parent
e04cd18
commit e1193cc
Showing
24 changed files
with
4,348 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. | ||
# See the LICENSE file for licensing terms (BSD-style). | ||
# Modified from https://github.com/webdataset/webdataset | ||
# | ||
# flake8: noqa | ||
|
||
from .cache import ( | ||
cached_tarfile_samples, | ||
cached_tarfile_to_samples, | ||
lru_cleanup, | ||
pipe_cleaner, | ||
) | ||
from .compat import WebDataset, WebLoader, FluidWrapper | ||
from webdataset.extradatasets import MockDataset, with_epoch, with_length | ||
from .filters import ( | ||
associate, | ||
batched, | ||
decode, | ||
detshuffle, | ||
extract_keys, | ||
getfirst, | ||
info, | ||
map, | ||
map_dict, | ||
map_tuple, | ||
pipelinefilter, | ||
rename, | ||
rename_keys, | ||
rsample, | ||
select, | ||
shuffle, | ||
slice, | ||
to_tuple, | ||
transform_with, | ||
unbatched, | ||
xdecode, | ||
data_filter, | ||
tokenize, | ||
resample, | ||
compute_fbank, | ||
spec_aug, | ||
sort, | ||
padding, | ||
cmvn | ||
) | ||
from webdataset.handlers import ( | ||
ignore_and_continue, | ||
ignore_and_stop, | ||
reraise_exception, | ||
warn_and_continue, | ||
warn_and_stop, | ||
) | ||
from .pipeline import DataPipeline | ||
from .shardlists import ( | ||
MultiShardSample, | ||
ResampledShards, | ||
SimpleShardList, | ||
non_empty, | ||
resampled, | ||
shardspec, | ||
single_node_only, | ||
split_by_node, | ||
split_by_worker, | ||
) | ||
from .tariterators import tarfile_samples, tarfile_to_samples | ||
from .utils import PipelineStage, repeatedly | ||
from webdataset.writer import ShardWriter, TarWriter, numpy_dumps | ||
from webdataset.mix import RandomMix, RoundRobin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved. | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# See the LICENSE file for licensing terms (BSD-style). | ||
# Modified from https://github.com/webdataset/webdataset | ||
import itertools, os, random, re, sys | ||
from urllib.parse import urlparse | ||
|
||
from . import filters | ||
from webdataset import gopen | ||
from webdataset.handlers import reraise_exception | ||
from .tariterators import tar_file_and_group_expander | ||
|
||
default_cache_dir = os.environ.get("WDS_CACHE", "./_cache") | ||
default_cache_size = float(os.environ.get("WDS_CACHE_SIZE", "1e18")) | ||
|
||
|
||
def lru_cleanup(cache_dir, cache_size, keyfn=os.path.getctime, verbose=False): | ||
"""Performs cleanup of the file cache in cache_dir using an LRU strategy, | ||
keeping the total size of all remaining files below cache_size.""" | ||
if not os.path.exists(cache_dir): | ||
return | ||
total_size = 0 | ||
for dirpath, dirnames, filenames in os.walk(cache_dir): | ||
for filename in filenames: | ||
total_size += os.path.getsize(os.path.join(dirpath, filename)) | ||
if total_size <= cache_size: | ||
return | ||
# sort files by last access time | ||
files = [] | ||
for dirpath, dirnames, filenames in os.walk(cache_dir): | ||
for filename in filenames: | ||
files.append(os.path.join(dirpath, filename)) | ||
files.sort(key=keyfn, reverse=True) | ||
# delete files until we're under the cache size | ||
while len(files) > 0 and total_size > cache_size: | ||
fname = files.pop() | ||
total_size -= os.path.getsize(fname) | ||
if verbose: | ||
print("# deleting %s" % fname, file=sys.stderr) | ||
os.remove(fname) | ||
|
||
|
||
def download(url, dest, chunk_size=1024 ** 2, verbose=False): | ||
"""Download a file from `url` to `dest`.""" | ||
temp = dest + f".temp{os.getpid()}" | ||
with gopen.gopen(url) as stream: | ||
with open(temp, "wb") as f: | ||
while True: | ||
data = stream.read(chunk_size) | ||
if not data: | ||
break | ||
f.write(data) | ||
os.rename(temp, dest) | ||
|
||
|
||
def pipe_cleaner(spec): | ||
"""Guess the actual URL from a "pipe:" specification.""" | ||
if spec.startswith("pipe:"): | ||
spec = spec[5:] | ||
words = spec.split(" ") | ||
for word in words: | ||
if re.match(r"^(https?|gs|ais|s3)", word): | ||
return word | ||
return spec | ||
|
||
|
||
def get_file_cached( | ||
spec, | ||
cache_size=-1, | ||
cache_dir=None, | ||
url_to_name=pipe_cleaner, | ||
verbose=False, | ||
): | ||
if cache_size == -1: | ||
cache_size = default_cache_size | ||
if cache_dir is None: | ||
cache_dir = default_cache_dir | ||
url = url_to_name(spec) | ||
parsed = urlparse(url) | ||
dirname, filename = os.path.split(parsed.path) | ||
dirname = dirname.lstrip("/") | ||
dirname = re.sub(r"[:/|;]", "_", dirname) | ||
destdir = os.path.join(cache_dir, dirname) | ||
os.makedirs(destdir, exist_ok=True) | ||
dest = os.path.join(cache_dir, dirname, filename) | ||
if not os.path.exists(dest): | ||
if verbose: | ||
print("# downloading %s to %s" % (url, dest), file=sys.stderr) | ||
lru_cleanup(cache_dir, cache_size, verbose=verbose) | ||
download(spec, dest, verbose=verbose) | ||
return dest | ||
|
||
|
||
def get_filetype(fname): | ||
with os.popen("file '%s'" % fname) as f: | ||
ftype = f.read() | ||
return ftype | ||
|
||
|
||
def check_tar_format(fname): | ||
"""Check whether a file is a tar archive.""" | ||
ftype = get_filetype(fname) | ||
return "tar archive" in ftype or "gzip compressed" in ftype | ||
|
||
|
||
verbose_cache = int(os.environ.get("WDS_VERBOSE_CACHE", "0")) | ||
|
||
|
||
def cached_url_opener( | ||
data, | ||
handler=reraise_exception, | ||
cache_size=-1, | ||
cache_dir=None, | ||
url_to_name=pipe_cleaner, | ||
validator=check_tar_format, | ||
verbose=False, | ||
always=False, | ||
): | ||
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams.""" | ||
verbose = verbose or verbose_cache | ||
for sample in data: | ||
assert isinstance(sample, dict), sample | ||
assert "url" in sample | ||
url = sample["url"] | ||
attempts = 5 | ||
try: | ||
if not always and os.path.exists(url): | ||
dest = url | ||
else: | ||
dest = get_file_cached( | ||
url, | ||
cache_size=cache_size, | ||
cache_dir=cache_dir, | ||
url_to_name=url_to_name, | ||
verbose=verbose, | ||
) | ||
if verbose: | ||
print("# opening %s" % dest, file=sys.stderr) | ||
assert os.path.exists(dest) | ||
if not validator(dest): | ||
ftype = get_filetype(dest) | ||
with open(dest, "rb") as f: | ||
data = f.read(200) | ||
os.remove(dest) | ||
raise ValueError( | ||
"%s (%s) is not a tar archive, but a %s, contains %s" | ||
% (dest, url, ftype, repr(data)) | ||
) | ||
try: | ||
stream = open(dest, "rb") | ||
sample.update(stream=stream) | ||
yield sample | ||
except FileNotFoundError as exn: | ||
# dealing with race conditions in lru_cleanup | ||
attempts -= 1 | ||
if attempts > 0: | ||
time.sleep(random.random() * 10) | ||
continue | ||
raise exn | ||
except Exception as exn: | ||
exn.args = exn.args + (url,) | ||
if handler(exn): | ||
continue | ||
else: | ||
break | ||
|
||
|
||
def cached_tarfile_samples( | ||
src, | ||
handler=reraise_exception, | ||
cache_size=-1, | ||
cache_dir=None, | ||
verbose=False, | ||
url_to_name=pipe_cleaner, | ||
always=False, | ||
): | ||
streams = cached_url_opener( | ||
src, | ||
handler=handler, | ||
cache_size=cache_size, | ||
cache_dir=cache_dir, | ||
verbose=verbose, | ||
url_to_name=url_to_name, | ||
always=always, | ||
) | ||
samples = tar_file_and_group_expander(streams, handler=handler) | ||
return samples | ||
|
||
|
||
cached_tarfile_to_samples = filters.pipelinefilter(cached_tarfile_samples) |
Oops, something went wrong.