Skip to content

Commit

Permalink
Add type annotations to bulk of python codebase
Browse files Browse the repository at this point in the history
There are no meaningful runtime changes in this commit.
  • Loading branch information
chadrik committed Oct 27, 2019
1 parent c6ec72d commit c25792c
Show file tree
Hide file tree
Showing 92 changed files with 3,006 additions and 603 deletions.
106 changes: 96 additions & 10 deletions sdks/python/apache_beam/coders/coder_impl.py

Large diffs are not rendered by default.

124 changes: 114 additions & 10 deletions sdks/python/apache_beam/coders/coders.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion sdks/python/apache_beam/coders/observable_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@

import logging
import unittest
from typing import List
from typing import Optional

from apache_beam.coders import observable


class ObservableMixinTest(unittest.TestCase):
observed_count = 0
observed_sum = 0
observed_keys = []
observed_keys = [] # type: List[Optional[str]]

def observer(self, value, key=None):
self.observed_count += 1
Expand Down
14 changes: 13 additions & 1 deletion sdks/python/apache_beam/coders/slow_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import sys
from builtins import chr
from builtins import object
from typing import List


class OutputStream(object):
Expand All @@ -33,10 +34,11 @@ class OutputStream(object):
A pure Python implementation of stream.OutputStream."""

def __init__(self):
self.data = []
self.data = [] # type: List[bytes]
self.byte_count = 0

def write(self, b, nested=False):
# type: (bytes, bool) -> None
assert isinstance(b, bytes)
if nested:
self.write_var_int64(len(b))
Expand All @@ -48,6 +50,7 @@ def write_byte(self, val):
self.byte_count += 1

def write_var_int64(self, v):
# type: (int) -> None
if v < 0:
v += 1 << 64
if v <= 0:
Expand All @@ -74,12 +77,15 @@ def write_bigendian_double(self, v):
self.write(struct.pack('>d', v))

def get(self):
# type: () -> bytes
return b''.join(self.data)

def size(self):
# type: () -> int
return self.byte_count

def _clear(self):
# type: () -> None
self.data = []
self.byte_count = 0

Expand All @@ -95,6 +101,7 @@ def __init__(self):
self.count = 0

def write(self, byte_array, nested=False):
# type: (bytes, bool) -> None
blen = len(byte_array)
if nested:
self.write_var_int64(blen)
Expand All @@ -119,6 +126,7 @@ class InputStream(object):
A pure Python implementation of stream.InputStream."""

def __init__(self, data):
# type: (bytes) -> None
self.data = data
self.pos = 0

Expand All @@ -139,18 +147,22 @@ def size(self):
return len(self.data) - self.pos

def read(self, size):
# type: (int) -> bytes
self.pos += size
return self.data[self.pos - size : self.pos]

def read_all(self, nested):
# type: (bool) -> bytes
return self.read(self.read_var_int64() if nested else self.size())

def read_byte_py2(self):
# type: () -> int
self.pos += 1
# mypy tests against python 3.x, where this is an error:
return ord(self.data[self.pos - 1]) # type: ignore[arg-type]

def read_byte_py3(self):
# type: () -> int
self.pos += 1
return self.data[self.pos - 1]

Expand Down
4 changes: 3 additions & 1 deletion sdks/python/apache_beam/coders/standard_coders_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import sys
import unittest
from builtins import map
from typing import Dict
from typing import Tuple

import yaml

Expand Down Expand Up @@ -147,7 +149,7 @@ def json_value_parser(self, coder_spec):
# Used when --fix is passed.

fix = False
to_fix = {}
to_fix = {} # type: Dict[Tuple[int, bytes], bytes]

@classmethod
def tearDownClass(cls):
Expand Down
13 changes: 11 additions & 2 deletions sdks/python/apache_beam/coders/typecoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def MakeXyzs(v):
from __future__ import absolute_import

from builtins import object
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Type

from past.builtins import unicode

Expand All @@ -79,8 +84,8 @@ class CoderRegistry(object):
"""A coder registry for typehint/coder associations."""

def __init__(self, fallback_coder=None):
self._coders = {}
self.custom_types = []
self._coders = {} # type: Dict[Any, Type[coders.Coder]]
self.custom_types = [] # type: List[Any]
self.register_standard_coders(fallback_coder)

def register_standard_coders(self, fallback_coder):
Expand All @@ -97,9 +102,11 @@ def register_standard_coders(self, fallback_coder):
self._fallback_coder = fallback_coder or FirstOf(default_fallback_coders)

def _register_coder_internal(self, typehint_type, typehint_coder_class):
# type: (Any, Type[coders.Coder]) -> None
self._coders[typehint_type] = typehint_coder_class

def register_coder(self, typehint_type, typehint_coder_class):
# type: (Any, Type[coders.Coder]) -> None
if not isinstance(typehint_coder_class, type):
raise TypeError('Coder registration requires a coder class object. '
'Received %r instead.' % typehint_coder_class)
Expand All @@ -108,6 +115,7 @@ def register_coder(self, typehint_type, typehint_coder_class):
self._register_coder_internal(typehint_type, typehint_coder_class)

