Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored operator set generation #69

Merged
merged 3 commits into from
Mar 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions include/operators/operator_set.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#ifndef OPERATOR_SET_H
#define OPERATOR_SET_H

#include "operator.h"
#include "tracing.h"
#include <stddef.h>
#include <string.h>

typedef struct operator_set_opdomain operator_set_opdomain;
typedef struct operator_set_opname operator_set_opname;
typedef struct operator_set_opversion operator_set_opversion;

//TODO clean up includes
#include "operators/operator_info.h"
#include "operators/operator.h"

struct operator_set_opversion
{
size_t version;
operator_preparer preparer;
operator_info *info;
};

struct operator_set_opname
{
char *name;
operator_set_opversion *opversions[];
};

struct operator_set_opdomain
{
char *name;
operator_set_opname *opnames[];
};

extern operator_set_opdomain *operator_set[];

static __attribute__((unused))
operator_preparer
operator_set_find_preparer(
char *name,
size_t version
) {
operator_set_opversion *tmp = NULL;
for (operator_set_opdomain **opdomain = operator_set; *opdomain; opdomain++)
{
for (operator_set_opname **opname = (*opdomain)->opnames; *opname; opname++)
{
if (strcmp((*opname)->name,name) == 0) {
for (operator_set_opversion **opversion = (*opname)->opversions; *opversion; opversion++)
{
if ((*opversion)->version <= version) {
if (!tmp || (*opversion)->version >= tmp->version) {
tmp = *opversion;
}
}
}
if (tmp) {
TRACE(2, true, "Found operator '%s' version '%zu'", name, tmp->version);
return tmp->preparer;
}
}
}
}
TRACE_FATAL(0, true, "No Operator not found with name '%s' for opset '%zu'", name, tmp->version);
return NULL;
}

static __attribute__((unused))
operator_info*
operator_set_find_info(
char *name,
size_t version)
{
operator_set_opversion *tmp = NULL;
for (operator_set_opdomain **opdomain = operator_set; *opdomain; opdomain++)
{
for (operator_set_opname **opname = (*opdomain)->opnames; *opname; opname++)
{
if (strcmp((*opname)->name, name) == 0)
{
for (operator_set_opversion **opversion = (*opname)->opversions; *opversion; opversion++)
{
if ((*opversion)->version <= version)
{
if (!tmp || (*opversion)->version >= tmp->version)
{
tmp = *opversion;
}
}
}
if (tmp)
{
TRACE(2, true, "Found operator '%s' version '%zu'", name, tmp->version);
return tmp->info;
}
}
}
}
TRACE_FATAL(0, true, "No Operator not found with name '%s' for opset '%zu'", name, tmp->version);
return NULL;
}

#endif
63 changes: 0 additions & 63 deletions include/operators/operator_sets.h

This file was deleted.

10 changes: 7 additions & 3 deletions scripts/onnx_generator/OnnxWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,13 @@ def __init__(self, schema):
def __repr__(self):
return f"OnnxSchema({self.__dict__.__repr__()})"

def _operator_name(self, schema):
name = f"operator__{self._domain(schema)}__{schema.name}__{schema.since_version}"
return re.sub(r"\W", "_", name).lower()
def _operator_name(self, schema, name=True, version=True):
opname = f"operator__{self._domain(schema)}"
if name:
opname += f"__{schema.name}"
if version:
opname += f"__{schema.since_version}"
return re.sub(r"\W", "_", opname).lower()

def _domain(self, schema):
domain = "ai.onnx"
Expand Down
163 changes: 110 additions & 53 deletions scripts/onnx_generator/OperatorSets.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,145 @@
from .Template import Template
import re

class OperatorSetEntry(Template):
class SourceOperatorVersion(Template):
_basepath = "{path}"
_filepath = "{schema.domain}/{schema.name}/{schema.version}/opversion_{operator_name}.c"
_template = '''
{{
.name = "{name}",
.preparer = prepare_{operator_name},
.info = &info_{operator_name}
}}
//this file was generated by {scriptpath}

#include "operators/operator_set.h"
#include "operators/{schema.domain}/{schema.name}/{schema.version}/{schema.operator_name}.h"

operator_set_opversion opversion_{operator_name} = {{
.version = {version},
.preparer = prepare_{operator_name},
.info = &info_{operator_name}
}};
'''
def __init__(self, schema):
def __init__(self, schema, path):
self.path = path
self.schema = schema
self.version = schema.version
self.name = schema.name
self.operator_name = schema.operator_name


class OperatorSet(Template):
class SourceOperator(Template):
_basepath = "{path}"
_filepath = "{domain}/{name}/opname_operator__{_domain}__{_name}.c"
_template = '''
operator_set {name} = {{
.version = {version},
.domain = "{domain}",
.length = {length},
.entries = {{
{entries}
}}
//this file was generated by {scriptpath}

#include "operators/operator_set.h"

{extern}

operator_set_opname opname_operator__{_domain}__{_name} = {{
.name = "{name}",
.opversions = {{
{entries}
NULL
}}
}};
'''
def __init__(self, domain, name, schemas, path):
self.path = path
self.name = name
self.schemas = schemas
self.domain = domain

