Skip to content

Commit

Permalink
Update path in generate_code.py
Browse files Browse the repository at this point in the history
  • Loading branch information
daquexian committed May 31, 2019
1 parent b5d6707 commit 59646a6
Showing 1 changed file with 5 additions and 12 deletions.
17 changes: 5 additions & 12 deletions generate_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

str_io = io.StringIO()

optional_namespace = 'std'


class Target(Enum):
ModelBuilder = 1
Expand Down Expand Up @@ -39,7 +37,7 @@ def get_param(elem: dict) -> Tuple[str, str]:
if elem['cpp_type'] == 'str':
return 'const std::string &', elem['name']
elif elem['cpp_type'] == 'optional_str':
return f'const {optional_namespace}::optional<std::string> &', elem['name']
return 'const dnn::optional<std::string> &', elem['name']
elif elem['cpp_type'] == 'str_list':
return 'const std::vector<std::string> &', elem['name']
elif elem['cpp_type'] == 'int32_list':
Expand Down Expand Up @@ -177,9 +175,6 @@ def update_code(file: str, label: str) -> None:


def generate_onnx_converter():
global optional_namespace
optional_namespace = 'nonstd'

with open('ops.yml') as f:
cfg = yaml.load(f)
infer_cfg(cfg, Target.OnnxConverter)
Expand Down Expand Up @@ -250,12 +245,10 @@ def get_input_param(x):
params = list(map(get_param, ipt_opt))
params_str = ', '.join(map(lambda param: "{} {}".format(*param), params))
cogoutl(f"void AddLayer{op['name']}{'' if op['converter'] else 'Impl'}({params_str});")
update_code('tools/onnx2daq/OnnxConverter.h', 'OnnxConverter auto generated methods')
update_code('include/tools/onnx2daq/OnnxConverter.h', 'OnnxConverter auto generated methods')


def generate_model_builder():
global optional_namespace
optional_namespace = 'std'
with open('ops.yml') as f:
cfg = yaml.load(f)
infer_cfg(cfg, Target.ModelBuilder)
Expand All @@ -266,7 +259,7 @@ def generate_model_builder():
ipt_opt = op['input'] + op['output']
params = list(map(get_param, ipt_opt))
if op['support_quant_asymm']:
params.append((f'const {optional_namespace}::optional<QuantInfo> &', 'output_quant_info'))
params.append(('const dnn::optional<QuantInfo> &', 'output_quant_info'))
params_str = ', '.join(map(lambda param: "{} {}".format(*param), params))
cogoutl("ModelBuilder::Index ModelBuilder::Add{}({}) {{".format(op['name'], params_str))
tensor_input = list(filter(lambda x: x['nnapi_type'] == 'tensor', op['input']))
Expand Down Expand Up @@ -310,11 +303,11 @@ def generate_model_builder():
ipt_opt = op['input'] + op['output']
params = list(map(get_param, ipt_opt))
if op['support_quant_asymm']:
params.append((f'const {optional_namespace}::optional<QuantInfo> &', 'output_quant_info'))
params.append(('const dnn::optional<QuantInfo> &', 'output_quant_info'))
params_str = ', '.join(map(lambda param: "{} {}".format(*param), params))
cogoutl("ModelBuilder::Index Add{}({});".format(op['name'], params_str))
cogoutl('#endif // __ANDROID_API__ >= {}'.format(op['api']))
update_code('dnnlibrary/include/ModelBuilder.h', 'ModelBuilder auto generated methods')
update_code('include/dnnlibrary/ModelBuilder.h', 'ModelBuilder auto generated methods')


def main():
Expand Down

0 comments on commit 59646a6

Please sign in to comment.