Skip to content

Commit

Permalink
Merge pull request #2 from RUCAIBox/master
Browse files Browse the repository at this point in the history
Merge
  • Loading branch information
chenyushuo authored Jul 5, 2020
2 parents b10e7d2 + 26a60e9 commit 51c9ce2
Show file tree
Hide file tree
Showing 11 changed files with 100,095 additions and 26 deletions.
2 changes: 1 addition & 1 deletion config/abstract_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __getitem__(self, item):
# convert param from str to value, i.e. int, float or list etc.
try:
value = eval(param)
if not isinstance(value, (str, int, float, list, tuple, bool, None.__class__)):
if not isinstance(value, (str, int, float, list, tuple, dict, bool, None.__class__)):
value = param
except NameError:
if param.lower() == "true":
Expand Down
18 changes: 14 additions & 4 deletions config/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from config.running_configurator import RunningConfig
from config.model_configurator import ModelConfig

from config.data_configurator import DataConfig

class Config(object):
"""
Expand Down Expand Up @@ -45,6 +45,10 @@ def __init__(self, config_file_name):
model_arg_file_name = os.path.join(os.path.dirname(config_file_name), model_name + '.config')
self.model_args = ModelConfig(model_arg_file_name)

dataset_name = self.run_args['dataset']
dataset_arg_file_name = os.path.join(os.path.dirname(config_file_name), dataset_name + '.config')
self.dataset_args = DataConfig(dataset_arg_file_name)

self.device = None

def init(self):
Expand Down Expand Up @@ -80,6 +84,8 @@ def __getitem__(self, item):
return self.run_args[item]
elif item in self.model_args:
return self.model_args[item]
elif item in self.dataset_args:
return self.dataset_args[item]
else:
raise KeyError("There are no parameter named '%s'" % item)

Expand All @@ -93,14 +99,16 @@ def __setitem__(self, key, value):
def __contains__(self, o):
if not isinstance(o, str):
raise TypeError("index must be a str!")
return o in self.run_args or o in self.model_args
return o in self.run_args or o in self.model_args or o in self.dataset_args

def __str__(self):

run_args_info = str(self.run_args)
model_args_info = str(self.model_args)
info = "\nRunning Hyper Parameters:\n%s\n\nRunning Model:%s\n\nModel Hyper Parameters:\n%s\n" % \
(run_args_info, self.run_args['model'], model_args_info)
dataset_args_info = str(self.dataset_args)
info = "\nRunning Hyper Parameters:\n%s\n\nRunning Model:%s\n\nDataset Hyper Parameters:%s\n\n" \
"Model Hyper Parameters:\n%s\n" % \
(run_args_info, self.run_args['model'], dataset_args_info, model_args_info)
return info

def __repr__(self):
Expand All @@ -112,6 +120,8 @@ def __repr__(self):
config = Config('../properties/overall.config')
config.init()
# print(config)
print(config['train.epochs'])
print(config['data.SEQ_LEN'])
print(config['process.split_by_ratio.train_ratio'])
print(config['eval.metric'])
print(config['eval.topk'])
Expand Down
19 changes: 19 additions & 0 deletions config/data_configurator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from config.abstract_configurator import AbstractConfig


class DataConfig(AbstractConfig):
def __init__(self, config_file_name):
super().__init__()
self.must_args = []
self.args = self._read_config_file(config_file_name, 'data')
for cmd_arg_name, cmd_arg_value in self.cmd_args.items():
if cmd_arg_name.startswith('data.'):
self.args[cmd_arg_name] = cmd_arg_value

self.default_args = self.default_parser.getargs()
for default_args in self.default_args.keys():

if str(default_args) not in self.args and (str(default_args).startswith('data.')):
self.args[str(default_args)] = str(self.default_args[default_args])

self._check_args()
7 changes: 4 additions & 3 deletions config/running_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ class RunningConfig(AbstractConfig):
def __init__(self, config_file_name):

super().__init__()
self.must_args = ['model', 'data.name', 'data.path']
self.must_args = ['model', 'dataset', 'dataset.path']
self.args = self._read_config_file(config_file_name, 'default')
for cmd_arg_name, cmd_arg_value in self.cmd_args.items():
if not cmd_arg_name.startswith('model.'):
if not cmd_arg_name.startswith('model.') or not cmd_arg_name.startswith('data.'):
self.args[cmd_arg_name] = cmd_arg_value

self.default_args = self.default_parser.getargs()
for default_args in self.default_args.keys():
if str(default_args) not in self.args and not str(default_args).startswith('model.'):
if str(default_args) not in self.args and not str(default_args).startswith('model.') \
and not str(default_args).startswith('data.'):
self.args[str(default_args)] = str(self.default_args[default_args])

self._check_args()
Loading

0 comments on commit 51c9ce2

Please sign in to comment.