-
Notifications
You must be signed in to change notification settings - Fork 621
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from RUCAIBox/master
Copy from RUCAIBox
- Loading branch information
Showing
29 changed files
with
1,893 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
.vscode/ | ||
.idea/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from config.configurator import Config | ||
from config.parser import Parser |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import os | ||
import sys | ||
from configparser import ConfigParser | ||
from config.parser import Parser | ||
|
||
class AbstractConfig(object): | ||
def __init__(self): | ||
self.cmd_args = dict() | ||
self._read_cmd_line() | ||
self.args = dict() | ||
self.must_args = [] | ||
self.default_parser = Parser() | ||
|
||
def _read_config_file(self, file_name, arg_section): | ||
""" | ||
This function is a protected function that read the config file and convert it to a dict | ||
:param file_name(str): The path of ini-style configuration file. | ||
:param arg_section(str): section name that distinguish running config file and model config file | ||
:return: A dict whose key and value are both str. | ||
""" | ||
if not os.path.isfile(file_name): | ||
raise FileNotFoundError("There is no config file named '%s'!" % file_name) | ||
config = ConfigParser() | ||
config.optionxform = str | ||
config.read(file_name, encoding="utf-8") | ||
sections = config.sections() | ||
|
||
if len(sections) == 0: | ||
raise ValueError("'%s' is not in correct format!" % file_name) | ||
elif arg_section not in sections: | ||
raise ValueError("'%s' is not in correct format!" % file_name) | ||
else: | ||
config_arg = dict(config[arg_section].items()) | ||
|
||
return config_arg | ||
|
||
def _read_cmd_line(self): | ||
|
||
for arg in sys.argv[1:]: | ||
if not arg.startswith("--"): | ||
raise SyntaxError("Commend arg must start with '--', but '%s' is not!" % arg) | ||
cmd_arg_name, cmd_arg_value = arg[2:].split("=") | ||
if cmd_arg_name in self.cmd_args and cmd_arg_value != self.cmd_args[cmd_arg_name]: | ||
raise SyntaxError("There are duplicate commend arg '%s' with different value!" % arg) | ||
else: | ||
self.cmd_args[cmd_arg_name] = cmd_arg_value | ||
|
||
def _check_args(self): | ||
""" | ||
This function is a protected function that check MUST parameters | ||
""" | ||
for parameter in self.must_args: | ||
if parameter not in self.args: | ||
raise ValueError("'%s' must be specified !" % parameter) | ||
|
||
def __getitem__(self, item): | ||
|
||
if not isinstance(item, str): | ||
raise TypeError("index must be a str") | ||
# Get device or other parameters | ||
|
||
if item in self.args: | ||
param = self.args[item] | ||
else: | ||
raise KeyError("There are no parameter named '%s'" % item) | ||
|
||
# convert param from str to value, i.e. int, float or list etc. | ||
try: | ||
value = eval(param) | ||
if not isinstance(value, (str, int, float, list, tuple, bool, None.__class__)): | ||
value = param | ||
except NameError: | ||
if param.lower() == "true": | ||
value = True | ||
elif param.lower() == "false": | ||
value = False | ||
else: | ||
value = param | ||
return value | ||
|
||
def __setitem__(self, key, value): | ||
|
||
if not isinstance(key, str): | ||
raise TypeError("index must be a str") | ||
if key not in self.args: | ||
raise KeyError("There are no model parameter named '%s'" % key) | ||
|
||
self.args[key] = str(value) | ||
|
||
def __contains__(self, o): | ||
if not isinstance(o, str): | ||
raise TypeError("index must be a str!") | ||
return o in self.args | ||
|
||
def __str__(self): | ||
args_info = '\n'.join( | ||
["{}={}".format(arg, value) for arg, value in self.args.items()]) | ||
return args_info | ||
|
||
def __repr__(self): | ||
|
||
return self.__str__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import os | ||
import torch | ||
import random | ||
import numpy as np | ||
from config.running_configurator import RunningConfig | ||
from config.model_configurator import ModelConfig | ||
|
||
|
||
class Config(object): | ||
""" | ||
A configuration class that load predefined hyper parameters. | ||
This class can read arguments from ini-style configuration file. There are two type | ||
of config file can be defined: running config file and model config file. Each file should | ||
be named as XXX.config, and model config file MUST be named as the model name. | ||
In the ini-style config file, only one section is cared. For running config file, the section | ||
MUST be [default] . For model config file, it MUST be [model] | ||
There are three parameter MUST be included in config file: model, data.name, data.path | ||
After initialization successful, the objective of this class can be used as | ||
a dictionary: | ||
config = Configurator("./overall.config") | ||
ratio = config["process.ratio"] | ||
metric = config["eval.metric"] | ||
All the parameter key MUST be str, but returned value is exactly the corresponding type | ||
support parameter type: str, int, float, list, tuple, bool, None | ||
""" | ||
|
||
def __init__(self, config_file_name): | ||
""" | ||
:param config_file_name(str): The path of ini-style configuration file. | ||
:raises : | ||
FileNotFoundError: If `config_file` is not existing. | ||
ValueError: If `config_file` is not in correct format or | ||
MUST parameter are not defined | ||
""" | ||
|
||
self.run_args = RunningConfig(config_file_name) | ||
|
||
model_name = self.run_args['model'] | ||
model_arg_file_name = os.path.join(os.path.dirname(config_file_name), model_name + '.config') | ||
self.model_args = ModelConfig(model_arg_file_name) | ||
|
||
self.device = None | ||
|
||
def init(self): | ||
""" | ||
This function is a global initialization function that fix random seed and gpu device. | ||
""" | ||
init_seed = self.run_args['seed'] | ||
gpu_id = self.run_args['gpu_id'] | ||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) | ||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Get the device that run on. | ||
|
||
random.seed(init_seed) | ||
np.random.seed(init_seed) | ||
torch.manual_seed(init_seed) | ||
torch.cuda.manual_seed(init_seed) | ||
torch.cuda.manual_seed_all(init_seed) | ||
torch.backends.cudnn.deterministic = True | ||
|
||
def dump_config_file(self, config_file): | ||
""" | ||
This function can dump the model's hyper parameters to a new config file | ||
:param config_file: file name that write to. | ||
""" | ||
self.model_args.dump_config_file(config_file) | ||
|
||
def __getitem__(self, item): | ||
if item == "device": | ||
if self.device is None: | ||
raise SyntaxError("device only can be get after init() !") | ||
else: | ||
return self.device | ||
elif item in self.run_args: | ||
return self.run_args[item] | ||
elif item in self.model_args: | ||
return self.model_args[item] | ||
else: | ||
raise KeyError("There are no parameter named '%s'" % item) | ||
|
||
def __setitem__(self, key, value): | ||
if not isinstance(key, str): | ||
raise TypeError("index must be a str") | ||
if key in self.run_args: | ||
raise SyntaxError("running args can not be changed when running!") | ||
self.model_args[key] = value | ||
|
||
def __contains__(self, o): | ||
if not isinstance(o, str): | ||
raise TypeError("index must be a str!") | ||
return o in self.run_args or o in self.model_args | ||
|
||
def __str__(self): | ||
|
||
run_args_info = str(self.run_args) | ||
model_args_info = str(self.model_args) | ||
info = "\nRunning Hyper Parameters:\n%s\n\nRunning Model:%s\n\nModel Hyper Parameters:\n%s\n" % \ | ||
(run_args_info, self.run_args['model'], model_args_info) | ||
return info | ||
|
||
def __repr__(self): | ||
|
||
return self.__str__() | ||
|
||
|
||
if __name__ == '__main__': | ||
config = Config('../properties/overall.config') | ||
config.init() | ||
# print(config) | ||
print(config['process.split_by_ratio.train_ratio']) | ||
print(config['eval.metric']) | ||
print(config['eval.topk']) | ||
print(config['model.learning_rate']) | ||
print(config['eval.group_view']) | ||
print(config['data.separator']) | ||
print(config['model.reg_mf']) | ||
print(config['device']) | ||
config['model.reg_mf'] = 0.6 | ||
print(config['model.reg_mf']) | ||
#config.dump_config_file('../properties/mf_new.config') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from configparser import ConfigParser | ||
from config.abstract_configurator import AbstractConfig | ||
|
||
|
||
class ModelConfig(AbstractConfig): | ||
def __init__(self, config_file_name): | ||
super().__init__() | ||
self.must_args = [] | ||
self.args = self._read_config_file(config_file_name, 'model') | ||
for cmd_arg_name, cmd_arg_value in self.cmd_args.items(): | ||
if cmd_arg_name.startswith('model.'): | ||
self.args[cmd_arg_name] = cmd_arg_value | ||
|
||
self.default_args = self.default_parser.getargs() | ||
for default_args in self.default_args.keys(): | ||
|
||
if str(default_args) not in self.args and str(default_args).startswith('model.'): | ||
self.args[str(default_args)] = str(self.default_args[default_args]) | ||
|
||
self._check_args() | ||
|
||
def dump_config_file(self, config_file): | ||
""" | ||
This function can dump the model's hyper parameters to a new config file | ||
:param config_file: file name that write to. | ||
""" | ||
model_config = ConfigParser() | ||
model_config['model'] = self.args | ||
with open(config_file, 'w') as configfile: | ||
model_config.write(configfile) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from optparse import OptionParser | ||
|
||
|
||
class Parser(object): | ||
def __init__(self): | ||
parser = OptionParser() | ||
parser.add_option('--gpu_id', dest='gpu_id', default=0, type='int', help='GPU id that be used') | ||
parser.add_option('--seed', dest='seed', default=2020, type='int', help='random seed') | ||
parser.add_option('--data.num_workers', dest='data.num_workers', default=0, type='int', | ||
help='multi-process when load data') | ||
parser.add_option('--data.separator', dest='data.separator', default=0, type='int', help='data separator') | ||
parser.add_option('--process.remove_lower_value_by_key.key', dest='process.remove_lower_value_by_key.key', | ||
default='rating', type='string', help='data filter') | ||
parser.add_option('--process.remove_lower_value_by_key.min_remain_value', | ||
dest='process.remove_lower_value_by_key.min_remain_value', | ||
default=3, type='int', help='data filter') | ||
parser.add_option('--process.neg_sample_to.num', | ||
dest='process.neg_sample_to.num', | ||
default=100, type='int', help='number of neg samples') | ||
parser.add_option('--eval.metric', dest='eval.metric', default='["Recall", "Hit", "MRR"]', | ||
type='str', help='evaluation metric') | ||
parser.add_option('--eval.topk', dest='eval.topk', default='[10, 20]', | ||
type='str', help='evaluation K') | ||
parser.add_option('--eval.candidate_neg', dest='eval.candidate_neg', default=0, | ||
type='int', help='number of candidate neg items when testing') | ||
parser.add_option('--eval.test_batch_size', dest='eval.test_batch_size', default=128, | ||
type='int', help='test batch size') | ||
parser.add_option('--model.learning_rate', dest='model.learning_rate', default=0.001, | ||
type='float', help='learning rate') | ||
|
||
(self.options, self.args) = parser.parse_args() | ||
|
||
def getargs(self): | ||
return self.options.__dict__ | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = Parser() | ||
args = parser.getargs() | ||
print(args) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from config.abstract_configurator import AbstractConfig | ||
|
||
|
||
class RunningConfig(AbstractConfig): | ||
def __init__(self, config_file_name): | ||
|
||
super().__init__() | ||
self.must_args = ['model', 'data.name', 'data.path'] | ||
self.args = self._read_config_file(config_file_name, 'default') | ||
for cmd_arg_name, cmd_arg_value in self.cmd_args.items(): | ||
if not cmd_arg_name.startswith('model.'): | ||
self.args[cmd_arg_name] = cmd_arg_value | ||
|
||
self.default_args = self.default_parser.getargs() | ||
for default_args in self.default_args.keys(): | ||
if str(default_args) not in self.args and not str(default_args).startswith('model.'): | ||
self.args[str(default_args)] = str(self.default_args[default_args]) | ||
|
||
self._check_args() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .dataset import * | ||
from .data import Data |
Oops, something went wrong.