-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtarsplit
executable file
·128 lines (112 loc) · 3.66 KB
/
tarsplit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python3
#
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# This file is part of webloader (see TBD).
# See the LICENSE file for licensing terms (BSD-style).
#
import argparse
import os
import subprocess
import sys
from tarproclib import reader, writer
parser = argparse.ArgumentParser(
"Split a tar file into shards based on size or number of samples."
)
parser.add_argument("-n", "--num-samples", default=100000, type=float)
parser.add_argument("-s", "--max-size", default=1e9, type=float)
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("-C", "--command", default=None)
parser.add_argument("-o", "--output", default="temp")
parser.add_argument("-O", "--open", default=None)
parser.add_argument("-z", "--compress", action="store_true")
parser.add_argument("--start", default=0, type=int)
parser.add_argument("--maxshards", default=1000000000, type=int)
parser.add_argument(
"--nodelete", action="store_true", help="don't delete after executing command"
)
parser.add_argument("input", default="-", nargs="?")
args = parser.parse_args()
def dprint(*args, **kw):
print(*args, file=sys.stderr, **kw)
def sample_size(sample):
total = 0
for k, v in sample.items():
total += len(k)
total += len(v)
return total
total_count = 0
total_size = 0
shard = 0
shard_name = None
count = 0
size = 0
sink = None
def finish_shard():
global sink
if sink is None:
return
if hasattr(sink, "process"):
sink.process.wait(timeout=60.0)
sink.close()
sink = None
if args.command is not None:
basename = os.path.basename(shard_name)
base, ext = os.path.splitext(basename)
kw = dict(
shard=shard_name,
abspath=os.path.abspath(shard_name),
basename=basename,
dirname=os.path.basename(shard_name),
base=base,
ext=ext,
)
cmd = args.command.format(**kw)
print(f"# {cmd}", file=sys.stderr)
status = os.system(cmd)
assert status == 0, (status, cmd)
if not args.nodelete:
print(f"# removing {shard_name}", file=sys.stderr)
os.unlink(shard_name)
if "{" not in args.output:
if args.compress:
output_pattern = args.output + "-{shard:06d}.tgz"
else:
output_pattern = args.output + "-{shard:06d}.tar"
else:
output_pattern = args.output
for sample in reader.TarIterator(args.input):
if args.verbose:
dprint(sample.get("__key__"))
if sink is None or count >= args.num_samples or size >= args.max_size:
total_count += count
total_size += size
count = 0
size = 0
finish_shard()
if shard >= args.maxshards:
break
shard_name = output_pattern.format(shard=shard)
dprint(f"# writing {shard_name} ({total_count}, {total_size})")
if shard_name[0] == "|":
process = subprocess.Popen(
shard_name[1:], stdin=subprocess.PIPE, shell=True
)
stream = process.stdin
stream.process = process
sink_stream = stream
elif args.open is not None:
process = subprocess.Popen(
args.open + " " + shard_name, stdin=subprocess.PIPE, shell=True
)
stream = process.stdin
stream.process = process
sink_stream = stream
else:
sink_stream = open(shard_name, "wb")
compress = None if not args.compress else True
sink = writer.TarWriter(sink_stream, compress=compress)
shard += 1
sink.write(sample)
count += 1
size += sample_size(sample)
finish_shard()