diff --git a/codegen.py b/codegen.py index 90b4392e4..f8054ea35 100644 --- a/codegen.py +++ b/codegen.py @@ -10,20 +10,13 @@ from codegen.Enum import Enum from codegen.Bitfield import Bitfield from codegen.Versions import Versions +from codegen.Module import Module from codegen.naming_conventions import clean_comment_str logging.basicConfig(level=logging.DEBUG) FIELD_TYPES = ("add", "field") -VER = "stream.version" - - -def write_file(filename: str, contents: str): - file_dir = os.path.dirname(filename) - if not os.path.exists(file_dir): - os.makedirs(file_dir) - with open(filename, 'w', encoding='utf-8') as file: - file.write(contents) +VER = "self.context.version" class XmlParser: @@ -34,6 +27,8 @@ def __init__(self, format_name): """Set up the xml parser.""" self.format_name = format_name + # which encoding to use for the output files + self.encoding='utf-8' # elements for versions self.version_string = None @@ -53,19 +48,27 @@ def generate_module_paths(self, root): """preprocessing - generate module paths for imports relative to the output dir""" for child in root: # only check stuff that has a name - ignore version tags - if child.tag not in ("version", "module", "token"): - class_name = convention.name_class(child.attrib["name"]) - out_segments = ["formats", self.format_name, child.tag, ] - if child.tag == "niobject": - out_segments.append(child.attrib["module"]) - out_segments.append(class_name) + if child.tag not in ("version", "token"): + base_segments = os.path.join("formats", self.format_name) + if child.tag == "module": + # for modules, set the path to base/module_name + class_name = convention.name_module(child.attrib["name"]) + class_segments = [class_name] + else: + # for classes, set the path to module_path/tag/class_name or + # base/tag/class_name if it's not part of a module + class_name = convention.name_class(child.attrib["name"]) + if child.attrib.get("module"): + base_segments = self.path_dict[convention.name_module(child.attrib["module"])] + class_segments = [child.tag, class_name, ] # store the final relative module path for this class - self.path_dict[class_name] = os.path.join(*out_segments) + self.path_dict[class_name] = os.path.join(base_segments, *class_segments) self.tag_dict[class_name.lower()] = child.tag self.path_dict["Array"] = "array" self.path_dict["BasicBitfield"] = "bitfield" self.path_dict["BitfieldMember"] = "bitfield" + self.path_dict["ContextReference"] = "context" self.path_dict["UbyteEnum"] = "base_enum" self.path_dict["UshortEnum"] = "base_enum" self.path_dict["UintEnum"] = "base_enum" @@ -81,7 +84,8 @@ def load_xml(self, xml_file): for child in root: self.replace_tokens(child) - self.apply_conventions(child) + if child.tag not in ('version', 'module'): + self.apply_conventions(child) try: if child.tag in self.struct_types: Compound(self, child) @@ -91,8 +95,8 @@ def load_xml(self, xml_file): # self.write_basic(child) elif child.tag == "enum": Enum(self, child) - # elif child.tag == "module": - # self.read_module(child) + elif child.tag == "module": + Module(self, child) elif child.tag == "version": versions.read(child) elif child.tag == "token": @@ -110,19 +114,6 @@ def read_token(self, token): for sub_token in token], token.attrib["attrs"].split(" "))) - @staticmethod - def get_names(struct): - # struct types can be organized in a hierarchy - # if inherit attribute is defined, look for corresponding base block - class_name = convention.name_class(struct.attrib.get("name")) - class_basename = struct.attrib.get("inherit") - class_debug_str = clean_comment_str(struct.text, indent="\t") - if class_basename: - # avoid turning None into 'None' if class doesn't inherit - class_basename = convention.name_class(class_basename) - # logging.debug(f"Struct {class_name} is based on {class_basename}") - return class_name, class_basename, class_debug_str - @staticmethod def apply_convention(struct, func, params): for k in params: @@ -137,22 +128,14 @@ def apply_conventions(self, struct): if field.tag in FIELD_TYPES: self.apply_convention(field, convention.name_attribute, ("name",)) self.apply_convention(field, convention.name_class, ("type",)) + self.apply_convention(field, convention.name_class, ("onlyT",)) + self.apply_convention(field, convention.name_class, ("excludeT",)) + for default in field: + self.apply_convention(field, convention.name_class, ("onlyT",)) # filter comment str struct.text = clean_comment_str(struct.text, indent="\t", class_comment='"""') - def collect_types(self, imports, struct): - """Iterate over all fields in struct and collect type references""" - # import classes used in the fields - for field in struct: - if field.tag in ("add", "field", "member"): - field_type = convention.name_class(field.attrib["type"]) - if field_type not in imports: - if field_type == "self.template": - imports.append("typing") - else: - imports.append(field_type) - def method_for_type(self, dtype: str, mode="read", attr="self.dummy", arg=None, template=None): if self.tag_dict[dtype.lower()] == "enum": storage = self.storage_dict[dtype] @@ -207,11 +190,6 @@ def replace_tokens(self, xml_struct): for op_token, op_str in fixed_tokens: expr_str = expr_str.replace(op_token, op_str) xml_struct.attrib[attrib] = expr_str - # onlyT & excludeT act as aliases for deprecated cond - for t, pref in (("onlyT", ""), ("excludeT", "!")): - if t in xml_struct.attrib: - xml_struct.attrib["cond"] = pref+xml_struct.attrib[t] - break for xml_child in xml_struct: self.replace_tokens(xml_child) @@ -224,6 +202,17 @@ def copy_src_to_generated(): copy_tree(src_dir, trg_dir) +def create_inits(): + """Create a __init__.py file in all subdirectories that don't have one, to prevent error on second import""" + base_dir = os.path.join(os.getcwd(), 'generated') + init_file = "__init__.py" + for root, dirs, files in os.walk(base_dir): + if init_file not in files: + # __init__.py does not exist, create it + with open(os.path.join(root, init_file), 'x'): pass + # don't go into subdirectories that start with a double underscore + dirs[:] = [dirname for dirname in dirs if dirname[:2] != '__'] + def generate_classes(): logging.info("Starting class generation") cwd = os.getcwd() @@ -237,6 +226,7 @@ def generate_classes(): logging.info(f"Reading {format_name} format") xmlp = XmlParser(format_name) xmlp.load_xml(xml_path) + create_inits() generate_classes() diff --git a/codegen/BaseClass.py b/codegen/BaseClass.py index dddaf9f3c..22422d097 100644 --- a/codegen/BaseClass.py +++ b/codegen/BaseClass.py @@ -40,7 +40,7 @@ def get_code_from_src(self,): if self.parser.format_name in root and py_name == name.lower(): src_path = os.path.join(root, name) print("found source", src_path) - with open(src_path, "r") as f: + with open(src_path, "r", encoding=self.parser.encoding) as f: return f.read() return "" diff --git a/codegen/Bitfield.py b/codegen/Bitfield.py index 6989e6e56..ec9f65762 100644 --- a/codegen/Bitfield.py +++ b/codegen/Bitfield.py @@ -22,8 +22,12 @@ def get_mask(self): num_bits = int(field.attrib["numbits"]) elif "width" in field.attrib: num_bits = int(field.attrib["width"]) + elif "bit" in field.attrib: + num_bits = 1 + field.attrib["pos"] = field.attrib["bit"] + field.attrib["type"] = "bool" else: - raise AttributeError(f"Neither width or mask or numbits are defined for {field.name}") + raise AttributeError(f"Neither width, mask, bit or numbits are defined for {field.attrib['name']}") pos = int(field.attrib["pos"]) mask = ~((~0) << (pos + num_bits)) & ((~0) << pos) @@ -45,7 +49,7 @@ def read(self): self.class_basename = "BasicBitfield" # write to python file - with open(self.out_file, "w") as f: + with open(self.out_file, "w", encoding=self.parser.encoding) as f: # write the header stuff super().write(f) self.map_pos() diff --git a/codegen/Compound.py b/codegen/Compound.py index f30c1e56c..799a0a88b 100644 --- a/codegen/Compound.py +++ b/codegen/Compound.py @@ -3,7 +3,7 @@ from .Union import Union, get_params FIELD_TYPES = ("add", "field") -VER = "stream.version" +VER = "self.context.version" class Compound(BaseClass): @@ -27,20 +27,25 @@ def read(self): self.imports.add("numpy") # write to python file - with open(self.out_file, "w") as f: + with open(self.out_file, "w", encoding=self.parser.encoding) as f: # write the header stuff super().write(f) + if not self.class_basename: + f.write(f"\n\n\tcontext = ContextReference()") + # check all fields/members in this class and write them as fields # for union in self.field_unions.values(): # union.write_declaration(f) if "def __init__" not in self.src_code: - f.write(f"\n\n\tdef __init__(self, arg=None, template=None):") + f.write(f"\n\n\tdef __init__(self, context, arg=None, template=None):") f.write(f"\n\t\tself.name = ''") # classes that this class inherits from have to be read first if self.class_basename: - f.write(f"\n\t\tsuper().__init__(arg, template)") + f.write(f"\n\t\tsuper().__init__(context, arg, template)") + else: + f.write(f"\n\t\tself._context = context") f.write(f"\n\t\tself.arg = arg") f.write(f"\n\t\tself.template = template") f.write(f"\n\t\tself.io_size = 0") @@ -60,7 +65,6 @@ def read(self): # classes that this class inherits from have to be read first if self.class_basename: f.write(f"\n\t\tsuper().{method_type}(stream)") - for union in self.field_unions: last_condition = union.write_io(f, method_type, last_condition) diff --git a/codegen/Enum.py b/codegen/Enum.py index f5d6ea9e1..93829ea21 100644 --- a/codegen/Enum.py +++ b/codegen/Enum.py @@ -1,7 +1,7 @@ from .BaseClass import BaseClass FIELD_TYPES = ("add", "field") -VER = "stream.version" +VER = "self.context.version" class Enum(BaseClass): @@ -18,7 +18,7 @@ def read(self): self.class_basename = enum_base self.imports.add(enum_base) # write to python file - with open(self.out_file, "w") as f: + with open(self.out_file, "w", encoding=self.parser.encoding) as f: # write the header stuff super().write(f) for option in self.struct: diff --git a/codegen/Imports.py b/codegen/Imports.py index b19e189cb..88574eae0 100644 --- a/codegen/Imports.py +++ b/codegen/Imports.py @@ -1,3 +1,6 @@ +from os.path import sep + + NO_CLASSES = ("Padding",) @@ -11,6 +14,9 @@ def __init__(self, parser, xml_struct): self.imports = [] # import parent class self.add(xml_struct.attrib.get("inherit")) + # import ContextReference class + if xml_struct.tag in parser.struct_types and not xml_struct.attrib.get("inherit"): + self.add("ContextReference") # import classes used in the fields for field in xml_struct: @@ -25,9 +31,23 @@ def __init__(self, parser, xml_struct): self.add(field_type) # arr1 needs typing.List arr1 = field.attrib.get("arr1") + if arr1 is None: + arr1 = field.attrib.get("length") if arr1: self.add("typing") self.add("Array") + type_attribs = ("onlyT", "excludeT") + for attrib in type_attribs: + attrib_type = field.attrib.get(attrib) + if attrib_type: + self.add(attrib_type) + + for default in field: + if default.tag in ("default",): + for attrib in type_attribs: + attrib_type = default.attrib.get(attrib) + if attrib_type: + self.add(attrib_type) def add(self, cls_to_import, import_from=None): if cls_to_import: @@ -43,7 +63,7 @@ def write(self, stream): if class_import in NO_CLASSES: continue if class_import in self.path_dict: - import_path = "generated." + self.path_dict[class_import].replace("\\", ".") + import_path = "generated." + self.path_dict[class_import].replace(sep, ".") local_imports.append(f"from {import_path} import {class_import}\n") else: module_imports.append(f"import {class_import}\n") diff --git a/codegen/Module.py b/codegen/Module.py new file mode 100644 index 000000000..8eedfad48 --- /dev/null +++ b/codegen/Module.py @@ -0,0 +1,24 @@ +import os +from codegen.naming_conventions import clean_comment_str, name_module + +class Module: + + def __init__(self, parser, element): + self.parser = parser + self.element = element + self.read(element) + self.write(parser.path_dict[name_module(element.attrib["name"])]) + + def read(self, element): + self.comment_str = clean_comment_str(element.text, indent="", class_comment='"""')[2:] + self.priority = int(element.attrib.get("priority","")) + self.depends = [name_module(module) for module in element.attrib.get("depends","").split(" ")] + self.custom = bool(eval(element.attrib.get("custom","true").replace("true","True").replace("false","False"),{})) + + def write(self, rel_path): + with open(os.path.join(os.getcwd(), "generated", rel_path, "__init__.py"), "w", encoding=self.parser.encoding) as file: + file.write(self.comment_str) + file.write(f'\n\n__priority__ = {repr(self.priority)}') + file.write(f'\n__depends__ = {repr(self.depends)}') + file.write(f'\n__custom__ = {repr(self.custom)}') + file.write(f'\n') \ No newline at end of file diff --git a/codegen/Union.py b/codegen/Union.py index 353dd9943..f9edff089 100644 --- a/codegen/Union.py +++ b/codegen/Union.py @@ -1,48 +1,69 @@ from codegen.expression import Expression, Version +from codegen.Versions import Versions from .naming_conventions import clean_comment_str -VER = "stream.version" +VER = "self.context.version" -def get_params(field): - # parse all conditions +def get_attr_with_backups(field, attribute_keys): + # return the value of the first attribute in the list that is not empty or + # missing + for key in attribute_keys: + attr_value = field.attrib.get(key) + if attr_value: + return attr_value + else: + return None + +def get_conditions(field): conditionals = [] - field_name = field.attrib["name"] - field_type = field.attrib["type"] - pad_mode = field.attrib.get("padding") - template = field.attrib.get("template") - ver1 = field.attrib.get("ver1") + ver1 = get_attr_with_backups(field, ["ver1", "since"]) if ver1: ver1 = Version(ver1) - else: - ver1 = field.attrib.get("since") - if ver1: - ver1 = Version(ver1) + ver2 = get_attr_with_backups(field, ["ver2", "until"]) - ver2 = field.attrib.get("ver2") if ver2: ver2 = Version(ver2) - else: - ver2 = field.attrib.get("until") - if ver2: - ver2 = Version(ver2) vercond = field.attrib.get("vercond") + versions = field.attrib.get("versions") + if versions: + versions = [Versions.format_id(version) for version in versions.split(" ")] cond = field.attrib.get("cond") + onlyT = field.attrib.get("onlyT") + excludeT = field.attrib.get("excludeT") if ver1 and ver2: - conditionals.append(f"{ver1} <= {VER} < {ver2}") + conditionals.append(f"{ver1} <= {VER} <= {ver2}") elif ver1: conditionals.append(f"{VER} >= {ver1}") elif ver2: - conditionals.append(f"{VER} < {ver2}") + conditionals.append(f"{VER} <= {ver2}") if vercond: - vercond = Expression(vercond) + vercond = Expression(vercond, g_vars=True) conditionals.append(f"{vercond}") + if versions: + conditionals.append(f"({' or '.join([f'is_{version}(self.context)' for version in versions])})") if cond: cond = Expression(cond) conditionals.append(f"{cond}") - arr1 = field.attrib.get("arr1") - arr2 = field.attrib.get("arr2") + if onlyT: + conditionals.append(f"isinstance(self, {onlyT})") + if excludeT: + conditionals.append(f"not isinstance(self, {excludeT})") + return conditionals + +def get_params(field): + # parse all attributes and return the python-evaluatable string + + field_name = field.attrib["name"] + field_type = field.attrib["type"] + pad_mode = field.attrib.get("padding") + template = field.attrib.get("template") + + conditionals = get_conditions(field) + arg = field.attrib.get("arg") + arr1 = get_attr_with_backups(field, ["arr1", "length"]) + arr2 = get_attr_with_backups(field, ["arr2", "width"]) if arg: arg = Expression(arg) if arr1: @@ -51,6 +72,22 @@ def get_params(field): arr2 = Expression(arr2) return arg, template, arr1, arr2, conditionals, field_name, field_type, pad_mode +def condition_indent(base_indent, conditionals, last_condition=""): + # determine the python condition and indentation level based on whether the + # last used condition was the same. + if conditionals: + new_condition = f"if {' and '.join(conditionals)}:" + # merge subsequent fields that have the same condition + if last_condition != new_condition: + last_condition = new_condition + else: + new_condition="" + indent = base_indent + "\t" + else: + indent = base_indent + new_condition = "" + + return indent, new_condition, last_condition class Union: def __init__(self, compound, union_name): @@ -101,12 +138,6 @@ def write_declaration(self, f): if self.compound.parser.tag_dict[field_type.lower()] == "enum": field_default = field_type + "." + field_default f.write(f" = {field_default}") - - # todo - handle several defaults? maybe save as docstring - # load defaults for this - # for default in field: - # if default.tag != "default": - # raise AttributeError("self.struct children's children must be 'default' tag") def get_basic_type(self): """If this union has just one field, return its dtype if it is basic""" @@ -117,58 +148,93 @@ def get_basic_type(self): if self.compound.parser.tag_dict[t.lower()] == "basic": return t + def get_default_string(self, default_string, arg, template, arr1, arr2, field_name, field_type): + # get the default (or the best guess of it) + field_type_lower = field_type.lower() + if field_type_lower in self.compound.parser.tag_dict: + type_of_field_type = self.compound.parser.tag_dict[field_type_lower] + # get the field's default, if it exists + if default_string: + # we have to check if the default is an enum default value, in which case it has to be a member of that enum + if type_of_field_type == "enum": + default_string = f'{field_type}.{default_string}' + elif type_of_field_type in ("bitfield", "bitflags"): + default_string = f'{field_type}({default_string})' + # no default, so guess one + else: + if type_of_field_type in ("compound", "struct", "niobject", "enum", "bitfield", "bitflags", "bitstruct"): + if type_of_field_type in self.compound.parser.struct_types: + arguments = f"context, {arg}, {template}" + else: + arguments = "" + default_string = f"{field_type}({arguments})" + else: + default_string = 0 + + if arr1: + default_string = "Array()" + if self.compound.parser.tag_dict[field_type_lower] == "basic": + valid_arrs = tuple(str(arr) for arr in (arr1, arr2) if arr and ".arg" not in str(arr)) + arr_str = ", ".join(valid_arrs) + if field_type_lower in ("ubyte", "byte", "short", "ushort", "int", "uint", "uint64", "int64", "float"): + default_string = f"numpy.zeros(({arr_str}), dtype='{field_type_lower}')" + # todo - if we do this, it breaks when arg is used in array + # default_string = f"[{default_string} for _ in range({Expression(arr1)})]" + return default_string + + def default_assigns(self, field, arg, template, arr1, arr2, field_name, field_type, base_indent): + field_default = self.get_default_string(field.attrib.get('default'), arg, template, arr1, arr2, field_name, field_type) + default_children = field.findall("default") + if default_children: + defaults = [(f'{base_indent}else:', f'{base_indent}\tself.{field_name} = {field_default}')] + last_default = len(default_children)-1 + last_condition = "" + for i, default_element in enumerate(default_children): + + # get the condition + conditions = get_conditions(default_element) + indent, condition, last_condition = condition_indent(base_indent, conditions, last_condition) + if not condition: + raise AttributeError(f"Default tag without or with overlapping conditions on {field.attrib['name']} {condition} {default_element.get('value')}") + if i != last_default: + condition = f'{base_indent}el{condition}' + else: + condition = f'{base_indent}{condition}' + + default = self.get_default_string(default_element.attrib.get("value"), arg, template, arr1, arr2, field_name, field_type) + defaults.append((condition, f'{indent}self.{field_name} = {default}')) + + defaults = defaults[::-1] + else: + defaults = [("", f'{base_indent}self.{field_name} = {field_default}')] + return defaults + def write_init(self, f): + last_condition="" + base_indent = "\n\t\t" for field in self.members: field_debug_str = clean_comment_str(field.text, indent="\t\t") arg, template, arr1, arr2, conditionals, field_name, field_type, pad_mode = get_params(field) - field_type_lower = field_type.lower() if field_debug_str.strip(): f.write(field_debug_str) - field_default = field.attrib.get("default") - if field_type_lower in self.compound.parser.tag_dict: - type_of_field_type = self.compound.parser.tag_dict[field_type_lower] - # write the field's default, if it exists - if field_default: - # we have to check if the default is an enum default value, in which case it has to be a member of that enum - if type_of_field_type == "enum": - field_default = field_type + "." + field_default - # no default, so guess one - else: - if type_of_field_type in ( - "compound", "struct", "niobject", "enum", "bitfield", "bitflags", "bitstruct"): - if type_of_field_type in ("compound", "struct", "niobject"): - arguments = f"{arg}, {template}" - else: - arguments = "" - field_default = f"{field_type}({arguments})" - if not field_default: - field_default = 0 - if arr1: - field_default = "Array()" - if self.compound.parser.tag_dict[field_type_lower] == "basic": - valid_arrs = tuple(str(arr) for arr in (arr1, arr2) if arr and ".arg" not in str(arr)) - arr_str = ", ".join(valid_arrs) - if field_type_lower in ("ubyte", "byte", "short", "ushort", "int", "uint", "uint64", "int64", "float"): - field_default = f"numpy.zeros(({arr_str}), dtype='{field_type_lower}')" - # todo - if we do this, it breaks when arg is used in array - # field_default = f"[{field_default} for _ in range({Expression(arr1)})]" - f.write(f"\n\t\tself.{field_name} = {field_default}") - def write_io(self, f, method_type, last_condition=""): + indent, new_condition, last_condition = condition_indent(base_indent, conditionals, last_condition) + if new_condition: + f.write(f"{base_indent}{new_condition}") + + defaults = self.default_assigns(field, arg, template, arr1, arr2, field_name, field_type, indent) + for condition, default in defaults: + if condition: + f.write(condition) + f.write(default) + def write_io(self, f, method_type, last_condition=""): + base_indent = "\n\t\t" for field in self.members: arg, template, arr1, arr2, conditionals, field_name, field_type, pad_mode = get_params(field) - # does a condition for this union exist? - if conditionals: - new_condition = f"if {' and '.join(conditionals)}:" - # merge subsequent fields that have the same condition - if last_condition != new_condition: - f.write(f"\n\t\t{new_condition}") - indent = "\n\t\t\t" - else: - indent = "\n\t\t" - new_condition = "" - last_condition = new_condition + indent, new_condition, last_condition = condition_indent(base_indent, conditionals, last_condition) + if new_condition: + f.write(f"{base_indent}{new_condition}") if arr1: if self.compound.parser.tag_dict[field_type.lower()] == "basic": valid_arrs = tuple(str(arr) for arr in (arr1, arr2) if arr) diff --git a/codegen/Versions.py b/codegen/Versions.py index 551ec8a19..4e44ce708 100644 --- a/codegen/Versions.py +++ b/codegen/Versions.py @@ -1,7 +1,14 @@ +from codegen.naming_conventions import name_enum_key +from codegen.expression import Version + class Versions: """Creates and writes a version block""" + @staticmethod + def format_id(version_id): + return version_id.lower() + def __init__(self, parser): self.parent = parser self.versions = [] @@ -12,14 +19,18 @@ def read(self, xml_struct): def write(self, out_file): full_game_names = [] if self.versions: - with open(out_file, "w") as stream: + with open(out_file, "w", encoding=self.parent.encoding) as stream: + stream.write(f"from enum import Enum\n\n\n") + for version in self.versions: - stream.write(f"def is_{version.attrib['id'].lower()}(inst):") + stream.write(f"def is_{self.format_id(version.attrib['id'])}(inst):") conds_list = [] for k, v in version.attrib.items(): if k != "id": name = k.lower() val = v.strip() + if name == 'num': + val = str(Version(val)) if " " in val: conds_list.append(f"inst.{name} in ({val.replace(' ', ', ')})") else: @@ -28,7 +39,7 @@ def write(self, out_file): stream.write("\n\t\treturn True") stream.write("\n\n\n") - stream.write(f"def set_{version.attrib['id'].lower()}(inst):") + stream.write(f"def set_{self.format_id(version.attrib['id'])}(inst):") for k, v in version.attrib.items(): if k != "id": name = k.lower() @@ -39,29 +50,57 @@ def write(self, out_file): suffix = "._value" else: suffix = "" + if name == "num": + val = str(Version(val)) stream.write(f"\n\tinst.{name}{suffix} = {val}") stream.write("\n\n\n") + # go through all the games, record them and map defaults to versions + full_name_key_map = {} + version_default_map = {} + version_game_map = {} + for version in self.versions: + version_default_map[version.attrib['id']] = set() + game_names = version.text.split(', ') + for i, game_name in enumerate(game_names): + game_name = game_name.strip() + # detect defaults and add them to the map + if len(game_name) > 4: + if game_name[:2] == '{{' and game_name[-2:] == '}}': + game_name = game_name[2:-2] + + version_default_map[version.attrib['id']].add(name_enum_key(game_name)) + game_names[i] = game_name + if game_name not in full_name_key_map: + full_name_key_map[game_name] = name_enum_key(game_name) + version_game_map[version.attrib['id']] = [full_name_key_map[game_name] for game_name in game_names] + + # define game enum + full_name_key_map = {full_name: key for full_name, key in sorted(full_name_key_map.items(), key=lambda item: item[1])} + full_name_key_map["Unknown Game"] = "UNKNOWN_GAME" + stream.write(f"games = Enum('Games',{repr([(key, full_name) for full_name, key in full_name_key_map.items()])})") + stream.write("\n\n\n") + # write game lookup function stream.write(f"def get_game(inst):") for version in self.versions: - stream.write(f"\n\tif is_{version.attrib['id'].lower()}(inst):") - full_game_name = version.text.replace('"', '').strip() - full_game_names.append(full_game_name) - stream.write(f"\n\t\treturn '{full_game_name}'") - stream.write("\n\treturn 'Unknown Game'") + stream.write(f"\n\tif is_{self.format_id(version.attrib['id'])}(inst):") + stream.write(f"\n\t\treturn [{', '.join([f'games.{key}' for key in version_game_map[version.attrib['id']]])}]") + stream.write("\n\treturn [games.UNKOWN_GAME]") stream.write("\n\n\n") # write game version setting function stream.write(f"def set_game(inst, game):") + # first check all the defaults for version in self.versions: - full_game_name = version.text.replace('"', '').strip() - stream.write(f"\n\tif game == '{full_game_name}':") - stream.write(f"\n\t\tset_{version.attrib['id'].lower()}(inst)") + if len(version_default_map[version.attrib['id']]) > 0: + stream.write(f"\n\tif game in {{{', '.join([f'games.{key}' for key in version_default_map[version.attrib['id']]])}}}:") + stream.write(f"\n\t\treturn set_{self.format_id(version.attrib['id'])}(inst)") + # then the rest + for version in self.versions: + non_default_games = set(version_game_map[version.attrib['id']]) - version_default_map[version.attrib['id']] + if len(non_default_games) > 0: + stream.write(f"\n\tif game in {{{', '.join([f'games.{key}' for key in non_default_games])}}}:") + stream.write(f"\n\t\treturn set_{self.format_id(version.attrib['id'])}(inst)") stream.write("\n\n\n") - full_game_names.sort() - full_game_names.append("Unknown Game") - # write game list - stream.write(f"games = {str(full_game_names)}") - stream.write("\n\n\n") diff --git a/codegen/expression.py b/codegen/expression.py index bbc2300da..9d6a8fafb 100644 --- a/codegen/expression.py +++ b/codegen/expression.py @@ -1,7 +1,7 @@ """Expression parser (for arr1, arr2, cond, and vercond xml attributes of tag).""" -from codegen import naming_conventions as convention +from codegen.naming_conventions import name_attribute class Version(object): @@ -60,88 +60,15 @@ class Expression(object): operators = {'==', '!=', '>=', '<=', '&&', '||', '&', '|', '-', '!', '<', '>', '/', '*', '+', '%'} - def __init__(self, expr_str, name_filter=None): + def __init__(self, expr_str, g_vars=False): try: left, self._op, right = self._partition(expr_str) - self._left = self._parse(left, name_filter) - self._right = self._parse(right, name_filter) + self._left = self._parse(left, g_vars) + self._right = self._parse(right, g_vars) except: print("error while parsing expression '%s'" % expr_str) raise - def eval(self, data=None): - """Evaluate the expression to an integer.""" - - if isinstance(self._left, Expression): - left = self._left.eval(data) - elif isinstance(self._left, str): - if self._left == '""': - left = "" - else: - left = data - for part in self._left.split("."): - left = getattr(left, part) - elif isinstance(self._left, type): - left = isinstance(data, self._left) - elif self._left is None: - pass - else: - assert (isinstance(self._left, int)) # debug - left = self._left - - if not self._op: - return left - - if isinstance(self._right, Expression): - right = self._right.eval(data) - elif isinstance(self._right, str): - if (not self._right) or self._right == '""': - right = "" - else: - right = getattr(data, self._right) - elif isinstance(self._right, type): - right = isinstance(data, self._right) - elif self._right is None: - pass - else: - assert (isinstance(self._right, int)) # debug - right = self._right - - if self._op == '==': - return left == right - elif self._op == '!=': - return left != right - elif self._op == '>=': - return left >= right - elif self._op == '<=': - return left <= right - elif self._op == '&&': - return left and right - elif self._op == '||': - return left or right - elif self._op == '&': - return left & right - elif self._op == '|': - return left | right - elif self._op == '-': - return left - right - elif self._op == '!': - return not (right) - elif self._op == '>': - return left > right - elif self._op == '<': - return left < right - elif self._op == '/': - return left / right - elif self._op == '*': - return left * right - elif self._op == '+': - return left + right - elif self._op == '%': - return left % right - else: - raise NotImplementedError("expression syntax error: operator '" + self._op + "' not implemented") - def __str__(self): """Reconstruct the expression to a string.""" @@ -163,7 +90,7 @@ def __str__(self): return f"{left} {op} {right}".strip() @classmethod - def _parse(cls, expr_str, name_filter=None): + def _parse(cls, expr_str, g_vars=False): """Returns an Expression, string, or int, depending on the contents of .""" if not expr_str: @@ -171,10 +98,10 @@ def _parse(cls, expr_str, name_filter=None): return None # brackets or operators => expression if ("(" in expr_str) or (")" in expr_str): - return Expression(expr_str, name_filter) + return Expression(expr_str, g_vars) for op in cls.operators: if expr_str.find(op) != -1: - return Expression(expr_str, name_filter) + return Expression(expr_str, g_vars) # try to convert it to one of the following classes for create_cls in (int, Version): try: @@ -185,16 +112,12 @@ def _parse(cls, expr_str, name_filter=None): # at this point, expr_str is a single attribute # apply name filter on each component separately # (where a dot separates components) - if name_filter is None: - def name_filter(x): - return convention.name_attribute(x) - prefix = "self." - # globals are stored on the stream - # it is only a global if the leftmost member has version in it - # ie. general_info.ms2_version is not a global - if "version" in expr_str.split(".")[0].lower(): - prefix = "stream." - return prefix + ('.'.join(name_filter(comp) for comp in expr_str.split("."))) + if g_vars: + # globals are stored on the context + prefix = "self.context." + else: + prefix = "self." + return prefix + ('.'.join(name_attribute(comp) for comp in expr_str.split("."))) @classmethod def _partition(cls, expr_str): @@ -332,16 +255,6 @@ def _scan_brackets(expr_str, fromIndex=0): raise ValueError("expression syntax error (non-matching brackets?)") return start_pos, end_pos - def map_(self, func): - if isinstance(self._left, Expression): - self._left.map_(func) - else: - self._left = func(self._left) - if isinstance(self._right, Expression): - self._right.map_(func) - else: - self._right = func(self._right) - if __name__ == "__main__": import doctest diff --git a/codegen/naming_conventions.py b/codegen/naming_conventions.py index 57e7dd4b3..306ddea2d 100644 --- a/codegen/naming_conventions.py +++ b/codegen/naming_conventions.py @@ -102,6 +102,17 @@ def name_class(name): return ''.join(part.capitalize() for part in name_parts(name)) +def name_enum_key(name): + """Converts a key name into a name suitable for an enum key. + :param name: the key name + :type name: str + :return: Reformatted key name. + >>> name_enum_key('Some key name') + 'SOME_KEY_NAME' + """ + return '_'.join(part.upper() for part in name_parts(name)) + + def clean_comment_str(comment_str="", indent="", class_comment=""): """Reformats an XML comment string into multi-line a python style comment block""" if comment_str is None: @@ -112,4 +123,15 @@ def clean_comment_str(comment_str="", indent="", class_comment=""): lines = [f"\n{indent}{class_comment}",] + [f"\n{indent}{line.strip()}" for line in comment_str.strip().split("\n")] + [f"\n{indent}{class_comment}",] else: lines = [f"\n{indent}# {line.strip()}" for line in comment_str.strip().split("\n")] - return "\n" + "".join(lines) \ No newline at end of file + return "\n" + "".join(lines) + + +def name_module(name): + """Converts a module name into a name suitable for a python module + :param name: the module name + :type name: str + :return: Reformatted module name + >>> name_module('BSHavok') + 'bshavok' + """ + return name.lower() diff --git a/source/bitfield.py b/source/bitfield.py index 6f0ae9ed9..3cab0e48f 100644 --- a/source/bitfield.py +++ b/source/bitfield.py @@ -21,7 +21,7 @@ def __set__(self, instance, value): instance._value |= (value << self.pos) & self.mask -class BasicBitfield(int): +class BasicBitfield(object): _value: int = 0 def set_defaults(self): @@ -34,36 +34,6 @@ def __hash__(self): def __int__(self): return self._value - def __eq__(self, other): - if isinstance(other, BasicBitfield): - return self._value == other._value - elif isinstance(other, int): - return self._value == other - return False - - def __new__(cls, *args, **kwargs): - return super(BasicBitfield, cls).__new__(cls) - - def __add__(self, other): - self._value += other - return self - - def __sub__(self, other): - self._value -= other - return self - - def __mul__(self, other): - self._value *= other - return self - - def __floordiv__(self, other): - self._value //= other - return self - - def __truediv__(self, other): - self._value /= other - return self - def __init__(self, value=None): super().__init__() if value is not None: @@ -84,6 +54,185 @@ def __str__(self): info += f"\n\t{field} = {str(val)}" return info + # rich comparison methods + def __lt__(self, other): + return self._value < other + + def __le__(self, other): + return self._value <= other + + def __eq__(self, other): + return self._value == other + + def __ne__(self, other): + return self._value != other + + def __gt__(self, other): + return self._value > other + + def __ge__(self, other): + return self._value >= other + + # basic arithmetic functions + def __add__(self, other): + return self._value + other + + def __sub__(self, other): + return self._value - other + + def __mul__(self, other): + return self._value * other + + def __truediv__(self, other): + return self._value / other + + def __floordiv__(self, other): + return self._value // other + + def __mod__(self, other): + return self._value % other + + def __divmod__(self, other): + return divmod(self._value, other) + + def __pow__(self, other, modulo=None): + if modulo is None: + return pow(self._value, other) + else: + return pow(self._value, other, modulo) + + def __lshift__(self, other): + return self._value << other + + def __rshift__(self, other): + return self._value >> other + + def __and__(self, other): + return self._value & other + + def __xor__(self, other): + return self._value ^ other + + def __or__(self, other): + return self._value | other + + # reflected basic arithmetic functions + def __radd__(self, other): + return other + self._value + + def __rsub__(self, other): + return other - self._value + + def __rmul__(self, other): + return other * self._value + + def __rtruediv__(self, other): + return other / self._value + + def __rfloordiv__(self, other): + return other // self._value + + def __rmod__(self, other): + return other % self._value + + def __rdivmod__(self, other): + return divmod(other, self._value) + + def __rpow__(self, other, modulo=None): + if modulo is None: + return pow(other, self._value) + else: + return pow(other, self._value, modulo) + + def __rlshift__(self, other): + return other << self._value + + def __rrshift__(self, other): + return other >> self._value + + def __rand__(self, other): + return other & self._value + + def __rxor__(self, other): + return other ^ self._value + + def __ror__(self, other): + return other | self._value + + # arithmetic assignments + def __iadd__(self, other): + self._value = int(self._value + other) + return self + + def __isub__(self, other): + self._value = int(self._value - other) + return self + + def __imul__(self, other): + self._value = int(self._value * other) + return self + + def __itruediv__(self, other): + self._value = int(self._value / other) + return self + + def __ifloordiv__(self, other): + self._value = int(self._value // other) + return self + + def __imod__(self, other): + self._value = int(self._value % other) + return self + + def __ipow__(self, other, modulo=None): + if modulo is None: + self._value = int(pow(self._value, other)) + else: + self._value = int(pow(self._value, other, modulo)) + return self + + def __ilshift__(self, other): + self._value = int(self._value << other) + return self + + def __irshift__(self, other): + self._value = int(self._value >> other) + return self + + def __iand__(self, other): + self._value = int(self._value & other) + return self + + def __ixor__(self, other): + self._value = int(self._value ^ other) + return self + + def __ior__(self, other): + self._value = int(self._value | other) + return self + + # unary operators + def __neg__(self): + return -self._value + + def __pos__(self): + return +self._value + + def __abs__(self): + return abs(self._value) + + def __invert__(self): + return ~self._value + + def __complex__(self): + return complex(self._value) + + def __float__(self): + return float(self._value) + + def __index__(self): + return self.__int__() + class AlphaFunction(IntEnum): """Describes alpha blend modes for NiAlphaProperty.""" diff --git a/source/context.py b/source/context.py new file mode 100644 index 000000000..73c4098af --- /dev/null +++ b/source/context.py @@ -0,0 +1,7 @@ +class ContextReference(object): + + def __get__(self, instance, owner): + return instance._context + + def __set__(self, instance, value): + raise AttributeError(f"Can't modify context attribute!") \ No newline at end of file diff --git a/source/formats/fgm/fgm.xml b/source/formats/fgm/fgm.xml index 55871d7a7..7d3beab20 100644 --- a/source/formats/fgm/fgm.xml +++ b/source/formats/fgm/fgm.xml @@ -9,7 +9,7 @@ Planet Zoo 1.6 Jurassic World Evolution - + Commonly used version expressions. Disneyland Adventure ZTUAC @@ -19,7 +19,7 @@ JWE, 25108 is JWE on switch - + Global Tokens. NOTE: These must be listed after the above tokens so that they replace last. For example, `verexpr` uses these tokens. diff --git a/source/formats/manis/manis.xml b/source/formats/manis/manis.xml index 212293e9d..62e0e3de7 100644 --- a/source/formats/manis/manis.xml +++ b/source/formats/manis/manis.xml @@ -4,7 +4,7 @@ - + Commonly used version expressions. PZ PZ @@ -12,7 +12,7 @@ PC - + Global Tokens. NOTE: These must be listed after the above tokens so that they replace last. For example, `verexpr` uses these tokens. @@ -224,10 +224,10 @@ always FF FF always FF FF - + rest 228 bytes - rest 14 bytes + rest 14 bytes always FF @@ -241,17 +241,17 @@ - - + + - - + + - - + + - - + + @@ -259,10 +259,10 @@ - + - + diff --git a/source/formats/matcol/matcol.xml b/source/formats/matcol/matcol.xml index 1ae254856..1d19a561b 100644 --- a/source/formats/matcol/matcol.xml +++ b/source/formats/matcol/matcol.xml @@ -2,13 +2,13 @@ - + Commonly used version expressions. PZ JWE - + Global Tokens. NOTE: These must be listed after the above tokens so that they replace last. For example, `verexpr` uses these tokens. diff --git a/source/formats/ms2/ms2.xml b/source/formats/ms2/ms2.xml index 630ca5f19..fa74f137c 100644 --- a/source/formats/ms2/ms2.xml +++ b/source/formats/ms2/ms2.xml @@ -4,7 +4,7 @@ Old - + Commonly used version expressions. Disneyland Adventure ZTUAC @@ -15,7 +15,7 @@ JWE, 25108 is JWE on switch - + Global Tokens. NOTE: These must be listed after the above tokens so that they replace last. For example, `verexpr` uses these tokens. @@ -323,18 +323,18 @@ used to find bone info name of ms2 - gives relevant info on the mdl, including counts and pack base - name pointers for each material - lod info for each level, only present if models are present (despite the count sometimes saying otherwise!) - instantiate the meshes with materials - model data blocks for this mdl2 + gives relevant info on the mdl, including counts and pack base + name pointers for each material + lod info for each level, only present if models are present (despite the count sometimes saying otherwise!) + instantiate the meshes with materials + model data blocks for this mdl2 - index into ms2 names array - index into ms2 names array - unknown, nonzero in PZ flamingo juvenile, might be junk (padding) - unknown, nonzero in PZ flamingo juvenile, might be junk (padding) + index into ms2 names array + index into ms2 names array + unknown, nonzero in PZ flamingo juvenile, might be junk (padding) + unknown, nonzero in PZ flamingo juvenile, might be junk (padding) @@ -522,7 +522,7 @@ index count 7 zero joint count - unnk 78 count + unnk 78 count jwe only, everything is shifted a bit due to extra uint 0 index into ms2 string table for bones used here @@ -597,28 +597,28 @@ small number small number small number - 0s, might be related to count 7 in PC + 0s, might be related to count 7 in PC size of the name buffer below, including trailing zeros 0s 0 or 1 0s - 0s + 0s 1, 1 matches bone count from bone info 0 usually 0s corresponds to bone transforms - might be pointers - - used by ptero, 16 bytes per entry - - + might be pointers + + used by ptero, 16 bytes per entry + + - ? - 1FAA FFAAFF00 000000 - counts hitchecks for pz - 0 + ? + 1FAA FFAAFF00 000000 + counts hitchecks for pz + 0 sometimes an array of floats index into bone info bones for each joint; bone that the joint is attached to diff --git a/source/formats/ovl/ovl.xml b/source/formats/ovl/ovl.xml index 21c92b760..1acff727e 100644 --- a/source/formats/ovl/ovl.xml +++ b/source/formats/ovl/ovl.xml @@ -9,7 +9,7 @@ Planet Zoo 1.6+ Jurassic World Evolution - + Commonly used version expressions. Disneyland Adventure ZTUAC @@ -19,7 +19,7 @@ JWE, 25108 is JWE on switch - + Global Tokens. NOTE: These must be listed after the above tokens so that they replace last. For example, `verexpr` uses these tokens. diff --git a/source/formats/tex/tex.xml b/source/formats/tex/tex.xml index e2c071331..ac510424e 100644 --- a/source/formats/tex/tex.xml +++ b/source/formats/tex/tex.xml @@ -9,7 +9,7 @@ Planet Zoo 1.6 Jurassic World Evolution - + Commonly used version expressions. Disneyland Adventure ZTUAC @@ -19,7 +19,7 @@ JWE, 25108 is JWE on switch - + Global Tokens. NOTE: These must be listed after the above tokens so that they replace last. For example, `verexpr` uses these tokens. diff --git a/source/formats/voxelskirt/voxelskirt.xml b/source/formats/voxelskirt/voxelskirt.xml index ca1700c3d..3facceb63 100644 --- a/source/formats/voxelskirt/voxelskirt.xml +++ b/source/formats/voxelskirt/voxelskirt.xml @@ -7,7 +7,7 @@ Planet Zoo Jurassic World Evolution - + Commonly used version expressions. ZTUAC PC @@ -15,7 +15,7 @@ JWE, 25108 is JWE on switch - + Global Tokens. NOTE: These must be listed after the above tokens so that they replace last. For example, `verexpr` uses these tokens.