Skip to content

Commit

Permalink
Merge pull request #724 from google/google_sync
Browse files Browse the repository at this point in the history
Google sync
  • Loading branch information
rchen152 authored Nov 3, 2020
2 parents b3c6da8 + 9327ad1 commit 7755e3c
Show file tree
Hide file tree
Showing 30 changed files with 563 additions and 183 deletions.
101 changes: 94 additions & 7 deletions docs/developers/typegraph.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# The Typegraph

<!--*
freshness: { owner: 'tsudol' reviewed: '2020-09-11' }
freshness: { owner: 'tsudol' reviewed: '2020-11-02' }
*-->

<!--ts-->
Expand All @@ -12,8 +12,10 @@ freshness: { owner: 'tsudol' reviewed: '2020-09-11' }
* [Default Data](#default-data)
* [Sets in the Typegraph](#sets-in-the-typegraph)
* [std::set or std::unordered_set?](#stdset-or-stdunordered_set)
* [Reachability](#reachability)
* [Implementation](#implementation)

<!-- Added by: rechen, at: 2020-09-14T12:22-07:00 -->
<!-- Added by: tsudol, at: 2020-11-02T11:52-08:00 -->

<!--te-->

Expand Down Expand Up @@ -150,12 +152,97 @@ cannot guarantee that the items in a set won't be modified, which violates the
`std::unordered_set` invariants. Finally, iterating over and comparing sets is
more efficient for `std::set`.

<!--
Reachability (reachable.h and .cc)
## Reachability

A basic operation when analyzing the CFG is checking if there is a path between
two nodes. This is used when checking if a binding's origin is visible at a
given node, for example. Queries always progress from a child node to the
parent, in the opposite direction of the edges of the CFG. To make these lookups
faster, the `Program` class uses a **backwards reachability** cache that's
handled by the `ReachabilityAnalyzer` class. Backwards reachability means that
the reachability analysis proceeds from child to parent, in the opposite
direction of the directed edges of the graph.

The `ReachabilityAnalyzer` tracks the list of adjacent nodes for each CFG node.
For node `i`, `reacahble[i][j]` indicates whether `j` is reachable from `i`.
When an edge is added to the reachability cache, the cache updates every node to
see if connections are possible.

- we know it's to make parts of the solver faster.
- But how
- and does it actually?
For example, consider three nodes `A -> B -> C`. The cache would be initialised
as:

```
reachable[A] = [True, False, False]
reachable[B] = [False, True, False]
reacahble[C] = [False, False, True]
```

(A node can always be reached from itself.) Since there's a connection from `B`
to `C`, the backwards edge `C -> B` is added to the cache:
```
reachable[A] = [True, False, False]
reachable[B] = [False, True, False]
reacahble[C] = [False, True, True]
```
Then the backwards edge `B -> A` is added:
```
reachable[A] = [True, False, False]
reachable[B] = [True, True, False]
reacahble[C] = [True, True, True]
```
The cache now shows that `A` can only reach itself, `B` can reach itself and
`A`, and `C` can reach all three nodes.
### Implementation
Every CFG node has an ID number of type `size_t`, which these docs will assume
is 64 bits. A naive implementation would make `reachable[n]` a list of the IDs
of the nodes reachable from node `n`, but at 8 bytes per node and potentially
thousands of nodes, that will get too expensive quickly. Instead, nodes are
split into _buckets_ based on their ID. Each bucket tracks 64 nodes, so the top
58 bits of the ID determine the node's bucket and the bottom 6 bits determine
the node's index within the bucket.
These buckets are implemented using bit vectors. Since a bucket covers 64 nodes,
it's represented by an `int64_t`. A node's reachable list is then a
`std::vector<int64_t>`, and the reachability cache that tracks every node's
reachable list is a `std::vector<std::vector<int64_t>>`. All together, node `j`
is reachable from node `i` if `reachable[i][bucket(j)] & bit(j) == 1`.

For example, consider a CFG with 100 nodes, which have IDs from 0 to 99. There
will be 100 entries in the reachability cache, one for each node, such that
`reachable[n]` corresponds to the nodes that are backwards reachable from node
`n`. `reachable[n]` is a `std::vector<int64_t>` with two elements, the first
tracking nodes 0 - 63 and the second tracking nodes 64 - 99, with room to track
another 28 nodes.

Let's check if node 75 is backwards reachable from node 30: `is_reachable(30,
75)`.

1. Find node 75's bucket: `bucket = 75 >> 6 = 1`. (This is equivalent to `75 /
64`.)
1. Find node 75's bit: `bit = 1 << (75 & 0x3F) = 1 << 11`.
1. Check node 30's reachability: `reachable[30][bucket] & bit`.

Adding a new node to the reachability cache is accomplished by adding another
entry to `reachable`. There is a catch: the cache must check if a new bucket is
needed to track the new node. If one is, then every node's reachable list is
extended with one more bucket. Finally, `reachable[n][bucket(n)] = bit[n]` is
set, indicating that node `n` is reachable from itself.

Adding an edge between two nodes is only slightly more complex. Because the
cache tracks reachability, adding an edge may update every node. Remember the `A
-> B -> C` example previously: `add_edge(B, A)` updated both `B` and `C`,
because `B` is reachable from `C`. For `add_edge(src, dst)`, the cache checks if
`src` is reachable from each node `i`, and if so, bitwise-ORs `reachable[i]` and
`reachable[dst]` together. Because `src` is reachable from itself, this will
also update `reachable[src]` when `i == src`.

<!--
Hashing and Sets
Expand Down
1 change: 1 addition & 0 deletions pytype/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ py_library(
SRCS
pytd/abc_hierarchy.py
pytd/booleq.py
pytd/escape.py
pytd/mro.py
pytd/optimize.py
pytd/pep484.py
Expand Down
4 changes: 3 additions & 1 deletion pytype/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pytype import mixin
from pytype import utils
from pytype.pyc import opcodes
from pytype.pytd import escape
from pytype.pytd import optimize
from pytype.pytd import pytd
from pytype.pytd import pytd_utils
Expand Down Expand Up @@ -3876,6 +3877,7 @@ def __init__(self, vm):
super().__init__("__build_class__", vm)

def call(self, node, _, args, alias_map=None):
args = args.simplify(node, self.vm)
funcvar, name = args.posargs[0:2]
if isinstance(args.namedargs, dict):
kwargs = args.namedargs
Expand Down Expand Up @@ -3971,7 +3973,7 @@ class Unknown(AtomicAbstractValue):
IGNORED_ATTRIBUTES = ["__get__", "__set__", "__getattribute__"]

def __init__(self, vm):
name = "~unknown%d" % Unknown._current_id
name = escape.unknown(Unknown._current_id)
super().__init__(name, vm)
self.members = datatypes.MonitorDict()
self.owner = None
Expand Down
11 changes: 4 additions & 7 deletions pytype/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytype import state as frame_state
from pytype import vm
from pytype.overlays import typing_overlay
from pytype.pytd import escape
from pytype.pytd import optimize
from pytype.pytd import pytd
from pytype.pytd import pytd_utils
Expand Down Expand Up @@ -555,12 +556,8 @@ def _call_traces_to_function(call_traces, name_transform=lambda x: x):
def _is_builtin(self, name, data):
return self._builtin_map.get(name) == data

def _pack_name(self, name):
"""Pack a name, for unpacking with type_match.unpack_name_of_partial()."""
return "~" + name.replace(".", "~")

def pytd_functions_for_call_traces(self):
return self._call_traces_to_function(self._calls, self._pack_name)
return self._call_traces_to_function(self._calls, escape.pack_partial)

def pytd_classes_for_call_traces(self):
class_to_records = collections.defaultdict(list)
Expand All @@ -577,7 +574,7 @@ def pytd_classes_for_call_traces(self):
for cls, call_records in class_to_records.items():
full_name = cls.module + "." + cls.name if cls.module else cls.name
classes.append(pytd.Class(
name=self._pack_name(full_name),
name=escape.pack_partial(full_name),
metaclass=None,
parents=(pytd.NamedType("__builtin__.object"),), # not used in solver
methods=tuple(self._call_traces_to_function(call_records)),
Expand Down Expand Up @@ -605,7 +602,7 @@ def compute_types(self, defs):
ty = ty.Visit(optimize.CombineReturnsAndExceptions())
ty = ty.Visit(optimize.PullInMethodClasses())
ty = ty.Visit(visitors.DefaceUnresolved(
[ty, self.loader.concat_all()], "~unknown"))
[ty, self.loader.concat_all()], escape.UNKNOWN))
return ty.Visit(visitors.AdjustTypeParameters())

def _check_return(self, node, actual, formal):
Expand Down
11 changes: 6 additions & 5 deletions pytype/convert_structural.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging

from pytype.pytd import booleq
from pytype.pytd import escape
from pytype.pytd import optimize
from pytype.pytd import pytd
from pytype.pytd import pytd_utils
Expand All @@ -18,8 +19,8 @@
MAX_DEPTH = 1

is_unknown = type_match.is_unknown
is_partial = type_match.is_partial
is_complete = type_match.is_complete
is_partial = escape.is_partial
is_complete = escape.is_complete


class FlawedQuery(Exception): # pylint: disable=g-bad-exception-name
Expand Down Expand Up @@ -104,7 +105,7 @@ def match_call_record(self, matcher, solver, call_record, complete):
else:
faulty_signature = ""
raise FlawedQuery("Bad call\n%s%s\nagainst:\n%s" % (
type_match.unpack_name_of_partial(call_record.name),
escape.unpack_partial(call_record.name),
faulty_signature, pytd_utils.Print(complete)))
solver.always_true(formula)

Expand Down Expand Up @@ -150,7 +151,7 @@ def solve(self):
# also solve partial equations
for complete in complete_classes.union(self.builtins.classes):
for partial in partial_classes:
if type_match.unpack_name_of_partial(partial.name) == complete.name:
if escape.unpack_partial(partial.name) == complete.name:
self.match_partial_against_complete(
factory_partial, solver_partial, partial, complete)

Expand All @@ -163,7 +164,7 @@ def solve(self):
complete_functions.add(f)
for partial in partial_functions:
for complete in complete_functions.union(self.builtins.functions):
if type_match.unpack_name_of_partial(partial.name) == complete.name:
if escape.unpack_partial(partial.name) == complete.name:
self.match_call_record(
factory_partial, solver_partial, partial, complete)

Expand Down
4 changes: 2 additions & 2 deletions pytype/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,8 @@ def _pytd_print(self, pytd_type):
pytd_type.Visit(visitors.RemoveUnknownClasses()))))
# Clean up autogenerated namedtuple names, e.g. "namedtuple-X-a-_0-c"
# becomes just "X", by extracting out just the type name.
if "namedtuple-" in name:
return re.sub(r"\bnamedtuple-([^-]+)-[-_\w]*", r"\1", name)
if "namedtuple" in name:
return re.sub(r"\bnamedtuple[-_]([^-_]+)[-_\w]*", r"\1", name)
nested_class_match = re.search(r"_(?:\w+)_DOT_", name)
if nested_class_match:
# Pytype doesn't have true support for nested classes. Instead, for
Expand Down
15 changes: 14 additions & 1 deletion pytype/pyi/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,9 @@ def _build_type_decl_unit(self, defs):
assert not aliases # We handle top-level aliases in add_alias_or_constant.
constants.extend(self._constants)

if self._ast_name == "__builtin__":
constants.extend(_builtin_keyword_constants())

generated_classes = sum(self._generated_classes.values(), [])

classes = generated_classes + classes
Expand Down Expand Up @@ -1043,7 +1046,7 @@ def new_named_tuple(self, base_name, fields):
fields = [(_handle_string_literal(n), t) for n, t in fields]
# Handle previously defined NamedTuples with the same name
prev_list = self._generated_classes[base_name]
class_name = "namedtuple-%s-%d" % (base_name, len(prev_list))
class_name = "namedtuple_%s_%d" % (base_name, len(prev_list))
class_parent = self._heterogeneous_tuple(pytd.NamedType("tuple"),
tuple(t for _, t in fields))
class_constants = tuple(pytd.Constant(n, t) for n, t in fields)
Expand Down Expand Up @@ -1674,3 +1677,13 @@ def _handle_string_literal(value):
if not match:
return value
return match.groups()[1][1:-1]


def _builtin_keyword_constants():
defs = [
("True", "bool"),
("False", "bool"),
("None", "NoneType"),
("__debug__", "bool")
]
return [pytd.Constant(name, pytd.NamedType(typ)) for name, typ in defs]
Loading

0 comments on commit 7755e3c

Please sign in to comment.