Skip to content

Commit

Permalink
frontends: light: capture typedefs and extract rows using TS (#1860)
Browse files Browse the repository at this point in the history
Signed-off-by: David Korczynski <david@adalogics.com>
  • Loading branch information
DavidKorczynski authored Nov 30, 2024
1 parent 85abfb1 commit bad9c03
Showing 1 changed file with 47 additions and 16 deletions.
63 changes: 47 additions & 16 deletions frontends/light/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def dump_module_logic(self, report_name):
'function_names':
source_code.get_defined_function_names(),
'types': {
'structs': source_code.struct_defs
'structs': source_code.struct_defs,
'typedefs': source_code.typedefs
}
})

Expand All @@ -70,16 +71,14 @@ def dump_module_logic(self, report_name):
func_dict['functionName'] = func_def.name()
#func_dict['source_file'] = source_code.source_file
func_dict['functionSourceFile'] = source_code.source_file
func_dict['functionLinenumber'] = source_code.get_linenumber(
func_def.position()[0])
func_dict[
'functionLinbernumberEnd'] = source_code.get_linenumber(
func_def.position()[1])
'functionLinenumber'] = source_code.root.start_point.row
func_dict[
'functionLinbernumberEnd'] = source_code.root.end_point.row
func_dict['linkageType'] = ''
func_dict['func_position'] = {
'start':
source_code.get_linenumber(func_def.position()[0]),
'end': source_code.get_linenumber(func_def.position()[1])
'start': source_code.root.start_point.row,
'end': source_code.root.end_point.row,
}
func_dict[
'CyclomaticComplexity'] = func_def.get_function_complexity(
Expand Down Expand Up @@ -365,9 +364,7 @@ def detailed_callsites(self):
for call_expr in call_exprs:
for call_child in call_expr.children:
if call_child.type == 'identifier':
# TODO(David): fix remaining column value
src_line = self.parent_source.get_linenumber(
call_child.byte_range[0])
src_line = call_child.start_point.row
src_loc = self.parent_source.source_file + ':%d,1' % (
src_line)
callsites.append({
Expand Down Expand Up @@ -409,6 +406,7 @@ def __init__(self, source_file, language):
self.function_names = []
self.line_range_pairs = []
self.struct_defs = []
self.typedefs = []
self.includes = set()

# List of function definitions in the source file.
Expand Down Expand Up @@ -440,13 +438,10 @@ def extract_types(self):
continue
if struct.child_by_field_name('name') is None:
continue
print(struct.text.decode())
# Go through each of the field declarations
fields = []
for child in struct.child_by_field_name('body').children:
print("- child %s" % (child.type))
if child.type == 'field_declaration':
print(child.text.decode())
fields.append({
'type':
child.child_by_field_name('type').text.decode(),
Expand All @@ -458,9 +453,45 @@ def extract_types(self):
'name':
struct.child_by_field_name('name').text.decode(),
'fields':
fields
fields,
'pos': {
'line_start': struct.start_point.row,
'line_end': struct.end_point.row,
}
})

type_query = self.tree_sitter_lang.query('( type_definition ) @tp')
type_query_res = type_query.captures(self.root)
for _, types in type_query_res.items():
for typedef in types:
print(typedef.text.decode())
# Skip if this is an anonymous struct.
# TODO(David): handle this
if typedef.child_by_field_name('declarator') is None:
continue
typedef_struct = {
'name':
typedef.child_by_field_name('declarator').text.decode()
}

typedef_struct['pos'] = {
'line_start': typedef.start_point.row,
'line_end': typedef.end_point.row,
}
typedef_type = typedef.child_by_field_name('type')
if typedef_type.type == 'struct_specifier':
if typedef.child_by_field_name('name') is not None:
typedef_struct[
'type'] = typedef_type.child_by_field_name(
'name').text.decode()
# TODO(David): handle the else branch here.
elif typedef_type.type == 'primitive_type':
typedef_struct['type'] = typedef_type.text.decode()
elif typedef_type.type == 'sized_type_specifier':
typedef_struct['type'] = typedef_type.text.decode()

self.typedefs.append(typedef_struct)

def extract_imported_header_files(self):
"""Sets the header files imported by a given module."""
if not self.root:
Expand Down Expand Up @@ -526,7 +557,7 @@ def get_linenumber(self, bytepos):
with open(self.source_file, 'r', encoding='utf-8') as f:
source_content = f.read()

payload_range = 0
payload_range = 1
for line in source_content.split('\n'):
end_line_pos = payload_range + len(line) + 1
self.line_range_pairs.append((payload_range, end_line_pos))
Expand Down

0 comments on commit bad9c03

Please sign in to comment.