diff --git a/src/fuzz_introspector/frontends/core.py b/src/fuzz_introspector/frontends/core.py index 31cd3ad88..277d99a34 100644 --- a/src/fuzz_introspector/frontends/core.py +++ b/src/fuzz_introspector/frontends/core.py @@ -17,11 +17,12 @@ from fuzz_introspector.frontends import (frontend_c, frontend_cpp, frontend_go, frontend_jvm, frontend_rust) +from fuzz_introspector.frontends.datatypes import SourceCodeFile logger = logging.getLogger(name=__name__) -def analyse_source_file(code: bytes, language: str): +def analyse_source_file(code: bytes, language: str) -> SourceCodeFile: """Runs frontend analysis on a code snippet. The code snippet should correspond to what you'd normally find in @@ -32,14 +33,18 @@ def analyse_source_file(code: bytes, language: str): if language == 'c': return frontend_c.analyse_source_code(code) - elif language == 'cpp': + + if language == 'cpp': return frontend_cpp.analyse_source_code(code) - elif language == 'go': + + if language == 'go': return frontend_go.analyse_source_code(code) - elif language == 'jvm': + + if language == 'jvm': return frontend_jvm.analyse_source_code(code) - elif language == 'rust': + + if language == 'rust': return frontend_rust.analyse_source_code(code) - else: - logger.info('Language %s not supported', language) + + logger.info('Language %s not supported', language) return None diff --git a/src/fuzz_introspector/frontends/frontend_jvm.py b/src/fuzz_introspector/frontends/frontend_jvm.py index 637c33895..8607c19de 100644 --- a/src/fuzz_introspector/frontends/frontend_jvm.py +++ b/src/fuzz_introspector/frontends/frontend_jvm.py @@ -71,7 +71,7 @@ class JvmSourceCodeFile(SourceCodeFile): """Class for holding file-specific information.""" - def language_specific_process(self): + def language_specific_process(self) -> None: """Perform some language specific processes in subclasses.""" # List of definitions in the source file. self.package = '' @@ -110,7 +110,7 @@ def _set_package_declaration(self): def _set_class_interface_declaration(self): """Internal helper for retrieving all classes.""" for node in self.root.children: - if node.type == 'class_declaration' or node.type == 'interface_declaration': + if node.type in ['class_declaration', 'interface_declaration']: self.classes.append( JavaClassInterface(node, self.tree_sitter_lang, self)) @@ -280,7 +280,8 @@ def post_process_full_qualified_name(self): class_name = self.parent_source.get_full_qualified_name( self.class_interface.name) if '[' not in self.name and '].' not in self.name: - self.name = f'[{class_name}].{self.name}({",".join(self.arg_types)})' + self.name = (f'[{class_name}].{self.name}' + f'({",".join(self.arg_types)})') # Refine variable map for key in self.var_map: @@ -343,7 +344,7 @@ def _process_declaration(self): self.return_type = child.text.decode() # Process body and store statment nodes - elif child.type == 'block' or child.type == 'constructor_body': + elif child.type in ['block', 'constructor_body']: for stmt in child.children: if stmt.type not in ['{', '}' ] and 'comment' not in stmt.type: @@ -451,24 +452,24 @@ def _process_invoke_object( # Variable call or static call else: - return_value = self.var_map.get(stmt.text.decode(), '') + var_name = stmt.text.decode() if stmt.text else '' + return_value = self.var_map.get(var_name, '') if not return_value: return_value = self.class_interface.class_fields.get( - stmt.text.decode(), '') + var_name, '') if not return_value and self.parent_source: - return_value = self.parent_source.imports.get( - stmt.text.decode(), '') + return_value = self.parent_source.imports.get(var_name, '') else: # Field access if stmt.type == 'field_access': - object = stmt.child_by_field_name('object') + obj = stmt.child_by_field_name('object') field = stmt.child_by_field_name('field') - if object and field: + if obj and field: object_class, callsites = self._process_invoke_object( - object, classes) + obj, classes) cls = classes.get(object_class) - if cls: + if cls and field.text: return_value = cls.class_fields.get( field.text.decode(), self.class_interface.name) @@ -493,19 +494,22 @@ def _process_invoke_object( for cast in stmt.children: if cast.type == 'cast_expression' and self.parent_source: value = cast.child_by_field_name('value') - cast_type = cast.child_by_field_name( - 'type').text.decode() - return_value = self.parent_source.get_full_qualified_name( - cast_type) - if value and value.type == 'method_invocation': + cast_type = cast.child_by_field_name('type') + if not value or not cast_type or not cast_type.text: + continue + return_value = ( + self.parent_source.get_full_qualified_name( + cast_type.text.decode())) + + if value.type == 'method_invocation': _, invoke_callsites = self._process_invoke( value, classes) callsites.extend(invoke_callsites) - if value and value.type == 'object_creation_expression': + elif value.type == 'object_creation_expression': _, invoke_callsites = self._process_invoke( value, classes, True) callsites.extend(invoke_callsites) - if value and value.type == 'explicit_constructor_invocation': + elif value.type == 'explicit_constructor_invocation': _, invoke_callsites = self._process_invoke( value, classes, True) callsites.extend(invoke_callsites) @@ -550,10 +554,11 @@ def _process_invoke_args( # Variables elif argument.type == 'identifier': - return_value = self.var_map.get(argument.text.decode(), '') + arg_name = argument.text.decode() if argument.text else '' + return_value = self.var_map.get(arg_name, '') if not return_value: return_value = self.class_interface.class_fields.get( - argument.text.decode(), self.class_interface.name) + arg_name, self.class_interface.name) return_values.append(return_value) # Method invocation @@ -577,14 +582,14 @@ def _process_invoke_args( # Field or static variable access elif argument.type == 'field_access': - object = argument.child_by_field_name('object') + obj = argument.child_by_field_name('object') field = argument.child_by_field_name('field') - if object and field: + if obj and field: object_class, callsites = self._process_invoke_object( - object, classes) + obj, classes) cls = classes.get(object_class) - if cls: + if cls and field.text: return_value = cls.class_fields.get( field.text.decode(), self.class_interface.name) return_values.append(return_value) @@ -592,17 +597,21 @@ def _process_invoke_args( # Type casting expression elif argument.type == 'cast_expression' and self.parent_source: value = argument.child_by_field_name('value') - cast_type = argument.child_by_field_name('type').text.decode() + cast_type = argument.child_by_field_name('type') + if not value or not cast_type or not cast_type.text: + continue + return_value = self.parent_source.get_full_qualified_name( - cast_type) - if value and value.type == 'method_invocation': + cast_type.text.decode()) + + if value.type == 'method_invocation': _, invoke_callsites = self._process_invoke(value, classes) callsites.extend(invoke_callsites) - if value and value.type == 'object_creation_expression': + elif value.type == 'object_creation_expression': _, invoke_callsites = self._process_invoke( value, classes, True) callsites.extend(invoke_callsites) - if value and value.type == 'explicit_constructor_invocation': + elif value.type == 'explicit_constructor_invocation': _, invoke_callsites = self._process_invoke( value, classes, True) callsites.extend(invoke_callsites) @@ -646,7 +655,8 @@ def _process_invoke( elif cls_type.type.endswith( 'type_identifier') or cls_type.type.endswith('_type'): - object_type = cls_type.text.decode().split('<')[0] + cls_name = cls_type.text.decode() if cls_type.text else '' + object_type = cls_name.split('<')[0] object_type = self.parent_source.get_full_qualified_name( object_type) @@ -673,33 +683,36 @@ def _process_invoke( # Process this method invocation target_name = '' - if object_type and name: + if object_type and name and name.text: for cls in classes.values(): packaged_type = cls.add_package_to_class_name(object_type) if packaged_type: object_type = packaged_type break - target_name = f'[{object_type}].{name.text.decode()}({",".join(argument_types)})' + target_name = (f'[{object_type}].{name.text.decode()}' + f'({",".join(argument_types)})') callsites.append( (target_name, expr.byte_range[1], expr.start_point.row + 1)) # Calling to library outside of project # Preserve the full method call - elif name: - if objects: + elif name and name.text: + if objects and objects.text: target_name = (f'{objects.text.decode()}.{name.text.decode()}' f'({",".join(argument_types)})') else: - target_name = f'{name.text.decode()}({",".join(argument_types)})' + target_name = (f'{name.text.decode()}' + f'({",".join(argument_types)})') callsites.append( (target_name, expr.byte_range[1], expr.start_point.row + 1)) # Determine return value from method invocation if object_type == 'com.code_intelligence.jazzer.api.FuzzedDataProvider': - return_type = FUZZING_METHOD_RETURN_TYPE_MAP.get( - name.text.decode(), '') + if name and name.text: + return_type = FUZZING_METHOD_RETURN_TYPE_MAP.get( + name.text.decode(), '') else: return_type = self.class_interface.name if object_type in classes and target_name: @@ -719,52 +732,63 @@ def _process_callsites( self, stmt: Node, classes: dict[str, 'JavaClassInterface'] ) -> tuple[str, list[tuple[str, int, int]]]: """Process and store the callsites of the method.""" - type = '' + type_str = '' callsites: list[tuple[str, int, int]] = [] if not stmt: - return type, callsites + return type_str, callsites if stmt.type == 'method_invocation': - type, invoke_callsites = self._process_invoke(stmt, classes) + type_str, invoke_callsites = self._process_invoke(stmt, classes) callsites.extend(invoke_callsites) elif stmt.type == 'object_creation_expression': - type, invoke_callsites = self._process_invoke(stmt, classes, True) + type_str, invoke_callsites = self._process_invoke( + stmt, classes, True) callsites.extend(invoke_callsites) elif stmt.type == 'explicit_constructor_invocation': - type, invoke_callsites = self._process_invoke(stmt, classes, True) + type_str, invoke_callsites = self._process_invoke( + stmt, classes, True) callsites.extend(invoke_callsites) elif stmt.type == 'assignment_expression': left = stmt.child_by_field_name('left') right = stmt.child_by_field_name('right') + if not left or not left.text or not right: + return type_str, callsites var_name = left.text.decode().split(' ')[-1] - type, invoke_callsites = self._process_callsites(right, classes) - self.var_map[var_name] = type + type_str, invoke_callsites = self._process_callsites( + right, classes) + self.var_map[var_name] = type_str callsites.extend(invoke_callsites) elif stmt.type.endswith('local_variable_declarattion'): - for vars in stmt.children: - if vars.type == 'variable_declarator': - var_name = vars.child_by_field_name('name').text.decode() - value_node = vars.child_by_field_name('value') - - type, invoke_callsites = self._process_callsites( + for var_del in stmt.children: + if var_del.type == 'variable_declarator': + name_node = var_del.child_by_field_name('name') + value_node = var_del.child_by_field_name('value') + if not name_node or not name_node.text or not value_node: + continue + + var_name = name_node.text.decode() + type_str, invoke_callsites = self._process_callsites( value_node, classes) - self.var_map[var_name] = type + self.var_map[var_name] = type_str callsites.extend(invoke_callsites) elif stmt.type.endswith('variable_declarator'): - var_name = stmt.child_by_field_name('name').text.decode() + name_node = stmt.child_by_field_name('name') value_node = stmt.child_by_field_name('value') + if not name_node or not name_node.text or not value_node: + return type_str, callsites - type, invoke_callsites = self._process_callsites( + var_name = name_node.text.decode() + type_str, invoke_callsites = self._process_callsites( value_node, classes) - self.var_map[var_name] = type + self.var_map[var_name] = type_str callsites.extend(invoke_callsites) else: for child in stmt.children: callsites.extend(self._process_callsites(child, classes)[1]) - return type, callsites + return type_str, callsites def extract_callsites(self, classes: dict[str, 'JavaClassInterface']): """Extract callsites.""" @@ -781,7 +805,7 @@ def extract_callsites(self, classes: dict[str, 'JavaClassInterface']): if not self.detailed_callsites: for dst, src_line in self.base_callsites: - src_loc = self.class_interface.name + ':%d,1' % (src_line) + src_loc = f'{self.class_interface.name}:{src_line},1' self.detailed_callsites.append({'Src': src_loc, 'Dst': dst}) @@ -866,25 +890,29 @@ def _process_node(self) -> list[Node]: # Process super class if child.type == 'superclass': for cls in child.children: - if cls.type.endswith('type_identifier'): + if cls.type.endswith('type_identifier') and cls.text: self.super_class = cls.text.decode() # Process super interfaces elif child.type == 'super_interfaces': for interfaces in child.children: - if interfaces.type == 'type_list': - type_set = set() - for interface in interfaces.children: - if interface.type.endswith('type_identifier'): - type_set.add(interface.text.decode()) - self.super_interfaces = list(type_set) + if interfaces.type != 'type_list': + continue + + type_set = set() + for interface in interfaces.children: + if (interface.type.endswith('type_identifier') + and interface.text): + type_set.add(interface.text.decode()) + self.super_interfaces = list(type_set) # Process modifiers elif child.type == 'modifiers': for modifier in child.children: - if modifier.text.decode() == 'public': + modi_txt = modifier.text.decode() if modifier.text else '' + if modi_txt == 'public': self.class_public = True - if modifier.text.decode() == 'abstract': + if modi_txt == 'abstract': self.class_concrete = False # Process modifiers for interface @@ -894,12 +922,12 @@ def _process_node(self) -> list[Node]: # Process name elif child.type == 'identifier': - self.name = child.text.decode() + self.name = child.text.decode() if child.text else '' if self.package: self.name = f'{self.package}.{self.name}' # Process body - elif child.type == 'class_body' or child.type == 'interface_body': + elif child.type in ['class_body', 'interface_body']: for body in child.children: # Process constructors if body.type == 'constructor_declaration': @@ -912,20 +940,28 @@ def _process_node(self) -> list[Node]: # Process class fields elif body.type == 'field_declaration': field_name = None - field_type = body.child_by_field_name( - 'type').text.decode() - for fields in body.children: + type_node = body.child_by_field_name('type') + if not type_node or not type_node.text: + continue + + field_type = type_node.text.decode() + fields = [ + field for field in body.children + if field.type == 'variable_declarator' + ] + for field in fields: # Process field_name - if fields.type == 'variable_declarator': - self.constructor_callsites.append(fields) - field_name = fields.child_by_field_name( - 'name').text.decode() + self.constructor_callsites.append(field) + name_node = field.child_by_field_name('name') - if field_name and field_type: + if name_node and name_node.text and field_type: + field_name = name_node.text.decode() self.class_fields[field_name] = field_type # Process inner classes or interfaces - elif body.type == 'class_declaration' or body.type == 'interface_declaration': + elif body.type in [ + 'class_declaration', 'interface_declaration' + ]: inner_class_nodes.append(body) return inner_class_nodes @@ -954,7 +990,8 @@ def get_all_methods(self) -> list[JavaMethod]: def get_entry_method_name(self) -> Optional[str]: """Get the entry method name for this class. - It can be the provided entrypoint of method with @FuzzTest annotation.""" + It can be the provided entrypoint or method with + @FuzzTest annotation.""" for method in self.get_all_methods(): if method.is_entry_method: return method.name @@ -1209,13 +1246,13 @@ def extract_calltree(self, return line_to_print visited_functions.add(function) - for cs, line_number in callsites: + for cs, line in callsites: line_to_print += self.extract_calltree( source_code.source_file, function=cs, visited_functions=visited_functions, depth=depth + 1, - line_number=line_number) + line_number=line) return line_to_print @@ -1264,7 +1301,8 @@ def get_reachable_functions( def load_treesitter_trees(source_files: list[str], entrypoint: str, is_log: bool = True) -> JvmProject: - """Creates treesitter trees for all files in a given list of source files.""" + """Creates treesitter trees for all files in a given list of + source files.""" results = [] for code_file in source_files: @@ -1277,8 +1315,7 @@ def load_treesitter_trees(source_files: list[str], return JvmProject(results) -def analyse_source_code(source_content: str, - entrypoint: str) -> JvmSourceCodeFile: +def analyse_source_code(source_content: str) -> JvmSourceCodeFile: """Returns a source abstraction based on a single source string.""" source_code = JvmSourceCodeFile('jvm', source_file='in-memory string', diff --git a/src/fuzz_introspector/frontends/frontend_rust.py b/src/fuzz_introspector/frontends/frontend_rust.py index 6b7c0104a..4415590ec 100644 --- a/src/fuzz_introspector/frontends/frontend_rust.py +++ b/src/fuzz_introspector/frontends/frontend_rust.py @@ -866,10 +866,8 @@ def load_treesitter_trees(source_files: list[str], return RustProject(results) -def analyse_source_code(source_content: str, - entrypoint: str) -> RustSourceCodeFile: +def analyse_source_code(source_content: str) -> RustSourceCodeFile: """Returns a source abstraction based on a single source string.""" - # pylint: disable=unused-argument source_code = RustSourceCodeFile('rust', source_file='in-memory string', source_content=source_content.encode())