From 362fdd750818a1c1549cc7a7a679e431baaee283 Mon Sep 17 00:00:00 2001 From: junkmd Date: Sun, 29 May 2022 15:01:05 +0900 Subject: [PATCH] fix codegenerator (enthought#284) (enthought#293) --- comtypes/tools/codegenerator.py | 345 ++++++++++++++++---------------- 1 file changed, 169 insertions(+), 176 deletions(-) diff --git a/comtypes/tools/codegenerator.py b/comtypes/tools/codegenerator.py index e6bd08d4..93ee65ad 100644 --- a/comtypes/tools/codegenerator.py +++ b/comtypes/tools/codegenerator.py @@ -9,6 +9,7 @@ import cStringIO as io import keyword import ctypes +import textwrap from comtypes.tools import typedesc import comtypes @@ -169,9 +170,8 @@ def __init__(self, ofi, known_symbols=None): self._externals = {} self.output = ofi self.stream = io.StringIO() - self.imports = {} - self.declarations = io.StringIO() -## self.stream = self.imports = self.output + self.imports = ImportedNamespaces() + self.declarations = DeclaredNameSpaces() self.known_symbols = known_symbols or {} self.done = set() # type descriptions that have been generated @@ -186,9 +186,7 @@ def generate(self, item): else: name = getattr(item, "name", None) if name in self.known_symbols: - mod = self.known_symbols[name] - if name not in self.imports: - self.imports[name] = mod + self.imports.add(name, symbols=self.known_symbols) self.done.add(item) if isinstance(item, typedesc.Structure): @@ -233,25 +231,20 @@ def _generate_typelib_path(self, filename): # resolution is with respect to current working directory -- later to be # relativized to comtypes.gen. if filename is not None: - # Hm, what is the CORRECT encoding? - print("# -*- coding: mbcs -*-", file=self.output) - print(file=self.output) - if os.path.isabs(filename): # absolute path - print("typelib_path = %r" % filename, file=self.declarations) + self.declarations.add("typelib_path", repr(filename)) elif not os.path.dirname(filename) and not os.path.isfile(filename): # no directory given, and not in current directory. - print("typelib_path = %r" % filename, file=self.declarations) + self.declarations.add("typelib_path", repr(filename)) else: # relative path; make relative to comtypes.gen. path = self._make_relative_path(filename, comtypes.gen.__path__[0]) - self.imports['os'] = None - - print("typelib_path = os.path.normpath(", file=self.declarations) - print(" os.path.abspath(os.path.join(os.path.dirname(__file__),", file=self.declarations) - print(" %r)))" % path, file=self.declarations) - + self.imports.add('os') + definition = "os.path.normpath(\n" \ + " os.path.abspath(os.path.join(os.path.dirname(__file__),\n" \ + " %r)))" % path + self.declarations.add("typelib_path", definition) p = os.path.normpath(os.path.abspath(os.path.join(comtypes.gen.__path__[0], path))) assert os.path.isfile(p) @@ -278,8 +271,7 @@ def generate_code(self, items, filename): filename = full_filename self.filename = filename - print("_lcid = 0 # change this if required", - file=self.declarations) + self.declarations.add("_lcid", "0", "change this if required") self._generate_typelib_path(filename) items = set(items) @@ -292,47 +284,32 @@ def generate_code(self, items, filename): items |= self.more items -= self.done - stream = self.stream.getvalue() - for key, value in self.imports.items(): - if value is None: - self.output.write('import ' + key + '\n') - elif key == '*': - self.output.write('from ' + value + ' import *\n') - elif key == '_check_version' or key in stream: - self.output.write('from ' + value + ' import ' + key + '\n') - - self.output.write("\n\n") - self.output.write(self.declarations.getvalue()) - self.output.write(stream) - - # XXX The space before '%s' is needed to make sure that the entire list - # does not get pushed to the next line when the first name is - # excessively long. - text = "__all__ = [%s]" % ", ".join( - [repr(str(n)) for n in self.names]) - - if len(text) > 80: - import textwrap + self.imports.add("ctypes", "*") # HACK: wildcard import is so ugly. + if tlib_mtime is not None: + logger.debug("filename: \"%s\": tlib_mtime: %s", filename, tlib_mtime) + self.imports.add('comtypes', '_check_version') + + if filename is not None: + # Hm, what is the CORRECT encoding? + print("# -*- coding: mbcs -*-", file=self.output) + print(file=self.output) + print(self.imports.getvalue(), file=self.output) + print(file=self.output) + print(self.declarations.getvalue(), file=self.output) + print(file=self.output) + print(self.stream.getvalue(), file=self.output) + dunder_all = "__all__ = [%s]" % ", ".join(repr(str(n)) for n in self.names) + if len(dunder_all) > 80: wrapper = textwrap.TextWrapper(subsequent_indent=" ", + initial_indent=" ", break_long_words=False) - - print("__all__ = [", file=self.output) - - text = ' ' + (", ".join([repr(str(n)) for n in self.names])) - for line in wrapper.wrap(text): - print(line, file=self.output) - - print("]", file=self.output) - else: - print(text, file=self.output) - + names = ", ".join(repr(str(n)) for n in self.names) + dunder_all = "__all__ = [\n%s\n]" % "\n".join(wrapper.wrap(names)) + print(dunder_all, file=self.output) + print(file=self.output) if tlib_mtime is not None: - self.imports['*'] = 'ctypes' - self.imports['_check_version'] = 'comtypes' - logger.debug("filename: \"%s\": tlib_mtime: %s", filename, - tlib_mtime) print("_check_version(%r, %f)" % (version, tlib_mtime), - file=self.output) + file=self.output) return loops @@ -351,10 +328,10 @@ def type_name(self, t, generate=True): x = get_real_type(t.typ) if isinstance(x, typedesc.FundamentalType): if x.name == "char": - self.need_STRING() + self.declarations.add("STRING", "c_char_p") return "STRING" elif x.name == "wchar_t": - self.need_WSTRING() + self.declarations.add("WSTRING", "c_wchar_p") return "WSTRING" result = "POINTER(%s)" % self.type_name(t.typ, generate) @@ -390,30 +367,13 @@ def type_name(self, t, generate=True): def need_VARIANT_imports(self, value): text = repr(value) if "Decimal(" in text: - self.imports['Decimal'] = 'decimal' + self.imports.add("decimal", "Decimal") if "datetime.datetime(" in text: - self.imports['datetime'] = None + self.imports.add("datetime") - _STRING_defined = False - def need_STRING(self): - if self._STRING_defined: - return - print("STRING = c_char_p", file=self.declarations) - self._STRING_defined = True - - _WSTRING_defined = False - def need_WSTRING(self): - if self._WSTRING_defined: - return - print("WSTRING = c_wchar_p", file=self.declarations) - self._WSTRING_defined = True - - _OPENARRAYS_defined = False - def need_OPENARRAYS(self): - if self._OPENARRAYS_defined: - return - print("OPENARRAY = POINTER(c_ubyte) # hack, see comtypes/tools/codegenerator.py", file=self.declarations) - self._OPENARRAYS_defined = True + def need_GUID(self): + if "GUID" in self.known_symbols: + self.imports.add("GUID", symbols=self.known_symbols) _arraytypes = 0 def ArrayType(self, tp): @@ -449,19 +409,9 @@ def Enumeration(self, tp): for item in tp.values: self.generate(item) if tp.name: - print("%s = c_int # enum" % tp.name, file=self.declarations) + print("%s = c_int # enum" % tp.name, file=self.stream) self.names.add(tp.name) - _GUID_defined = False - def need_GUID(self): - if self._GUID_defined: - return - - self._GUID_defined = True - modname = self.known_symbols.get("GUID") - if modname and 'GUID' not in self.imports: - self.imports['GUID'] = modname - _typedefs = 0 def Typedef(self, tp): self._typedefs += 1 @@ -470,16 +420,13 @@ def Typedef(self, tp): self.more.add(tp.typ) else: self.generate(tp.typ) - if self.type_name(tp.typ) in self.known_symbols: - stream = self.declarations - else: - stream = self.stream - if tp.name != self.type_name(tp.typ): - if stream == self.stream: + definition = self.type_name(tp.typ) + if tp.name != definition: + if definition in self.known_symbols: + self.declarations.add(tp.name, definition) + else: + print("%s = %s" % (tp.name, definition), file=self.stream) self.last_item_class = False - - print("%s = %s" % \ - (tp.name, self.type_name(tp.typ)), file=stream) self.names.add(tp.name) def FundamentalType(self, item): @@ -655,7 +602,7 @@ def StructureBody(self, body): (body.struct.name, align, body.struct.name), file=self.stream) if methods: - self.need_COMMETHOD() + self.imports.add("comtypes", "COMMETHOD") # method definitions normally span several lines. # Before we generate them, we need to 'import' everything they need. # So, call type_name for each field once, @@ -689,62 +636,6 @@ def StructureBody(self, body): print(" ),", file=self.stream) print("]", file=self.stream) - _midlSAFEARRAY_defined = False - def need_midlSAFEARRAY(self): - if self._midlSAFEARRAY_defined: - return - - self.imports['_midlSAFEARRAY'] = 'comtypes.automation' - self._midlSAFEARRAY_defined = True - - _CoClass_defined = False - def need_CoClass(self): - if self._CoClass_defined: - return - - self.imports['CoClass'] = 'comtypes' - self._CoClass_defined = True - - _helpstring_defined = False - def need_helpstring(self): - if self._helpstring_defined: - return - - self.imports['helpstring'] = 'comtypes' - self._helpstring_defined = True - - _dispid_defined = False - def need_dispid(self): - if self._dispid_defined: - return - - self.imports['dispid'] = 'comtypes' - self._dispid_defined = True - - _COMMETHOD_defined = False - def need_COMMETHOD(self): - if self._COMMETHOD_defined: - return - - self.imports['COMMETHOD'] = 'comtypes' - self._COMMETHOD_defined = True - - _DISPMETHOD_defined = False - def need_DISPMETHOD(self): - if self._DISPMETHOD_defined: - return - - self.imports['DISPMETHOD'] = 'comtypes' - self._DISPMETHOD_defined = True - - _DISPPROPERTY_defined = False - def need_DISPPROPERTY(self): - if self._DISPPROPERTY_defined: - return - - self.imports['DISPPROPERTY'] = 'comtypes' - self._DISPPROPERTY_defined = True - ################################################################ # top-level typedesc generators # @@ -794,7 +685,7 @@ def External(self, ext): ext.name = "%s.%s" % (modname, ext.symbol_name) self._externals[libdesc] = modname - self.imports[modname] = None + self.imports.add(modname) comtypes.client.GetModule(ext.tlib) def Constant(self, tp): @@ -806,7 +697,7 @@ def Constant(self, tp): def SAFEARRAYType(self, sa): self.generate(sa.typ) - self.need_midlSAFEARRAY() + self.imports.add("comtypes.automation", "_midlSAFEARRAY") _pointertypes = 0 def PointerType(self, tp): @@ -828,7 +719,7 @@ def PointerType(self, tp): def CoClass(self, coclass): self.need_GUID() - self.need_CoClass() + self.imports.add("comtypes", "CoClass") if not self.last_item_class: print(file=self.stream) print(file=self.stream) @@ -1073,16 +964,16 @@ def DispInterfaceBody(self, body): # non-toplevel method generators # def make_ComMethod(self, m, isdual): - self.need_COMMETHOD() + self.imports.add("comtypes", "COMMETHOD") # typ, name, idlflags, default if isdual: - self.need_dispid() + self.imports.add("comtypes", "dispid") idlflags = [dispid(m.memid)] + m.idlflags else: # We don't include the dispid for non-dispatch COM interfaces idlflags = m.idlflags if __debug__ and m.doc: - self.need_helpstring() + self.imports.add("comtypes", "helpstring") idlflags.insert(1, helpstring(m.doc)) self.last_item_class = False @@ -1106,7 +997,6 @@ def make_ComMethod(self, m, isdual): " '%s'," ) % (idlflags, self.type_name(m.returns), m.name) print(code, file=self.stream) - # self.stream.write("") arglist = [] for typ, name, idlflags, default in m.arguments: type_name = self.type_name(typ) @@ -1151,7 +1041,8 @@ def make_ComMethod(self, m, isdual): # [in, out] BYTE data[]); ########################################################### if isinstance(typ, typedesc.ComInterface): - self.need_OPENARRAYS() + self.declarations.add("OPENARRAY", "POINTER(c_ubyte)", + "hack, see comtypes/tools/codegenerator.py") type_name = "OPENARRAY" if 'in' not in idlflags: idlflags.append('in') @@ -1184,15 +1075,15 @@ def make_ComMethod(self, m, isdual): arglist.append(code) - self.stream.write(",\n".join(arglist)) - print("\n ),", file=self.stream) + print(",\n".join(arglist), file=self.stream) + print(" ),", file=self.stream) def make_DispMethod(self, m): - self.need_DISPMETHOD() - self.need_dispid() + self.imports.add("comtypes", "DISPMETHOD") + self.imports.add("comtypes", "dispid") idlflags = [dispid(m.dispid)] + m.idlflags if __debug__ and m.doc: - self.need_helpstring() + self.imports.add("comtypes", "helpstring") idlflags.insert(1, helpstring(m.doc)) self.last_item_class = False @@ -1250,15 +1141,15 @@ def make_DispMethod(self, m): arglist.append(code) - self.stream.write(",\n".join(arglist)) - print("\n ),", file=self.stream) + print(",\n".join(arglist), file=self.stream) + print(" ),", file=self.stream) def make_DispProperty(self, prop): - self.need_DISPPROPERTY() - self.need_dispid() + self.imports.add("comtypes", "DISPPROPERTY") + self.imports.add("comtypes", "dispid") idlflags = [dispid(prop.dispid)] + prop.idlflags if __debug__ and prop.doc: - self.need_helpstring() + self.imports.add("comtypes", "helpstring") idlflags.insert(1, helpstring(prop.doc)) self.last_item_class = False @@ -1274,6 +1165,108 @@ def make_DispProperty(self, prop): print(code, file=self.stream) + +class ImportedNamespaces(object): + def __init__(self): + if sys.version_info >= (3, 7): + self.data = {} + else: + from collections import OrderedDict + self.data = OrderedDict() + + def add(self, *names, symbols=None): + """Adds a namespace will be imported. + + Examples: + >>> imports = ImportedNamespaces() + >>> imports.add('datetime') + >>> imports.add('ctypes', '*') + >>> imports.add('decimal', 'Decimal') + >>> imports.add('GUID', symbols={'GUID': 'comtypes'}) + >>> for name in ('COMMETHOD', 'IUnknown', 'dispid', 'CoClass', + ... 'BSTR', 'DISPPROPERTY'): + ... imports.add('comtypes', name) + >>> imports.add('ctypes.wintypes') + >>> print(imports.getvalue()) + from ctypes import * + import datetime + from decimal import Decimal + from comtypes import BSTR, CoClass, COMMETHOD, dispid, DISPPROPERTY, \\ + GUID, IUnknown + import ctypes.wintypes + """ + if not names or len(names) > 2: + raise ValueError + if len(names) == 1: + (import_,) = names + if not symbols: + self.data[import_] = None + return + from_ = symbols[import_] + else: + from_, import_ = names + self.data[import_] = from_ + + def _make_line(self, import_, from_=None): + if from_ is None: + return "import %s" % import_ + code = "from %s import %s" % (from_, import_) + if len(code) > 80: + wrapper = textwrap.TextWrapper(subsequent_indent=" ", + break_long_words=False) + code = " \\\n".join(wrapper.wrap(code)) + return code + + def getvalue(self): + ns = {} + lines = [] + for key, val in self.data.items(): + if val is None: + ns[key] = val + elif key == "*": + lines.append(self._make_line("*", val)) + else: + ns.setdefault(val, set()).add(key) + for key, val in ns.items(): + if val is None: + lines.append(self._make_line(key)) + else: + names = ", ".join(sorted(val, key=lambda s: s.lower())) + lines.append(self._make_line(names, key)) + return "\n".join(lines) + + +class DeclaredNameSpaces(object): + def __init__(self): + if sys.version_info >= (3, 7): + self.data = {} + else: + from collections import OrderedDict + self.data = OrderedDict() + + def add(self, alias, definition, comment=None): + """Adds a namespace will be declared. + + Examples: + >>> declarations = DeclaredNameSpaces() + >>> declarations.add('STRING', 'c_char_p') + >>> declarations.add('_lcid', '0', 'change this if required') + >>> print(declarations.getvalue()) + STRING = c_char_p + _lcid = 0 # change this if required + """ + self.data[(alias, definition)] = comment + + def getvalue(self): + lines = [] + for (alias, definition), comment in self.data.items(): + code = "%s = %s" % (alias, definition) + if comment: + code = code + " # %s" % comment + lines.append(code) + return "\n".join(lines) + + # shortcut for development if __name__ == "__main__": from . import tlbparser