Skip to content

Commit

Permalink
compiler alignment (ivy-llc#18315)
Browse files Browse the repository at this point in the history
helper functions
  • Loading branch information
juliagsy authored Jul 12, 2023
1 parent 81eded9 commit db0ddc5
Showing 1 changed file with 130 additions and 0 deletions.
130 changes: 130 additions & 0 deletions ivy/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,141 @@
import functools
from typing import Callable
import traceback as tb
import inspect
import os
import ast
import builtins

# Helpers #
# ------- #

def _remove_so_log(trace):
old_stack_trace = tb.extract_tb(trace)
old_frames = inspect.getinnerframes(trace)

transpile_frame = None
module_frame = None
module_st = None
compiled_lineno = None

new_stack_trace = []
track = False

for idx, st in enumerate(old_stack_trace):
if ".pyx" in repr(st):
continue
if "<string>" in repr(st):
if "compiled_fn" in repr(st) and module_frame:
track = True
compiled_lineno = st.lineno

if "<module>" in repr(st):
module_frame = old_frames[idx]
module_st = st
elif (
transpile_frame is None
and os.path.join("ivy", "compiler") in st.filename
and st.name in ["compile", "transpile"]
):
transpile_frame = old_frames[idx]
elif track:
ret_st = _align_source(
st, transpile_frame, module_frame, module_st, compiled_lineno
)
if ret_st:
[new_stack_trace.append(r) for r in ret_st]

if track:
track = False
else:
new_stack_trace.append(st)

return new_stack_trace


def _align_source(st, transpile_frame, module_frame, module_st, compiled_lineno):
from ivy.compiler.utils.VVX import trace_obj
from ivy.compiler.utils.IIV import Graph

curr_obj = [None, None, "", ""]
if transpile_frame:
t_v = inspect.getargvalues(transpile_frame.frame)
obj = t_v.locals[t_v.varargs][0]

traced_data = trace_obj(obj, t_v.locals["args"], t_v.locals["kwargs"], {})
curr_obj[0] = traced_data[1]
curr_obj[1] = traced_data[2]
curr_obj[2] = traced_data[3]

if module_frame:
t_v = inspect.getargvalues(module_frame.frame)
for k, v in t_v.locals.items():
if k in module_st.line and isinstance(v, Graph):
traced_data = trace_obj(t_v.locals[v.__name__], (), {}, {})
curr_obj[0] = traced_data[1]
curr_obj[1] = traced_data[2]
curr_obj[2] = v.__name__

if compiled_lineno:
line = v._Graph__fn_str.split("\n")[compiled_lineno - 1]
line = line.split("=")[1].strip()
line = line.split("(")[0].strip()
target_name = line.split(".")[-1].strip()
curr_obj[3] = line
area = compiled_lineno / len(v._Graph__fn_str.strip().split("\n"))

curr_obj = _get_traces(curr_obj, area, t_v.locals, target_name)

if curr_obj[0] is None:
return None
if not isinstance(curr_obj[0], list):
curr_obj = [curr_obj]
return curr_obj


def _get_traces(curr_obj, area, local_dict, target_name):
from ivy.compiler.utils.VVX import trace_obj, get_source_code, CallVistior

traces_list = []
func = local_dict[curr_obj[2]]
func_module = inspect.getmodule(func)
rooted_source = get_source_code(func).strip()

try:
module_ast = ast.parse(rooted_source)
visitor = CallVistior(func_module)
visitor.visit(module_ast)
except SyntaxError:
pass

non_lib_objs_name_list = [f.__name__ for f in visitor.non_lib_objs]
rooted_src_list = rooted_source.split("\n")
max_idx = round(len(rooted_src_list) * area) - 1

for i in range(max_idx, 0, -1):
if target_name in rooted_src_list[i]:
curr_obj[3] = rooted_src_list[i]
curr_obj[1] += i
break
elif builtins.any(
[name in rooted_src_list[i] for name in non_lib_objs_name_list]
):
found = False
for name in non_lib_objs_name_list:
if name in rooted_src_list[i]:
traced_data = trace_obj(local_dict[name], (), {}, {})
ret_obj = [traced_data[1], traced_data[2], name, curr_obj[3]]
ret_obj = _get_traces(ret_obj, 1, local_dict, target_name)
if ret_obj:
traces_list += ret_obj
found = True
break
if found:
curr_obj[3] = rooted_src_list[i]
curr_obj[1] += i
break
return [curr_obj] + traces_list


def _check_if_path_found(path , full_path):
"""
Expand Down

0 comments on commit db0ddc5

Please sign in to comment.