Skip to content

Commit

Permalink
fix lint; add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Abacn committed Aug 8, 2024
1 parent bae780b commit 81ab460
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 18 deletions.
10 changes: 10 additions & 0 deletions sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from apache_beam.io.gcp.internal.clients import bigquery as bigquery_api
from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultMatcher
from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultStreamingMatcher
from apache_beam.metrics.metric import Lineage
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.runners.dataflow.test_dataflow_runner import TestDataflowRunner
Expand Down Expand Up @@ -508,6 +509,9 @@ def test_load_job_id_used(self):
| "GetJobs" >> beam.Map(lambda x: x[1])

assert_that(jobs, equal_to([job_reference]), label='CheckJobProjectIds')
self.assertSetEqual(
Lineage.query(p.result.metrics(), Lineage.SINK),
set(["bigquery:project1.dataset1.table1"]))

def test_load_job_id_use_for_copy_job(self):
destination = 'project1:dataset1.table1'
Expand Down Expand Up @@ -560,6 +564,9 @@ def test_load_job_id_use_for_copy_job(self):
job_reference
]),
label='CheckCopyJobProjectIds')
self.assertSetEqual(
Lineage.query(p.result.metrics(), Lineage.SINK),
set(["bigquery:project1.dataset1.table1"]))

@mock.patch('time.sleep')
def test_wait_for_load_job_completion(self, sleep_mock):
Expand Down Expand Up @@ -717,6 +724,9 @@ def test_multiple_partition_files(self):
copy_jobs | "CountCopyJobs" >> combiners.Count.Globally(),
equal_to([6]),
label='CheckCopyJobCount')
self.assertSetEqual(
Lineage.query(p.result.metrics(), Lineage.SINK),
set(["bigquery:project1.dataset1.table1"]))

