Skip to content

Commit

Permalink
Pass a context argument to astroid.Arguments to prevent recursion…
Browse files Browse the repository at this point in the history
… errors

Close pylint-dev/pylint#3414
  • Loading branch information
PCManticore committed Mar 2, 2020
1 parent 5bde219 commit 9543362
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 14 deletions.
4 changes: 4 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Release Date: TBA

Close PyCQA/pylint#3417

* Pass a context argument to ``astroid.Arguments`` to prevent recursion errors

Close PyCQA/pylint#3414

* Numpy `datetime64.astype` return value is inferred as a `ndarray`.

Close PyCQA/pylint#3332
Expand Down
41 changes: 28 additions & 13 deletions astroid/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,27 @@ class CallSite:
It needs a call context, which contains the arguments and the
keyword arguments that were passed into a given call site.
In order to infer what an argument represents, call
:meth:`infer_argument` with the corresponding function node
and the argument name.
In order to infer what an argument represents, call :meth:`infer_argument`
with the corresponding function node and the argument name.
:param callcontext:
An instance of :class:`astroid.context.CallContext`, that holds
the arguments for the call site.
:param argument_context_map:
Additional contexts per node, passed in from :attr:`astroid.context.Context.extra_context`
:param context:
An instance of :class:`astroid.context.Context`.
"""

def __init__(self, callcontext, argument_context_map=None):
def __init__(self, callcontext, argument_context_map=None, context=None):
if argument_context_map is None:
argument_context_map = {}
self.argument_context_map = argument_context_map
args = callcontext.args
keywords = callcontext.keywords
self.duplicated_keywords = set()
self._unpacked_args = self._unpack_args(args)
self._unpacked_kwargs = self._unpack_keywords(keywords)
self._unpacked_args = self._unpack_args(args, context=context)
self._unpacked_kwargs = self._unpack_keywords(keywords, context=context)

self.positional_arguments = [
arg for arg in self._unpacked_args if arg is not util.Uninferable
Expand All @@ -45,10 +52,18 @@ def __init__(self, callcontext, argument_context_map=None):
}

@classmethod
def from_call(cls, call_node):
"""Get a CallSite object from the given Call node."""
def from_call(cls, call_node, context=None):
"""Get a CallSite object from the given Call node.
:param context:
An instance of :class:`astroid.context.Context` that will be used
to force a single inference path.
"""

# Determine the callcontext from the given `context` object if any.
context = context or contextmod.InferenceContext()
callcontext = contextmod.CallContext(call_node.args, call_node.keywords)
return cls(callcontext)
return cls(callcontext, context=context)

def has_invalid_arguments(self):
"""Check if in the current CallSite were passed *invalid* arguments
Expand All @@ -70,9 +85,9 @@ def has_invalid_keywords(self):
"""
return len(self.keyword_arguments) != len(self._unpacked_kwargs)

def _unpack_keywords(self, keywords):
def _unpack_keywords(self, keywords, context=None):
values = {}
context = contextmod.InferenceContext()
context = context or contextmod.InferenceContext()
context.extra_context = self.argument_context_map
for name, value in keywords:
if name is None:
Expand Down Expand Up @@ -110,9 +125,9 @@ def _unpack_keywords(self, keywords):
values[name] = value
return values

def _unpack_args(self, args):
def _unpack_args(self, args, context=None):
values = []
context = contextmod.InferenceContext()
context = context or contextmod.InferenceContext()
context.extra_context = self.argument_context_map
for arg in args:
if isinstance(arg, nodes.Starred):
Expand Down
2 changes: 1 addition & 1 deletion astroid/brain/brain_builtin_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def infer_dict(node, context=None):
If a case can't be inferred, we'll fallback to default inference.
"""
call = arguments.CallSite.from_call(node)
call = arguments.CallSite.from_call(node, context=context)
if call.has_invalid_arguments() or call.has_invalid_keywords():
raise UseInferenceDefault

Expand Down
11 changes: 11 additions & 0 deletions tests/unittest_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5555,5 +5555,16 @@ def bar2(cls, text):
assert isinstance(next(second_node.infer()), BoundMethod)


def test_infer_dict_passes_context():
code = """
k = {}
(_ for k in __(dict(**k)))
"""
node = extract_node(code)
inferred = next(node.infer())
assert isinstance(inferred, Instance)
assert inferred.qname() == "builtins.dict"


if __name__ == "__main__":
unittest.main()

0 comments on commit 9543362

Please sign in to comment.