diff --git a/airflow/hooks/base_hook.py b/airflow/hooks/base_hook.py index 7793dad74498..cd8a8cd6f2fa 100644 --- a/airflow/hooks/base_hook.py +++ b/airflow/hooks/base_hook.py @@ -18,10 +18,10 @@ """Base class for all hooks""" import logging import random -from typing import List +from typing import Any, List from airflow import secrets -from airflow.models import Connection +from airflow.models.connection import Connection from airflow.utils.log.logging_mixin import LoggingMixin log = logging.getLogger(__name__) @@ -82,6 +82,6 @@ def get_hook(cls, conn_id: str) -> "BaseHook": connection = cls.get_connection(conn_id) return connection.get_hook() - def get_conn(self): + def get_conn(self) -> Any: """Returns connection for the hook.""" raise NotImplementedError() diff --git a/airflow/providers/apache/cassandra/sensors/record.py b/airflow/providers/apache/cassandra/sensors/record.py index 268ca44523f9..39bb1368260b 100644 --- a/airflow/providers/apache/cassandra/sensors/record.py +++ b/airflow/providers/apache/cassandra/sensors/record.py @@ -20,7 +20,7 @@ of a record in a Cassandra cluster. """ -from typing import Dict +from typing import Any, Dict, Tuple from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook from airflow.sensors.base_sensor_operator import BaseSensorOperator @@ -52,7 +52,8 @@ class CassandraRecordSensor(BaseSensorOperator): template_fields = ('table', 'keys') @apply_defaults - def __init__(self, table: str, keys: Dict[str, str], cassandra_conn_id: str, *args, **kwargs) -> None: + def __init__(self, table: str, keys: Dict[str, str], cassandra_conn_id: str, + *args: Tuple[Any, ...], **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.cassandra_conn_id = cassandra_conn_id self.table = table diff --git a/airflow/providers/apache/cassandra/sensors/table.py b/airflow/providers/apache/cassandra/sensors/table.py index 47b9679e989b..41a7aa4afa2e 100644 --- a/airflow/providers/apache/cassandra/sensors/table.py +++ b/airflow/providers/apache/cassandra/sensors/table.py @@ -21,7 +21,7 @@ of a table in a Cassandra cluster. """ -from typing import Dict +from typing import Any, Dict, Tuple from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook from airflow.sensors.base_sensor_operator import BaseSensorOperator @@ -49,12 +49,13 @@ class CassandraTableSensor(BaseSensorOperator): template_fields = ('table',) @apply_defaults - def __init__(self, table: str, cassandra_conn_id: str, *args, **kwargs) -> None: + def __init__(self, table: str, cassandra_conn_id: str, *args: Tuple[Any, ...], + **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.cassandra_conn_id = cassandra_conn_id self.table = table - def poke(self, context: Dict) -> bool: + def poke(self, context: Dict[Any, Any]) -> bool: self.log.info('Sensor check existence of table: %s', self.table) hook = CassandraHook(self.cassandra_conn_id) return hook.table_exists(self.table) diff --git a/airflow/providers/apache/druid/hooks/druid.py b/airflow/providers/apache/druid/hooks/druid.py index 0537ca9fa1ba..c55ad11553b4 100644 --- a/airflow/providers/apache/druid/hooks/druid.py +++ b/airflow/providers/apache/druid/hooks/druid.py @@ -17,6 +17,7 @@ # under the License. import time +from typing import Any, Dict, Iterable, Optional, Tuple import requests from pydruid.db import connect @@ -43,11 +44,13 @@ class DruidHook(BaseHook): :param max_ingestion_time: The maximum ingestion time before assuming the job failed :type max_ingestion_time: int """ + def __init__( - self, - druid_ingest_conn_id='druid_ingest_default', - timeout=1, - max_ingestion_time=None): + self, + druid_ingest_conn_id: str = 'druid_ingest_default', + timeout: int = 1, + max_ingestion_time: Optional[int] = None + ) -> None: super().__init__() self.druid_ingest_conn_id = druid_ingest_conn_id @@ -58,7 +61,7 @@ def __init__( if self.timeout < 1: raise ValueError("Druid timeout should be equal or greater than 1") - def get_conn_url(self): + def get_conn_url(self) -> str: """ Get Druid connection url """ @@ -70,7 +73,7 @@ def get_conn_url(self): return "{conn_type}://{host}:{port}/{endpoint}".format( conn_type=conn_type, host=host, port=port, endpoint=endpoint) - def get_auth(self): + def get_auth(self) -> Optional[requests.auth.HTTPBasicAuth]: """ Return username and password from connections tab as requests.auth.HTTPBasicAuth object. @@ -84,7 +87,7 @@ def get_auth(self): else: return None - def submit_indexing_job(self, json_index_spec: str): + def submit_indexing_job(self, json_index_spec: Dict[str, Any]) -> None: """ Submit Druid ingestion job """ @@ -144,11 +147,11 @@ class DruidDbApiHook(DbApiHook): default_conn_name = 'druid_broker_default' supports_autocommit = False - def get_conn(self): + def get_conn(self) -> connect: """ Establish a connection to druid broker. """ - conn = self.get_connection(self.druid_broker_conn_id) # pylint: disable=no-member + conn = self.get_connection(self.conn_name_attr) druid_broker_conn = connect( host=conn.host, port=conn.port, @@ -160,7 +163,7 @@ def get_conn(self): self.log.info('Get the connection to druid broker on %s using user %s', conn.host, conn.login) return druid_broker_conn - def get_uri(self): + def get_uri(self) -> str: """ Get the connection uri for druid broker. @@ -175,8 +178,11 @@ def get_uri(self): return '{conn_type}://{host}/{endpoint}'.format( conn_type=conn_type, host=host, endpoint=endpoint) - def set_autocommit(self, conn, autocommit): + def set_autocommit(self, conn: connect, autocommit: bool) -> NotImplemented: raise NotImplementedError() - def insert_rows(self, table, rows, target_fields=None, commit_every=1000): + def insert_rows(self, table: str, rows: Iterable[Tuple[str]], + target_fields: Optional[Iterable[str]] = None, + commit_every: int = 1000, replace: bool = False, + **kwargs: Any) -> NotImplemented: raise NotImplementedError() diff --git a/airflow/providers/apache/druid/operators/druid.py b/airflow/providers/apache/druid/operators/druid.py index 2d768561b0ad..194aa692f6be 100644 --- a/airflow/providers/apache/druid/operators/druid.py +++ b/airflow/providers/apache/druid/operators/druid.py @@ -17,6 +17,7 @@ # under the License. import json +from typing import Any, Dict, Optional from airflow.models import BaseOperator from airflow.providers.apache.druid.hooks.druid import DruidHook @@ -37,16 +38,16 @@ class DruidOperator(BaseOperator): template_ext = ('.json',) @apply_defaults - def __init__(self, json_index_file, - druid_ingest_conn_id='druid_ingest_default', - max_ingestion_time=None, - *args, **kwargs): + def __init__(self, json_index_file: str, + druid_ingest_conn_id: str = 'druid_ingest_default', + max_ingestion_time: Optional[int] = None, + *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.json_index_file = json_index_file self.conn_id = druid_ingest_conn_id self.max_ingestion_time = max_ingestion_time - def execute(self, context): + def execute(self, context: Dict[Any, Any]) -> None: hook = DruidHook( druid_ingest_conn_id=self.conn_id, max_ingestion_time=self.max_ingestion_time diff --git a/airflow/providers/apache/druid/operators/druid_check.py b/airflow/providers/apache/druid/operators/druid_check.py index eb168770af50..b5a14b3a5e47 100644 --- a/airflow/providers/apache/druid/operators/druid_check.py +++ b/airflow/providers/apache/druid/operators/druid_check.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Any, Dict, Optional from airflow.exceptions import AirflowException from airflow.operators.check_operator import CheckOperator @@ -57,21 +58,22 @@ class DruidCheckOperator(CheckOperator): @apply_defaults def __init__( - self, - sql: str, - druid_broker_conn_id: str = 'druid_broker_default', - *args, **kwargs) -> None: + self, + sql: str, + druid_broker_conn_id: str = 'druid_broker_default', + *args: Any, **kwargs: Any + ) -> None: super().__init__(sql=sql, *args, **kwargs) self.druid_broker_conn_id = druid_broker_conn_id self.sql = sql - def get_db_hook(self): + def get_db_hook(self) -> DruidDbApiHook: """ Return the druid db api hook. """ return DruidDbApiHook(druid_broker_conn_id=self.druid_broker_conn_id) - def get_first(self, sql): + def get_first(self, sql: str) -> Any: """ Executes the druid sql to druid broker and returns the first resulting row. @@ -82,7 +84,7 @@ def get_first(self, sql): cur.execute(sql) return cur.fetchone() - def execute(self, context=None): + def execute(self, context: Optional[Dict[Any, Any]] = None) -> None: self.log.info('Executing SQL check: %s', self.sql) record = self.get_first(self.sql) self.log.info("Record: %s", str(record)) diff --git a/airflow/providers/apache/druid/transfers/hive_to_druid.py b/airflow/providers/apache/druid/transfers/hive_to_druid.py index e8ba9bc2883a..88f0ae6e3349 100644 --- a/airflow/providers/apache/druid/transfers/hive_to_druid.py +++ b/airflow/providers/apache/druid/transfers/hive_to_druid.py @@ -20,7 +20,7 @@ This module contains operator to move data from Hive to Druid. """ -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from airflow.models import BaseOperator from airflow.providers.apache.druid.hooks.druid import DruidHook @@ -84,23 +84,25 @@ class HiveToDruidOperator(BaseOperator): @apply_defaults def __init__( # pylint: disable=too-many-arguments - self, - sql: str, - druid_datasource: str, - ts_dim: str, - metric_spec: Optional[List] = None, - hive_cli_conn_id: str = 'hive_cli_default', - druid_ingest_conn_id: str = 'druid_ingest_default', - metastore_conn_id: str = 'metastore_default', - hadoop_dependency_coordinates: Optional[List[str]] = None, - intervals: Optional[List] = None, - num_shards: float = -1, - target_partition_size: int = -1, - query_granularity: str = "NONE", - segment_granularity: str = "DAY", - hive_tblproperties: Optional[Dict] = None, - job_properties: Optional[Dict] = None, - *args, **kwargs) -> None: + self, + sql: str, + druid_datasource: str, + ts_dim: str, + metric_spec: Optional[List[Any]] = None, + hive_cli_conn_id: str = 'hive_cli_default', + druid_ingest_conn_id: str = 'druid_ingest_default', + metastore_conn_id: str = 'metastore_default', + hadoop_dependency_coordinates: Optional[List[str]] = None, + intervals: Optional[List[Any]] = None, + num_shards: float = -1, + target_partition_size: int = -1, + query_granularity: str = "NONE", + segment_granularity: str = "DAY", + hive_tblproperties: Optional[Dict[Any, Any]] = None, + job_properties: Optional[Dict[Any, Any]] = None, + *args: Any, + **kwargs: Any + ) -> None: super().__init__(*args, **kwargs) self.sql = sql self.druid_datasource = druid_datasource @@ -120,7 +122,7 @@ def __init__( # pylint: disable=too-many-arguments self.hive_tblproperties = hive_tblproperties or {} self.job_properties = job_properties - def execute(self, context): + def execute(self, context: Dict[str, Any]) -> None: hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.log.info("Extracting data from Hive") hive_table = 'druid.' + context['task_instance_key_str'].replace('.', '_') @@ -172,7 +174,8 @@ def execute(self, context): hql = "DROP TABLE IF EXISTS {}".format(hive_table) hive.run_cli(hql) - def construct_ingest_query(self, static_path, columns): + def construct_ingest_query(self, static_path: str, + columns: List[str]) -> Dict[str, Any]: """ Builds an ingest query for an HDFS TSV load. @@ -199,7 +202,7 @@ def construct_ingest_query(self, static_path, columns): # or a metric, as the dimension columns dimensions = [c for c in columns if c not in metric_names and c != self.ts_dim] - ingest_query_dict = { + ingest_query_dict: Dict[str, Any] = { "type": "index_hadoop", "spec": { "dataSchema": { diff --git a/airflow/providers/apache/hdfs/hooks/hdfs.py b/airflow/providers/apache/hdfs/hooks/hdfs.py index 008f08d024fc..acd0b46aa0d7 100644 --- a/airflow/providers/apache/hdfs/hooks/hdfs.py +++ b/airflow/providers/apache/hdfs/hooks/hdfs.py @@ -16,12 +16,15 @@ # specific language governing permissions and limitations # under the License. """Hook for HDFS operations""" +from typing import Any, Optional + from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook try: from snakebite.client import AutoConfigClient, Client, HAClient, Namenode # pylint: disable=syntax-error + snakebite_loaded = True except ImportError: snakebite_loaded = False @@ -43,8 +46,12 @@ class HDFSHook(BaseHook): :param autoconfig: use snakebite's automatically configured client :type autoconfig: bool """ - def __init__(self, hdfs_conn_id='hdfs_default', proxy_user=None, - autoconfig=False): + + def __init__(self, + hdfs_conn_id: str = 'hdfs_default', + proxy_user: Optional[str] = None, + autoconfig: bool = False + ): super().__init__() if not snakebite_loaded: raise ImportError( @@ -56,7 +63,7 @@ def __init__(self, hdfs_conn_id='hdfs_default', proxy_user=None, self.proxy_user = proxy_user self.autoconfig = autoconfig - def get_conn(self): + def get_conn(self) -> Any: """ Returns a snakebite HDFSClient object. """ diff --git a/airflow/providers/apache/hdfs/hooks/webhdfs.py b/airflow/providers/apache/hdfs/hooks/webhdfs.py index 77c9a029d28a..a72c7b0823a5 100644 --- a/airflow/providers/apache/hdfs/hooks/webhdfs.py +++ b/airflow/providers/apache/hdfs/hooks/webhdfs.py @@ -18,12 +18,14 @@ """Hook for Web HDFS""" import logging import socket +from typing import Any, Optional from hdfs import HdfsError, InsecureClient from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook +from airflow.models.connection import Connection log = logging.getLogger(__name__) @@ -50,12 +52,14 @@ class WebHDFSHook(BaseHook): :type proxy_user: str """ - def __init__(self, webhdfs_conn_id='webhdfs_default', proxy_user=None): + def __init__(self, webhdfs_conn_id: str = 'webhdfs_default', + proxy_user: Optional[str] = None + ): super().__init__() self.webhdfs_conn_id = webhdfs_conn_id self.proxy_user = proxy_user - def get_conn(self): + def get_conn(self) -> Any: """ Establishes a connection depending on the security mode set via config or environment variable. :return: a hdfscli InsecureClient or KerberosClient object. @@ -66,7 +70,7 @@ def get_conn(self): raise AirflowWebHDFSHookException("Failed to locate the valid server.") return connection - def _find_valid_server(self): + def _find_valid_server(self) -> Any: connections = self.get_connections(self.webhdfs_conn_id) for connection in connections: host_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -88,7 +92,7 @@ def _find_valid_server(self): connection.host, hdfs_error) return None - def _get_client(self, connection): + def _get_client(self, connection: Connection) -> Any: connection_str = 'http://{host}:{port}'.format(host=connection.host, port=connection.port) if _kerberos_security_mode: @@ -99,7 +103,7 @@ def _get_client(self, connection): return client - def check_for_path(self, hdfs_path): + def check_for_path(self, hdfs_path: str) -> bool: """ Check for the existence of a path in HDFS by querying FileStatus. @@ -113,7 +117,9 @@ def check_for_path(self, hdfs_path): status = conn.status(hdfs_path, strict=False) return bool(status) - def load_file(self, source, destination, overwrite=True, parallelism=1, **kwargs): + def load_file(self, source: str, destination: str, + overwrite: bool = True, parallelism: int = 1, + **kwargs: Any) -> None: r""" Uploads a file to HDFS. diff --git a/airflow/providers/apache/hdfs/sensors/hdfs.py b/airflow/providers/apache/hdfs/sensors/hdfs.py index 1ca8ca3d650f..a8cabe2086c7 100644 --- a/airflow/providers/apache/hdfs/sensors/hdfs.py +++ b/airflow/providers/apache/hdfs/sensors/hdfs.py @@ -18,6 +18,7 @@ import logging import re import sys +from typing import Any, Dict, List, Optional, Pattern, Type from airflow import settings from airflow.providers.apache.hdfs.hooks.hdfs import HDFSHook @@ -36,14 +37,14 @@ class HdfsSensor(BaseSensorOperator): @apply_defaults def __init__(self, - filepath, - hdfs_conn_id='hdfs_default', - ignored_ext=None, - ignore_copying=True, - file_size=None, - hook=HDFSHook, - *args, - **kwargs): + filepath: str, + hdfs_conn_id: str = 'hdfs_default', + ignored_ext: Optional[List[str]] = None, + ignore_copying: bool = True, + file_size: Optional[int] = None, + hook: Type[HDFSHook] = HDFSHook, + *args: Any, + **kwargs: Any) -> None: super().__init__(*args, **kwargs) if ignored_ext is None: ignored_ext = ['_COPYING_'] @@ -55,7 +56,10 @@ def __init__(self, self.hook = hook @staticmethod - def filter_for_filesize(result, size=None): + def filter_for_filesize( + result: List[Dict[Any, Any]], + size: Optional[int] = None + ) -> List[Dict[Any, Any]]: """ Will test the filepath result and test if its size is at least self.filesize @@ -74,7 +78,11 @@ def filter_for_filesize(result, size=None): return result @staticmethod - def filter_for_ignored_ext(result, ignored_ext, ignore_copying): + def filter_for_ignored_ext( + result: List[Dict[Any, Any]], + ignored_ext: List[str], + ignore_copying: bool + ) -> List[Dict[Any, Any]]: """ Will filter if instructed to do so the result to remove matching criteria @@ -98,7 +106,7 @@ def filter_for_ignored_ext(result, ignored_ext, ignore_copying): log.debug('HdfsSensor.poke: after ext filter result is %s', result) return result - def poke(self, context): + def poke(self, context: Dict[Any, Any]) -> bool: """Get a snakebite client connection and check for file.""" sb_client = self.hook(self.hdfs_conn_id).get_conn() self.log.info('Poking for file %s', self.filepath) @@ -125,14 +133,15 @@ class HdfsRegexSensor(HdfsSensor): """ Waits for matching files by matching on regex """ + def __init__(self, - regex, - *args, - **kwargs): + regex: Pattern[str], + *args: Any, + **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.regex = regex - def poke(self, context): + def poke(self, context: Dict[Any, Any]) -> bool: """ poke matching files in a directory with self.regex @@ -155,14 +164,15 @@ class HdfsFolderSensor(HdfsSensor): """ Waits for a non-empty directory """ + def __init__(self, - be_empty=False, - *args, - **kwargs): + be_empty: bool = False, + *args: Any, + **kwargs: Any): super().__init__(*args, **kwargs) self.be_empty = be_empty - def poke(self, context): + def poke(self, context: Dict[str, Any]) -> bool: """ poke for a non empty directory diff --git a/airflow/providers/apache/hdfs/sensors/web_hdfs.py b/airflow/providers/apache/hdfs/sensors/web_hdfs.py index 90e4b928fd1e..74e6f22de273 100644 --- a/airflow/providers/apache/hdfs/sensors/web_hdfs.py +++ b/airflow/providers/apache/hdfs/sensors/web_hdfs.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Any, Dict from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults @@ -28,15 +29,15 @@ class WebHdfsSensor(BaseSensorOperator): @apply_defaults def __init__(self, - filepath, - webhdfs_conn_id='webhdfs_default', - *args, - **kwargs): + filepath: str, + webhdfs_conn_id: str = 'webhdfs_default', + *args: Any, + **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.filepath = filepath self.webhdfs_conn_id = webhdfs_conn_id - def poke(self, context): + def poke(self, context: Dict[Any, Any]) -> bool: from airflow.providers.apache.hdfs.hooks.webhdfs import WebHDFSHook hook = WebHDFSHook(self.webhdfs_conn_id) self.log.info('Poking for file %s', self.filepath) diff --git a/airflow/providers/apache/hive/hooks/hive.py b/airflow/providers/apache/hive/hooks/hive.py index f79fce1f1a9f..f2e9999316dd 100644 --- a/airflow/providers/apache/hive/hooks/hive.py +++ b/airflow/providers/apache/hive/hooks/hive.py @@ -23,7 +23,9 @@ import time from collections import OrderedDict from tempfile import NamedTemporaryFile, TemporaryDirectory +from typing import Any, Dict, List, Optional, Text, Union +import pandas import unicodecsv as csv from airflow.configuration import conf @@ -37,7 +39,7 @@ HIVE_QUEUE_PRIORITIES = ['VERY_HIGH', 'HIGH', 'NORMAL', 'LOW', 'VERY_LOW'] -def get_context_from_env_var(): +def get_context_from_env_var() -> Dict[Any, Any]: """ Extract context from env variable, e.g. dag_id, task_id and execution_date, so that they can be used inside BashOperator and PythonOperator. @@ -75,20 +77,21 @@ class HiveCliHook(BaseHook): """ def __init__( - self, - hive_cli_conn_id="hive_cli_default", - run_as=None, - mapred_queue=None, - mapred_queue_priority=None, - mapred_job_name=None): + self, + hive_cli_conn_id: str = "hive_cli_default", + run_as: Optional[str] = None, + mapred_queue: Optional[str] = None, + mapred_queue_priority: Optional[str] = None, + mapred_job_name: Optional[str] = None + ) -> None: super().__init__() conn = self.get_connection(hive_cli_conn_id) - self.hive_cli_params = conn.extra_dejson.get('hive_cli_params', '') - self.use_beeline = conn.extra_dejson.get('use_beeline', False) + self.hive_cli_params: str = conn.extra_dejson.get('hive_cli_params', '') + self.use_beeline: bool = conn.extra_dejson.get('use_beeline', False) self.auth = conn.extra_dejson.get('auth', 'noSasl') self.conn = conn self.run_as = run_as - self.sub_process = None + self.sub_process: Any = None if mapred_queue_priority: mapred_queue_priority = mapred_queue_priority.upper() @@ -102,13 +105,13 @@ def __init__( self.mapred_queue_priority = mapred_queue_priority self.mapred_job_name = mapred_job_name - def _get_proxy_user(self): + def _get_proxy_user(self) -> str: """ This function set the proper proxy_user value in case the user overwtire the default. """ conn = self.conn - proxy_user_value = conn.extra_dejson.get('proxy_user', "") + proxy_user_value: str = conn.extra_dejson.get('proxy_user', "") if proxy_user_value == "login" and conn.login: return "hive.server2.proxy.user={0}".format(conn.login) if proxy_user_value == "owner" and self.run_as: @@ -117,7 +120,7 @@ def _get_proxy_user(self): return "hive.server2.proxy.user={0}".format(proxy_user_value) return proxy_user_value # The default proxy user (undefined) - def _prepare_cli_cmd(self): + def _prepare_cli_cmd(self) -> List[Any]: """ This function creates the command list from available information """ @@ -156,7 +159,7 @@ def _prepare_cli_cmd(self): return [hive_bin] + cmd_extra + hive_params_list @staticmethod - def _prepare_hiveconf(d): + def _prepare_hiveconf(d: Dict[Any, Any]) -> List[Any]: """ This function prepares a list of hiveconf params from a dictionary of key value pairs. @@ -178,7 +181,12 @@ def _prepare_hiveconf(d): ["{}={}".format(k, v) for k, v in d.items()]) ) - def run_cli(self, hql, schema=None, verbose=True, hive_conf=None): + def run_cli(self, + hql: Union[str, Text], + schema: Optional[str] = None, + verbose: Optional[bool] = True, + hive_conf: Optional[Dict[Any, Any]] = None + ) -> Any: """ Run an hql statement using the hive cli. If hive_conf is specified it should be a dict and the entries will be set as key/value pairs @@ -242,7 +250,7 @@ def run_cli(self, hql, schema=None, verbose=True, hive_conf=None): if verbose: self.log.info("%s", " ".join(hive_cmd)) - sub_process = subprocess.Popen( + sub_process: Any = subprocess.Popen( hive_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, @@ -264,7 +272,7 @@ def run_cli(self, hql, schema=None, verbose=True, hive_conf=None): return stdout - def test_hql(self, hql): + def test_hql(self, hql: Union[str, Text]) -> None: """ Test an hql statement using the hive cli and EXPLAIN @@ -282,14 +290,14 @@ def test_hql(self, hql): other.append(query_original) elif query.startswith('insert'): insert.append(query_original) - other = ';'.join(other) + other_ = ';'.join(other) for query_set in [create, insert]: for query in query_set: query_preview = ' '.join(query.split())[:50] self.log.info("Testing HQL [%s (...)]", query_preview) if query_set == insert: - query = other + '; explain ' + query + query = other_ + '; explain ' + query else: query = 'explain ' + query try: @@ -308,13 +316,15 @@ def test_hql(self, hql): self.log.info("SUCCESS") def load_df( - self, - df, - table, - field_dict=None, - delimiter=',', - encoding='utf8', - pandas_kwargs=None, **kwargs): + self, + df: pandas.DataFrame, + table: str, + field_dict: Optional[Dict[Any, Any]] = None, + delimiter: str = ',', + encoding: str = 'utf8', + pandas_kwargs: Any = None, + **kwargs: Any + ) -> None: """ Loads a pandas DataFrame into hive. @@ -338,18 +348,20 @@ def load_df( :param kwargs: passed to self.load_file """ - def _infer_field_types_from_df(df): + def _infer_field_types_from_df( + df: pandas.DataFrame + ) -> Dict[Any, Any]: dtype_kind_hive_type = { - 'b': 'BOOLEAN', # boolean - 'i': 'BIGINT', # signed integer - 'u': 'BIGINT', # unsigned integer - 'f': 'DOUBLE', # floating-point - 'c': 'STRING', # complex floating-point + 'b': 'BOOLEAN', # boolean + 'i': 'BIGINT', # signed integer + 'u': 'BIGINT', # unsigned integer + 'f': 'DOUBLE', # floating-point + 'c': 'STRING', # complex floating-point 'M': 'TIMESTAMP', # datetime - 'O': 'STRING', # object - 'S': 'STRING', # (byte-)string - 'U': 'STRING', # Unicode - 'V': 'STRING' # void + 'O': 'STRING', # object + 'S': 'STRING', # (byte-)string + 'U': 'STRING', # Unicode + 'V': 'STRING' # void } order_type = OrderedDict() @@ -362,7 +374,6 @@ def _infer_field_types_from_df(df): with TemporaryDirectory(prefix='airflow_hiveop_') as tmp_dir: with NamedTemporaryFile(dir=tmp_dir, mode="w") as f: - if field_dict is None: field_dict = _infer_field_types_from_df(df) @@ -382,16 +393,17 @@ def _infer_field_types_from_df(df): **kwargs) def load_file( - self, - filepath, - table, - delimiter=",", - field_dict=None, - create=True, - overwrite=True, - partition=None, - recreate=False, - tblproperties=None): + self, + filepath: str, + table: str, + delimiter: str = ",", + field_dict: Optional[Dict[Any, Any]] = None, + create: bool = True, + overwrite: bool = True, + partition: Optional[Dict[str, Any]] = None, + recreate: bool = False, + tblproperties: Optional[Dict[str, Any]] = None + ) -> None: """ Loads a local file into Hive @@ -466,7 +478,7 @@ def load_file( self.log.info(hql) self.run_cli(hql) - def kill(self): + def kill(self) -> None: """ Kill Hive cli command """ @@ -484,23 +496,23 @@ class HiveMetastoreHook(BaseHook): # java short max val MAX_PART_COUNT = 32767 - def __init__(self, metastore_conn_id='metastore_default'): + def __init__(self, metastore_conn_id: str = 'metastore_default') -> None: super().__init__() self.conn_id = metastore_conn_id self.metastore = self.get_metastore_client() - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: # This is for pickling to work despite the thirft hive client not # being pickable state = dict(self.__dict__) del state['metastore'] return state - def __setstate__(self, d): + def __setstate__(self, d: Dict[str, Any]) -> None: self.__dict__.update(d) self.__dict__['metastore'] = self.get_metastore_client() - def get_metastore_client(self): + def get_metastore_client(self) -> Any: """ Returns a Hive thrift client. """ @@ -521,14 +533,13 @@ def get_metastore_client(self): conn_socket = TSocket.TSocket(conn.host, conn.port) - if conf.get('core', 'security') == 'kerberos' \ - and auth_mechanism == 'GSSAPI': + if conf.get('core', 'security') == 'kerberos' and auth_mechanism == 'GSSAPI': try: import saslwrapper as sasl except ImportError: import sasl - def sasl_factory(): + def sasl_factory() -> sasl.Client: sasl_client = sasl.Client() sasl_client.setAttr("host", conn.host) sasl_client.setAttr("service", kerberos_service_name) @@ -544,7 +555,7 @@ def sasl_factory(): return hmsclient.HMSClient(iprot=protocol) - def _find_valid_server(self): + def _find_valid_server(self) -> Any: conns = self.get_connections(self.conn_id) for conn in conns: host_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -557,10 +568,10 @@ def _find_valid_server(self): self.log.error("Could not connect to %s:%s", conn.host, conn.port) return None - def get_conn(self): + def get_conn(self) -> Any: return self.metastore - def check_for_partition(self, schema, table, partition): + def check_for_partition(self, schema: str, table: str, partition: str) -> bool: """ Checks whether a partition exists @@ -584,16 +595,20 @@ def check_for_partition(self, schema, table, partition): return bool(partitions) - def check_for_named_partition(self, schema, table, partition_name): + def check_for_named_partition(self, + schema: str, + table: str, + partition_name: str + ) -> Any: """ Checks whether a partition with a given name exists :param schema: Name of hive schema (database) @table belongs to :type schema: str :param table: Name of hive table @partition belongs to - :type schema: str + :type table: str :partition: Name of the partitions to check for (eg `a=b/c=d`) - :type schema: str + :type table: str :rtype: bool >>> hh = HiveMetastoreHook() @@ -606,7 +621,7 @@ def check_for_named_partition(self, schema, table, partition_name): with self.metastore as client: return client.check_for_named_partition(schema, table, partition_name) - def get_table(self, table_name, db='default'): + def get_table(self, table_name: str, db: str = 'default') -> Any: """Get a metastore table object >>> hh = HiveMetastoreHook() @@ -621,7 +636,7 @@ def get_table(self, table_name, db='default'): with self.metastore as client: return client.get_table(dbname=db, tbl_name=table_name) - def get_tables(self, db, pattern='*'): + def get_tables(self, db: str, pattern: str = '*') -> Any: """ Get a metastore table object """ @@ -629,14 +644,16 @@ def get_tables(self, db, pattern='*'): tables = client.get_tables(db_name=db, pattern=pattern) return client.get_table_objects_by_name(db, tables) - def get_databases(self, pattern='*'): + def get_databases(self, pattern: str = '*') -> Any: """ Get a metastore table object """ with self.metastore as client: return client.get_databases(pattern) - def get_partitions(self, schema, table_name, partition_filter=None): + def get_partitions(self, schema: str, table_name: str, + partition_filter: Optional[str] = None + ) -> List[Any]: """ Returns a list of all partitions in a table. Works only for tables with less than 32767 (java short max val). @@ -668,7 +685,10 @@ def get_partitions(self, schema, table_name, partition_filter=None): return [dict(zip(pnames, p.values)) for p in parts] @staticmethod - def _get_max_partition_from_part_specs(part_specs, partition_key, filter_map): + def _get_max_partition_from_part_specs(part_specs: List[Any], + partition_key: Optional[str], + filter_map: Optional[Dict[str, Any]] + ) -> Any: """ Helper method to get max partition of partitions with partition_key from part specs. key:value pair in filter_map will be used to @@ -693,6 +713,7 @@ def _get_max_partition_from_part_specs(part_specs, partition_key, filter_map): if partition_key not in part_specs[0].keys(): raise AirflowException("Provided partition_key {} " "is not in part_specs.".format(partition_key)) + is_subset = None if filter_map: is_subset = set(filter_map.keys()).issubset(set(part_specs[0].keys())) if filter_map and not is_subset: @@ -710,7 +731,10 @@ def _get_max_partition_from_part_specs(part_specs, partition_key, filter_map): else: return max(candidates) - def max_partition(self, schema, table_name, field=None, filter_map=None): + def max_partition(self, schema: str, table_name: str, + field: Optional[str] = None, + filter_map: Optional[Dict[Any, Any]] = None + ) -> Any: """ Returns the maximum value for all partitions with given field in a table. If only one partition key exist in the table, the key will be used as field. @@ -759,7 +783,7 @@ def max_partition(self, schema, table_name, field=None, filter_map=None): field, filter_map) - def table_exists(self, table_name, db='default'): + def table_exists(self, table_name: str, db: str = 'default') -> bool: """ Check if table exists @@ -791,11 +815,15 @@ class HiveServer2Hook(DbApiHook): default_conn_name = 'hiveserver2_default' supports_autocommit = False - def get_conn(self, schema=None): + def get_conn(self, schema: Optional[str] = None + ) -> Any: """ Returns a Hive connection object. """ - db = self.get_connection(self.hiveserver2_conn_id) # pylint: disable=no-member + username: Optional[str] = None + # pylint: disable=no-member + db = self.get_connection(self.hiveserver2_conn_id) # type: ignore + auth_mechanism = db.extra_dejson.get('authMechanism', 'NONE') if auth_mechanism == 'NONE' and db.login is None: # we need to give a username @@ -810,7 +838,7 @@ def get_conn(self, schema=None): self.log.warning( "Detected deprecated 'GSSAPI' for authMechanism " "for %s. Please use 'KERBEROS' instead", - self.hiveserver2_conn_id # pylint: disable=no-member + self.hiveserver2_conn_id # type: ignore ) auth_mechanism = 'KERBEROS' @@ -824,17 +852,23 @@ def get_conn(self, schema=None): password=db.password, database=schema or db.schema or 'default') - def _get_results(self, hql, schema='default', fetch_size=None, hive_conf=None): + # pylint: enable=no-member + + def _get_results(self, hql: Union[str, Text, List[str]], schema: str = 'default', + fetch_size: Optional[int] = None, + hive_conf: Optional[Dict[Any, Any]] = None) -> Any: from pyhive.exc import ProgrammingError if isinstance(hql, str): hql = [hql] previous_description = None - with contextlib.closing(self.get_conn(schema)) as conn, \ - contextlib.closing(conn.cursor()) as cur: + with contextlib.closing(self.get_conn(schema)) as conn, contextlib.closing(conn.cursor()) as cur: + cur.arraysize = fetch_size or 1000 # not all query services (e.g. impala AIRFLOW-4434) support the set command - db = self.get_connection(self.hiveserver2_conn_id) # pylint: disable=no-member + # pylint: disable=no-member + db = self.get_connection(self.hiveserver2_conn_id) # type: ignore + # pylint: enable=no-member if db.extra_dejson.get('run_set_variable_statements', True): env_context = get_context_from_env_var() if hive_conf: @@ -869,7 +903,10 @@ def _get_results(self, hql, schema='default', fetch_size=None, hive_conf=None): except ProgrammingError: self.log.debug("get_results returned no records") - def get_results(self, hql, schema='default', fetch_size=None, hive_conf=None): + def get_results(self, hql: Union[str, Text], schema: str = 'default', + fetch_size: Optional[int] = None, + hive_conf: Optional[Dict[Any, Any]] = None + ) -> Dict[str, Any]: """ Get results of the provided hql in target schema. @@ -894,15 +931,16 @@ def get_results(self, hql, schema='default', fetch_size=None, hive_conf=None): return results def to_csv( - self, - hql, - csv_filepath, - schema='default', - delimiter=',', - lineterminator='\r\n', - output_header=True, - fetch_size=1000, - hive_conf=None): + self, + hql: Union[str, Text], + csv_filepath: str, + schema: str = 'default', + delimiter: str = ',', + lineterminator: str = '\r\n', + output_header: bool = True, + fetch_size: int = 1000, + hive_conf: Optional[Dict[Any, Any]] = None + ) -> None: """ Execute hql in target schema and write results to a csv file. @@ -955,7 +993,10 @@ def to_csv( self.log.info("Done. Loaded a total of %s rows.", i) - def get_records(self, hql, schema='default', hive_conf=None): + def get_records(self, hql: Union[str, Text], + schema: str = 'default', + hive_conf: Optional[Dict[Any, Any]] = None + ) -> Any: """ Get a set of records from a Hive query. @@ -975,7 +1016,10 @@ def get_records(self, hql, schema='default', hive_conf=None): """ return self.get_results(hql, schema=schema, hive_conf=hive_conf)['data'] - def get_pandas_df(self, hql, schema='default', hive_conf=None): + def get_pandas_df(self, hql: Union[str, Text], + schema: str = 'default', + hive_conf: Optional[Dict[Any, Any]] = None + ) -> pandas.DataFrame: """ Get a pandas dataframe from a Hive query @@ -996,8 +1040,7 @@ def get_pandas_df(self, hql, schema='default', hive_conf=None): :return: pandas.DateFrame """ - import pandas as pd res = self.get_results(hql, schema=schema, hive_conf=hive_conf) - df = pd.DataFrame(res['data']) + df = pandas.DataFrame(res['data']) df.columns = [c[0] for c in res['header']] return df diff --git a/airflow/providers/apache/hive/operators/hive.py b/airflow/providers/apache/hive/operators/hive.py index f82a2446bb1f..6f74ba5d8d38 100644 --- a/airflow/providers/apache/hive/operators/hive.py +++ b/airflow/providers/apache/hive/operators/hive.py @@ -15,10 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - import os import re -from typing import Dict, Optional +from typing import Any, Dict, Optional, Tuple from airflow.configuration import conf from airflow.models import BaseOperator @@ -75,14 +74,16 @@ def __init__( hql: str, hive_cli_conn_id: str = 'hive_cli_default', schema: str = 'default', - hiveconfs: Optional[Dict] = None, + hiveconfs: Optional[Dict[Any, Any]] = None, hiveconf_jinja_translate: bool = False, script_begin_tag: Optional[str] = None, run_as_owner: bool = False, mapred_queue: Optional[str] = None, mapred_queue_priority: Optional[str] = None, mapred_job_name: Optional[str] = None, - *args, **kwargs) -> None: + *args: Tuple[Any, ...], + **kwargs: Any + ) -> None: super().__init__(*args, **kwargs) self.hql = hql @@ -105,9 +106,9 @@ def __init__( # `None` initial value, later it will be populated by the execute method. # This also makes `on_kill` implementation consistent since it assumes `self.hook` # is defined. - self.hook = None + self.hook: Optional[HiveCliHook] = None - def get_hook(self): + def get_hook(self) -> HiveCliHook: """ Get Hive cli hook """ @@ -118,14 +119,14 @@ def get_hook(self): mapred_queue_priority=self.mapred_queue_priority, mapred_job_name=self.mapred_job_name) - def prepare_template(self): + def prepare_template(self) -> None: if self.hiveconf_jinja_translate: self.hql = re.sub( r"(\$\{(hiveconf:)?([ a-zA-Z0-9_]*)\})", r"{{ \g<3> }}", self.hql) if self.script_begin_tag and self.script_begin_tag in self.hql: self.hql = "\n".join(self.hql.split(self.script_begin_tag)[1:]) - def execute(self, context): + def execute(self, context: Dict[str, Any]) -> None: self.log.info('Executing: %s', self.hql) self.hook = self.get_hook() @@ -145,7 +146,7 @@ def execute(self, context): self.log.info('Passing HiveConf: %s', self.hiveconfs) self.hook.run_cli(hql=self.hql, schema=self.schema, hive_conf=self.hiveconfs) - def dry_run(self): + def dry_run(self) -> None: # Reset airflow environment variables to prevent # existing env vars from impacting behavior. self.clear_airflow_vars() @@ -153,11 +154,11 @@ def dry_run(self): self.hook = self.get_hook() self.hook.test_hql(hql=self.hql) - def on_kill(self): + def on_kill(self) -> None: if self.hook: self.hook.kill() - def clear_airflow_vars(self): + def clear_airflow_vars(self) -> None: """ Reset airflow environment variables to prevent existing ones from impacting behavior. """ diff --git a/airflow/providers/apache/hive/operators/hive_stats.py b/airflow/providers/apache/hive/operators/hive_stats.py index 96504c9451da..7e161d3ed03f 100644 --- a/airflow/providers/apache/hive/operators/hive_stats.py +++ b/airflow/providers/apache/hive/operators/hive_stats.py @@ -15,11 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - import json import warnings from collections import OrderedDict -from typing import Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -66,14 +65,16 @@ class HiveStatsCollectionOperator(BaseOperator): @apply_defaults def __init__(self, table: str, - partition: str, - extra_exprs: Optional[Dict] = None, - excluded_columns: Optional[List] = None, - assignment_func: Optional[Callable[[str, str], Optional[Dict]]] = None, + partition: Any, + extra_exprs: Optional[Dict[str, Any]] = None, + excluded_columns: Optional[List[str]] = None, + assignment_func: Optional[Callable[[str, str], Optional[Dict[Any, Any]]]] = None, metastore_conn_id: str = 'metastore_default', presto_conn_id: str = 'presto_default', mysql_conn_id: str = 'airflow_db', - *args, **kwargs) -> None: + *args: Tuple[Any, ...], + **kwargs: Any + ) -> None: if 'col_blacklist' in kwargs: warnings.warn( 'col_blacklist kwarg passed to {c} (task_id: {t}) is deprecated, please rename it to ' @@ -87,7 +88,7 @@ def __init__(self, self.table = table self.partition = partition self.extra_exprs = extra_exprs or {} - self.excluded_columns = excluded_columns or [] # type: List + self.excluded_columns = excluded_columns or [] # type: List[str] self.metastore_conn_id = metastore_conn_id self.presto_conn_id = presto_conn_id self.mysql_conn_id = mysql_conn_id @@ -95,7 +96,7 @@ def __init__(self, self.ds = '{{ ds }}' self.dttm = '{{ execution_date.isoformat() }}' - def get_default_exprs(self, col, col_type): + def get_default_exprs(self, col: str, col_type: str) -> Dict[Any, Any]: """ Get default expressions """ @@ -116,12 +117,12 @@ def get_default_exprs(self, col, col_type): return exp - def execute(self, context=None): + def execute(self, context: Optional[Dict[str, Any]] = None) -> None: metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id) table = metastore.get_table(table_name=self.table) field_types = {col.name: col.type for col in table.sd.cols} - exprs = { + exprs: Any = { ('', 'count'): 'COUNT(*)' } for col, col_type in list(field_types.items()): @@ -138,8 +139,8 @@ def execute(self, context=None): v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items()]) - where_clause = ["{} = '{}'".format(k, v) for k, v in self.partition.items()] - where_clause = " AND\n ".join(where_clause) + where_clause_ = ["{} = '{}'".format(k, v) for k, v in self.partition.items()] + where_clause = " AND\n ".join(where_clause_) sql = "SELECT {exprs_str} FROM {table} WHERE {where_clause};".format( exprs_str=exprs_str, table=self.table, where_clause=where_clause) diff --git a/airflow/providers/apache/hive/sensors/hive_partition.py b/airflow/providers/apache/hive/sensors/hive_partition.py index f705039b1f57..06762331022e 100644 --- a/airflow/providers/apache/hive/sensors/hive_partition.py +++ b/airflow/providers/apache/hive/sensors/hive_partition.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Any, Dict, Optional, Tuple from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook from airflow.sensors.base_sensor_operator import BaseSensorOperator @@ -46,12 +47,13 @@ class HivePartitionSensor(BaseSensorOperator): @apply_defaults def __init__(self, - table, partition="ds='{{ ds }}'", - metastore_conn_id='metastore_default', - schema='default', - poke_interval=60 * 3, - *args, - **kwargs): + table: str, + partition: Optional[str] = "ds='{{ ds }}'", + metastore_conn_id: str = 'metastore_default', + schema: str = 'default', + poke_interval: int = 60 * 3, + *args: Tuple[Any, ...], + **kwargs: Any): super().__init__( poke_interval=poke_interval, *args, **kwargs) if not partition: @@ -61,7 +63,7 @@ def __init__(self, self.partition = partition self.schema = schema - def poke(self, context): + def poke(self, context: Dict[str, Any]) -> bool: if '.' in self.table: self.schema, self.table = self.table.split('.') self.log.info( diff --git a/airflow/providers/apache/hive/sensors/metastore_partition.py b/airflow/providers/apache/hive/sensors/metastore_partition.py index a39abe9f2a51..04f688ab1990 100644 --- a/airflow/providers/apache/hive/sensors/metastore_partition.py +++ b/airflow/providers/apache/hive/sensors/metastore_partition.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Any, Dict, Tuple from airflow.sensors.sql_sensor import SqlSensor from airflow.utils.decorators import apply_defaults @@ -45,12 +46,12 @@ class MetastorePartitionSensor(SqlSensor): @apply_defaults def __init__(self, - table, - partition_name, - schema="default", - mysql_conn_id="metastore_mysql", - *args, - **kwargs): + table: str, + partition_name: str, + schema: str = "default", + mysql_conn_id: str = "metastore_mysql", + *args: Tuple[Any, ...], + **kwargs: Any): self.partition_name = partition_name self.table = table @@ -64,7 +65,7 @@ def __init__(self, # constructor below and apply_defaults will no longer throw an exception. super().__init__(*args, **kwargs) - def poke(self, context): + def poke(self, context: Dict[str, Any]) -> Any: if self.first_poke: self.first_poke = False if '.' in self.table: diff --git a/airflow/providers/apache/hive/sensors/named_hive_partition.py b/airflow/providers/apache/hive/sensors/named_hive_partition.py index d1c838e87e8e..acab91823796 100644 --- a/airflow/providers/apache/hive/sensors/named_hive_partition.py +++ b/airflow/providers/apache/hive/sensors/named_hive_partition.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Any, Dict, List, Tuple from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults @@ -42,12 +43,12 @@ class NamedHivePartitionSensor(BaseSensorOperator): @apply_defaults def __init__(self, - partition_names, - metastore_conn_id='metastore_default', - poke_interval=60 * 3, - hook=None, - *args, - **kwargs): + partition_names: List[str], + metastore_conn_id: str = 'metastore_default', + poke_interval: int = 60 * 3, + hook: Any = None, + *args: Tuple[Any, ...], + **kwargs: Any): super().__init__( poke_interval=poke_interval, *args, **kwargs) @@ -64,7 +65,7 @@ def __init__(self, ) @staticmethod - def parse_partition_name(partition): + def parse_partition_name(partition: str) -> Tuple[Any, ...]: """Get schema, table, and partition info.""" first_split = partition.split('.', 1) if len(first_split) == 1: @@ -80,7 +81,7 @@ def parse_partition_name(partition): table, partition = second_split return schema, table, partition - def poke_partition(self, partition): + def poke_partition(self, partition: str) -> Any: """Check for a named partition.""" if not self.hook: from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook @@ -93,7 +94,7 @@ def poke_partition(self, partition): return self.hook.check_for_named_partition( schema, table, partition) - def poke(self, context): + def poke(self, context: Dict[str, Any]) -> bool: number_of_partitions = len(self.partition_names) poke_index_start = self.next_index_to_poke diff --git a/airflow/providers/apache/livy/hooks/livy.py b/airflow/providers/apache/livy/hooks/livy.py index c826640e310a..f0030872bddf 100644 --- a/airflow/providers/apache/livy/hooks/livy.py +++ b/airflow/providers/apache/livy/hooks/livy.py @@ -18,10 +18,10 @@ """ This module contains the Apache Livy hook. """ - import json import re from enum import Enum +from typing import Any, Dict, List, Optional, Sequence, Union import requests @@ -70,10 +70,10 @@ class LivyHook(HttpHook, LoggingMixin): 'Accept': 'application/json' } - def __init__(self, livy_conn_id='livy_default'): + def __init__(self, livy_conn_id: str = 'livy_default') -> None: super(LivyHook, self).__init__(http_conn_id=livy_conn_id) - def get_conn(self, headers=None): + def get_conn(self, headers: Optional[Dict[str, Any]] = None) -> Any: """ Returns http session for use with requests @@ -87,7 +87,14 @@ def get_conn(self, headers=None): tmp_headers.update(headers) return super().get_conn(tmp_headers) - def run_method(self, method='GET', endpoint=None, data=None, headers=None, extra_options=None): + def run_method( + self, + endpoint: str, + method: str = 'GET', + data: Optional[Any] = None, + headers: Optional[Dict[str, Any]] = None, + extra_options: Optional[Dict[Any, Any]] = None + ) -> Any: """ Wrapper for HttpHook, allows to change method on the same HttpHook @@ -117,7 +124,7 @@ def run_method(self, method='GET', endpoint=None, data=None, headers=None, extra self.method = back_method return result - def post_batch(self, *args, **kwargs): + def post_batch(self, *args: Any, **kwargs: Any) -> Any: """ Perform request to submit batch @@ -153,7 +160,7 @@ def post_batch(self, *args, **kwargs): return batch_id - def get_batch(self, session_id): + def get_batch(self, session_id: Union[int, str]) -> Any: """ Fetch info about the specified batch @@ -178,7 +185,7 @@ def get_batch(self, session_id): return response.json() - def get_batch_state(self, session_id): + def get_batch_state(self, session_id: Union[int, str]) -> BatchState: """ Fetch the state of the specified batch @@ -206,7 +213,7 @@ def get_batch_state(self, session_id): raise AirflowException("Unable to get state for batch with id: {}".format(session_id)) return BatchState(jresp['state']) - def delete_batch(self, session_id): + def delete_batch(self, session_id: Union[int, str]) -> Any: """ Delete the specified batch @@ -235,7 +242,7 @@ def delete_batch(self, session_id): return response.json() @staticmethod - def _validate_session_id(session_id): + def _validate_session_id(session_id: Union[int, str]) -> None: """ Validate session id is a int @@ -248,36 +255,36 @@ def _validate_session_id(session_id): raise TypeError("'session_id' must be an integer") @staticmethod - def _parse_post_response(response): + def _parse_post_response(response: Dict[Any, Any]) -> Any: """ Parse batch response for batch id :param response: response body :type response: dict :return: session id - :rtype: str + :rtype: int """ return response.get('id') @staticmethod def build_post_batch_body( - file, - args=None, - class_name=None, - jars=None, - py_files=None, - files=None, - archives=None, - name=None, - driver_memory=None, - driver_cores=None, - executor_memory=None, - executor_cores=None, - num_executors=None, - queue=None, - proxy_user=None, - conf=None - ): + file: str, + args: Optional[Sequence[Union[str, int, float]]] = None, + class_name: Optional[str] = None, + jars: Optional[List[str]] = None, + py_files: Optional[List[str]] = None, + files: Optional[List[str]] = None, + archives: Optional[List[str]] = None, + name: Optional[str] = None, + driver_memory: Optional[str] = None, + driver_cores: Optional[Union[int, str]] = None, + executor_memory: Optional[str] = None, + executor_cores: Optional[int] = None, + num_executors: Optional[Union[int, str]] = None, + queue: Optional[str] = None, + proxy_user: Optional[str] = None, + conf: Optional[Dict[Any, Any]] = None + ) -> Any: """ Build the post batch request body. For more information about the format refer to @@ -320,7 +327,7 @@ def build_post_batch_body( """ # pylint: disable-msg=too-many-arguments - body = {'file': file} + body: Dict[str, Any] = {'file': file} if proxy_user: body['proxyUser'] = proxy_user @@ -356,7 +363,7 @@ def build_post_batch_body( return body @staticmethod - def _validate_size_format(size): + def _validate_size_format(size: str) -> bool: """ Validate size format. @@ -370,7 +377,7 @@ def _validate_size_format(size): return True @staticmethod - def _validate_list_of_stringables(vals): + def _validate_list_of_stringables(vals: Sequence[Union[str, int, float]]) -> bool: """ Check the values in the provided list can be converted to strings. @@ -386,7 +393,7 @@ def _validate_list_of_stringables(vals): return True @staticmethod - def _validate_extra_conf(conf): + def _validate_extra_conf(conf: Dict[Any, Any]) -> bool: """ Check configuration values are either strings or ints. diff --git a/airflow/providers/apache/livy/operators/livy.py b/airflow/providers/apache/livy/operators/livy.py index d105c7a028ce..44cc29d8ee4a 100644 --- a/airflow/providers/apache/livy/operators/livy.py +++ b/airflow/providers/apache/livy/operators/livy.py @@ -18,8 +18,8 @@ """ This module contains the Apache Livy operator. """ - from time import sleep +from typing import Any, Dict, Optional, Sequence, Union from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -75,26 +75,26 @@ class LivyOperator(BaseOperator): @apply_defaults def __init__( self, - file, - class_name=None, - args=None, - conf=None, - jars=None, - py_files=None, - files=None, - driver_memory=None, - driver_cores=None, - executor_memory=None, - executor_cores=None, - num_executors=None, - archives=None, - queue=None, - name=None, - proxy_user=None, - livy_conn_id='livy_default', - polling_interval=0, - **kwargs - ): + file: str, + class_name: Optional[str] = None, + args: Optional[Sequence[Union[str, int, float]]] = None, + conf: Optional[Dict[Any, Any]] = None, + jars: Optional[Sequence[str]] = None, + py_files: Optional[Sequence[str]] = None, + files: Optional[Sequence[str]] = None, + driver_memory: Optional[str] = None, + driver_cores: Optional[Union[int, str]] = None, + executor_memory: Optional[str] = None, + executor_cores: Optional[Union[int, str]] = None, + num_executors: Optional[Union[int, str]] = None, + archives: Optional[Sequence[str]] = None, + queue: Optional[str] = None, + name: Optional[str] = None, + proxy_user: Optional[str] = None, + livy_conn_id: str = 'livy_default', + polling_interval: int = 0, + **kwargs: Any + ) -> None: # pylint: disable-msg=too-many-arguments super().__init__(**kwargs) @@ -121,10 +121,10 @@ def __init__( self._livy_conn_id = livy_conn_id self._polling_interval = polling_interval - self._livy_hook = None - self._batch_id = None + self._livy_hook: Optional[LivyHook] = None + self._batch_id: Union[int, str] - def get_hook(self): + def get_hook(self) -> LivyHook: """ Get valid hook. @@ -135,7 +135,7 @@ def get_hook(self): self._livy_hook = LivyHook(livy_conn_id=self._livy_conn_id) return self._livy_hook - def execute(self, context): + def execute(self, context: Dict[Any, Any]) -> Any: self._batch_id = self.get_hook().post_batch(**self.spark_params) if self._polling_interval > 0: @@ -143,7 +143,7 @@ def execute(self, context): return self._batch_id - def poll_for_termination(self, batch_id): + def poll_for_termination(self, batch_id: Union[int, str]) -> None: """ Pool Livy for batch termination. @@ -160,10 +160,10 @@ def poll_for_termination(self, batch_id): if state != BatchState.SUCCESS: raise AirflowException("Batch {} did not succeed".format(batch_id)) - def on_kill(self): + def on_kill(self) -> None: self.kill() - def kill(self): + def kill(self) -> None: """ Delete the current batch session. """ diff --git a/airflow/providers/apache/livy/sensors/livy.py b/airflow/providers/apache/livy/sensors/livy.py index c4fb02442134..023946499801 100644 --- a/airflow/providers/apache/livy/sensors/livy.py +++ b/airflow/providers/apache/livy/sensors/livy.py @@ -18,6 +18,7 @@ """ This module contains the Apache Livy sensor. """ +from typing import Any, Dict, Optional, Union from airflow.providers.apache.livy.hooks.livy import LivyHook from airflow.sensors.base_sensor_operator import BaseSensorOperator @@ -39,17 +40,17 @@ class LivySensor(BaseSensorOperator): @apply_defaults def __init__( self, - livy_conn_id='livy_default', - batch_id=None, - *vargs, - **kwargs - ): + batch_id: Union[int, str], + livy_conn_id: str = 'livy_default', + *vargs: Any, + **kwargs: Any + ) -> None: super().__init__(*vargs, **kwargs) self._livy_conn_id = livy_conn_id self._batch_id = batch_id - self._livy_hook = None + self._livy_hook: Optional[LivyHook] = None - def get_hook(self): + def get_hook(self) -> LivyHook: """ Get valid hook. @@ -60,8 +61,8 @@ def get_hook(self): self._livy_hook = LivyHook(livy_conn_id=self._livy_conn_id) return self._livy_hook - def poke(self, context): + def poke(self, context: Dict[Any, Any]) -> bool: batch_id = self._batch_id status = self.get_hook().get_batch_state(batch_id) - return status in self._livy_hook.TERMINAL_STATES + return status in self.get_hook().TERMINAL_STATES diff --git a/airflow/providers/apache/pig/hooks/pig.py b/airflow/providers/apache/pig/hooks/pig.py index 6a938aaa8ed5..8baee6c73717 100644 --- a/airflow/providers/apache/pig/hooks/pig.py +++ b/airflow/providers/apache/pig/hooks/pig.py @@ -15,9 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - import subprocess from tempfile import NamedTemporaryFile, TemporaryDirectory +from typing import Any, List, Optional from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook @@ -35,14 +35,15 @@ class PigCliHook(BaseHook): def __init__( self, - pig_cli_conn_id="pig_cli_default"): + pig_cli_conn_id: str = "pig_cli_default") -> None: super().__init__() conn = self.get_connection(pig_cli_conn_id) self.pig_properties = conn.extra_dejson.get('pig_properties', '') self.conn = conn self.sub_process = None - def run_cli(self, pig, pig_opts=None, verbose=True): + def run_cli(self, pig: str, pig_opts: Optional[str] = None, + verbose: bool = True) -> Any: """ Run an pig script using the pig cli @@ -58,7 +59,7 @@ def run_cli(self, pig, pig_opts=None, verbose=True): f.flush() fname = f.name pig_bin = 'pig' - cmd_extra = [] + cmd_extra: List[str] = [] pig_cmd = [pig_bin] @@ -73,7 +74,7 @@ def run_cli(self, pig, pig_opts=None, verbose=True): if verbose: self.log.info("%s", " ".join(pig_cmd)) - sub_process = subprocess.Popen( + sub_process: Any = subprocess.Popen( pig_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, @@ -92,11 +93,11 @@ def run_cli(self, pig, pig_opts=None, verbose=True): return stdout - def kill(self): + def kill(self) -> None: """ Kill Pig job """ if self.sub_process: if self.sub_process.poll() is None: - print("Killing the Pig job") + self.log.info("Killing the Pig job") self.sub_process.kill() diff --git a/airflow/providers/apache/pig/operators/pig.py b/airflow/providers/apache/pig/operators/pig.py index 99da407f4071..8d3b08f437f7 100644 --- a/airflow/providers/apache/pig/operators/pig.py +++ b/airflow/providers/apache/pig/operators/pig.py @@ -15,9 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - import re -from typing import Optional +from typing import Any, Optional, Tuple from airflow.models import BaseOperator from airflow.providers.apache.pig.hooks.pig import PigCliHook @@ -53,7 +52,8 @@ def __init__( pig_cli_conn_id: str = 'pig_cli_default', pigparams_jinja_translate: bool = False, pig_opts: Optional[str] = None, - *args, **kwargs) -> None: + *args: Tuple[Any, ...], + **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.pigparams_jinja_translate = pigparams_jinja_translate diff --git a/airflow/providers/apache/pinot/hooks/pinot.py b/airflow/providers/apache/pinot/hooks/pinot.py index 89b106fdc19d..24369455f53a 100644 --- a/airflow/providers/apache/pinot/hooks/pinot.py +++ b/airflow/providers/apache/pinot/hooks/pinot.py @@ -18,13 +18,14 @@ import os import subprocess -from typing import Optional +from typing import Any, Dict, Iterable, List, Optional, Union from pinotdb import connect from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook from airflow.hooks.dbapi_hook import DbApiHook +from airflow.models import Connection class PinotAdminHook(BaseHook): @@ -54,9 +55,10 @@ class PinotAdminHook(BaseHook): """ def __init__(self, - conn_id="pinot_admin_default", - cmd_path="pinot-admin.sh", - pinot_admin_system_exit=False): + conn_id: str = "pinot_admin_default", + cmd_path: str = "pinot-admin.sh", + pinot_admin_system_exit: bool = False + ) -> None: super().__init__() conn = self.get_connection(conn_id) self.host = conn.host @@ -66,10 +68,12 @@ def __init__(self, pinot_admin_system_exit) self.conn = conn - def get_conn(self): + def get_conn(self) -> Any: return self.conn - def add_schema(self, schema_file: str, with_exec: Optional[bool] = True): + def add_schema(self, schema_file: str, + with_exec: Optional[bool] = True + ) -> Any: """ Add Pinot schema by run AddSchema command @@ -86,7 +90,9 @@ def add_schema(self, schema_file: str, with_exec: Optional[bool] = True): cmd += ["-exec"] self.run_cli(cmd) - def add_table(self, file_path: str, with_exec: Optional[bool] = True): + def add_table(self, file_path: str, + with_exec: Optional[bool] = True + ) -> Any: """ Add Pinot table with run AddTable command @@ -105,24 +111,25 @@ def add_table(self, file_path: str, with_exec: Optional[bool] = True): # pylint: disable=too-many-arguments def create_segment(self, - generator_config_file=None, - data_dir=None, - segment_format=None, - out_dir=None, - overwrite=None, - table_name=None, - segment_name=None, - time_column_name=None, - schema_file=None, - reader_config_file=None, - enable_star_tree_index=None, - star_tree_index_spec_file=None, - hll_size=None, - hll_columns=None, - hll_suffix=None, - num_threads=None, - post_creation_verification=None, - retry=None): + generator_config_file: Optional[str] = None, + data_dir: Optional[str] = None, + segment_format: Optional[str] = None, + out_dir: Optional[str] = None, + overwrite: Optional[str] = None, + table_name: Optional[str] = None, + segment_name: Optional[str] = None, + time_column_name: Optional[str] = None, + schema_file: Optional[str] = None, + reader_config_file: Optional[str] = None, + enable_star_tree_index: Optional[str] = None, + star_tree_index_spec_file: Optional[str] = None, + hll_size: Optional[str] = None, + hll_columns: Optional[str] = None, + hll_suffix: Optional[str] = None, + num_threads: Optional[str] = None, + post_creation_verification: Optional[str] = None, + retry: Optional[str] = None + ) -> Any: """ Create Pinot segment by run CreateSegment command """ @@ -184,7 +191,8 @@ def create_segment(self, self.run_cli(cmd) - def upload_segment(self, segment_dir, table_name=None): + def upload_segment(self, segment_dir: str, table_name: Optional[str] = None + ) -> Any: """ Upload Segment with run UploadSegment command @@ -200,7 +208,7 @@ def upload_segment(self, segment_dir, table_name=None): cmd += ["-tableName", table_name] self.run_cli(cmd) - def run_cli(self, cmd: list, verbose: Optional[bool] = True): + def run_cli(self, cmd: List[str], verbose: Optional[bool] = True) -> str: """ Run command with pinot-admin.sh @@ -255,11 +263,13 @@ class PinotDbApiHook(DbApiHook): default_conn_name = 'pinot_broker_default' supports_autocommit = False - def get_conn(self): + def get_conn(self) -> Any: """ Establish a connection to pinot broker through pinot dbapi. """ - conn = self.get_connection(self.pinot_broker_conn_id) # pylint: disable=no-member + # pylint: disable=no-member + conn = self.get_connection(self.pinot_broker_conn_id) # type: ignore + # pylint: enable=no-member pinot_broker_conn = connect( host=conn.host, port=conn.port, @@ -270,7 +280,7 @@ def get_conn(self): 'broker on %s', conn.host) return pinot_broker_conn - def get_uri(self): + def get_uri(self) -> str: """ Get the connection uri for pinot broker. @@ -285,32 +295,44 @@ def get_uri(self): return '{conn_type}://{host}/{endpoint}'.format( conn_type=conn_type, host=host, endpoint=endpoint) - def get_records(self, sql): + def get_records(self, sql: str, + parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None + ) -> Any: """ Executes the sql and returns a set of records. :param sql: the sql statement to be executed (str) or a list of sql statements to execute :type sql: str + :param parameters: The parameters to render the SQL query with. + :type parameters: dict or iterable """ with self.get_conn() as cur: cur.execute(sql) return cur.fetchall() - def get_first(self, sql): + def get_first(self, sql: str, + parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None + ) -> Any: """ Executes the sql and returns the first resulting row. :param sql: the sql statement to be executed (str) or a list of sql statements to execute :type sql: str or list + :param parameters: The parameters to render the SQL query with. + :type parameters: dict or iterable """ with self.get_conn() as cur: cur.execute(sql) return cur.fetchone() - def set_autocommit(self, conn, autocommit): + def set_autocommit(self, conn: Connection, autocommit: Any) -> Any: raise NotImplementedError() - def insert_rows(self, table, rows, target_fields=None, commit_every=1000): + def insert_rows(self, table: str, rows: str, + target_fields: Optional[str] = None, + commit_every: int = 1000, + replace: bool = False, + **kwargs: Any) -> Any: raise NotImplementedError() diff --git a/airflow/providers/apache/spark/hooks/spark_jdbc.py b/airflow/providers/apache/spark/hooks/spark_jdbc.py index f0442fdcca7a..8ec3f4996c16 100644 --- a/airflow/providers/apache/spark/hooks/spark_jdbc.py +++ b/airflow/providers/apache/spark/hooks/spark_jdbc.py @@ -17,6 +17,7 @@ # under the License. # import os +from typing import Any, Dict, Optional from airflow.exceptions import AirflowException from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook @@ -113,36 +114,36 @@ class SparkJDBCHook(SparkSubmitHook): # pylint: disable=too-many-arguments,too-many-locals def __init__(self, - spark_app_name='airflow-spark-jdbc', - spark_conn_id='spark-default', - spark_conf=None, - spark_py_files=None, - spark_files=None, - spark_jars=None, - num_executors=None, - executor_cores=None, - executor_memory=None, - driver_memory=None, - verbose=False, - principal=None, - keytab=None, - cmd_type='spark_to_jdbc', - jdbc_table=None, - jdbc_conn_id='jdbc-default', - jdbc_driver=None, - metastore_table=None, - jdbc_truncate=False, - save_mode=None, - save_format=None, - batch_size=None, - fetch_size=None, - num_partitions=None, - partition_column=None, - lower_bound=None, - upper_bound=None, - create_table_column_types=None, - *args, - **kwargs + spark_app_name: str = 'airflow-spark-jdbc', + spark_conn_id: str = 'spark-default', + spark_conf: Optional[Dict[str, Any]] = None, + spark_py_files: Optional[str] = None, + spark_files: Optional[str] = None, + spark_jars: Optional[str] = None, + num_executors: Optional[int] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + driver_memory: Optional[str] = None, + verbose: bool = False, + principal: Optional[str] = None, + keytab: Optional[str] = None, + cmd_type: str = 'spark_to_jdbc', + jdbc_table: Optional[str] = None, + jdbc_conn_id: str = 'jdbc-default', + jdbc_driver: Optional[str] = None, + metastore_table: Optional[str] = None, + jdbc_truncate: bool = False, + save_mode: Optional[str] = None, + save_format: Optional[str] = None, + batch_size: Optional[int] = None, + fetch_size: Optional[int] = None, + num_partitions: Optional[int] = None, + partition_column: Optional[str] = None, + lower_bound: Optional[str] = None, + upper_bound: Optional[str] = None, + create_table_column_types: Optional[str] = None, + *args: Any, + **kwargs: Any ): super().__init__(*args, **kwargs) self._name = spark_app_name @@ -175,7 +176,7 @@ def __init__(self, self._create_table_column_types = create_table_column_types self._jdbc_connection = self._resolve_jdbc_connection() - def _resolve_jdbc_connection(self): + def _resolve_jdbc_connection(self) -> Dict[str, Any]: conn_data = {'url': '', 'schema': '', 'conn_prefix': '', @@ -200,7 +201,7 @@ def _resolve_jdbc_connection(self): ) return conn_data - def _build_jdbc_application_arguments(self, jdbc_conn): + def _build_jdbc_application_arguments(self, jdbc_conn: Dict[str, Any]) -> Any: arguments = [] arguments += ["-cmdType", self._cmd_type] if self._jdbc_connection['url']: @@ -239,7 +240,7 @@ def _build_jdbc_application_arguments(self, jdbc_conn): arguments += ['-createTableColumnTypes', self._create_table_column_types] return arguments - def submit_jdbc_job(self): + def submit_jdbc_job(self) -> None: """ Submit Spark JDBC job """ @@ -248,5 +249,5 @@ def submit_jdbc_job(self): self.submit(application=os.path.dirname(os.path.abspath(__file__)) + "/spark_jdbc_script.py") - def get_conn(self): + def get_conn(self) -> Any: pass diff --git a/airflow/providers/apache/spark/hooks/spark_jdbc_script.py b/airflow/providers/apache/spark/hooks/spark_jdbc_script.py index 442a88224a9f..3a9f56a24a0f 100644 --- a/airflow/providers/apache/spark/hooks/spark_jdbc_script.py +++ b/airflow/providers/apache/spark/hooks/spark_jdbc_script.py @@ -17,20 +17,20 @@ # under the License. # import argparse -from typing import List, Optional +from typing import Any, List, Optional from pyspark.sql import SparkSession -SPARK_WRITE_TO_JDBC = "spark_to_jdbc" -SPARK_READ_FROM_JDBC = "jdbc_to_spark" +SPARK_WRITE_TO_JDBC: str = "spark_to_jdbc" +SPARK_READ_FROM_JDBC: str = "jdbc_to_spark" -def set_common_options(spark_source, - url='localhost:5432', - jdbc_table='default.default', - user='root', - password='root', - driver='driver'): +def set_common_options(spark_source: Any, + url: str = 'localhost:5432', + jdbc_table: str = 'default.default', + user: str = 'root', + password: str = 'root', + driver: str = 'driver') -> Any: """ Get Spark source from JDBC connection @@ -53,9 +53,18 @@ def set_common_options(spark_source, # pylint: disable=too-many-arguments -def spark_write_to_jdbc(spark_session, url, user, password, metastore_table, jdbc_table, driver, - truncate, save_mode, batch_size, num_partitions, - create_table_column_types): +def spark_write_to_jdbc(spark_session: SparkSession, + url: str, + user: str, + password: str, + metastore_table: str, + jdbc_table: str, + driver: Any, + truncate: bool, + save_mode: str, + batch_size: int, + num_partitions: int, + create_table_column_types: str) -> None: """ Transfer data from Spark to JDBC source """ @@ -81,9 +90,21 @@ def spark_write_to_jdbc(spark_session, url, user, password, metastore_table, jdb # pylint: disable=too-many-arguments -def spark_read_from_jdbc(spark_session, url, user, password, metastore_table, jdbc_table, driver, - save_mode, save_format, fetch_size, num_partitions, - partition_column, lower_bound, upper_bound): +def spark_read_from_jdbc(spark_session: SparkSession, + url: str, + user: str, + password: str, + metastore_table: str, + jdbc_table: str, + driver: Any, + save_mode: str, + save_format: str, + fetch_size: int, + num_partitions: int, + partition_column: str, + lower_bound: str, + upper_bound: str + ) -> None: """ Transfer data from JDBC source to Spark """ @@ -108,7 +129,7 @@ def spark_read_from_jdbc(spark_session, url, user, password, metastore_table, jd .saveAsTable(metastore_table, format=save_format, mode=save_mode) -def _parse_arguments(args: Optional[List[str]] = None): +def _parse_arguments(args: Optional[List[str]] = None) -> Any: parser = argparse.ArgumentParser(description='Spark-JDBC') parser.add_argument('-cmdType', dest='cmd_type', action='store') parser.add_argument('-url', dest='url', action='store') @@ -132,14 +153,14 @@ def _parse_arguments(args: Optional[List[str]] = None): return parser.parse_args(args=args) -def _create_spark_session(arguments) -> SparkSession: +def _create_spark_session(arguments: Any) -> SparkSession: return SparkSession.builder \ .appName(arguments.name) \ .enableHiveSupport() \ .getOrCreate() -def _run_spark(arguments): +def _run_spark(arguments: Any) -> None: # Disable dynamic allocation by default to allow num_executors to take effect. spark = _create_spark_session(arguments) diff --git a/airflow/providers/apache/spark/hooks/spark_sql.py b/airflow/providers/apache/spark/hooks/spark_sql.py index 3bc6aced1cc7..c0491dde1d60 100644 --- a/airflow/providers/apache/spark/hooks/spark_sql.py +++ b/airflow/providers/apache/spark/hooks/spark_sql.py @@ -17,6 +17,7 @@ # under the License. # import subprocess +from typing import Any, List, Optional, Union from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook @@ -57,20 +58,20 @@ class SparkSqlHook(BaseHook): # pylint: disable=too-many-arguments def __init__(self, - sql, - conf=None, - conn_id='spark_sql_default', - total_executor_cores=None, - executor_cores=None, - executor_memory=None, - keytab=None, - principal=None, - master='yarn', - name='default-name', - num_executors=None, - verbose=True, - yarn_queue='default' - ): + sql: str, + conf: Optional[str] = None, + conn_id: str = 'spark_sql_default', + total_executor_cores: Optional[int] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + keytab: Optional[str] = None, + principal: Optional[str] = None, + master: str = 'yarn', + name: str = 'default-name', + num_executors: Optional[int] = None, + verbose: bool = True, + yarn_queue: str = 'default' + ) -> None: super().__init__() self._sql = sql self._conf = conf @@ -85,12 +86,12 @@ def __init__(self, self._num_executors = num_executors self._verbose = verbose self._yarn_queue = yarn_queue - self._sp = None + self._sp: Any = None - def get_conn(self): + def get_conn(self) -> Any: pass - def _prepare_command(self, cmd): + def _prepare_command(self, cmd: Union[str, List[str]]) -> List[str]: """ Construct the spark-sql command to execute. Verbose output is enabled as default. @@ -141,7 +142,7 @@ def _prepare_command(self, cmd): return connection_cmd - def run_query(self, cmd="", **kwargs): + def run_query(self, cmd: str = "", **kwargs: Any) -> None: """ Remote Popen (actually execute the Spark-sql query) @@ -156,7 +157,7 @@ def run_query(self, cmd="", **kwargs): stderr=subprocess.STDOUT, **kwargs) - for line in iter(self._sp.stdout): + for line in iter(self._sp.stdout): # type: ignore self.log.info(line) returncode = self._sp.wait() @@ -168,7 +169,7 @@ def run_query(self, cmd="", **kwargs): ) ) - def kill(self): + def kill(self) -> None: """ Kill Spark job """ diff --git a/airflow/providers/apache/spark/hooks/spark_submit.py b/airflow/providers/apache/spark/hooks/spark_submit.py index e31fbdba29e1..0c3dd6f658b2 100644 --- a/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/airflow/providers/apache/spark/hooks/spark_submit.py @@ -20,6 +20,7 @@ import re import subprocess import time +from typing import Any, Dict, Iterator, List, Optional, Union from airflow.configuration import conf as airflow_conf from airflow.exceptions import AirflowException @@ -105,31 +106,32 @@ class SparkSubmitHook(BaseHook, LoggingMixin): # pylint: disable=too-many-arguments,too-many-locals,too-many-branches def __init__(self, - conf=None, - conn_id='spark_default', - files=None, - py_files=None, - archives=None, - driver_class_path=None, - jars=None, - java_class=None, - packages=None, - exclude_packages=None, - repositories=None, - total_executor_cores=None, - executor_cores=None, - executor_memory=None, - driver_memory=None, - keytab=None, - principal=None, - proxy_user=None, - name='default-name', - num_executors=None, - status_poll_interval=1, - application_args=None, - env_vars=None, - verbose=False, - spark_binary=None): + conf: Optional[Dict[str, Any]] = None, + conn_id: str = 'spark_default', + files: Optional[str] = None, + py_files: Optional[str] = None, + archives: Optional[str] = None, + driver_class_path: Optional[str] = None, + jars: Optional[str] = None, + java_class: Optional[str] = None, + packages: Optional[str] = None, + exclude_packages: Optional[str] = None, + repositories: Optional[str] = None, + total_executor_cores: Optional[int] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + driver_memory: Optional[str] = None, + keytab: Optional[str] = None, + principal: Optional[str] = None, + proxy_user: Optional[str] = None, + name: str = 'default-name', + num_executors: Optional[int] = None, + status_poll_interval: int = 1, + application_args: Optional[List[Any]] = None, + env_vars: Optional[Dict[str, Any]] = None, + verbose: bool = False, + spark_binary: Optional[str] = None + ) -> None: super().__init__() self._conf = conf or {} self._conn_id = conn_id @@ -155,9 +157,9 @@ def __init__(self, self._application_args = application_args self._env_vars = env_vars self._verbose = verbose - self._submit_sp = None - self._yarn_application_id = None - self._kubernetes_driver_pod = None + self._submit_sp: Optional[Any] = None + self._yarn_application_id: Optional[str] = None + self._kubernetes_driver_pod: Optional[str] = None self._spark_binary = spark_binary self._connection = self._resolve_connection() @@ -169,12 +171,12 @@ def __init__(self, self._connection['master'])) self._should_track_driver_status = self._resolve_should_track_driver_status() - self._driver_id = None - self._driver_status = None - self._spark_exit_code = None - self._env = None + self._driver_id: Optional[str] = None + self._driver_status: Optional[str] = None + self._spark_exit_code: Optional[int] = None + self._env: Optional[Dict[str, Any]] = None - def _resolve_should_track_driver_status(self): + def _resolve_should_track_driver_status(self) -> bool: """ Determines whether or not this hook should poll the spark driver status through subsequent spark-submit status requests after the initial spark-submit request @@ -183,7 +185,7 @@ def _resolve_should_track_driver_status(self): return ('spark://' in self._connection['master'] and self._connection['deploy_mode'] == 'cluster') - def _resolve_connection(self): + def _resolve_connection(self) -> Dict[str, Any]: # Build from connection master or default to yarn if not available conn_data = {'master': 'yarn', 'queue': None, @@ -220,10 +222,10 @@ def _resolve_connection(self): return conn_data - def get_conn(self): + def get_conn(self) -> Any: pass - def _get_spark_binary_path(self): + def _get_spark_binary_path(self) -> List[str]: # If the spark_home is passed then build the spark-submit executable path using # the spark_home; otherwise assume that spark-submit is present in the path to # the executing user @@ -235,7 +237,7 @@ def _get_spark_binary_path(self): return connection_cmd - def _mask_cmd(self, connection_cmd): + def _mask_cmd(self, connection_cmd: Union[str, List[str]]) -> str: # Mask any password related fields in application args with key value pair # where key contains password (case insensitive), e.g. HivePassword='abc' connection_cmd_masked = re.sub( @@ -259,7 +261,7 @@ def _mask_cmd(self, connection_cmd): return connection_cmd_masked - def _build_spark_submit_command(self, application): + def _build_spark_submit_command(self, application: str) -> List[str]: """ Construct the spark-submit command to execute. @@ -347,7 +349,7 @@ def _build_spark_submit_command(self, application): return connection_cmd - def _build_track_driver_status_command(self): + def _build_track_driver_status_command(self) -> List[str]: """ Construct the command to poll the driver status. @@ -393,7 +395,7 @@ def _build_track_driver_status_command(self): return connection_cmd - def submit(self, application="", **kwargs): + def submit(self, application: str = "", **kwargs: Any) -> None: """ Remote Popen to execute the spark-submit job @@ -415,7 +417,7 @@ def submit(self, application="", **kwargs): universal_newlines=True, **kwargs) - self._process_spark_submit_log(iter(self._submit_sp.stdout)) + self._process_spark_submit_log(iter(self._submit_sp.stdout)) # type: ignore returncode = self._submit_sp.wait() # Check spark-submit return code. In Kubernetes mode, also check the value @@ -449,7 +451,7 @@ def submit(self, application="", **kwargs): .format(self._driver_id, self._driver_status) ) - def _process_spark_submit_log(self, itr): + def _process_spark_submit_log(self, itr: Iterator[Any]) -> None: """ Processes the log files and extracts useful information out of it. @@ -498,7 +500,7 @@ def _process_spark_submit_log(self, itr): self.log.info(line) - def _process_spark_status_log(self, itr): + def _process_spark_status_log(self, itr: Iterator[Any]) -> None: """ parses the logs of the spark driver status query process @@ -520,7 +522,7 @@ def _process_spark_status_log(self, itr): if not driver_found: self._driver_status = "UNKNOWN" - def _start_driver_status_tracking(self): + def _start_driver_status_tracking(self) -> None: """ Polls the driver based on self._driver_id to get the status. Finish successfully when the status is FINISHED. @@ -566,11 +568,12 @@ def _start_driver_status_tracking(self): self.log.debug("polling status of spark driver with id %s", self._driver_id) poll_drive_status_cmd = self._build_track_driver_status_command() - status_process = subprocess.Popen(poll_drive_status_cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - bufsize=-1, - universal_newlines=True) + status_process: Any = subprocess.Popen(poll_drive_status_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + bufsize=-1, + universal_newlines=True + ) self._process_spark_status_log(iter(status_process.stdout)) returncode = status_process.wait() @@ -584,7 +587,7 @@ def _start_driver_status_tracking(self): .format(max_missed_job_status_reports, returncode) ) - def _build_spark_driver_kill_command(self): + def _build_spark_driver_kill_command(self) -> List[str]: """ Construct the spark-submit command to kill a driver. :return: full command to kill a driver @@ -604,13 +607,14 @@ def _build_spark_driver_kill_command(self): connection_cmd += ["--master", self._connection['master']] # The actual kill command - connection_cmd += ["--kill", self._driver_id] + if self._driver_id: + connection_cmd += ["--kill", self._driver_id] self.log.debug("Spark-Kill cmd: %s", connection_cmd) return connection_cmd - def on_kill(self): + def on_kill(self) -> None: """ Kill Spark submit command """ diff --git a/airflow/providers/apache/spark/operators/spark_jdbc.py b/airflow/providers/apache/spark/operators/spark_jdbc.py index 807f26749221..04a6de7cf4ef 100644 --- a/airflow/providers/apache/spark/operators/spark_jdbc.py +++ b/airflow/providers/apache/spark/operators/spark_jdbc.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. # +from typing import Any, Dict, Optional + from airflow.providers.apache.spark.hooks.spark_jdbc import SparkJDBCHook from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator from airflow.utils.decorators import apply_defaults @@ -119,36 +121,36 @@ class SparkJDBCOperator(SparkSubmitOperator): # pylint: disable=too-many-arguments,too-many-locals @apply_defaults def __init__(self, - spark_app_name='airflow-spark-jdbc', - spark_conn_id='spark-default', - spark_conf=None, - spark_py_files=None, - spark_files=None, - spark_jars=None, - num_executors=None, - executor_cores=None, - executor_memory=None, - driver_memory=None, - verbose=False, - keytab=None, - principal=None, - cmd_type='spark_to_jdbc', - jdbc_table=None, - jdbc_conn_id='jdbc-default', - jdbc_driver=None, - metastore_table=None, - jdbc_truncate=False, - save_mode=None, - save_format=None, - batch_size=None, - fetch_size=None, - num_partitions=None, - partition_column=None, - lower_bound=None, - upper_bound=None, - create_table_column_types=None, - *args, - **kwargs): + spark_app_name: str = 'airflow-spark-jdbc', + spark_conn_id: str = 'spark-default', + spark_conf: Optional[Dict[str, Any]] = None, + spark_py_files: Optional[str] = None, + spark_files: Optional[str] = None, + spark_jars: Optional[str] = None, + num_executors: Optional[int] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + driver_memory: Optional[str] = None, + verbose: bool = False, + principal: Optional[str] = None, + keytab: Optional[str] = None, + cmd_type: str = 'spark_to_jdbc', + jdbc_table: Optional[str] = None, + jdbc_conn_id: str = 'jdbc-default', + jdbc_driver: Optional[str] = None, + metastore_table: Optional[str] = None, + jdbc_truncate: bool = False, + save_mode: Optional[str] = None, + save_format: Optional[str] = None, + batch_size: Optional[int] = None, + fetch_size: Optional[int] = None, + num_partitions: Optional[int] = None, + partition_column: Optional[str] = None, + lower_bound: Optional[str] = None, + upper_bound: Optional[str] = None, + create_table_column_types: Optional[str] = None, + *args: Any, + **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._spark_app_name = spark_app_name self._spark_conn_id = spark_conn_id @@ -178,12 +180,23 @@ def __init__(self, self._lower_bound = lower_bound self._upper_bound = upper_bound self._create_table_column_types = create_table_column_types + self._hook: Optional[SparkJDBCHook] = None - def execute(self, context): + def execute(self, context: Dict[str, Any]) -> None: """ Call the SparkSubmitHook to run the provided spark job """ - self._hook = SparkJDBCHook( + if self._hook is None: + self._hook = self._get_hook() + self._hook.submit_jdbc_job() + + def on_kill(self) -> None: + if self._hook is None: + self._hook = self._get_hook() + self._hook.on_kill() + + def _get_hook(self) -> SparkJDBCHook: + return SparkJDBCHook( spark_app_name=self._spark_app_name, spark_conn_id=self._spark_conn_id, spark_conf=self._spark_conf, @@ -213,7 +226,3 @@ def execute(self, context): upper_bound=self._upper_bound, create_table_column_types=self._create_table_column_types ) - self._hook.submit_jdbc_job() - - def on_kill(self): - self._hook.on_kill() diff --git a/airflow/providers/apache/spark/operators/spark_sql.py b/airflow/providers/apache/spark/operators/spark_sql.py index 7230919790f4..3d0e5a859ff1 100644 --- a/airflow/providers/apache/spark/operators/spark_sql.py +++ b/airflow/providers/apache/spark/operators/spark_sql.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. # +from typing import Any, Dict, Optional + from airflow.models import BaseOperator from airflow.providers.apache.spark.hooks.spark_sql import SparkSqlHook from airflow.utils.decorators import apply_defaults @@ -63,21 +65,21 @@ class SparkSqlOperator(BaseOperator): # pylint: disable=too-many-arguments @apply_defaults def __init__(self, - sql, - conf=None, - conn_id='spark_sql_default', - total_executor_cores=None, - executor_cores=None, - executor_memory=None, - keytab=None, - principal=None, - master='yarn', - name='default-name', - num_executors=None, - verbose=True, - yarn_queue='default', - *args, - **kwargs): + sql: str, + conf: Optional[str] = None, + conn_id: str = 'spark_sql_default', + total_executor_cores: Optional[int] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + keytab: Optional[str] = None, + principal: Optional[str] = None, + master: str = 'yarn', + name: str = 'default-name', + num_executors: Optional[int] = None, + verbose: bool = True, + yarn_queue: str = 'default', + *args: Any, + **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._sql = sql self._conf = conf @@ -92,27 +94,34 @@ def __init__(self, self._num_executors = num_executors self._verbose = verbose self._yarn_queue = yarn_queue - self._hook = None + self._hook: Optional[SparkSqlHook] = None - def execute(self, context): + def execute(self, context: Dict[str, Any]) -> None: """ Call the SparkSqlHook to run the provided sql query """ - self._hook = SparkSqlHook(sql=self._sql, - conf=self._conf, - conn_id=self._conn_id, - total_executor_cores=self._total_executor_cores, - executor_cores=self._executor_cores, - executor_memory=self._executor_memory, - keytab=self._keytab, - principal=self._principal, - name=self._name, - num_executors=self._num_executors, - master=self._master, - verbose=self._verbose, - yarn_queue=self._yarn_queue - ) + if self._hook is None: + self._hook = self._get_hook() self._hook.run_query() - def on_kill(self): + def on_kill(self) -> None: + if self._hook is None: + self._hook = self._get_hook() self._hook.kill() + + def _get_hook(self) -> SparkSqlHook: + """ Get SparkSqlHook """ + return SparkSqlHook(sql=self._sql, + conf=self._conf, + conn_id=self._conn_id, + total_executor_cores=self._total_executor_cores, + executor_cores=self._executor_cores, + executor_memory=self._executor_memory, + keytab=self._keytab, + principal=self._principal, + name=self._name, + num_executors=self._num_executors, + master=self._master, + verbose=self._verbose, + yarn_queue=self._yarn_queue + ) diff --git a/airflow/providers/apache/spark/operators/spark_submit.py b/airflow/providers/apache/spark/operators/spark_submit.py index 393362463c52..02a448b264a9 100644 --- a/airflow/providers/apache/spark/operators/spark_submit.py +++ b/airflow/providers/apache/spark/operators/spark_submit.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. # +from typing import Any, Dict, List, Optional + from airflow.models import BaseOperator from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook from airflow.settings import WEB_COLORS @@ -101,34 +103,34 @@ class SparkSubmitOperator(BaseOperator): # pylint: disable=too-many-arguments,too-many-locals @apply_defaults def __init__(self, - application='', - conf=None, - conn_id='spark_default', - files=None, - py_files=None, - archives=None, - driver_class_path=None, - jars=None, - java_class=None, - packages=None, - exclude_packages=None, - repositories=None, - total_executor_cores=None, - executor_cores=None, - executor_memory=None, - driver_memory=None, - keytab=None, - principal=None, - proxy_user=None, - name='airflow-spark', - num_executors=None, - status_poll_interval=1, - application_args=None, - env_vars=None, - verbose=False, - spark_binary=None, - *args, - **kwargs): + application: str = '', + conf: Optional[Dict[str, Any]] = None, + conn_id: str = 'spark_default', + files: Optional[str] = None, + py_files: Optional[str] = None, + archives: Optional[str] = None, + driver_class_path: Optional[str] = None, + jars: Optional[str] = None, + java_class: Optional[str] = None, + packages: Optional[str] = None, + exclude_packages: Optional[str] = None, + repositories: Optional[str] = None, + total_executor_cores: Optional[int] = None, + executor_cores: Optional[int] = None, + executor_memory: Optional[str] = None, + driver_memory: Optional[str] = None, + keytab: Optional[str] = None, + principal: Optional[str] = None, + proxy_user: Optional[str] = None, + name: str = 'arrow-spark', + num_executors: Optional[int] = None, + status_poll_interval: int = 1, + application_args: Optional[List[Any]] = None, + env_vars: Optional[Dict[str, Any]] = None, + verbose: bool = False, + spark_binary: Optional[str] = None, + *args: Any, + **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._application = application self._conf = conf @@ -155,14 +157,24 @@ def __init__(self, self._env_vars = env_vars self._verbose = verbose self._spark_binary = spark_binary - self._hook = None + self._hook: Optional[SparkSubmitHook] = None self._conn_id = conn_id - def execute(self, context): + def execute(self, context: Dict[str, Any]) -> None: """ Call the SparkSubmitHook to run the provided spark job """ - self._hook = SparkSubmitHook( + if self._hook is None: + self._hook = self._get_hook() + self._hook.submit(self._application) + + def on_kill(self) -> None: + if self._hook is None: + self._hook = self._get_hook() + self._hook.on_kill() + + def _get_hook(self) -> SparkSubmitHook: + return SparkSubmitHook( conf=self._conf, conn_id=self._conn_id, files=self._files, @@ -189,7 +201,3 @@ def execute(self, context): verbose=self._verbose, spark_binary=self._spark_binary ) - self._hook.submit(self._application) - - def on_kill(self): - self._hook.on_kill() diff --git a/airflow/providers/apache/sqoop/hooks/sqoop.py b/airflow/providers/apache/sqoop/hooks/sqoop.py index 9510b7b934d9..6b849356282b 100644 --- a/airflow/providers/apache/sqoop/hooks/sqoop.py +++ b/airflow/providers/apache/sqoop/hooks/sqoop.py @@ -22,6 +22,7 @@ """ import subprocess from copy import deepcopy +from typing import Any, Dict, List, Optional from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook @@ -53,9 +54,14 @@ class SqoopHook(BaseHook): :type properties: dict """ - def __init__(self, conn_id='sqoop_default', verbose=False, - num_mappers=None, hcatalog_database=None, - hcatalog_table=None, properties=None): + def __init__(self, + conn_id: str = 'sqoop_default', + verbose: bool = False, + num_mappers: Optional[int] = None, + hcatalog_database: Optional[str] = None, + hcatalog_table: Optional[str] = None, + properties: Optional[Dict[str, Any]] = None + ) -> None: # No mutable types in the default parameters super().__init__() self.conn = self.get_connection(conn_id) @@ -73,12 +79,12 @@ def __init__(self, conn_id='sqoop_default', verbose=False, self.properties = properties or {} self.log.info("Using connection to: %s:%s/%s", self.conn.host, self.conn.port, self.conn.schema) - self.sub_process = None + self.sub_process: Any = None - def get_conn(self): + def get_conn(self) -> Any: return self.conn - def cmd_mask_password(self, cmd_orig): + def cmd_mask_password(self, cmd_orig: List[str]) -> List[str]: """ Mask command password for safety """ @@ -90,7 +96,7 @@ def cmd_mask_password(self, cmd_orig): self.log.debug("No password in sqoop cmd") return cmd - def popen(self, cmd, **kwargs): + def popen(self, cmd: List[str], **kwargs: Any) -> None: """ Remote Popen @@ -106,7 +112,7 @@ def popen(self, cmd, **kwargs): stderr=subprocess.STDOUT, **kwargs) - for line in iter(self.sub_process.stdout): + for line in iter(self.sub_process.stdout): # type: ignore self.log.info(line.strip()) self.sub_process.wait() @@ -116,7 +122,7 @@ def popen(self, cmd, **kwargs): if self.sub_process.returncode: raise AirflowException("Sqoop command failed: {}".format(masked_cmd)) - def _prepare_command(self, export=False): + def _prepare_command(self, export: bool = False) -> List[str]: sqoop_cmd_type = "export" if export else "import" connection_cmd = ["sqoop", sqoop_cmd_type] @@ -158,7 +164,7 @@ def _prepare_command(self, export=False): return connection_cmd @staticmethod - def _get_export_format_argument(file_type='text'): + def _get_export_format_argument(file_type: str = 'text') -> List[str]: if file_type == "avro": return ["--as-avrodatafile"] elif file_type == "sequence": @@ -171,8 +177,9 @@ def _get_export_format_argument(file_type='text'): raise AirflowException("Argument file_type should be 'avro', " "'sequence', 'parquet' or 'text'.") - def _import_cmd(self, target_dir, append, file_type, split_by, direct, - driver, extra_import_options): + def _import_cmd(self, target_dir: Optional[str], append: bool, file_type: str, + split_by: Optional[str], direct: Optional[bool], + driver: Any, extra_import_options: Any) -> List[str]: cmd = self._prepare_command(export=False) @@ -202,9 +209,18 @@ def _import_cmd(self, target_dir, append, file_type, split_by, direct, return cmd # pylint: disable=too-many-arguments - def import_table(self, table, target_dir=None, append=False, file_type="text", - columns=None, split_by=None, where=None, direct=False, - driver=None, extra_import_options=None): + def import_table(self, + table: str, + target_dir: Optional[str] = None, + append: bool = False, + file_type: str = "text", + columns: Optional[str] = None, + split_by: Optional[str] = None, + where: Optional[str] = None, + direct: bool = False, + driver: Any = None, + extra_import_options: Optional[Dict[str, Any]] = None + ) -> Any: """ Imports table from remote location to target dir. Arguments are copies of direct sqoop command line arguments @@ -235,8 +251,15 @@ def import_table(self, table, target_dir=None, append=False, file_type="text", self.popen(cmd) - def import_query(self, query, target_dir, append=False, file_type="text", - split_by=None, direct=None, driver=None, extra_import_options=None): + def import_query(self, query: str, + target_dir: Optional[str] = None, + append: bool = False, + file_type: str = "text", + split_by: Optional[str] = None, + direct: Optional[bool] = None, + driver: Optional[Any] = None, + extra_import_options: Optional[Dict[str, Any]] = None + ) -> Any: """ Imports a specific query from the rdbms to hdfs @@ -259,11 +282,21 @@ def import_query(self, query, target_dir, append=False, file_type="text", self.popen(cmd) # pylint: disable=too-many-arguments - def _export_cmd(self, table, export_dir, input_null_string, - input_null_non_string, staging_table, clear_staging_table, - enclosed_by, escaped_by, input_fields_terminated_by, - input_lines_terminated_by, input_optionally_enclosed_by, - batch, relaxed_isolation, extra_export_options): + def _export_cmd(self, table: str, + export_dir: Optional[str] = None, + input_null_string: Optional[str] = None, + input_null_non_string: Optional[str] = None, + staging_table: Optional[str] = None, + clear_staging_table: bool = False, + enclosed_by: Optional[str] = None, + escaped_by: Optional[str] = None, + input_fields_terminated_by: Optional[str] = None, + input_lines_terminated_by: Optional[str] = None, + input_optionally_enclosed_by: Optional[str] = None, + batch: bool = False, + relaxed_isolation: bool = False, + extra_export_options: Optional[Dict[str, Any]] = None + ) -> List[str]: cmd = self._prepare_command(export=True) @@ -316,13 +349,22 @@ def _export_cmd(self, table, export_dir, input_null_string, return cmd # pylint: disable=too-many-arguments - def export_table(self, table, export_dir, input_null_string, - input_null_non_string, staging_table, - clear_staging_table, enclosed_by, - escaped_by, input_fields_terminated_by, - input_lines_terminated_by, - input_optionally_enclosed_by, batch, - relaxed_isolation, extra_export_options=None): + def export_table(self, + table: str, + export_dir: Optional[str] = None, + input_null_string: Optional[str] = None, + input_null_non_string: Optional[str] = None, + staging_table: Optional[str] = None, + clear_staging_table: bool = False, + enclosed_by: Optional[str] = None, + escaped_by: Optional[str] = None, + input_fields_terminated_by: Optional[str] = None, + input_lines_terminated_by: Optional[str] = None, + input_optionally_enclosed_by: Optional[str] = None, + batch: bool = False, + relaxed_isolation: bool = False, + extra_export_options: Optional[Dict[str, Any]] = None + ) -> None: """ Exports Hive table to remote location. Arguments are copies of direct sqoop command line Arguments diff --git a/airflow/providers/apache/sqoop/operators/sqoop.py b/airflow/providers/apache/sqoop/operators/sqoop.py index ec8419ff3a61..0db5c0903f36 100644 --- a/airflow/providers/apache/sqoop/operators/sqoop.py +++ b/airflow/providers/apache/sqoop/operators/sqoop.py @@ -22,6 +22,7 @@ """ import os import signal +from typing import Any, Dict, Optional, Tuple from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -95,40 +96,41 @@ class SqoopOperator(BaseOperator): # pylint: disable=too-many-arguments,too-many-locals @apply_defaults def __init__(self, - conn_id='sqoop_default', - cmd_type='import', - table=None, - query=None, - target_dir=None, - append=None, - file_type='text', - columns=None, - num_mappers=None, - split_by=None, - where=None, - export_dir=None, - input_null_string=None, - input_null_non_string=None, - staging_table=None, - clear_staging_table=False, - enclosed_by=None, - escaped_by=None, - input_fields_terminated_by=None, - input_lines_terminated_by=None, - input_optionally_enclosed_by=None, - batch=False, - direct=False, - driver=None, - verbose=False, - relaxed_isolation=False, - properties=None, - hcatalog_database=None, - hcatalog_table=None, - create_hcatalog_table=False, - extra_import_options=None, - extra_export_options=None, - *args, - **kwargs): + conn_id: str = 'sqoop_default', + cmd_type: str = 'import', + table: Optional[str] = None, + query: Optional[str] = None, + target_dir: Optional[str] = None, + append: bool = False, + file_type: str = 'text', + columns: Optional[str] = None, + num_mappers: Optional[int] = None, + split_by: Optional[str] = None, + where: Optional[str] = None, + export_dir: Optional[str] = None, + input_null_string: Optional[str] = None, + input_null_non_string: Optional[str] = None, + staging_table: Optional[str] = None, + clear_staging_table: bool = False, + enclosed_by: Optional[str] = None, + escaped_by: Optional[str] = None, + input_fields_terminated_by: Optional[str] = None, + input_lines_terminated_by: Optional[str] = None, + input_optionally_enclosed_by: Optional[str] = None, + batch: bool = False, + direct: bool = False, + driver: Optional[Any] = None, + verbose: bool = False, + relaxed_isolation: bool = False, + properties: Optional[Dict[str, Any]] = None, + hcatalog_database: Optional[str] = None, + hcatalog_table: Optional[str] = None, + create_hcatalog_table: bool = False, + extra_import_options: Optional[Dict[str, Any]] = None, + extra_export_options: Optional[Dict[str, Any]] = None, + *args: Tuple[str, Any], + **kwargs: Any + ) -> None: super().__init__(*args, **kwargs) self.conn_id = conn_id self.cmd_type = cmd_type @@ -162,24 +164,18 @@ def __init__(self, self.properties = properties self.extra_import_options = extra_import_options or {} self.extra_export_options = extra_export_options or {} - self.hook = None + self.hook: Optional[SqoopHook] = None - def execute(self, context): + def execute(self, context: Dict[str, Any]) -> None: """ Execute sqoop job """ - self.hook = SqoopHook( - conn_id=self.conn_id, - verbose=self.verbose, - num_mappers=self.num_mappers, - hcatalog_database=self.hcatalog_database, - hcatalog_table=self.hcatalog_table, - properties=self.properties - ) + if self.hook is None: + self.hook = self._get_hook() if self.cmd_type == 'export': self.hook.export_table( - table=self.table, + table=self.table, # type: ignore export_dir=self.export_dir, input_null_string=self.input_null_string, input_null_non_string=self.input_null_non_string, @@ -234,6 +230,18 @@ def execute(self, context): else: raise AirflowException("cmd_type should be 'import' or 'export'") - def on_kill(self): + def on_kill(self) -> None: + if self.hook is None: + self.hook = self._get_hook() self.log.info('Sending SIGTERM signal to bash process group') os.killpg(os.getpgid(self.hook.sub_process.pid), signal.SIGTERM) + + def _get_hook(self) -> SqoopHook: + return SqoopHook( + conn_id=self.conn_id, + verbose=self.verbose, + num_mappers=self.num_mappers, + hcatalog_database=self.hcatalog_database, + hcatalog_table=self.hcatalog_table, + properties=self.properties + ) diff --git a/airflow/providers/discord/hooks/discord_webhook.py b/airflow/providers/discord/hooks/discord_webhook.py index 031943b28ce0..1d5f199663ee 100644 --- a/airflow/providers/discord/hooks/discord_webhook.py +++ b/airflow/providers/discord/hooks/discord_webhook.py @@ -62,10 +62,11 @@ def __init__(self, avatar_url: Optional[str] = None, tts: bool = False, proxy: Optional[str] = None, - *args, - **kwargs) -> None: + *args: Any, + **kwargs: Any + ) -> None: super().__init__(*args, **kwargs) - self.http_conn_id = http_conn_id + self.http_conn_id: Any = http_conn_id self.webhook_endpoint = self._get_webhook_endpoint(http_conn_id, webhook_endpoint) self.message = message self.username = username diff --git a/airflow/providers/http/hooks/http.py b/airflow/providers/http/hooks/http.py index 6c29178c82b2..e8afce18749a 100644 --- a/airflow/providers/http/hooks/http.py +++ b/airflow/providers/http/hooks/http.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Any, Callable, Dict, Optional, Union import requests import tenacity @@ -40,20 +41,20 @@ class HttpHook(BaseHook): def __init__( self, - method='POST', - http_conn_id='http_default', - auth_type=HTTPBasicAuth, + method: str = 'POST', + http_conn_id: str = 'http_default', + auth_type: Any = HTTPBasicAuth, ) -> None: super().__init__() self.http_conn_id = http_conn_id self.method = method.upper() - self.base_url = None - self._retry_obj = None - self.auth_type = auth_type + self.base_url: str = "" + self._retry_obj: Callable[..., Any] + self.auth_type: Any = auth_type # headers may be passed through directly or in the "extra" field in the connection # definition - def get_conn(self, headers=None): + def get_conn(self, headers: Optional[Dict[Any, Any]] = None) -> requests.Session: """ Returns http session for use with requests @@ -61,6 +62,7 @@ def get_conn(self, headers=None): :type headers: dict """ session = requests.Session() + if self.http_conn_id: conn = self.get_connection(self.http_conn_id) @@ -86,7 +88,12 @@ def get_conn(self, headers=None): return session - def run(self, endpoint, data=None, headers=None, extra_options=None, **request_kwargs): + def run(self, + endpoint: str, + data: Optional[Union[Dict[str, Any], str]] = None, + headers: Optional[Dict[str, Any]] = None, + extra_options: Optional[Dict[str, Any]] = None, + **request_kwargs: Any) -> Any: r""" Performs the request @@ -113,7 +120,6 @@ def run(self, endpoint, data=None, headers=None, extra_options=None, **request_k else: url = (self.base_url or '') + (endpoint or '') - req = None if self.method == 'GET': # GET uses params req = requests.Request(self.method, @@ -139,7 +145,7 @@ def run(self, endpoint, data=None, headers=None, extra_options=None, **request_k self.log.info("Sending '%s' to url: %s", self.method, url) return self.run_and_check(session, prepped_request, extra_options) - def check_response(self, response): + def check_response(self, response: requests.Response) -> None: """ Checks the status code and raise an AirflowException exception on non 2XX or 3XX status codes @@ -154,7 +160,11 @@ def check_response(self, response): self.log.error(response.text) raise AirflowException(str(response.status_code) + ":" + response.reason) - def run_and_check(self, session, prepped_request, extra_options): + def run_and_check(self, + session: requests.Session, + prepped_request: requests.PreparedRequest, + extra_options: Dict[Any, Any] + ) -> Any: """ Grabs extra options like timeout and actually runs the request, checking for the result @@ -188,7 +198,8 @@ def run_and_check(self, session, prepped_request, extra_options): self.log.warning('%s Tenacity will retry to execute the operation', ex) raise ex - def run_with_advanced_retry(self, _retry_args, *args, **kwargs): + def run_with_advanced_retry(self, _retry_args: Dict[Any, Any], + *args: Any, **kwargs: Any) -> Any: """ Runs Hook.run() with a Tenacity decorator attached to it. This is useful for connectors which might be disturbed by intermittent issues and should not diff --git a/airflow/providers/http/operators/http.py b/airflow/providers/http/operators/http.py index 87df78df6998..a317e4bee028 100644 --- a/airflow/providers/http/operators/http.py +++ b/airflow/providers/http/operators/http.py @@ -64,11 +64,11 @@ def __init__(self, method: str = 'POST', data: Any = None, headers: Optional[Dict[str, str]] = None, - response_check: Optional[Callable] = None, + response_check: Optional[Callable[..., Any]] = None, extra_options: Optional[Dict[str, Any]] = None, http_conn_id: str = 'http_default', log_response: bool = False, - *args, **kwargs) -> None: + *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.http_conn_id = http_conn_id self.method = method @@ -81,7 +81,7 @@ def __init__(self, if kwargs.get('xcom_push') is not None: raise AirflowException("'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead") - def execute(self, context): + def execute(self, context: Dict[str, Any]) -> Any: http = HttpHook(self.method, http_conn_id=self.http_conn_id) self.log.info("Calling HTTP method") diff --git a/airflow/providers/http/sensors/http.py b/airflow/providers/http/sensors/http.py index f4eb00a53791..05a9bffe4340 100644 --- a/airflow/providers/http/sensors/http.py +++ b/airflow/providers/http/sensors/http.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional from airflow.exceptions import AirflowException from airflow.operators.python import PythonOperator @@ -74,10 +74,12 @@ def __init__(self, endpoint: str, http_conn_id: str = 'http_default', method: str = 'GET', - request_params: Optional[Dict] = None, - headers: Optional[Dict] = None, - response_check: Optional[Callable] = None, - extra_options: Optional[Dict] = None, *args, **kwargs): + request_params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, + response_check: Optional[Callable[..., Any]] = None, + extra_options: Optional[Dict[str, Any]] = None, + *args: Any, **kwargs: Any + ) -> None: super().__init__(*args, **kwargs) self.endpoint = endpoint self.http_conn_id = http_conn_id @@ -90,7 +92,7 @@ def __init__(self, method=method, http_conn_id=http_conn_id) - def poke(self, context: Dict): + def poke(self, context: Dict[Any, Any]) -> bool: self.log.info('Poking: %s', self.endpoint) try: response = self.hook.run(self.endpoint, diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py index 4ab1faed5e5f..e3aeb9d17ec5 100644 --- a/airflow/utils/decorators.py +++ b/airflow/utils/decorators.py @@ -21,13 +21,14 @@ import os from copy import copy from functools import wraps +from typing import Any, Callable, Dict from airflow.exceptions import AirflowException signature = inspect.signature -def apply_defaults(func): +def apply_defaults(func: Callable[..., Any]) -> Any: """ Function decorator that Looks for an argument named "default_args", and fills the unspecified arguments from it. @@ -46,17 +47,17 @@ def apply_defaults(func): non_optional_args = { name for (name, param) in sig_cache.parameters.items() if param.default == param.empty and - param.name != 'self' and - param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)} + param.name != 'self' and + param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)} @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: from airflow.models.dag import DagContext if len(args) > 1: raise AirflowException( "Use keyword arguments when initializing operators") - dag_args = {} - dag_params = {} + dag_args: Dict[str, Any] = {} + dag_params: Dict[str, Any] = {} dag = kwargs.get('dag', None) or DagContext.get_current_dag() if dag: @@ -89,6 +90,7 @@ def wrapper(*args, **kwargs): result = func(*args, **kwargs) return result + return wrapper