def __init__(self, domain, version, schemas):
self._domain = re.sub(r"\W", "_",domain).lower()
self._name = re.sub(r"\W", "_",name).lower()

entries = []
extern = []
for s in schemas:
entries.append(f"&opversion_{s.operator_name},")
extern.append(f'extern operator_set_opversion opversion_{s.operator_name};')
self.entries = "\n ".join(entries)
self.extern = "\n".join(extern)

class SourceDomain(Template):
_basepath = "{path}"
_filepath = "{domain}/opdomain_operator__{_domain}.c"
_template = '''
//this file was generated by {scriptpath}

#include "operators/operator_set.h"

{extern}

operator_set_opdomain opdomain_operator__{_domain} = {{
.name = "{domain}",
.opnames = {{
{entries}
NULL
}}
}};
'''
def __init__(self, domain, operators, path):
self.path = path
self.domain = domain
domain_sane = re.sub(r"\W","_",domain)
self.version = version
self.schemas = schemas
self.name = f"operator_set__{domain_sane}__{version}"
self.length = len(schemas)
self.entries = ",".join([ str(OperatorSetEntry(s)) for s in self.schemas ])
self._domain = re.sub(r"\W", "_",domain).lower()
self.operators = operators
entries = []
extern = []
for op in operators:
_op = re.sub(r"\W", "_",op).lower()
entries.append(f"&opname_operator__{self._domain}__{_op},")
extern.append(f'extern operator_set_opname opname_operator__{self._domain}__{_op};')
self.entries = "\n ".join(entries)
self.extern = "\n".join(extern)

class Source(Template):
class OperatorSet(Template):
_basepath = "{path}"
_filepath = "operator_sets.c"
_filepath = "operator_set.c"
_template = '''
//this file was generated by {scriptpath}
#include "operators/operator_sets.h"

{includes}
#include "operators/operator_set.h"

{sets}
{extern}

operator_sets all_operator_sets = {{
.length = {length},
.sets = {{
{set_refs}
}}
operator_set_opdomain *operator_set[] = {{
{entries}
NULL
}};
'''

def __init__(self, headers, path):
self.headers = headers
def __init__(self, domains, path):
self.path = path
self.domains = domains
entries = []
extern = []
for domain in domains:
_domain = re.sub(r"\W", "_",domain).lower()
entries.append(f"&opdomain_operator__{_domain},")
extern.append(f'extern operator_set_opdomain opdomain_operator__{_domain};')
self.entries = "\n ".join(entries)
self.extern = "\n".join(extern)

class Sets(Template):
def __init__(self, schemas, path):
self.schemas = schemas
self.path = path

sets = []
versions = set()
domain2name2version2schema = {}
for header in self.headers:
schema = header.schema
name2version2schema = domain2name2version2schema.setdefault(schema.domain,{})
self.domain2name2version2schema = {}
for schema in self.schemas:
name2version2schema = self.domain2name2version2schema.setdefault(schema.domain,{})
name2version2schema.setdefault(schema.name,{})[schema.version] = schema
versions.add(schema.version)

for version in versions:
for domain, name2version2schema in domain2name2version2schema.items():
tmp = []
for _name, version2schema in name2version2schema.items():
for v in range(version, 0, -1):
if v in version2schema:
tmp.append(version2schema[v])
break
# print(sets, domain, version, tmp)
sets.append(OperatorSet(domain, version, tmp))
def __iter__(self):
yield OperatorSet(self.domain2name2version2schema.keys(), self.path)
for domain, name2version2schema in self.domain2name2version2schema.items():
yield SourceDomain(domain, name2version2schema.keys(), self.path)
for name, version2schema in name2version2schema.items():
yield SourceOperator(domain, name, version2schema.values(), self.path)
for schema in version2schema.values():
yield SourceOperatorVersion(schema, self.path)



self.includes = "\n".join([ f'#include "operators/{h.filepath(False,False)}"' for h in self.headers ])
self.sets = "\n\n".join([ str(s) for s in sets])
self.length = len(sets)
self.set_refs = ",\n".join([f"&{s.name}" for s in sets])

3 changes: 1 addition & 2 deletions scripts/onnx_generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,10 @@ def fatal(text, error=1):
resolvers = [ OperatorTypeResolver.Source(h,path) for h in headers ]
note("generating onnx operator sets")
path = f"{args.path[-1]}/{args.sets[-1]}/"
sets = [OperatorSets.Source(headers,path)]
sets = OperatorSets.Sets(schemas,path)
note("generating onnx operator template")
path = f"{args.path[-1]}/{args.template[-1]}/"
templates = itertools.chain(*[ OperatorTemplate.Templates(h,path) for h in headers ])
note("generating onnx operator info")
path = f"{args.path[-1]}/{args.info[-1]}/"
info = [ OperatorInfo.Source(h, path) for h in headers ]

Expand Down
Loading