forked from microsoft/torchgeo
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbenchmark_dataset_load.py
79 lines (74 loc) · 1.86 KB
/
benchmark_dataset_load.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""Benchmark dataset load."""
# %%
import os
from ffcv.fields.decoders import NDArrayDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.transforms import ToTensor
from torch.utils.benchmark import Timer
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchgeo.datasets import BigEarthNet
# %%
batch_size = 64
num_workers = 0
num_batches_load = 50
data_root = "/scratch/users/mike/data/"
# %%
ds_no_target = BigEarthNet(
root=os.path.join(data_root, "BigEarthNet"), load_target=False
)
ds_target = BigEarthNet(root=os.path.join(data_root, "BigEarthNet"), load_target=True)
ffcv_pipeline = {
"image": [NDArrayDecoder(), ToTensor()],
"label": [NDArrayDecoder(), ToTensor()],
}
# %%
dl_no_label = DataLoader(
dataset=ds_no_target, batch_size=batch_size, num_workers=num_workers, shuffle=True
)
# %%
dl_with_label = DataLoader(
dataset=ds_target, batch_size=batch_size, num_workers=num_workers, shuffle=True
)
# %%
dl_ffcv = Loader(
fname=os.path.join(data_root, "FFCV", "BigEarthNet_test.beton"),
batch_size=batch_size,
num_workers=num_workers,
order=OrderOption.RANDOM,
distributed=False,
batches_ahead=2,
pipelines=ffcv_pipeline,
)
# %%
def load_num_batches(num_batches, dataloader):
i = 0
for _ in tqdm(dataloader):
if i >= num_batches:
return True
i += 1
# %%
timer_no_label = Timer(
stmt="load_num_batches(num_batches_load, dl_no_label)",
globals=globals(),
label="Load without label",
)
# %%
timer_label = Timer(
stmt="load_num_batches(num_batches_load, dl_with_label)",
globals=globals(),
label="Load with label",
)
# %%
timer_ffcv = Timer(
stmt="load_num_batches(num_batches_load, dl_ffcv)",
globals=globals(),
label="Load ffcv",
)
# %%
timer_no_label.timeit(number=1)
# %%
timer_label.timeit(number=1)
# %%
timer_ffcv.timeit(number=1)
# %%