Skip to content

Commit

Permalink
Merge pull request #1 from RUCAIBox/master
Browse files Browse the repository at this point in the history
Copy from RUCAIBox
  • Loading branch information
chenyushuo authored Jul 2, 2020
2 parents a0d6d7b + ca6f067 commit b10e7d2
Show file tree
Hide file tree
Showing 29 changed files with 1,893 additions and 19 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.vscode/
.idea/
2 changes: 2 additions & 0 deletions config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from config.configurator import Config
from config.parser import Parser
102 changes: 102 additions & 0 deletions config/abstract_configurator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os
import sys
from configparser import ConfigParser
from config.parser import Parser

class AbstractConfig(object):
def __init__(self):
self.cmd_args = dict()
self._read_cmd_line()
self.args = dict()
self.must_args = []
self.default_parser = Parser()

def _read_config_file(self, file_name, arg_section):
"""
This function is a protected function that read the config file and convert it to a dict
:param file_name(str): The path of ini-style configuration file.
:param arg_section(str): section name that distinguish running config file and model config file
:return: A dict whose key and value are both str.
"""
if not os.path.isfile(file_name):
raise FileNotFoundError("There is no config file named '%s'!" % file_name)
config = ConfigParser()
config.optionxform = str
config.read(file_name, encoding="utf-8")
sections = config.sections()

if len(sections) == 0:
raise ValueError("'%s' is not in correct format!" % file_name)
elif arg_section not in sections:
raise ValueError("'%s' is not in correct format!" % file_name)
else:
config_arg = dict(config[arg_section].items())

return config_arg

def _read_cmd_line(self):

for arg in sys.argv[1:]:
if not arg.startswith("--"):
raise SyntaxError("Commend arg must start with '--', but '%s' is not!" % arg)
cmd_arg_name, cmd_arg_value = arg[2:].split("=")
if cmd_arg_name in self.cmd_args and cmd_arg_value != self.cmd_args[cmd_arg_name]:
raise SyntaxError("There are duplicate commend arg '%s' with different value!" % arg)
else:
self.cmd_args[cmd_arg_name] = cmd_arg_value

def _check_args(self):
"""
This function is a protected function that check MUST parameters
"""
for parameter in self.must_args:
if parameter not in self.args:
raise ValueError("'%s' must be specified !" % parameter)

def __getitem__(self, item):

if not isinstance(item, str):
raise TypeError("index must be a str")
# Get device or other parameters

if item in self.args:
param = self.args[item]
else:
raise KeyError("There are no parameter named '%s'" % item)

# convert param from str to value, i.e. int, float or list etc.
try:
value = eval(param)
if not isinstance(value, (str, int, float, list, tuple, bool, None.__class__)):
value = param
except NameError:
if param.lower() == "true":
value = True
elif param.lower() == "false":
value = False
else:
value = param
return value

def __setitem__(self, key, value):

if not isinstance(key, str):
raise TypeError("index must be a str")
if key not in self.args:
raise KeyError("There are no model parameter named '%s'" % key)

self.args[key] = str(value)

def __contains__(self, o):
if not isinstance(o, str):
raise TypeError("index must be a str!")
return o in self.args

def __str__(self):
args_info = '\n'.join(
["{}={}".format(arg, value) for arg, value in self.args.items()])
return args_info

def __repr__(self):

return self.__str__()
125 changes: 125 additions & 0 deletions config/configurator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os
import torch
import random
import numpy as np
from config.running_configurator import RunningConfig
from config.model_configurator import ModelConfig


class Config(object):
"""
A configuration class that load predefined hyper parameters.
This class can read arguments from ini-style configuration file. There are two type
of config file can be defined: running config file and model config file. Each file should
be named as XXX.config, and model config file MUST be named as the model name.
In the ini-style config file, only one section is cared. For running config file, the section
MUST be [default] . For model config file, it MUST be [model]
There are three parameter MUST be included in config file: model, data.name, data.path
After initialization successful, the objective of this class can be used as
a dictionary:
config = Configurator("./overall.config")
ratio = config["process.ratio"]
metric = config["eval.metric"]
All the parameter key MUST be str, but returned value is exactly the corresponding type
support parameter type: str, int, float, list, tuple, bool, None
"""

def __init__(self, config_file_name):
"""
:param config_file_name(str): The path of ini-style configuration file.
:raises :
FileNotFoundError: If `config_file` is not existing.
ValueError: If `config_file` is not in correct format or
MUST parameter are not defined
"""

self.run_args = RunningConfig(config_file_name)

model_name = self.run_args['model']
model_arg_file_name = os.path.join(os.path.dirname(config_file_name), model_name + '.config')
self.model_args = ModelConfig(model_arg_file_name)

self.device = None

def init(self):
"""
This function is a global initialization function that fix random seed and gpu device.
"""
init_seed = self.run_args['seed']
gpu_id = self.run_args['gpu_id']
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Get the device that run on.

random.seed(init_seed)
np.random.seed(init_seed)
torch.manual_seed(init_seed)
torch.cuda.manual_seed(init_seed)
torch.cuda.manual_seed_all(init_seed)
torch.backends.cudnn.deterministic = True

def dump_config_file(self, config_file):
"""
This function can dump the model's hyper parameters to a new config file
:param config_file: file name that write to.
"""
self.model_args.dump_config_file(config_file)

