-
Notifications
You must be signed in to change notification settings - Fork 0
/
mkdataset.py
71 lines (60 loc) · 2.59 KB
/
mkdataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#! /usr/bin/env python
import argparse
import json
from mmlkg.data import dataset
from mmlkg.data.hdf5 import HDF5
_MODALITIES = ["textual", "numerical", "temporal", "visual", "spatial"]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", help="Directory with CSV files "
+ "(generated by `generateInput.py`)",
required=True)
parser.add_argument("-c", "--config",
help="JSON file with hyperparameters",
default=None)
parser.add_argument("-m", "--modalities", nargs='*',
help="Which modalities to include",
choices=[m.lower() for m in _MODALITIES],
default=_MODALITIES)
parser.add_argument("-o", "--output", help="Output directory",
default=None)
parser.add_argument("--compression", help="Compress generated dataset",
choices=["gzip", "lzf", "none"],
default="lzf")
flags = parser.parse_args()
config = {"encoders": dict(), "optim": dict()}
if flags.config is not None:
print("[CONF] Using configuration from %s" % flags.config)
with open(flags.config, 'r') as f:
config = json.load(f)
out_dir = flags.input if flags.output is None else flags.output
out_dir = out_dir + '/' if not out_dir.endswith('/') else out_dir
out_file = out_dir + 'dataset.h5'
with HDF5(out_file, 'w', flags.compression) as hf:
nc_data = dict()
lp_data = dict()
for name, item in dataset.generate_dataset(flags, config):
if name == 'num_nodes':
hf.write_metadata(name, item)
elif name in _MODALITIES:
hf.write_modality_data(item, name)
elif name in ['num_classes',
'training',
'testing',
'validation']:
if item is not None:
nc_data[name] = item
elif name in ['entities',
'triples',
'training_lp',
'testing_lp',
'validation_lp']:
if item is not None:
lp_data[name] = item
else:
continue
if len(nc_data) == 4:
hf.write_task_data(nc_data, HDF5.NODE_CLASSIFICATION)
if len(lp_data) == 5:
hf.write_task_data(lp_data, HDF5.LINK_PREDICTION)
print(f"Dataset saved to {out_file}")