Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #708 Check fastmath flags #1068

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

NimaSarajpoor
Copy link
Collaborator

@NimaSarajpoor NimaSarajpoor commented Jan 28, 2025

This PR is to fix #708. An initial inspection was done by @seanlaw in this comment. I am copying that list here for transparency and better tracking.

  • aamp._compute_diagonal - P, PL, PR contains np.inf and p can be np.inf
  • aamp._aamp - P, PL, PR contains np.inf
  • core._sliding_dot_product - Should be okay
  • core._calculate_squared_distance_profile - Should be okay (Returned value in D_squared might be np.inf but no arithmetic operation)
  • core.calculate_distance_profile - Should be okay (Returned value in D_squared might be np.inf)
  • core._p_norm_distance_profile - Should be okay (can p be np.inf? not supported. See: Add support for p=np.inf for non-normalized p-norm distance #1071 )
  • core._mass Should be okay
  • core._apply_exclusion_zone - val contains np.inf
  • core._count_diagonal_ndist - Should be okay
  • core._get_array_ranges - Should be okay
  • core._get_ranges - Should be okay
  • core._total_diagonal_ndists - Should be okay
  • fastmath._add_assoc - Should be okay
  • maamp._compute_multi_p_norm - p could possibly be np.inf (not supported. See: Add support for p=np.inf for non-normalized p-norm distance #1071). p_norm array can contain np.inf
  • mstump._compute_multi_D - Might be okay??
  • scraamp._compute_PI - P_NORM contains np.inf
  • scraamp._prescraamp - P_NORM contains np.inf
  • scrump._compute_PI - references np.inf values, so likely bad
  • scrump._prescrump - P_squared is np.inf
  • stump._compute_diagonal - ρ, ρL, and ρR contain np.inf
  • stump._stump - ρ, ρL, and ρR contain np.inf

@NimaSarajpoor NimaSarajpoor changed the title Fix #708 Fix #708 Check fastmath flags Jan 28, 2025
Copy link

codecov bot commented Jan 28, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.31%. Comparing base (bbc97e4) to head (da39fa3).
Report is 2 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1068      +/-   ##
==========================================
- Coverage   97.33%   97.31%   -0.02%     
==========================================
  Files          93       93              
  Lines       15219    15239      +20     
==========================================
+ Hits        14813    14830      +17     
- Misses        406      409       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@NimaSarajpoor
Copy link
Collaborator Author

After going through core._p_norm_distance_profile, I noticed that we should not set p to np.inf. In fact, we should not set p to np.inf in all non-normalized functions. This is because we cannot follow the usual calculation for Minkowski distance when p is np.inf as that is a limiting case. If we ignore that "limit", we can see the issue in the following example:

arr = np.array([0.9, 1.0, 1.1])
out = np.power(arr, np.inf)  # [0.0, 1.0, np.inf]

Any value less than 1.0 will be 0.0, and any value above 1.0 will be np.inf. And that will give us wrong output. So, we can add a note to the docstring that p=np.inf is not supported.

Maybe we open another issue for this case, known as Chebyshev distance. Then, we can think if we should add support for it. The distance can be computed via rolling-max approach. Not sure if it can be added to the current code base without hurting the design.


For now, we should not consider p=np.inf when we want to make a decision about fastmath flag. And we can revise the flag later if needed once the function starts supporting the case p=np.inf

@seanlaw
Copy link
Contributor

seanlaw commented Feb 2, 2025

@NimaSarajpoor I had noticed it as well when I looked the other day. I agree, we should not allow p = np.inf for now and simply ignore it (after adding a note to the docstring(s))

@NimaSarajpoor
Copy link
Collaborator Author

@seanlaw

@NimaSarajpoor I had noticed it as well when I looked the other day. I agree, we should not allow p = np.inf for now and simply ignore it (after adding a note to the docstring(s))

Created the issue #1071

@seanlaw
Copy link
Contributor

seanlaw commented Feb 7, 2025

@NimaSarajpoor Is this ready to be merged?

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented Feb 9, 2025

@seanlaw

@NimaSarajpoor Is this ready to be merged?

Not yet. I am trying to get chain of caller-callees.

I have a script that can give a dictionary with key as (module_name, func_nam), and value as list of callees (some of the code are based on the work I did initially in #1025). The code is provided below. Do you think we should add the code to STUMPY? In that case, I need to clean this up and push it. If not, I can just use it locally to get a chain of caller-callees and list them here. And then check each chain regarding fastmath flag. The reason that I am a bit hesitant to add it to STUMPY is that I am not sure if we are going to make any exceptions. The goal is to check for a switch from fastmath == True to fastmath != True (or vice versa) in caller-callee chain. But, if we decide to allow that switching happens in some cases, then adding a check to STUMPY may not make sense.

Code
# in STUMPY's root directory 

import ast
import importlib
import pathlib


def _get_func_callees(node, so_far_callees):
    for n in ast.iter_child_nodes(node):
        if isinstance(n, ast.Call):
            obj = n.func
            if isinstance(obj, ast.Attribute):  # e.g., np.sum
                name = obj.attr 
            elif isinstance(obj, ast.Name):  # e.g., sum
                name = obj.id
            else:
                msg = f"The type {type(obj)} is not supported"
                raise ValueError(msg)

            so_far_callees.append(name)

        _get_func_callees(n, so_far_callees)


def get_func_callees(func_node):
    """
    For a given node of type ast.FunctionDef, visit all of its child nodes,
    and return a list of all of its callees
    """
    out = []
    _get_func_callees(func_node, so_far_callees=out)

    return out


def get_func_nodes(filepath):
    """
    For the given `filepath`, return a dictionary with the key
    being the function name and the value being a set of function names
    that are called by the function
    """
    file_contents = ""
    with open(filepath, encoding="utf8") as f:
        file_contents = f.read()
    module = ast.parse(file_contents)

    func_nodes = [
        node for node in module.body if isinstance(node, ast.FunctionDef)
    ]

    return func_nodes


def get_callees():
    ignore = ["__init__.py", "__pycache__"]

    stumpy_path = pathlib.Path(__file__).parent / "stumpy"
    filepaths = sorted(f for f in pathlib.Path(stumpy_path).iterdir() if f.is_file())

    all_callees = {}
    for filepath in filepaths:
        file_name = filepath.name
        if (
            file_name not in ignore 
            and not file_name.startswith("gpu")
            and str(filepath).endswith(".py")
        ):
            module_name = file_name.replace(".py", "")
            module = importlib.import_module(f".{module_name}", package="stumpy")
            
            func_nodes = get_func_nodes(filepath)
            for node in func_nodes:
                all_callees[(module_name, node.name)] = get_func_callees(node)

    
    # clean all_callees to only include callees that are in stumpy
    all_stumpy_funcs = set(item[1] for item in all_callees.keys())

    out = {}
    for (module_name, func_name), callees in all_callees.items():
        lst = []
        for callee in callees:
            if callee in all_stumpy_funcs:
                lst.append(callee)
        out[(module_name, func_name)] = lst

    return out


out = get_callees()
print(out)

@NimaSarajpoor NimaSarajpoor mentioned this pull request Feb 9, 2025
59 tasks
@seanlaw
Copy link
Contributor

seanlaw commented Feb 10, 2025

The reason that I am a bit hesitant to add it to STUMPY is that I am not sure if we are going to make any exceptions. The goal is to check for a switch from fastmath == True to fastmath != True (or vice versa) in caller-callee chain. But, if we decide to allow that switching happens in some cases, then adding a check to STUMPY may not make sense.

At this point, I don't anticipate allowing any exceptions. Since this is something tedious, I would prefer to automate it and add it to test.sh like all of the other checks. Naturally, I think we would now add this to the fastmath.py script, right?

One thing to consider is that ast may not easily allow you to traverse across different Python modules

@NimaSarajpoor
Copy link
Collaborator Author

NimaSarajpoor commented Feb 11, 2025

@seanlaw

Since this is something tedious, I would prefer to automate it and add it to test.sh like all of the other checks

👍

Naturally, I think we would now add this to the fastmath.py script, right?

Right. This should be placed in ./fastmath.py, and then it can be used in the testing process via test.sh

One thing to consider is that ast may not easily allow you to traverse across different Python modules

Currently I am not trying to jump between modules. What I do is that I collect one-level-deep callees of ALL stumpy functions. that's all I need to create chain for a given caller. However, if I can find a tool that can jump from one module to different module, then finding chain should become easier. Going to look for it.

@seanlaw
Copy link
Contributor

seanlaw commented Feb 11, 2025

Currently I am not trying to jump between modules. What I do is that I collect one-level-deep callees of ALL stumpy functions. that's all I need to create chain for a given caller. However, if I can find a tool that can jump from one module to different module, then finding chain should become easier. Going to look for it.

I have some ideas and will be able to share them soon

@seanlaw
Copy link
Contributor

seanlaw commented Feb 11, 2025

@NimaSarajpoor While quite verbose, I believe that this will work nicely to generate a list of njit call stacks that is able to jump ACROSS modules:

import fastmath
import ast
import pathlib

class FunctionCallVisitor(ast.NodeVisitor):
    def __init__(self):
        super().__init__()
        self.module_names = []
        self.call_stack = []
        self.last_depth = 0
        self.out = []

        # Setup lists, dicts, and ast objects
        self.njit_funcs = fastmath.get_njit_funcs('stumpy')
        self.njit_modules = set(mod_name for mod_name, func_name in self.njit_funcs)
        self.njit_nodes = {}
        self.ast_modules = {}
        
        stumpy_path = pathlib.Path('__file__').parent / "stumpy"
        filepaths = sorted(f for f in pathlib.Path(stumpy_path).iterdir() if f.is_file())
        ignore = ["__init__.py", "__pycache__"]
        
        for filepath in filepaths:
            file_name = filepath.name
            if (
                file_name not in ignore 
                and not file_name.startswith("gpu")
                and str(filepath).endswith(".py")
            ):
                module_name = file_name.replace(".py", "")
                file_contents = ""
                with open(filepath, encoding="utf8") as f:
                    file_contents = f.read()
                self.ast_modules[module_name] = ast.parse(file_contents)
        
                for node in self.ast_modules[module_name].body:
                    if isinstance(node, ast.FunctionDef):
                        func_name = node.name
                        if (module_name, func_name) in self.njit_funcs:
                            self.njit_nodes[f'{module_name}.{func_name}'] = node


    def push_module(self, module_name):
        self.module_names.append(module_name)
        
    def pop_module(self):
        if self.module_names:
            self.module_names.pop()

    def push_call_stack(self, module_name, func_name):
        self.call_stack.append(f'{module_name}.{func_name}')
        
    def pop_call_stack(self):
        if self.call_stack:
            self.call_stack.pop()

    def goto_deeper_func(self, node):
        self.generic_visit(node)

    def goto_next_func(self, node):
        self.generic_visit(node)

    def push_out(self):
        unique = True
        for cs in self.out:
            if ' '.join(self.call_stack) in ' '.join(cs):
                unique = False
                break

        if unique:
            self.out.append(self.call_stack.copy())
        
    def visit_Call(self, node):
        callee_name = ast.unparse(node.func)

        if "." in callee_name:
            new_module_name, new_func_name = callee_name.split('.')[:2]

            if new_module_name in self.njit_modules:
                self.push_module(new_module_name)
        else:
            if self.module_names:
                new_module_name = self.module_names[-1]
                new_func_name = callee_name
                callee_name = f'{new_module_name}.{new_func_name}'

        if callee_name in self.njit_nodes.keys():
            callee_node = self.njit_nodes[callee_name]
            self.push_call_stack(new_module_name, new_func_name)
            self.goto_deeper_func(callee_node)
            self.pop_module()
            self.push_out()
            self.pop_call_stack()

        self.goto_next_func(node)


def get_njit_call_stacks():
    visitor = FunctionCallVisitor()
    
    for module_name in visitor.njit_modules:
        visitor.push_module(module_name)
        
        for node in visitor.ast_modules[module_name].body:
            if isinstance(node, ast.FunctionDef):
                func_name = node.name
                if (module_name, func_name) in visitor.njit_funcs:
                    visitor.push_call_stack(module_name, func_name)
                    visitor.visit(node)
                    visitor.pop_call_stack()

        visitor.pop_module()

    return visitor.out


if __name__ == '__main__':
    for cs in get_njit_call_stacks():
        print(cs)

The output should be:

['core._calculate_squared_distance_profile', 'core._calculate_squared_distance']
['maamp._compute_multi_p_norm', 'core._apply_exclusion_zone']
['stump._compute_diagonal', 'core._shift_insert_at_index']
['stump._stump', 'core._count_diagonal_ndist']
['stump._stump', 'core._get_array_ranges']
['stump._stump', 'stump._compute_diagonal', 'core._shift_insert_at_index']
['stump._stump', 'core._merge_topk_ρI']
['aamp._compute_diagonal', 'core._shift_insert_at_index']
['aamp._aamp', 'core._count_diagonal_ndist']
['aamp._aamp', 'core._get_array_ranges']
['aamp._aamp', 'aamp._compute_diagonal', 'core._shift_insert_at_index']
['aamp._aamp', 'core._merge_topk_PI']
['scraamp._compute_PI', 'core._p_norm_distance_profile', 'core._sliding_dot_product']
['scraamp._compute_PI', 'core._apply_exclusion_zone']
['scraamp._compute_PI', 'core._shift_insert_at_index']
['scraamp._prescraamp', 'core._get_ranges']
['scraamp._prescraamp', 'core._merge_topk_PI']
['mstump._compute_multi_D', 'core._calculate_squared_distance_profile', 'core._calculate_squared_distance']
['mstump._compute_multi_D', 'core._apply_exclusion_zone']
['scrump._compute_PI', 'core._sliding_dot_product']
['scrump._compute_PI', 'core._calculate_squared_distance_profile', 'core._calculate_squared_distance']
['scrump._compute_PI', 'core._apply_exclusion_zone']
['scrump._compute_PI', 'core._shift_insert_at_index']
['scrump._prescrump', 'core._get_ranges']
['scrump._prescrump', 'core._merge_topk_PI']

One immediate observation is that our njit call stacks are very, very flat/shallow, which is a GREAT thing! In my head, I was dreading to find that we have very deeply nested call stacks and so this is a pleasant surprise. Please verify that we haven't missed any edge cases.

Hopefully, you are able to take these call stacks and check the fastmath flags accordingly. Please let me know if you have any questions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Check fastmath=True
2 participants