diff --git a/rosidl_adapter/CMakeLists.txt b/rosidl_adapter/CMakeLists.txt index 33d9e879a..ae1337a91 100644 --- a/rosidl_adapter/CMakeLists.txt +++ b/rosidl_adapter/CMakeLists.txt @@ -10,6 +10,7 @@ ament_python_install_package(${PROJECT_NAME}) if(BUILD_TESTING) find_package(ament_cmake_pytest REQUIRED) find_package(ament_lint_auto REQUIRED) + find_package(ament_cmake_mypy REQUIRED) ament_lint_auto_find_test_dependencies() ament_add_pytest_test(pytest test) endif() diff --git a/rosidl_adapter/package.xml b/rosidl_adapter/package.xml index 192c883f1..4d97110ec 100644 --- a/rosidl_adapter/package.xml +++ b/rosidl_adapter/package.xml @@ -26,6 +26,7 @@ python3-empy rosidl_cli + ament_cmake_mypy ament_cmake_pytest ament_lint_common ament_lint_auto diff --git a/rosidl_adapter/rosidl_adapter/__init__.py b/rosidl_adapter/rosidl_adapter/__init__.py index 2cfd58295..7351d577f 100644 --- a/rosidl_adapter/rosidl_adapter/__init__.py +++ b/rosidl_adapter/rosidl_adapter/__init__.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path -def convert_to_idl(package_dir, package_name, interface_file, output_dir): + +def convert_to_idl(package_dir: Path, package_name: str, interface_file: Path, + output_dir: Path) -> Path: if interface_file.suffix == '.msg': from rosidl_adapter.msg import convert_msg_to_idl return convert_msg_to_idl( diff --git a/rosidl_adapter/rosidl_adapter/__main__.py b/rosidl_adapter/rosidl_adapter/__main__.py index f7a5bbf46..6adf087c0 100644 --- a/rosidl_adapter/rosidl_adapter/__main__.py +++ b/rosidl_adapter/rosidl_adapter/__main__.py @@ -16,4 +16,5 @@ from rosidl_adapter.main import main -sys.exit(main()) +main() +sys.exit() diff --git a/rosidl_adapter/rosidl_adapter/action/__init__.py b/rosidl_adapter/rosidl_adapter/action/__init__.py index 6e8719b6c..53e94a080 100644 --- a/rosidl_adapter/rosidl_adapter/action/__init__.py +++ b/rosidl_adapter/rosidl_adapter/action/__init__.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path + from rosidl_adapter.parser import parse_action_string -from rosidl_adapter.resource import expand_template +from rosidl_adapter.resource import ActionData, expand_template -def convert_action_to_idl(package_dir, package_name, input_file, output_dir): +def convert_action_to_idl(package_dir: Path, package_name: str, input_file: Path, + output_dir: Path) -> Path: assert package_dir.is_absolute() assert not input_file.is_absolute() assert input_file.suffix == '.action' @@ -30,7 +33,7 @@ def convert_action_to_idl(package_dir, package_name, input_file, output_dir): output_file = output_dir / input_file.with_suffix('.idl').name abs_output_file = output_file.absolute() print(f'Writing output file: {abs_output_file}') - data = { + data: ActionData = { 'pkg_name': package_name, 'relative_input_file': input_file.as_posix(), 'action': action, diff --git a/rosidl_adapter/rosidl_adapter/cli.py b/rosidl_adapter/rosidl_adapter/cli.py index f1f5bb5eb..c50e7116b 100644 --- a/rosidl_adapter/rosidl_adapter/cli.py +++ b/rosidl_adapter/rosidl_adapter/cli.py @@ -13,8 +13,9 @@ # limitations under the License. import argparse -import pathlib +from pathlib import Path import sys +from typing import Callable, List, Literal, TYPE_CHECKING from catkin_pkg.package import package_exists_at from catkin_pkg.package import parse_package @@ -27,7 +28,18 @@ from rosidl_cli.command.translate.extensions import TranslateCommandExtension -def convert_files_to_idl(extension, conversion_function, argv=sys.argv[1:]): +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + ConversionFunctionType: TypeAlias = Callable[[Path, str, Path, Path], Path] + + +def convert_files_to_idl( + extension: Literal['.msg', '.srv', '.action'], + conversion_function: 'ConversionFunctionType', + argv: List[str] = sys.argv[1:] +) -> None: + parser = argparse.ArgumentParser( description=f'Convert {extension} files to .idl') parser.add_argument( @@ -36,7 +48,7 @@ def convert_files_to_idl(extension, conversion_function, argv=sys.argv[1:]): args = parser.parse_args(argv) for interface_file in args.interface_files: - interface_file = pathlib.Path(interface_file) + interface_file = Path(interface_file) package_dir = interface_file.parent.absolute() while ( len(package_dir.parents) and @@ -48,8 +60,7 @@ def convert_files_to_idl(extension, conversion_function, argv=sys.argv[1:]): f"Could not find package for '{interface_file}'", file=sys.stderr) continue - warnings = [] - pkg = parse_package(package_dir, warnings=warnings) + pkg = parse_package(package_dir, warnings=[]) conversion_function( package_dir, pkg.name, @@ -63,14 +74,14 @@ class TranslateToIDL(TranslateCommandExtension): def translate( self, - package_name, - interface_files, - include_paths, - output_path - ): + package_name: str, + interface_files: List[str], + include_paths: List[str], + output_path: Path + ) -> List[str]: translated_interface_files = [] - for interface_file in interface_files: - prefix, interface_file = interface_path_as_tuple(interface_file) + for interface_file_str in interface_files: + prefix, interface_file = interface_path_as_tuple(interface_file_str) output_dir = output_path / interface_file.parent translated_interface_file = self.conversion_function( prefix, package_name, interface_file, output_dir) @@ -87,7 +98,7 @@ class TranslateMsgToIDL(TranslateToIDL): input_format = 'msg' @property - def conversion_function(self): + def conversion_function(self) -> 'ConversionFunctionType': return convert_msg_to_idl @@ -96,7 +107,7 @@ class TranslateSrvToIDL(TranslateToIDL): input_format = 'srv' @property - def conversion_function(self): + def conversion_function(self) -> 'ConversionFunctionType': return convert_srv_to_idl @@ -104,5 +115,5 @@ class TranslateActionToIDL(TranslateToIDL): input_format = 'action' @property - def conversion_function(self): + def conversion_function(self) -> 'ConversionFunctionType': return convert_action_to_idl diff --git a/rosidl_adapter/rosidl_adapter/main.py b/rosidl_adapter/rosidl_adapter/main.py index bbcd50845..759159c2a 100644 --- a/rosidl_adapter/rosidl_adapter/main.py +++ b/rosidl_adapter/rosidl_adapter/main.py @@ -17,12 +17,13 @@ import os import pathlib import sys +from typing import List from rosidl_adapter import convert_to_idl -def main(argv=sys.argv[1:]): +def main(argv: List[str] = sys.argv[1:]) -> None: parser = argparse.ArgumentParser( description='Convert interface files to .idl') parser.add_argument( diff --git a/rosidl_adapter/rosidl_adapter/msg/__init__.py b/rosidl_adapter/rosidl_adapter/msg/__init__.py index b02b7b5bd..1798ba4ae 100644 --- a/rosidl_adapter/rosidl_adapter/msg/__init__.py +++ b/rosidl_adapter/rosidl_adapter/msg/__init__.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from rosidl_adapter.parser import parse_message_string -from rosidl_adapter.resource import expand_template +from pathlib import Path +from typing import Final, Optional, Union +from rosidl_adapter.parser import BaseType, parse_message_string, Type +from rosidl_adapter.resource import expand_template, MsgData -def convert_msg_to_idl(package_dir, package_name, input_file, output_dir): + +def convert_msg_to_idl(package_dir: Path, package_name: str, input_file: Path, + output_dir: Path) -> Path: assert package_dir.is_absolute() assert not input_file.is_absolute() assert input_file.suffix == '.msg' @@ -30,7 +34,7 @@ def convert_msg_to_idl(package_dir, package_name, input_file, output_dir): output_file = output_dir / input_file.with_suffix('.idl').name abs_output_file = output_file.absolute() print(f'Writing output file: {abs_output_file}') - data = { + data: MsgData = { 'pkg_name': package_name, 'relative_input_file': input_file.as_posix(), 'msg': msg, @@ -40,7 +44,7 @@ def convert_msg_to_idl(package_dir, package_name, input_file, output_dir): return output_file -MSG_TYPE_TO_IDL = { +MSG_TYPE_TO_IDL: Final = { 'bool': 'boolean', 'byte': 'octet', 'char': 'uint8', @@ -59,7 +63,7 @@ def convert_msg_to_idl(package_dir, package_name, input_file, output_dir): } -def to_idl_literal(idl_type, value): +def to_idl_literal(idl_type: str, value: str) -> str: if idl_type[-1] == ']' or idl_type.startswith('sequence<'): content = repr(tuple(value)).replace('\\', r'\\').replace('"', r'\"') return f'"{content}"' @@ -73,24 +77,24 @@ def to_idl_literal(idl_type, value): return value -def string_to_idl_string_literal(string): +def string_to_idl_string_literal(string: str) -> str: """Convert string to character literal as described in IDL 4.2 section 7.2.6.3 .""" estr = string.encode().decode('unicode_escape') estr = estr.replace('"', r'\"') return '"{0}"'.format(estr) -def string_to_idl_wstring_literal(string): +def string_to_idl_wstring_literal(string: str) -> str: return string_to_idl_string_literal(string) -def get_include_file(base_type): +def get_include_file(base_type: BaseType) -> Optional[str]: if base_type.is_primitive_type(): return None return f'{base_type.pkg_name}/msg/{base_type.type}.idl' -def get_idl_type(type_): +def get_idl_type(type_: Union[str, Type]) -> str: if isinstance(type_, str): identifier = MSG_TYPE_TO_IDL[type_] elif type_.is_primitive_type(): diff --git a/rosidl_adapter/rosidl_adapter/parser.py b/rosidl_adapter/rosidl_adapter/parser.py index 836a3b751..aeee4b933 100644 --- a/rosidl_adapter/rosidl_adapter/parser.py +++ b/rosidl_adapter/rosidl_adapter/parser.py @@ -16,28 +16,29 @@ import re import sys import textwrap +from typing import Final, Iterable, List, Optional, Tuple, TYPE_CHECKING, TypedDict, Union -PACKAGE_NAME_MESSAGE_TYPE_SEPARATOR = '/' -COMMENT_DELIMITER = '#' -CONSTANT_SEPARATOR = '=' -ARRAY_UPPER_BOUND_TOKEN = '<=' -STRING_UPPER_BOUND_TOKEN = '<=' +PACKAGE_NAME_MESSAGE_TYPE_SEPARATOR: Final = '/' +COMMENT_DELIMITER: Final = '#' +CONSTANT_SEPARATOR: Final = '=' +ARRAY_UPPER_BOUND_TOKEN: Final = '<=' +STRING_UPPER_BOUND_TOKEN: Final = '<=' -SERVICE_REQUEST_RESPONSE_SEPARATOR = '---' -SERVICE_REQUEST_MESSAGE_SUFFIX = '_Request' -SERVICE_RESPONSE_MESSAGE_SUFFIX = '_Response' -SERVICE_EVENT_MESSAGE_SUFFIX = '_Event' +SERVICE_REQUEST_RESPONSE_SEPARATOR: Final = '---' +SERVICE_REQUEST_MESSAGE_SUFFIX: Final = '_Request' +SERVICE_RESPONSE_MESSAGE_SUFFIX: Final = '_Response' +SERVICE_EVENT_MESSAGE_SUFFIX: Final = '_Event' -ACTION_REQUEST_RESPONSE_SEPARATOR = '---' -ACTION_GOAL_SUFFIX = '_Goal' -ACTION_RESULT_SUFFIX = '_Result' -ACTION_FEEDBACK_SUFFIX = '_Feedback' +ACTION_REQUEST_RESPONSE_SEPARATOR: Final = '---' +ACTION_GOAL_SUFFIX: Final = '_Goal' +ACTION_RESULT_SUFFIX: Final = '_Result' +ACTION_FEEDBACK_SUFFIX: Final = '_Feedback' -ACTION_GOAL_SERVICE_SUFFIX = '_Goal' -ACTION_RESULT_SERVICE_SUFFIX = '_Result' -ACTION_FEEDBACK_MESSAGE_SUFFIX = '_Feedback' +ACTION_GOAL_SERVICE_SUFFIX: Final = '_Goal' +ACTION_RESULT_SERVICE_SUFFIX: Final = '_Result' +ACTION_FEEDBACK_MESSAGE_SUFFIX: Final = '_Feedback' -PRIMITIVE_TYPES = [ +PRIMITIVE_TYPES: Final = [ 'bool', 'byte', 'char', @@ -59,20 +60,29 @@ 'time', # for compatibility only ] -VALID_PACKAGE_NAME_PATTERN = re.compile( +VALID_PACKAGE_NAME_PATTERN: Final = re.compile( '^' '(?!.*__)' # no consecutive underscores '(?!.*_$)' # no underscore at the end '[a-z]' # first character must be alpha '[a-z0-9_]*' # followed by alpha, numeric, and underscore '$') -VALID_FIELD_NAME_PATTERN = VALID_PACKAGE_NAME_PATTERN +VALID_FIELD_NAME_PATTERN: Final = VALID_PACKAGE_NAME_PATTERN # relaxed patterns used for compatibility with ROS 1 messages # VALID_FIELD_NAME_PATTERN = re.compile('^[A-Za-z][A-Za-z0-9_]*$') -VALID_MESSAGE_NAME_PATTERN = re.compile('^[A-Z][A-Za-z0-9]*$') +VALID_MESSAGE_NAME_PATTERN: Final = re.compile('^[A-Z][A-Za-z0-9]*$') # relaxed patterns used for compatibility with ROS 1 messages # VALID_MESSAGE_NAME_PATTERN = re.compile('^[A-Za-z][A-Za-z0-9]*$') -VALID_CONSTANT_NAME_PATTERN = re.compile('^[A-Z]([A-Z0-9_]?[A-Z0-9]+)*$') +VALID_CONSTANT_NAME_PATTERN: Final = re.compile('^[A-Z]([A-Z0-9_]?[A-Z0-9]+)*$') + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + PrimitiveType: TypeAlias = Union[bool, float, int, str] + + class Annotations(TypedDict, total=False): + comment: List[str] + unit: str class InvalidSpecification(Exception): @@ -101,7 +111,8 @@ class UnknownMessageType(InvalidSpecification): class InvalidValue(Exception): - def __init__(self, type_, value_string, message_suffix=None): + def __init__(self, type_: Union['Type', str], value_string: str, + message_suffix: Optional[str] = None) -> None: message = "value '%s' can not be converted to type '%s'" % \ (value_string, type_) if message_suffix is not None: @@ -109,7 +120,7 @@ def __init__(self, type_, value_string, message_suffix=None): super(InvalidValue, self).__init__(message) -def is_valid_package_name(name): +def is_valid_package_name(name: str) -> bool: try: m = VALID_PACKAGE_NAME_PATTERN.match(name) except TypeError: @@ -117,7 +128,7 @@ def is_valid_package_name(name): return m is not None and m.group(0) == name -def is_valid_field_name(name): +def is_valid_field_name(name: str) -> bool: try: m = VALID_FIELD_NAME_PATTERN.match(name) except TypeError: @@ -125,7 +136,7 @@ def is_valid_field_name(name): return m is not None and m.group(0) == name -def is_valid_message_name(name): +def is_valid_message_name(name: str) -> bool: try: prefix = 'Sample_' if name.startswith(prefix): @@ -146,7 +157,7 @@ def is_valid_message_name(name): return m is not None and m.group(0) == name -def is_valid_constant_name(name): +def is_valid_constant_name(name: str) -> bool: try: m = VALID_CONSTANT_NAME_PATTERN.match(name) except TypeError: @@ -158,7 +169,7 @@ class BaseType: __slots__ = ['pkg_name', 'type', 'string_upper_bound'] - def __init__(self, type_string, context_package_name=None): + def __init__(self, type_string: str, context_package_name: Optional[str] = None) -> None: # check for primitive types if type_string in PRIMITIVE_TYPES: self.pkg_name = None @@ -194,10 +205,14 @@ def __init__(self, type_string, context_package_name=None): # either the type string contains the package name self.pkg_name = parts[0] self.type = parts[1] - else: + elif context_package_name: # or the package name is provided by context self.pkg_name = context_package_name self.type = type_string + else: + raise ValueError('Either parts has length 2 or context_package_name exist' + 'otherwise BaseType Malformed.') + if not is_valid_package_name(self.pkg_name): raise InvalidResourceName( "'{}' is an invalid package name. It should have the pattern '{}'".format( @@ -209,20 +224,20 @@ def __init__(self, type_string, context_package_name=None): self.string_upper_bound = None - def is_primitive_type(self): + def is_primitive_type(self) -> bool: return self.pkg_name is None - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if other is None or not isinstance(other, BaseType): return False return self.pkg_name == other.pkg_name and \ self.type == other.type and \ self.string_upper_bound == other.string_upper_bound - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) - def __str__(self): + def __str__(self) -> str: if self.pkg_name is not None: return '%s/%s' % (self.pkg_name, self.type) @@ -237,7 +252,7 @@ class Type(BaseType): __slots__ = ['is_array', 'array_size', 'is_upper_bound'] - def __init__(self, type_string, context_package_name=None): + def __init__(self, type_string: str, context_package_name: Optional[str] = None) -> None: # check for array brackets self.is_array = type_string[-1] == ']' @@ -247,8 +262,8 @@ def __init__(self, type_string, context_package_name=None): try: index = type_string.rindex('[') except ValueError: - raise TypeError("the type ends with ']' but does not " + - "contain a '['" % type_string) + raise TypeError("the type %s ends with ']' but does not " % type_string + + "contain a '['") array_size_string = type_string[index + 1:-1] # get array limit if array_size_string != '': @@ -279,13 +294,13 @@ def __init__(self, type_string, context_package_name=None): type_string, context_package_name=context_package_name) - def is_dynamic_array(self): + def is_dynamic_array(self) -> bool: return self.is_array and (not self.array_size or self.is_upper_bound) - def is_fixed_size_array(self): - return self.is_array and self.array_size and not self.is_upper_bound + def is_fixed_size_array(self) -> bool: + return self.is_array and bool(self.array_size) and not self.is_upper_bound - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if other is None or not isinstance(other, Type): return False return super(Type, self).__eq__(other) and \ @@ -293,10 +308,10 @@ def __eq__(self, other): self.array_size == other.array_size and \ self.is_upper_bound == other.is_upper_bound - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) - def __str__(self): + def __str__(self) -> str: s = super(Type, self).__str__() if self.is_array: s += '[' @@ -312,7 +327,7 @@ class Constant: __slots__ = ['type', 'name', 'value', 'annotations'] - def __init__(self, primitive_type, name, value_string): + def __init__(self, primitive_type: str, name: str, value_string: str) -> None: if primitive_type not in PRIMITIVE_TYPES: raise TypeError("the constant type '%s' must be a primitive type" % primitive_type) @@ -328,16 +343,16 @@ def __init__(self, primitive_type, name, value_string): self.value = parse_primitive_value_string( Type(primitive_type), value_string) - self.annotations = {} + self.annotations: 'Annotations' = {} - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if other is None or not isinstance(other, Constant): return False return self.type == other.type and \ self.name == other.name and \ self.value == other.value - def __str__(self): + def __str__(self) -> str: value = self.value if self.type in ('string', 'wstring'): value = "'%s'" % value @@ -346,7 +361,8 @@ def __str__(self): class Field: - def __init__(self, type_, name, default_value_string=None): + def __init__(self, type_: 'Type', name: str, + default_value_string: Optional[str] = None) -> None: if not isinstance(type_, Type): raise TypeError( "the field type '%s' must be a 'Type' instance" % type_) @@ -362,9 +378,9 @@ def __init__(self, type_, name, default_value_string=None): self.default_value = parse_value_string( type_, default_value_string) - self.annotations = {} + self.annotations: 'Annotations' = {} - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if other is None or not isinstance(other, Field): return False else: @@ -372,7 +388,7 @@ def __eq__(self, other): self.name == other.name and \ self.default_value == other.default_value - def __str__(self): + def __str__(self) -> str: s = '%s %s' % (str(self.type), self.name) if self.default_value is not None: if self.type.is_primitive_type() and not self.type.is_array and \ @@ -385,11 +401,12 @@ def __str__(self): class MessageSpecification: - def __init__(self, pkg_name, msg_name, fields, constants): + def __init__(self, pkg_name: str, msg_name: str, fields: Iterable['Field'], + constants: Iterable['Constant']) -> None: self.base_type = BaseType( pkg_name + PACKAGE_NAME_MESSAGE_TYPE_SEPARATOR + msg_name) self.msg_name = msg_name - self.annotations = {} + self.annotations: 'Annotations' = {} self.fields = [] for index, field in enumerate(fields): @@ -420,7 +437,7 @@ def __init__(self, pkg_name, msg_name, fields, constants): 'the constants iterable contains duplicate names: %s' % ', '.join(sorted(duplicate_constant_names))) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not other or not isinstance(other, MessageSpecification): return False return self.base_type == other.base_type and \ @@ -429,7 +446,7 @@ def __eq__(self, other): len(self.constants) == len(other.constants) and \ self.constants == other.constants - def __str__(self): + def __str__(self) -> str: """Output an equivalent .msg IDL string.""" output = ['# ', str(self.base_type), '\n'] for constant in self.constants: @@ -441,7 +458,7 @@ def __str__(self): return ''.join(output) -def parse_message_file(pkg_name, interface_filename): +def parse_message_file(pkg_name: str, interface_filename: str) -> MessageSpecification: basename = os.path.basename(interface_filename) msg_name = os.path.splitext(basename)[0] with open(interface_filename, 'r', encoding='utf-8') as h: @@ -449,7 +466,7 @@ def parse_message_file(pkg_name, interface_filename): pkg_name, msg_name, h.read()) -def extract_file_level_comments(message_string): +def extract_file_level_comments(message_string: str) -> Tuple[List[str], List[str]]: lines = message_string.splitlines() index = next( (i for i, v in enumerate(lines) if not v.startswith(COMMENT_DELIMITER)), -1) @@ -463,10 +480,11 @@ def extract_file_level_comments(message_string): return file_level_comments, file_content -def parse_message_string(pkg_name, msg_name, message_string): - fields = [] - constants = [] - last_element = None # either a field or a constant +def parse_message_string(pkg_name: str, msg_name: str, + message_string: str) -> MessageSpecification: + fields: List[Field] = [] + constants: List[Constant] = [] + last_element: Union[Field, Constant, None] = None # either a field or a constant # replace tabs with spaces message_string = message_string.replace('\t', ' ') @@ -513,13 +531,13 @@ def parse_message_string(pkg_name, msg_name, message_string): if index == -1: # line contains a field field_name, _, default_value_string = rest.partition(' ') - default_value_string = default_value_string.lstrip() - if not default_value_string: - default_value_string = None + optional_default_value_string: Optional[str] = default_value_string.lstrip() + if not optional_default_value_string: + optional_default_value_string = None try: fields.append(Field( Type(type_string, context_package_name=pkg_name), - field_name, default_value_string)) + field_name, optional_default_value_string)) except Exception as err: print( "Error processing '{line}' of '{pkg}/{msg}': '{err}'".format( @@ -555,7 +573,7 @@ def parse_message_string(pkg_name, msg_name, message_string): return msg -def process_comments(instance): +def process_comments(instance: Union[MessageSpecification, Field, Constant]) -> None: if 'comment' in instance.annotations: lines = instance.annotations['comment'] @@ -590,7 +608,8 @@ def process_comments(instance): instance.annotations['comment'] = textwrap.dedent(text).split('\n') -def parse_value_string(type_, value_string): +def parse_value_string(type_: Type, value_string: str) -> Union['PrimitiveType', + List['PrimitiveType']]: if type_.is_primitive_type() and not type_.is_array: return parse_primitive_value_string(type_, value_string) @@ -640,10 +659,11 @@ def parse_value_string(type_, value_string): "parsing string values into type '%s' is not supported" % type_) -def parse_string_array_value_string(element_string, expected_size): +def parse_string_array_value_string(element_string: str, + expected_size: Optional[int]) -> List[str]: # Walks the string, if start with quote (' or ") find next unescapted quote, # returns a list of string elements - value_strings = [] + value_strings: List[str] = [] while len(element_string) > 0: element_string = element_string.lstrip(' ') if element_string[0] == ',': @@ -677,7 +697,7 @@ def parse_string_array_value_string(element_string, expected_size): return value_strings -def find_matching_end_quote(string, quote): +def find_matching_end_quote(string: str, quote: str) -> int: # Given a string, walk it and find the next unescapted quote # returns the index of the ending quote if successful, -1 otherwise ending_quote_idx = -1 @@ -695,7 +715,7 @@ def find_matching_end_quote(string, quote): return -1 -def parse_primitive_value_string(type_, value_string): +def parse_primitive_value_string(type_: Type, value_string: str) -> 'PrimitiveType': if not type_.is_primitive_type() or type_.is_array: raise ValueError('the passed type must be a non-array primitive type') primitive_type = type_.type @@ -791,10 +811,13 @@ def parse_primitive_value_string(type_, value_string): assert False, "unknown primitive type '%s'" % primitive_type -def validate_field_types(spec, known_msg_types): +def validate_field_types(spec: Union[MessageSpecification, + 'ServiceSpecification', + 'ActionSpecification'], + known_msg_types: List[BaseType]) -> None: if isinstance(spec, MessageSpecification): spec_type = 'Message' - fields = spec.fields + fields: List[Field] = spec.fields elif isinstance(spec, ServiceSpecification): spec_type = 'Service' fields = spec.request.fields + spec.response.fields @@ -818,7 +841,8 @@ def validate_field_types(spec, known_msg_types): class ServiceSpecification: - def __init__(self, pkg_name, srv_name, request, response): + def __init__(self, pkg_name: str, srv_name: str, request: MessageSpecification, + response: MessageSpecification) -> None: self.pkg_name = pkg_name self.srv_name = srv_name assert isinstance(request, MessageSpecification) @@ -826,7 +850,7 @@ def __init__(self, pkg_name, srv_name, request, response): assert isinstance(response, MessageSpecification) self.response = response - def __str__(self): + def __str__(self) -> str: """Output an equivalent .srv IDL string.""" output = ['# ', str(self.pkg_name), '/', str(self.srv_name), '\n'] output.append(str(self.request)) @@ -835,7 +859,7 @@ def __str__(self): return ''.join(output) -def parse_service_file(pkg_name, interface_filename): +def parse_service_file(pkg_name: str, interface_filename: str) -> ServiceSpecification: basename = os.path.basename(interface_filename) srv_name = os.path.splitext(basename)[0] with open(interface_filename, 'r', encoding='utf-8') as h: @@ -843,7 +867,8 @@ def parse_service_file(pkg_name, interface_filename): pkg_name, srv_name, h.read()) -def parse_service_string(pkg_name, srv_name, message_string): +def parse_service_string(pkg_name: str, srv_name: str, + message_string: str) -> ServiceSpecification: lines = message_string.splitlines() separator_indices = [ index for index, line in enumerate(lines) if line == SERVICE_REQUEST_RESPONSE_SEPARATOR] @@ -869,7 +894,11 @@ def parse_service_string(pkg_name, srv_name, message_string): class ActionSpecification: - def __init__(self, pkg_name, action_name, goal, result, feedback): + goal_service: ServiceSpecification + result_service: ServiceSpecification + + def __init__(self, pkg_name: str, action_name: str, goal: MessageSpecification, + result: MessageSpecification, feedback: MessageSpecification) -> None: self.pkg_name = pkg_name self.action_name = action_name assert isinstance(goal, MessageSpecification) @@ -880,14 +909,15 @@ def __init__(self, pkg_name, action_name, goal, result, feedback): self.feedback = feedback -def parse_action_file(pkg_name, interface_filename): +def parse_action_file(pkg_name: str, interface_filename: str) -> ActionSpecification: basename = os.path.basename(interface_filename) action_name = os.path.splitext(basename)[0] with open(interface_filename, 'r', encoding='utf-8') as h: return parse_action_string(pkg_name, action_name, h.read()) -def parse_action_string(pkg_name, action_name, action_string): +def parse_action_string(pkg_name: str, action_name: str, + action_string: str) -> ActionSpecification: lines = action_string.splitlines() separator_indices = [ index for index, line in enumerate(lines) if line == ACTION_REQUEST_RESPONSE_SEPARATOR] diff --git a/rosidl_adapter/rosidl_adapter/py.typed b/rosidl_adapter/rosidl_adapter/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/rosidl_adapter/rosidl_adapter/resource/__init__.py b/rosidl_adapter/rosidl_adapter/resource/__init__.py index b629021a1..a422e4309 100644 --- a/rosidl_adapter/rosidl_adapter/resource/__init__.py +++ b/rosidl_adapter/rosidl_adapter/resource/__init__.py @@ -14,9 +14,13 @@ from io import StringIO import os +from pathlib import Path import sys +from typing import Any, Optional, TypedDict import em +from rosidl_adapter.parser import ActionSpecification, MessageSpecification, ServiceSpecification + try: from em import Configuration @@ -25,7 +29,25 @@ em_has_configuration = False -def expand_template(template_name, data, output_file, encoding='utf-8'): +class Data(TypedDict): + pkg_name: str + relative_input_file: str + + +class MsgData(Data): + msg: MessageSpecification + + +class SrvData(Data): + srv: ServiceSpecification + + +class ActionData(Data): + action: ActionSpecification + + +def expand_template(template_name: str, data: Data, output_file: Path, + encoding: str = 'utf-8') -> None: content = evaluate_template(template_name, data) if output_file.exists(): @@ -38,14 +60,14 @@ def expand_template(template_name, data, output_file, encoding='utf-8'): output_file.write_text(content, encoding=encoding) -_interpreter = None +_interpreter: Optional[em.Interpreter] = None -def evaluate_template(template_name, data): +def evaluate_template(template_name: str, data: Data) -> str: global _interpreter # create copy before manipulating - data = dict(data) - data['TEMPLATE'] = _evaluate_template + data_copy = dict(data) + data_copy['TEMPLATE'] = _evaluate_template template_path = os.path.join(os.path.dirname(__file__), template_name) @@ -71,11 +93,11 @@ def evaluate_template(template_name, data): with open(template_path, 'r') as h: content = h.read() _interpreter.invoke( - 'beforeFile', name=template_name, file=h, locals=data) + 'beforeFile', name=template_name, file=h, locals=data_copy) if em_has_configuration: - _interpreter.string(content, locals=data) + _interpreter.string(content, locals=data_copy) else: - _interpreter.string(content, template_path, locals=data) + _interpreter.string(content, template_path, locals=data_copy) _interpreter.invoke('afterFile') return output.getvalue() @@ -90,8 +112,11 @@ def evaluate_template(template_name, data): _interpreter = None -def _evaluate_template(template_name, **kwargs): +def _evaluate_template(template_name: str, **kwargs: Any) -> None: global _interpreter + if _interpreter is None: + raise RuntimeError('_evaluate_template called without running evaluate_template.') + template_path = os.path.join(os.path.dirname(__file__), template_name) with open(template_path, 'r') as h: _interpreter.invoke( diff --git a/rosidl_adapter/rosidl_adapter/srv/__init__.py b/rosidl_adapter/rosidl_adapter/srv/__init__.py index c57b7013f..710a97d8e 100644 --- a/rosidl_adapter/rosidl_adapter/srv/__init__.py +++ b/rosidl_adapter/rosidl_adapter/srv/__init__.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path + from rosidl_adapter.parser import parse_service_string -from rosidl_adapter.resource import expand_template +from rosidl_adapter.resource import expand_template, SrvData -def convert_srv_to_idl(package_dir, package_name, input_file, output_dir): +def convert_srv_to_idl(package_dir: Path, package_name: str, input_file: Path, + output_dir: Path) -> Path: assert package_dir.is_absolute() assert not input_file.is_absolute() assert input_file.suffix == '.srv' @@ -30,7 +33,7 @@ def convert_srv_to_idl(package_dir, package_name, input_file, output_dir): output_file = output_dir / input_file.with_suffix('.idl').name abs_output_file = output_file.absolute() print(f'Writing output file: {abs_output_file}') - data = { + data: SrvData = { 'pkg_name': package_name, 'relative_input_file': input_file.as_posix(), 'srv': srv, diff --git a/rosidl_adapter/test/parse_msg_files.py b/rosidl_adapter/test/parse_msg_files.py index 9dbc31a4c..09ccbb46e 100755 --- a/rosidl_adapter/test/parse_msg_files.py +++ b/rosidl_adapter/test/parse_msg_files.py @@ -17,11 +17,12 @@ import argparse import os import sys +from typing import List, Literal from rosidl_adapter.parser import parse_message_file -def main(argv=sys.argv[1:]): +def main(argv: List[str] = sys.argv[1:]) -> Literal[0]: parser = argparse.ArgumentParser( description='Parse all recursively found .msg files.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -45,8 +46,8 @@ def main(argv=sys.argv[1:]): return 0 -def get_files(paths): - files = [] +def get_files(paths: str) -> List[str]: + files: List[str] = [] for path in paths: if os.path.isdir(path): for dirpath, dirnames, filenames in os.walk(path): diff --git a/rosidl_adapter/test/test_base_type.py b/rosidl_adapter/test/test_base_type.py index dfaf18eef..1f1b123da 100644 --- a/rosidl_adapter/test/test_base_type.py +++ b/rosidl_adapter/test/test_base_type.py @@ -18,7 +18,7 @@ from rosidl_adapter.parser import InvalidResourceName -def test_base_type_constructor(): +def test_base_type_constructor() -> None: primitive_types = [ 'bool', 'byte', @@ -73,7 +73,7 @@ def test_base_type_constructor(): BaseType('pkg/Foo Bar') -def test_base_type_methods(): +def test_base_type_methods() -> None: assert BaseType('bool').is_primitive_type() assert not BaseType('pkg/Foo').is_primitive_type() diff --git a/rosidl_adapter/test/test_cli_extensions.py b/rosidl_adapter/test/test_cli_extensions.py index 218d7bb41..b4f33997e 100644 --- a/rosidl_adapter/test/test_cli_extensions.py +++ b/rosidl_adapter/test/test_cli_extensions.py @@ -15,13 +15,15 @@ import filecmp import pathlib +from pytest import CaptureFixture from rosidl_cli.command.translate.api import translate DATA_PATH = pathlib.Path(__file__).parent / 'data' -def test_translation_extensions(tmp_path, capsys): +def test_translation_extensions(tmp_path: pathlib.Path, + capsys: CaptureFixture[str]) -> None: # NOTE(hidmic): pytest and empy do not play along, # the latter expects some proxy will stay in sys.stdout # and the former insists in overwriting it diff --git a/rosidl_adapter/test/test_constant.py b/rosidl_adapter/test/test_constant.py index 26d396abd..5fcb08148 100644 --- a/rosidl_adapter/test/test_constant.py +++ b/rosidl_adapter/test/test_constant.py @@ -17,7 +17,7 @@ from rosidl_adapter.parser import Constant -def test_constant_constructor(): +def test_constant_constructor() -> None: value = Constant('bool', 'FOO', '1') assert value @@ -28,10 +28,10 @@ def test_constant_constructor(): Constant('bool', 'FOO BAR', '') with pytest.raises(ValueError): - Constant('bool', 'FOO', None) + Constant('bool', 'FOO', None) # type: ignore[arg-type] -def test_constant_methods(): +def test_constant_methods() -> None: assert Constant('bool', 'FOO', '1') != 23 assert Constant('bool', 'FOO', '1') == Constant('bool', 'FOO', '1') diff --git a/rosidl_adapter/test/test_extract_message_comments.py b/rosidl_adapter/test/test_extract_message_comments.py index b03ad26b1..5b4ff5ee8 100644 --- a/rosidl_adapter/test/test_extract_message_comments.py +++ b/rosidl_adapter/test/test_extract_message_comments.py @@ -15,7 +15,7 @@ from rosidl_adapter.parser import parse_message_string -def test_extract_message_comments(): +def test_extract_message_comments() -> None: # multi line file-level comment msg_spec = parse_message_string('pkg', 'Foo', '# comment 1\n#\n# comment 2\nbool value') assert len(msg_spec.annotations) == 1 diff --git a/rosidl_adapter/test/test_field.py b/rosidl_adapter/test/test_field.py index 291cb42a4..c27b2f01c 100644 --- a/rosidl_adapter/test/test_field.py +++ b/rosidl_adapter/test/test_field.py @@ -19,7 +19,7 @@ from rosidl_adapter.parser import Type -def test_field_constructor(): +def test_field_constructor() -> None: type_ = Type('bool') field = Field(type_, 'foo') assert field.type == type_ @@ -30,10 +30,10 @@ def test_field_constructor(): assert field.default_value with pytest.raises(TypeError): - Field('type', 'foo') + Field('type', 'foo') # type: ignore[arg-type] with pytest.raises(NameError): - Field(type_, 'foo bar') + Field(type_, 'foo bar') # type: ignore[arg-type] type_ = Type('bool[2]') field = Field(type_, 'foo', '[false, true]') @@ -48,7 +48,7 @@ def test_field_constructor(): Field(type_, 'foo', '[false, true]') -def test_field_methods(): +def test_field_methods() -> None: assert Field(Type('bool'), 'foo') != 23 assert (Field(Type('bool'), 'foo', '1') == diff --git a/rosidl_adapter/test/test_message_specification.py b/rosidl_adapter/test/test_message_specification.py index 4ce925266..a3e17178b 100644 --- a/rosidl_adapter/test/test_message_specification.py +++ b/rosidl_adapter/test/test_message_specification.py @@ -20,7 +20,7 @@ from rosidl_adapter.parser import Type -def test_message_specification_constructor(): +def test_message_specification_constructor() -> None: msg_spec = MessageSpecification('pkg', 'Foo', [], []) assert msg_spec.base_type.pkg_name == 'pkg' assert msg_spec.base_type.type == 'Foo' @@ -28,13 +28,13 @@ def test_message_specification_constructor(): assert len(msg_spec.constants) == 0 with pytest.raises(TypeError): - MessageSpecification('pkg', 'Foo', None, []) + MessageSpecification('pkg', 'Foo', None, []) # type: ignore[arg-type] with pytest.raises(TypeError): - MessageSpecification('pkg', 'Foo', [], None) + MessageSpecification('pkg', 'Foo', [], None) # type: ignore[arg-type] with pytest.raises(TypeError): - MessageSpecification('pkg', 'Foo', ['field'], []) + MessageSpecification('pkg', 'Foo', ['field'], []) # type: ignore[list-item] with pytest.raises(TypeError): - MessageSpecification('pkg', 'Foo', [], ['constant']) + MessageSpecification('pkg', 'Foo', [], ['constant']) # type: ignore[list-item] field = Field(Type('bool'), 'foo', '1') constant = Constant('bool', 'BAR', '1') @@ -50,7 +50,7 @@ def test_message_specification_constructor(): MessageSpecification('pkg', 'Foo', [], [constant, constant]) -def test_message_specification_methods(): +def test_message_specification_methods() -> None: field = Field(Type('bool'), 'foo', '1') constant = Constant('bool', 'BAR', '1') msg_spec = MessageSpecification('pkg', 'Foo', [field], [constant]) diff --git a/rosidl_adapter/test/test_parse_action_string.py b/rosidl_adapter/test/test_parse_action_string.py index c3d6e0f0c..445367172 100644 --- a/rosidl_adapter/test/test_parse_action_string.py +++ b/rosidl_adapter/test/test_parse_action_string.py @@ -18,7 +18,7 @@ from rosidl_adapter.parser import parse_action_string -def test_invalid_action_specification(): +def test_invalid_action_specification() -> None: with pytest.raises(InvalidActionSpecification): parse_action_string('pkg', 'Foo', '') @@ -27,11 +27,11 @@ def test_invalid_action_specification(): parse_action_string('pkg', 'Foo', 'bool foo\n---\nint8 bar') -def test_valid_action_string(): +def test_valid_action_string() -> None: parse_action_string('pkg', 'Foo', 'bool foo\n---\nint8 bar\n---') -def test_valid_action_string1(): +def test_valid_action_string1() -> None: spec = parse_action_string('pkg', 'Foo', 'bool foo\n---\nint8 bar\n---\nbool foo') # Goal checks assert spec.goal.base_type.pkg_name == 'pkg' @@ -50,7 +50,7 @@ def test_valid_action_string1(): assert len(spec.feedback.constants) == 0 -def test_valid_action_string2(): +def test_valid_action_string2() -> None: spec = parse_action_string( 'pkg', 'Foo', '#comment---\n \nbool foo\n---\n#comment\n \nint8 bar\n---\nbool foo') # Goal checks @@ -70,7 +70,7 @@ def test_valid_action_string2(): assert len(spec.feedback.constants) == 0 -def test_valid_action_string3(): +def test_valid_action_string3() -> None: spec = parse_action_string( 'pkg', 'Foo', diff --git a/rosidl_adapter/test/test_parse_message_file.py b/rosidl_adapter/test/test_parse_message_file.py index aed6de95c..c1b889992 100644 --- a/rosidl_adapter/test/test_parse_message_file.py +++ b/rosidl_adapter/test/test_parse_message_file.py @@ -21,7 +21,7 @@ from rosidl_adapter.parser import parse_message_file -def test_parse_message_file(): +def test_parse_message_file() -> None: path = tempfile.mkdtemp(prefix='test_parse_message_file_') try: filename = os.path.join(path, 'Foo.msg') diff --git a/rosidl_adapter/test/test_parse_message_string.py b/rosidl_adapter/test/test_parse_message_string.py index 569bc6a3d..d14bd840d 100644 --- a/rosidl_adapter/test/test_parse_message_string.py +++ b/rosidl_adapter/test/test_parse_message_string.py @@ -19,7 +19,7 @@ from rosidl_adapter.parser import parse_message_string -def test_parse_message_string(): +def test_parse_message_string() -> None: msg_spec = parse_message_string('pkg', 'Foo', '') assert msg_spec.base_type.pkg_name == 'pkg' assert msg_spec.base_type.type == 'Foo' diff --git a/rosidl_adapter/test/test_parse_primitive_value_string.py b/rosidl_adapter/test/test_parse_primitive_value_string.py index 6b328bb56..ce8f1d0bf 100644 --- a/rosidl_adapter/test/test_parse_primitive_value_string.py +++ b/rosidl_adapter/test/test_parse_primitive_value_string.py @@ -19,14 +19,14 @@ from rosidl_adapter.parser import Type -def test_parse_primitive_value_string_invalid(): +def test_parse_primitive_value_string_invalid() -> None: with pytest.raises(ValueError): parse_primitive_value_string(Type('pkg/Foo'), '') with pytest.raises(ValueError): parse_primitive_value_string(Type('bool[]'), '') -def test_parse_primitive_value_string_bool(): +def test_parse_primitive_value_string_bool() -> None: valid_bool_string_values = { 'true': True, 'TrUe': True, @@ -47,7 +47,7 @@ def test_parse_primitive_value_string_bool(): parse_primitive_value_string(Type('bool'), 'true ') -def test_parse_primitive_value_string_integer(): +def test_parse_primitive_value_string_integer() -> None: integer_types = { 'byte': [8, True], 'char': [8, True], @@ -85,7 +85,7 @@ def test_parse_primitive_value_string_integer(): Type(integer_type), str(upper_bound + 1)) -def test_parse_primitive_value_string_hex(): +def test_parse_primitive_value_string_hex() -> None: integer_types = { 'byte': [8, True], 'char': [8, True], @@ -123,7 +123,7 @@ def test_parse_primitive_value_string_hex(): Type(integer_type), hex(upper_bound + 1)) -def test_parse_primitive_value_string_oct(): +def test_parse_primitive_value_string_oct() -> None: integer_types = { 'byte': [8, True], 'char': [8, True], @@ -161,7 +161,7 @@ def test_parse_primitive_value_string_oct(): Type(integer_type), oct(upper_bound + 1)) -def test_parse_primitive_value_string_bin(): +def test_parse_primitive_value_string_bin() -> None: integer_types = { 'byte': [8, True], 'char': [8, True], @@ -199,7 +199,7 @@ def test_parse_primitive_value_string_bin(): Type(integer_type), bin(upper_bound + 1)) -def test_parse_primitive_value_string_float(): +def test_parse_primitive_value_string_float() -> None: for float_type in ['float32', 'float64']: value = parse_primitive_value_string( Type(float_type), '0') @@ -216,7 +216,7 @@ def test_parse_primitive_value_string_float(): Type(float_type), 'value') -def test_parse_primitive_value_string_string(): +def test_parse_primitive_value_string_string() -> None: value = parse_primitive_value_string( Type('string'), 'foo') assert value == 'foo' @@ -286,7 +286,7 @@ def test_parse_primitive_value_string_string(): assert value == '"foo"' -def test_parse_primitive_value_wstring_string(): +def test_parse_primitive_value_wstring_string() -> None: value = parse_primitive_value_string( Type('wstring'), 'foo') assert value == 'foo' @@ -356,10 +356,10 @@ def test_parse_primitive_value_wstring_string(): assert value == '"foo"' -def test_parse_primitive_value_string_unknown(): +def test_parse_primitive_value_string_unknown() -> None: class CustomType(Type): - def is_primitive_type(self): + def is_primitive_type(self) -> bool: return True type_ = CustomType('pkg/Foo') diff --git a/rosidl_adapter/test/test_parse_service_string.py b/rosidl_adapter/test/test_parse_service_string.py index 5da60e7f6..762b593bb 100644 --- a/rosidl_adapter/test/test_parse_service_string.py +++ b/rosidl_adapter/test/test_parse_service_string.py @@ -19,7 +19,7 @@ from rosidl_adapter.parser import parse_service_string -def test_parse_service_string(): +def test_parse_service_string() -> None: with pytest.raises(InvalidServiceSpecification): parse_service_string('pkg', 'Foo', '') diff --git a/rosidl_adapter/test/test_parse_unicode.py b/rosidl_adapter/test/test_parse_unicode.py index 89a2aba6a..c56d22cd5 100644 --- a/rosidl_adapter/test/test_parse_unicode.py +++ b/rosidl_adapter/test/test_parse_unicode.py @@ -23,7 +23,7 @@ from rosidl_adapter.parser import parse_message_string -def test_parse_message_string_with_unicode_comments(): +def test_parse_message_string_with_unicode_comments() -> None: # Similar to `test_parse_message_string.py` but we only care about the comments part. msg_spec = parse_message_string('pkg', 'Foo', '#comment ڭ with ڮ some گ unicode ڰ sprinkles\ \n \n # ♔ ♕ ♖ ♗ ♘ ♙ ♚ ♛ ♜ ♝ ♞ ♟') @@ -34,7 +34,7 @@ def test_parse_message_string_with_unicode_comments(): parse_message_string('pkg', 'Foo', 'bool # comment ✌✌') -def test_parse_message_file_with_unicode_comments(): +def test_parse_message_file_with_unicode_comments() -> None: # Like `test_parse_message_file.py` but with a unicode comment line. path = tempfile.mkdtemp(prefix='test_parse_message_file_with_unicode_comments_') try: @@ -61,7 +61,7 @@ def test_parse_message_file_with_unicode_comments(): shutil.rmtree(path) -def test_extract_message_unicode_comments(): +def test_extract_message_unicode_comments() -> None: # Like `test_extract_message_commnets.py` but with several unicode symbols as comments. # multi line file-level comment msg_spec = parse_message_string('pkg', 'Foo', '# ¡¢£¤¥¦§¨©ª«¬­®¯°±²³´µ¶\n' diff --git a/rosidl_adapter/test/test_parse_value_string.py b/rosidl_adapter/test/test_parse_value_string.py index b3b2695c4..33d80ed6f 100644 --- a/rosidl_adapter/test/test_parse_value_string.py +++ b/rosidl_adapter/test/test_parse_value_string.py @@ -19,11 +19,11 @@ from rosidl_adapter.parser import Type -def test_parse_value_string_primitive(): +def test_parse_value_string_primitive() -> None: parse_value_string(Type('bool'), '1') -def test_parse_value_string(): +def test_parse_value_string() -> None: with pytest.raises(InvalidValue): parse_value_string(Type('bool[]'), '1') @@ -57,6 +57,6 @@ def test_parse_value_string(): assert value -def test_parse_value_string_not_implemented(): +def test_parse_value_string_not_implemented() -> None: with pytest.raises(NotImplementedError): parse_value_string(Type('pkg/Foo[]'), '') diff --git a/rosidl_adapter/test/test_type.py b/rosidl_adapter/test/test_type.py index 1900f48c8..4241f7660 100644 --- a/rosidl_adapter/test/test_type.py +++ b/rosidl_adapter/test/test_type.py @@ -17,7 +17,7 @@ from rosidl_adapter.parser import Type -def test_type_constructor(): +def test_type_constructor() -> None: type_ = Type('bool') assert type_.pkg_name is None assert type_.type == 'bool' @@ -63,7 +63,7 @@ def test_type_constructor(): Type('bool[<=0]') -def test_type_methods(): +def test_type_methods() -> None: assert Type('bool[5]') != 23 assert Type('pkg/Foo') == Type('pkg/Foo') diff --git a/rosidl_adapter/test/test_valid_names.py b/rosidl_adapter/test/test_valid_names.py index 0e449ead5..ad62aef48 100644 --- a/rosidl_adapter/test/test_valid_names.py +++ b/rosidl_adapter/test/test_valid_names.py @@ -21,7 +21,7 @@ from rosidl_adapter.parser import is_valid_package_name -def test_is_valid_package_name(): +def test_is_valid_package_name() -> None: for valid_package_name in [ 'foo', 'foo_bar']: assert is_valid_package_name(valid_package_name) @@ -29,10 +29,10 @@ def test_is_valid_package_name(): '_foo', 'foo_', 'foo__bar', 'foo-bar']: assert not is_valid_package_name(invalid_package_name) with pytest.raises(InvalidResourceName): - is_valid_package_name(None) + is_valid_package_name(None) # type: ignore[arg-type] -def test_is_valid_field_name(): +def test_is_valid_field_name() -> None: for valid_field_name in [ 'foo', 'foo_bar']: is_valid_field_name(valid_field_name) @@ -40,10 +40,10 @@ def test_is_valid_field_name(): '_foo', 'foo_', 'foo__bar', 'foo-bar']: assert not is_valid_field_name(invalid_field_name) with pytest.raises(InvalidResourceName): - is_valid_field_name(None) + is_valid_field_name(None) # type: ignore[arg-type] -def test_is_valid_message_name(): +def test_is_valid_message_name() -> None: for valid_message_name in [ 'Foo', 'FooBar']: assert is_valid_message_name(valid_message_name) @@ -51,10 +51,10 @@ def test_is_valid_message_name(): '0foo', '_Foo', 'Foo_', 'Foo_Bar']: assert not is_valid_message_name(invalid_message_name) with pytest.raises(InvalidResourceName): - is_valid_message_name(None) + is_valid_message_name(None) # type: ignore[arg-type] -def test_is_valid_constant_name(): +def test_is_valid_constant_name() -> None: for valid_constant_name in [ 'FOO', 'FOO_BAR']: assert is_valid_constant_name(valid_constant_name) @@ -62,4 +62,4 @@ def test_is_valid_constant_name(): '_FOO', 'FOO_', 'FOO__BAR', 'Foo']: assert not is_valid_constant_name(invalid_constant_name) with pytest.raises(InvalidResourceName): - is_valid_constant_name(None) + is_valid_constant_name(None) # type: ignore[arg-type] diff --git a/rosidl_adapter/test/test_validate_field_types.py b/rosidl_adapter/test/test_validate_field_types.py index 3db4a91de..f585ad28e 100644 --- a/rosidl_adapter/test/test_validate_field_types.py +++ b/rosidl_adapter/test/test_validate_field_types.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest +from typing import List +import pytest from rosidl_adapter.parser import BaseType from rosidl_adapter.parser import Field from rosidl_adapter.parser import MessageSpecification @@ -22,9 +23,9 @@ from rosidl_adapter.parser import validate_field_types -def test_validate_field_types(): +def test_validate_field_types() -> None: msg_spec = MessageSpecification('pkg', 'Foo', [], []) - known_msg_type = [] + known_msg_type: List[BaseType] = [] validate_field_types(msg_spec, known_msg_type) msg_spec.fields.append(Field(Type('bool'), 'foo')) diff --git a/rosidl_cli/package.xml b/rosidl_cli/package.xml index 46f5c7ee5..e45804c7a 100644 --- a/rosidl_cli/package.xml +++ b/rosidl_cli/package.xml @@ -22,6 +22,7 @@ ament_copyright ament_flake8 + ament_mypy ament_pep257 ament_xmllint python3-pytest diff --git a/rosidl_cli/py.typed b/rosidl_cli/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/rosidl_cli/rosidl_cli/cli.py b/rosidl_cli/rosidl_cli/cli.py index 9c3f8ac28..ce729a816 100644 --- a/rosidl_cli/rosidl_cli/cli.py +++ b/rosidl_cli/rosidl_cli/cli.py @@ -14,13 +14,18 @@ import argparse import signal +from typing import Any, List, Union from rosidl_cli.command.generate import GenerateCommand from rosidl_cli.command.translate import TranslateCommand from rosidl_cli.common import get_first_line_doc -def add_subparsers(parser, cli_name, commands): +def add_subparsers( + parser: argparse.ArgumentParser, + cli_name: str, + commands: List[Union[GenerateCommand, TranslateCommand]] +) -> argparse._SubParsersAction[argparse.ArgumentParser]: """ Create argparse subparser for each command. @@ -63,7 +68,7 @@ def add_subparsers(parser, cli_name, commands): return subparser -def main(): +def main() -> Union[str, signal.Signals, Any]: script_name = 'rosidl' description = f'{script_name} is an extensible command-line tool ' \ 'for ROS interface generation.' @@ -74,7 +79,8 @@ def main(): formatter_class=argparse.RawDescriptionHelpFormatter ) - commands = [GenerateCommand(), TranslateCommand()] + commands: List[Union[GenerateCommand, TranslateCommand]] = \ + [GenerateCommand(), TranslateCommand()] # add arguments for command extension(s) add_subparsers( diff --git a/rosidl_cli/rosidl_cli/command/__init__.py b/rosidl_cli/rosidl_cli/command/__init__.py index 22d035bb5..187a6c187 100644 --- a/rosidl_cli/rosidl_cli/command/__init__.py +++ b/rosidl_cli/rosidl_cli/command/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse + class Command: """ @@ -22,8 +24,8 @@ class Command: * `add_arguments` """ - def add_arguments(self, parser): + def add_arguments(self, parser: argparse.ArgumentParser) -> None: pass - def main(self, *, parser, args): + def main(self, *, args: argparse.Namespace) -> None: raise NotImplementedError() diff --git a/rosidl_cli/rosidl_cli/command/generate/__init__.py b/rosidl_cli/rosidl_cli/command/generate/__init__.py index ee46a937a..b12b946c1 100644 --- a/rosidl_cli/rosidl_cli/command/generate/__init__.py +++ b/rosidl_cli/rosidl_cli/command/generate/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import pathlib from rosidl_cli.command import Command @@ -24,7 +25,7 @@ class GenerateCommand(Command): name = 'generate' - def add_arguments(self, parser): + def add_arguments(self, parser: argparse.ArgumentParser) -> None: parser.add_argument( '-o', '--output-path', metavar='PATH', type=pathlib.Path, default=None, @@ -50,7 +51,7 @@ def add_arguments(self, parser): "If prefixed by another path followed by a colon ':', " 'path resolution is performed against such path.')) - def main(self, *, args): + def main(self, *, args: argparse.Namespace) -> None: generate( package_name=args.package_name, interface_files=args.interface_files, diff --git a/rosidl_cli/rosidl_cli/command/generate/api.py b/rosidl_cli/rosidl_cli/command/generate/api.py index ebec89144..ff7edbbcb 100644 --- a/rosidl_cli/rosidl_cli/command/generate/api.py +++ b/rosidl_cli/rosidl_cli/command/generate/api.py @@ -14,20 +14,22 @@ import os import pathlib +from typing import List, Optional +from .extensions import GenerateCommandExtension from .extensions import load_type_extensions from .extensions import load_typesupport_extensions def generate( *, - package_name, - interface_files, - include_paths=None, - output_path=None, - types=None, - typesupports=None -): + package_name: str, + interface_files: List[str], + include_paths: Optional[List[str]] = None, + output_path: Optional[pathlib.Path] = None, + types: Optional[List[str]] = None, + typesupports: Optional[List[str]] = None +) -> List[List[str]]: """ Generate source code from interface definition files. @@ -60,7 +62,7 @@ def generate( :returns: list of lists of paths to generated source code files, one group per type or type support extension invoked """ - extensions = [] + extensions: List[GenerateCommandExtension] = [] unspecific_generation = not types and not typesupports diff --git a/rosidl_cli/rosidl_cli/command/generate/extensions.py b/rosidl_cli/rosidl_cli/command/generate/extensions.py index fde1fb6b0..a89630d71 100644 --- a/rosidl_cli/rosidl_cli/command/generate/extensions.py +++ b/rosidl_cli/rosidl_cli/command/generate/extensions.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path +from typing import cast, List, Optional + from rosidl_cli.extensions import Extension from rosidl_cli.extensions import load_extensions @@ -26,11 +29,11 @@ class GenerateCommandExtension(Extension): def generate( self, - package_name, - interface_files, - include_paths, - output_path - ): + package_name: str, + interface_files: List[str], + include_paths: List[str], + output_path: Path + ) -> List[str]: """ Generate source code. @@ -48,11 +51,17 @@ def generate( raise NotImplementedError() -def load_type_extensions(**kwargs): +def load_type_extensions(*, specs: Optional[List[str]], + strict: bool) -> List[GenerateCommandExtension]: """Load extensions for type representation source code generation.""" - return load_extensions('rosidl_cli.command.generate.type_extensions', **kwargs) + extensions = load_extensions('rosidl_cli.command.generate.type_extensions', specs=specs, + strict=strict) + return cast(List[GenerateCommandExtension], extensions) -def load_typesupport_extensions(**kwargs): +def load_typesupport_extensions(*, specs: Optional[List[str]], strict: bool + ) -> List[GenerateCommandExtension]: """Load extensions for type support source code generation.""" - return load_extensions('rosidl_cli.command.generate.typesupport_extensions', **kwargs) + extensions = load_extensions('rosidl_cli.command.generate.typesupport_extensions', + specs=specs, strict=strict) + return cast(List[GenerateCommandExtension], extensions) diff --git a/rosidl_cli/rosidl_cli/command/helpers.py b/rosidl_cli/rosidl_cli/command/helpers.py index f23cc9a88..d81b0cdc3 100644 --- a/rosidl_cli/rosidl_cli/command/helpers.py +++ b/rosidl_cli/rosidl_cli/command/helpers.py @@ -17,9 +17,10 @@ import os import pathlib import tempfile +from typing import Generator, List, Tuple -def package_name_from_interface_file_path(path): +def package_name_from_interface_file_path(path: pathlib.Path) -> str: """ Derive ROS package name from a ROS interface definition file path. @@ -29,7 +30,7 @@ def package_name_from_interface_file_path(path): return pathlib.Path(os.path.abspath(path)).parents[1].name -def dependencies_from_include_paths(include_paths): +def dependencies_from_include_paths(include_paths: List[str]) -> List[str]: """ Collect dependencies' ROS interface definition files from include paths. @@ -45,7 +46,7 @@ def dependencies_from_include_paths(include_paths): }) -def interface_path_as_tuple(path): +def interface_path_as_tuple(path: str) -> Tuple[pathlib.Path, pathlib.Path]: """ Express interface definition file path as an (absolute prefix, relative path) tuple. @@ -61,18 +62,20 @@ def interface_path_as_tuple(path): """ path_as_string = str(path) if ':' not in path_as_string: - prefix = pathlib.Path.cwd() + prefix_path = pathlib.Path.cwd() else: prefix, _, path = path_as_string.rpartition(':') - prefix = pathlib.Path(os.path.abspath(prefix)) - path = pathlib.Path(path) - if path.is_absolute(): + prefix_path = pathlib.Path(os.path.abspath(prefix)) + path_as_path = pathlib.Path(path) + if path_as_path.is_absolute(): raise ValueError('Interface definition file path ' - f"'{path}' cannot be absolute") - return prefix, path + f"'{path_as_path}' cannot be absolute") + return prefix_path, path_as_path -def idl_tuples_from_interface_files(interface_files): +def idl_tuples_from_interface_files( + interface_files: List[str] +) -> List[str]: """ Express ROS interface definition file paths as IDL tuples. @@ -80,9 +83,9 @@ def idl_tuples_from_interface_files(interface_files): which to resolve it followed by a colon ':'. This function then applies the same logic as `interface_path_as_tuple`. """ - idl_tuples = [] - for path in interface_files: - prefix, path = interface_path_as_tuple(path) + idl_tuples: List[str] = [] + for interface_path in interface_files: + prefix, path = interface_path_as_tuple(interface_path) idl_tuples.append(f'{prefix}:{path.as_posix()}') return idl_tuples @@ -90,12 +93,12 @@ def idl_tuples_from_interface_files(interface_files): @contextlib.contextmanager def legacy_generator_arguments_file( *, - package_name, - interface_files, - include_paths, - templates_path, - output_path -): + package_name: str, + interface_files: List[str], + include_paths: List[str], + templates_path: str, + output_path: str +) -> Generator[str, None, None]: """ Generate a temporary rosidl generator arguments file. @@ -138,10 +141,10 @@ def legacy_generator_arguments_file( def generate_visibility_control_file( *, - package_name, - template_path, - output_path -): + package_name: str, + template_path: str, + output_path: str +) -> None: """ Generate a visibility control file from a template. diff --git a/rosidl_cli/rosidl_cli/command/translate/__init__.py b/rosidl_cli/rosidl_cli/command/translate/__init__.py index 03798db3e..f5b6368f0 100644 --- a/rosidl_cli/rosidl_cli/command/translate/__init__.py +++ b/rosidl_cli/rosidl_cli/command/translate/__init__.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import pathlib + from rosidl_cli.command import Command from .api import translate @@ -24,7 +26,7 @@ class TranslateCommand(Command): name = 'translate' - def add_arguments(self, parser): + def add_arguments(self, parser: argparse.ArgumentParser) -> None: parser.add_argument( '-o', '--output-path', metavar='PATH', type=pathlib.Path, default=None, @@ -64,7 +66,7 @@ def add_arguments(self, parser): 'path resolution is performed against such path.') ) - def main(self, *, args): + def main(self, *, args: argparse.Namespace) -> None: translate( package_name=args.package_name, interface_files=args.interface_files, diff --git a/rosidl_cli/rosidl_cli/command/translate/api.py b/rosidl_cli/rosidl_cli/command/translate/api.py index b63db278e..a7d18bd45 100644 --- a/rosidl_cli/rosidl_cli/command/translate/api.py +++ b/rosidl_cli/rosidl_cli/command/translate/api.py @@ -15,20 +15,21 @@ import collections import os import pathlib +from typing import DefaultDict, Dict, List, Optional, Union from .extensions import load_translate_extensions def translate( *, - package_name, - interface_files, - output_format, - input_format=None, - include_paths=None, - output_path=None, - translators=None -): + package_name: str, + interface_files: List[str], + output_format: str, + input_format: Optional[str] = None, + include_paths: Optional[List[str]] = None, + output_path: Optional[pathlib.Path] = None, + translators: Optional[List[str]] = None +) -> List[str]: """ Translate interface definition files from one format to another. @@ -64,7 +65,8 @@ def translate( raise RuntimeError('No translate extensions found') if not input_format: - interface_files_per_format = collections.defaultdict(list) + interface_files_per_format: Union[DefaultDict[str, List[str]], + Dict[str, List[str]]] = collections.defaultdict(list) for interface_file in interface_files: input_format = os.path.splitext(interface_file)[-1][1:] interface_files_per_format[input_format].append(interface_file) diff --git a/rosidl_cli/rosidl_cli/command/translate/extensions.py b/rosidl_cli/rosidl_cli/command/translate/extensions.py index f193a4804..10c3ba194 100644 --- a/rosidl_cli/rosidl_cli/command/translate/extensions.py +++ b/rosidl_cli/rosidl_cli/command/translate/extensions.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path +from typing import cast, ClassVar, List, Optional from rosidl_cli.extensions import Extension from rosidl_cli.extensions import load_extensions @@ -28,13 +30,16 @@ class TranslateCommandExtension(Extension): * `translate` """ + input_format: ClassVar[str] + output_format: ClassVar[str] + def translate( self, - package_name, - interface_files, - include_paths, - output_path - ): + package_name: str, + interface_files: List[str], + include_paths: List[str], + output_path: Path + ) -> List[str]: """ Translate interface definition files. @@ -57,8 +62,10 @@ def translate( raise NotImplementedError() -def load_translate_extensions(**kwargs): +def load_translate_extensions(*, specs: Optional[List[str]], strict: bool + ) -> List[TranslateCommandExtension]: """Load extensions for interface definition translation.""" - return load_extensions( - 'rosidl_cli.command.translate.extensions', **kwargs + extensions = load_extensions( + 'rosidl_cli.command.translate.extensions', specs=specs, strict=strict ) + return cast(List[TranslateCommandExtension], extensions) diff --git a/rosidl_cli/rosidl_cli/common.py b/rosidl_cli/rosidl_cli/common.py index 1f94c2c63..c0e8c9d42 100644 --- a/rosidl_cli/rosidl_cli/common.py +++ b/rosidl_cli/rosidl_cli/common.py @@ -13,7 +13,7 @@ # limitations under the License. -def get_first_line_doc(any_type): +def get_first_line_doc(any_type: object) -> str: if any_type.__doc__: for line in any_type.__doc__.splitlines(): line = line.strip() diff --git a/rosidl_cli/rosidl_cli/entry_points.py b/rosidl_cli/rosidl_cli/entry_points.py index edab63729..6a79673f9 100644 --- a/rosidl_cli/rosidl_cli/entry_points.py +++ b/rosidl_cli/rosidl_cli/entry_points.py @@ -12,38 +12,41 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib.metadata as importlib_metadata import logging - -try: - import importlib.metadata as importlib_metadata -except ModuleNotFoundError: - import importlib_metadata +import sys +from typing import Any, Dict, List, Optional, Tuple, Union logger = logging.getLogger(__name__) -def get_entry_points(group_name, *, specs=None, strict=False): +def get_entry_points(group_name: str, *, specs: Optional[List[str]] = None, strict: bool = False + ) -> Dict[str, importlib_metadata.EntryPoint]: """ Get entry points from a specific group. - :param str group_name: the name of the entry point group - :param list specs: an optional collection of entry point names to retrieve - :param bool strict: whether to raise or warn on error + :param group_name: the name of the entry point group + :param specs: an optional collection of entry point names to retrieve + :param strict: whether to raise or warn on error :returns: mapping from entry point names to ``EntryPoint`` instances - :rtype: dict """ if specs is not None: - specs = set(specs) + specs_set = set(specs) + else: + specs_set = None entry_points_impl = importlib_metadata.entry_points() - if hasattr(entry_points_impl, 'select'): + # Select does not exist until python 3.10 + if sys.version_info >= (3, 10): groups = entry_points_impl.select(group=group_name) else: - groups = entry_points_impl.get(group_name, []) - entry_points = {} + groups: Union[Tuple[importlib_metadata.EntryPoint, ...], + List[importlib_metadata.EntryPoint]] = entry_points_impl.get(group_name, []) + + entry_points: Dict[str, importlib_metadata.EntryPoint] = {} for entry_point in groups: name = entry_point.name - if specs and name not in specs: + if specs_set and name not in specs_set: continue if name in entry_points: msg = (f"Found duplicate entry point '{name}': " @@ -53,8 +56,8 @@ def get_entry_points(group_name, *, specs=None, strict=False): logger.warning(msg) continue entry_points[name] = entry_point - if specs: - pending = specs - set(entry_points) + if specs_set: + pending = specs_set - set(entry_points) if pending: msg = 'Some specs could not be met: ' msg += ', '.join(map(str, pending)) @@ -64,21 +67,22 @@ def get_entry_points(group_name, *, specs=None, strict=False): return entry_points -def load_entry_points(group_name, *, strict=False, **kwargs): +def load_entry_points(group_name: str, *, specs: Optional[List[str]], + strict: bool = False, + ) -> Dict[str, Any]: """ Load entry points for a specific group. See :py:meth:`get_entry_points` for further reference on additional keyword arguments. - :param str group_name: the name of the entry point group - :param bool strict: whether to raise or warn on error + :param group_name: the name of the entry point group + :param strict: whether to raise or warn on error :returns: mapping from entry point name to loaded entry point - :rtype: dict """ - loaded_entry_points = {} + loaded_entry_points: Dict[str, Any] = {} for name, entry_point in get_entry_points( - group_name, strict=strict, **kwargs + group_name, strict=strict, specs=specs ).items(): try: loaded_entry_points[name] = entry_point.load() diff --git a/rosidl_cli/rosidl_cli/extensions.py b/rosidl_cli/rosidl_cli/extensions.py index 7cc31afa2..bbc28c9e7 100644 --- a/rosidl_cli/rosidl_cli/extensions.py +++ b/rosidl_cli/rosidl_cli/extensions.py @@ -14,10 +14,19 @@ import logging import re +from typing import Any, Dict, Final, List, Optional, Tuple, TYPE_CHECKING, Union from rosidl_cli.entry_points import load_entry_points -import yaml +import yaml # type: ignore[import] + +if TYPE_CHECKING: + from typing import TypedDict + from typing_extensions import NotRequired + + class LoadExtensionsArg(TypedDict): + specs: NotRequired[Optional[List[str]]] + strict: NotRequired[bool] logger = logging.getLogger(__name__) @@ -26,18 +35,18 @@ class Extension: """A generic extension point.""" - def __init__(self, name): + def __init__(self, name: str) -> None: self.__name = name @property - def name(self): + def name(self) -> str: return self.__name -SPECS_PATTERN = re.compile(r'^(\w+)(?:\[(.+)\])?$') +SPECS_PATTERN: Final = re.compile(r'^(\w+)(?:\[(.+)\])?$') -def parse_extension_specification(spec): +def parse_extension_specification(spec: str) -> Tuple[Union[str, Any], Union[Dict[Any, Any], Any]]: """ Parse extension specification. @@ -64,18 +73,18 @@ def parse_extension_specification(spec): return name, kwargs -def load_extensions(group_name, *, specs=None, strict=False): +def load_extensions(group_name: str, *, specs: Optional[List[str]] = None, + strict: bool = False) -> List[Extension]: """ Load extensions for a specific group. - :param str group_name: the name of the extension group - :param list specs: an optional collection of extension specs + :param group_name: the name of the extension group + :param specs: an optional collection of extension specs (see :py:meth:`parse_extension_specification` for spec format) - :param bool strict: whether to raise or warn on error + :param strict: whether to raise or warn on error :returns: a list of :py:class:`Extension` instances - :rtype: list """ - extensions = [] + extensions: List[Extension] = [] if specs is not None: kwargs = dict(map( diff --git a/rosidl_cli/test/rosidl_cli/test_common.py b/rosidl_cli/test/rosidl_cli/test_common.py index 166f8bf0f..ccc4a6cb7 100644 --- a/rosidl_cli/test/rosidl_cli/test_common.py +++ b/rosidl_cli/test/rosidl_cli/test_common.py @@ -15,20 +15,20 @@ from rosidl_cli.common import get_first_line_doc -def test_getting_first_line_from_no_docstring(): +def test_getting_first_line_from_no_docstring() -> None: func = test_getting_first_line_from_no_docstring line = get_first_line_doc(func) assert line == '' -def test_getting_first_line_from_docstring(): +def test_getting_first_line_from_docstring() -> None: """Check it gets the first line.""" func = test_getting_first_line_from_docstring line = get_first_line_doc(func) assert line == 'Check it gets the first line' -def test_getting_first_line_from_multiline_docstring(): +def test_getting_first_line_from_multiline_docstring() -> None: """ Check it really gets the first non-empty line. diff --git a/rosidl_cli/test/rosidl_cli/test_extensions.py b/rosidl_cli/test/rosidl_cli/test_extensions.py index 7e3dd2158..94af493af 100644 --- a/rosidl_cli/test/rosidl_cli/test_extensions.py +++ b/rosidl_cli/test/rosidl_cli/test_extensions.py @@ -17,7 +17,7 @@ from rosidl_cli.extensions import parse_extension_specification -def test_extension_specification_parsing(): +def test_extension_specification_parsing() -> None: with pytest.raises(ValueError): parse_extension_specification('bad[') diff --git a/rosidl_cli/test/rosidl_cli/test_helpers.py b/rosidl_cli/test/rosidl_cli/test_helpers.py index 5a16f684f..bba66d4e9 100644 --- a/rosidl_cli/test/rosidl_cli/test_helpers.py +++ b/rosidl_cli/test/rosidl_cli/test_helpers.py @@ -15,6 +15,7 @@ import json import os import pathlib +from typing import Iterable import pytest @@ -22,7 +23,7 @@ from rosidl_cli.command.helpers import legacy_generator_arguments_file -def test_interface_path_as_tuple(): +def test_interface_path_as_tuple() -> None: prefix, path = interface_path_as_tuple('/tmp:msg/Empty.idl') assert pathlib.Path('msg/Empty.idl') == path assert pathlib.Path(os.path.abspath('/tmp')) == prefix @@ -37,7 +38,7 @@ def test_interface_path_as_tuple(): @pytest.fixture -def current_path(request): +def current_path(request: pytest.FixtureRequest) -> Iterable[pathlib.Path]: path = pathlib.Path(request.module.__file__) path = path.resolve() path = path.parent @@ -49,7 +50,7 @@ def current_path(request): os.chdir(str(cwd)) -def test_legacy_generator_arguments_file(current_path): +def test_legacy_generator_arguments_file(current_path: pathlib.Path) -> None: with legacy_generator_arguments_file( package_name='foo', interface_files=['msg/Foo.idl'], diff --git a/rosidl_cli/test/test_copyright.py b/rosidl_cli/test/test_copyright.py index cf0fae31f..66a7d63eb 100644 --- a/rosidl_cli/test/test_copyright.py +++ b/rosidl_cli/test/test_copyright.py @@ -18,6 +18,6 @@ @pytest.mark.copyright @pytest.mark.linter -def test_copyright(): +def test_copyright() -> None: rc = main(argv=['.', 'test']) assert rc == 0, 'Found errors' diff --git a/rosidl_cli/test/test_flake8.py b/rosidl_cli/test/test_flake8.py index 27ee1078f..eac16eef9 100644 --- a/rosidl_cli/test/test_flake8.py +++ b/rosidl_cli/test/test_flake8.py @@ -18,7 +18,7 @@ @pytest.mark.flake8 @pytest.mark.linter -def test_flake8(): +def test_flake8() -> None: rc, errors = main_with_errors(argv=[]) assert rc == 0, \ 'Found %d code style errors / warnings:\n' % len(errors) + \ diff --git a/rosidl_cli/test/test_mypy.py b/rosidl_cli/test/test_mypy.py new file mode 100644 index 000000000..97e4f502a --- /dev/null +++ b/rosidl_cli/test/test_mypy.py @@ -0,0 +1,23 @@ +# Copyright 2024 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_mypy.main import main +import pytest + + +@pytest.mark.mypy +@pytest.mark.linter +def test_mypy() -> None: + rc = main(argv=[]) + assert rc == 0, 'Found type errors!' diff --git a/rosidl_cli/test/test_pep257.py b/rosidl_cli/test/test_pep257.py index 0e38a6c60..4ae521a5a 100644 --- a/rosidl_cli/test/test_pep257.py +++ b/rosidl_cli/test/test_pep257.py @@ -18,6 +18,6 @@ @pytest.mark.linter @pytest.mark.pep257 -def test_pep257(): +def test_pep257() -> None: rc = main(argv=[]) assert rc == 0, 'Found code style errors / warnings' diff --git a/rosidl_cli/test/test_xmllint.py b/rosidl_cli/test/test_xmllint.py index f46285e71..08bf7fd78 100644 --- a/rosidl_cli/test/test_xmllint.py +++ b/rosidl_cli/test/test_xmllint.py @@ -18,6 +18,6 @@ @pytest.mark.linter @pytest.mark.xmllint -def test_xmllint(): +def test_xmllint() -> None: rc = main(argv=[]) assert rc == 0, 'Found errors' diff --git a/rosidl_pycommon/package.xml b/rosidl_pycommon/package.xml index d63cb3751..83a9cbd11 100644 --- a/rosidl_pycommon/package.xml +++ b/rosidl_pycommon/package.xml @@ -21,6 +21,7 @@ ament_flake8 ament_pep257 ament_mypy + ament_xmllint python3-pytest diff --git a/rosidl_pycommon/test/test_xmllint.py b/rosidl_pycommon/test/test_xmllint.py new file mode 100644 index 000000000..08bf7fd78 --- /dev/null +++ b/rosidl_pycommon/test/test_xmllint.py @@ -0,0 +1,23 @@ +# Copyright 2019 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_xmllint.main import main +import pytest + + +@pytest.mark.linter +@pytest.mark.xmllint +def test_xmllint() -> None: + rc = main(argv=[]) + assert rc == 0, 'Found errors'