def get_coder(self, typehint):
# type: (Any) -> coders.Coder
coder = self._coders.get(
typehint.__class__ if isinstance(typehint, typehints.TypeConstraint)
else typehint, None)
Expand Down Expand Up @@ -164,6 +172,7 @@ class FirstOf(object):
A class used to get the first matching coder from a list of coders."""

def __init__(self, coders):
# type: (Iterable[Type[coders.Coder]]) -> None
self._coders = coders

def from_type_hint(self, typehint, registry):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import string
import unittest
import uuid
from typing import TYPE_CHECKING
from typing import List

import pytz

Expand All @@ -47,8 +49,10 @@
_microseconds_from_datetime = lambda label_stamp: label_stamp
_datetime_from_microseconds = lambda micro: micro

if TYPE_CHECKING:
import google.cloud.bigtable.instance

EXISTING_INSTANCES = []
EXISTING_INSTANCES = [] # type: List[google.cloud.bigtable.instance.Instance]
LABEL_KEY = u'python-bigtable-beam'
label_stamp = datetime.datetime.utcnow().replace(tzinfo=UTC)
label_stamp_micros = _microseconds_from_datetime(label_stamp)
Expand Down
6 changes: 5 additions & 1 deletion sdks/python/apache_beam/internal/pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
import traceback
import types
import zlib
from typing import Any
from typing import Dict
from typing import Tuple

import dill

Expand Down Expand Up @@ -157,7 +160,7 @@ def save_module(pickler, obj):
# Pickle module dictionaries (commonly found in lambda's globals)
# by referencing their module.
old_save_module_dict = dill.dill.save_module_dict
known_module_dicts = {}
known_module_dicts = {} # type: Dict[int, Tuple[types.ModuleType, Dict[str, Any]]]

@dill.dill.register(dict)
def new_save_module_dict(pickler, obj):
Expand Down Expand Up @@ -227,6 +230,7 @@ def new_log_info(msg, *args, **kwargs):
# pickler.loads() being used for data, which results in an unnecessary base64
# encoding. This should be cleaned up.
def dumps(o, enable_trace=True):
# type: (...) -> bytes
"""For internal use only; no backwards-compatibility guarantees."""

try:
Expand Down
16 changes: 15 additions & 1 deletion sdks/python/apache_beam/internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@
import weakref
from builtins import object
from multiprocessing.pool import ThreadPool
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Tuple
from typing import Type
from typing import TypeVar
from typing import Union

T = TypeVar('T')


class ArgumentPlaceholder(object):
Expand Down Expand Up @@ -62,7 +72,11 @@ def __hash__(self):
return hash(type(self))


def remove_objects_from_args(args, kwargs, pvalue_class):
def remove_objects_from_args(args, # type: Iterable[Any]
kwargs, # type: Dict[str, Any]
pvalue_class # type: Union[Type[T], Tuple[Type[T], ...]]
):
# type: (...) -> Tuple[List[Any], Dict[str, Any], List[T]]
"""For internal use only; no backwards-compatibility guarantees.
Replaces all objects of a given type in args/kwargs with a placeholder.
Expand Down
3 changes: 2 additions & 1 deletion sdks/python/apache_beam/io/avroio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import tempfile
import unittest
from builtins import range
from typing import List
import sys

# patches unittest.TestCase to be python3 compatible
Expand Down Expand Up @@ -90,7 +91,7 @@

class AvroBase(object):

_temp_files = []
_temp_files = [] # type: List[str]

def __init__(self, methodName='runTest'):
super(AvroBase, self).__init__(methodName)
Expand Down
16 changes: 12 additions & 4 deletions sdks/python/apache_beam/io/filebasedsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

from __future__ import absolute_import

from typing import Callable

from past.builtins import long
from past.builtins import unicode

Expand Down Expand Up @@ -71,7 +73,7 @@ def __init__(self,
file_pattern (str): the file glob to read a string or a
:class:`~apache_beam.options.value_provider.ValueProvider`
(placeholder to inject a runtime value).
min_bundle_size (str): minimum size of bundles that should be generated
min_bundle_size (int): minimum size of bundles that should be generated
when performing initial splitting on this source.
compression_type (str): Used to handle compressed output files.
Typical value is :attr:`CompressionTypes.AUTO
Expand Down Expand Up @@ -128,6 +130,7 @@ def display_data(self):

@check_accessible(['_pattern'])
def _get_concat_source(self):
# type: () -> concat_source.ConcatSource

This comment has been minimized.

Copy link
@robertwb

robertwb Oct 28, 2019

Contributor

Why can't things like this be inferred?

This comment has been minimized.

Copy link
@chadrik

chadrik Oct 28, 2019

Author Contributor

There's no feature in mypy for inferring return types. ¯\_(ツ)_/¯

if self._concat_source is None:
pattern = self._pattern.get()

Expand Down Expand Up @@ -358,6 +361,7 @@ def process(self, element, *args, **kwargs):
class _ReadRange(DoFn):

def __init__(self, source_from_file):
# type: (Callable[[str], iobase.BoundedSource]) -> None

This comment has been minimized.

Copy link
@robertwb

robertwb Oct 28, 2019

Contributor

-> None reads a bit odd (though correct i guess).

This comment has been minimized.

Copy link
@chadrik

chadrik Oct 28, 2019

Author Contributor

Yeah, it's a bit pedantic. By default mypy assigns Any to an unannotated type, but in the case of __init__ nothing would break or error if it were omitted. However, it's good for people to get into the habit of fully annotating their functions so that eventually we can reach complete coverage by enabling --disallow-untyped-defs:

  --disallow-untyped-defs   Disallow defining functions without type
                            annotations or with incomplete type annotations
                            (inverse: --allow-untyped-defs)
self._source_from_file = source_from_file

def process(self, element, *args, **kwargs):
Expand All @@ -380,9 +384,13 @@ class ReadAllFiles(PTransform):
read a PCollection of files.
"""

def __init__(
self, splittable, compression_type, desired_bundle_size, min_bundle_size,
source_from_file):
def __init__(self,
splittable, # type: bool
compression_type,
desired_bundle_size, # type: int
min_bundle_size, # type: int
source_from_file, # type: Callable[[str], iobase.BoundedSource]
):
"""
Args:
splittable: If False, files won't be split into sub-ranges. If True,
Expand Down
Loading

0 comments on commit c25792c

Please sign in to comment.