@parameterized.expand([
param(write_disposition=BigQueryDisposition.WRITE_TRUNCATE),
Expand Down
18 changes: 6 additions & 12 deletions sdks/python/apache_beam/io/gcp/bigquery_schema_tools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,11 @@ def test_bad_schema_public_api_export(self, get_table):
with self.assertRaisesRegex(ValueError,
"Encountered an unsupported type: 'DOUBLE'"):
p = apache_beam.Pipeline()
pipeline = p | apache_beam.io.gcp.bigquery.ReadFromBigQuery(
_ = p | apache_beam.io.gcp.bigquery.ReadFromBigQuery(
table="dataset.sample_table",
method="EXPORT",
project="project",
output_type='BEAM_ROW')
pipeline

@mock.patch.object(BigQueryWrapper, 'get_table')
def test_bad_schema_public_api_direct_read(self, get_table):
Expand All @@ -159,21 +158,19 @@ def test_bad_schema_public_api_direct_read(self, get_table):
with self.assertRaisesRegex(ValueError,
"Encountered an unsupported type: 'DOUBLE'"):
p = apache_beam.Pipeline()
pipeline = p | apache_beam.io.gcp.bigquery.ReadFromBigQuery(
_ = p | apache_beam.io.gcp.bigquery.ReadFromBigQuery(
table="dataset.sample_table",
method="DIRECT_READ",
project="project",
output_type='BEAM_ROW')
pipeline

def test_unsupported_value_provider(self):
with self.assertRaisesRegex(TypeError,
'ReadFromBigQuery: table must be of type string'
'; got ValueProvider instead'):
p = apache_beam.Pipeline()
pipeline = p | apache_beam.io.gcp.bigquery.ReadFromBigQuery(
_ = p | apache_beam.io.gcp.bigquery.ReadFromBigQuery(
table=value_provider.ValueProvider(), output_type='BEAM_ROW')
pipeline

def test_unsupported_callable(self):
def filterTable(table):
Expand All @@ -185,35 +182,32 @@ def filterTable(table):
'ReadFromBigQuery: table must be of type string'
'; got a callable instead'):
p = apache_beam.Pipeline()
pipeline = p | apache_beam.io.gcp.bigquery.ReadFromBigQuery(
_ = p | apache_beam.io.gcp.bigquery.ReadFromBigQuery(
table=res, output_type='BEAM_ROW')
pipeline

def test_unsupported_query_export(self):
with self.assertRaisesRegex(
ValueError,
"Both a query and an output type of 'BEAM_ROW' were specified. "
"'BEAM_ROW' is not currently supported with queries."):
p = apache_beam.Pipeline()
pipeline = p | apache_beam.io.gcp.bigquery.ReadFromBigQuery(
_ = p | apache_beam.io.gcp.bigquery.ReadFromBigQuery(
table="project:dataset.sample_table",
method="EXPORT",
query='SELECT name FROM dataset.sample_table',
output_type='BEAM_ROW')
pipeline

def test_unsupported_query_direct_read(self):
with self.assertRaisesRegex(
ValueError,
"Both a query and an output type of 'BEAM_ROW' were specified. "
"'BEAM_ROW' is not currently supported with queries."):
p = apache_beam.Pipeline()
pipeline = p | apache_beam.io.gcp.bigquery.ReadFromBigQuery(
_ = p | apache_beam.io.gcp.bigquery.ReadFromBigQuery(
table="project:dataset.sample_table",
method="DIRECT_READ",
query='SELECT name FROM dataset.sample_table',
output_type='BEAM_ROW')
pipeline

if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
43 changes: 43 additions & 0 deletions sdks/python/apache_beam/io/gcp/bigquery_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from apache_beam.io.gcp.bigquery import TableRowJsonCoder
from apache_beam.io.gcp.bigquery import WriteToBigQuery
from apache_beam.io.gcp.bigquery import _StreamToBigQuery
from apache_beam.io.gcp.bigquery_read_internal import _BigQueryReadSplit
from apache_beam.io.gcp.bigquery_read_internal import _JsonToDictCoder
from apache_beam.io.gcp.bigquery_read_internal import bigquery_export_destination_uri
from apache_beam.io.gcp.bigquery_tools import JSON_COMPLIANCE_ERROR
Expand All @@ -61,6 +62,7 @@
from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultMatcher
from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultStreamingMatcher
from apache_beam.io.gcp.tests.bigquery_matcher import BigQueryTableMatcher
from apache_beam.metrics.metric import Lineage
from apache_beam.options import value_provider
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
Expand All @@ -85,9 +87,11 @@
from apitools.base.py.exceptions import HttpError
from apitools.base.py.exceptions import HttpForbiddenError
from google.cloud import bigquery as gcp_bigquery
from google.cloud import bigquery_storage_v1 as bq_storage
from google.api_core import exceptions
except ImportError:
gcp_bigquery = None
bq_storage = None
HttpError = None
HttpForbiddenError = None
exceptions = None
Expand Down Expand Up @@ -460,6 +464,8 @@ def test_create_temp_dataset_exception(self, exception_type, error_message):
self.assertIn(error_message, exc.exception.args[0])

@parameterized.expand([
# read without exception
param(responses=[], expected_retries=0),
# first attempt returns a Http 500 blank error and retries
# second attempt returns a Http 408 blank error and retries,
# third attempt passes
Expand Down Expand Up @@ -540,6 +546,9 @@ def store_callback(unused_request):
# metadata (numBytes), and once to retrieve the table's schema
# Any additional calls are retries
self.assertEqual(expected_retries, mock_get_table.call_count - 2)
self.assertSetEqual(
Lineage.query(p.result.metrics(), Lineage.SOURCE),
set(["bigquery:project.dataset.table"]))

@parameterized.expand([
# first attempt returns a Http 429 with transient reason and retries
Expand Down Expand Up @@ -719,6 +728,40 @@ def test_read_export_exception(self, exception_type, error_message):
mock_query_job.assert_called()
self.assertIn(error_message, exc.exception.args[0])

def test_read_direct_lineage(self):
with mock.patch.object(bigquery_tools.BigQueryWrapper,
'_bigquery_client'),\
mock.patch.object(bq_storage.BigQueryReadClient,
'create_read_session'),\
beam.Pipeline() as p:

_ = p | ReadFromBigQuery(
method=ReadFromBigQuery.Method.DIRECT_READ,
table='project:dataset.table')
self.assertSetEqual(
Lineage.query(p.result.metrics(), Lineage.SOURCE),
set(["bigquery:project.dataset.table"]))

def test_read_all_lineage(self):
with mock.patch.object(_BigQueryReadSplit, '_export_files') as export, \
beam.Pipeline() as p:

export.return_value = (None, [])

_ = (
p
| beam.Create([
beam.io.ReadFromBigQueryRequest(table='project1:dataset1.table1'),
beam.io.ReadFromBigQueryRequest(table='project2:dataset2.table2')
])
| beam.io.ReadAllFromBigQuery(gcs_location='gs://bucket/tmp'))
self.assertSetEqual(
Lineage.query(p.result.metrics(), Lineage.SOURCE),
set([
'bigquery:project1.dataset1.table1',
'bigquery:project2.dataset2.table2'
]))


@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
class TestBigQuerySink(unittest.TestCase):
Expand Down
13 changes: 7 additions & 6 deletions sdks/python/apache_beam/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from typing import Union

from apache_beam.metrics import cells
from apache_beam.metrics.execution import MetricResult
from apache_beam.metrics.execution import MetricUpdater
from apache_beam.metrics.metricbase import Counter
from apache_beam.metrics.metricbase import Distribution
Expand Down Expand Up @@ -224,7 +225,7 @@ def matches(
def query(
self,
filter: Optional['MetricsFilter'] = None
) -> Dict[str, List['MetricResults']]:
) -> Dict[str, List['MetricResult']]:
"""Queries the runner for existing user metrics that match the filter.
It should return a dictionary, with lists of each kind of metric, and
Expand Down Expand Up @@ -321,24 +322,24 @@ class Lineage:
SINK: Metrics.string_set(LINEAGE_NAMESPACE, SINK)
}

def __init__(self, label):
def __init__(self, label: str) -> None:
"""Create a Lineage with valid babel (:data:`~Lineage.SOURCE` or
:data:`~Lineage.SINK`)
"""
self.metric = Lineage._METRICS[label]

@classmethod
def sources(cls):
def sources(cls) -> 'Lineage':
return cls(Lineage.SOURCE)

@classmethod
def sinks(cls):
def sinks(cls) -> 'Lineage':
return cls(Lineage.SINK)

_RESERVED_CHARS = re.compile(r'[:\s.]')

@staticmethod
def wrap_segment(segment: str):
def wrap_segment(segment: str) -> str:
"""Wrap segment to valid segment name.
Specifically, If there are reserved chars (colon, whitespace, dot), escape
Expand Down Expand Up @@ -372,7 +373,7 @@ def add(
self.metric.add(self.get_fq_name(system, *segments, route=route))

@staticmethod
def query(results: MetricResults, label: str):
def query(results: MetricResults, label: str) -> Set[str]:
if not label in Lineage._METRICS:
raise ValueError("Label {} does not exist for Lineage", label)
response = results.query(
Expand Down

0 comments on commit 81ab460

Please sign in to comment.