Skip to content

Commit

Permalink
[JVM] Fix mypy error for frontend-jvm (#1925)
Browse files Browse the repository at this point in the history
* [JVM]: Fix mypy error for frontend-jvm

Signed-off-by: Arthur Chan <arthur.chan@adalogics.com>

* Fix main.py

Signed-off-by: Arthur Chan <arthur.chan@adalogics.com>

* Fix formatting

Signed-off-by: Arthur Chan <arthur.chan@adalogics.com>

---------

Signed-off-by: Arthur Chan <arthur.chan@adalogics.com>
  • Loading branch information
arthurscchan authored Dec 27, 2024
1 parent 44a3abd commit c19b7d9
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 45 deletions.
94 changes: 49 additions & 45 deletions src/fuzz_introspector/frontends/frontend_jvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def __init__(self,

# List of definitions in the source file.
self.package = ''
self.classes = []
self.imports = {}
self.classes: list['JavaClassInterface'] = []
self.imports: dict[str, str] = {}

# Initialization ruotines
self.load_tree()
Expand Down Expand Up @@ -233,34 +233,35 @@ def __init__(self,
self.root = root
self.class_interface = class_interface
self.tree_sitter_lang = self.class_interface.tree_sitter_lang
self.parent_source = self.class_interface.parent_source
self.parent_source: Optional[
SourceCodeFile] = self.class_interface.parent_source
self.is_constructor = is_constructor

# Store method line information
self.start_line = self.root.start_point.row + 1
self.end_line = self.root.end_point.row + 1

# Other properties
self.name = ''
self.name: str = ''
self.complexity = 0
self.icount = 0
self.arg_names = []
self.arg_types = []
self.exceptions = []
self.arg_names: list[str] = []
self.arg_types: list[str] = []
self.exceptions: list[str] = []
self.return_type = ''
self.sig = ''
self.function_uses = 0
self.function_depth = 0
self.base_callsites = []
self.detailed_callsites = []
self.base_callsites: list[tuple[str, int]] = []
self.detailed_callsites: list[dict[str, str]] = []
self.public = False
self.concrete = True
self.static = False
self.is_entry_method = False

# Other properties
self.stmts = []
self.var_map = {}
self.stmts: list[Node] = []
self.var_map: dict[str, str] = {}

# Process method declaration
self._process_declaration()
Expand Down Expand Up @@ -447,7 +448,7 @@ def _process_invoke_object(
self, stmt: Node, classes: dict[str, 'JavaClassInterface']
) -> tuple[str, list[tuple[str, int, int]]]:
"""Internal helper for processing the object from a invocation."""
callsites = []
callsites: list[tuple[str, int, int]] = []
return_value = ''
# Determine the type of the object
if stmt.child_count == 0:
Expand All @@ -465,7 +466,7 @@ def _process_invoke_object(
if not return_value:
return_value = self.class_interface.class_fields.get(
stmt.text.decode(), '')
if not return_value:
if not return_value and self.parent_source:
return_value = self.parent_source.imports.get(
stmt.text.decode(), self.class_interface.name)
else:
Expand Down Expand Up @@ -501,7 +502,7 @@ def _process_invoke_object(
# Casting expression in Parenthesized statement
elif stmt.type == 'parenthesized_expression':
for cast in stmt.children:
if cast.type == 'cast_expression':
if cast.type == 'cast_expression' and self.parent_source:
value = cast.child_by_field_name('value')
cast_type = cast.child_by_field_name(
'type').text.decode()
Expand Down Expand Up @@ -574,7 +575,7 @@ def _process_invoke_args(
return_values.append(return_value)

# Type casting expression
elif argument.type == 'cast_expression':
elif argument.type == 'cast_expression' and self.parent_source:
value = argument.child_by_field_name('value')
cast_type = argument.child_by_field_name('type').text.decode()
return_value = self.parent_source.get_full_qualified_name(
Expand All @@ -600,7 +601,7 @@ def _process_invoke(
expr: Node,
classes: dict[str, 'JavaClassInterface'],
is_constructor_call: bool = False
) -> tuple[list[str], list[tuple[str, int, int]]]:
) -> tuple[str, list[tuple[str, int, int]]]:
"""Internal helper for processing the method invocation statement."""
callsites = []

Expand All @@ -619,7 +620,7 @@ def _process_invoke(
argument_types = []

# Process constructor call
if is_constructor_call:
if is_constructor_call and self.parent_source:
object_type = ''
for cls_type in expr.children:
if cls_type.type == 'this':
Expand Down Expand Up @@ -717,31 +718,29 @@ class JavaClassInterface():

def __init__(self,
root: Node,
tree_sitter_lang: Optional[Language] = None,
source_code: Optional[SourceCodeFile] = None,
tree_sitter_lang: Language,
source_code: SourceCodeFile,
parent: Optional['JavaClassInterface'] = None):
self.root = root
self.parent = parent
self.tree_sitter_lang = tree_sitter_lang
self.parent_source = source_code

if parent:
self.tree_sitter_lang = parent.tree_sitter_lang
self.parent_source = parent.parent_source
if self.parent:
self.package = self.parent.name
else:
self.tree_sitter_lang = tree_sitter_lang
self.parent_source = source_code
self.package = self.parent_source.package

# Properties
self.name = ''
self.name: str = ''
self.class_public = False
self.class_concrete = True
self.is_interface = False
self.methods = []
self.inner_classes = []
self.class_fields = {}
self.methods: list[JavaMethod] = []
self.inner_classes: list[JavaClassInterface] = []
self.class_fields: dict[str, str] = {}
self.super_class = 'Object'
self.super_interfaces = []
self.super_interfaces: list[str] = []

# Process the class/interface tree
inner_class_nodes = self._process_node()
Expand Down Expand Up @@ -846,7 +845,8 @@ def _process_inner_classes(self, inner_class_nodes: list[Node]):
"""Internal helper to recursively process inner classes"""
for node in inner_class_nodes:
self.inner_classes.append(
JavaClassInterface(node, None, None, self))
JavaClassInterface(node, self.tree_sitter_lang,
self.parent_source, self))

def get_all_methods(self) -> list[JavaMethod]:
all_methods = self.methods
Expand Down Expand Up @@ -903,11 +903,11 @@ def dump_module_logic(self,
harness_name: Optional[str] = None):
"""Dumps the data for the module in full."""
logger.info('Dumping project-wide logic.')
report = {'report': 'name'}
report['sources']: dict[str, Any] = []
report: dict[str, Any] = {'report': 'name'}
report['sources'] = []

all_classes = {}
project_methods = []
project_methods: list[JavaMethod] = []

# Post process source code files with full qualified names
# Retrieve full project methods, classes and information
Expand Down Expand Up @@ -941,7 +941,7 @@ def dump_module_logic(self,
# Process all project methods
method_list = []
for method in project_methods:
method_dict = {}
method_dict: dict[str, Any] = {}
method_dict['functionName'] = method.name
method_dict['functionSourceFile'] = method.class_interface.name
method_dict['functionLinenumber'] = method.start_line
Expand Down Expand Up @@ -974,7 +974,7 @@ def dump_module_logic(self,
method_dict['functionsReached'] = list(reached)

# Handles Java method properties
java_method_info = {}
java_method_info: dict[str, Any] = {}
java_method_info['exceptions'] = method.exceptions
java_method_info[
'interfaces'] = method.class_interface.super_interfaces[:]
Expand Down Expand Up @@ -1058,7 +1058,7 @@ def _recursive_method_depth(method: JavaMethod) -> int:

return depth

visited = []
visited: list[str] = []
method_dict = {method.name: method for method in all_methods}
method_depth = _recursive_method_depth(target_method)

Expand All @@ -1067,37 +1067,40 @@ def _recursive_method_depth(method: JavaMethod) -> int:
def extract_calltree(self,
source_file: str,
source_code: Optional[SourceCodeFile] = None,
method: str = None,
visited_methods: set[str] = None,
method: Optional[str] = None,
visited_methods: Optional[set[str]] = None,
depth: int = 0,
line_number: int = -1) -> str:
"""Extracts calltree string of a calltree so that FI core can use it."""
if not visited_methods:
visited_methods = set()

if not method:
if not source_code and method:
source_code = self.find_source_with_method(method)

if not method and source_code:
method = source_code.get_entry_method_name(True)

if not method or not source_code:
return ''

line_to_print = ' ' * depth
line_to_print += method
line_to_print += ' '
line_to_print += source_file

if not source_code:
source_code = self.find_source_with_method(method)

line_to_print += ' '
line_to_print += str(line_number)

line_to_print += '\n'
if not source_code:
return line_to_print

method = source_code.get_method_node(method)
if not method:
method_node = source_code.get_method_node(method)
if not method_node:
return line_to_print

callsites = method.base_callsites
callsites = method_node.base_callsites

if method in visited_methods:
return line_to_print
Expand All @@ -1110,6 +1113,7 @@ def extract_calltree(self,
visited_methods=visited_methods,
depth=depth + 1,
line_number=line_number)

return line_to_print


Expand Down
1 change: 1 addition & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

def main() -> int:
cli.main()
return 0


if __name__ == "__main__":
Expand Down

0 comments on commit c19b7d9

Please sign in to comment.