def __getitem__(self, item):
if item == "device":
if self.device is None:
raise SyntaxError("device only can be get after init() !")
else:
return self.device
elif item in self.run_args:
return self.run_args[item]
elif item in self.model_args:
return self.model_args[item]
else:
raise KeyError("There are no parameter named '%s'" % item)

def __setitem__(self, key, value):
if not isinstance(key, str):
raise TypeError("index must be a str")
if key in self.run_args:
raise SyntaxError("running args can not be changed when running!")
self.model_args[key] = value

def __contains__(self, o):
if not isinstance(o, str):
raise TypeError("index must be a str!")
return o in self.run_args or o in self.model_args

def __str__(self):

run_args_info = str(self.run_args)
model_args_info = str(self.model_args)
info = "\nRunning Hyper Parameters:\n%s\n\nRunning Model:%s\n\nModel Hyper Parameters:\n%s\n" % \
(run_args_info, self.run_args['model'], model_args_info)
return info

def __repr__(self):

return self.__str__()


if __name__ == '__main__':
config = Config('../properties/overall.config')
config.init()
# print(config)
print(config['process.split_by_ratio.train_ratio'])
print(config['eval.metric'])
print(config['eval.topk'])
print(config['model.learning_rate'])
print(config['eval.group_view'])
print(config['data.separator'])
print(config['model.reg_mf'])
print(config['device'])
config['model.reg_mf'] = 0.6
print(config['model.reg_mf'])
#config.dump_config_file('../properties/mf_new.config')
30 changes: 30 additions & 0 deletions config/model_configurator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from configparser import ConfigParser
from config.abstract_configurator import AbstractConfig


class ModelConfig(AbstractConfig):
def __init__(self, config_file_name):
super().__init__()
self.must_args = []
self.args = self._read_config_file(config_file_name, 'model')
for cmd_arg_name, cmd_arg_value in self.cmd_args.items():
if cmd_arg_name.startswith('model.'):
self.args[cmd_arg_name] = cmd_arg_value

self.default_args = self.default_parser.getargs()
for default_args in self.default_args.keys():

if str(default_args) not in self.args and str(default_args).startswith('model.'):
self.args[str(default_args)] = str(self.default_args[default_args])

self._check_args()

def dump_config_file(self, config_file):
"""
This function can dump the model's hyper parameters to a new config file
:param config_file: file name that write to.
"""
model_config = ConfigParser()
model_config['model'] = self.args
with open(config_file, 'w') as configfile:
model_config.write(configfile)
41 changes: 41 additions & 0 deletions config/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from optparse import OptionParser


class Parser(object):
def __init__(self):
parser = OptionParser()
parser.add_option('--gpu_id', dest='gpu_id', default=0, type='int', help='GPU id that be used')
parser.add_option('--seed', dest='seed', default=2020, type='int', help='random seed')
parser.add_option('--data.num_workers', dest='data.num_workers', default=0, type='int',
help='multi-process when load data')
parser.add_option('--data.separator', dest='data.separator', default=0, type='int', help='data separator')
parser.add_option('--process.remove_lower_value_by_key.key', dest='process.remove_lower_value_by_key.key',
default='rating', type='string', help='data filter')
parser.add_option('--process.remove_lower_value_by_key.min_remain_value',
dest='process.remove_lower_value_by_key.min_remain_value',
default=3, type='int', help='data filter')
parser.add_option('--process.neg_sample_to.num',
dest='process.neg_sample_to.num',
default=100, type='int', help='number of neg samples')
parser.add_option('--eval.metric', dest='eval.metric', default='["Recall", "Hit", "MRR"]',
type='str', help='evaluation metric')
parser.add_option('--eval.topk', dest='eval.topk', default='[10, 20]',
type='str', help='evaluation K')
parser.add_option('--eval.candidate_neg', dest='eval.candidate_neg', default=0,
type='int', help='number of candidate neg items when testing')
parser.add_option('--eval.test_batch_size', dest='eval.test_batch_size', default=128,
type='int', help='test batch size')
parser.add_option('--model.learning_rate', dest='model.learning_rate', default=0.001,
type='float', help='learning rate')

(self.options, self.args) = parser.parse_args()

def getargs(self):
return self.options.__dict__


if __name__ == '__main__':
parser = Parser()
args = parser.getargs()
print(args)

19 changes: 19 additions & 0 deletions config/running_configurator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from config.abstract_configurator import AbstractConfig


class RunningConfig(AbstractConfig):
def __init__(self, config_file_name):

super().__init__()
self.must_args = ['model', 'data.name', 'data.path']
self.args = self._read_config_file(config_file_name, 'default')
for cmd_arg_name, cmd_arg_value in self.cmd_args.items():
if not cmd_arg_name.startswith('model.'):
self.args[cmd_arg_name] = cmd_arg_value

self.default_args = self.default_parser.getargs()
for default_args in self.default_args.keys():
if str(default_args) not in self.args and not str(default_args).startswith('model.'):
self.args[str(default_args)] = str(self.default_args[default_args])

self._check_args()
2 changes: 2 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .dataset import *
from .data import Data
Loading

0 comments on commit b10e7d2

Please sign in to comment.