diff --git a/.asf.yaml b/.asf.yaml index 95e7add23914..3fba2f41a85f 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -49,6 +49,7 @@ github: protected_branches: master: {} + release-2.58.0: {} release-2.57.0: {} release-2.56.0: {} release-2.55.1: {} diff --git a/.github/ACTIONS.md b/.github/ACTIONS.md index 7432a3d28b1c..aa7d3b04fa87 100644 --- a/.github/ACTIONS.md +++ b/.github/ACTIONS.md @@ -98,3 +98,7 @@ Phrases self-assign, close, or manage labels on an issue: | `.add-labels` | Add comma separated labels to the issue (e.g. `add-labels l1, 'l2 with spaces'`) | | `.remove-labels` | Remove comma separated labels to the issue (e.g. `remove-labels l1, 'l2 with spaces'`) | | `.set-labels` | Sets comma separated labels to the issue and removes any other labels (e.g. `set-labels l1, 'l2 with spaces'`) | + +## Security Model + +For information on the Beam CI security model, see https://cwiki.apache.org/confluence/display/BEAM/CI+Security+Model diff --git a/.github/ISSUE_TEMPLATE/bug.yml b/.github/ISSUE_TEMPLATE/bug.yml index 67f8b21445dc..a2fbeae1319c 100644 --- a/.github/ISSUE_TEMPLATE/bug.yml +++ b/.github/ISSUE_TEMPLATE/bug.yml @@ -50,6 +50,7 @@ body: - "Priority: 2 (default / most bugs should be filed as P2)" - "Priority: 1 (data loss / total loss of function)" - "Priority: 0 (outage / urgent vulnerability)" + default: 1 validations: required: true - type: checkboxes @@ -68,6 +69,7 @@ body: - label: "Component: Beam playground" - label: "Component: Beam katas" - label: "Component: Website" + - label: "Component: Infrastructure" - label: "Component: Spark Runner" - label: "Component: Flink Runner" - label: "Component: Samza Runner" diff --git a/.github/ISSUE_TEMPLATE/failing_test.yml b/.github/ISSUE_TEMPLATE/failing_test.yml index 44e1cd720745..4295624995fb 100644 --- a/.github/ISSUE_TEMPLATE/failing_test.yml +++ b/.github/ISSUE_TEMPLATE/failing_test.yml @@ -56,6 +56,7 @@ body: - "Priority: 2 (backlog / disabled test but we think the product is healthy)" - "Priority: 1 (unhealthy code / failing or flaky postcommit so we cannot be sure the product is healthy)" - "Priority: 0 (outage / failing precommit test impacting development)" + default: 1 validations: required: true - type: checkboxes @@ -74,6 +75,7 @@ body: - label: "Component: Beam playground" - label: "Component: Beam katas" - label: "Component: Website" + - label: "Component: Infrastructure" - label: "Component: Spark Runner" - label: "Component: Flink Runner" - label: "Component: Samza Runner" diff --git a/.github/ISSUE_TEMPLATE/feature.yml b/.github/ISSUE_TEMPLATE/feature.yml index 11234a5e1501..e47c7c0751ce 100644 --- a/.github/ISSUE_TEMPLATE/feature.yml +++ b/.github/ISSUE_TEMPLATE/feature.yml @@ -44,6 +44,7 @@ body: options: - "Priority: 3 (nice-to-have improvement)" - "Priority: 2 (default / most feature requests should be filed as P2)" + default: 1 validations: required: true - type: checkboxes @@ -62,6 +63,7 @@ body: - label: "Component: Beam playground" - label: "Component: Beam katas" - label: "Component: Website" + - label: "Component: Infrastructure" - label: "Component: Spark Runner" - label: "Component: Flink Runner" - label: "Component: Samza Runner" diff --git a/.github/ISSUE_TEMPLATE/task.yml b/.github/ISSUE_TEMPLATE/task.yml index 477b91b181be..8da74a65d8f2 100644 --- a/.github/ISSUE_TEMPLATE/task.yml +++ b/.github/ISSUE_TEMPLATE/task.yml @@ -45,6 +45,7 @@ body: - "Priority: 3 (nice-to-have improvement)" - "Priority: 2 (default / most normal work should be filed as P2)" - "Priority: 1 (urgent / mostly reserved for critical bugs)" + default: 1 validations: required: true - type: checkboxes @@ -63,6 +64,7 @@ body: - label: "Component: Beam playground" - label: "Component: Beam katas" - label: "Component: Website" + - label: "Component: Infrastructure" - label: "Component: Spark Runner" - label: "Component: Flink Runner" - label: "Component: Samza Runner" diff --git a/.github/actions/setup-environment-action/action.yml b/.github/actions/setup-environment-action/action.yml index 912aca0e16f9..3a14112cf0ef 100644 --- a/.github/actions/setup-environment-action/action.yml +++ b/.github/actions/setup-environment-action/action.yml @@ -52,9 +52,10 @@ runs: - name: Setup Gradle uses: gradle/gradle-build-action@v2 with: - cache-read-only: ${{ inputs.disable-cache }} + cache-disabled: ${{ inputs.disable-cache }} - name: Install Go if: ${{ inputs.go-version != '' }} - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: - go-version: ${{ inputs.go-version == 'default' && '1.21' || inputs.go-version }} # never set patch, to get latest patch releases. + go-version: ${{ inputs.go-version == 'default' && '1.22' || inputs.go-version }} # never set patch, to get latest patch releases. + cache-dependency-path: $${{ inputs.disable-cache && '' || 'sdks/go.sum' }} diff --git a/.github/gh-actions-self-hosted-runners/arc/images/Dockerfile b/.github/gh-actions-self-hosted-runners/arc/images/Dockerfile index 7fa8f631729c..3737492f6179 100644 --- a/.github/gh-actions-self-hosted-runners/arc/images/Dockerfile +++ b/.github/gh-actions-self-hosted-runners/arc/images/Dockerfile @@ -31,7 +31,7 @@ RUN curl -OL https://nodejs.org/dist/v18.16.0/node-v18.16.0-linux-x64.tar.xz && mv /usr/local/node-v18.16.0-linux-x64 /usr/local/node ENV PATH="${PATH}:/usr/local/node/bin" #Install Go -ARG go_version=1.22.4 +ARG go_version=1.22.5 RUN curl -OL https://go.dev/dl/go${go_version}.linux-amd64.tar.gz && \ tar -C /usr/local -xzf go${go_version}.linux-amd64.tar.gz && \ rm go${go_version}.linux-amd64.tar.gz diff --git a/.github/issue-rules.yml b/.github/issue-rules.yml index b01a22dafd78..c4acb2945575 100644 --- a/.github/issue-rules.yml +++ b/.github/issue-rules.yml @@ -46,6 +46,8 @@ rules: addLabels: ['katas'] - contains: '[x] Component: Website' addLabels: ['website'] +- contains: '[x] Component: Infrastructure' + addLabels: ['infra'] - contains: '[x] Component: Spark' addLabels: ['spark'] - contains: '[x] Component: Flink' diff --git a/.github/trigger_files/IO_Iceberg_Integration_Tests.json b/.github/trigger_files/IO_Iceberg_Integration_Tests.json index a03c067d2c4e..3f63c0c9975f 100644 --- a/.github/trigger_files/IO_Iceberg_Integration_Tests.json +++ b/.github/trigger_files/IO_Iceberg_Integration_Tests.json @@ -1,3 +1,4 @@ { - "comment": "Modify this file in a trivial way to cause this test suite to run" + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 2 } diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Streaming.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Streaming.json index b970762c8397..b60f5c4cc3c8 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Streaming.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Streaming.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test" + "modification": 0 } diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Samza.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Samza.json index a937ef2fc07d..b60f5c4cc3c8 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Samza.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Samza.json @@ -1,5 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", - "https://github.com/apache/beam/pull/31270": "re-adds specialized Samza translation of Redistribute" + "modification": 0 } diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark3_Streaming.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark3_Streaming.json index b970762c8397..b60f5c4cc3c8 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark3_Streaming.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark3_Streaming.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test" + "modification": 0 } diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json index e3d6056a5de9..b60f5c4cc3c8 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 0 } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Direct.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Direct.json index b970762c8397..38ae94aee2fa 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Direct.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Direct.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test" + "https://github.com/apache/beam/pull/31761": "noting that PR #31761 should run this test" } diff --git a/.github/workflows/IO_Iceberg_Integration_Tests.yml b/.github/workflows/IO_Iceberg_Integration_Tests.yml index d7c9c6d95746..20d1f4bb60fd 100644 --- a/.github/workflows/IO_Iceberg_Integration_Tests.yml +++ b/.github/workflows/IO_Iceberg_Integration_Tests.yml @@ -72,12 +72,6 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Run IcebergIO Integration Test uses: ./.github/actions/gradle-command-self-hosted-action with: diff --git a/.github/workflows/IO_Iceberg_Performance_Tests.yml b/.github/workflows/IO_Iceberg_Performance_Tests.yml index 40bd43aa17ed..976fbedeadad 100644 --- a/.github/workflows/IO_Iceberg_Performance_Tests.yml +++ b/.github/workflows/IO_Iceberg_Performance_Tests.yml @@ -72,12 +72,6 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Run IcebergIO Performance Test uses: ./.github/actions/gradle-command-self-hosted-action with: diff --git a/.github/workflows/IO_Iceberg_Unit_Tests.yml b/.github/workflows/IO_Iceberg_Unit_Tests.yml index 1787756ab68b..0d72b0da8597 100644 --- a/.github/workflows/IO_Iceberg_Unit_Tests.yml +++ b/.github/workflows/IO_Iceberg_Unit_Tests.yml @@ -91,12 +91,6 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: run IcebergIO build script uses: ./.github/actions/gradle-command-self-hosted-action with: diff --git a/.github/workflows/beam_CancelStaleDataflowJobs.yml b/.github/workflows/beam_CancelStaleDataflowJobs.yml index b568b91dd34a..e8dfee525e31 100644 --- a/.github/workflows/beam_CancelStaleDataflowJobs.yml +++ b/.github/workflows/beam_CancelStaleDataflowJobs.yml @@ -73,12 +73,6 @@ jobs: uses: ./.github/actions/setup-environment-action with: disable-cache: true - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: run cancel stale dataflow jobs uses: ./.github/actions/gradle-command-self-hosted-action with: diff --git a/.github/workflows/beam_CleanUpGCPResources.yml b/.github/workflows/beam_CleanUpGCPResources.yml index cf77dd68a92e..29d602357d6f 100644 --- a/.github/workflows/beam_CleanUpGCPResources.yml +++ b/.github/workflows/beam_CleanUpGCPResources.yml @@ -73,13 +73,8 @@ jobs: uses: ./.github/actions/setup-environment-action with: disable-cache: true - - name: Authenticate on GCP - id: auth - uses: google-github-actions/setup-gcloud@v0 - with: - service_account_email: ${{ secrets.GCP_SA_EMAIL }} - service_account_key: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} + - name: Setup gcloud + uses: google-github-actions/setup-gcloud@v2 - name: Install gcloud bigtable cli run: gcloud components install cbt - name: run cleanup GCP resources diff --git a/.github/workflows/beam_CleanUpPrebuiltSDKImages.yml b/.github/workflows/beam_CleanUpPrebuiltSDKImages.yml index 14e5f9783b61..20de04854282 100644 --- a/.github/workflows/beam_CleanUpPrebuiltSDKImages.yml +++ b/.github/workflows/beam_CleanUpPrebuiltSDKImages.yml @@ -73,12 +73,6 @@ jobs: uses: ./.github/actions/setup-environment-action with: disable-cache: true - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: GCloud Docker credential helper run: | gcloud auth configure-docker gcr.io && \ diff --git a/.github/workflows/beam_PerformanceTests_AvroIOIT_HDFS.yml b/.github/workflows/beam_PerformanceTests_AvroIOIT_HDFS.yml index 5fde65bdccd1..6dfb0634d7f5 100644 --- a/.github/workflows/beam_PerformanceTests_AvroIOIT_HDFS.yml +++ b/.github/workflows/beam_PerformanceTests_AvroIOIT_HDFS.yml @@ -71,12 +71,6 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Set k8s access uses: ./.github/actions/setup-k8s-access with: diff --git a/.github/workflows/beam_PerformanceTests_Cdap.yml b/.github/workflows/beam_PerformanceTests_Cdap.yml index 2848b555185d..3b32129b761c 100644 --- a/.github/workflows/beam_PerformanceTests_Cdap.yml +++ b/.github/workflows/beam_PerformanceTests_Cdap.yml @@ -71,12 +71,6 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Set k8s access uses: ./.github/actions/setup-k8s-access with: diff --git a/.github/workflows/beam_PerformanceTests_Compressed_TextIOIT_HDFS.yml b/.github/workflows/beam_PerformanceTests_Compressed_TextIOIT_HDFS.yml index 58b96f7e3526..071577ade0f1 100644 --- a/.github/workflows/beam_PerformanceTests_Compressed_TextIOIT_HDFS.yml +++ b/.github/workflows/beam_PerformanceTests_Compressed_TextIOIT_HDFS.yml @@ -71,12 +71,6 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Set k8s access uses: ./.github/actions/setup-k8s-access with: diff --git a/.github/workflows/beam_PerformanceTests_HadoopFormat.yml b/.github/workflows/beam_PerformanceTests_HadoopFormat.yml index 03c2732ce4ff..a30ef9aab510 100644 --- a/.github/workflows/beam_PerformanceTests_HadoopFormat.yml +++ b/.github/workflows/beam_PerformanceTests_HadoopFormat.yml @@ -71,12 +71,6 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Set k8s access uses: ./.github/actions/setup-k8s-access with: diff --git a/.github/workflows/beam_PerformanceTests_JDBC.yml b/.github/workflows/beam_PerformanceTests_JDBC.yml index 2305a779a09c..c65be423d156 100644 --- a/.github/workflows/beam_PerformanceTests_JDBC.yml +++ b/.github/workflows/beam_PerformanceTests_JDBC.yml @@ -71,12 +71,6 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Set k8s access uses: ./.github/actions/setup-k8s-access with: diff --git a/.github/workflows/beam_PerformanceTests_Kafka_IO.yml b/.github/workflows/beam_PerformanceTests_Kafka_IO.yml index 2b620043c37b..39e49db09196 100644 --- a/.github/workflows/beam_PerformanceTests_Kafka_IO.yml +++ b/.github/workflows/beam_PerformanceTests_Kafka_IO.yml @@ -73,12 +73,6 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Set k8s access uses: ./.github/actions/setup-k8s-access with: diff --git a/.github/workflows/beam_PerformanceTests_ManyFiles_TextIOIT_HDFS.yml b/.github/workflows/beam_PerformanceTests_ManyFiles_TextIOIT_HDFS.yml index 63b70cb810e9..6329c8ce8f5e 100644 --- a/.github/workflows/beam_PerformanceTests_ManyFiles_TextIOIT_HDFS.yml +++ b/.github/workflows/beam_PerformanceTests_ManyFiles_TextIOIT_HDFS.yml @@ -71,12 +71,6 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Set k8s access uses: ./.github/actions/setup-k8s-access with: diff --git a/.github/workflows/beam_PerformanceTests_MongoDBIO_IT.yml b/.github/workflows/beam_PerformanceTests_MongoDBIO_IT.yml index 655c61ea373a..e2ce7ed94e5b 100644 --- a/.github/workflows/beam_PerformanceTests_MongoDBIO_IT.yml +++ b/.github/workflows/beam_PerformanceTests_MongoDBIO_IT.yml @@ -71,12 +71,6 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Set k8s access uses: ./.github/actions/setup-k8s-access with: diff --git a/.github/workflows/beam_PerformanceTests_ParquetIOIT_HDFS.yml b/.github/workflows/beam_PerformanceTests_ParquetIOIT_HDFS.yml index d76dd3061ac8..929d214bd676 100644 --- a/.github/workflows/beam_PerformanceTests_ParquetIOIT_HDFS.yml +++ b/.github/workflows/beam_PerformanceTests_ParquetIOIT_HDFS.yml @@ -71,12 +71,6 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Set k8s access uses: ./.github/actions/setup-k8s-access with: diff --git a/.github/workflows/beam_PerformanceTests_SingleStoreIO.yml b/.github/workflows/beam_PerformanceTests_SingleStoreIO.yml index f0f7ec1d373f..f842d31fba0c 100644 --- a/.github/workflows/beam_PerformanceTests_SingleStoreIO.yml +++ b/.github/workflows/beam_PerformanceTests_SingleStoreIO.yml @@ -72,12 +72,6 @@ jobs: comment_phrase: ${{ matrix.job_phrase }} github_token: ${{ secrets.GITHUB_TOKEN }} github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Set k8s access uses: ./.github/actions/setup-k8s-access with: diff --git a/.github/workflows/beam_PerformanceTests_SparkReceiver_IO.yml b/.github/workflows/beam_PerformanceTests_SparkReceiver_IO.yml index 5872cc332417..ec3bc1a23fd9 100644 --- a/.github/workflows/beam_PerformanceTests_SparkReceiver_IO.yml +++ b/.github/workflows/beam_PerformanceTests_SparkReceiver_IO.yml @@ -71,12 +71,6 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Set k8s access uses: ./.github/actions/setup-k8s-access with: diff --git a/.github/workflows/beam_PerformanceTests_TFRecordIOIT_HDFS.yml b/.github/workflows/beam_PerformanceTests_TFRecordIOIT_HDFS.yml index 103775b034ff..f6c0ddece2f6 100644 --- a/.github/workflows/beam_PerformanceTests_TFRecordIOIT_HDFS.yml +++ b/.github/workflows/beam_PerformanceTests_TFRecordIOIT_HDFS.yml @@ -73,12 +73,6 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Set k8s access uses: ./.github/actions/setup-k8s-access with: diff --git a/.github/workflows/beam_PerformanceTests_XmlIOIT_HDFS.yml b/.github/workflows/beam_PerformanceTests_XmlIOIT_HDFS.yml index 5808ddad6572..362cbdaedd64 100644 --- a/.github/workflows/beam_PerformanceTests_XmlIOIT_HDFS.yml +++ b/.github/workflows/beam_PerformanceTests_XmlIOIT_HDFS.yml @@ -71,12 +71,6 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Set k8s access uses: ./.github/actions/setup-k8s-access with: diff --git a/.github/workflows/beam_PerformanceTests_xlang_KafkaIO_Python.yml b/.github/workflows/beam_PerformanceTests_xlang_KafkaIO_Python.yml index d4204ab09451..8abc8a3199dd 100644 --- a/.github/workflows/beam_PerformanceTests_xlang_KafkaIO_Python.yml +++ b/.github/workflows/beam_PerformanceTests_xlang_KafkaIO_Python.yml @@ -73,12 +73,6 @@ jobs: uses: ./.github/actions/setup-environment-action with: python-version: default - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Set k8s access uses: ./.github/actions/setup-k8s-access with: diff --git a/.github/workflows/beam_PostCommit_Java_InfluxDbIO_IT.yml b/.github/workflows/beam_PostCommit_Java_InfluxDbIO_IT.yml index bfbfc6c04119..353f86e082c5 100644 --- a/.github/workflows/beam_PostCommit_Java_InfluxDbIO_IT.yml +++ b/.github/workflows/beam_PostCommit_Java_InfluxDbIO_IT.yml @@ -74,12 +74,6 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Set k8s access uses: ./.github/actions/setup-k8s-access with: diff --git a/.github/workflows/beam_PostCommit_Java_SingleStoreIO_IT.yml b/.github/workflows/beam_PostCommit_Java_SingleStoreIO_IT.yml index b6662f3d6595..3925bd924714 100644 --- a/.github/workflows/beam_PostCommit_Java_SingleStoreIO_IT.yml +++ b/.github/workflows/beam_PostCommit_Java_SingleStoreIO_IT.yml @@ -76,25 +76,19 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Set k8s access uses: ./.github/actions/setup-k8s-access with: cluster_name: io-datastores k8s_namespace: ${{ matrix.job_name }}-${{ github.run_id }} remove_finalizer: memsqlclusters.memsql.com/sdb-cluster - - name: Install Singlestore operator + - name: Install SingleStore operator run: | kubectl apply -f ${{github.workspace}}/.test-infra/kubernetes/singlestore/sdb-rbac.yaml kubectl apply -f ${{github.workspace}}/.test-infra/kubernetes/singlestore/sdb-cluster-crd.yaml kubectl apply -f ${{github.workspace}}/.test-infra/kubernetes/singlestore/sdb-operator.yaml kubectl wait --for=condition=Ready pod -l name=sdb-operator --timeout=120s - - name: Install Singlestore cluster + - name: Install SingleStore cluster id: install_singlestore run: | kubectl apply -f ${{github.workspace}}/.test-infra/kubernetes/singlestore/sdb-cluster.yaml diff --git a/.github/workflows/beam_PreCommit_Java_Solace_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Solace_IO_Direct.yml index 40a533b42a3b..5aeaaec11dec 100644 --- a/.github/workflows/beam_PreCommit_Java_Solace_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Solace_IO_Direct.yml @@ -95,7 +95,7 @@ jobs: - name: run Solace IO IT script uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :sdks:java:io:solace:integrationTest + gradle-command: :sdks:java:io:solace:integrationTest --info arguments: | -PdisableSpotlessCheck=true \ -PdisableCheckStyle=true \ diff --git a/.github/workflows/beam_StressTests_Java_KafkaIO.yml b/.github/workflows/beam_StressTests_Java_KafkaIO.yml index e84c49f01478..9e4550338992 100644 --- a/.github/workflows/beam_StressTests_Java_KafkaIO.yml +++ b/.github/workflows/beam_StressTests_Java_KafkaIO.yml @@ -71,12 +71,6 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - name: Set k8s access uses: ./.github/actions/setup-k8s-access with: diff --git a/.github/workflows/build_release_candidate.yml b/.github/workflows/build_release_candidate.yml index eafbf369fe88..f944ce90c9f1 100644 --- a/.github/workflows/build_release_candidate.yml +++ b/.github/workflows/build_release_candidate.yml @@ -51,8 +51,6 @@ on: env: DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} - GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} - GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} jobs: publish_java_artifacts: @@ -186,6 +184,7 @@ jobs: uses: ./.github/actions/setup-environment-action with: python-version: default + disable-cache: true - name: Import GPG key id: import_gpg uses: crazy-max/ghaction-import-gpg@111c56156bcc6918c056dbef52164cfa583dc549 @@ -435,8 +434,6 @@ jobs: - uses: actions/setup-go@v5 with: go-version: '1.22' - cache-dependency-path: | - sdks/go.sum - name: Import GPG key id: import_gpg uses: crazy-max/ghaction-import-gpg@111c56156bcc6918c056dbef52164cfa583dc549 diff --git a/.github/workflows/build_runner_image.yml b/.github/workflows/build_runner_image.yml index 0492622f8847..0f17a9073daf 100644 --- a/.github/workflows/build_runner_image.yml +++ b/.github/workflows/build_runner_image.yml @@ -41,14 +41,6 @@ jobs: uses: actions/checkout@v4 with: ref: ${{ github.event.pull_request.head.sha }} - - name: Authenticate on GCP - if: github.ref == 'refs/heads/master' - uses: google-github-actions/setup-gcloud@v0 - with: - service_account_email: ${{ secrets.GCP_SA_EMAIL }} - service_account_key: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - export_default_credentials: true - name: GCloud Docker credential helper run: | gcloud auth configure-docker ${{env.docker_registry}} diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml index 5142b0b22c30..1275b38b9d23 100644 --- a/.github/workflows/build_wheels.yml +++ b/.github/workflows/build_wheels.yml @@ -45,7 +45,7 @@ jobs: check_env_variables: timeout-minutes: 5 name: "Check environment variables" - runs-on: ubuntu-latest + runs-on: [self-hosted, ubuntu-20.04, main] env: EVENT_NAME: ${{ github.event_name }} PY_VERSIONS_FULL: "cp38-* cp39-* cp310-* cp311-* cp312-*" @@ -59,8 +59,8 @@ jobs: run: "./scripts/ci/ci_check_are_gcp_variables_set.sh" id: check_gcp_variables env: - GCP_SA_EMAIL: ${{ secrets.GCP_SA_EMAIL }} - GCP_SA_KEY: ${{ secrets.GCP_SA_KEY }} + GCP_SA_EMAIL: "not used by self hosted runner" + GCP_SA_KEY: "not used by self hosted runner" GCP_PYTHON_WHEELS_BUCKET: ${{ secrets.GCP_PYTHON_WHEELS_BUCKET }} GCP_PROJECT_ID: "not-needed-here" GCP_REGION: "not-needed-here" @@ -80,7 +80,7 @@ jobs: echo "py-versions-full=$PY_VERSIONS_FULL" >> $GITHUB_OUTPUT build_source: - runs-on: ubuntu-latest + runs-on: [self-hosted, ubuntu-20.04, main] name: Build python source distribution outputs: is_rc: ${{ steps.is_rc.outputs.is_rc }} @@ -190,14 +190,9 @@ jobs: needs: - build_source - check_env_variables - runs-on: ubuntu-latest + runs-on: [self-hosted, ubuntu-20.04, main] if: needs.check_env_variables.outputs.gcp-variables-set == 'true' && github.event_name != 'pull_request' steps: - - name: Authenticate on GCP - uses: google-github-actions/setup-gcloud@v0 - with: - service_account_email: ${{ secrets.GCP_SA_EMAIL }} - service_account_key: ${{ secrets.GCP_SA_KEY }} - name: Remove existing files on GCS bucket run: gsutil rm -r ${{ env.GCP_PATH }} || true @@ -206,7 +201,7 @@ jobs: needs: - prepare_gcs - check_env_variables - runs-on: ubuntu-latest + runs-on: [self-hosted, ubuntu-20.04, main] if: needs.check_env_variables.outputs.gcp-variables-set == 'true' steps: - name: Download compressed sources from artifacts @@ -215,11 +210,6 @@ jobs: with: name: source_zip path: source/ - - name: Authenticate on GCP - uses: google-github-actions/setup-gcloud@v0 - with: - service_account_email: ${{ secrets.GCP_SA_EMAIL }} - service_account_key: ${{ secrets.GCP_SA_KEY }} - name: Copy sources to GCS bucket run: gsutil cp -r -a public-read source/* ${{ env.GCP_PATH }} @@ -230,19 +220,20 @@ jobs: - build_source env: CIBW_ARCHS_LINUX: ${{matrix.arch}} - runs-on: ${{ matrix.os_python.os }} + runs-on: ${{ matrix.os_python.runner }} + timeout-minutes: 480 strategy: matrix: os_python: [ - {"os": "ubuntu-latest", "python": "${{ needs.check_env_variables.outputs.py-versions-full }}" }, + {"os": "ubuntu-20.04", "runner": [self-hosted, ubuntu-20.04, main], "python": "${{ needs.check_env_variables.outputs.py-versions-full }}" }, # Temporarily pin to macos-13 because macos-latest breaks this build # TODO(https://github.com/apache/beam/issues/31114) - {"os": "macos-13", "python": "${{ needs.check_env_variables.outputs.py-versions-test }}" }, - {"os": "windows-latest", "python": "${{ needs.check_env_variables.outputs.py-versions-test }}" }, + {"os": "macos-13", "runner": "macos-13", "python": "${{ needs.check_env_variables.outputs.py-versions-test }}" }, + {"os": "windows-latest", "runner": "windows-latest", "python": "${{ needs.check_env_variables.outputs.py-versions-test }}" }, ] arch: [auto] include: - - os_python: {"os": "ubuntu-latest", "python": "${{ needs.check_env_variables.outputs.py-versions-test }}" } + - os_python: {"os": "ubuntu-20.04", "runner": [self-hosted, ubuntu-20.04, main], "python": "${{ needs.check_env_variables.outputs.py-versions-test }}" } arch: aarch64 steps: - name: Download python source distribution from artifacts @@ -324,16 +315,16 @@ jobs: needs: - build_wheels - check_env_variables - runs-on: ubuntu-latest + runs-on: [self-hosted, ubuntu-20.04, main] if: needs.check_env_variables.outputs.gcp-variables-set == 'true' && github.event_name != 'pull_request' strategy: matrix: # Temporarily pin to macos-13 because macos-latest breaks this build # TODO(https://github.com/apache/beam/issues/31114) - os : [ubuntu-latest, macos-13, windows-latest] + os : [ubuntu-20.04, macos-13, windows-latest] arch: [auto] include: - - os: "ubuntu-latest" + - os: ubuntu-20.04 arch: aarch64 steps: - name: Download wheels from artifacts @@ -342,11 +333,6 @@ jobs: with: name: wheelhouse-${{ matrix.os }}${{ (matrix.arch == 'aarch64' && '-aarch64') || '' }} path: wheelhouse/ - - name: Authenticate on GCP - uses: google-github-actions/setup-gcloud@v0 - with: - service_account_email: ${{ secrets.GCP_SA_EMAIL }} - service_account_key: ${{ secrets.GCP_SA_KEY }} - name: Copy wheels to GCS bucket run: gsutil cp -r -a public-read wheelhouse/* ${{ env.GCP_PATH }} - name: Create github action information file on GCS bucket @@ -375,14 +361,9 @@ jobs: needs: - upload_wheels_to_gcs - check_env_variables - runs-on: ubuntu-latest + runs-on: [self-hosted, ubuntu-20.04, main] if: needs.check_env_variables.outputs.gcp-variables-set == 'true' && github.event_name != 'pull_request' steps: - - name: Authenticate on GCP - uses: google-github-actions/setup-gcloud@v0 - with: - service_account_email: ${{ secrets.GCP_SA_EMAIL }} - service_account_key: ${{ secrets.GCP_SA_KEY }} - name: List file on Google Cloud Storage Bucket run: gsutil ls "${{ env.GCP_PATH }}*" @@ -393,7 +374,7 @@ jobs: needs: - build_source - build_wheels - runs-on: ubuntu-latest + runs-on: [self-hosted, ubuntu-20.04, main] timeout-minutes: 60 if: github.repository_owner == 'apache' && github.event_name == 'schedule' steps: diff --git a/.github/workflows/go_tests.yml b/.github/workflows/go_tests.yml index db30bac68ec4..6818e92bc677 100644 --- a/.github/workflows/go_tests.yml +++ b/.github/workflows/go_tests.yml @@ -43,11 +43,10 @@ jobs: uses: actions/checkout@v4 with: fetch-depth: 2 - - uses: actions/setup-go@v5 + - name: Setup environment + uses: ./.github/actions/setup-environment-action with: - go-version: '1.22' - cache-dependency-path: | - sdks/go.sum + go-version: default - name: Delete old coverage run: "cd sdks && rm -rf .coverage.txt || :" - name: Run coverage diff --git a/.github/workflows/java_tests.yml b/.github/workflows/java_tests.yml index 0e9d862f91a3..1d6441b24681 100644 --- a/.github/workflows/java_tests.yml +++ b/.github/workflows/java_tests.yml @@ -162,11 +162,10 @@ jobs: fail-fast: false matrix: os: [[self-hosted, ubuntu-20.04, main], windows-latest] + # TODO(https://github.com/apache/beam/issues/31848) run on Dataflow after fixes credential on macOS/win GHA runner if: | - needs.check_gcp_variables.outputs.gcp-variables-set == 'true' && ( - (github.event_name == 'push' || github.event_name == 'schedule') || + needs.check_gcp_variables.outputs.gcp-variables-set == 'true' && (github.event_name == 'workflow_dispatch' && github.event.inputs.runDataflow == 'true') - ) steps: - name: Check out code uses: actions/checkout@v4 @@ -179,12 +178,10 @@ jobs: java-version: 11 go-version: default - name: Authenticate on GCP - uses: google-github-actions/setup-gcloud@v0 + uses: google-github-actions/auth@v1 with: - service_account_email: ${{ secrets.GCP_SA_EMAIL }} - service_account_key: ${{ secrets.GCP_SA_KEY }} + credentials_json: ${{ secrets.GCP_SA_KEY }} project_id: ${{ secrets.GCP_PROJECT_ID }} - export_default_credentials: true - name: Run WordCount uses: ./.github/actions/gradle-command-self-hosted-action with: diff --git a/.github/workflows/local_env_tests.yml b/.github/workflows/local_env_tests.yml index adfef66b8591..ae2f159710d9 100644 --- a/.github/workflows/local_env_tests.yml +++ b/.github/workflows/local_env_tests.yml @@ -46,12 +46,11 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 + - name: Setup environment + uses: ./.github/actions/setup-environment-action with: - go-version: '1.22' - - uses: actions/setup-python@v5 - with: - python-version: '3.8' + go-version: default + python-version: default - name: "Installing local env dependencies" run: "sudo ./local-env-setup.sh" id: local_env_install_ubuntu @@ -64,12 +63,11 @@ jobs: runs-on: macos-latest steps: - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version: '1.22' - - uses: actions/setup-python@v5 + - name: Setup environment + uses: ./.github/actions/setup-environment-action with: - python-version: '3.8' + go-version: default + python-version: default - name: "Installing local env dependencies" run: "./local-env-setup.sh" id: local_env_install_mac diff --git a/.github/workflows/playground_backend_precommit.yml b/.github/workflows/playground_backend_precommit.yml index 4c45547f4698..79517e705c27 100644 --- a/.github/workflows/playground_backend_precommit.yml +++ b/.github/workflows/playground_backend_precommit.yml @@ -60,7 +60,7 @@ jobs: sudo apt-get install sbt --yes sudo wget https://codeload.github.com/spotify/scio.g8/zip/7c1ba7c1651dfd70976028842e721da4107c0d6d -O scio.g8.zip && unzip scio.g8.zip && mv scio.g8-7c1ba7c1651dfd70976028842e721da4107c0d6d /opt/scio.g8 - name: Set up Cloud SDK and its components - uses: google-github-actions/setup-gcloud@v0 + uses: google-github-actions/setup-gcloud@v2 with: install_components: 'beta,cloud-datastore-emulator' version: '${{ env.DATASTORE_EMULATOR_VERSION }}' diff --git a/.github/workflows/python_tests.yml b/.github/workflows/python_tests.yml index efc5af84c4bc..a65b26645533 100644 --- a/.github/workflows/python_tests.yml +++ b/.github/workflows/python_tests.yml @@ -153,6 +153,8 @@ jobs: python_wordcount_dataflow: name: 'Python Wordcount Dataflow' + # TODO(https://github.com/apache/beam/issues/31848) run on Dataflow after fixes credential on macOS/win GHA runner + if: (github.event_name == 'workflow_dispatch' && github.event.inputs.runDataflow == 'true') needs: - build_python_sdk_source runs-on: ${{ matrix.os }} @@ -175,12 +177,11 @@ jobs: name: python_sdk_source path: apache-beam-source - name: Authenticate on GCP - uses: google-github-actions/setup-gcloud@v0 + id: auth + uses: google-github-actions/auth@v1 with: - service_account_email: ${{ secrets.GCP_SA_EMAIL }} - service_account_key: ${{ secrets.GCP_SA_KEY }} + credentials_json: ${{ secrets.GCP_SA_KEY }} project_id: ${{ secrets.GCP_PROJECT_ID }} - export_default_credentials: true - name: Install requirements working-directory: ./sdks/python run: pip install setuptools --upgrade && pip install -e ".[gcp]" diff --git a/.github/workflows/run_perf_alert_tool.yml b/.github/workflows/run_perf_alert_tool.yml index 1f623571acde..4bb5df41dcfb 100644 --- a/.github/workflows/run_perf_alert_tool.yml +++ b/.github/workflows/run_perf_alert_tool.yml @@ -40,12 +40,6 @@ jobs: uses: actions/setup-python@v5 with: python-version: 3.8 - - name: Authenticate on GCP - if: github.event_name != 'pull_request' - uses: google-github-actions/setup-gcloud@v0 - with: - service_account_key: ${{ secrets.GCP_SA_KEY }} - export_default_credentials: true - name: Install Apache Beam working-directory: ./sdks/python run: pip install -e .[gcp,test] diff --git a/.github/workflows/typescript_tests.yml b/.github/workflows/typescript_tests.yml index 0fdcfb070a22..1b45ea67b5c6 100644 --- a/.github/workflows/typescript_tests.yml +++ b/.github/workflows/typescript_tests.yml @@ -147,12 +147,10 @@ jobs: pip install 'pandas>=1.0,<1.5' pip install -e ".[gcp]" - name: Authenticate on GCP - uses: google-github-actions/setup-gcloud@v0 + uses: google-github-actions/auth@v1 with: - service_account_email: ${{ secrets.GCP_SA_EMAIL }} - service_account_key: ${{ secrets.GCP_SA_KEY }} + credentials_json: ${{ secrets.GCP_SA_KEY }} project_id: ${{ secrets.GCP_PROJECT_ID }} - export_default_credentials: true - run: npm ci working-directory: ./sdks/typescript - run: npm run build diff --git a/.test-infra/mock-apis/poetry.lock b/.test-infra/mock-apis/poetry.lock index b36baff7a74b..98985df7ea4a 100644 --- a/.test-infra/mock-apis/poetry.lock +++ b/.test-infra/mock-apis/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "beautifulsoup4" @@ -188,19 +188,18 @@ files = [ [[package]] name = "setuptools" -version = "68.2.2" +version = "70.0.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-68.2.2-py3-none-any.whl", hash = "sha256:b454a35605876da60632df1a60f736524eb73cc47bbc9f3f1ef1b644de74fd2a"}, - {file = "setuptools-68.2.2.tar.gz", hash = "sha256:4ac1475276d2f1c48684874089fefcd83bd7162ddaafb81fac866ba0db282a87"}, + {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"}, + {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] -testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] [[package]] name = "soupsieve" diff --git a/.test-infra/tools/stale_bq_datasets_cleaner.sh b/.test-infra/tools/stale_bq_datasets_cleaner.sh index c4afabe11e9a..326000fdc754 100755 --- a/.test-infra/tools/stale_bq_datasets_cleaner.sh +++ b/.test-infra/tools/stale_bq_datasets_cleaner.sh @@ -18,7 +18,7 @@ # Deletes stale and old BQ datasets that are left after tests. # -set -exuo pipefail +set -euo pipefail PROJECT=apache-beam-testing MAX_RESULT=1500 @@ -51,7 +51,7 @@ for dataset in ${BQ_DATASETS[@]}; do # date command usage depending on OS echo "Deleted $dataset (modified `date -d @$LAST_MODIFIED`)" elif [[ $OSTYPE == "darwin"* ]]; then - echo "Deleted $dataset (modified `date -r @$LAST_MODIFIED`)" + echo "Deleted $dataset (modified `date -r $LAST_MODIFIED`)" fi else echo "Tried and failed to delete $dataset" diff --git a/CHANGES.md b/CHANGES.md index 0a620038f11e..2bc2ecc49970 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -53,7 +53,7 @@ * ([#X](https://github.com/apache/beam/issues/X)). --> -# [2.58.0] - Unreleased +# [2.59.0] - Unreleased ## Highlights @@ -63,16 +63,56 @@ ## I/Os * Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Improvements to the performance of BigqueryIO when using withPropagateSuccessfulStorageApiWrites(true) method (Java) ([#31840](https://github.com/apache/beam/pull/31840)). + +## New Features / Improvements + +* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). + +## Breaking Changes + +* X behavior was changed ([#X](https://github.com/apache/beam/issues/X)). + +## Deprecations + +* X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)). + +## Bugfixes + +* Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). + +## Security Fixes +* Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)). + +## Known Issues + +* ([#X](https://github.com/apache/beam/issues/X)). + +# [2.58.0] - Unreleased + +## Highlights + +* New highly anticipated feature X added to Python SDK ([#X](https://github.com/apache/beam/issues/X)). +* New highly anticipated feature Y added to Java SDK ([#Y](https://github.com/apache/beam/issues/Y)). + +## I/Os + +* Support for [Solace](https://solace.com/) source (`SolaceIO.Read`) added (Java) ([#31440](https://github.com/apache/beam/issues/31440)). ## New Features / Improvements * Multiple RunInference instances can now share the same model instance by setting the model_identifier parameter (Python) ([#31665](https://github.com/apache/beam/issues/31665)). +* Added options to control the number of Storage API multiplexing connections ([#31721](https://github.com/apache/beam/pull/31721)) +* [BigQueryIO] Better handling for batch Storage Write API when it hits AppendRows throughput quota ([#31837](https://github.com/apache/beam/pull/31837)) +* [IcebergIO] All specified catalog properties are passed through to the connector ([#31726](https://github.com/apache/beam/pull/31726)) * Removed a 3rd party LGPL dependency from the Go SDK ([#31765](https://github.com/apache/beam/issues/31765)). * Support for MapState and SetState when using Dataflow Runner v1 with Streaming Engine (Java) ([[#18200](https://github.com/apache/beam/issues/18200)]) ## Breaking Changes * X behavior was changed ([#X](https://github.com/apache/beam/issues/X)). +* [IcebergIO] IcebergCatalogConfig was changed to support specifying catalog properties in a key-store fashion ([#31726](https://github.com/apache/beam/pull/31726)) +* [SpannerIO] Added validation that query and table cannot be specified at the same time for SpannerIO.read(). Previously withQuery overrides withTable, if set ([#24956](https://github.com/apache/beam/issues/24956)). ## Deprecations @@ -80,6 +120,7 @@ ## Bugfixes +* [BigQueryIO] Fixed a bug in batch Storage Write API that frequently exhausted concurrent connections quota ([#31710](https://github.com/apache/beam/pull/31710)) * Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). ## Security Fixes diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index ba6279a13490..e603e49f842f 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -758,6 +758,7 @@ class BeamModulePlugin implements Plugin { // [bomupgrader] the BOM version is set by scripts/tools/bomupgrader.py. If update manually, also update // libraries-bom version on sdks/java/container/license_scripts/dep_urls_java.yaml google_cloud_platform_libraries_bom : "com.google.cloud:libraries-bom:26.39.0", + google_cloud_secret_manager : "com.google.cloud:google-cloud-secretmanager", // google_cloud_platform_libraries_bom sets version google_cloud_spanner : "com.google.cloud:google-cloud-spanner", // google_cloud_platform_libraries_bom sets version google_cloud_spanner_test : "com.google.cloud:google-cloud-spanner:$google_cloud_spanner_version:tests", google_cloud_vertexai : "com.google.cloud:google-cloud-vertexai", // google_cloud_platform_libraries_bom sets version @@ -858,6 +859,7 @@ class BeamModulePlugin implements Plugin { proto_google_cloud_firestore_v1 : "com.google.api.grpc:proto-google-cloud-firestore-v1", // google_cloud_platform_libraries_bom sets version proto_google_cloud_pubsub_v1 : "com.google.api.grpc:proto-google-cloud-pubsub-v1", // google_cloud_platform_libraries_bom sets version proto_google_cloud_pubsublite_v1 : "com.google.api.grpc:proto-google-cloud-pubsublite-v1", // google_cloud_platform_libraries_bom sets version + proto_google_cloud_secret_manager_v1 : "com.google.api.grpc:proto-google-cloud-secretmanager-v1", // google_cloud_platform_libraries_bom sets version proto_google_cloud_spanner_v1 : "com.google.api.grpc:proto-google-cloud-spanner-v1", // google_cloud_platform_libraries_bom sets version proto_google_cloud_spanner_admin_database_v1: "com.google.api.grpc:proto-google-cloud-spanner-admin-database-v1", // google_cloud_platform_libraries_bom sets version proto_google_common_protos : "com.google.api.grpc:proto-google-common-protos", // google_cloud_platform_libraries_bom sets version @@ -900,6 +902,7 @@ class BeamModulePlugin implements Plugin { testcontainers_oracle : "org.testcontainers:oracle-xe:$testcontainers_version", testcontainers_postgresql : "org.testcontainers:postgresql:$testcontainers_version", testcontainers_rabbitmq : "org.testcontainers:rabbitmq:$testcontainers_version", + testcontainers_solace : "org.testcontainers:solace:$testcontainers_version", truth : "com.google.truth:truth:1.1.5", threetenbp : "org.threeten:threetenbp:1.6.8", vendored_grpc_1_60_1 : "org.apache.beam:beam-vendor-grpc-1_60_1:0.2", @@ -2223,7 +2226,7 @@ class BeamModulePlugin implements Plugin { // This sets the whole project Go version. // The latest stable Go version can be checked at https://go.dev/dl/ - project.ext.goVersion = "go1.22.4" + project.ext.goVersion = "go1.22.5" // Minor TODO: Figure out if we can pull out the GOCMD env variable after goPrepare script // completion, and avoid this GOBIN substitution. diff --git a/contributor-docs/code-change-guide.md b/contributor-docs/code-change-guide.md index 935a2c6276c5..f0785d3509d0 100644 --- a/contributor-docs/code-change-guide.md +++ b/contributor-docs/code-change-guide.md @@ -286,14 +286,27 @@ Integration tests differ from standard pipelines in the following ways: * They have a default timeout of 15 minutes. * The pipeline options are set in the system property `beamTestPipelineOptions`. -To configure the test, you need to set the property `-DbeamTestPipelineOptions=[...]`. This property sets the runner that the test uses. - -The following example demonstrates how to run an integration test by using the command line. This example includes the options required to run the pipeline on the Dataflow runner. +To configure the test pipeline, you need to set the property `-DbeamTestPipelineOptions=[...]`. This property sets the pipeline option that the test uses, for example, ``` -DbeamTestPipelineOptions='["--runner=TestDataflowRunner","--project=mygcpproject","--region=us-central1","--stagingLocation=gs://mygcsbucket/path"]' ``` +For some projects, `beamTestPipelineOptions` is explicitly configured in `build.gradle`. +Checkout the sources of the corresponding build file for setting. For example, +in `sdks/java/io/google-cloud-platform/build.gradle`, it sets `beamTestPipelineOptions` +from project properties 'gcpProject', 'gcpTempRoot', etc, and when not assigned, +it defaults to `apache-beam-testing` GCP project. To run the test in your own project, +assign these project properties with command line: + +``` +./gradlew :sdks:java:io:google-cloud-platform:integrationTest -PgcpProject= -PgcpTempRoot= +``` + +Some other projects (e.g. `sdks/java/io/jdbc`, `sdks/java/io/kafka`) does not +assemble (overwrite) `beamTestPipelineOptions` in `build.gradle`, then just set +it explicitly with `-DbeamTestPipelineOptions='[...]'`, as aforementioned. + #### Write integration tests To set up a `TestPipeline` object in an integration test, use the following code: @@ -423,6 +436,17 @@ If you're using Dataflow Runner v2 and `sdks/java/harness` or its dependencies ( --sdkContainerImage="us.gcr.io/apache-beam-testing/beam_java11_sdk:2.49.0-custom" ``` +#### Snapshot Version Containers + +By default, a Snapshot version for an SDK under development will use the containers published to the [apache-beam-testing project's container registry](https://us.gcr.io/apache-beam-testing/github-actions). For example, the most recent snapshot container for Java 17 can be found [here](https://us.gcr.io/apache-beam-testing/github-actions/beam_java17_sdk). + +When a version is entering the [release candidate stage](https://github.com/apache/beam/blob/master/contributor-docs/release-guide.md), one final SNAPSHOT version will be published. +This SNAPSHOT version will use the final containers published on [DockerHub](https://hub.docker.com/search?q=apache%2Fbeam). + +**NOTE:** During the release process, there may be some downtime where a container is not available for use for a SNAPSHOT version. To avoid this, it is recommended to either switch to the latest SNAPSHOT version available or to use [custom containers](https://beam.apache.org/documentation/runtime/environments/#custom-containers). You should also only rely on snapshot versions for important workloads if absolutely necessary. + +Certain runners may override this snapshot behavior; for example, the Dataflow runner overrides all SNAPSHOT containers into a [single registry](https://console.cloud.google.com/gcr/images/cloud-dataflow/GLOBAL/v1beta3). The same downtime will still be incurred, however, when switching to the final container + ## Python guide The Beam Python SDK is distributed as a single wheel, which is more straightforward than the Java SDK. diff --git a/contributor-docs/release-guide.md b/contributor-docs/release-guide.md index c0e8e7c67ce7..b3d3c77d25df 100644 --- a/contributor-docs/release-guide.md +++ b/contributor-docs/release-guide.md @@ -1006,18 +1006,18 @@ write to BigQuery, and create a cluster of machines for running containers (for In comment area, type in `Run Python ReleaseCandidate` to trigger validation. * **Python Leaderboard & GameStats** - * **Get staging RC** `wget https://dist.apache.org/repos/dist/dev/beam/2.5.0/* ` + * **Get staging RC** `wget https://dist.apache.org/repos/dist/dev/beam/2.XX.0/* ` * **Verify the hashes** ``` - sha512sum -c apache-beam-2.5.0-python.tar.gz.sha512 - sha512sum -c apache-beam-2.5.0-source-release.tar.gz.sha512 + sha512sum -c apache_beam-2.XX.0-python.tar.gz.sha512 + sha512sum -c apache_beam-2.XX.0-source-release.tar.gz.sha512 ``` * **Build SDK** ``` sudo apt-get install unzip - unzip apache-beam-2.5.0-source-release.tar.gz + unzip apache_beam-2.XX.0-source-release.tar.gz python setup.py sdist ``` * **Setup virtual environment** @@ -1030,8 +1030,8 @@ write to BigQuery, and create a cluster of machines for running containers (for * **Install SDK** ``` - pip install dist/apache-beam-2.5.0.tar.gz - pip install dist/apache-beam-2.5.0.tar.gz[gcp] + pip install dist/apache_beam-2.XX.0.tar.gz + pip install dist/apache_beam-2.XX.0.tar.gz[gcp] ``` * **Setup GCP** diff --git a/dev-support/docker/Dockerfile b/dev-support/docker/Dockerfile index 5cabba8be41a..d7887561b710 100644 --- a/dev-support/docker/Dockerfile +++ b/dev-support/docker/Dockerfile @@ -78,7 +78,7 @@ RUN pip3 install distlib==0.3.1 yapf==0.29.0 pytest ### # Install Go ### -ENV DOWNLOAD_GO_VERSION=1.22.4 +ENV DOWNLOAD_GO_VERSION=1.22.5 RUN wget https://golang.org/dl/go${DOWNLOAD_GO_VERSION}.linux-amd64.tar.gz && \ tar -C /usr/local -xzf go${DOWNLOAD_GO_VERSION}.linux-amd64.tar.gz ENV GOROOT /usr/local/go diff --git a/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb b/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb index 7510831bac40..95be8b1d957c 100644 --- a/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb +++ b/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb @@ -139,41 +139,33 @@ }, "source": [ "### Authenticate with Google Cloud\n", - "This notebook reads data from Pub/Sub and Bigtable. To use your Google Cloud account, authenticate this notebook." + "This notebook reads data from Pub/Sub and Bigtable. To use your Google Cloud account, authenticate this notebook.\n", + "To prepare for this step, replace ``, ``, and `` with the appropriate values for your setup. These fields are used with Bigtable." ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "id": "Kz9sccyGBqz3" + "id": "wEXucyi2liij" }, "outputs": [], "source": [ - "from google.colab import auth\n", - "auth.authenticate_user()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nAmGgUMt48o9" - }, - "source": [ - "Replace ``, ``, and `` with the appropriate values for your setup. These fields are used with Bigtable." + "PROJECT_ID = \"\"\n", + "INSTANCE_ID = \"\"\n", + "TABLE_ID = \"\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "id": "wEXucyi2liij" + "id": "Kz9sccyGBqz3" }, "outputs": [], "source": [ - "PROJECT_ID = \"\"\n", - "INSTANCE_ID = \"\"\n", - "TABLE_ID = \"\"" + "from google.colab import auth\n", + "auth.authenticate_user(project_id=PROJECT_ID)" ] }, { @@ -879,4 +871,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/gradle.properties b/gradle.properties index 931c925c0efa..9db48fd21e8f 100644 --- a/gradle.properties +++ b/gradle.properties @@ -30,8 +30,8 @@ signing.gnupg.useLegacyGpg=true # buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy. # To build a custom Beam version make sure you change it in both places, see # https://github.com/apache/beam/issues/21302. -version=2.58.0-SNAPSHOT -sdk_version=2.58.0.dev +version=2.59.0-SNAPSHOT +sdk_version=2.59.0.dev javaVersion=1.8 diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/metrics.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/metrics.proto index 40956ddf856a..4ec189e4637f 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/metrics.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/metrics.proto @@ -187,6 +187,17 @@ message MonitoringInfoSpecs { }] }]; + // Represents a set of strings seen across bundles. + USER_SET_STRING = 21 [(monitoring_info_spec) = { + urn: "beam:metric:user:set_string:v1", + type: "beam:metrics:set_string:v1", + required_labels: ["PTRANSFORM", "NAMESPACE", "NAME"], + annotations: [{ + key: "description", + value: "URN utilized to report user metric." + }] + }]; + // General monitored state information which contains structured information // which does not fit into a typical metric format. See MonitoringTableData // for more details. @@ -557,6 +568,14 @@ message MonitoringInfoTypeUrns { PROGRESS_TYPE = 10 [(org.apache.beam.model.pipeline.v1.beam_urn) = "beam:metrics:progress:v1"]; + // Represents a set of strings. + // + // Encoding: ... + // - iter: beam:coder:iterable:v1 + // - valueX: beam:coder:stringutf8:v1 + SET_STRING_TYPE = 11 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:metrics:set_string:v1"]; + // General monitored state information which contains structured information // which does not fit into a typical metric format. See MonitoringTableData // for more details. diff --git a/playground/backend/containers/python/Dockerfile b/playground/backend/containers/python/Dockerfile index ca57bef9025f..fd7d8b7f8958 100644 --- a/playground/backend/containers/python/Dockerfile +++ b/playground/backend/containers/python/Dockerfile @@ -65,8 +65,7 @@ RUN cd /opt/playground/backend/kafka-emulator/ && tar -xvf kafka-emulator.tar && mv kafka-emulator/*.jar . && rmdir kafka-emulator/ &&\ mv beam-playground-kafka-emulator-*.jar beam-playground-kafka-emulator.jar RUN apt-get update && \ - wget http://http.us.debian.org/debian/pool/main/o/openjdk-11/openjdk-11-jre-headless_11.0.23+9-1~deb11u1_amd64.deb && \ - apt install -y ./openjdk-11-jre-headless_11.0.23+9-1~deb11u1_amd64.deb + apt install openjdk-17-jre-headless -y # Create a user group `appgroup` and a user `appuser` RUN groupadd --gid 20000 appgroup \ diff --git a/release/src/main/scripts/run_rc_validation.sh b/release/src/main/scripts/run_rc_validation.sh index 0f2bfe4aaec2..91bfa9e2f8bb 100755 --- a/release/src/main/scripts/run_rc_validation.sh +++ b/release/src/main/scripts/run_rc_validation.sh @@ -300,14 +300,14 @@ if [[ ("$python_leaderboard_direct" = true \ cd ${LOCAL_BEAM_DIR} echo "---------------------Downloading Python Staging RC----------------------------" - wget ${PYTHON_RC_DOWNLOAD_URL}/${RELEASE_VER}/python/apache-beam-${RELEASE_VER}.tar.gz - wget ${PYTHON_RC_DOWNLOAD_URL}/${RELEASE_VER}/python/apache-beam-${RELEASE_VER}.tar.gz.sha512 - if [[ ! -f apache-beam-${RELEASE_VER}.tar.gz ]]; then + wget ${PYTHON_RC_DOWNLOAD_URL}/${RELEASE_VER}/python/apache_beam-${RELEASE_VER}.tar.gz + wget ${PYTHON_RC_DOWNLOAD_URL}/${RELEASE_VER}/python/apache_beam-${RELEASE_VER}.tar.gz.sha512 + if [[ ! -f apache_beam-${RELEASE_VER}.tar.gz ]]; then { echo "Fail to download Python Staging RC files." ;exit 1; } fi echo "--------------------------Verifying Hashes------------------------------------" - sha512sum -c apache-beam-${RELEASE_VER}.tar.gz.sha512 + sha512sum -c apache_beam-${RELEASE_VER}.tar.gz.sha512 echo "--------------------------Updating ~/.m2/settings.xml-------------------------" cd ~ @@ -378,7 +378,7 @@ if [[ ("$python_leaderboard_direct" = true \ pip install --upgrade pip setuptools wheel echo "--------------------------Installing Python SDK-------------------------------" - pip install apache-beam-${RELEASE_VER}.tar.gz[gcp] + pip install apache_beam-${RELEASE_VER}.tar.gz[gcp] echo "----------------Starting Leaderboard with DirectRunner-----------------------" if [[ "$python_leaderboard_direct" = true ]]; then @@ -434,7 +434,7 @@ if [[ ("$python_leaderboard_direct" = true \ --dataset ${LEADERBOARD_DF_DATASET} \ --runner DataflowRunner \ --temp_location=${USER_GCS_BUCKET}/temp/ \ - --sdk_location apache-beam-${RELEASE_VER}.tar.gz; \ + --sdk_location apache_beam-${RELEASE_VER}.tar.gz; \ exec bash" echo "***************************************************************" @@ -509,7 +509,7 @@ if [[ ("$python_leaderboard_direct" = true \ --dataset ${GAMESTATS_DF_DATASET} \ --runner DataflowRunner \ --temp_location=${USER_GCS_BUCKET}/temp/ \ - --sdk_location apache-beam-${RELEASE_VER}.tar.gz \ + --sdk_location apache_beam-${RELEASE_VER}.tar.gz \ --fixed_window_duration ${FIXED_WINDOW_DURATION}; exec bash" echo "***************************************************************" @@ -566,14 +566,14 @@ if [[ ("$python_xlang_quickstart" = true) \ cd ${LOCAL_BEAM_DIR} echo "---------------------Downloading Python Staging RC----------------------------" - wget ${PYTHON_RC_DOWNLOAD_URL}/${RELEASE_VER}/python/apache-beam-${RELEASE_VER}.tar.gz - wget ${PYTHON_RC_DOWNLOAD_URL}/${RELEASE_VER}/python/apache-beam-${RELEASE_VER}.tar.gz.sha512 - if [[ ! -f apache-beam-${RELEASE_VER}.tar.gz ]]; then + wget ${PYTHON_RC_DOWNLOAD_URL}/${RELEASE_VER}/python/apache_beam-${RELEASE_VER}.tar.gz + wget ${PYTHON_RC_DOWNLOAD_URL}/${RELEASE_VER}/python/apache_beam-${RELEASE_VER}.tar.gz.sha512 + if [[ ! -f apache_beam-${RELEASE_VER}.tar.gz ]]; then { echo "Failed to download Python Staging RC files." ;exit 1; } fi echo "--------------------------Verifying Hashes------------------------------------" - sha512sum -c apache-beam-${RELEASE_VER}.tar.gz.sha512 + sha512sum -c apache_beam-${RELEASE_VER}.tar.gz.sha512 `which pip` install --upgrade pip `which pip` install --upgrade setuptools @@ -593,7 +593,7 @@ if [[ ("$python_xlang_quickstart" = true) \ ln -s ${LOCAL_BEAM_DIR}/sdks beam_env_${py_version}/lib/sdks echo "--------------------------Installing Python SDK-------------------------------" - pip install apache-beam-${RELEASE_VER}.tar.gz + pip install apache_beam-${RELEASE_VER}.tar.gz echo '************************************************************'; echo '* Running Python Multi-language Quickstart with DirectRunner'; @@ -672,14 +672,14 @@ if [[ ("$java_xlang_quickstart" = true) \ cd ${LOCAL_BEAM_DIR} echo "---------------------Downloading Python Staging RC----------------------------" - wget ${PYTHON_RC_DOWNLOAD_URL}/${RELEASE_VER}/python/apache-beam-${RELEASE_VER}.tar.gz - wget ${PYTHON_RC_DOWNLOAD_URL}/${RELEASE_VER}/python/apache-beam-${RELEASE_VER}.tar.gz.sha512 - if [[ ! -f apache-beam-${RELEASE_VER}.tar.gz ]]; then + wget ${PYTHON_RC_DOWNLOAD_URL}/${RELEASE_VER}/python/apache_beam-${RELEASE_VER}.tar.gz + wget ${PYTHON_RC_DOWNLOAD_URL}/${RELEASE_VER}/python/apache_beam-${RELEASE_VER}.tar.gz.sha512 + if [[ ! -f apache_beam-${RELEASE_VER}.tar.gz ]]; then { echo "Failed to download Python Staging RC files." ;exit 1; } fi echo "--------------------------Verifying Hashes------------------------------------" - sha512sum -c apache-beam-${RELEASE_VER}.tar.gz.sha512 + sha512sum -c apache_beam-${RELEASE_VER}.tar.gz.sha512 `which pip` install --upgrade pip `which pip` install --upgrade setuptools @@ -699,7 +699,7 @@ if [[ ("$java_xlang_quickstart" = true) \ ln -s ${LOCAL_BEAM_DIR}/sdks beam_env_${py_version}/lib/sdks echo "--------------------------Installing Python SDK-------------------------------" - pip install apache-beam-${RELEASE_VER}.tar.gz[dataframe] + pip install apache_beam-${RELEASE_VER}.tar.gz[dataframe] # Deacrivating in the main shell. We will reactivate the virtual environment new shells # for the expansion service and the job server. @@ -768,14 +768,14 @@ if [[ ("$python_xlang_kafka_taxi_dataflow" = true cd ${LOCAL_BEAM_DIR} echo "---------------------Downloading Python Staging RC----------------------------" - wget ${PYTHON_RC_DOWNLOAD_URL}/${RELEASE_VER}/python/apache-beam-${RELEASE_VER}.tar.gz - wget ${PYTHON_RC_DOWNLOAD_URL}/${RELEASE_VER}/python/apache-beam-${RELEASE_VER}.tar.gz.sha512 - if [[ ! -f apache-beam-${RELEASE_VER}.tar.gz ]]; then + wget ${PYTHON_RC_DOWNLOAD_URL}/${RELEASE_VER}/python/apache_beam-${RELEASE_VER}.tar.gz + wget ${PYTHON_RC_DOWNLOAD_URL}/${RELEASE_VER}/python/apache_beam-${RELEASE_VER}.tar.gz.sha512 + if [[ ! -f apache_beam-${RELEASE_VER}.tar.gz ]]; then { echo "Fail to download Python Staging RC files." ;exit 1; } fi echo "--------------------------Verifying Hashes------------------------------------" - sha512sum -c apache-beam-${RELEASE_VER}.tar.gz.sha512 + sha512sum -c apache_beam-${RELEASE_VER}.tar.gz.sha512 `which pip` install --upgrade pip `which pip` install --upgrade setuptools @@ -807,7 +807,7 @@ if [[ ("$python_xlang_kafka_taxi_dataflow" = true ln -s ${LOCAL_BEAM_DIR}/sdks beam_env_${py_version}/lib/sdks echo "--------------------------Installing Python SDK-------------------------------" - pip install apache-beam-${RELEASE_VER}.tar.gz[gcp] + pip install apache_beam-${RELEASE_VER}.tar.gz[gcp] echo "----------------Starting XLang Kafka Taxi with DataflowRunner---------------------" if [[ "$python_xlang_kafka_taxi_dataflow" = true ]]; then @@ -837,7 +837,7 @@ if [[ ("$python_xlang_kafka_taxi_dataflow" = true --temp_location=${USER_GCS_BUCKET}/temp/ \ --with_metadata \ --beam_services=\"{\\\"sdks:java:io:expansion-service:shadowJar\\\": \\\"${KAFKA_EXPANSION_SERVICE_JAR}\\\"}\" \ - --sdk_location apache-beam-${RELEASE_VER}.tar.gz; \ + --sdk_location apache_beam-${RELEASE_VER}.tar.gz; \ exec bash" echo "***************************************************************" @@ -882,7 +882,7 @@ if [[ ("$python_xlang_kafka_taxi_dataflow" = true --temp_location=${USER_GCS_BUCKET}/temp/ \ --output_topic projects/${USER_GCP_PROJECT}/topics/${SQL_TAXI_TOPIC} \ --beam_services=\"{\\\":sdks:java:extensions:sql:expansion-service:shadowJar\\\": \\\"${SQL_EXPANSION_SERVICE_JAR}\\\"}\" \ - --sdk_location apache-beam-${RELEASE_VER}.tar.gz; \ + --sdk_location apache_beam-${RELEASE_VER}.tar.gz; \ exec bash" echo "***************************************************************" diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/DefaultMetricResults.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/DefaultMetricResults.java index a77f3947b529..ea8a333d397b 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/DefaultMetricResults.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/DefaultMetricResults.java @@ -24,13 +24,14 @@ import org.apache.beam.sdk.metrics.MetricResult; import org.apache.beam.sdk.metrics.MetricResults; import org.apache.beam.sdk.metrics.MetricsFilter; +import org.apache.beam.sdk.metrics.StringSetResult; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.checkerframework.checker.nullness.qual.Nullable; /** * Default implementation of {@link org.apache.beam.sdk.metrics.MetricResults}, which takes static - * {@link Iterable}s of counters, distributions, and gauges, and serves queries by applying {@link - * org.apache.beam.sdk.metrics.MetricsFilter}s linearly to them. + * {@link Iterable}s of counters, distributions, gauges, and stringsets, and serves queries by + * applying {@link org.apache.beam.sdk.metrics.MetricsFilter}s linearly to them. */ @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) @@ -40,14 +41,17 @@ public class DefaultMetricResults extends MetricResults { private final Iterable> counters; private final Iterable> distributions; private final Iterable> gauges; + private final Iterable> stringSets; public DefaultMetricResults( Iterable> counters, Iterable> distributions, - Iterable> gauges) { + Iterable> gauges, + Iterable> stringSets) { this.counters = counters; this.distributions = distributions; this.gauges = gauges; + this.stringSets = stringSets; } @Override @@ -56,6 +60,8 @@ public MetricQueryResults queryMetrics(@Nullable MetricsFilter filter) { Iterables.filter(counters, counter -> MetricFiltering.matches(filter, counter.getKey())), Iterables.filter( distributions, distribution -> MetricFiltering.matches(filter, distribution.getKey())), - Iterables.filter(gauges, gauge -> MetricFiltering.matches(filter, gauge.getKey()))); + Iterables.filter(gauges, gauge -> MetricFiltering.matches(filter, gauge.getKey())), + Iterables.filter( + stringSets, stringSets -> MetricFiltering.matches(filter, stringSets.getKey()))); } } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricUpdates.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricUpdates.java index 7ef936c8552d..ada5bda4df4a 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricUpdates.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricUpdates.java @@ -29,10 +29,12 @@ "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) }) public abstract class MetricUpdates { - public static final MetricUpdates EMPTY = MetricUpdates.create( - Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList()); /** * Representation of a single metric update. @@ -52,25 +54,33 @@ public static MetricUpdate create(MetricKey key, T update) { } } - /** Returns true if there are no updates in this MetricUpdates object. */ - public boolean isEmpty() { - return Iterables.isEmpty(counterUpdates()) && Iterables.isEmpty(distributionUpdates()); - } - - /** All of the counter updates. */ + /** All the counter updates. */ public abstract Iterable> counterUpdates(); - /** All of the distribution updates. */ + /** All the distribution updates. */ public abstract Iterable> distributionUpdates(); - /** All of the gauges updates. */ + /** All the gauges updates. */ public abstract Iterable> gaugeUpdates(); + /** All the sets updates. */ + public abstract Iterable> stringSetUpdates(); + /** Create a new {@link MetricUpdates} bundle. */ public static MetricUpdates create( Iterable> counterUpdates, Iterable> distributionUpdates, - Iterable> gaugeUpdates) { - return new AutoValue_MetricUpdates(counterUpdates, distributionUpdates, gaugeUpdates); + Iterable> gaugeUpdates, + Iterable> stringSetUpdates) { + return new AutoValue_MetricUpdates( + counterUpdates, distributionUpdates, gaugeUpdates, stringSetUpdates); + } + + /** Returns true if there are no updates in this MetricUpdates object. */ + public boolean isEmpty() { + return Iterables.isEmpty(counterUpdates()) + && Iterables.isEmpty(distributionUpdates()) + && Iterables.isEmpty(gaugeUpdates()) + && Iterables.isEmpty(stringSetUpdates()); } } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerImpl.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerImpl.java index abf3bb2f886b..99cf98508505 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerImpl.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerImpl.java @@ -19,13 +19,16 @@ import static org.apache.beam.runners.core.metrics.MonitoringInfoConstants.TypeUrns.DISTRIBUTION_INT64_TYPE; import static org.apache.beam.runners.core.metrics.MonitoringInfoConstants.TypeUrns.LATEST_INT64_TYPE; +import static org.apache.beam.runners.core.metrics.MonitoringInfoConstants.TypeUrns.SET_STRING_TYPE; import static org.apache.beam.runners.core.metrics.MonitoringInfoConstants.TypeUrns.SUM_INT64_TYPE; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.decodeInt64Counter; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.decodeInt64Distribution; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.decodeInt64Gauge; +import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.decodeStringSet; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeInt64Counter; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeInt64Distribution; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeInt64Gauge; +import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeStringSet; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import java.io.Serializable; @@ -85,6 +88,8 @@ public class MetricsContainerImpl implements Serializable, MetricsContainer { private MetricsMap gauges = new MetricsMap<>(GaugeCell::new); + private MetricsMap stringSets = new MetricsMap<>(StringSetCell::new); + private MetricsMap, HistogramCell> histograms = new MetricsMap<>(HistogramCell::new); @@ -123,6 +128,7 @@ public void reset() { distributions.forEachValue(DistributionCell::reset); gauges.forEachValue(GaugeCell::reset); histograms.forEachValue(HistogramCell::reset); + stringSets.forEachValue(StringSetCell::reset); } /** @@ -193,6 +199,23 @@ public GaugeCell getGauge(MetricName metricName) { return gauges.tryGet(metricName); } + /** + * Return a {@code StringSetCell} named {@code metricName}. If it doesn't exist, create a {@code + * Metric} with the specified name. + */ + @Override + public StringSetCell getStringSet(MetricName metricName) { + return stringSets.get(metricName); + } + + /** + * Return a {@code StringSetCell} named {@code metricName}. If it doesn't exist, return {@code + * null}. + */ + public @Nullable StringSetCell tryGetStringSet(MetricName metricName) { + return stringSets.tryGet(metricName); + } + private > ImmutableList> extractUpdates(MetricsMap cells) { ImmutableList.Builder> updates = ImmutableList.builder(); @@ -212,7 +235,10 @@ ImmutableList> extractUpdates(MetricsMap metricUpdate) { + SimpleMonitoringInfoBuilder builder = stringSetToMonitoringMetadata(metricUpdate.getKey()); + if (builder == null) { + return null; + } + builder.setStringSetValue(metricUpdate.getUpdate()); + return builder.build(); + } + /** Return the cumulative values for any metrics in this container as MonitoringInfos. */ @Override public Iterable getMonitoringInfos() { @@ -335,6 +383,13 @@ public Iterable getMonitoringInfos() { monitoringInfos.add(mi); } } + + for (MetricUpdate metricUpdate : metricUpdates.stringSetUpdates()) { + MonitoringInfo mi = stringSetUpdateToMonitoringInfo(metricUpdate); + if (mi != null) { + monitoringInfos.add(mi); + } + } return monitoringInfos; } @@ -368,6 +423,15 @@ public Map getMonitoringData(ShortIdMap shortIds) { } } }); + stringSets.forEach( + (metricName, stringSetCell) -> { + if (stringSetCell.getDirty().beforeCommit()) { + String shortId = getShortId(metricName, this::stringSetToMonitoringMetadata, shortIds); + if (shortId != null) { + builder.put(shortId, encodeStringSet(stringSetCell.getCumulative())); + } + } + }); return builder.build(); } @@ -402,6 +466,7 @@ public void commitUpdates() { counters.forEachValue(counter -> counter.getDirty().afterCommit()); distributions.forEachValue(distribution -> distribution.getDirty().afterCommit()); gauges.forEachValue(gauge -> gauge.getDirty().afterCommit()); + stringSets.forEachValue(sSets -> sSets.getDirty().afterCommit()); } private > @@ -423,7 +488,8 @@ public MetricUpdates getCumulative() { return MetricUpdates.create( extractCumulatives(counters), extractCumulatives(distributions), - extractCumulatives(gauges)); + extractCumulatives(gauges), + extractCumulatives(stringSets)); } /** Update values of this {@link MetricsContainerImpl} by merging the value of another cell. */ @@ -432,6 +498,7 @@ public void update(MetricsContainerImpl other) { updateDistributions(distributions, other.distributions); updateGauges(gauges, other.gauges); updateHistograms(histograms, other.histograms); + updateStringSets(stringSets, other.stringSets); } private void updateForSumInt64Type(MonitoringInfo monitoringInfo) { @@ -454,6 +521,12 @@ private void updateForLatestInt64Type(MonitoringInfo monitoringInfo) { gauge.update(decodeInt64Gauge(monitoringInfo.getPayload())); } + private void updateForStringSetType(MonitoringInfo monitoringInfo) { + MetricName metricName = MonitoringInfoMetricName.of(monitoringInfo); + StringSetCell stringSet = getStringSet(metricName); + stringSet.update(decodeStringSet(monitoringInfo.getPayload())); + } + /** Update values of this {@link MetricsContainerImpl} by reading from {@code monitoringInfos}. */ public void update(Iterable monitoringInfos) { for (MonitoringInfo monitoringInfo : monitoringInfos) { @@ -474,6 +547,10 @@ public void update(Iterable monitoringInfos) { updateForLatestInt64Type(monitoringInfo); break; + case SET_STRING_TYPE: + updateForStringSetType(monitoringInfo); + break; + default: LOG.warn("Unsupported metric type {}", monitoringInfo.getType()); } @@ -502,6 +579,12 @@ private void updateHistograms( updates.forEach((key, value) -> current.get(key).update(value)); } + private void updateStringSets( + MetricsMap current, + MetricsMap updates) { + updates.forEach((key, value) -> current.get(key).update(value.getCumulative())); + } + @Override public boolean equals(@Nullable Object object) { if (object instanceof MetricsContainerImpl) { @@ -509,14 +592,15 @@ public boolean equals(@Nullable Object object) { return Objects.equals(stepName, metricsContainerImpl.stepName) && Objects.equals(counters, metricsContainerImpl.counters) && Objects.equals(distributions, metricsContainerImpl.distributions) - && Objects.equals(gauges, metricsContainerImpl.gauges); + && Objects.equals(gauges, metricsContainerImpl.gauges) + && Objects.equals(stringSets, metricsContainerImpl.stringSets); } return false; } @Override public int hashCode() { - return Objects.hash(stepName, counters, distributions, gauges); + return Objects.hash(stepName, counters, distributions, gauges, stringSets); } /** @@ -588,6 +672,16 @@ public String getCumulativeString(@Nullable Set allowedMetricUrns) { } message.append("\n"); } + for (Map.Entry cell : stringSets.entries()) { + if (!matchMetric(cell.getKey(), allowedMetricUrns)) { + continue; + } + message.append(cell.getKey().toString()); + message.append(" = "); + StringSetData data = cell.getValue().getCumulative(); + message.append(data.stringSet().toString()); + message.append("\n"); + } return message.toString(); } @@ -628,6 +722,10 @@ public static MetricsContainerImpl deltaContainer( deltaValueCell.incTopBucketCount( currValue.getTopBucketCount() - prevValue.getTopBucketCount()); } + for (Map.Entry cell : curr.stringSets.entries()) { + // Simply take the most recent value for stringSets, no need to count deltas. + deltaContainer.stringSets.get(cell.getKey()).update(cell.getValue().getCumulative()); + } return deltaContainer; } } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMap.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMap.java index b59e58956a12..688491184e67 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMap.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMap.java @@ -136,6 +136,7 @@ public static MetricResults asMetricResults( Map> counters = new HashMap<>(); Map> distributions = new HashMap<>(); Map> gauges = new HashMap<>(); + Map> sets = new HashMap<>(); attemptedMetricsContainers.forEachMetricContainer( container -> { @@ -144,6 +145,7 @@ public static MetricResults asMetricResults( mergeAttemptedResults( distributions, cumulative.distributionUpdates(), DistributionData::combine); mergeAttemptedResults(gauges, cumulative.gaugeUpdates(), GaugeData::combine); + mergeAttemptedResults(sets, cumulative.stringSetUpdates(), StringSetData::combine); }); committedMetricsContainers.forEachMetricContainer( container -> { @@ -152,6 +154,7 @@ public static MetricResults asMetricResults( mergeCommittedResults( distributions, cumulative.distributionUpdates(), DistributionData::combine); mergeCommittedResults(gauges, cumulative.gaugeUpdates(), GaugeData::combine); + mergeCommittedResults(sets, cumulative.stringSetUpdates(), StringSetData::combine); }); return new DefaultMetricResults( @@ -161,6 +164,9 @@ public static MetricResults asMetricResults( .collect(toList()), gauges.values().stream() .map(result -> result.transform(GaugeData::extractResult)) + .collect(toList()), + sets.values().stream() + .map(result -> result.transform(StringSetData::extractResult)) .collect(toList())); } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java index 44d1b4f53071..2bb935111d38 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java @@ -52,6 +52,8 @@ public static final class Urns { extractUrn(MonitoringInfoSpecs.Enum.USER_DISTRIBUTION_INT64); public static final String USER_DISTRIBUTION_DOUBLE = extractUrn(MonitoringInfoSpecs.Enum.USER_DISTRIBUTION_DOUBLE); + public static final String USER_SET_STRING = + extractUrn(MonitoringInfoSpecs.Enum.USER_SET_STRING); public static final String SAMPLED_BYTE_SIZE = extractUrn(MonitoringInfoSpecs.Enum.SAMPLED_BYTE_SIZE); public static final String WORK_COMPLETED = extractUrn(MonitoringInfoSpecs.Enum.WORK_COMPLETED); @@ -162,6 +164,7 @@ public static final class TypeUrns { public static final String BOTTOM_N_INT64_TYPE = "beam:metrics:bottom_n_int64:v1"; public static final String BOTTOM_N_DOUBLE_TYPE = "beam:metrics:bottom_n_double:v1"; public static final String PROGRESS_TYPE = "beam:metrics:progress:v1"; + public static final String SET_STRING_TYPE = "beam:metrics:set_string:v1"; static { // Validate that compile time constants match the values stored in the protos. @@ -187,6 +190,7 @@ public static final class TypeUrns { checkArgument( BOTTOM_N_DOUBLE_TYPE.equals(getUrn(MonitoringInfoTypeUrns.Enum.BOTTOM_N_DOUBLE_TYPE))); checkArgument(PROGRESS_TYPE.equals(getUrn(MonitoringInfoTypeUrns.Enum.PROGRESS_TYPE))); + checkArgument(SET_STRING_TYPE.equals(getUrn(MonitoringInfoTypeUrns.Enum.SET_STRING_TYPE))); } } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoEncodings.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoEncodings.java index 12e7b41650dd..433e7f4fb20b 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoEncodings.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoEncodings.java @@ -19,17 +19,23 @@ import java.io.IOException; import java.io.InputStream; +import java.util.Set; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.DoubleCoder; +import org.apache.beam.sdk.coders.IterableCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.util.ByteStringOutputStream; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; import org.joda.time.Instant; /** A set of functions used to encode and decode common monitoring info types. */ public class MonitoringInfoEncodings { private static final Coder VARINT_CODER = VarLongCoder.of(); private static final Coder DOUBLE_CODER = DoubleCoder.of(); + private static final IterableCoder STRING_SET_CODER = + IterableCoder.of(StringUtf8Coder.of()); /** Encodes to {@link MonitoringInfoConstants.TypeUrns#DISTRIBUTION_INT64_TYPE}. */ public static ByteString encodeInt64Distribution(DistributionData data) { @@ -98,6 +104,26 @@ public static GaugeData decodeInt64Gauge(ByteString payload) { } } + /** Encodes to {@link MonitoringInfoConstants.TypeUrns#SET_STRING_TYPE}. */ + public static ByteString encodeStringSet(StringSetData data) { + try (ByteStringOutputStream output = new ByteStringOutputStream()) { + STRING_SET_CODER.encode(data.stringSet(), output); + return output.toByteString(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** Decodes from {@link MonitoringInfoConstants.TypeUrns#SET_STRING_TYPE}. */ + public static StringSetData decodeStringSet(ByteString payload) { + try (InputStream input = payload.newInput()) { + Set elements = Sets.newHashSet(STRING_SET_CODER.decode(input)); + return StringSetData.create(elements); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + /** Encodes to {@link MonitoringInfoConstants.TypeUrns#SUM_INT64_TYPE}. */ public static ByteString encodeInt64Counter(long value) { ByteStringOutputStream output = new ByteStringOutputStream(); diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/SimpleMonitoringInfoBuilder.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/SimpleMonitoringInfoBuilder.java index c44a2621ee6c..e0f5092e6b1f 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/SimpleMonitoringInfoBuilder.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/SimpleMonitoringInfoBuilder.java @@ -23,6 +23,7 @@ import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeInt64Counter; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeInt64Distribution; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeInt64Gauge; +import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeStringSet; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import java.util.HashMap; @@ -148,6 +149,16 @@ public SimpleMonitoringInfoBuilder setDoubleDistributionValue( return this; } + /** + * Encodes the value and sets the type to {@link + * MonitoringInfoConstants.TypeUrns#SET_STRING_TYPE}. + */ + public SimpleMonitoringInfoBuilder setStringSetValue(StringSetData value) { + this.builder.setPayload(encodeStringSet(value)); + this.builder.setType(MonitoringInfoConstants.TypeUrns.SET_STRING_TYPE); + return this; + } + /** Sets the MonitoringInfo label to the given name and value. */ public SimpleMonitoringInfoBuilder setLabel(String labelName, String labelValue) { this.builder.putLabels(labelName, labelValue); diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetCell.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetCell.java new file mode 100644 index 000000000000..8455f154c0f8 --- /dev/null +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetCell.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core.metrics; + +import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.beam.sdk.metrics.MetricName; +import org.apache.beam.sdk.metrics.MetricsContainer; +import org.apache.beam.sdk.metrics.StringSet; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Tracks the current value for a {@link StringSet} metric. + * + *

This class generally shouldn't be used directly. The only exception is within a runner where a + * counter is being reported for a specific step (rather than the counter in the current context). + * In that case retrieving the underlying cell and reporting directly to it avoids a step of + * indirection. + */ +public class StringSetCell implements StringSet, MetricCell { + + private final DirtyState dirty = new DirtyState(); + private final AtomicReference setValue = + new AtomicReference<>(StringSetData.empty()); + private final MetricName name; + + /** + * Generally, runners should construct instances using the methods in {@link + * MetricsContainerImpl}, unless they need to define their own version of {@link + * MetricsContainer}. These constructors are *only* public so runners can instantiate. + */ + public StringSetCell(MetricName name) { + this.name = name; + } + + @Override + public void reset() { + setValue.set(StringSetData.empty()); + dirty.reset(); + } + + void update(StringSetData data) { + StringSetData original; + do { + original = setValue.get(); + } while (!setValue.compareAndSet(original, original.combine(data))); + dirty.afterModification(); + } + + @Override + public DirtyState getDirty() { + return dirty; + } + + @Override + public StringSetData getCumulative() { + return setValue.get(); + } + + @Override + public MetricName getName() { + return name; + } + + @Override + public boolean equals(@Nullable Object object) { + if (object instanceof StringSetCell) { + StringSetCell stringSetCell = (StringSetCell) object; + return Objects.equals(dirty, stringSetCell.dirty) + && Objects.equals(setValue.get(), stringSetCell.setValue.get()) + && Objects.equals(name, stringSetCell.name); + } + + return false; + } + + @Override + public int hashCode() { + return Objects.hash(dirty, setValue.get(), name); + } + + @Override + public void add(String value) { + // if the given value is already present in the StringSet then skip this add for efficiency + if (this.setValue.get().stringSet().contains(value)) { + return; + } + update(StringSetData.create(ImmutableSet.of(value))); + } + + @Override + public void add(String... values) { + update(StringSetData.create(ImmutableSet.copyOf(values))); + } +} diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java new file mode 100644 index 000000000000..93dfb8e3ebc8 --- /dev/null +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core.metrics; + +import com.google.auto.value.AutoValue; +import java.io.Serializable; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; +import org.apache.beam.sdk.metrics.StringSetResult; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; + +/** + * Data describing the StringSet. The {@link StringSetData} hold an immutable copy of the set from + * which it was initially created. This should retain enough detail that it can be combined with + * other {@link StringSetData}. + */ +@AutoValue +public abstract class StringSetData implements Serializable { + + public abstract Set stringSet(); + + /** Returns a {@link StringSetData} which is made from an immutable copy of the given set. */ + public static StringSetData create(Set set) { + return new AutoValue_StringSetData(ImmutableSet.copyOf(set)); + } + + /** Return a {@link EmptyStringSetData#INSTANCE} representing an empty {@link StringSetData}. */ + public static StringSetData empty() { + return EmptyStringSetData.INSTANCE; + } + + /** + * Combines this {@link StringSetData} with other, both original StringSetData are left intact. + */ + public StringSetData combine(StringSetData other) { + // do not merge other on this as this StringSetData might hold an immutable set like in case + // of EmptyStringSetData + Set combined = new HashSet<>(); + combined.addAll(this.stringSet()); + combined.addAll(other.stringSet()); + return StringSetData.create(combined); + } + + /** + * Combines this {@link StringSetData} with others, all original StringSetData are left intact. + */ + public StringSetData combine(Iterable others) { + Set combined = + StreamSupport.stream(others.spliterator(), true) + .flatMap(other -> other.stringSet().stream()) + .collect(Collectors.toSet()); + combined.addAll(this.stringSet()); + return StringSetData.create(combined); + } + + /** Returns a {@link StringSetResult} representing this {@link StringSetData}. */ + public StringSetResult extractResult() { + return StringSetResult.create(stringSet()); + } + + /** Empty {@link StringSetData}, representing no values reported and is immutable. */ + public static class EmptyStringSetData extends StringSetData { + + private static final EmptyStringSetData INSTANCE = new EmptyStringSetData(); + + private EmptyStringSetData() {} + + /** Returns an immutable empty set. */ + @Override + public Set stringSet() { + return ImmutableSet.of(); + } + + /** Return a {@link StringSetResult#empty()} which is immutable empty set. */ + @Override + public StringSetResult extractResult() { + return StringSetResult.empty(); + } + } +} diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MetricsContainerImplTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MetricsContainerImplTest.java index 146b7df10f0c..5b3d71f4873e 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MetricsContainerImplTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MetricsContainerImplTest.java @@ -37,6 +37,7 @@ import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo; import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.sdk.util.HistogramData; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; @@ -269,6 +270,38 @@ public void testMonitoringInfosArePopulatedForUserGauges() { assertThat(actualMonitoringInfos, containsInAnyOrder(builder1.build(), builder2.build())); } + @Test + public void testMonitoringInfosArePopulatedForUserStringSets() { + MetricsContainerImpl testObject = new MetricsContainerImpl("step1"); + StringSetCell stringSetCellA = testObject.getStringSet(MetricName.named("ns", "nameA")); + StringSetCell stringSetCellB = testObject.getStringSet(MetricName.named("ns", "nameB")); + stringSetCellA.add("A"); + stringSetCellB.add("BBB"); + + SimpleMonitoringInfoBuilder builder1 = new SimpleMonitoringInfoBuilder(); + builder1 + .setUrn(MonitoringInfoConstants.Urns.USER_SET_STRING) + .setLabel(MonitoringInfoConstants.Labels.NAMESPACE, "ns") + .setLabel(MonitoringInfoConstants.Labels.NAME, "nameA") + .setStringSetValue(stringSetCellA.getCumulative()) + .setLabel(MonitoringInfoConstants.Labels.PTRANSFORM, "step1"); + + SimpleMonitoringInfoBuilder builder2 = new SimpleMonitoringInfoBuilder(); + builder2 + .setUrn(MonitoringInfoConstants.Urns.USER_SET_STRING) + .setLabel(MonitoringInfoConstants.Labels.NAMESPACE, "ns") + .setLabel(MonitoringInfoConstants.Labels.NAME, "nameB") + .setStringSetValue(stringSetCellB.getCumulative()) + .setLabel(MonitoringInfoConstants.Labels.PTRANSFORM, "step1"); + + List actualMonitoringInfos = new ArrayList<>(); + for (MonitoringInfo mi : testObject.getMonitoringInfos()) { + actualMonitoringInfos.add(mi); + } + + assertThat(actualMonitoringInfos, containsInAnyOrder(builder1.build(), builder2.build())); + } + @Test public void testMonitoringInfosArePopulatedForSystemDistributions() { MetricsContainerImpl testObject = new MetricsContainerImpl("step1"); @@ -338,10 +371,12 @@ public void testDeltaCounters() { MetricName gName = MetricName.named("namespace", "gauge"); HistogramData.BucketType bucketType = HistogramData.LinearBuckets.of(0, 2, 5); MetricName hName = MetricName.named("namespace", "histogram"); + MetricName stringSetName = MetricName.named("namespace", "stringset"); MetricsContainerImpl prevContainer = new MetricsContainerImpl(null); prevContainer.getCounter(cName).inc(2L); prevContainer.getGauge(gName).set(4L); + prevContainer.getStringSet(stringSetName).add("ab"); // Set buckets counts to: [1,1,1,0,0,0,1] prevContainer.getHistogram(hName, bucketType).update(-1); prevContainer.getHistogram(hName, bucketType).update(1); @@ -351,6 +386,8 @@ public void testDeltaCounters() { MetricsContainerImpl nextContainer = new MetricsContainerImpl(null); nextContainer.getCounter(cName).inc(9L); nextContainer.getGauge(gName).set(8L); + nextContainer.getStringSet(stringSetName).add("cd"); + nextContainer.getStringSet(stringSetName).add("ab"); // Set buckets counts to: [2,4,5,0,0,0,3] nextContainer.getHistogram(hName, bucketType).update(-1); nextContainer.getHistogram(hName, bucketType).update(-1); @@ -374,6 +411,10 @@ public void testDeltaCounters() { GaugeData gValue = deltaContainer.getGauge(gName).getCumulative(); assertEquals(8L, gValue.value()); + // Expect most recent value of string set which is all unique strings + StringSetData stringSetData = deltaContainer.getStringSet(stringSetName).getCumulative(); + assertEquals(ImmutableSet.of("ab", "cd"), stringSetData.stringSet()); + // Expect bucket counts: [1,3,4,0,0,0,2] assertEquals( 1, deltaContainer.getHistogram(hName, bucketType).getCumulative().getBottomBucketCount()); @@ -411,6 +452,11 @@ public void testNotEquals() { differentGauges.getGauge(MetricName.named("namespace", "name")); Assert.assertNotEquals(metricsContainerImpl, differentGauges); Assert.assertNotEquals(metricsContainerImpl.hashCode(), differentGauges.hashCode()); + + MetricsContainerImpl differentStringSets = new MetricsContainerImpl("stepName"); + differentStringSets.getStringSet(MetricName.named("namespace", "name")); + Assert.assertNotEquals(metricsContainerImpl, differentStringSets); + Assert.assertNotEquals(metricsContainerImpl.hashCode(), differentStringSets.hashCode()); } @Test diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMapTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMapTest.java index 4718a6f2fed3..868c47f6a2e6 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMapTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMapTest.java @@ -40,6 +40,9 @@ import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.metrics.MetricsFilter; +import org.apache.beam.sdk.metrics.StringSet; +import org.apache.beam.sdk.metrics.StringSetResult; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.hamcrest.collection.IsIterableWithSize; import org.joda.time.Instant; import org.junit.Assert; @@ -60,14 +63,22 @@ public class MetricsContainerStepMapTest { private static final String DISTRIBUTION_NAME = "myDistribution"; private static final String GAUGE_NAME = "myGauge"; + private static final String STRING_SET_NAME = "myStringSet"; + private static final long VALUE = 100; + private static final String FIRST_STRING = "first"; + private static final String SECOND_STRING = "second"; + private static final Counter counter = Metrics.counter(MetricsContainerStepMapTest.class, COUNTER_NAME); private static final Distribution distribution = Metrics.distribution(MetricsContainerStepMapTest.class, DISTRIBUTION_NAME); private static final Gauge gauge = Metrics.gauge(MetricsContainerStepMapTest.class, GAUGE_NAME); + private static final StringSet stringSet = + Metrics.stringSet(MetricsContainerStepMapTest.class, STRING_SET_NAME); + private static final MetricsContainerImpl metricsContainer; static { @@ -77,6 +88,7 @@ public class MetricsContainerStepMapTest { distribution.update(VALUE); distribution.update(VALUE * 2); gauge.set(VALUE); + stringSet.add(FIRST_STRING, SECOND_STRING); } catch (IOException e) { LOG.error(e.getMessage(), e); } @@ -99,6 +111,7 @@ public void testAttemptedAccumulatedMetricResults() { assertIterableSize(step1res.getCounters(), 1); assertIterableSize(step1res.getDistributions(), 1); assertIterableSize(step1res.getGauges(), 1); + assertIterableSize(step1res.getStringSets(), 1); assertCounter(COUNTER_NAME, step1res, STEP1, VALUE, false); assertDistribution( @@ -109,12 +122,20 @@ public void testAttemptedAccumulatedMetricResults() { false); assertGauge(GAUGE_NAME, step1res, STEP1, GaugeResult.create(VALUE, Instant.now()), false); + assertStringSet( + STRING_SET_NAME, + step1res, + STEP1, + StringSetResult.create(ImmutableSet.of(FIRST_STRING, SECOND_STRING)), + false); + MetricQueryResults step2res = metricResults.queryMetrics(MetricsFilter.builder().addStep(STEP2).build()); assertIterableSize(step2res.getCounters(), 1); assertIterableSize(step2res.getDistributions(), 1); assertIterableSize(step2res.getGauges(), 1); + assertIterableSize(step2res.getStringSets(), 1); assertCounter(COUNTER_NAME, step2res, STEP2, VALUE * 2, false); assertDistribution( @@ -125,11 +146,19 @@ public void testAttemptedAccumulatedMetricResults() { false); assertGauge(GAUGE_NAME, step2res, STEP2, GaugeResult.create(VALUE, Instant.now()), false); + assertStringSet( + STRING_SET_NAME, + step2res, + STEP2, + StringSetResult.create(ImmutableSet.of(FIRST_STRING, SECOND_STRING)), + false); + MetricQueryResults allres = metricResults.allMetrics(); assertIterableSize(allres.getCounters(), 2); assertIterableSize(allres.getDistributions(), 2); assertIterableSize(allres.getGauges(), 2); + assertIterableSize(allres.getStringSets(), 2); } @Test @@ -178,6 +207,21 @@ public void testGaugeCommittedUnsupportedInAttemptedAccumulatedMetricResults() { assertGauge(GAUGE_NAME, step1res, STEP1, GaugeResult.empty(), true); } + @Test + public void testStringSetCommittedUnsupportedInAttemptedAccumulatedMetricResults() { + MetricsContainerStepMap attemptedMetrics = new MetricsContainerStepMap(); + attemptedMetrics.update(STEP1, metricsContainer); + MetricResults metricResults = asAttemptedOnlyMetricResults(attemptedMetrics); + + MetricQueryResults step1res = + metricResults.queryMetrics(MetricsFilter.builder().addStep(STEP1).build()); + + thrown.expect(UnsupportedOperationException.class); + thrown.expectMessage("This runner does not currently support committed metrics results."); + + assertStringSet(STRING_SET_NAME, step1res, STEP1, StringSetResult.empty(), true); + } + @Test public void testUserMetricDroppedOnUnbounded() { MetricsContainerStepMap testObject = new MetricsContainerStepMap(); @@ -248,6 +292,7 @@ public void testAttemptedAndCommittedAccumulatedMetricResults() { assertIterableSize(step1res.getCounters(), 1); assertIterableSize(step1res.getDistributions(), 1); assertIterableSize(step1res.getGauges(), 1); + assertIterableSize(step1res.getStringSets(), 1); assertCounter(COUNTER_NAME, step1res, STEP1, VALUE * 2, false); assertDistribution( @@ -257,6 +302,12 @@ public void testAttemptedAndCommittedAccumulatedMetricResults() { DistributionResult.create(VALUE * 6, 4, VALUE, VALUE * 2), false); assertGauge(GAUGE_NAME, step1res, STEP1, GaugeResult.create(VALUE, Instant.now()), false); + assertStringSet( + STRING_SET_NAME, + step1res, + STEP1, + StringSetResult.create(ImmutableSet.of(FIRST_STRING, SECOND_STRING)), + false); assertCounter(COUNTER_NAME, step1res, STEP1, VALUE, true); assertDistribution( @@ -266,6 +317,12 @@ public void testAttemptedAndCommittedAccumulatedMetricResults() { DistributionResult.create(VALUE * 3, 2, VALUE, VALUE * 2), true); assertGauge(GAUGE_NAME, step1res, STEP1, GaugeResult.create(VALUE, Instant.now()), true); + assertStringSet( + STRING_SET_NAME, + step1res, + STEP1, + StringSetResult.create(ImmutableSet.of(FIRST_STRING, SECOND_STRING)), + true); MetricQueryResults step2res = metricResults.queryMetrics(MetricsFilter.builder().addStep(STEP2).build()); @@ -273,6 +330,7 @@ public void testAttemptedAndCommittedAccumulatedMetricResults() { assertIterableSize(step2res.getCounters(), 1); assertIterableSize(step2res.getDistributions(), 1); assertIterableSize(step2res.getGauges(), 1); + assertIterableSize(step2res.getStringSets(), 1); assertCounter(COUNTER_NAME, step2res, STEP2, VALUE * 3, false); assertDistribution( @@ -282,6 +340,12 @@ public void testAttemptedAndCommittedAccumulatedMetricResults() { DistributionResult.create(VALUE * 9, 6, VALUE, VALUE * 2), false); assertGauge(GAUGE_NAME, step2res, STEP2, GaugeResult.create(VALUE, Instant.now()), false); + assertStringSet( + STRING_SET_NAME, + step2res, + STEP2, + StringSetResult.create(ImmutableSet.of(FIRST_STRING, SECOND_STRING)), + false); assertCounter(COUNTER_NAME, step2res, STEP2, VALUE * 2, true); assertDistribution( @@ -291,12 +355,25 @@ public void testAttemptedAndCommittedAccumulatedMetricResults() { DistributionResult.create(VALUE * 6, 4, VALUE, VALUE * 2), true); assertGauge(GAUGE_NAME, step2res, STEP2, GaugeResult.create(VALUE, Instant.now()), true); + assertStringSet( + STRING_SET_NAME, + step2res, + STEP2, + StringSetResult.create(ImmutableSet.of(FIRST_STRING, SECOND_STRING)), + true); + assertStringSet( + STRING_SET_NAME, + step2res, + STEP2, + StringSetResult.create(ImmutableSet.of(FIRST_STRING, SECOND_STRING)), + true); MetricQueryResults allres = metricResults.queryMetrics(MetricsFilter.builder().build()); assertIterableSize(allres.getCounters(), 2); assertIterableSize(allres.getDistributions(), 2); assertIterableSize(allres.getGauges(), 2); + assertIterableSize(allres.getStringSets(), 2); } @Test @@ -345,6 +422,12 @@ public void testReset() { DistributionResult.create(VALUE * 3, 2, VALUE, VALUE * 2), false); assertGauge(GAUGE_NAME, allres, STEP1, GaugeResult.create(VALUE, Instant.now()), false); + assertStringSet( + STRING_SET_NAME, + allres, + STEP1, + StringSetResult.create(ImmutableSet.of(FIRST_STRING, SECOND_STRING)), + false); assertCounter(COUNTER_NAME, allres, STEP2, VALUE * 2, false); assertDistribution( @@ -354,6 +437,12 @@ public void testReset() { DistributionResult.create(VALUE * 6, 4, VALUE, VALUE * 2), false); assertGauge(GAUGE_NAME, allres, STEP2, GaugeResult.create(VALUE, Instant.now()), false); + assertStringSet( + STRING_SET_NAME, + allres, + STEP2, + StringSetResult.create(ImmutableSet.of(FIRST_STRING, SECOND_STRING)), + false); attemptedMetrics.reset(); metricResults = asAttemptedOnlyMetricResults(attemptedMetrics); @@ -364,12 +453,14 @@ public void testReset() { assertDistribution( DISTRIBUTION_NAME, allres, STEP1, DistributionResult.IDENTITY_ELEMENT, false); assertGauge(GAUGE_NAME, allres, STEP1, GaugeResult.empty(), false); + assertStringSet(STRING_SET_NAME, allres, STEP1, StringSetResult.empty(), false); // Check that the metrics container for STEP2 is reset assertCounter(COUNTER_NAME, allres, STEP2, 0L, false); assertDistribution( DISTRIBUTION_NAME, allres, STEP2, DistributionResult.IDENTITY_ELEMENT, false); assertGauge(GAUGE_NAME, allres, STEP2, GaugeResult.empty(), false); + assertStringSet(STRING_SET_NAME, allres, STEP2, StringSetResult.empty(), false); } private void assertIterableSize(Iterable iterable, int size) { @@ -408,4 +499,15 @@ private void assertGauge( metricQueryResults.getGauges(), hasItem(metricsResult(NAMESPACE, name, step, expected, isCommitted))); } + + private void assertStringSet( + String name, + MetricQueryResults metricQueryResults, + String step, + StringSetResult expected, + boolean isCommitted) { + assertThat( + metricQueryResults.getStringSets(), + hasItem(metricsResult(NAMESPACE, name, step, expected, isCommitted))); + } } diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MonitoringInfoEncodingsTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MonitoringInfoEncodingsTest.java index a1b73781cd6c..8a43eef5883d 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MonitoringInfoEncodingsTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MonitoringInfoEncodingsTest.java @@ -21,14 +21,18 @@ import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.decodeInt64Counter; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.decodeInt64Distribution; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.decodeInt64Gauge; +import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.decodeStringSet; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeDoubleCounter; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeDoubleDistribution; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeInt64Counter; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeInt64Distribution; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeInt64Gauge; +import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeStringSet; import static org.junit.Assert.assertEquals; +import java.util.Collections; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.joda.time.Instant; import org.junit.Test; import org.junit.runner.RunWith; @@ -64,6 +68,30 @@ public void testInt64GaugeEncoding() { assertEquals(data, decodeInt64Gauge(payload)); } + @Test + public void testStringSetEncoding() { + + // test empty string set encoding + StringSetData data = StringSetData.create(Collections.emptySet()); + ByteString payload = encodeStringSet(data); + assertEquals(data, decodeStringSet(payload)); + + // test single element string set encoding + data = StringSetData.create(ImmutableSet.of("ab")); + payload = encodeStringSet(data); + assertEquals(data, decodeStringSet(payload)); + + // test multiple element string set encoding + data = StringSetData.create(ImmutableSet.of("ab", "cd", "ef")); + payload = encodeStringSet(data); + assertEquals(data, decodeStringSet(payload)); + + // test empty string encoding + data = StringSetData.create(ImmutableSet.of("ab", "", "ef")); + payload = encodeStringSet(data); + assertEquals(data, decodeStringSet(payload)); + } + @Test public void testInt64CounterEncoding() { ByteString payload = encodeInt64Counter(1L); diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java new file mode 100644 index 000000000000..f78ed01603fb --- /dev/null +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core.metrics; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertEquals; + +import org.apache.beam.sdk.metrics.MetricName; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.junit.Assert; +import org.junit.Test; + +/** Tests for {@link StringSetCell}. */ +public class StringSetCellTest { + private final StringSetCell cell = new StringSetCell(MetricName.named("lineage", "sources")); + + @Test + public void testDeltaAndCumulative() { + cell.add("pubsub"); + cell.add("bq", "spanner"); + assertEquals(cell.getCumulative().stringSet(), ImmutableSet.of("spanner", "pubsub", "bq")); + assertEquals( + "getCumulative is idempotent", + cell.getCumulative().stringSet(), + ImmutableSet.of("spanner", "pubsub", "bq")); + + assertThat(cell.getDirty().beforeCommit(), equalTo(true)); + cell.getDirty().afterCommit(); + assertThat(cell.getDirty().beforeCommit(), equalTo(false)); + + cell.add("gcs"); + assertEquals( + cell.getCumulative().stringSet(), ImmutableSet.of("spanner", "pubsub", "bq", "gcs")); + + assertThat( + "Adding a new value made the cell dirty", cell.getDirty().beforeCommit(), equalTo(true)); + } + + @Test + public void testEquals() { + StringSetCell stringSetCell = new StringSetCell(MetricName.named("namespace", "name")); + StringSetCell equal = new StringSetCell(MetricName.named("namespace", "name")); + assertEquals(stringSetCell, equal); + assertEquals(stringSetCell.hashCode(), equal.hashCode()); + } + + @Test + public void testNotEquals() { + StringSetCell stringSetCell = new StringSetCell(MetricName.named("namespace", "name")); + + Assert.assertNotEquals(stringSetCell, new Object()); + + StringSetCell differentDirty = new StringSetCell(MetricName.named("namespace", "name")); + differentDirty.getDirty().afterModification(); + Assert.assertNotEquals(stringSetCell, differentDirty); + Assert.assertNotEquals(stringSetCell.hashCode(), differentDirty.hashCode()); + + StringSetCell differentSetValues = new StringSetCell(MetricName.named("namespace", "name")); + differentSetValues.update(StringSetData.create(ImmutableSet.of("hello"))); + Assert.assertNotEquals(stringSetCell, differentSetValues); + Assert.assertNotEquals(stringSetCell.hashCode(), differentSetValues.hashCode()); + + StringSetCell differentName = new StringSetCell(MetricName.named("DIFFERENT", "DIFFERENT")); + Assert.assertNotEquals(stringSetCell, differentName); + Assert.assertNotEquals(stringSetCell.hashCode(), differentName.hashCode()); + } + + @Test + public void testReset() { + StringSetCell stringSetCell = new StringSetCell(MetricName.named("namespace", "name")); + stringSetCell.add("hello"); + Assert.assertNotEquals(stringSetCell.getDirty(), new DirtyState()); + assertThat( + stringSetCell.getCumulative().stringSet(), + equalTo(StringSetData.create(ImmutableSet.of("hello")).stringSet())); + + stringSetCell.reset(); + assertThat(stringSetCell.getCumulative(), equalTo(StringSetData.empty())); + assertThat(stringSetCell.getDirty(), equalTo(new DirtyState())); + } +} diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetDataTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetDataTest.java new file mode 100644 index 000000000000..665ce3743c51 --- /dev/null +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetDataTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core.metrics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.util.Collections; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +/** Tests for {@link StringSetData}. */ +public class StringSetDataTest { + @Rule public ExpectedException exception = ExpectedException.none(); + + @Test + public void testCreate() { + // test empty stringset creation + assertTrue(StringSetData.create(Collections.emptySet()).stringSet().isEmpty()); + // single element test + ImmutableSet singleElement = ImmutableSet.of("ab"); + StringSetData setData = StringSetData.create(singleElement); + assertEquals(setData.stringSet(), singleElement); + + // multiple element test + ImmutableSet multipleElement = ImmutableSet.of("cd", "ef"); + setData = StringSetData.create(multipleElement); + assertEquals(setData.stringSet(), multipleElement); + } + + @Test + public void testCombine() { + StringSetData singleElement = StringSetData.create(ImmutableSet.of("ab")); + StringSetData multipleElement = StringSetData.create(ImmutableSet.of("cd", "ef")); + StringSetData result = singleElement.combine(multipleElement); + assertEquals(result.stringSet(), ImmutableSet.of("cd", "ef", "ab")); + + // original sets in stringsetdata should have remained the same + assertEquals(singleElement.stringSet(), ImmutableSet.of("ab")); + assertEquals(multipleElement.stringSet(), ImmutableSet.of("cd", "ef")); + } + + @Test + public void testCombineWithEmpty() { + StringSetData empty = StringSetData.empty(); + StringSetData multipleElement = StringSetData.create(ImmutableSet.of("cd", "ef")); + StringSetData result = empty.combine(multipleElement); + assertEquals(result.stringSet(), ImmutableSet.of("cd", "ef")); + // original sets in stringsetdata should have remained the same + assertTrue(empty.stringSet().isEmpty()); + assertEquals(multipleElement.stringSet(), ImmutableSet.of("cd", "ef")); + } + + @Test + public void testEmpty() { + StringSetData empty = StringSetData.empty(); + assertTrue(empty.stringSet().isEmpty()); + } + + @Test + public void testStringSetDataEmptyIsImmutable() { + StringSetData empty = StringSetData.empty(); + assertThrows(UnsupportedOperationException.class, () -> empty.stringSet().add("aa")); + } + + @Test + public void testEmptyExtract() { + assertTrue(StringSetData.empty().extractResult().getStringSet().isEmpty()); + } + + @Test + public void testExtract() { + ImmutableSet contents = ImmutableSet.of("ab", "cd"); + StringSetData stringSetData = StringSetData.create(contents); + assertEquals(stringSetData.stringSet(), contents); + } + + @Test + public void testExtractReturnsImmutable() { + StringSetData stringSetData = StringSetData.create(ImmutableSet.of("ab", "cd")); + // check that immutable copy is returned + assertThrows(UnsupportedOperationException.class, () -> stringSetData.stringSet().add("aa")); + } +} diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectMetrics.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectMetrics.java index 5b286dc0b2e0..b02c4f030b27 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectMetrics.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectMetrics.java @@ -33,6 +33,7 @@ import org.apache.beam.runners.core.metrics.MetricUpdates; import org.apache.beam.runners.core.metrics.MetricUpdates.MetricUpdate; import org.apache.beam.runners.core.metrics.MetricsMap; +import org.apache.beam.runners.core.metrics.StringSetData; import org.apache.beam.sdk.metrics.DistributionResult; import org.apache.beam.sdk.metrics.GaugeResult; import org.apache.beam.sdk.metrics.MetricFiltering; @@ -41,6 +42,7 @@ import org.apache.beam.sdk.metrics.MetricResult; import org.apache.beam.sdk.metrics.MetricResults; import org.apache.beam.sdk.metrics.MetricsFilter; +import org.apache.beam.sdk.metrics.StringSetResult; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.Nullable; @@ -217,6 +219,26 @@ public GaugeResult extract(GaugeData data) { } }; + private static final MetricAggregation STRING_SET = + new MetricAggregation() { + @Override + public StringSetData zero() { + return StringSetData.empty(); + } + + @Override + public StringSetData combine(Iterable updates) { + StringSetData result = StringSetData.empty(); + result = result.combine(updates); + return result; + } + + @Override + public StringSetResult extract(StringSetData data) { + return data.extractResult(); + } + }; + /** The current values of counters in memory. */ private final MetricsMap> counters; @@ -224,12 +246,14 @@ public GaugeResult extract(GaugeData data) { distributions; private final MetricsMap> gauges; + private final MetricsMap> stringSet; DirectMetrics(ExecutorService executorService) { this.counters = new MetricsMap<>(unusedKey -> new DirectMetric<>(COUNTER, executorService)); this.distributions = new MetricsMap<>(unusedKey -> new DirectMetric<>(DISTRIBUTION, executorService)); this.gauges = new MetricsMap<>(unusedKey -> new DirectMetric<>(GAUGE, executorService)); + this.stringSet = new MetricsMap<>(unusedKey -> new DirectMetric<>(STRING_SET, executorService)); } @Override @@ -249,8 +273,17 @@ public MetricQueryResults queryMetrics(@Nullable MetricsFilter filter) { maybeExtractResult(filter, gaugeResults, gauge); } + ImmutableList.Builder> stringSetResult = ImmutableList.builder(); + for (Entry> stringSet : + stringSet.entries()) { + maybeExtractResult(filter, stringSetResult, stringSet); + } + return MetricQueryResults.create( - counterResults.build(), distributionResults.build(), gaugeResults.build()); + counterResults.build(), + distributionResults.build(), + gaugeResults.build(), + stringSetResult.build()); } private void maybeExtractResult( @@ -277,6 +310,10 @@ public void updatePhysical(CommittedBundle bundle, MetricUpdates updates) { for (MetricUpdate gauge : updates.gaugeUpdates()) { gauges.get(gauge.getKey()).updatePhysical(bundle, gauge.getUpdate()); } + + for (MetricUpdate sSet : updates.stringSetUpdates()) { + stringSet.get(sSet.getKey()).updatePhysical(bundle, sSet.getUpdate()); + } } public void commitPhysical(CommittedBundle bundle, MetricUpdates updates) { @@ -289,6 +326,9 @@ public void commitPhysical(CommittedBundle bundle, MetricUpdates updates) { for (MetricUpdate gauge : updates.gaugeUpdates()) { gauges.get(gauge.getKey()).commitPhysical(bundle, gauge.getUpdate()); } + for (MetricUpdate sSet : updates.stringSetUpdates()) { + stringSet.get(sSet.getKey()).commitPhysical(bundle, sSet.getUpdate()); + } } /** Apply metric updates that represent new logical values from a bundle being committed. */ @@ -302,5 +342,8 @@ public void commitLogical(CommittedBundle bundle, MetricUpdates updates) { for (MetricUpdate gauge : updates.gaugeUpdates()) { gauges.get(gauge.getKey()).commitLogical(bundle, gauge.getUpdate()); } + for (MetricUpdate sSet : updates.stringSetUpdates()) { + stringSet.get(sSet.getKey()).commitLogical(bundle, sSet.getUpdate()); + } } } diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectMetricsTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectMetricsTest.java index 46f74d6b7e05..00df20c4ac39 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectMetricsTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectMetricsTest.java @@ -30,13 +30,16 @@ import org.apache.beam.runners.core.metrics.GaugeData; import org.apache.beam.runners.core.metrics.MetricUpdates; import org.apache.beam.runners.core.metrics.MetricUpdates.MetricUpdate; +import org.apache.beam.runners.core.metrics.StringSetData; import org.apache.beam.sdk.metrics.DistributionResult; import org.apache.beam.sdk.metrics.GaugeResult; import org.apache.beam.sdk.metrics.MetricKey; import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.sdk.metrics.MetricQueryResults; import org.apache.beam.sdk.metrics.MetricsFilter; +import org.apache.beam.sdk.metrics.StringSetResult; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.joda.time.Instant; import org.junit.After; import org.junit.Before; @@ -85,7 +88,11 @@ public void testApplyCommittedNoFilter() { MetricUpdate.create( MetricKey.create("step1", NAME1), DistributionData.create(8, 2, 3, 5))), ImmutableList.of( - MetricUpdate.create(MetricKey.create("step1", NAME4), GaugeData.create(15L))))); + MetricUpdate.create(MetricKey.create("step1", NAME4), GaugeData.create(15L))), + ImmutableList.of( + MetricUpdate.create( + MetricKey.create("step1", NAME4), + StringSetData.create(ImmutableSet.of("ab")))))); metrics.commitLogical( bundle1, MetricUpdates.create( @@ -96,7 +103,11 @@ public void testApplyCommittedNoFilter() { MetricUpdate.create( MetricKey.create("step1", NAME1), DistributionData.create(4, 1, 4, 4))), ImmutableList.of( - MetricUpdate.create(MetricKey.create("step1", NAME4), GaugeData.create(27L))))); + MetricUpdate.create(MetricKey.create("step1", NAME4), GaugeData.create(27L))), + ImmutableList.of( + MetricUpdate.create( + MetricKey.create("step1", NAME4), + StringSetData.create(ImmutableSet.of("cd")))))); MetricQueryResults results = metrics.allMetrics(); assertThat( @@ -128,6 +139,11 @@ public void testApplyCommittedNoFilter() { contains( committedMetricsResult( "ns2", "name2", "step1", GaugeResult.create(27L, Instant.now())))); + assertThat( + results.getStringSets(), + contains( + committedMetricsResult( + "ns2", "name2", "step1", StringSetResult.create(ImmutableSet.of("ab", "cd"))))); } @SuppressWarnings("unchecked") @@ -140,6 +156,7 @@ public void testApplyAttemptedCountersQueryOneNamespace() { MetricUpdate.create(MetricKey.create("step1", NAME1), 5L), MetricUpdate.create(MetricKey.create("step1", NAME3), 8L)), ImmutableList.of(), + ImmutableList.of(), ImmutableList.of())); metrics.updatePhysical( bundle1, @@ -148,6 +165,7 @@ public void testApplyAttemptedCountersQueryOneNamespace() { MetricUpdate.create(MetricKey.create("step2", NAME1), 7L), MetricUpdate.create(MetricKey.create("step1", NAME3), 4L)), ImmutableList.of(), + ImmutableList.of(), ImmutableList.of())); MetricQueryResults results = @@ -176,6 +194,7 @@ public void testApplyAttemptedQueryCompositeScope() { MetricUpdate.create(MetricKey.create("Outer1/Inner1", NAME1), 5L), MetricUpdate.create(MetricKey.create("Outer1/Inner2", NAME1), 8L)), ImmutableList.of(), + ImmutableList.of(), ImmutableList.of())); metrics.updatePhysical( bundle1, @@ -184,6 +203,7 @@ public void testApplyAttemptedQueryCompositeScope() { MetricUpdate.create(MetricKey.create("Outer1/Inner1", NAME1), 12L), MetricUpdate.create(MetricKey.create("Outer2/Inner2", NAME1), 18L)), ImmutableList.of(), + ImmutableList.of(), ImmutableList.of())); MetricQueryResults results = @@ -212,6 +232,7 @@ public void testPartialScopeMatchingInMetricsQuery() { MetricUpdate.create(MetricKey.create("Top1/Outer1/Inner1", NAME1), 5L), MetricUpdate.create(MetricKey.create("Top1/Outer1/Inner2", NAME1), 8L)), ImmutableList.of(), + ImmutableList.of(), ImmutableList.of())); metrics.updatePhysical( bundle1, @@ -220,6 +241,7 @@ public void testPartialScopeMatchingInMetricsQuery() { MetricUpdate.create(MetricKey.create("Top2/Outer1/Inner1", NAME1), 12L), MetricUpdate.create(MetricKey.create("Top1/Outer2/Inner2", NAME1), 18L)), ImmutableList.of(), + ImmutableList.of(), ImmutableList.of())); MetricQueryResults results = diff --git a/runners/extensions-java/metrics/src/test/java/org/apache/beam/runners/extensions/metrics/CustomMetricQueryResults.java b/runners/extensions-java/metrics/src/test/java/org/apache/beam/runners/extensions/metrics/CustomMetricQueryResults.java index a9cea996680b..96c0374067cf 100644 --- a/runners/extensions-java/metrics/src/test/java/org/apache/beam/runners/extensions/metrics/CustomMetricQueryResults.java +++ b/runners/extensions-java/metrics/src/test/java/org/apache/beam/runners/extensions/metrics/CustomMetricQueryResults.java @@ -26,6 +26,8 @@ import org.apache.beam.sdk.metrics.MetricQueryResults; import org.apache.beam.sdk.metrics.MetricResult; import org.apache.beam.sdk.metrics.MetricsSink; +import org.apache.beam.sdk.metrics.StringSetResult; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.joda.time.Instant; /** Test class to be used as a input to {@link MetricsSink} implementations tests. */ @@ -71,4 +73,13 @@ public List> getGauges() { GaugeResult.create(100L, new Instant(345862800L)), GaugeResult.create(120L, new Instant(345862800L))); } + + @Override + public Iterable> getStringSets() { + return makeResults( + "s3", + "n3", + StringSetResult.create(ImmutableSet.of("ab")), + StringSetResult.create(ImmutableSet.of("cd"))); + } } diff --git a/runners/extensions-java/metrics/src/test/java/org/apache/beam/runners/extensions/metrics/MetricsHttpSinkTest.java b/runners/extensions-java/metrics/src/test/java/org/apache/beam/runners/extensions/metrics/MetricsHttpSinkTest.java index afbe77bdb885..10e9481d271b 100644 --- a/runners/extensions-java/metrics/src/test/java/org/apache/beam/runners/extensions/metrics/MetricsHttpSinkTest.java +++ b/runners/extensions-java/metrics/src/test/java/org/apache/beam/runners/extensions/metrics/MetricsHttpSinkTest.java @@ -94,7 +94,9 @@ public void testWriteMetricsWithCommittedSupported() throws Exception { + "\"namespace\":\"ns1\"},\"step\":\"s2\"}],\"gauges\":[{\"attempted\":{\"timestamp\":" + "\"1970-01-05T00:04:22.800Z\",\"value\":120},\"committed\":{\"timestamp\":" + "\"1970-01-05T00:04:22.800Z\",\"value\":100},\"name\":{\"name\":\"n3\",\"namespace\":" - + "\"ns1\"},\"step\":\"s3\"}]}"; + + "\"ns1\"},\"step\":\"s3\"}],\"stringSets\":[{\"attempted\":{\"stringSet\":[\"cd" + + "\"]},\"committed\":{\"stringSet\":[\"ab\"]},\"name\":{\"name\":\"n3\"," + + "\"namespace\":\"ns1\"},\"step\":\"s3\"}]}"; assertEquals("Wrong number of messages sent to HTTP server", 1, messages.size()); assertEquals("Wrong messages sent to HTTP server", expected, messages.get(0)); } @@ -114,7 +116,8 @@ public void testWriteMetricsWithCommittedUnSupported() throws Exception { + "{\"count\":4,\"max\":9,\"mean\":6.25,\"min\":3,\"sum\":25},\"name\":{\"name\":\"n2\"" + ",\"namespace\":\"ns1\"},\"step\":\"s2\"}],\"gauges\":[{\"attempted\":{\"timestamp\":" + "\"1970-01-05T00:04:22.800Z\",\"value\":120},\"name\":{\"name\":\"n3\",\"namespace\":" - + "\"ns1\"},\"step\":\"s3\"}]}"; + + "\"ns1\"},\"step\":\"s3\"}],\"stringSets\":[{\"attempted\":{\"stringSet\":[\"cd\"]}," + + "\"name\":{\"name\":\"n3\",\"namespace\":\"ns1\"},\"step\":\"s3\"}]}"; assertEquals("Wrong number of messages sent to HTTP server", 1, messages.size()); assertEquals("Wrong messages sent to HTTP server", expected, messages.get(0)); } diff --git a/runners/flink/job-server/flink_job_server.gradle b/runners/flink/job-server/flink_job_server.gradle index 56a58df4fb09..9b565f119a62 100644 --- a/runners/flink/job-server/flink_job_server.gradle +++ b/runners/flink/job-server/flink_job_server.gradle @@ -171,6 +171,7 @@ def portableValidatesRunnerTask(String name, boolean streaming, boolean checkpoi excludeCategories 'org.apache.beam.sdk.testing.UsesCustomWindowMerging' excludeCategories 'org.apache.beam.sdk.testing.UsesFailureMessage' excludeCategories 'org.apache.beam.sdk.testing.UsesGaugeMetrics' + excludeCategories 'org.apache.beam.sdk.testing.UsesStringSetMetrics' excludeCategories 'org.apache.beam.sdk.testing.UsesParDoLifecycle' excludeCategories 'org.apache.beam.sdk.testing.UsesMapState' excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState' diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/BeamAdapterCoderUtils.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/BeamAdapterCoderUtils.java index d20d5e2fc035..f20988fef8ff 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/BeamAdapterCoderUtils.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/BeamAdapterCoderUtils.java @@ -21,6 +21,7 @@ import java.util.Map; import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.fnexecution.wire.LengthPrefixUnknownCoders; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; @@ -80,4 +81,19 @@ static Coder lookupCoder(RunnerApi.Pipeline p, String pCollectionId) { throw new RuntimeException(exn); } } + + static void registerKnownCoderFor(RunnerApi.Pipeline p, String pCollectionId) { + registerAsKnownCoder(p, p.getComponents().getPcollectionsOrThrow(pCollectionId).getCoderId()); + } + + static void registerAsKnownCoder(RunnerApi.Pipeline p, String coderId) { + RunnerApi.Coder coder = p.getComponents().getCodersOrThrow(coderId); + // It'd be more targeted to note the coder id rather than the URN, + // but the length prefixing code is invoked within a deeply nested + // sequence of static method calls. + LengthPrefixUnknownCoders.addKnownCoderUrn(coder.getSpec().getUrn()); + for (String componentCoderId : coder.getComponentCoderIdsList()) { + registerAsKnownCoder(p, componentCoderId); + } + } } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/BeamAdapterUtils.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/BeamAdapterUtils.java index 982dabe2dd78..0e642c96cdc0 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/BeamAdapterUtils.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/BeamAdapterUtils.java @@ -29,6 +29,7 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.util.construction.Environments; +import org.apache.beam.sdk.util.construction.PTransformTranslation; import org.apache.beam.sdk.util.construction.PipelineTranslation; import org.apache.beam.sdk.util.construction.SdkComponents; import org.apache.beam.sdk.values.PBegin; @@ -37,6 +38,7 @@ import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.flink.api.common.typeinfo.TypeInformation; @@ -108,6 +110,26 @@ Map applyBeamPTransformInternal( // Extract the pipeline definition so that we can apply or Flink translation logic. SdkComponents components = SdkComponents.create(pipelineOptions); RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline, components); + + // Avoid swapping input and output coders for BytesCoders. + // As we have instantiated the actual coder objects here, there is no need ot length prefix them + // anyway. + // TODO(robertwb): Even better would be to avoid coding and decoding along these edges via a + // direct + // in-memory channel for embedded mode. As well as improving performance, there could be + // control-flow advantages too. + for (RunnerApi.PTransform transformProto : + pipelineProto.getComponents().getTransforms().values()) { + if (FlinkInput.URN.equals(PTransformTranslation.urnForTransformOrNull(transformProto))) { + BeamAdapterCoderUtils.registerKnownCoderFor( + pipelineProto, Iterables.getOnlyElement(transformProto.getOutputs().values())); + } else if (FlinkOutput.URN.equals( + PTransformTranslation.urnForTransformOrNull(transformProto))) { + BeamAdapterCoderUtils.registerKnownCoderFor( + pipelineProto, Iterables.getOnlyElement(transformProto.getInputs().values())); + } + } + return translator.translate(inputs, pipelineProto, executionEnvironment); } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/BeamFlinkDataSetAdapter.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/BeamFlinkDataSetAdapter.java index d7840067e7c4..4865a25f70eb 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/BeamFlinkDataSetAdapter.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/BeamFlinkDataSetAdapter.java @@ -246,7 +246,6 @@ private FlinkBatchPortablePipelineTranslator.PTransformTranslator flink Coder outputCoder = BeamAdapterCoderUtils.lookupCoder( p, Iterables.getOnlyElement(t.getTransform().getInputsMap().values())); - // TODO(robertwb): Also handle or disable length prefix coding (for embedded mode at least). outputMap.put( outputId, new MapOperator, InputT>( diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/BeamFlinkDataStreamAdapter.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/BeamFlinkDataStreamAdapter.java index d21bf89d40b2..0a7a1fec803b 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/BeamFlinkDataStreamAdapter.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/adapter/BeamFlinkDataStreamAdapter.java @@ -271,7 +271,6 @@ public void processElement( Coder outputCoder = BeamAdapterCoderUtils.lookupCoder( p, Iterables.getOnlyElement(transform.getInputsMap().values())); - // TODO(robertwb): Also handle or disable length prefix coding (for embedded mode at least). outputMap.put( outputId, inputDataStream.transform( diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/adapter/BeamFlinkDataSetAdapterTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/adapter/BeamFlinkDataSetAdapterTest.java index b69106e44f3e..6b41178d1264 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/adapter/BeamFlinkDataSetAdapterTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/adapter/BeamFlinkDataSetAdapterTest.java @@ -20,7 +20,13 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.List; import java.util.Map; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; @@ -132,4 +138,48 @@ public void testApplyGroupingTransform() throws Exception { assertThat(result.collect(), containsInAnyOrder(KV.of("a", 2L), KV.of("b", 1L))); } + + @Test + public void testCustomCoder() throws Exception { + ExecutionEnvironment env = ExecutionEnvironment.createCollectionsEnvironment(); + + DataSet input = env.fromCollection(ImmutableList.of("a", "b", "c")); + DataSet result = + new BeamFlinkDataSetAdapter() + .applyBeamPTransform( + input, + new PTransform, PCollection>() { + @Override + public PCollection expand(PCollection input) { + return input.apply(withPrefix("x")).setCoder(new MyCoder()); + } + }); + + assertThat(result.collect(), containsInAnyOrder("xa", "xb", "xc")); + } + + private static class MyCoder extends Coder { + + private static final int CUSTOM_MARKER = 3; + + @Override + public void encode(String value, OutputStream outStream) throws IOException { + outStream.write(CUSTOM_MARKER); + StringUtf8Coder.of().encode(value, outStream); + } + + @Override + public String decode(InputStream inStream) throws IOException { + assert inStream.read() == CUSTOM_MARKER; + return StringUtf8Coder.of().decode(inStream); + } + + @Override + public List> getCoderArguments() { + return null; + } + + @Override + public void verifyDeterministic() throws NonDeterministicException {} + } } diff --git a/runners/google-cloud-dataflow-java/arm/build.gradle b/runners/google-cloud-dataflow-java/arm/build.gradle index 4771ee5efe82..ae63cdf8bdb7 100644 --- a/runners/google-cloud-dataflow-java/arm/build.gradle +++ b/runners/google-cloud-dataflow-java/arm/build.gradle @@ -84,19 +84,19 @@ def javaVer = "java8" if (project.hasProperty('testJavaVersion')) { javaVer = "java${project.getProperty('testJavaVersion')}" } -def dataflowProject = project.findProperty('dataflowProject') ?: 'apache-beam-testing' -def dataflowRegion = project.findProperty('dataflowRegion') ?: 'us-central1' -def dataflowValidatesTempRoot = project.findProperty('dataflowTempRoot') ?: 'gs://temp-storage-for-validates-runner-tests' +def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing' +def gcpRegion = project.findProperty('gcpRegion') ?: 'us-central1' +def dataflowValidatesTempRoot = project.findProperty('gcpTempRoot') ?: 'gs://temp-storage-for-validates-runner-tests' def firestoreDb = project.findProperty('firestoreDb') ?: 'firestoredb' -def dockerImageRoot = project.findProperty('docker-repository-root') ?: "us.gcr.io/${dataflowProject}/java-postcommit-it" +def dockerImageRoot = project.findProperty('docker-repository-root') ?: "us.gcr.io/${gcpProject}/java-postcommit-it" def DockerJavaMultiarchImageContainer = "${dockerImageRoot}/${project.docker_image_default_repo_prefix}${javaVer}_sdk" def dockerTag = project.findProperty('docker-tag') ?: new Date().format('yyyyMMddHHmmss') ext.DockerJavaMultiarchImageName = "${DockerJavaMultiarchImageContainer}:${dockerTag}" as String def runnerV2PipelineOptionsARM = [ "--runner=TestDataflowRunner", - "--project=${dataflowProject}", - "--region=${dataflowRegion}", + "--project=${gcpProject}", + "--region=${gcpRegion}", "--tempRoot=${dataflowValidatesTempRoot}", "--sdkContainerImage=${project.ext.DockerJavaMultiarchImageName}", "--experiments=use_unified_worker,use_runner_v2", diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle index 09ab59b3f4bc..4ddbb6ea41f4 100644 --- a/runners/google-cloud-dataflow-java/build.gradle +++ b/runners/google-cloud-dataflow-java/build.gradle @@ -51,8 +51,8 @@ evaluationDependsOn(":sdks:java:container:java11") ext.dataflowLegacyEnvironmentMajorVersion = '8' ext.dataflowFnapiEnvironmentMajorVersion = '8' -ext.dataflowLegacyContainerVersion = 'beam-master-20240306' -ext.dataflowFnapiContainerVersion = 'beam-master-20240306' +ext.dataflowLegacyContainerVersion = 'beam-master-20240718' +ext.dataflowFnapiContainerVersion = 'beam-master-20240718' ext.dataflowContainerBaseRepository = 'gcr.io/cloud-dataflow/v1beta3' processResources { @@ -138,18 +138,18 @@ dependencies { googleCloudPlatformIntegrationTest project(path: ":sdks:java:io:google-cloud-platform", configuration: "testRuntimeMigration") } -def dataflowProject = project.findProperty('dataflowProject') ?: 'apache-beam-testing' -def dataflowRegion = project.findProperty('dataflowRegion') ?: 'us-central1' -def dataflowValidatesTempRoot = project.findProperty('dataflowTempRoot') ?: 'gs://temp-storage-for-validates-runner-tests' -def dataflowPostCommitTempRoot = project.findProperty('dataflowTempRoot') ?: 'gs://temp-storage-for-end-to-end-tests' -def dataflowPostCommitTempRootKms = project.findProperty('dataflowTempRootKms') ?: 'gs://temp-storage-for-end-to-end-tests-cmek' -def dataflowUploadTemp = project.findProperty('dataflowTempRoot') ?: 'gs://temp-storage-for-upload-tests' +def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing' +def gcpRegion = project.findProperty('gcpRegion') ?: 'us-central1' +def dataflowValidatesTempRoot = project.findProperty('gcpTempRoot') ?: 'gs://temp-storage-for-validates-runner-tests' +def dataflowPostCommitTempRoot = project.findProperty('gcpTempRoot') ?: 'gs://temp-storage-for-end-to-end-tests' +def dataflowPostCommitTempRootKms = project.findProperty('gcpTempRootKms') ?: 'gs://temp-storage-for-end-to-end-tests-cmek' +def dataflowUploadTemp = project.findProperty('gcpTempRoot') ?: 'gs://temp-storage-for-upload-tests' def testFilesToStage = project.findProperty('filesToStage') ?: 'test.txt' def dataflowLegacyWorkerJar = project.findProperty('dataflowWorkerJar') ?: project(":runners:google-cloud-dataflow-java:worker").shadowJar.archivePath def dataflowKmsKey = project.findProperty('dataflowKmsKey') ?: "projects/apache-beam-testing/locations/global/keyRings/beam-it/cryptoKeys/test" def firestoreDb = project.findProperty('firestoreDb') ?: 'firestoredb' -def dockerImageRoot = project.findProperty('dockerImageRoot') ?: "us.gcr.io/${dataflowProject}/java-postcommit-it" +def dockerImageRoot = project.findProperty('dockerImageRoot') ?: "us.gcr.io/${gcpProject}/java-postcommit-it" def dockerJavaImageContainer = "${dockerImageRoot}/java" def dockerPythonImageContainer = "${dockerImageRoot}/python" def dockerTag = new Date().format('yyyyMMddHHmmss') @@ -158,8 +158,8 @@ ext.dockerPythonImageName = "${dockerPythonImageContainer}:${dockerTag}" def legacyPipelineOptions = [ "--runner=TestDataflowRunner", - "--project=${dataflowProject}", - "--region=${dataflowRegion}", + "--project=${gcpProject}", + "--region=${gcpRegion}", "--tempRoot=${dataflowValidatesTempRoot}", "--dataflowWorkerJar=${dataflowLegacyWorkerJar}", "--workerHarnessContainerImage=", @@ -167,8 +167,8 @@ def legacyPipelineOptions = [ def runnerV2PipelineOptions = [ "--runner=TestDataflowRunner", - "--project=${dataflowProject}", - "--region=${dataflowRegion}", + "--project=${gcpProject}", + "--region=${gcpRegion}", "--tempRoot=${dataflowValidatesTempRoot}", "--sdkContainerImage=${dockerJavaImageContainer}:${dockerTag}", "--experiments=use_unified_worker,use_runner_v2", @@ -193,6 +193,7 @@ def commonLegacyExcludeCategories = [ def commonRunnerV2ExcludeCategories = [ 'org.apache.beam.sdk.testing.UsesExternalService', 'org.apache.beam.sdk.testing.UsesGaugeMetrics', + 'org.apache.beam.sdk.testing.UsesStringSetMetrics', 'org.apache.beam.sdk.testing.UsesSetState', 'org.apache.beam.sdk.testing.UsesMapState', 'org.apache.beam.sdk.testing.UsesMultimapState', @@ -433,14 +434,14 @@ createCrossLanguageValidatesRunnerTask( semiPersistDir: "/var/opt/google", pythonPipelineOptions: [ "--runner=TestDataflowRunner", - "--project=${dataflowProject}", - "--region=${dataflowRegion}", + "--project=${gcpProject}", + "--region=${gcpRegion}", "--sdk_harness_container_image_overrides=.*java.*,${dockerJavaImageContainer}:${dockerTag}", ], javaPipelineOptions: [ "--runner=TestDataflowRunner", - "--project=${dataflowProject}", - "--region=${dataflowRegion}", + "--project=${gcpProject}", + "--region=${gcpRegion}", "--tempRoot=${dataflowValidatesTempRoot}", "--sdkContainerImage=${dockerJavaImageContainer}:${dockerTag}", "--sdkHarnessContainerImageOverrides=.*python.*,${dockerPythonImageContainer}:${dockerTag}", @@ -453,9 +454,9 @@ createCrossLanguageValidatesRunnerTask( ], goScriptOptions: [ "--runner dataflow", - "--project ${dataflowProject}", - "--dataflow_project ${dataflowProject}", - "--region ${dataflowRegion}", + "--project ${gcpProject}", + "--dataflow_project ${gcpProject}", + "--region ${gcpRegion}", "--tests \"./test/integration/xlang ./test/integration/io/xlang/...\"", "--sdk_overrides \".*java.*,${dockerJavaImageContainer}:${dockerTag}\"", ], @@ -552,8 +553,8 @@ task googleCloudPlatformLegacyWorkerIntegrationTest(type: Test, dependsOn: copyG dependsOn ":runners:google-cloud-dataflow-java:worker:shadowJar" systemProperty "beamTestPipelineOptions", JsonOutput.toJson([ "--runner=TestDataflowRunner", - "--project=${dataflowProject}", - "--region=${dataflowRegion}", + "--project=${gcpProject}", + "--region=${gcpRegion}", "--tempRoot=${dataflowPostCommitTempRoot}", "--dataflowWorkerJar=${dataflowLegacyWorkerJar}", "--workerHarnessContainerImage=", @@ -580,8 +581,8 @@ task googleCloudPlatformLegacyWorkerKmsIntegrationTest(type: Test) { dependsOn ":runners:google-cloud-dataflow-java:worker:shadowJar" systemProperty "beamTestPipelineOptions", JsonOutput.toJson([ "--runner=TestDataflowRunner", - "--project=${dataflowProject}", - "--region=${dataflowRegion}", + "--project=${gcpProject}", + "--region=${gcpRegion}", "--tempRoot=${dataflowPostCommitTempRootKms}", "--dataflowWorkerJar=${dataflowLegacyWorkerJar}", "--workerHarnessContainerImage=", @@ -666,8 +667,8 @@ task coreSDKJavaLegacyWorkerIntegrationTest(type: Test) { systemProperty "beamTestPipelineOptions", JsonOutput.toJson([ "--runner=TestDataflowRunner", - "--project=${dataflowProject}", - "--region=${dataflowRegion}", + "--project=${gcpProject}", + "--region=${gcpRegion}", "--tempRoot=${dataflowPostCommitTempRoot}", "--dataflowWorkerJar=${dataflowLegacyWorkerJar}", "--workerHarnessContainerImage=", @@ -712,8 +713,6 @@ task postCommitRunnerV2 { dependsOn coreSDKJavaRunnerV2IntegrationTest } -def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing' -def gcpRegion = project.findProperty('gcpRegion') ?: 'us-central1' def gcsBucket = project.findProperty('gcsBucket') ?: 'temp-storage-for-release-validation-tests/nightly-snapshot-validation' def bqDataset = project.findProperty('bqDataset') ?: 'beam_postrelease_mobile_gaming' def pubsubTopic = project.findProperty('pubsubTopic') ?: 'java_mobile_gaming_topic' @@ -743,7 +742,7 @@ createJavaExamplesArchetypeValidationTask(type: 'MobileGaming', bqDataset: bqDataset, pubsubTopic: pubsubTopic) -// Standalone task for testing GCS upload, use with -PfilesToStage and -PdataflowTempRoot. +// Standalone task for testing GCS upload, use with -PfilesToStage and -PgcpTempRoot. task GCSUpload(type: JavaExec) { mainClass = 'org.apache.beam.runners.dataflow.util.GCSUploadMain' classpath = sourceSets.test.runtimeClasspath diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowMetrics.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowMetrics.java index c5023a57d8d6..46fdce507c3d 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowMetrics.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowMetrics.java @@ -24,6 +24,7 @@ import com.google.api.services.dataflow.model.JobMetrics; import com.google.api.services.dataflow.model.MetricUpdate; import java.io.IOException; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -38,10 +39,12 @@ import org.apache.beam.sdk.metrics.MetricResult; import org.apache.beam.sdk.metrics.MetricResults; import org.apache.beam.sdk.metrics.MetricsFilter; +import org.apache.beam.sdk.metrics.StringSetResult; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Objects; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.BiMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -99,12 +102,13 @@ public MetricQueryResults queryMetrics(MetricsFilter filter) { ImmutableList> counters = ImmutableList.of(); ImmutableList> distributions = ImmutableList.of(); ImmutableList> gauges = ImmutableList.of(); + ImmutableList> stringSets = ImmutableList.of(); JobMetrics jobMetrics; try { jobMetrics = getJobMetrics(); } catch (IOException e) { LOG.warn("Unable to query job metrics.\n"); - return MetricQueryResults.create(counters, distributions, gauges); + return MetricQueryResults.create(counters, distributions, gauges, stringSets); } metricUpdates = firstNonNull(jobMetrics.getMetrics(), Collections.emptyList()); return populateMetricQueryResults(metricUpdates, filter); @@ -127,12 +131,19 @@ private static class DataflowMetricResultExtractor { private final ImmutableList.Builder> counterResults; private final ImmutableList.Builder> distributionResults; private final ImmutableList.Builder> gaugeResults; + private final ImmutableList.Builder> stringSetResults; private final boolean isStreamingJob; DataflowMetricResultExtractor(boolean isStreamingJob) { counterResults = ImmutableList.builder(); distributionResults = ImmutableList.builder(); gaugeResults = ImmutableList.builder(); + stringSetResults = ImmutableList.builder(); + /* In Dataflow streaming jobs, only ATTEMPTED metrics are available. + * In Dataflow batch jobs, only COMMITTED metrics are available, but + * we must provide ATTEMPTED, so we use COMMITTED as a good approximation. + * Reporting the appropriate metric depending on whether it's a batch/streaming job. + */ this.isStreamingJob = isStreamingJob; } @@ -148,20 +159,14 @@ public void addMetricResult( // distribution metric DistributionResult value = getDistributionValue(committed); distributionResults.add(MetricResult.create(metricKey, !isStreamingJob, value)); - /* In Dataflow streaming jobs, only ATTEMPTED metrics are available. - * In Dataflow batch jobs, only COMMITTED metrics are available, but - * we must provide ATTEMPTED, so we use COMMITTED as a good approximation. - * Reporting the appropriate metric depending on whether it's a batch/streaming job. - */ } else if (committed.getScalar() != null && attempted.getScalar() != null) { // counter metric Long value = getCounterValue(committed); counterResults.add(MetricResult.create(metricKey, !isStreamingJob, value)); - /* In Dataflow streaming jobs, only ATTEMPTED metrics are available. - * In Dataflow batch jobs, only COMMITTED metrics are available, but - * we must provide ATTEMPTED, so we use COMMITTED as a good approximation. - * Reporting the appropriate metric depending on whether it's a batch/streaming job. - */ + } else if (committed.getSet() != null && attempted.getSet() != null) { + // stringset metric + StringSetResult value = getStringSetValue(committed); + stringSetResults.add(MetricResult.create(metricKey, !isStreamingJob, value)); } else { // This is exceptionally unexpected. We expect matching user metrics to only have the // value types provided by the Metrics API. @@ -182,6 +187,13 @@ private Long getCounterValue(MetricUpdate metricUpdate) { return ((Number) metricUpdate.getScalar()).longValue(); } + private StringSetResult getStringSetValue(MetricUpdate metricUpdate) { + if (metricUpdate.getSet() == null) { + return StringSetResult.empty(); + } + return StringSetResult.create(ImmutableSet.copyOf(((Collection) metricUpdate.getSet()))); + } + private DistributionResult getDistributionValue(MetricUpdate metricUpdate) { if (metricUpdate.getDistribution() == null) { return DistributionResult.IDENTITY_ELEMENT; @@ -205,6 +217,10 @@ public Iterable> getCounterResults() { public Iterable> getGaugeResults() { return gaugeResults.build(); } + + public Iterable> geStringSetResults() { + return stringSetResults.build(); + } } private static class DataflowMetricQueryResultsFactory { @@ -369,7 +385,8 @@ public MetricQueryResults build() { return MetricQueryResults.create( extractor.getCounterResults(), extractor.getDistributionResults(), - extractor.getGaugeResults()); + extractor.getGaugeResults(), + extractor.geStringSetResults()); } } } diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowMetricsTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowMetricsTest.java index 527273abb42e..9b8e3cc871da 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowMetricsTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowMetricsTest.java @@ -40,6 +40,7 @@ import com.google.api.services.dataflow.model.MetricUpdate; import java.io.IOException; import java.math.BigDecimal; +import java.util.Set; import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; import org.apache.beam.runners.dataflow.util.DataflowTemplateJob; import org.apache.beam.sdk.PipelineResult.State; @@ -48,12 +49,14 @@ import org.apache.beam.sdk.metrics.DistributionResult; import org.apache.beam.sdk.metrics.MetricQueryResults; import org.apache.beam.sdk.metrics.MetricsFilter; +import org.apache.beam.sdk.metrics.StringSetResult; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.BiMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashBiMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -115,6 +118,7 @@ public void testEmptyMetricUpdates() throws IOException { MetricQueryResults result = dataflowMetrics.allMetrics(); assertThat(ImmutableList.copyOf(result.getCounters()), is(empty())); assertThat(ImmutableList.copyOf(result.getDistributions()), is(empty())); + assertThat(ImmutableList.copyOf(result.getStringSets()), is(empty())); } @Test @@ -184,6 +188,13 @@ private MetricUpdate makeCounterMetricUpdate( return setStructuredName(update, name, namespace, step, tentative); } + private MetricUpdate makeStringSetMetricUpdate( + String name, String namespace, String step, Set setValues, boolean tentative) { + MetricUpdate update = new MetricUpdate(); + update.setSet(setValues); + return setStructuredName(update, name, namespace, step, tentative); + } + @Test public void testSingleCounterUpdates() throws IOException { AppliedPTransform myStep = mock(AppliedPTransform.class); @@ -226,6 +237,54 @@ public void testSingleCounterUpdates() throws IOException { committedMetricsResult("counterNamespace", "counterName", "myStepName", 1234L))); } + @Test + public void testSingleStringSetUpdates() throws IOException { + AppliedPTransform myStep = mock(AppliedPTransform.class); + when(myStep.getFullName()).thenReturn("myStepName"); + BiMap, String> transformStepNames = HashBiMap.create(); + transformStepNames.put(myStep, "s2"); + + JobMetrics jobMetrics = new JobMetrics(); + DataflowPipelineJob job = mock(DataflowPipelineJob.class); + DataflowPipelineOptions options = mock(DataflowPipelineOptions.class); + when(options.isStreaming()).thenReturn(false); + when(job.getDataflowOptions()).thenReturn(options); + when(job.getState()).thenReturn(State.RUNNING); + when(job.getJobId()).thenReturn(JOB_ID); + when(job.getTransformStepNames()).thenReturn(transformStepNames); + + // The parser relies on the fact that one tentative and one committed metric update exist in + // the job metrics results. + MetricUpdate mu1 = + makeStringSetMetricUpdate( + "counterName", "counterNamespace", "s2", ImmutableSet.of("ab", "cd"), false); + MetricUpdate mu1Tentative = + makeStringSetMetricUpdate( + "counterName", "counterNamespace", "s2", ImmutableSet.of("ab", "cd"), true); + jobMetrics.setMetrics(ImmutableList.of(mu1, mu1Tentative)); + DataflowClient dataflowClient = mock(DataflowClient.class); + when(dataflowClient.getJobMetrics(JOB_ID)).thenReturn(jobMetrics); + + DataflowMetrics dataflowMetrics = new DataflowMetrics(job, dataflowClient); + MetricQueryResults result = dataflowMetrics.allMetrics(); + assertThat( + result.getStringSets(), + containsInAnyOrder( + attemptedMetricsResult( + "counterNamespace", + "counterName", + "myStepName", + StringSetResult.create(ImmutableSet.of("ab", "cd"))))); + assertThat( + result.getStringSets(), + containsInAnyOrder( + committedMetricsResult( + "counterNamespace", + "counterName", + "myStepName", + StringSetResult.create(ImmutableSet.of("ab", "cd"))))); + } + @Test public void testIgnoreDistributionButGetCounterUpdates() throws IOException { AppliedPTransform myStep = mock(AppliedPTransform.class); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/BatchModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/BatchModeExecutionContext.java index 901e305b22b1..aeef7784c2c3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/BatchModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/BatchModeExecutionContext.java @@ -39,6 +39,7 @@ import org.apache.beam.runners.dataflow.worker.profiler.ScopedProfiler.ProfileScope; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.metrics.MetricName; +import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.metrics.MetricsContainer; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.state.TimeDomain; @@ -67,9 +68,6 @@ public class BatchModeExecutionContext private Object key; private final MetricsContainerRegistry containerRegistry; - - // TODO(https://github.com/apache/beam/issues/19632): Move throttle time Metric to a dedicated - // namespace. protected static final String DATASTORE_THROTTLE_TIME_NAMESPACE = "org.apache.beam.sdk.io.gcp.datastore.DatastoreV1$DatastoreWriterFn"; protected static final String HTTP_CLIENT_API_THROTTLE_TIME_NAMESPACE = @@ -78,7 +76,6 @@ public class BatchModeExecutionContext "org.apache.beam.sdk.io.gcp.bigquery.BigQueryServicesImpl$DatasetServiceImpl"; protected static final String BIGQUERY_READ_THROTTLE_TIME_NAMESPACE = "org.apache.beam.sdk.io.gcp.bigquery.BigQueryServicesImpl$StorageClientImpl"; - protected static final String THROTTLE_TIME_COUNTER_NAME = "throttling-msecs"; private BatchModeExecutionContext( CounterFactory counterFactory, @@ -517,7 +514,12 @@ public Iterable extractMetricUpdates(boolean isFinalUpdate) { .transform( update -> MetricsToCounterUpdateConverter.fromDistribution( - update.getKey(), true, update.getUpdate()))); + update.getKey(), true, update.getUpdate())), + FluentIterable.from(updates.stringSetUpdates()) + .transform( + update -> + MetricsToCounterUpdateConverter.fromStringSet( + update.getKey(), update.getUpdate()))); }); } @@ -534,11 +536,18 @@ public Iterable extractMsecCounters(boolean isFinalUpdate) { public Long extractThrottleTime() { long totalThrottleMsecs = 0L; for (MetricsContainerImpl container : containerRegistry.getContainers()) { - // TODO(https://github.com/apache/beam/issues/19632): Update throttling counters to use - // generic throttling-msecs metric. + CounterCell userThrottlingTime = + container.tryGetCounter( + MetricName.named( + Metrics.THROTTLE_TIME_NAMESPACE, Metrics.THROTTLE_TIME_COUNTER_NAME)); + if (userThrottlingTime != null) { + totalThrottleMsecs += userThrottlingTime.getCumulative(); + } + CounterCell dataStoreThrottlingTime = container.tryGetCounter( - MetricName.named(DATASTORE_THROTTLE_TIME_NAMESPACE, THROTTLE_TIME_COUNTER_NAME)); + MetricName.named( + DATASTORE_THROTTLE_TIME_NAMESPACE, Metrics.THROTTLE_TIME_COUNTER_NAME)); if (dataStoreThrottlingTime != null) { totalThrottleMsecs += dataStoreThrottlingTime.getCumulative(); } @@ -546,7 +555,7 @@ public Long extractThrottleTime() { CounterCell httpClientApiThrottlingTime = container.tryGetCounter( MetricName.named( - HTTP_CLIENT_API_THROTTLE_TIME_NAMESPACE, THROTTLE_TIME_COUNTER_NAME)); + HTTP_CLIENT_API_THROTTLE_TIME_NAMESPACE, Metrics.THROTTLE_TIME_COUNTER_NAME)); if (httpClientApiThrottlingTime != null) { totalThrottleMsecs += httpClientApiThrottlingTime.getCumulative(); } @@ -554,14 +563,16 @@ public Long extractThrottleTime() { CounterCell bigqueryStreamingInsertThrottleTime = container.tryGetCounter( MetricName.named( - BIGQUERY_STREAMING_INSERT_THROTTLE_TIME_NAMESPACE, THROTTLE_TIME_COUNTER_NAME)); + BIGQUERY_STREAMING_INSERT_THROTTLE_TIME_NAMESPACE, + Metrics.THROTTLE_TIME_COUNTER_NAME)); if (bigqueryStreamingInsertThrottleTime != null) { totalThrottleMsecs += bigqueryStreamingInsertThrottleTime.getCumulative(); } CounterCell bigqueryReadThrottleTime = container.tryGetCounter( - MetricName.named(BIGQUERY_READ_THROTTLE_TIME_NAMESPACE, THROTTLE_TIME_COUNTER_NAME)); + MetricName.named( + BIGQUERY_READ_THROTTLE_TIME_NAMESPACE, Metrics.THROTTLE_TIME_COUNTER_NAME)); if (bigqueryReadThrottleTime != null) { totalThrottleMsecs += bigqueryReadThrottleTime.getCumulative(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowMetricsContainer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowMetricsContainer.java index c3e4fb1388b0..f9cd098edaa6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowMetricsContainer.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowMetricsContainer.java @@ -26,6 +26,7 @@ import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.sdk.metrics.MetricsContainer; import org.apache.beam.sdk.metrics.MetricsEnvironment; +import org.apache.beam.sdk.metrics.StringSet; import org.apache.beam.sdk.util.HistogramData; /** @@ -73,6 +74,11 @@ public Gauge getGauge(MetricName metricName) { return getCurrentContainer().getGauge(metricName); } + @Override + public StringSet getStringSet(MetricName metricName) { + return getCurrentContainer().getStringSet(metricName); + } + @Override public Histogram getPerWorkerHistogram( MetricName metricName, HistogramData.BucketType bucketType) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowSystemMetrics.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowSystemMetrics.java index 640febc616ba..c5a24df192eb 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowSystemMetrics.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowSystemMetrics.java @@ -20,15 +20,14 @@ import org.apache.beam.runners.dataflow.worker.counters.CounterName; import org.apache.beam.runners.dataflow.worker.counters.NameContext; import org.apache.beam.sdk.metrics.MetricName; +import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; /** This holds system metrics related constants used in Batch and Streaming. */ public class DataflowSystemMetrics { public static final MetricName THROTTLING_MSECS_METRIC_NAME = - MetricName.named("dataflow-throttling-metrics", "throttling-msecs"); - - // TODO: Provide an utility in SDK 'ThrottlingReporter' to update throttling time. + MetricName.named("dataflow-throttling-metrics", Metrics.THROTTLE_TIME_COUNTER_NAME); /** System counters populated by streaming dataflow workers. */ public enum StreamingSystemCounterNames { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToCounterUpdateConverter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToCounterUpdateConverter.java index 22b55058d4f5..dbedc51528a5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToCounterUpdateConverter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToCounterUpdateConverter.java @@ -25,7 +25,10 @@ import com.google.api.services.dataflow.model.CounterUpdate; import com.google.api.services.dataflow.model.DistributionUpdate; import com.google.api.services.dataflow.model.IntegerGauge; +import com.google.api.services.dataflow.model.StringList; +import java.util.ArrayList; import org.apache.beam.runners.core.metrics.DistributionData; +import org.apache.beam.runners.core.metrics.StringSetData; import org.apache.beam.sdk.metrics.MetricKey; import org.apache.beam.sdk.metrics.MetricName; @@ -58,7 +61,8 @@ public enum Kind { DISTRIBUTION("DISTRIBUTION"), MEAN("MEAN"), SUM("SUM"), - LATEST_VALUE("LATEST_VALUE"); + LATEST_VALUE("LATEST_VALUE"), + SET("SET"); private final String kind; @@ -94,6 +98,18 @@ public static CounterUpdate fromGauge( .setIntegerGauge(integerGaugeProto); } + public static CounterUpdate fromStringSet(MetricKey key, StringSetData stringSetData) { + CounterStructuredNameAndMetadata name = structuredNameAndMetadata(key, Kind.SET); + + StringList stringList = new StringList(); + stringList.setElements(new ArrayList<>(stringSetData.stringSet())); + + return new CounterUpdate() + .setStructuredNameAndMetadata(name) + .setCumulative(false) + .setStringList(stringList); + } + public static CounterUpdate fromDistribution( MetricKey key, boolean isCumulative, DistributionData update) { CounterStructuredNameAndMetadata name = structuredNameAndMetadata(key, Kind.DISTRIBUTION); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/OperationalLimits.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/OperationalLimits.java new file mode 100644 index 000000000000..e9ee8f39cba4 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/OperationalLimits.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker; + +import com.google.auto.value.AutoBuilder; + +/** Keep track of any operational limits required by the backend. */ +public class OperationalLimits { + // Maximum size of a commit from a single work item. + public final long maxWorkItemCommitBytes; + // Maximum size of a single output element's serialized key. + public final long maxOutputKeyBytes; + // Maximum size of a single output element's serialized value. + public final long maxOutputValueBytes; + // Whether to throw an exception when processing output that violates any of the given limits. + public final boolean throwExceptionOnLargeOutput; + + OperationalLimits( + long maxWorkItemCommitBytes, + long maxOutputKeyBytes, + long maxOutputValueBytes, + boolean throwExceptionOnLargeOutput) { + this.maxWorkItemCommitBytes = maxWorkItemCommitBytes; + this.maxOutputKeyBytes = maxOutputKeyBytes; + this.maxOutputValueBytes = maxOutputValueBytes; + this.throwExceptionOnLargeOutput = throwExceptionOnLargeOutput; + } + + @AutoBuilder(ofClass = OperationalLimits.class) + public interface Builder { + Builder setMaxWorkItemCommitBytes(long bytes); + + Builder setMaxOutputKeyBytes(long bytes); + + Builder setMaxOutputValueBytes(long bytes); + + Builder setThrowExceptionOnLargeOutput(boolean shouldThrow); + + OperationalLimits build(); + } + + public static Builder builder() { + return new AutoBuilder_OperationalLimits_Builder() + .setMaxWorkItemCommitBytes(Long.MAX_VALUE) + .setMaxOutputKeyBytes(Long.MAX_VALUE) + .setMaxOutputValueBytes(Long.MAX_VALUE) + .setThrowExceptionOnLargeOutput(false); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/OutputTooLargeException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/OutputTooLargeException.java new file mode 100644 index 000000000000..9f4b413841c5 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/OutputTooLargeException.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker; + +import org.checkerframework.checker.nullness.qual.Nullable; + +/** Indicates that an output element was too large. */ +public class OutputTooLargeException extends RuntimeException { + public OutputTooLargeException(String reason) { + super(reason); + } + + /** Returns whether an exception was caused by a {@link OutputTooLargeException}. */ + public static boolean isCausedByOutputTooLargeException(@Nullable Throwable t) { + while (t != null) { + if (t instanceof OutputTooLargeException) { + return true; + } + t = t.getCause(); + } + return false; + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 0c51c381b360..f196852b2253 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -35,7 +35,7 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; @@ -100,13 +100,13 @@ import org.apache.beam.sdk.fn.JvmInitializers; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.io.gcp.bigquery.BigQuerySinkMetrics; -import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.util.construction.CoderTranslation; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.*; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; @@ -120,14 +120,6 @@ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) public class StreamingDataflowWorker { - - // TODO(https://github.com/apache/beam/issues/19632): Update throttling counters to use generic - // throttling-msecs metric. - public static final MetricName BIGQUERY_STREAMING_INSERT_THROTTLE_TIME = - MetricName.named( - "org.apache.beam.sdk.io.gcp.bigquery.BigQueryServicesImpl$DatasetServiceImpl", - "throttling-msecs"); - /** * Sinks are marked 'full' in {@link StreamingModeExecutionContext} once the amount of data sinked * (across all the sinks, if there are more than one) reaches this limit. This serves as hint for @@ -194,7 +186,7 @@ private StreamingDataflowWorker( WorkFailureProcessor workFailureProcessor, StreamingCounters streamingCounters, MemoryMonitor memoryMonitor, - AtomicInteger maxWorkItemCommitBytes, + AtomicReference operationalLimits, GrpcWindmillStreamFactory windmillStreamFactory, Function executorSupplier, ConcurrentMap stageInfoMap) { @@ -330,7 +322,7 @@ private StreamingDataflowWorker( streamingCounters, hotKeyLogger, sampler, - maxWorkItemCommitBytes, + operationalLimits, ID_GENERATOR, stageInfoMap); @@ -338,7 +330,6 @@ private StreamingDataflowWorker( LOG.debug("WindmillServiceEndpoint: {}", options.getWindmillServiceEndpoint()); LOG.debug("WindmillServicePort: {}", options.getWindmillServicePort()); LOG.debug("LocalWindmillHostport: {}", options.getLocalWindmillHostport()); - LOG.debug("maxWorkItemCommitBytes: {}", maxWorkItemCommitBytes.get()); } public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions options) { @@ -348,7 +339,8 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o StreamingCounters streamingCounters = StreamingCounters.create(); WorkUnitClient dataflowServiceClient = new DataflowWorkUnitClient(options, LOG); BoundedQueueExecutor workExecutor = createWorkUnitExecutor(options); - AtomicInteger maxWorkItemCommitBytes = new AtomicInteger(Integer.MAX_VALUE); + AtomicReference operationalLimits = + new AtomicReference<>(OperationalLimits.builder().build()); WindmillStateCache windmillStateCache = WindmillStateCache.builder() .setSizeMb(options.getWorkerCacheMb()) @@ -366,7 +358,7 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o createConfigFetcherComputationStateCacheAndWindmillClient( options, dataflowServiceClient, - maxWorkItemCommitBytes, + operationalLimits, windmillStreamFactoryBuilder, configFetcher -> ComputationStateCache.create( @@ -424,7 +416,7 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o workFailureProcessor, streamingCounters, memoryMonitor, - maxWorkItemCommitBytes, + operationalLimits, configFetcherComputationStateCacheAndWindmillClient.windmillStreamFactory(), executorSupplier, stageInfo); @@ -440,7 +432,7 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o createConfigFetcherComputationStateCacheAndWindmillClient( DataflowWorkerHarnessOptions options, WorkUnitClient dataflowServiceClient, - AtomicInteger maxWorkItemCommitBytes, + AtomicReference operationalLimits, GrpcWindmillStreamFactory.Builder windmillStreamFactoryBuilder, Function computationStateCacheFactory) { ComputationConfig.Fetcher configFetcher; @@ -456,8 +448,9 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o config -> onPipelineConfig( config, + options, dispatcherClient::consumeWindmillDispatcherEndpoints, - maxWorkItemCommitBytes)); + operationalLimits::set)); computationStateCache = computationStateCacheFactory.apply(configFetcher); windmillStreamFactory = windmillStreamFactoryBuilder @@ -503,9 +496,9 @@ static StreamingDataflowWorker forTesting( Supplier clock, Function executorSupplier, int localRetryTimeoutMs, - int maxWorkItemCommitBytesOverrides) { + OperationalLimits limits) { ConcurrentMap stageInfo = new ConcurrentHashMap<>(); - AtomicInteger maxWorkItemCommitBytes = new AtomicInteger(maxWorkItemCommitBytesOverrides); + AtomicReference operationalLimits = new AtomicReference<>(limits); BoundedQueueExecutor workExecutor = createWorkUnitExecutor(options); WindmillStateCache stateCache = WindmillStateCache.builder() @@ -522,8 +515,9 @@ static StreamingDataflowWorker forTesting( config -> onPipelineConfig( config, + options, windmillServer::setWindmillServiceEndpoints, - maxWorkItemCommitBytes)) + operationalLimits::set)) : new StreamingApplianceComputationConfigFetcher(windmillServer::getConfig); ConcurrentMap stateNameMap = new ConcurrentHashMap<>(prePopulatedStateNameMappings); @@ -591,7 +585,7 @@ static StreamingDataflowWorker forTesting( workFailureProcessor, streamingCounters, memoryMonitor, - maxWorkItemCommitBytes, + operationalLimits, options.isEnableStreamingEngine() ? windmillStreamFactory .setHealthCheckIntervalMillis( @@ -604,12 +598,18 @@ static StreamingDataflowWorker forTesting( private static void onPipelineConfig( StreamingEnginePipelineConfig config, + DataflowWorkerHarnessOptions options, Consumer> consumeWindmillServiceEndpoints, - AtomicInteger maxWorkItemCommitBytes) { - if (config.maxWorkItemCommitBytes() != maxWorkItemCommitBytes.get()) { - LOG.info("Setting maxWorkItemCommitBytes to {}", maxWorkItemCommitBytes); - maxWorkItemCommitBytes.set((int) config.maxWorkItemCommitBytes()); - } + Consumer operationalLimits) { + + operationalLimits.accept( + OperationalLimits.builder() + .setMaxWorkItemCommitBytes(config.maxWorkItemCommitBytes()) + .setMaxOutputKeyBytes(config.maxOutputKeyBytes()) + .setMaxOutputValueBytes(config.maxOutputValueBytes()) + .setThrowExceptionOnLargeOutput( + DataflowRunner.hasExperiment(options, "throw_exceptions_on_large_output")) + .build()); if (!config.windmillServiceEndpoints().isEmpty()) { consumeWindmillServiceEndpoints.accept(config.windmillServiceEndpoints()); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index dd6353060abc..a594dbb1e0f7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -129,6 +129,10 @@ public class StreamingModeExecutionContext extends DataflowExecutionContext gauges = new MetricsMap<>(GaugeCell::new); + private MetricsMap stringSet = new MetricsMap<>(StringSetCell::new); + private MetricsMap distributions = new MetricsMap<>(DeltaDistributionCell::new); @@ -159,6 +163,11 @@ public Gauge getGauge(MetricName metricName) { return gauges.get(metricName); } + @Override + public StringSet getStringSet(MetricName metricName) { + return stringSet.get(metricName); + } + @Override public Histogram getPerWorkerHistogram( MetricName metricName, HistogramData.BucketType bucketType) { @@ -176,7 +185,9 @@ public Histogram getPerWorkerHistogram( } public Iterable extractUpdates() { - return counterUpdates().append(distributionUpdates()).append(gaugeUpdates()); + return counterUpdates() + .append(distributionUpdates()) + .append(gaugeUpdates().append(stringSetUpdates())); } private FluentIterable counterUpdates() { @@ -218,6 +229,20 @@ private FluentIterable gaugeUpdates() { .filter(Predicates.notNull()); } + private FluentIterable stringSetUpdates() { + return FluentIterable.from(stringSet.entries()) + .transform( + new Function, CounterUpdate>() { + @Override + public @Nullable CounterUpdate apply( + @Nonnull Map.Entry entry) { + return MetricsToCounterUpdateConverter.fromStringSet( + MetricKey.create(stepName, entry.getKey()), entry.getValue().getCumulative()); + } + }) + .filter(Predicates.notNull()); + } + private FluentIterable distributionUpdates() { return FluentIterable.from(distributions.entries()) .transform( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillSink.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillSink.java index 1f26572941a0..78d0c6b4550a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillSink.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillSink.java @@ -44,6 +44,8 @@ import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; @SuppressWarnings({ "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) @@ -54,6 +56,7 @@ class WindmillSink extends Sink> { private final Coder valueCoder; private final Coder> windowsCoder; private StreamingModeExecutionContext context; + private static final Logger LOG = LoggerFactory.getLogger(WindmillSink.class); WindmillSink( String destinationName, @@ -172,6 +175,28 @@ public long add(WindowedValue data) throws IOException { key = context.getSerializedKey(); value = encode(valueCoder, data.getValue()); } + if (key.size() > context.getMaxOutputKeyBytes()) { + if (context.throwExceptionsForLargeOutput()) { + throw new OutputTooLargeException("Key too large: " + key.size()); + } else { + LOG.error( + "Trying to output too large key with size " + + key.size() + + ". Limit is " + + context.getMaxOutputKeyBytes()); + } + } + if (value.size() > context.getMaxOutputValueBytes()) { + if (context.throwExceptionsForLargeOutput()) { + throw new OutputTooLargeException("Value too large: " + value.size()); + } else { + LOG.error( + "Trying to output too large value with size " + + value.size() + + ". Limit is " + + context.getMaxOutputValueBytes()); + } + } Windmill.KeyedMessageBundle.Builder keyedOutput = productionMap.get(key); if (keyedOutput == null) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java index dd34e85bc93c..8a00194887da 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java @@ -24,6 +24,7 @@ import org.apache.beam.runners.core.metrics.ExecutionStateTracker; import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutor; import org.apache.beam.runners.dataflow.worker.DataflowWorkExecutor; +import org.apache.beam.runners.dataflow.worker.OperationalLimits; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; import org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter; @@ -45,7 +46,7 @@ * @implNote Once closed, it cannot be reused. */ // TODO(m-trieu): See if this can be combined/cleaned up with StreamingModeExecutionContext as the -// seperation of responsibilities are unclear. +// separation of responsibilities are unclear. @AutoValue @Internal @NotThreadSafe @@ -72,9 +73,11 @@ public final void executeWork( Work work, WindmillStateReader stateReader, SideInputStateFetcher sideInputStateFetcher, + OperationalLimits operationalLimits, Windmill.WorkItemCommitRequest.Builder outputBuilder) throws Exception { - context().start(key, work, stateReader, sideInputStateFetcher, outputBuilder); + context() + .start(key, work, stateReader, sideInputStateFetcher, operationalLimits, outputBuilder); workExecutor().execute(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java index 8f14ea26a461..a18ca8cfd6dc 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java @@ -17,7 +17,7 @@ */ package org.apache.beam.runners.dataflow.worker.streaming; -import static org.apache.beam.runners.dataflow.worker.DataflowSystemMetrics.THROTTLING_MSECS_METRIC_NAME; +import static org.apache.beam.sdk.metrics.Metrics.THROTTLE_TIME_COUNTER_NAME; import com.google.api.services.dataflow.model.CounterStructuredName; import com.google.api.services.dataflow.model.CounterUpdate; @@ -28,7 +28,6 @@ import java.util.List; import org.apache.beam.runners.dataflow.worker.DataflowSystemMetrics; import org.apache.beam.runners.dataflow.worker.MetricsContainerRegistry; -import org.apache.beam.runners.dataflow.worker.StreamingDataflowWorker; import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext.StreamingModeExecutionStateRegistry; import org.apache.beam.runners.dataflow.worker.StreamingStepMetricsContainer; import org.apache.beam.runners.dataflow.worker.counters.Counter; @@ -93,20 +92,13 @@ public List extractCounterUpdates() { } /** - * Checks if the step counter affects any per-stage counters. Currently 'throttled_millis' is the + * Checks if the step counter affects any per-stage counters. Currently 'throttled-msecs' is the * only counter updated. */ private void translateKnownStepCounters(CounterUpdate stepCounterUpdate) { CounterStructuredName structuredName = stepCounterUpdate.getStructuredNameAndMetadata().getName(); - if ((THROTTLING_MSECS_METRIC_NAME.getNamespace().equals(structuredName.getOriginNamespace()) - && THROTTLING_MSECS_METRIC_NAME.getName().equals(structuredName.getName())) - || (StreamingDataflowWorker.BIGQUERY_STREAMING_INSERT_THROTTLE_TIME - .getNamespace() - .equals(structuredName.getOriginNamespace()) - && StreamingDataflowWorker.BIGQUERY_STREAMING_INSERT_THROTTLE_TIME - .getName() - .equals(structuredName.getName()))) { + if (THROTTLE_TIME_COUNTER_NAME.equals(structuredName.getName())) { long msecs = DataflowCounterUpdateExtractor.splitIntToLong(stepCounterUpdate.getInteger()); if (msecs > 0) { throttledMsecs().addValue(msecs); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEngineComputationConfigFetcher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEngineComputationConfigFetcher.java index 51d1507af5fe..850e8c3f24bd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEngineComputationConfigFetcher.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEngineComputationConfigFetcher.java @@ -157,7 +157,7 @@ private static Optional fetchConfigWithRetry( } } - private static StreamingEnginePipelineConfig createPipelineConfig(StreamingConfigTask config) { + private StreamingEnginePipelineConfig createPipelineConfig(StreamingConfigTask config) { StreamingEnginePipelineConfig.Builder pipelineConfig = StreamingEnginePipelineConfig.builder(); if (config.getUserStepToStateFamilyNameMap() != null) { pipelineConfig.setUserStepToStateFamilyNameMap(config.getUserStepToStateFamilyNameMap()); @@ -187,6 +187,18 @@ private static StreamingEnginePipelineConfig createPipelineConfig(StreamingConfi pipelineConfig.setMaxWorkItemCommitBytes(config.getMaxWorkItemCommitBytes().intValue()); } + if (config.getOperationalLimits() != null) { + if (config.getOperationalLimits().getMaxKeyBytes() > 0 + && config.getOperationalLimits().getMaxKeyBytes() <= Integer.MAX_VALUE) { + pipelineConfig.setMaxOutputKeyBytes(config.getOperationalLimits().getMaxKeyBytes()); + } + if (config.getOperationalLimits().getMaxProductionOutputBytes() > 0 + && config.getOperationalLimits().getMaxProductionOutputBytes() <= Integer.MAX_VALUE) { + pipelineConfig.setMaxOutputValueBytes( + config.getOperationalLimits().getMaxProductionOutputBytes()); + } + } + return pipelineConfig.build(); } @@ -273,7 +285,7 @@ private synchronized void fetchInitialPipelineGlobalConfig() { private Optional fetchGlobalConfig() { return fetchConfigWithRetry(dataflowServiceClient::getGlobalStreamingConfigWorkItem) - .map(StreamingEngineComputationConfigFetcher::createPipelineConfig); + .map(config -> createPipelineConfig(config)); } @FunctionalInterface diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEnginePipelineConfig.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEnginePipelineConfig.java index b5b761ada703..8f1ff93f6a49 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEnginePipelineConfig.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEnginePipelineConfig.java @@ -34,12 +34,18 @@ public abstract class StreamingEnginePipelineConfig { public static StreamingEnginePipelineConfig.Builder builder() { return new AutoValue_StreamingEnginePipelineConfig.Builder() .setMaxWorkItemCommitBytes(DEFAULT_MAX_WORK_ITEM_COMMIT_BYTES) + .setMaxOutputKeyBytes(Long.MAX_VALUE) + .setMaxOutputValueBytes(Long.MAX_VALUE) .setUserStepToStateFamilyNameMap(new HashMap<>()) .setWindmillServiceEndpoints(ImmutableSet.of()); } public abstract long maxWorkItemCommitBytes(); + public abstract long maxOutputKeyBytes(); + + public abstract long maxOutputValueBytes(); + public abstract Map userStepToStateFamilyNameMap(); public abstract ImmutableSet windmillServiceEndpoints(); @@ -48,6 +54,10 @@ public static StreamingEnginePipelineConfig.Builder builder() { public abstract static class Builder { public abstract Builder setMaxWorkItemCommitBytes(long value); + public abstract Builder setMaxOutputKeyBytes(long value); + + public abstract Builder setMaxOutputValueBytes(long value); + public abstract Builder setUserStepToStateFamilyNameMap(Map value); public abstract Builder setWindmillServiceEndpoints(ImmutableSet value); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index 95e5c42bf59c..b0b6377dd8b1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -23,7 +23,7 @@ import java.util.Optional; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; @@ -31,6 +31,7 @@ import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.DataflowMapTaskExecutorFactory; import org.apache.beam.runners.dataflow.worker.HotKeyLogger; +import org.apache.beam.runners.dataflow.worker.OperationalLimits; import org.apache.beam.runners.dataflow.worker.ReaderCache; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.logging.DataflowWorkerLoggingMDC; @@ -82,7 +83,7 @@ public final class StreamingWorkScheduler { private final HotKeyLogger hotKeyLogger; private final ConcurrentMap stageInfoMap; private final DataflowExecutionStateSampler sampler; - private final AtomicInteger maxWorkItemCommitBytes; + private final AtomicReference operationalLimits; public StreamingWorkScheduler( DataflowWorkerHarnessOptions options, @@ -96,7 +97,7 @@ public StreamingWorkScheduler( HotKeyLogger hotKeyLogger, ConcurrentMap stageInfoMap, DataflowExecutionStateSampler sampler, - AtomicInteger maxWorkItemCommitBytes) { + AtomicReference operationalLimits) { this.options = options; this.clock = clock; this.computationWorkExecutorFactory = computationWorkExecutorFactory; @@ -108,7 +109,7 @@ public StreamingWorkScheduler( this.hotKeyLogger = hotKeyLogger; this.stageInfoMap = stageInfoMap; this.sampler = sampler; - this.maxWorkItemCommitBytes = maxWorkItemCommitBytes; + this.operationalLimits = operationalLimits; } public static StreamingWorkScheduler create( @@ -123,7 +124,7 @@ public static StreamingWorkScheduler create( StreamingCounters streamingCounters, HotKeyLogger hotKeyLogger, DataflowExecutionStateSampler sampler, - AtomicInteger maxWorkItemCommitBytes, + AtomicReference operationalLimits, IdGenerator idGenerator, ConcurrentMap stageInfoMap) { ComputationWorkExecutorFactory computationWorkExecutorFactory = @@ -148,7 +149,7 @@ public static StreamingWorkScheduler create( hotKeyLogger, stageInfoMap, sampler, - maxWorkItemCommitBytes); + operationalLimits); } private static long computeShuffleBytesRead(Windmill.WorkItem workItem) { @@ -292,7 +293,7 @@ private Windmill.WorkItemCommitRequest validateCommitRequestSize( Windmill.WorkItemCommitRequest commitRequest, String computationId, Windmill.WorkItem workItem) { - int byteLimit = maxWorkItemCommitBytes.get(); + long byteLimit = operationalLimits.get().maxWorkItemCommitBytes; int commitSize = commitRequest.getSerializedSize(); int estimatedCommitSize = commitSize < 0 ? Integer.MAX_VALUE : commitSize; @@ -376,7 +377,12 @@ private ExecuteWorkResult executeWork( // Blocks while executing work. computationWorkExecutor.executeWork( - executionKey, work, stateReader, localSideInputStateFetcher, outputBuilder); + executionKey, + work, + stateReader, + localSideInputStateFetcher, + operationalLimits.get(), + outputBuilder); if (work.isFailed()) { throw new WorkItemCancelledException(workItem.getShardingKey()); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/BatchModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/BatchModeExecutionContextTest.java index 0e516b3ffb49..4062fbf6ebed 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/BatchModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/BatchModeExecutionContextTest.java @@ -30,6 +30,8 @@ import com.google.api.services.dataflow.model.CounterStructuredNameAndMetadata; import com.google.api.services.dataflow.model.CounterUpdate; import com.google.api.services.dataflow.model.DistributionUpdate; +import com.google.api.services.dataflow.model.StringList; +import java.util.Arrays; import org.apache.beam.runners.core.metrics.ExecutionStateSampler; import org.apache.beam.runners.core.metrics.ExecutionStateTracker; import org.apache.beam.runners.core.metrics.ExecutionStateTracker.ExecutionState; @@ -41,7 +43,9 @@ import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Distribution; import org.apache.beam.sdk.metrics.MetricName; +import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.metrics.MetricsContainer; +import org.apache.beam.sdk.metrics.StringSet; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.hamcrest.Matchers; import org.junit.Test; @@ -158,6 +162,37 @@ public void extractMetricUpdatesDistribution() { assertThat(executionContext.extractMetricUpdates(false), containsInAnyOrder(expected)); } + @Test + public void extractMetricUpdatesStringSet() { + BatchModeExecutionContext executionContext = + BatchModeExecutionContext.forTesting(PipelineOptionsFactory.create(), "testStage"); + DataflowOperationContext operationContext = + executionContext.createOperationContext(NameContextsForTests.nameContextForTest()); + + StringSet stringSet = + operationContext + .metricsContainer() + .getStringSet(MetricName.named("namespace", "some-stringset")); + stringSet.add("ab"); + stringSet.add("cd"); + + final CounterUpdate expected = + new CounterUpdate() + .setStructuredNameAndMetadata( + new CounterStructuredNameAndMetadata() + .setName( + new CounterStructuredName() + .setOrigin("USER") + .setOriginNamespace("namespace") + .setName("some-stringset") + .setOriginalStepName("originalName")) + .setMetadata(new CounterMetadata().setKind(Kind.SET.toString()))) + .setCumulative(false) + .setStringList(new StringList().setElements(Arrays.asList("ab", "cd"))); + + assertThat(executionContext.extractMetricUpdates(false), containsInAnyOrder(expected)); + } + @Test public void extractMsecCounters() { BatchModeExecutionContext executionContext = @@ -232,7 +267,7 @@ public void extractThrottleTimeCounters() { .getCounter( MetricName.named( BatchModeExecutionContext.DATASTORE_THROTTLE_TIME_NAMESPACE, - BatchModeExecutionContext.THROTTLE_TIME_COUNTER_NAME)); + Metrics.THROTTLE_TIME_COUNTER_NAME)); counter.inc(12000); counter.inc(17000); counter.inc(1000); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index 0d2eb2997550..5855057c4210 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -549,7 +549,6 @@ private Windmill.GetWorkResponse buildSessionInput( List inputs, List timers) throws Exception { - // Windmill.GetWorkResponse.Builder builder = Windmill.GetWorkResponse.newBuilder(); Windmill.WorkItem.Builder builder = Windmill.WorkItem.newBuilder(); builder.setKey(DEFAULT_KEY_BYTES); builder.setShardingKey(DEFAULT_SHARDING_KEY); @@ -849,7 +848,7 @@ private StreamingDataflowWorker makeWorker( streamingDataflowWorkerTestParams.clock(), streamingDataflowWorkerTestParams.executorSupplier(), streamingDataflowWorkerTestParams.localRetryTimeoutMs(), - streamingDataflowWorkerTestParams.maxWorkItemCommitBytes()); + streamingDataflowWorkerTestParams.operationalLimits()); this.computationStateCache = worker.getComputationStateCache(); return worker; } @@ -1211,7 +1210,8 @@ public void testKeyCommitTooLargeException() throws Exception { makeWorker( defaultWorkerParams() .setInstructions(instructions) - .setMaxWorkItemCommitBytes(1000) + .setOperationalLimits( + OperationalLimits.builder().setMaxWorkItemCommitBytes(1000).build()) .publishCounters() .build()); worker.start(); @@ -1266,6 +1266,80 @@ public void testKeyCommitTooLargeException() throws Exception { assertTrue(foundErrors); } + @Test + public void testOutputKeyTooLargeException() throws Exception { + KvCoder kvCoder = KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()); + + List instructions = + Arrays.asList( + makeSourceInstruction(kvCoder), + makeDoFnInstruction(new ExceptionCatchingFn(), 0, kvCoder), + makeSinkInstruction(kvCoder, 1)); + + server.setExpectedExceptionCount(1); + + StreamingDataflowWorker worker = + makeWorker( + defaultWorkerParams() + .setInstructions(instructions) + .setOperationalLimits( + OperationalLimits.builder() + .setMaxOutputKeyBytes(15) + .setThrowExceptionOnLargeOutput(true) + .build()) + .build()); + worker.start(); + + // This large key will cause the ExceptionCatchingFn to throw an exception, which will then + // cause it to output a smaller key. + String bigKey = "some_much_too_large_output_key"; + server.whenGetWorkCalled().thenReturn(makeInput(1, 0, bigKey, DEFAULT_SHARDING_KEY)); + server.waitForEmptyWorkQueue(); + + Map result = server.waitForAndGetCommits(1); + assertEquals(1, result.size()); + assertEquals( + makeExpectedOutput(1, 0, bigKey, DEFAULT_SHARDING_KEY, "smaller_key").build(), + removeDynamicFields(result.get(1L))); + } + + @Test + public void testOutputValueTooLargeException() throws Exception { + KvCoder kvCoder = KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()); + + List instructions = + Arrays.asList( + makeSourceInstruction(kvCoder), + makeDoFnInstruction(new ExceptionCatchingFn(), 0, kvCoder), + makeSinkInstruction(kvCoder, 1)); + + server.setExpectedExceptionCount(1); + + StreamingDataflowWorker worker = + makeWorker( + defaultWorkerParams() + .setInstructions(instructions) + .setOperationalLimits( + OperationalLimits.builder() + .setMaxOutputValueBytes(15) + .setThrowExceptionOnLargeOutput(true) + .build()) + .build()); + worker.start(); + + // The first time processing will have value "data1_a_bunch_more_data_output", which is above + // the limit. After throwing the exception, the output should be just "data1", which is small + // enough. + server.whenGetWorkCalled().thenReturn(makeInput(1, 0, "key", DEFAULT_SHARDING_KEY)); + server.waitForEmptyWorkQueue(); + + Map result = server.waitForAndGetCommits(1); + assertEquals(1, result.size()); + assertEquals( + makeExpectedOutput(1, 0, "key", DEFAULT_SHARDING_KEY, "smaller_key").build(), + removeDynamicFields(result.get(1L))); + } + @Test public void testKeyChange() throws Exception { KvCoder kvCoder = KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()); @@ -4017,6 +4091,18 @@ public void processElement(ProcessContext c) { } } + static class ExceptionCatchingFn extends DoFn, KV> { + + @ProcessElement + public void processElement(ProcessContext c) { + try { + c.output(KV.of(c.element().getKey(), c.element().getValue() + "_a_bunch_more_data_output")); + } catch (Exception e) { + c.output(KV.of("smaller_key", c.element().getValue())); + } + } + } + static class ChangeKeysFn extends DoFn, KV> { @ProcessElement @@ -4429,7 +4515,7 @@ private static StreamingDataflowWorkerTestParams.Builder builder() { .setLocalRetryTimeoutMs(-1) .setPublishCounters(false) .setClock(Instant::now) - .setMaxWorkItemCommitBytes(Integer.MAX_VALUE); + .setOperationalLimits(OperationalLimits.builder().build()); } abstract ImmutableMap stateNameMappings(); @@ -4446,7 +4532,7 @@ private static StreamingDataflowWorkerTestParams.Builder builder() { abstract int localRetryTimeoutMs(); - abstract int maxWorkItemCommitBytes(); + abstract OperationalLimits operationalLimits(); @AutoValue.Builder abstract static class Builder { @@ -4480,7 +4566,7 @@ final Builder publishCounters() { abstract Builder setLocalRetryTimeoutMs(int value); - abstract Builder setMaxWorkItemCommitBytes(int maxWorkItemCommitBytes); + abstract Builder setOperationalLimits(OperationalLimits operationalLimits); abstract StreamingDataflowWorkerTestParams build(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 2bd6621dd4f4..8445e8ede852 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -157,6 +157,7 @@ public void testTimerInternalsSetTimer() { Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), stateReader, sideInputStateFetcher, + OperationalLimits.builder().build(), outputBuilder); TimerInternals timerInternals = stepContext.timerInternals(); @@ -206,6 +207,7 @@ public void testTimerInternalsProcessingTimeSkew() { Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), stateReader, sideInputStateFetcher, + OperationalLimits.builder().build(), outputBuilder); TimerInternals timerInternals = stepContext.timerInternals(); assertTrue(timerTimestamp.isBefore(timerInternals.currentProcessingTime())); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingStepMetricsContainerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingStepMetricsContainerTest.java index a9b8abdca93c..2d5a8d8266ae 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingStepMetricsContainerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingStepMetricsContainerTest.java @@ -37,11 +37,13 @@ import com.google.api.services.dataflow.model.Linear; import com.google.api.services.dataflow.model.MetricValue; import com.google.api.services.dataflow.model.PerStepNamespaceMetrics; +import com.google.api.services.dataflow.model.StringList; import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.time.ZoneId; import java.time.ZoneOffset; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -57,6 +59,7 @@ import org.apache.beam.sdk.metrics.MetricsContainer; import org.apache.beam.sdk.metrics.NoOpCounter; import org.apache.beam.sdk.metrics.NoOpHistogram; +import org.apache.beam.sdk.metrics.StringSet; import org.apache.beam.sdk.util.HistogramData; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.hamcrest.collection.IsEmptyIterable; @@ -267,6 +270,61 @@ public void testGaugeUpdateExtraction() { DateTimeUtils.setCurrentMillisSystem(); } + @Test + public void testStringSetUpdateExtraction() { + StringSet stringSet = c1.getStringSet(name1); + stringSet.add("ab"); + stringSet.add("cd", "ef"); + stringSet.add("gh"); + stringSet.add("gh"); + + CounterUpdate name1Update = + new CounterUpdate() + .setStructuredNameAndMetadata( + new CounterStructuredNameAndMetadata() + .setName( + new CounterStructuredName() + .setOrigin(Origin.USER.toString()) + .setOriginNamespace("ns") + .setName("name1") + .setOriginalStepName("s1")) + .setMetadata(new CounterMetadata().setKind(Kind.SET.toString()))) + .setCumulative(false) + .setStringList(new StringList().setElements(Arrays.asList("ab", "cd", "ef", "gh"))); + + Iterable updates = StreamingStepMetricsContainer.extractMetricUpdates(registry); + assertThat(updates, containsInAnyOrder(name1Update)); + + stringSet = c2.getStringSet(name2); + stringSet.add("ij"); + stringSet.add("kl", "mn"); + stringSet.add("mn"); + + CounterUpdate name2Update = + new CounterUpdate() + .setStructuredNameAndMetadata( + new CounterStructuredNameAndMetadata() + .setName( + new CounterStructuredName() + .setOrigin(Origin.USER.toString()) + .setOriginNamespace("ns") + .setName("name2") + .setOriginalStepName("s2")) + .setMetadata(new CounterMetadata().setKind(Kind.SET.toString()))) + .setCumulative(false) + .setStringList(new StringList().setElements(Arrays.asList("ij", "kl", "mn"))); + + updates = StreamingStepMetricsContainer.extractMetricUpdates(registry); + assertThat(updates, containsInAnyOrder(name1Update, name2Update)); + + c1.getStringSet(name1).add("op"); + name1Update.setStringList( + new StringList().setElements(Arrays.asList("ab", "cd", "ef", "gh", "op"))); + + updates = StreamingStepMetricsContainer.extractMetricUpdates(registry); + assertThat(updates, containsInAnyOrder(name1Update, name2Update)); + } + @Test public void testPerWorkerMetrics() { StreamingStepMetricsContainer.setEnablePerWorkerMetrics(false); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index 98302c512256..5c149a65f4ce 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -634,6 +634,7 @@ public void testReadUnboundedReader() throws Exception { Watermarks.builder().setInputDataWatermark(new Instant(0)).build()), mock(WindmillStateReader.class), mock(SideInputStateFetcher.class), + OperationalLimits.builder().build(), Windmill.WorkItemCommitRequest.newBuilder()); @SuppressWarnings({"unchecked", "rawtypes"}) @@ -1009,6 +1010,7 @@ public void testFailedWorkItemsAbort() throws Exception { dummyWork, mock(WindmillStateReader.class), mock(SideInputStateFetcher.class), + OperationalLimits.builder().build(), Windmill.WorkItemCommitRequest.newBuilder()); @SuppressWarnings({"unchecked", "rawtypes"}) diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/wire/LengthPrefixUnknownCoders.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/wire/LengthPrefixUnknownCoders.java index bde97f0b9d19..bd5a159efeac 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/wire/LengthPrefixUnknownCoders.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/wire/LengthPrefixUnknownCoders.java @@ -17,7 +17,9 @@ */ package org.apache.beam.runners.fnexecution.wire; +import java.util.HashSet; import java.util.Map.Entry; +import java.util.Set; import java.util.function.Predicate; import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.model.pipeline.v1.RunnerApi.Coder; @@ -28,6 +30,17 @@ /** Utilities for replacing or wrapping unknown coders with {@link LengthPrefixCoder}. */ public class LengthPrefixUnknownCoders { + private static Set otherKnownCoderUrns = new HashSet<>(); + + /** + * Registers a coder as being of known type and as such not meriting length prefixing. + * + * @param urn The urn of the coder that should not be length prefixed. + */ + public static void addKnownCoderUrn(String urn) { + otherKnownCoderUrns.add(urn); + } + /** * Recursively traverses the coder tree and wraps the first unknown coder in every branch with a * {@link LengthPrefixCoder} unless an ancestor coder is itself a {@link LengthPrefixCoder}. If @@ -59,7 +72,7 @@ public static String addLengthPrefixedCoder( // with a length prefix coder or replace it with a length prefix byte array coder. if (ModelCoders.LENGTH_PREFIX_CODER_URN.equals(urn)) { return replaceWithByteArrayCoder ? lengthPrefixedByteArrayCoderId : coderId; - } else if (ModelCoders.urns().contains(urn)) { + } else if (ModelCoders.urns().contains(urn) || otherKnownCoderUrns.contains(urn)) { return addForModelCoder(coderId, components, replaceWithByteArrayCoder); } else { return replaceWithByteArrayCoder @@ -71,6 +84,9 @@ public static String addLengthPrefixedCoder( private static String addForModelCoder( String coderId, RunnerApi.Components.Builder components, boolean replaceWithByteArrayCoder) { Coder coder = components.getCodersOrThrow(coderId); + if (coder.getComponentCoderIdsCount() == 0) { + return coderId; + } RunnerApi.Coder.Builder builder = coder.toBuilder().clearComponentCoderIds(); for (String componentCoderId : coder.getComponentCoderIdsList()) { builder.addComponentCoderIds( diff --git a/runners/jet/src/main/java/org/apache/beam/runners/jet/FailedRunningPipelineResults.java b/runners/jet/src/main/java/org/apache/beam/runners/jet/FailedRunningPipelineResults.java index b6dae10da6bc..67cf3280a83c 100644 --- a/runners/jet/src/main/java/org/apache/beam/runners/jet/FailedRunningPipelineResults.java +++ b/runners/jet/src/main/java/org/apache/beam/runners/jet/FailedRunningPipelineResults.java @@ -25,6 +25,7 @@ import org.apache.beam.sdk.metrics.MetricResult; import org.apache.beam.sdk.metrics.MetricResults; import org.apache.beam.sdk.metrics.MetricsFilter; +import org.apache.beam.sdk.metrics.StringSetResult; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Duration; @@ -84,6 +85,11 @@ public Iterable> getDistributions() { public Iterable> getGauges() { return Collections.emptyList(); } + + @Override + public Iterable> getStringSets() { + return Collections.emptyList(); + } }; } }; diff --git a/runners/jet/src/main/java/org/apache/beam/runners/jet/metrics/JetMetricResults.java b/runners/jet/src/main/java/org/apache/beam/runners/jet/metrics/JetMetricResults.java index 8e28f3fda0e8..44681a626cc0 100644 --- a/runners/jet/src/main/java/org/apache/beam/runners/jet/metrics/JetMetricResults.java +++ b/runners/jet/src/main/java/org/apache/beam/runners/jet/metrics/JetMetricResults.java @@ -25,6 +25,7 @@ import org.apache.beam.runners.core.metrics.GaugeData; import org.apache.beam.runners.core.metrics.MetricUpdates; import org.apache.beam.runners.core.metrics.MetricUpdates.MetricUpdate; +import org.apache.beam.runners.core.metrics.StringSetData; import org.apache.beam.sdk.metrics.DistributionResult; import org.apache.beam.sdk.metrics.GaugeResult; import org.apache.beam.sdk.metrics.MetricFiltering; @@ -33,6 +34,7 @@ import org.apache.beam.sdk.metrics.MetricResult; import org.apache.beam.sdk.metrics.MetricResults; import org.apache.beam.sdk.metrics.MetricsFilter; +import org.apache.beam.sdk.metrics.StringSetResult; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Predicate; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.FluentIterable; import org.checkerframework.checker.nullness.qual.Nullable; @@ -52,6 +54,9 @@ public class JetMetricResults extends MetricResults { @GuardedBy("this") private final Gauges gauges = new Gauges(); + @GuardedBy("this") + private final StringSets stringSet = new StringSets(); + @GuardedBy("this") private IMap metricsAccumulator; @@ -70,18 +75,23 @@ public synchronized MetricQueryResults queryMetrics(@Nullable MetricsFilter filt updateLocalMetrics(metricsAccumulator); } return new QueryResults( - counters.filter(filter), distributions.filter(filter), gauges.filter(filter)); + counters.filter(filter), + distributions.filter(filter), + gauges.filter(filter), + stringSet.filter(filter)); } private synchronized void updateLocalMetrics(IMap metricsAccumulator) { counters.clear(); distributions.clear(); gauges.clear(); + stringSet.clear(); for (MetricUpdates metricUpdates : metricsAccumulator.values()) { counters.merge(metricUpdates.counterUpdates()); distributions.merge(metricUpdates.distributionUpdates()); gauges.merge(metricUpdates.gaugeUpdates()); + stringSet.merge(metricUpdates.stringSetUpdates()); } } @@ -93,14 +103,17 @@ private static class QueryResults extends MetricQueryResults { private final Iterable> counters; private final Iterable> distributions; private final Iterable> gauges; + private final Iterable> stringSets; private QueryResults( Iterable> counters, Iterable> distributions, - Iterable> gauges) { + Iterable> gauges, + Iterable> stringSets) { this.counters = counters; this.distributions = distributions; this.gauges = gauges; + this.stringSets = stringSets; } @Override @@ -117,6 +130,11 @@ public Iterable> getDistributions() { public Iterable> getGauges() { return gauges; } + + @Override + public Iterable> getStringSets() { + return stringSets; + } } private static class Counters { @@ -212,4 +230,36 @@ private MetricResult toUpdateResult(Map.Entry return MetricResult.create(key, gaugeResult, gaugeResult); } } + + private static class StringSets { + + private final Map stringSets = new HashMap<>(); + + void merge(Iterable> updates) { + for (MetricUpdate update : updates) { + MetricKey key = update.getKey(); + StringSetData oldStringSet = stringSets.getOrDefault(key, StringSetData.empty()); + StringSetData updatedStringSet = update.getUpdate().combine(oldStringSet); + stringSets.put(key, updatedStringSet); + } + } + + void clear() { + stringSets.clear(); + } + + Iterable> filter(MetricsFilter filter) { + return FluentIterable.from(stringSets.entrySet()) + .filter(matchesFilter(filter)) + .transform(this::toUpdateResult) + .toList(); + } + + private MetricResult toUpdateResult( + Map.Entry entry) { + MetricKey key = entry.getKey(); + StringSetResult stringSetResult = entry.getValue().extractResult(); + return MetricResult.create(key, stringSetResult, stringSetResult); + } + } } diff --git a/runners/jet/src/main/java/org/apache/beam/runners/jet/metrics/JetMetricsContainer.java b/runners/jet/src/main/java/org/apache/beam/runners/jet/metrics/JetMetricsContainer.java index 5441d05dcf76..64455d704c9b 100644 --- a/runners/jet/src/main/java/org/apache/beam/runners/jet/metrics/JetMetricsContainer.java +++ b/runners/jet/src/main/java/org/apache/beam/runners/jet/metrics/JetMetricsContainer.java @@ -26,12 +26,14 @@ import org.apache.beam.runners.core.metrics.DistributionData; import org.apache.beam.runners.core.metrics.GaugeData; import org.apache.beam.runners.core.metrics.MetricUpdates; +import org.apache.beam.runners.core.metrics.StringSetData; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Distribution; import org.apache.beam.sdk.metrics.Gauge; import org.apache.beam.sdk.metrics.MetricKey; import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.sdk.metrics.MetricsContainer; +import org.apache.beam.sdk.metrics.StringSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; /** Jet specific implementation of {@link MetricsContainer}. */ @@ -47,6 +49,7 @@ public static String getMetricsMapName(long jobId) { private final Map counters = new HashMap<>(); private final Map distributions = new HashMap<>(); private final Map gauges = new HashMap<>(); + private final Map stringSets = new HashMap<>(); private final IMap accumulator; @@ -71,9 +74,14 @@ public Gauge getGauge(MetricName metricName) { return gauges.computeIfAbsent(metricName, GaugeImpl::new); } + @Override + public StringSet getStringSet(MetricName metricName) { + return stringSets.computeIfAbsent(metricName, StringSetImpl::new); + } + @SuppressWarnings("FutureReturnValueIgnored") public void flush(boolean async) { - if (counters.isEmpty() && distributions.isEmpty() && gauges.isEmpty()) { + if (counters.isEmpty() && distributions.isEmpty() && gauges.isEmpty() && stringSets.isEmpty()) { return; } @@ -81,7 +89,9 @@ public void flush(boolean async) { ImmutableList> distributions = extractUpdates(this.distributions); ImmutableList> gauges = extractUpdates(this.gauges); - MetricUpdates updates = new MetricUpdatesImpl(counters, distributions, gauges); + ImmutableList> stringSets = + extractUpdates(this.stringSets); + MetricUpdates updates = new MetricUpdatesImpl(counters, distributions, gauges, stringSets); if (async) { accumulator.setAsync(metricsKey, updates); @@ -110,14 +120,17 @@ private static class MetricUpdatesImpl extends MetricUpdates implements Serializ private final Iterable> counters; private final Iterable> distributions; private final Iterable> gauges; + private final Iterable> stringSets; MetricUpdatesImpl( Iterable> counters, Iterable> distributions, - Iterable> gauges) { + Iterable> gauges, + Iterable> stringSets) { this.counters = counters; this.distributions = distributions; this.gauges = gauges; + this.stringSets = stringSets; } @Override @@ -134,5 +147,10 @@ public Iterable> distributionUpdates() { public Iterable> gaugeUpdates() { return gauges; } + + @Override + public Iterable> stringSetUpdates() { + return stringSets; + } } } diff --git a/runners/jet/src/main/java/org/apache/beam/runners/jet/metrics/StringSetImpl.java b/runners/jet/src/main/java/org/apache/beam/runners/jet/metrics/StringSetImpl.java new file mode 100644 index 000000000000..4fd67042e3cf --- /dev/null +++ b/runners/jet/src/main/java/org/apache/beam/runners/jet/metrics/StringSetImpl.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.jet.metrics; + +import org.apache.beam.runners.core.metrics.StringSetData; +import org.apache.beam.sdk.metrics.MetricName; +import org.apache.beam.sdk.metrics.StringSet; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; + +/** Implementation of {@link StringSet}. */ +public class StringSetImpl extends AbstractMetric implements StringSet { + + private final StringSetData stringSetData = StringSetData.empty(); + + public StringSetImpl(MetricName name) { + super(name); + } + + @Override + StringSetData getValue() { + return stringSetData; + } + + @Override + public void add(String value) { + if (stringSetData.stringSet().contains(value)) { + return; + } + stringSetData.combine(StringSetData.create(ImmutableSet.of(value))); + } + + @Override + public void add(String... values) { + stringSetData.combine(StringSetData.create(ImmutableSet.copyOf(values))); + } +} diff --git a/runners/portability/java/build.gradle b/runners/portability/java/build.gradle index 04f7b53ced21..b684299c3174 100644 --- a/runners/portability/java/build.gradle +++ b/runners/portability/java/build.gradle @@ -159,6 +159,7 @@ def createUlrValidatesRunnerTask = { name, environmentType, dockerImageTask = "" excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService' excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment' excludeCategories 'org.apache.beam.sdk.testing.UsesGaugeMetrics' + excludeCategories 'org.apache.beam.sdk.testing.UsesStringSetMetrics' excludeCategories 'org.apache.beam.sdk.testing.UsesOnWindowExpiration' excludeCategories 'org.apache.beam.sdk.testing.UsesMapState' excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState' diff --git a/runners/portability/java/src/main/java/org/apache/beam/runners/portability/PortableMetrics.java b/runners/portability/java/src/main/java/org/apache/beam/runners/portability/PortableMetrics.java index fc94e408bfd3..1d45a83b1e79 100644 --- a/runners/portability/java/src/main/java/org/apache/beam/runners/portability/PortableMetrics.java +++ b/runners/portability/java/src/main/java/org/apache/beam/runners/portability/PortableMetrics.java @@ -19,10 +19,12 @@ import static org.apache.beam.runners.core.metrics.MonitoringInfoConstants.TypeUrns.DISTRIBUTION_INT64_TYPE; import static org.apache.beam.runners.core.metrics.MonitoringInfoConstants.TypeUrns.LATEST_INT64_TYPE; +import static org.apache.beam.runners.core.metrics.MonitoringInfoConstants.TypeUrns.SET_STRING_TYPE; import static org.apache.beam.runners.core.metrics.MonitoringInfoConstants.TypeUrns.SUM_INT64_TYPE; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.decodeInt64Counter; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.decodeInt64Distribution; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.decodeInt64Gauge; +import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.decodeStringSet; import java.util.ArrayList; import java.util.List; @@ -32,6 +34,7 @@ import org.apache.beam.model.pipeline.v1.MetricsApi; import org.apache.beam.runners.core.metrics.DistributionData; import org.apache.beam.runners.core.metrics.GaugeData; +import org.apache.beam.runners.core.metrics.StringSetData; import org.apache.beam.sdk.metrics.DistributionResult; import org.apache.beam.sdk.metrics.GaugeResult; import org.apache.beam.sdk.metrics.MetricFiltering; @@ -41,6 +44,7 @@ import org.apache.beam.sdk.metrics.MetricResult; import org.apache.beam.sdk.metrics.MetricResults; import org.apache.beam.sdk.metrics.MetricsFilter; +import org.apache.beam.sdk.metrics.StringSetResult; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; @SuppressWarnings({ @@ -53,14 +57,17 @@ public class PortableMetrics extends MetricResults { private Iterable> counters; private Iterable> distributions; private Iterable> gauges; + private Iterable> stringSets; private PortableMetrics( Iterable> counters, Iterable> distributions, - Iterable> gauges) { + Iterable> gauges, + Iterable> stringSets) { this.counters = counters; this.distributions = distributions; this.gauges = gauges; + this.stringSets = stringSets; } public static PortableMetrics of(JobApi.MetricResults jobMetrics) { @@ -75,7 +82,9 @@ public MetricQueryResults queryMetrics(MetricsFilter filter) { Iterables.filter( this.distributions, (distribution) -> MetricFiltering.matches(filter, distribution.getKey())), - Iterables.filter(this.gauges, (gauge) -> MetricFiltering.matches(filter, gauge.getKey()))); + Iterables.filter(this.gauges, (gauge) -> MetricFiltering.matches(filter, gauge.getKey())), + Iterables.filter( + this.stringSets, (stringSet) -> MetricFiltering.matches(filter, stringSet.getKey()))); } private static PortableMetrics convertMonitoringInfosToMetricResults( @@ -89,7 +98,10 @@ private static PortableMetrics convertMonitoringInfosToMetricResults( extractDistributionMetricsFromJobMetrics(monitoringInfoList); Iterable> gaugesFromMetrics = extractGaugeMetricsFromJobMetrics(monitoringInfoList); - return new PortableMetrics(countersFromJobMetrics, distributionsFromMetrics, gaugesFromMetrics); + Iterable> stringSetFromMetrics = + extractStringSetMetricsFromJobMetrics(monitoringInfoList); + return new PortableMetrics( + countersFromJobMetrics, distributionsFromMetrics, gaugesFromMetrics, stringSetFromMetrics); } private static Iterable> @@ -123,6 +135,28 @@ private static MetricResult convertGaugeMonitoringInfoToGauge( return MetricResult.create(key, false, result); } + private static Iterable> extractStringSetMetricsFromJobMetrics( + List monitoringInfoList) { + return monitoringInfoList.stream() + .filter(item -> SET_STRING_TYPE.equals(item.getType())) + .filter(item -> item.getLabelsMap().get(NAMESPACE_LABEL) != null) + .map(PortableMetrics::convertStringSetMonitoringInfoToStringSet) + .collect(Collectors.toList()); + } + + private static MetricResult convertStringSetMonitoringInfoToStringSet( + MetricsApi.MonitoringInfo monitoringInfo) { + Map labelsMap = monitoringInfo.getLabelsMap(); + MetricKey key = + MetricKey.create( + labelsMap.get(STEP_NAME_LABEL), + MetricName.named(labelsMap.get(NAMESPACE_LABEL), labelsMap.get(METRIC_NAME_LABEL))); + + StringSetData data = decodeStringSet(monitoringInfo.getPayload()); + StringSetResult result = StringSetResult.create(data.stringSet()); + return MetricResult.create(key, false, result); + } + private static MetricResult convertDistributionMonitoringInfoToDistribution( MetricsApi.MonitoringInfo monitoringInfo) { Map labelsMap = monitoringInfo.getLabelsMap(); diff --git a/runners/portability/java/src/test/java/org/apache/beam/runners/portability/PortableRunnerTest.java b/runners/portability/java/src/test/java/org/apache/beam/runners/portability/PortableRunnerTest.java index 25353437a2ec..788d4a43319d 100644 --- a/runners/portability/java/src/test/java/org/apache/beam/runners/portability/PortableRunnerTest.java +++ b/runners/portability/java/src/test/java/org/apache/beam/runners/portability/PortableRunnerTest.java @@ -20,6 +20,7 @@ import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeInt64Counter; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeInt64Distribution; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeInt64Gauge; +import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeStringSet; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; @@ -27,6 +28,7 @@ import java.io.Serializable; import java.util.HashMap; import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; import org.apache.beam.model.jobmanagement.v1.JobApi; import org.apache.beam.model.jobmanagement.v1.JobApi.JobState; @@ -34,6 +36,7 @@ import org.apache.beam.model.pipeline.v1.MetricsApi; import org.apache.beam.runners.core.metrics.DistributionData; import org.apache.beam.runners.core.metrics.GaugeData; +import org.apache.beam.runners.core.metrics.StringSetData; import org.apache.beam.runners.fnexecution.artifact.ArtifactStagingService; import org.apache.beam.runners.portability.testing.TestJobService; import org.apache.beam.sdk.PipelineResult; @@ -50,6 +53,7 @@ import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; import org.joda.time.Duration; import org.joda.time.Instant; @@ -68,6 +72,7 @@ public class PortableRunnerTest implements Serializable { private static final String COUNTER_TYPE = "beam:metrics:sum_int64:v1"; private static final String DIST_TYPE = "beam:metrics:distribution_int64:v1"; private static final String GAUGE_TYPE = "beam:metrics:latest_int64:v1"; + private static final String STRING_SET_TYPE = "beam:metrics:set_string:v1"; private static final String NAMESPACE_LABEL = "NAMESPACE"; private static final String METRIC_NAME_LABEL = "NAME"; private static final String STEP_NAME_LABEL = "PTRANSFORM"; @@ -76,6 +81,7 @@ public class PortableRunnerTest implements Serializable { private static final String STEP_NAME = "testStep"; private static final Long COUNTER_VALUE = 42L; private static final Long GAUGE_VALUE = 64L; + private static final Set STRING_SET_VALUE = ImmutableSet.of("ab", "cd"); private static final Instant GAUGE_TIME = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.standardSeconds(1)); private static final Long DIST_SUM = 1000L; @@ -124,6 +130,9 @@ public void extractsMetrics() throws Exception { assertThat( metricQueryResults.getGauges().iterator().next().getAttempted().getValue(), is(GAUGE_VALUE)); + assertThat( + metricQueryResults.getStringSets().iterator().next().getAttempted().getStringSet(), + is(STRING_SET_VALUE)); } private JobApi.MetricResults generateMetricResults() throws Exception { @@ -155,10 +164,18 @@ private JobApi.MetricResults generateMetricResults() throws Exception { .setPayload(encodeInt64Gauge(GaugeData.create(GAUGE_VALUE, GAUGE_TIME))) .build(); + MetricsApi.MonitoringInfo stringSetMonitoringInfo = + MetricsApi.MonitoringInfo.newBuilder() + .setType(STRING_SET_TYPE) + .putAllLabels(labelMap) + .setPayload(encodeStringSet(StringSetData.create(STRING_SET_VALUE))) + .build(); + return JobApi.MetricResults.newBuilder() .addAttempted(counterMonitoringInfo) .addAttempted(distMonitoringInfo) .addAttempted(gaugeMonitoringInfo) + .addAttempted(stringSetMonitoringInfo) .build(); } diff --git a/runners/prism/build.gradle b/runners/prism/build.gradle index 7169f91ea156..711a1aa2dd75 100644 --- a/runners/prism/build.gradle +++ b/runners/prism/build.gradle @@ -36,6 +36,7 @@ def modDir = project.rootDir.toPath().resolve("sdks") // prismDir is the directory containing the prism executable. def prismDir = modDir.resolve("go/cmd/prism") +ext.set('buildTarget', buildTarget) // Overrides the gradle build task to build the prism executable. def buildTask = tasks.named("build") { diff --git a/runners/prism/java/build.gradle b/runners/prism/java/build.gradle new file mode 100644 index 000000000000..93d151f3e058 --- /dev/null +++ b/runners/prism/java/build.gradle @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { id 'org.apache.beam.module' } + +applyJavaNature( + automaticModuleName: 'org.apache.beam.runners.prism', +) + +description = "Apache Beam :: Runners :: Prism :: Java" +ext.summary = "Support for executing a pipeline on Prism." + +dependencies { + implementation project(path: ":sdks:java:core", configuration: "shadow") + implementation project(":runners:portability:java") + + implementation library.java.joda_time + implementation library.java.slf4j_api + implementation library.java.vendored_guava_32_1_2_jre + + testImplementation library.java.junit + testImplementation library.java.mockito_core + testImplementation library.java.truth +} + +tasks.test { + var prismBuildTask = dependsOn(':runners:prism:build') + systemProperty 'prism.buildTarget', prismBuildTask.project.property('buildTarget').toString() +} diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java new file mode 100644 index 000000000000..620d5508f22a --- /dev/null +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.prism; + +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; + +import com.google.auto.value.AutoValue; +import java.io.File; +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * {@link PrismExecutor} builds and executes a {@link ProcessBuilder} for use by the {@link + * PrismRunner}. Prism is a {@link org.apache.beam.runners.portability.PortableRunner} maintained at + * sdks/go/cmd/prism. + */ +@AutoValue +abstract class PrismExecutor { + + private static final Logger LOG = LoggerFactory.getLogger(PrismExecutor.class); + + protected @MonotonicNonNull Process process; + protected ExecutorService executorService = Executors.newSingleThreadExecutor(); + protected @MonotonicNonNull Future future = null; + + static Builder builder() { + return new AutoValue_PrismExecutor.Builder(); + } + + /** The command to execute the Prism binary. */ + abstract String getCommand(); + + /** + * Additional arguments to pass when invoking the Prism binary. Defaults to an {@link + * Collections#emptyList()}. + */ + abstract List getArguments(); + + /** Stops the execution of the {@link Process}, created as a result of {@link #execute}. */ + void stop() { + LOG.info("Stopping Prism..."); + if (future != null) { + future.cancel(true); + } + executorService.shutdown(); + try { + boolean ignored = executorService.awaitTermination(1000L, TimeUnit.MILLISECONDS); + } catch (InterruptedException ignored) { + } + if (process == null) { + return; + } + if (!process.isAlive()) { + return; + } + process.destroy(); + try { + process.waitFor(); + } catch (InterruptedException ignored) { + } + } + + /** Reports whether the Prism executable {@link Process#isAlive()}. */ + boolean isAlive() { + if (process == null) { + return false; + } + return process.isAlive(); + } + + /** + * Execute the {@link ProcessBuilder} that starts the Prism service. Redirects output to STDOUT. + */ + void execute() throws IOException { + execute(createProcessBuilder().inheritIO()); + } + + /** + * Execute the {@link ProcessBuilder} that starts the Prism service. Redirects output to the + * {@param outputStream}. + */ + void execute(OutputStream outputStream) throws IOException { + execute(createProcessBuilder().redirectErrorStream(true)); + this.future = + executorService.submit( + () -> { + try { + ByteStreams.copy(checkStateNotNull(process).getInputStream(), outputStream); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + /** + * Execute the {@link ProcessBuilder} that starts the Prism service. Redirects output to the + * {@param file}. + */ + void execute(File file) throws IOException { + execute( + createProcessBuilder() + .redirectErrorStream(true) + .redirectOutput(ProcessBuilder.Redirect.appendTo(file))); + } + + private void execute(ProcessBuilder processBuilder) throws IOException { + this.process = processBuilder.start(); + LOG.info("started {}", String.join(" ", getCommandWithArguments())); + } + + private List getCommandWithArguments() { + List commandWithArguments = new ArrayList<>(); + commandWithArguments.add(getCommand()); + commandWithArguments.addAll(getArguments()); + + return commandWithArguments; + } + + private ProcessBuilder createProcessBuilder() { + return new ProcessBuilder(getCommandWithArguments()); + } + + @AutoValue.Builder + abstract static class Builder { + + abstract Builder setCommand(String command); + + abstract Builder setArguments(List arguments); + + abstract Optional> getArguments(); + + abstract PrismExecutor autoBuild(); + + final PrismExecutor build() { + if (!getArguments().isPresent()) { + setArguments(Collections.emptyList()); + } + return autoBuild(); + } + } +} diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismLocator.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismLocator.java new file mode 100644 index 000000000000..f32e4d88f42b --- /dev/null +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismLocator.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.prism; + +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.attribute.PosixFilePermission; +import java.nio.file.attribute.PosixFilePermissions; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.zip.ZipEntry; +import java.util.zip.ZipInputStream; +import org.apache.beam.sdk.util.ReleaseInfo; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; + +/** + * Locates a Prism executable based on a user's default operating system and architecture + * environment or a {@link PrismPipelineOptions#getPrismLocation()} override. Handles the download, + * unzip, {@link PosixFilePermissions}, as needed. For {@link #GITHUB_DOWNLOAD_PREFIX} sources, + * additionally performs a SHA512 verification. + */ +class PrismLocator { + static final String OS_NAME_PROPERTY = "os.name"; + static final String ARCH_PROPERTY = "os.arch"; + static final String USER_HOME_PROPERTY = "user.home"; + + private static final String ZIP_EXT = "zip"; + private static final ReleaseInfo RELEASE_INFO = ReleaseInfo.getReleaseInfo(); + private static final String PRISM_BIN_PATH = ".apache_beam/cache/prism/bin"; + private static final Set PERMS = + PosixFilePermissions.fromString("rwxr-xr-x"); + private static final String GITHUB_DOWNLOAD_PREFIX = + "https://github.com/apache/beam/releases/download"; + private static final String GITHUB_TAG_PREFIX = "https://github.com/apache/beam/releases/tag"; + + private final PrismPipelineOptions options; + + PrismLocator(PrismPipelineOptions options) { + this.options = options; + } + + /** + * Downloads and prepares a Prism executable for use with the {@link PrismRunner}. The returned + * {@link String} is the absolute path to the Prism executable. + */ + String resolve() throws IOException { + + String from = + String.format("%s/v%s/%s.zip", GITHUB_DOWNLOAD_PREFIX, getSDKVersion(), buildFileName()); + + if (!Strings.isNullOrEmpty(options.getPrismLocation())) { + checkArgument( + !options.getPrismLocation().startsWith(GITHUB_TAG_PREFIX), + "Provided --prismLocation URL is not an Apache Beam Github " + + "Release page URL or download URL: ", + from); + + from = options.getPrismLocation(); + } + + String fromFileName = getNameWithoutExtension(from); + Path to = Paths.get(userHome(), PRISM_BIN_PATH, fromFileName); + + if (Files.exists(to)) { + return to.toString(); + } + + createDirectoryIfNeeded(to); + + if (from.startsWith("http")) { + String result = resolve(new URL(from), to); + checkState(Files.exists(to), "Resolved location does not exist: %s", result); + return result; + } + + String result = resolve(Paths.get(from), to); + checkState(Files.exists(to), "Resolved location does not exist: %s", result); + return result; + } + + static Path prismBinDirectory() { + return Paths.get(userHome(), PRISM_BIN_PATH); + } + + private String resolve(URL from, Path to) throws IOException { + BiConsumer downloadFn = PrismLocator::download; + if (from.getPath().endsWith(ZIP_EXT)) { + downloadFn = PrismLocator::unzip; + } + downloadFn.accept(from, to); + + Files.setPosixFilePermissions(to, PERMS); + + return to.toString(); + } + + private String resolve(Path from, Path to) throws IOException { + + BiConsumer copyFn = PrismLocator::copy; + if (from.endsWith(ZIP_EXT)) { + copyFn = PrismLocator::unzip; + } + + copyFn.accept(from.toUri().toURL().openStream(), to); + ByteStreams.copy(from.toUri().toURL().openStream(), Files.newOutputStream(to)); + Files.setPosixFilePermissions(to, PERMS); + + return to.toString(); + } + + String buildFileName() { + String version = getSDKVersion(); + return String.format("apache_beam-v%s-prism-%s-%s", version, os(), arch()); + } + + private static void unzip(URL from, Path to) { + try { + unzip(from.openStream(), to); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static void unzip(InputStream from, Path to) { + try (OutputStream out = Files.newOutputStream(to)) { + ZipInputStream zis = new ZipInputStream(from); + for (ZipEntry entry = zis.getNextEntry(); entry != null; entry = zis.getNextEntry()) { + InputStream in = ByteStreams.limit(zis, entry.getSize()); + ByteStreams.copy(in, out); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static void copy(InputStream from, Path to) { + try { + ByteStreams.copy(from, Files.newOutputStream(to)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static void download(URL from, Path to) { + try { + ByteStreams.copy(from.openStream(), Files.newOutputStream(to)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static String getNameWithoutExtension(String path) { + return org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Files + .getNameWithoutExtension(path); + } + + private String getSDKVersion() { + if (Strings.isNullOrEmpty(options.getPrismVersionOverride())) { + return RELEASE_INFO.getSdkVersion(); + } + return options.getPrismVersionOverride(); + } + + private static String os() { + String result = mustGetPropertyAsLowerCase(OS_NAME_PROPERTY); + if (result.contains("mac")) { + return "darwin"; + } + return result; + } + + private static String arch() { + String result = mustGetPropertyAsLowerCase(ARCH_PROPERTY); + if (result.contains("aarch")) { + return "arm64"; + } + return result; + } + + private static String userHome() { + return mustGetPropertyAsLowerCase(USER_HOME_PROPERTY); + } + + private static String mustGetPropertyAsLowerCase(String name) { + return checkStateNotNull(System.getProperty(name), "System property: " + name + " not set") + .toLowerCase(); + } + + private static void createDirectoryIfNeeded(Path path) throws IOException { + Path parent = path.getParent(); + if (parent == null) { + return; + } + Files.createDirectories(parent); + } +} diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismPipelineOptions.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismPipelineOptions.java new file mode 100644 index 000000000000..ec0f8beb620a --- /dev/null +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismPipelineOptions.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.prism; + +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PortablePipelineOptions; + +/** + * {@link org.apache.beam.sdk.options.PipelineOptions} for running a {@link + * org.apache.beam.sdk.Pipeline} on the {@link PrismRunner}. + */ +public interface PrismPipelineOptions extends PortablePipelineOptions { + @Description( + "Path or URL to a prism binary, or zipped binary for the current " + + "platform (Operating System and Architecture). May also be an Apache " + + "Beam Github Release page URL, with a matching --prismVersionOverride " + + "set. This option overrides all others for finding a prism binary.") + String getPrismLocation(); + + void setPrismLocation(String prismLocation); + + @Description( + "Override the SDK's version for deriving the Github Release URLs for " + + "downloading a zipped prism binary, for the current platform. If " + + "set to a Github Release page URL, then it will use that release page as a base when constructing the download URL.") + String getPrismVersionOverride(); + + void setPrismVersionOverride(String prismVersionOverride); +} diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismPipelineResult.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismPipelineResult.java new file mode 100644 index 000000000000..a551196c9b6f --- /dev/null +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismPipelineResult.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.prism; + +import java.io.IOException; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.metrics.MetricResults; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; + +/** + * The {@link PipelineResult} of executing a {@link org.apache.beam.sdk.Pipeline} using the {@link + * PrismRunner} and an internal {@link PipelineResult} delegate. + */ +class PrismPipelineResult implements PipelineResult { + + static PrismPipelineResult of(PipelineResult delegate, PrismExecutor executor) { + return new PrismPipelineResult(delegate, executor::stop); + } + + private final PipelineResult delegate; + private final Runnable cancel; + private @Nullable MetricResults terminalMetrics; + private @Nullable State terminalState; + + /** + * Instantiate the {@link PipelineResult} from the {@param delegate} and a {@param cancel} to be + * called when stopping the underlying executable Job management service. + */ + PrismPipelineResult(PipelineResult delegate, Runnable cancel) { + this.delegate = delegate; + this.cancel = cancel; + } + + /** Forwards the result of the delegate {@link PipelineResult#getState}. */ + @Override + public State getState() { + if (terminalState != null) { + return terminalState; + } + return delegate.getState(); + } + + /** + * Forwards the result of the delegate {@link PipelineResult#cancel}. Invokes {@link + * PrismExecutor#stop()} before returning the resulting {@link + * org.apache.beam.sdk.PipelineResult.State}. + */ + @Override + public State cancel() throws IOException { + State state = delegate.cancel(); + this.terminalMetrics = delegate.metrics(); + this.terminalState = state; + this.cancel.run(); + return state; + } + + /** + * Forwards the result of the delegate {@link PipelineResult#waitUntilFinish(Duration)}. Invokes + * {@link PrismExecutor#stop()} before returning the resulting {@link + * org.apache.beam.sdk.PipelineResult.State}. + */ + @Override + public State waitUntilFinish(Duration duration) { + State state = delegate.waitUntilFinish(duration); + this.terminalMetrics = delegate.metrics(); + this.terminalState = state; + this.cancel.run(); + return state; + } + + /** + * Forwards the result of the delegate {@link PipelineResult#waitUntilFinish}. Invokes {@link + * PrismExecutor#stop()} before returning the resulting {@link + * org.apache.beam.sdk.PipelineResult.State}. + */ + @Override + public State waitUntilFinish() { + State state = delegate.waitUntilFinish(); + this.terminalMetrics = delegate.metrics(); + this.terminalState = state; + this.cancel.run(); + return state; + } + + /** Forwards the result of the delegate {@link PipelineResult#metrics}. */ + @Override + public MetricResults metrics() { + if (terminalMetrics != null) { + return terminalMetrics; + } + return delegate.metrics(); + } +} diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismRunner.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismRunner.java new file mode 100644 index 000000000000..1ea4367292b0 --- /dev/null +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismRunner.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.prism; + +import org.apache.beam.runners.portability.PortableRunner; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.PipelineRunner; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PortablePipelineOptions; +import org.apache.beam.sdk.util.construction.Environments; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A {@link PipelineRunner} executed on Prism. Downloads, prepares, and executes the Prism service + * on behalf of the developer when {@link PipelineRunner#run}ning the pipeline. If users want to + * submit to an already running Prism service, use the {@link PortableRunner} with the {@link + * PortablePipelineOptions#getJobEndpoint()} option instead. Prism is a {@link + * org.apache.beam.runners.portability.PortableRunner} maintained at sdks/go/cmd/prism. + */ +// TODO(https://github.com/apache/beam/issues/31793): add public modifier after finalizing +// PrismRunner. Depends on: https://github.com/apache/beam/issues/31402 and +// https://github.com/apache/beam/issues/31792. +class PrismRunner extends PipelineRunner { + + private static final Logger LOG = LoggerFactory.getLogger(PrismRunner.class); + + private static final String DEFAULT_PRISM_ENDPOINT = "localhost:8073"; + + private final PortableRunner internal; + private final PrismPipelineOptions prismPipelineOptions; + + private PrismRunner(PortableRunner internal, PrismPipelineOptions prismPipelineOptions) { + this.internal = internal; + this.prismPipelineOptions = prismPipelineOptions; + } + + /** + * Invoked from {@link Pipeline#run} where {@link PrismRunner} instantiates using {@link + * PrismPipelineOptions} configuration details. + */ + public static PrismRunner fromOptions(PipelineOptions options) { + PrismPipelineOptions prismPipelineOptions = options.as(PrismPipelineOptions.class); + assignDefaultsIfNeeded(prismPipelineOptions); + PortableRunner internal = PortableRunner.fromOptions(options); + return new PrismRunner(internal, prismPipelineOptions); + } + + @Override + public PipelineResult run(Pipeline pipeline) { + LOG.info( + "running Pipeline using {}: defaultEnvironmentType: {}, jobEndpoint: {}", + PortableRunner.class.getName(), + prismPipelineOptions.getDefaultEnvironmentType(), + prismPipelineOptions.getJobEndpoint()); + + return internal.run(pipeline); + } + + private static void assignDefaultsIfNeeded(PrismPipelineOptions prismPipelineOptions) { + if (Strings.isNullOrEmpty(prismPipelineOptions.getDefaultEnvironmentType())) { + prismPipelineOptions.setDefaultEnvironmentType(Environments.ENVIRONMENT_LOOPBACK); + } + if (Strings.isNullOrEmpty(prismPipelineOptions.getJobEndpoint())) { + prismPipelineOptions.setJobEndpoint(DEFAULT_PRISM_ENDPOINT); + } + } +} diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/package-info.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/package-info.java new file mode 100644 index 000000000000..2642f3e59951 --- /dev/null +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Support for executing a pipeline on Prism. */ +package org.apache.beam.runners.prism; diff --git a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismExecutorTest.java b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismExecutorTest.java new file mode 100644 index 000000000000..315e585a0c5f --- /dev/null +++ b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismExecutorTest.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.prism; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.runners.prism.PrismRunnerTest.getLocalPrismBuildOrIgnoreTest; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.Collections; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link PrismExecutor}. */ +@RunWith(JUnit4.class) +public class PrismExecutorTest { + @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder(); + + @Rule public TestName testName = new TestName(); + + @Test + public void executeThenStop() throws IOException { + PrismExecutor executor = underTest().build(); + executor.execute(); + sleep(3000L); + executor.stop(); + } + + @Test + public void executeWithStreamRedirectThenStop() throws IOException { + PrismExecutor executor = underTest().build(); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + executor.execute(outputStream); + sleep(3000L); + executor.stop(); + String output = outputStream.toString(StandardCharsets.UTF_8.name()); + assertThat(output).contains("INFO Serving JobManagement endpoint=localhost:8073"); + } + + @Test + public void executeWithFileOutputThenStop() throws IOException { + PrismExecutor executor = underTest().build(); + File log = temporaryFolder.newFile(testName.getMethodName()); + executor.execute(log); + sleep(3000L); + executor.stop(); + try (Stream stream = Files.lines(log.toPath(), StandardCharsets.UTF_8)) { + String output = stream.collect(Collectors.joining("\n")); + assertThat(output).contains("INFO Serving JobManagement endpoint=localhost:8073"); + } + } + + @Test + public void executeWithCustomArgumentsThenStop() throws IOException { + PrismExecutor executor = + underTest().setArguments(Collections.singletonList("-job_port=5555")).build(); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + executor.execute(outputStream); + sleep(3000L); + executor.stop(); + String output = outputStream.toString(StandardCharsets.UTF_8.name()); + assertThat(output).contains("INFO Serving JobManagement endpoint=localhost:5555"); + } + + private PrismExecutor.Builder underTest() { + return PrismExecutor.builder().setCommand(getLocalPrismBuildOrIgnoreTest()); + } + + private void sleep(long millis) { + try { + Thread.sleep(millis); + } catch (InterruptedException ignored) { + } + } +} diff --git a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismLocatorTest.java b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismLocatorTest.java new file mode 100644 index 000000000000..982a8bfd657c --- /dev/null +++ b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismLocatorTest.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.prism; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.runners.prism.PrismLocator.prismBinDirectory; +import static org.apache.beam.runners.prism.PrismRunnerTest.getLocalPrismBuildOrIgnoreTest; +import static org.junit.Assert.assertThrows; + +import java.io.IOException; +import java.nio.file.FileVisitResult; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.SimpleFileVisitor; +import java.nio.file.attribute.BasicFileAttributes; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link PrismLocator}. */ +@RunWith(JUnit4.class) +public class PrismLocatorTest { + + private static final Path DESTINATION_DIRECTORY = prismBinDirectory(); + + @Before + public void setup() throws IOException { + if (Files.exists(DESTINATION_DIRECTORY)) { + Files.walkFileTree( + DESTINATION_DIRECTORY, + new SimpleFileVisitor() { + @Override + public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) + throws IOException { + Files.delete(file); + return FileVisitResult.CONTINUE; + } + }); + + Files.delete(DESTINATION_DIRECTORY); + } + } + + @Test + public void givenVersionOverride_thenResolves() throws IOException { + assertThat(Files.exists(DESTINATION_DIRECTORY)).isFalse(); + PrismPipelineOptions options = options(); + options.setPrismVersionOverride("2.57.0"); + PrismLocator underTest = new PrismLocator(options); + String got = underTest.resolve(); + assertThat(got).contains(DESTINATION_DIRECTORY.toString()); + assertThat(got).contains("2.57.0"); + Path gotPath = Paths.get(got); + assertThat(Files.exists(gotPath)).isTrue(); + } + + @Test + public void givenHttpPrismLocationOption_thenResolves() throws IOException { + assertThat(Files.exists(DESTINATION_DIRECTORY)).isFalse(); + PrismPipelineOptions options = options(); + options.setPrismLocation( + "https://github.com/apache/beam/releases/download/v2.57.0/apache_beam-v2.57.0-prism-darwin-arm64.zip"); + PrismLocator underTest = new PrismLocator(options); + String got = underTest.resolve(); + assertThat(got).contains(DESTINATION_DIRECTORY.toString()); + Path gotPath = Paths.get(got); + assertThat(Files.exists(gotPath)).isTrue(); + } + + @Test + public void givenFilePrismLocationOption_thenResolves() throws IOException { + assertThat(Files.exists(DESTINATION_DIRECTORY)).isFalse(); + PrismPipelineOptions options = options(); + options.setPrismLocation(getLocalPrismBuildOrIgnoreTest()); + PrismLocator underTest = new PrismLocator(options); + String got = underTest.resolve(); + assertThat(got).contains(DESTINATION_DIRECTORY.toString()); + Path gotPath = Paths.get(got); + assertThat(Files.exists(gotPath)).isTrue(); + } + + @Test + public void givenGithubTagPrismLocationOption_thenThrows() { + PrismPipelineOptions options = options(); + options.setPrismLocation( + "https://github.com/apache/beam/releases/tag/v2.57.0/apache_beam-v2.57.0-prism-darwin-amd64.zip"); + PrismLocator underTest = new PrismLocator(options); + IllegalArgumentException error = + assertThrows(IllegalArgumentException.class, underTest::resolve); + assertThat(error.getMessage()) + .contains( + "Provided --prismLocation URL is not an Apache Beam Github Release page URL or download URL"); + } + + @Test + public void givenPrismLocation404_thenThrows() { + PrismPipelineOptions options = options(); + options.setPrismLocation("https://example.com/i/dont/exist.zip"); + PrismLocator underTest = new PrismLocator(options); + RuntimeException error = assertThrows(RuntimeException.class, underTest::resolve); + assertThat(error.getMessage()).contains("NotFoundException"); + } + + private static PrismPipelineOptions options() { + return PipelineOptionsFactory.create().as(PrismPipelineOptions.class); + } +} diff --git a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismPipelineResultTest.java b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismPipelineResultTest.java new file mode 100644 index 000000000000..2ad7e2eb3dd9 --- /dev/null +++ b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismPipelineResultTest.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.prism; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.runners.prism.PrismRunnerTest.getLocalPrismBuildOrIgnoreTest; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.metrics.MetricResults; +import org.joda.time.Duration; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link PrismPipelineResult}. */ +@RunWith(JUnit4.class) +public class PrismPipelineResultTest { + + final PrismExecutor exec = executor(); + + @Before + public void setUp() throws IOException { + exec.execute(); + assertThat(exec.isAlive()).isTrue(); + } + + @After + public void tearDown() { + assertThat(exec.isAlive()).isFalse(); + } + + @Test + public void givenTerminated_reportsState() { + PipelineResult delegate = mock(PipelineResult.class); + when(delegate.waitUntilFinish()).thenReturn(PipelineResult.State.FAILED); + PrismPipelineResult underTest = new PrismPipelineResult(delegate, exec::stop); + // Assigns terminal state. + underTest.waitUntilFinish(); + assertThat(underTest.getState()).isEqualTo(PipelineResult.State.FAILED); + } + + @Test + public void givenNotTerminated_reportsState() { + PipelineResult delegate = mock(PipelineResult.class); + when(delegate.getState()).thenReturn(PipelineResult.State.RUNNING); + PrismPipelineResult underTest = new PrismPipelineResult(delegate, exec::stop); + assertThat(underTest.getState()).isEqualTo(PipelineResult.State.RUNNING); + exec.stop(); + } + + @Test + public void cancelStopsExecutable_reportsTerminalState() throws IOException { + PipelineResult delegate = mock(PipelineResult.class); + when(delegate.cancel()).thenReturn(PipelineResult.State.CANCELLED); + PrismPipelineResult underTest = new PrismPipelineResult(delegate, exec::stop); + assertThat(underTest.cancel()).isEqualTo(PipelineResult.State.CANCELLED); + } + + @Test + public void givenTerminated_cancelIsNoop_reportsTerminalState() throws IOException { + PipelineResult delegate = mock(PipelineResult.class); + when(delegate.cancel()).thenReturn(PipelineResult.State.FAILED); + PrismPipelineResult underTest = new PrismPipelineResult(delegate, exec::stop); + assertThat(underTest.cancel()).isEqualTo(PipelineResult.State.FAILED); + } + + @Test + public void givenPipelineRunWithDuration_waitUntilFinish_reportsTerminalState() { + PipelineResult delegate = mock(PipelineResult.class); + when(delegate.waitUntilFinish(Duration.millis(3000L))) + .thenReturn(PipelineResult.State.CANCELLED); + PrismPipelineResult underTest = new PrismPipelineResult(delegate, exec::stop); + assertThat(underTest.waitUntilFinish(Duration.millis(3000L))) + .isEqualTo(PipelineResult.State.CANCELLED); + } + + @Test + public void givenTerminated_waitUntilFinishIsNoop_reportsTerminalState() { + PipelineResult delegate = mock(PipelineResult.class); + when(delegate.waitUntilFinish()).thenReturn(PipelineResult.State.DONE); + PrismPipelineResult underTest = new PrismPipelineResult(delegate, exec::stop); + // Terminate Job as setup for additional call. + underTest.waitUntilFinish(); + assertThat(underTest.waitUntilFinish()).isEqualTo(PipelineResult.State.DONE); + } + + @Test + public void givenNotTerminated_reportsMetrics() { + PipelineResult delegate = mock(PipelineResult.class); + when(delegate.metrics()).thenReturn(mock(MetricResults.class)); + PrismPipelineResult underTest = new PrismPipelineResult(delegate, exec::stop); + assertThat(underTest.metrics()).isNotNull(); + exec.stop(); + } + + @Test + public void givenTerminated_reportsTerminatedMetrics() { + PipelineResult delegate = mock(PipelineResult.class); + when(delegate.metrics()).thenReturn(mock(MetricResults.class)); + when(delegate.waitUntilFinish()).thenReturn(PipelineResult.State.DONE); + PrismPipelineResult underTest = new PrismPipelineResult(delegate, exec::stop); + // Terminate Job as setup for additional call. + underTest.waitUntilFinish(); + assertThat(underTest.metrics()).isNotNull(); + } + + private static PrismExecutor executor() { + return PrismExecutor.builder().setCommand(getLocalPrismBuildOrIgnoreTest()).build(); + } +} diff --git a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismRunnerTest.java b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismRunnerTest.java new file mode 100644 index 000000000000..2cacb671be3e --- /dev/null +++ b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismRunnerTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.prism; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assume.assumeTrue; + +import java.io.IOException; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.PeriodicImpulse; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link PrismRunner}. */ + +// TODO(https://github.com/apache/beam/issues/31793): Remove @Ignore after finalizing PrismRunner. +// Depends on: https://github.com/apache/beam/issues/31402 and +// https://github.com/apache/beam/issues/31792. +@Ignore +@RunWith(JUnit4.class) +public class PrismRunnerTest { + // See build.gradle for test task configuration. + private static final String PRISM_BUILD_TARGET_PROPERTY_NAME = "prism.buildTarget"; + + @Test + public void givenBoundedSource_runsUntilDone() { + Pipeline pipeline = Pipeline.create(options()); + pipeline.apply(Create.of(1, 2, 3)); + PipelineResult.State state = pipeline.run().waitUntilFinish(); + assertThat(state).isEqualTo(PipelineResult.State.DONE); + } + + @Test + public void givenUnboundedSource_runsUntilCancel() throws IOException { + Pipeline pipeline = Pipeline.create(options()); + pipeline.apply(PeriodicImpulse.create()); + PipelineResult result = pipeline.run(); + assertThat(result.getState()).isEqualTo(PipelineResult.State.RUNNING); + PipelineResult.State state = result.cancel(); + assertThat(state).isEqualTo(PipelineResult.State.CANCELLED); + } + + private static PrismPipelineOptions options() { + PrismPipelineOptions opts = PipelineOptionsFactory.create().as(PrismPipelineOptions.class); + + opts.setRunner(PrismRunner.class); + opts.setPrismLocation(getLocalPrismBuildOrIgnoreTest()); + + return opts; + } + + /** + * Drives ignoring of tests via checking {@link org.junit.Assume#assumeTrue} that the {@link + * System#getProperty} for {@link #PRISM_BUILD_TARGET_PROPERTY_NAME} is not null or empty. + */ + static String getLocalPrismBuildOrIgnoreTest() { + String command = System.getProperty(PRISM_BUILD_TARGET_PROPERTY_NAME); + assumeTrue( + "System property: " + + PRISM_BUILD_TARGET_PROPERTY_NAME + + " is not set; see build.gradle for test task configuration", + !Strings.isNullOrEmpty(command)); + return command; + } +} diff --git a/runners/samza/build.gradle b/runners/samza/build.gradle index a50e0d62e59a..fc590172f11c 100644 --- a/runners/samza/build.gradle +++ b/runners/samza/build.gradle @@ -111,6 +111,8 @@ def sickbayTests = [ 'org.apache.beam.sdk.coders.PCollectionCustomCoderTest.testDecodingIOException', // https://github.com/apache/beam/issues/19344 'org.apache.beam.sdk.io.BoundedReadFromUnboundedSourceTest.testTimeBound', + // https://github.com/apache/beam/issues/31725 + 'org.apache.beam.sdk.io.TextIOWriteTest.testWriteUnboundedWithCustomBatchParameters', ] tasks.register("validatesRunner", Test) { group = "Verification" diff --git a/runners/samza/job-server/build.gradle b/runners/samza/job-server/build.gradle index f972f376e5c8..4be206727121 100644 --- a/runners/samza/job-server/build.gradle +++ b/runners/samza/job-server/build.gradle @@ -90,6 +90,7 @@ def portableValidatesRunnerTask(String name, boolean docker) { excludeCategories 'org.apache.beam.sdk.testing.UsesCustomWindowMerging' excludeCategories 'org.apache.beam.sdk.testing.UsesFailureMessage' excludeCategories 'org.apache.beam.sdk.testing.UsesGaugeMetrics' + excludeCategories 'org.apache.beam.sdk.testing.UsesStringSetMetrics' excludeCategories 'org.apache.beam.sdk.testing.UsesParDoLifecycle' excludeCategories 'org.apache.beam.sdk.testing.UsesMapState' excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState' diff --git a/runners/spark/job-server/spark_job_server.gradle b/runners/spark/job-server/spark_job_server.gradle index 6d2d4b2bafbf..bd00c8cf52ac 100644 --- a/runners/spark/job-server/spark_job_server.gradle +++ b/runners/spark/job-server/spark_job_server.gradle @@ -117,6 +117,7 @@ def portableValidatesRunnerTask(String name, boolean streaming, boolean docker, excludeCategories 'org.apache.beam.sdk.testing.UsesCustomWindowMerging' excludeCategories 'org.apache.beam.sdk.testing.UsesFailureMessage' excludeCategories 'org.apache.beam.sdk.testing.UsesGaugeMetrics' + excludeCategories 'org.apache.beam.sdk.testing.UsesStringSetMetrics' excludeCategories 'org.apache.beam.sdk.testing.UsesPerKeyOrderedDelivery' excludeCategories 'org.apache.beam.sdk.testing.UsesParDoLifecycle' excludeCategories 'org.apache.beam.sdk.testing.UsesMapState' @@ -185,6 +186,7 @@ def portableValidatesRunnerTask(String name, boolean streaming, boolean docker, excludeCategories 'org.apache.beam.sdk.testing.UsesCustomWindowMerging' excludeCategories 'org.apache.beam.sdk.testing.UsesFailureMessage' excludeCategories 'org.apache.beam.sdk.testing.UsesGaugeMetrics' + excludeCategories 'org.apache.beam.sdk.testing.UsesStringSetMetrics' excludeCategories 'org.apache.beam.sdk.testing.UsesPerKeyOrderedDelivery' excludeCategories 'org.apache.beam.sdk.testing.UsesPerKeyOrderInBundle' excludeCategories 'org.apache.beam.sdk.testing.UsesParDoLifecycle' diff --git a/scripts/ci/pr-bot/package-lock.json b/scripts/ci/pr-bot/package-lock.json index 336a8d45677d..7cb764a43795 100644 --- a/scripts/ci/pr-bot/package-lock.json +++ b/scripts/ci/pr-bot/package-lock.json @@ -273,12 +273,12 @@ } }, "node_modules/braces": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", - "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", "dev": true, "dependencies": { - "fill-range": "^7.0.1" + "fill-range": "^7.1.1" }, "engines": { "node": ">=8" @@ -469,9 +469,9 @@ } }, "node_modules/fill-range": { - "version": "7.0.1", - "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", - "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", "dev": true, "dependencies": { "to-regex-range": "^5.0.1" @@ -1421,12 +1421,12 @@ } }, "braces": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", - "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", "dev": true, "requires": { - "fill-range": "^7.0.1" + "fill-range": "^7.1.1" } }, "browser-stdout": { @@ -1563,9 +1563,9 @@ "dev": true }, "fill-range": { - "version": "7.0.1", - "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", - "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", "dev": true, "requires": { "to-regex-range": "^5.0.1" diff --git a/scripts/ci/pr-bot/processPrUpdate.ts b/scripts/ci/pr-bot/processPrUpdate.ts index f9aa15713216..d38fa452a4b6 100644 --- a/scripts/ci/pr-bot/processPrUpdate.ts +++ b/scripts/ci/pr-bot/processPrUpdate.ts @@ -83,6 +83,18 @@ async function processPrComment( stateClient, reviewerConfig ); + + // Check to see if notifications have been stopped before processing further. + // Notifications can be stopped by an "R: reviewer" comment, + // and then restarted by adding "assign set of reviewers" comment. + if ( + (await stateClient.getPrState(getPullNumberFromPayload(payload))) + .stopReviewerNotifications + ) { + console.log("Notifications have been paused for this pull - skipping"); + return; + } + // If there's been a comment by a non-author, we can remove the slow review label if (commentAuthor !== pullAuthor && commentAuthor !== BOT_NAME) { await removeSlowReviewLabel(payload); @@ -140,11 +152,6 @@ async function processPrUpdate() { const pullNumber = getPullNumberFromPayload(payload); const stateClient = new PersistentState(); - const prState = await stateClient.getPrState(pullNumber); - if (prState.stopReviewerNotifications) { - console.log("Notifications have been paused for this pull - skipping"); - return; - } switch (github.context.eventName) { case "issue_comment": @@ -156,6 +163,12 @@ async function processPrUpdate() { await processPrComment(payload, stateClient, reviewerConfig); break; case "pull_request_target": + if ( + (await stateClient.getPrState(pullNumber)).stopReviewerNotifications + ) { + console.log("Notifications have been paused for this pull - skipping"); + return; + } if (payload.action === "synchronize") { console.log("Processing synchronize action"); await setNextActionReviewers(payload, stateClient); diff --git a/scripts/ci/pr-bot/shared/commentStrings.ts b/scripts/ci/pr-bot/shared/commentStrings.ts index 138a494feb27..d1b366bcf773 100644 --- a/scripts/ci/pr-bot/shared/commentStrings.ts +++ b/scripts/ci/pr-bot/shared/commentStrings.ts @@ -58,7 +58,7 @@ export function someChecksFailing(reviewersToNotify: string[]): string { } export function stopNotifications(reason: string): string { - return `Stopping reviewer notifications for this pull request: ${reason}`; + return `Stopping reviewer notifications for this pull request: ${reason}. If you'd like to restart, comment \`assign set of reviewers\``; } export function remindReviewerAfterTestsPass(requester: string): string { diff --git a/scripts/ci/pr-bot/shared/userCommand.ts b/scripts/ci/pr-bot/shared/userCommand.ts index e32746eb7fce..6980468c3b19 100644 --- a/scripts/ci/pr-bot/shared/userCommand.ts +++ b/scripts/ci/pr-bot/shared/userCommand.ts @@ -39,32 +39,42 @@ export async function processCommand( const pullNumber = payload.issue?.number || payload.pull_request?.number; commentText = commentText.toLowerCase(); - if (commentText.indexOf("r: @") > -1) { - await manuallyAssignedToReviewer(pullNumber, stateClient); - } else if (commentText.indexOf("assign to next reviewer") > -1) { - await assignToNextReviewer( - payload, - commentAuthor, - pullNumber, - stateClient, - reviewerConfig - ); - } else if (commentText.indexOf("stop reviewer notifications") > -1) { - await stopReviewerNotifications( - pullNumber, - stateClient, - "requested by reviewer" - ); - } else if (commentText.indexOf("remind me after tests pass") > -1) { - await remindAfterTestsPass(pullNumber, commentAuthor, stateClient); - } else if (commentText.indexOf("waiting on author") > -1) { - await waitOnAuthor(payload, pullNumber, stateClient); - } else if (commentText.indexOf("assign set of reviewers") > -1) { - await assignReviewerSet(payload, pullNumber, stateClient, reviewerConfig); + + let prState = await stateClient.getPrState(pullNumber); + if(prState.stopReviewerNotifications) { + // Notifications stopped, only "allow assign set of reviewers" + if (commentText.indexOf("assign set of reviewers") > -1) { + await assignReviewerSet(payload, pullNumber, stateClient, reviewerConfig); + } else { + return false; + } } else { - return false; + if (commentText.indexOf("r: @") > -1) { + await manuallyAssignedToReviewer(pullNumber, stateClient); + } else if (commentText.indexOf("assign to next reviewer") > -1) { + await assignToNextReviewer( + payload, + commentAuthor, + pullNumber, + stateClient, + reviewerConfig + ); + } else if (commentText.indexOf("stop reviewer notifications") > -1) { + await stopReviewerNotifications( + pullNumber, + stateClient, + "requested by reviewer" + ); + } else if (commentText.indexOf("remind me after tests pass") > -1) { + await remindAfterTestsPass(pullNumber, commentAuthor, stateClient); + } else if (commentText.indexOf("waiting on author") > -1) { + await waitOnAuthor(payload, pullNumber, stateClient); + } else if (commentText.indexOf("assign set of reviewers") > -1) { + await assignReviewerSet(payload, pullNumber, stateClient, reviewerConfig); + } else { + return false; + } } - return true; } @@ -175,6 +185,12 @@ async function assignReviewerSet( reviewerConfig: typeof ReviewerConfig ) { let prState = await stateClient.getPrState(pullNumber); + if(prState.stopReviewerNotifications) { + // Restore notifications, and clear any existing reviewer set to + // allow new reviewers to be assigned. + prState.stopReviewerNotifications = false; + prState.reviewersAssignedForLabels = {}; + } if (Object.values(prState.reviewersAssignedForLabels).length > 0) { await github.addPrComment( pullNumber, diff --git a/sdks/go.mod b/sdks/go.mod index b156c8beac2e..6d42e02296c7 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -27,9 +27,9 @@ require ( cloud.google.com/go/bigtable v1.25.0 cloud.google.com/go/datastore v1.17.1 cloud.google.com/go/profiler v0.4.0 - cloud.google.com/go/pubsub v1.38.0 + cloud.google.com/go/pubsub v1.39.0 cloud.google.com/go/spanner v1.63.0 - cloud.google.com/go/storage v1.41.0 + cloud.google.com/go/storage v1.43.0 github.com/aws/aws-sdk-go-v2 v1.30.0 github.com/aws/aws-sdk-go-v2/config v1.27.4 github.com/aws/aws-sdk-go-v2/credentials v1.17.18 @@ -45,8 +45,8 @@ require ( github.com/johannesboyne/gofakes3 v0.0.0-20221110173912-32fb85c5aed6 github.com/lib/pq v1.10.9 github.com/linkedin/goavro/v2 v2.13.0 - github.com/nats-io/nats-server/v2 v2.10.12 - github.com/nats-io/nats.go v1.33.1 + github.com/nats-io/nats-server/v2 v2.10.16 + github.com/nats-io/nats.go v1.35.0 github.com/proullon/ramsql v0.1.3 github.com/spf13/cobra v1.8.1 github.com/testcontainers/testcontainers-go v0.26.0 @@ -59,9 +59,9 @@ require ( golang.org/x/sync v0.7.0 golang.org/x/sys v0.21.0 golang.org/x/text v0.16.0 - google.golang.org/api v0.184.0 - google.golang.org/genproto v0.0.0-20240604185151-ef581f913117 - google.golang.org/grpc v1.64.0 + google.golang.org/api v0.187.0 + google.golang.org/genproto v0.0.0-20240624140628-dc46fd24d27d + google.golang.org/grpc v1.64.1 google.golang.org/protobuf v1.34.2 gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 @@ -74,7 +74,7 @@ require ( ) require ( - cloud.google.com/go/auth v0.5.1 // indirect + cloud.google.com/go/auth v0.6.1 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect dario.cat/mergo v1.0.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect @@ -89,7 +89,7 @@ require ( github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/minio/highwayhash v1.0.2 // indirect github.com/moby/sys/user v0.1.0 // indirect - github.com/nats-io/jwt/v2 v2.5.5 // indirect + github.com/nats-io/jwt/v2 v2.5.7 // indirect github.com/nats-io/nkeys v0.4.7 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect @@ -151,12 +151,12 @@ require ( github.com/google/renameio/v2 v2.0.0 // indirect github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect - github.com/googleapis/gax-go/v2 v2.12.4 // indirect + github.com/googleapis/gax-go/v2 v2.12.5 // indirect github.com/gorilla/handlers v1.5.2 // indirect github.com/gorilla/mux v1.8.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect - github.com/klauspost/compress v1.17.7 // indirect + github.com/klauspost/compress v1.17.8 // indirect github.com/klauspost/cpuid/v2 v2.2.6 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/moby/patternmatcher v0.6.0 // indirect @@ -183,6 +183,6 @@ require ( golang.org/x/mod v0.17.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240610135401-a8a62080eff3 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect ) diff --git a/sdks/go.sum b/sdks/go.sum index 2b2862bb2f0b..098f858488b7 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -99,8 +99,8 @@ cloud.google.com/go/assuredworkloads v1.7.0/go.mod h1:z/736/oNmtGAyU47reJgGN+KVo cloud.google.com/go/assuredworkloads v1.8.0/go.mod h1:AsX2cqyNCOvEQC8RMPnoc0yEarXQk6WEKkxYfL6kGIo= cloud.google.com/go/assuredworkloads v1.9.0/go.mod h1:kFuI1P78bplYtT77Tb1hi0FMxM0vVpRC7VVoJC3ZoT0= cloud.google.com/go/assuredworkloads v1.10.0/go.mod h1:kwdUQuXcedVdsIaKgKTp9t0UJkE5+PAVNhdQm4ZVq2E= -cloud.google.com/go/auth v0.5.1 h1:0QNO7VThG54LUzKiQxv8C6x1YX7lUrzlAa1nVLF8CIw= -cloud.google.com/go/auth v0.5.1/go.mod h1:vbZT8GjzDf3AVqCcQmqeeM32U9HBFc32vVVAbwDsa6s= +cloud.google.com/go/auth v0.6.1 h1:T0Zw1XM5c1GlpN2HYr2s+m3vr1p2wy+8VN+Z1FKxW38= +cloud.google.com/go/auth v0.6.1/go.mod h1:eFHG7zDzbXHKmjJddFG/rBlcGp6t25SwRUiEQSlO4x4= cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= cloud.google.com/go/automl v1.5.0/go.mod h1:34EjfoFGMZ5sgJ9EoLsRtdPSNZLcfflJR39VbVNS2M0= @@ -345,7 +345,7 @@ cloud.google.com/go/kms v1.8.0/go.mod h1:4xFEhYFqvW+4VMELtZyxomGSYtSQKzM178ylFW4 cloud.google.com/go/kms v1.9.0/go.mod h1:qb1tPTgfF9RQP8e1wq4cLFErVuTJv7UsSC915J8dh3w= cloud.google.com/go/kms v1.10.0/go.mod h1:ng3KTUtQQU9bPX3+QGLsflZIHlkbn8amFAMY63m8d24= cloud.google.com/go/kms v1.10.1/go.mod h1:rIWk/TryCkR59GMC3YtHtXeLzd634lBbKenvyySAyYI= -cloud.google.com/go/kms v1.17.1 h1:5k0wXqkxL+YcXd4viQzTqCgzzVKKxzgrK+rCZJytEQs= +cloud.google.com/go/kms v1.18.0 h1:pqNdaVmZJFP+i8OVLocjfpdTWETTYa20FWOegSCdrRo= cloud.google.com/go/language v1.4.0/go.mod h1:F9dRpNFQmJbkaop6g0JhSBXCNlO90e1KWx5iDdxbWic= cloud.google.com/go/language v1.6.0/go.mod h1:6dJ8t3B+lUYfStgls25GusK04NLh3eDLQnWM3mdEbhI= cloud.google.com/go/language v1.7.0/go.mod h1:DJ6dYN/W+SQOjF8e1hLQXMF21AkH2w9wiPzPCJa2MIE= @@ -443,8 +443,8 @@ cloud.google.com/go/pubsub v1.26.0/go.mod h1:QgBH3U/jdJy/ftjPhTkyXNj543Tin1pRYcd cloud.google.com/go/pubsub v1.27.1/go.mod h1:hQN39ymbV9geqBnfQq6Xf63yNhUAhv9CZhzp5O6qsW0= cloud.google.com/go/pubsub v1.28.0/go.mod h1:vuXFpwaVoIPQMGXqRyUQigu/AX1S3IWugR9xznmcXX8= cloud.google.com/go/pubsub v1.30.0/go.mod h1:qWi1OPS0B+b5L+Sg6Gmc9zD1Y+HaM0MdUr7LsupY1P4= -cloud.google.com/go/pubsub v1.38.0 h1:J1OT7h51ifATIedjqk/uBNPh+1hkvUaH4VKbz4UuAsc= -cloud.google.com/go/pubsub v1.38.0/go.mod h1:IPMJSWSus/cu57UyR01Jqa/bNOQA+XnPF6Z4dKW4fAA= +cloud.google.com/go/pubsub v1.39.0 h1:qt1+S6H+wwW8Q/YvDwM8lJnq+iIFgFEgaD/7h3lMsAI= +cloud.google.com/go/pubsub v1.39.0/go.mod h1:FrEnrSGU6L0Kh3iBaAbIUM8KMR7LqyEkMboVxGXCT+s= cloud.google.com/go/pubsublite v1.5.0/go.mod h1:xapqNQ1CuLfGi23Yda/9l4bBCKz/wC3KIJ5gKcxveZg= cloud.google.com/go/pubsublite v1.6.0/go.mod h1:1eFCS0U11xlOuMFV/0iBqw3zP12kddMeCbj/F3FSj9k= cloud.google.com/go/pubsublite v1.7.0/go.mod h1:8hVMwRXfDfvGm3fahVbtDbiLePT3gpoiJYJY+vxWxVM= @@ -553,8 +553,8 @@ cloud.google.com/go/storage v1.23.0/go.mod h1:vOEEDNFnciUMhBeT6hsJIn3ieU5cFRmzeL cloud.google.com/go/storage v1.27.0/go.mod h1:x9DOL8TK/ygDUMieqwfhdpQryTeEkhGKMi80i/iqR2s= cloud.google.com/go/storage v1.28.1/go.mod h1:Qnisd4CqDdo6BGs2AD5LLnEsmSQ80wQ5ogcBBKhU86Y= cloud.google.com/go/storage v1.29.0/go.mod h1:4puEjyTKnku6gfKoTfNOU/W+a9JyuVNxjpS5GBrB8h4= -cloud.google.com/go/storage v1.41.0 h1:RusiwatSu6lHeEXe3kglxakAmAbfV+rhtPqA6i8RBx0= -cloud.google.com/go/storage v1.41.0/go.mod h1:J1WCa/Z2FcgdEDuPUY8DxT5I+d9mFKsCepp5vR6Sq80= +cloud.google.com/go/storage v1.43.0 h1:CcxnSohZwizt4LCzQHWvBf1/kvtHUn7gk9QERXPyXFs= +cloud.google.com/go/storage v1.43.0/go.mod h1:ajvxEa7WmZS1PxvKRq4bq0tFT3vMd502JwstCcYv0Q0= cloud.google.com/go/storagetransfer v1.5.0/go.mod h1:dxNzUopWy7RQevYFHewchb29POFv3/AaBgnhqzqiK0w= cloud.google.com/go/storagetransfer v1.6.0/go.mod h1:y77xm4CQV/ZhFZH75PLEXY0ROiS7Gh6pSKrM8dJyg6I= cloud.google.com/go/storagetransfer v1.7.0/go.mod h1:8Giuj1QNb1kfLAiWM1bN6dHzfdlDAVC9rv9abHot2W4= @@ -954,8 +954,8 @@ github.com/googleapis/gax-go/v2 v2.5.1/go.mod h1:h6B0KMMFNtI2ddbGJn3T3ZbwkeT6yqE github.com/googleapis/gax-go/v2 v2.6.0/go.mod h1:1mjbznJAPHFpesgE5ucqfYEscaz5kMdcIDwU/6+DDoY= github.com/googleapis/gax-go/v2 v2.7.0/go.mod h1:TEop28CZZQ2y+c0VxMUmu1lV+fQx57QpBWsYpwqHJx8= github.com/googleapis/gax-go/v2 v2.7.1/go.mod h1:4orTrqY6hXxxaUL4LHIPl6lGo8vAE38/qKbhSAKP6QI= -github.com/googleapis/gax-go/v2 v2.12.4 h1:9gWcmF85Wvq4ryPFvGFaOgPIs1AQX0d0bcbGw4Z96qg= -github.com/googleapis/gax-go/v2 v2.12.4/go.mod h1:KYEYLorsnIGDi/rPC8b5TdlB9kbKoFubselGIoBMCwI= +github.com/googleapis/gax-go/v2 v2.12.5 h1:8gw9KZK8TiVKB6q3zHY3SBzLnrGp6HQjyfYBYGmXdxA= +github.com/googleapis/gax-go/v2 v2.12.5/go.mod h1:BUDKcWo+RaKq5SC9vVYL0wLADa3VcfswbOMMRmB9H3E= github.com/googleapis/go-type-adapters v1.0.0/go.mod h1:zHW75FOG2aur7gAO2B+MLby+cLsWGBF62rFAi7WjWO4= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE= @@ -1002,8 +1002,8 @@ github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0 github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= -github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= -github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= +github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= @@ -1052,12 +1052,12 @@ github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe h1:iruDEfMl2E6f github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= -github.com/nats-io/jwt/v2 v2.5.5 h1:ROfXb50elFq5c9+1ztaUbdlrArNFl2+fQWP6B8HGEq4= -github.com/nats-io/jwt/v2 v2.5.5/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A= -github.com/nats-io/nats-server/v2 v2.10.12 h1:G6u+RDrHkw4bkwn7I911O5jqys7jJVRY6MwgndyUsnE= -github.com/nats-io/nats-server/v2 v2.10.12/go.mod h1:H1n6zXtYLFCgXcf/SF8QNTSIFuS8tyZQMN9NguUHdEs= -github.com/nats-io/nats.go v1.33.1 h1:8TxLZZ/seeEfR97qV0/Bl939tpDnt2Z2fK3HkPypj70= -github.com/nats-io/nats.go v1.33.1/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8= +github.com/nats-io/jwt/v2 v2.5.7 h1:j5lH1fUXCnJnY8SsQeB/a/z9Azgu2bYIDvtPVNdxe2c= +github.com/nats-io/jwt/v2 v2.5.7/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A= +github.com/nats-io/nats-server/v2 v2.10.16 h1:2jXaiydp5oB/nAx/Ytf9fdCi9QN6ItIc9eehX8kwVV0= +github.com/nats-io/nats-server/v2 v2.10.16/go.mod h1:Pksi38H2+6xLe1vQx0/EA4bzetM0NqyIHcIbmgXSkIU= +github.com/nats-io/nats.go v1.35.0 h1:XFNqNM7v5B+MQMKqVGAyHwYhyKb48jrenXNxIU20ULk= +github.com/nats-io/nats.go v1.35.0/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8= github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI= github.com/nats-io/nkeys v0.4.7/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDmGD0nc= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= @@ -1661,8 +1661,8 @@ google.golang.org/api v0.108.0/go.mod h1:2Ts0XTHNVWxypznxWOYUeI4g3WdP9Pk2Qk58+a/ google.golang.org/api v0.110.0/go.mod h1:7FC4Vvx1Mooxh8C5HWjzZHcavuS2f6pmJpZx60ca7iI= google.golang.org/api v0.111.0/go.mod h1:qtFHvU9mhgTJegR31csQ+rwxyUTHOKFqCKWp1J0fdw0= google.golang.org/api v0.114.0/go.mod h1:ifYI2ZsFK6/uGddGfAD5BMxlnkBqCmqHSDUVi45N5Yg= -google.golang.org/api v0.184.0 h1:dmEdk6ZkJNXy1JcDhn/ou0ZUq7n9zropG2/tR4z+RDg= -google.golang.org/api v0.184.0/go.mod h1:CeDTtUEiYENAf8PPG5VZW2yNp2VM3VWbCeTioAZBTBA= +google.golang.org/api v0.187.0 h1:Mxs7VATVC2v7CY+7Xwm4ndkX71hpElcvx0D1Ji/p1eo= +google.golang.org/api v0.187.0/go.mod h1:KIHlTc4x7N7gKKuVsdmfBXN13yEEWXWFURWY6SBp2gk= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -1802,12 +1802,12 @@ google.golang.org/genproto v0.0.0-20230323212658-478b75c54725/go.mod h1:UUQDJDOl google.golang.org/genproto v0.0.0-20230330154414-c0448cd141ea/go.mod h1:UUQDJDOlWu4KYeJZffbWgBkS1YFobzKbLVfK69pe0Ak= google.golang.org/genproto v0.0.0-20230331144136-dcfb400f0633/go.mod h1:UUQDJDOlWu4KYeJZffbWgBkS1YFobzKbLVfK69pe0Ak= google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= -google.golang.org/genproto v0.0.0-20240604185151-ef581f913117 h1:HCZ6DlkKtCDAtD8ForECsY3tKuaR+p4R3grlK80uCCc= -google.golang.org/genproto v0.0.0-20240604185151-ef581f913117/go.mod h1:lesfX/+9iA+3OdqeCpoDddJaNxVB1AB6tD7EfqMmprc= -google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117 h1:+rdxYoE3E5htTEWIe15GlN6IfvbURM//Jt0mmkmm6ZU= -google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117/go.mod h1:OimBR/bc1wPO9iV4NC2bpyjy3VnAwZh5EBPQdtaE5oo= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240610135401-a8a62080eff3 h1:9Xyg6I9IWQZhRVfCWjKK+l6kI0jHcPesVlMnT//aHNo= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240610135401-a8a62080eff3/go.mod h1:EfXuqaE1J41VCDicxHzUDm+8rk+7ZdXzHV0IhO/I6s0= +google.golang.org/genproto v0.0.0-20240624140628-dc46fd24d27d h1:PksQg4dV6Sem3/HkBX+Ltq8T0ke0PKIRBNBatoDTVls= +google.golang.org/genproto v0.0.0-20240624140628-dc46fd24d27d/go.mod h1:s7iA721uChleev562UJO2OYB0PPT9CMFjV+Ce7VJH5M= +google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 h1:MuYw1wJzT+ZkybKfaOXKp5hJiZDn2iHaXRw0mRYdHSc= +google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4/go.mod h1:px9SlOOZBg1wM1zdnr8jEL4CNGUBZ+ZKYtNPApNQc4c= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d h1:k3zyW3BYYR30e8v3x0bTDdE9vpYFjZHK+HcyqkrppWk= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -1849,8 +1849,8 @@ google.golang.org/grpc v1.52.3/go.mod h1:pu6fVzoFb+NBYNAvQL08ic+lvB2IojljRYuun5v google.golang.org/grpc v1.53.0/go.mod h1:OnIrk0ipVdj4N5d9IUoFUx72/VlD7+jUsHwZgwSMQpw= google.golang.org/grpc v1.54.0/go.mod h1:PUSEXI6iWghWaB6lXM4knEgpJNu2qUcKfDtNci3EC2g= google.golang.org/grpc v1.56.3/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s= -google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY= -google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg= +google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= +google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= diff --git a/sdks/go/pkg/beam/core/core.go b/sdks/go/pkg/beam/core/core.go index 284db342eef3..2b449534409f 100644 --- a/sdks/go/pkg/beam/core/core.go +++ b/sdks/go/pkg/beam/core/core.go @@ -27,7 +27,7 @@ const ( // SdkName is the human readable name of the SDK for UserAgents. SdkName = "Apache Beam SDK for Go" // SdkVersion is the current version of the SDK. - SdkVersion = "2.58.0.dev" + SdkVersion = "2.59.0.dev" // DefaultDockerImage represents the associated image for this release. DefaultDockerImage = "apache/beam_go_sdk:" + SdkVersion diff --git a/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go b/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go index a205c768731b..a1eeeba02c4b 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go +++ b/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go @@ -291,16 +291,32 @@ func windowingStrategy(comps *pipepb.Components, tid string) *pipepb.WindowingSt // gbkBytes re-encodes gbk inputs in a gbk result. func gbkBytes(ws *pipepb.WindowingStrategy, wc, kc, vc *pipepb.Coder, toAggregate [][]byte, coders map[string]*pipepb.Coder, watermark mtime.Time) []byte { - var outputTime func(typex.Window, mtime.Time) mtime.Time + // Pick how the timestamp of the aggregated output is computed. + var outputTime func(typex.Window, mtime.Time, mtime.Time) mtime.Time switch ws.GetOutputTime() { case pipepb.OutputTime_END_OF_WINDOW: - outputTime = func(w typex.Window, et mtime.Time) mtime.Time { + outputTime = func(w typex.Window, _, _ mtime.Time) mtime.Time { return w.MaxTimestamp() } + case pipepb.OutputTime_EARLIEST_IN_PANE: + outputTime = func(_ typex.Window, cur, et mtime.Time) mtime.Time { + if et < cur { + return et + } + return cur + } + case pipepb.OutputTime_LATEST_IN_PANE: + outputTime = func(_ typex.Window, cur, et mtime.Time) mtime.Time { + if et > cur { + return et + } + return cur + } default: // TODO need to correct session logic if output time is different. panic(fmt.Sprintf("unsupported OutputTime behavior: %v", ws.GetOutputTime())) } + wDec, wEnc := makeWindowCoders(wc) type keyTime struct { @@ -336,14 +352,18 @@ func gbkBytes(ws *pipepb.WindowingStrategy, wc, kc, vc *pipepb.Coder, toAggregat key := string(keyByt) value := vd(buf) for _, w := range ws { - ft := outputTime(w, tm) wk, ok := windows[w] if !ok { wk = make(map[string]keyTime) windows[w] = wk } - kt := wk[key] - kt.time = ft + kt, ok := wk[key] + if !ok { + // If the window+key map doesn't have a value, inititialize time with the element time. + // This allows earliest or latest to work properly in the outputTime function's first use. + kt.time = tm + } + kt.time = outputTime(w, kt.time, tm) kt.key = keyByt kt.w = w kt.values = append(kt.values, value) diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go index 737a1b22276a..3efe48e23119 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go @@ -22,6 +22,7 @@ import ( "sync" "sync/atomic" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" @@ -195,7 +196,8 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (*jo // Inspect Windowing strategies for unsupported features. for wsID, ws := range job.Pipeline.GetComponents().GetWindowingStrategies() { - check("WindowingStrategy.AllowedLateness", ws.GetAllowedLateness(), int64(0)) + check("WindowingStrategy.AllowedLateness", ws.GetAllowedLateness(), int64(0), mtime.MaxTimestamp.Milliseconds()) + // Both Closing behaviors are identical without additional trigger firings. check("WindowingStrategy.ClosingBehaviour", ws.GetClosingBehavior(), pipepb.ClosingBehavior_EMIT_IF_NONEMPTY, pipepb.ClosingBehavior_EMIT_ALWAYS) check("WindowingStrategy.AccumulationMode", ws.GetAccumulationMode(), pipepb.AccumulationMode_DISCARDING) diff --git a/sdks/go/run_with_go_version.sh b/sdks/go/run_with_go_version.sh index 7de5f339011d..2e542821969c 100755 --- a/sdks/go/run_with_go_version.sh +++ b/sdks/go/run_with_go_version.sh @@ -37,7 +37,7 @@ set -e # # This variable is also used as the execution command downscript. # The list of downloadable versions are at https://go.dev/dl/ -GOVERS=go1.22.4 +GOVERS=go1.22.5 if ! command -v go &> /dev/null then diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Lineage.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Lineage.java new file mode 100644 index 000000000000..7890a9f74b94 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Lineage.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.metrics; + +/** + * Standard collection of metrics used to record source and sinks information for lineage tracking. + */ +public class Lineage { + + public static final String LINEAGE_NAMESPACE = "lineage"; + public static final String SOURCE_METRIC_NAME = "sources"; + public static final String SINK_METRIC_NAME = "sinks"; + + private static final StringSet SOURCES = Metrics.stringSet(LINEAGE_NAMESPACE, SOURCE_METRIC_NAME); + private static final StringSet SINKS = Metrics.stringSet(LINEAGE_NAMESPACE, SINK_METRIC_NAME); + + /** {@link StringSet} representing sources and optionally side inputs. */ + public static StringSet getSources() { + return SOURCES; + } + + /** {@link StringSet} representing sinks. */ + public static StringSet getSinks() { + return SINKS; + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricQueryResults.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricQueryResults.java index 86b1c1092824..9f60ce3d6c07 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricQueryResults.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricQueryResults.java @@ -21,9 +21,7 @@ import java.util.List; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -/** - * The results of a query for metrics. Allows accessing all of the metrics that matched the filter. - */ +/** The results of a query for metrics. Allows accessing all the metrics that matched the filter. */ @AutoValue public abstract class MetricQueryResults { /** Return the metric results for the counters that matched the filter. */ @@ -35,6 +33,9 @@ public abstract class MetricQueryResults { /** Return the metric results for the gauges that matched the filter. */ public abstract Iterable> getGauges(); + /** Return the metric results for the sets that matched the filter. */ + public abstract Iterable> getStringSets(); + static void printMetrics(String type, Iterable> metrics, StringBuilder sb) { List> metricsList = ImmutableList.copyOf(metrics); if (!metricsList.isEmpty()) { @@ -63,6 +64,7 @@ public final String toString() { printMetrics("Counters", getCounters(), sb); printMetrics("Distributions", getDistributions(), sb); printMetrics("Gauges", getGauges(), sb); + printMetrics("StringSets", getStringSets(), sb); sb.append(")"); return sb.toString(); } @@ -70,7 +72,8 @@ public final String toString() { public static MetricQueryResults create( Iterable> counters, Iterable> distributions, - Iterable> gauges) { - return new AutoValue_MetricQueryResults(counters, distributions, gauges); + Iterable> gauges, + Iterable> stringSets) { + return new AutoValue_MetricQueryResults(counters, distributions, gauges, stringSets); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricResult.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricResult.java index 25f4d8d9e626..b9cbc8d755ee 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricResult.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricResult.java @@ -37,7 +37,7 @@ public abstract class MetricResult { /** Return the name of the metric. */ public MetricName getName() { return getKey().metricName(); - }; + } public abstract MetricKey getKey(); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Metrics.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Metrics.java index 056141284655..a963015e98a7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Metrics.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Metrics.java @@ -93,6 +93,23 @@ public static Gauge gauge(Class namespace, String name) { return new DelegatingGauge(MetricName.named(namespace, name)); } + /** Create a metric that accumulates and reports set of unique string values. */ + public static StringSet stringSet(String namespace, String name) { + return new DelegatingStringSet(MetricName.named(namespace, name)); + } + + /** Create a metric that accumulates and reports set of unique string values. */ + public static StringSet stringSet(Class namespace, String name) { + return new DelegatingStringSet(MetricName.named(namespace, name)); + } + + /* + * A dedicated namespace for client throttling time. User DoFn can increment this metrics and then + * runner will put back pressure on scaling decision, if supported. + */ + public static final String THROTTLE_TIME_NAMESPACE = "beam-throttling-metrics"; + public static final String THROTTLE_TIME_COUNTER_NAME = "throttling-msecs"; + /** * Implementation of {@link Distribution} that delegates to the instance for the current context. */ @@ -146,4 +163,34 @@ public MetricName getName() { return name; } } + + /** Implementation of {@link StringSet} that delegates to the instance for the current context. */ + private static class DelegatingStringSet implements Metric, StringSet, Serializable { + private final MetricName name; + + private DelegatingStringSet(MetricName name) { + this.name = name; + } + + @Override + public void add(String value) { + MetricsContainer container = MetricsEnvironment.getCurrentContainer(); + if (container != null) { + container.getStringSet(name).add(value); + } + } + + @Override + public void add(String... value) { + MetricsContainer container = MetricsEnvironment.getCurrentContainer(); + if (container != null) { + container.getStringSet(name).add(value); + } + } + + @Override + public MetricName getName() { + return name; + } + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricsContainer.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricsContainer.java index f48b9195c37c..0c4766bb2c0b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricsContainer.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricsContainer.java @@ -53,6 +53,12 @@ default Counter getPerWorkerCounter(MetricName metricName) { */ Gauge getGauge(MetricName metricName); + /** + * Return the {@link StringSet} that should be used for implementing the given {@code metricName} + * in this container. + */ + StringSet getStringSet(MetricName metricName); + /** * Return the {@link Histogram} that should be used for implementing the given {@code metricName} * in this container. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricsEnvironment.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricsEnvironment.java index 7f8f2a436433..3421bb4afc85 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricsEnvironment.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/MetricsEnvironment.java @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.util.StringUtils; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; @@ -134,10 +135,14 @@ public void close() throws IOException { if (container == null && REPORTED_MISSING_CONTAINER.compareAndSet(false, true)) { if (isMetricsSupported()) { LOG.error( - "Unable to update metrics on the current thread. " - + "Most likely caused by using metrics outside the managed work-execution thread."); + "Unable to update metrics on the current thread. Most likely caused by using metrics " + + "outside the managed work-execution thread:\n {}", + StringUtils.arrayToNewlines(Thread.currentThread().getStackTrace(), 10)); } else { - LOG.warn("Reporting metrics are not supported in the current execution environment."); + // rate limiting this log as it can be emitted each time metrics incremented + LOG.warn( + "Reporting metrics are not supported in the current execution environment:\n {}", + StringUtils.arrayToNewlines(Thread.currentThread().getStackTrace(), 10)); } } return container; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/StringSet.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/StringSet.java new file mode 100644 index 000000000000..42e8f2388e38 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/StringSet.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.metrics; + +/** + * A metric that reports set of unique string values. This metric is backed by {@link + * java.util.HashSet} and hence it does not maintain any ordering. + */ +public interface StringSet extends Metric { + + /** Add a value to this set. */ + void add(String value); + + /** Add values to this set. */ + default void add(String... values) { + for (String value : values) { + add(value); + } + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/StringSetResult.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/StringSetResult.java new file mode 100644 index 000000000000..f2ad6292a5aa --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/StringSetResult.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.metrics; + +import com.google.auto.value.AutoValue; +import java.util.Set; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; + +/** + * The result of a {@link StringSet} metric. The {@link StringSetResult} hold an immutable copy of + * the set from which it was initially created representing that a result cannot be modified once + * created. + */ +@AutoValue +public abstract class StringSetResult { + public abstract Set getStringSet(); + + /** + * Creates a {@link StringSetResult} from the given {@link Set} by making an immutable copy. + * + * @param s the set from which the {@link StringSetResult} should be created. + * @return {@link StringSetResult} containing an immutable copy of the given set. + */ + public static StringSetResult create(Set s) { + return new AutoValue_StringSetResult(ImmutableSet.copyOf(s)); + } + + /** @return a {@link EmptyStringSetResult} */ + public static StringSetResult empty() { + return EmptyStringSetResult.INSTANCE; + } + + /** Empty {@link StringSetResult}, representing no values reported and is immutable. */ + public static class EmptyStringSetResult extends StringSetResult { + + private static final EmptyStringSetResult INSTANCE = new EmptyStringSetResult(); + + private EmptyStringSetResult() {} + + /** Returns an empty immutable set. */ + @Override + public Set getStringSet() { + return ImmutableSet.of(); + } + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java index aa7b2630cce2..7a102747b9f7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java @@ -31,6 +31,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.PipelineRunner; @@ -1610,8 +1611,18 @@ private SingletonCheckerDoFn( @ProcessElement public void processElement(ProcessContext c) { - ActualT actualContents = Iterables.getOnlyElement(c.element()); - c.output(doChecks(site, actualContents, checkerFn)); + try { + ActualT actualContents = Iterables.getOnlyElement(c.element()); + c.output(doChecks(site, actualContents, checkerFn)); + } catch (NoSuchElementException e) { + c.output( + SuccessOrFailure.failure( + site, + new IllegalArgumentException( + "expected singleton PCollection but was: empty PCollection", e))); + } catch (IllegalArgumentException e) { + c.output(SuccessOrFailure.failure(site, e)); + } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/SerializableMatchers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/SerializableMatchers.java index 749d95960263..ad3506045995 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/SerializableMatchers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/SerializableMatchers.java @@ -29,6 +29,7 @@ import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.util.SerializableSupplier; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.values.KV; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; @@ -734,17 +735,6 @@ public static SerializableMatcher fromSupplier(SerializableSupplier(supplier); } - /** - * Supplies values of type {@code T}, and is serializable. Thus, even if {@code T} is not - * serializable, the supplier can be serialized and provide a {@code T} wherever it is - * deserialized. - * - * @param the type of value supplied. - */ - public interface SerializableSupplier extends Serializable { - T get(); - } - /** * Since the delegate {@link Matcher} is not generally serializable, instead this takes a nullary * SerializableFunction to return such a matcher. @@ -752,7 +742,7 @@ public interface SerializableSupplier extends Serializable { private static class SerializableMatcherFromSupplier extends BaseMatcher implements SerializableMatcher { - private SerializableSupplier> supplier; + private final SerializableSupplier> supplier; public SerializableMatcherFromSupplier(SerializableSupplier> supplier) { this.supplier = supplier; diff --git a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/SerializableSupplier.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesStringSetMetrics.java similarity index 70% rename from sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/SerializableSupplier.java rename to sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesStringSetMetrics.java index f9ebaf815605..e645db801e48 100644 --- a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/SerializableSupplier.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesStringSetMetrics.java @@ -15,14 +15,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.io.requestresponse; +package org.apache.beam.sdk.testing; -import java.io.Serializable; -import java.util.function.Supplier; +import org.apache.beam.sdk.annotations.Internal; /** - * A union of a {@link Supplier} and {@link Serializable}, enabling configuration with {@link T} - * types that are not {@link Serializable}. + * Category tag for validation tests which utilize {@link org.apache.beam.sdk.metrics.StringSet}. + * Tests tagged with {@link UsesStringSetMetrics} should be run for runners which support StringSet. */ -@FunctionalInterface -public interface SerializableSupplier extends Supplier, Serializable {} +@Internal +public class UsesStringSetMetrics {} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/StringUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/StringUtils.java index 13105fb6c02c..ccd58857da04 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/StringUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/StringUtils.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import org.apache.beam.sdk.annotations.Internal; +import org.checkerframework.checker.nullness.qual.Nullable; /** Utilities for working with JSON and other human-readable string formats. */ @Internal @@ -143,4 +144,38 @@ public static int getLevenshteinDistance(final String s, final String t) { return v1[t.length()]; } + + /** + * Convert Array to new lined String. Truncate to first {@code maxLine} elements. + * + *

Useful to truncate stacktrace and for logging. + */ + public static String arrayToNewlines(Object[] array, int maxLine) { + int n = (maxLine > 0 && array.length > maxLine) ? maxLine : array.length; + StringBuilder b = new StringBuilder(); + for (int i = 0; i < n; i++) { + b.append(array[i]); + b.append("\n"); + } + if (array.length > maxLine) { + b.append("...\n"); + } + return b.toString(); + } + + /** + * Truncate String if length greater than maxLen, and append "..." to the end. Handles null. + * + *

Useful to truncate long logging message. + */ + public static String leftTruncate(@Nullable Object element, int maxLen) { + if (element == null) { + return ""; + } + String s = element.toString(); + if (s.length() > maxLen) { + return s.substring(0, maxLen) + "..."; + } + return s; + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java index 089d67993314..79709c89963b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java @@ -22,6 +22,7 @@ import static org.apache.beam.sdk.metrics.MetricResultsMatchers.metricsResult; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.hasItem; import static org.junit.Assert.assertNull; import static org.mockito.Mockito.verify; @@ -37,12 +38,14 @@ import org.apache.beam.sdk.testing.UsesCounterMetrics; import org.apache.beam.sdk.testing.UsesDistributionMetrics; import org.apache.beam.sdk.testing.UsesGaugeMetrics; +import org.apache.beam.sdk.testing.UsesStringSetMetrics; import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.After; @@ -85,6 +88,7 @@ public void tearDown() { protected PipelineResult runPipelineWithMetrics() { final Counter count = Metrics.counter(MetricsTest.class, "count"); + StringSet sideinputs = Metrics.stringSet(MetricsTest.class, "sideinputs"); final TupleTag output1 = new TupleTag() {}; final TupleTag output2 = new TupleTag() {}; pipeline @@ -104,11 +108,16 @@ public void startBundle() { @ProcessElement public void processElement(ProcessContext c) { Distribution values = Metrics.distribution(MetricsTest.class, "input"); + StringSet sources = Metrics.stringSet(MetricsTest.class, "sources"); count.inc(); values.update(c.element()); c.output(c.element()); c.output(c.element()); + sources.add("gcs"); + sources.add("gcs"); // repeated should appear once + sources.add("gcs", "gcs"); // repeated should appear once + sideinputs.add("bigtable", "spanner"); } @DoFn.FinishBundle @@ -125,11 +134,14 @@ public void finishBundle() { public void processElement(ProcessContext c) { Distribution values = Metrics.distribution(MetricsTest.class, "input"); Gauge gauge = Metrics.gauge(MetricsTest.class, "my-gauge"); + StringSet sinks = Metrics.stringSet(MetricsTest.class, "sinks"); Integer element = c.element(); count.inc(); values.update(element); gauge.set(12L); c.output(element); + sinks.add("bq", "kafka", "kafka"); // repeated should appear once + sideinputs.add("bigtable", "sql"); c.output(output2, element); } }) @@ -233,7 +245,8 @@ public static class CommittedMetricTests extends SharedTestBase { UsesCommittedMetrics.class, UsesCounterMetrics.class, UsesDistributionMetrics.class, - UsesGaugeMetrics.class + UsesGaugeMetrics.class, + UsesStringSetMetrics.class }) @Test public void testAllCommittedMetrics() { @@ -267,6 +280,14 @@ public void testCommittedGaugeMetrics() { assertGaugeMetrics(metrics, true); } + @Category({ValidatesRunner.class, UsesCommittedMetrics.class, UsesStringSetMetrics.class}) + @Test + public void testCommittedStringSetMetrics() { + PipelineResult result = runPipelineWithMetrics(); + MetricQueryResults metrics = queryTestMetrics(result); + assertStringSetMetrics(metrics, true); + } + @Test @Category({NeedsRunner.class, UsesAttemptedMetrics.class, UsesCounterMetrics.class}) public void testBoundedSourceMetrics() { @@ -352,7 +373,8 @@ public static class AttemptedMetricTests extends SharedTestBase { UsesAttemptedMetrics.class, UsesCounterMetrics.class, UsesDistributionMetrics.class, - UsesGaugeMetrics.class + UsesGaugeMetrics.class, + UsesStringSetMetrics.class }) @Test public void testAllAttemptedMetrics() { @@ -386,6 +408,14 @@ public void testAttemptedGaugeMetrics() { MetricQueryResults metrics = queryTestMetrics(result); assertGaugeMetrics(metrics, false); } + + @Category({ValidatesRunner.class, UsesAttemptedMetrics.class, UsesStringSetMetrics.class}) + @Test + public void testAttemptedStringSetMetrics() { + PipelineResult result = runPipelineWithMetrics(); + MetricQueryResults metrics = queryTestMetrics(result); + assertStringSetMetrics(metrics, false); + } } private static void assertCounterMetrics(MetricQueryResults metrics, boolean isCommitted) { @@ -415,6 +445,36 @@ private static void assertGaugeMetrics(MetricQueryResults metrics, boolean isCom isCommitted))); } + private static void assertStringSetMetrics(MetricQueryResults metrics, boolean isCommitted) { + assertThat( + metrics.getStringSets(), + containsInAnyOrder( + metricsResult( + NAMESPACE, + "sources", + "MyStep1", + StringSetResult.create(ImmutableSet.of("gcs")), + isCommitted), + metricsResult( + NAMESPACE, + "sinks", + "MyStep2", + StringSetResult.create(ImmutableSet.of("kafka", "bq")), + isCommitted), + metricsResult( + NAMESPACE, + "sideinputs", + "MyStep1", + StringSetResult.create(ImmutableSet.of("bigtable", "spanner")), + isCommitted), + metricsResult( + NAMESPACE, + "sideinputs", + "MyStep2", + StringSetResult.create(ImmutableSet.of("sql", "bigtable")), + isCommitted))); + } + private static void assertDistributionMetrics(MetricQueryResults metrics, boolean isCommitted) { assertThat( metrics.getDistributions(), @@ -458,5 +518,6 @@ private static void assertAllMetrics(MetricQueryResults metrics, boolean isCommi assertCounterMetrics(metrics, isCommitted); assertDistributionMetrics(metrics, isCommitted); assertGaugeMetrics(metrics, isCommitted); + assertStringSetMetrics(metrics, isCommitted); } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/StringSetResultTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/StringSetResultTest.java new file mode 100644 index 000000000000..85c819b4a9cb --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/StringSetResultTest.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.metrics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets.SetView; +import org.junit.Test; + +public class StringSetResultTest { + + @Test + public void getStringSet() { + // Test that getStringSet gives an immutable set + HashSet initialSet = new HashSet<>(Arrays.asList("ab", "cd")); + Set stringSetResultSet = StringSetResult.create(initialSet).getStringSet(); + assertEquals(initialSet, stringSetResultSet); + assertThrows(UnsupportedOperationException.class, () -> stringSetResultSet.add("should-fail")); + } + + @Test + public void create() { + // Test that create makes an immutable copy of the given set + HashSet modifiableSet = new HashSet<>(Arrays.asList("ab", "cd")); + StringSetResult stringSetResult = StringSetResult.create(modifiableSet); + // change the initial set. + modifiableSet.add("ef"); + SetView difference = Sets.difference(modifiableSet, stringSetResult.getStringSet()); + assertEquals(1, difference.size()); + assertEquals("ef", difference.iterator().next()); + assertTrue(Sets.difference(stringSetResult.getStringSet(), modifiableSet).isEmpty()); + } + + @Test + public void empty() { + // Test empty returns an immutable set + StringSetResult empptyStringSetResult = StringSetResult.empty(); + assertTrue(empptyStringSetResult.getStringSet().isEmpty()); + assertThrows( + UnsupportedOperationException.class, + () -> empptyStringSetResult.getStringSet().add("should-fail")); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java index dfdb6282b549..a02196bb2c05 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java @@ -37,6 +37,7 @@ import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.io.GenerateSequence; import org.apache.beam.sdk.testing.PAssert.MatcherCheckerFn; @@ -386,6 +387,36 @@ public void testPAssertEqualsSingletonFalse() throws Exception { assertThat(message, containsString("but: was <42>")); } + @Test + @Category({ValidatesRunner.class, UsesFailureMessage.class}) + public void testPAssertEqualsSingletonFailsForEmptyPCollection() throws Exception { + PCollection pcollection = pipeline.apply(Create.empty(VarIntCoder.of())); + PAssert.thatSingleton("The value was not equal to 44", pcollection).isEqualTo(44); + + Throwable thrown = runExpectingAssertionFailure(pipeline); + + String message = thrown.getMessage(); + + assertThat(message, containsString("The value was not equal to 44")); + assertThat(message, containsString("expected singleton PCollection")); + assertThat(message, containsString("but was: empty PCollection")); + } + + @Test + @Category({ValidatesRunner.class, UsesFailureMessage.class}) + public void testPAssertEqualsSingletonFailsForNonSingletonPCollection() throws Exception { + PCollection pcollection = pipeline.apply(Create.of(44, 44)); + PAssert.thatSingleton("The value was not equal to 44", pcollection).isEqualTo(44); + + Throwable thrown = runExpectingAssertionFailure(pipeline); + + String message = thrown.getMessage(); + + assertThat(message, containsString("The value was not equal to 44")); + assertThat(message, containsString("expected one element")); + assertThat(message, containsString("but was: <44, 44>")); + } + /** Test that we throw an error for false assertion on singleton. */ @Test @Category({ValidatesRunner.class, UsesFailureMessage.class}) diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/StringUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/StringUtilsTest.java index 9e9686ca2011..e8b0e7ecd470 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/StringUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/StringUtilsTest.java @@ -17,9 +17,13 @@ */ package org.apache.beam.sdk.util; +import static org.apache.commons.lang3.StringUtils.countMatches; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import java.util.UUID; +import java.util.stream.IntStream; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -54,4 +58,23 @@ public void testLevenshteinDistance() { assertEquals(1, StringUtils.getLevenshteinDistance("abc", "ab1c")); // insertion assertEquals(1, StringUtils.getLevenshteinDistance("abc", "a1c")); // modification } + + @Test + public void testArrayToNewlines() { + Object[] uuids = IntStream.range(1, 10).mapToObj(unused -> UUID.randomUUID()).toArray(); + + String r1 = StringUtils.arrayToNewlines(uuids, 6); + assertTrue(r1.endsWith("...\n")); + assertEquals(7, countMatches(r1, "\n")); + String r2 = StringUtils.arrayToNewlines(uuids, 15); + String r3 = StringUtils.arrayToNewlines(uuids, 10); + assertEquals(r3, r2); + } + + @Test + public void testLeftTruncate() { + assertEquals("", StringUtils.leftTruncate(null, 3)); + assertEquals("", StringUtils.leftTruncate("", 3)); + assertEquals("abc...", StringUtils.leftTruncate("abcd", 3)); + } } diff --git a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/RetryHttpRequestInitializer.java b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/RetryHttpRequestInitializer.java index d053a5f4bf80..b48dc6368050 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/RetryHttpRequestInitializer.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/RetryHttpRequestInitializer.java @@ -75,7 +75,7 @@ private static class LoggingHttpBackOffHandler private final Set ignoredResponseCodes; // aggregate the total time spent in exponential backoff private final Counter throttlingMsecs = - Metrics.counter(LoggingHttpBackOffHandler.class, "throttling-msecs"); + Metrics.counter(LoggingHttpBackOffHandler.class, Metrics.THROTTLE_TIME_COUNTER_NAME); private int ioExceptionRetries; private int unsuccessfulResponseRetries; private @Nullable CustomHttpErrors customHttpErrors; diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java index 5d856ee63063..bcd243ba746d 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java @@ -42,6 +42,7 @@ import org.apache.beam.sdk.metrics.Histogram; import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.sdk.metrics.MetricsContainer; +import org.apache.beam.sdk.metrics.StringSet; import org.apache.beam.sdk.options.ExecutorOptions; import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.PipelineOptions; @@ -216,6 +217,14 @@ public Gauge getGauge(MetricName metricName) { return tracker.metricsContainerRegistry.getUnboundContainer().getGauge(metricName); } + @Override + public StringSet getStringSet(MetricName metricName) { + if (tracker.currentState != null) { + return tracker.currentState.metricsContainer.getStringSet(metricName); + } + return tracker.metricsContainerRegistry.getUnboundContainer().getStringSet(metricName); + } + @Override public Histogram getHistogram(MetricName metricName, HistogramData.BucketType bucketType) { if (tracker.currentState != null) { diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java index 93f89301d158..5b304890b354 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java @@ -813,14 +813,8 @@ public WatermarkHoldState bindWatermark( private Cache getCacheFor(StateKey stateKey) { switch (stateKey.getTypeCase()) { case BAG_USER_STATE: - for (CacheToken token : cacheTokens.get()) { - if (!token.hasUserState()) { - continue; - } - return Caches.subCache(processWideCache, token, stateKey); - } - break; case MULTIMAP_KEYS_USER_STATE: + case ORDERED_LIST_USER_STATE: for (CacheToken token : cacheTokens.get()) { if (!token.hasUserState()) { continue; diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java index c2fd308205a1..1f4341860295 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java @@ -42,10 +42,12 @@ import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.metrics.MetricsEnvironment; +import org.apache.beam.sdk.metrics.StringSet; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.ExpectedLogs; import org.apache.beam.sdk.util.HistogramData; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.joda.time.DateTimeUtils.MillisProvider; import org.joda.time.Duration; import org.junit.After; @@ -65,6 +67,8 @@ public class ExecutionStateSamplerTest { private static final Distribution TEST_USER_DISTRIBUTION = Metrics.distribution("foo", "distribution"); private static final Gauge TEST_USER_GAUGE = Metrics.gauge("foo", "gauge"); + + private static final StringSet TEST_USER_STRING_SET = Metrics.stringSet("foo", "stringset"); private static final Histogram TEST_USER_HISTOGRAM = new DelegatingHistogram( MetricName.named("foo", "histogram"), HistogramData.LinearBuckets.of(0, 100, 1), false); @@ -375,12 +379,14 @@ public void testCountersReturnedAreBasedUponCurrentExecutionState() throws Excep TEST_USER_COUNTER.inc(); TEST_USER_DISTRIBUTION.update(2); TEST_USER_GAUGE.set(3); + TEST_USER_STRING_SET.add("ab"); TEST_USER_HISTOGRAM.update(4); state.deactivate(); TEST_USER_COUNTER.inc(11); TEST_USER_DISTRIBUTION.update(12); TEST_USER_GAUGE.set(13); + TEST_USER_STRING_SET.add("cd"); TEST_USER_HISTOGRAM.update(14); TEST_USER_HISTOGRAM.update(14); @@ -411,6 +417,14 @@ public void testCountersReturnedAreBasedUponCurrentExecutionState() throws Excep .getGauge(TEST_USER_GAUGE.getName()) .getCumulative() .value()); + assertEquals( + ImmutableSet.of("ab"), + tracker + .getMetricsContainerRegistry() + .getContainer("ptransformId") + .getStringSet(TEST_USER_STRING_SET.getName()) + .getCumulative() + .stringSet()); assertEquals( 1L, (long) @@ -449,6 +463,14 @@ public void testCountersReturnedAreBasedUponCurrentExecutionState() throws Excep .getGauge(TEST_USER_GAUGE.getName()) .getCumulative() .value()); + assertEquals( + ImmutableSet.of("cd"), + tracker + .getMetricsContainerRegistry() + .getUnboundContainer() + .getStringSet(TEST_USER_STRING_SET.getName()) + .getCumulative() + .stringSet()); assertEquals( 2L, (long) diff --git a/sdks/java/io/common/src/main/java/org/apache/beam/sdk/io/common/SchemaAwareJavaBeans.java b/sdks/java/io/common/src/main/java/org/apache/beam/sdk/io/common/SchemaAwareJavaBeans.java index b97d4ab8c5f8..76535c3e17f6 100644 --- a/sdks/java/io/common/src/main/java/org/apache/beam/sdk/io/common/SchemaAwareJavaBeans.java +++ b/sdks/java/io/common/src/main/java/org/apache/beam/sdk/io/common/SchemaAwareJavaBeans.java @@ -137,7 +137,7 @@ public static DoublyNestedDataTypes doublyNestedDataTypes( .build(); } - private static final TypeDescriptor + public static final TypeDescriptor ALL_PRIMITIVE_DATA_TYPES_TYPE_DESCRIPTOR = TypeDescriptor.of(AllPrimitiveDataTypes.class); /** The schema for {@link AllPrimitiveDataTypes}. */ @@ -160,7 +160,7 @@ public static SerializableFunction allPrimitiveDataT return DEFAULT_SCHEMA_PROVIDER.fromRowFunction(ALL_PRIMITIVE_DATA_TYPES_TYPE_DESCRIPTOR); } - private static final TypeDescriptor + public static final TypeDescriptor NULLABLE_ALL_PRIMITIVE_DATA_TYPES_TYPE_DESCRIPTOR = TypeDescriptor.of(NullableAllPrimitiveDataTypes.class); @@ -187,7 +187,7 @@ public static SerializableFunction allPrimitiveDataT NULLABLE_ALL_PRIMITIVE_DATA_TYPES_TYPE_DESCRIPTOR); } - private static final TypeDescriptor TIME_CONTAINING_TYPE_DESCRIPTOR = + public static final TypeDescriptor TIME_CONTAINING_TYPE_DESCRIPTOR = TypeDescriptor.of(TimeContaining.class); /** The schema for {@link TimeContaining}. */ @@ -250,7 +250,7 @@ public static SerializableFunction byteSequenceTypeFromRo return DEFAULT_SCHEMA_PROVIDER.fromRowFunction(BYTE_SEQUENCE_TYPE_TYPE_DESCRIPTOR); } - private static final TypeDescriptor + public static final TypeDescriptor ARRAY_PRIMITIVE_DATA_TYPES_TYPE_DESCRIPTOR = TypeDescriptor.of(ArrayPrimitiveDataTypes.class); /** The schema for {@link ArrayPrimitiveDataTypes}. */ diff --git a/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIO.java b/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIO.java index 5bed0186e0d6..04141e5c677a 100644 --- a/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIO.java +++ b/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIO.java @@ -55,6 +55,72 @@ *

Reading from CSV files is not yet implemented. Please see https://github.com/apache/beam/issues/24552. * + *

Valid CSVFormat Configuration

+ * + *

A {@code + * CSVFormat} must meet the following conditions to be considered valid when reading CSV: + * + *

+ * + *

Ignored CSVFormat parameters

+ * + *

The following {@code CSVFormat} parameters are either not relevant for parsing CSV or are + * validated satisfactorily by the Apache Commons CSV + * library. + * + *

+ * *

Writing CSV files

* *

To write a {@link PCollection} to one or more CSV files, use {@link CsvIO.Write}, using {@link diff --git a/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOParseConfiguration.java b/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOParseConfiguration.java index 22f06edc8322..87e0128d73eb 100644 --- a/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOParseConfiguration.java +++ b/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOParseConfiguration.java @@ -18,26 +18,34 @@ package org.apache.beam.sdk.io.csv; import com.google.auto.value.AutoValue; +import java.io.Serializable; import java.util.HashMap; import java.util.Map; import java.util.Optional; +import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; import org.apache.commons.csv.CSVFormat; /** Stores parameters needed for CSV record parsing. */ @AutoValue -abstract class CsvIOParseConfiguration { +abstract class CsvIOParseConfiguration implements Serializable { - static Builder builder() { - return new AutoValue_CsvIOParseConfiguration.Builder(); + /** A Dead Letter Queue that returns potential errors with {@link BadRecord}. */ + final PTransform, PCollection> errorHandlerTransform = + new BadRecordOutput(); + + static Builder builder() { + return new AutoValue_CsvIOParseConfiguration.Builder<>(); } - /** - * The expected CSVFormat - * of the parsed CSV record. - */ + /** The expected {@link CSVFormat} of the parsed CSV record. */ abstract CSVFormat getCsvFormat(); /** The expected {@link Schema} of the target type. */ @@ -46,24 +54,50 @@ static Builder builder() { /** A map of the {@link Schema.Field#getName()} to the custom CSV processing lambda. */ abstract Map> getCustomProcessingMap(); + /** The expected {@link Coder} of the target type. */ + abstract Coder getCoder(); + + /** A {@link SerializableFunction} that converts from Row to the target type. */ + abstract SerializableFunction getFromRowFn(); + @AutoValue.Builder - abstract static class Builder { - abstract Builder setCsvFormat(CSVFormat csvFormat); + abstract static class Builder implements Serializable { + abstract Builder setCsvFormat(CSVFormat csvFormat); - abstract Builder setSchema(Schema schema); + abstract Builder setSchema(Schema schema); - abstract Builder setCustomProcessingMap( + abstract Builder setCustomProcessingMap( Map> customProcessingMap); + abstract Builder setCoder(Coder coder); + + abstract Builder setFromRowFn(SerializableFunction fromRowFn); + abstract Optional>> getCustomProcessingMap(); - abstract CsvIOParseConfiguration autoBuild(); + abstract CsvIOParseConfiguration autoBuild(); - final CsvIOParseConfiguration build() { + final CsvIOParseConfiguration build() { if (!getCustomProcessingMap().isPresent()) { setCustomProcessingMap(new HashMap<>()); } return autoBuild(); } } + + private static class BadRecordOutput + extends PTransform, PCollection> { + + @Override + public PCollection expand(PCollection input) { + return input.apply(ParDo.of(new BadRecordTransformFn())); + } + + private static class BadRecordTransformFn extends DoFn { + @ProcessElement + public void process(@Element BadRecord input, OutputReceiver receiver) { + receiver.output(input); + } + } + } } diff --git a/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOParseHelpers.java b/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOParseHelpers.java index 042e284cd527..15a398d3c557 100644 --- a/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOParseHelpers.java +++ b/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOParseHelpers.java @@ -17,39 +17,146 @@ */ package org.apache.beam.sdk.io.csv; -import java.util.ArrayList; +import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; + +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.commons.csv.CSVFormat; +import org.joda.time.Instant; /** A utility class containing shared methods for parsing CSV records. */ final class CsvIOParseHelpers { - /** Validate the {@link CSVFormat} for CSV record parsing requirements. */ - // TODO(https://github.com/apache/beam/issues/31712): implement method. - static void validate(CSVFormat format) {} + /** + * Validate the {@link CSVFormat} for CSV record parsing requirements. See the public-facing + * "Reading CSV Files" section of the {@link CsvIO} documentation for information regarding which + * {@link CSVFormat} parameters are checked during validation. + */ + static void validateCsvFormat(CSVFormat format) { + String[] header = + checkArgumentNotNull(format.getHeader(), "Illegal %s: header is required", CSVFormat.class); + + checkArgument(header.length > 0, "Illegal %s: header cannot be empty", CSVFormat.class); + + checkArgument( + !format.getAllowMissingColumnNames(), + "Illegal %s: cannot allow missing column names", + CSVFormat.class); + + checkArgument( + !format.getIgnoreHeaderCase(), "Illegal %s: cannot ignore header case", CSVFormat.class); + + checkArgument( + !format.getAllowDuplicateHeaderNames(), + "Illegal %s: cannot allow duplicate header names", + CSVFormat.class); + + for (String columnName : header) { + checkArgument( + !Strings.isNullOrEmpty(columnName), + "Illegal %s: column name is required", + CSVFormat.class); + } + checkArgument( + !format.getSkipHeaderRecord(), + "Illegal %s: cannot skip header record because the header is already accounted for", + CSVFormat.class); + } /** * Validate the {@link CSVFormat} in relation to the {@link Schema} for CSV record parsing * requirements. */ - // TODO(https://github.com/apache/beam/issues/31716): implement method. - static void validate(CSVFormat format, Schema schema) {} + static void validateCsvFormatWithSchema(CSVFormat format, Schema schema) { + List header = Arrays.asList(format.getHeader()); + for (Schema.Field field : schema.getFields()) { + String fieldName = field.getName(); + if (!field.getType().getNullable()) { + checkArgument( + header.contains(fieldName), + "Illegal %s: required %s field '%s' not found in header", + CSVFormat.class, + Schema.class.getTypeName(), + fieldName); + } + } + } /** * Build a {@link List} of {@link Schema.Field}s corresponding to the expected position of each * field within the CSV record. */ - // TODO(https://github.com/apache/beam/issues/31718): implement method. - static List mapFieldPositions(CSVFormat format, Schema schema) { - return new ArrayList<>(); + static Map mapFieldPositions(CSVFormat format, Schema schema) { + List header = Arrays.asList(format.getHeader()); + Map indexToFieldMap = new HashMap<>(); + for (Schema.Field field : schema.getFields()) { + int index = getIndex(header, field); + if (index >= 0) { + indexToFieldMap.put(index, field); + } + } + return indexToFieldMap; + } + + /** + * Attains expected index from {@link CSVFormat's} header matching a given {@link Schema.Field}. + */ + private static int getIndex(List header, Schema.Field field) { + String fieldName = field.getName(); + boolean presentInHeader = header.contains(fieldName); + boolean isNullable = field.getType().getNullable(); + if (presentInHeader) { + return header.indexOf(fieldName); + } + if (isNullable) { + return -1; + } + + throw new IllegalArgumentException( + String.format("header does not contain required %s field: %s", Schema.class, fieldName)); } /** * Parse the given {@link String} cell of the CSV record based on the given field's {@link * Schema.FieldType}. */ - // TODO(https://github.com/apache/beam/issues/31719): implement method. static Object parseCell(String cell, Schema.Field field) { - return ""; + Schema.FieldType fieldType = field.getType(); + try { + switch (fieldType.getTypeName()) { + case STRING: + return cell; + case INT16: + return Short.parseShort(cell); + case INT32: + return Integer.parseInt(cell); + case INT64: + return Long.parseLong(cell); + case BOOLEAN: + return Boolean.parseBoolean(cell); + case BYTE: + return Byte.parseByte(cell); + case DECIMAL: + return new BigDecimal(cell); + case DOUBLE: + return Double.parseDouble(cell); + case FLOAT: + return Float.parseFloat(cell); + case DATETIME: + return Instant.parse(cell); + default: + throw new UnsupportedOperationException( + "Unsupported type: " + fieldType + ", consider using withCustomRecordParsing"); + } + + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + e.getMessage() + " field " + field.getName() + " was received -- type mismatch"); + } } } diff --git a/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOParseKV.java b/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOParseKV.java new file mode 100644 index 000000000000..1b8e43314b14 --- /dev/null +++ b/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOParseKV.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.csv; + +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.commons.csv.CSVRecord; + +/** + * A {@link PTransform} that takes an input {@link PCollection>} and outputs a + * {@link PCollection} of custom type. + */ +// TODO(https://github.com/apache/beam/issues/31873): implement class after all dependencies are +// completed. +class CsvIOParseKV + extends PTransform>>, PCollection> { + + // TODO(https://github.com/apache/beam/issues/31873): implement method. + @Override + public PCollection expand(PCollection>> input) { + return input.apply(ParDo.of(new DoFn>, T>() {})); + } +} diff --git a/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOParseResult.java b/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOParseResult.java deleted file mode 100644 index 5d4d4c8c02e9..000000000000 --- a/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOParseResult.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.csv; - -import java.util.Map; -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionTuple; -import org.apache.beam.sdk.values.PInput; -import org.apache.beam.sdk.values.POutput; -import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; - -/** - * The {@link T} and {@link org.apache.beam.sdk.io.csv.CsvIOParseError} {@link PCollection} results - * of parsing CSV records. Use {@link #getOutput()} and {@link #getErrors()} to apply these results - * in a pipeline. - */ -public class CsvIOParseResult implements POutput { - - static CsvIOParseResult of( - TupleTag outputTag, TupleTag errorTag, PCollectionTuple pct) { - return new CsvIOParseResult<>(outputTag, errorTag, pct); - } - - private final Pipeline pipeline; - private final TupleTag outputTag; - private final PCollection output; - private final TupleTag errorTag; - private final PCollection errors; - - private CsvIOParseResult( - TupleTag outputTag, TupleTag errorTag, PCollectionTuple pct) { - this.outputTag = outputTag; - this.errorTag = errorTag; - this.pipeline = pct.getPipeline(); - this.output = pct.get(outputTag); - this.errors = pct.get(errorTag); - } - - /** The {@link T} {@link PCollection} as a result of successfully parsing CSV records. */ - public PCollection getOutput() { - return output; - } - - /** - * The {@link org.apache.beam.sdk.io.csv.CsvIOParseError} {@link PCollection} as a result of - * errors associated with parsing CSV records. - */ - public PCollection getErrors() { - return errors; - } - - @Override - public Pipeline getPipeline() { - return pipeline; - } - - @Override - public Map, PValue> expand() { - return ImmutableMap.of( - outputTag, output, - errorTag, errors); - } - - @Override - public void finishSpecifyingOutput( - String transformName, PInput input, PTransform transform) {} -} diff --git a/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOReadFiles.java b/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOReadFiles.java index 0f6267c6b34c..b28072091326 100644 --- a/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOReadFiles.java +++ b/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOReadFiles.java @@ -17,38 +17,32 @@ */ package org.apache.beam.sdk.io.csv; -import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.io.FileIO; +import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionTuple; -import org.apache.beam.sdk.values.TupleTag; /** * Skeleton for error handling in CsvIO that transforms a {@link FileIO.ReadableFile} into the * result of parsing. */ // TODO(https://github.com/apache/beam/issues/31736): Plan completion in future PR after -// dependencies are completed. -class CsvIOReadFiles extends PTransform, CsvIOParseResult> { +// dependencies are completed. +class CsvIOReadFiles extends PTransform, PCollection> { /** Stores required parameters for parsing. */ - private final CsvIOParseConfiguration.Builder configBuilder; + private final CsvIOParseConfiguration.Builder configBuilder; - CsvIOReadFiles(CsvIOParseConfiguration.Builder configBuilder) { + CsvIOReadFiles(CsvIOParseConfiguration.Builder configBuilder) { this.configBuilder = configBuilder; } /** {@link PTransform} that parses and relays the filename associated with each error. */ - // TODO: complete expand method to unsure parsing from FileIO.ReadableFile to CsvIOParseResult. @Override - public CsvIOParseResult expand(PCollection input) { + public PCollection expand(PCollection input) { // TODO(https://github.com/apache/beam/issues/31736): Needed to prevent check errors, will - // remove with future PR. + // remove with future PR. configBuilder.build(); - TupleTag outputTag = new TupleTag<>(); - TupleTag errorTag = new TupleTag<>(); - Pipeline p = input.getPipeline(); - PCollectionTuple tuple = PCollectionTuple.empty(p); - return CsvIOParseResult.of(outputTag, errorTag, tuple); + return input.apply(ParDo.of(new DoFn() {})); } } diff --git a/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIORecordToObjects.java b/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIORecordToObjects.java new file mode 100644 index 000000000000..4340b68f3c49 --- /dev/null +++ b/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIORecordToObjects.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.csv; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; + +/** + * {@link CsvIORecordToObjects} is a class that takes an input of {@link PCollection>} + * and outputs custom type {@link PCollection}. + */ +class CsvIORecordToObjects extends PTransform>, PCollection> { + + /** The expected {@link Schema} of the target type. */ + private final Schema schema; + + /** A map of the {@link Schema.Field#getName()} to the custom CSV processing lambda. */ + private final Map> customProcessingMap; + + /** A {@link Map} of {@link Schema.Field}s to their expected positions within the CSV record. */ + private final Map indexToFieldMap; + + /** + * A {@link SerializableFunction} that converts from {@link Row} to {@link Schema} mapped custom + * type. + */ + private final SerializableFunction fromRowFn; + + /** The expected coder of target type. */ + private final Coder coder; + + CsvIORecordToObjects(CsvIOParseConfiguration configuration) { + this.schema = configuration.getSchema(); + this.customProcessingMap = configuration.getCustomProcessingMap(); + this.indexToFieldMap = + CsvIOParseHelpers.mapFieldPositions(configuration.getCsvFormat(), schema); + this.fromRowFn = configuration.getFromRowFn(); + this.coder = configuration.getCoder(); + } + + @Override + public PCollection expand(PCollection> input) { + return input.apply(ParDo.of(new RecordToObjectsFn())).setCoder(coder); + } + + private class RecordToObjectsFn extends DoFn, T> { + @ProcessElement + public void process(@Element List record, OutputReceiver receiver) { + Map fieldNamesToValues = new HashMap<>(); + for (Map.Entry entry : indexToFieldMap.entrySet()) { + Schema.Field field = entry.getValue(); + int index = entry.getKey(); + String cell = record.get(index); + Object value = parseCell(cell, field); + fieldNamesToValues.put(field.getName(), value); + } + Row row = Row.withSchema(schema).withFieldValues(fieldNamesToValues).build(); + receiver.output(fromRowFn.apply(row)); + } + } + + /** Parses cell to emit the value, as well as potential errors with filename. */ + Object parseCell(String cell, Schema.Field field) { + if (cell == null) { + if (!field.getType().getNullable()) { + throw new IllegalArgumentException( + "Required org.apache.beam.sdk.schemas.Schema field " + + field.getName() + + " has null value"); + } + return cell; + } + if (customProcessingMap.containsKey(field.getName())) { + return customProcessingMap.get(field.getName()).apply(cell); + } + return CsvIOParseHelpers.parseCell(cell, field); + } +} diff --git a/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOStringToCsvRecord.java b/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOStringToCsvRecord.java new file mode 100644 index 000000000000..b5ce6a0fec22 --- /dev/null +++ b/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIOStringToCsvRecord.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.csv; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.coders.NullableCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PCollection; +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.CSVParser; +import org.apache.commons.csv.CSVRecord; + +/** + * {@link CsvIOStringToCsvRecord} is a class that takes a {@link PCollection} input and + * outputs a {@link PCollection} with potential {@link PCollection} for + * targeted error detection. + */ +final class CsvIOStringToCsvRecord + extends PTransform, PCollection>> { + private final CSVFormat csvFormat; + + CsvIOStringToCsvRecord(CSVFormat csvFormat) { + this.csvFormat = csvFormat; + } + + /** + * Creates {@link PCollection} from {@link PCollection} for future processing + * to Row or custom type. + */ + @Override + public PCollection> expand(PCollection input) { + return input + .apply(ParDo.of(new ProcessLineToRecordFn())) + .setCoder(ListCoder.of(NullableCoder.of(StringUtf8Coder.of()))); + } + + /** Processes each line in order to convert it to a {@link CSVRecord}. */ + private class ProcessLineToRecordFn extends DoFn> { + private final String headerLine = headerLine(csvFormat); + + @ProcessElement + public void process(@Element String line, OutputReceiver> receiver) + throws IOException { + if (headerLine.equals(line)) { + return; + } + for (CSVRecord record : CSVParser.parse(line, csvFormat).getRecords()) { + receiver.output(csvRecordtoList(record)); + } + } + } + + /** Creates a {@link List} containing {@link CSVRecord} values. */ + private static List csvRecordtoList(CSVRecord record) { + List cells = new ArrayList<>(); + for (String cell : record) { + cells.add(cell); + } + return cells; + } + + /** Returns a formatted line of the CSVFormat header. */ + static String headerLine(CSVFormat csvFormat) { + return String.join(String.valueOf(csvFormat.getDelimiter()), csvFormat.getHeader()); + } +} diff --git a/sdks/java/io/csv/src/test/java/org/apache/beam/sdk/io/csv/CsvIOParseHelpersTest.java b/sdks/java/io/csv/src/test/java/org/apache/beam/sdk/io/csv/CsvIOParseHelpersTest.java new file mode 100644 index 000000000000..5276fa008c7c --- /dev/null +++ b/sdks/java/io/csv/src/test/java/org/apache/beam/sdk/io/csv/CsvIOParseHelpersTest.java @@ -0,0 +1,608 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.csv; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import java.math.BigDecimal; +import java.util.Map; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.commons.collections.keyvalue.DefaultMapEntry; +import org.apache.commons.csv.CSVFormat; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link CsvIOParseHelpers}. */ +@RunWith(JUnit4.class) +public class CsvIOParseHelpersTest { + + /** Tests for {@link CsvIOParseHelpers#validateCsvFormat(CSVFormat)}. */ + @Test + public void givenCSVFormatWithHeader_validates() { + CSVFormat format = csvFormatWithHeader(); + CsvIOParseHelpers.validateCsvFormat(format); + } + + @Test + public void givenCSVFormatWithNullHeader_throwsException() { + CSVFormat format = csvFormat(); + String gotMessage = + assertThrows( + IllegalArgumentException.class, () -> CsvIOParseHelpers.validateCsvFormat(format)) + .getMessage(); + assertEquals("Illegal class org.apache.commons.csv.CSVFormat: header is required", gotMessage); + } + + @Test + public void givenCSVFormatWithEmptyHeader_throwsException() { + CSVFormat format = csvFormat().withHeader(); + String gotMessage = + assertThrows( + IllegalArgumentException.class, () -> CsvIOParseHelpers.validateCsvFormat(format)) + .getMessage(); + assertEquals( + "Illegal class org.apache.commons.csv.CSVFormat: header cannot be empty", gotMessage); + } + + @Test + public void givenCSVFormatWithHeaderContainingEmptyString_throwsException() { + CSVFormat format = csvFormat().withHeader("", "bar"); + String gotMessage = + assertThrows( + IllegalArgumentException.class, () -> CsvIOParseHelpers.validateCsvFormat(format)) + .getMessage(); + assertEquals( + "Illegal class org.apache.commons.csv.CSVFormat: column name is required", gotMessage); + } + + @Test + public void givenCSVFormatWithHeaderContainingNull_throwsException() { + CSVFormat format = csvFormat().withHeader(null, "bar"); + String gotMessage = + assertThrows( + IllegalArgumentException.class, () -> CsvIOParseHelpers.validateCsvFormat(format)) + .getMessage(); + assertEquals( + "Illegal class org.apache.commons.csv.CSVFormat: column name is required", gotMessage); + } + + @Test + public void givenCSVFormatThatAllowsMissingColumnNames_throwsException() { + CSVFormat format = csvFormatWithHeader().withAllowMissingColumnNames(true); + String gotMessage = + assertThrows( + IllegalArgumentException.class, () -> CsvIOParseHelpers.validateCsvFormat(format)) + .getMessage(); + assertEquals( + "Illegal class org.apache.commons.csv.CSVFormat: cannot allow missing column names", + gotMessage); + } + + @Test + public void givenCSVFormatThatIgnoresHeaderCase_throwsException() { + CSVFormat format = csvFormatWithHeader().withIgnoreHeaderCase(true); + String gotMessage = + assertThrows( + IllegalArgumentException.class, () -> CsvIOParseHelpers.validateCsvFormat(format)) + .getMessage(); + assertEquals( + "Illegal class org.apache.commons.csv.CSVFormat: cannot ignore header case", gotMessage); + } + + @Test + public void givenCSVFormatThatAllowsDuplicateHeaderNames_throwsException() { + CSVFormat format = csvFormatWithHeader().withAllowDuplicateHeaderNames(true); + String gotMessage = + assertThrows( + IllegalArgumentException.class, () -> CsvIOParseHelpers.validateCsvFormat(format)) + .getMessage(); + assertEquals( + "Illegal class org.apache.commons.csv.CSVFormat: cannot allow duplicate header names", + gotMessage); + } + + @Test + public void givenCSVFormatThatSkipsHeaderRecord_throwsException() { + CSVFormat format = csvFormatWithHeader().withSkipHeaderRecord(true); + String gotMessage = + assertThrows( + IllegalArgumentException.class, () -> CsvIOParseHelpers.validateCsvFormat(format)) + .getMessage(); + assertEquals( + "Illegal class org.apache.commons.csv.CSVFormat: cannot skip header record because the header is already accounted for", + gotMessage); + } + + /** End of tests for {@link CsvIOParseHelpers#validateCsvFormat(CSVFormat)}. */ + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** Tests for {@link CsvIOParseHelpers#validateCsvFormatWithSchema(CSVFormat, Schema)}. */ + @Test + public void givenNullableSchemaFieldNotPresentInHeader_validates() { + CSVFormat format = csvFormat().withHeader("foo", "bar"); + Schema schema = + Schema.of( + Schema.Field.of("foo", Schema.FieldType.STRING), + Schema.Field.of("bar", Schema.FieldType.STRING), + Schema.Field.nullable("baz", Schema.FieldType.STRING)); + CsvIOParseHelpers.validateCsvFormatWithSchema(format, schema); + } + + @Test + public void givenRequiredSchemaFieldNotPresentInHeader_throwsException() { + CSVFormat format = csvFormat().withHeader("foo", "bar"); + Schema schema = + Schema.of( + Schema.Field.of("foo", Schema.FieldType.STRING), + Schema.Field.of("bar", Schema.FieldType.STRING), + Schema.Field.of("baz", Schema.FieldType.STRING)); + String gotMessage = + assertThrows( + IllegalArgumentException.class, + () -> CsvIOParseHelpers.validateCsvFormatWithSchema(format, schema)) + .getMessage(); + assertEquals( + "Illegal class org.apache.commons.csv.CSVFormat: required org.apache.beam.sdk.schemas.Schema field 'baz' not found in header", + gotMessage); + } + + /** End of tests for {@link CsvIOParseHelpers#validateCsvFormatWithSchema(CSVFormat, Schema)}. */ + ////////////////////////////////////////////////////////////////////////////////////////////// + /** Tests for {@link CsvIOParseHelpers#mapFieldPositions(CSVFormat, Schema)}. */ + @Test + public void testHeaderWithComments() { + String[] comments = {"first line", "second line", "third line"}; + Schema schema = + Schema.builder().addStringField("a_string").addStringField("another_string").build(); + ImmutableMap want = + ImmutableMap.of(0, schema.getField("a_string"), 1, schema.getField("another_string")); + Map got = + CsvIOParseHelpers.mapFieldPositions( + csvFormat() + .withHeader("a_string", "another_string") + .withHeaderComments((Object) comments), + schema); + assertEquals(want, got); + } + + @Test + public void givenMatchingHeaderAndSchemaField_mapsPositions() { + Schema schema = + Schema.builder() + .addStringField("a_string") + .addDoubleField("a_double") + .addInt32Field("an_integer") + .build(); + ImmutableMap want = + ImmutableMap.of( + 0, + schema.getField("a_string"), + 1, + schema.getField("an_integer"), + 2, + schema.getField("a_double")); + Map got = + CsvIOParseHelpers.mapFieldPositions( + csvFormat().withHeader("a_string", "an_integer", "a_double"), schema); + assertEquals(want, got); + } + + @Test + public void givenSchemaContainsNullableFieldTypes() { + Schema schema = + Schema.builder() + .addNullableStringField("a_string") + .addDoubleField("a_double") + .addInt32Field("an_integer") + .addDateTimeField("a_datetime") + .addNullableStringField("another_string") + .build(); + ImmutableMap want = + ImmutableMap.of( + 0, + schema.getField("an_integer"), + 1, + schema.getField("a_double"), + 2, + schema.getField("a_datetime")); + Map got = + CsvIOParseHelpers.mapFieldPositions( + csvFormat().withHeader("an_integer", "a_double", "a_datetime"), schema); + assertEquals(want, got); + } + + @Test + public void givenNonNullableHeaderAndSchemaFieldMismatch_throws() { + Schema schema = + Schema.builder() + .addStringField("another_string") + .addInt32Field("an_integer") + .addStringField("a_string") + .build(); + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, + () -> + CsvIOParseHelpers.mapFieldPositions( + csvFormat().withHeader("an_integer", "a_string"), schema)); + assertEquals( + "header does not contain required class org.apache.beam.sdk.schemas.Schema field: " + + schema.getField("another_string").getName(), + e.getMessage()); + } + + /** End of tests for {@link CsvIOParseHelpers#mapFieldPositions(CSVFormat, Schema)} */ + + //////////////////////////////////////////////////////////////////////////////////////////// + + /** Tests for {@link CsvIOParseHelpers#parseCell(String, Schema.Field)}. */ + @Test + public void ignoresCaseFormat() { + String allCapsBool = "TRUE"; + Schema schema = Schema.builder().addBooleanField("a_boolean").build(); + assertEquals(true, CsvIOParseHelpers.parseCell(allCapsBool, schema.getField("a_boolean"))); + } + + @Test + public void givenIntegerWithSurroundingSpaces_throws() { + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry(" 12 ", 12); + Schema schema = Schema.builder().addInt32Field("an_integer").addStringField("a_string").build(); + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, + () -> + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("an_integer"))); + assertEquals( + "For input string: \"" + + cellToExpectedValue.getKey() + + "\" field " + + schema.getField("an_integer").getName() + + " was received -- type mismatch", + e.getMessage()); + } + + @Test + public void givenDoubleWithSurroundingSpaces_parses() { + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry(" 20.04 ", 20.04); + Schema schema = Schema.builder().addDoubleField("a_double").addInt32Field("an_integer").build(); + assertEquals( + cellToExpectedValue.getValue(), + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("a_double"))); + } + + @Test + public void givenStringWithSurroundingSpaces_parsesIncorrectly() { + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry(" a ", "a"); + Schema schema = Schema.builder().addStringField("a_string").addInt64Field("a_long").build(); + assertEquals( + cellToExpectedValue.getKey(), + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("a_string"))); + } + + @Test + public void givenBigDecimalWithSurroundingSpaces_throws() { + BigDecimal decimal = new BigDecimal("123.456"); + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry(" 123.456 ", decimal); + Schema schema = + Schema.builder().addDecimalField("a_decimal").addStringField("a_string").build(); + assertThrows( + IllegalArgumentException.class, + () -> + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("a_decimal"))); + } + + @Test + public void givenShortWithSurroundingSpaces_throws() { + Short shortNum = Short.parseShort("12"); + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry(" 12 ", shortNum); + Schema schema = + Schema.builder() + .addInt16Field("a_short") + .addInt32Field("an_integer") + .addInt64Field("a_long") + .build(); + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, + () -> + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("a_short"))); + assertEquals( + "For input string: \"" + + cellToExpectedValue.getKey() + + "\" field " + + schema.getField("a_short").getName() + + " was received -- type mismatch", + e.getMessage()); + } + + @Test + public void givenLongWithSurroundingSpaces_throws() { + Long longNum = Long.parseLong("3400000000"); + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry(" 12 ", longNum); + Schema schema = + Schema.builder() + .addInt16Field("a_short") + .addInt32Field("an_integer") + .addInt64Field("a_long") + .build(); + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, + () -> + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("a_long"))); + assertEquals( + "For input string: \"" + + cellToExpectedValue.getKey() + + "\" field " + + schema.getField("a_long").getName() + + " was received -- type mismatch", + e.getMessage()); + } + + @Test + public void givenFloatWithSurroundingSpaces_parses() { + Float floatNum = Float.parseFloat("3.141592"); + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry(" 3.141592 ", floatNum); + Schema schema = + Schema.builder() + .addFloatField("a_float") + .addInt32Field("an_integer") + .addStringField("a_string") + .build(); + assertEquals( + cellToExpectedValue.getValue(), + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("a_float"))); + } + + @Test + public void givenDatetimeWithSurroundingSpaces() { + Instant datetime = Instant.parse("1234-01-23T10:00:05.000Z"); + DefaultMapEntry cellToExpectedValue = + new DefaultMapEntry(" 1234-01-23T10:00:05.000Z ", datetime); + Schema schema = + Schema.builder().addDateTimeField("a_datetime").addStringField("a_string").build(); + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, + () -> + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("a_datetime"))); + assertEquals( + "Invalid format: \" 1234-01-23T10:00:05.000Z \" field a_datetime was received -- type mismatch", + e.getMessage()); + } + + @Test + public void givenByteWithSurroundingSpaces_throws() { + Byte byteNum = Byte.parseByte("40"); + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry(" 40 ", byteNum); + Schema schema = Schema.builder().addByteField("a_byte").addInt32Field("an_integer").build(); + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, + () -> + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("a_byte"))); + assertEquals( + "For input string: \"" + + cellToExpectedValue.getKey() + + "\" field " + + schema.getField("a_byte").getName() + + " was received -- type mismatch", + e.getMessage()); + } + + @Test + public void givenBooleanWithSurroundingSpaces_returnsInverse() { + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry(" true ", true); + Schema schema = + Schema.builder() + .addBooleanField("a_boolean") + .addInt32Field("an_integer") + .addStringField("a_string") + .build(); + assertEquals( + false, + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("a_boolean"))); + } + + @Test + public void givenMultiLineCell_parses() { + String multiLineString = "a\na\na\na\na\na\na\na\na\nand"; + Schema schema = Schema.builder().addStringField("a_string").addDoubleField("a_double").build(); + assertEquals( + multiLineString, CsvIOParseHelpers.parseCell(multiLineString, schema.getField("a_string"))); + } + + @Test + public void givenValidIntegerCell_parses() { + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry("12", 12); + Schema schema = Schema.builder().addInt32Field("an_integer").addInt64Field("a_long").build(); + assertEquals( + cellToExpectedValue.getValue(), + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("an_integer"))); + } + + @Test + public void givenValidDoubleCell_parses() { + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry("10.05", 10.05); + Schema schema = Schema.builder().addDoubleField("a_double").addStringField("a_string").build(); + assertEquals( + cellToExpectedValue.getValue(), + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("a_double"))); + } + + @Test + public void givenValidStringCell_parses() { + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry("lithium", "lithium"); + Schema schema = + Schema.builder().addStringField("a_string").addDateTimeField("a_datetime").build(); + assertEquals( + cellToExpectedValue.getValue(), + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("a_string"))); + } + + @Test + public void givenValidDecimalCell_parses() { + BigDecimal decimal = new BigDecimal("127.99"); + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry("127.99", decimal); + Schema schema = + Schema.builder().addDecimalField("a_decimal").addDoubleField("a_double").build(); + assertEquals( + cellToExpectedValue.getValue(), + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("a_decimal"))); + } + + @Test + public void givenValidShortCell_parses() { + Short shortNum = Short.parseShort("36"); + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry("36", shortNum); + Schema schema = + Schema.builder() + .addInt32Field("an_integer") + .addInt64Field("a_long") + .addInt16Field("a_short") + .build(); + assertEquals( + cellToExpectedValue.getValue(), + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("a_short"))); + } + + @Test + public void givenValidLongCell_parses() { + Long longNum = Long.parseLong("1234567890"); + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry("1234567890", longNum); + Schema schema = + Schema.builder() + .addInt32Field("an_integer") + .addInt64Field("a_long") + .addInt16Field("a_short") + .build(); + assertEquals( + cellToExpectedValue.getValue(), + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("a_long"))); + } + + @Test + public void givenValidFloatCell_parses() { + Float floatNum = Float.parseFloat("3.141592"); + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry("3.141592", floatNum); + Schema schema = Schema.builder().addFloatField("a_float").addDoubleField("a_double").build(); + assertEquals( + cellToExpectedValue.getValue(), + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("a_float"))); + } + + @Test + public void givenValidDateTimeCell_parses() { + Instant datetime = Instant.parse("2020-01-01T00:00:00.000Z"); + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry("2020-01-01T00:00:00.000Z", datetime); + Schema schema = + Schema.builder().addDateTimeField("a_datetime").addStringField("a_string").build(); + assertEquals( + cellToExpectedValue.getValue(), + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("a_datetime"))); + } + + @Test + public void givenValidByteCell_parses() { + Byte byteNum = Byte.parseByte("4"); + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry("4", byteNum); + Schema schema = Schema.builder().addByteField("a_byte").addInt32Field("an_integer").build(); + assertEquals( + cellToExpectedValue.getValue(), + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("a_byte"))); + } + + @Test + public void givenValidBooleanCell_parses() { + DefaultMapEntry cellToExpectedValue = new DefaultMapEntry("false", false); + Schema schema = + Schema.builder().addBooleanField("a_boolean").addStringField("a_string").build(); + assertEquals( + cellToExpectedValue.getValue(), + CsvIOParseHelpers.parseCell( + cellToExpectedValue.getKey().toString(), schema.getField("a_boolean"))); + } + + @Test + public void givenCellSchemaFieldMismatch_throws() { + String boolTrue = "true"; + Schema schema = Schema.builder().addBooleanField("a_boolean").addFloatField("a_float").build(); + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, + () -> CsvIOParseHelpers.parseCell(boolTrue, schema.getField("a_float"))); + assertEquals( + "For input string: \"" + boolTrue + "\" field a_float was received -- type mismatch", + e.getMessage()); + } + + @Test + public void givenCellUnsupportedType_throws() { + String counting = "[one,two,three]"; + Schema schema = + Schema.builder() + .addField("an_array", Schema.FieldType.array(Schema.FieldType.STRING)) + .addStringField("a_string") + .build(); + UnsupportedOperationException e = + assertThrows( + UnsupportedOperationException.class, + () -> CsvIOParseHelpers.parseCell(counting, schema.getField("an_array"))); + assertEquals( + "Unsupported type: " + + schema.getField("an_array").getType() + + ", consider using withCustomRecordParsing", + e.getMessage()); + } + + /** End of tests for {@link CsvIOParseHelpers#parseCell(String, Schema.Field)}. */ + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** Return a {@link CSVFormat} with a header and with no duplicate header names allowed. */ + private static CSVFormat csvFormatWithHeader() { + return csvFormat().withHeader("foo", "bar"); + } + + /** Return a {@link CSVFormat} with no header and with no duplicate header names allowed. */ + private static CSVFormat csvFormat() { + return CSVFormat.DEFAULT.withAllowDuplicateHeaderNames(false); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/SerializableSupplier.java b/sdks/java/io/csv/src/test/java/org/apache/beam/sdk/io/csv/CsvIOParseKVTest.java similarity index 71% rename from sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/SerializableSupplier.java rename to sdks/java/io/csv/src/test/java/org/apache/beam/sdk/io/csv/CsvIOParseKVTest.java index 2b09adbc75dd..c20a29174503 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/SerializableSupplier.java +++ b/sdks/java/io/csv/src/test/java/org/apache/beam/sdk/io/csv/CsvIOParseKVTest.java @@ -15,11 +15,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.io.gcp.bigtable.changestreams.dofn; +package org.apache.beam.sdk.io.csv; -import java.io.Serializable; -import java.util.function.Supplier; +import org.apache.beam.sdk.util.SerializableUtils; +import org.junit.Test; -/** Union of Supplier and Serializable interfaces to allow serialized supplier for testing. */ -@FunctionalInterface -interface SerializableSupplier extends Supplier, Serializable {} +/** Contains tests for {@link CsvIOParseKV}. */ +public class CsvIOParseKVTest { + @Test + public void isSerializable() { + SerializableUtils.ensureSerializable(CsvIOParseKV.class); + } +} diff --git a/sdks/java/io/csv/src/test/java/org/apache/beam/sdk/io/csv/CsvIORecordToObjectsTest.java b/sdks/java/io/csv/src/test/java/org/apache/beam/sdk/io/csv/CsvIORecordToObjectsTest.java new file mode 100644 index 000000000000..eb8cacdec5ab --- /dev/null +++ b/sdks/java/io/csv/src/test/java/org/apache/beam/sdk/io/csv/CsvIORecordToObjectsTest.java @@ -0,0 +1,398 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.csv; + +import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.ALL_PRIMITIVE_DATA_TYPES_SCHEMA; +import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.ALL_PRIMITIVE_DATA_TYPES_TYPE_DESCRIPTOR; +import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.AllPrimitiveDataTypes; +import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.NULLABLE_ALL_PRIMITIVE_DATA_TYPES_SCHEMA; +import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.NULLABLE_ALL_PRIMITIVE_DATA_TYPES_TYPE_DESCRIPTOR; +import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.NullableAllPrimitiveDataTypes; +import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.TIME_CONTAINING_SCHEMA; +import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.TIME_CONTAINING_TYPE_DESCRIPTOR; +import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.TimeContaining; +import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.allPrimitiveDataTypes; +import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.allPrimitiveDataTypesFromRowFn; +import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.allPrimitiveDataTypesToRowFn; +import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.nullableAllPrimitiveDataTypes; +import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.nullableAllPrimitiveDataTypesFromRowFn; +import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.nullableAllPrimitiveDataTypesToRowFn; +import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.timeContaining; +import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.timeContainingFromRowFn; +import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.timeContainingToRowFn; +import static org.junit.Assert.assertThrows; + +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.coders.NullableCoder; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.io.common.SchemaAwareJavaBeans; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.commons.csv.CSVFormat; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link CsvIORecordToObjects}. */ +@RunWith(JUnit4.class) +public class CsvIORecordToObjectsTest { + + @Rule public final TestPipeline pipeline = TestPipeline.create(); + private static final SerializableFunction ROW_ROW_SERIALIZABLE_FUNCTION = row -> row; + private static final RowCoder ALL_PRIMITIVE_DATA_TYPES_ROW_CODER = + RowCoder.of(ALL_PRIMITIVE_DATA_TYPES_SCHEMA); + private static final Coder NULLABLE_ALL_PRIMITIVE_DATA_TYPES_ROW_CODER = + NullableCoder.of(RowCoder.of(NULLABLE_ALL_PRIMITIVE_DATA_TYPES_SCHEMA)); + private static final Coder + ALL_PRIMITIVE_DATA_TYPES_CODER = + SchemaCoder.of( + ALL_PRIMITIVE_DATA_TYPES_SCHEMA, + ALL_PRIMITIVE_DATA_TYPES_TYPE_DESCRIPTOR, + allPrimitiveDataTypesToRowFn(), + allPrimitiveDataTypesFromRowFn()); + private static final Coder + NULLABLE_ALL_PRIMITIVE_DATA_TYPES_CODER = + SchemaCoder.of( + NULLABLE_ALL_PRIMITIVE_DATA_TYPES_SCHEMA, + NULLABLE_ALL_PRIMITIVE_DATA_TYPES_TYPE_DESCRIPTOR, + nullableAllPrimitiveDataTypesToRowFn(), + nullableAllPrimitiveDataTypesFromRowFn()); + private static final Coder TIME_CONTAINING_ROW_CODER = RowCoder.of(TIME_CONTAINING_SCHEMA); + private static final Coder TIME_CONTAINING_POJO_CODER = + SchemaCoder.of( + TIME_CONTAINING_SCHEMA, + TIME_CONTAINING_TYPE_DESCRIPTOR, + timeContainingToRowFn(), + timeContainingFromRowFn()); + + @Test + public void isSerializable() { + SerializableUtils.ensureSerializable(CsvIORecordToObjects.class); + } + + @Test + public void parsesToRows() { + PCollection> input = + csvRecords(pipeline, "true", "1.0", "2.0", "3.0", "4", "5", "foo"); + Row want = + Row.withSchema(ALL_PRIMITIVE_DATA_TYPES_SCHEMA) + .withFieldValues( + ImmutableMap.of( + "aBoolean", + true, + "aDecimal", + BigDecimal.valueOf(1.0), + "aDouble", + 2.0, + "aFloat", + 3.0f, + "anInteger", + 4, + "aLong", + 5L, + "aString", + "foo")) + .build(); + CsvIORecordToObjects underTest = + underTest( + ALL_PRIMITIVE_DATA_TYPES_SCHEMA, + allPrimitiveDataTypesCsvFormat(), + emptyCustomProcessingMap(), + ROW_ROW_SERIALIZABLE_FUNCTION, + ALL_PRIMITIVE_DATA_TYPES_ROW_CODER); + PAssert.that(input.apply(underTest)).containsInAnyOrder(want); + pipeline.run(); + } + + @Test + public void parsesToPojos() { + PCollection> input = + csvRecords(pipeline, "true", "1.0", "2.0", "3.0", "4", "5", "foo"); + SchemaAwareJavaBeans.AllPrimitiveDataTypes want = + allPrimitiveDataTypes(true, BigDecimal.valueOf(1.0), 2.0d, 3.0f, 4, 5L, "foo"); + CsvIORecordToObjects underTest = + underTest( + ALL_PRIMITIVE_DATA_TYPES_SCHEMA, + allPrimitiveDataTypesCsvFormat(), + emptyCustomProcessingMap(), + allPrimitiveDataTypesFromRowFn(), + ALL_PRIMITIVE_DATA_TYPES_CODER); + PAssert.that(input.apply(underTest)).containsInAnyOrder(want); + pipeline.run(); + } + + @Test + public void givenNullableField_containsNullCell_parsesToRows() { + PCollection> input = csvRecords(pipeline, "true", "1.0", "2.0", "3", "4", null); + Row want = + Row.withSchema(NULLABLE_ALL_PRIMITIVE_DATA_TYPES_SCHEMA) + .withFieldValue("aBoolean", true) + .withFieldValue("aDouble", 1.0) + .withFieldValue("aFloat", 2.0f) + .withFieldValue("anInteger", 3) + .withFieldValue("aLong", 4L) + .withFieldValue("aString", null) + .build(); + + CsvIORecordToObjects underTest = + underTest( + NULLABLE_ALL_PRIMITIVE_DATA_TYPES_SCHEMA, + nullableAllPrimitiveDataTypesCsvFormat(), + emptyCustomProcessingMap(), + ROW_ROW_SERIALIZABLE_FUNCTION, + NULLABLE_ALL_PRIMITIVE_DATA_TYPES_ROW_CODER); + PAssert.that(input.apply(underTest)).containsInAnyOrder(want); + pipeline.run(); + } + + @Test + public void givenNullableField_containsNullCell_parsesToPojos() { + PCollection> input = csvRecords(pipeline, "true", "1.0", "2.0", "3", "4", null); + SchemaAwareJavaBeans.NullableAllPrimitiveDataTypes want = + nullableAllPrimitiveDataTypes(true, 1.0, 2.0f, 3, 4L, null); + + CsvIORecordToObjects underTest = + underTest( + NULLABLE_ALL_PRIMITIVE_DATA_TYPES_SCHEMA, + nullableAllPrimitiveDataTypesCsvFormat(), + emptyCustomProcessingMap(), + nullableAllPrimitiveDataTypesFromRowFn(), + NULLABLE_ALL_PRIMITIVE_DATA_TYPES_CODER); + PAssert.that(input.apply(underTest)).containsInAnyOrder(want); + pipeline.run(); + } + + @Test + public void givenNoNullableField_containsNullCell_throws() { + PCollection> input = + csvRecords(pipeline, "true", "1.0", "2.0", "3.0", "4", "5", null); + pipeline.apply( + "Null Cell with No Nullable Fields", + Create.of( + Collections.singletonList( + Arrays.asList("true", "1.0", "2.0", "3.0", "4", "5", null))) + .withCoder(ListCoder.of(NullableCoder.of(StringUtf8Coder.of())))); + CsvIORecordToObjects underTest = + underTest( + ALL_PRIMITIVE_DATA_TYPES_SCHEMA, + allPrimitiveDataTypesCsvFormat(), + emptyCustomProcessingMap(), + allPrimitiveDataTypesFromRowFn(), + ALL_PRIMITIVE_DATA_TYPES_CODER); + input.apply(underTest); + assertThrows(Pipeline.PipelineExecutionException.class, pipeline::run); + } + + @Test + public void givenAllNullableFields_emptyRecord_parsesToRows() { + PCollection> input = emptyCsvRecords(pipeline); + CsvIORecordToObjects underTest = + underTest( + NULLABLE_ALL_PRIMITIVE_DATA_TYPES_SCHEMA, + nullableAllPrimitiveDataTypesCsvFormat(), + emptyCustomProcessingMap(), + ROW_ROW_SERIALIZABLE_FUNCTION, + NULLABLE_ALL_PRIMITIVE_DATA_TYPES_ROW_CODER); + PAssert.that(input.apply(underTest)).empty(); + pipeline.run(); + } + + @Test + public void givenAllNullableFields_emptyRecord_parsesToPojos() { + PCollection> input = emptyCsvRecords(pipeline); + CsvIORecordToObjects underTest = + underTest( + ALL_PRIMITIVE_DATA_TYPES_SCHEMA, + allPrimitiveDataTypesCsvFormat(), + emptyCustomProcessingMap(), + allPrimitiveDataTypesFromRowFn(), + ALL_PRIMITIVE_DATA_TYPES_CODER); + PAssert.that(input.apply(underTest)).empty(); + pipeline.run(); + } + + @Test + public void givenFieldHasCustomProcessing_parsesToRows() { + PCollection> input = + csvRecords( + pipeline, + "2024-07-25T11:25:14.000Z", + "2024-07-25T11:26:01.000Z,2024-07-25T11:26:22.000Z,2024-07-25T11:26:38.000Z"); + Row want = + Row.withSchema(TIME_CONTAINING_SCHEMA) + .withFieldValue("instant", Instant.parse("2024-07-25T11:25:14.000Z")) + .withFieldValue( + "instantList", + Arrays.asList( + Instant.parse("2024-07-25T11:26:01.000Z"), + Instant.parse("2024-07-25T11:26:22.000Z"), + Instant.parse("2024-07-25T11:26:38.000Z"))) + .build(); + CsvIORecordToObjects underTest = + underTest( + TIME_CONTAINING_SCHEMA, + timeContainingCsvFormat(), + timeContainingCustomProcessingMap(), + ROW_ROW_SERIALIZABLE_FUNCTION, + TIME_CONTAINING_ROW_CODER); + PAssert.that(input.apply(underTest)).containsInAnyOrder(want); + pipeline.run(); + } + + @Test + public void givenFieldHasCustomProcessing_parsesToPojos() { + PCollection> input = + csvRecords( + pipeline, + "2024-07-25T11:25:14.000Z", + "2024-07-25T11:26:01.000Z,2024-07-25T11:26:22.000Z,2024-07-25T11:26:38.000Z"); + TimeContaining want = + timeContaining( + Instant.parse("2024-07-25T11:25:14.000Z"), + Arrays.asList( + Instant.parse("2024-07-25T11:26:01.000Z"), + Instant.parse("2024-07-25T11:26:22.000Z"), + Instant.parse("2024-07-25T11:26:38.000Z"))); + CsvIORecordToObjects underTest = + underTest( + TIME_CONTAINING_SCHEMA, + timeContainingCsvFormat(), + timeContainingCustomProcessingMap(), + timeContainingFromRowFn(), + TIME_CONTAINING_POJO_CODER); + PAssert.that(input.apply(underTest)).containsInAnyOrder(want); + pipeline.run(); + } + + @Test + public void givenInvalidCell_throws() { + PCollection> input = + csvRecords(pipeline, "true", "invalid cell for Decimal", "2.0", "3.0", "4", "5", "foo"); + CsvIORecordToObjects underTest = + underTest( + ALL_PRIMITIVE_DATA_TYPES_SCHEMA, + allPrimitiveDataTypesCsvFormat(), + emptyCustomProcessingMap(), + allPrimitiveDataTypesFromRowFn(), + ALL_PRIMITIVE_DATA_TYPES_CODER); + input.apply(underTest); + assertThrows(Pipeline.PipelineExecutionException.class, pipeline::run); + } + + @Test + public void givenInvalidCustomProcessing_throws() { + PCollection> input = + csvRecords( + pipeline, + "2024-07-25T11:25:14.000Z", + "2024-15-25T11:26:01.000Z,2024-24-25T11:26:22.000Z,2024-96-25T11:26:38.000Z"); + CsvIORecordToObjects underTest = + underTest( + TIME_CONTAINING_SCHEMA, + timeContainingCsvFormat(), + timeContainingCustomProcessingMap(), + timeContainingFromRowFn(), + TIME_CONTAINING_POJO_CODER); + input.apply(underTest); + assertThrows(Pipeline.PipelineExecutionException.class, pipeline::run); + } + + private static PCollection> csvRecords(Pipeline pipeline, String... cells) { + return pipeline.apply( + Create.of(Collections.singletonList(Arrays.asList(cells))) + .withCoder(ListCoder.of(NullableCoder.of(StringUtf8Coder.of())))); + } + + private static PCollection> emptyCsvRecords(Pipeline pipeline) { + return pipeline.apply(Create.empty(ListCoder.of(StringUtf8Coder.of()))); + } + + private static CsvIORecordToObjects underTest( + Schema schema, + CSVFormat csvFormat, + Map> customProcessingMap, + SerializableFunction fromRowFn, + Coder coder) { + CsvIOParseConfiguration configuration = + CsvIOParseConfiguration.builder() + .setSchema(schema) + .setCsvFormat(csvFormat) + .setCustomProcessingMap(customProcessingMap) + .setFromRowFn(fromRowFn) + .setCoder(coder) + .build(); + return new CsvIORecordToObjects<>(configuration); + } + + private static Map> emptyCustomProcessingMap() { + return new HashMap<>(); + } + + private static Map> + timeContainingCustomProcessingMap() { + Map> customProcessingMap = new HashMap<>(); + customProcessingMap.put( + "instantList", + input -> { + List cells = Arrays.asList(input.split(",")); + List output = new ArrayList<>(); + cells.forEach(cell -> output.add(Instant.parse(cell))); + return output; + }); + return customProcessingMap; + } + + private static CSVFormat allPrimitiveDataTypesCsvFormat() { + return CSVFormat.DEFAULT + .withAllowDuplicateHeaderNames(false) + .withHeader("aBoolean", "aDecimal", "aDouble", "aFloat", "anInteger", "aLong", "aString"); + } + + private static CSVFormat nullableAllPrimitiveDataTypesCsvFormat() { + return CSVFormat.DEFAULT + .withAllowDuplicateHeaderNames(false) + .withHeader("aBoolean", "aDouble", "aFloat", "anInteger", "aLong", "aString") + .withNullString("null"); + } + + private static CSVFormat timeContainingCsvFormat() { + return CSVFormat.DEFAULT + .withAllowDuplicateHeaderNames(false) + .withHeader("instant", "instantList"); + } +} diff --git a/sdks/java/io/csv/src/test/java/org/apache/beam/sdk/io/csv/CsvIOStringToCsvRecordTest.java b/sdks/java/io/csv/src/test/java/org/apache/beam/sdk/io/csv/CsvIOStringToCsvRecordTest.java new file mode 100644 index 000000000000..1b81391c4fb0 --- /dev/null +++ b/sdks/java/io/csv/src/test/java/org/apache/beam/sdk/io/csv/CsvIOStringToCsvRecordTest.java @@ -0,0 +1,494 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.csv; + +import static org.apache.beam.sdk.io.csv.CsvIOStringToCsvRecord.headerLine; + +import java.util.Arrays; +import java.util.Collections; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PCollection; +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.QuoteMode; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link CsvIOStringToCsvRecord}. */ +@RunWith(JUnit4.class) +public class CsvIOStringToCsvRecordTest { + @Rule public final TestPipeline pipeline = TestPipeline.create(); + + private static final String[] header = {"a_string", "an_integer", "a_double"}; + + @Test + public void givenCommentMarker_skipsLine() { + CSVFormat csvFormat = csvFormat().withCommentMarker('#'); + PCollection input = + pipeline.apply( + Create.of(headerLine(csvFormat), "#should skip me", "a,1,1.1", "b,2,2.2", "c,3,3.3")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a", "1", "1.1"), + Arrays.asList("b", "2", "2.2"), + Arrays.asList("c", "3", "3.3"))); + + pipeline.run(); + } + + @Test + public void givenNoCommentMarker_doesntSkipLine() { + CSVFormat csvFormat = csvFormat(); + PCollection input = + pipeline.apply( + Create.of(headerLine(csvFormat), "#comment", "a,1,1.1", "b,2,2.2", "c,3,3.3")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Collections.singletonList("#comment"), + Arrays.asList("a", "1", "1.1"), + Arrays.asList("b", "2", "2.2"), + Arrays.asList("c", "3", "3.3"))); + + pipeline.run(); + } + + @Test + public void givenCustomDelimiter_splitsCells() { + CSVFormat csvFormat = csvFormat().withDelimiter(';'); + PCollection input = + pipeline.apply(Create.of(headerLine(csvFormat), "a;1;1.1", "b;2;2.2", "c;3;3.3")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a", "1", "1.1"), + Arrays.asList("b", "2", "2.2"), + Arrays.asList("c", "3", "3.3"))); + + pipeline.run(); + } + + @Test + public void givenEscapeCharacter_includeInCell() { + CSVFormat csvFormat = csvFormat().withEscape('$'); + PCollection input = + pipeline.apply(Create.of(headerLine(csvFormat), "a$,b,1,1.1", "b,2,2.2", "c,3,3.3")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a,b", "1", "1.1"), + Arrays.asList("b", "2", "2.2"), + Arrays.asList("c", "3", "3.3"))); + + pipeline.run(); + } + + @Test + public void givenHeaderComment_isNoop() { + CSVFormat csvFormat = csvFormat().withHeaderComments("abc", "def", "xyz"); + PCollection input = + pipeline.apply(Create.of(headerLine(csvFormat), "a,1,1.1", "b,2,2.2", "c,3,3.3")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a", "1", "1.1"), + Arrays.asList("b", "2", "2.2"), + Arrays.asList("c", "3", "3.3"))); + + pipeline.run(); + } + + @Test + public void givenIgnoreEmptyLines_shouldSkip() { + CSVFormat csvFormat = csvFormat().withIgnoreEmptyLines(true); + PCollection input = + pipeline.apply(Create.of(headerLine(csvFormat), "a,1,1.1", "", "b,2,2.2", "", "c,3,3.3")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a", "1", "1.1"), + Arrays.asList("b", "2", "2.2"), + Arrays.asList("c", "3", "3.3"))); + + pipeline.run(); + } + + @Test + public void givenNoIgnoreEmptyLines_isNoop() { + CSVFormat csvFormat = csvFormat().withIgnoreEmptyLines(false); + PCollection input = + pipeline.apply(Create.of(headerLine(csvFormat), "a,1,1.1", "", "b,2,2.2", "", "c,3,3.3")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a", "1", "1.1"), + Arrays.asList("b", "2", "2.2"), + Arrays.asList("c", "3", "3.3"))); + + pipeline.run(); + } + + @Test + public void givenIgnoreSurroundingSpaces_removesSpaces() { + CSVFormat csvFormat = csvFormat().withIgnoreSurroundingSpaces(true); + PCollection input = + pipeline.apply( + Create.of( + headerLine(csvFormat), + " a ,1,1.1", + "b, 2 ,2.2", + "c,3, 3.3 ")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a", "1", "1.1"), + Arrays.asList("b", "2", "2.2"), + Arrays.asList("c", "3", "3.3"))); + + pipeline.run(); + } + + @Test + public void givenNotIgnoreSurroundingSpaces_keepsSpaces() { + CSVFormat csvFormat = csvFormat().withIgnoreSurroundingSpaces(false); + PCollection input = + pipeline.apply( + Create.of( + headerLine(csvFormat), + " a ,1,1.1", + "b, 2 ,2.2", + "c,3, 3.3 ")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList(" a ", "1", "1.1"), + Arrays.asList("b", " 2 ", "2.2"), + Arrays.asList("c", "3", " 3.3 "))); + + pipeline.run(); + } + + @Test + public void givenNullString_parsesNullCells() { + CSVFormat csvFormat = csvFormat().withNullString("🐼"); + PCollection input = + pipeline.apply(Create.of(headerLine(csvFormat), "a,1,🐼", "b,🐼,2.2", "🐼,3,3.3")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a", "1", null), + Arrays.asList("b", null, "2.2"), + Arrays.asList(null, "3", "3.3"))); + + pipeline.run(); + } + + @Test + public void givenNoNullString_isNoop() { + CSVFormat csvFormat = csvFormat(); + PCollection input = + pipeline.apply(Create.of(headerLine(csvFormat), "a,1,🐼", "b,🐼,2.2", "🐼,3,3.3")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a", "1", "🐼"), + Arrays.asList("b", "🐼", "2.2"), + Arrays.asList("🐼", "3", "3.3"))); + + pipeline.run(); + } + + @Test + public void givenCustomQuoteCharacter_includesSpecialCharacters() { + CSVFormat csvFormat = csvFormat().withQuote(':'); + PCollection input = + pipeline.apply(Create.of(headerLine(csvFormat), ":a,:,1,1.1", "b,2,2.2", "c,3,3.3")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a,", "1", "1.1"), + Arrays.asList("b", "2", "2.2"), + Arrays.asList("c", "3", "3.3"))); + pipeline.run(); + } + + @Test + public void givenQuoteModeAll_isNoop() { + CSVFormat csvFormat = csvFormat().withQuoteMode(QuoteMode.ALL); + PCollection input = + pipeline.apply( + Create.of( + headerLine(csvFormat), + "\"a\",\"1\",\"1.1\"", + "\"b\",\"2\",\"2.2\"", + "\"c\",\"3\",\"3.3\"")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a", "1", "1.1"), + Arrays.asList("b", "2", "2.2"), + Arrays.asList("c", "3", "3.3"))); + + pipeline.run(); + } + + @Test + public void givenQuoteModeAllNonNull_isNoop() { + CSVFormat csvFormat = csvFormat().withNullString("N/A").withQuoteMode(QuoteMode.ALL_NON_NULL); + PCollection input = + pipeline.apply( + Create.of( + headerLine(csvFormat), + "\"a\",\"1\",N/A", + "\"b\",\"2\",\"2.2\"", + "\"c\",\"3\",\"3.3\"")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a", "1", null), + Arrays.asList("b", "2", "2.2"), + Arrays.asList("c", "3", "3.3"))); + + pipeline.run(); + } + + @Test + public void givenQuoteModeMinimal_isNoop() { + CSVFormat csvFormat = csvFormat().withQuoteMode(QuoteMode.MINIMAL); + PCollection input = + pipeline.apply(Create.of(headerLine(csvFormat), "\"a,\",1,1.1", "b,2,2.2", "c,3,3.3")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a,", "1", "1.1"), + Arrays.asList("b", "2", "2.2"), + Arrays.asList("c", "3", "3.3"))); + pipeline.run(); + } + + @Test + public void givenQuoteModeNonNumeric_isNoop() { + CSVFormat csvFormat = csvFormat().withQuoteMode(QuoteMode.NON_NUMERIC); + PCollection input = + pipeline.apply( + Create.of(headerLine(csvFormat), "\"a\",1,1.1", "\"b\",2,2.2", "\"c\",3,3.3")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a", "1", "1.1"), + Arrays.asList("b", "2", "2.2"), + Arrays.asList("c", "3", "3.3"))); + pipeline.run(); + } + + @Test + public void givenQuoteModeNone_isNoop() { + CSVFormat csvFormat = csvFormat().withEscape('$').withQuoteMode(QuoteMode.NONE); + PCollection input = + pipeline.apply(Create.of(headerLine(csvFormat), "a,1,1.1", "b,2,2.2", "c,3,3.3")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a", "1", "1.1"), + Arrays.asList("b", "2", "2.2"), + Arrays.asList("c", "3", "3.3"))); + pipeline.run(); + } + + @Test + public void givenCustomRecordSeparator_isNoop() { + CSVFormat csvFormat = csvFormat().withRecordSeparator("😆"); + PCollection input = + pipeline.apply(Create.of(headerLine(csvFormat), "a,1,1.1😆b,2,2.2😆c,3,3.3")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Collections.singletonList( + Arrays.asList("a", "1", "1.1😆b", "2", "2.2😆c", "3", "3.3"))); + pipeline.run(); + } + + @Test + public void givenSystemRecordSeparator_isNoop() { + CSVFormat csvFormat = csvFormat().withSystemRecordSeparator(); + String systemRecordSeparator = csvFormat.getRecordSeparator(); + PCollection input = + pipeline.apply( + Create.of( + headerLine(csvFormat), + "a,1,1.1" + systemRecordSeparator + "b,2,2.2" + systemRecordSeparator + "c,3,3.3")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a", "1", "1.1"), + Arrays.asList("b", "2", "2.2"), + Arrays.asList("c", "3", "3.3"))); + pipeline.run(); + } + + @Test + public void givenTrailingDelimiter_skipsEndingDelimiter() { + CSVFormat csvFormat = csvFormat().withTrailingDelimiter(true); + PCollection input = + pipeline.apply(Create.of(headerLine(csvFormat), "a,1,1.1,", "b,2,2.2,", "c,3,3.3,")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a", "1", "1.1"), + Arrays.asList("b", "2", "2.2"), + Arrays.asList("c", "3", "3.3"))); + pipeline.run(); + } + + @Test + public void givenNoTrailingDelimiter_includesEndingCell() { + CSVFormat csvFormat = csvFormat().withTrailingDelimiter(false); + PCollection input = + pipeline.apply(Create.of(headerLine(csvFormat), "a,1,1.1,", "b,2,2.2,", "c,3,3.3,")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a", "1", "1.1", ""), + Arrays.asList("b", "2", "2.2", ""), + Arrays.asList("c", "3", "3.3", ""))); + pipeline.run(); + } + + @Test + public void givenTrim_removesSpaces() { + CSVFormat csvFormat = csvFormat().withTrim(true); + PCollection input = + pipeline.apply( + Create.of( + headerLine(csvFormat), + " a ,1,1.1", + "b, 2 ,2.2", + "c,3, 3.3 ")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a", "1", "1.1"), + Arrays.asList("b", "2", "2.2"), + Arrays.asList("c", "3", "3.3"))); + + pipeline.run(); + } + + @Test + public void givenNoTrim_keepsSpaces() { + CSVFormat csvFormat = csvFormat().withTrim(false); + PCollection input = + pipeline.apply( + Create.of( + headerLine(csvFormat), + " a ,1,1.1", + "b, 2 ,2.2", + "c,3, 3.3 ")); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList(" a ", "1", "1.1"), + Arrays.asList("b", " 2 ", "2.2"), + Arrays.asList("c", "3", " 3.3 "))); + + pipeline.run(); + } + + @Test + public void testSingleLineCsvRecord() { + String csvRecord = "a,1"; + PCollection input = pipeline.apply(Create.of(csvRecord)); + + CsvIOStringToCsvRecord underTest = new CsvIOStringToCsvRecord(csvFormat()); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder(Collections.singletonList(Arrays.asList("a", "1"))); + + pipeline.run(); + } + + @Test + public void testMultiLineCsvRecord() { + String csvRecords = + "\"a\r\n1\",\"a\r\n2\"" + "\n" + "\"b\r\n1\",\"b\r\n2\"" + "\n" + "\"c\r\n1\",\"c\r\n2\""; + PCollection input = pipeline.apply(Create.of(csvRecords)); + + CsvIOStringToCsvRecord underTest = + new CsvIOStringToCsvRecord(csvFormat().withRecordSeparator('\n')); + PAssert.that(input.apply(underTest)) + .containsInAnyOrder( + Arrays.asList( + Arrays.asList("a\r\n1", "a\r\n2"), + Arrays.asList("b\r\n1", "b\r\n2"), + Arrays.asList("c\r\n1", "c\r\n2"))); + + pipeline.run(); + } + + private static CSVFormat csvFormat() { + return CSVFormat.DEFAULT.withAllowDuplicateHeaderNames(false).withHeader(header); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AppendClientInfo.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AppendClientInfo.java index 5a12e81ea79d..7505f77fb5b4 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AppendClientInfo.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AppendClientInfo.java @@ -40,8 +40,8 @@ */ @AutoValue abstract class AppendClientInfo { - private final Counter activeConnections = - Metrics.counter(AppendClientInfo.class, "activeConnections"); + private final Counter activeStreamAppendClients = + Metrics.counter(AppendClientInfo.class, "activeStreamAppendClients"); abstract @Nullable BigQueryServices.StreamAppendClient getStreamAppendClient(); @@ -123,7 +123,7 @@ public AppendClientInfo withAppendClient( writeStreamService.getStreamAppendClient( streamName, getDescriptor(), useConnectionPool, missingValueInterpretation); - activeConnections.inc(); + activeStreamAppendClients.inc(); return toBuilder().setStreamName(streamName).setStreamAppendClient(client).build(); } @@ -133,7 +133,7 @@ public void close() { BigQueryServices.StreamAppendClient client = getStreamAppendClient(); if (client != null) { getCloseAppendClient().accept(client); - activeConnections.dec(); + activeStreamAppendClients.dec(); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryHelpers.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryHelpers.java index 8c600cf780ae..61bed66a3365 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryHelpers.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryHelpers.java @@ -54,6 +54,7 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.FluentBackoff; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; @@ -412,6 +413,25 @@ public static String toTableSpec(TableReference ref) { return sb.toString(); } + public static String dataCatalogName(TableReference ref, BigQueryOptions options) { + String tableIdBase; + int ix = ref.getTableId().indexOf('$'); + if (ix == -1) { + tableIdBase = ref.getTableId(); + } else { + tableIdBase = ref.getTableId().substring(0, ix); + } + String projectId; + if (!Strings.isNullOrEmpty(ref.getProjectId())) { + projectId = ref.getProjectId(); + } else if (!Strings.isNullOrEmpty(options.getBigQueryProject())) { + projectId = options.getBigQueryProject(); + } else { + projectId = options.getProject(); + } + return String.format("bigquery:%s.%s.%s", projectId, ref.getDatasetId(), tableIdBase); + } + static List getOrCreateMapListValue(Map> map, K key) { return map.computeIfAbsent(key, k -> new ArrayList<>()); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryOptions.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryOptions.java index cd1fc6d3842c..ba76f483f774 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryOptions.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryOptions.java @@ -109,6 +109,28 @@ public interface BigQueryOptions void setNumStorageWriteApiStreamAppendClients(Integer value); + @Description( + "When using the STORAGE_API_AT_LEAST_ONCE write method with multiplexing (ie. useStorageApiConnectionPool=true), " + + "this option sets the minimum number of connections each pool creates before any connections are shared. This is " + + "on a per worker, per region basis. Note that in practice, the minimum number of connections created is the minimum " + + "of this value and (numStorageWriteApiStreamAppendClients x num destinations). BigQuery will create this many " + + "connections at first and will only create more connections if the current ones are \"overwhelmed\". Consider " + + "increasing this value if you are running into performance issues.") + @Default.Integer(2) + Integer getMinConnectionPoolConnections(); + + void setMinConnectionPoolConnections(Integer value); + + @Description( + "When using the STORAGE_API_AT_LEAST_ONCE write method with multiplexing (ie. useStorageApiConnectionPool=true), " + + "this option sets the maximum number of connections each pool creates. This is on a per worker, per region basis. " + + "If writing to many dynamic destinations (>20) and experiencing performance issues or seeing append operations competing" + + "for streams, consider increasing this value.") + @Default.Integer(20) + Integer getMaxConnectionPoolConnections(); + + void setMaxConnectionPoolConnections(Integer value); + @Description("The max number of messages inflight that we expect each connection will retain.") @Default.Long(1000) Long getStorageWriteMaxInflightRequests(); @@ -122,6 +144,11 @@ public interface BigQueryOptions void setStorageWriteMaxInflightBytes(Long value); + @Description( + "Enables multiplexing mode, where multiple tables can share the same connection. Only available when writing with STORAGE_API_AT_LEAST_ONCE" + + " mode. This is recommended if your write operation is creating 20+ connections. When using multiplexing, consider tuning " + + "the number of connections created by the connection pool with minConnectionPoolConnections and maxConnectionPoolConnections. " + + "For more information, see https://cloud.google.com/bigquery/docs/write-api-best-practices#connection_pool_management") @Default.Boolean(false) Boolean getUseStorageApiConnectionPool(); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java index 2bdba0b053c8..b87b6a222a4d 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java @@ -69,6 +69,7 @@ import com.google.cloud.bigquery.storage.v1.BigQueryReadSettings; import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; import com.google.cloud.bigquery.storage.v1.BigQueryWriteSettings; +import com.google.cloud.bigquery.storage.v1.ConnectionWorkerPool; import com.google.cloud.bigquery.storage.v1.CreateReadSessionRequest; import com.google.cloud.bigquery.storage.v1.CreateWriteStreamRequest; import com.google.cloud.bigquery.storage.v1.FinalizeWriteStreamRequest; @@ -574,7 +575,7 @@ public static class DatasetServiceImpl implements DatasetService { private final long maxRowBatchSize; // aggregate the total time spent in exponential backoff private final Counter throttlingMsecs = - Metrics.counter(DatasetServiceImpl.class, "throttling-msecs"); + Metrics.counter(DatasetServiceImpl.class, Metrics.THROTTLE_TIME_COUNTER_NAME); private @Nullable BoundedExecutorService executor; @@ -1423,6 +1424,14 @@ public StreamAppendClient getStreamAppendClient( bqIOMetadata.getBeamJobId() == null ? "" : bqIOMetadata.getBeamJobId(), bqIOMetadata.getBeamWorkerId() == null ? "" : bqIOMetadata.getBeamWorkerId()); + ConnectionWorkerPool.setOptions( + ConnectionWorkerPool.Settings.builder() + .setMinConnectionsPerRegion( + options.as(BigQueryOptions.class).getMinConnectionPoolConnections()) + .setMaxConnectionsPerRegion( + options.as(BigQueryOptions.class).getMaxConnectionPoolConnections()) + .build()); + StreamWriter streamWriter = StreamWriter.newBuilder(streamName, newWriteClient) .setExecutorProvider( @@ -1654,7 +1663,7 @@ public void cancel() { static class StorageClientImpl implements StorageClient { public static final Counter THROTTLING_MSECS = - Metrics.counter(StorageClientImpl.class, "throttling-msecs"); + Metrics.counter(StorageClientImpl.class, Metrics.THROTTLE_TIME_COUNTER_NAME); private transient long unreportedDelay = 0L; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySourceBase.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySourceBase.java index 96abde5dc357..998c82ab8d83 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySourceBase.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySourceBase.java @@ -41,6 +41,7 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.Status; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryResourceNaming.JobType; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.JobService; +import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; @@ -119,6 +120,8 @@ protected ExtractResult extractFiles(PipelineOptions options) throws Exception { "Cannot start an export job since table %s does not exist", BigQueryHelpers.toTableSpec(tableToExtract))); } + // emit this table ID as a lineage source + Lineage.getSources().add(BigQueryHelpers.dataCatalogName(tableToExtract, bqOptions)); TableSchema schema = table.getSchema(); JobService jobService = bqServices.getJobService(bqOptions); @@ -152,7 +155,6 @@ public List> split(long desiredBundleSizeBytes, PipelineOptions if (cachedSplitResult == null) { ExtractResult res = extractFiles(options); LOG.info("Extract job produced {} files", res.extractedFiles.size()); - if (res.extractedFiles.size() > 0) { BigQueryOptions bqOptions = options.as(BigQueryOptions.class); final String extractDestinationDir = diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageSourceBase.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageSourceBase.java index fd41000ca5b9..3852d18ec12d 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageSourceBase.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageSourceBase.java @@ -20,6 +20,7 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import com.google.api.services.bigquery.model.Table; +import com.google.api.services.bigquery.model.TableReference; import com.google.api.services.bigquery.model.TableSchema; import com.google.cloud.bigquery.storage.v1.CreateReadSessionRequest; import com.google.cloud.bigquery.storage.v1.DataFormat; @@ -33,6 +34,8 @@ import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.StorageClient; +import org.apache.beam.sdk.metrics.Lineage; +import org.apache.beam.sdk.metrics.StringSet; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.transforms.SerializableFunction; @@ -106,15 +109,21 @@ public List> split( @Nullable Table targetTable = getTargetTable(bqOptions); ReadSession.Builder readSessionBuilder = ReadSession.newBuilder(); + StringSet lineageSources = Lineage.getSources(); if (targetTable != null) { - readSessionBuilder.setTable( - BigQueryHelpers.toTableResourceName(targetTable.getTableReference())); + TableReference tableReference = targetTable.getTableReference(); + readSessionBuilder.setTable(BigQueryHelpers.toTableResourceName(tableReference)); + // register the table as lineage source + lineageSources.add(BigQueryHelpers.dataCatalogName(tableReference, bqOptions)); } else { // If the table does not exist targetTable will be null. // Construct the table id if we can generate it. For error recording/logging. @Nullable String tableReferenceId = getTargetTableId(bqOptions); if (tableReferenceId != null) { readSessionBuilder.setTable(tableReferenceId); + // register the table as lineage source + TableReference tableReference = BigQueryHelpers.parseTableUrn(tableReferenceId); + lineageSources.add(BigQueryHelpers.dataCatalogName(tableReference, bqOptions)); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/CreateTables.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/CreateTables.java index 7e5299b7e674..a55bcc3fe025 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/CreateTables.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/CreateTables.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; +import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; @@ -117,8 +118,13 @@ public void processElement(ProcessContext context) { Supplier<@Nullable TableConstraints> tableConstraintsSupplier = () -> dynamicDestinations.getTableConstraints(dest); + BigQueryOptions bqOptions = context.getPipelineOptions().as(BigQueryOptions.class); + Lineage.getSinks() + .add( + BigQueryHelpers.dataCatalogName( + tableDestination1.getTableReference(), bqOptions)); return CreateTableHelpers.possiblyCreateTable( - context.getPipelineOptions().as(BigQueryOptions.class), + bqOptions, tableDestination1, schemaSupplier, tableConstraintsSupplier, diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords.java index ce5e7b4854e9..8a902ec6d264 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords.java @@ -32,6 +32,8 @@ import com.google.cloud.bigquery.storage.v1.WriteStream.Type; import com.google.protobuf.ByteString; import com.google.protobuf.DescriptorProtos; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.DescriptorValidationException; import com.google.protobuf.DynamicMessage; import io.grpc.Status; import java.io.IOException; @@ -61,6 +63,7 @@ import org.apache.beam.sdk.io.gcp.bigquery.StorageApiDynamicDestinations.MessageConverter; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Distribution; +import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.DoFn; @@ -267,6 +270,7 @@ public AppendRowsContext( } class DestinationState { + private final TableDestination tableDestination; private final String tableUrn; private final String shortTableUrn; private String streamName = ""; @@ -298,6 +302,7 @@ class DestinationState { private final boolean includeCdcColumns; public DestinationState( + TableDestination tableDestination, String tableUrn, String shortTableUrn, MessageConverter messageConverter, @@ -309,6 +314,7 @@ public DestinationState( Callable tryCreateTable, boolean includeCdcColumns) throws Exception { + this.tableDestination = tableDestination; this.tableUrn = tableUrn; this.shortTableUrn = shortTableUrn; this.pendingMessages = Lists.newArrayList(); @@ -327,6 +333,10 @@ public DestinationState( } } + public TableDestination getTableDestination() { + return tableDestination; + } + void teardown() { maybeTickleCache(); if (appendClientInfo != null) { @@ -763,7 +773,7 @@ long flush( invalidateWriteStream(); allowedRetry = 5; } else { - allowedRetry = 10; + allowedRetry = 35; } // Maximum number of times we retry before we fail the work item. @@ -826,21 +836,28 @@ long flush( c, BigQuerySinkMetrics.RpcMethod.APPEND_ROWS, shortTableUrn); if (successfulRowsReceiver != null) { - for (int i = 0; i < c.protoRows.getSerializedRowsCount(); ++i) { - ByteString rowBytes = c.protoRows.getSerializedRowsList().get(i); - try { - TableRow row = - TableRowToStorageApiProto.tableRowFromMessage( - DynamicMessage.parseFrom( - TableRowToStorageApiProto.wrapDescriptorProto( - Preconditions.checkStateNotNull(appendClientInfo) - .getDescriptor()), - rowBytes), - true); - org.joda.time.Instant timestamp = c.timestamps.get(i); - successfulRowsReceiver.outputWithTimestamp(row, timestamp); - } catch (Exception e) { - LOG.warn("Failure parsing TableRow", e); + Descriptor descriptor = null; + try { + descriptor = + TableRowToStorageApiProto.wrapDescriptorProto( + Preconditions.checkStateNotNull(appendClientInfo).getDescriptor()); + } catch (DescriptorValidationException e) { + LOG.warn( + "Failure getting proto descriptor. Successful output will not be produced.", + e); + } + if (descriptor != null) { + for (int i = 0; i < c.protoRows.getSerializedRowsCount(); ++i) { + ByteString rowBytes = c.protoRows.getSerializedRowsList().get(i); + try { + TableRow row = + TableRowToStorageApiProto.tableRowFromMessage( + DynamicMessage.parseFrom(descriptor, rowBytes), true); + org.joda.time.Instant timestamp = c.timestamps.get(i); + successfulRowsReceiver.outputWithTimestamp(row, timestamp); + } catch (Exception e) { + LOG.warn("Failure parsing TableRow", e); + } } } } @@ -1050,6 +1067,7 @@ DestinationState createDestinationState( try { messageConverter = messageConverters.get(destination, dynamicDestinations, datasetService); return new DestinationState( + tableDestination1, tableDestination1.getTableUrn(bigQueryOptions), tableDestination1.getShortTableUrn(), messageConverter, @@ -1089,6 +1107,11 @@ public void process( initializedDatasetService, initializedWriteStreamService, pipelineOptions.as(BigQueryOptions.class))); + Lineage.getSinks() + .add( + BigQueryHelpers.dataCatalogName( + state.getTableDestination().getTableReference(), + pipelineOptions.as(BigQueryOptions.class))); OutputReceiver failedRowsReceiver = o.get(failedRowsTag); @Nullable diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java index 1232e1a7097c..a7da19a75f85 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java @@ -62,6 +62,7 @@ import org.apache.beam.sdk.io.gcp.bigquery.StorageApiFlushAndFinalizeDoFn.Operation; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Distribution; +import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.schemas.NoSuchSchemaException; @@ -469,6 +470,11 @@ public void process( final DatasetService datasetService = getDatasetService(pipelineOptions); final WriteStreamService writeStreamService = getWriteStreamService(pipelineOptions); + Lineage.getSinks() + .add( + BigQueryHelpers.dataCatalogName( + tableDestination.getTableReference(), bigQueryOptions)); + Coder destinationCoder = dynamicDestinations.getDestinationCoder(); Callable tryCreateTable = () -> { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProto.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProto.java index c1f452ba93f9..fbc17fb59704 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProto.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProto.java @@ -1099,7 +1099,7 @@ public static TableRow tableRowFromMessage(Message message, boolean includeCdcCo FieldDescriptor fieldDescriptor = field.getKey(); Object fieldValue = field.getValue(); if (includeCdcColumns || !StorageApiCDC.COLUMNS.contains(fieldDescriptor.getName())) { - tableRow.putIfAbsent( + tableRow.put( fieldDescriptor.getName(), jsonValueFromMessageValue(fieldDescriptor, fieldValue, true)); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteRename.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteRename.java index a7177613c60d..1a6a6a4db70d 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteRename.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteRename.java @@ -34,6 +34,7 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.JobService; import org.apache.beam.sdk.io.gcp.bigquery.WriteTables.Result; +import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.transforms.DoFn; @@ -207,6 +208,11 @@ public void processElement( // Process each destination table. // Do not copy if no temp tables are provided. if (!entry.getValue().isEmpty()) { + Lineage.getSinks() + .add( + BigQueryHelpers.dataCatalogName( + entry.getKey().getTableReference(), + c.getPipelineOptions().as(BigQueryOptions.class))); pendingJobs.add(startWriteRename(entry.getKey(), entry.getValue(), c, window)); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java index c6a7d32e2486..ace0bc5a74cd 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java @@ -49,6 +49,7 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.JobService; +import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.transforms.DoFn; @@ -259,6 +260,11 @@ public void processElement( } // This is a temp table. Create a new one for each partition and each pane. tableReference.setTableId(jobIdPrefix); + } else { + Lineage.getSinks() + .add( + BigQueryHelpers.dataCatalogName( + tableReference, c.getPipelineOptions().as(BigQueryOptions.class))); } WriteDisposition writeDisposition = firstPaneWriteDisposition; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java index d25ad7d4871d..d78ae2cb6c57 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java @@ -21,13 +21,16 @@ import static org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; import static org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.BAD_RECORD_TAG; import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects.firstNonNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import com.google.api.gax.batching.BatchingException; import com.google.api.gax.rpc.ApiException; +import com.google.api.gax.rpc.DeadlineExceededException; import com.google.api.gax.rpc.InvalidArgumentException; import com.google.api.gax.rpc.NotFoundException; +import com.google.api.gax.rpc.ResourceExhaustedException; import com.google.auto.value.AutoValue; import com.google.bigtable.v2.MutateRowResponse; import com.google.bigtable.v2.Mutation; @@ -38,6 +41,7 @@ import com.google.cloud.bigtable.data.v2.models.ChangeStreamRecord; import com.google.cloud.bigtable.data.v2.models.KeyOffset; import com.google.protobuf.ByteString; +import io.grpc.StatusRuntimeException; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; @@ -69,6 +73,8 @@ import org.apache.beam.sdk.io.range.ByteKey; import org.apache.beam.sdk.io.range.ByteKeyRange; import org.apache.beam.sdk.io.range.ByteKeyRangeTracker; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; @@ -82,6 +88,7 @@ import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.StringUtils; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; @@ -1109,12 +1116,51 @@ public Write withMaxOutstandingBytes(long bytes) { * always enabled on batch writes and limits the number of outstanding requests to the Bigtable * server. * + *

When enabled, will also set default {@link #withThrottlingReportTargetMs} to 1 minute. + * This enables runner react with increased latency in flush call due to flow control. + * *

Does not modify this object. */ public Write withFlowControl(boolean enableFlowControl) { + BigtableWriteOptions options = getBigtableWriteOptions(); + BigtableWriteOptions.Builder builder = options.toBuilder().setFlowControl(enableFlowControl); + if (enableFlowControl) { + builder = builder.setThrottlingReportTargetMs(60_000); + } + return toBuilder().setBigtableWriteOptions(builder.build()).build(); + } + + /** + * Returns a new {@link BigtableIO.Write} with client side latency based throttling enabled. + * + *

Will also set {@link #withThrottlingReportTargetMs} to the same value. + */ + public Write withThrottlingTargetMs(int throttlingTargetMs) { + BigtableWriteOptions options = getBigtableWriteOptions(); + return toBuilder() + .setBigtableWriteOptions( + options + .toBuilder() + .setThrottlingTargetMs(throttlingTargetMs) + .setThrottlingReportTargetMs(throttlingTargetMs) + .build()) + .build(); + } + + /** + * Returns a new {@link BigtableIO.Write} with throttling time reporting enabled. When write + * request latency exceeded the set value, the amount greater than the target will be considered + * as throttling time and report back to runner. + * + *

If not set, defaults to 3 min for completed batch request. Client side flowing control + * configurations (e.g. {@link #withFlowControl}, {@link #withThrottlingTargetMs} will adjust + * the default value accordingly. Set to 0 to disable throttling time reporting. + */ + public Write withThrottlingReportTargetMs(int throttlingReportTargetMs) { BigtableWriteOptions options = getBigtableWriteOptions(); return toBuilder() - .setBigtableWriteOptions(options.toBuilder().setFlowControl(enableFlowControl).build()) + .setBigtableWriteOptions( + options.toBuilder().setThrottlingReportTargetMs(throttlingReportTargetMs).build()) .build(); } @@ -1283,7 +1329,14 @@ private static class BigtableWriterFn private final Coder>> inputCoder; private final BadRecordRouter badRecordRouter; + private final Counter throttlingMsecs = + Metrics.counter(Metrics.THROTTLE_TIME_NAMESPACE, Metrics.THROTTLE_TIME_COUNTER_NAME); + + private final int throttleReportThresMsecs; + private transient Set> badRecords = null; + // Due to callback thread not supporting Beam metrics, Record pending metrics and report later. + private transient long pendingThrottlingMsecs; // Assign serviceEntry in startBundle and clear it in tearDown. @Nullable private BigtableServiceEntry serviceEntry; @@ -1301,6 +1354,8 @@ private static class BigtableWriterFn this.badRecordRouter = badRecordRouter; this.failures = new ConcurrentLinkedQueue<>(); this.id = factory.newId(); + // a request completed more than this time will be considered throttled. Disabled if set to 0 + throttleReportThresMsecs = firstNonNull(writeOptions.getThrottlingReportTargetMs(), 180_000); LOG.debug("Created Bigtable Write Fn with writeOptions {} ", writeOptions); } @@ -1322,20 +1377,52 @@ public void startBundle(StartBundleContext c) throws IOException { public void processElement(ProcessContext c, BoundedWindow window) throws Exception { checkForFailures(); KV> record = c.element(); - bigtableWriter.writeRecord(record).whenComplete(handleMutationException(record, window)); + Instant writeStart = Instant.now(); + pendingThrottlingMsecs = 0; + bigtableWriter + .writeRecord(record) + .whenComplete(handleMutationException(record, window, writeStart)); + if (pendingThrottlingMsecs > 0) { + throttlingMsecs.inc(pendingThrottlingMsecs); + } ++recordsWritten; seenWindows.compute(window, (key, count) -> (count != null ? count : 0) + 1); } private BiConsumer handleMutationException( - KV> record, BoundedWindow window) { + KV> record, BoundedWindow window, Instant writeStart) { return (MutateRowResponse result, Throwable exception) -> { if (exception != null) { if (isDataException(exception)) { retryIndividualRecord(record, window); } else { + // Exception due to resource unavailable or rate limited, + // including DEADLINE_EXCEEDED and RESOURCE_EXHAUSTED. + boolean isResourceException = false; + if (exception instanceof StatusRuntimeException) { + StatusRuntimeException se = (StatusRuntimeException) exception; + if (io.grpc.Status.DEADLINE_EXCEEDED.equals(se.getStatus()) + || io.grpc.Status.RESOURCE_EXHAUSTED.equals(se.getStatus())) { + isResourceException = true; + } + } else if (exception instanceof DeadlineExceededException + || exception instanceof ResourceExhaustedException) { + isResourceException = true; + } + if (isResourceException) { + pendingThrottlingMsecs = new Duration(writeStart, Instant.now()).getMillis(); + } failures.add(new BigtableWriteException(record, exception)); } + } else { + // add the excessive amount to throttling metrics if elapsed time > target latency + if (throttleReportThresMsecs > 0) { + long excessTime = + new Duration(writeStart, Instant.now()).getMillis() - throttleReportThresMsecs; + if (excessTime > 0) { + pendingThrottlingMsecs = excessTime; + } + } } }; } @@ -1371,8 +1458,8 @@ private static boolean isDataException(Throwable e) { @FinishBundle public void finishBundle(FinishBundleContext c) throws Exception { try { - if (bigtableWriter != null) { + Instant closeStart = Instant.now(); try { bigtableWriter.close(); } catch (IOException e) { @@ -1381,9 +1468,18 @@ public void finishBundle(FinishBundleContext c) throws Exception { // to the error queue. Bigtable will successfully write other failures in the batch, // so this exception should be ignored if (!(e.getCause() instanceof BatchingException)) { + throttlingMsecs.inc(new Duration(closeStart, Instant.now()).getMillis()); throw e; } } + // add the excessive amount to throttling metrics if elapsed time > target latency + if (throttleReportThresMsecs > 0) { + long excessTime = + new Duration(closeStart, Instant.now()).getMillis() - throttleReportThresMsecs; + if (excessTime > 0) { + throttlingMsecs.inc(excessTime); + } + } bigtableWriter = null; } @@ -2015,7 +2111,7 @@ public BigtableWriteException(KV> record, Throwab super( String.format( "Error mutating row %s with mutations %s", - record.getKey().toStringUtf8(), record.getValue()), + record.getKey().toStringUtf8(), StringUtils.leftTruncate(record.getValue(), 100)), cause); this.record = record; } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java index 06e0108259d5..10cfa724c2ad 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableServiceImpl.java @@ -24,6 +24,7 @@ import com.google.api.gax.batching.BatchingException; import com.google.api.gax.grpc.GrpcCallContext; import com.google.api.gax.rpc.ApiException; +import com.google.api.gax.rpc.DeadlineExceededException; import com.google.api.gax.rpc.ResponseObserver; import com.google.api.gax.rpc.ServerStream; import com.google.api.gax.rpc.StreamController; @@ -611,6 +612,9 @@ public void onFailure(Throwable throwable) { if (throwable instanceof StatusRuntimeException) { serviceCallMetric.call( ((StatusRuntimeException) throwable).getStatus().getCode().value()); + } else if (throwable instanceof DeadlineExceededException) { + // incoming throwable can be a StatusRuntimeException or a specific grpc ApiException + serviceCallMetric.call(504); } else { serviceCallMetric.call("unknown"); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteOptions.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteOptions.java index a63cc575809b..5963eb6be3ce 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteOptions.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteOptions.java @@ -57,6 +57,9 @@ abstract class BigtableWriteOptions implements Serializable { /** Returns the target latency if latency based throttling is enabled. */ abstract @Nullable Integer getThrottlingTargetMs(); + /** Returns the target latency if latency based throttling report to runner is enabled. */ + abstract @Nullable Integer getThrottlingReportTargetMs(); + /** Returns true if batch write flow control is enabled. Otherwise return false. */ abstract @Nullable Boolean getFlowControl(); @@ -88,6 +91,8 @@ abstract static class Builder { abstract Builder setThrottlingTargetMs(int targetMs); + abstract Builder setThrottlingReportTargetMs(int targetMs); + abstract Builder setFlowControl(boolean enableFlowControl); abstract Builder setCloseWaitTimeout(Duration timeout); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFn.java index 826710d9c588..c90ec97bfe35 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFn.java @@ -42,6 +42,7 @@ import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators.Manual; +import org.apache.beam.sdk.util.SerializableSupplier; import org.apache.beam.sdk.values.KV; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java index 86cd7a3439aa..1563b0b059f2 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/DatastoreV1.java @@ -1711,7 +1711,7 @@ static class DatastoreWriterFn extends DoFn { private WriteBatcher writeBatcher; private transient AdaptiveThrottler adaptiveThrottler; private final Counter throttlingMsecs = - Metrics.counter(DatastoreWriterFn.class, "throttling-msecs"); + Metrics.counter(DatastoreWriterFn.class, Metrics.THROTTLE_TIME_COUNTER_NAME); private final Counter rpcErrors = Metrics.counter(DatastoreWriterFn.class, "datastoreRpcErrors"); private final Counter rpcSuccesses = diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/RampupThrottlingFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/RampupThrottlingFn.java index db098c0a5166..ae94d4b612d0 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/RampupThrottlingFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/datastore/RampupThrottlingFn.java @@ -53,7 +53,8 @@ public class RampupThrottlingFn extends DoFn implements Serializable { private final PCollectionView firstInstantSideInput; @VisibleForTesting - Counter throttlingMsecs = Metrics.counter(RampupThrottlingFn.class, "throttling-msecs"); + Counter throttlingMsecs = + Metrics.counter(RampupThrottlingFn.class, Metrics.THROTTLE_TIME_COUNTER_NAME); // Initialized on every setup. private transient MovingFunction successfulOps; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreOptions.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreOptions.java index a292a106e51f..5adc9ef38f36 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreOptions.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreOptions.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.io.gcp.firestore; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.PipelineOptions; @@ -75,4 +76,14 @@ public interface FirestoreOptions extends PipelineOptions { * @param host the host and port to connect to */ void setFirestoreHost(String host); + + /** The Firestore project ID to connect to. */ + @Description("Firestore project ID") + @Nullable + String getFirestoreProject(); + + /** + * Set the Firestore project ID, it will override the value from {@link GcpOptions#getProject()}. + */ + void setFirestoreProject(String firestoreProject); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1ReadFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1ReadFn.java index b4a334b75c99..51e5efa380e8 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1ReadFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1ReadFn.java @@ -635,9 +635,14 @@ public void setup() { /** {@inheritDoc} */ @Override public final void startBundle(StartBundleContext c) { - String project = c.getPipelineOptions().as(GcpOptions.class).getProject(); + String project = c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreProject(); + if (project == null) { + project = c.getPipelineOptions().as(GcpOptions.class).getProject(); + } projectId = - requireNonNull(project, "project must be defined on GcpOptions of PipelineOptions"); + requireNonNull( + project, + "project must be defined on FirestoreOptions or GcpOptions of PipelineOptions"); firestoreStub = firestoreStatefulComponentFactory.getFirestoreStub(c.getPipelineOptions()); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1WriteFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1WriteFn.java index 3e9e1890c1e3..09378d4f80c5 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1WriteFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1WriteFn.java @@ -202,11 +202,16 @@ public void setup() { @Override public final void startBundle(StartBundleContext c) { - String project = c.getPipelineOptions().as(GcpOptions.class).getProject(); + String project = c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreProject(); + if (project == null) { + project = c.getPipelineOptions().as(GcpOptions.class).getProject(); + } String databaseId = c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreDb(); databaseRootName = DatabaseRootName.of( - requireNonNull(project, "project must be defined on GcpOptions of PipelineOptions"), + requireNonNull( + project, + "project must be defined on FirestoreOptions or GcpOptions of PipelineOptions"), requireNonNull( databaseId, "firestoreDb must be defined on FirestoreOptions of PipelineOptions")); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java index 01848d92d928..6233cf669080 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java @@ -1488,7 +1488,7 @@ public PDone expand(PCollection input) { .get(BAD_RECORD_TAG) .setCoder(BadRecord.getCoder(input.getPipeline()))); PCollection pubsubMessages = - pubsubMessageTuple.get(pubsubMessageTupleTag).setCoder(new PubsubMessageWithTopicCoder()); + pubsubMessageTuple.get(pubsubMessageTupleTag).setCoder(PubsubMessageWithTopicCoder.of()); switch (input.isBounded()) { case BOUNDED: pubsubMessages.apply( diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessageWithTopicCoder.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessageWithTopicCoder.java index d10b9a2f1066..768aebe54e65 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessageWithTopicCoder.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessageWithTopicCoder.java @@ -45,8 +45,8 @@ public static Coder of(TypeDescriptor ignored) { return of(); } - public static PubsubMessageWithAttributesAndMessageIdCoder of() { - return new PubsubMessageWithAttributesAndMessageIdCoder(); + public static PubsubMessageWithTopicCoder of() { + return new PubsubMessageWithTopicCoder(); } @Override diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/SubscriptionPartitionLoader.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/SubscriptionPartitionLoader.java index 589af5236de4..60ba6e0c65c0 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/SubscriptionPartitionLoader.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/SubscriptionPartitionLoader.java @@ -24,7 +24,6 @@ import com.google.cloud.pubsublite.PartitionLookupUtils; import com.google.cloud.pubsublite.SubscriptionPath; import com.google.cloud.pubsublite.TopicPath; -import org.apache.beam.sdk.testing.SerializableMatchers.SerializableSupplier; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Impulse; import org.apache.beam.sdk.transforms.PTransform; @@ -34,6 +33,7 @@ import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.splittabledofn.SplitResult; import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators; +import org.apache.beam.sdk.util.SerializableSupplier; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java index dc3844504218..943efc9883b6 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java @@ -30,6 +30,7 @@ import com.google.cloud.spanner.DatabaseAdminClient; import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.spanner.DatabaseId; +import com.google.cloud.spanner.SessionPoolOptions; import com.google.cloud.spanner.Spanner; import com.google.cloud.spanner.SpannerOptions; import com.google.cloud.spanner.v1.stub.SpannerStubSettings; @@ -229,7 +230,9 @@ static SpannerOptions buildSpannerOptions(SpannerConfig spannerConfig) { if (credentials != null && credentials.get() != null) { builder.setCredentials(credentials.get()); } - + SessionPoolOptions sessionPoolOptions = + SessionPoolOptions.newBuilder().setFailIfPoolExhausted().build(); + builder.setSessionPoolOption(sessionPoolOptions); return builder.build(); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOReadTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOReadTest.java index b290cee89e28..5c43666e79e5 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOReadTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOReadTest.java @@ -20,6 +20,7 @@ import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryResourceNaming.createTempTableReference; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; @@ -45,6 +46,7 @@ import java.util.concurrent.ExecutionException; import org.apache.avro.specific.SpecificDatumReader; import org.apache.avro.specific.SpecificRecordBase; +import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.SerializableCoder; @@ -58,6 +60,10 @@ import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; import org.apache.beam.sdk.io.gcp.testing.FakeJobService; +import org.apache.beam.sdk.metrics.Lineage; +import org.apache.beam.sdk.metrics.MetricNameFilter; +import org.apache.beam.sdk.metrics.MetricQueryResults; +import org.apache.beam.sdk.metrics.MetricsFilter; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.ValueProvider; @@ -344,6 +350,21 @@ private void checkTypedReadQueryObjectWithValidate( assertEquals(validate, read.getValidate()); } + private void checkLineageSourceMetric(PipelineResult pipelineResult, String tableName) { + MetricQueryResults lineageMetrics = + pipelineResult + .metrics() + .queryMetrics( + MetricsFilter.builder() + .addNameFilter( + MetricNameFilter.named( + Lineage.LINEAGE_NAMESPACE, Lineage.SOURCE_METRIC_NAME)) + .build()); + assertThat( + lineageMetrics.getStringSets().iterator().next().getCommitted().getStringSet(), + contains("bigquery:" + tableName.replace(':', '.'))); + } + @Before public void setUp() throws ExecutionException, IOException, InterruptedException { FakeDatasetService.setUp(); @@ -578,7 +599,11 @@ public void processElement(ProcessContext c) throws Exception { new MyData("a", 1L, bd1, bd2), new MyData("b", 2L, bd1, bd2), new MyData("c", 3L, bd1, bd2))); - p.run(); + PipelineResult result = p.run(); + // Skip when direct runner splits outside of a counters context. + if (useTemplateCompatibility) { + checkLineageSourceMetric(result, "non-executing-project:somedataset.sometable"); + } } @Test diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java index d303948fe443..bc90d4c8bae7 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java @@ -28,6 +28,7 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -93,6 +94,7 @@ import org.apache.avro.io.DatumWriter; import org.apache.avro.io.Encoder; import org.apache.beam.runners.direct.DirectOptions; +import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.Coder; @@ -115,6 +117,10 @@ import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; import org.apache.beam.sdk.io.gcp.testing.FakeJobService; +import org.apache.beam.sdk.metrics.Lineage; +import org.apache.beam.sdk.metrics.MetricNameFilter; +import org.apache.beam.sdk.metrics.MetricQueryResults; +import org.apache.beam.sdk.metrics.MetricsFilter; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.schemas.JavaFieldSchema; @@ -278,6 +284,20 @@ public void evaluate() throws Throwable { .withDatasetService(fakeDatasetService) .withJobService(fakeJobService); + private void checkLineageSinkMetric(PipelineResult pipelineResult, String tableName) { + MetricQueryResults lineageMetrics = + pipelineResult + .metrics() + .queryMetrics( + MetricsFilter.builder() + .addNameFilter( + MetricNameFilter.named(Lineage.LINEAGE_NAMESPACE, Lineage.SINK_METRIC_NAME)) + .build()); + assertThat( + lineageMetrics.getStringSets().iterator().next().getCommitted().getStringSet(), + hasItem("bigquery:" + tableName.replace(':', '.'))); + } + @Before public void setUp() throws ExecutionException, IOException, InterruptedException { FakeDatasetService.setUp(); @@ -488,7 +508,7 @@ private void verifySideInputs() { .containsInAnyOrder(expectedTables); } - p.run(); + PipelineResult pipelineResult = p.run(); Map> expectedTableRows = Maps.newHashMap(); for (String anUserList : userList) { @@ -505,6 +525,7 @@ private void verifySideInputs() { assertThat( fakeDatasetService.getAllRows("project-id", "dataset-id", "userid-" + entry.getKey()), containsInAnyOrder(Iterables.toArray(entry.getValue(), TableRow.class))); + checkLineageSinkMetric(pipelineResult, "project-id.dataset-id.userid-" + entry.getKey()); } } @@ -680,7 +701,7 @@ public void runStreamingFileLoads(String tableRef, boolean useTempTables, boolea } p.apply(testStream).apply(writeTransform); - p.run(); + PipelineResult pipelineResult = p.run(); final int projectIdSplitter = tableRef.indexOf(':'); final String projectId = @@ -689,6 +710,9 @@ public void runStreamingFileLoads(String tableRef, boolean useTempTables, boolea assertThat( fakeDatasetService.getAllRows(projectId, "dataset-id", "table-id"), containsInAnyOrder(Iterables.toArray(elements, TableRow.class))); + + checkLineageSinkMetric( + pipelineResult, tableRef.contains(projectId) ? tableRef : projectId + ":" + tableRef); } public void runStreamingFileLoads(String tableRef) throws Exception { @@ -828,11 +852,12 @@ public void testBatchFileLoads() throws Exception { PAssert.that(result.getSuccessfulTableLoads()) .containsInAnyOrder(new TableDestination("project-id:dataset-id.table-id", null)); - p.run(); + PipelineResult pipelineResult = p.run(); assertThat( fakeDatasetService.getAllRows("project-id", "dataset-id", "table-id"), containsInAnyOrder(Iterables.toArray(elements, TableRow.class))); + checkLineageSinkMetric(pipelineResult, "project-id.dataset-id.table-id"); } @Test @@ -861,11 +886,12 @@ public void testBatchFileLoadsWithTempTables() throws Exception { PAssert.that(result.getSuccessfulTableLoads()) .containsInAnyOrder(new TableDestination("project-id:dataset-id.table-id", null)); - p.run(); + PipelineResult pipelineResult = p.run(); assertThat( fakeDatasetService.getAllRows("project-id", "dataset-id", "table-id"), containsInAnyOrder(Iterables.toArray(elements, TableRow.class))); + checkLineageSinkMetric(pipelineResult, "project-id.dataset-id.table-id"); } @Test diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySinkMetricsTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySinkMetricsTest.java index 50660326275c..8695a445c118 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySinkMetricsTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySinkMetricsTest.java @@ -37,6 +37,7 @@ import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Histogram; import org.apache.beam.sdk.metrics.MetricName; +import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.util.HistogramData; import org.apache.beam.sdk.values.KV; @@ -178,7 +179,8 @@ public void testThrottledTimeCounter() throws Exception { testContainer.assertPerWorkerCounterValue(counterName, 1L); counterName = - MetricName.named(BigQueryServicesImpl.StorageClientImpl.class, "throttling-msecs"); + MetricName.named( + BigQueryServicesImpl.StorageClientImpl.class, Metrics.THROTTLE_TIME_COUNTER_NAME); assertEquals(1L, (long) testContainer.getCounter(counterName).getCumulative()); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java index dd6a55ff4378..e5049b037010 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java @@ -436,6 +436,21 @@ public void testWriteValidationFailsMissingOptionsAndInstanceAndProject() { write.expand(null); } + @Test + public void testWriteClientRateLimitingAlsoSetReportMsecs() { + // client side flow control + BigtableIO.Write write = BigtableIO.write().withTableId("table").withFlowControl(true); + assertEquals( + 60_000, (int) checkNotNull(write.getBigtableWriteOptions().getThrottlingReportTargetMs())); + + // client side latency based throttling + int targetMs = 30_000; + write = BigtableIO.write().withTableId("table").withThrottlingTargetMs(targetMs); + assertEquals( + targetMs, + (int) checkNotNull(write.getBigtableWriteOptions().getThrottlingReportTargetMs())); + } + /** Helper function to make a single row mutation to be written. */ private static KV> makeWrite(String key, String value) { ByteString rowKey = ByteString.copyFromUtf8(key); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFnTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFnTest.java index d4f9da768088..582d9e709c9a 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFnTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFnTest.java @@ -48,6 +48,7 @@ import org.apache.beam.sdk.io.gcp.bigtable.changestreams.restriction.StreamProgress; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator; +import org.apache.beam.sdk.util.SerializableSupplier; import org.apache.beam.sdk.values.KV; import org.joda.time.Duration; import org.joda.time.Instant; diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/FirestoreTestingHelper.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/FirestoreTestingHelper.java index a57dd688d4af..d3a82bf24ced 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/FirestoreTestingHelper.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/it/FirestoreTestingHelper.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.io.gcp.firestore.it; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects.firstNonNull; + import com.google.api.core.ApiFunction; import com.google.api.core.ApiFuture; import com.google.api.core.ApiFutures; @@ -132,7 +134,8 @@ public FirestoreTestingHelper(CleanupMode cleanupMode) { firestoreOptions = FirestoreOptions.newBuilder() .setCredentials(gcpOptions.getGcpCredential()) - .setProjectId(gcpOptions.getProject()) + .setProjectId( + firstNonNull(firestoreBeamOptions.getFirestoreProject(), gcpOptions.getProject())) .setDatabaseId(firestoreBeamOptions.getFirestoreDb()) .setHost(firestoreBeamOptions.getFirestoreHost()) .build(); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIOTest.java index fe6338a501c4..3027db6aee9d 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIOTest.java @@ -732,7 +732,7 @@ public void testWriteMalformedMessagesWithErrorHandler() throws Exception { PCollection messages = pipeline.apply( Create.timestamped(ImmutableList.of(pubsubMsg, failingPubsubMsg)) - .withCoder(new PubsubMessageWithTopicCoder())); + .withCoder(PubsubMessageWithTopicCoder.of())); messages.setIsBoundedInternal(PCollection.IsBounded.BOUNDED); ErrorHandler> badRecordErrorHandler = pipeline.registerBadRecordErrorHandler(new ErrorSinkTransform()); @@ -882,7 +882,7 @@ public void testDynamicTopics(boolean isBounded) throws IOException { PCollection messages = pipeline.apply( - Create.timestamped(pubsubMessages).withCoder(new PubsubMessageWithTopicCoder())); + Create.timestamped(pubsubMessages).withCoder(PubsubMessageWithTopicCoder.of())); if (!isBounded) { messages = messages.setIsBoundedInternal(PCollection.IsBounded.UNBOUNDED); } @@ -919,7 +919,7 @@ public void testBigMessageBounded() throws IOException { PCollection messages = pipeline.apply( Create.timestamped(ImmutableList.of(pubsubMsg)) - .withCoder(new PubsubMessageWithTopicCoder())); + .withCoder(PubsubMessageWithTopicCoder.of())); messages.setIsBoundedInternal(PCollection.IsBounded.BOUNDED); messages.apply(PubsubIO.writeMessagesDynamic().withClientFactory(factory)); pipeline.run(); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSinkTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSinkTest.java index c9b6bae45b98..be68083bb28c 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSinkTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSinkTest.java @@ -223,7 +223,7 @@ public void testDynamicTopics() throws IOException { Instant.ofEpochMilli(o.getTimestampMsSinceEpoch()))) .collect(Collectors.toList()); - p.apply(Create.timestamped(pubsubMessages).withCoder(new PubsubMessageWithTopicCoder())) + p.apply(Create.timestamped(pubsubMessages).withCoder(PubsubMessageWithTopicCoder.of())) .apply(sink); p.run(); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/FakeSerializable.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/FakeSerializable.java index 13d44ddfeebf..ded551bb9b05 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/FakeSerializable.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/FakeSerializable.java @@ -20,7 +20,7 @@ import java.io.Serializable; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; -import org.apache.beam.sdk.testing.SerializableMatchers.SerializableSupplier; +import org.apache.beam.sdk.util.SerializableSupplier; /** * A FakeSerializable hides a non-serializable object in a static map and returns a handle into the diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/SubscriptionPartitionLoaderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/SubscriptionPartitionLoaderTest.java index 31b1ad34179c..5d4d99beaeab 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/SubscriptionPartitionLoaderTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/SubscriptionPartitionLoaderTest.java @@ -25,9 +25,9 @@ import com.google.cloud.pubsublite.SubscriptionPath; import com.google.cloud.pubsublite.TopicPath; import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.SerializableMatchers.SerializableSupplier; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.SerializableSupplier; import org.apache.beam.sdk.values.PCollection; import org.joda.time.Duration; import org.junit.Before; diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessorTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessorTest.java index b80fba31d3a2..70105f820536 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessorTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessorTest.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.gcp.spanner; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -163,5 +164,6 @@ public void testBuildSpannerOptionsWithCredential() { assertEquals("project", options.getProjectId()); assertEquals("test-role", options.getDatabaseRole()); assertEquals(testCredential, options.getCredentials()); + assertNotNull(options.getSessionPoolOptions()); } } diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergCatalogConfig.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergCatalogConfig.java index fefef4aa4917..2956d75a266e 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergCatalogConfig.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergCatalogConfig.java @@ -19,214 +19,35 @@ import com.google.auto.value.AutoValue; import java.io.Serializable; -import javax.annotation.Nullable; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import java.util.Properties; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.hadoop.conf.Configuration; -import org.apache.iceberg.CatalogProperties; import org.apache.iceberg.CatalogUtil; import org.checkerframework.dataflow.qual.Pure; @AutoValue public abstract class IcebergCatalogConfig implements Serializable { - - @Pure - public abstract String getName(); - - /* Core Properties */ - @Pure - public abstract @Nullable String getIcebergCatalogType(); - - @Pure - public abstract @Nullable String getCatalogImplementation(); - - @Pure - public abstract @Nullable String getFileIOImplementation(); - - @Pure - public abstract @Nullable String getWarehouseLocation(); - - @Pure - public abstract @Nullable String getMetricsReporterImplementation(); - - /* Caching */ - @Pure - public abstract boolean getCacheEnabled(); - - @Pure - public abstract boolean getCacheCaseSensitive(); - - @Pure - public abstract long getCacheExpirationIntervalMillis(); - - @Pure - public abstract boolean getIOManifestCacheEnabled(); - - @Pure - public abstract long getIOManifestCacheExpirationIntervalMillis(); - - @Pure - public abstract long getIOManifestCacheMaxTotalBytes(); - - @Pure - public abstract long getIOManifestCacheMaxContentLength(); - - @Pure - public abstract @Nullable String getUri(); - - @Pure - public abstract int getClientPoolSize(); - - @Pure - public abstract long getClientPoolEvictionIntervalMs(); - - @Pure - public abstract @Nullable String getClientPoolCacheKeys(); - - @Pure - public abstract @Nullable String getLockImplementation(); - - @Pure - public abstract long getLockHeartbeatIntervalMillis(); - - @Pure - public abstract long getLockHeartbeatTimeoutMillis(); - - @Pure - public abstract int getLockHeartbeatThreads(); - - @Pure - public abstract long getLockAcquireIntervalMillis(); - - @Pure - public abstract long getLockAcquireTimeoutMillis(); - - @Pure - public abstract @Nullable String getAppIdentifier(); - - @Pure - public abstract @Nullable String getUser(); - @Pure - public abstract long getAuthSessionTimeoutMillis(); + public abstract String getCatalogName(); @Pure - public abstract @Nullable Configuration getConfiguration(); + public abstract Properties getProperties(); @Pure public static Builder builder() { - return new AutoValue_IcebergCatalogConfig.Builder() - .setIcebergCatalogType(null) - .setCatalogImplementation(null) - .setFileIOImplementation(null) - .setWarehouseLocation(null) - .setMetricsReporterImplementation(null) // TODO: Set this to our implementation - .setCacheEnabled(CatalogProperties.CACHE_ENABLED_DEFAULT) - .setCacheCaseSensitive(CatalogProperties.CACHE_CASE_SENSITIVE_DEFAULT) - .setCacheExpirationIntervalMillis(CatalogProperties.CACHE_EXPIRATION_INTERVAL_MS_DEFAULT) - .setIOManifestCacheEnabled(CatalogProperties.IO_MANIFEST_CACHE_ENABLED_DEFAULT) - .setIOManifestCacheExpirationIntervalMillis( - CatalogProperties.CACHE_EXPIRATION_INTERVAL_MS_DEFAULT) - .setIOManifestCacheMaxTotalBytes( - CatalogProperties.IO_MANIFEST_CACHE_MAX_TOTAL_BYTES_DEFAULT) - .setIOManifestCacheMaxContentLength( - CatalogProperties.IO_MANIFEST_CACHE_MAX_CONTENT_LENGTH_DEFAULT) - .setUri(null) - .setClientPoolSize(CatalogProperties.CLIENT_POOL_SIZE_DEFAULT) - .setClientPoolEvictionIntervalMs( - CatalogProperties.CLIENT_POOL_CACHE_EVICTION_INTERVAL_MS_DEFAULT) - .setClientPoolCacheKeys(null) - .setLockImplementation(null) - .setLockHeartbeatIntervalMillis(CatalogProperties.LOCK_HEARTBEAT_INTERVAL_MS_DEFAULT) - .setLockHeartbeatTimeoutMillis(CatalogProperties.LOCK_HEARTBEAT_TIMEOUT_MS_DEFAULT) - .setLockHeartbeatThreads(CatalogProperties.LOCK_HEARTBEAT_THREADS_DEFAULT) - .setLockAcquireIntervalMillis(CatalogProperties.LOCK_ACQUIRE_INTERVAL_MS_DEFAULT) - .setLockAcquireTimeoutMillis(CatalogProperties.LOCK_HEARTBEAT_TIMEOUT_MS_DEFAULT) - .setAppIdentifier(null) - .setUser(null) - .setAuthSessionTimeoutMillis(CatalogProperties.AUTH_SESSION_TIMEOUT_MS_DEFAULT) - .setConfiguration(null); - } - - @Pure - public ImmutableMap properties() { - return new PropertyBuilder() - .put(CatalogUtil.ICEBERG_CATALOG_TYPE, getIcebergCatalogType()) - .put(CatalogProperties.CATALOG_IMPL, getCatalogImplementation()) - .put(CatalogProperties.FILE_IO_IMPL, getFileIOImplementation()) - .put(CatalogProperties.WAREHOUSE_LOCATION, getWarehouseLocation()) - .put(CatalogProperties.METRICS_REPORTER_IMPL, getMetricsReporterImplementation()) - .put(CatalogProperties.CACHE_ENABLED, getCacheEnabled()) - .put(CatalogProperties.CACHE_CASE_SENSITIVE, getCacheCaseSensitive()) - .put(CatalogProperties.CACHE_EXPIRATION_INTERVAL_MS, getCacheExpirationIntervalMillis()) - .build(); + return new AutoValue_IcebergCatalogConfig.Builder(); } public org.apache.iceberg.catalog.Catalog catalog() { - Configuration conf = getConfiguration(); - if (conf == null) { - conf = new Configuration(); - } - return CatalogUtil.buildIcebergCatalog(getName(), properties(), conf); + return CatalogUtil.buildIcebergCatalog( + getCatalogName(), Maps.fromProperties(getProperties()), new Configuration()); } @AutoValue.Builder public abstract static class Builder { + public abstract Builder setCatalogName(String catalogName); - /* Core Properties */ - public abstract Builder setName(String name); - - public abstract Builder setIcebergCatalogType(@Nullable String icebergType); - - public abstract Builder setCatalogImplementation(@Nullable String catalogImpl); - - public abstract Builder setFileIOImplementation(@Nullable String fileIOImpl); - - public abstract Builder setWarehouseLocation(@Nullable String warehouse); - - public abstract Builder setMetricsReporterImplementation(@Nullable String metricsImpl); - - /* Caching */ - public abstract Builder setCacheEnabled(boolean cacheEnabled); - - public abstract Builder setCacheCaseSensitive(boolean cacheCaseSensitive); - - public abstract Builder setCacheExpirationIntervalMillis(long expiration); - - public abstract Builder setIOManifestCacheEnabled(boolean enabled); - - public abstract Builder setIOManifestCacheExpirationIntervalMillis(long expiration); - - public abstract Builder setIOManifestCacheMaxTotalBytes(long bytes); - - public abstract Builder setIOManifestCacheMaxContentLength(long length); - - public abstract Builder setUri(@Nullable String uri); - - public abstract Builder setClientPoolSize(int size); - - public abstract Builder setClientPoolEvictionIntervalMs(long interval); - - public abstract Builder setClientPoolCacheKeys(@Nullable String keys); - - public abstract Builder setLockImplementation(@Nullable String lockImplementation); - - public abstract Builder setLockHeartbeatIntervalMillis(long interval); - - public abstract Builder setLockHeartbeatTimeoutMillis(long timeout); - - public abstract Builder setLockHeartbeatThreads(int threads); - - public abstract Builder setLockAcquireIntervalMillis(long interval); - - public abstract Builder setLockAcquireTimeoutMillis(long timeout); - - public abstract Builder setAppIdentifier(@Nullable String id); - - public abstract Builder setUser(@Nullable String user); - - public abstract Builder setAuthSessionTimeoutMillis(long timeout); - - public abstract Builder setConfiguration(@Nullable Configuration conf); + public abstract Builder setProperties(Properties props); public abstract IcebergCatalogConfig build(); } diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergIO.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergIO.java index 75a35e6f8a30..50e0ea8b63d1 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergIO.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergIO.java @@ -22,6 +22,7 @@ import com.google.auto.value.AutoValue; import java.util.Arrays; import java.util.List; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.PBegin; @@ -33,6 +34,13 @@ import org.apache.iceberg.catalog.TableIdentifier; import org.checkerframework.checker.nullness.qual.Nullable; +/** + * The underlying Iceberg connector used by {@link org.apache.beam.sdk.managed.Managed#ICEBERG}. Not + * intended to be used directly. + * + *

For internal use only; no backwards compatibility guarantees + */ +@Internal public class IcebergIO { public static WriteRows writeRows(IcebergCatalogConfig catalog) { diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProvider.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProvider.java index fb32e18d9374..ef535353efd0 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProvider.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProvider.java @@ -21,19 +21,21 @@ import com.google.auto.value.AutoValue; import java.util.Collections; import java.util.List; +import java.util.Map; +import java.util.Properties; import org.apache.beam.sdk.io.iceberg.IcebergReadSchemaTransformProvider.Config; import org.apache.beam.sdk.managed.ManagedTransformConstants; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.SchemaRegistry; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.iceberg.catalog.TableIdentifier; /** @@ -47,7 +49,6 @@ public class IcebergReadSchemaTransformProvider extends TypedSchemaTransformProv @Override protected SchemaTransform from(Config configuration) { - configuration.validate(); return new IcebergReadSchemaTransform(configuration); } @@ -68,21 +69,24 @@ public static Builder builder() { return new AutoValue_IcebergReadSchemaTransformProvider_Config.Builder(); } + @SchemaFieldDescription("Identifier of the Iceberg table to write to.") public abstract String getTable(); - public abstract IcebergSchemaTransformCatalogConfig getCatalogConfig(); + @SchemaFieldDescription("Name of the catalog containing the table.") + public abstract String getCatalogName(); + + @SchemaFieldDescription("Configuration properties used to set up the Iceberg catalog.") + public abstract Map getCatalogProperties(); @AutoValue.Builder public abstract static class Builder { - public abstract Builder setTable(String tables); + public abstract Builder setTable(String table); - public abstract Builder setCatalogConfig(IcebergSchemaTransformCatalogConfig catalogConfig); + public abstract Builder setCatalogName(String catalogName); - public abstract Config build(); - } + public abstract Builder setCatalogProperties(Map catalogProperties); - public void validate() { - getCatalogConfig().validate(); + public abstract Config build(); } } @@ -109,17 +113,13 @@ Row getConfigurationRow() { @Override public PCollectionRowTuple expand(PCollectionRowTuple input) { - IcebergSchemaTransformCatalogConfig catalogConfig = configuration.getCatalogConfig(); + Properties properties = new Properties(); + properties.putAll(configuration.getCatalogProperties()); IcebergCatalogConfig.Builder catalogBuilder = - IcebergCatalogConfig.builder().setName(catalogConfig.getCatalogName()); - - if (!Strings.isNullOrEmpty(catalogConfig.getCatalogType())) { - catalogBuilder = catalogBuilder.setIcebergCatalogType(catalogConfig.getCatalogType()); - } - if (!Strings.isNullOrEmpty(catalogConfig.getWarehouseLocation())) { - catalogBuilder = catalogBuilder.setWarehouseLocation(catalogConfig.getWarehouseLocation()); - } + IcebergCatalogConfig.builder() + .setCatalogName(configuration.getCatalogName()) + .setProperties(properties); PCollection output = input diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergSchemaTransformCatalogConfig.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergSchemaTransformCatalogConfig.java deleted file mode 100644 index 87b3d2041985..000000000000 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergSchemaTransformCatalogConfig.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.iceberg; - -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; - -import com.google.auto.value.AutoValue; -import java.util.Set; -import org.apache.beam.sdk.schemas.AutoValueSchema; -import org.apache.beam.sdk.schemas.NoSuchSchemaException; -import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.SchemaRegistry; -import org.apache.beam.sdk.schemas.annotations.DefaultSchema; -import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; -import org.apache.beam.sdk.util.Preconditions; -import org.apache.beam.sdk.values.Row; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; -import org.apache.iceberg.CatalogUtil; -import org.checkerframework.checker.nullness.qual.Nullable; - -@DefaultSchema(AutoValueSchema.class) -@AutoValue -public abstract class IcebergSchemaTransformCatalogConfig { - public static Builder builder() { - return new AutoValue_IcebergSchemaTransformCatalogConfig.Builder(); - } - - public abstract String getCatalogName(); - - @SchemaFieldDescription("Valid types are: {hadoop, hive, rest}") - public abstract @Nullable String getCatalogType(); - - public abstract @Nullable String getCatalogImplementation(); - - public abstract @Nullable String getWarehouseLocation(); - - @AutoValue.Builder - public abstract static class Builder { - - public abstract Builder setCatalogName(String catalogName); - - public abstract Builder setCatalogType(String catalogType); - - public abstract Builder setCatalogImplementation(String catalogImplementation); - - public abstract Builder setWarehouseLocation(String warehouseLocation); - - public abstract IcebergSchemaTransformCatalogConfig build(); - } - - public static final Schema SCHEMA; - - static { - try { - // To stay consistent with our SchemaTransform configuration naming conventions, - // we sort lexicographically and convert field names to snake_case - SCHEMA = - SchemaRegistry.createDefault() - .getSchema(IcebergSchemaTransformCatalogConfig.class) - .sorted() - .toSnakeCase(); - } catch (NoSuchSchemaException e) { - throw new RuntimeException(e); - } - } - - @SuppressWarnings("argument") - public Row toRow() { - return Row.withSchema(SCHEMA) - .withFieldValue("catalog_name", getCatalogName()) - .withFieldValue("catalog_type", getCatalogType()) - .withFieldValue("catalog_implementation", getCatalogImplementation()) - .withFieldValue("warehouse_location", getWarehouseLocation()) - .build(); - } - - public static final Set VALID_CATALOG_TYPES = - Sets.newHashSet( - CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP, - CatalogUtil.ICEBERG_CATALOG_TYPE_HIVE, - CatalogUtil.ICEBERG_CATALOG_TYPE_REST); - - public void validate() { - if (!Strings.isNullOrEmpty(getCatalogType())) { - checkArgument( - VALID_CATALOG_TYPES.contains(Preconditions.checkArgumentNotNull(getCatalogType())), - "Invalid catalog type. Please pick one of %s", - VALID_CATALOG_TYPES); - } - } -} diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProvider.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProvider.java index b490693a9adb..b3de7a88c541 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProvider.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProvider.java @@ -21,6 +21,8 @@ import com.google.auto.value.AutoValue; import java.util.Collections; import java.util.List; +import java.util.Map; +import java.util.Properties; import org.apache.beam.sdk.io.iceberg.IcebergWriteSchemaTransformProvider.Config; import org.apache.beam.sdk.managed.ManagedTransformConstants; import org.apache.beam.sdk.schemas.AutoValueSchema; @@ -39,7 +41,6 @@ import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.iceberg.catalog.TableIdentifier; /** @@ -64,7 +65,6 @@ public String description() { @Override protected SchemaTransform from(Config configuration) { - configuration.validate(); return new IcebergWriteSchemaTransform(configuration); } @@ -93,20 +93,21 @@ public static Builder builder() { @SchemaFieldDescription("Identifier of the Iceberg table to write to.") public abstract String getTable(); - @SchemaFieldDescription("Configuration parameters used to set up the Iceberg catalog.") - public abstract IcebergSchemaTransformCatalogConfig getCatalogConfig(); + @SchemaFieldDescription("Name of the catalog containing the table.") + public abstract String getCatalogName(); + + @SchemaFieldDescription("Configuration properties used to set up the Iceberg catalog.") + public abstract Map getCatalogProperties(); @AutoValue.Builder public abstract static class Builder { - public abstract Builder setTable(String tables); + public abstract Builder setTable(String table); - public abstract Builder setCatalogConfig(IcebergSchemaTransformCatalogConfig catalogConfig); + public abstract Builder setCatalogName(String catalogName); - public abstract Config build(); - } + public abstract Builder setCatalogProperties(Map catalogProperties); - public void validate() { - getCatalogConfig().validate(); + public abstract Config build(); } } @@ -133,26 +134,21 @@ Row getConfigurationRow() { @Override public PCollectionRowTuple expand(PCollectionRowTuple input) { - PCollection rows = input.get(INPUT_TAG); - IcebergSchemaTransformCatalogConfig catalogConfig = configuration.getCatalogConfig(); - - IcebergCatalogConfig.Builder catalogBuilder = - IcebergCatalogConfig.builder().setName(catalogConfig.getCatalogName()); + Properties properties = new Properties(); + properties.putAll(configuration.getCatalogProperties()); - if (!Strings.isNullOrEmpty(catalogConfig.getCatalogType())) { - catalogBuilder = catalogBuilder.setIcebergCatalogType(catalogConfig.getCatalogType()); - } - if (!Strings.isNullOrEmpty(catalogConfig.getWarehouseLocation())) { - catalogBuilder = catalogBuilder.setWarehouseLocation(catalogConfig.getWarehouseLocation()); - } + IcebergCatalogConfig catalog = + IcebergCatalogConfig.builder() + .setCatalogName(configuration.getCatalogName()) + .setProperties(properties) + .build(); // TODO: support dynamic destinations IcebergWriteResult result = rows.apply( - IcebergIO.writeRows(catalogBuilder.build()) - .to(TableIdentifier.parse(configuration.getTable()))); + IcebergIO.writeRows(catalog).to(TableIdentifier.parse(configuration.getTable()))); PCollection snapshots = result diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java index 467a2cbaf242..1c5686bfde91 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java @@ -108,8 +108,7 @@ public static void beforeClass() { catalogHadoopConf = new Configuration(); catalogHadoopConf.set("fs.gs.project.id", options.getProject()); - catalogHadoopConf.set( - "fs.gs.auth.service.account.json.keyfile", System.getenv("GOOGLE_APPLICATION_CREDENTIALS")); + catalogHadoopConf.set("fs.gs.auth.type", "APPLICATION_DEFAULT"); } @Before @@ -206,12 +205,12 @@ public void testRead() throws Exception { Map config = ImmutableMap.builder() .put("table", tableId.toString()) + .put("catalog_name", "test-name") .put( - "catalog_config", + "catalog_properties", ImmutableMap.builder() - .put("catalog_name", "hadoop") - .put("catalog_type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) - .put("warehouse_location", warehouseLocation) + .put("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) + .put("warehouse", warehouseLocation) .build()) .build(); @@ -246,12 +245,12 @@ public void testWrite() { Map config = ImmutableMap.builder() .put("table", tableId.toString()) + .put("catalog_name", "test-name") .put( - "catalog_config", + "catalog_properties", ImmutableMap.builder() - .put("catalog_name", "hadoop") - .put("catalog_type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) - .put("warehouse_location", warehouseLocation) + .put("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) + .put("warehouse", warehouseLocation) .build()) .build(); diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOReadTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOReadTest.java index 12d86811e604..d6db3f689117 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOReadTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOReadTest.java @@ -21,6 +21,7 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import java.util.List; +import java.util.Properties; import java.util.UUID; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -93,12 +94,12 @@ public void testSimpleScan() throws Exception { .map(record -> SchemaAndRowConversions.recordToRow(schema, record)) .collect(Collectors.toList()); + Properties props = new Properties(); + props.setProperty("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP); + props.setProperty("warehouse", warehouse.location); + IcebergCatalogConfig catalogConfig = - IcebergCatalogConfig.builder() - .setName("hadoop") - .setIcebergCatalogType(CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) - .setWarehouseLocation(warehouse.location) - .build(); + IcebergCatalogConfig.builder().setCatalogName("name").setProperties(props).build(); PCollection output = testPipeline diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOWriteTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOWriteTest.java index e04eaf48cb3d..e0a584ec9da9 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOWriteTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOWriteTest.java @@ -23,6 +23,7 @@ import java.io.Serializable; import java.util.List; import java.util.Map; +import java.util.Properties; import java.util.UUID; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.testing.TestPipeline; @@ -75,12 +76,12 @@ public void testSimpleAppend() throws Exception { // Create a table and add records to it. Table table = warehouse.createTable(tableId, TestFixtures.SCHEMA); + Properties props = new Properties(); + props.setProperty("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP); + props.setProperty("warehouse", warehouse.location); + IcebergCatalogConfig catalog = - IcebergCatalogConfig.builder() - .setName("hadoop") - .setIcebergCatalogType(CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) - .setWarehouseLocation(warehouse.location) - .build(); + IcebergCatalogConfig.builder().setCatalogName("name").setProperties(props).build(); testPipeline .apply("Records To Add", Create.of(TestFixtures.asRows(TestFixtures.FILE1SNAPSHOT1))) @@ -109,12 +110,12 @@ public void testDynamicDestinationsWithoutSpillover() throws Exception { Table table2 = warehouse.createTable(table2Id, TestFixtures.SCHEMA); Table table3 = warehouse.createTable(table3Id, TestFixtures.SCHEMA); + Properties props = new Properties(); + props.setProperty("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP); + props.setProperty("warehouse", warehouse.location); + IcebergCatalogConfig catalog = - IcebergCatalogConfig.builder() - .setName("hadoop") - .setIcebergCatalogType(CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) - .setWarehouseLocation(warehouse.location) - .build(); + IcebergCatalogConfig.builder().setCatalogName("name").setProperties(props).build(); DynamicDestinations dynamicDestinations = new DynamicDestinations() { @@ -199,12 +200,12 @@ public void testDynamicDestinationsWithSpillover() throws Exception { elementsPerTable.computeIfAbsent(tableId, ignored -> Lists.newArrayList()).add(element); } + Properties props = new Properties(); + props.setProperty("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP); + props.setProperty("warehouse", warehouse.location); + IcebergCatalogConfig catalog = - IcebergCatalogConfig.builder() - .setName("hadoop") - .setIcebergCatalogType(CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) - .setWarehouseLocation(warehouse.location) - .build(); + IcebergCatalogConfig.builder().setCatalogName("name").setProperties(props).build(); DynamicDestinations dynamicDestinations = new DynamicDestinations() { diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProviderTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProviderTest.java index 46168a487dda..bc15021fa2b0 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProviderTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProviderTest.java @@ -21,6 +21,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; @@ -52,16 +53,15 @@ public class IcebergReadSchemaTransformProviderTest { @Test public void testBuildTransformWithRow() { - Row catalogConfigRow = - Row.withSchema(IcebergSchemaTransformCatalogConfig.SCHEMA) - .withFieldValue("catalog_name", "test_name") - .withFieldValue("catalog_type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) - .withFieldValue("warehouse_location", "test_location") - .build(); + Map properties = new HashMap<>(); + properties.put("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP); + properties.put("warehouse", "test_location"); + Row transformConfigRow = Row.withSchema(new IcebergReadSchemaTransformProvider().configurationSchema()) .withFieldValue("table", "test_table_identifier") - .withFieldValue("catalog_config", catalogConfigRow) + .withFieldValue("catalog_name", "test-name") + .withFieldValue("catalog_properties", properties) .build(); new IcebergReadSchemaTransformProvider().from(transformConfigRow); @@ -97,17 +97,15 @@ public void testSimpleScan() throws Exception { .map(record -> SchemaAndRowConversions.recordToRow(schema, record)) .collect(Collectors.toList()); - IcebergSchemaTransformCatalogConfig catalogConfig = - IcebergSchemaTransformCatalogConfig.builder() - .setCatalogName("hadoop") - .setCatalogType(CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) - .setWarehouseLocation(warehouse.location) - .build(); + Map properties = new HashMap<>(); + properties.put("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP); + properties.put("warehouse", warehouse.location); IcebergReadSchemaTransformProvider.Config readConfig = IcebergReadSchemaTransformProvider.Config.builder() .setTable(identifier) - .setCatalogConfig(catalogConfig) + .setCatalogName("name") + .setCatalogProperties(properties) .build(); PCollection output = @@ -158,10 +156,10 @@ public void testReadUsingManagedTransform() throws Exception { String yamlConfig = String.format( "table: %s\n" - + "catalog_config: \n" - + " catalog_name: hadoop\n" - + " catalog_type: %s\n" - + " warehouse_location: %s", + + "catalog_name: test-name\n" + + "catalog_properties: \n" + + " type: %s\n" + + " warehouse: %s", identifier, CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP, warehouse.location); Map configMap = new Yaml().load(yamlConfig); diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergSchemaTransformTranslationTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergSchemaTransformTranslationTest.java index fb4c98cb0bdf..7863f7812a13 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergSchemaTransformTranslationTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergSchemaTransformTranslationTest.java @@ -25,7 +25,9 @@ import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.UUID; import java.util.stream.Collectors; import org.apache.beam.model.pipeline.v1.ExternalTransforms.SchemaTransformPayload; @@ -42,6 +44,7 @@ import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.InvalidProtocolBufferException; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.iceberg.CatalogUtil; import org.apache.iceberg.catalog.TableIdentifier; import org.junit.ClassRule; @@ -63,18 +66,19 @@ public class IcebergSchemaTransformTranslationTest { static final IcebergReadSchemaTransformProvider READ_PROVIDER = new IcebergReadSchemaTransformProvider(); + private static final Map CATALOG_PROPERTIES = + ImmutableMap.builder() + .put("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) + .put("warehouse", "test_location") + .build(); + @Test public void testReCreateWriteTransformFromRow() { - Row catalogConfigRow = - Row.withSchema(IcebergSchemaTransformCatalogConfig.SCHEMA) - .withFieldValue("catalog_name", "test_name") - .withFieldValue("catalog_type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) - .withFieldValue("warehouse_location", "test_location") - .build(); Row transformConfigRow = Row.withSchema(WRITE_PROVIDER.configurationSchema()) .withFieldValue("table", "test_table_identifier") - .withFieldValue("catalog_config", catalogConfigRow) + .withFieldValue("catalog_name", "test-name") + .withFieldValue("catalog_properties", CATALOG_PROPERTIES) .build(); IcebergWriteSchemaTransform writeTransform = (IcebergWriteSchemaTransform) WRITE_PROVIDER.from(transformConfigRow); @@ -101,17 +105,11 @@ public void testWriteTransformProtoTranslation() Collections.singletonList(Row.withSchema(inputSchema).addValue("a").build()))) .setRowSchema(inputSchema); - Row catalogConfigRow = - Row.withSchema(IcebergSchemaTransformCatalogConfig.SCHEMA) - .withFieldValue("catalog_name", "test_catalog") - .withFieldValue("catalog_type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) - .withFieldValue("catalog_implementation", "test_implementation") - .withFieldValue("warehouse_location", warehouse.location) - .build(); Row transformConfigRow = Row.withSchema(WRITE_PROVIDER.configurationSchema()) .withFieldValue("table", "test_identifier") - .withFieldValue("catalog_config", catalogConfigRow) + .withFieldValue("catalog_name", "test-name") + .withFieldValue("catalog_properties", CATALOG_PROPERTIES) .build(); IcebergWriteSchemaTransform writeTransform = @@ -158,16 +156,11 @@ public void testWriteTransformProtoTranslation() @Test public void testReCreateReadTransformFromRow() { // setting a subset of fields here. - Row catalogConfigRow = - Row.withSchema(IcebergSchemaTransformCatalogConfig.SCHEMA) - .withFieldValue("catalog_name", "test_name") - .withFieldValue("catalog_type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) - .withFieldValue("warehouse_location", "test_location") - .build(); Row transformConfigRow = Row.withSchema(READ_PROVIDER.configurationSchema()) .withFieldValue("table", "test_table_identifier") - .withFieldValue("catalog_config", catalogConfigRow) + .withFieldValue("catalog_name", "test-name") + .withFieldValue("catalog_properties", CATALOG_PROPERTIES) .build(); IcebergReadSchemaTransform readTransform = @@ -188,19 +181,17 @@ public void testReadTransformProtoTranslation() throws InvalidProtocolBufferException, IOException { // First build a pipeline Pipeline p = Pipeline.create(); - Row catalogConfigRow = - Row.withSchema(IcebergSchemaTransformCatalogConfig.SCHEMA) - .withFieldValue("catalog_name", "test_catalog") - .withFieldValue("catalog_type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) - .withFieldValue("warehouse_location", warehouse.location) - .build(); String identifier = "default.table_" + Long.toString(UUID.randomUUID().hashCode(), 16); warehouse.createTable(TableIdentifier.parse(identifier), TestFixtures.SCHEMA); + Map properties = new HashMap<>(CATALOG_PROPERTIES); + properties.put("warehouse", warehouse.location); + Row transformConfigRow = Row.withSchema(READ_PROVIDER.configurationSchema()) .withFieldValue("table", identifier) - .withFieldValue("catalog_config", catalogConfigRow) + .withFieldValue("catalog_name", "test-name") + .withFieldValue("catalog_properties", properties) .build(); IcebergReadSchemaTransform readTransform = diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java index 9ef3e9945ec9..75884f4bcf70 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java @@ -23,6 +23,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.Assert.assertEquals; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; @@ -61,16 +62,15 @@ public class IcebergWriteSchemaTransformProviderTest { @Test public void testBuildTransformWithRow() { - Row catalogConfigRow = - Row.withSchema(IcebergSchemaTransformCatalogConfig.SCHEMA) - .withFieldValue("catalog_name", "test_name") - .withFieldValue("catalog_type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) - .withFieldValue("warehouse_location", "test_location") - .build(); + Map properties = new HashMap<>(); + properties.put("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP); + properties.put("warehouse", "test_location"); + Row transformConfigRow = Row.withSchema(new IcebergWriteSchemaTransformProvider().configurationSchema()) .withFieldValue("table", "test_table_identifier") - .withFieldValue("catalog_config", catalogConfigRow) + .withFieldValue("catalog_name", "test-name") + .withFieldValue("catalog_properties", properties) .build(); new IcebergWriteSchemaTransformProvider().from(transformConfigRow); @@ -85,15 +85,15 @@ public void testSimpleAppend() { // Create a table and add records to it. Table table = warehouse.createTable(tableId, TestFixtures.SCHEMA); + Map properties = new HashMap<>(); + properties.put("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP); + properties.put("warehouse", warehouse.location); + Config config = Config.builder() .setTable(identifier) - .setCatalogConfig( - IcebergSchemaTransformCatalogConfig.builder() - .setCatalogName("hadoop") - .setCatalogType(CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) - .setWarehouseLocation(warehouse.location) - .build()) + .setCatalogName("name") + .setCatalogProperties(properties) .build(); PCollectionRowTuple input = @@ -127,10 +127,10 @@ public void testWriteUsingManagedTransform() { String yamlConfig = String.format( "table: %s\n" - + "catalog_config: \n" - + " catalog_name: hadoop\n" - + " catalog_type: %s\n" - + " warehouse_location: %s", + + "catalog_name: test-name\n" + + "catalog_properties: \n" + + " type: %s\n" + + " warehouse: %s", identifier, CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP, warehouse.location); Map configMap = new Yaml().load(yamlConfig); diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/ScanSourceTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/ScanSourceTest.java index c7d5353428c2..143687e3c999 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/ScanSourceTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/ScanSourceTest.java @@ -20,6 +20,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import java.util.List; +import java.util.Properties; import java.util.UUID; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.options.PipelineOptions; @@ -64,14 +65,17 @@ public void testUnstartedReaderReadsSamesItsSource() throws Exception { PipelineOptions options = PipelineOptionsFactory.create(); + Properties props = new Properties(); + props.setProperty("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP); + props.setProperty("warehouse", warehouse.location); + BoundedSource source = new ScanSource( IcebergScanConfig.builder() .setCatalogConfig( IcebergCatalogConfig.builder() - .setName("hadoop") - .setIcebergCatalogType(CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) - .setWarehouseLocation(warehouse.location) + .setCatalogName("name") + .setProperties(props) .build()) .setScanType(IcebergScanConfig.ScanType.TABLE) .setTableIdentifier(simpleTable.name().replace("hadoop.", "").split("\\.")) @@ -103,14 +107,17 @@ public void testInitialSplitting() throws Exception { PipelineOptions options = PipelineOptionsFactory.create(); + Properties props = new Properties(); + props.setProperty("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP); + props.setProperty("warehouse", warehouse.location); + BoundedSource source = new ScanSource( IcebergScanConfig.builder() .setCatalogConfig( IcebergCatalogConfig.builder() - .setName("hadoop") - .setIcebergCatalogType(CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) - .setWarehouseLocation(warehouse.location) + .setCatalogName("name") + .setProperties(props) .build()) .setScanType(IcebergScanConfig.ScanType.TABLE) .setTableIdentifier(simpleTable.name().replace("hadoop.", "").split("\\.")) @@ -146,14 +153,17 @@ public void testDoubleInitialSplitting() throws Exception { PipelineOptions options = PipelineOptionsFactory.create(); + Properties props = new Properties(); + props.setProperty("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP); + props.setProperty("warehouse", warehouse.location); + BoundedSource source = new ScanSource( IcebergScanConfig.builder() .setCatalogConfig( IcebergCatalogConfig.builder() - .setName("hadoop") - .setIcebergCatalogType(CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) - .setWarehouseLocation(warehouse.location) + .setCatalogName("name") + .setProperties(props) .build()) .setScanType(IcebergScanConfig.ScanType.TABLE) .setTableIdentifier(simpleTable.name().replace("hadoop.", "").split("\\.")) diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java index 1fa459a733dd..060e660b0847 100644 --- a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java +++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java @@ -1011,12 +1011,13 @@ public Write withPassword(String password) { * Specify the JMS topic destination name where to send messages to dynamically. The {@link * JmsIO.Write} acts as a publisher on the topic. * - *

This method is exclusive with {@link JmsIO.Write#withQueue(String) and - * {@link JmsIO.Write#withTopic(String)}. The user has to specify a {@link SerializableFunction} - * that takes {@code EventT} object as a parameter, and returns the topic name depending of the content - * of the event object. + *

This method is exclusive with {@link JmsIO.Write#withQueue(String)} and {@link + * JmsIO.Write#withTopic(String)}. The user has to specify a {@link SerializableFunction} that + * takes {@code EventT} object as a parameter, and returns the topic name depending of the + * content of the event object. * *

For example: + * *

{@code
      * SerializableFunction topicNameMapper =
      *   (event ->
diff --git a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/CommonJms.java b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/CommonJms.java
index c0f8cf258d21..1d1245e6877d 100644
--- a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/CommonJms.java
+++ b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/CommonJms.java
@@ -22,6 +22,7 @@
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.function.Supplier;
 import javax.jms.BytesMessage;
 import javax.jms.ConnectionFactory;
 import javax.jms.Message;
@@ -34,8 +35,8 @@
 import org.apache.activemq.transport.TransportFactory;
 import org.apache.activemq.transport.amqp.AmqpTransportFactory;
 import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.util.SerializableSupplier;
 import org.apache.beam.sdk.util.ThrowingSupplier;
-import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier;
 
 /**
  * A common test fixture to create a broker and connection factories for {@link JmsIOIT} & {@link
@@ -47,8 +48,6 @@ public class CommonJms implements Serializable {
   // convenient typedefs and a helper conversion functions
   interface ThrowingSerializableSupplier extends ThrowingSupplier, Serializable {}
 
-  private interface SerializableSupplier extends Serializable, Supplier {}
-
   private static  SerializableSupplier toSerializableSupplier(
       ThrowingSerializableSupplier throwingSerializableSupplier) {
     return () -> {
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
index 5e9ff9ab80c6..68400e83106f 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
@@ -75,6 +75,7 @@
 import org.apache.beam.sdk.transforms.MapElements;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Redistribute;
 import org.apache.beam.sdk.transforms.Reshuffle;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.SimpleFunction;
@@ -600,6 +601,9 @@ public static  Read read() {
         .setDynamicRead(false)
         .setTimestampPolicyFactory(TimestampPolicyFactory.withProcessingTime())
         .setConsumerPollingTimeout(2L)
+        .setRedistributed(false)
+        .setAllowDuplicates(false)
+        .setRedistributeNumKeys(0)
         .build();
   }
 
@@ -698,6 +702,15 @@ public abstract static class Read
     @Pure
     public abstract boolean isDynamicRead();
 
+    @Pure
+    public abstract boolean isRedistributed();
+
+    @Pure
+    public abstract boolean isAllowDuplicates();
+
+    @Pure
+    public abstract int getRedistributeNumKeys();
+
     @Pure
     public abstract @Nullable Duration getWatchTopicPartitionDuration();
 
@@ -757,6 +770,12 @@ abstract Builder setConsumerFactoryFn(
 
       abstract Builder setWatchTopicPartitionDuration(Duration duration);
 
+      abstract Builder setRedistributed(boolean withRedistribute);
+
+      abstract Builder setAllowDuplicates(boolean allowDuplicates);
+
+      abstract Builder setRedistributeNumKeys(int redistributeNumKeys);
+
       abstract Builder setTimestampPolicyFactory(
           TimestampPolicyFactory timestampPolicyFactory);
 
@@ -852,6 +871,22 @@ static  void setupExternalBuilder(
         } else {
           builder.setConsumerPollingTimeout(2L);
         }
+
+        if (config.redistribute != null) {
+          builder.setRedistributed(config.redistribute);
+          if (config.redistributeNumKeys != null) {
+            builder.setRedistributeNumKeys((int) config.redistributeNumKeys);
+          }
+          if (config.allowDuplicates != null) {
+            builder.setAllowDuplicates(config.allowDuplicates);
+          }
+
+        } else {
+          builder.setRedistributed(false);
+          builder.setRedistributeNumKeys(0);
+          builder.setAllowDuplicates(false);
+        }
+        System.out.println("xxx builder service" + builder.toString());
       }
 
       private static  Coder resolveCoder(Class> deserializer) {
@@ -916,6 +951,9 @@ public static class Configuration {
         private Boolean commitOffsetInFinalize;
         private Long consumerPollingTimeout;
         private String timestampPolicy;
+        private Integer redistributeNumKeys;
+        private Boolean redistribute;
+        private Boolean allowDuplicates;
 
         public void setConsumerConfig(Map consumerConfig) {
           this.consumerConfig = consumerConfig;
@@ -960,6 +998,18 @@ public void setTimestampPolicy(String timestampPolicy) {
         public void setConsumerPollingTimeout(Long consumerPollingTimeout) {
           this.consumerPollingTimeout = consumerPollingTimeout;
         }
+
+        public void setRedistributeNumKeys(Integer redistributeNumKeys) {
+          this.redistributeNumKeys = redistributeNumKeys;
+        }
+
+        public void setRedistribute(Boolean redistribute) {
+          this.redistribute = redistribute;
+        }
+
+        public void setAllowDuplicates(Boolean allowDuplicates) {
+          this.allowDuplicates = allowDuplicates;
+        }
       }
     }
 
@@ -1007,6 +1057,30 @@ public Read withTopicPartitions(List topicPartitions) {
       return toBuilder().setTopicPartitions(ImmutableList.copyOf(topicPartitions)).build();
     }
 
+    /**
+     * Sets redistribute transform that hints to the runner to try to redistribute the work evenly.
+     */
+    public Read withRedistribute() {
+      if (getRedistributeNumKeys() == 0 && isRedistributed()) {
+        LOG.warn("This will create a key per record, which is sub-optimal for most use cases.");
+      }
+      return toBuilder().setRedistributed(true).build();
+    }
+
+    public Read withAllowDuplicates(Boolean allowDuplicates) {
+      if (!isAllowDuplicates()) {
+        LOG.warn("Setting this value without setting withRedistribute() will have no effect.");
+      }
+      return toBuilder().setAllowDuplicates(allowDuplicates).build();
+    }
+
+    public Read withRedistributeNumKeys(int redistributeNumKeys) {
+      checkState(
+          isRedistributed(),
+          "withRedistributeNumKeys is ignored if withRedistribute() is not enabled on the transform.");
+      return toBuilder().setRedistributeNumKeys(redistributeNumKeys).build();
+    }
+
     /**
      * Internally sets a {@link java.util.regex.Pattern} of topics to read from. All the partitions
      * from each of the matching topics are read.
@@ -1618,6 +1692,25 @@ public PCollection> expand(PBegin input) {
                   .withMaxNumRecords(kafkaRead.getMaxNumRecords());
         }
 
+        if (kafkaRead.isRedistributed()) {
+          // fail here instead.
+          checkArgument(
+              kafkaRead.isCommitOffsetsInFinalizeEnabled(),
+              "commitOffsetsInFinalize() can't be enabled with isRedistributed");
+          PCollection> output = input.getPipeline().apply(transform);
+          if (kafkaRead.getRedistributeNumKeys() == 0) {
+            return output.apply(
+                "Insert Redistribute",
+                Redistribute.>arbitrarily()
+                    .withAllowDuplicates(kafkaRead.isAllowDuplicates()));
+          } else {
+            return output.apply(
+                "Insert Redistribute with Shards",
+                Redistribute.>arbitrarily()
+                    .withAllowDuplicates(kafkaRead.isAllowDuplicates())
+                    .withNumBuckets((int) kafkaRead.getRedistributeNumKeys()));
+          }
+        }
         return input.getPipeline().apply(transform);
       }
     }
@@ -1637,6 +1730,8 @@ public PCollection> expand(PBegin input) {
                 .withKeyDeserializerProvider(kafkaRead.getKeyDeserializerProvider())
                 .withValueDeserializerProvider(kafkaRead.getValueDeserializerProvider())
                 .withManualWatermarkEstimator()
+                .withRedistribute()
+                .withAllowDuplicates() // must be set with withRedistribute option.
                 .withTimestampPolicyFactory(kafkaRead.getTimestampPolicyFactory())
                 .withCheckStopReadingFn(kafkaRead.getCheckStopReadingFn())
                 .withConsumerPollingTimeout(kafkaRead.getConsumerPollingTimeout());
@@ -1650,6 +1745,15 @@ public PCollection> expand(PBegin input) {
           readTransform =
               readTransform.withBadRecordErrorHandler(kafkaRead.getBadRecordErrorHandler());
         }
+        if (kafkaRead.isRedistributed()) {
+          readTransform = readTransform.withRedistribute();
+        }
+        if (kafkaRead.isAllowDuplicates()) {
+          readTransform = readTransform.withAllowDuplicates();
+        }
+        if (kafkaRead.getRedistributeNumKeys() > 0) {
+          readTransform = readTransform.withRedistributeNumKeys(kafkaRead.getRedistributeNumKeys());
+        }
         PCollection output;
         if (kafkaRead.isDynamicRead()) {
           Set topics = new HashSet<>();
@@ -1679,6 +1783,22 @@ public PCollection> expand(PBegin input) {
                   .apply(Impulse.create())
                   .apply(ParDo.of(new GenerateKafkaSourceDescriptor(kafkaRead)));
         }
+        if (kafkaRead.isRedistributed()) {
+          PCollection> pcol =
+              output.apply(readTransform).setCoder(KafkaRecordCoder.of(keyCoder, valueCoder));
+          if (kafkaRead.getRedistributeNumKeys() == 0) {
+            return pcol.apply(
+                "Insert Redistribute",
+                Redistribute.>arbitrarily()
+                    .withAllowDuplicates(kafkaRead.isAllowDuplicates()));
+          } else {
+            return pcol.apply(
+                "Insert Redistribute with Shards",
+                Redistribute.>arbitrarily()
+                    .withAllowDuplicates(true)
+                    .withNumBuckets((int) kafkaRead.getRedistributeNumKeys()));
+          }
+        }
         return output.apply(readTransform).setCoder(KafkaRecordCoder.of(keyCoder, valueCoder));
       }
     }
@@ -2070,6 +2190,15 @@ public abstract static class ReadSourceDescriptors
     @Pure
     abstract boolean isCommitOffsetEnabled();
 
+    @Pure
+    abstract boolean isRedistribute();
+
+    @Pure
+    abstract boolean isAllowDuplicates();
+
+    @Pure
+    abstract int getRedistributeNumKeys();
+
     @Pure
     abstract @Nullable TimestampPolicyFactory getTimestampPolicyFactory();
 
@@ -2136,6 +2265,12 @@ abstract ReadSourceDescriptors.Builder setBadRecordErrorHandler(
 
       abstract ReadSourceDescriptors.Builder setBounded(boolean bounded);
 
+      abstract ReadSourceDescriptors.Builder setRedistribute(boolean withRedistribute);
+
+      abstract ReadSourceDescriptors.Builder setAllowDuplicates(boolean allowDuplicates);
+
+      abstract ReadSourceDescriptors.Builder setRedistributeNumKeys(int redistributeNumKeys);
+
       abstract ReadSourceDescriptors build();
     }
 
@@ -2148,6 +2283,9 @@ public static  ReadSourceDescriptors read() {
           .setBadRecordRouter(BadRecordRouter.THROWING_ROUTER)
           .setBadRecordErrorHandler(new ErrorHandler.DefaultErrorHandler<>())
           .setConsumerPollingTimeout(2L)
+          .setRedistribute(false)
+          .setAllowDuplicates(false)
+          .setRedistributeNumKeys(0)
           .build()
           .withProcessingTime()
           .withMonotonicallyIncreasingWatermarkEstimator();
@@ -2307,6 +2445,19 @@ public ReadSourceDescriptors withProcessingTime() {
           ReadSourceDescriptors.ExtractOutputTimestampFns.useProcessingTime());
     }
 
+    /** Enable Redistribute. */
+    public ReadSourceDescriptors withRedistribute() {
+      return toBuilder().setRedistribute(true).build();
+    }
+
+    public ReadSourceDescriptors withAllowDuplicates() {
+      return toBuilder().setAllowDuplicates(true).build();
+    }
+
+    public ReadSourceDescriptors withRedistributeNumKeys(int redistributeNumKeys) {
+      return toBuilder().setRedistributeNumKeys(redistributeNumKeys).build();
+    }
+
     /** Use the creation time of {@link KafkaRecord} as the output timestamp. */
     public ReadSourceDescriptors withCreateTime() {
       return withExtractOutputTimestampFn(
@@ -2497,6 +2648,12 @@ public PCollection> expand(PCollection
         }
       }
 
+      if (isRedistribute()) {
+        if (getRedistributeNumKeys() == 0) {
+          LOG.warn("This will create a key per record, which is sub-optimal for most use cases.");
+        }
+      }
+
       if (getConsumerConfig().get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG) == null) {
         LOG.warn(
             "The bootstrapServers is not set. It must be populated through the KafkaSourceDescriptor during runtime otherwise the pipeline will fail.");
@@ -2527,7 +2684,7 @@ public PCollection> expand(PCollection
                             .getSchemaRegistry()
                             .getSchemaCoder(KafkaSourceDescriptor.class),
                         recordCoder));
-        if (isCommitOffsetEnabled() && !configuredKafkaCommit()) {
+        if (isCommitOffsetEnabled() && !configuredKafkaCommit() && !isRedistribute()) {
           outputWithDescriptor =
               outputWithDescriptor
                   .apply(Reshuffle.viaRandomKey())
@@ -2538,6 +2695,7 @@ public PCollection> expand(PCollection
                               .getSchemaRegistry()
                               .getSchemaCoder(KafkaSourceDescriptor.class),
                           recordCoder));
+
           PCollection unused = outputWithDescriptor.apply(new KafkaCommitOffset(this));
           unused.setCoder(VoidCoder.of());
         }
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibility.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibility.java
index ce0434ee88d1..457e0003705e 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibility.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibility.java
@@ -119,6 +119,24 @@ Object getDefaultValue() {
         return Long.valueOf(2);
       }
     },
+    REDISTRIBUTE_NUM_KEYS {
+      @Override
+      Object getDefaultValue() {
+        return Integer.valueOf(0);
+      }
+    },
+    REDISTRIBUTED {
+      @Override
+      Object getDefaultValue() {
+        return false;
+      }
+    },
+    ALLOW_DUPLICATES {
+      @Override
+      Object getDefaultValue() {
+        return false;
+      }
+    },
     ;
 
     private final @NonNull ImmutableSet supportedImplementations;
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java
index 100f06d42d07..d6ec9015a95f 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java
@@ -436,19 +436,6 @@ public ProcessContinuation processElement(
         "Creating Kafka consumer for process continuation for {}",
         kafkaSourceDescriptor.getTopicPartition());
     try (Consumer consumer = consumerFactoryFn.apply(updatedConsumerConfig)) {
-      // Check whether current TopicPartition is still available to read.
-      Set existingTopicPartitions = new HashSet<>();
-      for (List topicPartitionList : consumer.listTopics().values()) {
-        topicPartitionList.forEach(
-            partitionInfo -> {
-              existingTopicPartitions.add(
-                  new TopicPartition(partitionInfo.topic(), partitionInfo.partition()));
-            });
-      }
-      if (!existingTopicPartitions.contains(kafkaSourceDescriptor.getTopicPartition())) {
-        return ProcessContinuation.stop();
-      }
-
       ConsumerSpEL.evaluateAssign(
           consumer, ImmutableList.of(kafkaSourceDescriptor.getTopicPartition()));
       long startOffset = tracker.currentRestriction().getFrom();
@@ -462,6 +449,10 @@ public ProcessContinuation processElement(
         // When there are no records available for the current TopicPartition, self-checkpoint
         // and move to process the next element.
         if (rawRecords.isEmpty()) {
+          if (!topicPartitionExists(
+              kafkaSourceDescriptor.getTopicPartition(), consumer.listTopics())) {
+            return ProcessContinuation.stop();
+          }
           if (timestampPolicy != null) {
             updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker);
           }
@@ -522,6 +513,23 @@ public ProcessContinuation processElement(
     }
   }
 
+  private boolean topicPartitionExists(
+      TopicPartition topicPartition, Map> topicListMap) {
+    // Check if the current TopicPartition still exists.
+    Set existingTopicPartitions = new HashSet<>();
+    for (List topicPartitionList : topicListMap.values()) {
+      topicPartitionList.forEach(
+          partitionInfo -> {
+            existingTopicPartitions.add(
+                new TopicPartition(partitionInfo.topic(), partitionInfo.partition()));
+          });
+    }
+    if (!existingTopicPartitions.contains(topicPartition)) {
+      return false;
+    }
+    return true;
+  }
+
   // see https://github.com/apache/beam/issues/25962
   private ConsumerRecords poll(
       Consumer consumer, TopicPartition topicPartition) {
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java
index 246fdd80d739..f021789a912c 100644
--- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java
@@ -108,7 +108,10 @@ public void testConstructKafkaRead() throws Exception {
                         Field.of("start_read_time", FieldType.INT64),
                         Field.of("commit_offset_in_finalize", FieldType.BOOLEAN),
                         Field.of("timestamp_policy", FieldType.STRING),
-                        Field.of("consumer_polling_timeout", FieldType.INT64)))
+                        Field.of("consumer_polling_timeout", FieldType.INT64),
+                        Field.of("redistribute_num_keys", FieldType.INT32),
+                        Field.of("redistribute", FieldType.BOOLEAN),
+                        Field.of("allow_duplicates", FieldType.BOOLEAN)))
                 .withFieldValue("topics", topics)
                 .withFieldValue("consumer_config", consumerConfig)
                 .withFieldValue("key_deserializer", keyDeserializer)
@@ -117,6 +120,9 @@ public void testConstructKafkaRead() throws Exception {
                 .withFieldValue("commit_offset_in_finalize", false)
                 .withFieldValue("timestamp_policy", "ProcessingTime")
                 .withFieldValue("consumer_polling_timeout", 5L)
+                .withFieldValue("redistribute_num_keys", 0)
+                .withFieldValue("redistribute", false)
+                .withFieldValue("allow_duplicates", false)
                 .build());
 
     RunnerApi.Components defaultInstance = RunnerApi.Components.getDefaultInstance();
@@ -139,6 +145,7 @@ public void testConstructKafkaRead() throws Exception {
     expansionService.expand(request, observer);
     ExpansionApi.ExpansionResponse result = observer.result;
     RunnerApi.PTransform transform = result.getTransform();
+    System.out.println("xxx : " + result.toString());
     assertThat(
         transform.getSubtransformsList(),
         Matchers.hasItem(MatchesPattern.matchesPattern(".*KafkaIO-Read.*")));
@@ -237,7 +244,10 @@ public void testConstructKafkaReadWithoutMetadata() throws Exception {
                         Field.of("value_deserializer", FieldType.STRING),
                         Field.of("start_read_time", FieldType.INT64),
                         Field.of("commit_offset_in_finalize", FieldType.BOOLEAN),
-                        Field.of("timestamp_policy", FieldType.STRING)))
+                        Field.of("timestamp_policy", FieldType.STRING),
+                        Field.of("redistribute_num_keys", FieldType.INT32),
+                        Field.of("redistribute", FieldType.BOOLEAN),
+                        Field.of("allow_duplicates", FieldType.BOOLEAN)))
                 .withFieldValue("topics", topics)
                 .withFieldValue("consumer_config", consumerConfig)
                 .withFieldValue("key_deserializer", keyDeserializer)
@@ -245,6 +255,9 @@ public void testConstructKafkaReadWithoutMetadata() throws Exception {
                 .withFieldValue("start_read_time", startReadTime)
                 .withFieldValue("commit_offset_in_finalize", false)
                 .withFieldValue("timestamp_policy", "ProcessingTime")
+                .withFieldValue("redistribute_num_keys", 0)
+                .withFieldValue("redistribute", false)
+                .withFieldValue("allow_duplicates", false)
                 .build());
 
     RunnerApi.Components defaultInstance = RunnerApi.Components.getDefaultInstance();
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibilityTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibilityTest.java
index 2d306b0d7798..ae939d66c210 100644
--- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibilityTest.java
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibilityTest.java
@@ -103,7 +103,9 @@ public void testPrimitiveKafkaIOReadPropertiesDefaultValueExistence() {
 
   private void testReadTransformCreationWithImplementationBoundProperties(
       Function, KafkaIO.Read> kafkaReadDecorator) {
-    p.apply(kafkaReadDecorator.apply(mkKafkaReadTransform(1000, null, new ValueAsTimestampFn())));
+    p.apply(
+        kafkaReadDecorator.apply(
+            mkKafkaReadTransform(1000, null, new ValueAsTimestampFn(), false, 0)));
     p.run();
   }
 
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
index 07e5b519c013..73aee5aeeef0 100644
--- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
@@ -376,7 +376,7 @@ public Consumer apply(Map config) {
 
   static KafkaIO.Read mkKafkaReadTransform(
       int numElements, @Nullable SerializableFunction, Instant> timestampFn) {
-    return mkKafkaReadTransform(numElements, numElements, timestampFn);
+    return mkKafkaReadTransform(numElements, numElements, timestampFn, false, 0);
   }
 
   /**
@@ -386,7 +386,9 @@ static KafkaIO.Read mkKafkaReadTransform(
   static KafkaIO.Read mkKafkaReadTransform(
       int numElements,
       @Nullable Integer maxNumRecords,
-      @Nullable SerializableFunction, Instant> timestampFn) {
+      @Nullable SerializableFunction, Instant> timestampFn,
+      @Nullable Boolean redistribute,
+      @Nullable Integer numKeys) {
 
     List topics = ImmutableList.of("topic_a", "topic_b");
 
@@ -404,10 +406,16 @@ static KafkaIO.Read mkKafkaReadTransform(
     }
 
     if (timestampFn != null) {
-      return reader.withTimestampFn(timestampFn);
-    } else {
-      return reader;
+      reader = reader.withTimestampFn(timestampFn);
     }
+
+    if (redistribute) {
+      if (numKeys != null) {
+        reader = reader.withRedistribute().withRedistributeNumKeys(numKeys);
+      }
+      reader = reader.withRedistribute();
+    }
+    return reader;
   }
 
   private static class AssertMultipleOf implements SerializableFunction, Void> {
@@ -616,6 +624,42 @@ public void testRiskyConfigurationWarnsProperly() {
     p.run();
   }
 
+  @Test
+  public void testCommitOffsetsInFinalizeAndRedistributeErrors() {
+    thrown.expect(Exception.class);
+    thrown.expectMessage("commitOffsetsInFinalize() can't be enabled with isRedistributed");
+
+    int numElements = 1000;
+
+    PCollection input =
+        p.apply(
+                mkKafkaReadTransform(numElements, numElements, new ValueAsTimestampFn(), true, 0)
+                    .withConsumerConfigUpdates(
+                        ImmutableMap.of(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, true))
+                    .withoutMetadata())
+            .apply(Values.create());
+
+    addCountingAsserts(input, numElements);
+    p.run();
+  }
+
+  @Test
+  public void testNumKeysIgnoredWithRedistributeNotEnabled() {
+    int numElements = 1000;
+
+    PCollection input =
+        p.apply(
+                mkKafkaReadTransform(numElements, numElements, new ValueAsTimestampFn(), false, 0)
+                    .withConsumerConfigUpdates(
+                        ImmutableMap.of(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, true))
+                    .withoutMetadata())
+            .apply(Values.create());
+
+    addCountingAsserts(input, numElements);
+
+    p.run();
+  }
+
   @Test
   public void testUnreachableKafkaBrokers() {
     // Expect an exception when the Kafka brokers are not reachable on the workers.
@@ -1905,7 +1949,7 @@ public void testUnboundedSourceStartReadTime() {
 
     PCollection input =
         p.apply(
-                mkKafkaReadTransform(numElements, maxNumRecords, new ValueAsTimestampFn())
+                mkKafkaReadTransform(numElements, maxNumRecords, new ValueAsTimestampFn(), false, 0)
                     .withStartReadTime(new Instant(startTime))
                     .withoutMetadata())
             .apply(Values.create());
@@ -1929,7 +1973,7 @@ public void testUnboundedSourceStartReadTimeException() {
     int startTime = numElements / 20;
 
     p.apply(
-            mkKafkaReadTransform(numElements, numElements, new ValueAsTimestampFn())
+            mkKafkaReadTransform(numElements, numElements, new ValueAsTimestampFn(), false, 0)
                 .withStartReadTime(new Instant(startTime))
                 .withoutMetadata())
         .apply(Values.create());
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java
index b8ff08485c3b..612b20393d78 100644
--- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java
@@ -515,7 +515,7 @@ public void testProcessElementWithEmptyPoll() throws Exception {
   public void testProcessElementWhenTopicPartitionIsRemoved() throws Exception {
     MockMultiOutputReceiver receiver = new MockMultiOutputReceiver();
     consumer.setRemoved();
-    consumer.setNumOfRecordsPerPoll(10);
+    consumer.setNumOfRecordsPerPoll(-1);
     OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, Long.MAX_VALUE));
     ProcessContinuation result =
         dofnInstance.processElement(
diff --git a/sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslation.java b/sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslation.java
index 9848e429e215..db7b172170a1 100644
--- a/sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslation.java
+++ b/sdks/java/io/kafka/upgrade/src/main/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslation.java
@@ -98,6 +98,9 @@ static class KafkaIOReadWithMetadataTranslator implements TransformPayloadTransl
             .addNullableLogicalTypeField("stop_read_time", new NanosInstant())
             .addBooleanField("is_commit_offset_finalize_enabled")
             .addBooleanField("is_dynamic_read")
+            .addBooleanField("redistribute")
+            .addBooleanField("allows_duplicates")
+            .addNullableInt32Field("redistribute_num_keys")
             .addNullableLogicalTypeField("watch_topic_partition_duration", new NanosDuration())
             .addByteArrayField("timestamp_policy_factory")
             .addNullableMapField("offset_consumer_config", FieldType.STRING, FieldType.BYTES)
@@ -215,6 +218,9 @@ public Row toConfigRow(Read transform) {
                 + " is not supported yet.");
       }
 
+      fieldValues.put("redistribute", transform.isRedistributed());
+      fieldValues.put("redistribute_num_keys", transform.getRedistributeNumKeys());
+      fieldValues.put("allows_duplicates", transform.isAllowDuplicates());
       return Row.withSchema(schema).withFieldValues(fieldValues).build();
     }
 
@@ -325,6 +331,22 @@ public Row toConfigRow(Read transform) {
         if (maxNumRecords != null) {
           transform = transform.withMaxNumRecords(maxNumRecords);
         }
+
+        Boolean isRedistributed = configRow.getBoolean("redistribute");
+        if (isRedistributed != null && isRedistributed) {
+          transform = transform.withRedistribute();
+          Integer redistributeNumKeys =
+              configRow.getValue("redistribute_num_keys") == null
+                  ? Integer.valueOf(0)
+                  : configRow.getInt32("redistribute_num_keys");
+          if (redistributeNumKeys != null && !redistributeNumKeys.equals(0)) {
+            transform = transform.withRedistributeNumKeys(redistributeNumKeys);
+          }
+          Boolean allowDuplicates = configRow.getBoolean("allows_duplicates");
+          if (allowDuplicates != null && allowDuplicates) {
+            transform = transform.withAllowDuplicates(allowDuplicates);
+          }
+        }
         Duration maxReadTime = configRow.getValue("max_read_time");
         if (maxReadTime != null) {
           transform =
diff --git a/sdks/java/io/kafka/upgrade/src/test/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslationTest.java b/sdks/java/io/kafka/upgrade/src/test/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslationTest.java
index f69b9c3649b4..095702a5c6ff 100644
--- a/sdks/java/io/kafka/upgrade/src/test/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslationTest.java
+++ b/sdks/java/io/kafka/upgrade/src/test/java/org/apache/beam/sdk/io/kafka/upgrade/KafkaIOTranslationTest.java
@@ -64,6 +64,7 @@ public class KafkaIOTranslationTest {
     READ_TRANSFORM_SCHEMA_MAPPING.put("getMaxReadTime", "max_read_time");
     READ_TRANSFORM_SCHEMA_MAPPING.put("getStartReadTime", "start_read_time");
     READ_TRANSFORM_SCHEMA_MAPPING.put("getStopReadTime", "stop_read_time");
+    READ_TRANSFORM_SCHEMA_MAPPING.put("getRedistributeNumKeys", "redistribute_num_keys");
     READ_TRANSFORM_SCHEMA_MAPPING.put(
         "isCommitOffsetsInFinalizeEnabled", "is_commit_offset_finalize_enabled");
     READ_TRANSFORM_SCHEMA_MAPPING.put("isDynamicRead", "is_dynamic_read");
diff --git a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/Call.java b/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/Call.java
index b6941f8fcbb6..f9c1a23e64fe 100644
--- a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/Call.java
+++ b/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/Call.java
@@ -40,6 +40,7 @@
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.util.BackOff;
 import org.apache.beam.sdk.util.FluentBackoff;
+import org.apache.beam.sdk.util.SerializableSupplier;
 import org.apache.beam.sdk.util.SerializableUtils;
 import org.apache.beam.sdk.util.Sleeper;
 import org.apache.beam.sdk.values.PCollection;
diff --git a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/DefaultSerializableBackoffSupplier.java b/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/DefaultSerializableBackoffSupplier.java
index 89e9400854d7..b92bce53438f 100644
--- a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/DefaultSerializableBackoffSupplier.java
+++ b/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/DefaultSerializableBackoffSupplier.java
@@ -19,6 +19,7 @@
 
 import org.apache.beam.sdk.util.BackOff;
 import org.apache.beam.sdk.util.FluentBackoff;
+import org.apache.beam.sdk.util.SerializableSupplier;
 import org.apache.beam.sdk.util.SerializableUtils;
 
 /**
diff --git a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/RequestResponseIO.java b/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/RequestResponseIO.java
index 9c5c6128c29a..1bac1dd07386 100644
--- a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/RequestResponseIO.java
+++ b/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/RequestResponseIO.java
@@ -35,6 +35,7 @@
 import org.apache.beam.sdk.transforms.Values;
 import org.apache.beam.sdk.util.BackOff;
 import org.apache.beam.sdk.util.FluentBackoff;
+import org.apache.beam.sdk.util.SerializableSupplier;
 import org.apache.beam.sdk.util.SerializableUtils;
 import org.apache.beam.sdk.util.Sleeper;
 import org.apache.beam.sdk.values.KV;
diff --git a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/WindowedCallShouldBackoff.java b/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/WindowedCallShouldBackoff.java
index fbbafeb906f7..ab078154b8c2 100644
--- a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/WindowedCallShouldBackoff.java
+++ b/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/WindowedCallShouldBackoff.java
@@ -19,6 +19,7 @@
 
 import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
 
+import org.apache.beam.sdk.util.SerializableSupplier;
 import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
diff --git a/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/RequestResponseIOTest.java b/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/RequestResponseIOTest.java
index f54d3e595b03..4cbadf237336 100644
--- a/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/RequestResponseIOTest.java
+++ b/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/RequestResponseIOTest.java
@@ -43,6 +43,7 @@
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.util.BackOff;
+import org.apache.beam.sdk.util.SerializableSupplier;
 import org.apache.beam.sdk.util.Sleeper;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TypeDescriptor;
diff --git a/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/WindowedCallShouldBackoffTest.java b/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/WindowedCallShouldBackoffTest.java
index 5316f251200a..18d452451838 100644
--- a/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/WindowedCallShouldBackoffTest.java
+++ b/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/WindowedCallShouldBackoffTest.java
@@ -20,6 +20,7 @@
 import static org.apache.beam.sdk.testing.SerializableMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
 
+import org.apache.beam.sdk.util.SerializableSupplier;
 import org.joda.time.Duration;
 import org.junit.Test;
 
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeIO.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeIO.java
index 6fad4a89b635..b36b678c0a33 100644
--- a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeIO.java
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/SnowflakeIO.java
@@ -1079,13 +1079,17 @@ public Object[] apply(T element) {
                   ParDo.of(new MapObjectsArrayToCsvFn(getQuotationMark())))
               .setCoder(StringUtf8Coder.of());
 
+      String filePrefix = getFileNameTemplate();
+      if (filePrefix == null) {
+        filePrefix = UUID.randomUUID().toString().subSequence(0, 8).toString();
+      }
       WriteFilesResult filesResult =
           mappedUserData.apply(
               "Write files to specified location",
               FileIO.write()
                   .via(TextIO.sink())
                   .to(stagingBucketDir)
-                  .withPrefix(UUID.randomUUID().toString().subSequence(0, 8).toString())
+                  .withPrefix(filePrefix)
                   .withSuffix(".csv")
                   .withNumShards(numShards)
                   .withCompression(Compression.GZIP));
diff --git a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/SnowflakeBatchServiceImpl.java b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/SnowflakeBatchServiceImpl.java
index e317e44d4901..d86689670b39 100644
--- a/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/SnowflakeBatchServiceImpl.java
+++ b/sdks/java/io/snowflake/src/main/java/org/apache/beam/sdk/io/snowflake/services/SnowflakeBatchServiceImpl.java
@@ -82,7 +82,7 @@ private String copyIntoStage(SnowflakeBatchServiceConfig config) throws SQLExcep
 
     String copyQuery =
         String.format(
-            "COPY INTO '%s' FROM %s STORAGE_INTEGRATION=%s FILE_FORMAT=(TYPE=CSV COMPRESSION=GZIP FIELD_OPTIONALLY_ENCLOSED_BY='%s');",
+            "COPY INTO '%s' FROM %s STORAGE_INTEGRATION=%s FILE_FORMAT=(TYPE=CSV COMPRESSION=GZIP FIELD_OPTIONALLY_ENCLOSED_BY='%s' ESCAPE='\\\\');",
             getProperBucketDir(stagingBucketDir),
             source,
             storageIntegrationName,
@@ -138,7 +138,7 @@ private void copyToTable(SnowflakeBatchServiceConfig config) throws SQLException
     if (!storageIntegrationName.isEmpty()) {
       query =
           String.format(
-              "COPY INTO %s FROM %s FILES=(%s) FILE_FORMAT=(TYPE=CSV FIELD_OPTIONALLY_ENCLOSED_BY='%s' COMPRESSION=GZIP) STORAGE_INTEGRATION=%s;",
+              "COPY INTO %s FROM %s FILES=(%s) FILE_FORMAT=(TYPE=CSV FIELD_OPTIONALLY_ENCLOSED_BY='%s' ESCAPE='\\\\' COMPRESSION=GZIP) STORAGE_INTEGRATION=%s;",
               getTablePath(database, schema, table),
               getProperBucketDir(source),
               files,
@@ -147,7 +147,7 @@ private void copyToTable(SnowflakeBatchServiceConfig config) throws SQLException
     } else {
       query =
           String.format(
-              "COPY INTO %s FROM %s FILES=(%s) FILE_FORMAT=(TYPE=CSV FIELD_OPTIONALLY_ENCLOSED_BY='%s' COMPRESSION=GZIP);",
+              "COPY INTO %s FROM %s FILES=(%s) FILE_FORMAT=(TYPE=CSV FIELD_OPTIONALLY_ENCLOSED_BY='%s' ESCAPE='\\\\' COMPRESSION=GZIP);",
               table, source, files, getASCIICharRepresentation(config.getQuotationMark()));
     }
 
diff --git a/sdks/java/io/solace/build.gradle b/sdks/java/io/solace/build.gradle
index 7c643dc91278..741db51a5772 100644
--- a/sdks/java/io/solace/build.gradle
+++ b/sdks/java/io/solace/build.gradle
@@ -18,9 +18,7 @@
 
 plugins { id 'org.apache.beam.module' }
 applyJavaNature(
-        automaticModuleName: 'org.apache.beam.sdk.io.solace',
-        exportJavadoc: false,
-        publish: false,
+        automaticModuleName: 'org.apache.beam.sdk.io.solace'
 )
 provideIntegrationTestingDependencies()
 enableJavaPerformanceTesting()
@@ -36,17 +34,26 @@ dependencies {
     implementation library.java.joda_time
     implementation library.java.solace
     implementation library.java.google_cloud_core
+    implementation library.java.google_cloud_secret_manager
+    implementation library.java.proto_google_cloud_secret_manager_v1
+    implementation library.java.protobuf_java
     implementation library.java.vendored_guava_32_1_2_jre
     implementation project(":sdks:java:extensions:avro")
     implementation library.java.avro
     permitUnusedDeclared library.java.avro
     implementation library.java.google_api_common
+    implementation library.java.threetenbp
     implementation library.java.gax
     implementation library.java.threetenbp
+    implementation library.java.google_http_client
+    implementation library.java.google_http_client_gson
+    implementation library.java.jackson_core
+    implementation library.java.jackson_databind
 
     testImplementation library.java.junit
     testImplementation project(path: ":sdks:java:io:common")
     testImplementation project(path: ":sdks:java:testing:test-utils")
     testRuntimeOnly library.java.slf4j_jdk14
+    testImplementation library.java.testcontainers_solace
     testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/SolaceIO.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/SolaceIO.java
index e6b0dd34b184..bb9f0c6ea689 100644
--- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/SolaceIO.java
+++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/SolaceIO.java
@@ -22,6 +22,7 @@
 
 import com.google.auto.value.AutoValue;
 import com.solacesystems.jcsmp.BytesXMLMessage;
+import com.solacesystems.jcsmp.DeliveryMode;
 import com.solacesystems.jcsmp.Destination;
 import com.solacesystems.jcsmp.JCSMPFactory;
 import com.solacesystems.jcsmp.Queue;
@@ -31,18 +32,22 @@
 import org.apache.beam.sdk.annotations.Internal;
 import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.io.solace.broker.BasicAuthJcsmpSessionServiceFactory;
+import org.apache.beam.sdk.io.solace.broker.GCPSecretSessionServiceFactory;
 import org.apache.beam.sdk.io.solace.broker.SempClientFactory;
 import org.apache.beam.sdk.io.solace.broker.SessionService;
 import org.apache.beam.sdk.io.solace.broker.SessionServiceFactory;
 import org.apache.beam.sdk.io.solace.data.Solace;
 import org.apache.beam.sdk.io.solace.data.Solace.SolaceRecordMapper;
 import org.apache.beam.sdk.io.solace.read.UnboundedSolaceSource;
+import org.apache.beam.sdk.io.solace.write.SolaceOutput;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.schemas.NoSuchSchemaException;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.TypeDescriptor;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import org.checkerframework.checker.nullness.qual.Nullable;
@@ -194,6 +199,186 @@
  * 

For the authentication to the SEMP API ({@link Read#withSempClientFactory(SempClientFactory)}) * the connector provides {@link org.apache.beam.sdk.io.solace.broker.BasicAuthSempClientFactory} to * authenticate using the Basic Authentication. + * + *

Writing

+ * + *

To write to Solace, use {@link #write()} with a {@link PCollection}. You can + * also use {@link #write(SerializableFunction)} to specify a format function to convert the input + * type to {@link Solace.Record}. + * + *

Writing to a static topic or queue

+ * + *

The connector uses the Solace JCSMP API. + * The clients will write to a SMF + * topic to the port 55555 of the host. If you want to use a different port, specify it in the + * host property with the format "X.X.X.X:PORT". + * + *

Once you have a {@link PCollection} of {@link Solace.Record}, you can write to Solace using: + * + *

{@code
+ * PCollection solaceRecs = ...;
+ *
+ * PCollection results =
+ *         solaceRecs.apply(
+ *                 "Write to Solace",
+ *                 SolaceIO.write()
+ *                         .to(SolaceIO.topicFromName("some-topic"))
+ *                         .withSessionServiceFactory(
+ *                            BasicAuthJcsmpSessionServiceFactory.builder()
+ *                              .username("username")
+ *                              .password("password")
+ *                              .host("host:port")
+ *                              .build()));
+ * }
+ * + *

The above code snippet will write to the VPN named "default", using 4 clients per worker (VM + * in Dataflow), and a maximum of 20 workers/VMs for writing (default values). You can change the + * default VPN name by setting the required JCSMP property in the session factory (in this case, + * with {@link BasicAuthJcsmpSessionServiceFactory#vpnName()}), the number of clients per worker + * with {@link Write#withNumberOfClientsPerWorker(int)} and the number of parallel write clients + * using {@link Write#withMaxNumOfUsedWorkers(int)}. + * + *

Writing to dynamic destinations

+ * + * To write to dynamic destinations, don't set the {@link Write#to(Solace.Queue)} or {@link + * Write#to(Solace.Topic)} property and make sure that all the {@link Solace.Record}s have their + * destination field set to either a topic or a queue. You can do this prior to calling the write + * connector, or by using a format function and {@link #write(SerializableFunction)}. + * + *

For instance, you can create a function like the following: + * + *

{@code
+ * // Generate Record with different destinations
+ * SerializableFunction formatFn =
+ *    (MyType msg) -> {
+ *       int queue = ... // some random number
+ *       return Solace.Record.builder()
+ *         .setDestination(Solace.Destination.builder()
+ *                        .setName(String.format("q%d", queue))
+ *                        .setType(Solace.DestinationType.QUEUE)
+ *                        .build())
+ *         .setMessageId(msg.getMessageId())
+ *         .build();
+ * };
+ * }
+ * + * And then use the connector as follows: + * + *
{@code
+ * // Ignore "to" method to use dynamic destinations
+ * SolaceOutput solaceResponses = msgs.apply("Write to Solace",
+ *   SolaceIO.write(formatFn)
+ *        .withDeliveryMode(DeliveryMode.PERSISTENT)
+ *        .withWriterType(SolaceIO.WriterType.STREAMING)
+ * ...
+ * }
+ * + *

Direct and persistent messages, and latency metrics

+ * + *

The connector can write either direct or persistent messages. The default mode is DIRECT. + * + *

The connector returns a {@link PCollection} of {@link Solace.PublishResult}, that you can use + * to get a confirmation of messages that have been published, or rejected, but only if it is + * publishing persistent messages. + * + *

If you are publishing persistent messages, then you can have some feedback about whether the + * messages have been published, and some publishing latency metrics. If the message has been + * published, {@link Solace.PublishResult#getPublished()} will be true. If it is false, it means + * that the message could not be published, and {@link Solace.PublishResult#getError()} will contain + * more details about why the message could not be published. To get latency metrics as well as the + * results, set the property {@link Write#publishLatencyMetrics()}. + * + *

Throughput and latency

+ * + *

This connector can work in two main modes: high latency or high throughput. The default mode + * favors high throughput over high latency. You can control this behavior with the methods {@link + * Write#withSubmissionMode(SubmissionMode)} and {@link Write#withWriterType(WriterType)}. + * + *

The default mode works like the following options: + * + *

{@code
+ * PCollection solaceRecs = ...;
+ *
+ * PCollection results =
+ *         solaceRecs.apply(
+ *                 "Write to Solace",
+ *                 SolaceIO.write()
+ *                         .to(SolaceIO.topicFromName("some-topic"))
+ *                         .withSessionServiceFactory(
+ *                            BasicAuthJcsmpSessionServiceFactory.builder()
+ *                              .username("username")
+ *                              .password("password")
+ *                              .host("host:port")
+ *                              .build())
+ *                         .withSubmissionMode(SubmissionMode.HIGHER_THROUGHPUT)
+ *                         .withWriterType(WriterType.BATCHED));
+ * }
+ * + *

{@link SubmissionMode#HIGHER_THROUGHPUT} and {@link WriterType#BATCHED} are the default + * values, and offer the higher possible throughput, and the lowest usage of resources in the runner + * side (due to the lower backpressure). + * + *

This connector writes bundles of 50 messages, using a bulk publish JCSMP method. This will + * increase the latency, since a message needs to "wait" until 50 messages are accumulated, before + * they are submitted to Solace. + * + *

For the lowest latency possible, use {@link SubmissionMode#LOWER_LATENCY} and {@link + * WriterType#STREAMING}. + * + *

{@code
+ * PCollection results =
+ *         solaceRecs.apply(
+ *                 "Write to Solace",
+ *                 SolaceIO.write()
+ *                         .to(SolaceIO.topicFromName("some-topic"))
+ *                         .withSessionServiceFactory(
+ *                            BasicAuthJcsmpSessionServiceFactory.builder()
+ *                              .username("username")
+ *                              .password("password")
+ *                              .host("host:port")
+ *                              .build())
+ *                         .withSubmissionMode(SubmissionMode.LOWER_LATENCY)
+ *                         .withWriterType(WriterType.STREAMING));
+ * }
+ * + *

The streaming connector publishes each message individually, without holding up or batching + * before the message is sent to Solace. This will ensure the lowest possible latency, but it will + * offer a much lower throughput. The streaming connector does not use state & timers. + * + *

Both connectors uses state & timers to control the level of parallelism. If you are using + * Cloud Dataflow, it is recommended that you enable Streaming Engine to use this + * connector. + * + *

Authentication

+ * + *

When writing to Solace, the user must use {@link + * Write#withSessionServiceFactory(SessionServiceFactory)} to create a JCSMP session. + * + *

See {@link Write#withSessionServiceFactory(SessionServiceFactory)} for session authentication. + * The connector provides implementation of the {@link SessionServiceFactory} using basic + * authentication ({@link BasicAuthJcsmpSessionServiceFactory}), and another implementation using + * basic authentication but with a password stored as a secret in Google Cloud Secret Manager + * ({@link GCPSecretSessionServiceFactory}) + * + *

Connector retries

+ * + *

When the worker using the connector is created, the connector will attempt to connect to + * Solace. + * + *

If the client cannot connect to Solace for whatever reason, the connector will retry the + * connections using the following strategy. There will be a maximum of 4 retries. The first retry + * will be attempted 1 second after the first connection attempt. Every subsequent retry will + * multiply that time by a factor of two, with a maximum of 10 seconds. + * + *

If after those retries the client is still unable to connect to Solace, the connector will + * attempt to reconnect using the same strategy repeated for every single incoming message. If for + * some reason, there is a persistent issue that prevents the connection (e.g. client quota + * exhausted), you will need to stop your job manually, or the connector will keep retrying. + * + *

This strategy is applied to all the remote calls sent to Solace, either to connect, pull + * messages, push messages, etc. */ @Internal public class SolaceIO { @@ -209,6 +394,14 @@ public class SolaceIO { }; private static final boolean DEFAULT_DEDUPLICATE_RECORDS = false; + public static final int DEFAULT_WRITER_MAX_NUMBER_OF_WORKERS = 20; + public static final int DEFAULT_WRITER_CLIENTS_PER_WORKER = 4; + public static final Boolean DEFAULT_WRITER_PUBLISH_LATENCY_METRICS = false; + public static final SubmissionMode DEFAULT_WRITER_SUBMISSION_MODE = + SubmissionMode.HIGHER_THROUGHPUT; + public static final DeliveryMode DEFAULT_WRITER_DELIVERY_MODE = DeliveryMode.DIRECT; + public static final WriterType DEFAULT_WRITER_TYPE = WriterType.BATCHED; + /** Get a {@link Topic} object from the topic name. */ static Topic topicFromName(String topicName) { return JCSMPFactory.onlyInstance().createTopic(topicName); @@ -277,13 +470,31 @@ public static Read read( .setDeduplicateRecords(DEFAULT_DEDUPLICATE_RECORDS)); } + /** + * Create a {@link Write} transform, to write to Solace with a custom type. + * + *

If you are using a custom data class, the format function should return a {@link + * Solace.Record} corresponding to your custom data class instance. + * + *

If you are using this formatting function with dynamic destinations, you must ensure that + * you set the right value in the destination value of the {@link Solace.Record} messages. + */ + public static Write write(SerializableFunction formatFunction) { + return Write.builder().setFormatFunction(formatFunction).build(); + } + + /** Create a {@link Write} transform, to write to Solace using {@link Solace.Record} objects. */ + public static Write write() { + return Write.builder().build(); + } + public static class Read extends PTransform> { private static final Logger LOG = LoggerFactory.getLogger(Read.class); @VisibleForTesting final Configuration.Builder configurationBuilder; - public Read(Configuration.Builder configurationBuilder) { + private Read(Configuration.Builder configurationBuilder) { this.configurationBuilder = configurationBuilder; } @@ -569,4 +780,232 @@ private Queue initializeQueueForTopicIfNeeded( } } } + + public enum SubmissionMode { + HIGHER_THROUGHPUT, + LOWER_LATENCY + } + + public enum WriterType { + STREAMING, + BATCHED + } + + @AutoValue + public abstract static class Write extends PTransform, SolaceOutput> { + + public static final TupleTag FAILED_PUBLISH_TAG = + new TupleTag() {}; + public static final TupleTag SUCCESSFUL_PUBLISH_TAG = + new TupleTag() {}; + + /** + * Write to a Solace topic. + * + *

The topic does not need to exist before launching the pipeline. + * + *

This will write all records to the same topic, ignoring their destination field. + * + *

Optional. If not specified, the connector will use dynamic destinations based on the + * destination field of {@link Solace.Record}. + */ + public Write to(Solace.Topic topic) { + return toBuilder().setDestination(topicFromName(topic.getName())).build(); + } + + /** + * Write to a Solace queue. + * + *

The queue must exist prior to launching the pipeline. + * + *

This will write all records to the same queue, ignoring their destination field. + * + *

Optional. If not specified, the connector will use dynamic destinations based on the + * destination field of {@link Solace.Record}. + */ + public Write to(Solace.Queue queue) { + return toBuilder().setDestination(queueFromName(queue.getName())).build(); + } + + /** + * The number of workers used by the job to write to Solace. + * + *

This is optional, the default value is 20. + * + *

This is the maximum value that the job would use, but depending on the amount of data, the + * actual number of writers may be lower than this value. With the Dataflow runner, the + * connector will as maximum this number of VMs in the job (but the job itself may use more + * VMs). + * + *

Set this number taking into account the limits in the number of clients in your Solace + * cluster, and the need for performance when writing to Solace (more workers will achieve + * higher throughput). + */ + public Write withMaxNumOfUsedWorkers(int maxNumOfUsedWorkers) { + return toBuilder().setMaxNumOfUsedWorkers(maxNumOfUsedWorkers).build(); + } + + /** + * The number of clients that each worker will create. + * + *

This is optional, the default number is 4. + * + *

The number of clients is per worker. If there are more than one worker, the number of + * clients will be multiplied by the number of workers. With the Dataflow runner, this will be + * the number of clients created per VM. The clients will be re-used across different threads in + * the same worker. + * + *

Set this number in combination with {@link #withMaxNumOfUsedWorkers}, to ensure that the + * limit for number of clients in your Solace cluster is not exceeded. + * + *

Normally, using a higher number of clients with fewer workers will achieve better + * throughput at a lower cost, since the workers are better utilized. A good rule of thumb to + * use is setting as many clients per worker as vCPUs the worker has. + */ + public Write withNumberOfClientsPerWorker(int numberOfClientsPerWorker) { + return toBuilder().setNumberOfClientsPerWorker(numberOfClientsPerWorker).build(); + } + + /** + * Set the delivery mode. This is optional, the default value is DIRECT. + * + *

For more details, see https://docs.solace.com/API/API-Developer-Guide/Message-Delivery-Modes.htm + */ + public Write withDeliveryMode(DeliveryMode deliveryMode) { + return toBuilder().setDeliveryMode(deliveryMode).build(); + } + + /** + * Publish latency metrics using Beam metrics. + * + *

Latency metrics are only available if {@link #withDeliveryMode(DeliveryMode)} is set to + * PERSISTENT. In that mode, latency is measured for each single message, as the time difference + * between the message creation and the reception of the publishing confirmation. + * + *

For the batched writer, the creation time is set for every message in a batch shortly + * before the batch is submitted. So the latency is very close to the actual publishing latency, + * and it does not take into account the time spent waiting for the batch to be submitted. + * + *

This is optional, the default value is false (don't publish latency metrics). + */ + public Write publishLatencyMetrics() { + return toBuilder().setPublishLatencyMetrics(true).build(); + } + + /** + * This setting controls the JCSMP property MESSAGE_CALLBACK_ON_REACTOR. Optional. + * + *

For full details, please check https://docs.solace.com/API/API-Developer-Guide/Java-API-Best-Practices.htm. + * + *

The Solace JCSMP client libraries can dispatch messages using two different modes: + * + *

One of the modes dispatches messages directly from the same thread that is doing the rest + * of I/O work. This mode favors lower latency but lower throughput. Set this to LOWER_LATENCY + * to use that mode (MESSAGE_CALLBACK_ON_REACTOR set to True). + * + *

The other mode uses a parallel thread to accumulate and dispatch messages. This mode + * favors higher throughput but also has higher latency. Set this to HIGHER_THROUGHPUT to use + * that mode. This is the default mode (MESSAGE_CALLBACK_ON_REACTOR set to False). + * + *

This is optional, the default value is HIGHER_THROUGHPUT. + */ + public Write withSubmissionMode(SubmissionMode submissionMode) { + return toBuilder().setDispatchMode(submissionMode).build(); + } + + /** + * Set the type of writer used by the connector. Optional. + * + *

The Solace writer can either use the JCSMP modes in streaming or batched. + * + *

In streaming mode, the publishing latency will be lower, but the throughput will also be + * lower. + * + *

With the batched mode, messages are accumulated until a batch size of 50 is reached, or 5 + * seconds have elapsed since the first message in the batch was received. The 50 messages are + * sent to Solace in a single batch. This writer offers higher throughput but higher publishing + * latency, as messages can be held up for up to 5 seconds until they are published. + * + *

Notice that this is the message publishing latency, not the end-to-end latency. For very + * large scale pipelines, you will probably prefer to use the HIGHER_THROUGHPUT mode, as with + * lower throughput messages will accumulate in the pipeline, and the end-to-end latency may + * actually be higher. + * + *

This is optional, the default is the BATCHED writer. + */ + public Write withWriterType(WriterType writerType) { + return toBuilder().setWriterType(writerType).build(); + } + + /** + * Set the provider used to obtain the properties to initialize a new session in the broker. + * + *

This provider should define the destination host where the broker is listening, and all + * the properties related to authentication (base auth, client certificate, etc.). + */ + public Write withSessionServiceFactory(SessionServiceFactory factory) { + return toBuilder().setSessionServiceFactory(factory).build(); + } + + abstract int getMaxNumOfUsedWorkers(); + + abstract int getNumberOfClientsPerWorker(); + + abstract @Nullable Destination getDestination(); + + abstract DeliveryMode getDeliveryMode(); + + abstract boolean getPublishLatencyMetrics(); + + abstract SubmissionMode getDispatchMode(); + + abstract WriterType getWriterType(); + + abstract @Nullable SerializableFunction getFormatFunction(); + + abstract @Nullable SessionServiceFactory getSessionServiceFactory(); + + static Builder builder() { + return new AutoValue_SolaceIO_Write.Builder() + .setDeliveryMode(DEFAULT_WRITER_DELIVERY_MODE) + .setMaxNumOfUsedWorkers(DEFAULT_WRITER_MAX_NUMBER_OF_WORKERS) + .setNumberOfClientsPerWorker(DEFAULT_WRITER_CLIENTS_PER_WORKER) + .setPublishLatencyMetrics(DEFAULT_WRITER_PUBLISH_LATENCY_METRICS) + .setDispatchMode(DEFAULT_WRITER_SUBMISSION_MODE) + .setWriterType(DEFAULT_WRITER_TYPE); + } + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setMaxNumOfUsedWorkers(int maxNumOfUsedWorkers); + + abstract Builder setNumberOfClientsPerWorker(int numberOfClientsPerWorker); + + abstract Builder setDestination(Destination topicOrQueue); + + abstract Builder setDeliveryMode(DeliveryMode deliveryMode); + + abstract Builder setPublishLatencyMetrics(Boolean publishLatencyMetrics); + + abstract Builder setDispatchMode(SubmissionMode submissionMode); + + abstract Builder setWriterType(WriterType writerType); + + abstract Builder setFormatFunction(SerializableFunction formatFunction); + + abstract Builder setSessionServiceFactory(SessionServiceFactory factory); + + abstract Write build(); + } + + @Override + public SolaceOutput expand(PCollection input) { + // TODO: will be sent in upcoming PR + return SolaceOutput.in(input.getPipeline(), null, null); + } + } } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionService.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionService.java index 7863dbd129ce..2137d574b09a 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionService.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionService.java @@ -39,13 +39,14 @@ *

This class provides a way to connect to a Solace broker and receive messages from a queue. The * connection is established using basic authentication. */ -public class BasicAuthJcsmpSessionService implements SessionService { +public class BasicAuthJcsmpSessionService extends SessionService { private final String queueName; private final String host; private final String username; private final String password; private final String vpnName; @Nullable private JCSMPSession jcsmpSession; + @Nullable private MessageReceiver messageReceiver; private final RetryCallableManager retryCallableManager = RetryCallableManager.create(); /** @@ -73,12 +74,14 @@ public void connect() { @Override public void close() { - if (isClosed()) { - return; - } retryCallableManager.retryCallable( () -> { - checkStateNotNull(jcsmpSession).closeSession(); + if (messageReceiver != null) { + messageReceiver.close(); + } + if (!isClosed()) { + checkStateNotNull(jcsmpSession).closeSession(); + } return 0; }, ImmutableSet.of(IOException.class)); @@ -86,8 +89,10 @@ public void close() { @Override public MessageReceiver createReceiver() { - return retryCallableManager.retryCallable( - this::createFlowReceiver, ImmutableSet.of(JCSMPException.class)); + this.messageReceiver = + retryCallableManager.retryCallable( + this::createFlowReceiver, ImmutableSet.of(JCSMPException.class)); + return this.messageReceiver; } @Override @@ -137,12 +142,19 @@ private int connectSession() throws JCSMPException { } private JCSMPSession createSessionObject() throws InvalidPropertiesException { - JCSMPProperties properties = new JCSMPProperties(); - properties.setProperty(JCSMPProperties.HOST, host); - properties.setProperty(JCSMPProperties.USERNAME, username); - properties.setProperty(JCSMPProperties.PASSWORD, password); - properties.setProperty(JCSMPProperties.VPN_NAME, vpnName); - + JCSMPProperties properties = initializeSessionProperties(new JCSMPProperties()); return JCSMPFactory.onlyInstance().createSession(properties); } + + @Override + public JCSMPProperties initializeSessionProperties(JCSMPProperties baseProps) { + baseProps.setProperty(JCSMPProperties.VPN_NAME, vpnName); + + baseProps.setProperty( + JCSMPProperties.AUTHENTICATION_SCHEME, JCSMPProperties.AUTHENTICATION_SCHEME_BASIC); + baseProps.setProperty(JCSMPProperties.USERNAME, username); + baseProps.setProperty(JCSMPProperties.PASSWORD, password); + baseProps.setProperty(JCSMPProperties.HOST, host); + return baseProps; + } } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionServiceFactory.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionServiceFactory.java index 8cb4ff0af053..2084e61b7e38 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionServiceFactory.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionServiceFactory.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.io.solace.broker; +import static org.apache.beam.sdk.io.solace.broker.SessionService.DEFAULT_VPN_NAME; import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; import com.google.auto.value.AutoValue; @@ -39,7 +40,7 @@ public abstract class BasicAuthJcsmpSessionServiceFactory extends SessionService public abstract String vpnName(); public static Builder builder() { - return new AutoValue_BasicAuthJcsmpSessionServiceFactory.Builder(); + return new AutoValue_BasicAuthJcsmpSessionServiceFactory.Builder().vpnName(DEFAULT_VPN_NAME); } @AutoValue.Builder diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthSempClient.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthSempClient.java new file mode 100644 index 000000000000..4884bb61e628 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthSempClient.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.broker; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.api.client.http.HttpRequestFactory; +import com.solacesystems.jcsmp.JCSMPFactory; +import java.io.IOException; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.solace.data.Semp.Queue; +import org.apache.beam.sdk.util.SerializableSupplier; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A class that manages REST calls to the Solace Element Management Protocol (SEMP) using basic + * authentication. + * + *

This class provides methods to check necessary information, such as if the queue is + * non-exclusive, remaining backlog bytes of a queue. It can also create and execute calls to create + * queue for a topic. + */ +@Internal +public class BasicAuthSempClient implements SempClient { + private static final Logger LOG = LoggerFactory.getLogger(BasicAuthSempClient.class); + private final ObjectMapper objectMapper = + new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + + private final SempBasicAuthClientExecutor sempBasicAuthClientExecutor; + + public BasicAuthSempClient( + String host, + String username, + String password, + String vpnName, + SerializableSupplier httpRequestFactorySupplier) { + sempBasicAuthClientExecutor = + new SempBasicAuthClientExecutor( + host, username, password, vpnName, httpRequestFactorySupplier.get()); + } + + @Override + public boolean isQueueNonExclusive(String queueName) throws IOException { + LOG.info("SolaceIO.Read: SempOperations: query SEMP if queue {} is nonExclusive", queueName); + BrokerResponse response = sempBasicAuthClientExecutor.getQueueResponse(queueName); + if (response.content == null) { + throw new IOException("SolaceIO: response from SEMP is empty!"); + } + Queue q = mapJsonToClass(response.content, Queue.class); + return q.data().accessType().equals("non-exclusive"); + } + + @Override + public com.solacesystems.jcsmp.Queue createQueueForTopic(String queueName, String topicName) + throws IOException { + createQueue(queueName); + createSubscription(queueName, topicName); + return JCSMPFactory.onlyInstance().createQueue(queueName); + } + + @Override + public long getBacklogBytes(String queueName) throws IOException { + BrokerResponse response = sempBasicAuthClientExecutor.getQueueResponse(queueName); + if (response.content == null) { + throw new IOException("SolaceIO: response from SEMP is empty!"); + } + Queue q = mapJsonToClass(response.content, Queue.class); + return q.data().msgSpoolUsage(); + } + + private void createQueue(String queueName) throws IOException { + LOG.info("SolaceIO.Read: Creating new queue {}.", queueName); + sempBasicAuthClientExecutor.createQueueResponse(queueName); + } + + private void createSubscription(String queueName, String topicName) throws IOException { + LOG.info("SolaceIO.Read: Creating new subscription {} for topic {}.", queueName, topicName); + sempBasicAuthClientExecutor.createSubscriptionResponse(queueName, topicName); + } + + private T mapJsonToClass(String content, Class mapSuccessToClass) + throws JsonProcessingException { + return objectMapper.readValue(content, mapSuccessToClass); + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthSempClientFactory.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthSempClientFactory.java new file mode 100644 index 000000000000..4c01257373b4 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthSempClientFactory.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.broker; + +import com.google.api.client.http.HttpRequestFactory; +import com.google.api.client.http.javanet.NetHttpTransport; +import com.google.auto.value.AutoValue; +import org.apache.beam.sdk.util.SerializableSupplier; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * A factory for creating {@link BasicAuthSempClient} instances. + * + *

This factory provides a way to create {@link BasicAuthSempClient} instances with different + * configurations. + */ +@AutoValue +public abstract class BasicAuthSempClientFactory implements SempClientFactory { + + abstract String host(); + + abstract String username(); + + abstract String password(); + + abstract String vpnName(); + + abstract @Nullable SerializableSupplier httpRequestFactorySupplier(); + + public static Builder builder() { + return new AutoValue_BasicAuthSempClientFactory.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + /** Set Solace SEMP host, format: [Protocol://]Host[:Port]. e.g. "http://127.0.0.1:8080" */ + public abstract Builder host(String host); + + /** Set Solace username. */ + public abstract Builder username(String username); + /** Set Solace password. */ + public abstract Builder password(String password); + + /** Set Solace vpn name. */ + public abstract Builder vpnName(String vpnName); + + @VisibleForTesting + abstract Builder httpRequestFactorySupplier( + SerializableSupplier httpRequestFactorySupplier); + + public abstract BasicAuthSempClientFactory build(); + } + + @Override + public SempClient create() { + return new BasicAuthSempClient( + host(), username(), password(), vpnName(), getHttpRequestFactorySupplier()); + } + + SerializableSupplier getHttpRequestFactorySupplier() { + SerializableSupplier httpRequestSupplier = httpRequestFactorySupplier(); + return httpRequestSupplier != null + ? httpRequestSupplier + : () -> new NetHttpTransport().createRequestFactory(); + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BrokerResponse.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BrokerResponse.java new file mode 100644 index 000000000000..1a47f8012285 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BrokerResponse.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.broker; + +import com.google.api.client.http.HttpResponse; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.util.stream.Collectors; +import org.checkerframework.checker.nullness.qual.Nullable; + +public class BrokerResponse { + final int code; + final String message; + @Nullable String content; + + public BrokerResponse(int responseCode, String message, @Nullable InputStream content) { + this.code = responseCode; + this.message = message; + if (content != null) { + this.content = + new BufferedReader(new InputStreamReader(content, StandardCharsets.UTF_8)) + .lines() + .collect(Collectors.joining("\n")); + } + } + + public static BrokerResponse fromHttpResponse(HttpResponse response) throws IOException { + return new BrokerResponse( + response.getStatusCode(), response.getStatusMessage(), response.getContent()); + } + + @Override + public String toString() { + return "BrokerResponse{" + + "code=" + + code + + ", message='" + + message + + '\'' + + ", content=" + + content + + '}'; + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/GCPSecretSessionServiceFactory.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/GCPSecretSessionServiceFactory.java new file mode 100644 index 000000000000..dd87e1d75fa5 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/GCPSecretSessionServiceFactory.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.broker; + +import static org.apache.beam.sdk.io.solace.broker.SessionService.DEFAULT_VPN_NAME; + +import com.google.auto.value.AutoValue; +import com.google.cloud.secretmanager.v1.SecretManagerServiceClient; +import com.google.cloud.secretmanager.v1.SecretVersionName; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.net.HttpURLConnection; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.util.Optional; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This class implements a {@link SessionServiceFactory} that retrieve the basic authentication + * credentials from a Google Cloud Secret Manager secret. + * + *

It can be used to avoid having to pass the password as an option of your pipeline. For this + * provider to work, the worker where the job runs needs to have the necessary credentials to access + * the secret. In Dataflow, this implies adding the necessary permissions to the worker service + * account. For other runners, set the credentials in the pipeline options using {@link + * org.apache.beam.sdk.extensions.gcp.options.GcpOptions}. + * + *

It also shows how to implement a {@link SessionServiceFactory} that depends on using external + * resources to retrieve the Solace session properties. In this case, using the Google Cloud Secrete + * Manager client. + * + *

Example of how to create the provider object: + * + *

{@code
+ * GCPSecretSessionServiceFactory factory =
+ *     GCPSecretSessionServiceFactory.builder()
+ *         .username("user")
+ *         .host("host:port")
+ *         .passwordSecretName("secret-name")
+ *         .build();
+ *
+ * SessionService serviceUsingGCPSecret = factory.create();
+ * }
+ */ +@AutoValue +public abstract class GCPSecretSessionServiceFactory extends SessionServiceFactory { + + private static final Logger LOG = LoggerFactory.getLogger(GCPSecretSessionServiceFactory.class); + + private static final String PROJECT_NOT_FOUND = "PROJECT-NOT-FOUND"; + + public abstract String username(); + + public abstract String host(); + + public abstract String passwordSecretName(); + + public abstract String vpnName(); + + public abstract @Nullable String secretManagerProjectId(); + + public abstract String passwordSecretVersion(); + + public static GCPSecretSessionServiceFactory.Builder builder() { + return new AutoValue_GCPSecretSessionServiceFactory.Builder() + .passwordSecretVersion("latest") + .vpnName(DEFAULT_VPN_NAME); + } + + @AutoValue.Builder + public abstract static class Builder { + + /** Username to be used to authenticate with the broker. */ + public abstract GCPSecretSessionServiceFactory.Builder username(String username); + + /** + * The location of the broker, including port details if it is not listening in the default + * port. + */ + public abstract GCPSecretSessionServiceFactory.Builder host(String host); + + /** The Secret Manager secret name where the password is stored. */ + public abstract GCPSecretSessionServiceFactory.Builder passwordSecretName(String name); + + /** Optional. Solace broker VPN name. If not set, "default" is used. */ + public abstract GCPSecretSessionServiceFactory.Builder vpnName(String name); + + /** + * Optional for Dataflow or VMs running on Google Cloud. The project id of the project where the + * secret is stored. If not set, the project id where the job is running is used. + */ + public abstract GCPSecretSessionServiceFactory.Builder secretManagerProjectId(String id); + + /** Optional. Solace broker password secret version. If not set, "latest" is used. */ + public abstract GCPSecretSessionServiceFactory.Builder passwordSecretVersion(String version); + + public abstract GCPSecretSessionServiceFactory build(); + } + + @Override + public SessionService create() { + String password = null; + try { + password = retrieveSecret(); + } catch (IOException e) { + throw new RuntimeException(e); + } + + BasicAuthJcsmpSessionServiceFactory factory = + BasicAuthJcsmpSessionServiceFactory.builder() + .username(username()) + .host(host()) + .password(password) + .vpnName(vpnName()) + .build(); + + return factory.create(); + } + + private String retrieveSecret() throws IOException { + try (SecretManagerServiceClient client = SecretManagerServiceClient.create()) { + String projectId = + Optional.ofNullable(secretManagerProjectId()).orElse(getProjectIdFromVmMetadata()); + SecretVersionName secretVersionName = + SecretVersionName.of(projectId, passwordSecretName(), passwordSecretVersion()); + return client.accessSecretVersion(secretVersionName).getPayload().getData().toStringUtf8(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private String getProjectIdFromVmMetadata() throws IOException { + URL metadataUrl = + new URL("http://metadata.google.internal/computeMetadata/v1/project/project-id"); + HttpURLConnection connection = (HttpURLConnection) metadataUrl.openConnection(); + connection.setRequestProperty("Metadata-Flavor", "Google"); + + String output; + try (BufferedReader reader = + new BufferedReader( + new InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8))) { + output = reader.readLine(); + } + + if (output == null || output.isEmpty()) { + LOG.error( + "Cannot retrieve project id from VM metadata, please set a project id in your GoogleCloudSecretProvider."); + } + return output != null ? output : PROJECT_NOT_FOUND; + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageReceiver.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageReceiver.java index 199a83e322bd..95f989bd1be9 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageReceiver.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageReceiver.java @@ -49,6 +49,9 @@ public interface MessageReceiver { */ BytesXMLMessage receive() throws IOException; + /** Closes the message receiver. */ + void close(); + /** * Test clients may return {@literal true} to signal that all expected messages have been pulled * and the test may complete. Real clients should always return {@literal false}. diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SempBasicAuthClientExecutor.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SempBasicAuthClientExecutor.java new file mode 100644 index 000000000000..62a492775e7c --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SempBasicAuthClientExecutor.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.broker; + +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; + +import com.google.api.client.http.GenericUrl; +import com.google.api.client.http.HttpContent; +import com.google.api.client.http.HttpHeaders; +import com.google.api.client.http.HttpRequest; +import com.google.api.client.http.HttpRequestFactory; +import com.google.api.client.http.HttpResponse; +import com.google.api.client.http.HttpResponseException; +import com.google.api.client.http.json.JsonHttpContent; +import com.google.api.client.json.gson.GsonFactory; +import java.io.IOException; +import java.io.Serializable; +import java.net.CookieManager; +import java.net.HttpCookie; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * A class to execute requests to SEMP v2 with Basic Auth authentication. + * + *

This approach takes advantage of SEMP Sessions. The + * session is established when a user authenticates with HTTP Basic authentication. When the + * response is 401 Unauthorized, the client will execute an additional request with Basic Auth + * header to refresh the token. + */ +class SempBasicAuthClientExecutor implements Serializable { + // Every request will be repeated 2 times in case of abnormal connection failures. + private static final int REQUEST_NUM_RETRIES = 2; + private static final Map COOKIE_MANAGER_MAP = + new ConcurrentHashMap(); + private static final String COOKIES_HEADER = "Set-Cookie"; + + private final String username; + private final String messageVpn; + private final String baseUrl; + private final String password; + private final CookieManagerKey cookieManagerKey; + private final transient HttpRequestFactory requestFactory; + + SempBasicAuthClientExecutor( + String host, + String username, + String password, + String vpnName, + HttpRequestFactory httpRequestFactory) { + this.baseUrl = String.format("%s/SEMP/v2", host); + this.username = username; + this.messageVpn = vpnName; + this.password = password; + this.requestFactory = httpRequestFactory; + this.cookieManagerKey = new CookieManagerKey(this.baseUrl, this.username); + COOKIE_MANAGER_MAP.putIfAbsent(this.cookieManagerKey, new CookieManager()); + } + + private static String getQueueEndpoint(String messageVpn, String queueName) { + return String.format("/monitor/msgVpns/%s/queues/%s", messageVpn, queueName); + } + + private static String createQueueEndpoint(String messageVpn) { + return String.format("/config/msgVpns/%s/queues", messageVpn); + } + + private static String subscriptionEndpoint(String messageVpn, String queueName) { + return String.format("/config/msgVpns/%s/queues/%s/subscriptions", messageVpn, queueName); + } + + BrokerResponse getQueueResponse(String queueName) throws IOException { + String queryUrl = getQueueEndpoint(messageVpn, queueName); + HttpResponse response = executeGet(new GenericUrl(baseUrl + queryUrl)); + return BrokerResponse.fromHttpResponse(response); + } + + BrokerResponse createQueueResponse(String queueName) throws IOException { + String queryUrl = createQueueEndpoint(messageVpn); + ImmutableMap params = + ImmutableMap.builder() + .put("accessType", "non-exclusive") + .put("queueName", queueName) + .put("owner", username) + .put("permission", "consume") + .put("ingressEnabled", true) + .put("egressEnabled", true) + .build(); + + HttpResponse response = executePost(new GenericUrl(baseUrl + queryUrl), params); + return BrokerResponse.fromHttpResponse(response); + } + + BrokerResponse createSubscriptionResponse(String queueName, String topicName) throws IOException { + String queryUrl = subscriptionEndpoint(messageVpn, queueName); + + ImmutableMap params = + ImmutableMap.builder() + .put("subscriptionTopic", topicName) + .put("queueName", queueName) + .build(); + HttpResponse response = executePost(new GenericUrl(baseUrl + queryUrl), params); + return BrokerResponse.fromHttpResponse(response); + } + + private HttpResponse executeGet(GenericUrl url) throws IOException { + HttpRequest request = requestFactory.buildGetRequest(url); + return execute(request); + } + + private HttpResponse executePost(GenericUrl url, ImmutableMap parameters) + throws IOException { + HttpContent content = new JsonHttpContent(GsonFactory.getDefaultInstance(), parameters); + HttpRequest request = requestFactory.buildPostRequest(url, content); + return execute(request); + } + + private HttpResponse execute(HttpRequest request) throws IOException { + request.setNumberOfRetries(REQUEST_NUM_RETRIES); + HttpHeaders httpHeaders = new HttpHeaders(); + boolean authFromCookie = + !checkStateNotNull(COOKIE_MANAGER_MAP.get(cookieManagerKey)) + .getCookieStore() + .getCookies() + .isEmpty(); + if (authFromCookie) { + setCookiesFromCookieManager(httpHeaders); + request.setHeaders(httpHeaders); + } else { + httpHeaders.setBasicAuthentication(username, password); + request.setHeaders(httpHeaders); + } + + HttpResponse response; + try { + response = request.execute(); + } catch (HttpResponseException e) { + if (authFromCookie && e.getStatusCode() == 401) { + checkStateNotNull(COOKIE_MANAGER_MAP.get(cookieManagerKey)).getCookieStore().removeAll(); + // execute again without cookies to refresh the token. + return execute(request); + } else { // we might need to handle other response codes here. + throw e; + } + } + + storeCookiesInCookieManager(response.getHeaders()); + return response; + } + + private void setCookiesFromCookieManager(HttpHeaders httpHeaders) { + httpHeaders.setCookie( + checkStateNotNull(COOKIE_MANAGER_MAP.get(cookieManagerKey)).getCookieStore().getCookies() + .stream() + .map(s -> s.getName() + "=" + s.getValue()) + .collect(Collectors.joining(";"))); + } + + private void storeCookiesInCookieManager(HttpHeaders headers) { + List cookiesHeader = headers.getHeaderStringValues(COOKIES_HEADER); + if (cookiesHeader != null) { + for (String cookie : cookiesHeader) { + checkStateNotNull(COOKIE_MANAGER_MAP.get(cookieManagerKey)) + .getCookieStore() + .add(null, HttpCookie.parse(cookie).get(0)); + } + } + } + + private static class CookieManagerKey implements Serializable { + private final String baseUrl; + private final String username; + + CookieManagerKey(String baseUrl, String username) { + this.baseUrl = baseUrl; + this.username = username; + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof CookieManagerKey)) { + return false; + } + CookieManagerKey that = (CookieManagerKey) o; + return Objects.equals(baseUrl, that.baseUrl) && Objects.equals(username, that.username); + } + + @Override + public int hashCode() { + return Objects.hash(baseUrl, username); + } + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionService.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionService.java index cd368865f0c3..aed700a71ded 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionService.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionService.java @@ -17,34 +17,220 @@ */ package org.apache.beam.sdk.io.solace.broker; +import com.solacesystems.jcsmp.JCSMPProperties; import java.io.Serializable; +import org.apache.beam.sdk.io.solace.SolaceIO; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * The SessionService interface provides a set of methods for managing a session with the Solace * messaging system. It allows for establishing a connection, creating a message-receiver object, * checking if the connection is closed or not, and gracefully closing the session. + * + *

Override this class and the method {@link #initializeSessionProperties(JCSMPProperties)} with + * your specific properties, including all those related to authentication. + * + *

The connector will call the method only once per session created, so you can perform + * relatively heavy operations in that method (e.g. connect to a store or vault to retrieve + * credentials). + * + *

There are some default properties that are set by default and can be overridden in this + * provider, that are relevant for the writer connector, and not used in the case of the read + * connector (since they are not necessary for reading): + * + *

    + *
  • VPN_NAME: default + *
  • GENERATE_SEND_TIMESTAMPS: true + *
  • PUB_MULTI_THREAD: true + *
+ * + *

The connector overrides other properties, regardless of what this provider sends to the + * connector. Those properties are the following. Again, these properties are only relevant for the + * write connector. + * + *

    + *
  • PUB_ACK_WINDOW_SIZE + *
  • MESSAGE_CALLBACK_ON_REACTOR + *
+ * + * Those properties are set by the connector based on the values of {@link + * org.apache.beam.sdk.io.solace.SolaceIO.Write#withWriterType(SolaceIO.WriterType)} and {@link + * org.apache.beam.sdk.io.solace.SolaceIO.Write#withSubmissionMode(SolaceIO.SubmissionMode)}. + * + *

The method will always run in a worker thread or task, and not in the driver program. If you + * need to access any resource to set the properties, you need to make sure that the worker has the + * network connectivity required for that, and that any credential or configuration is passed to the + * provider through the constructor. + * + *

The connector ensures that no two threads will be calling that method at the same time, so you + * don't have to take any specific precautions to avoid race conditions. + * + *

For basic authentication, use {@link BasicAuthJcsmpSessionService} and {@link + * BasicAuthJcsmpSessionServiceFactory}. + * + *

For other situations, you need to extend this class. For instance: + * + *

{@code
+ * public class MySessionService extends SessionService {
+ *   private final String authToken;
+ *
+ *   public MySessionService(String token) {
+ *    this.oauthToken = token;
+ *    ...
+ *   }
+ *
+ *   {@literal }@Override
+ *   public JCSMPProperties initializeSessionProperties(JCSMPProperties baseProps) {
+ *     baseProps.setProperty(JCSMPProperties.AUTHENTICATION_SCHEME, JCSMPProperties.AUTHENTICATION_SCHEME_OAUTH2);
+ *     baseProps.setProperty(JCSMPProperties.OAUTH2_ACCESS_TOKEN, authToken);
+ *     return props;
+ *   }
+ *
+ *   {@literal }@Override
+ *   public void connect() {
+ *       ...
+ *   }
+ *
+ *   ...
+ * }
+ * }
*/ -public interface SessionService extends Serializable { +public abstract class SessionService implements Serializable { + private static final Logger LOG = LoggerFactory.getLogger(SessionService.class); + + public static final String DEFAULT_VPN_NAME = "default"; + + private static final int STREAMING_PUB_ACK_WINDOW = 50; + private static final int BATCHED_PUB_ACK_WINDOW = 255; /** * Establishes a connection to the service. This could involve providing connection details like * host, port, VPN name, username, and password. */ - void connect(); + public abstract void connect(); /** Gracefully closes the connection to the service. */ - void close(); + public abstract void close(); /** * Checks whether the connection to the service is currently closed. This method is called when an * `UnboundedSolaceReader` is starting to read messages - a session will be created if this * returns true. */ - boolean isClosed(); + public abstract boolean isClosed(); /** * Creates a MessageReceiver object for receiving messages from Solace. Typically, this object is * created from the session instance. */ - MessageReceiver createReceiver(); + public abstract MessageReceiver createReceiver(); + + /** + * Override this method and provide your specific properties, including all those related to + * authentication, and possibly others too. The {@code}baseProperties{@code} parameter sets the + * Solace VPN to "default" if none is specified. + * + *

You should add your properties to the parameter {@code}baseProperties{@code}, and return the + * result. + * + *

The method will be used whenever the session needs to be created or refreshed. If you are + * setting credentials with expiration, just make sure that the latest available credentials (e.g. + * renewed token) are set when the method is called. + * + *

For a list of all the properties that can be set, please check the following link: + * + *

+ */ + public abstract JCSMPProperties initializeSessionProperties(JCSMPProperties baseProperties); + + /** + * This method will be called by the write connector when a new session is started. + * + *

This call will happen in the worker, so you need to make sure that the worker has access to + * the resources you need to set the properties. + * + *

The call will happen only once per session initialization. Typically, that will be when the + * worker and the client are created. But if for any reason the session is lost (e.g. expired auth + * token), this method will be called again. + */ + public final JCSMPProperties initializeWriteSessionProperties(SolaceIO.SubmissionMode mode) { + JCSMPProperties jcsmpProperties = initializeSessionProperties(getDefaultProperties()); + return overrideConnectorProperties(jcsmpProperties, mode); + } + + private static JCSMPProperties getDefaultProperties() { + JCSMPProperties props = new JCSMPProperties(); + props.setProperty(JCSMPProperties.VPN_NAME, DEFAULT_VPN_NAME); + // Outgoing messages will have a sender timestamp field populated + props.setProperty(JCSMPProperties.GENERATE_SEND_TIMESTAMPS, true); + // Make XMLProducer safe to access from several threads. This is the default value, setting + // it just in case. + props.setProperty(JCSMPProperties.PUB_MULTI_THREAD, true); + + return props; + } + + /** + * This method overrides some properties for the broker session to prevent misconfiguration, + * taking into account how the write connector works. + */ + private static JCSMPProperties overrideConnectorProperties( + JCSMPProperties props, SolaceIO.SubmissionMode mode) { + + // PUB_ACK_WINDOW_SIZE heavily affects performance when publishing persistent + // messages. It can be a value between 1 and 255. This is the batch size for the ack + // received from Solace. A value of 1 will have the lowest latency, but a very low + // throughput and a monumental backpressure. + + // This controls how the messages are sent to Solace + if (mode == SolaceIO.SubmissionMode.HIGHER_THROUGHPUT) { + // Create a parallel thread and a queue to send the messages + + Boolean msgCbProp = props.getBooleanProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR); + if (msgCbProp != null && msgCbProp) { + LOG.warn( + "SolaceIO.Write: Overriding MESSAGE_CALLBACK_ON_REACTOR to false since" + + " HIGHER_THROUGHPUT mode was selected"); + } + + props.setProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR, false); + + Integer ackWindowSize = props.getIntegerProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE); + if ((ackWindowSize != null && ackWindowSize != BATCHED_PUB_ACK_WINDOW)) { + LOG.warn( + String.format( + "SolaceIO.Write: Overriding PUB_ACK_WINDOW_SIZE to %d since" + + " HIGHER_THROUGHPUT mode was selected", + BATCHED_PUB_ACK_WINDOW)); + } + props.setProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE, BATCHED_PUB_ACK_WINDOW); + } else { + // Send from the same thread where the produced is being called. This offers the lowest + // latency, but a low throughput too. + Boolean msgCbProp = props.getBooleanProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR); + if (msgCbProp != null && !msgCbProp) { + LOG.warn( + "SolaceIO.Write: Overriding MESSAGE_CALLBACK_ON_REACTOR to true since" + + " LOWER_LATENCY mode was selected"); + } + + props.setProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR, true); + + Integer ackWindowSize = props.getIntegerProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE); + if ((ackWindowSize != null && ackWindowSize != STREAMING_PUB_ACK_WINDOW)) { + LOG.warn( + String.format( + "SolaceIO.Write: Overriding PUB_ACK_WINDOW_SIZE to %d since" + + " LOWER_LATENCY mode was selected", + STREAMING_PUB_ACK_WINDOW)); + } + + props.setProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE, STREAMING_PUB_ACK_WINDOW); + } + return props; + } } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionServiceFactory.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionServiceFactory.java index 7d1dee7a1187..027de2cff134 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionServiceFactory.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionServiceFactory.java @@ -26,9 +26,8 @@ * queue property and mandates the implementation of a create() method in concrete subclasses. */ public abstract class SessionServiceFactory implements Serializable { - /** - * A reference to a Queue object. This is set when the pipline is constructed (in the {@link + * A reference to a Queue object. This is set when the pipeline is constructed (in the {@link * org.apache.beam.sdk.io.solace.SolaceIO.Read#expand(org.apache.beam.sdk.values.PBegin)} method). * This could be used to associate the created SessionService with a specific queue for message * handling. diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SolaceMessageReceiver.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SolaceMessageReceiver.java index e5f129d3ddfc..d548d2049a5b 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SolaceMessageReceiver.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SolaceMessageReceiver.java @@ -69,4 +69,11 @@ public BytesXMLMessage receive() throws IOException { throw new IOException(e); } } + + @Override + public void close() { + if (!isClosed()) { + this.flowReceiver.close(); + } + } } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/data/Semp.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/data/Semp.java new file mode 100644 index 000000000000..f6f0fb51d22e --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/data/Semp.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.data; + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; +import com.google.auto.value.AutoValue; + +public class Semp { + + @AutoValue + @JsonSerialize(as = Queue.class) + @JsonDeserialize(builder = AutoValue_Semp_Queue.Builder.class) + public abstract static class Queue { + + public abstract QueueData data(); + + public static Builder builder() { + return new AutoValue_Semp_Queue.Builder(); + } + + public abstract Builder toBuilder(); + + @AutoValue.Builder + @JsonPOJOBuilder(withPrefix = "set") + abstract static class Builder { + + public abstract Builder setData(QueueData queueData); + + public abstract Queue build(); + } + } + + @AutoValue + @JsonDeserialize(builder = AutoValue_Semp_QueueData.Builder.class) + public abstract static class QueueData { + public abstract String accessType(); + + public abstract long msgSpoolUsage(); + + public static Builder builder() { + return new AutoValue_Semp_QueueData.Builder(); + } + + public abstract Builder toBuilder(); + + @AutoValue.Builder + @JsonPOJOBuilder(withPrefix = "set") + abstract static class Builder { + + public abstract Builder setAccessType(String accessType); + + public abstract Builder setMsgSpoolUsage(long msgSpoolUsage); + + public abstract QueueData build(); + } + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/data/Solace.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/data/Solace.java index 18fee9184446..00b94b5b9ea9 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/data/Solace.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/data/Solace.java @@ -24,6 +24,7 @@ import java.nio.ByteBuffer; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.beam.sdk.schemas.annotations.SchemaFieldNumber; import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -118,6 +119,7 @@ public abstract static class Record { * * @return The message ID, or null if not available. */ + @SchemaFieldNumber("0") public abstract @Nullable String getMessageId(); /** @@ -127,6 +129,7 @@ public abstract static class Record { * * @return The message payload. */ + @SchemaFieldNumber("1") public abstract ByteBuffer getPayload(); /** * Gets the destination (topic or queue) to which the message was sent. @@ -135,6 +138,7 @@ public abstract static class Record { * * @return The destination, or null if not available. */ + @SchemaFieldNumber("2") public abstract @Nullable Destination getDestination(); /** @@ -146,6 +150,7 @@ public abstract static class Record { * * @return The expiration timestamp. */ + @SchemaFieldNumber("3") public abstract long getExpiration(); /** @@ -155,6 +160,7 @@ public abstract static class Record { * * @return The message priority. */ + @SchemaFieldNumber("4") public abstract int getPriority(); /** @@ -164,6 +170,7 @@ public abstract static class Record { * * @return True if redelivered, false otherwise. */ + @SchemaFieldNumber("5") public abstract boolean getRedelivered(); /** @@ -173,6 +180,7 @@ public abstract static class Record { * * @return The reply-to destination, or null if not specified. */ + @SchemaFieldNumber("6") public abstract @Nullable Destination getReplyTo(); /** @@ -183,6 +191,7 @@ public abstract static class Record { * * @return The timestamp. */ + @SchemaFieldNumber("7") public abstract long getReceiveTimestamp(); /** @@ -191,6 +200,7 @@ public abstract static class Record { * * @return The sender timestamp, or null if not available. */ + @SchemaFieldNumber("8") public abstract @Nullable Long getSenderTimestamp(); /** @@ -200,6 +210,7 @@ public abstract static class Record { * * @return The sequence number, or null if not available. */ + @SchemaFieldNumber("9") public abstract @Nullable Long getSequenceNumber(); /** @@ -210,6 +221,7 @@ public abstract static class Record { * * @return The time-to-live value. */ + @SchemaFieldNumber("10") public abstract long getTimeToLive(); /** @@ -225,7 +237,9 @@ public abstract static class Record { * * @return The replication group message ID, or null if not present. */ + @SchemaFieldNumber("11") public abstract @Nullable String getReplicationGroupMessageId(); + /** * Gets the attachment data of the message as a ByteString, if any. This might represent files * or other binary content associated with the message. @@ -234,6 +248,7 @@ public abstract static class Record { * * @return The attachment data, or an empty ByteString if no attachment is present. */ + @SchemaFieldNumber("12") public abstract ByteBuffer getAttachmentBytes(); static Builder builder() { @@ -271,6 +286,90 @@ abstract static class Builder { abstract Record build(); } } + + /** + * The result of writing a message to Solace. This will be returned by the {@link + * com.google.cloud.dataflow.dce.io.solace.SolaceIO.Write} connector. + * + *

This class provides a builder to create instances, but you will probably not need it. The + * write connector will create and return instances of {@link Solace.PublishResult}. + * + *

If the message has been published, {@link Solace.PublishResult#getPublished()} will be true. + * If it is false, it means that the message could not be published, and {@link + * Solace.PublishResult#getError()} will contain more details about why the message could not be + * published. + */ + @AutoValue + @DefaultSchema(AutoValueSchema.class) + public abstract static class PublishResult { + /** The message id of the message that was published. */ + @SchemaFieldNumber("0") + public abstract String getMessageId(); + + /** Whether the message was published or not. */ + @SchemaFieldNumber("1") + public abstract Boolean getPublished(); + + /** + * The publishing latency in milliseconds. This is the difference between the time the message + * was created, and the time the message was published. It is only available if the {@link + * CorrelationKey} class is used as correlation key of the messages. + */ + @SchemaFieldNumber("2") + public abstract @Nullable Long getLatencyMilliseconds(); + + /** The error details if the message could not be published. */ + @SchemaFieldNumber("3") + public abstract @Nullable String getError(); + + public static Builder builder() { + return new AutoValue_Solace_PublishResult.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setMessageId(String messageId); + + public abstract Builder setPublished(Boolean published); + + public abstract Builder setLatencyMilliseconds(Long latencyMs); + + public abstract Builder setError(String error); + + public abstract PublishResult build(); + } + } + + /** + * The correlation key is an object that is passed back to the client during the event broker ack + * or nack. + * + *

In the streaming writer is optionally used to calculate publish latencies, by calculating + * the time difference between the creation of the correlation key, and the time of the ack. + */ + @AutoValue + @DefaultSchema(AutoValueSchema.class) + public abstract static class CorrelationKey { + @SchemaFieldNumber("0") + public abstract String getMessageId(); + + @SchemaFieldNumber("1") + public abstract long getPublishMonotonicMillis(); + + public static Builder builder() { + return new AutoValue_Solace_CorrelationKey.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setMessageId(String messageId); + + public abstract Builder setPublishMonotonicMillis(long millis); + + public abstract CorrelationKey build(); + } + } + /** * A utility class for mapping {@link BytesXMLMessage} instances to {@link Solace.Record} objects. * This simplifies the process of converting raw Solace messages into a format suitable for use diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/SolaceOutput.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/SolaceOutput.java new file mode 100644 index 000000000000..6c37f879ae7f --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/SolaceOutput.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.write; + +import java.util.Map; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.solace.SolaceIO; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * The {@link SolaceIO.Write} transform's output return this type, containing both the successful + * publishes ({@link #getSuccessfulPublish()}) and the failed publishes ({@link + * #getFailedPublish()}). + * + *

The streaming writer with DIRECT messages does not return anything, and the output {@link + * PCollection}s will be equal to null. + */ +public final class SolaceOutput implements POutput { + private final Pipeline pipeline; + private final TupleTag failedPublishTag; + private final TupleTag successfulPublishTag; + private final @Nullable PCollection failedPublish; + private final @Nullable PCollection successfulPublish; + + public @Nullable PCollection getFailedPublish() { + return failedPublish; + } + + public @Nullable PCollection getSuccessfulPublish() { + return successfulPublish; + } + + public static SolaceOutput in( + Pipeline pipeline, + @Nullable PCollection failedPublish, + @Nullable PCollection successfulPublish) { + return new SolaceOutput( + pipeline, + SolaceIO.Write.FAILED_PUBLISH_TAG, + SolaceIO.Write.SUCCESSFUL_PUBLISH_TAG, + failedPublish, + successfulPublish); + } + + private SolaceOutput( + Pipeline pipeline, + TupleTag failedPublishTag, + TupleTag successfulPublishTag, + @Nullable PCollection failedPublish, + @Nullable PCollection successfulPublish) { + this.pipeline = pipeline; + this.failedPublishTag = failedPublishTag; + this.successfulPublishTag = successfulPublishTag; + this.failedPublish = failedPublish; + this.successfulPublish = successfulPublish; + } + + @Override + public Pipeline getPipeline() { + return pipeline; + } + + @Override + public Map, PValue> expand() { + ImmutableMap.Builder, PValue> builder = ImmutableMap., PValue>builder(); + + if (failedPublish != null) { + builder.put(failedPublishTag, failedPublish); + } + + if (successfulPublish != null) { + builder.put(successfulPublishTag, successfulPublish); + } + + return builder.build(); + } + + @Override + public void finishSpecifyingOutput( + String transformName, PInput input, PTransform transform) {} +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/package-info.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/package-info.java new file mode 100644 index 000000000000..65974b9b29c2 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** SolaceIO Write connector. */ +package org.apache.beam.sdk.io.solace.write; diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockEmptySessionService.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockEmptySessionService.java index 285c1cb8a7e8..ec0ae7194686 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockEmptySessionService.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockEmptySessionService.java @@ -17,10 +17,11 @@ */ package org.apache.beam.sdk.io.solace; +import com.solacesystems.jcsmp.JCSMPProperties; import org.apache.beam.sdk.io.solace.broker.MessageReceiver; import org.apache.beam.sdk.io.solace.broker.SessionService; -public class MockEmptySessionService implements SessionService { +public class MockEmptySessionService extends SessionService { String exceptionMessage = "This is an empty client, use a MockSessionService instead."; @@ -43,4 +44,9 @@ public MessageReceiver createReceiver() { public void connect() { throw new UnsupportedOperationException(exceptionMessage); } + + @Override + public JCSMPProperties initializeSessionProperties(JCSMPProperties baseProperties) { + throw new UnsupportedOperationException(exceptionMessage); + } } diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionService.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionService.java index 7b14da138c64..a4d6a42ef302 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionService.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionService.java @@ -18,23 +18,35 @@ package org.apache.beam.sdk.io.solace; import com.solacesystems.jcsmp.BytesXMLMessage; +import com.solacesystems.jcsmp.JCSMPProperties; import java.io.IOException; import java.io.Serializable; import java.util.concurrent.atomic.AtomicInteger; +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; import org.apache.beam.sdk.io.solace.broker.MessageReceiver; import org.apache.beam.sdk.io.solace.broker.SessionService; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.checkerframework.checker.nullness.qual.Nullable; -public class MockSessionService implements SessionService { +public class MockSessionService extends SessionService { private final SerializableFunction getRecordFn; private MessageReceiver messageReceiver = null; private final int minMessagesReceived; + private final @Nullable SubmissionMode mode; public MockSessionService( - SerializableFunction getRecordFn, int minMessagesReceived) { + SerializableFunction getRecordFn, + int minMessagesReceived, + @Nullable SubmissionMode mode) { this.getRecordFn = getRecordFn; this.minMessagesReceived = minMessagesReceived; + this.mode = mode; + } + + public MockSessionService( + SerializableFunction getRecordFn, int minMessagesReceived) { + this(getRecordFn, minMessagesReceived, null); } @Override @@ -80,9 +92,24 @@ public BytesXMLMessage receive() throws IOException { return getRecordFn.apply(counter.getAndIncrement()); } + @Override + public void close() {} + @Override public boolean isEOF() { return counter.get() >= minMessagesReceived; } } + + @Override + public JCSMPProperties initializeSessionProperties(JCSMPProperties baseProperties) { + // Let's override some properties that will be overriden by the connector + // Opposite of the mode, to test that is overriden + baseProperties.setProperty( + JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR, mode == SubmissionMode.HIGHER_THROUGHPUT); + + baseProperties.setProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE, 87); + + return baseProperties; + } } diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/broker/BasicAuthWriterSessionTest.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/broker/BasicAuthWriterSessionTest.java new file mode 100644 index 000000000000..e33917641e33 --- /dev/null +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/broker/BasicAuthWriterSessionTest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.broker; + +import static org.apache.beam.sdk.io.solace.broker.SessionService.DEFAULT_VPN_NAME; +import static org.junit.Assert.assertEquals; + +import com.solacesystems.jcsmp.JCSMPFactory; +import com.solacesystems.jcsmp.JCSMPProperties; +import com.solacesystems.jcsmp.Queue; +import org.apache.beam.sdk.io.solace.SolaceIO; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class BasicAuthWriterSessionTest { + private final String username = "Some Username"; + private final String password = "Some Password"; + private final String host = "Some Host"; + private final String vpn = "Some non default VPN"; + SessionService withVpn; + SessionService withoutVpn; + + @Before + public void setUp() throws Exception { + Queue q = JCSMPFactory.onlyInstance().createQueue("test-queue"); + + BasicAuthJcsmpSessionServiceFactory factoryWithVpn = + BasicAuthJcsmpSessionServiceFactory.builder() + .username(username) + .password(password) + .host(host) + .vpnName(vpn) + .build(); + factoryWithVpn.setQueue(q); + withVpn = factoryWithVpn.create(); + + BasicAuthJcsmpSessionServiceFactory factoryNoVpn = + BasicAuthJcsmpSessionServiceFactory.builder() + .username(username) + .password(password) + .host(host) + .build(); + factoryNoVpn.setQueue(q); + withoutVpn = factoryNoVpn.create(); + } + + @Test + public void testAuthProperties() { + SolaceIO.SubmissionMode mode = SolaceIO.SubmissionMode.HIGHER_THROUGHPUT; + JCSMPProperties props = withoutVpn.initializeWriteSessionProperties(mode); + assertEquals(username, props.getStringProperty(JCSMPProperties.USERNAME)); + assertEquals(password, props.getStringProperty(JCSMPProperties.PASSWORD)); + assertEquals(host, props.getStringProperty(JCSMPProperties.HOST)); + assertEquals( + JCSMPProperties.AUTHENTICATION_SCHEME_BASIC, + props.getStringProperty(JCSMPProperties.AUTHENTICATION_SCHEME)); + } + + @Test + public void testVpnNames() { + SolaceIO.SubmissionMode mode = SolaceIO.SubmissionMode.LOWER_LATENCY; + JCSMPProperties propsWithoutVpn = withoutVpn.initializeWriteSessionProperties(mode); + assertEquals(DEFAULT_VPN_NAME, propsWithoutVpn.getStringProperty(JCSMPProperties.VPN_NAME)); + JCSMPProperties propsWithVpn = withVpn.initializeWriteSessionProperties(mode); + assertEquals(vpn, propsWithVpn.getStringProperty(JCSMPProperties.VPN_NAME)); + } + + @Test + public void testOverrideWithHigherThroughput() { + SolaceIO.SubmissionMode mode = SolaceIO.SubmissionMode.HIGHER_THROUGHPUT; + JCSMPProperties props = withoutVpn.initializeWriteSessionProperties(mode); + + assertEquals(false, props.getBooleanProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR)); + assertEquals( + Long.valueOf(255), + Long.valueOf(props.getIntegerProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE))); + } + + @Test + public void testOverrideWithLowerLatency() { + SolaceIO.SubmissionMode mode = SolaceIO.SubmissionMode.LOWER_LATENCY; + JCSMPProperties props = withoutVpn.initializeWriteSessionProperties(mode); + assertEquals(true, props.getBooleanProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR)); + assertEquals( + Long.valueOf(50), + Long.valueOf(props.getIntegerProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE))); + } +} diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/broker/OverrideWriterPropertiesTest.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/broker/OverrideWriterPropertiesTest.java new file mode 100644 index 000000000000..0c6f88a7c9d5 --- /dev/null +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/broker/OverrideWriterPropertiesTest.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.broker; + +import static org.junit.Assert.assertEquals; + +import com.solacesystems.jcsmp.JCSMPProperties; +import org.apache.beam.sdk.io.solace.MockSessionService; +import org.apache.beam.sdk.io.solace.SolaceIO; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class OverrideWriterPropertiesTest { + @Test + public void testOverrideForHigherThroughput() { + SolaceIO.SubmissionMode mode = SolaceIO.SubmissionMode.HIGHER_THROUGHPUT; + MockSessionService service = new MockSessionService(null, 0, mode); + + // Test HIGHER_THROUGHPUT mode + JCSMPProperties props = service.initializeWriteSessionProperties(mode); + assertEquals(false, props.getBooleanProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR)); + assertEquals( + Long.valueOf(255), + Long.valueOf(props.getIntegerProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE))); + } + + @Test + public void testOverrideForLowerLatency() { + SolaceIO.SubmissionMode mode = SolaceIO.SubmissionMode.LOWER_LATENCY; + MockSessionService service = new MockSessionService(null, 0, mode); + + // Test HIGHER_THROUGHPUT mode + JCSMPProperties props = service.initializeWriteSessionProperties(mode); + assertEquals(true, props.getBooleanProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR)); + assertEquals( + Long.valueOf(50), + Long.valueOf(props.getIntegerProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE))); + } +} diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/broker/SempBasicAuthClientExecutorTest.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/broker/SempBasicAuthClientExecutorTest.java new file mode 100644 index 000000000000..8cc48ed17ef6 --- /dev/null +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/broker/SempBasicAuthClientExecutorTest.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.broker; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import com.google.api.client.http.HttpRequestFactory; +import com.google.api.client.http.HttpResponseException; +import com.google.api.client.http.LowLevelHttpRequest; +import com.google.api.client.http.LowLevelHttpResponse; +import com.google.api.client.json.Json; +import com.google.api.client.testing.http.MockHttpTransport; +import com.google.api.client.testing.http.MockLowLevelHttpRequest; +import com.google.api.client.testing.http.MockLowLevelHttpResponse; +import java.io.IOException; +import java.util.List; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.junit.Test; + +public class SempBasicAuthClientExecutorTest { + + @Test + public void testExecuteStatus4xx() { + MockHttpTransport transport = + new MockHttpTransport() { + @Override + public LowLevelHttpRequest buildRequest(String method, String url) { + return new MockLowLevelHttpRequest() { + @Override + public LowLevelHttpResponse execute() { + MockLowLevelHttpResponse response = new MockLowLevelHttpResponse(); + response.setStatusCode(404); + response.setContentType(Json.MEDIA_TYPE); + response.setContent( + "{\"meta\":{\"error\":{\"code\":404,\"description\":\"some" + + " error\",\"status\":\"xx\"}}}"); + return response; + } + }; + } + }; + + HttpRequestFactory requestFactory = transport.createRequestFactory(); + SempBasicAuthClientExecutor client = + new SempBasicAuthClientExecutor( + "http://host", "username", "password", "vpnName", requestFactory); + + assertThrows(HttpResponseException.class, () -> client.getQueueResponse("queue")); + } + + @Test + public void testExecuteStatus3xx() { + MockHttpTransport transport = + new MockHttpTransport() { + @Override + public LowLevelHttpRequest buildRequest(String method, String url) { + return new MockLowLevelHttpRequest() { + @Override + public LowLevelHttpResponse execute() { + MockLowLevelHttpResponse response = new MockLowLevelHttpResponse(); + response.setStatusCode(301); + response.setContentType(Json.MEDIA_TYPE); + response.setContent( + "{\"meta\":{\"error\":{\"code\":301,\"description\":\"some" + + " error\",\"status\":\"xx\"}}}"); + return response; + } + }; + } + }; + + HttpRequestFactory requestFactory = transport.createRequestFactory(); + SempBasicAuthClientExecutor client = + new SempBasicAuthClientExecutor( + "http://host", "username", "password", "vpnName", requestFactory); + + assertThrows(HttpResponseException.class, () -> client.getQueueResponse("queue")); + } + + /** + * In this test case, we test a situation when a session that we used to authenticate to Semp + * expires. + * + *

To test this scenario, we need to do the following: + * + *

    + *
  1. Send the first request, to initialize a session. This request has to contain the Basic + * Auth header and should not include any cookie headers. The response for this request + * contains a session cookie we can re-use in the following requests. + *
  2. Send the second request - this request should use a cookie from the previous response. + * There should be no Authorization header. To simulate an expired session scenario, we set + * the response of this request to the "401 Unauthorized". This should cause a the request + * to be retried, this time with the Authorization header. + *
  3. Validate the third request to contain the Basic Auth header and no session cookies. + *
+ */ + @Test + public void testExecuteWithUnauthorized() throws IOException { + // Making it a final array, so that we can reference it from within the MockHttpTransport + // instance + final int[] requestCounter = {0}; + MockHttpTransport transport = + new MockHttpTransport() { + @Override + public LowLevelHttpRequest buildRequest(String method, String url) { + return new MockLowLevelHttpRequest() { + @Override + public LowLevelHttpResponse execute() throws IOException { + MockLowLevelHttpResponse response = new MockLowLevelHttpResponse(); + if (requestCounter[0] == 0) { + // The first request has to include Basic Auth header + assertTrue(this.getHeaders().containsKey("authorization")); + List authorizationHeaders = this.getHeaders().get("authorization"); + assertEquals(1, authorizationHeaders.size()); + assertTrue(authorizationHeaders.get(0).contains("Basic")); + assertFalse(this.getHeaders().containsKey("cookie")); + + // Set the response to include Session cookies + response + .setHeaderNames(ImmutableList.of("Set-Cookie", "Set-Cookie")) + .setHeaderValues( + ImmutableList.of( + "ProxySession=JddSdJaGo6FYYmQk6nt8jXxFtq6n3FCFR14ebzRGQ5w;" + + " HttpOnly; SameSite=Strict;" + + " Path=/proxy; Max-Age=2592000", + "Session=JddSdJaGo6FYYmQk6nt8jXxFtq6n3FCFR14ebzRGQ5w;" + + " HttpOnly; SameSite=Strict;" + + " Path=/SEMP; Max-Age=2592000")); + response.setStatusCode(200); + } else if (requestCounter[0] == 1) { + // The second request does not include Basic Auth header + assertFalse(this.getHeaders().containsKey("authorization")); + // It must include a cookie header + assertTrue(this.getHeaders().containsKey("cookie")); + boolean hasSessionCookie = + this.getHeaders().get("cookie").stream() + .filter( + c -> + c.contains( + "Session=JddSdJaGo6FYYmQk6nt8jXxFtq6n3FCFR14ebzRGQ5w")) + .count() + == 1; + assertTrue(hasSessionCookie); + + // Let's assume the Session expired - we return the 401 + // unauthorized + response.setStatusCode(401); + } else { + // The second request has to be retried with a Basic Auth header + // this time + assertTrue(this.getHeaders().containsKey("authorization")); + List authorizationHeaders = this.getHeaders().get("authorization"); + assertEquals(1, authorizationHeaders.size()); + assertTrue(authorizationHeaders.get(0).contains("Basic")); + assertFalse(this.getHeaders().containsKey("cookie")); + + response.setStatusCode(200); + } + response.setContentType(Json.MEDIA_TYPE); + requestCounter[0]++; + return response; + } + }; + } + }; + + HttpRequestFactory requestFactory = transport.createRequestFactory(); + SempBasicAuthClientExecutor client = + new SempBasicAuthClientExecutor( + "http://host", "username", "password", "vpnName", requestFactory); + + // The first, initial request + client.getQueueResponse("queue"); + // The second request, which will try to authenticate with a cookie, and then with Basic + // Auth when it receives a 401 unauthorized + client.getQueueResponse("queue"); + + // There should be 3 requests executed: + // the first one is the initial one with Basic Auth, + // the second one uses the session cookie, but we simulate it being expired, + // so there should be a third request with Basic Auth to create a new session. + assertEquals(3, requestCounter[0]); + } +} diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceContainerManager.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceContainerManager.java new file mode 100644 index 000000000000..6d2b3a27ffd0 --- /dev/null +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceContainerManager.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.it; + +import java.io.IOException; +import java.net.ServerSocket; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.Container.ExecResult; +import org.testcontainers.containers.output.Slf4jLogConsumer; +import org.testcontainers.solace.Service; +import org.testcontainers.solace.SolaceContainer; +import org.testcontainers.utility.DockerImageName; + +public class SolaceContainerManager { + + public static final String VPN_NAME = "default"; + public static final String PASSWORD = "password"; + public static final String USERNAME = "username"; + public static final String TOPIC_NAME = "test_topic"; + private static final Logger LOG = LoggerFactory.getLogger(SolaceContainerManager.class); + private final SolaceContainer container; + int jcsmpPortMapped = findAvailablePort(); + int sempPortMapped = findAvailablePort(); + + public SolaceContainerManager() throws IOException { + this.container = + new SolaceContainer(DockerImageName.parse("solace/solace-pubsub-standard:10.7")) { + { + addFixedExposedPort(jcsmpPortMapped, 55555); + addFixedExposedPort(sempPortMapped, 8080); + } + }.withVpn(VPN_NAME) + .withCredentials(USERNAME, PASSWORD) + .withTopic(TOPIC_NAME, Service.SMF) + .withLogConsumer(new Slf4jLogConsumer(LOG)); + } + + public void start() { + container.start(); + } + + void createQueueWithSubscriptionTopic(String queueName) { + executeCommand( + "curl", + "http://localhost:8080/SEMP/v2/config/msgVpns/" + VPN_NAME + "/topicEndpoints", + "-X", + "GET", + "-u", + "admin:admin"); + executeCommand( + "curl", + "http://localhost:8080/SEMP/v2/config/msgVpns/" + VPN_NAME + "/topicEndpoints", + "-X", + "POST", + "-u", + "admin:admin", + "-H", + "Content-Type:application/json", + "-d", + "{\"topicEndpointName\":\"" + + TOPIC_NAME + + "\",\"accessType\":\"exclusive\",\"permission\":\"modify-topic\",\"ingressEnabled\":true,\"egressEnabled\":true}"); + executeCommand( + "curl", + "http://localhost:8080/SEMP/v2/config/msgVpns/" + VPN_NAME + "/queues", + "-X", + "POST", + "-u", + "admin:admin", + "-H", + "Content-Type:application/json", + "-d", + "{\"queueName\":\"" + + queueName + + "\",\"accessType\":\"non-exclusive\",\"maxMsgSpoolUsage\":200,\"permission\":\"consume\",\"ingressEnabled\":true,\"egressEnabled\":true}"); + executeCommand( + "curl", + "http://localhost:8080/SEMP/v2/config/msgVpns/" + + VPN_NAME + + "/queues/" + + queueName + + "/subscriptions", + "-X", + "POST", + "-u", + "admin:admin", + "-H", + "Content-Type:application/json", + "-d", + "{\"subscriptionTopic\":\"" + TOPIC_NAME + "\"}"); + } + + private void executeCommand(String... command) { + try { + ExecResult execResult = container.execInContainer(command); + if (execResult.getExitCode() != 0) { + logCommandError(execResult.getStderr(), command); + } else { + LOG.info(execResult.getStdout()); + } + } catch (IOException | InterruptedException e) { + logCommandError(e.getMessage(), command); + } + } + + private void logCommandError(String error, String... command) { + LOG.error("Could not execute command {}: {}", command, error); + } + + public void stop() { + if (container != null) { + container.stop(); + } + } + + public void getQueueDetails(String queueName) { + executeCommand( + "curl", + "http://localhost:8080/SEMP/v2/monitor/msgVpns/" + + VPN_NAME + + "/queues/" + + queueName + + "/msgs", + "-X", + "GET", + "-u", + "admin:admin"); + } + + public void sendToTopic(String payload, List additionalHeaders) { + // https://docs.solace.com/API/RESTMessagingPrtl/Solace-REST-Message-Encoding.htm + + List command = + new ArrayList<>( + Arrays.asList( + "curl", + "http://localhost:9000/TOPIC/" + TOPIC_NAME, + "-X", + "POST", + "-u", + USERNAME + ":" + PASSWORD, + "--header", + "Content-Type:application/json", + "-d", + payload)); + + for (String additionalHeader : additionalHeaders) { + command.add("--header"); + command.add(additionalHeader); + } + + executeCommand(command.toArray(new String[0])); + } + + private static int findAvailablePort() throws IOException { + ServerSocket s = new ServerSocket(0); + try { + return s.getLocalPort(); + } finally { + s.close(); + try { + // Some systems don't free the port for future use immediately. + Thread.sleep(100); + } catch (InterruptedException exn) { + // ignore + } + } + } +} diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceIOIT.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceIOIT.java new file mode 100644 index 000000000000..35ee7595352d --- /dev/null +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceIOIT.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.it; + +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.io.solace.SolaceIO; +import org.apache.beam.sdk.io.solace.broker.BasicAuthJcsmpSessionServiceFactory; +import org.apache.beam.sdk.io.solace.broker.BasicAuthSempClientFactory; +import org.apache.beam.sdk.io.solace.data.Solace.Queue; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.StreamingOptions; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestPipelineOptions; +import org.apache.beam.sdk.testutils.metrics.MetricsReader; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.joda.time.Duration; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; + +public class SolaceIOIT { + private static final String NAMESPACE = SolaceIOIT.class.getName(); + private static final String READ_COUNT = "read_count"; + private static SolaceContainerManager solaceContainerManager; + private static final TestPipelineOptions readPipelineOptions; + + static { + readPipelineOptions = PipelineOptionsFactory.create().as(TestPipelineOptions.class); + readPipelineOptions.setBlockOnRun(false); + readPipelineOptions.as(TestPipelineOptions.class).setBlockOnRun(false); + readPipelineOptions.as(StreamingOptions.class).setStreaming(false); + } + + @Rule public final TestPipeline readPipeline = TestPipeline.fromOptions(readPipelineOptions); + + @BeforeClass + public static void setup() throws IOException { + solaceContainerManager = new SolaceContainerManager(); + solaceContainerManager.start(); + } + + @AfterClass + public static void afterClass() { + if (solaceContainerManager != null) { + solaceContainerManager.stop(); + } + } + + @Test + public void testRead() { + String queueName = "test_queue"; + solaceContainerManager.createQueueWithSubscriptionTopic(queueName); + + // todo this is very slow, needs to be replaced with the SolaceIO.write connector. + int publishMessagesCount = 20; + for (int i = 0; i < publishMessagesCount; i++) { + solaceContainerManager.sendToTopic( + "{\"field_str\":\"value\",\"field_int\":123}", + ImmutableList.of("Solace-Message-ID:m" + i)); + } + + readPipeline + .apply( + "Read from Solace", + SolaceIO.read() + .from(Queue.fromName(queueName)) + .withMaxNumConnections(1) + .withSempClientFactory( + BasicAuthSempClientFactory.builder() + .host("http://localhost:" + solaceContainerManager.sempPortMapped) + .username("admin") + .password("admin") + .vpnName(SolaceContainerManager.VPN_NAME) + .build()) + .withSessionServiceFactory( + BasicAuthJcsmpSessionServiceFactory.builder() + .host("localhost:" + solaceContainerManager.jcsmpPortMapped) + .username(SolaceContainerManager.USERNAME) + .password(SolaceContainerManager.PASSWORD) + .vpnName(SolaceContainerManager.VPN_NAME) + .build())) + .apply("Count", ParDo.of(new CountingFn<>(NAMESPACE, READ_COUNT))); + + PipelineResult pipelineResult = readPipeline.run(); + pipelineResult.waitUntilFinish(Duration.standardSeconds(15)); + + MetricsReader metricsReader = new MetricsReader(pipelineResult, NAMESPACE); + long actualRecordsCount = metricsReader.getCounterMetric(READ_COUNT); + assertEquals(publishMessagesCount, actualRecordsCount); + } + + private static class CountingFn extends DoFn { + + private final Counter elementCounter; + + CountingFn(String namespace, String name) { + elementCounter = Metrics.counter(namespace, name); + } + + @ProcessElement + public void processElement(@Element T record, OutputReceiver c) { + elementCounter.inc(1L); + c.output(record); + } + } +} diff --git a/sdks/java/io/synthetic/src/main/java/org/apache/beam/sdk/io/synthetic/SyntheticStep.java b/sdks/java/io/synthetic/src/main/java/org/apache/beam/sdk/io/synthetic/SyntheticStep.java index d32640ffbf7d..98db23c95a38 100644 --- a/sdks/java/io/synthetic/src/main/java/org/apache/beam/sdk/io/synthetic/SyntheticStep.java +++ b/sdks/java/io/synthetic/src/main/java/org/apache/beam/sdk/io/synthetic/SyntheticStep.java @@ -58,7 +58,7 @@ public class SyntheticStep extends DoFn, KV> private final KV idAndThroughput; private final Counter throttlingCounter = - Metrics.counter("dataflow-throttling-metrics", "throttling-msecs"); + Metrics.counter("dataflow-throttling-metrics", Metrics.THROTTLE_TIME_COUNTER_NAME); /** * Static cache to store one worker level rate limiter for a step. Value in KV is the desired diff --git a/sdks/python/apache_beam/coders/observable_test.py b/sdks/python/apache_beam/coders/observable_test.py index 46f5186ba533..df4e7ef09408 100644 --- a/sdks/python/apache_beam/coders/observable_test.py +++ b/sdks/python/apache_beam/coders/observable_test.py @@ -29,7 +29,7 @@ class ObservableMixinTest(unittest.TestCase): observed_count = 0 observed_sum = 0 - observed_keys = [] # type: List[Optional[str]] + observed_keys: List[Optional[str]] = [] def observer(self, value, key=None): self.observed_count += 1 diff --git a/sdks/python/apache_beam/coders/row_coder.py b/sdks/python/apache_beam/coders/row_coder.py index 7765ccebc26f..e93abbc887fb 100644 --- a/sdks/python/apache_beam/coders/row_coder.py +++ b/sdks/python/apache_beam/coders/row_coder.py @@ -117,8 +117,7 @@ def from_type_hint(cls, type_hint, registry): return cls(schema) @staticmethod - def from_payload(payload): - # type: (bytes) -> RowCoder + def from_payload(payload: bytes) -> 'RowCoder': return RowCoder(proto_utils.parse_Bytes(payload, schema_pb2.Schema)) def __reduce__(self): diff --git a/sdks/python/apache_beam/coders/slow_stream.py b/sdks/python/apache_beam/coders/slow_stream.py index 71a5b45d7691..b08ad8e9a37f 100644 --- a/sdks/python/apache_beam/coders/slow_stream.py +++ b/sdks/python/apache_beam/coders/slow_stream.py @@ -30,11 +30,10 @@ class OutputStream(object): A pure Python implementation of stream.OutputStream.""" def __init__(self): - self.data = [] # type: List[bytes] + self.data: List[bytes] = [] self.byte_count = 0 - def write(self, b, nested=False): - # type: (bytes, bool) -> None + def write(self, b: bytes, nested: bool = False) -> None: assert isinstance(b, bytes) if nested: self.write_var_int64(len(b)) @@ -45,8 +44,7 @@ def write_byte(self, val): self.data.append(chr(val).encode('latin-1')) self.byte_count += 1 - def write_var_int64(self, v): - # type: (int) -> None + def write_var_int64(self, v: int) -> None: if v < 0: v += 1 << 64 if v <= 0: @@ -78,16 +76,13 @@ def write_bigendian_double(self, v): def write_bigendian_float(self, v): self.write(struct.pack('>f', v)) - def get(self): - # type: () -> bytes + def get(self) -> bytes: return b''.join(self.data) - def size(self): - # type: () -> int + def size(self) -> int: return self.byte_count - def _clear(self): - # type: () -> None + def _clear(self) -> None: self.data = [] self.byte_count = 0 @@ -101,8 +96,7 @@ def __init__(self): super().__init__() self.count = 0 - def write(self, byte_array, nested=False): - # type: (bytes, bool) -> None + def write(self, byte_array: bytes, nested: bool = False) -> None: blen = len(byte_array) if nested: self.write_var_int64(blen) @@ -125,25 +119,21 @@ class InputStream(object): """For internal use only; no backwards-compatibility guarantees. A pure Python implementation of stream.InputStream.""" - def __init__(self, data): - # type: (bytes) -> None + def __init__(self, data: bytes) -> None: self.data = data self.pos = 0 def size(self): return len(self.data) - self.pos - def read(self, size): - # type: (int) -> bytes + def read(self, size: int) -> bytes: self.pos += size return self.data[self.pos - size:self.pos] - def read_all(self, nested): - # type: (bool) -> bytes + def read_all(self, nested: bool) -> bytes: return self.read(self.read_var_int64() if nested else self.size()) - def read_byte(self): - # type: () -> int + def read_byte(self) -> int: self.pos += 1 return self.data[self.pos - 1] diff --git a/sdks/python/apache_beam/coders/standard_coders_test.py b/sdks/python/apache_beam/coders/standard_coders_test.py index b2cbe6e339f7..47df0116f2c6 100644 --- a/sdks/python/apache_beam/coders/standard_coders_test.py +++ b/sdks/python/apache_beam/coders/standard_coders_test.py @@ -300,7 +300,7 @@ def json_value_parser(self, coder_spec): # Used when --fix is passed. fix = False - to_fix = {} # type: Dict[Tuple[int, bytes], bytes] + to_fix: Dict[Tuple[int, bytes], bytes] = {} @classmethod def tearDownClass(cls): diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py index e32e4823c48d..1667cb7a916a 100644 --- a/sdks/python/apache_beam/coders/typecoders.py +++ b/sdks/python/apache_beam/coders/typecoders.py @@ -80,8 +80,8 @@ def MakeXyzs(v): class CoderRegistry(object): """A coder registry for typehint/coder associations.""" def __init__(self, fallback_coder=None): - self._coders = {} # type: Dict[Any, Type[coders.Coder]] - self.custom_types = [] # type: List[Any] + self._coders: Dict[Any, Type[coders.Coder]] = {} + self.custom_types: List[Any] = [] self.register_standard_coders(fallback_coder) def register_standard_coders(self, fallback_coder): @@ -104,12 +104,14 @@ def register_standard_coders(self, fallback_coder): def register_fallback_coder(self, fallback_coder): self._fallback_coder = FirstOf([fallback_coder, self._fallback_coder]) - def _register_coder_internal(self, typehint_type, typehint_coder_class): - # type: (Any, Type[coders.Coder]) -> None + def _register_coder_internal( + self, typehint_type: Any, + typehint_coder_class: Type[coders.Coder]) -> None: self._coders[typehint_type] = typehint_coder_class - def register_coder(self, typehint_type, typehint_coder_class): - # type: (Any, Type[coders.Coder]) -> None + def register_coder( + self, typehint_type: Any, + typehint_coder_class: Type[coders.Coder]) -> None: if not isinstance(typehint_coder_class, type): raise TypeError( 'Coder registration requires a coder class object. ' @@ -122,8 +124,7 @@ def register_coder(self, typehint_type, typehint_coder_class): typehint_type = getattr(typehint_type, '__name__', str(typehint_type)) self._register_coder_internal(typehint_type, typehint_coder_class) - def get_coder(self, typehint): - # type: (Any) -> coders.Coder + def get_coder(self, typehint: Any) -> coders.Coder: if typehint and typehint.__module__ == '__main__': # See https://github.com/apache/beam/issues/21541 # TODO(robertwb): Remove once all runners are portable. @@ -187,8 +188,7 @@ class FirstOf(object): """For internal use only; no backwards-compatibility guarantees. A class used to get the first matching coder from a list of coders.""" - def __init__(self, coders): - # type: (Iterable[Type[coders.Coder]]) -> None + def __init__(self, coders: Iterable[Type[coders.Coder]]) -> None: self._coders = coders def from_type_hint(self, typehint, registry): diff --git a/sdks/python/apache_beam/dataframe/convert.py b/sdks/python/apache_beam/dataframe/convert.py index 96d0c4f8b9f5..e44cc429eac1 100644 --- a/sdks/python/apache_beam/dataframe/convert.py +++ b/sdks/python/apache_beam/dataframe/convert.py @@ -17,10 +17,10 @@ import inspect import warnings import weakref -from typing import TYPE_CHECKING from typing import Any from typing import Dict from typing import Iterable +from typing import Optional from typing import Tuple from typing import Union @@ -35,19 +35,13 @@ from apache_beam.dataframe.schemas import generate_proxy from apache_beam.typehints.pandas_type_compatibility import dtype_to_fieldtype -if TYPE_CHECKING: - # pylint: disable=ungrouped-imports - from typing import Optional - # TODO: Or should this be called as_dataframe? def to_dataframe( - pcoll, # type: pvalue.PCollection - proxy=None, # type: Optional[pd.core.generic.NDFrame] - label=None, # type: Optional[str] -): - # type: (...) -> frame_base.DeferredFrame - + pcoll: pvalue.PCollection, + proxy: Optional[pd.core.generic.NDFrame] = None, + label: Optional[str] = None, +) -> frame_base.DeferredFrame: """Converts a PCollection to a deferred dataframe-like object, which can manipulated with pandas methods like `filter` and `groupby`. @@ -93,10 +87,10 @@ def to_dataframe( # Note that the pipeline (indirectly) holds references to the transforms which # keeps both the PCollections and expressions alive. This ensures the # expression's ids are never accidentally re-used. -TO_PCOLLECTION_CACHE = weakref.WeakValueDictionary( -) # type: weakref.WeakValueDictionary[str, pvalue.PCollection] -UNBATCHED_CACHE = weakref.WeakValueDictionary( -) # type: weakref.WeakValueDictionary[str, pvalue.PCollection] +TO_PCOLLECTION_CACHE: 'weakref.WeakValueDictionary[str, pvalue.PCollection]' = ( + weakref.WeakValueDictionary()) +UNBATCHED_CACHE: 'weakref.WeakValueDictionary[str, pvalue.PCollection]' = ( + weakref.WeakValueDictionary()) class RowsToDataFrameFn(beam.DoFn): @@ -173,7 +167,7 @@ def infer_output_type(self, input_element_type): def to_pcollection( - *dataframes, # type: Union[frame_base.DeferredFrame, pd.DataFrame, pd.Series] + *dataframes: Union[frame_base.DeferredFrame, pd.DataFrame, pd.Series], label=None, always_return_tuple=False, yield_elements='schemas', @@ -258,12 +252,12 @@ def extract_input(placeholder): df for df in dataframes if df._expr._id not in TO_PCOLLECTION_CACHE ] if len(new_dataframes): - new_results = {p: extract_input(p) - for p in placeholders - } | label >> transforms._DataframeExpressionsTransform({ - ix: df._expr - for (ix, df) in enumerate(new_dataframes) - }) # type: Dict[Any, pvalue.PCollection] + new_results: Dict[Any, pvalue.PCollection] = { + p: extract_input(p) + for p in placeholders + } | label >> transforms._DataframeExpressionsTransform( + {ix: df._expr + for (ix, df) in enumerate(new_dataframes)}) TO_PCOLLECTION_CACHE.update( {new_dataframes[ix]._expr._id: pc diff --git a/sdks/python/apache_beam/dataframe/frame_base.py b/sdks/python/apache_beam/dataframe/frame_base.py index 4e89e473b730..90f34d45dd98 100644 --- a/sdks/python/apache_beam/dataframe/frame_base.py +++ b/sdks/python/apache_beam/dataframe/frame_base.py @@ -38,7 +38,7 @@ class DeferredBase(object): - _pandas_type_map = {} # type: Dict[Union[type, None], type] + _pandas_type_map: Dict[Union[type, None], type] = {} def __init__(self, expr): self._expr = expr @@ -197,8 +197,8 @@ def _proxy_method( inplace=False, base=None, *, - requires_partition_by, # type: partitionings.Partitioning - preserves_partition_by, # type: partitionings.Partitioning + requires_partition_by: partitionings.Partitioning, + preserves_partition_by: partitionings.Partitioning, ): if name is None: name, func = name_and_func(func) @@ -227,14 +227,14 @@ def _elementwise_function( def _proxy_function( - func, # type: Union[Callable, str] - name=None, # type: Optional[str] - restrictions=None, # type: Optional[Dict[str, Union[Any, List[Any]]]] - inplace=False, # type: bool - base=None, # type: Optional[type] + func: Union[Callable, str], + name: Optional[str] = None, + restrictions: Optional[Dict[str, Union[Any, List[Any]]]] = None, + inplace: bool = False, + base: Optional[type] = None, *, - requires_partition_by, # type: partitionings.Partitioning - preserves_partition_by, # type: partitionings.Partitioning + requires_partition_by: partitionings.Partitioning, + preserves_partition_by: partitionings.Partitioning, ): if name is None: diff --git a/sdks/python/apache_beam/dataframe/partitionings.py b/sdks/python/apache_beam/dataframe/partitionings.py index 5513f4bb496e..0ff09e111480 100644 --- a/sdks/python/apache_beam/dataframe/partitionings.py +++ b/sdks/python/apache_beam/dataframe/partitionings.py @@ -32,9 +32,7 @@ class Partitioning(object): def __repr__(self): return self.__class__.__name__ - def is_subpartitioning_of(self, other): - # type: (Partitioning) -> bool - + def is_subpartitioning_of(self, other: 'Partitioning') -> bool: """Returns whether self is a sub-partition of other. Specifically, returns whether something partitioned by self is necissarily @@ -48,9 +46,8 @@ def __lt__(self, other): def __le__(self, other): return not self.is_subpartitioning_of(other) - def partition_fn(self, df, num_partitions): - # type: (Frame, int) -> Iterable[Tuple[Any, Frame]] - + def partition_fn(self, df: Frame, + num_partitions: int) -> Iterable[Tuple[Any, Frame]]: """A callable that actually performs the partitioning of a Frame df. This will be invoked via a FlatMap in conjunction with a GroupKey to diff --git a/sdks/python/apache_beam/dataframe/schemas.py b/sdks/python/apache_beam/dataframe/schemas.py index 6356945e05f9..e70229f21f77 100644 --- a/sdks/python/apache_beam/dataframe/schemas.py +++ b/sdks/python/apache_beam/dataframe/schemas.py @@ -85,9 +85,7 @@ def expand(self, pcoll): | beam.Map(converter.produce_batch)) -def generate_proxy(element_type): - # type: (type) -> pd.DataFrame - +def generate_proxy(element_type: type) -> pd.DataFrame: """Generate a proxy pandas object for the given PCollection element_type. Currently only supports generating a DataFrame proxy from a schema-aware @@ -106,9 +104,8 @@ def generate_proxy(element_type): return proxy -def element_type_from_dataframe(proxy, include_indexes=False): - # type: (pd.DataFrame, bool) -> type - +def element_type_from_dataframe( + proxy: pd.DataFrame, include_indexes: bool = False) -> type: """Generate an element_type for an element-wise PCollection from a proxy pandas object. Currently only supports converting the element_type for a schema-aware PCollection to a proxy DataFrame. diff --git a/sdks/python/apache_beam/dataframe/schemas_test.py b/sdks/python/apache_beam/dataframe/schemas_test.py index ed0ba6b342af..4c196e29e712 100644 --- a/sdks/python/apache_beam/dataframe/schemas_test.py +++ b/sdks/python/apache_beam/dataframe/schemas_test.py @@ -64,36 +64,57 @@ def check_df_pcoll_equal(actual): # pd.Series([b'abc'], dtype=bytes).dtype != 'S' # pd.Series([b'abc'], dtype=bytes).astype(bytes).dtype == 'S' # (test data, pandas_type, column_name, beam_type) -COLUMNS = [ - ([375, 24, 0, 10, 16], np.int32, 'i32', np.int32), - ([375, 24, 0, 10, 16], np.int64, 'i64', np.int64), - ([375, 24, None, 10, 16], - pd.Int32Dtype(), - 'i32_nullable', - typing.Optional[np.int32]), - ([375, 24, None, 10, 16], - pd.Int64Dtype(), - 'i64_nullable', - typing.Optional[np.int64]), - ([375., 24., None, 10., 16.], - np.float64, - 'f64', - typing.Optional[np.float64]), - ([375., 24., None, 10., 16.], - np.float32, - 'f32', - typing.Optional[np.float32]), - ([True, False, True, True, False], bool, 'bool', bool), - (['Falcon', 'Ostrich', None, 3.14, 0], object, 'any', typing.Any), - ([True, False, True, None, False], - pd.BooleanDtype(), - 'bool_nullable', - typing.Optional[bool]), - (['Falcon', 'Ostrich', None, 'Aardvark', 'Elephant'], - pd.StringDtype(), - 'strdtype', - typing.Optional[str]), -] # type: typing.List[typing.Tuple[typing.List[typing.Any], typing.Any, str, typing.Any]] +COLUMNS: typing.List[typing.Tuple[typing.List[typing.Any], + typing.Any, + str, + typing.Any]] = [ + ([375, 24, 0, 10, 16], + np.int32, + 'i32', + np.int32), + ([375, 24, 0, 10, 16], + np.int64, + 'i64', + np.int64), + ([375, 24, None, 10, 16], + pd.Int32Dtype(), + 'i32_nullable', + typing.Optional[np.int32]), + ([375, 24, None, 10, 16], + pd.Int64Dtype(), + 'i64_nullable', + typing.Optional[np.int64]), + ([375., 24., None, 10., 16.], + np.float64, + 'f64', + typing.Optional[np.float64]), + ([375., 24., None, 10., 16.], + np.float32, + 'f32', + typing.Optional[np.float32]), + ([True, False, True, True, False], + bool, + 'bool', + bool), + (['Falcon', 'Ostrich', None, 3.14, 0], + object, + 'any', + typing.Any), + ([True, False, True, None, False], + pd.BooleanDtype(), + 'bool_nullable', + typing.Optional[bool]), + ([ + 'Falcon', + 'Ostrich', + None, + 'Aardvark', + 'Elephant' + ], + pd.StringDtype(), + 'strdtype', + typing.Optional[str]), + ] NICE_TYPES_DF = pd.DataFrame(columns=[name for _, _, name, _ in COLUMNS]) for arr, dtype, name, _ in COLUMNS: @@ -104,9 +125,9 @@ def check_df_pcoll_equal(actual): SERIES_TESTS = [(pd.Series(arr, dtype=dtype, name=name), arr, beam_type) for (arr, dtype, name, beam_type) in COLUMNS] -_TEST_ARRAYS = [ +_TEST_ARRAYS: typing.List[typing.List[typing.Any]] = [ arr for (arr, _, _, _) in COLUMNS -] # type: typing.List[typing.List[typing.Any]] +] DF_RESULT = list(zip(*_TEST_ARRAYS)) BEAM_SCHEMA = typing.NamedTuple( # type: ignore 'BEAM_SCHEMA', [(name, beam_type) for _, _, name, beam_type in COLUMNS]) diff --git a/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py b/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py index 98023fbc624c..6b5573aa4569 100644 --- a/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py +++ b/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py @@ -53,7 +53,7 @@ if TYPE_CHECKING: import google.cloud.bigtable.instance -EXISTING_INSTANCES = [] # type: List[google.cloud.bigtable.instance.Instance] +EXISTING_INSTANCES: List['google.cloud.bigtable.instance.Instance'] = [] LABEL_KEY = 'python-bigtable-beam' label_stamp = datetime.datetime.utcnow().replace(tzinfo=UTC) label_stamp_micros = _microseconds_from_datetime(label_stamp) diff --git a/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py b/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py index 6a4b9e234297..65ea7990a2d8 100644 --- a/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py +++ b/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py @@ -87,9 +87,7 @@ def __init__(self): self.word_counter = Metrics.counter('main', 'total_words') self.word_lengths_dist = Metrics.distribution('main', 'word_len_dist') - def process(self, element): - # type: (Entity) -> Optional[Iterable[Text]] - + def process(self, element: Entity) -> Optional[Iterable[Text]]: """Extract words from the 'content' property of Cloud Datastore entities. The element is a line of text. If the line is blank, note that, too. diff --git a/sdks/python/apache_beam/internal/cloudpickle_pickler.py b/sdks/python/apache_beam/internal/cloudpickle_pickler.py index 6063faa0b14c..83cdac4b5f33 100644 --- a/sdks/python/apache_beam/internal/cloudpickle_pickler.py +++ b/sdks/python/apache_beam/internal/cloudpickle_pickler.py @@ -46,9 +46,7 @@ RLOCK_TYPE = type(_pickle_lock) -def dumps(o, enable_trace=True, use_zlib=False): - # type: (...) -> bytes - +def dumps(o, enable_trace=True, use_zlib=False) -> bytes: """For internal use only; no backwards-compatibility guarantees.""" with _pickle_lock: with io.BytesIO() as file: diff --git a/sdks/python/apache_beam/internal/dill_pickler.py b/sdks/python/apache_beam/internal/dill_pickler.py index 8a0742642dfb..7f7ac5b214fa 100644 --- a/sdks/python/apache_beam/internal/dill_pickler.py +++ b/sdks/python/apache_beam/internal/dill_pickler.py @@ -309,8 +309,7 @@ def save_module(pickler, obj): # Pickle module dictionaries (commonly found in lambda's globals) # by referencing their module. old_save_module_dict = dill.dill.save_module_dict - known_module_dicts = { - } # type: Dict[int, Tuple[types.ModuleType, Dict[str, Any]]] + known_module_dicts: Dict[int, Tuple[types.ModuleType, Dict[str, Any]]] = {} @dill.dill.register(dict) def new_save_module_dict(pickler, obj): @@ -370,9 +369,7 @@ def new_log_info(msg, *args, **kwargs): logging.getLogger('dill').setLevel(logging.WARN) -def dumps(o, enable_trace=True, use_zlib=False): - # type: (...) -> bytes - +def dumps(o, enable_trace=True, use_zlib=False) -> bytes: """For internal use only; no backwards-compatibility guarantees.""" with _pickle_lock: try: diff --git a/sdks/python/apache_beam/internal/metrics/cells.py b/sdks/python/apache_beam/internal/metrics/cells.py index 3fcaecf8c677..c7b546258a70 100644 --- a/sdks/python/apache_beam/internal/metrics/cells.py +++ b/sdks/python/apache_beam/internal/metrics/cells.py @@ -55,8 +55,7 @@ def __init__(self, bucket_type): def reset(self): self.data = HistogramAggregator(self._bucket_type).identity_element() - def combine(self, other): - # type: (HistogramCell) -> HistogramCell + def combine(self, other: 'HistogramCell') -> 'HistogramCell': result = HistogramCell(self._bucket_type) result.data = self.data.combine(other.data) return result @@ -64,8 +63,7 @@ def combine(self, other): def update(self, value): self.data.histogram.record(value) - def get_cumulative(self): - # type: () -> HistogramData + def get_cumulative(self) -> 'HistogramData': return self.data.get_cumulative() def to_runner_api_monitoring_info(self, name, transform_id): @@ -92,8 +90,7 @@ def __hash__(self): class HistogramResult(object): - def __init__(self, data): - # type: (HistogramData) -> None + def __init__(self, data: 'HistogramData') -> None: self.data = data def __eq__(self, other): @@ -142,12 +139,10 @@ def __hash__(self): def __repr__(self): return 'HistogramData({})'.format(self.histogram.get_percentile_info()) - def get_cumulative(self): - # type: () -> HistogramData + def get_cumulative(self) -> 'HistogramData': return HistogramData(self.histogram) - def combine(self, other): - # type: (Optional[HistogramData]) -> HistogramData + def combine(self, other: Optional['HistogramData']) -> 'HistogramData': if other is None: return self @@ -161,18 +156,14 @@ class HistogramAggregator(MetricAggregator): Values aggregated should be ``HistogramData`` objects. """ - def __init__(self, bucket_type): - # type: (BucketType) -> None + def __init__(self, bucket_type: 'BucketType') -> None: self._bucket_type = bucket_type - def identity_element(self): - # type: () -> HistogramData + def identity_element(self) -> HistogramData: return HistogramData(Histogram(self._bucket_type)) - def combine(self, x, y): - # type: (HistogramData, HistogramData) -> HistogramData + def combine(self, x: HistogramData, y: HistogramData) -> HistogramData: return x.combine(y) - def result(self, x): - # type: (HistogramData) -> HistogramResult + def result(self, x: HistogramData) -> HistogramResult: return HistogramResult(x.get_cumulative()) diff --git a/sdks/python/apache_beam/internal/metrics/metric.py b/sdks/python/apache_beam/internal/metrics/metric.py index f892dd2024a1..8acf800ff8c6 100644 --- a/sdks/python/apache_beam/internal/metrics/metric.py +++ b/sdks/python/apache_beam/internal/metrics/metric.py @@ -61,9 +61,10 @@ class Metrics(object): @staticmethod - def counter(urn, labels=None, process_wide=False): - # type: (str, Optional[Dict[str, str]], bool) -> UserMetrics.DelegatingCounter - + def counter( + urn: str, + labels: Optional[Dict[str, str]] = None, + process_wide: bool = False) -> UserMetrics.DelegatingCounter: """Obtains or creates a Counter metric. Args: @@ -82,9 +83,11 @@ def counter(urn, labels=None, process_wide=False): process_wide=process_wide) @staticmethod - def histogram(namespace, name, bucket_type, logger=None): - # type: (Union[Type, str], str, BucketType, Optional[MetricLogger]) -> Metrics.DelegatingHistogram - + def histogram( + namespace: Union[Type, str], + name: str, + bucket_type: 'BucketType', + logger: Optional['MetricLogger'] = None) -> 'Metrics.DelegatingHistogram': """Obtains or creates a Histogram metric. Args: @@ -103,16 +106,18 @@ def histogram(namespace, name, bucket_type, logger=None): class DelegatingHistogram(Histogram): """Metrics Histogram that Delegates functionality to MetricsEnvironment.""" - def __init__(self, metric_name, bucket_type, logger): - # type: (MetricName, BucketType, Optional[MetricLogger]) -> None + def __init__( + self, + metric_name: MetricName, + bucket_type: 'BucketType', + logger: Optional['MetricLogger']) -> None: super().__init__(metric_name) self.metric_name = metric_name self.cell_type = HistogramCellFactory(bucket_type) self.logger = logger self.updater = MetricUpdater(self.cell_type, self.metric_name) - def update(self, value): - # type: (object) -> None + def update(self, value: object) -> None: self.updater(value) if self.logger: self.logger.update(self.cell_type, self.metric_name, value) @@ -120,27 +125,30 @@ def update(self, value): class MetricLogger(object): """Simple object to locally aggregate and log metrics.""" - def __init__(self): - # type: () -> None - self._metric = {} # type: Dict[MetricName, MetricCell] + def __init__(self) -> None: + self._metric: Dict[MetricName, 'MetricCell'] = {} self._lock = threading.Lock() self._last_logging_millis = int(time.time() * 1000) self.minimum_logging_frequency_msec = 180000 - def update(self, cell_type, metric_name, value): - # type: (Union[Type[MetricCell], MetricCellFactory], MetricName, object) -> None + def update( + self, + cell_type: Union[Type['MetricCell'], 'MetricCellFactory'], + metric_name: MetricName, + value: object) -> None: cell = self._get_metric_cell(cell_type, metric_name) cell.update(value) - def _get_metric_cell(self, cell_type, metric_name): - # type: (Union[Type[MetricCell], MetricCellFactory], MetricName) -> MetricCell + def _get_metric_cell( + self, + cell_type: Union[Type['MetricCell'], 'MetricCellFactory'], + metric_name: MetricName) -> 'MetricCell': with self._lock: if metric_name not in self._metric: self._metric[metric_name] = cell_type() return self._metric[metric_name] - def log_metrics(self, reset_after_logging=False): - # type: (bool) -> None + def log_metrics(self, reset_after_logging: bool = False) -> None: if self._lock.acquire(False): try: current_millis = int(time.time() * 1000) @@ -172,14 +180,14 @@ class ServiceCallMetric(object): TODO(ajamato): Add Request latency metric. """ - def __init__(self, request_count_urn, base_labels=None): - # type: (str, Optional[Dict[str, str]]) -> None + def __init__( + self, + request_count_urn: str, + base_labels: Optional[Dict[str, str]] = None) -> None: self.base_labels = base_labels if base_labels else {} self.request_count_urn = request_count_urn - def call(self, status): - # type: (Union[int, str, HttpError]) -> None - + def call(self, status: Union[int, str, 'HttpError']) -> None: """Record the status of the call into appropriate metrics.""" canonical_status = self.convert_to_canonical_status_string(status) additional_labels = {monitoring_infos.STATUS_LABEL: canonical_status} @@ -191,9 +199,8 @@ def call(self, status): urn=self.request_count_urn, labels=labels, process_wide=True) request_counter.inc() - def convert_to_canonical_status_string(self, status): - # type: (Union[int, str, HttpError]) -> str - + def convert_to_canonical_status_string( + self, status: Union[int, str, 'HttpError']) -> str: """Converts a status to a canonical GCP status cdoe string.""" http_status_code = None if isinstance(status, int): @@ -222,9 +229,8 @@ def convert_to_canonical_status_string(self, status): return str(http_status_code) @staticmethod - def bigtable_error_code_to_grpc_status_string(grpc_status_code): - # type: (Optional[int]) -> str - + def bigtable_error_code_to_grpc_status_string( + grpc_status_code: Optional[int]) -> str: """ Converts the bigtable error code to a canonical GCP status code string. diff --git a/sdks/python/apache_beam/internal/module_test.py b/sdks/python/apache_beam/internal/module_test.py index eaa1629be8e5..ff0ad0c564e6 100644 --- a/sdks/python/apache_beam/internal/module_test.py +++ b/sdks/python/apache_beam/internal/module_test.py @@ -21,7 +21,7 @@ import re import sys -from typing import Type +from typing import Any class TopClass(object): @@ -64,7 +64,7 @@ def get(self): class RecursiveClass(object): """A class that contains a reference to itself.""" - SELF_TYPE = None # type: Type[RecursiveClass] + SELF_TYPE: Any = None def __init__(self, datum): self.datum = 'RecursiveClass:%s' % datum diff --git a/sdks/python/apache_beam/internal/pickler.py b/sdks/python/apache_beam/internal/pickler.py index 1685ae928167..79ebd16314bf 100644 --- a/sdks/python/apache_beam/internal/pickler.py +++ b/sdks/python/apache_beam/internal/pickler.py @@ -38,8 +38,7 @@ desired_pickle_lib = dill_pickler -def dumps(o, enable_trace=True, use_zlib=False): - # type: (...) -> bytes +def dumps(o, enable_trace=True, use_zlib=False) -> bytes: return desired_pickle_lib.dumps( o, enable_trace=enable_trace, use_zlib=use_zlib) diff --git a/sdks/python/apache_beam/internal/util.py b/sdks/python/apache_beam/internal/util.py index f0a3ad8288b5..85a6e4c43b83 100644 --- a/sdks/python/apache_beam/internal/util.py +++ b/sdks/python/apache_beam/internal/util.py @@ -66,12 +66,11 @@ def __hash__(self): return hash(type(self)) -def remove_objects_from_args(args, # type: Iterable[Any] - kwargs, # type: Dict[str, Any] - pvalue_class # type: Union[Type[T], Tuple[Type[T], ...]] - ): - # type: (...) -> Tuple[List[Any], Dict[str, Any], List[T]] - +def remove_objects_from_args( + args: Iterable[Any], + kwargs: Dict[str, Any], + pvalue_class: Union[Type[T], Tuple[Type[T], ...]] +) -> Tuple[List[Any], Dict[str, Any], List[T]]: """For internal use only; no backwards-compatibility guarantees. Replaces all objects of a given type in args/kwargs with a placeholder. diff --git a/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py b/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py index 8bc3dd68281f..c446c17247d7 100644 --- a/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py +++ b/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py @@ -151,8 +151,6 @@ def create( path, mime_type='application/octet-stream', compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO # noqa: F821 - """Returns a write channel for the given file path. Args: @@ -169,8 +167,6 @@ def open( path, mime_type='application/octet-stream', compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO # noqa: F821 - """Returns a read channel for the given file path. Args: diff --git a/sdks/python/apache_beam/io/filebasedsource.py b/sdks/python/apache_beam/io/filebasedsource.py index 240fc65c52b3..91763ced6e69 100644 --- a/sdks/python/apache_beam/io/filebasedsource.py +++ b/sdks/python/apache_beam/io/filebasedsource.py @@ -135,8 +135,7 @@ def display_data(self): } @check_accessible(['_pattern']) - def _get_concat_source(self): - # type: () -> concat_source.ConcatSource + def _get_concat_source(self) -> concat_source.ConcatSource: if self._concat_source is None: pattern = self._pattern.get() @@ -369,9 +368,8 @@ def process(self, element: Union[str, FileMetadata], *args, class _ReadRange(DoFn): def __init__( self, - source_from_file, # type: Union[str, iobase.BoundedSource] - with_filename=False # type: bool - ) -> None: + source_from_file: Union[str, iobase.BoundedSource], + with_filename: bool = False) -> None: self._source_from_file = source_from_file self._with_filename = with_filename @@ -402,14 +400,14 @@ class ReadAllFiles(PTransform): PTransform authors who wishes to implement file-based Read transforms that read a PCollection of files. """ - def __init__(self, - splittable, # type: bool - compression_type, - desired_bundle_size, # type: int - min_bundle_size, # type: int - source_from_file, # type: Callable[[str], iobase.BoundedSource] - with_filename=False # type: bool - ): + def __init__( + self, + splittable: bool, + compression_type, + desired_bundle_size: int, + min_bundle_size: int, + source_from_file: Callable[[str], iobase.BoundedSource], + with_filename: bool = False): """ Args: splittable: If False, files won't be split into sub-ranges. If True, diff --git a/sdks/python/apache_beam/io/fileio.py b/sdks/python/apache_beam/io/fileio.py index a90ba3a50e7e..378ecf71920d 100644 --- a/sdks/python/apache_beam/io/fileio.py +++ b/sdks/python/apache_beam/io/fileio.py @@ -90,7 +90,6 @@ import collections import logging -import os import random import uuid from collections import namedtuple @@ -702,9 +701,10 @@ def process(self, element, w=beam.DoFn.WindowParam): move_from = [f.file_name for f in temp_file_results] move_to = [f.file_name for f in final_file_results] + _LOGGER.info( - 'Moving temporary files %s to dir: %s as %s', - map(os.path.basename, move_from), + 'Moving %d temporary files to dir: %s as %s', + len(move_from), self.path.get(), move_to) @@ -745,13 +745,13 @@ def _check_orphaned_files(self, writer_key): orphaned_files = [m.path for m in match_result[0].metadata_list] if len(orphaned_files) > 0: - _LOGGER.info( + _LOGGER.warning( 'Some files may be left orphaned in the temporary folder: %s. ' - 'This may be a result of insufficient permissions to delete' - 'these temp files.', + 'This may be a result of retried work items or insufficient' + 'permissions to delete these temp files.', orphaned_files) except BeamIOError as e: - _LOGGER.info('Exceptions when checking orphaned files: %s', e) + _LOGGER.warning('Exceptions when checking orphaned files: %s', e) class _WriteShardedRecordsFn(beam.DoFn): diff --git a/sdks/python/apache_beam/io/filesystem.py b/sdks/python/apache_beam/io/filesystem.py index 142e04bc295e..550079a482c4 100644 --- a/sdks/python/apache_beam/io/filesystem.py +++ b/sdks/python/apache_beam/io/filesystem.py @@ -145,7 +145,7 @@ class CompressedFile(object): def __init__( self, - fileobj, # type: BinaryIO + fileobj: BinaryIO, compression_type=CompressionTypes.GZIP, read_size=DEFAULT_READ_BUFFER_SIZE): if not fileobj: @@ -167,7 +167,7 @@ def __init__( raise ValueError( 'File object must be at position 0 but was %d' % self._file.tell()) self._uncompressed_position = 0 - self._uncompressed_size = None # type: Optional[int] + self._uncompressed_size: Optional[int] = None if self.readable(): self._read_size = read_size @@ -217,19 +217,15 @@ def _initialize_compressor(self): self._compressor = zlib.compressobj( zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, self._gzip_mask) - def readable(self): - # type: () -> bool + def readable(self) -> bool: mode = self._file.mode return 'r' in mode or 'a' in mode - def writeable(self): - # type: () -> bool + def writeable(self) -> bool: mode = self._file.mode return 'w' in mode or 'a' in mode - def write(self, data): - # type: (bytes) -> None - + def write(self, data: bytes) -> None: """Write data to file.""" if not self._compressor: raise ValueError('compressor not initialized') @@ -303,9 +299,7 @@ def read(self, num_bytes: Optional[int] = None) -> bytes: return self._read_from_internal_buffer( lambda: self._read_buffer.read(num_bytes)) - def readline(self): - # type: () -> bytes - + def readline(self) -> bytes: """Equivalent to standard file.readline(). Same return conventions apply.""" if not self._decompressor: raise ValueError('decompressor not initialized') @@ -345,31 +339,24 @@ def flush(self) -> None: self._file.flush() @property - def seekable(self): - # type: () -> bool + def seekable(self) -> bool: return 'r' in self._file.mode - def _clear_read_buffer(self): - # type: () -> None - + def _clear_read_buffer(self) -> None: """Clears the read buffer by removing all the contents and resetting _read_position to 0""" self._read_position = 0 self._read_buffer.seek(0) self._read_buffer.truncate(0) - def _rewind_file(self): - # type: () -> None - + def _rewind_file(self) -> None: """Seeks to the beginning of the input file. Input file's EOF marker is cleared and _uncompressed_position is reset to zero""" self._file.seek(0, os.SEEK_SET) self._read_eof = False self._uncompressed_position = 0 - def _rewind(self): - # type: () -> None - + def _rewind(self) -> None: """Seeks to the beginning of the input file and resets the internal read buffer. The decompressor object is re-initialized to ensure that no data left in it's buffer.""" @@ -379,9 +366,7 @@ def _rewind(self): # Re-initialize decompressor to clear any data buffered prior to rewind self._initialize_decompressor() - def seek(self, offset, whence=os.SEEK_SET): - # type: (int, int) -> None - + def seek(self, offset: int, whence: int = os.SEEK_SET) -> None: """Set the file's current offset. Seeking behavior: @@ -445,9 +430,7 @@ def seek(self, offset, whence=os.SEEK_SET): break bytes_to_skip -= len(data) - def tell(self): - # type: () -> int - + def tell(self) -> int: """Returns current position in uncompressed file.""" return self._uncompressed_position @@ -503,8 +486,7 @@ class MatchResult(object): """Result from the ``FileSystem`` match operation which contains the list of matched ``FileMetadata``. """ - def __init__(self, pattern, metadata_list): - # type: (str, List[FileMetadata]) -> None + def __init__(self, pattern: str, metadata_list: List[FileMetadata]) -> None: self.metadata_list = metadata_list self.pattern = pattern @@ -559,9 +541,7 @@ def scheme(cls): raise NotImplementedError @abc.abstractmethod - def join(self, basepath, *paths): - # type: (str, *str) -> str - + def join(self, basepath: str, *paths: str) -> str: """Join two or more pathname components for the filesystem Args: @@ -573,9 +553,7 @@ def join(self, basepath, *paths): raise NotImplementedError @abc.abstractmethod - def split(self, path): - # type: (str) -> Tuple[str, str] - + def split(self, path: str) -> Tuple[str, str]: """Splits the given path into two parts. Splits the path into a pair (head, tail) such that tail contains the last @@ -648,9 +626,8 @@ def _url_dirname(self, url_or_path): scheme, path = self._split_scheme(url_or_path) return self._combine_scheme(scheme, posixpath.dirname(path)) - def match_files(self, file_metas, pattern): - # type: (List[FileMetadata], str) -> Iterator[FileMetadata] - + def match_files(self, file_metas: List[FileMetadata], + pattern: str) -> Iterator[FileMetadata]: """Filter :class:`FileMetadata` objects by *pattern* Args: @@ -671,9 +648,7 @@ def match_files(self, file_metas, pattern): yield file_metadata @staticmethod - def translate_pattern(pattern): - # type: (str) -> str - + def translate_pattern(pattern: str) -> str: """ Translate a *pattern* to a regular expression. There is no way to quote meta-characters. @@ -809,9 +784,7 @@ def create( self, path, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO - + compression_type=CompressionTypes.AUTO) -> BinaryIO: """Returns a write channel for the given file path. Args: @@ -828,9 +801,7 @@ def open( self, path, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO - + compression_type=CompressionTypes.AUTO) -> BinaryIO: """Returns a read channel for the given file path. Args: @@ -870,9 +841,7 @@ def rename(self, source_file_names, destination_file_names): raise NotImplementedError @abc.abstractmethod - def exists(self, path): - # type: (str) -> bool - + def exists(self, path: str) -> bool: """Check if the provided path exists on the FileSystem. Args: @@ -883,9 +852,7 @@ def exists(self, path): raise NotImplementedError @abc.abstractmethod - def size(self, path): - # type: (str) -> int - + def size(self, path: str) -> int: """Get size in bytes of a file on the FileSystem. Args: diff --git a/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source.py b/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source.py index d50672ed6be2..91c76b5d54bf 100644 --- a/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source.py +++ b/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source.py @@ -35,7 +35,7 @@ class FlinkStreamingImpulseSource(PTransform): URN = "flink:transform:streaming_impulse:v1" - config = {} # type: Dict[str, Any] + config: Dict[str, Any] = {} def expand(self, pbegin): assert isinstance(pbegin, pvalue.PBegin), ( diff --git a/sdks/python/apache_beam/io/gcp/bigquery_avro_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_avro_tools.py index d10a4d8fc2a3..b6c177fc7418 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_avro_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_avro_tools.py @@ -23,6 +23,9 @@ NOTHING IN THIS FILE HAS BACKWARDS COMPATIBILITY GUARANTEES. """ +from typing import Any +from typing import Dict + # BigQuery types as listed in # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types # with aliases (RECORD, BOOLEAN, FLOAT, INTEGER) as defined in @@ -63,18 +66,20 @@ def get_record_schema_from_dict_table_schema( - schema_name, table_schema, namespace="apache_beam.io.gcp.bigquery"): - # type: (Text, Dict[Text, Any], Text) -> Dict[Text, Any] # noqa: F821 + schema_name: str, + table_schema: Dict[str, Any], + namespace: str = "apache_beam.io.gcp.bigquery") -> Dict[str, Any]: + # noqa: F821 """Convert a table schema into an Avro schema. Args: - schema_name (Text): The name of the record. - table_schema (Dict[Text, Any]): A BigQuery table schema in dict form. - namespace (Text): The namespace of the Avro schema. + schema_name (str): The name of the record. + table_schema (Dict[str, Any]): A BigQuery table schema in dict form. + namespace (str): The namespace of the Avro schema. Returns: - Dict[Text, Any]: The schema as an Avro RecordSchema. + Dict[str, Any]: The schema as an Avro RecordSchema. """ avro_fields = [ table_field_to_avro_field(field, ".".join((namespace, schema_name))) @@ -90,16 +95,17 @@ def get_record_schema_from_dict_table_schema( } -def table_field_to_avro_field(table_field, namespace): - # type: (Dict[Text, Any], str) -> Dict[Text, Any] # noqa: F821 +def table_field_to_avro_field(table_field: Dict[str, Any], + namespace: str) -> Dict[str, Any]: + # noqa: F821 """Convert a BigQuery field to an avro field. Args: - table_field (Dict[Text, Any]): A BigQuery field in dict form. + table_field (Dict[str, Any]): A BigQuery field in dict form. Returns: - Dict[Text, Any]: An equivalent Avro field in dict form. + Dict[str, Any]: An equivalent Avro field in dict form. """ assert "type" in table_field, \ "Unable to get type for table field {}".format(table_field) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_read_internal.py b/sdks/python/apache_beam/io/gcp/bigquery_read_internal.py index ce49cd0161df..f3881ed261ae 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_read_internal.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_read_internal.py @@ -24,7 +24,7 @@ import decimal import json import logging -import random +import secrets import time import uuid from typing import TYPE_CHECKING @@ -212,7 +212,7 @@ def __init__( self._source_uuid = unique_id self.kms_key = kms_key self.project = project - self.temp_dataset = temp_dataset or 'bq_read_all_%s' % uuid.uuid4().hex + self.temp_dataset = temp_dataset self.query_priority = query_priority self.bq_io_metadata = None @@ -226,22 +226,27 @@ def display_data(self): 'temp_dataset': str(self.temp_dataset) } - def _get_temp_dataset(self): - if isinstance(self.temp_dataset, str): - return DatasetReference( - datasetId=self.temp_dataset, projectId=self._get_project()) - else: + def _get_temp_dataset_id(self): + if self.temp_dataset is None: + return None + elif isinstance(self.temp_dataset, DatasetReference): + return self.temp_dataset.datasetId + elif isinstance(self.temp_dataset, str): return self.temp_dataset + else: + raise ValueError("temp_dataset has to be either str or DatasetReference") - def process(self, - element: 'ReadFromBigQueryRequest') -> Iterable[BoundedSource]: - bq = bigquery_tools.BigQueryWrapper( - temp_dataset_id=self._get_temp_dataset().datasetId, + def start_bundle(self): + self.bq = bigquery_tools.BigQueryWrapper( + temp_dataset_id=self._get_temp_dataset_id(), client=bigquery_tools.BigQueryWrapper._bigquery_client(self.options)) + def process(self, + element: 'ReadFromBigQueryRequest') -> Iterable[BoundedSource]: if element.query is not None: - self._setup_temporary_dataset(bq, element) - table_reference = self._execute_query(bq, element) + if not self.bq.created_temp_dataset: + self._setup_temporary_dataset(self.bq, element) + table_reference = self._execute_query(self.bq, element) else: assert element.table table_reference = bigquery_tools.parse_table_reference( @@ -250,19 +255,21 @@ def process(self, if not table_reference.projectId: table_reference.projectId = self._get_project() - schema, metadata_list = self._export_files(bq, element, table_reference) + schema, metadata_list = self._export_files( + self.bq, element, table_reference) for metadata in metadata_list: yield self._create_source(metadata.path, schema) if element.query is not None: - bq._delete_table( + self.bq._delete_table( table_reference.projectId, table_reference.datasetId, table_reference.tableId) - if bq.created_temp_dataset: - self._clean_temporary_dataset(bq, element) + def finish_bundle(self): + if self.bq.created_temp_dataset: + self.bq.clean_up_temporary_dataset(self._get_project()) def _get_bq_metadata(self): if not self.bq_io_metadata: @@ -288,12 +295,6 @@ def _setup_temporary_dataset( self._get_project(), element.query, not element.use_standard_sql) bq.create_temporary_dataset(self._get_project(), location) - def _clean_temporary_dataset( - self, - bq: bigquery_tools.BigQueryWrapper, - element: 'ReadFromBigQueryRequest'): - bq.clean_up_temporary_dataset(self._get_project()) - def _execute_query( self, bq: bigquery_tools.BigQueryWrapper, @@ -302,7 +303,7 @@ def _execute_query( self._job_name, self._source_uuid, bigquery_tools.BigQueryJobTypes.QUERY, - '%s_%s' % (int(time.time()), random.randint(0, 1000))) + '%s_%s' % (int(time.time()), secrets.token_hex(3))) job = bq._start_query_job( self._get_project(), element.query, diff --git a/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py index d56a4c764715..913d6e078d89 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_read_it_test.py @@ -109,11 +109,11 @@ def tearDownClass(cls): request = bigquery.BigqueryDatasetsDeleteRequest( projectId=cls.project, datasetId=cls.dataset_id, deleteContents=True) try: - _LOGGER.info( + _LOGGER.debug( "Deleting dataset %s in project %s", cls.dataset_id, cls.project) cls.bigquery_client.client.datasets.Delete(request) except HttpError: - _LOGGER.debug( + _LOGGER.warning( 'Failed to clean up dataset %s in project %s', cls.dataset_id, cls.project) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py index 7b8a58e96978..beb373a7dea3 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_schema_tools.py @@ -53,9 +53,8 @@ } -def generate_user_type_from_bq_schema(the_table_schema, selected_fields=None): - #type: (bigquery.TableSchema) -> type - +def generate_user_type_from_bq_schema( + the_table_schema, selected_fields: 'bigquery.TableSchema' = None) -> type: """Convert a schema of type TableSchema into a pcollection element. Args: the_table_schema: A BQ schema of type TableSchema diff --git a/sdks/python/apache_beam/io/gcp/bigtableio.py b/sdks/python/apache_beam/io/gcp/bigtableio.py index 0f3944a791bd..3f54e09ee3dd 100644 --- a/sdks/python/apache_beam/io/gcp/bigtableio.py +++ b/sdks/python/apache_beam/io/gcp/bigtableio.py @@ -141,7 +141,10 @@ def start_bundle(self): self.beam_options['instance_id'], self.beam_options['table_id']) self.batcher = MutationsBatcher( - self.table, batch_completed_callback=self.write_mutate_metrics) + self.table, + batch_completed_callback=self.write_mutate_metrics, + flush_count=FLUSH_COUNT, + max_row_bytes=MAX_ROW_BYTES) def process(self, row): self.written.inc() diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/helper.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/helper.py index a6f8ef594695..417a04c3d2b4 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1new/helper.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/helper.py @@ -70,9 +70,9 @@ def retry_on_rpc_error(exception): def create_entities(count, id_or_name=False): """Creates a list of entities with random keys.""" if id_or_name: - ids_or_names = [ + ids_or_names: List[Union[str, int]] = [ uuid.uuid4().int & ((1 << 63) - 1) for _ in range(count) - ] # type: List[Union[str, int]] + ] else: ids_or_names = [str(uuid.uuid4()) for _ in range(count)] diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py index 137df4235d47..f7ce69099ca0 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py @@ -25,7 +25,6 @@ from typing import Iterable from typing import List from typing import Optional -from typing import Text from typing import Union from google.cloud.datastore import entity @@ -153,12 +152,12 @@ def __repr__(self): class Key(object): - def __init__(self, - path_elements, # type: List[Union[Text, int]] - parent=None, # type: Optional[Key] - project=None, # type: Optional[Text] - namespace=None # type: Optional[Text] - ): + def __init__( + self, + path_elements: List[Union[str, int]], + parent: Optional['Key'] = None, + project: Optional[str] = None, + namespace: Optional[str] = None): """ Represents a Datastore key. @@ -229,11 +228,7 @@ def __repr__(self): class Entity(object): - def __init__( - self, - key, # type: Key - exclude_from_indexes=() # type: Iterable[str] - ): + def __init__(self, key: Key, exclude_from_indexes: Iterable[str] = ()): """ Represents a Datastore entity. diff --git a/sdks/python/apache_beam/io/gcp/gcsfilesystem.py b/sdks/python/apache_beam/io/gcp/gcsfilesystem.py index 173b21a38f88..47d1997ddc7b 100644 --- a/sdks/python/apache_beam/io/gcp/gcsfilesystem.py +++ b/sdks/python/apache_beam/io/gcp/gcsfilesystem.py @@ -159,9 +159,7 @@ def create( self, path, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO - + compression_type=CompressionTypes.AUTO) -> BinaryIO: """Returns a write channel for the given file path. Args: @@ -177,9 +175,7 @@ def open( self, path, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO - + compression_type=CompressionTypes.AUTO) -> BinaryIO: """Returns a read channel for the given file path. Args: diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index cec65bc530f3..32e7fbe5ed58 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -37,6 +37,7 @@ from typing import NamedTuple from typing import Optional from typing import Tuple +from typing import Union from apache_beam import coders from apache_beam.io import iobase @@ -110,9 +111,7 @@ def __repr__(self): return 'PubsubMessage(%s, %s)' % (self.data, self.attributes) @staticmethod - def _from_proto_str(proto_msg): - # type: (bytes) -> PubsubMessage - + def _from_proto_str(proto_msg: bytes) -> 'PubsubMessage': """Construct from serialized form of ``PubsubMessage``. Args: @@ -185,9 +184,7 @@ def _to_proto_str(self, for_publish=False): return serialized @staticmethod - def _from_message(msg): - # type: (Any) -> PubsubMessage - + def _from_message(msg: Any) -> 'PubsubMessage': """Construct from ``google.cloud.pubsub_v1.subscriber.message.Message``. https://googleapis.github.io/google-cloud-python/latest/pubsub/subscriber/api/message.html @@ -211,14 +208,11 @@ class ReadFromPubSub(PTransform): def __init__( self, - topic=None, # type: Optional[str] - subscription=None, # type: Optional[str] - id_label=None, # type: Optional[str] - with_attributes=False, # type: bool - timestamp_attribute=None # type: Optional[str] - ): - # type: (...) -> None - + topic: Optional[str] = None, + subscription: Optional[str] = None, + id_label: Optional[str] = None, + with_attributes: bool = False, + timestamp_attribute: Optional[str] = None) -> None: """Initializes ``ReadFromPubSub``. Args: @@ -327,13 +321,10 @@ class WriteToPubSub(PTransform): def __init__( self, - topic, # type: str - with_attributes=False, # type: bool - id_label=None, # type: Optional[str] - timestamp_attribute=None # type: Optional[str] - ): - # type: (...) -> None - + topic: str, + with_attributes: bool = False, + id_label: Optional[str] = None, + timestamp_attribute: Optional[str] = None) -> None: """Initializes ``WriteToPubSub``. Args: @@ -359,8 +350,7 @@ def __init__( self._sink = _PubSubSink(topic, id_label, timestamp_attribute) @staticmethod - def message_to_proto_str(element): - # type: (PubsubMessage) -> bytes + def message_to_proto_str(element: PubsubMessage) -> bytes: if not isinstance(element, PubsubMessage): raise TypeError( 'Unexpected element. Type: %s (expected: PubsubMessage), ' @@ -368,16 +358,15 @@ def message_to_proto_str(element): return element._to_proto_str(for_publish=True) @staticmethod - def bytes_to_proto_str(element): - # type: (bytes) -> bytes + def bytes_to_proto_str(element: Union[bytes, str]) -> bytes: msg = PubsubMessage(element, {}) return msg._to_proto_str(for_publish=True) def expand(self, pcoll): if self.with_attributes: - pcoll = pcoll | 'ToProtobuf' >> Map(self.message_to_proto_str) + pcoll = pcoll | 'ToProtobufX' >> Map(self.message_to_proto_str) else: - pcoll = pcoll | 'ToProtobuf' >> Map(self.bytes_to_proto_str) + pcoll = pcoll | 'ToProtobufY' >> Map(self.bytes_to_proto_str) pcoll.element_type = bytes return pcoll | Write(self._sink) @@ -438,12 +427,11 @@ class _PubSubSource(iobase.SourceBase): """ def __init__( self, - topic=None, # type: Optional[str] - subscription=None, # type: Optional[str] - id_label=None, # type: Optional[str] - with_attributes=False, # type: bool - timestamp_attribute=None # type: Optional[str] - ): + topic: Optional[str] = None, + subscription: Optional[str] = None, + id_label: Optional[str] = None, + with_attributes: bool = False, + timestamp_attribute: Optional[str] = None): self.coder = coders.BytesCoder() self.full_topic = topic self.full_subscription = subscription @@ -562,8 +550,8 @@ class MultipleReadFromPubSub(PTransform): """ def __init__( self, - pubsub_source_descriptors, # type: List[PubSubSourceDescriptor] - with_attributes=False, # type: bool + pubsub_source_descriptors: List[PubSubSourceDescriptor], + with_attributes: bool = False, ): """Initializes ``PubSubMultipleReader``. diff --git a/sdks/python/apache_beam/io/gcp/pubsub_test.py b/sdks/python/apache_beam/io/gcp/pubsub_test.py index 7b4a4d5c93b9..f704338626ee 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub_test.py +++ b/sdks/python/apache_beam/io/gcp/pubsub_test.py @@ -391,6 +391,7 @@ def test_expand(self): pcoll = ( p | ReadFromPubSub('projects/fakeprj/topics/baz') + | beam.Map(lambda x: PubsubMessage(x)) | WriteToPubSub( 'projects/fakeprj/topics/a_topic', with_attributes=True) | beam.Map(lambda x: x)) @@ -875,7 +876,7 @@ def test_write_messages_with_attributes_error(self, mock_pubsub): options = PipelineOptions([]) options.view_as(StandardOptions).streaming = True - with self.assertRaisesRegex(AttributeError, r'str.*has no attribute.*data'): + with self.assertRaisesRegex(Exception, r'Type hint violation'): with TestPipeline(options=options) as p: _ = ( p @@ -897,7 +898,9 @@ def test_write_messages_unsupported_features(self, mock_pubsub): p | Create(payloads) | WriteToPubSub( - 'projects/fakeprj/topics/a_topic', id_label='a_label')) + 'projects/fakeprj/topics/a_topic', + id_label='a_label', + with_attributes=True)) options = PipelineOptions([]) options.view_as(StandardOptions).streaming = True @@ -909,7 +912,8 @@ def test_write_messages_unsupported_features(self, mock_pubsub): | Create(payloads) | WriteToPubSub( 'projects/fakeprj/topics/a_topic', - timestamp_attribute='timestamp')) + timestamp_attribute='timestamp', + with_attributes=True)) def test_runner_api_transformation(self, unused_mock_pubsub): sink = _PubSubSink( diff --git a/sdks/python/apache_beam/io/hadoopfilesystem.py b/sdks/python/apache_beam/io/hadoopfilesystem.py index c47a66c0f105..cf488c228a28 100644 --- a/sdks/python/apache_beam/io/hadoopfilesystem.py +++ b/sdks/python/apache_beam/io/hadoopfilesystem.py @@ -237,9 +237,7 @@ def create( self, url, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO - + compression_type=CompressionTypes.AUTO) -> BinaryIO: """ Returns: A Python File-like object. @@ -261,9 +259,7 @@ def open( self, url, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO - + compression_type=CompressionTypes.AUTO) -> BinaryIO: """ Returns: A Python File-like object. @@ -356,9 +352,7 @@ def rename(self, source_file_names, destination_file_names): if exceptions: raise BeamIOError('Rename operation failed', exceptions) - def exists(self, url): - # type: (str) -> bool - + def exists(self, url: str) -> bool: """Checks existence of url in HDFS. Args: diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py index 96f154dbe4b8..53215275e050 100644 --- a/sdks/python/apache_beam/io/iobase.py +++ b/sdks/python/apache_beam/io/iobase.py @@ -104,8 +104,7 @@ class SourceBase(HasDisplayData, urns.RunnerApiFn): def default_output_coder(self): raise NotImplementedError - def is_bounded(self): - # type: () -> bool + def is_bounded(self) -> bool: raise NotImplementedError @@ -144,9 +143,7 @@ class BoundedSource(SourceBase): implementations may invoke methods of ``BoundedSource`` objects through multi-threaded and/or reentrant execution modes. """ - def estimate_size(self): - # type: () -> Optional[int] - + def estimate_size(self) -> Optional[int]: """Estimates the size of source in bytes. An estimate of the total size (in bytes) of the data that would be read @@ -159,13 +156,12 @@ def estimate_size(self): """ raise NotImplementedError - def split(self, - desired_bundle_size, # type: int - start_position=None, # type: Optional[Any] - stop_position=None, # type: Optional[Any] - ): - # type: (...) -> Iterator[SourceBundle] - + def split( + self, + desired_bundle_size: int, + start_position: Optional[Any] = None, + stop_position: Optional[Any] = None, + ) -> Iterator[SourceBundle]: """Splits the source into a set of bundles. Bundles should be approximately of size ``desired_bundle_size`` bytes. @@ -182,12 +178,11 @@ def split(self, """ raise NotImplementedError - def get_range_tracker(self, - start_position, # type: Optional[Any] - stop_position, # type: Optional[Any] - ): - # type: (...) -> RangeTracker - + def get_range_tracker( + self, + start_position: Optional[Any], + stop_position: Optional[Any], + ) -> 'RangeTracker': """Returns a RangeTracker for a given position range. Framework may invoke ``read()`` method with the RangeTracker object returned @@ -879,9 +874,7 @@ class Read(ptransform.PTransform): # Import runners here to prevent circular imports from apache_beam.runners.pipeline_context import PipelineContext - def __init__(self, source): - # type: (SourceBase) -> None - + def __init__(self, source: SourceBase) -> None: """Initializes a Read transform. Args: @@ -921,12 +914,12 @@ def expand(self, pbegin): return pvalue.PCollection( pbegin.pipeline, is_bounded=self.source.is_bounded()) - def get_windowing(self, unused_inputs): - # type: (...) -> core.Windowing + def get_windowing(self, unused_inputs) -> core.Windowing: return core.Windowing(window.GlobalWindows()) - def _infer_output_coder(self, input_type=None, input_coder=None): - # type: (...) -> Optional[coders.Coder] + def _infer_output_coder(self, + input_type=None, + input_coder=None) -> Optional[coders.Coder]: if isinstance(self.source, SourceBase): return self.source.default_output_coder() else: @@ -1129,8 +1122,7 @@ def from_runner_api_parameter( class WriteImpl(ptransform.PTransform): """Implements the writing of custom sinks.""" - def __init__(self, sink): - # type: (Sink) -> None + def __init__(self, sink: Sink) -> None: super().__init__() self.sink = sink @@ -1289,9 +1281,7 @@ def current_restriction(self): """ raise NotImplementedError - def current_progress(self): - # type: () -> RestrictionProgress - + def current_progress(self) -> 'RestrictionProgress': """Returns a RestrictionProgress object representing the current progress. This API is recommended to be implemented. The runner can do a better job @@ -1416,16 +1406,12 @@ def get_estimator_state(self): """ raise NotImplementedError(type(self)) - def current_watermark(self): - # type: () -> timestamp.Timestamp - + def current_watermark(self) -> timestamp.Timestamp: """Return estimated output_watermark. This function must return monotonically increasing watermarks.""" raise NotImplementedError(type(self)) - def observe_timestamp(self, timestamp): - # type: (timestamp.Timestamp) -> None - + def observe_timestamp(self, timestamp: timestamp.Timestamp) -> None: """Update tracking watermark with latest output timestamp. Args: @@ -1450,8 +1436,7 @@ def __repr__(self): self._fraction, self._completed, self._remaining) @property - def completed_work(self): - # type: () -> float + def completed_work(self) -> float: if self._completed is not None: return self._completed elif self._remaining is not None and self._fraction is not None: @@ -1460,8 +1445,7 @@ def completed_work(self): return self._fraction @property - def remaining_work(self): - # type: () -> float + def remaining_work(self) -> float: if self._remaining is not None: return self._remaining elif self._completed is not None and self._fraction: @@ -1470,28 +1454,24 @@ def remaining_work(self): return 1 - self._fraction @property - def total_work(self): - # type: () -> float + def total_work(self) -> float: return self.completed_work + self.remaining_work @property - def fraction_completed(self): - # type: () -> float + def fraction_completed(self) -> float: if self._fraction is not None: return self._fraction else: return float(self._completed) / self.total_work @property - def fraction_remaining(self): - # type: () -> float + def fraction_remaining(self) -> float: if self._fraction is not None: return 1 - self._fraction else: return float(self._remaining) / self.total_work - def with_completed(self, completed): - # type: (int) -> RestrictionProgress + def with_completed(self, completed: int) -> 'RestrictionProgress': return RestrictionProgress( fraction=self._fraction, remaining=self._remaining, completed=completed) @@ -1569,8 +1549,7 @@ def __init__(self, restriction): restriction) self.restriction = restriction - def current_progress(self): - # type: () -> RestrictionProgress + def current_progress(self) -> RestrictionProgress: return RestrictionProgress( fraction=self.restriction.range_tracker().fraction_consumed()) diff --git a/sdks/python/apache_beam/io/jdbc.py b/sdks/python/apache_beam/io/jdbc.py index 903b0d1b0fef..3fef1f5fee35 100644 --- a/sdks/python/apache_beam/io/jdbc.py +++ b/sdks/python/apache_beam/io/jdbc.py @@ -373,8 +373,7 @@ def __init__(self, argument=""): pass @classmethod - def representation_type(cls): - # type: () -> type + def representation_type(cls) -> type: return Timestamp @classmethod @@ -385,14 +384,12 @@ def urn(cls): def language_type(cls): return datetime.date - def to_representation_type(self, value): - # type: (datetime.date) -> Timestamp + def to_representation_type(self, value: datetime.date) -> Timestamp: return Timestamp.from_utc_datetime( datetime.datetime.combine( value, datetime.datetime.min.time(), tzinfo=datetime.timezone.utc)) - def to_language_type(self, value): - # type: (Timestamp) -> datetime.date + def to_language_type(self, value: Timestamp) -> datetime.date: return value.to_utc_datetime().date() @@ -420,8 +417,7 @@ def __init__(self, argument=""): pass @classmethod - def representation_type(cls): - # type: () -> type + def representation_type(cls) -> type: return Timestamp @classmethod @@ -432,16 +428,14 @@ def urn(cls): def language_type(cls): return datetime.time - def to_representation_type(self, value): - # type: (datetime.date) -> Timestamp + def to_representation_type(self, value: datetime.date) -> Timestamp: return Timestamp.from_utc_datetime( datetime.datetime.combine( datetime.datetime.utcfromtimestamp(0), value, tzinfo=datetime.timezone.utc)) - def to_language_type(self, value): - # type: (Timestamp) -> datetime.date + def to_language_type(self, value: Timestamp) -> datetime.date: return value.to_utc_datetime().time() diff --git a/sdks/python/apache_beam/io/kafka.py b/sdks/python/apache_beam/io/kafka.py index b4fd7d86e688..b19e9c22aa3c 100644 --- a/sdks/python/apache_beam/io/kafka.py +++ b/sdks/python/apache_beam/io/kafka.py @@ -82,19 +82,29 @@ import typing +import numpy as np + from apache_beam.transforms.external import BeamJarExpansionService from apache_beam.transforms.external import ExternalTransform from apache_beam.transforms.external import NamedTupleBasedPayloadBuilder ReadFromKafkaSchema = typing.NamedTuple( 'ReadFromKafkaSchema', - [('consumer_config', typing.Mapping[str, str]), - ('topics', typing.List[str]), ('key_deserializer', str), - ('value_deserializer', str), ('start_read_time', typing.Optional[int]), - ('max_num_records', typing.Optional[int]), - ('max_read_time', typing.Optional[int]), - ('commit_offset_in_finalize', bool), ('timestamp_policy', str), - ('consumer_polling_timeout', typing.Optional[int])]) + [ + ('consumer_config', typing.Mapping[str, str]), + ('topics', typing.List[str]), + ('key_deserializer', str), + ('value_deserializer', str), + ('start_read_time', typing.Optional[int]), + ('max_num_records', typing.Optional[int]), + ('max_read_time', typing.Optional[int]), + ('commit_offset_in_finalize', bool), + ('timestamp_policy', str), + ('consumer_polling_timeout', typing.Optional[int]), + ('redistribute', typing.Optional[bool]), + ('redistribute_num_keys', typing.Optional[np.int32]), + ('allow_duplicates', typing.Optional[bool]), + ]) def default_io_expansion_service(append_args=None): @@ -138,6 +148,9 @@ def __init__( consumer_polling_timeout=2, with_metadata=False, expansion_service=None, + redistribute=False, + redistribute_num_keys=np.int32(0), + allow_duplicates=False, ): """ Initializes a read operation from Kafka. @@ -172,6 +185,12 @@ def __init__( this only works when using default key and value deserializers where Java Kafka Reader reads keys and values as 'byte[]'. :param expansion_service: The address (host:port) of the ExpansionService. + :param redistribute: whether a Redistribute transform should be applied + immediately after the read. + :param redistribute_num_keys: Configures how many keys the Redistribute + spreads the data across. + :param allow_duplicates: whether the Redistribute transform allows for + duplicates (this serves solely as a hint to the underlying runner). """ if timestamp_policy not in [ReadFromKafka.processing_time_policy, ReadFromKafka.create_time_policy, @@ -193,7 +212,10 @@ def __init__( start_read_time=start_read_time, commit_offset_in_finalize=commit_offset_in_finalize, timestamp_policy=timestamp_policy, - consumer_polling_timeout=consumer_polling_timeout)), + consumer_polling_timeout=consumer_polling_timeout, + redistribute=redistribute, + redistribute_num_keys=redistribute_num_keys, + allow_duplicates=allow_duplicates)), expansion_service or default_io_expansion_service()) diff --git a/sdks/python/apache_beam/io/localfilesystem.py b/sdks/python/apache_beam/io/localfilesystem.py index 3580b79ea56f..e9fe7dd4b1c2 100644 --- a/sdks/python/apache_beam/io/localfilesystem.py +++ b/sdks/python/apache_beam/io/localfilesystem.py @@ -147,9 +147,7 @@ def create( self, path, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO - + compression_type=CompressionTypes.AUTO) -> BinaryIO: """Returns a write channel for the given file path. Args: @@ -166,9 +164,7 @@ def open( self, path, mime_type='application/octet-stream', - compression_type=CompressionTypes.AUTO): - # type: (...) -> BinaryIO - + compression_type=CompressionTypes.AUTO) -> BinaryIO: """Returns a read channel for the given file path. Args: diff --git a/sdks/python/apache_beam/io/restriction_trackers.py b/sdks/python/apache_beam/io/restriction_trackers.py index 06b06fa1ed34..4b819e87a8d6 100644 --- a/sdks/python/apache_beam/io/restriction_trackers.py +++ b/sdks/python/apache_beam/io/restriction_trackers.py @@ -62,8 +62,7 @@ def split(self, desired_num_offsets_per_split, min_num_offsets_per_split=1): yield OffsetRange(current_split_start, current_split_stop) current_split_start = current_split_stop - def split_at(self, split_pos): - # type: (...) -> Tuple[OffsetRange, OffsetRange] + def split_at(self, split_pos) -> Tuple['OffsetRange', 'OffsetRange']: return OffsetRange(self.start, split_pos), OffsetRange(split_pos, self.stop) def new_tracker(self): @@ -78,8 +77,7 @@ class OffsetRestrictionTracker(RestrictionTracker): Offset range is represented as OffsetRange. """ - def __init__(self, offset_range): - # type: (OffsetRange) -> None + def __init__(self, offset_range: OffsetRange) -> None: assert isinstance(offset_range, OffsetRange), offset_range self._range = offset_range self._current_position = None @@ -100,8 +98,7 @@ def check_done(self): def current_restriction(self): return self._range - def current_progress(self): - # type: () -> RestrictionProgress + def current_progress(self) -> RestrictionProgress: if self._current_position is None: fraction = 0.0 elif self._range.stop == self._range.start: diff --git a/sdks/python/apache_beam/io/textio.py b/sdks/python/apache_beam/io/textio.py index 454fe4d69dea..3de9709d7362 100644 --- a/sdks/python/apache_beam/io/textio.py +++ b/sdks/python/apache_beam/io/textio.py @@ -102,18 +102,19 @@ def reset(self): self.data = b'' self.position = 0 - def __init__(self, - file_pattern, - min_bundle_size, - compression_type, - strip_trailing_newlines, - coder, # type: coders.Coder - buffer_size=DEFAULT_READ_BUFFER_SIZE, - validate=True, - skip_header_lines=0, - header_processor_fns=(None, None), - delimiter=None, - escapechar=None): + def __init__( + self, + file_pattern, + min_bundle_size, + compression_type, + strip_trailing_newlines, + coder: coders.Coder, + buffer_size=DEFAULT_READ_BUFFER_SIZE, + validate=True, + skip_header_lines=0, + header_processor_fns=(None, None), + delimiter=None, + escapechar=None): """Initialize a _TextSource Args: @@ -433,21 +434,21 @@ def output_type_hint(self): class _TextSink(filebasedsink.FileBasedSink): """A sink to a GCS or local text file or files.""" - - def __init__(self, - file_path_prefix, - file_name_suffix='', - append_trailing_newlines=True, - num_shards=0, - shard_name_template=None, - coder=coders.ToBytesCoder(), # type: coders.Coder - compression_type=CompressionTypes.AUTO, - header=None, - footer=None, - *, - max_records_per_shard=None, - max_bytes_per_shard=None, - skip_if_empty=False): + def __init__( + self, + file_path_prefix, + file_name_suffix='', + append_trailing_newlines=True, + num_shards=0, + shard_name_template=None, + coder: coders.Coder = coders.ToBytesCoder(), + compression_type=CompressionTypes.AUTO, + header=None, + footer=None, + *, + max_records_per_shard=None, + max_bytes_per_shard=None, + skip_if_empty=False): """Initialize a _TextSink. Args: @@ -591,7 +592,7 @@ def __init__( compression_type=CompressionTypes.AUTO, strip_trailing_newlines=True, validate=False, - coder=coders.StrUtf8Coder(), # type: coders.Coder + coder: coders.Coder = coders.StrUtf8Coder(), skip_header_lines=0, with_filename=False, delimiter=None, @@ -742,7 +743,7 @@ def __init__( min_bundle_size=0, compression_type=CompressionTypes.AUTO, strip_trailing_newlines=True, - coder=coders.StrUtf8Coder(), # type: coders.Coder + coder: coders.Coder = coders.StrUtf8Coder(), validate=True, skip_header_lines=0, delimiter=None, @@ -808,15 +809,14 @@ class ReadFromTextWithFilename(ReadFromText): class WriteToText(PTransform): """A :class:`~apache_beam.transforms.ptransform.PTransform` for writing to text files.""" - def __init__( self, - file_path_prefix, # type: str + file_path_prefix: str, file_name_suffix='', append_trailing_newlines=True, num_shards=0, - shard_name_template=None, # type: Optional[str] - coder=coders.ToBytesCoder(), # type: coders.Coder + shard_name_template: Optional[str] = None, + coder: coders.Coder = coders.ToBytesCoder(), compression_type=CompressionTypes.AUTO, header=None, footer=None, diff --git a/sdks/python/apache_beam/metrics/cells.pxd b/sdks/python/apache_beam/metrics/cells.pxd index 0eaa890c02ac..a8f4003d8980 100644 --- a/sdks/python/apache_beam/metrics/cells.pxd +++ b/sdks/python/apache_beam/metrics/cells.pxd @@ -44,6 +44,12 @@ cdef class GaugeCell(MetricCell): cdef readonly object data +cdef class StringSetCell(MetricCell): + cdef readonly set data + + cdef inline bint _update(self, value) except -1 + + cdef class DistributionData(object): cdef readonly libc.stdint.int64_t sum cdef readonly libc.stdint.int64_t count diff --git a/sdks/python/apache_beam/metrics/cells.py b/sdks/python/apache_beam/metrics/cells.py index 53b6fc849592..d836d4cee58f 100644 --- a/sdks/python/apache_beam/metrics/cells.py +++ b/sdks/python/apache_beam/metrics/cells.py @@ -268,6 +268,62 @@ def to_runner_api_monitoring_info_impl(self, name, transform_id): ptransform=transform_id) +class StringSetCell(MetricCell): + """For internal use only; no backwards-compatibility guarantees. + + Tracks the current value for a StringSet metric. + + Each cell tracks the state of a metric independently per context per bundle. + Therefore, each metric has a different cell in each bundle, that is later + aggregated. + + This class is thread safe. + """ + def __init__(self, *args): + super().__init__(*args) + self.data = StringSetAggregator.identity_element() + + def add(self, value): + self.update(value) + + def update(self, value): + # type: (str) -> None + if cython.compiled: + # We will hold the GIL throughout the entire _update. + self._update(value) + else: + with self._lock: + self._update(value) + + def _update(self, value): + self.data.add(value) + + def get_cumulative(self): + # type: () -> set + with self._lock: + return set(self.data) + + def combine(self, other): + # type: (StringSetCell) -> StringSetCell + combined = StringSetAggregator().combine(self.data, other.data) + result = StringSetCell() + result.data = combined + return result + + def to_runner_api_monitoring_info_impl(self, name, transform_id): + from apache_beam.metrics import monitoring_infos + + return monitoring_infos.user_set_string( + name.namespace, + name.name, + self.get_cumulative(), + ptransform=transform_id) + + def reset(self): + # type: () -> None + self.data = StringSetAggregator.identity_element() + + class DistributionResult(object): """The result of a Distribution metric.""" def __init__(self, data): @@ -553,3 +609,22 @@ def combine(self, x, y): def result(self, x): # type: (GaugeData) -> GaugeResult return GaugeResult(x.get_cumulative()) + + +class StringSetAggregator(MetricAggregator): + @staticmethod + def identity_element(): + # type: () -> set + return set() + + def combine(self, x, y): + # type: (set, set) -> set + if len(x) == 0: + return y + elif len(y) == 0: + return x + else: + return set.union(x, y) + + def result(self, x): + return x diff --git a/sdks/python/apache_beam/metrics/cells_test.py b/sdks/python/apache_beam/metrics/cells_test.py index 3d4d81c3d12b..052ff051bf96 100644 --- a/sdks/python/apache_beam/metrics/cells_test.py +++ b/sdks/python/apache_beam/metrics/cells_test.py @@ -25,6 +25,7 @@ from apache_beam.metrics.cells import DistributionData from apache_beam.metrics.cells import GaugeCell from apache_beam.metrics.cells import GaugeData +from apache_beam.metrics.cells import StringSetCell from apache_beam.metrics.metricbase import MetricName @@ -169,5 +170,28 @@ def test_start_time_set(self): self.assertGreater(mi.start_time.seconds, 0) +class TestStringSetCell(unittest.TestCase): + def test_not_leak_mutable_set(self): + c = StringSetCell() + c.add('test') + c.add('another') + s = c.get_cumulative() + self.assertEqual(s, set(('test', 'another'))) + s.add('yet another') + self.assertEqual(c.get_cumulative(), set(('test', 'another'))) + + def test_combine_appropriately(self): + s1 = StringSetCell() + s1.add('1') + s1.add('2') + + s2 = StringSetCell() + s2.add('1') + s2.add('3') + + result = s2.combine(s1) + self.assertEqual(result.data, set(('1', '2', '3'))) + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/metrics/execution.py b/sdks/python/apache_beam/metrics/execution.py index 4202f7996c7f..74890b822bcc 100644 --- a/sdks/python/apache_beam/metrics/execution.py +++ b/sdks/python/apache_beam/metrics/execution.py @@ -48,6 +48,7 @@ from apache_beam.metrics.cells import CounterCell from apache_beam.metrics.cells import DistributionCell from apache_beam.metrics.cells import GaugeCell +from apache_beam.metrics.cells import StringSetCell from apache_beam.runners.worker import statesampler from apache_beam.runners.worker.statesampler import get_current_tracker @@ -259,6 +260,12 @@ def get_gauge(self, metric_name): GaugeCell, self.get_metric_cell(_TypedMetricName(GaugeCell, metric_name))) + def get_string_set(self, metric_name): + # type: (MetricName) -> StringSetCell + return cast( + StringSetCell, + self.get_metric_cell(_TypedMetricName(StringSetCell, metric_name))) + def get_metric_cell(self, typed_metric_name): # type: (_TypedMetricName) -> MetricCell cell = self.metrics.get(typed_metric_name, None) @@ -292,7 +299,13 @@ def get_cumulative(self): v in self.metrics.items() if k.cell_type == GaugeCell } - return MetricUpdates(counters, distributions, gauges) + string_sets = { + MetricKey(self.step_name, k.metric_name): v.get_cumulative() + for k, + v in self.metrics.items() if k.cell_type == StringSetCell + } + + return MetricUpdates(counters, distributions, gauges, string_sets) def to_runner_api(self): return [ @@ -344,7 +357,8 @@ def __init__( self, counters=None, # type: Optional[Dict[MetricKey, int]] distributions=None, # type: Optional[Dict[MetricKey, DistributionData]] - gauges=None # type: Optional[Dict[MetricKey, GaugeData]] + gauges=None, # type: Optional[Dict[MetricKey, GaugeData]] + string_sets=None, # type: Optional[Dict[MetricKey, set]] ): # type: (...) -> None @@ -354,7 +368,9 @@ def __init__( counters: Dictionary of MetricKey:MetricUpdate updates. distributions: Dictionary of MetricKey:MetricUpdate objects. gauges: Dictionary of MetricKey:MetricUpdate objects. + string_sets: Dictionary of MetricKey:MetricUpdate objects. """ self.counters = counters or {} self.distributions = distributions or {} self.gauges = gauges or {} + self.string_sets = string_sets or {} diff --git a/sdks/python/apache_beam/metrics/execution_test.py b/sdks/python/apache_beam/metrics/execution_test.py index a888376e7091..b157aeb20e9e 100644 --- a/sdks/python/apache_beam/metrics/execution_test.py +++ b/sdks/python/apache_beam/metrics/execution_test.py @@ -17,6 +17,7 @@ # pytype: skip-file +import functools import unittest from apache_beam.metrics.execution import MetricKey @@ -88,10 +89,12 @@ def test_get_cumulative_or_updates(self): distribution = mc.get_distribution( MetricName('namespace', 'name{}'.format(i))) gauge = mc.get_gauge(MetricName('namespace', 'name{}'.format(i))) + str_set = mc.get_string_set(MetricName('namespace', 'name{}'.format(i))) counter.inc(i) distribution.update(i) gauge.set(i) + str_set.add(str(i % 7)) all_values.append(i) # Retrieve ALL updates. @@ -99,6 +102,7 @@ def test_get_cumulative_or_updates(self): self.assertEqual(len(cumulative.counters), 10) self.assertEqual(len(cumulative.distributions), 10) self.assertEqual(len(cumulative.gauges), 10) + self.assertEqual(len(cumulative.string_sets), 10) self.assertEqual( set(all_values), {v @@ -106,6 +110,11 @@ def test_get_cumulative_or_updates(self): self.assertEqual( set(all_values), {v.value for _, v in cumulative.gauges.items()}) + self.assertEqual({str(i % 7) + for i in all_values}, + functools.reduce( + set.union, + (v for _, v in cumulative.string_sets.items()))) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/metrics/metric.py b/sdks/python/apache_beam/metrics/metric.py index 08a359edae90..77cafb8bd64b 100644 --- a/sdks/python/apache_beam/metrics/metric.py +++ b/sdks/python/apache_beam/metrics/metric.py @@ -44,6 +44,7 @@ from apache_beam.metrics.metricbase import Distribution from apache_beam.metrics.metricbase import Gauge from apache_beam.metrics.metricbase import MetricName +from apache_beam.metrics.metricbase import StringSet if TYPE_CHECKING: from apache_beam.metrics.execution import MetricKey @@ -57,8 +58,7 @@ class Metrics(object): """Lets users create/access metric objects during pipeline execution.""" @staticmethod - def get_namespace(namespace): - # type: (Union[Type, str]) -> str + def get_namespace(namespace: Union[Type, str]) -> str: if isinstance(namespace, type): return '{}.{}'.format(namespace.__module__, namespace.__name__) elif isinstance(namespace, str): @@ -67,9 +67,8 @@ def get_namespace(namespace): raise ValueError('Unknown namespace type') @staticmethod - def counter(namespace, name): - # type: (Union[Type, str], str) -> Metrics.DelegatingCounter - + def counter( + namespace: Union[Type, str], name: str) -> 'Metrics.DelegatingCounter': """Obtains or creates a Counter metric. Args: @@ -83,9 +82,9 @@ def counter(namespace, name): return Metrics.DelegatingCounter(MetricName(namespace, name)) @staticmethod - def distribution(namespace, name): - # type: (Union[Type, str], str) -> Metrics.DelegatingDistribution - + def distribution( + namespace: Union[Type, str], + name: str) -> 'Metrics.DelegatingDistribution': """Obtains or creates a Distribution metric. Distribution metrics are restricted to integer-only distributions. @@ -101,9 +100,8 @@ def distribution(namespace, name): return Metrics.DelegatingDistribution(MetricName(namespace, name)) @staticmethod - def gauge(namespace, name): - # type: (Union[Type, str], str) -> Metrics.DelegatingGauge - + def gauge( + namespace: Union[Type, str], name: str) -> 'Metrics.DelegatingGauge': """Obtains or creates a Gauge metric. Gauge metrics are restricted to integer-only values. @@ -118,10 +116,27 @@ def gauge(namespace, name): namespace = Metrics.get_namespace(namespace) return Metrics.DelegatingGauge(MetricName(namespace, name)) + @staticmethod + def string_set( + namespace: Union[Type, str], name: str) -> 'Metrics.DelegatingStringSet': + """Obtains or creates a String set metric. + + String set metrics are restricted to string values. + + Args: + namespace: A class or string that gives the namespace to a metric + name: A string that gives a unique name to a metric + + Returns: + A StringSet object. + """ + namespace = Metrics.get_namespace(namespace) + return Metrics.DelegatingStringSet(MetricName(namespace, name)) + class DelegatingCounter(Counter): """Metrics Counter that Delegates functionality to MetricsEnvironment.""" - def __init__(self, metric_name, process_wide=False): - # type: (MetricName, bool) -> None + def __init__( + self, metric_name: MetricName, process_wide: bool = False) -> None: super().__init__(metric_name) self.inc = MetricUpdater( # type: ignore[assignment] cells.CounterCell, @@ -131,27 +146,31 @@ def __init__(self, metric_name, process_wide=False): class DelegatingDistribution(Distribution): """Metrics Distribution Delegates functionality to MetricsEnvironment.""" - def __init__(self, metric_name): - # type: (MetricName) -> None + def __init__(self, metric_name: MetricName) -> None: super().__init__(metric_name) self.update = MetricUpdater(cells.DistributionCell, metric_name) # type: ignore[assignment] class DelegatingGauge(Gauge): """Metrics Gauge that Delegates functionality to MetricsEnvironment.""" - def __init__(self, metric_name): - # type: (MetricName) -> None + def __init__(self, metric_name: MetricName) -> None: super().__init__(metric_name) self.set = MetricUpdater(cells.GaugeCell, metric_name) # type: ignore[assignment] + class DelegatingStringSet(StringSet): + """Metrics StringSet that Delegates functionality to MetricsEnvironment.""" + def __init__(self, metric_name: MetricName) -> None: + super().__init__(metric_name) + self.add = MetricUpdater(cells.StringSetCell, metric_name) # type: ignore[assignment] + class MetricResults(object): COUNTERS = "counters" DISTRIBUTIONS = "distributions" GAUGES = "gauges" + STRINGSETS = "string_sets" @staticmethod - def _matches_name(filter, metric_key): - # type: (MetricsFilter, MetricKey) -> bool + def _matches_name(filter: 'MetricsFilter', metric_key: 'MetricKey') -> bool: if ((filter.namespaces and metric_key.metric.namespace not in filter.namespaces) or (filter.names and metric_key.metric.name not in filter.names)): @@ -160,9 +179,7 @@ def _matches_name(filter, metric_key): return True @staticmethod - def _is_sub_list(needle, haystack): - # type: (List[str], List[str]) -> bool - + def _is_sub_list(needle: List[str], haystack: List[str]) -> bool: """True iff `needle` is a sub-list of `haystack` (i.e. a contiguous slice of `haystack` exactly matches `needle`""" needle_len = len(needle) @@ -174,9 +191,7 @@ def _is_sub_list(needle, haystack): return False @staticmethod - def _matches_sub_path(actual_scope, filter_scope): - # type: (str, str) -> bool - + def _matches_sub_path(actual_scope: str, filter_scope: str) -> bool: """True iff the '/'-delimited pieces of filter_scope exist as a sub-list of the '/'-delimited pieces of actual_scope""" return bool( @@ -184,8 +199,7 @@ def _matches_sub_path(actual_scope, filter_scope): filter_scope.split('/'), actual_scope.split('/'))) @staticmethod - def _matches_scope(filter, metric_key): - # type: (MetricsFilter, MetricKey) -> bool + def _matches_scope(filter: 'MetricsFilter', metric_key: 'MetricKey') -> bool: if not filter.steps: return True @@ -196,8 +210,8 @@ def _matches_scope(filter, metric_key): return False @staticmethod - def matches(filter, metric_key): - # type: (Optional[MetricsFilter], MetricKey) -> bool + def matches( + filter: Optional['MetricsFilter'], metric_key: 'MetricKey') -> bool: if filter is None: return True @@ -206,9 +220,10 @@ def matches(filter, metric_key): return True return False - def query(self, filter=None): - # type: (Optional[MetricsFilter]) -> Dict[str, List[MetricResults]] - + def query( + self, + filter: Optional['MetricsFilter'] = None + ) -> Dict[str, List['MetricResults']]: """Queries the runner for existing user metrics that match the filter. It should return a dictionary, with lists of each kind of metric, and @@ -217,11 +232,13 @@ def query(self, filter=None): { "counters": [MetricResult(counter_key, committed, attempted), ...], "distributions": [MetricResult(dist_key, committed, attempted), ...], - "gauges": [] // Empty list if nothing matched the filter. + "gauges": [], // Empty list if nothing matched the filter. + "string_sets": [] [MetricResult(string_set_key, committed, attempted), + ...] } The committed / attempted values are DistributionResult / GaugeResult / int - objects. + / set objects. """ raise NotImplementedError @@ -236,63 +253,53 @@ class MetricsFilter(object): Note: This class only supports user defined metrics. """ - def __init__(self): - # type: () -> None - self._names = set() # type: Set[str] - self._namespaces = set() # type: Set[str] - self._steps = set() # type: Set[str] + def __init__(self) -> None: + self._names: Set[str] = set() + self._namespaces: Set[str] = set() + self._steps: Set[str] = set() @property - def steps(self): - # type: () -> FrozenSet[str] + def steps(self) -> FrozenSet[str]: return frozenset(self._steps) @property - def names(self): - # type: () -> FrozenSet[str] + def names(self) -> FrozenSet[str]: return frozenset(self._names) @property - def namespaces(self): - # type: () -> FrozenSet[str] + def namespaces(self) -> FrozenSet[str]: return frozenset(self._namespaces) - def with_metric(self, metric): - # type: (Metric) -> MetricsFilter + def with_metric(self, metric: 'Metric') -> 'MetricsFilter': name = metric.metric_name.name or '' namespace = metric.metric_name.namespace or '' return self.with_name(name).with_namespace(namespace) - def with_name(self, name): - # type: (str) -> MetricsFilter + def with_name(self, name: str) -> 'MetricsFilter': return self.with_names([name]) - def with_names(self, names): - # type: (Iterable[str]) -> MetricsFilter + def with_names(self, names: Iterable[str]) -> 'MetricsFilter': if isinstance(names, str): raise ValueError('Names must be a collection, not a string') self._names.update(names) return self - def with_namespace(self, namespace): - # type: (Union[Type, str]) -> MetricsFilter + def with_namespace(self, namespace: Union[Type, str]) -> 'MetricsFilter': return self.with_namespaces([namespace]) - def with_namespaces(self, namespaces): - # type: (Iterable[Union[Type, str]]) -> MetricsFilter + def with_namespaces( + self, namespaces: Iterable[Union[Type, str]]) -> 'MetricsFilter': if isinstance(namespaces, str): raise ValueError('Namespaces must be an iterable, not a string') self._namespaces.update([Metrics.get_namespace(ns) for ns in namespaces]) return self - def with_step(self, step): - # type: (str) -> MetricsFilter + def with_step(self, step: str) -> 'MetricsFilter': return self.with_steps([step]) - def with_steps(self, steps): - # type: (Iterable[str]) -> MetricsFilter + def with_steps(self, steps: Iterable[str]) -> 'MetricsFilter': if isinstance(steps, str): raise ValueError('Steps must be an iterable, not a string') diff --git a/sdks/python/apache_beam/metrics/metricbase.py b/sdks/python/apache_beam/metrics/metricbase.py index 12e7881792f9..7819dbb093a5 100644 --- a/sdks/python/apache_beam/metrics/metricbase.py +++ b/sdks/python/apache_beam/metrics/metricbase.py @@ -38,7 +38,13 @@ from typing import Optional __all__ = [ - 'Metric', 'Counter', 'Distribution', 'Gauge', 'Histogram', 'MetricName' + 'Metric', + 'Counter', + 'Distribution', + 'Gauge', + 'StringSet', + 'Histogram', + 'MetricName' ] @@ -49,9 +55,12 @@ class MetricName(object): allows grouping related metrics together and also prevents collisions between multiple metrics of the same name. """ - def __init__(self, namespace, name, urn=None, labels=None): - # type: (Optional[str], Optional[str], Optional[str], Optional[Dict[str, str]]) -> None - + def __init__( + self, + namespace: Optional[str], + name: Optional[str], + urn: Optional[str] = None, + labels: Optional[Dict[str, str]] = None) -> None: """Initializes ``MetricName``. Note: namespace and name should be set for user metrics, @@ -103,8 +112,7 @@ def fast_name(self): class Metric(object): """Base interface of a metric object.""" - def __init__(self, metric_name): - # type: (MetricName) -> None + def __init__(self, metric_name: MetricName) -> None: self.metric_name = metric_name @@ -136,6 +144,14 @@ def set(self, value): raise NotImplementedError +class StringSet(Metric): + """StringSet Metric interface. + + Reports set of unique string values during pipeline execution..""" + def add(self, value): + raise NotImplementedError + + class Histogram(Metric): """Histogram Metric interface. diff --git a/sdks/python/apache_beam/metrics/monitoring_infos.py b/sdks/python/apache_beam/metrics/monitoring_infos.py index 7bc7cced280c..72640c8f92ac 100644 --- a/sdks/python/apache_beam/metrics/monitoring_infos.py +++ b/sdks/python/apache_beam/metrics/monitoring_infos.py @@ -50,8 +50,13 @@ USER_DISTRIBUTION_URN = ( common_urns.monitoring_info_specs.USER_DISTRIBUTION_INT64.spec.urn) USER_GAUGE_URN = common_urns.monitoring_info_specs.USER_LATEST_INT64.spec.urn -USER_METRIC_URNS = set( - [USER_COUNTER_URN, USER_DISTRIBUTION_URN, USER_GAUGE_URN]) +USER_STRING_SET_URN = common_urns.monitoring_info_specs.USER_SET_STRING.spec.urn +USER_METRIC_URNS = set([ + USER_COUNTER_URN, + USER_DISTRIBUTION_URN, + USER_GAUGE_URN, + USER_STRING_SET_URN +]) WORK_REMAINING_URN = common_urns.monitoring_info_specs.WORK_REMAINING.spec.urn WORK_COMPLETED_URN = common_urns.monitoring_info_specs.WORK_COMPLETED.spec.urn DATA_CHANNEL_READ_INDEX = ( @@ -67,10 +72,12 @@ common_urns.monitoring_info_types.DISTRIBUTION_INT64_TYPE.urn) LATEST_INT64_TYPE = common_urns.monitoring_info_types.LATEST_INT64_TYPE.urn PROGRESS_TYPE = common_urns.monitoring_info_types.PROGRESS_TYPE.urn +STRING_SET_TYPE = common_urns.monitoring_info_types.SET_STRING_TYPE.urn COUNTER_TYPES = set([SUM_INT64_TYPE]) DISTRIBUTION_TYPES = set([DISTRIBUTION_INT64_TYPE]) GAUGE_TYPES = set([LATEST_INT64_TYPE]) +STRING_SET_TYPES = set([STRING_SET_TYPE]) # TODO(migryz) extract values from beam_fn_api.proto::MonitoringInfoLabels PCOLLECTION_LABEL = ( @@ -149,6 +156,14 @@ def extract_distribution(monitoring_info_proto): coders.VarIntCoder(), monitoring_info_proto.payload) +def extract_string_set_value(monitoring_info_proto): + if not is_string_set(monitoring_info_proto): + raise ValueError('Unsupported type %s' % monitoring_info_proto.type) + + coder = coders.IterableCoder(coders.StrUtf8Coder()) + return set(coder.decode(monitoring_info_proto.payload)) + + def create_labels(ptransform=None, namespace=None, name=None, pcollection=None): """Create the label dictionary based on the provided values. @@ -243,8 +258,8 @@ def int64_user_gauge(namespace, name, metric, ptransform=None): """Return the gauge monitoring info for the URN, metric and labels. Args: - namespace: User-defined namespace of counter. - name: Name of counter. + namespace: User-defined namespace of gauge metric. + name: Name of gauge metric. metric: The GaugeData containing the metrics. ptransform: The ptransform id used as a label. """ @@ -286,6 +301,24 @@ def int64_gauge(urn, metric, ptransform=None): return create_monitoring_info(urn, LATEST_INT64_TYPE, payload, labels) +def user_set_string(namespace, name, metric, ptransform=None): + """Return the string set monitoring info for the URN, metric and labels. + + Args: + namespace: User-defined namespace of StringSet. + name: Name of StringSet. + metric: The set representing the metrics. + ptransform: The ptransform id used as a label. + """ + labels = create_labels(ptransform=ptransform, namespace=namespace, name=name) + if isinstance(metric, set): + metric = list(metric) + if isinstance(metric, list): + metric = coders.IterableCoder(coders.StrUtf8Coder()).encode(metric) + return create_monitoring_info( + USER_STRING_SET_URN, STRING_SET_TYPE, metric, labels) + + def create_monitoring_info(urn, type_urn, payload, labels=None): # type: (...) -> metrics_pb2.MonitoringInfo @@ -322,15 +355,21 @@ def is_distribution(monitoring_info_proto): return monitoring_info_proto.type in DISTRIBUTION_TYPES +def is_string_set(monitoring_info_proto): + """Returns true if the monitoring info is a StringSet metric.""" + return monitoring_info_proto.type in STRING_SET_TYPES + + def is_user_monitoring_info(monitoring_info_proto): """Returns true if the monitoring info is a user metric.""" return monitoring_info_proto.urn in USER_METRIC_URNS def extract_metric_result_map_value(monitoring_info_proto): - # type: (...) -> Union[None, int, DistributionResult, GaugeResult] + # type: (...) -> Union[None, int, DistributionResult, GaugeResult, set] - """Returns the relevant GaugeResult, DistributionResult or int value. + """Returns the relevant GaugeResult, DistributionResult or int value for + counter metric, set for StringSet metric. These are the proper format for use in the MetricResult.query() result. """ @@ -344,6 +383,8 @@ def extract_metric_result_map_value(monitoring_info_proto): if is_gauge(monitoring_info_proto): (timestamp, value) = extract_gauge_value(monitoring_info_proto) return GaugeResult(GaugeData(value, timestamp)) + if is_string_set(monitoring_info_proto): + return extract_string_set_value(monitoring_info_proto) return None diff --git a/sdks/python/apache_beam/metrics/monitoring_infos_test.py b/sdks/python/apache_beam/metrics/monitoring_infos_test.py index d19e8bc10df1..022943f417c2 100644 --- a/sdks/python/apache_beam/metrics/monitoring_infos_test.py +++ b/sdks/python/apache_beam/metrics/monitoring_infos_test.py @@ -21,6 +21,7 @@ from apache_beam.metrics import monitoring_infos from apache_beam.metrics.cells import CounterCell from apache_beam.metrics.cells import GaugeCell +from apache_beam.metrics.cells import StringSetCell class MonitoringInfosTest(unittest.TestCase): @@ -64,6 +65,17 @@ def test_parse_namespace_and_name_for_user_gauge_metric(self): self.assertEqual(namespace, "counternamespace") self.assertEqual(name, "countername") + def test_parse_namespace_and_name_for_user_string_set_metric(self): + urn = monitoring_infos.USER_STRING_SET_URN + labels = {} + labels[monitoring_infos.NAMESPACE_LABEL] = "stringsetnamespace" + labels[monitoring_infos.NAME_LABEL] = "stringsetname" + input = monitoring_infos.create_monitoring_info( + urn, "typeurn", None, labels) + namespace, name = monitoring_infos.parse_namespace_and_name(input) + self.assertEqual(namespace, "stringsetnamespace") + self.assertEqual(name, "stringsetname") + def test_int64_user_gauge(self): metric = GaugeCell().get_cumulative() result = monitoring_infos.int64_user_gauge( @@ -105,6 +117,19 @@ def test_int64_counter(self): self.assertEqual(0, counter_value) self.assertEqual(result.labels, expected_labels) + def test_user_set_string(self): + expected_labels = {} + expected_labels[monitoring_infos.NAMESPACE_LABEL] = "stringsetnamespace" + expected_labels[monitoring_infos.NAME_LABEL] = "stringsetname" + + metric = StringSetCell().get_cumulative() + result = monitoring_infos.user_set_string( + 'stringsetnamespace', 'stringsetname', metric) + string_set_value = monitoring_infos.extract_string_set_value(result) + + self.assertEqual(set(), string_set_value) + self.assertEqual(result.labels, expected_labels) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/ml/gcp/naturallanguageml.py b/sdks/python/apache_beam/ml/gcp/naturallanguageml.py index 4f63aef68232..f46b8d61639b 100644 --- a/sdks/python/apache_beam/ml/gcp/naturallanguageml.py +++ b/sdks/python/apache_beam/ml/gcp/naturallanguageml.py @@ -52,15 +52,13 @@ class Document(object): from_gcs (bool): Whether the content should be interpret as a Google Cloud Storage URI. The default value is :data:`False`. """ - def __init__( self, - content, # type: str - type='PLAIN_TEXT', # type: Union[str, language_v1.Document.Type] - language_hint=None, # type: Optional[str] - encoding='UTF8', # type: Optional[str] - from_gcs=False # type: bool - ): + content: str, + type: Union[str, language_v1.Document.Type] = 'PLAIN_TEXT', + language_hint: Optional[str] = None, + encoding: Optional[str] = 'UTF8', + from_gcs: bool = False): self.content = content self.type = type self.encoding = encoding @@ -68,8 +66,7 @@ def __init__( self.from_gcs = from_gcs @staticmethod - def to_dict(document): - # type: (Document) -> Mapping[str, Optional[str]] + def to_dict(document: 'Document') -> Mapping[str, Optional[str]]: if document.from_gcs: dict_repr = {'gcs_content_uri': document.content} else: @@ -82,11 +79,11 @@ def to_dict(document): @beam.ptransform_fn def AnnotateText( - pcoll, # type: beam.pvalue.PCollection - features, # type: Union[Mapping[str, bool], language_v1.AnnotateTextRequest.Features] - timeout=None, # type: Optional[float] - metadata=None # type: Optional[Sequence[Tuple[str, str]]] -): + pcoll: beam.pvalue.PCollection, + features: Union[Mapping[str, bool], + language_v1.AnnotateTextRequest.Features], + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None): """A :class:`~apache_beam.transforms.ptransform.PTransform` for annotating text using the Google Cloud Natural Language API: https://cloud.google.com/natural-language/docs. @@ -113,10 +110,10 @@ def AnnotateText( class _AnnotateTextFn(beam.DoFn): def __init__( self, - features, # type: Union[Mapping[str, bool], language_v1.AnnotateTextRequest.Features] - timeout, # type: Optional[float] - metadata=None # type: Optional[Sequence[Tuple[str, str]]] - ): + features: Union[Mapping[str, bool], + language_v1.AnnotateTextRequest.Features], + timeout: Optional[float], + metadata: Optional[Sequence[Tuple[str, str]]] = None): self.features = features self.timeout = timeout self.metadata = metadata @@ -127,8 +124,7 @@ def setup(self): self.client = self._get_api_client() @staticmethod - def _get_api_client(): - # type: () -> language.LanguageServiceClient + def _get_api_client() -> language.LanguageServiceClient: return language.LanguageServiceClient() def process(self, element): diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference.py b/sdks/python/apache_beam/ml/inference/huggingface_inference.py index 91efcdd76a27..2934a5362910 100644 --- a/sdks/python/apache_beam/ml/inference/huggingface_inference.py +++ b/sdks/python/apache_beam/ml/inference/huggingface_inference.py @@ -677,7 +677,7 @@ def _deduplicate_device_value(self, device: Optional[str]): self._load_pipeline_args['device'] = 'cpu' else: if is_gpu_available_torch(): - self._load_pipeline_args['device'] = 'cuda:1' + self._load_pipeline_args['device'] = 'cuda:0' else: _LOGGER.warning( "HuggingFaceModelHandler specified a 'GPU' device, " diff --git a/sdks/python/apache_beam/options/value_provider.py b/sdks/python/apache_beam/options/value_provider.py index 5a5d36370f39..fa1649beed26 100644 --- a/sdks/python/apache_beam/options/value_provider.py +++ b/sdks/python/apache_beam/options/value_provider.py @@ -95,7 +95,7 @@ class RuntimeValueProvider(ValueProvider): at graph construction time. """ runtime_options = None - experiments = set() # type: Set[str] + experiments: Set[str] = set() def __init__(self, option_name, value_type, default_value): self.option_name = option_name diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index 0858d628a55c..5a400570cf18 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -80,14 +80,14 @@ class PValue(object): (2) Has a transform that can compute the value if executed. (3) Has a value which is meaningful if the transform was executed. """ - - def __init__(self, - pipeline, # type: Pipeline - tag=None, # type: Optional[str] - element_type=None, # type: Optional[Union[type,typehints.TypeConstraint]] - windowing=None, # type: Optional[Windowing] - is_bounded=True, - ): + def __init__( + self, + pipeline: 'Pipeline', + tag: Optional[str] = None, + element_type: Optional[Union[type, 'typehints.TypeConstraint']] = None, + windowing: Optional['Windowing'] = None, + is_bounded=True, + ): """Initializes a PValue with all arguments hidden behind keyword arguments. Args: @@ -101,7 +101,7 @@ def __init__(self, # The AppliedPTransform instance for the application of the PTransform # generating this PValue. The field gets initialized when a transform # gets applied. - self.producer = None # type: Optional[AppliedPTransform] + self.producer: Optional[AppliedPTransform] = None self.is_bounded = is_bounded if windowing: self._windowing = windowing @@ -152,8 +152,7 @@ def __hash__(self): return hash((self.tag, self.producer)) @property - def windowing(self): - # type: () -> Windowing + def windowing(self) -> 'Windowing': if not hasattr(self, '_windowing'): assert self.producer is not None and self.producer.transform is not None self._windowing = self.producer.transform.get_windowing( @@ -167,9 +166,7 @@ def __reduce_ex__(self, unused_version): return _InvalidUnpickledPCollection, () @staticmethod - def from_(pcoll, is_bounded=None): - # type: (PValue, Optional[bool]) -> PCollection - + def from_(pcoll: PValue, is_bounded: Optional[bool] = None) -> 'PCollection': """Create a PCollection, using another PCollection as a starting point. Transfers relevant attributes. @@ -178,8 +175,8 @@ def from_(pcoll, is_bounded=None): is_bounded = pcoll.is_bounded return PCollection(pcoll.pipeline, is_bounded=is_bounded) - def to_runner_api(self, context): - # type: (PipelineContext) -> beam_runner_api_pb2.PCollection + def to_runner_api( + self, context: 'PipelineContext') -> beam_runner_api_pb2.PCollection: return beam_runner_api_pb2.PCollection( unique_name=self._unique_name(), coder_id=context.coder_id_from_element_type( @@ -189,8 +186,7 @@ def to_runner_api(self, context): windowing_strategy_id=context.windowing_strategies.get_id( self.windowing)) - def _unique_name(self): - # type: () -> str + def _unique_name(self) -> str: if self.producer: return '%d%s.%s' % ( len(self.producer.full_label), self.producer.full_label, self.tag) @@ -198,8 +194,9 @@ def _unique_name(self): return 'PCollection%s' % id(self) @staticmethod - def from_runner_api(proto, context): - # type: (beam_runner_api_pb2.PCollection, PipelineContext) -> PCollection + def from_runner_api( + proto: beam_runner_api_pb2.PCollection, + context: 'PipelineContext') -> 'PCollection': # Producer and tag will be filled in later, the key point is that the same # object is returned for the same pcollection id. # We pass None for the PCollection's Pipeline to avoid a cycle during @@ -236,14 +233,14 @@ class PDone(PValue): class DoOutputsTuple(object): """An object grouping the multiple outputs of a ParDo or FlatMap transform.""" - - def __init__(self, - pipeline, # type: Pipeline - transform, # type: ParDo - tags, # type: Sequence[str] - main_tag, # type: Optional[str] - allow_unknown_tags=None, # type: Optional[bool] - ): + def __init__( + self, + pipeline: 'Pipeline', + transform: 'ParDo', + tags: Sequence[str], + main_tag: Optional[str], + allow_unknown_tags: Optional[bool] = None, + ): self._pipeline = pipeline self._tags = tags self._main_tag = main_tag @@ -253,9 +250,9 @@ def __init__(self, # The ApplyPTransform instance for the application of the multi FlatMap # generating this value. The field gets initialized when a transform # gets applied. - self.producer = None # type: Optional[AppliedPTransform] + self.producer: Optional[AppliedPTransform] = None # Dictionary of PCollections already associated with tags. - self._pcolls = {} # type: Dict[Optional[str], PCollection] + self._pcolls: Dict[Optional[str], PCollection] = {} def __str__(self): return '<%s>' % self._str_internal() @@ -267,25 +264,21 @@ def _str_internal(self): return '%s main_tag=%s tags=%s transform=%s' % ( self.__class__.__name__, self._main_tag, self._tags, self._transform) - def __iter__(self): - # type: () -> Iterator[PCollection] - + def __iter__(self) -> Iterator[PCollection]: """Iterates over tags returning for each call a (tag, pcollection) pair.""" if self._main_tag is not None: yield self[self._main_tag] for tag in self._tags: yield self[tag] - def __getattr__(self, tag): - # type: (str) -> PCollection + def __getattr__(self, tag: str) -> PCollection: # Special methods which may be accessed before the object is # fully constructed (e.g. in unpickling). if tag[:2] == tag[-2:] == '__': return object.__getattr__(self, tag) # type: ignore return self[tag] - def __getitem__(self, tag): - # type: (Union[int, str, None]) -> PCollection + def __getitem__(self, tag: Union[int, str, None]) -> PCollection: # Accept int tags so that we can look at Partition tags with the # same ints that we used in the partition function. # TODO(gildea): Consider requiring string-based tags everywhere. @@ -337,8 +330,7 @@ class TaggedOutput(object): if it wants to emit on the main output and TaggedOutput objects if it wants to emit a value on a specific tagged output. """ - def __init__(self, tag, value): - # type: (str, Any) -> None + def __init__(self, tag: str, value: Any) -> None: if not isinstance(tag, str): raise TypeError( 'Attempting to create a TaggedOutput with non-string tag %s' % @@ -357,8 +349,7 @@ class AsSideInput(object): options, and should not be instantiated directly. (See instead AsSingleton, AsIter, etc.) """ - def __init__(self, pcoll): - # type: (PCollection) -> None + def __init__(self, pcoll: PCollection) -> None: from apache_beam.transforms import sideinputs self.pvalue = pcoll self._window_mapping_fn = sideinputs.default_window_mapping_fn( @@ -389,8 +380,7 @@ def _windowed_coder(self): # TODO(robertwb): Get rid of _from_runtime_iterable and _view_options # in favor of _side_input_data(). - def _side_input_data(self): - # type: () -> SideInputData + def _side_input_data(self) -> 'SideInputData': view_options = self._view_options() from_runtime_iterable = type(self)._from_runtime_iterable return SideInputData( @@ -398,15 +388,14 @@ def _side_input_data(self): self._window_mapping_fn, lambda iterable: from_runtime_iterable(iterable, view_options)) - def to_runner_api(self, context): - # type: (PipelineContext) -> beam_runner_api_pb2.SideInput + def to_runner_api( + self, context: 'PipelineContext') -> beam_runner_api_pb2.SideInput: return self._side_input_data().to_runner_api(context) @staticmethod - def from_runner_api(proto, # type: beam_runner_api_pb2.SideInput - context # type: PipelineContext - ): - # type: (...) -> _UnpickledSideInput + def from_runner_api( + proto: beam_runner_api_pb2.SideInput, + context: 'PipelineContext') -> '_UnpickledSideInput': return _UnpickledSideInput(SideInputData.from_runner_api(proto, context)) @staticmethod @@ -418,8 +407,7 @@ def requires_keyed_input(self): class _UnpickledSideInput(AsSideInput): - def __init__(self, side_input_data): - # type: (SideInputData) -> None + def __init__(self, side_input_data: 'SideInputData') -> None: self._data = side_input_data self._window_mapping_fn = side_input_data.window_mapping_fn @@ -450,17 +438,17 @@ def _side_input_data(self): class SideInputData(object): """All of the data about a side input except for the bound PCollection.""" - def __init__(self, - access_pattern, # type: str - window_mapping_fn, # type: sideinputs.WindowMappingFn - view_fn - ): + def __init__( + self, + access_pattern: str, + window_mapping_fn: 'sideinputs.WindowMappingFn', + view_fn): self.access_pattern = access_pattern self.window_mapping_fn = window_mapping_fn self.view_fn = view_fn - def to_runner_api(self, context): - # type: (PipelineContext) -> beam_runner_api_pb2.SideInput + def to_runner_api( + self, context: 'PipelineContext') -> beam_runner_api_pb2.SideInput: return beam_runner_api_pb2.SideInput( access_pattern=beam_runner_api_pb2.FunctionSpec( urn=self.access_pattern), @@ -472,8 +460,9 @@ def to_runner_api(self, context): payload=pickler.dumps(self.window_mapping_fn))) @staticmethod - def from_runner_api(proto, unused_context): - # type: (beam_runner_api_pb2.SideInput, PipelineContext) -> SideInputData + def from_runner_api( + proto: beam_runner_api_pb2.SideInput, + unused_context: 'PipelineContext') -> 'SideInputData': assert proto.view_fn.urn == python_urns.PICKLED_VIEWFN assert ( proto.window_mapping_fn.urn == python_urns.PICKLED_WINDOW_MAPPING_FN) @@ -501,8 +490,8 @@ class AsSingleton(AsSideInput): """ _NO_DEFAULT = object() - def __init__(self, pcoll, default_value=_NO_DEFAULT): - # type: (PCollection, Any) -> None + def __init__( + self, pcoll: PCollection, default_value: Any = _NO_DEFAULT) -> None: super().__init__(pcoll) self.default_value = default_value @@ -552,8 +541,7 @@ def __repr__(self): def _from_runtime_iterable(it, options): return it - def _side_input_data(self): - # type: () -> SideInputData + def _side_input_data(self) -> SideInputData: return SideInputData( common_urns.side_inputs.ITERABLE.urn, self._window_mapping_fn, @@ -582,8 +570,7 @@ class AsList(AsSideInput): def _from_runtime_iterable(it, options): return list(it) - def _side_input_data(self): - # type: () -> SideInputData + def _side_input_data(self) -> SideInputData: return SideInputData( common_urns.side_inputs.ITERABLE.urn, self._window_mapping_fn, list) @@ -607,8 +594,7 @@ class AsDict(AsSideInput): def _from_runtime_iterable(it, options): return dict(it) - def _side_input_data(self): - # type: () -> SideInputData + def _side_input_data(self) -> SideInputData: return SideInputData( common_urns.side_inputs.ITERABLE.urn, self._window_mapping_fn, dict) @@ -631,8 +617,7 @@ def _from_runtime_iterable(it, options): result[k].append(v) return result - def _side_input_data(self): - # type: () -> SideInputData + def _side_input_data(self) -> SideInputData: return SideInputData( common_urns.side_inputs.MULTIMAP.urn, self._window_mapping_fn, diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_metrics.py b/sdks/python/apache_beam/runners/dataflow/dataflow_metrics.py index 7e6a11c4abf8..78c3b64595b0 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_metrics.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_metrics.py @@ -90,6 +90,10 @@ def _is_counter(metric_result): def _is_distribution(metric_result): return isinstance(metric_result.attempted, DistributionResult) + @staticmethod + def _is_string_set(metric_result): + return isinstance(metric_result.attempted, set) + def _translate_step_name(self, internal_name): """Translate between internal step names (e.g. "s1") and user step names.""" if not self._job_graph: @@ -233,6 +237,8 @@ def _get_metric_value(self, metric): lambda x: x.key == 'sum').value.double_value) return DistributionResult( DistributionData(dist_sum, dist_count, dist_min, dist_max)) + #TODO(https://github.com/apache/beam/issues/31788) support StringSet after + # re-generate apiclient else: return None @@ -277,8 +283,13 @@ def query(self, filter=None): elm for elm in metric_results if self.matches(filter, elm.key) and DataflowMetrics._is_distribution(elm) ], - self.GAUGES: [] - } # TODO(pabloem): Add Gauge support for dataflow. + # TODO(pabloem): Add Gauge support for dataflow. + self.GAUGES: [], + self.STRINGSETS: [ + elm for elm in metric_results if self.matches(filter, elm.key) and + DataflowMetrics._is_string_set(elm) + ] + } def main(argv): diff --git a/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/dataflow_v1b3_client.py b/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/dataflow_v1b3_client.py index cc982098797b..e42b180bbecd 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/dataflow_v1b3_client.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/dataflow_v1b3_client.py @@ -1,8 +1,5 @@ """Generated client library for dataflow version v1b3.""" # NOTE: This file is autogenerated and should not be edited by hand. - -from __future__ import absolute_import - from apitools.base.py import base_api from . import dataflow_v1b3_messages as messages @@ -17,9 +14,7 @@ class DataflowV1b3(base_api.BaseApiClient): _PACKAGE = 'dataflow' _SCOPES = [ 'https://www.googleapis.com/auth/cloud-platform', - 'https://www.googleapis.com/auth/compute', - 'https://www.googleapis.com/auth/compute.readonly', - 'https://www.googleapis.com/auth/userinfo.email' + 'https://www.googleapis.com/auth/compute' ] _VERSION = 'v1b3' _CLIENT_ID = '1042881264118.apps.googleusercontent.com' @@ -75,7 +70,6 @@ def __init__( self.projects_locations_jobs = self.ProjectsLocationsJobsService(self) self.projects_locations_snapshots = self.ProjectsLocationsSnapshotsService( self) - self.projects_locations_sql = self.ProjectsLocationsSqlService(self) self.projects_locations_templates = self.ProjectsLocationsTemplatesService( self) self.projects_locations = self.ProjectsLocationsService(self) @@ -254,7 +248,7 @@ def __init__(self, client): self._upload_configs = {} def Aggregated(self, request, global_params=None): - r"""List the jobs of a project across all regions. + r"""List the jobs of a project across all regions. **Note:** This method doesn't support filtering the list of jobs by name. Args: request: (DataflowProjectsJobsAggregatedRequest) input message @@ -270,7 +264,8 @@ def Aggregated(self, request, global_params=None): method_id='dataflow.projects.jobs.aggregated', ordered_params=['projectId'], path_params=['projectId'], - query_params=['filter', 'location', 'pageSize', 'pageToken', 'view'], + query_params= + ['filter', 'location', 'name', 'pageSize', 'pageToken', 'view'], relative_path='v1b3/projects/{projectId}/jobs:aggregated', request_field='', request_type_name='DataflowProjectsJobsAggregatedRequest', @@ -279,7 +274,7 @@ def Aggregated(self, request, global_params=None): ) def Create(self, request, global_params=None): - r"""Creates a Cloud Dataflow job. To create a job, we recommend using `projects.locations.jobs.create` with a [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints). Using `projects.jobs.create` is not recommended, as your job will always start in `us-central1`. + r"""Creates a Cloud Dataflow job. To create a job, we recommend using `projects.locations.jobs.create` with a [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints). Using `projects.jobs.create` is not recommended, as your job will always start in `us-central1`. Do not enter confidential information when you supply string values using the API. Args: request: (DataflowProjectsJobsCreateRequest) input message @@ -354,7 +349,7 @@ def GetMetrics(self, request, global_params=None): ) def List(self, request, global_params=None): - r"""List the jobs of a project. To list the jobs of a project in a region, we recommend using `projects.locations.jobs.list` with a [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints). To list the all jobs across all regions, use `projects.jobs.aggregated`. Using `projects.jobs.list` is not recommended, as you can only get the list of jobs that are running in `us-central1`. + r"""List the jobs of a project. To list the jobs of a project in a region, we recommend using `projects.locations.jobs.list` with a [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints). To list the all jobs across all regions, use `projects.jobs.aggregated`. Using `projects.jobs.list` is not recommended, because you can only get the list of jobs that are running in `us-central1`. `projects.locations.jobs.list` and `projects.jobs.list` support filtering the list of jobs by name. Filtering by name isn't supported by `projects.jobs.aggregated`. Args: request: (DataflowProjectsJobsListRequest) input message @@ -370,7 +365,8 @@ def List(self, request, global_params=None): method_id='dataflow.projects.jobs.list', ordered_params=['projectId'], path_params=['projectId'], - query_params=['filter', 'location', 'pageSize', 'pageToken', 'view'], + query_params= + ['filter', 'location', 'name', 'pageSize', 'pageToken', 'view'], relative_path='v1b3/projects/{projectId}/jobs', request_field='', request_type_name='DataflowProjectsJobsListRequest', @@ -420,7 +416,7 @@ def Update(self, request, global_params=None): method_id='dataflow.projects.jobs.update', ordered_params=['projectId', 'jobId'], path_params=['jobId', 'projectId'], - query_params=['location'], + query_params=['location', 'updateMask'], relative_path='v1b3/projects/{projectId}/jobs/{jobId}', request_field='job', request_type_name='DataflowProjectsJobsUpdateRequest', @@ -611,7 +607,7 @@ def __init__(self, client): self._upload_configs = {} def GetExecutionDetails(self, request, global_params=None): - r"""Request detailed information about the execution status of a stage of the job. + r"""Request detailed information about the execution status of a stage of the job. EXPERIMENTAL. This API is subject to change or removal without notice. Args: request: (DataflowProjectsLocationsJobsStagesGetExecutionDetailsRequest) input message @@ -710,7 +706,7 @@ def __init__(self, client): self._upload_configs = {} def Create(self, request, global_params=None): - r"""Creates a Cloud Dataflow job. To create a job, we recommend using `projects.locations.jobs.create` with a [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints). Using `projects.jobs.create` is not recommended, as your job will always start in `us-central1`. + r"""Creates a Cloud Dataflow job. To create a job, we recommend using `projects.locations.jobs.create` with a [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints). Using `projects.jobs.create` is not recommended, as your job will always start in `us-central1`. Do not enter confidential information when you supply string values using the API. Args: request: (DataflowProjectsLocationsJobsCreateRequest) input message @@ -761,7 +757,7 @@ def Get(self, request, global_params=None): ) def GetExecutionDetails(self, request, global_params=None): - r"""Request detailed information about the execution status of the job. + r"""Request detailed information about the execution status of the job. EXPERIMENTAL. This API is subject to change or removal without notice. Args: request: (DataflowProjectsLocationsJobsGetExecutionDetailsRequest) input message @@ -814,7 +810,7 @@ def GetMetrics(self, request, global_params=None): ) def List(self, request, global_params=None): - r"""List the jobs of a project. To list the jobs of a project in a region, we recommend using `projects.locations.jobs.list` with a [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints). To list the all jobs across all regions, use `projects.jobs.aggregated`. Using `projects.jobs.list` is not recommended, as you can only get the list of jobs that are running in `us-central1`. + r"""List the jobs of a project. To list the jobs of a project in a region, we recommend using `projects.locations.jobs.list` with a [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints). To list the all jobs across all regions, use `projects.jobs.aggregated`. Using `projects.jobs.list` is not recommended, because you can only get the list of jobs that are running in `us-central1`. `projects.locations.jobs.list` and `projects.jobs.list` support filtering the list of jobs by name. Filtering by name isn't supported by `projects.jobs.aggregated`. Args: request: (DataflowProjectsLocationsJobsListRequest) input message @@ -830,7 +826,7 @@ def List(self, request, global_params=None): method_id='dataflow.projects.locations.jobs.list', ordered_params=['projectId', 'location'], path_params=['location', 'projectId'], - query_params=['filter', 'pageSize', 'pageToken', 'view'], + query_params=['filter', 'name', 'pageSize', 'pageToken', 'view'], relative_path='v1b3/projects/{projectId}/locations/{location}/jobs', request_field='', request_type_name='DataflowProjectsLocationsJobsListRequest', @@ -881,7 +877,7 @@ def Update(self, request, global_params=None): method_id='dataflow.projects.locations.jobs.update', ordered_params=['projectId', 'location', 'jobId'], path_params=['jobId', 'location', 'projectId'], - query_params=[], + query_params=['updateMask'], relative_path= 'v1b3/projects/{projectId}/locations/{location}/jobs/{jobId}', request_field='job', @@ -978,41 +974,6 @@ def List(self, request, global_params=None): supports_download=False, ) - class ProjectsLocationsSqlService(base_api.BaseApiService): - """Service class for the projects_locations_sql resource.""" - - _NAME = 'projects_locations_sql' - - def __init__(self, client): - super(DataflowV1b3.ProjectsLocationsSqlService, self).__init__(client) - self._upload_configs = {} - - def Validate(self, request, global_params=None): - r"""Validates a GoogleSQL query for Cloud Dataflow syntax. Will always confirm the given query parses correctly, and if able to look up schema information from DataCatalog, will validate that the query analyzes properly as well. - - Args: - request: (DataflowProjectsLocationsSqlValidateRequest) input message - global_params: (StandardQueryParameters, default: None) global arguments - Returns: - (ValidateResponse) The response message. - """ - config = self.GetMethodConfig('Validate') - return self._RunMethod(config, request, global_params=global_params) - - Validate.method_config = lambda: base_api.ApiMethodInfo( - http_method='GET', - method_id='dataflow.projects.locations.sql.validate', - ordered_params=['projectId', 'location'], - path_params=['location', 'projectId'], - query_params=['query'], - relative_path= - 'v1b3/projects/{projectId}/locations/{location}/sql:validate', - request_field='', - request_type_name='DataflowProjectsLocationsSqlValidateRequest', - response_type_name='ValidateResponse', - supports_download=False, - ) - class ProjectsLocationsTemplatesService(base_api.BaseApiService): """Service class for the projects_locations_templates resource.""" @@ -1024,7 +985,7 @@ def __init__(self, client): self._upload_configs = {} def Create(self, request, global_params=None): - r"""Creates a Cloud Dataflow job from a template. + r"""Creates a Cloud Dataflow job from a template. Do not enter confidential information when you supply string values using the API. To create a job, we recommend using `projects.locations.templates.create` with a [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints). Using `projects.templates.create` is not recommended, because your job will always start in `us-central1`. Args: request: (DataflowProjectsLocationsTemplatesCreateRequest) input message @@ -1050,7 +1011,7 @@ def Create(self, request, global_params=None): ) def Get(self, request, global_params=None): - r"""Get the template associated with a template. + r"""Get the template associated with a template. To get the template, we recommend using `projects.locations.templates.get` with a [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints). Using `projects.templates.get` is not recommended, because only templates that are running in `us-central1` are retrieved. Args: request: (DataflowProjectsLocationsTemplatesGetRequest) input message @@ -1076,7 +1037,7 @@ def Get(self, request, global_params=None): ) def Launch(self, request, global_params=None): - r"""Launch a template. + r"""Launches a template. To launch a template, we recommend using `projects.locations.templates.launch` with a [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints). Using `projects.templates.launch` is not recommended, because jobs launched from the template will always start in `us-central1`. Args: request: (DataflowProjectsLocationsTemplatesLaunchRequest) input message @@ -1210,7 +1171,7 @@ def __init__(self, client): self._upload_configs = {} def Create(self, request, global_params=None): - r"""Creates a Cloud Dataflow job from a template. + r"""Creates a Cloud Dataflow job from a template. Do not enter confidential information when you supply string values using the API. To create a job, we recommend using `projects.locations.templates.create` with a [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints). Using `projects.templates.create` is not recommended, because your job will always start in `us-central1`. Args: request: (DataflowProjectsTemplatesCreateRequest) input message @@ -1235,7 +1196,7 @@ def Create(self, request, global_params=None): ) def Get(self, request, global_params=None): - r"""Get the template associated with a template. + r"""Get the template associated with a template. To get the template, we recommend using `projects.locations.templates.get` with a [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints). Using `projects.templates.get` is not recommended, because only templates that are running in `us-central1` are retrieved. Args: request: (DataflowProjectsTemplatesGetRequest) input message @@ -1260,7 +1221,7 @@ def Get(self, request, global_params=None): ) def Launch(self, request, global_params=None): - r"""Launch a template. + r"""Launches a template. To launch a template, we recommend using `projects.locations.templates.launch` with a [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints). Using `projects.templates.launch` is not recommended, because jobs launched from the template will always start in `us-central1`. Args: request: (DataflowProjectsTemplatesLaunchRequest) input message diff --git a/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/dataflow_v1b3_messages.py b/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/dataflow_v1b3_messages.py index e7cf625250d2..c0bbfa74ac1e 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/dataflow_v1b3_messages.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/dataflow_v1b3_messages.py @@ -4,8 +4,6 @@ """ # NOTE: This file is autogenerated and should not be edited by hand. -from __future__ import absolute_import - from apitools.base.protorpclite import messages as _messages from apitools.base.py import encoding from apitools.base.py import extra_types @@ -169,6 +167,22 @@ class AlgorithmValueValuesEnum(_messages.Enum): maxNumWorkers = _messages.IntegerField(2, variant=_messages.Variant.INT32) +class Base2Exponent(_messages.Message): + r"""Exponential buckets where the growth factor between buckets is + `2**(2**-scale)`. e.g. for `scale=1` growth factor is + `2**(2**(-1))=sqrt(2)`. `n` buckets will have the following boundaries. - + 0th: [0, gf) - i in [1, n-1]: [gf^(i), gf^(i+1)) + + Fields: + numberOfBuckets: Must be greater than 0. + scale: Must be between -3 and 3. This forces the growth factor of the + bucket boundaries to be between `2^(1/8)` and `256`. + """ + + numberOfBuckets = _messages.IntegerField(1, variant=_messages.Variant.INT32) + scale = _messages.IntegerField(2, variant=_messages.Variant.INT32) + + class BigQueryIODetails(_messages.Message): r"""Metadata for a BigQuery connector used by the job. @@ -199,6 +213,18 @@ class BigTableIODetails(_messages.Message): tableId = _messages.StringField(3) +class BucketOptions(_messages.Message): + r"""`BucketOptions` describes the bucket boundaries used in the histogram. + + Fields: + exponential: Bucket boundaries grow exponentially. + linear: Bucket boundaries grow linearly. + """ + + exponential = _messages.MessageField('Base2Exponent', 1) + linear = _messages.MessageField('Linear', 2) + + class CPUTime(_messages.Message): r"""Modeled after information exposed by /proc/stat. @@ -288,6 +314,12 @@ class ContainerSpec(_messages.Message): Fields: defaultEnvironment: Default runtime environment for the job. image: Name of the docker container image. E.g., gcr.io/project/some-image + imageRepositoryCertPath: Cloud Storage path to self-signed certificate of + private registry. + imageRepositoryPasswordSecretId: Secret Manager secret id for password to + authenticate to private registry. + imageRepositoryUsernameSecretId: Secret Manager secret id for username to + authenticate to private registry. metadata: Metadata describing a template including description and validation rules. sdkInfo: Required. SDK info of the Flex Template. @@ -296,8 +328,11 @@ class ContainerSpec(_messages.Message): defaultEnvironment = _messages.MessageField( 'FlexTemplateRuntimeEnvironment', 1) image = _messages.StringField(2) - metadata = _messages.MessageField('TemplateMetadata', 3) - sdkInfo = _messages.MessageField('SDKInfo', 4) + imageRepositoryCertPath = _messages.StringField(3) + imageRepositoryPasswordSecretId = _messages.StringField(4) + imageRepositoryUsernameSecretId = _messages.StringField(5) + metadata = _messages.MessageField('TemplateMetadata', 6) + sdkInfo = _messages.MessageField('SDKInfo', 7) class CounterMetadata(_messages.Message): @@ -568,6 +603,94 @@ class DataDiskAssignment(_messages.Message): vmInstance = _messages.StringField(2) +class DataSamplingConfig(_messages.Message): + r"""Configuration options for sampling elements. + + Enums: + BehaviorsValueListEntryValuesEnum: + + Fields: + behaviors: List of given sampling behaviors to enable. For example, + specifying behaviors = [ALWAYS_ON] samples in-flight elements but does + not sample exceptions. Can be used to specify multiple behaviors like, + behaviors = [ALWAYS_ON, EXCEPTIONS] for specifying periodic sampling and + exception sampling. If DISABLED is in the list, then sampling will be + disabled and ignore the other given behaviors. Ordering does not matter. + """ + class BehaviorsValueListEntryValuesEnum(_messages.Enum): + r"""BehaviorsValueListEntryValuesEnum enum type. + + Values: + DATA_SAMPLING_BEHAVIOR_UNSPECIFIED: If given, has no effect on sampling + behavior. Used as an unknown or unset sentinel value. + DISABLED: When given, disables element sampling. Has same behavior as + not setting the behavior. + ALWAYS_ON: When given, enables sampling in-flight from all PCollections. + EXCEPTIONS: When given, enables sampling input elements when a user- + defined DoFn causes an exception. + """ + DATA_SAMPLING_BEHAVIOR_UNSPECIFIED = 0 + DISABLED = 1 + ALWAYS_ON = 2 + EXCEPTIONS = 3 + + behaviors = _messages.EnumField( + 'BehaviorsValueListEntryValuesEnum', 1, repeated=True) + + +class DataSamplingReport(_messages.Message): + r"""Contains per-worker telemetry about the data sampling feature. + + Fields: + bytesWrittenDelta: Optional. Delta of bytes written to file from previous + report. + elementsSampledBytes: Optional. Delta of bytes sampled from previous + report. + elementsSampledCount: Optional. Delta of number of elements sampled from + previous report. + exceptionsSampledCount: Optional. Delta of number of samples taken from + user code exceptions from previous report. + pcollectionsSampledCount: Optional. Delta of number of PCollections + sampled from previous report. + persistenceErrorsCount: Optional. Delta of errors counts from persisting + the samples from previous report. + translationErrorsCount: Optional. Delta of errors counts from retrieving, + or translating the samples from previous report. + """ + + bytesWrittenDelta = _messages.IntegerField(1) + elementsSampledBytes = _messages.IntegerField(2) + elementsSampledCount = _messages.IntegerField(3) + exceptionsSampledCount = _messages.IntegerField(4) + pcollectionsSampledCount = _messages.IntegerField(5) + persistenceErrorsCount = _messages.IntegerField(6) + translationErrorsCount = _messages.IntegerField(7) + + +class DataflowHistogramValue(_messages.Message): + r"""Summary statistics for a population of values. HistogramValue contains a + sequence of buckets and gives a count of values that fall into each bucket. + Bucket boundares are defined by a formula and bucket widths are either fixed + or exponentially increasing. + + Fields: + bucketCounts: Optional. The number of values in each bucket of the + histogram, as described in `bucket_options`. `bucket_counts` should + contain N values, where N is the number of buckets specified in + `bucket_options`. If `bucket_counts` has fewer than N values, the + remaining values are assumed to be 0. + bucketOptions: Describes the bucket boundaries used in the histogram. + count: Number of values recorded in this histogram. + outlierStats: Statistics on the values recorded in the histogram that fall + out of the bucket boundaries. + """ + + bucketCounts = _messages.IntegerField(1, repeated=True) + bucketOptions = _messages.MessageField('BucketOptions', 2) + count = _messages.IntegerField(3) + outlierStats = _messages.MessageField('OutlierStats', 4) + + class DataflowProjectsDeleteSnapshotsRequest(_messages.Message): r"""A DataflowProjectsDeleteSnapshotsRequest object. @@ -596,6 +719,7 @@ class DataflowProjectsJobsAggregatedRequest(_messages.Message): location: The [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints) that contains this job. + name: Optional. The job name. pageSize: If there are many jobs, limit response to at most this many. The actual number of jobs returned will be the lesser of max_responses and an unspecified server-defined limit. @@ -635,7 +759,12 @@ class ViewValueValuesEnum(_messages.Enum): JOB_VIEW_SUMMARY: Request summary information only: Project ID, Job ID, job name, job type, job status, start/end time, and Cloud SDK version details. - JOB_VIEW_ALL: Request all information available for this job. + JOB_VIEW_ALL: Request all information available for this job. When the + job is in `JOB_STATE_PENDING`, the job has been created but is not yet + running, and not all job information is available. For complete job + information, wait until the job in is `JOB_STATE_RUNNING`. For more + information, see [JobState](https://cloud.google.com/dataflow/docs/ref + erence/rest/v1b3/projects.jobs#jobstate). JOB_VIEW_DESCRIPTION: Request summary info and limited job description data for steps, labels and environment. """ @@ -646,10 +775,11 @@ class ViewValueValuesEnum(_messages.Enum): filter = _messages.EnumField('FilterValueValuesEnum', 1) location = _messages.StringField(2) - pageSize = _messages.IntegerField(3, variant=_messages.Variant.INT32) - pageToken = _messages.StringField(4) - projectId = _messages.StringField(5, required=True) - view = _messages.EnumField('ViewValueValuesEnum', 6) + name = _messages.StringField(3) + pageSize = _messages.IntegerField(4, variant=_messages.Variant.INT32) + pageToken = _messages.StringField(5) + projectId = _messages.StringField(6, required=True) + view = _messages.EnumField('ViewValueValuesEnum', 7) class DataflowProjectsJobsCreateRequest(_messages.Message): @@ -677,7 +807,12 @@ class ViewValueValuesEnum(_messages.Enum): JOB_VIEW_SUMMARY: Request summary information only: Project ID, Job ID, job name, job type, job status, start/end time, and Cloud SDK version details. - JOB_VIEW_ALL: Request all information available for this job. + JOB_VIEW_ALL: Request all information available for this job. When the + job is in `JOB_STATE_PENDING`, the job has been created but is not yet + running, and not all job information is available. For complete job + information, wait until the job in is `JOB_STATE_RUNNING`. For more + information, see [JobState](https://cloud.google.com/dataflow/docs/ref + erence/rest/v1b3/projects.jobs#jobstate). JOB_VIEW_DESCRIPTION: Request summary info and limited job description data for steps, labels and environment. """ @@ -766,7 +901,12 @@ class ViewValueValuesEnum(_messages.Enum): JOB_VIEW_SUMMARY: Request summary information only: Project ID, Job ID, job name, job type, job status, start/end time, and Cloud SDK version details. - JOB_VIEW_ALL: Request all information available for this job. + JOB_VIEW_ALL: Request all information available for this job. When the + job is in `JOB_STATE_PENDING`, the job has been created but is not yet + running, and not all job information is available. For complete job + information, wait until the job in is `JOB_STATE_RUNNING`. For more + information, see [JobState](https://cloud.google.com/dataflow/docs/ref + erence/rest/v1b3/projects.jobs#jobstate). JOB_VIEW_DESCRIPTION: Request summary info and limited job description data for steps, labels and environment. """ @@ -794,6 +934,7 @@ class DataflowProjectsJobsListRequest(_messages.Message): location: The [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints) that contains this job. + name: Optional. The job name. pageSize: If there are many jobs, limit response to at most this many. The actual number of jobs returned will be the lesser of max_responses and an unspecified server-defined limit. @@ -833,7 +974,12 @@ class ViewValueValuesEnum(_messages.Enum): JOB_VIEW_SUMMARY: Request summary information only: Project ID, Job ID, job name, job type, job status, start/end time, and Cloud SDK version details. - JOB_VIEW_ALL: Request all information available for this job. + JOB_VIEW_ALL: Request all information available for this job. When the + job is in `JOB_STATE_PENDING`, the job has been created but is not yet + running, and not all job information is available. For complete job + information, wait until the job in is `JOB_STATE_RUNNING`. For more + information, see [JobState](https://cloud.google.com/dataflow/docs/ref + erence/rest/v1b3/projects.jobs#jobstate). JOB_VIEW_DESCRIPTION: Request summary info and limited job description data for steps, labels and environment. """ @@ -844,10 +990,11 @@ class ViewValueValuesEnum(_messages.Enum): filter = _messages.EnumField('FilterValueValuesEnum', 1) location = _messages.StringField(2) - pageSize = _messages.IntegerField(3, variant=_messages.Variant.INT32) - pageToken = _messages.StringField(4) - projectId = _messages.StringField(5, required=True) - view = _messages.EnumField('ViewValueValuesEnum', 6) + name = _messages.StringField(3) + pageSize = _messages.IntegerField(4, variant=_messages.Variant.INT32) + pageToken = _messages.StringField(5) + projectId = _messages.StringField(6, required=True) + view = _messages.EnumField('ViewValueValuesEnum', 7) class DataflowProjectsJobsMessagesListRequest(_messages.Message): @@ -947,12 +1094,19 @@ class DataflowProjectsJobsUpdateRequest(_messages.Message): (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints) that contains this job. projectId: The ID of the Cloud Platform project that the job belongs to. + updateMask: The list of fields to update relative to Job. If empty, only + RequestedJobState will be considered for update. If the FieldMask is not + empty and RequestedJobState is none/empty, The fields specified in the + update mask will be the only ones considered for update. If both + RequestedJobState and update_mask are specified, an error will be + returned as we cannot update both state and mask. """ job = _messages.MessageField('Job', 1) jobId = _messages.StringField(2, required=True) location = _messages.StringField(3) projectId = _messages.StringField(4, required=True) + updateMask = _messages.StringField(5) class DataflowProjectsJobsWorkItemsLeaseRequest(_messages.Message): @@ -1030,7 +1184,12 @@ class ViewValueValuesEnum(_messages.Enum): JOB_VIEW_SUMMARY: Request summary information only: Project ID, Job ID, job name, job type, job status, start/end time, and Cloud SDK version details. - JOB_VIEW_ALL: Request all information available for this job. + JOB_VIEW_ALL: Request all information available for this job. When the + job is in `JOB_STATE_PENDING`, the job has been created but is not yet + running, and not all job information is available. For complete job + information, wait until the job in is `JOB_STATE_RUNNING`. For more + information, see [JobState](https://cloud.google.com/dataflow/docs/ref + erence/rest/v1b3/projects.jobs#jobstate). JOB_VIEW_DESCRIPTION: Request summary info and limited job description data for steps, labels and environment. """ @@ -1152,7 +1311,12 @@ class ViewValueValuesEnum(_messages.Enum): JOB_VIEW_SUMMARY: Request summary information only: Project ID, Job ID, job name, job type, job status, start/end time, and Cloud SDK version details. - JOB_VIEW_ALL: Request all information available for this job. + JOB_VIEW_ALL: Request all information available for this job. When the + job is in `JOB_STATE_PENDING`, the job has been created but is not yet + running, and not all job information is available. For complete job + information, wait until the job in is `JOB_STATE_RUNNING`. For more + information, see [JobState](https://cloud.google.com/dataflow/docs/ref + erence/rest/v1b3/projects.jobs#jobstate). JOB_VIEW_DESCRIPTION: Request summary info and limited job description data for steps, labels and environment. """ @@ -1180,6 +1344,7 @@ class DataflowProjectsLocationsJobsListRequest(_messages.Message): location: The [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints) that contains this job. + name: Optional. The job name. pageSize: If there are many jobs, limit response to at most this many. The actual number of jobs returned will be the lesser of max_responses and an unspecified server-defined limit. @@ -1219,7 +1384,12 @@ class ViewValueValuesEnum(_messages.Enum): JOB_VIEW_SUMMARY: Request summary information only: Project ID, Job ID, job name, job type, job status, start/end time, and Cloud SDK version details. - JOB_VIEW_ALL: Request all information available for this job. + JOB_VIEW_ALL: Request all information available for this job. When the + job is in `JOB_STATE_PENDING`, the job has been created but is not yet + running, and not all job information is available. For complete job + information, wait until the job in is `JOB_STATE_RUNNING`. For more + information, see [JobState](https://cloud.google.com/dataflow/docs/ref + erence/rest/v1b3/projects.jobs#jobstate). JOB_VIEW_DESCRIPTION: Request summary info and limited job description data for steps, labels and environment. """ @@ -1230,10 +1400,11 @@ class ViewValueValuesEnum(_messages.Enum): filter = _messages.EnumField('FilterValueValuesEnum', 1) location = _messages.StringField(2, required=True) - pageSize = _messages.IntegerField(3, variant=_messages.Variant.INT32) - pageToken = _messages.StringField(4) - projectId = _messages.StringField(5, required=True) - view = _messages.EnumField('ViewValueValuesEnum', 6) + name = _messages.StringField(3) + pageSize = _messages.IntegerField(4, variant=_messages.Variant.INT32) + pageToken = _messages.StringField(5) + projectId = _messages.StringField(6, required=True) + view = _messages.EnumField('ViewValueValuesEnum', 7) class DataflowProjectsLocationsJobsMessagesListRequest(_messages.Message): @@ -1380,12 +1551,19 @@ class DataflowProjectsLocationsJobsUpdateRequest(_messages.Message): (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints) that contains this job. projectId: The ID of the Cloud Platform project that the job belongs to. + updateMask: The list of fields to update relative to Job. If empty, only + RequestedJobState will be considered for update. If the FieldMask is not + empty and RequestedJobState is none/empty, The fields specified in the + update mask will be the only ones considered for update. If both + RequestedJobState and update_mask are specified, an error will be + returned as we cannot update both state and mask. """ job = _messages.MessageField('Job', 1) jobId = _messages.StringField(2, required=True) location = _messages.StringField(3, required=True) projectId = _messages.StringField(4, required=True) + updateMask = _messages.StringField(5) class DataflowProjectsLocationsJobsWorkItemsLeaseRequest(_messages.Message): @@ -1472,23 +1650,6 @@ class DataflowProjectsLocationsSnapshotsListRequest(_messages.Message): projectId = _messages.StringField(3, required=True) -class DataflowProjectsLocationsSqlValidateRequest(_messages.Message): - r"""A DataflowProjectsLocationsSqlValidateRequest object. - - Fields: - location: The [regional endpoint] - (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints) to - which to direct the request. - projectId: Required. The ID of the Cloud Platform project that the job - belongs to. - query: The sql query to validate. - """ - - location = _messages.StringField(1, required=True) - projectId = _messages.StringField(2, required=True) - query = _messages.StringField(3) - - class DataflowProjectsLocationsTemplatesCreateRequest(_messages.Message): r"""A DataflowProjectsLocationsTemplatesCreateRequest object. @@ -1543,13 +1704,13 @@ class DataflowProjectsLocationsTemplatesLaunchRequest(_messages.Message): r"""A DataflowProjectsLocationsTemplatesLaunchRequest object. Fields: - dynamicTemplate_gcsPath: Path to dynamic template spec file on Cloud - Storage. The file must be a Json serialized DynamicTemplateFieSpec - object. + dynamicTemplate_gcsPath: Path to the dynamic template specification file + on Cloud Storage. The file must be a JSON serialized + `DynamicTemplateFileSpec` object. dynamicTemplate_stagingLocation: Cloud Storage path for staging dependencies. Must be a valid Cloud Storage URL, beginning with `gs://`. - gcsPath: A Cloud Storage path to the template from which to create the - job. Must be valid Cloud Storage URL, beginning with 'gs://'. + gcsPath: A Cloud Storage path to the template to use to create the job. + Must be valid Cloud Storage URL, beginning with `gs://`. launchTemplateParameters: A LaunchTemplateParameters resource to be passed as the request body. location: The [regional endpoint] @@ -1668,13 +1829,13 @@ class DataflowProjectsTemplatesLaunchRequest(_messages.Message): r"""A DataflowProjectsTemplatesLaunchRequest object. Fields: - dynamicTemplate_gcsPath: Path to dynamic template spec file on Cloud - Storage. The file must be a Json serialized DynamicTemplateFieSpec - object. + dynamicTemplate_gcsPath: Path to the dynamic template specification file + on Cloud Storage. The file must be a JSON serialized + `DynamicTemplateFileSpec` object. dynamicTemplate_stagingLocation: Cloud Storage path for staging dependencies. Must be a valid Cloud Storage URL, beginning with `gs://`. - gcsPath: A Cloud Storage path to the template from which to create the - job. Must be valid Cloud Storage URL, beginning with 'gs://'. + gcsPath: A Cloud Storage path to the template to use to create the job. + Must be valid Cloud Storage URL, beginning with `gs://`. launchTemplateParameters: A LaunchTemplateParameters resource to be passed as the request body. location: The [regional endpoint] @@ -1726,11 +1887,14 @@ class DebugOptions(_messages.Message): r"""Describes any options that have an effect on the debugging of pipelines. Fields: - enableHotKeyLogging: When true, enables the logging of the literal hot key - to the user's Cloud Logging. + dataSampling: Configuration options for sampling elements from a running + pipeline. + enableHotKeyLogging: Optional. When true, enables the logging of the + literal hot key to the user's Cloud Logging. """ - enableHotKeyLogging = _messages.BooleanField(1) + dataSampling = _messages.MessageField('DataSamplingConfig', 1) + enableHotKeyLogging = _messages.BooleanField(2) class DeleteSnapshotResponse(_messages.Message): @@ -1883,10 +2047,17 @@ class Environment(_messages.Message): r"""Describes the environment in which a Dataflow Job runs. Enums: - FlexResourceSchedulingGoalValueValuesEnum: Which Flexible Resource - Scheduling mode to run in. + FlexResourceSchedulingGoalValueValuesEnum: Optional. Which Flexible + Resource Scheduling mode to run in. ShuffleModeValueValuesEnum: Output only. The shuffle mode used for the job. + StreamingModeValueValuesEnum: Optional. Specifies the Streaming Engine + message processing guarantees. Reduces cost and latency but might result + in duplicate messages committed to storage. Designed to run simple + mapping streaming ETL jobs at the lowest cost. For example, Change Data + Capture (CDC) to BigQuery is a canonical use case. For more information, + see [Set the pipeline streaming + mode](https://cloud.google.com/dataflow/docs/guides/streaming-modes). Messages: InternalExperimentsValue: Experimental settings. @@ -1903,31 +2074,38 @@ class Environment(_messages.Message): unknown or unspecified, the service will attempt to choose a reasonable default. This should be in the form of the API service name, e.g. "compute.googleapis.com". - dataset: The dataset for the current project where various workflow - related tables are stored. The supported resource type is: Google - BigQuery: bigquery.googleapis.com/{dataset} - debugOptions: Any debugging options to be supplied to the job. + dataset: Optional. The dataset for the current project where various + workflow related tables are stored. The supported resource type is: + Google BigQuery: bigquery.googleapis.com/{dataset} + debugOptions: Optional. Any debugging options to be supplied to the job. experiments: The list of experiments to enable. This field should be used for SDK related experiments and not for service related experiments. The proper field for service related experiments is service_options. - flexResourceSchedulingGoal: Which Flexible Resource Scheduling mode to run - in. + flexResourceSchedulingGoal: Optional. Which Flexible Resource Scheduling + mode to run in. internalExperiments: Experimental settings. sdkPipelineOptions: The Cloud Dataflow SDK pipeline options specified by the user. These options are passed through the service and are used to recreate the SDK pipeline options on the worker in a language agnostic and platform independent way. - serviceAccountEmail: Identity to run virtual machines as. Defaults to the - default account. - serviceKmsKeyName: If set, contains the Cloud KMS key identifier used to - encrypt data at rest, AKA a Customer Managed Encryption Key (CMEK). - Format: + serviceAccountEmail: Optional. Identity to run virtual machines as. + Defaults to the default account. + serviceKmsKeyName: Optional. If set, contains the Cloud KMS key identifier + used to encrypt data at rest, AKA a Customer Managed Encryption Key + (CMEK). Format: projects/PROJECT_ID/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY - serviceOptions: The list of service options to enable. This field should - be used for service related experiments only. These experiments, when - graduating to GA, should be replaced by dedicated fields or become - default (i.e. always on). + serviceOptions: Optional. The list of service options to enable. This + field should be used for service related experiments only. These + experiments, when graduating to GA, should be replaced by dedicated + fields or become default (i.e. always on). shuffleMode: Output only. The shuffle mode used for the job. + streamingMode: Optional. Specifies the Streaming Engine message processing + guarantees. Reduces cost and latency but might result in duplicate + messages committed to storage. Designed to run simple mapping streaming + ETL jobs at the lowest cost. For example, Change Data Capture (CDC) to + BigQuery is a canonical use case. For more information, see [Set the + pipeline streaming + mode](https://cloud.google.com/dataflow/docs/guides/streaming-modes). tempStoragePrefix: The prefix of the resources the system should use for temporary storage. The system will append the suffix "/temp-{JOBNAME} to this resource prefix, where {JOBNAME} is the value of the job_name @@ -1937,17 +2115,19 @@ class Environment(_messages.Message): The supported resource type is: Google Cloud Storage: storage.googleapis.com/{bucket}/{object} bucket.storage.googleapis.com/{object} + useStreamingEngineResourceBasedBilling: Output only. Whether the job uses + the Streaming Engine resource-based billing model. userAgent: A description of the process that generated the request. version: A structure describing which components and their versions of the service are required in order to run the job. workerPools: The worker pools. At least one "harness" worker pool must be specified in order for the job to have workers. - workerRegion: The Compute Engine region + workerRegion: Optional. The Compute Engine region (https://cloud.google.com/compute/docs/regions-zones/regions-zones) in which worker processing should occur, e.g. "us-west1". Mutually exclusive with worker_zone. If neither worker_region nor worker_zone is specified, default to the control plane's region. - workerZone: The Compute Engine zone + workerZone: Optional. The Compute Engine zone (https://cloud.google.com/compute/docs/regions-zones/regions-zones) in which worker processing should occur, e.g. "us-west1-a". Mutually exclusive with worker_region. If neither worker_region nor worker_zone @@ -1955,7 +2135,7 @@ class Environment(_messages.Message): available capacity. """ class FlexResourceSchedulingGoalValueValuesEnum(_messages.Enum): - r"""Which Flexible Resource Scheduling mode to run in. + r"""Optional. Which Flexible Resource Scheduling mode to run in. Values: FLEXRS_UNSPECIFIED: Run in the default mode. @@ -1978,6 +2158,29 @@ class ShuffleModeValueValuesEnum(_messages.Enum): VM_BASED = 1 SERVICE_BASED = 2 + class StreamingModeValueValuesEnum(_messages.Enum): + r"""Optional. Specifies the Streaming Engine message processing + guarantees. Reduces cost and latency but might result in duplicate + messages committed to storage. Designed to run simple mapping streaming + ETL jobs at the lowest cost. For example, Change Data Capture (CDC) to + BigQuery is a canonical use case. For more information, see [Set the + pipeline streaming + mode](https://cloud.google.com/dataflow/docs/guides/streaming-modes). + + Values: + STREAMING_MODE_UNSPECIFIED: Run in the default mode. + STREAMING_MODE_EXACTLY_ONCE: In this mode, message deduplication is + performed against persistent state to make sure each message is + processed and committed to storage exactly once. + STREAMING_MODE_AT_LEAST_ONCE: Message deduplication is not performed. + Messages might be processed multiple times, and the results are + applied multiple times. Note: Setting this value also enables + Streaming Engine and Streaming Engine resource-based billing. + """ + STREAMING_MODE_UNSPECIFIED = 0 + STREAMING_MODE_EXACTLY_ONCE = 1 + STREAMING_MODE_AT_LEAST_ONCE = 2 + @encoding.MapUnrecognizedFields('additionalProperties') class InternalExperimentsValue(_messages.Message): r"""Experimental settings. @@ -2093,12 +2296,14 @@ class AdditionalProperty(_messages.Message): serviceKmsKeyName = _messages.StringField(9) serviceOptions = _messages.StringField(10, repeated=True) shuffleMode = _messages.EnumField('ShuffleModeValueValuesEnum', 11) - tempStoragePrefix = _messages.StringField(12) - userAgent = _messages.MessageField('UserAgentValue', 13) - version = _messages.MessageField('VersionValue', 14) - workerPools = _messages.MessageField('WorkerPool', 15, repeated=True) - workerRegion = _messages.StringField(16) - workerZone = _messages.StringField(17) + streamingMode = _messages.EnumField('StreamingModeValueValuesEnum', 12) + tempStoragePrefix = _messages.StringField(13) + useStreamingEngineResourceBasedBilling = _messages.BooleanField(14) + userAgent = _messages.MessageField('UserAgentValue', 15) + version = _messages.MessageField('VersionValue', 16) + workerPools = _messages.MessageField('WorkerPool', 17, repeated=True) + workerRegion = _messages.StringField(18) + workerZone = _messages.StringField(19) class ExecutionStageState(_messages.Message): @@ -2281,12 +2486,20 @@ class FlattenInstruction(_messages.Message): class FlexTemplateRuntimeEnvironment(_messages.Message): r"""The environment values to be set at runtime for flex template. + LINT.IfChange Enums: AutoscalingAlgorithmValueValuesEnum: The algorithm to use for autoscaling FlexrsGoalValueValuesEnum: Set FlexRS goal for the job. https://cloud.google.com/dataflow/docs/guides/flexrs IpConfigurationValueValuesEnum: Configuration for VM IPs. + StreamingModeValueValuesEnum: Optional. Specifies the Streaming Engine + message processing guarantees. Reduces cost and latency but might result + in duplicate messages committed to storage. Designed to run simple + mapping streaming ETL jobs at the lowest cost. For example, Change Data + Capture (CDC) to BigQuery is a canonical use case. For more information, + see [Set the pipeline streaming + mode](https://cloud.google.com/dataflow/docs/guides/streaming-modes). Messages: AdditionalUserLabelsValue: Additional user labels to be specified for the @@ -2304,10 +2517,15 @@ class FlexTemplateRuntimeEnvironment(_messages.Message): value pairs. Example: { "name": "wrench", "mass": "1kg", "count": "3" }. autoscalingAlgorithm: The algorithm to use for autoscaling diskSizeGb: Worker disk size, in gigabytes. - dumpHeapOnOom: If true, save a heap dump before killing a thread or - process which is GC thrashing or out of memory. The location of the heap - file will either be echoed back to the user, or the user will be given - the opportunity to download the heap file. + dumpHeapOnOom: If true, when processing time is spent almost entirely on + garbage collection (GC), saves a heap dump before ending the thread or + process. If false, ends the thread or process without saving a heap + dump. Does not save a heap dump when the Java Virtual Machine (JVM) has + an out of memory error during processing. The location of the heap file + is either echoed back to the user, or the user is given the opportunity + to download the heap file. + enableLauncherVmSerialPortLogging: If true serial port logging will be + enabled for the launcher VM. enableStreamingEngine: Whether to enable Streaming Engine for the job. flexrsGoal: Set FlexRS goal for the job. https://cloud.google.com/dataflow/docs/guides/flexrs @@ -2325,8 +2543,8 @@ class FlexTemplateRuntimeEnvironment(_messages.Message): numWorkers: The initial number of Google Compute Engine instances for the job. saveHeapDumpsToGcsPath: Cloud Storage bucket (directory) to upload heap - dumps to the given location. Enabling this implies that heap dumps - should be generated on OOM (dump_heap_on_oom is set to true). + dumps to. Enabling this field implies that `dump_heap_on_oom` is set to + true. sdkContainerImage: Docker registry location of container image to use for the 'worker harness. Default is the container for the version of the SDK. Note this field is only valid for portable pipelines. @@ -2334,6 +2552,13 @@ class FlexTemplateRuntimeEnvironment(_messages.Message): job as. stagingLocation: The Cloud Storage path for staging local files. Must be a valid Cloud Storage URL, beginning with `gs://`. + streamingMode: Optional. Specifies the Streaming Engine message processing + guarantees. Reduces cost and latency but might result in duplicate + messages committed to storage. Designed to run simple mapping streaming + ETL jobs at the lowest cost. For example, Change Data Capture (CDC) to + BigQuery is a canonical use case. For more information, see [Set the + pipeline streaming + mode](https://cloud.google.com/dataflow/docs/guides/streaming-modes). subnetwork: Subnetwork to which VMs will be assigned, if desired. You can specify a subnetwork using either a complete URL or an abbreviated path. Expected to be of the form "https://www.googleapis.com/compute/v1/projec @@ -2397,6 +2622,29 @@ class IpConfigurationValueValuesEnum(_messages.Enum): WORKER_IP_PUBLIC = 1 WORKER_IP_PRIVATE = 2 + class StreamingModeValueValuesEnum(_messages.Enum): + r"""Optional. Specifies the Streaming Engine message processing + guarantees. Reduces cost and latency but might result in duplicate + messages committed to storage. Designed to run simple mapping streaming + ETL jobs at the lowest cost. For example, Change Data Capture (CDC) to + BigQuery is a canonical use case. For more information, see [Set the + pipeline streaming + mode](https://cloud.google.com/dataflow/docs/guides/streaming-modes). + + Values: + STREAMING_MODE_UNSPECIFIED: Run in the default mode. + STREAMING_MODE_EXACTLY_ONCE: In this mode, message deduplication is + performed against persistent state to make sure each message is + processed and committed to storage exactly once. + STREAMING_MODE_AT_LEAST_ONCE: Message deduplication is not performed. + Messages might be processed multiple times, and the results are + applied multiple times. Note: Setting this value also enables + Streaming Engine and Streaming Engine resource-based billing. + """ + STREAMING_MODE_UNSPECIFIED = 0 + STREAMING_MODE_EXACTLY_ONCE = 1 + STREAMING_MODE_AT_LEAST_ONCE = 2 + @encoding.MapUnrecognizedFields('additionalProperties') class AdditionalUserLabelsValue(_messages.Message): r"""Additional user labels to be specified for the job. Keys and values @@ -2433,24 +2681,26 @@ class AdditionalProperty(_messages.Message): 'AutoscalingAlgorithmValueValuesEnum', 3) diskSizeGb = _messages.IntegerField(4, variant=_messages.Variant.INT32) dumpHeapOnOom = _messages.BooleanField(5) - enableStreamingEngine = _messages.BooleanField(6) - flexrsGoal = _messages.EnumField('FlexrsGoalValueValuesEnum', 7) - ipConfiguration = _messages.EnumField('IpConfigurationValueValuesEnum', 8) - kmsKeyName = _messages.StringField(9) - launcherMachineType = _messages.StringField(10) - machineType = _messages.StringField(11) - maxWorkers = _messages.IntegerField(12, variant=_messages.Variant.INT32) - network = _messages.StringField(13) - numWorkers = _messages.IntegerField(14, variant=_messages.Variant.INT32) - saveHeapDumpsToGcsPath = _messages.StringField(15) - sdkContainerImage = _messages.StringField(16) - serviceAccountEmail = _messages.StringField(17) - stagingLocation = _messages.StringField(18) - subnetwork = _messages.StringField(19) - tempLocation = _messages.StringField(20) - workerRegion = _messages.StringField(21) - workerZone = _messages.StringField(22) - zone = _messages.StringField(23) + enableLauncherVmSerialPortLogging = _messages.BooleanField(6) + enableStreamingEngine = _messages.BooleanField(7) + flexrsGoal = _messages.EnumField('FlexrsGoalValueValuesEnum', 8) + ipConfiguration = _messages.EnumField('IpConfigurationValueValuesEnum', 9) + kmsKeyName = _messages.StringField(10) + launcherMachineType = _messages.StringField(11) + machineType = _messages.StringField(12) + maxWorkers = _messages.IntegerField(13, variant=_messages.Variant.INT32) + network = _messages.StringField(14) + numWorkers = _messages.IntegerField(15, variant=_messages.Variant.INT32) + saveHeapDumpsToGcsPath = _messages.StringField(16) + sdkContainerImage = _messages.StringField(17) + serviceAccountEmail = _messages.StringField(18) + stagingLocation = _messages.StringField(19) + streamingMode = _messages.EnumField('StreamingModeValueValuesEnum', 20) + subnetwork = _messages.StringField(21) + tempLocation = _messages.StringField(22) + workerRegion = _messages.StringField(23) + workerZone = _messages.StringField(24) + zone = _messages.StringField(25) class FloatingPointList(_messages.Message): @@ -2556,6 +2806,46 @@ class Histogram(_messages.Message): firstBucketOffset = _messages.IntegerField(2, variant=_messages.Variant.INT32) +class HotKeyDebuggingInfo(_messages.Message): + r"""Information useful for debugging a hot key detection. + + Messages: + DetectedHotKeysValue: Debugging information for each detected hot key. + Keyed by a hash of the key. + + Fields: + detectedHotKeys: Debugging information for each detected hot key. Keyed by + a hash of the key. + """ + @encoding.MapUnrecognizedFields('additionalProperties') + class DetectedHotKeysValue(_messages.Message): + r"""Debugging information for each detected hot key. Keyed by a hash of + the key. + + Messages: + AdditionalProperty: An additional property for a DetectedHotKeysValue + object. + + Fields: + additionalProperties: Additional properties of type DetectedHotKeysValue + """ + class AdditionalProperty(_messages.Message): + r"""An additional property for a DetectedHotKeysValue object. + + Fields: + key: Name of the additional property. + value: A HotKeyInfo attribute. + """ + + key = _messages.StringField(1) + value = _messages.MessageField('HotKeyInfo', 2) + + additionalProperties = _messages.MessageField( + 'AdditionalProperty', 1, repeated=True) + + detectedHotKeys = _messages.MessageField('DetectedHotKeysValue', 1) + + class HotKeyDetection(_messages.Message): r"""Proto describing a hot key detected on a given WorkItem. @@ -2572,6 +2862,25 @@ class HotKeyDetection(_messages.Message): userStepName = _messages.StringField(3) +class HotKeyInfo(_messages.Message): + r"""Information about a hot key. + + Fields: + hotKeyAge: The age of the hot key measured from when it was first + detected. + key: A detected hot key that is causing limited parallelism. This field + will be populated only if the following flag is set to true: "-- + enable_hot_key_logging". + keyTruncated: If true, then the above key is truncated and cannot be + deserialized. This occurs if the key above is populated and the key size + is >5MB. + """ + + hotKeyAge = _messages.StringField(1) + key = _messages.StringField(2) + keyTruncated = _messages.BooleanField(3) + + class InstructionInput(_messages.Message): r"""An input of an instruction, as a reference to an output of a producer instruction. @@ -2676,22 +2985,25 @@ class IntegerMean(_messages.Message): class Job(_messages.Message): - r"""Defines a job to be run by the Cloud Dataflow service. + r"""Defines a job to be run by the Cloud Dataflow service. Do not enter + confidential information when you supply string values using the API. Enums: CurrentStateValueValuesEnum: The current state of the job. Jobs are created in the `JOB_STATE_STOPPED` state unless otherwise specified. A job in the `JOB_STATE_RUNNING` state may asynchronously enter a terminal state. After a job has reached a terminal state, no further state - updates may be made. This field may be mutated by the Cloud Dataflow + updates may be made. This field might be mutated by the Dataflow service; callers cannot mutate it. - RequestedStateValueValuesEnum: The job's requested state. `UpdateJob` may - be used to switch between the `JOB_STATE_STOPPED` and - `JOB_STATE_RUNNING` states, by setting requested_state. `UpdateJob` may - also be used to directly set a job's requested state to - `JOB_STATE_CANCELLED` or `JOB_STATE_DONE`, irrevocably terminating the - job if it has not already reached a terminal state. - TypeValueValuesEnum: The type of Cloud Dataflow job. + RequestedStateValueValuesEnum: The job's requested state. Applies to + `UpdateJob` requests. Set `requested_state` with `UpdateJob` requests to + switch between the states `JOB_STATE_STOPPED` and `JOB_STATE_RUNNING`. + You can also use `UpdateJob` requests to change a job's state from + `JOB_STATE_RUNNING` to `JOB_STATE_CANCELLED`, `JOB_STATE_DONE`, or + `JOB_STATE_DRAINED`. These states irrevocably terminate the job if it + hasn't already reached a terminal state. This field has no effect on + `CreateJob` requests. + TypeValueValuesEnum: Optional. The type of Dataflow job. Messages: LabelsValue: User-defined labels for this job. The labels map can contain @@ -2700,8 +3012,9 @@ class Job(_messages.Message): \p{Ll}\p{Lo}{0,62} * Values must conform to regexp: [\p{Ll}\p{Lo}\p{N}_-]{0,63} * Both keys and values are additionally constrained to be <= 128 bytes in size. - TransformNameMappingValue: The map of transform name prefixes of the job - to be replaced to the corresponding name prefixes of the new job. + TransformNameMappingValue: Optional. The map of transform name prefixes of + the job to be replaced to the corresponding name prefixes of the new + job. Fields: clientRequestId: The client's unique identifier of the job, re-used across @@ -2719,14 +3032,13 @@ class Job(_messages.Message): `JOB_STATE_STOPPED` state unless otherwise specified. A job in the `JOB_STATE_RUNNING` state may asynchronously enter a terminal state. After a job has reached a terminal state, no further state updates may - be made. This field may be mutated by the Cloud Dataflow service; - callers cannot mutate it. + be made. This field might be mutated by the Dataflow service; callers + cannot mutate it. currentStateTime: The timestamp associated with the current state. - environment: The environment for the job. + environment: Optional. The environment for the job. executionInfo: Deprecated. - id: The unique ID of this job. This field is set by the Cloud Dataflow - service when the Job is created, and is immutable for the life of the - job. + id: The unique ID of this job. This field is set by the Dataflow service + when the job is created, and is immutable for the life of the job. jobMetadata: This field is populated by the Dataflow service to support filtering jobs by the metadata values provided here. Populated for ListJobs and all GetJob views SUMMARY and higher. @@ -2736,33 +3048,44 @@ class Job(_messages.Message): \p{Ll}\p{Lo}{0,62} * Values must conform to regexp: [\p{Ll}\p{Lo}\p{N}_-]{0,63} * Both keys and values are additionally constrained to be <= 128 bytes in size. - location: The [regional endpoint] + location: Optional. The [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints) that contains this job. - name: The user-specified Cloud Dataflow job name. Only one Job with a - given name may exist in a project at any given time. If a caller - attempts to create a Job with the same name as an already-existing Job, - the attempt returns the existing Job. The name must match the regular - expression `[a-z]([-a-z0-9]{0,1022}[a-z0-9])?` + name: Optional. The user-specified Dataflow job name. Only one active job + with a given name can exist in a project within one region at any given + time. Jobs in different regions can have the same name. If a caller + attempts to create a job with the same name as an active job that + already exists, the attempt returns the existing job. The name must + match the regular expression `[a-z]([-a-z0-9]{0,1022}[a-z0-9])?` pipelineDescription: Preliminary field: The format of this data may change at any time. A description of the user pipeline and stages through which it is executed. Created by Cloud Dataflow service. Only retrieved with JOB_VIEW_DESCRIPTION or JOB_VIEW_ALL. - projectId: The ID of the Cloud Platform project that the job belongs to. + projectId: The ID of the Google Cloud project that the job belongs to. replaceJobId: If this job is an update of an existing job, this field is the job ID of the job it replaced. When sending a `CreateJobRequest`, you can update a job by specifying it here. The job named here is stopped, and its intermediate state is transferred to this job. replacedByJobId: If another job is an update of this job (and thus, this job is in `JOB_STATE_UPDATED`), this field contains the ID of that job. - requestedState: The job's requested state. `UpdateJob` may be used to - switch between the `JOB_STATE_STOPPED` and `JOB_STATE_RUNNING` states, - by setting requested_state. `UpdateJob` may also be used to directly set - a job's requested state to `JOB_STATE_CANCELLED` or `JOB_STATE_DONE`, - irrevocably terminating the job if it has not already reached a terminal - state. + requestedState: The job's requested state. Applies to `UpdateJob` + requests. Set `requested_state` with `UpdateJob` requests to switch + between the states `JOB_STATE_STOPPED` and `JOB_STATE_RUNNING`. You can + also use `UpdateJob` requests to change a job's state from + `JOB_STATE_RUNNING` to `JOB_STATE_CANCELLED`, `JOB_STATE_DONE`, or + `JOB_STATE_DRAINED`. These states irrevocably terminate the job if it + hasn't already reached a terminal state. This field has no effect on + `CreateJob` requests. + runtimeUpdatableParams: This field may ONLY be modified at runtime using + the projects.jobs.update method to adjust job behavior. This field has + no effect when specified at job creation. + satisfiesPzi: Output only. Reserved for future use. This field is set only + in responses from the server; it is ignored if it is set in any + requests. satisfiesPzs: Reserved for future use. This field is set only in responses from the server; it is ignored if it is set in any requests. + serviceResources: Output only. Resources used by the Dataflow Service to + run the job. stageStates: This field may be mutated by the Cloud Dataflow service; callers cannot mutate it. startTime: The timestamp when the job was started (transitioned to @@ -2781,17 +3104,17 @@ class Job(_messages.Message): The supported files are: Google Cloud Storage: storage.googleapis.com/{bucket}/{object} bucket.storage.googleapis.com/{object} - transformNameMapping: The map of transform name prefixes of the job to be - replaced to the corresponding name prefixes of the new job. - type: The type of Cloud Dataflow job. + transformNameMapping: Optional. The map of transform name prefixes of the + job to be replaced to the corresponding name prefixes of the new job. + type: Optional. The type of Dataflow job. """ class CurrentStateValueValuesEnum(_messages.Enum): r"""The current state of the job. Jobs are created in the `JOB_STATE_STOPPED` state unless otherwise specified. A job in the `JOB_STATE_RUNNING` state may asynchronously enter a terminal state. After a job has reached a terminal state, no further state updates may be made. - This field may be mutated by the Cloud Dataflow service; callers cannot - mutate it. + This field might be mutated by the Dataflow service; callers cannot mutate + it. Values: JOB_STATE_UNKNOWN: The job's run state isn't specified. @@ -2859,11 +3182,13 @@ class CurrentStateValueValuesEnum(_messages.Enum): JOB_STATE_RESOURCE_CLEANING_UP = 12 class RequestedStateValueValuesEnum(_messages.Enum): - r"""The job's requested state. `UpdateJob` may be used to switch between - the `JOB_STATE_STOPPED` and `JOB_STATE_RUNNING` states, by setting - requested_state. `UpdateJob` may also be used to directly set a job's - requested state to `JOB_STATE_CANCELLED` or `JOB_STATE_DONE`, irrevocably - terminating the job if it has not already reached a terminal state. + r"""The job's requested state. Applies to `UpdateJob` requests. Set + `requested_state` with `UpdateJob` requests to switch between the states + `JOB_STATE_STOPPED` and `JOB_STATE_RUNNING`. You can also use `UpdateJob` + requests to change a job's state from `JOB_STATE_RUNNING` to + `JOB_STATE_CANCELLED`, `JOB_STATE_DONE`, or `JOB_STATE_DRAINED`. These + states irrevocably terminate the job if it hasn't already reached a + terminal state. This field has no effect on `CreateJob` requests. Values: JOB_STATE_UNKNOWN: The job's run state isn't specified. @@ -2931,7 +3256,7 @@ class RequestedStateValueValuesEnum(_messages.Enum): JOB_STATE_RESOURCE_CLEANING_UP = 12 class TypeValueValuesEnum(_messages.Enum): - r"""The type of Cloud Dataflow job. + r"""Optional. The type of Dataflow job. Values: JOB_TYPE_UNKNOWN: The type of the job is unspecified, or unknown. @@ -2975,8 +3300,8 @@ class AdditionalProperty(_messages.Message): @encoding.MapUnrecognizedFields('additionalProperties') class TransformNameMappingValue(_messages.Message): - r"""The map of transform name prefixes of the job to be replaced to the - corresponding name prefixes of the new job. + r"""Optional. The map of transform name prefixes of the job to be replaced + to the corresponding name prefixes of the new job. Messages: AdditionalProperty: An additional property for a @@ -3017,14 +3342,17 @@ class AdditionalProperty(_messages.Message): replaceJobId = _messages.StringField(15) replacedByJobId = _messages.StringField(16) requestedState = _messages.EnumField('RequestedStateValueValuesEnum', 17) - satisfiesPzs = _messages.BooleanField(18) - stageStates = _messages.MessageField('ExecutionStageState', 19, repeated=True) - startTime = _messages.StringField(20) - steps = _messages.MessageField('Step', 21, repeated=True) - stepsLocation = _messages.StringField(22) - tempFiles = _messages.StringField(23, repeated=True) - transformNameMapping = _messages.MessageField('TransformNameMappingValue', 24) - type = _messages.EnumField('TypeValueValuesEnum', 25) + runtimeUpdatableParams = _messages.MessageField('RuntimeUpdatableParams', 18) + satisfiesPzi = _messages.BooleanField(19) + satisfiesPzs = _messages.BooleanField(20) + serviceResources = _messages.MessageField('ServiceResources', 21) + stageStates = _messages.MessageField('ExecutionStageState', 22, repeated=True) + startTime = _messages.StringField(23) + steps = _messages.MessageField('Step', 24, repeated=True) + stepsLocation = _messages.StringField(25) + tempFiles = _messages.StringField(26, repeated=True) + transformNameMapping = _messages.MessageField('TransformNameMappingValue', 27) + type = _messages.EnumField('TypeValueValuesEnum', 28) class JobExecutionDetails(_messages.Message): @@ -3150,6 +3478,10 @@ class JobMetadata(_messages.Message): r"""Metadata available primarily for filtering jobs. Will be included in the ListJob response and Job SUMMARY view. + Messages: + UserDisplayPropertiesValue: List of display properties to help UI filter + jobs. + Fields: bigTableDetails: Identification of a Cloud Bigtable source used in the Dataflow job. @@ -3163,7 +3495,33 @@ class JobMetadata(_messages.Message): sdkVersion: The SDK version used to run the job. spannerDetails: Identification of a Spanner source used in the Dataflow job. + userDisplayProperties: List of display properties to help UI filter jobs. """ + @encoding.MapUnrecognizedFields('additionalProperties') + class UserDisplayPropertiesValue(_messages.Message): + r"""List of display properties to help UI filter jobs. + + Messages: + AdditionalProperty: An additional property for a + UserDisplayPropertiesValue object. + + Fields: + additionalProperties: Additional properties of type + UserDisplayPropertiesValue + """ + class AdditionalProperty(_messages.Message): + r"""An additional property for a UserDisplayPropertiesValue object. + + Fields: + key: Name of the additional property. + value: A string attribute. + """ + + key = _messages.StringField(1) + value = _messages.StringField(2) + + additionalProperties = _messages.MessageField( + 'AdditionalProperty', 1, repeated=True) bigTableDetails = _messages.MessageField( 'BigTableIODetails', 1, repeated=True) @@ -3175,14 +3533,18 @@ class JobMetadata(_messages.Message): pubsubDetails = _messages.MessageField('PubSubIODetails', 5, repeated=True) sdkVersion = _messages.MessageField('SdkVersion', 6) spannerDetails = _messages.MessageField('SpannerIODetails', 7, repeated=True) + userDisplayProperties = _messages.MessageField( + 'UserDisplayPropertiesValue', 8) class JobMetrics(_messages.Message): r"""JobMetrics contains a collection of metrics describing the detailed progress of a Dataflow job. Metrics correspond to user-defined and system- - defined metrics in the job. This resource captures only the most recent - values of each metric; time-series data can be queried for them (under the - same metric names) from Cloud Monitoring. + defined metrics in the job. For more information, see [Dataflow job metrics] + (https://cloud.google.com/dataflow/docs/guides/using-monitoring-intf). This + resource captures only the most recent values of each metric; time-series + data can be queried for them (under the same metric names) from Cloud + Monitoring. Fields: metricTime: Timestamp as of which metric values are current. @@ -3380,7 +3742,10 @@ class LaunchFlexTemplateResponse(_messages.Message): class LaunchTemplateParameters(_messages.Message): - r"""Parameters to provide to the template being launched. + r"""Parameters to provide to the template being launched. Note that the + [metadata in the pipeline code] + (https://cloud.google.com/dataflow/docs/guides/templates/creating- + templates#metadata) determines which runtime parameters are valid. Messages: ParametersValue: The runtime parameters to pass to the job. @@ -3390,7 +3755,8 @@ class LaunchTemplateParameters(_messages.Message): Fields: environment: The runtime environment for the job. - jobName: Required. The job name to use for the created job. + jobName: Required. The job name to use for the created job. The name must + match the regular expression `[a-z]([-a-z0-9]{0,1022}[a-z0-9])?` parameters: The runtime parameters to pass to the job. transformNameMapping: Only applicable when updating a pipeline. Map of transform name prefixes of the job to be replaced to the corresponding @@ -3567,6 +3933,21 @@ class AdditionalProperty(_messages.Message): workItems = _messages.MessageField('WorkItem', 2, repeated=True) +class Linear(_messages.Message): + r"""Linear buckets with the following boundaries for indices in 0 to n-1. - + i in [0, n-1]: [start + (i)*width, start + (i+1)*width) + + Fields: + numberOfBuckets: Must be greater than 0. + start: Lower bound of the first bucket. + width: Distance between bucket boundaries. Must be greater than 0. + """ + + numberOfBuckets = _messages.IntegerField(1, variant=_messages.Variant.INT32) + start = _messages.FloatField(2) + width = _messages.FloatField(3) + + class ListJobMessagesResponse(_messages.Message): r"""Response to a request to list job messages. @@ -3778,6 +4159,49 @@ class MetricUpdate(_messages.Message): updateTime = _messages.StringField(11) +class MetricValue(_messages.Message): + r"""The value of a metric along with its name and labels. + + Messages: + MetricLabelsValue: Optional. Set of metric labels for this metric. + + Fields: + metric: Base name for this metric. + metricLabels: Optional. Set of metric labels for this metric. + valueHistogram: Histogram value of this metric. + valueInt64: Integer value of this metric. + """ + @encoding.MapUnrecognizedFields('additionalProperties') + class MetricLabelsValue(_messages.Message): + r"""Optional. Set of metric labels for this metric. + + Messages: + AdditionalProperty: An additional property for a MetricLabelsValue + object. + + Fields: + additionalProperties: Additional properties of type MetricLabelsValue + """ + class AdditionalProperty(_messages.Message): + r"""An additional property for a MetricLabelsValue object. + + Fields: + key: Name of the additional property. + value: A string attribute. + """ + + key = _messages.StringField(1) + value = _messages.StringField(2) + + additionalProperties = _messages.MessageField( + 'AdditionalProperty', 1, repeated=True) + + metric = _messages.StringField(1) + metricLabels = _messages.MessageField('MetricLabelsValue', 2) + valueHistogram = _messages.MessageField('DataflowHistogramValue', 3) + valueInt64 = _messages.IntegerField(4) + + class MountedDataDisk(_messages.Message): r"""Describes mounted data disk. @@ -3843,6 +4267,24 @@ class KindValueValuesEnum(_messages.Enum): name = _messages.StringField(2) +class OutlierStats(_messages.Message): + r"""Statistics for the underflow and overflow bucket. + + Fields: + overflowCount: Number of values that are larger than the upper bound of + the largest bucket. + overflowMean: Mean of values in the overflow bucket. + underflowCount: Number of values that are smaller than the lower bound of + the smallest bucket. + underflowMean: Mean of values in the undeflow bucket. + """ + + overflowCount = _messages.IntegerField(1) + overflowMean = _messages.FloatField(2) + underflowCount = _messages.IntegerField(3) + underflowMean = _messages.FloatField(4) + + class Package(_messages.Message): r"""The packages that must be installed in order for a worker to run the steps of the Cloud Dataflow job that will be assigned to its worker pool. @@ -3964,13 +4406,32 @@ class ParameterMetadata(_messages.Message): Fields: customMetadata: Optional. Additional metadata for describing this parameter. + defaultValue: Optional. The default values will pre-populate the parameter + with the given value from the proto. If default_value is left empty, the + parameter will be populated with a default of the relevant type, e.g. + false for a boolean. + enumOptions: Optional. The options shown when ENUM ParameterType is + specified. + groupName: Optional. Specifies a group name for this parameter to be + rendered under. Group header text will be rendered exactly as specified + in this field. Only considered when parent_name is NOT provided. helpText: Required. The help text to display for the parameter. + hiddenUi: Optional. Whether the parameter should be hidden in the UI. isOptional: Optional. Whether the parameter is optional. Defaults to false. label: Required. The label to display for the parameter. name: Required. The name of the parameter. paramType: Optional. The type of the parameter. Used for selecting input picker. + parentName: Optional. Specifies the name of the parent parameter. Used in + conjunction with 'parent_trigger_values' to make this parameter + conditional (will only be rendered conditionally). Should be mappable to + a ParameterMetadata.name field. + parentTriggerValues: Optional. The value(s) of the 'parent_name' parameter + which will trigger this parameter to be shown. If left empty, ANY non- + empty value in parent_name will trigger this parameter to be shown. Only + considered when this parameter is conditional (when 'parent_name' has + been provided). regexes: Optional. Regexes that the parameter must match. """ class ParamTypeValueValuesEnum(_messages.Enum): @@ -3993,6 +4454,25 @@ class ParamTypeValueValuesEnum(_messages.Enum): write to. PUBSUB_TOPIC: The parameter specifies a Pub/Sub Topic. PUBSUB_SUBSCRIPTION: The parameter specifies a Pub/Sub Subscription. + BIGQUERY_TABLE: The parameter specifies a BigQuery table. + JAVASCRIPT_UDF_FILE: The parameter specifies a JavaScript UDF in Cloud + Storage. + SERVICE_ACCOUNT: The parameter specifies a Service Account email. + MACHINE_TYPE: The parameter specifies a Machine Type. + KMS_KEY_NAME: The parameter specifies a KMS Key name. + WORKER_REGION: The parameter specifies a Worker Region. + WORKER_ZONE: The parameter specifies a Worker Zone. + BOOLEAN: The parameter specifies a boolean input. + ENUM: The parameter specifies an enum input. + NUMBER: The parameter specifies a number input. + KAFKA_TOPIC: Deprecated. Please use KAFKA_READ_TOPIC instead. + KAFKA_READ_TOPIC: The parameter specifies the fully-qualified name of an + Apache Kafka topic. This can be either a Google Managed Kafka topic or + a non-managed Kafka topic. + KAFKA_WRITE_TOPIC: The parameter specifies the fully-qualified name of + an Apache Kafka topic. This can be an existing Google Managed Kafka + topic, the name for a new Google Managed Kafka topic, or an existing + non-managed Kafka topic. """ DEFAULT = 0 TEXT = 1 @@ -4004,6 +4484,19 @@ class ParamTypeValueValuesEnum(_messages.Enum): GCS_WRITE_FOLDER = 7 PUBSUB_TOPIC = 8 PUBSUB_SUBSCRIPTION = 9 + BIGQUERY_TABLE = 10 + JAVASCRIPT_UDF_FILE = 11 + SERVICE_ACCOUNT = 12 + MACHINE_TYPE = 13 + KMS_KEY_NAME = 14 + WORKER_REGION = 15 + WORKER_ZONE = 16 + BOOLEAN = 17 + ENUM = 18 + NUMBER = 19 + KAFKA_TOPIC = 20 + KAFKA_READ_TOPIC = 21 + KAFKA_WRITE_TOPIC = 22 @encoding.MapUnrecognizedFields('additionalProperties') class CustomMetadataValue(_messages.Message): @@ -4031,12 +4524,33 @@ class AdditionalProperty(_messages.Message): 'AdditionalProperty', 1, repeated=True) customMetadata = _messages.MessageField('CustomMetadataValue', 1) - helpText = _messages.StringField(2) - isOptional = _messages.BooleanField(3) - label = _messages.StringField(4) - name = _messages.StringField(5) - paramType = _messages.EnumField('ParamTypeValueValuesEnum', 6) - regexes = _messages.StringField(7, repeated=True) + defaultValue = _messages.StringField(2) + enumOptions = _messages.MessageField( + 'ParameterMetadataEnumOption', 3, repeated=True) + groupName = _messages.StringField(4) + helpText = _messages.StringField(5) + hiddenUi = _messages.BooleanField(6) + isOptional = _messages.BooleanField(7) + label = _messages.StringField(8) + name = _messages.StringField(9) + paramType = _messages.EnumField('ParamTypeValueValuesEnum', 10) + parentName = _messages.StringField(11) + parentTriggerValues = _messages.StringField(12, repeated=True) + regexes = _messages.StringField(13, repeated=True) + + +class ParameterMetadataEnumOption(_messages.Message): + r"""ParameterMetadataEnumOption specifies the option shown in the enum form. + + Fields: + description: Optional. The description to display for the enum option. + label: Optional. The label to display for the enum option. + value: Required. The value of the enum option. + """ + + description = _messages.StringField(1) + label = _messages.StringField(2) + value = _messages.StringField(3) class PartialGroupByKeyInstruction(_messages.Message): @@ -4119,6 +4633,36 @@ class AdditionalProperty(_messages.Message): valueCombiningFn = _messages.MessageField('ValueCombiningFnValue', 6) +class PerStepNamespaceMetrics(_messages.Message): + r"""Metrics for a particular unfused step and namespace. A metric is + uniquely identified by the `metrics_namespace`, `original_step`, `metric + name` and `metric_labels`. + + Fields: + metricValues: Optional. Metrics that are recorded for this namespace and + unfused step. + metricsNamespace: The namespace of these metrics on the worker. + originalStep: The original system name of the unfused step that these + metrics are reported from. + """ + + metricValues = _messages.MessageField('MetricValue', 1, repeated=True) + metricsNamespace = _messages.StringField(2) + originalStep = _messages.StringField(3) + + +class PerWorkerMetrics(_messages.Message): + r"""Per worker metrics. + + Fields: + perStepNamespaceMetrics: Optional. Metrics for a particular unfused step + and namespace. + """ + + perStepNamespaceMetrics = _messages.MessageField( + 'PerStepNamespaceMetrics', 1, repeated=True) + + class PipelineDescription(_messages.Message): r"""A descriptive representation of submitted pipeline as well as the executed form. This data is provided by the Dataflow service for ease of @@ -4130,6 +4674,8 @@ class PipelineDescription(_messages.Message): pipeline. originalPipelineTransform: Description of each transform in the pipeline and collections between them. + stepNamesHash: A hash value of the submitted pipeline portable graph step + names if exists. """ displayData = _messages.MessageField('DisplayData', 1, repeated=True) @@ -4137,6 +4683,7 @@ class PipelineDescription(_messages.Message): 'ExecutionStageSummary', 2, repeated=True) originalPipelineTransform = _messages.MessageField( 'TransformSummary', 3, repeated=True) + stepNamesHash = _messages.StringField(4) class Point(_messages.Message): @@ -4207,6 +4754,8 @@ class PubsubLocation(_messages.Message): Fields: dropLateData: Indicates whether the pipeline allows late-arriving data. + dynamicDestinations: If true, then this location represents dynamic + topics. idLabel: If set, contains a pubsub label from which to extract record ids. If left empty, record deduplication will be strictly best effort. subscription: A pubsub subscription, in the form of @@ -4222,12 +4771,13 @@ class PubsubLocation(_messages.Message): """ dropLateData = _messages.BooleanField(1) - idLabel = _messages.StringField(2) - subscription = _messages.StringField(3) - timestampLabel = _messages.StringField(4) - topic = _messages.StringField(5) - trackingSubscription = _messages.StringField(6) - withAttributes = _messages.BooleanField(7) + dynamicDestinations = _messages.BooleanField(2) + idLabel = _messages.StringField(3) + subscription = _messages.StringField(4) + timestampLabel = _messages.StringField(5) + topic = _messages.StringField(6) + trackingSubscription = _messages.StringField(7) + withAttributes = _messages.BooleanField(8) class PubsubSnapshotMetadata(_messages.Message): @@ -4244,31 +4794,6 @@ class PubsubSnapshotMetadata(_messages.Message): topicName = _messages.StringField(3) -class QueryInfo(_messages.Message): - r"""Information about a validated query. - - Enums: - QueryPropertyValueListEntryValuesEnum: - - Fields: - queryProperty: Includes an entry for each satisfied QueryProperty. - """ - class QueryPropertyValueListEntryValuesEnum(_messages.Enum): - r"""QueryPropertyValueListEntryValuesEnum enum type. - - Values: - QUERY_PROPERTY_UNSPECIFIED: The query property is unknown or - unspecified. - HAS_UNBOUNDED_SOURCE: Indicates this query reads from >= 1 unbounded - source. - """ - QUERY_PROPERTY_UNSPECIFIED = 0 - HAS_UNBOUNDED_SOURCE = 1 - - queryProperty = _messages.EnumField( - 'QueryPropertyValueListEntryValuesEnum', 1, repeated=True) - - class ReadInstruction(_messages.Message): r"""An instruction that reads records. Takes no inputs, produces one output. @@ -4448,69 +4973,88 @@ class ResourceUtilizationReportResponse(_messages.Message): class RuntimeEnvironment(_messages.Message): - r"""The environment values to set at runtime. + r"""The environment values to set at runtime. LINT.IfChange Enums: - IpConfigurationValueValuesEnum: Configuration for VM IPs. + IpConfigurationValueValuesEnum: Optional. Configuration for VM IPs. + StreamingModeValueValuesEnum: Optional. Specifies the Streaming Engine + message processing guarantees. Reduces cost and latency but might result + in duplicate messages committed to storage. Designed to run simple + mapping streaming ETL jobs at the lowest cost. For example, Change Data + Capture (CDC) to BigQuery is a canonical use case. For more information, + see [Set the pipeline streaming + mode](https://cloud.google.com/dataflow/docs/guides/streaming-modes). Messages: - AdditionalUserLabelsValue: Additional user labels to be specified for the - job. Keys and values should follow the restrictions specified in the - [labeling restrictions](https://cloud.google.com/compute/docs/labeling- + AdditionalUserLabelsValue: Optional. Additional user labels to be + specified for the job. Keys and values should follow the restrictions + specified in the [labeling + restrictions](https://cloud.google.com/compute/docs/labeling- resources#restrictions) page. An object containing a list of "key": value pairs. Example: { "name": "wrench", "mass": "1kg", "count": "3" }. Fields: - additionalExperiments: Additional experiment flags for the job, specified - with the `--experiments` option. - additionalUserLabels: Additional user labels to be specified for the job. - Keys and values should follow the restrictions specified in the + additionalExperiments: Optional. Additional experiment flags for the job, + specified with the `--experiments` option. + additionalUserLabels: Optional. Additional user labels to be specified for + the job. Keys and values should follow the restrictions specified in the [labeling restrictions](https://cloud.google.com/compute/docs/labeling- resources#restrictions) page. An object containing a list of "key": value pairs. Example: { "name": "wrench", "mass": "1kg", "count": "3" }. - bypassTempDirValidation: Whether to bypass the safety checks for the job's - temporary directory. Use with caution. - enableStreamingEngine: Whether to enable Streaming Engine for the job. - ipConfiguration: Configuration for VM IPs. - kmsKeyName: Name for the Cloud KMS key for the job. Key format is: - projects//locations//keyRings//cryptoKeys/ - machineType: The machine type to use for the job. Defaults to the value - from the template if not specified. - maxWorkers: The maximum number of Google Compute Engine instances to be - made available to your pipeline during execution, from 1 to 1000. - network: Network to which VMs will be assigned. If empty or unspecified, - the service will use the network "default". - numWorkers: The initial number of Google Compute Engine instnaces for the - job. - serviceAccountEmail: The email address of the service account to run the - job as. - subnetwork: Subnetwork to which VMs will be assigned, if desired. You can - specify a subnetwork using either a complete URL or an abbreviated path. - Expected to be of the form "https://www.googleapis.com/compute/v1/projec - ts/HOST_PROJECT_ID/regions/REGION/subnetworks/SUBNETWORK" or - "regions/REGION/subnetworks/SUBNETWORK". If the subnetwork is located in - a Shared VPC network, you must use the complete URL. - tempLocation: The Cloud Storage path to use for temporary files. Must be a - valid Cloud Storage URL, beginning with `gs://`. - workerRegion: The Compute Engine region + bypassTempDirValidation: Optional. Whether to bypass the safety checks for + the job's temporary directory. Use with caution. + diskSizeGb: Optional. The disk size, in gigabytes, to use on each remote + Compute Engine worker instance. + enableStreamingEngine: Optional. Whether to enable Streaming Engine for + the job. + ipConfiguration: Optional. Configuration for VM IPs. + kmsKeyName: Optional. Name for the Cloud KMS key for the job. Key format + is: projects//locations//keyRings//cryptoKeys/ + machineType: Optional. The machine type to use for the job. Defaults to + the value from the template if not specified. + maxWorkers: Optional. The maximum number of Google Compute Engine + instances to be made available to your pipeline during execution, from 1 + to 1000. The default value is 1. + network: Optional. Network to which VMs will be assigned. If empty or + unspecified, the service will use the network "default". + numWorkers: Optional. The initial number of Google Compute Engine + instances for the job. The default value is 11. + serviceAccountEmail: Optional. The email address of the service account to + run the job as. + streamingMode: Optional. Specifies the Streaming Engine message processing + guarantees. Reduces cost and latency but might result in duplicate + messages committed to storage. Designed to run simple mapping streaming + ETL jobs at the lowest cost. For example, Change Data Capture (CDC) to + BigQuery is a canonical use case. For more information, see [Set the + pipeline streaming + mode](https://cloud.google.com/dataflow/docs/guides/streaming-modes). + subnetwork: Optional. Subnetwork to which VMs will be assigned, if + desired. You can specify a subnetwork using either a complete URL or an + abbreviated path. Expected to be of the form "https://www.googleapis.com + /compute/v1/projects/HOST_PROJECT_ID/regions/REGION/subnetworks/SUBNETWO + RK" or "regions/REGION/subnetworks/SUBNETWORK". If the subnetwork is + located in a Shared VPC network, you must use the complete URL. + tempLocation: Required. The Cloud Storage path to use for temporary files. + Must be a valid Cloud Storage URL, beginning with `gs://`. + workerRegion: Required. The Compute Engine region (https://cloud.google.com/compute/docs/regions-zones/regions-zones) in which worker processing should occur, e.g. "us-west1". Mutually exclusive with worker_zone. If neither worker_region nor worker_zone is specified, default to the control plane's region. - workerZone: The Compute Engine zone + workerZone: Optional. The Compute Engine zone (https://cloud.google.com/compute/docs/regions-zones/regions-zones) in which worker processing should occur, e.g. "us-west1-a". Mutually exclusive with worker_region. If neither worker_region nor worker_zone is specified, a zone in the control plane's region is chosen based on available capacity. If both `worker_zone` and `zone` are set, `worker_zone` takes precedence. - zone: The Compute Engine [availability + zone: Optional. The Compute Engine [availability zone](https://cloud.google.com/compute/docs/regions-zones/regions-zones) for launching worker instances to run your pipeline. In the future, worker_zone will take precedence. """ class IpConfigurationValueValuesEnum(_messages.Enum): - r"""Configuration for VM IPs. + r"""Optional. Configuration for VM IPs. Values: WORKER_IP_UNSPECIFIED: The configuration is unknown, or unspecified. @@ -4521,10 +5065,33 @@ class IpConfigurationValueValuesEnum(_messages.Enum): WORKER_IP_PUBLIC = 1 WORKER_IP_PRIVATE = 2 + class StreamingModeValueValuesEnum(_messages.Enum): + r"""Optional. Specifies the Streaming Engine message processing + guarantees. Reduces cost and latency but might result in duplicate + messages committed to storage. Designed to run simple mapping streaming + ETL jobs at the lowest cost. For example, Change Data Capture (CDC) to + BigQuery is a canonical use case. For more information, see [Set the + pipeline streaming + mode](https://cloud.google.com/dataflow/docs/guides/streaming-modes). + + Values: + STREAMING_MODE_UNSPECIFIED: Run in the default mode. + STREAMING_MODE_EXACTLY_ONCE: In this mode, message deduplication is + performed against persistent state to make sure each message is + processed and committed to storage exactly once. + STREAMING_MODE_AT_LEAST_ONCE: Message deduplication is not performed. + Messages might be processed multiple times, and the results are + applied multiple times. Note: Setting this value also enables + Streaming Engine and Streaming Engine resource-based billing. + """ + STREAMING_MODE_UNSPECIFIED = 0 + STREAMING_MODE_EXACTLY_ONCE = 1 + STREAMING_MODE_AT_LEAST_ONCE = 2 + @encoding.MapUnrecognizedFields('additionalProperties') class AdditionalUserLabelsValue(_messages.Message): - r"""Additional user labels to be specified for the job. Keys and values - should follow the restrictions specified in the [labeling + r"""Optional. Additional user labels to be specified for the job. Keys and + values should follow the restrictions specified in the [labeling restrictions](https://cloud.google.com/compute/docs/labeling- resources#restrictions) page. An object containing a list of "key": value pairs. Example: { "name": "wrench", "mass": "1kg", "count": "3" }. @@ -4554,19 +5121,21 @@ class AdditionalProperty(_messages.Message): additionalExperiments = _messages.StringField(1, repeated=True) additionalUserLabels = _messages.MessageField('AdditionalUserLabelsValue', 2) bypassTempDirValidation = _messages.BooleanField(3) - enableStreamingEngine = _messages.BooleanField(4) - ipConfiguration = _messages.EnumField('IpConfigurationValueValuesEnum', 5) - kmsKeyName = _messages.StringField(6) - machineType = _messages.StringField(7) - maxWorkers = _messages.IntegerField(8, variant=_messages.Variant.INT32) - network = _messages.StringField(9) - numWorkers = _messages.IntegerField(10, variant=_messages.Variant.INT32) - serviceAccountEmail = _messages.StringField(11) - subnetwork = _messages.StringField(12) - tempLocation = _messages.StringField(13) - workerRegion = _messages.StringField(14) - workerZone = _messages.StringField(15) - zone = _messages.StringField(16) + diskSizeGb = _messages.IntegerField(4, variant=_messages.Variant.INT32) + enableStreamingEngine = _messages.BooleanField(5) + ipConfiguration = _messages.EnumField('IpConfigurationValueValuesEnum', 6) + kmsKeyName = _messages.StringField(7) + machineType = _messages.StringField(8) + maxWorkers = _messages.IntegerField(9, variant=_messages.Variant.INT32) + network = _messages.StringField(10) + numWorkers = _messages.IntegerField(11, variant=_messages.Variant.INT32) + serviceAccountEmail = _messages.StringField(12) + streamingMode = _messages.EnumField('StreamingModeValueValuesEnum', 13) + subnetwork = _messages.StringField(14) + tempLocation = _messages.StringField(15) + workerRegion = _messages.StringField(16) + workerZone = _messages.StringField(17) + zone = _messages.StringField(18) class RuntimeMetadata(_messages.Message): @@ -4581,6 +5150,29 @@ class RuntimeMetadata(_messages.Message): sdkInfo = _messages.MessageField('SDKInfo', 2) +class RuntimeUpdatableParams(_messages.Message): + r"""Additional job parameters that can only be updated during runtime using + the projects.jobs.update method. These fields have no effect when specified + during job creation. + + Fields: + maxNumWorkers: The maximum number of workers to cap autoscaling at. This + field is currently only supported for Streaming Engine jobs. + minNumWorkers: The minimum number of workers to scale down to. This field + is currently only supported for Streaming Engine jobs. + workerUtilizationHint: Target worker utilization, compared against the + aggregate utilization of the worker pool by autoscaler, to determine + upscaling and downscaling when absent other constraints such as backlog. + For more information, see [Update an existing + pipeline](https://cloud.google.com/dataflow/docs/guides/updating-a- + pipeline). + """ + + maxNumWorkers = _messages.IntegerField(1, variant=_messages.Variant.INT32) + minNumWorkers = _messages.IntegerField(2, variant=_messages.Variant.INT32) + workerUtilizationHint = _messages.FloatField(3) + + class SDKInfo(_messages.Message): r"""SDK Information. @@ -4598,22 +5190,75 @@ class LanguageValueValuesEnum(_messages.Enum): UNKNOWN: UNKNOWN Language. JAVA: Java. PYTHON: Python. + GO: Go. """ UNKNOWN = 0 JAVA = 1 PYTHON = 2 + GO = 3 language = _messages.EnumField('LanguageValueValuesEnum', 1) version = _messages.StringField(2) +class SdkBug(_messages.Message): + r"""A bug found in the Dataflow SDK. + + Enums: + SeverityValueValuesEnum: Output only. How severe the SDK bug is. + TypeValueValuesEnum: Output only. Describes the impact of this SDK bug. + + Fields: + severity: Output only. How severe the SDK bug is. + type: Output only. Describes the impact of this SDK bug. + uri: Output only. Link to more information on the bug. + """ + class SeverityValueValuesEnum(_messages.Enum): + r"""Output only. How severe the SDK bug is. + + Values: + SEVERITY_UNSPECIFIED: A bug of unknown severity. + NOTICE: A minor bug that that may reduce reliability or performance for + some jobs. Impact will be minimal or non-existent for most jobs. + WARNING: A bug that has some likelihood of causing performance + degradation, data loss, or job failures. + SEVERE: A bug with extremely significant impact. Jobs may fail + erroneously, performance may be severely degraded, and data loss may + be very likely. + """ + SEVERITY_UNSPECIFIED = 0 + NOTICE = 1 + WARNING = 2 + SEVERE = 3 + + class TypeValueValuesEnum(_messages.Enum): + r"""Output only. Describes the impact of this SDK bug. + + Values: + TYPE_UNSPECIFIED: Unknown issue with this SDK. + GENERAL: Catch-all for SDK bugs that don't fit in the below categories. + PERFORMANCE: Using this version of the SDK may result in degraded + performance. + DATALOSS: Using this version of the SDK may cause data loss. + """ + TYPE_UNSPECIFIED = 0 + GENERAL = 1 + PERFORMANCE = 2 + DATALOSS = 3 + + severity = _messages.EnumField('SeverityValueValuesEnum', 1) + type = _messages.EnumField('TypeValueValuesEnum', 2) + uri = _messages.StringField(3) + + class SdkHarnessContainerImage(_messages.Message): - r"""Defines a SDK harness container for executing Dataflow pipelines. + r"""Defines an SDK harness container for executing Dataflow pipelines. Fields: capabilities: The set of capabilities enumerated in the above Environment - proto. See also https://github.com/apache/beam/blob/master/model/pipelin - e/src/main/proto/beam_runner_api.proto + proto. See also [beam_runner_api.proto](https://github.com/apache/beam/b + lob/master/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/ + v1/beam_runner_api.proto) containerImage: A docker container image that resides in Google Container Registry. environmentId: Environment ID for the Beam runner API proto Environment @@ -4638,6 +5283,7 @@ class SdkVersion(_messages.Message): SdkSupportStatusValueValuesEnum: The support status for this SDK version. Fields: + bugs: Output only. Known bugs found in this SDK version. sdkSupportStatus: The support status for this SDK version. version: The version of the SDK used to run the job. versionDisplayName: A readable string describing the version of the SDK. @@ -4661,9 +5307,10 @@ class SdkSupportStatusValueValuesEnum(_messages.Enum): DEPRECATED = 3 UNSUPPORTED = 4 - sdkSupportStatus = _messages.EnumField('SdkSupportStatusValueValuesEnum', 1) - version = _messages.StringField(2) - versionDisplayName = _messages.StringField(3) + bugs = _messages.MessageField('SdkBug', 1, repeated=True) + sdkSupportStatus = _messages.EnumField('SdkSupportStatusValueValuesEnum', 2) + version = _messages.StringField(3) + versionDisplayName = _messages.StringField(4) class SendDebugCaptureRequest(_messages.Message): @@ -4796,6 +5443,17 @@ class SeqMapTaskOutputInfo(_messages.Message): tag = _messages.StringField(2) +class ServiceResources(_messages.Message): + r"""Resources used by the Dataflow Service to run the job. + + Fields: + zones: Output only. List of Cloud Zones being used by the Dataflow Service + for this job. Example: us-central1-c + """ + + zones = _messages.StringField(1, repeated=True) + + class ShellTask(_messages.Message): r"""A task which consists of a shell command for the worker to execute. @@ -5381,6 +6039,7 @@ class StageSummary(_messages.Message): stageId: ID of this stage startTime: Start time of this stage. state: State of this stage. + stragglerSummary: Straggler summary for this stage. """ class StateValueValuesEnum(_messages.Enum): r"""State of this stage. @@ -5406,6 +6065,7 @@ class StateValueValuesEnum(_messages.Enum): stageId = _messages.StringField(4) startTime = _messages.StringField(5) state = _messages.EnumField('StateValueValuesEnum', 6) + stragglerSummary = _messages.MessageField('StragglerSummary', 7) class StandardQueryParameters(_messages.Message): @@ -5536,15 +6196,16 @@ class Step(_messages.Message): r"""Defines a particular step within a Cloud Dataflow job. A job consists of multiple steps, each of which performs some specific operation as part of the overall job. Data is typically passed from one step to another as part - of the job. Here's an example of a sequence of steps which together - implement a Map-Reduce job: * Read a collection of data from some source, - parsing the collection's elements. * Validate the elements. * Apply a user- - defined function to map each element to some value and extract an element- - specific key value. * Group elements with the same key into a single element - with that key, transforming a multiply-keyed collection into a uniquely- - keyed collection. * Write the elements out to some data sink. Note that the - Cloud Dataflow service may be used to run many different types of jobs, not - just Map-Reduce. + of the job. **Note:** The properties of this object are not stable and might + change. Here's an example of a sequence of steps which together implement a + Map-Reduce job: * Read a collection of data from some source, parsing the + collection's elements. * Validate the elements. * Apply a user-defined + function to map each element to some value and extract an element-specific + key value. * Group elements with the same key into a single element with + that key, transforming a multiply-keyed collection into a uniquely-keyed + collection. * Write the elements out to some data sink. Note that the Cloud + Dataflow service may be used to run many different types of jobs, not just + Map-Reduce. Messages: PropertiesValue: Named properties associated with the step. Each kind of @@ -5590,6 +6251,120 @@ class AdditionalProperty(_messages.Message): properties = _messages.MessageField('PropertiesValue', 3) +class Straggler(_messages.Message): + r"""Information for a straggler. + + Fields: + batchStraggler: Batch straggler identification and debugging information. + streamingStraggler: Streaming straggler identification and debugging + information. + """ + + batchStraggler = _messages.MessageField('StragglerInfo', 1) + streamingStraggler = _messages.MessageField('StreamingStragglerInfo', 2) + + +class StragglerDebuggingInfo(_messages.Message): + r"""Information useful for debugging a straggler. Each type will provide + specialized debugging information relevant for a particular cause. The + StragglerDebuggingInfo will be 1:1 mapping to the StragglerCause enum. + + Fields: + hotKey: Hot key debugging details. + """ + + hotKey = _messages.MessageField('HotKeyDebuggingInfo', 1) + + +class StragglerInfo(_messages.Message): + r"""Information useful for straggler identification and debugging. + + Messages: + CausesValue: The straggler causes, keyed by the string representation of + the StragglerCause enum and contains specialized debugging information + for each straggler cause. + + Fields: + causes: The straggler causes, keyed by the string representation of the + StragglerCause enum and contains specialized debugging information for + each straggler cause. + startTime: The time when the work item attempt became a straggler. + """ + @encoding.MapUnrecognizedFields('additionalProperties') + class CausesValue(_messages.Message): + r"""The straggler causes, keyed by the string representation of the + StragglerCause enum and contains specialized debugging information for + each straggler cause. + + Messages: + AdditionalProperty: An additional property for a CausesValue object. + + Fields: + additionalProperties: Additional properties of type CausesValue + """ + class AdditionalProperty(_messages.Message): + r"""An additional property for a CausesValue object. + + Fields: + key: Name of the additional property. + value: A StragglerDebuggingInfo attribute. + """ + + key = _messages.StringField(1) + value = _messages.MessageField('StragglerDebuggingInfo', 2) + + additionalProperties = _messages.MessageField( + 'AdditionalProperty', 1, repeated=True) + + causes = _messages.MessageField('CausesValue', 1) + startTime = _messages.StringField(2) + + +class StragglerSummary(_messages.Message): + r"""Summarized straggler identification details. + + Messages: + StragglerCauseCountValue: Aggregated counts of straggler causes, keyed by + the string representation of the StragglerCause enum. + + Fields: + recentStragglers: The most recent stragglers. + stragglerCauseCount: Aggregated counts of straggler causes, keyed by the + string representation of the StragglerCause enum. + totalStragglerCount: The total count of stragglers. + """ + @encoding.MapUnrecognizedFields('additionalProperties') + class StragglerCauseCountValue(_messages.Message): + r"""Aggregated counts of straggler causes, keyed by the string + representation of the StragglerCause enum. + + Messages: + AdditionalProperty: An additional property for a + StragglerCauseCountValue object. + + Fields: + additionalProperties: Additional properties of type + StragglerCauseCountValue + """ + class AdditionalProperty(_messages.Message): + r"""An additional property for a StragglerCauseCountValue object. + + Fields: + key: Name of the additional property. + value: A string attribute. + """ + + key = _messages.StringField(1) + value = _messages.IntegerField(2) + + additionalProperties = _messages.MessageField( + 'AdditionalProperty', 1, repeated=True) + + recentStragglers = _messages.MessageField('Straggler', 1, repeated=True) + stragglerCauseCount = _messages.MessageField('StragglerCauseCountValue', 2) + totalStragglerCount = _messages.IntegerField(3) + + class StreamLocation(_messages.Message): r"""Describes a stream of data, either as input to be processed or as output of a streaming Dataflow job. @@ -5736,6 +6511,8 @@ class StreamingConfigTask(_messages.Message): harness to windmill. maxWorkItemCommitBytes: Maximum size for work item commit supported windmill storage layer. + operationalLimits: Operational limits for the streaming job. Can be used + by the worker to validate outputs sent to the backend. streamingComputationConfigs: Set of computation configuration information. userStepToStateFamilyNameMap: Map from user step names to state families. windmillServiceEndpoint: If present, the worker must use this endpoint to @@ -5775,12 +6552,80 @@ class AdditionalProperty(_messages.Message): commitStreamChunkSizeBytes = _messages.IntegerField(1) getDataStreamChunkSizeBytes = _messages.IntegerField(2) maxWorkItemCommitBytes = _messages.IntegerField(3) + operationalLimits = _messages.MessageField('StreamingOperationalLimits', 4) streamingComputationConfigs = _messages.MessageField( - 'StreamingComputationConfig', 4, repeated=True) + 'StreamingComputationConfig', 5, repeated=True) userStepToStateFamilyNameMap = _messages.MessageField( - 'UserStepToStateFamilyNameMapValue', 5) - windmillServiceEndpoint = _messages.StringField(6) - windmillServicePort = _messages.IntegerField(7) + 'UserStepToStateFamilyNameMapValue', 6) + windmillServiceEndpoint = _messages.StringField(7) + windmillServicePort = _messages.IntegerField(8) + + +class StreamingOperationalLimits(_messages.Message): + r"""Operational limits imposed on streaming jobs by the backend. + + Fields: + maxBagElementBytes: The maximum size for an element in bag state. + maxGlobalDataBytes: The maximum size for an element in global data. + maxKeyBytes: The maximum size allowed for a key. + maxProductionOutputBytes: The maximum size for a single output element. + maxSortedListElementBytes: The maximum size for an element in sorted list + state. + maxSourceStateBytes: The maximum size for a source state update. + maxTagBytes: The maximum size for a state tag. + maxValueBytes: The maximum size for a value state field. + """ + + maxBagElementBytes = _messages.IntegerField(1) + maxGlobalDataBytes = _messages.IntegerField(2) + maxKeyBytes = _messages.IntegerField(3) + maxProductionOutputBytes = _messages.IntegerField(4) + maxSortedListElementBytes = _messages.IntegerField(5) + maxSourceStateBytes = _messages.IntegerField(6) + maxTagBytes = _messages.IntegerField(7) + maxValueBytes = _messages.IntegerField(8) + + +class StreamingScalingReport(_messages.Message): + r"""Contains per-user worker telemetry used in streaming autoscaling. + + Fields: + activeBundleCount: A integer attribute. + activeThreadCount: Current acive thread count. + maximumBundleCount: Maximum bundle count. + maximumBytes: Maximum bytes. + maximumBytesCount: A integer attribute. + maximumThreadCount: Maximum thread count limit. + outstandingBundleCount: Current outstanding bundle count. + outstandingBytes: Current outstanding bytes. + outstandingBytesCount: A integer attribute. + """ + + activeBundleCount = _messages.IntegerField(1, variant=_messages.Variant.INT32) + activeThreadCount = _messages.IntegerField(2, variant=_messages.Variant.INT32) + maximumBundleCount = _messages.IntegerField( + 3, variant=_messages.Variant.INT32) + maximumBytes = _messages.IntegerField(4) + maximumBytesCount = _messages.IntegerField(5, variant=_messages.Variant.INT32) + maximumThreadCount = _messages.IntegerField( + 6, variant=_messages.Variant.INT32) + outstandingBundleCount = _messages.IntegerField( + 7, variant=_messages.Variant.INT32) + outstandingBytes = _messages.IntegerField(8) + outstandingBytesCount = _messages.IntegerField( + 9, variant=_messages.Variant.INT32) + + +class StreamingScalingReportResponse(_messages.Message): + r"""Contains per-user-worker streaming scaling recommendation from the + backend. + + Fields: + maximumThreadCount: Maximum thread count limit; + """ + + maximumThreadCount = _messages.IntegerField( + 1, variant=_messages.Variant.INT32) class StreamingSetupTask(_messages.Message): @@ -5829,6 +6674,26 @@ class StreamingStageLocation(_messages.Message): streamId = _messages.StringField(1) +class StreamingStragglerInfo(_messages.Message): + r"""Information useful for streaming straggler identification and debugging. + + Fields: + dataWatermarkLag: The event-time watermark lag at the time of the + straggler detection. + endTime: End time of this straggler. + startTime: Start time of this straggler. + systemWatermarkLag: The system watermark lag at the time of the straggler + detection. + workerName: Name of the worker where the straggler was detected. + """ + + dataWatermarkLag = _messages.StringField(1) + endTime = _messages.StringField(2) + startTime = _messages.StringField(3) + systemWatermarkLag = _messages.StringField(4) + workerName = _messages.StringField(5) + + class StringList(_messages.Message): r"""A metric value representing a list of strings. @@ -5923,14 +6788,27 @@ class TemplateMetadata(_messages.Message): r"""Metadata describing a template. Fields: + defaultStreamingMode: Optional. Indicates the default streaming mode for a + streaming template. Only valid if both supports_at_least_once and + supports_exactly_once are true. Possible values: UNSPECIFIED, + EXACTLY_ONCE and AT_LEAST_ONCE description: Optional. A description of the template. name: Required. The name of the template. parameters: The parameters for the template. + streaming: Optional. Indicates if the template is streaming or not. + supportsAtLeastOnce: Optional. Indicates if the streaming template + supports at least once mode. + supportsExactlyOnce: Optional. Indicates if the streaming template + supports exactly once mode. """ - description = _messages.StringField(1) - name = _messages.StringField(2) - parameters = _messages.MessageField('ParameterMetadata', 3, repeated=True) + defaultStreamingMode = _messages.StringField(1) + description = _messages.StringField(2) + name = _messages.StringField(3) + parameters = _messages.MessageField('ParameterMetadata', 4, repeated=True) + streaming = _messages.BooleanField(5) + supportsAtLeastOnce = _messages.BooleanField(6) + supportsExactlyOnce = _messages.BooleanField(7) class TopologyConfig(_messages.Message): @@ -6036,19 +6914,6 @@ class KindValueValuesEnum(_messages.Enum): outputCollectionName = _messages.StringField(6, repeated=True) -class ValidateResponse(_messages.Message): - r"""Response to the validation request. - - Fields: - errorMessage: Will be empty if validation succeeds. - queryInfo: Information about the validated query. Not defined if - validation fails. - """ - - errorMessage = _messages.StringField(1) - queryInfo = _messages.MessageField('QueryInfo', 2) - - class WorkItem(_messages.Message): r"""WorkItem represents basic information about a WorkItem to be executed in the cloud. @@ -6110,6 +6975,7 @@ class WorkItemDetails(_messages.Message): progress: Progress of this work item. startTime: Start time of this work item attempt. state: State of this work item. + stragglerInfo: Information about straggler detections for this work item. taskId: Name of this work item. """ class StateValueValuesEnum(_messages.Enum): @@ -6136,7 +7002,8 @@ class StateValueValuesEnum(_messages.Enum): progress = _messages.MessageField('ProgressTimeseries', 4) startTime = _messages.StringField(5) state = _messages.EnumField('StateValueValuesEnum', 6) - taskId = _messages.StringField(7) + stragglerInfo = _messages.MessageField('StragglerInfo', 7) + taskId = _messages.StringField(8) class WorkItemServiceState(_messages.Message): @@ -6459,6 +7326,8 @@ class WorkerMessage(_messages.Message): not be used here. Fields: + dataSamplingReport: Optional. Contains metrics related to go/dataflow- + data-sampling-telemetry. labels: Labels are used to group WorkerMessages. For example, a worker_message about a particular container might have the labels: { "JOB_ID": "2015-04-22", "WORKER_ID": "wordcount-vm-2015..." @@ -6466,12 +7335,16 @@ class WorkerMessage(_messages.Message): typically correspond to Label enum values. However, for ease of development other strings can be used as tags. LABEL_UNSPECIFIED should not be used here. + perWorkerMetrics: System defined metrics for this worker. + streamingScalingReport: Contains per-user worker telemetry used in + streaming autoscaling. time: The timestamp of the worker_message. workerHealthReport: The health of a worker. workerLifecycleEvent: Record of worker lifecycle events. workerMessageCode: A worker message code. workerMetrics: Resource metrics reported by workers. workerShutdownNotice: Shutdown notice by workers. + workerThreadScalingReport: Thread scaling information reported by workers. """ @encoding.MapUnrecognizedFields('additionalProperties') class LabelsValue(_messages.Message): @@ -6502,13 +7375,18 @@ class AdditionalProperty(_messages.Message): additionalProperties = _messages.MessageField( 'AdditionalProperty', 1, repeated=True) - labels = _messages.MessageField('LabelsValue', 1) - time = _messages.StringField(2) - workerHealthReport = _messages.MessageField('WorkerHealthReport', 3) - workerLifecycleEvent = _messages.MessageField('WorkerLifecycleEvent', 4) - workerMessageCode = _messages.MessageField('WorkerMessageCode', 5) - workerMetrics = _messages.MessageField('ResourceUtilizationReport', 6) - workerShutdownNotice = _messages.MessageField('WorkerShutdownNotice', 7) + dataSamplingReport = _messages.MessageField('DataSamplingReport', 1) + labels = _messages.MessageField('LabelsValue', 2) + perWorkerMetrics = _messages.MessageField('PerWorkerMetrics', 3) + streamingScalingReport = _messages.MessageField('StreamingScalingReport', 4) + time = _messages.StringField(5) + workerHealthReport = _messages.MessageField('WorkerHealthReport', 6) + workerLifecycleEvent = _messages.MessageField('WorkerLifecycleEvent', 7) + workerMessageCode = _messages.MessageField('WorkerMessageCode', 8) + workerMetrics = _messages.MessageField('ResourceUtilizationReport', 9) + workerShutdownNotice = _messages.MessageField('WorkerShutdownNotice', 10) + workerThreadScalingReport = _messages.MessageField( + 'WorkerThreadScalingReport', 11) class WorkerMessageCode(_messages.Message): @@ -6600,20 +7478,28 @@ class WorkerMessageResponse(_messages.Message): sender. Fields: + streamingScalingReportResponse: Service's streaming scaling response for + workers. workerHealthReportResponse: The service's response to a worker's health report. workerMetricsResponse: Service's response to reporting worker metrics (currently empty). workerShutdownNoticeResponse: Service's response to shutdown notice (currently empty). + workerThreadScalingReportResponse: Service's thread scaling recommendation + for workers. """ + streamingScalingReportResponse = _messages.MessageField( + 'StreamingScalingReportResponse', 1) workerHealthReportResponse = _messages.MessageField( - 'WorkerHealthReportResponse', 1) + 'WorkerHealthReportResponse', 2) workerMetricsResponse = _messages.MessageField( - 'ResourceUtilizationReportResponse', 2) + 'ResourceUtilizationReportResponse', 3) workerShutdownNoticeResponse = _messages.MessageField( - 'WorkerShutdownNoticeResponse', 3) + 'WorkerShutdownNoticeResponse', 4) + workerThreadScalingReportResponse = _messages.MessageField( + 'WorkerThreadScalingReportResponse', 5) class WorkerPool(_messages.Message): @@ -6884,6 +7770,29 @@ class WorkerShutdownNoticeResponse(_messages.Message): r"""Service-side response to WorkerMessage issuing shutdown notice.""" +class WorkerThreadScalingReport(_messages.Message): + r"""Contains information about the thread scaling information of a worker. + + Fields: + currentThreadCount: Current number of active threads in a worker. + """ + + currentThreadCount = _messages.IntegerField( + 1, variant=_messages.Variant.INT32) + + +class WorkerThreadScalingReportResponse(_messages.Message): + r"""Contains the thread scaling recommendation for a worker from the + backend. + + Fields: + recommendedThreadCount: Recommended number of threads for a worker. + """ + + recommendedThreadCount = _messages.IntegerField( + 1, variant=_messages.Variant.INT32) + + class WriteInstruction(_messages.Message): r"""An instruction that writes records. Takes one input, produces no outputs. diff --git a/sdks/python/apache_beam/runners/direct/bundle_factory.py b/sdks/python/apache_beam/runners/direct/bundle_factory.py index e4beefe992c1..95d8c06111a2 100644 --- a/sdks/python/apache_beam/runners/direct/bundle_factory.py +++ b/sdks/python/apache_beam/runners/direct/bundle_factory.py @@ -40,16 +40,17 @@ class BundleFactory(object): in case consecutive ones share the same timestamp and windows. DirectRunnerOptions.direct_runner_use_stacked_bundle controls this option. """ - def __init__(self, stacked): - # type: (bool) -> None + def __init__(self, stacked: bool) -> None: self._stacked = stacked - def create_bundle(self, output_pcollection): - # type: (Union[pvalue.PBegin, pvalue.PCollection]) -> _Bundle + def create_bundle( + self, output_pcollection: Union[pvalue.PBegin, + pvalue.PCollection]) -> '_Bundle': return _Bundle(output_pcollection, self._stacked) - def create_empty_committed_bundle(self, output_pcollection): - # type: (Union[pvalue.PBegin, pvalue.PCollection]) -> _Bundle + def create_empty_committed_bundle( + self, output_pcollection: Union[pvalue.PBegin, + pvalue.PCollection]) -> '_Bundle': bundle = self.create_bundle(output_pcollection) bundle.commit(None) return bundle @@ -110,27 +111,27 @@ def pane_info(self): def add_value(self, value): self._appended_values.append(value) - def windowed_values(self): - # type: () -> Iterator[WindowedValue] + def windowed_values(self) -> Iterator[WindowedValue]: # yield first windowed_value as is, then iterate through # _appended_values to yield WindowedValue on the fly. yield self._initial_windowed_value for v in self._appended_values: yield self._initial_windowed_value.with_value(v) - def __init__(self, pcollection, stacked=True): - # type: (Union[pvalue.PBegin, pvalue.PCollection], bool) -> None + def __init__( + self, + pcollection: Union[pvalue.PBegin, pvalue.PCollection], + stacked: bool = True) -> None: assert isinstance(pcollection, (pvalue.PBegin, pvalue.PCollection)) self._pcollection = pcollection - self._elements = [ - ] # type: List[Union[WindowedValue, _Bundle._StackedWindowedValues]] + self._elements: List[Union[WindowedValue, + _Bundle._StackedWindowedValues]] = [] self._stacked = stacked self._committed = False self._tag = None # optional tag information for this bundle - def get_elements_iterable(self, make_copy=False): - # type: (bool) -> Iterable[WindowedValue] - + def get_elements_iterable(self, + make_copy: bool = False) -> Iterable[WindowedValue]: """Returns iterable elements. Args: @@ -203,8 +204,7 @@ def add(self, element): def output(self, element): self.add(element) - def receive(self, element): - # type: (WindowedValue) -> None + def receive(self, element: WindowedValue) -> None: self.add(element) def commit(self, synchronized_processing_time): diff --git a/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py b/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py index 2a6fc3ee6093..91085274f32a 100644 --- a/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py +++ b/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py @@ -19,16 +19,13 @@ # pytype: skip-file -from typing import TYPE_CHECKING from typing import Dict from typing import Set from apache_beam import pvalue +from apache_beam.pipeline import AppliedPTransform from apache_beam.pipeline import PipelineVisitor -if TYPE_CHECKING: - from apache_beam.pipeline import AppliedPTransform - class ConsumerTrackingPipelineVisitor(PipelineVisitor): """For internal use only; no backwards-compatibility guarantees. @@ -40,10 +37,9 @@ class ConsumerTrackingPipelineVisitor(PipelineVisitor): transform has produced and committed output. """ def __init__(self): - self.value_to_consumers = { - } # type: Dict[pvalue.PValue, Set[AppliedPTransform]] - self.root_transforms = set() # type: Set[AppliedPTransform] - self.step_names = {} # type: Dict[AppliedPTransform, str] + self.value_to_consumers: Dict[pvalue.PValue, Set[AppliedPTransform]] = {} + self.root_transforms: Set[AppliedPTransform] = set() + self.step_names: Dict[AppliedPTransform, str] = {} self._num_transforms = 0 self._views = set() @@ -57,8 +53,7 @@ def views(self): """ return list(self._views) - def visit_transform(self, applied_ptransform): - # type: (AppliedPTransform) -> None + def visit_transform(self, applied_ptransform: AppliedPTransform) -> None: inputs = list(applied_ptransform.inputs) if inputs: for input_value in inputs: diff --git a/sdks/python/apache_beam/runners/direct/direct_metrics.py b/sdks/python/apache_beam/runners/direct/direct_metrics.py index e4fd44053119..f715ce3bf521 100644 --- a/sdks/python/apache_beam/runners/direct/direct_metrics.py +++ b/sdks/python/apache_beam/runners/direct/direct_metrics.py @@ -28,6 +28,7 @@ from apache_beam.metrics.cells import CounterAggregator from apache_beam.metrics.cells import DistributionAggregator from apache_beam.metrics.cells import GaugeAggregator +from apache_beam.metrics.cells import StringSetAggregator from apache_beam.metrics.execution import MetricKey from apache_beam.metrics.execution import MetricResult from apache_beam.metrics.metric import MetricResults @@ -39,6 +40,7 @@ def __init__(self): self._distributions = defaultdict( lambda: DirectMetric(DistributionAggregator())) self._gauges = defaultdict(lambda: DirectMetric(GaugeAggregator())) + self._string_sets = defaultdict(lambda: DirectMetric(StringSetAggregator())) def _apply_operation(self, bundle, updates, op): for k, v in updates.counters.items(): @@ -50,6 +52,9 @@ def _apply_operation(self, bundle, updates, op): for k, v in updates.gauges.items(): op(self._gauges[k], bundle, v) + for k, v in updates.string_sets.items(): + op(self._string_sets[k], bundle, v) + def commit_logical(self, bundle, updates): op = lambda obj, bundle, update: obj.commit_logical(bundle, update) self._apply_operation(bundle, updates, op) @@ -84,11 +89,19 @@ def query(self, filter=None): v.extract_latest_attempted()) for k, v in self._gauges.items() if self.matches(filter, k) ] + string_sets = [ + MetricResult( + MetricKey(k.step, k.metric), + v.extract_committed(), + v.extract_latest_attempted()) for k, + v in self._string_sets.items() if self.matches(filter, k) + ] return { self.COUNTERS: counters, self.DISTRIBUTIONS: distributions, - self.GAUGES: gauges + self.GAUGES: gauges, + self.STRINGSETS: string_sets } diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index a470ba80d8ee..1cd20550edf3 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -390,8 +390,9 @@ def __init__(self, source): self._source = source def _infer_output_coder( - self, unused_input_type=None, unused_input_coder=None): - # type: (...) -> typing.Optional[coders.Coder] + self, + unused_input_type=None, + unused_input_coder=None) -> typing.Optional[coders.Coder]: return coders.BytesCoder() def get_windowing(self, unused_inputs): diff --git a/sdks/python/apache_beam/runners/direct/direct_runner_test.py b/sdks/python/apache_beam/runners/direct/direct_runner_test.py index 58cec732d3fa..d8f1ea097b88 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner_test.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner_test.py @@ -76,6 +76,8 @@ def process(self, element): count.inc() distro = Metrics.distribution(self.__class__, 'element_dist') distro.update(element) + str_set = Metrics.string_set(self.__class__, 'element_str_set') + str_set.add(str(element % 4)) return [element] p = Pipeline(DirectRunner()) @@ -115,6 +117,13 @@ def process(self, element): hc.assert_that(gauge_result.committed.value, hc.equal_to(5)) hc.assert_that(gauge_result.attempted.value, hc.equal_to(5)) + str_set_result = metrics['string_sets'][0] + hc.assert_that( + str_set_result.key, + hc.equal_to(MetricKey('Do', MetricName(namespace, 'element_str_set')))) + hc.assert_that(len(str_set_result.committed), hc.equal_to(4)) + hc.assert_that(len(str_set_result.attempted), hc.equal_to(4)) + def test_create_runner(self): self.assertTrue(isinstance(create_runner('DirectRunner'), DirectRunner)) self.assertTrue( diff --git a/sdks/python/apache_beam/runners/direct/evaluation_context.py b/sdks/python/apache_beam/runners/direct/evaluation_context.py index fbe59b072ae4..c34735499abc 100644 --- a/sdks/python/apache_beam/runners/direct/evaluation_context.py +++ b/sdks/python/apache_beam/runners/direct/evaluation_context.py @@ -31,21 +31,21 @@ from typing import Tuple from typing import Union +from apache_beam import pvalue +from apache_beam.pipeline import AppliedPTransform from apache_beam.runners.direct.direct_metrics import DirectMetrics from apache_beam.runners.direct.executor import TransformExecutor from apache_beam.runners.direct.watermark_manager import WatermarkManager from apache_beam.transforms import sideinputs from apache_beam.transforms.trigger import InMemoryUnmergedState from apache_beam.utils import counters +from apache_beam.utils.timestamp import Timestamp if TYPE_CHECKING: - from apache_beam import pvalue - from apache_beam.pipeline import AppliedPTransform from apache_beam.runners.direct.bundle_factory import BundleFactory, _Bundle from apache_beam.runners.direct.util import TimerFiring from apache_beam.runners.direct.util import TransformResult from apache_beam.runners.direct.watermark_manager import _TransformWatermarks - from apache_beam.utils.timestamp import Timestamp class _ExecutionContext(object): @@ -53,10 +53,7 @@ class _ExecutionContext(object): It holds the watermarks for that transform, as well as keyed states. """ - def __init__( - self, - watermarks, # type: _TransformWatermarks - keyed_states): + def __init__(self, watermarks: '_TransformWatermarks', keyed_states): self.watermarks = watermarks self.keyed_states = keyed_states @@ -91,13 +88,12 @@ class _SideInputsContainer(object): It provides methods for blocking until a side-input is available and writing to a side input. """ - def __init__(self, side_inputs): - # type: (Iterable[pvalue.AsSideInput]) -> None + def __init__(self, side_inputs: Iterable['pvalue.AsSideInput']) -> None: self._lock = threading.Lock() - self._views = {} # type: Dict[pvalue.AsSideInput, _SideInputView] - self._transform_to_side_inputs = collections.defaultdict( - list - ) # type: DefaultDict[Optional[AppliedPTransform], List[pvalue.AsSideInput]] + self._views: Dict[pvalue.AsSideInput, _SideInputView] = {} + self._transform_to_side_inputs: DefaultDict[ + Optional[AppliedPTransform], + List[pvalue.AsSideInput]] = collections.defaultdict(list) # this appears unused: self._side_input_to_blocked_tasks = collections.defaultdict(list) # type: ignore @@ -111,13 +107,8 @@ def __repr__(self): for elm in self._views.values()) if self._views else '[]') return '_SideInputsContainer(_views=%s)' % views_string - def get_value_or_block_until_ready(self, - side_input, - task, # type: TransformExecutor - block_until # type: Timestamp - ): - # type: (...) -> Any - + def get_value_or_block_until_ready( + self, side_input, task: TransformExecutor, block_until: Timestamp) -> Any: """Returns the value of a view whose task is unblocked or blocks its task. It gets the value of a view whose watermark has been updated and @@ -147,9 +138,7 @@ def add_values(self, side_input, values): view.elements.extend(values) def update_watermarks_for_transform_and_unblock_tasks( - self, ptransform, watermark): - # type: (...) -> List[Tuple[TransformExecutor, Timestamp]] - + self, ptransform, watermark) -> List[Tuple[TransformExecutor, Timestamp]]: """Updates _SideInputsContainer after a watermark update and unbloks tasks. It traverses the list of side inputs per PTransform and calls @@ -170,9 +159,7 @@ def update_watermarks_for_transform_and_unblock_tasks( return unblocked_tasks def _update_watermarks_for_side_input_and_unblock_tasks( - self, side_input, watermark): - # type: (...) -> List[Tuple[TransformExecutor, Timestamp]] - + self, side_input, watermark) -> List[Tuple[TransformExecutor, Timestamp]]: """Helps update _SideInputsContainer after a watermark update. For each view of the side input, it updates the value of the watermark @@ -238,24 +225,24 @@ class EvaluationContext(object): appropriately. This includes updating the per-(step,key) state, updating global watermarks, and executing any callbacks that can be executed. """ - - def __init__(self, - pipeline_options, - bundle_factory, # type: BundleFactory - root_transforms, - value_to_consumers, - step_names, - views, # type: Iterable[pvalue.AsSideInput] - clock - ): + def __init__( + self, + pipeline_options, + bundle_factory: 'BundleFactory', + root_transforms, + value_to_consumers, + step_names, + views: Iterable[pvalue.AsSideInput], + clock): self.pipeline_options = pipeline_options self._bundle_factory = bundle_factory self._root_transforms = root_transforms self._value_to_consumers = value_to_consumers self._step_names = step_names self.views = views - self._pcollection_to_views = collections.defaultdict( - list) # type: DefaultDict[pvalue.PValue, List[pvalue.AsSideInput]] + self._pcollection_to_views: DefaultDict[ + pvalue.PValue, + List[pvalue.AsSideInput]] = collections.defaultdict(list) for view in views: self._pcollection_to_views[view.pvalue].append(view) self._transform_keyed_states = self._initialize_keyed_states( @@ -266,8 +253,8 @@ def __init__(self, root_transforms, value_to_consumers, self._transform_keyed_states) - self._pending_unblocked_tasks = [ - ] # type: List[Tuple[TransformExecutor, Timestamp]] + self._pending_unblocked_tasks: List[Tuple[TransformExecutor, + Timestamp]] = [] self._counter_factory = counters.CounterFactory() self._metrics = DirectMetrics() @@ -291,15 +278,14 @@ def metrics(self): # TODO. Should this be made a @property? return self._metrics - def is_root_transform(self, applied_ptransform): - # type: (AppliedPTransform) -> bool + def is_root_transform(self, applied_ptransform: AppliedPTransform) -> bool: return applied_ptransform in self._root_transforms - def handle_result(self, - completed_bundle, # type: _Bundle - completed_timers, - result # type: TransformResult - ): + def handle_result( + self, + completed_bundle: '_Bundle', + completed_timers, + result: 'TransformResult'): """Handle the provided result produced after evaluating the input bundle. Handle the provided TransformResult, produced after evaluating @@ -352,10 +338,8 @@ def handle_result(self, existing_keyed_state[k] = v return committed_bundles - def _update_side_inputs_container(self, - committed_bundles, # type: Iterable[_Bundle] - result # type: TransformResult - ): + def _update_side_inputs_container( + self, committed_bundles: Iterable['_Bundle'], result: 'TransformResult'): """Update the side inputs container if we are outputting into a side input. Look at the result, and if it's outputing into a PCollection that we have @@ -381,12 +365,11 @@ def schedule_pending_unblocked_tasks(self, executor_service): executor_service.submit(task) self._pending_unblocked_tasks = [] - def _commit_bundles(self, - uncommitted_bundles, # type: Iterable[_Bundle] - unprocessed_bundles # type: Iterable[_Bundle] - ): - # type: (...) -> Tuple[Tuple[_Bundle, ...], Tuple[_Bundle, ...]] - + def _commit_bundles( + self, + uncommitted_bundles: Iterable['_Bundle'], + unprocessed_bundles: Iterable['_Bundle'] + ) -> Tuple[Tuple['_Bundle', ...], Tuple['_Bundle', ...]]: """Commits bundles and returns a immutable set of committed bundles.""" for in_progress_bundle in uncommitted_bundles: producing_applied_ptransform = in_progress_bundle.pcollection.producer @@ -398,32 +381,29 @@ def _commit_bundles(self, unprocessed_bundle.commit(None) return tuple(uncommitted_bundles), tuple(unprocessed_bundles) - def get_execution_context(self, applied_ptransform): - # type: (AppliedPTransform) -> _ExecutionContext + def get_execution_context( + self, applied_ptransform: AppliedPTransform) -> _ExecutionContext: return _ExecutionContext( self._watermark_manager.get_watermarks(applied_ptransform), self._transform_keyed_states[applied_ptransform]) - def create_bundle(self, output_pcollection): - # type: (Union[pvalue.PBegin, pvalue.PCollection]) -> _Bundle - + def create_bundle( + self, output_pcollection: Union[pvalue.PBegin, + pvalue.PCollection]) -> '_Bundle': """Create an uncommitted bundle for the specified PCollection.""" return self._bundle_factory.create_bundle(output_pcollection) - def create_empty_committed_bundle(self, output_pcollection): - # type: (pvalue.PCollection) -> _Bundle - + def create_empty_committed_bundle( + self, output_pcollection: pvalue.PCollection) -> '_Bundle': """Create empty bundle useful for triggering evaluation.""" return self._bundle_factory.create_empty_committed_bundle( output_pcollection) - def extract_all_timers(self): - # type: () -> Tuple[List[Tuple[AppliedPTransform, List[TimerFiring]]], bool] + def extract_all_timers( + self) -> Tuple[List[Tuple[AppliedPTransform, List['TimerFiring']]], bool]: return self._watermark_manager.extract_all_timers() - def is_done(self, transform=None): - # type: (Optional[AppliedPTransform]) -> bool - + def is_done(self, transform: Optional[AppliedPTransform] = None) -> bool: """Checks completion of a step or the pipeline. Args: @@ -441,8 +421,7 @@ def is_done(self, transform=None): return False return True - def _is_transform_done(self, transform): - # type: (AppliedPTransform) -> bool + def _is_transform_done(self, transform: AppliedPTransform) -> bool: tw = self._watermark_manager.get_watermarks(transform) return tw.output_watermark == WatermarkManager.WATERMARK_POS_INF diff --git a/sdks/python/apache_beam/runners/direct/executor.py b/sdks/python/apache_beam/runners/direct/executor.py index 0ab3033d68b5..e8be9d64f993 100644 --- a/sdks/python/apache_beam/runners/direct/executor.py +++ b/sdks/python/apache_beam/runners/direct/executor.py @@ -64,9 +64,7 @@ class _ExecutorServiceWorker(threading.Thread): TIMEOUT = 5 def __init__( - self, - queue, # type: queue.Queue[_ExecutorService.CallableTask] - index): + self, queue: 'queue.Queue[_ExecutorService.CallableTask]', index): super().__init__() self.queue = queue self._index = index @@ -86,8 +84,7 @@ def _update_name(self, task=None): self.name = 'Thread: %d, %s (%s)' % ( self._index, name, 'executing' if task else 'idle') - def _get_task_or_none(self): - # type: () -> Optional[_ExecutorService.CallableTask] + def _get_task_or_none(self) -> Optional['_ExecutorService.CallableTask']: try: # Do not block indefinitely, otherwise we may not act for a requested # shutdown. @@ -114,16 +111,14 @@ def shutdown(self): self.shutdown_requested = True def __init__(self, num_workers): - self.queue = queue.Queue( - ) # type: queue.Queue[_ExecutorService.CallableTask] + self.queue: queue.Queue[_ExecutorService.CallableTask] = queue.Queue() self.workers = [ _ExecutorService._ExecutorServiceWorker(self.queue, i) for i in range(num_workers) ] self.shutdown_requested = False - def submit(self, task): - # type: (_ExecutorService.CallableTask) -> None + def submit(self, task: '_ExecutorService.CallableTask') -> None: assert isinstance(task, _ExecutorService.CallableTask) if not self.shutdown_requested: self.queue.put(task) @@ -150,11 +145,7 @@ def shutdown(self): class _TransformEvaluationState(object): - def __init__( - self, - executor_service, - scheduled # type: Set[TransformExecutor] - ): + def __init__(self, executor_service, scheduled: Set['TransformExecutor']): self.executor_service = executor_service self.scheduled = scheduled @@ -219,21 +210,18 @@ class _TransformExecutorServices(object): Controls the concurrency as appropriate for the applied transform the executor exists for. """ - def __init__(self, executor_service): - # type: (_ExecutorService) -> None + def __init__(self, executor_service: _ExecutorService) -> None: self._executor_service = executor_service - self._scheduled = set() # type: Set[TransformExecutor] + self._scheduled: Set[TransformExecutor] = set() self._parallel = _ParallelEvaluationState( self._executor_service, self._scheduled) - self._serial_cache = WeakValueDictionary( - ) # type: WeakValueDictionary[Any, _SerialEvaluationState] + self._serial_cache: WeakValueDictionary[ + Any, _SerialEvaluationState] = WeakValueDictionary() - def parallel(self): - # type: () -> _ParallelEvaluationState + def parallel(self) -> _ParallelEvaluationState: return self._parallel - def serial(self, step): - # type: (Any) -> _SerialEvaluationState + def serial(self, step: Any) -> _SerialEvaluationState: cached = self._serial_cache.get(step) if not cached: cached = _SerialEvaluationState(self._executor_service, self._scheduled) @@ -241,8 +229,7 @@ def serial(self, step): return cached @property - def executors(self): - # type: () -> FrozenSet[TransformExecutor] + def executors(self) -> FrozenSet['TransformExecutor']: return frozenset(self._scheduled) @@ -253,12 +240,11 @@ class _CompletionCallback(object): that are triggered due to the arrival of elements from an upstream transform, or for a source transform. """ - - def __init__(self, - evaluation_context, # type: EvaluationContext - all_updates, - timer_firings=None - ): + def __init__( + self, + evaluation_context: 'EvaluationContext', + all_updates, + timer_firings=None): self._evaluation_context = evaluation_context self._all_updates = all_updates self._timer_firings = timer_firings or [] @@ -295,15 +281,15 @@ class TransformExecutor(_ExecutorService.CallableTask): _MAX_RETRY_PER_BUNDLE = 4 - def __init__(self, - transform_evaluator_registry, # type: TransformEvaluatorRegistry - evaluation_context, # type: EvaluationContext - input_bundle, # type: _Bundle - fired_timers, - applied_ptransform, - completion_callback, - transform_evaluation_state # type: _TransformEvaluationState - ): + def __init__( + self, + transform_evaluator_registry: 'TransformEvaluatorRegistry', + evaluation_context: 'EvaluationContext', + input_bundle: '_Bundle', + fired_timers, + applied_ptransform, + completion_callback, + transform_evaluation_state: _TransformEvaluationState): self._transform_evaluator_registry = transform_evaluator_registry self._evaluation_context = evaluation_context self._input_bundle = input_bundle @@ -319,7 +305,7 @@ def __init__(self, self._applied_ptransform = applied_ptransform self._completion_callback = completion_callback self._transform_evaluation_state = transform_evaluation_state - self._side_input_values = {} # type: Dict[pvalue.AsSideInput, Any] + self._side_input_values: Dict[pvalue.AsSideInput, Any] = {} self.blocked = False self._call_count = 0 self._retry_count = 0 @@ -444,8 +430,7 @@ def __init__( self, value_to_consumers, transform_evaluator_registry, - evaluation_context # type: EvaluationContext - ): + evaluation_context: 'EvaluationContext'): self.executor_service = _ExecutorService( _ExecutorServiceParallelExecutor.NUM_WORKERS) self.transform_executor_services = _TransformExecutorServices( @@ -487,8 +472,7 @@ def request_shutdown(self): self.executor_service.await_completion() self.evaluation_context.shutdown() - def schedule_consumers(self, committed_bundle): - # type: (_Bundle) -> None + def schedule_consumers(self, committed_bundle: '_Bundle') -> None: if committed_bundle.pcollection in self.value_to_consumers: consumers = self.value_to_consumers[committed_bundle.pcollection] for applied_ptransform in consumers: @@ -500,20 +484,20 @@ def schedule_consumers(self, committed_bundle): def schedule_unprocessed_bundle(self, applied_ptransform, unprocessed_bundle): self.node_to_pending_bundles[applied_ptransform].append(unprocessed_bundle) - def schedule_consumption(self, - consumer_applied_ptransform, - committed_bundle, # type: _Bundle - fired_timers, - on_complete - ): + def schedule_consumption( + self, + consumer_applied_ptransform, + committed_bundle: '_Bundle', + fired_timers, + on_complete): """Schedules evaluation of the given bundle with the transform.""" assert consumer_applied_ptransform assert committed_bundle assert on_complete if self.transform_evaluator_registry.should_execute_serially( consumer_applied_ptransform): - transform_executor_service = self.transform_executor_services.serial( - consumer_applied_ptransform) # type: _TransformEvaluationState + transform_executor_service: _TransformEvaluationState = ( + self.transform_executor_services.serial(consumer_applied_ptransform)) else: transform_executor_service = self.transform_executor_services.parallel() @@ -587,8 +571,7 @@ def __init__(self, exception=None): class _MonitorTask(_ExecutorService.CallableTask): """MonitorTask continuously runs to ensure that pipeline makes progress.""" - def __init__(self, executor): - # type: (_ExecutorServiceParallelExecutor) -> None + def __init__(self, executor: '_ExecutorServiceParallelExecutor') -> None: self._executor = executor @property @@ -624,9 +607,7 @@ def call(self, state_sampler): if not self._should_shutdown(): self._executor.executor_service.submit(self) - def _should_shutdown(self): - # type: () -> bool - + def _should_shutdown(self) -> bool: """Checks whether the pipeline is completed and should be shut down. If there is anything in the queue of tasks to do or @@ -690,9 +671,7 @@ def _fire_timers(self): timer_completion_callback) return bool(transform_fired_timers) - def _is_executing(self): - # type: () -> bool - + def _is_executing(self) -> bool: """Checks whether the job is still executing. Returns: diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py index 528b2d1f576b..e0a58db0ef3e 100644 --- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py @@ -23,7 +23,6 @@ import uuid from threading import Lock from threading import Timer -from typing import TYPE_CHECKING from typing import Any from typing import Iterable from typing import Optional @@ -32,6 +31,7 @@ from apache_beam import TimeDomain from apache_beam import pvalue from apache_beam.coders import typecoders +from apache_beam.io.iobase import WatermarkEstimator from apache_beam.pipeline import AppliedPTransform from apache_beam.pipeline import PTransformOverride from apache_beam.runners.common import DoFnContext @@ -47,9 +47,6 @@ from apache_beam.transforms.trigger import _ReadModifyWriteStateTag from apache_beam.utils.windowed_value import WindowedValue -if TYPE_CHECKING: - from apache_beam.iobase import WatermarkEstimator - class SplittableParDoOverride(PTransformOverride): """A transform override for ParDo transformss of SplittableDoFns. @@ -541,8 +538,10 @@ def __init__(self): self.output_iter = None def handle_process_outputs( - self, windowed_input_element, output_iter, watermark_estimator=None): - # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None + self, + windowed_input_element: WindowedValue, + output_iter: Iterable[Any], + watermark_estimator: Optional[WatermarkEstimator] = None) -> None: self.output_iter = output_iter def reset(self): @@ -551,6 +550,8 @@ def reset(self): class _NoneShallPassOutputHandler(OutputHandler): def handle_process_outputs( - self, windowed_input_element, output_iter, watermark_estimator=None): - # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None + self, + windowed_input_element: WindowedValue, + output_iter: Iterable[Any], + watermark_estimator: Optional[WatermarkEstimator] = None) -> None: raise RuntimeError() diff --git a/sdks/python/apache_beam/runners/direct/test_stream_impl.py b/sdks/python/apache_beam/runners/direct/test_stream_impl.py index 0842a51d5666..c720418b05ed 100644 --- a/sdks/python/apache_beam/runners/direct/test_stream_impl.py +++ b/sdks/python/apache_beam/runners/direct/test_stream_impl.py @@ -30,6 +30,7 @@ from queue import Empty as EmptyException from queue import Queue from threading import Thread +from typing import Union import grpc @@ -309,8 +310,8 @@ def is_alive(): return not (shutdown_requested or evaluation_context.shutdown_requested) # The shared queue that allows the producer and consumer to communicate. - channel = Queue( - ) # type: Queue[Union[test_stream.Event, _EndOfStream]] # noqa: F821 + channel: 'Queue[Union[test_stream.Event, _EndOfStream]]' = ( # noqa: F821 + Queue()) event_stream = Thread( target=_TestStream._stream_events_from_rpc, args=(endpoint, output_tags, coder, channel, is_alive)) diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index 37004c7258a7..b0278ba5356c 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -36,6 +36,7 @@ from apache_beam import io from apache_beam import pvalue from apache_beam.internal import pickler +from apache_beam.pipeline import AppliedPTransform from apache_beam.runners import common from apache_beam.runners.common import DoFnRunner from apache_beam.runners.common import DoFnState @@ -77,7 +78,6 @@ if TYPE_CHECKING: from apache_beam.io.gcp.pubsub import _PubSubSource from apache_beam.io.gcp.pubsub import PubsubMessage - from apache_beam.pipeline import AppliedPTransform from apache_beam.runners.direct.evaluation_context import EvaluationContext _LOGGER = logging.getLogger(__name__) @@ -89,14 +89,13 @@ class TransformEvaluatorRegistry(object): Creates instances of TransformEvaluator for the application of a transform. """ - _test_evaluators_overrides = { - } # type: Dict[Type[core.PTransform], Type[_TransformEvaluator]] + _test_evaluators_overrides: Dict[Type[core.PTransform], + Type['_TransformEvaluator']] = {} - def __init__(self, evaluation_context): - # type: (EvaluationContext) -> None + def __init__(self, evaluation_context: 'EvaluationContext') -> None: assert evaluation_context self._evaluation_context = evaluation_context - self._evaluators = { + self._evaluators: Dict[Type[core.PTransform], Type[_TransformEvaluator]] = { io.Read: _BoundedReadEvaluator, _DirectReadFromPubSub: _PubSubReadEvaluator, core.Flatten: _FlattenEvaluator, @@ -109,7 +108,7 @@ def __init__(self, evaluation_context): ProcessElements: _ProcessElementsEvaluator, _WatermarkController: _WatermarkControllerEvaluator, PairWithTiming: _PairWithTimingEvaluator, - } # type: Dict[Type[core.PTransform], Type[_TransformEvaluator]] + } self._evaluators.update(self._test_evaluators_overrides) self._root_bundle_providers = { core.PTransform: DefaultRootBundleProvider, @@ -231,13 +230,12 @@ def get_root_bundles(self): class _TransformEvaluator(object): """An evaluator of a specific application of a transform.""" - - def __init__(self, - evaluation_context, # type: EvaluationContext - applied_ptransform, # type: AppliedPTransform - input_committed_bundle, - side_inputs - ): + def __init__( + self, + evaluation_context: 'EvaluationContext', + applied_ptransform: AppliedPTransform, + input_committed_bundle, + side_inputs): self._evaluation_context = evaluation_context self._applied_ptransform = applied_ptransform self._input_committed_bundle = input_committed_bundle @@ -321,9 +319,7 @@ def process_element(self, element): """Processes a new element as part of the current bundle.""" raise NotImplementedError('%s do not process elements.' % type(self)) - def finish_bundle(self): - # type: () -> TransformResult - + def finish_bundle(self) -> TransformResult: """Finishes the bundle and produces output.""" pass @@ -592,7 +588,7 @@ class _PubSubReadEvaluator(_TransformEvaluator): # A mapping of transform to _PubSubSubscriptionWrapper. # TODO(https://github.com/apache/beam/issues/19751): Prevents garbage # collection of pipeline instances. - _subscription_cache = {} # type: Dict[AppliedPTransform, str] + _subscription_cache: Dict[AppliedPTransform, str] = {} def __init__( self, @@ -607,7 +603,7 @@ def __init__( input_committed_bundle, side_inputs) - self.source = self._applied_ptransform.transform._source # type: _PubSubSource + self.source: _PubSubSource = self._applied_ptransform.transform._source if self.source.id_label: raise NotImplementedError( 'DirectRunner: id_label is not supported for PubSub reads') @@ -655,8 +651,8 @@ def start_bundle(self): def process_element(self, element): pass - def _read_from_pubsub(self, timestamp_attribute): - # type: (...) -> List[Tuple[Timestamp, PubsubMessage]] + def _read_from_pubsub( + self, timestamp_attribute) -> List[Tuple[Timestamp, 'PubsubMessage']]: from apache_beam.io.gcp.pubsub import PubsubMessage from google.cloud import pubsub @@ -699,8 +695,7 @@ def _get_element(message): return results - def finish_bundle(self): - # type: () -> TransformResult + def finish_bundle(self) -> TransformResult: data = self._read_from_pubsub(self.source.timestamp_attribute) if data: output_pcollection = list(self._outputs)[0] @@ -777,8 +772,7 @@ def __init__(self, evaluation_context): class NullReceiver(common.Receiver): """Ignores undeclared outputs, default execution mode.""" - def receive(self, element): - # type: (WindowedValue) -> None + def receive(self, element: WindowedValue) -> None: pass class _InMemoryReceiver(common.Receiver): @@ -787,8 +781,7 @@ def __init__(self, target, tag): self._target = target self._tag = tag - def receive(self, element): - # type: (WindowedValue) -> None + def receive(self, element: WindowedValue) -> None: self._target[self._tag].append(element) def __missing__(self, key): @@ -799,14 +792,13 @@ def __missing__(self, key): class _ParDoEvaluator(_TransformEvaluator): """TransformEvaluator for ParDo transform.""" - - def __init__(self, - evaluation_context, # type: EvaluationContext - applied_ptransform, # type: AppliedPTransform - input_committed_bundle, - side_inputs, - perform_dofn_pickle_test=True - ): + def __init__( + self, + evaluation_context: 'EvaluationContext', + applied_ptransform: AppliedPTransform, + input_committed_bundle, + side_inputs, + perform_dofn_pickle_test=True): super().__init__( evaluation_context, applied_ptransform, diff --git a/sdks/python/apache_beam/runners/direct/watermark_manager.py b/sdks/python/apache_beam/runners/direct/watermark_manager.py index 8f97de508ff5..666ade6cf82d 100644 --- a/sdks/python/apache_beam/runners/direct/watermark_manager.py +++ b/sdks/python/apache_beam/runners/direct/watermark_manager.py @@ -29,15 +29,15 @@ from apache_beam import pipeline from apache_beam import pvalue +from apache_beam.pipeline import AppliedPTransform from apache_beam.runners.direct.util import TimerFiring from apache_beam.utils.timestamp import MAX_TIMESTAMP from apache_beam.utils.timestamp import MIN_TIMESTAMP from apache_beam.utils.timestamp import TIME_GRANULARITY +from apache_beam.utils.timestamp import Timestamp if TYPE_CHECKING: - from apache_beam.pipeline import AppliedPTransform from apache_beam.runners.direct.bundle_factory import _Bundle - from apache_beam.utils.timestamp import Timestamp class WatermarkManager(object): @@ -55,8 +55,8 @@ def __init__( self._value_to_consumers = value_to_consumers self._transform_keyed_states = transform_keyed_states # AppliedPTransform -> TransformWatermarks - self._transform_to_watermarks = { - } # type: Dict[AppliedPTransform, _TransformWatermarks] + self._transform_to_watermarks: Dict[AppliedPTransform, + _TransformWatermarks] = {} for root_transform in root_transforms: self._transform_to_watermarks[root_transform] = _TransformWatermarks( @@ -71,8 +71,8 @@ def __init__( for consumer in consumers: self._update_input_transform_watermarks(consumer) - def _update_input_transform_watermarks(self, applied_ptransform): - # type: (AppliedPTransform) -> None + def _update_input_transform_watermarks( + self, applied_ptransform: AppliedPTransform) -> None: assert isinstance(applied_ptransform, pipeline.AppliedPTransform) input_transform_watermarks = [] for input_pvalue in applied_ptransform.inputs: @@ -84,9 +84,8 @@ def _update_input_transform_watermarks(self, applied_ptransform): applied_ptransform].update_input_transform_watermarks( input_transform_watermarks) - def get_watermarks(self, applied_ptransform): - # type: (AppliedPTransform) -> _TransformWatermarks - + def get_watermarks( + self, applied_ptransform: AppliedPTransform) -> '_TransformWatermarks': """Gets the input and output watermarks for an AppliedPTransform. If the applied_ptransform has not processed any elements, return a @@ -107,15 +106,15 @@ def get_watermarks(self, applied_ptransform): return self._transform_to_watermarks[applied_ptransform] - def update_watermarks(self, - completed_committed_bundle, # type: _Bundle - applied_ptransform, # type: AppliedPTransform - completed_timers, - outputs, - unprocessed_bundles, - keyed_earliest_holds, - side_inputs_container - ): + def update_watermarks( + self, + completed_committed_bundle: '_Bundle', + applied_ptransform: AppliedPTransform, + completed_timers, + outputs, + unprocessed_bundles, + keyed_earliest_holds, + side_inputs_container): assert isinstance(applied_ptransform, pipeline.AppliedPTransform) self._update_pending( completed_committed_bundle, @@ -127,13 +126,13 @@ def update_watermarks(self, tw.hold(keyed_earliest_holds) return self._refresh_watermarks(applied_ptransform, side_inputs_container) - def _update_pending(self, - input_committed_bundle, - applied_ptransform, # type: AppliedPTransform - completed_timers, - output_committed_bundles, # type: Iterable[_Bundle] - unprocessed_bundles # type: Iterable[_Bundle] - ): + def _update_pending( + self, + input_committed_bundle, + applied_ptransform: AppliedPTransform, + completed_timers, + output_committed_bundles: Iterable['_Bundle'], + unprocessed_bundles: Iterable['_Bundle']): """Updated list of pending bundles for the given AppliedPTransform.""" # Update pending elements. Filter out empty bundles. They do not impact @@ -179,12 +178,11 @@ def _refresh_watermarks(self, applied_ptransform, side_inputs_container): applied_ptransform, tw)) return unblocked_tasks - def extract_all_timers(self): - # type: () -> Tuple[List[Tuple[AppliedPTransform, List[TimerFiring]]], bool] - + def extract_all_timers( + self) -> Tuple[List[Tuple[AppliedPTransform, List[TimerFiring]]], bool]: """Extracts fired timers for all transforms and reports if there are any timers set.""" - all_timers = [] # type: List[Tuple[AppliedPTransform, List[TimerFiring]]] + all_timers: List[Tuple[AppliedPTransform, List[TimerFiring]]] = [] has_realtime_timer = False for applied_ptransform, tw in self._transform_to_watermarks.items(): fired_timers, had_realtime_timer = tw.extract_transform_timers() @@ -203,19 +201,19 @@ class _TransformWatermarks(object): def __init__(self, clock, keyed_states, transform): self._clock = clock self._keyed_states = keyed_states - self._input_transform_watermarks = [] # type: List[_TransformWatermarks] + self._input_transform_watermarks: List[_TransformWatermarks] = [] self._input_watermark = WatermarkManager.WATERMARK_NEG_INF self._output_watermark = WatermarkManager.WATERMARK_NEG_INF self._keyed_earliest_holds = {} # Scheduled bundles targeted for this transform. - self._pending = set() # type: Set[_Bundle] + self._pending: Set['_Bundle'] = set() self._fired_timers = set() self._lock = threading.Lock() self._label = str(transform) - def update_input_transform_watermarks(self, input_transform_watermarks): - # type: (List[_TransformWatermarks]) -> None + def update_input_transform_watermarks( + self, input_transform_watermarks: List['_TransformWatermarks']) -> None: with self._lock: self._input_transform_watermarks = input_transform_watermarks @@ -225,14 +223,12 @@ def update_timers(self, completed_timers): self._fired_timers.remove(timer_firing) @property - def input_watermark(self): - # type: () -> Timestamp + def input_watermark(self) -> Timestamp: with self._lock: return self._input_watermark @property - def output_watermark(self): - # type: () -> Timestamp + def output_watermark(self) -> Timestamp: with self._lock: return self._output_watermark @@ -244,22 +240,18 @@ def hold(self, keyed_earliest_holds): hold_value == WatermarkManager.WATERMARK_POS_INF): del self._keyed_earliest_holds[key] - def add_pending(self, pending): - # type: (_Bundle) -> None + def add_pending(self, pending: '_Bundle') -> None: with self._lock: self._pending.add(pending) - def remove_pending(self, completed): - # type: (_Bundle) -> None + def remove_pending(self, completed: '_Bundle') -> None: with self._lock: # Ignore repeated removes. This will happen if a transform has a repeated # input. if completed in self._pending: self._pending.remove(completed) - def refresh(self): - # type: () -> bool - + def refresh(self) -> bool: """Refresh the watermark for a given transform. This method looks at the watermark coming from all input PTransforms, and @@ -308,9 +300,7 @@ def refresh(self): def synchronized_processing_output_time(self): return self._clock.time() - def extract_transform_timers(self): - # type: () -> Tuple[List[TimerFiring], bool] - + def extract_transform_timers(self) -> Tuple[List[TimerFiring], bool]: """Extracts fired timers and reports of any timers set per transform.""" with self._lock: fired_timers = [] diff --git a/sdks/python/apache_beam/runners/interactive/background_caching_job.py b/sdks/python/apache_beam/runners/interactive/background_caching_job.py index 3802cfa60095..71f7f77ded4e 100644 --- a/sdks/python/apache_beam/runners/interactive/background_caching_job.py +++ b/sdks/python/apache_beam/runners/interactive/background_caching_job.py @@ -193,9 +193,7 @@ def is_background_caching_job_needed(user_pipeline): cache_changed)) -def is_cache_complete(pipeline_id): - # type: (str) -> bool - +def is_cache_complete(pipeline_id: str) -> bool: """Returns True if the backgrond cache for the given pipeline is done. """ user_pipeline = ie.current_env().pipeline_id_to_pipeline(pipeline_id) diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager.py b/sdks/python/apache_beam/runners/interactive/cache_manager.py index b04eb92132a5..ce543796a6bd 100644 --- a/sdks/python/apache_beam/runners/interactive/cache_manager.py +++ b/sdks/python/apache_beam/runners/interactive/cache_manager.py @@ -145,9 +145,7 @@ def cleanup(self): """Cleans up all the PCollection caches.""" raise NotImplementedError - def size(self, *labels): - # type: (*str) -> int - + def size(self, *labels: str) -> int: """Returns the size of the PCollection on disk in bytes.""" raise NotImplementedError diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager_test.py b/sdks/python/apache_beam/runners/interactive/cache_manager_test.py index 8dd525978284..a5d38682716c 100644 --- a/sdks/python/apache_beam/runners/interactive/cache_manager_test.py +++ b/sdks/python/apache_beam/runners/interactive/cache_manager_test.py @@ -37,7 +37,7 @@ class FileBasedCacheManagerTest(object): tested with InteractiveRunner as a part of integration tests instead. """ - cache_format = None # type: str + cache_format: str = None def setUp(self): self.cache_manager = cache.FileBasedCacheManager( diff --git a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py index 92cb108bc46f..1f1e315fea09 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py +++ b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py @@ -48,7 +48,7 @@ class PipelineGraph(object): """Creates a DOT representing the pipeline. Thread-safe. Runner agnostic.""" def __init__( self, - pipeline, # type: Union[beam_runner_api_pb2.Pipeline, beam.Pipeline] + pipeline: Union[beam_runner_api_pb2.Pipeline, beam.Pipeline], default_vertex_attrs={'shape': 'box'}, default_edge_attrs=None, render_option=None): @@ -71,7 +71,7 @@ def __init__( rendered. See display.pipeline_graph_renderer for available options. """ self._lock = threading.Lock() - self._graph = None # type: pydot.Dot + self._graph: pydot.Dot = None self._pipeline_instrument = None if isinstance(pipeline, beam.Pipeline): self._pipeline_instrument = inst.PipelineInstrument( @@ -90,10 +90,9 @@ def __init__( (beam_runner_api_pb2.Pipeline, beam.Pipeline, type(pipeline))) # A dict from PCollection ID to a list of its consuming Transform IDs - self._consumers = collections.defaultdict( - list) # type: DefaultDict[str, List[str]] + self._consumers: DefaultDict[str, List[str]] = collections.defaultdict(list) # A dict from PCollection ID to its producing Transform ID - self._producers = {} # type: Dict[str, str] + self._producers: Dict[str, str] = {} for transform_id, transform_proto in self._top_level_transforms(): for pcoll_id in transform_proto.inputs.values(): @@ -113,8 +112,7 @@ def __init__( self._renderer = pipeline_graph_renderer.get_renderer(render_option) - def get_dot(self): - # type: () -> str + def get_dot(self) -> str: return self._get_graph().to_string() def display_graph(self): @@ -130,9 +128,8 @@ def display_graph(self): 'environment is in a notebook. Cannot display the ' 'pipeline graph.') - def _top_level_transforms(self): - # type: () -> Iterator[Tuple[str, beam_runner_api_pb2.PTransform]] - + def _top_level_transforms( + self) -> Iterator[Tuple[str, beam_runner_api_pb2.PTransform]]: """Yields all top level PTransforms (subtransforms of the root PTransform). Yields: (str, PTransform proto) ID, proto pair of top level PTransforms. diff --git a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py index 9e23fc1deeda..ad46f5d65ea3 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py +++ b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py @@ -40,17 +40,13 @@ class PipelineGraphRenderer(BeamPlugin, metaclass=abc.ABCMeta): """ @classmethod @abc.abstractmethod - def option(cls): - # type: () -> str - + def option(cls) -> str: """The corresponding rendering option for the renderer. """ raise NotImplementedError @abc.abstractmethod - def render_pipeline_graph(self, pipeline_graph): - # type: (PipelineGraph) -> str - + def render_pipeline_graph(self, pipeline_graph: 'PipelineGraph') -> str: """Renders the pipeline graph in HTML-compatible format. Args: @@ -66,12 +62,10 @@ class MuteRenderer(PipelineGraphRenderer): """Use this renderer to mute the pipeline display. """ @classmethod - def option(cls): - # type: () -> str + def option(cls) -> str: return 'mute' - def render_pipeline_graph(self, pipeline_graph): - # type: (PipelineGraph) -> str + def render_pipeline_graph(self, pipeline_graph: 'PipelineGraph') -> str: return '' @@ -79,12 +73,10 @@ class TextRenderer(PipelineGraphRenderer): """This renderer simply returns the dot representation in text format. """ @classmethod - def option(cls): - # type: () -> str + def option(cls) -> str: return 'text' - def render_pipeline_graph(self, pipeline_graph): - # type: (PipelineGraph) -> str + def render_pipeline_graph(self, pipeline_graph: 'PipelineGraph') -> str: return pipeline_graph.get_dot() @@ -96,18 +88,14 @@ class PydotRenderer(PipelineGraphRenderer): 2. The python module pydot: https://pypi.org/project/pydot/ """ @classmethod - def option(cls): - # type: () -> str + def option(cls) -> str: return 'graph' - def render_pipeline_graph(self, pipeline_graph): - # type: (PipelineGraph) -> str + def render_pipeline_graph(self, pipeline_graph: 'PipelineGraph') -> str: return pipeline_graph._get_graph().create_svg().decode("utf-8") # pylint: disable=protected-access -def get_renderer(option=None): - # type: (Optional[str]) -> Type[PipelineGraphRenderer] - +def get_renderer(option: Optional[str] = None) -> Type[PipelineGraphRenderer]: """Get an instance of PipelineGraphRenderer given rendering option. Args: diff --git a/sdks/python/apache_beam/runners/interactive/options/capture_control.py b/sdks/python/apache_beam/runners/interactive/options/capture_control.py index 86422cd8219d..826b596bbc6d 100644 --- a/sdks/python/apache_beam/runners/interactive/options/capture_control.py +++ b/sdks/python/apache_beam/runners/interactive/options/capture_control.py @@ -25,6 +25,7 @@ import logging from datetime import timedelta +from typing import List from apache_beam.io.gcp.pubsub import ReadFromPubSub from apache_beam.runners.interactive import interactive_environment as ie @@ -45,8 +46,8 @@ def __init__(self): self._capture_size_limit = 1e9 self._test_limiters = None - def limiters(self): - # type: () -> List[capture_limiters.Limiter] # noqa: F821 + def limiters(self) -> List['capture_limiters.Limiter']: + # noqa: F821 if self._test_limiters: return self._test_limiters return [ @@ -54,8 +55,9 @@ def limiters(self): capture_limiters.DurationLimiter(self._capture_duration) ] - def set_limiters_for_test(self, limiters): - # type: (List[capture_limiters.Limiter]) -> None # noqa: F821 + def set_limiters_for_test( + self, limiters: List['capture_limiters.Limiter']) -> None: + # noqa: F821 self._test_limiters = limiters diff --git a/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py b/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py index 9634685e6fb5..497772f94c36 100644 --- a/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py +++ b/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py @@ -20,7 +20,9 @@ For internal use only; no backwards-compatibility guarantees. """ +import datetime import threading +from typing import Any import pandas as pd @@ -32,9 +34,7 @@ class Limiter: """Limits an aspect of the caching layer.""" - def is_triggered(self): - # type: () -> bool - + def is_triggered(self) -> bool: """Returns True if the limiter has triggered, and caching should stop.""" raise NotImplementedError @@ -43,8 +43,8 @@ class ElementLimiter(Limiter): """A `Limiter` that limits reading from cache based on some property of an element. """ - def update(self, e): - # type: (Any) -> None # noqa: F821 + def update(self, e: Any) -> None: + # noqa: F821 """Update the internal state based on some property of an element. @@ -55,10 +55,7 @@ def update(self, e): class SizeLimiter(Limiter): """Limits the cache size to a specified byte limit.""" - def __init__( - self, - size_limit # type: int - ): + def __init__(self, size_limit: int): self._size_limit = size_limit def is_triggered(self): @@ -75,7 +72,7 @@ class DurationLimiter(Limiter): """Limits the duration of the capture.""" def __init__( self, - duration_limit # type: datetime.timedelta # noqa: F821 + duration_limit: datetime.timedelta # noqa: F821 ): self._duration_limit = duration_limit self._timer = threading.Timer(duration_limit.total_seconds(), self._trigger) diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager.py b/sdks/python/apache_beam/runners/interactive/recording_manager.py index bee215717b4d..2e113240c09c 100644 --- a/sdks/python/apache_beam/runners/interactive/recording_manager.py +++ b/sdks/python/apache_beam/runners/interactive/recording_manager.py @@ -19,6 +19,10 @@ import threading import time import warnings +from typing import Any +from typing import Dict +from typing import List +from typing import Union import pandas as pd @@ -40,12 +44,11 @@ class ElementStream: """A stream of elements from a given PCollection.""" def __init__( self, - pcoll, # type: beam.pvalue.PCollection - var, # type: str - cache_key, # type: str - max_n, # type: int - max_duration_secs # type: float - ): + pcoll: beam.pvalue.PCollection, + var: str, + cache_key: str, + max_n: int, + max_duration_secs: float): self._pcoll = pcoll self._cache_key = cache_key self._pipeline = ie.current_env().user_pipeline(pcoll.pipeline) @@ -58,47 +61,37 @@ def __init__( self._done = False @property - def var(self): - # type: () -> str - + def var(self) -> str: """Returns the variable named that defined this PCollection.""" return self._var @property - def pcoll(self): - # type: () -> beam.pvalue.PCollection - + def pcoll(self) -> beam.pvalue.PCollection: """Returns the PCollection that supplies this stream with data.""" return self._pcoll @property - def cache_key(self): - # type: () -> str - + def cache_key(self) -> str: """Returns the cache key for this stream.""" return self._cache_key - def display_id(self, suffix): - # type: (str) -> str - + def display_id(self, suffix: str) -> str: """Returns a unique id able to be displayed in a web browser.""" return utils.obfuscate(self._cache_key, suffix) - def is_computed(self): - # type: () -> boolean # noqa: F821 + def is_computed(self) -> bool: + # noqa: F821 """Returns True if no more elements will be recorded.""" return self._pcoll in ie.current_env().computed_pcollections - def is_done(self): - # type: () -> boolean # noqa: F821 + def is_done(self) -> bool: + # noqa: F821 """Returns True if no more new elements will be yielded.""" return self._done - def read(self, tail=True): - # type: (boolean) -> Any # noqa: F821 - + def read(self, tail: bool = True) -> Any: """Reads the elements currently recorded.""" # Get the cache manager and wait until the file exists. @@ -154,11 +147,11 @@ class Recording: """A group of PCollections from a given pipeline run.""" def __init__( self, - user_pipeline, # type: beam.Pipeline - pcolls, # type: List[beam.pvalue.PCollection] # noqa: F821 - result, # type: beam.runner.PipelineResult - max_n, # type: int - max_duration_secs, # type: float + user_pipeline: beam.Pipeline, + pcolls: List[beam.pvalue.PCollection], # noqa: F821 + result: 'beam.runner.PipelineResult', + max_n: int, + max_duration_secs: float, ): self._user_pipeline = user_pipeline self._result = result @@ -188,9 +181,7 @@ def __init__( self._mark_computed.daemon = True self._mark_computed.start() - def _mark_all_computed(self): - # type: () -> None - + def _mark_all_computed(self) -> None: """Marks all the PCollections upon a successful pipeline run.""" if not self._result: return @@ -216,40 +207,28 @@ def _mark_all_computed(self): if self._result.state is PipelineState.DONE and self._set_computed: ie.current_env().mark_pcollection_computed(self._pcolls) - def is_computed(self): - # type: () -> boolean # noqa: F821 - + def is_computed(self) -> bool: """Returns True if all PCollections are computed.""" return all(s.is_computed() for s in self._streams.values()) - def stream(self, pcoll): - # type: (beam.pvalue.PCollection) -> ElementStream - + def stream(self, pcoll: beam.pvalue.PCollection) -> ElementStream: """Returns an ElementStream for a given PCollection.""" return self._streams[pcoll] - def computed(self): - # type: () -> None - + def computed(self) -> None: """Returns all computed ElementStreams.""" return {p: s for p, s in self._streams.items() if s.is_computed()} - def uncomputed(self): - # type: () -> None - + def uncomputed(self) -> None: """Returns all uncomputed ElementStreams.""" return {p: s for p, s in self._streams.items() if not s.is_computed()} - def cancel(self): - # type: () -> None - + def cancel(self) -> None: """Cancels the recording.""" with self._result_lock: self._result.cancel() - def wait_until_finish(self): - # type: () -> None - + def wait_until_finish(self) -> None: """Waits until the pipeline is done and returns the final state. This also marks any PCollections as computed right away if the pipeline is @@ -261,9 +240,7 @@ def wait_until_finish(self): self._mark_computed.join() return self._result.state - def describe(self): - # type: () -> dict[str, int] - + def describe(self) -> Dict[str, int]: """Returns a dictionary describing the cache and recording.""" cache_manager = ie.current_env().get_cache_manager(self._user_pipeline) @@ -274,18 +251,19 @@ def describe(self): class RecordingManager: """Manages recordings of PCollections for a given pipeline.""" - def __init__(self, user_pipeline, pipeline_var=None, test_limiters=None): - # type: (beam.Pipeline, str, list[Limiter]) -> None # noqa: F821 - - self.user_pipeline = user_pipeline # type: beam.Pipeline - self.pipeline_var = pipeline_var if pipeline_var else '' # type: str - self._recordings = set() # type: set[Recording] - self._start_time_sec = 0 # type: float + def __init__( + self, + user_pipeline: beam.Pipeline, + pipeline_var: str = None, + test_limiters: List['Limiter'] = None) -> None: # noqa: F821 + + self.user_pipeline: beam.Pipeline = user_pipeline + self.pipeline_var: str = pipeline_var if pipeline_var else '' + self._recordings: set[Recording] = set() + self._start_time_sec: float = 0 self._test_limiters = test_limiters if test_limiters else [] - def _watch(self, pcolls): - # type: (List[beam.pvalue.PCollection]) -> None # noqa: F821 - + def _watch(self, pcolls: List[beam.pvalue.PCollection]) -> None: """Watch any pcollections not being watched. This allows for the underlying caching layer to identify the PCollection as @@ -314,9 +292,7 @@ def _watch(self, pcolls): ie.current_env().watch( {'anonymous_pcollection_{}'.format(id(pcoll)): pcoll}) - def _clear(self): - # type: () -> None - + def _clear(self) -> None: """Clears the recording of all non-source PCollections.""" cache_manager = ie.current_env().get_cache_manager(self.user_pipeline) @@ -338,17 +314,13 @@ def _clear_pcolls(self, cache_manager, pcolls): for pc in pcolls: cache_manager.clear('full', pc) - def clear(self): - # type: () -> None - + def clear(self) -> None: """Clears all cached PCollections for this RecordingManager.""" cache_manager = ie.current_env().get_cache_manager(self.user_pipeline) if cache_manager: cache_manager.cleanup() - def cancel(self): - # type: (None) -> None - + def cancel(self: None) -> None: """Cancels the current background recording job.""" bcj.attempt_to_cancel_background_caching_job(self.user_pipeline) @@ -361,9 +333,7 @@ def cancel(self): # evict the BCJ after they complete. ie.current_env().evict_background_caching_job(self.user_pipeline) - def describe(self): - # type: () -> dict[str, int] - + def describe(self) -> Dict[str, int]: """Returns a dictionary describing the cache and recording.""" cache_manager = ie.current_env().get_cache_manager(self.user_pipeline) @@ -384,9 +354,7 @@ def describe(self): 'pipeline_var': self.pipeline_var } - def record_pipeline(self): - # type: () -> bool - + def record_pipeline(self) -> bool: """Starts a background caching job for this RecordingManager's pipeline.""" runner = self.user_pipeline.runner @@ -412,8 +380,12 @@ def record_pipeline(self): return True return False - def record(self, pcolls, max_n, max_duration): - # type: (List[beam.pvalue.PCollection], int, Union[int,str]) -> Recording # noqa: F821 + def record( + self, + pcolls: List[beam.pvalue.PCollection], + max_n: int, + max_duration: Union[int, str]) -> Recording: + # noqa: F821 """Records the given PCollections.""" @@ -464,8 +436,13 @@ def record(self, pcolls, max_n, max_duration): return recording - def read(self, pcoll_name, pcoll, max_n, max_duration_secs): - # type: (str, beam.pvalue.PValue, int, float) -> Union[None, ElementStream] # noqa: F821 + def read( + self, + pcoll_name: str, + pcoll: beam.pvalue.PValue, + max_n: int, + max_duration_secs: float) -> Union[None, ElementStream]: + # noqa: F821 """Reads an ElementStream of a computed PCollection. diff --git a/sdks/python/apache_beam/runners/interactive/testing/integration/notebook_executor.py b/sdks/python/apache_beam/runners/interactive/testing/integration/notebook_executor.py index 6a80639ee285..808ede64d60d 100644 --- a/sdks/python/apache_beam/runners/interactive/testing/integration/notebook_executor.py +++ b/sdks/python/apache_beam/runners/interactive/testing/integration/notebook_executor.py @@ -40,8 +40,7 @@ class NotebookExecutor(object): """Executor that reads notebooks, executes it and gathers outputs into static HTML pages that can be served.""" - def __init__(self, path): - # type: (str) -> None + def __init__(self, path: str) -> None: assert _interactive_integration_ready, ( '[interactive_test] dependency is not installed.') diff --git a/sdks/python/apache_beam/runners/interactive/testing/integration/screen_diff.py b/sdks/python/apache_beam/runners/interactive/testing/integration/screen_diff.py index a1c9971b0882..743d5614f9a2 100644 --- a/sdks/python/apache_beam/runners/interactive/testing/integration/screen_diff.py +++ b/sdks/python/apache_beam/runners/interactive/testing/integration/screen_diff.py @@ -52,8 +52,11 @@ class ScreenDiffIntegrationTestEnvironment(object): """A test environment to conduct screen diff integration tests for notebooks. """ - def __init__(self, test_notebook_path, golden_dir, cleanup=True): - # type: (str, str, bool) -> None + def __init__( + self, + test_notebook_path: str, + golden_dir: str, + cleanup: bool = True) -> None: assert _interactive_integration_ready, ( '[interactive_test] dependency is not installed.') diff --git a/sdks/python/apache_beam/runners/job/utils.py b/sdks/python/apache_beam/runners/job/utils.py index 205d87941a5a..1e15064ffd70 100644 --- a/sdks/python/apache_beam/runners/job/utils.py +++ b/sdks/python/apache_beam/runners/job/utils.py @@ -27,8 +27,7 @@ from google.protobuf import struct_pb2 -def dict_to_struct(dict_obj): - # type: (dict) -> struct_pb2.Struct +def dict_to_struct(dict_obj: dict) -> struct_pb2.Struct: try: return json_format.ParseDict(dict_obj, struct_pb2.Struct()) except json_format.ParseError: @@ -36,6 +35,5 @@ def dict_to_struct(dict_obj): raise -def struct_to_dict(struct_obj): - # type: (struct_pb2.Struct) -> dict +def struct_to_dict(struct_obj: struct_pb2.Struct) -> dict: return json.loads(json_format.MessageToJson(struct_obj)) diff --git a/sdks/python/apache_beam/runners/pipeline_context.py b/sdks/python/apache_beam/runners/pipeline_context.py index 102b8b60d69a..0a03c96bc19b 100644 --- a/sdks/python/apache_beam/runners/pipeline_context.py +++ b/sdks/python/apache_beam/runners/pipeline_context.py @@ -23,7 +23,6 @@ # pytype: skip-file # mypy: disallow-untyped-defs -from typing import TYPE_CHECKING from typing import Any from typing import Dict from typing import FrozenSet @@ -35,11 +34,14 @@ from typing import TypeVar from typing import Union +from google.protobuf import message from typing_extensions import Protocol from apache_beam import coders from apache_beam import pipeline from apache_beam import pvalue +from apache_beam.coders.coder_impl import IterableStateReader +from apache_beam.coders.coder_impl import IterableStateWriter from apache_beam.internal import pickler from apache_beam.pipeline import ComponentIdMap from apache_beam.portability.api import beam_fn_api_pb2 @@ -49,23 +51,15 @@ from apache_beam.transforms.resources import merge_resource_hints from apache_beam.typehints import native_type_compatibility -if TYPE_CHECKING: - from google.protobuf import message # pylint: disable=ungrouped-imports - from apache_beam.coders.coder_impl import IterableStateReader - from apache_beam.coders.coder_impl import IterableStateWriter - from apache_beam.transforms import ptransform - PortableObjectT = TypeVar('PortableObjectT', bound='PortableObject') class PortableObject(Protocol): - def to_runner_api(self, __context): - # type: (PipelineContext) -> Any + def to_runner_api(self, __context: 'PipelineContext') -> Any: pass @classmethod - def from_runner_api(cls, __proto, __context): - # type: (Any, PipelineContext) -> Any + def from_runner_api(cls, __proto: Any, __context: 'PipelineContext') -> Any: pass @@ -75,27 +69,24 @@ class _PipelineContextMap(Generic[PortableObjectT]): Under the hood it encodes and decodes these objects into runner API representations. """ - def __init__(self, - context, # type: PipelineContext - obj_type, # type: Type[PortableObjectT] - namespace, # type: str - proto_map=None # type: Optional[Mapping[str, message.Message]] - ): - # type: (...) -> None + def __init__( + self, + context: 'PipelineContext', + obj_type: Type[PortableObjectT], + namespace: str, + proto_map: Optional[Mapping[str, message.Message]] = None) -> None: self._pipeline_context = context self._obj_type = obj_type self._namespace = namespace - self._obj_to_id = {} # type: Dict[Any, str] - self._id_to_obj = {} # type: Dict[str, Any] + self._obj_to_id: Dict[Any, str] = {} + self._id_to_obj: Dict[str, Any] = {} self._id_to_proto = dict(proto_map) if proto_map else {} - def populate_map(self, proto_map): - # type: (Mapping[str, message.Message]) -> None + def populate_map(self, proto_map: Mapping[str, message.Message]) -> None: for id, proto in self._id_to_proto.items(): proto_map[id].CopyFrom(proto) - def get_id(self, obj, label=None): - # type: (PortableObjectT, Optional[str]) -> str + def get_id(self, obj: PortableObjectT, label: Optional[str] = None) -> str: if obj not in self._obj_to_id: id = self._pipeline_context.component_id_map.get_or_assign( obj, self._obj_type, label) @@ -104,19 +95,23 @@ def get_id(self, obj, label=None): self._id_to_proto[id] = obj.to_runner_api(self._pipeline_context) return self._obj_to_id[obj] - def get_proto(self, obj, label=None): - # type: (PortableObjectT, Optional[str]) -> message.Message + def get_proto( + self, + obj: PortableObjectT, + label: Optional[str] = None) -> message.Message: return self._id_to_proto[self.get_id(obj, label)] - def get_by_id(self, id): - # type: (str) -> PortableObjectT + def get_by_id(self, id: str) -> PortableObjectT: if id not in self._id_to_obj: self._id_to_obj[id] = self._obj_type.from_runner_api( self._id_to_proto[id], self._pipeline_context) return self._id_to_obj[id] - def get_by_proto(self, maybe_new_proto, label=None, deduplicate=False): - # type: (message.Message, Optional[str], bool) -> str + def get_by_proto( + self, + maybe_new_proto: message.Message, + label: Optional[str] = None, + deduplicate: bool = False) -> str: # TODO: this method may not be safe for arbitrary protos due to # xlang concerns, hence limiting usage to the only current use-case it has. # See: https://github.com/apache/beam/pull/14390#discussion_r616062377 @@ -136,16 +131,17 @@ def get_by_proto(self, maybe_new_proto, label=None, deduplicate=False): obj=obj, obj_type=self._obj_type, label=label), maybe_new_proto) - def get_id_to_proto_map(self): - # type: () -> Dict[str, message.Message] + def get_id_to_proto_map(self) -> Dict[str, message.Message]: return self._id_to_proto - def get_proto_from_id(self, id): - # type: (str) -> message.Message + def get_proto_from_id(self, id: str) -> message.Message: return self.get_id_to_proto_map()[id] - def put_proto(self, id, proto, ignore_duplicates=False): - # type: (str, message.Message, bool) -> str + def put_proto( + self, + id: str, + proto: message.Message, + ignore_duplicates: bool = False) -> str: if not ignore_duplicates and id in self._id_to_proto: raise ValueError("Id '%s' is already taken." % id) elif (ignore_duplicates and id in self._id_to_proto and @@ -158,12 +154,10 @@ def put_proto(self, id, proto, ignore_duplicates=False): self._id_to_proto[id] = proto return id - def __getitem__(self, id): - # type: (str) -> Any + def __getitem__(self, id: str) -> Any: return self.get_by_id(id) - def __contains__(self, id): - # type: (str) -> bool + def __contains__(self, id: str) -> bool: return id in self._id_to_proto @@ -172,18 +166,18 @@ class PipelineContext(object): Used for accessing and constructing the referenced objects of a Pipeline. """ - - def __init__(self, - proto=None, # type: Optional[Union[beam_runner_api_pb2.Components, beam_fn_api_pb2.ProcessBundleDescriptor]] - component_id_map=None, # type: Optional[pipeline.ComponentIdMap] - default_environment=None, # type: Optional[environments.Environment] - use_fake_coders=False, # type: bool - iterable_state_read=None, # type: Optional[IterableStateReader] - iterable_state_write=None, # type: Optional[IterableStateWriter] - namespace='ref', # type: str - requirements=(), # type: Iterable[str] - ): - # type: (...) -> None + def __init__( + self, + proto: Optional[Union[beam_runner_api_pb2.Components, + beam_fn_api_pb2.ProcessBundleDescriptor]] = None, + component_id_map: Optional[pipeline.ComponentIdMap] = None, + default_environment: Optional[environments.Environment] = None, + use_fake_coders: bool = False, + iterable_state_read: Optional[IterableStateReader] = None, + iterable_state_write: Optional[IterableStateWriter] = None, + namespace: str = 'ref', + requirements: Iterable[str] = (), + ) -> None: if isinstance(proto, beam_fn_api_pb2.ProcessBundleDescriptor): proto = beam_runner_api_pb2.Components( coders=dict(proto.coders.items()), @@ -224,22 +218,19 @@ def __init__(self, if default_environment is None: default_environment = environments.DefaultEnvironment() - self._default_environment_id = self.environments.get_id( - default_environment, label='default_environment') # type: str + self._default_environment_id: str = self.environments.get_id( + default_environment, label='default_environment') self.use_fake_coders = use_fake_coders - self.deterministic_coder_map = { - } # type: Mapping[coders.Coder, coders.Coder] + self.deterministic_coder_map: Mapping[coders.Coder, coders.Coder] = {} self.iterable_state_read = iterable_state_read self.iterable_state_write = iterable_state_write self._requirements = set(requirements) - def add_requirement(self, requirement): - # type: (str) -> None + def add_requirement(self, requirement: str) -> None: self._requirements.add(requirement) - def requirements(self): - # type: () -> FrozenSet[str] + def requirements(self) -> FrozenSet[str]: return frozenset(self._requirements) # If fake coders are requested, return a pickled version of the element type @@ -248,8 +239,9 @@ def requirements(self): # TODO(https://github.com/apache/beam/issues/18490): Remove once this is no # longer needed. def coder_id_from_element_type( - self, element_type, requires_deterministic_key_coder=None): - # type: (Any, Optional[str]) -> str + self, + element_type: Any, + requires_deterministic_key_coder: Optional[str] = None) -> str: if self.use_fake_coders: return pickler.dumps(element_type).decode('ascii') else: @@ -262,14 +254,12 @@ def coder_id_from_element_type( ]) return self.coders.get_id(coder) - def deterministic_coder(self, coder, msg): - # type: (coders.Coder, str) -> coders.Coder + def deterministic_coder(self, coder: coders.Coder, msg: str) -> coders.Coder: if coder not in self.deterministic_coder_map: self.deterministic_coder_map[coder] = coder.as_deterministic_coder(msg) # type: ignore return self.deterministic_coder_map[coder] - def element_type_from_coder_id(self, coder_id): - # type: (str) -> Any + def element_type_from_coder_id(self, coder_id: str) -> Any: if self.use_fake_coders or coder_id not in self.coders: return pickler.loads(coder_id) else: @@ -277,12 +267,11 @@ def element_type_from_coder_id(self, coder_id): self.coders[coder_id].to_type_hint()) @staticmethod - def from_runner_api(proto): - # type: (beam_runner_api_pb2.Components) -> PipelineContext + def from_runner_api( + proto: beam_runner_api_pb2.Components) -> 'PipelineContext': return PipelineContext(proto) - def to_runner_api(self): - # type: () -> beam_runner_api_pb2.Components + def to_runner_api(self) -> beam_runner_api_pb2.Components: context_proto = beam_runner_api_pb2.Components() self.transforms.populate_map(context_proto.transforms) @@ -293,20 +282,19 @@ def to_runner_api(self): return context_proto - def default_environment_id(self): - # type: () -> str + def default_environment_id(self) -> str: return self._default_environment_id def get_environment_id_for_resource_hints( - self, hints): # type: (Dict[str, bytes]) -> str + self, hints: Dict[str, bytes]) -> str: """Returns an environment id that has necessary resource hints.""" if not hints: return self.default_environment_id() def get_or_create_environment_with_resource_hints( - template_env_id, - resource_hints, - ): # type: (str, Dict[str, bytes]) -> str + template_env_id: str, + resource_hints: Dict[str, bytes], + ) -> str: """Creates an environment that has necessary hints and returns its id.""" template_env = self.environments.get_proto_from_id(template_env_id) cloned_env = beam_runner_api_pb2.Environment() diff --git a/sdks/python/apache_beam/runners/portability/abstract_job_service.py b/sdks/python/apache_beam/runners/portability/abstract_job_service.py index 1aa841df4c31..87162d5feda5 100644 --- a/sdks/python/apache_beam/runners/portability/abstract_job_service.py +++ b/sdks/python/apache_beam/runners/portability/abstract_job_service.py @@ -25,7 +25,7 @@ import uuid import zipfile from concurrent import futures -from typing import TYPE_CHECKING +from typing import BinaryIO from typing import Dict from typing import Iterator from typing import Optional @@ -34,21 +34,17 @@ import grpc from google.protobuf import json_format +from google.protobuf import struct_pb2 from google.protobuf import timestamp_pb2 from apache_beam.portability.api import beam_artifact_api_pb2_grpc from apache_beam.portability.api import beam_job_api_pb2 from apache_beam.portability.api import beam_job_api_pb2_grpc +from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.portability.api import endpoints_pb2 from apache_beam.runners.portability import artifact_service from apache_beam.utils.timestamp import Timestamp -if TYPE_CHECKING: - # pylint: disable=ungrouped-imports - from typing import BinaryIO - from google.protobuf import struct_pb2 - from apache_beam.portability.api import beam_runner_api_pb2 - _LOGGER = logging.getLogger(__name__) StateEvent = Tuple[int, Union[timestamp_pb2.Timestamp, Timestamp]] @@ -74,25 +70,22 @@ class AbstractJobServiceServicer(beam_job_api_pb2_grpc.JobServiceServicer): Servicer for the Beam Job API. """ def __init__(self): - self._jobs = {} # type: Dict[str, AbstractBeamJob] + self._jobs: Dict[str, AbstractBeamJob] = {} def create_beam_job(self, preparation_id, # stype: str - job_name, # type: str - pipeline, # type: beam_runner_api_pb2.Pipeline - options # type: struct_pb2.Struct - ): - # type: (...) -> AbstractBeamJob - + job_name: str, + pipeline: beam_runner_api_pb2.Pipeline, + options: struct_pb2.Struct + ) -> 'AbstractBeamJob': """Returns an instance of AbstractBeamJob specific to this servicer.""" raise NotImplementedError(type(self)) - def Prepare(self, - request, # type: beam_job_api_pb2.PrepareJobRequest - context=None, - timeout=None - ): - # type: (...) -> beam_job_api_pb2.PrepareJobResponse + def Prepare( + self, + request: beam_job_api_pb2.PrepareJobRequest, + context=None, + timeout=None) -> beam_job_api_pb2.PrepareJobResponse: _LOGGER.debug('Got Prepare request.') preparation_id = '%s-%s' % (request.job_name, uuid.uuid4()) self._jobs[preparation_id] = self.create_beam_job( @@ -108,56 +101,52 @@ def Prepare(self, artifact_staging_endpoint(), staging_session_token=preparation_id) - def Run(self, - request, # type: beam_job_api_pb2.RunJobRequest - context=None, - timeout=None - ): - # type: (...) -> beam_job_api_pb2.RunJobResponse + def Run( + self, + request: beam_job_api_pb2.RunJobRequest, + context=None, + timeout=None) -> beam_job_api_pb2.RunJobResponse: # For now, just use the preparation id as the job id. job_id = request.preparation_id _LOGGER.info("Running job '%s'", job_id) self._jobs[job_id].run() return beam_job_api_pb2.RunJobResponse(job_id=job_id) - def GetJobs(self, - request, # type: beam_job_api_pb2.GetJobsRequest - context=None, - timeout=None - ): - # type: (...) -> beam_job_api_pb2.GetJobsResponse + def GetJobs( + self, + request: beam_job_api_pb2.GetJobsRequest, + context=None, + timeout=None) -> beam_job_api_pb2.GetJobsResponse: return beam_job_api_pb2.GetJobsResponse( job_info=[job.to_runner_api() for job in self._jobs.values()]) def GetState( self, - request, # type: beam_job_api_pb2.GetJobStateRequest - context=None): - # type: (...) -> beam_job_api_pb2.JobStateEvent + request: beam_job_api_pb2.GetJobStateRequest, + context=None) -> beam_job_api_pb2.JobStateEvent: return make_state_event(*self._jobs[request.job_id].get_state()) - def GetPipeline(self, - request, # type: beam_job_api_pb2.GetJobPipelineRequest - context=None, - timeout=None - ): - # type: (...) -> beam_job_api_pb2.GetJobPipelineResponse + def GetPipeline( + self, + request: beam_job_api_pb2.GetJobPipelineRequest, + context=None, + timeout=None) -> beam_job_api_pb2.GetJobPipelineResponse: return beam_job_api_pb2.GetJobPipelineResponse( pipeline=self._jobs[request.job_id].get_pipeline()) - def Cancel(self, - request, # type: beam_job_api_pb2.CancelJobRequest - context=None, - timeout=None - ): - # type: (...) -> beam_job_api_pb2.CancelJobResponse + def Cancel( + self, + request: beam_job_api_pb2.CancelJobRequest, + context=None, + timeout=None) -> beam_job_api_pb2.CancelJobResponse: self._jobs[request.job_id].cancel() return beam_job_api_pb2.CancelJobResponse( state=self._jobs[request.job_id].get_state()[0]) - def GetStateStream(self, request, context=None, timeout=None): - # type: (...) -> Iterator[beam_job_api_pb2.JobStateEvent] - + def GetStateStream(self, + request, + context=None, + timeout=None) -> Iterator[beam_job_api_pb2.JobStateEvent]: """Yields state transitions since the stream started. """ if request.job_id not in self._jobs: @@ -167,9 +156,11 @@ def GetStateStream(self, request, context=None, timeout=None): for state, timestamp in job.get_state_stream(): yield make_state_event(state, timestamp) - def GetMessageStream(self, request, context=None, timeout=None): - # type: (...) -> Iterator[beam_job_api_pb2.JobMessagesResponse] - + def GetMessageStream( + self, + request, + context=None, + timeout=None) -> Iterator[beam_job_api_pb2.JobMessagesResponse]: """Yields messages since the stream started. """ if request.job_id not in self._jobs: @@ -184,50 +175,48 @@ def GetMessageStream(self, request, context=None, timeout=None): resp = beam_job_api_pb2.JobMessagesResponse(message_response=msg) yield resp - def DescribePipelineOptions(self, request, context=None, timeout=None): - # type: (...) -> beam_job_api_pb2.DescribePipelineOptionsResponse + def DescribePipelineOptions( + self, + request, + context=None, + timeout=None) -> beam_job_api_pb2.DescribePipelineOptionsResponse: return beam_job_api_pb2.DescribePipelineOptionsResponse() class AbstractBeamJob(object): """Abstract baseclass for managing a single Beam job.""" - - def __init__(self, - job_id, # type: str - job_name, # type: str - pipeline, # type: beam_runner_api_pb2.Pipeline - options # type: struct_pb2.Struct - ): + def __init__( + self, + job_id: str, + job_name: str, + pipeline: beam_runner_api_pb2.Pipeline, + options: struct_pb2.Struct): self._job_id = job_id self._job_name = job_name self._pipeline_proto = pipeline self._pipeline_options = options self._state_history = [(beam_job_api_pb2.JobState.STOPPED, Timestamp.now())] - def prepare(self): - # type: () -> None - + def prepare(self) -> None: """Called immediately after this class is instantiated""" raise NotImplementedError(self) - def run(self): - # type: () -> None + def run(self) -> None: raise NotImplementedError(self) - def cancel(self): - # type: () -> Optional[beam_job_api_pb2.JobState.Enum] + def cancel(self) -> Optional['beam_job_api_pb2.JobState.Enum']: raise NotImplementedError(self) - def artifact_staging_endpoint(self): - # type: () -> Optional[endpoints_pb2.ApiServiceDescriptor] + def artifact_staging_endpoint( + self) -> Optional[endpoints_pb2.ApiServiceDescriptor]: raise NotImplementedError(self) - def get_state_stream(self): - # type: () -> Iterator[StateEvent] + def get_state_stream(self) -> Iterator[StateEvent]: raise NotImplementedError(self) - def get_message_stream(self): - # type: () -> Iterator[Union[StateEvent, Optional[beam_job_api_pb2.JobMessage]]] + def get_message_stream( + self + ) -> Iterator[Union[StateEvent, Optional[beam_job_api_pb2.JobMessage]]]: raise NotImplementedError(self) @property @@ -259,8 +248,7 @@ def with_state_history(self, state_stream): """Utility to prepend recorded state history to an active state stream""" return itertools.chain(self._state_history[:], state_stream) - def get_pipeline(self): - # type: () -> beam_runner_api_pb2.Pipeline + def get_pipeline(self) -> beam_runner_api_pb2.Pipeline: return self._pipeline_proto @staticmethod @@ -268,8 +256,7 @@ def is_terminal_state(state): from apache_beam.runners.portability import portable_runner return state in portable_runner.TERMINAL_STATES - def to_runner_api(self): - # type: () -> beam_job_api_pb2.JobInfo + def to_runner_api(self) -> beam_job_api_pb2.JobInfo: return beam_job_api_pb2.JobInfo( job_id=self._job_id, job_name=self._job_name, @@ -285,9 +272,7 @@ def __init__(self, jar_path, root): def close(self): self._zipfile_handle.close() - def file_writer(self, path): - # type: (str) -> Tuple[BinaryIO, str] - + def file_writer(self, path: str) -> Tuple[BinaryIO, str]: """Given a relative path, returns an open handle that can be written to and an reference that can later be used to read this file.""" full_path = '%s/%s' % (self._root, path) diff --git a/sdks/python/apache_beam/runners/portability/artifact_service.py b/sdks/python/apache_beam/runners/portability/artifact_service.py index 6dec4031ee07..b9395caeafaf 100644 --- a/sdks/python/apache_beam/runners/portability/artifact_service.py +++ b/sdks/python/apache_beam/runners/portability/artifact_service.py @@ -57,7 +57,7 @@ class ArtifactRetrievalService( def __init__( self, - file_reader, # type: Callable[[str], BinaryIO] + file_reader: Callable[[str], BinaryIO], chunk_size=None, ): self._file_reader = file_reader @@ -97,18 +97,20 @@ class ArtifactStagingService( beam_artifact_api_pb2_grpc.ArtifactStagingServiceServicer): def __init__( self, - file_writer, # type: Callable[[str, Optional[str]], Tuple[BinaryIO, str]] - ): + file_writer: Callable[[str, Optional[str]], Tuple[BinaryIO, str]], + ): self._lock = threading.Lock() - self._jobs_to_stage = { - } # type: Dict[str, Tuple[Dict[Any, List[beam_runner_api_pb2.ArtifactInformation]], threading.Event]] + self._jobs_to_stage: Dict[ + str, + Tuple[Dict[Any, List[beam_runner_api_pb2.ArtifactInformation]], + threading.Event]] = {} self._file_writer = file_writer def register_job( self, - staging_token, # type: str - dependency_sets # type: MutableMapping[Any, List[beam_runner_api_pb2.ArtifactInformation]] - ): + staging_token: str, + dependency_sets: MutableMapping[ + Any, List[beam_runner_api_pb2.ArtifactInformation]]): if staging_token in self._jobs_to_stage: raise ValueError('Already staging %s' % staging_token) with self._lock: diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py b/sdks/python/apache_beam/runners/portability/flink_runner_test.py index f69ee1c24c4e..4dc2446fdd9d 100644 --- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py @@ -252,6 +252,9 @@ def test_expand_kafka_read(self): 'LongDeserializer', commit_offset_in_finalize=True, timestamp_policy=ReadFromKafka.create_time_policy, + redistribute=False, + redistribute_num_keys=0, + allow_duplicates=False, expansion_service=self.get_expansion_service())) self.assertTrue( 'No resolvable bootstrap urls given in bootstrap.servers' in str( diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py index 885c96146456..e69e37495f64 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py @@ -48,6 +48,7 @@ from apache_beam import coders from apache_beam.coders.coder_impl import CoderImpl +from apache_beam.coders.coder_impl import WindowedValueCoderImpl from apache_beam.coders.coder_impl import create_InputStream from apache_beam.coders.coder_impl import create_OutputStream from apache_beam.coders.coders import WindowedValueCoder @@ -55,6 +56,7 @@ from apache_beam.portability import python_urns from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.portability.api import beam_runner_api_pb2 +from apache_beam.portability.api import endpoints_pb2 from apache_beam.runners import pipeline_context from apache_beam.runners.common import ENCODED_IMPULSE_VALUE from apache_beam.runners.direct.clock import RealClock @@ -73,6 +75,7 @@ from apache_beam.transforms import core from apache_beam.transforms import trigger from apache_beam.transforms import window +from apache_beam.transforms.window import BoundedWindow from apache_beam.transforms.window import GlobalWindow from apache_beam.transforms.window import GlobalWindows from apache_beam.utils import proto_utils @@ -81,12 +84,8 @@ from apache_beam.utils.timestamp import Timestamp if TYPE_CHECKING: - from apache_beam.coders.coder_impl import WindowedValueCoderImpl - from apache_beam.portability.api import endpoints_pb2 from apache_beam.runners.portability.fn_api_runner import worker_handlers from apache_beam.runners.portability.fn_api_runner.translations import DataSideInput - from apache_beam.runners.portability.fn_api_runner.translations import TimerFamilyId - from apache_beam.transforms.window import BoundedWindow _LOGGER = logging.getLogger(__name__) @@ -95,12 +94,10 @@ class Buffer(Protocol): - def __iter__(self): - # type: () -> Iterator[bytes] + def __iter__(self) -> Iterator[bytes]: pass - def append(self, item): - # type: (bytes) -> None + def append(self, item: bytes) -> None: pass def extend(self, other: 'Buffer') -> None: @@ -111,31 +108,26 @@ class PartitionableBuffer(Buffer, Protocol): def copy(self) -> 'PartitionableBuffer': pass - def partition(self, n): - # type: (int) -> List[List[bytes]] + def partition(self, n: int) -> List[List[bytes]]: pass @property - def cleared(self): - # type: () -> bool + def cleared(self) -> bool: pass - def clear(self): - # type: () -> None + def clear(self) -> None: pass - def reset(self): - # type: () -> None + def reset(self) -> None: pass class ListBuffer: """Used to support parititioning of a list.""" - def __init__(self, coder_impl): - # type: (Optional[CoderImpl]) -> None + def __init__(self, coder_impl: Optional[CoderImpl]) -> None: self._coder_impl = coder_impl or CoderImpl() - self._inputs = [] # type: List[bytes] - self._grouped_output = None # type: Optional[List[List[bytes]]] + self._inputs: List[bytes] = [] + self._grouped_output: Optional[List[List[bytes]]] = None self.cleared = False def copy(self) -> 'ListBuffer': @@ -151,16 +143,14 @@ def extend(self, extra: 'Buffer') -> None: assert isinstance(extra, ListBuffer) self._inputs.extend(extra._inputs) - def append(self, element): - # type: (bytes) -> None + def append(self, element: bytes) -> None: if self.cleared: raise RuntimeError('Trying to append to a cleared ListBuffer.') if self._grouped_output: raise RuntimeError('ListBuffer append after read.') self._inputs.append(element) - def partition(self, n): - # type: (int) -> List[List[bytes]] + def partition(self, n: int) -> List[List[bytes]]: if self.cleared: raise RuntimeError('Trying to partition a cleared ListBuffer.') if len(self._inputs) >= n or len(self._inputs) == 0: @@ -181,21 +171,17 @@ def partition(self, n): for output_stream in output_stream_list] return self._grouped_output - def __iter__(self): - # type: () -> Iterator[bytes] + def __iter__(self) -> Iterator[bytes]: if self.cleared: raise RuntimeError('Trying to iterate through a cleared ListBuffer.') return iter(self._inputs) - def clear(self): - # type: () -> None + def clear(self) -> None: self.cleared = True self._inputs = [] self._grouped_output = None - def reset(self): - # type: () -> None - + def reset(self) -> None: """Resets a cleared buffer for reuse.""" if not self.cleared: raise RuntimeError('Trying to reset a non-cleared ListBuffer.') @@ -204,19 +190,17 @@ def reset(self): class GroupingBuffer(object): """Used to accumulate groupded (shuffled) results.""" - def __init__(self, - pre_grouped_coder, # type: coders.Coder - post_grouped_coder, # type: coders.Coder - windowing # type: core.Windowing - ): - # type: (...) -> None + def __init__( + self, + pre_grouped_coder: coders.Coder, + post_grouped_coder: coders.Coder, + windowing: core.Windowing) -> None: self._key_coder = pre_grouped_coder.key_coder() self._pre_grouped_coder = pre_grouped_coder self._post_grouped_coder = post_grouped_coder - self._table = collections.defaultdict( - list) # type: DefaultDict[bytes, List[Any]] + self._table: DefaultDict[bytes, List[Any]] = collections.defaultdict(list) self._windowing = windowing - self._grouped_output = None # type: Optional[List[List[bytes]]] + self._grouped_output: Optional[List[List[bytes]]] = None def copy(self) -> 'GroupingBuffer': # This is a silly temporary optimization. This class must be removed once @@ -224,8 +208,7 @@ def copy(self) -> 'GroupingBuffer': # data grouping instead of GroupingBuffer). return self - def append(self, elements_data): - # type: (bytes) -> None + def append(self, elements_data: bytes) -> None: if self._grouped_output: raise RuntimeError('Grouping table append after read.') input_stream = create_InputStream(elements_data) @@ -241,8 +224,7 @@ def append(self, elements_data): value if is_trivial_windowing else windowed_key_value. with_value(value)) - def extend(self, input_buffer): - # type: (Buffer) -> None + def extend(self, input_buffer: Buffer) -> None: if isinstance(input_buffer, ListBuffer): # TODO(pabloem): GroupingBuffer will be removed once shuffling is done # via state. Remove this workaround along with that. @@ -252,9 +234,7 @@ def extend(self, input_buffer): for key, values in input_buffer._table.items(): self._table[key].extend(values) - def partition(self, n): - # type: (int) -> List[List[bytes]] - + def partition(self, n: int) -> List[List[bytes]]: """ It is used to partition _GroupingBuffer to N parts. Once it is partitioned, it would not be re-partitioned with diff N. Re-partition is not supported now. @@ -292,9 +272,7 @@ def partition(self, n): self._table.clear() return self._grouped_output - def __iter__(self): - # type: () -> Iterator[bytes] - + def __iter__(self) -> Iterator[bytes]: """ Since partition() returns a list of lists, add this __iter__ to return a list to simplify code when we need to iterate through ALL elements of _GroupingBuffer. @@ -305,12 +283,10 @@ def __iter__(self): # PartionableBuffer protocol cleared = False - def clear(self): - # type: () -> None + def clear(self) -> None: pass - def reset(self): - # type: () -> None + def reset(self) -> None: pass @@ -318,15 +294,13 @@ class WindowGroupingBuffer(object): """Used to partition windowed side inputs.""" def __init__( self, - access_pattern, # type: beam_runner_api_pb2.FunctionSpec - coder # type: WindowedValueCoder - ): - # type: (...) -> None + access_pattern: beam_runner_api_pb2.FunctionSpec, + coder: WindowedValueCoder) -> None: # Here's where we would use a different type of partitioning # (e.g. also by key) for a different access pattern. if access_pattern.urn == common_urns.side_inputs.ITERABLE.urn: self._kv_extractor = lambda value: ('', value) - self._key_coder = coders.SingletonCoder('') # type: coders.Coder + self._key_coder: coders.Coder = coders.SingletonCoder('') self._value_coder = coder.wrapped_value_coder elif access_pattern.urn == common_urns.side_inputs.MULTIMAP.urn: self._kv_extractor = lambda value: value @@ -336,23 +310,22 @@ def __init__( raise ValueError("Unknown access pattern: '%s'" % access_pattern.urn) self._windowed_value_coder = coder self._window_coder = coder.window_coder - self._values_by_window = collections.defaultdict( - list) # type: DefaultDict[Tuple[str, BoundedWindow], List[Any]] + self._values_by_window: DefaultDict[Tuple[str, BoundedWindow], + List[Any]] = collections.defaultdict( + list) - def append(self, elements_data): - # type: (bytes) -> None + def append(self, elements_data: bytes) -> None: input_stream = create_InputStream(elements_data) while input_stream.size() > 0: - windowed_val_coder_impl = self._windowed_value_coder.get_impl( - ) # type: WindowedValueCoderImpl + windowed_val_coder_impl: WindowedValueCoderImpl = ( + self._windowed_value_coder.get_impl()) windowed_value = windowed_val_coder_impl.decode_from_stream( input_stream, True) key, value = self._kv_extractor(windowed_value.value) for window in windowed_value.windows: self._values_by_window[key, window].append(value) - def encoded_items(self): - # type: () -> Iterator[Tuple[bytes, bytes, bytes, int]] + def encoded_items(self) -> Iterator[Tuple[bytes, bytes, bytes, int]]: value_coder_impl = self._value_coder.get_impl() key_coder_impl = self._key_coder.get_impl() for (key, window), values in self._values_by_window.items(): @@ -368,22 +341,21 @@ class GenericNonMergingWindowFn(window.NonMergingWindowFn): URN = 'internal-generic-non-merging' - def __init__(self, coder): - # type: (coders.Coder) -> None + def __init__(self, coder: coders.Coder) -> None: self._coder = coder - def assign(self, assign_context): - # type: (window.WindowFn.AssignContext) -> Iterable[BoundedWindow] + def assign( + self, + assign_context: window.WindowFn.AssignContext) -> Iterable[BoundedWindow]: raise NotImplementedError() - def get_window_coder(self): - # type: () -> coders.Coder + def get_window_coder(self) -> coders.Coder: return self._coder @staticmethod @window.urns.RunnerApiFn.register_urn(URN, bytes) - def from_runner_api_parameter(window_coder_id, context): - # type: (bytes, Any) -> GenericNonMergingWindowFn + def from_runner_api_parameter( + window_coder_id: bytes, context: Any) -> 'GenericNonMergingWindowFn': return GenericNonMergingWindowFn( context.coders[window_coder_id.decode('utf-8')]) @@ -478,11 +450,13 @@ class GenericMergingWindowFn(window.WindowFn): TO_SDK_TRANSFORM = 'read' FROM_SDK_TRANSFORM = 'write' - _HANDLES = {} # type: Dict[str, GenericMergingWindowFn] + _HANDLES: Dict[str, 'GenericMergingWindowFn'] = {} - def __init__(self, execution_context, windowing_strategy_proto): - # type: (FnApiRunnerExecutionContext, beam_runner_api_pb2.WindowingStrategy) -> None - self._worker_handler = None # type: Optional[worker_handlers.WorkerHandler] + def __init__( + self, + execution_context: 'FnApiRunnerExecutionContext', + windowing_strategy_proto: beam_runner_api_pb2.WindowingStrategy) -> None: + self._worker_handler: Optional[worker_handlers.WorkerHandler] = None self._handle_id = handle_id = uuid.uuid4().hex self._HANDLES[handle_id] = self # ExecutionContexts are expensive, we don't want to keep them in the @@ -494,32 +468,30 @@ def __init__(self, execution_context, windowing_strategy_proto): self._counter = 0 # Lazily created in make_process_bundle_descriptor() self._process_bundle_descriptor = None - self._bundle_processor_id = '' # type: str - self.windowed_input_coder_impl = None # type: Optional[CoderImpl] - self.windowed_output_coder_impl = None # type: Optional[CoderImpl] + self._bundle_processor_id: str = '' + self.windowed_input_coder_impl: Optional[CoderImpl] = None + self.windowed_output_coder_impl: Optional[CoderImpl] = None - def _execution_context_ref(self): - # type: () -> FnApiRunnerExecutionContext + def _execution_context_ref(self) -> 'FnApiRunnerExecutionContext': result = self._execution_context_ref_obj() assert result is not None return result - def payload(self): - # type: () -> bytes + def payload(self) -> bytes: return self._handle_id.encode('utf-8') @staticmethod @window.urns.RunnerApiFn.register_urn(URN, bytes) - def from_runner_api_parameter(handle_id, unused_context): - # type: (bytes, Any) -> GenericMergingWindowFn + def from_runner_api_parameter( + handle_id: bytes, unused_context: Any) -> 'GenericMergingWindowFn': return GenericMergingWindowFn._HANDLES[handle_id.decode('utf-8')] - def assign(self, assign_context): - # type: (window.WindowFn.AssignContext) -> Iterable[window.BoundedWindow] + def assign( + self, assign_context: window.WindowFn.AssignContext + ) -> Iterable[window.BoundedWindow]: raise NotImplementedError() - def merge(self, merge_context): - # type: (window.WindowFn.MergeContext) -> None + def merge(self, merge_context: window.WindowFn.MergeContext) -> None: worker_handler = self.worker_handle() assert self.windowed_input_coder_impl is not None @@ -554,13 +526,11 @@ def merge(self, merge_context): raise RuntimeError(result.error) # The result was "returned" via the merge callbacks on merge_context above. - def get_window_coder(self): - # type: () -> coders.Coder + def get_window_coder(self) -> coders.Coder: return self._execution_context_ref().pipeline_context.coders[ self._windowing_strategy_proto.window_coder_id] - def worker_handle(self): - # type: () -> worker_handlers.WorkerHandler + def worker_handle(self) -> 'worker_handlers.WorkerHandler': if self._worker_handler is None: worker_handler_manager = self._execution_context_ref( ).worker_handler_manager @@ -574,14 +544,14 @@ def worker_handle(self): return self._worker_handler def make_process_bundle_descriptor( - self, data_api_service_descriptor, state_api_service_descriptor): - # type: (Optional[endpoints_pb2.ApiServiceDescriptor], Optional[endpoints_pb2.ApiServiceDescriptor]) -> beam_fn_api_pb2.ProcessBundleDescriptor - + self, + data_api_service_descriptor: Optional[endpoints_pb2.ApiServiceDescriptor], + state_api_service_descriptor: Optional[endpoints_pb2.ApiServiceDescriptor] + ) -> beam_fn_api_pb2.ProcessBundleDescriptor: """Creates a ProcessBundleDescriptor for invoking the WindowFn's merge operation. """ - def make_channel_payload(coder_id): - # type: (str) -> bytes + def make_channel_payload(coder_id: str) -> bytes: data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) if data_api_service_descriptor: data_spec.api_service_descriptor.url = (data_api_service_descriptor.url) @@ -593,8 +563,7 @@ def make_channel_payload(coder_id): window.GlobalWindows()).to_runner_api(pipeline_context) coders = dict(pipeline_context.coders.get_id_to_proto_map()) - def make_coder(urn, *components): - # type: (str, str) -> str + def make_coder(urn: str, *components: str) -> str: coder_proto = beam_runner_api_pb2.Coder( spec=beam_runner_api_pb2.FunctionSpec(urn=urn), component_coder_ids=components) @@ -681,8 +650,7 @@ def make_coder(urn, *components): state_api_service_descriptor=state_api_service_descriptor, timer_api_service_descriptor=data_api_service_descriptor) - def uid(self, name=''): - # type: (str) -> str + def uid(self, name: str = '') -> str: self._counter += 1 return '%s_%s_%s' % (self._handle_id, name, self._counter) @@ -693,16 +661,18 @@ class FnApiRunnerExecutionContext(object): PCollection IDs to list that functions as buffer for the ``beam.PCollection``. """ - def __init__(self, - stages, # type: List[translations.Stage] - worker_handler_manager, # type: worker_handlers.WorkerHandlerManager - pipeline_components, # type: beam_runner_api_pb2.Components - safe_coders: translations.SafeCoderMapping, - data_channel_coders: Dict[str, str], - num_workers: int, - uses_teststream: bool = False, - split_managers = () # type: Sequence[Tuple[str, Callable[[int], Iterable[float]]]] - ) -> None: + def __init__( + self, + stages: List[translations.Stage], + worker_handler_manager: 'worker_handlers.WorkerHandlerManager', + pipeline_components: beam_runner_api_pb2.Components, + safe_coders: translations.SafeCoderMapping, + data_channel_coders: Dict[str, str], + num_workers: int, + uses_teststream: bool = False, + split_managers: Sequence[Tuple[str, Callable[[int], + Iterable[float]]]] = () + ) -> None: """ :param worker_handler_manager: This class manages the set of worker handlers, and the communication with state / control APIs. @@ -714,8 +684,8 @@ def __init__(self, self.stages = {s.name: s for s in stages} self.side_input_descriptors_by_stage = ( self._build_data_side_inputs_map(stages)) - self.pcoll_buffers = {} # type: MutableMapping[bytes, PartitionableBuffer] - self.timer_buffers = {} # type: MutableMapping[bytes, ListBuffer] + self.pcoll_buffers: MutableMapping[bytes, PartitionableBuffer] = {} + self.timer_buffers: MutableMapping[bytes, ListBuffer] = {} self.worker_handler_manager = worker_handler_manager self.pipeline_components = pipeline_components self.safe_coders = safe_coders @@ -806,7 +776,7 @@ def setup(self) -> None: def _enqueue_stage_initial_inputs(self, stage: Stage) -> None: """Sets up IMPULSE inputs for a stage, and the data GRPC API endpoint.""" - data_input = {} # type: MutableMapping[str, PartitionableBuffer] + data_input: MutableMapping[str, PartitionableBuffer] = {} ready_to_schedule = True for transform in stage.transforms: if (transform.spec.urn in {bundle_processor.DATA_INPUT_URN, @@ -854,23 +824,23 @@ def _enqueue_stage_initial_inputs(self, stage: Stage) -> None: ((stage.name, MAX_TIMESTAMP), DataInput(data_input, {}))) @staticmethod - def _build_data_side_inputs_map(stages): - # type: (Iterable[translations.Stage]) -> MutableMapping[str, DataSideInput] - + def _build_data_side_inputs_map( + stages: Iterable[translations.Stage] + ) -> MutableMapping[str, 'DataSideInput']: """Builds an index mapping stages to side input descriptors. A side input descriptor is a map of side input IDs to side input access patterns for all of the outputs of a stage that will be consumed as a side input. """ - transform_consumers = collections.defaultdict( - list) # type: DefaultDict[str, List[beam_runner_api_pb2.PTransform]] - stage_consumers = collections.defaultdict( - list) # type: DefaultDict[str, List[translations.Stage]] - - def get_all_side_inputs(): - # type: () -> Set[str] - all_side_inputs = set() # type: Set[str] + transform_consumers: DefaultDict[ + str, + List[beam_runner_api_pb2.PTransform]] = collections.defaultdict(list) + stage_consumers: DefaultDict[ + str, List[translations.Stage]] = collections.defaultdict(list) + + def get_all_side_inputs() -> Set[str]: + all_side_inputs: Set[str] = set() for stage in stages: for transform in stage.transforms: for input in transform.inputs.values(): @@ -881,7 +851,7 @@ def get_all_side_inputs(): return all_side_inputs all_side_inputs = frozenset(get_all_side_inputs()) - data_side_inputs_by_producing_stage = {} # type: Dict[str, DataSideInput] + data_side_inputs_by_producing_stage: Dict[str, DataSideInput] = {} producing_stages_by_pcoll = {} @@ -912,8 +882,7 @@ def get_all_side_inputs(): return data_side_inputs_by_producing_stage - def _make_safe_windowing_strategy(self, id): - # type: (str) -> str + def _make_safe_windowing_strategy(self, id: str) -> str: windowing_strategy_proto = self.pipeline_components.windowing_strategies[id] if windowing_strategy_proto.window_fn.urn in SAFE_WINDOW_FNS: return id @@ -940,18 +909,16 @@ def _make_safe_windowing_strategy(self, id): return safe_id @property - def state_servicer(self): - # type: () -> worker_handlers.StateServicer + def state_servicer(self) -> 'worker_handlers.StateServicer': # TODO(BEAM-9625): Ensure FnApiRunnerExecutionContext owns StateServicer return self.worker_handler_manager.state_servicer - def next_uid(self): - # type: () -> str + def next_uid(self) -> str: self._last_uid += 1 return str(self._last_uid) - def _iterable_state_write(self, values, element_coder_impl): - # type: (Iterable, CoderImpl) -> bytes + def _iterable_state_write( + self, values: Iterable, element_coder_impl: CoderImpl) -> bytes: token = unique_name(None, 'iter').encode('ascii') out = create_OutputStream() for element in values: @@ -964,9 +931,8 @@ def _iterable_state_write(self, values, element_coder_impl): def commit_side_inputs_to_state( self, - data_side_input, # type: DataSideInput - ): - # type: (...) -> None + data_side_input: 'DataSideInput', + ) -> None: for (consuming_transform_id, tag), (buffer_id, func_spec) in data_side_input.items(): _, pcoll_id = split_buffer_id(buffer_id) @@ -1024,14 +990,13 @@ def commit_side_inputs_to_state( class BundleContextManager(object): - - def __init__(self, - execution_context, # type: FnApiRunnerExecutionContext - stage, # type: translations.Stage - num_workers, # type: int - split_managers, # type: Sequence[Tuple[str, Callable[[int], Iterable[float]]]] - ): - # type: (...) -> None + def __init__( + self, + execution_context: FnApiRunnerExecutionContext, + stage: translations.Stage, + num_workers: int, + split_managers: Sequence[Tuple[str, Callable[[int], Iterable[float]]]], + ) -> None: self.execution_context = execution_context self.stage = stage self.bundle_uid = self.execution_context.next_uid() @@ -1039,12 +1004,13 @@ def __init__(self, self.split_managers = split_managers # Properties that are lazily initialized - self._process_bundle_descriptor = None # type: Optional[beam_fn_api_pb2.ProcessBundleDescriptor] - self._worker_handlers = None # type: Optional[List[worker_handlers.WorkerHandler]] + self._process_bundle_descriptor: Optional[ + beam_fn_api_pb2.ProcessBundleDescriptor] = None + self._worker_handlers: Optional[List[worker_handlers.WorkerHandler]] = None # a mapping of {(transform_id, timer_family_id): timer_coder_id}. The map # is built after self._process_bundle_descriptor is initialized. # This field can be used to tell whether current bundle has timers. - self._timer_coder_ids = None # type: Optional[Dict[Tuple[str, str], str]] + self._timer_coder_ids: Optional[Dict[Tuple[str, str], str]] = None # A mapping from transform_name to Buffer ID self.stage_data_outputs: DataOutput = {} @@ -1066,36 +1032,35 @@ def _compute_expected_outputs(self) -> None: create_buffer_id(timer_family_id, 'timers'), time_domain) @property - def worker_handlers(self): - # type: () -> List[worker_handlers.WorkerHandler] + def worker_handlers(self) -> List['worker_handlers.WorkerHandler']: if self._worker_handlers is None: self._worker_handlers = ( self.execution_context.worker_handler_manager.get_worker_handlers( self.stage.environment, self.num_workers)) return self._worker_handlers - def data_api_service_descriptor(self): - # type: () -> Optional[endpoints_pb2.ApiServiceDescriptor] + def data_api_service_descriptor( + self) -> Optional[endpoints_pb2.ApiServiceDescriptor]: # All worker_handlers share the same grpc server, so we can read grpc server # info from any worker_handler and read from the first worker_handler. return self.worker_handlers[0].data_api_service_descriptor() - def state_api_service_descriptor(self): - # type: () -> Optional[endpoints_pb2.ApiServiceDescriptor] + def state_api_service_descriptor( + self) -> Optional[endpoints_pb2.ApiServiceDescriptor]: # All worker_handlers share the same grpc server, so we can read grpc server # info from any worker_handler and read from the first worker_handler. return self.worker_handlers[0].state_api_service_descriptor() @property - def process_bundle_descriptor(self): - # type: () -> beam_fn_api_pb2.ProcessBundleDescriptor + def process_bundle_descriptor( + self) -> beam_fn_api_pb2.ProcessBundleDescriptor: if self._process_bundle_descriptor is None: self._process_bundle_descriptor = self._build_process_bundle_descriptor() self._timer_coder_ids = self._build_timer_coders_id_map() return self._process_bundle_descriptor - def _build_process_bundle_descriptor(self): - # type: () -> beam_fn_api_pb2.ProcessBundleDescriptor + def _build_process_bundle_descriptor( + self) -> beam_fn_api_pb2.ProcessBundleDescriptor: # Cannot be invoked until *after* _extract_endpoints is called. # Always populate the timer_api_service_descriptor. return beam_fn_api_pb2.ProcessBundleDescriptor( @@ -1115,16 +1080,14 @@ def _build_process_bundle_descriptor(self): state_api_service_descriptor=self.state_api_service_descriptor(), timer_api_service_descriptor=self.data_api_service_descriptor()) - def get_input_coder_impl(self, transform_id): - # type: (str) -> CoderImpl + def get_input_coder_impl(self, transform_id: str) -> CoderImpl: coder_id = beam_fn_api_pb2.RemoteGrpcPort.FromString( self.process_bundle_descriptor.transforms[transform_id].spec.payload ).coder_id assert coder_id return self.get_coder_impl(coder_id) - def _build_timer_coders_id_map(self): - # type: () -> Dict[Tuple[str, str], str] + def _build_timer_coders_id_map(self) -> Dict[Tuple[str, str], str]: assert self._process_bundle_descriptor is not None timer_coder_ids = {} for transform_id, transform_proto in (self._process_bundle_descriptor @@ -1137,23 +1100,21 @@ def _build_timer_coders_id_map(self): timer_family_spec.timer_family_coder_id) return timer_coder_ids - def get_coder_impl(self, coder_id): - # type: (str) -> CoderImpl + def get_coder_impl(self, coder_id: str) -> CoderImpl: if coder_id in self.execution_context.safe_coders: return self.execution_context.pipeline_context.coders[ self.execution_context.safe_coders[coder_id]].get_impl() else: return self.execution_context.pipeline_context.coders[coder_id].get_impl() - def get_timer_coder_impl(self, transform_id, timer_family_id): - # type: (str, str) -> CoderImpl + def get_timer_coder_impl( + self, transform_id: str, timer_family_id: str) -> CoderImpl: assert self._timer_coder_ids is not None return self.get_coder_impl( self._timer_coder_ids[(transform_id, timer_family_id)]) - def get_buffer(self, buffer_id, transform_id): - # type: (bytes, str) -> PartitionableBuffer - + def get_buffer( + self, buffer_id: bytes, transform_id: str) -> PartitionableBuffer: """Returns the buffer for a given (operation_type, PCollection ID). For grouping-typed operations, we produce a ``GroupingBuffer``. For others, we produce a ``ListBuffer``. diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py index 07569fe328d8..1ed21942d28f 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py @@ -31,7 +31,6 @@ import sys import threading import time -from typing import TYPE_CHECKING from typing import Callable from typing import Dict from typing import Iterable @@ -55,11 +54,13 @@ from apache_beam.metrics.monitoring_infos import consolidate as consolidate_monitoring_infos from apache_beam.options import pipeline_options from apache_beam.options.value_provider import RuntimeValueProvider +from apache_beam.pipeline import Pipeline from apache_beam.portability import common_urns from apache_beam.portability import python_urns from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.portability.api import beam_provision_api_pb2 from apache_beam.portability.api import beam_runner_api_pb2 +from apache_beam.portability.api import metrics_pb2 from apache_beam.runners import runner from apache_beam.runners.common import group_by_key_input_visitor from apache_beam.runners.common import merge_common_environments @@ -75,6 +76,7 @@ from apache_beam.runners.portability.fn_api_runner.translations import OutputTimers from apache_beam.runners.portability.fn_api_runner.translations import create_buffer_id from apache_beam.runners.portability.fn_api_runner.translations import only_element +from apache_beam.runners.portability.fn_api_runner.worker_handlers import WorkerHandler from apache_beam.runners.portability.fn_api_runner.worker_handlers import WorkerHandlerManager from apache_beam.runners.worker import bundle_processor from apache_beam.transforms import environments @@ -83,11 +85,6 @@ from apache_beam.utils import timestamp from apache_beam.utils.profiler import Profile -if TYPE_CHECKING: - from apache_beam.pipeline import Pipeline - from apache_beam.portability.api import metrics_pb2 - from apache_beam.runners.portability.fn_api_runner.worker_handlers import WorkerHandler - _LOGGER = logging.getLogger(__name__) _BUNDLE_LOGGER = logging.getLogger(__name__ + ".run_bundle") @@ -102,15 +99,12 @@ class FnApiRunner(runner.PipelineRunner): def __init__( self, - default_environment=None, # type: Optional[environments.Environment] - bundle_repeat=0, # type: int - use_state_iterables=False, # type: bool - provision_info=None, # type: Optional[ExtendedProvisionInfo] - progress_request_frequency=None, # type: Optional[float] - is_drain=False # type: bool - ): - # type: (...) -> None - + default_environment: Optional[environments.Environment] = None, + bundle_repeat: int = 0, + use_state_iterables: bool = False, + provision_info: Optional['ExtendedProvisionInfo'] = None, + progress_request_frequency: Optional[float] = None, + is_drain: bool = False) -> None: """Creates a new Fn API Runner. Args: @@ -138,19 +132,16 @@ def __init__( retrieval_token='unused-retrieval-token')) @staticmethod - def supported_requirements(): - # type: () -> Tuple[str, ...] + def supported_requirements() -> Tuple[str, ...]: return ( common_urns.requirements.REQUIRES_STATEFUL_PROCESSING.urn, common_urns.requirements.REQUIRES_BUNDLE_FINALIZATION.urn, common_urns.requirements.REQUIRES_SPLITTABLE_DOFN.urn, ) - def run_pipeline(self, - pipeline, # type: Pipeline - options # type: pipeline_options.PipelineOptions - ): - # type: (...) -> RunnerResult + def run_pipeline( + self, pipeline: Pipeline, + options: pipeline_options.PipelineOptions) -> 'RunnerResult': RuntimeValueProvider.set_runtime_options({}) # Setup "beam_fn_api" experiment options if lacked. @@ -206,8 +197,10 @@ def run_pipeline(self, options) return self._latest_run_result - def run_via_runner_api(self, pipeline_proto, options): - # type: (beam_runner_api_pb2.Pipeline, pipeline_options.PipelineOptions) -> RunnerResult + def run_via_runner_api( + self, + pipeline_proto: beam_runner_api_pb2.Pipeline, + options: pipeline_options.PipelineOptions) -> 'RunnerResult': validate_pipeline_graph(pipeline_proto) self._validate_requirements(pipeline_proto) self._check_requirements(pipeline_proto) @@ -282,8 +275,7 @@ def resolve_any_environments(self, pipeline_proto): return pipeline_proto @contextlib.contextmanager - def maybe_profile(self): - # type: () -> Iterator[None] + def maybe_profile(self) -> Iterator[None]: if self._profiler_factory: try: profile_id = 'direct-' + subprocess.check_output([ @@ -291,8 +283,8 @@ def maybe_profile(self): ]).decode(errors='ignore').strip() except subprocess.CalledProcessError: profile_id = 'direct-unknown' - profiler = self._profiler_factory( - profile_id, time_prefix='') # type: Optional[Profile] + profiler: Optional[Profile] = self._profiler_factory( + profile_id, time_prefix='') else: profiler = None @@ -328,14 +320,12 @@ def maybe_profile(self): # Empty context. yield - def _validate_requirements(self, pipeline_proto): - # type: (beam_runner_api_pb2.Pipeline) -> None - + def _validate_requirements( + self, pipeline_proto: beam_runner_api_pb2.Pipeline) -> None: """As a test runner, validate requirements were set correctly.""" expected_requirements = set() - def add_requirements(transform_id): - # type: (str) -> None + def add_requirements(transform_id: str) -> None: transform = pipeline_proto.components.transforms[transform_id] if transform.spec.urn in translations.PAR_DO_URNS: payload = proto_utils.parse_Bytes( @@ -366,9 +356,8 @@ def add_requirements(transform_id): 'Missing requirement declaration: %s' % (expected_requirements - set(pipeline_proto.requirements))) - def _check_requirements(self, pipeline_proto): - # type: (beam_runner_api_pb2.Pipeline) -> None - + def _check_requirements( + self, pipeline_proto: beam_runner_api_pb2.Pipeline) -> None: """Check that this runner can satisfy all pipeline requirements.""" supported_requirements = set(self.supported_requirements()) for requirement in pipeline_proto.requirements: @@ -388,10 +377,8 @@ def _check_requirements(self, pipeline_proto): raise NotImplementedError(timer.time_domain) def create_stages( - self, - pipeline_proto # type: beam_runner_api_pb2.Pipeline - ): - # type: (...) -> Tuple[translations.TransformContext, List[translations.Stage]] + self, pipeline_proto: beam_runner_api_pb2.Pipeline + ) -> Tuple[translations.TransformContext, List[translations.Stage]]: return translations.create_and_optimize_stages( copy.deepcopy(pipeline_proto), phases=[ @@ -417,12 +404,10 @@ def create_stages( use_state_iterables=self._use_state_iterables, is_drain=self._is_drain) - def run_stages(self, - stage_context, # type: translations.TransformContext - stages # type: List[translations.Stage] - ): - # type: (...) -> RunnerResult - + def run_stages( + self, + stage_context: translations.TransformContext, + stages: List[translations.Stage]) -> 'RunnerResult': """Run a list of topologically-sorted stages in batch mode. Args: @@ -593,11 +578,12 @@ def _schedule_ready_bundles( def _run_bundle_multiple_times_for_testing( self, - runner_execution_context, # type: execution.FnApiRunnerExecutionContext - bundle_manager, # type: BundleManager - data_input, # type: MutableMapping[str, execution.PartitionableBuffer] - data_output, # type: DataOutput - fired_timers, # type: Mapping[translations.TimerFamilyId, execution.PartitionableBuffer] + runner_execution_context: execution.FnApiRunnerExecutionContext, + bundle_manager: 'BundleManager', + data_input: MutableMapping[str, execution.PartitionableBuffer], + data_output: DataOutput, + fired_timers: Mapping[translations.TimerFamilyId, + execution.PartitionableBuffer], expected_output_timers: OutputTimers, ) -> None: """ @@ -679,12 +665,10 @@ def _collect_written_timers( def _add_sdk_delayed_applications_to_deferred_inputs( self, - bundle_context_manager, # type: execution.BundleContextManager - bundle_result, # type: beam_fn_api_pb2.InstructionResponse - deferred_inputs # type: MutableMapping[str, execution.PartitionableBuffer] - ): - # type: (...) -> Set[str] - + bundle_context_manager: execution.BundleContextManager, + bundle_result: beam_fn_api_pb2.InstructionResponse, + deferred_inputs: MutableMapping[str, execution.PartitionableBuffer] + ) -> Set[str]: """Returns a set of PCollection IDs of PColls having delayed applications. This transform inspects the bundle_context_manager, and bundle_result @@ -711,13 +695,11 @@ def _add_sdk_delayed_applications_to_deferred_inputs( def _add_residuals_and_channel_splits_to_deferred_inputs( self, - splits, # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] - bundle_context_manager, # type: execution.BundleContextManager - last_sent, # type: MutableMapping[str, execution.PartitionableBuffer] - deferred_inputs # type: MutableMapping[str, execution.PartitionableBuffer] - ): - # type: (...) -> Tuple[Set[str], Set[str]] - + splits: List[beam_fn_api_pb2.ProcessBundleSplitResponse], + bundle_context_manager: execution.BundleContextManager, + last_sent: MutableMapping[str, execution.PartitionableBuffer], + deferred_inputs: MutableMapping[str, execution.PartitionableBuffer] + ) -> Tuple[Set[str], Set[str]]: """Returns a two sets representing PCollections with watermark holds. The first set represents PCollections with delayed root applications. @@ -726,7 +708,7 @@ def _add_residuals_and_channel_splits_to_deferred_inputs( pcolls_with_delayed_apps = set() transforms_with_channel_splits = set() - prev_stops = {} # type: Dict[str, int] + prev_stops: Dict[str, int] = {} for split in splits: for delayed_application in split.residual_roots: producer_name = bundle_context_manager.input_for( @@ -783,11 +765,11 @@ def _add_residuals_and_channel_splits_to_deferred_inputs( channel_split.transform_id] = channel_split.last_primary_element return pcolls_with_delayed_apps, transforms_with_channel_splits - def _execute_bundle(self, - runner_execution_context, # type: execution.FnApiRunnerExecutionContext - bundle_context_manager, # type: execution.BundleContextManager - bundle_input: DataInput - ) -> beam_fn_api_pb2.InstructionResponse: + def _execute_bundle( + self, + runner_execution_context: execution.FnApiRunnerExecutionContext, + bundle_context_manager: execution.BundleContextManager, + bundle_input: DataInput) -> beam_fn_api_pb2.InstructionResponse: """Execute a bundle end-to-end. Args: @@ -943,7 +925,8 @@ def _get_bundle_manager( cache_token_generator = FnApiRunner.get_cache_token_generator(static=False) if bundle_context_manager.num_workers == 1: # Avoid thread/processor pools for increased performance and debugability. - bundle_manager_type = BundleManager # type: Union[Type[BundleManager], Type[ParallelBundleManager]] + bundle_manager_type: Union[Type[BundleManager], + Type[ParallelBundleManager]] = BundleManager elif bundle_context_manager.stage.is_stateful(): # State is keyed, and a single key cannot be processed concurrently. # Alternatively, we could arrange to partition work by key. @@ -958,12 +941,13 @@ def _get_bundle_manager( @staticmethod def _build_watermark_updates( - runner_execution_context, # type: execution.FnApiRunnerExecutionContext - stage_inputs, # type: Iterable[str] - expected_timers, # type: Iterable[translations.TimerFamilyId] - pcolls_with_da, # type: Set[str] - transforms_w_splits, # type: Set[str] - watermarks_by_transform_and_timer_family # type: Dict[translations.TimerFamilyId, timestamp.Timestamp] + runner_execution_context: execution.FnApiRunnerExecutionContext, + stage_inputs: Iterable[str], + expected_timers: Iterable[translations.TimerFamilyId], + pcolls_with_da: Set[str], + transforms_w_splits: Set[str], + watermarks_by_transform_and_timer_family: Dict[translations.TimerFamilyId, + timestamp.Timestamp] ) -> Dict[Union[str, translations.TimerFamilyId], timestamp.Timestamp]: """Builds a dictionary of PCollection (or TimerFamilyId) to timestamp. @@ -979,8 +963,8 @@ def _build_watermark_updates( watermarks_by_transform_and_timer_family: represent the set of watermark holds to be added for each timer family. """ - updates = { - } # type: Dict[Union[str, translations.TimerFamilyId], timestamp.Timestamp] + updates: Dict[Union[str, translations.TimerFamilyId], + timestamp.Timestamp] = {} def get_pcoll_id(transform_id): buffer_id = runner_execution_context.input_transform_to_buffer_id[ @@ -1024,12 +1008,12 @@ def get_pcoll_id(transform_id): def _run_bundle( self, - runner_execution_context, # type: execution.FnApiRunnerExecutionContext - bundle_context_manager, # type: execution.BundleContextManager + runner_execution_context: execution.FnApiRunnerExecutionContext, + bundle_context_manager: execution.BundleContextManager, bundle_input: DataInput, data_output: DataOutput, expected_timer_output: OutputTimers, - bundle_manager # type: BundleManager + bundle_manager: 'BundleManager' ) -> Tuple[beam_fn_api_pb2.InstructionResponse, Dict[str, execution.PartitionableBuffer], OutputTimerData, @@ -1052,7 +1036,7 @@ def _run_bundle( # - timers # - SDK-initiated deferred applications of root elements # - Runner-initiated deferred applications of root elements - deferred_inputs = {} # type: Dict[str, execution.PartitionableBuffer] + deferred_inputs: Dict[str, execution.PartitionableBuffer] = {} watermarks_by_transform_and_timer_family, newly_set_timers = ( self._collect_written_timers(bundle_context_manager)) @@ -1085,48 +1069,42 @@ def _run_bundle( return result, deferred_inputs, newly_set_timers, watermark_updates @staticmethod - def get_cache_token_generator(static=True): - # type: (bool) -> Iterator[beam_fn_api_pb2.ProcessBundleRequest.CacheToken] - + def get_cache_token_generator( + static: bool = True + ) -> Iterator[beam_fn_api_pb2.ProcessBundleRequest.CacheToken]: """A generator for cache tokens. :arg static If True, generator always returns the same cache token If False, generator returns a new cache token each time :return A generator which returns a cache token on next(generator) """ - def generate_token(identifier): - # type: (int) -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken + def generate_token( + identifier: int) -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken: return beam_fn_api_pb2.ProcessBundleRequest.CacheToken( user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.UserState( ), token="cache_token_{}".format(identifier).encode("utf-8")) class StaticGenerator(object): - def __init__(self): - # type: () -> None + def __init__(self) -> None: self._token = generate_token(1) - def __iter__(self): - # type: () -> StaticGenerator + def __iter__(self) -> 'StaticGenerator': # pylint: disable=non-iterator-returned return self - def __next__(self): - # type: () -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken + def __next__(self) -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken: return self._token class DynamicGenerator(object): - def __init__(self): - # type: () -> None + def __init__(self) -> None: self._counter = 0 self._lock = threading.Lock() - def __iter__(self): - # type: () -> DynamicGenerator + def __iter__(self) -> 'DynamicGenerator': # pylint: disable=non-iterator-returned return self - def __next__(self): - # type: () -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken + def __next__(self) -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken: with self._lock: self._counter += 1 return generate_token(self._counter) @@ -1138,19 +1116,18 @@ def __next__(self): class ExtendedProvisionInfo(object): - def __init__(self, - provision_info=None, # type: Optional[beam_provision_api_pb2.ProvisionInfo] - artifact_staging_dir=None, # type: Optional[str] - job_name='', # type: str - ): - # type: (...) -> None + def __init__( + self, + provision_info: Optional[beam_provision_api_pb2.ProvisionInfo] = None, + artifact_staging_dir: Optional[str] = None, + job_name: str = '', + ) -> None: self.provision_info = ( provision_info or beam_provision_api_pb2.ProvisionInfo()) self.artifact_staging_dir = artifact_staging_dir self.job_name = job_name - def for_environment(self, env): - # type: (...) -> ExtendedProvisionInfo + def for_environment(self, env) -> 'ExtendedProvisionInfo': if env.dependencies: provision_info_with_deps = copy.deepcopy(self.provision_info) provision_info_with_deps.dependencies.extend(env.dependencies) @@ -1218,31 +1195,28 @@ class BundleManager(object): _uid_counter = 0 _lock = threading.Lock() - def __init__(self, - bundle_context_manager, # type: execution.BundleContextManager - progress_frequency=None, # type: Optional[float] - cache_token_generator=FnApiRunner.get_cache_token_generator(), - split_managers=() - ): - # type: (...) -> None - + def __init__( + self, + bundle_context_manager: execution.BundleContextManager, + progress_frequency: Optional[float] = None, + cache_token_generator=FnApiRunner.get_cache_token_generator(), + split_managers=() + ) -> None: """Set up a bundle manager. Args: progress_frequency """ - self.bundle_context_manager = bundle_context_manager # type: execution.BundleContextManager + self.bundle_context_manager: execution.BundleContextManager = ( + bundle_context_manager) self._progress_frequency = progress_frequency - self._worker_handler = None # type: Optional[WorkerHandler] + self._worker_handler: Optional[WorkerHandler] = None self._cache_token_generator = cache_token_generator self.split_managers = split_managers - def _send_input_to_worker(self, - process_bundle_id, # type: str - read_transform_id, # type: str - byte_streams - ): - # type: (...) -> None + def _send_input_to_worker( + self, process_bundle_id: str, read_transform_id: str, + byte_streams) -> None: assert self._worker_handler is not None data_out = self._worker_handler.data_conn.output_stream( process_bundle_id, read_transform_id) @@ -1251,8 +1225,7 @@ def _send_input_to_worker(self, data_out.close() def _send_timers_to_worker( - self, process_bundle_id, transform_id, timer_family_id, timers): - # type: (...) -> None + self, process_bundle_id, transform_id, timer_family_id, timers) -> None: assert self._worker_handler is not None timer_out = self._worker_handler.data_conn.output_timer_stream( process_bundle_id, transform_id, timer_family_id) @@ -1273,13 +1246,12 @@ def _select_split_manager(self) -> Optional[Callable[[int], Iterable[float]]]: return None - def _generate_splits_for_testing(self, - split_manager, - inputs, # type: Mapping[str, execution.PartitionableBuffer] - process_bundle_id - ): - # type: (...) -> List[beam_fn_api_pb2.ProcessBundleSplitResponse] - split_results = [] # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] + def _generate_splits_for_testing( + self, + split_manager, + inputs: Mapping[str, execution.PartitionableBuffer], + process_bundle_id) -> List[beam_fn_api_pb2.ProcessBundleSplitResponse]: + split_results: List[beam_fn_api_pb2.ProcessBundleSplitResponse] = [] read_transform_id, buffer_data = only_element(inputs.items()) byte_stream = b''.join(buffer_data or []) num_elements = len( @@ -1317,8 +1289,8 @@ def _generate_splits_for_testing(self, estimated_input_elements=num_elements) })) logging.info("Requesting split %s", split_request) - split_response = self._worker_handler.control_conn.push( - split_request).get() # type: beam_fn_api_pb2.InstructionResponse + split_response: beam_fn_api_pb2.InstructionResponse = ( + self._worker_handler.control_conn.push(split_request).get()) for t in (0.05, 0.1, 0.2): if ('Unknown process bundle' in split_response.error or split_response.process_bundle_split == @@ -1343,13 +1315,15 @@ def _generate_splits_for_testing(self, break return split_results - def process_bundle(self, - inputs, # type: Mapping[str, execution.PartitionableBuffer] - expected_outputs, # type: DataOutput - fired_timers, # type: Mapping[translations.TimerFamilyId, execution.PartitionableBuffer] - expected_output_timers: OutputTimers, - dry_run=False, # type: bool - ) -> BundleProcessResult: + def process_bundle( + self, + inputs: Mapping[str, execution.PartitionableBuffer], + expected_outputs: DataOutput, + fired_timers: Mapping[translations.TimerFamilyId, + execution.PartitionableBuffer], + expected_output_timers: OutputTimers, + dry_run: bool = False, + ) -> BundleProcessResult: # Unique id for the instruction processing this bundle. with BundleManager._lock: BundleManager._uid_counter += 1 @@ -1383,7 +1357,7 @@ def process_bundle(self, cache_tokens=[next(self._cache_token_generator)])) result_future = self._worker_handler.control_conn.push(process_bundle_req) - split_results = [] # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] + split_results: List[beam_fn_api_pb2.ProcessBundleSplitResponse] = [] with ProgressRequester(self._worker_handler, process_bundle_id, self._progress_frequency): @@ -1392,8 +1366,9 @@ def process_bundle(self, split_results = self._generate_splits_for_testing( split_manager, inputs, process_bundle_id) - expect_reads = list( - expected_outputs.keys()) # type: List[Union[str, Tuple[str, str]]] + expect_reads: List[Union[str, + Tuple[str, + str]]] = list(expected_outputs.keys()) expect_reads.extend(list(expected_output_timers.keys())) # Gather all output data. @@ -1417,7 +1392,7 @@ def process_bundle(self, expected_outputs[output.transform_id], output.transform_id).append(output.data) - result = result_future.get() # type: beam_fn_api_pb2.InstructionResponse + result: beam_fn_api_pb2.InstructionResponse = result_future.get() if result.error: raise RuntimeError(result.error) @@ -1435,30 +1410,30 @@ def process_bundle(self, class ParallelBundleManager(BundleManager): - def __init__( self, - bundle_context_manager, # type: execution.BundleContextManager - progress_frequency=None, # type: Optional[float] + bundle_context_manager: execution.BundleContextManager, + progress_frequency: Optional[float] = None, cache_token_generator=None, - **kwargs): - # type: (...) -> None + **kwargs) -> None: super().__init__( bundle_context_manager, progress_frequency, cache_token_generator=cache_token_generator) self._num_workers = bundle_context_manager.num_workers - def process_bundle(self, - inputs, # type: Mapping[str, execution.PartitionableBuffer] - expected_outputs, # type: DataOutput - fired_timers, # type: Mapping[translations.TimerFamilyId, execution.PartitionableBuffer] - expected_output_timers, # type: OutputTimers - dry_run=False, # type: bool - ): - # type: (...) -> BundleProcessResult - part_inputs = [{} for _ in range(self._num_workers) - ] # type: List[Dict[str, List[bytes]]] + def process_bundle( + self, + inputs: Mapping[str, execution.PartitionableBuffer], + expected_outputs: DataOutput, + fired_timers: Mapping[translations.TimerFamilyId, + execution.PartitionableBuffer], + expected_output_timers: OutputTimers, + dry_run: bool = False, + ) -> BundleProcessResult: + part_inputs: List[Dict[str, + List[bytes]]] = [{} + for _ in range(self._num_workers)] # Timers are only executed on the first worker # TODO(BEAM-9741): Split timers to multiple workers timer_inputs = [ @@ -1468,12 +1443,10 @@ def process_bundle(self, for ix, part in enumerate(input.partition(self._num_workers)): part_inputs[ix][name] = part - merged_result = None # type: Optional[beam_fn_api_pb2.InstructionResponse] - split_result_list = [ - ] # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] + merged_result: Optional[beam_fn_api_pb2.InstructionResponse] = None + split_result_list: List[beam_fn_api_pb2.ProcessBundleSplitResponse] = [] - def execute(part_map_input_timers): - # type: (...) -> BundleProcessResult + def execute(part_map_input_timers) -> BundleProcessResult: part_map, input_timers = part_map_input_timers bundle_manager = BundleManager( self.bundle_context_manager, @@ -1509,20 +1482,19 @@ class ProgressRequester(threading.Thread): A callback can be passed to call with progress updates. """ - - def __init__(self, - worker_handler, # type: WorkerHandler - instruction_id, - frequency, - callback=None - ): - # type: (...) -> None + def __init__( + self, + worker_handler: WorkerHandler, + instruction_id, + frequency, + callback=None) -> None: super().__init__() self._worker_handler = worker_handler self._instruction_id = instruction_id self._frequency = frequency self._done = False - self._latest_progress = None # type: Optional[beam_fn_api_pb2.ProcessBundleProgressResponse] + self._latest_progress: Optional[ + beam_fn_api_pb2.ProcessBundleProgressResponse] = None self._callback = callback self.daemon = True @@ -1563,15 +1535,17 @@ def __init__(self, step_monitoring_infos, user_metrics_only=True): self._counters = {} self._distributions = {} self._gauges = {} + self._string_sets = {} self._user_metrics_only = user_metrics_only self._monitoring_infos = step_monitoring_infos for smi in step_monitoring_infos.values(): - counters, distributions, gauges = \ + counters, distributions, gauges, string_sets = \ portable_metrics.from_monitoring_infos(smi, user_metrics_only) self._counters.update(counters) self._distributions.update(distributions) self._gauges.update(gauges) + self._string_sets.update(string_sets) def query(self, filter=None): counters = [ @@ -1586,15 +1560,19 @@ def query(self, filter=None): MetricResult(k, v, v) for k, v in self._gauges.items() if self.matches(filter, k) ] + string_sets = [ + MetricResult(k, v, v) for k, + v in self._string_sets.items() if self.matches(filter, k) + ] return { self.COUNTERS: counters, self.DISTRIBUTIONS: distributions, - self.GAUGES: gauges + self.GAUGES: gauges, + self.STRINGSETS: string_sets } - def monitoring_infos(self): - # type: () -> List[metrics_pb2.MonitoringInfo] + def monitoring_infos(self) -> List[metrics_pb2.MonitoringInfo]: return [ item for sublist in self._monitoring_infos.values() for item in sublist ] diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py index 4a35da8dd274..4a737feaf288 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py @@ -1212,13 +1212,16 @@ def test_metrics(self, check_gauge=True): counter = beam.metrics.Metrics.counter('ns', 'counter') distribution = beam.metrics.Metrics.distribution('ns', 'distribution') gauge = beam.metrics.Metrics.gauge('ns', 'gauge') + string_set = beam.metrics.Metrics.string_set('ns', 'string_set') - pcoll = p | beam.Create(['a', 'zzz']) + elements = ['a', 'zzz'] + pcoll = p | beam.Create(elements) # pylint: disable=expression-not-assigned pcoll | 'count1' >> beam.FlatMap(lambda x: counter.inc()) pcoll | 'count2' >> beam.FlatMap(lambda x: counter.inc(len(x))) pcoll | 'dist' >> beam.FlatMap(lambda x: distribution.update(len(x))) pcoll | 'gauge' >> beam.FlatMap(lambda x: gauge.set(3)) + pcoll | 'string_set' >> beam.FlatMap(lambda x: string_set.add(x)) res = p.run() res.wait_until_finish() @@ -1238,6 +1241,10 @@ def test_metrics(self, check_gauge=True): .with_name('gauge'))['gauges'] self.assertEqual(gaug.committed.value, 3) + str_set, = res.metrics().query(beam.metrics.MetricsFilter() + .with_name('string_set'))['string_sets'] + self.assertEqual(str_set.committed, set(elements)) + def test_callbacks_with_exception(self): elements_list = ['1', '2'] @@ -2224,7 +2231,7 @@ def __reduce__(self): return _unpickle_element_counter, (name, ) -_pickled_element_counters = {} # type: Dict[str, ElementCounter] +_pickled_element_counters: Dict[str, ElementCounter] = {} def _unpickle_element_counter(name): diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/watermark_manager.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/watermark_manager.py index 6f926a6284e2..106eca108297 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/watermark_manager.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/watermark_manager.py @@ -102,8 +102,7 @@ def input_watermark(self): w = min(w, min(i._produced_watermark for i in self.side_inputs)) return w - def __init__(self, stages): - # type: (List[translations.Stage]) -> None + def __init__(self, stages: List[translations.Stage]) -> None: self._pcollections_by_name: Dict[Union[str, translations.TimerFamilyId], WatermarkManager.PCollectionNode] = {} self._stages_by_name: Dict[str, WatermarkManager.StageNode] = {} @@ -189,12 +188,12 @@ def _verify(self, stages: List[translations.Stage]): 'Stage %s has no main inputs. ' 'At least one main input is necessary.' % s.name) - def get_stage_node(self, name): - # type: (str) -> StageNode # noqa: F821 + def get_stage_node(self, name: str) -> StageNode: + # noqa: F821 return self._stages_by_name[name] - def get_pcoll_node(self, name): - # type: (str) -> PCollectionNode # noqa: F821 + def get_pcoll_node(self, name: str) -> PCollectionNode: + # noqa: F821 return self._pcollections_by_name[name] def set_pcoll_watermark(self, name, watermark): diff --git a/sdks/python/apache_beam/runners/portability/job_server.py b/sdks/python/apache_beam/runners/portability/job_server.py index eda8755e18ab..e44d8ab0ae93 100644 --- a/sdks/python/apache_beam/runners/portability/job_server.py +++ b/sdks/python/apache_beam/runners/portability/job_server.py @@ -48,8 +48,7 @@ def __init__(self, endpoint, timeout=None): self._endpoint = endpoint self._timeout = timeout - def start(self): - # type: () -> beam_job_api_pb2_grpc.JobServiceStub + def start(self) -> beam_job_api_pb2_grpc.JobServiceStub: channel = grpc.insecure_channel(self._endpoint) grpc.channel_ready_future(channel).result(timeout=self._timeout) return beam_job_api_pb2_grpc.JobServiceStub(channel) @@ -59,8 +58,7 @@ def stop(self): class EmbeddedJobServer(JobServer): - def start(self): - # type: () -> local_job_service.LocalJobServicer + def start(self) -> 'local_job_service.LocalJobServicer': return local_job_service.LocalJobServicer() def stop(self): diff --git a/sdks/python/apache_beam/runners/portability/local_job_service.py b/sdks/python/apache_beam/runners/portability/local_job_service.py index 6966e66d2c64..869f013d0d26 100644 --- a/sdks/python/apache_beam/runners/portability/local_job_service.py +++ b/sdks/python/apache_beam/runners/portability/local_job_service.py @@ -27,7 +27,6 @@ import threading import time import traceback -from typing import TYPE_CHECKING from typing import Any from typing import List from typing import Mapping @@ -35,6 +34,7 @@ import grpc from google.protobuf import json_format +from google.protobuf import struct_pb2 from google.protobuf import text_format # type: ignore # not in typeshed from apache_beam import pipeline @@ -57,9 +57,6 @@ from apache_beam.transforms import environments from apache_beam.utils import thread_pool_executor -if TYPE_CHECKING: - from google.protobuf import struct_pb2 # pylint: disable=ungrouped-imports - _LOGGER = logging.getLogger(__name__) @@ -87,16 +84,16 @@ def __init__(self, staging_dir=None, beam_job_type=None): self._staging_dir = staging_dir or tempfile.mkdtemp() self._artifact_service = artifact_service.ArtifactStagingService( artifact_service.BeamFilesystemHandler(self._staging_dir).file_writer) - self._artifact_staging_endpoint = None # type: Optional[endpoints_pb2.ApiServiceDescriptor] + self._artifact_staging_endpoint: Optional[ + endpoints_pb2.ApiServiceDescriptor] = None self._beam_job_type = beam_job_type or BeamJob def create_beam_job(self, preparation_id, # stype: str - job_name, # type: str - pipeline, # type: beam_runner_api_pb2.Pipeline - options # type: struct_pb2.Struct - ): - # type: (...) -> BeamJob + job_name: str, + pipeline: beam_runner_api_pb2.Pipeline, + options: struct_pb2.Struct + ) -> 'BeamJob': self._artifact_service.register_job( staging_token=preparation_id, dependency_sets=_extract_dependency_sets( @@ -181,7 +178,7 @@ class SubprocessSdkWorker(object): """ def __init__( self, - worker_command_line, # type: bytes + worker_command_line: bytes, control_address, provision_info, worker_id=None): @@ -238,20 +235,20 @@ class BeamJob(abstract_job_service.AbstractBeamJob): The current state of the pipeline is available as self.state. """ - - def __init__(self, - job_id, # type: str - pipeline, - options, - provision_info, # type: fn_runner.ExtendedProvisionInfo - artifact_staging_endpoint, # type: Optional[endpoints_pb2.ApiServiceDescriptor] - artifact_service, # type: artifact_service.ArtifactStagingService - ): + def __init__( + self, + job_id: str, + pipeline, + options, + provision_info: fn_runner.ExtendedProvisionInfo, + artifact_staging_endpoint: Optional[endpoints_pb2.ApiServiceDescriptor], + artifact_service: artifact_service.ArtifactStagingService, + ): super().__init__(job_id, provision_info.job_name, pipeline, options) self._provision_info = provision_info self._artifact_staging_endpoint = artifact_staging_endpoint self._artifact_service = artifact_service - self._state_queues = [] # type: List[queue.Queue] + self._state_queues: List[queue.Queue] = [] self._log_queues = JobLogQueues() self.daemon = True self.result = None @@ -378,7 +375,7 @@ def Logging(self, log_bundles, context=None): class JobLogQueues(object): def __init__(self): - self._queues = [] # type: List[queue.Queue] + self._queues: List[queue.Queue] = [] self._cache = [] self._cache_size = 10 self._lock = threading.Lock() diff --git a/sdks/python/apache_beam/runners/portability/portable_metrics.py b/sdks/python/apache_beam/runners/portability/portable_metrics.py index d7d330dd7e77..5bc3e0539181 100644 --- a/sdks/python/apache_beam/runners/portability/portable_metrics.py +++ b/sdks/python/apache_beam/runners/portability/portable_metrics.py @@ -27,18 +27,21 @@ def from_monitoring_infos(monitoring_info_list, user_metrics_only=False): - """Groups MonitoringInfo objects into counters, distributions and gauges. + """Groups MonitoringInfo objects into counters, distributions, gauges and + string sets Args: monitoring_info_list: An iterable of MonitoringInfo objects. user_metrics_only: If true, includes user metrics only. Returns: - A tuple containing three dictionaries: counters, distributions and gauges, - respectively. Each dictionary contains (MetricKey, metric result) pairs. + A tuple containing three dictionaries: counters, distributions, gauges and + string set, respectively. Each dictionary contains (MetricKey, metric + result) pairs. """ counters = {} distributions = {} gauges = {} + string_sets = {} for mi in monitoring_info_list: if (user_metrics_only and not monitoring_infos.is_user_monitoring_info(mi)): @@ -57,8 +60,10 @@ def from_monitoring_infos(monitoring_info_list, user_metrics_only=False): distributions[key] = metric_result elif monitoring_infos.is_gauge(mi): gauges[key] = metric_result + elif monitoring_infos.is_string_set(mi): + string_sets[key] = metric_result - return counters, distributions, gauges + return counters, distributions, gauges, string_sets def _create_metric_key(monitoring_info): diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py b/sdks/python/apache_beam/runners/portability/portable_runner.py index ab5ee9fff6f9..ba48bbec6d3a 100644 --- a/sdks/python/apache_beam/runners/portability/portable_runner.py +++ b/sdks/python/apache_beam/runners/portability/portable_runner.py @@ -25,7 +25,6 @@ import logging import threading import time -from typing import TYPE_CHECKING from typing import Any from typing import Dict from typing import Iterator @@ -33,6 +32,7 @@ from typing import Tuple import grpc +from google.protobuf import struct_pb2 from apache_beam.metrics import metric from apache_beam.metrics.execution import MetricResult @@ -41,6 +41,7 @@ from apache_beam.options.pipeline_options import PortableOptions from apache_beam.options.pipeline_options import StandardOptions from apache_beam.options.value_provider import ValueProvider +from apache_beam.pipeline import Pipeline from apache_beam.portability import common_urns from apache_beam.portability import python_urns from apache_beam.portability.api import beam_artifact_api_pb2_grpc @@ -56,10 +57,6 @@ from apache_beam.runners.worker import worker_pool_main from apache_beam.transforms import environments -if TYPE_CHECKING: - from google.protobuf import struct_pb2 # pylint: disable=ungrouped-imports - from apache_beam.pipeline import Pipeline - __all__ = ['PortableRunner'] MESSAGE_LOG_LEVELS = { @@ -97,9 +94,11 @@ def __init__(self, job_service, options, retain_unknown_options=False): self.artifact_endpoint = options.view_as(PortableOptions).artifact_endpoint self._retain_unknown_options = retain_unknown_options - def submit(self, proto_pipeline): - # type: (beam_runner_api_pb2.Pipeline) -> Tuple[str, Iterator[beam_job_api_pb2.JobStateEvent], Iterator[beam_job_api_pb2.JobMessagesResponse]] - + def submit( + self, proto_pipeline: beam_runner_api_pb2.Pipeline + ) -> Tuple[str, + Iterator[beam_job_api_pb2.JobStateEvent], + Iterator[beam_job_api_pb2.JobMessagesResponse]]: """ Submit and run the pipeline defined by `proto_pipeline`. """ @@ -113,9 +112,7 @@ def submit(self, proto_pipeline): prepare_response.staging_session_token) return self.run(prepare_response.preparation_id) - def get_pipeline_options(self): - # type: () -> struct_pb2.Struct - + def get_pipeline_options(self) -> struct_pb2.Struct: """ Get `self.options` as a protobuf Struct """ @@ -189,9 +186,9 @@ def convert_pipeline_option_value(v): } return job_utils.dict_to_struct(p_options) - def prepare(self, proto_pipeline): - # type: (beam_runner_api_pb2.Pipeline) -> beam_job_api_pb2.PrepareJobResponse - + def prepare( + self, proto_pipeline: beam_runner_api_pb2.Pipeline + ) -> beam_job_api_pb2.PrepareJobResponse: """Prepare the job on the job service""" return self.job_service.Prepare( beam_job_api_pb2.PrepareJobRequest( @@ -200,13 +197,11 @@ def prepare(self, proto_pipeline): pipeline_options=self.get_pipeline_options()), timeout=self.timeout) - def stage(self, - proto_pipeline, # type: beam_runner_api_pb2.Pipeline - artifact_staging_endpoint, - staging_session_token - ): - # type: (...) -> None - + def stage( + self, + proto_pipeline: beam_runner_api_pb2.Pipeline, + artifact_staging_endpoint, + staging_session_token) -> None: """Stage artifacts""" if artifact_staging_endpoint: artifact_service.offer_artifacts( @@ -216,9 +211,11 @@ def stage(self, artifact_service.BeamFilesystemHandler(None).file_reader), staging_session_token) - def run(self, preparation_id): - # type: (str) -> Tuple[str, Iterator[beam_job_api_pb2.JobStateEvent], Iterator[beam_job_api_pb2.JobMessagesResponse]] - + def run( + self, preparation_id: str + ) -> Tuple[str, + Iterator[beam_job_api_pb2.JobStateEvent], + Iterator[beam_job_api_pb2.JobMessagesResponse]]: """Run the job""" try: state_stream = self.job_service.GetStateStream( @@ -260,11 +257,10 @@ class PortableRunner(runner.PipelineRunner): running and managing the job lies with the job service used. """ def __init__(self): - self._dockerized_job_server = None # type: Optional[job_server.JobServer] + self._dockerized_job_server: Optional[job_server.JobServer] = None @staticmethod - def _create_environment(options): - # type: (PipelineOptions) -> environments.Environment + def _create_environment(options: PipelineOptions) -> environments.Environment: return environments.Environment.from_options( options.view_as(PortableOptions)) @@ -274,20 +270,17 @@ def default_job_server(self, options): 'Alternatively, you may specify which portable runner you intend to ' 'use, such as --runner=FlinkRunner or --runner=SparkRunner.') - def create_job_service_handle(self, job_service, options): - # type: (...) -> JobServiceHandle + def create_job_service_handle(self, job_service, options) -> JobServiceHandle: return JobServiceHandle(job_service, options) - def create_job_service(self, options): - # type: (PipelineOptions) -> JobServiceHandle - + def create_job_service(self, options: PipelineOptions) -> JobServiceHandle: """ Start the job service and return a `JobServiceHandle` """ job_endpoint = options.view_as(PortableOptions).job_endpoint if job_endpoint: if job_endpoint == 'embed': - server = job_server.EmbeddedJobServer() # type: job_server.JobServer + server: job_server.JobServer = job_server.EmbeddedJobServer() else: job_server_timeout = options.view_as(PortableOptions).job_server_timeout server = job_server.ExternalJobServer(job_endpoint, job_server_timeout) @@ -296,8 +289,9 @@ def create_job_service(self, options): return self.create_job_service_handle(server.start(), options) @staticmethod - def get_proto_pipeline(pipeline, options): - # type: (Pipeline, PipelineOptions) -> beam_runner_api_pb2.Pipeline + def get_proto_pipeline( + pipeline: Pipeline, + options: PipelineOptions) -> beam_runner_api_pb2.Pipeline: proto_pipeline = pipeline.to_runner_api( default_environment=environments.Environment.from_options( options.view_as(PortableOptions))) @@ -443,7 +437,7 @@ def _combine(committed, attempted, filter): ] def query(self, filter=None): - counters, distributions, gauges = [ + counters, distributions, gauges, stringsets = [ self._combine(x, y, filter) for x, y in zip(self.committed, self.attempted) ] @@ -451,7 +445,8 @@ def query(self, filter=None): return { self.COUNTERS: counters, self.DISTRIBUTIONS: distributions, - self.GAUGES: gauges + self.GAUGES: gauges, + self.STRINGSETS: stringsets } @@ -473,8 +468,7 @@ def __init__( self._metrics = None self._runtime_exception = None - def cancel(self): - # type: () -> None + def cancel(self) -> None: try: self._job_service.Cancel( beam_job_api_pb2.CancelJobRequest(job_id=self._job_id)) @@ -513,8 +507,7 @@ def metrics(self): self._metrics = PortableMetrics(job_metrics_response) return self._metrics - def _last_error_message(self): - # type: () -> str + def _last_error_message(self) -> str: # Filter only messages with the "message_response" and error messages. messages = [ m.message_response for m in self._messages @@ -535,8 +528,7 @@ def wait_until_finish(self, duration=None): the execution. If None or zero, will wait until the pipeline finishes. :return: The result of the pipeline, i.e. PipelineResult. """ - def read_messages(): - # type: () -> None + def read_messages() -> None: previous_state = -1 for message in self._message_stream: if message.HasField('message_response'): @@ -595,8 +587,7 @@ def _observe_state(self, message_thread): finally: self._cleanup() - def _cleanup(self, on_exit=False): - # type: (bool) -> None + def _cleanup(self, on_exit: bool = False) -> None: if on_exit and self._cleanup_callbacks: _LOGGER.info( 'Running cleanup on exit. If your pipeline should continue running, ' diff --git a/sdks/python/apache_beam/runners/portability/stager.py b/sdks/python/apache_beam/runners/portability/stager.py index 48dabe18aa36..98c0e3176f75 100644 --- a/sdks/python/apache_beam/runners/portability/stager.py +++ b/sdks/python/apache_beam/runners/portability/stager.py @@ -214,7 +214,8 @@ def create_job_resources(options, # type: PipelineOptions os.path.join(tempfile.gettempdir(), 'dataflow-requirements-cache') if (setup_options.requirements_cache is None) else setup_options.requirements_cache) - if not os.path.exists(requirements_cache_path): + if (setup_options.requirements_cache != SKIP_REQUIREMENTS_CACHE and + not os.path.exists(requirements_cache_path)): os.makedirs(requirements_cache_path) # Stage a requirements file if present. diff --git a/sdks/python/apache_beam/runners/portability/stager_test.py b/sdks/python/apache_beam/runners/portability/stager_test.py index 25fd62b16533..5535989a5786 100644 --- a/sdks/python/apache_beam/runners/portability/stager_test.py +++ b/sdks/python/apache_beam/runners/portability/stager_test.py @@ -75,7 +75,7 @@ def create_temp_file(self, path, contents): def is_remote_path(self, path): return path.startswith('/tmp/remote/') - remote_copied_files = [] # type: List[str] + remote_copied_files: List[str] = [] def file_copy(self, from_path, to_path): if self.is_remote_path(from_path): diff --git a/sdks/python/apache_beam/runners/runner.py b/sdks/python/apache_beam/runners/runner.py index d037e0d42c0b..78022724226a 100644 --- a/sdks/python/apache_beam/runners/runner.py +++ b/sdks/python/apache_beam/runners/runner.py @@ -55,9 +55,7 @@ _LOGGER = logging.getLogger(__name__) -def create_runner(runner_name): - # type: (str) -> PipelineRunner - +def create_runner(runner_name: str) -> 'PipelineRunner': """For internal use only; no backwards-compatibility guarantees. Creates a runner instance from a runner class name. @@ -113,13 +111,10 @@ class PipelineRunner(object): provide a new implementation for clear_pvalue(), which is used to wipe out materialized values in order to reduce footprint. """ - - def run(self, - transform, # type: PTransform - options=None # type: Optional[PipelineOptions] - ): - # type: (...) -> PipelineResult - + def run( + self, + transform: 'PTransform', + options: Optional[PipelineOptions] = None) -> 'PipelineResult': """Run the given transform or callable with this runner. Blocks until the pipeline is complete. See also `PipelineRunner.run_async`. @@ -128,12 +123,10 @@ def run(self, result.wait_until_finish() return result - def run_async(self, - transform, # type: PTransform - options=None # type: Optional[PipelineOptions] - ): - # type: (...) -> PipelineResult - + def run_async( + self, + transform: 'PTransform', + options: Optional[PipelineOptions] = None) -> 'PipelineResult': """Run the given transform or callable with this runner. May return immediately, executing the pipeline in the background. @@ -171,12 +164,7 @@ def default_environment( options.view_as(PortableOptions)) def run_pipeline( - self, - pipeline, # type: Pipeline - options # type: PipelineOptions - ): - # type: (...) -> PipelineResult - + self, pipeline: 'Pipeline', options: PipelineOptions) -> 'PipelineResult': """Execute the entire pipeline or the sub-DAG reachable from a node. """ pipeline.visit( @@ -194,11 +182,11 @@ def run_pipeline( default_environment=self.default_environment(options)), options) - def apply(self, - transform, # type: PTransform - input, # type: Optional[pvalue.PValue] - options # type: PipelineOptions - ): + def apply( + self, + transform: 'PTransform', + input: Optional['pvalue.PValue'], + options: PipelineOptions): # TODO(robertwb): Remove indirection once internal references are fixed. return self.apply_PTransform(transform, input, options) diff --git a/sdks/python/apache_beam/runners/sdf_utils.py b/sdks/python/apache_beam/runners/sdf_utils.py index bbb6b2de6e85..01573656b6ac 100644 --- a/sdks/python/apache_beam/runners/sdf_utils.py +++ b/sdks/python/apache_beam/runners/sdf_utils.py @@ -55,8 +55,7 @@ class ThreadsafeRestrictionTracker(object): This wrapper guarantees synchronization of modifying restrictions across multi-thread. """ - def __init__(self, restriction_tracker): - # type: (RestrictionTracker) -> None + def __init__(self, restriction_tracker: 'RestrictionTracker') -> None: from apache_beam.io.iobase import RestrictionTracker if not isinstance(restriction_tracker, RestrictionTracker): raise ValueError( @@ -67,7 +66,7 @@ def __init__(self, restriction_tracker): self._timestamp = None self._lock = threading.RLock() self._deferred_residual = None - self._deferred_timestamp = None # type: Optional[Union[Timestamp, Duration]] + self._deferred_timestamp: Optional[Union[Timestamp, Duration]] = None def current_restriction(self): with self._lock: @@ -110,8 +109,7 @@ def check_done(self): with self._lock: return self._restriction_tracker.check_done() - def current_progress(self): - # type: () -> RestrictionProgress + def current_progress(self) -> 'RestrictionProgress': with self._lock: return self._restriction_tracker.current_progress() @@ -119,9 +117,7 @@ def try_split(self, fraction_of_remainder): with self._lock: return self._restriction_tracker.try_split(fraction_of_remainder) - def deferred_status(self): - # type: () -> Optional[Tuple[Any, Duration]] - + def deferred_status(self) -> Optional[Tuple[Any, Duration]]: """Returns deferred work which is produced by ``defer_remainder()``. When there is a self-checkpoint performed, the system needs to fulfill the @@ -159,8 +155,9 @@ class RestrictionTrackerView(object): time, the RestrictionTrackerView will be fed into the ``DoFn.process`` as a restriction_tracker. """ - def __init__(self, threadsafe_restriction_tracker): - # type: (ThreadsafeRestrictionTracker) -> None + def __init__( + self, + threadsafe_restriction_tracker: ThreadsafeRestrictionTracker) -> None: if not isinstance(threadsafe_restriction_tracker, ThreadsafeRestrictionTracker): raise ValueError( @@ -185,8 +182,7 @@ class ThreadsafeWatermarkEstimator(object): """A threadsafe wrapper which wraps a WatermarkEstimator with locking mechanism to guarantee multi-thread safety. """ - def __init__(self, watermark_estimator): - # type: (WatermarkEstimator) -> None + def __init__(self, watermark_estimator: 'WatermarkEstimator') -> None: from apache_beam.io.iobase import WatermarkEstimator if not isinstance(watermark_estimator, WatermarkEstimator): raise ValueError('Initializing Threadsafe requires a WatermarkEstimator') @@ -207,13 +203,11 @@ def get_estimator_state(self): with self._lock: return self._watermark_estimator.get_estimator_state() - def current_watermark(self): - # type: () -> Timestamp + def current_watermark(self) -> Timestamp: with self._lock: return self._watermark_estimator.current_watermark() - def observe_timestamp(self, timestamp): - # type: (Timestamp) -> None + def observe_timestamp(self, timestamp: Timestamp) -> None: if not isinstance(timestamp, Timestamp): raise ValueError( 'Input of observe_timestamp should be a Timestamp ' diff --git a/sdks/python/apache_beam/runners/worker/log_handler.py b/sdks/python/apache_beam/runners/worker/log_handler.py index 6a2f612fbee1..88cc3c9791d5 100644 --- a/sdks/python/apache_beam/runners/worker/log_handler.py +++ b/sdks/python/apache_beam/runners/worker/log_handler.py @@ -27,7 +27,6 @@ import threading import time import traceback -from typing import TYPE_CHECKING from typing import Iterable from typing import Iterator from typing import List @@ -38,14 +37,12 @@ from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.portability.api import beam_fn_api_pb2_grpc +from apache_beam.portability.api import endpoints_pb2 from apache_beam.runners.worker import statesampler from apache_beam.runners.worker.channel_factory import GRPCChannelFactory from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor from apache_beam.utils.sentinel import Sentinel -if TYPE_CHECKING: - from apache_beam.portability.api import endpoints_pb2 - # Mapping from logging levels to LogEntry levels. LOG_LEVEL_TO_LOGENTRY_MAP = { logging.FATAL: beam_fn_api_pb2.LogEntry.Severity.CRITICAL, @@ -81,15 +78,15 @@ class FnApiLogRecordHandler(logging.Handler): # dropped. If the average log size is 1KB this may use up to 10MB of memory. _QUEUE_SIZE = 10000 - def __init__(self, log_service_descriptor): - # type: (endpoints_pb2.ApiServiceDescriptor) -> None + def __init__( + self, log_service_descriptor: endpoints_pb2.ApiServiceDescriptor) -> None: super().__init__() self._alive = True self._dropped_logs = 0 - self._log_entry_queue = queue.Queue( - maxsize=self._QUEUE_SIZE - ) # type: queue.Queue[Union[beam_fn_api_pb2.LogEntry, Sentinel]] + self._log_entry_queue: queue.Queue[Union[beam_fn_api_pb2.LogEntry, + Sentinel]] = queue.Queue( + maxsize=self._QUEUE_SIZE) ch = GRPCChannelFactory.insecure_channel(log_service_descriptor.url) # Make sure the channel is ready to avoid [BEAM-4649] @@ -101,16 +98,15 @@ def __init__(self, log_service_descriptor): self._reader.daemon = True self._reader.start() - def connect(self): - # type: () -> Iterable + def connect(self) -> Iterable: if hasattr(self, '_logging_stub'): del self._logging_stub # type: ignore[has-type] self._logging_stub = beam_fn_api_pb2_grpc.BeamFnLoggingStub( self._log_channel) return self._logging_stub.Logging(self._write_log_entries()) - def map_log_level(self, level): - # type: (int) -> beam_fn_api_pb2.LogEntry.Severity.Enum.ValueType + def map_log_level( + self, level: int) -> beam_fn_api_pb2.LogEntry.Severity.Enum.ValueType: try: return LOG_LEVEL_TO_LOGENTRY_MAP[level] except KeyError: @@ -119,8 +115,7 @@ def map_log_level(self, level): beam_level in LOG_LEVEL_TO_LOGENTRY_MAP.items() if python_level <= level) - def emit(self, record): - # type: (logging.LogRecord) -> None + def emit(self, record: logging.LogRecord) -> None: log_entry = beam_fn_api_pb2.LogEntry() log_entry.severity = self.map_log_level(record.levelno) try: @@ -154,9 +149,7 @@ def emit(self, record): except queue.Full: self._dropped_logs += 1 - def close(self): - # type: () -> None - + def close(self) -> None: """Flush out all existing log entries and unregister this handler.""" try: self._alive = False @@ -175,8 +168,7 @@ def close(self): # prematurely. logging.error("Error closing the logging channel.", exc_info=True) - def _write_log_entries(self): - # type: () -> Iterator[beam_fn_api_pb2.LogEntry.List] + def _write_log_entries(self) -> Iterator[beam_fn_api_pb2.LogEntry.List]: done = False while not done: log_entries = [self._log_entry_queue.get()] @@ -194,8 +186,7 @@ def _write_log_entries(self): yield beam_fn_api_pb2.LogEntry.List( log_entries=cast(List[beam_fn_api_pb2.LogEntry], log_entries)) - def _read_log_control_messages(self): - # type: () -> None + def _read_log_control_messages(self) -> None: # Only reconnect when we are alive. # We can drop some logs in the unlikely event of logging connection # dropped(not closed) during termination when we still have logs to be sent. diff --git a/sdks/python/apache_beam/runners/worker/logger.py b/sdks/python/apache_beam/runners/worker/logger.py index e1c84bc6ded2..1efebeb3c78c 100644 --- a/sdks/python/apache_beam/runners/worker/logger.py +++ b/sdks/python/apache_beam/runners/worker/logger.py @@ -39,15 +39,13 @@ # context information that changes while work items get executed: # work_item_id, step_name, stage_name. class _PerThreadWorkerData(threading.local): - def __init__(self): - # type: () -> None + def __init__(self) -> None: super().__init__() # in the list, as going up and down all the way to zero incurs several # reallocations. - self.stack = [] # type: List[Dict[str, Any]] + self.stack: List[Dict[str, Any]] = [] - def get_data(self): - # type: () -> Dict[str, Any] + def get_data(self) -> Dict[str, Any]: all_data = {} for datum in self.stack: all_data.update(datum) @@ -58,9 +56,7 @@ def get_data(self): @contextlib.contextmanager -def PerThreadLoggingContext(**kwargs): - # type: (**Any) -> Iterator[None] - +def PerThreadLoggingContext(**kwargs: Any) -> Iterator[None]: """A context manager to add per thread attributes.""" stack = per_thread_worker_data.stack stack.append(kwargs) @@ -72,15 +68,12 @@ def PerThreadLoggingContext(**kwargs): class JsonLogFormatter(logging.Formatter): """A JSON formatter class as expected by the logging standard module.""" - def __init__(self, job_id, worker_id): - # type: (str, str) -> None + def __init__(self, job_id: str, worker_id: str) -> None: super().__init__() self.job_id = job_id self.worker_id = worker_id - def format(self, record): - # type: (logging.LogRecord) -> str - + def format(self, record: logging.LogRecord) -> str: """Returns a JSON string based on a LogRecord instance. Args: @@ -115,7 +108,7 @@ def format(self, record): Python thread object. Nevertheless having this value can allow to filter log statement from only one specific thread. """ - output = {} # type: Dict[str, Any] + output: Dict[str, Any] = {} output['timestamp'] = { 'seconds': int(record.created), 'nanos': int(record.msecs * 1000000) } @@ -170,9 +163,11 @@ def format(self, record): return json.dumps(output) -def initialize(job_id, worker_id, log_path, log_level=logging.INFO): - # type: (str, str, str, int) -> None - +def initialize( + job_id: str, + worker_id: str, + log_path: str, + log_level: int = logging.INFO) -> None: """Initialize root logger so that we log JSON to a file and text to stdout.""" file_handler = logging.FileHandler(log_path) diff --git a/sdks/python/apache_beam/runners/worker/statecache.py b/sdks/python/apache_beam/runners/worker/statecache.py index dde4243057dd..d4e61cc9297f 100644 --- a/sdks/python/apache_beam/runners/worker/statecache.py +++ b/sdks/python/apache_beam/runners/worker/statecache.py @@ -58,39 +58,31 @@ class WeightedValue(object): :arg weight The associated weight of the value. If unspecified, the objects size will be used. """ - def __init__(self, value, weight): - # type: (Any, int) -> None + def __init__(self, value: Any, weight: int) -> None: self._value = value if weight <= 0: raise ValueError( 'Expected weight to be > 0 for %s but received %d' % (value, weight)) self._weight = weight - def weight(self): - # type: () -> int + def weight(self) -> int: return self._weight - def value(self): - # type: () -> Any + def value(self) -> Any: return self._value class CacheAware(object): """Allows cache users to override what objects are measured.""" - def __init__(self): - # type: () -> None + def __init__(self) -> None: pass - def get_referents_for_cache(self): - # type: () -> List[Any] - + def get_referents_for_cache(self) -> List[Any]: """Returns the list of objects accounted during cache measurement.""" raise NotImplementedError() -def _safe_isinstance(obj, type): - # type: (Any, Union[type, Tuple[type, ...]]) -> bool - +def _safe_isinstance(obj: Any, type: Union[type, Tuple[type, ...]]) -> bool: """ Return whether an object is an instance of a class or of a subclass thereof. See `isinstance()` for more information. @@ -106,9 +98,7 @@ def _safe_isinstance(obj, type): return False -def _size_func(obj): - # type: (Any) -> int - +def _size_func(obj: Any) -> int: """ Returns the size of the object or a default size if an error occurred during sizing. @@ -136,9 +126,7 @@ def _size_func(obj): _size_func.last_log_time = 0 # type: ignore -def _get_referents_func(*objs): - # type: (List[Any]) -> List[Any] - +def _get_referents_func(*objs: List[Any]) -> List[Any]: """Returns the list of objects accounted during cache measurement. Users can inherit CacheAware to override which referents should be @@ -154,9 +142,7 @@ def _get_referents_func(*objs): return rval -def _filter_func(o): - # type: (Any) -> bool - +def _filter_func(o: Any) -> bool: """ Filter out specific types from being measured. @@ -171,9 +157,7 @@ def _filter_func(o): return not _safe_isinstance(o, _TYPES_TO_NOT_MEASURE) -def get_deep_size(*objs): - # type: (Any) -> int - +def get_deep_size(*objs: Any) -> int: """Calculates the deep size of all the arguments in bytes.""" return objsize.get_deep_size( *objs, @@ -184,13 +168,11 @@ def get_deep_size(*objs): class _LoadingValue(WeightedValue): """Allows concurrent users of the cache to wait for a value to be loaded.""" - def __init__(self): - # type: () -> None + def __init__(self) -> None: super().__init__(None, 1) self._wait_event = threading.Event() - def load(self, key, loading_fn): - # type: (Any, Callable[[Any], Any]) -> None + def load(self, key: Any, loading_fn: Callable[[Any], Any]) -> None: try: self._value = loading_fn(key) except Exception as err: @@ -198,8 +180,7 @@ def load(self, key, loading_fn): finally: self._wait_event.set() - def value(self): - # type: () -> Any + def value(self) -> Any: self._wait_event.wait() err = getattr(self, "_error", None) if err: @@ -229,13 +210,12 @@ class StateCache(object): :arg max_weight The maximum weight of entries to store in the cache in bytes. """ - def __init__(self, max_weight): - # type: (int) -> None + def __init__(self, max_weight: int) -> None: _LOGGER.info('Creating state cache with size %s', max_weight) self._max_weight = max_weight self._current_weight = 0 - self._cache = collections.OrderedDict( - ) # type: collections.OrderedDict[Any, WeightedValue] + self._cache: collections.OrderedDict[ + Any, WeightedValue] = collections.OrderedDict() self._hit_count = 0 self._miss_count = 0 self._evict_count = 0 @@ -243,8 +223,7 @@ def __init__(self, max_weight): self._load_count = 0 self._lock = threading.RLock() - def peek(self, key): - # type: (Any) -> Any + def peek(self, key: Any) -> Any: assert self.is_cache_enabled() with self._lock: value = self._cache.get(key, None) @@ -256,8 +235,7 @@ def peek(self, key): self._hit_count += 1 return value.value() - def get(self, key, loading_fn): - # type: (Any, Callable[[Any], Any]) -> Any + def get(self, key: Any, loading_fn: Callable[[Any], Any]) -> Any: assert self.is_cache_enabled() and callable(loading_fn) self._lock.acquire() @@ -333,8 +311,7 @@ def get(self, key, loading_fn): return value.value() - def put(self, key, value): - # type: (Any, Any) -> None + def put(self, key: Any, value: Any) -> None: assert self.is_cache_enabled() if not _safe_isinstance(value, WeightedValue): weight = get_deep_size(value) @@ -356,22 +333,19 @@ def put(self, key, value): self._current_weight -= weighted_value.weight() self._evict_count += 1 - def invalidate(self, key): - # type: (Any) -> None + def invalidate(self, key: Any) -> None: assert self.is_cache_enabled() with self._lock: weighted_value = self._cache.pop(key, None) if weighted_value is not None: self._current_weight -= weighted_value.weight() - def invalidate_all(self): - # type: () -> None + def invalidate_all(self) -> None: with self._lock: self._cache.clear() self._current_weight = 0 - def describe_stats(self): - # type: () -> str + def describe_stats(self) -> str: with self._lock: request_count = self._hit_count + self._miss_count if request_count > 0: @@ -390,11 +364,9 @@ def describe_stats(self): self._load_count, self._evict_count) - def is_cache_enabled(self): - # type: () -> bool + def is_cache_enabled(self) -> bool: return self._max_weight > 0 - def size(self): - # type: () -> int + def size(self) -> int: with self._lock: return len(self._cache) diff --git a/sdks/python/apache_beam/runners/worker/statesampler.py b/sdks/python/apache_beam/runners/worker/statesampler.py index 4dc7e97c140d..b9c75f4de93d 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler.py +++ b/sdks/python/apache_beam/runners/worker/statesampler.py @@ -89,46 +89,39 @@ def for_test(): class StateSampler(statesampler_impl.StateSampler): - - def __init__(self, - prefix, # type: str - counter_factory, - sampling_period_ms=DEFAULT_SAMPLING_PERIOD_MS): + def __init__( + self, + prefix: str, + counter_factory, + sampling_period_ms=DEFAULT_SAMPLING_PERIOD_MS): self._prefix = prefix self._counter_factory = counter_factory - self._states_by_name = { - } # type: Dict[CounterName, statesampler_impl.ScopedState] + self._states_by_name: Dict[CounterName, statesampler_impl.ScopedState] = {} self.sampling_period_ms = sampling_period_ms - self.tracked_thread = None # type: Optional[threading.Thread] + self.tracked_thread: Optional[threading.Thread] = None self.finished = False self.started = False super().__init__(sampling_period_ms) @property - def stage_name(self): - # type: () -> str + def stage_name(self) -> str: return self._prefix - def stop(self): - # type: () -> None + def stop(self) -> None: set_current_tracker(None) super().stop() - def stop_if_still_running(self): - # type: () -> None + def stop_if_still_running(self) -> None: if self.started and not self.finished: self.stop() - def start(self): - # type: () -> None + def start(self) -> None: self.tracked_thread = threading.current_thread() set_current_tracker(self) super().start() self.started = True - def get_info(self): - # type: () -> StateSamplerInfo - + def get_info(self) -> StateSamplerInfo: """Returns StateSamplerInfo with transition statistics.""" return StateSamplerInfo( self.current_state().name, @@ -136,14 +129,13 @@ def get_info(self): self.time_since_transition, self.tracked_thread) - def scoped_state(self, - name_context, # type: Union[str, common.NameContext] - state_name, # type: str - io_target=None, - metrics_container=None # type: Optional[MetricsContainer] - ): - # type: (...) -> statesampler_impl.ScopedState - + def scoped_state( + self, + name_context: Union[str, 'common.NameContext'], + state_name: str, + io_target=None, + metrics_container: Optional['MetricsContainer'] = None + ) -> statesampler_impl.ScopedState: """Returns a ScopedState object associated to a Step and a State. Args: @@ -173,9 +165,7 @@ def scoped_state(self, counter_name, name_context, output_counter, metrics_container) return self._states_by_name[counter_name] - def commit_counters(self): - # type: () -> None - + def commit_counters(self) -> None: """Updates output counters with latest state statistics.""" for state in self._states_by_name.values(): state_msecs = int(1e-6 * state.nsecs) diff --git a/sdks/python/apache_beam/runners/worker/statesampler_slow.py b/sdks/python/apache_beam/runners/worker/statesampler_slow.py index 4279b4f8d7f3..be801284450a 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_slow.py +++ b/sdks/python/apache_beam/runners/worker/statesampler_slow.py @@ -31,21 +31,19 @@ def __init__(self, sampling_period_ms): self.state_transition_count = 0 self.time_since_transition = 0 - def current_state(self): - # type: () -> ScopedState - + def current_state(self) -> 'ScopedState': """Returns the current execution state. This operation is not thread safe, and should only be called from the execution thread.""" return self._state_stack[-1] - def _scoped_state(self, - counter_name, # type: counters.CounterName - name_context, # type: common.NameContext - output_counter, - metrics_container=None): - # type: (...) -> ScopedState + def _scoped_state( + self, + counter_name: counters.CounterName, + name_context: 'common.NameContext', + output_counter, + metrics_container=None) -> 'ScopedState': assert isinstance(name_context, common.NameContext) return ScopedState( self, counter_name, name_context, output_counter, metrics_container) @@ -55,38 +53,33 @@ def update_metric(self, typed_metric_name, value): if metrics_container is not None: metrics_container.get_metric_cell(typed_metric_name).update(value) - def _enter_state(self, state): - # type: (ScopedState) -> None + def _enter_state(self, state: 'ScopedState') -> None: self.state_transition_count += 1 self._state_stack.append(state) - def _exit_state(self): - # type: () -> None + def _exit_state(self) -> None: self.state_transition_count += 1 self._state_stack.pop() - def start(self): - # type: () -> None + def start(self) -> None: # Sampling not yet supported. Only state tracking at the moment. pass - def stop(self): - # type: () -> None + def stop(self) -> None: pass - def reset(self): - # type: () -> None + def reset(self) -> None: pass class ScopedState(object): - - def __init__(self, - sampler, # type: StateSampler - name, # type: counters.CounterName - step_name_context, # type: Optional[common.NameContext] - counter=None, - metrics_container=None): + def __init__( + self, + sampler: StateSampler, + name: counters.CounterName, + step_name_context: Optional['common.NameContext'], + counter=None, + metrics_container=None): self.state_sampler = sampler self.name = name self.name_context = step_name_context @@ -94,12 +87,10 @@ def __init__(self, self.nsecs = 0 self.metrics_container = metrics_container - def sampled_seconds(self): - # type: () -> float + def sampled_seconds(self) -> float: return 1e-9 * self.nsecs - def sampled_msecs_int(self): - # type: () -> int + def sampled_msecs_int(self) -> int: return int(1e-6 * self.nsecs) def __repr__(self): diff --git a/sdks/python/apache_beam/runners/worker/worker_id_interceptor.py b/sdks/python/apache_beam/runners/worker/worker_id_interceptor.py index b913f2c63b63..1db2b5f4a151 100644 --- a/sdks/python/apache_beam/runners/worker/worker_id_interceptor.py +++ b/sdks/python/apache_beam/runners/worker/worker_id_interceptor.py @@ -40,8 +40,7 @@ class WorkerIdInterceptor(grpc.UnaryUnaryClientInterceptor, # Unique worker Id for this worker. _worker_id = os.environ.get('WORKER_ID') - def __init__(self, worker_id=None): - # type: (Optional[str]) -> None + def __init__(self, worker_id: Optional[str] = None) -> None: if worker_id: self._worker_id = worker_id diff --git a/sdks/python/apache_beam/runners/worker/worker_pool_main.py b/sdks/python/apache_beam/runners/worker/worker_pool_main.py index 7e81b1fa6d72..307261c2d3c3 100644 --- a/sdks/python/apache_beam/runners/worker/worker_pool_main.py +++ b/sdks/python/apache_beam/runners/worker/worker_pool_main.py @@ -73,18 +73,17 @@ def _kill(): class BeamFnExternalWorkerPoolServicer( beam_fn_api_pb2_grpc.BeamFnExternalWorkerPoolServicer): - - def __init__(self, - use_process=False, - container_executable=None, # type: Optional[str] - state_cache_size=0, - data_buffer_time_limit_ms=0 - ): + def __init__( + self, + use_process=False, + container_executable: Optional[str] = None, + state_cache_size=0, + data_buffer_time_limit_ms=0): self._use_process = use_process self._container_executable = container_executable self._state_cache_size = state_cache_size self._data_buffer_time_limit_ms = data_buffer_time_limit_ms - self._worker_processes = {} # type: Dict[str, subprocess.Popen] + self._worker_processes: Dict[str, subprocess.Popen] = {} @classmethod def start( @@ -93,9 +92,7 @@ def start( port=0, state_cache_size=0, data_buffer_time_limit_ms=-1, - container_executable=None # type: Optional[str] - ): - # type: (...) -> Tuple[str, grpc.Server] + container_executable: Optional[str] = None) -> Tuple[str, grpc.Server]: options = [("grpc.http2.max_pings_without_data", 0), ("grpc.http2.max_ping_strikes", 0)] worker_server = grpc.server( @@ -121,11 +118,10 @@ def kill_worker_processes(): return worker_address, worker_server - def StartWorker(self, - start_worker_request, # type: beam_fn_api_pb2.StartWorkerRequest - unused_context - ): - # type: (...) -> beam_fn_api_pb2.StartWorkerResponse + def StartWorker( + self, + start_worker_request: beam_fn_api_pb2.StartWorkerRequest, + unused_context) -> beam_fn_api_pb2.StartWorkerResponse: try: if self._use_process: command = [ @@ -182,11 +178,10 @@ def StartWorker(self, except Exception: return beam_fn_api_pb2.StartWorkerResponse(error=traceback.format_exc()) - def StopWorker(self, - stop_worker_request, # type: beam_fn_api_pb2.StopWorkerRequest - unused_context - ): - # type: (...) -> beam_fn_api_pb2.StopWorkerResponse + def StopWorker( + self, + stop_worker_request: beam_fn_api_pb2.StopWorkerRequest, + unused_context) -> beam_fn_api_pb2.StopWorkerResponse: # applicable for process mode to ensure process cleanup # thread based workers terminate automatically worker_process = self._worker_processes.pop( diff --git a/sdks/python/apache_beam/runners/worker/worker_status.py b/sdks/python/apache_beam/runners/worker/worker_status.py index a7f4890344a8..2271b4495d79 100644 --- a/sdks/python/apache_beam/runners/worker/worker_status.py +++ b/sdks/python/apache_beam/runners/worker/worker_status.py @@ -96,9 +96,7 @@ def heap_dump(): return banner + heap + ending -def _state_cache_stats(state_cache): - # type: (StateCache) -> str - +def _state_cache_stats(state_cache: StateCache) -> str: """Gather state cache statistics.""" cache_stats = ['=' * 10 + ' CACHE STATS ' + '=' * 10] if not state_cache.is_cache_enabled(): diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/monitor.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/monitor.py index 9d363bfeec61..064fbb11da5d 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/monitor.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/monitor.py @@ -32,8 +32,7 @@ class Monitor(object): name_prefix: a prefix for this Monitor's metrics' names, intended to be unique in per-monitor basis in pipeline """ - def __init__(self, namespace, name_prefix): - # type: (str, str) -> None + def __init__(self, namespace: str, name_prefix: str) -> None: self.namespace = namespace self.name_prefix = name_prefix self.doFn = MonitorDoFn(namespace, name_prefix) diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_launcher.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_launcher.py index bdf6f476212d..ec686543c3f9 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_launcher.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_launcher.py @@ -381,8 +381,7 @@ def monitor(self, job, event_monitor, result_monitor): return perf @staticmethod - def log_performance(perf): - # type: (NexmarkPerf) -> None + def log_performance(perf: NexmarkPerf) -> None: logging.info( 'input event count: %d, output event count: %d' % (perf.event_count, perf.result_count)) diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_perf.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_perf.py index e691b312e201..c29825f95f3e 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_perf.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_perf.py @@ -36,9 +36,7 @@ def __init__( # number of result produced self.result_count = result_count if result_count else -1 - def has_progress(self, previous_perf): - # type: (NexmarkPerf) -> bool - + def has_progress(self, previous_perf: 'NexmarkPerf') -> bool: """ Args: previous_perf: a NexmarkPerf object to be compared to self diff --git a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_util.py b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_util.py index 570fcb1e1ec0..ef53156d8be0 100644 --- a/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_util.py +++ b/sdks/python/apache_beam/testing/benchmarks/nexmark/nexmark_util.py @@ -219,15 +219,13 @@ def unnest_to_json(cand): return cand -def millis_to_timestamp(millis): - # type: (int) -> Timestamp +def millis_to_timestamp(millis: int) -> Timestamp: micro_second = millis * 1000 return Timestamp(micros=micro_second) -def get_counter_metric(result, namespace, name): - # type: (PipelineResult, str, str) -> int - +def get_counter_metric( + result: PipelineResult, namespace: str, name: str) -> int: """ get specific counter metric from pipeline result @@ -249,9 +247,8 @@ def get_counter_metric(result, namespace, name): return counters[0].result if len(counters) > 0 else -1 -def get_start_time_metric(result, namespace, name): - # type: (PipelineResult, str, str) -> int - +def get_start_time_metric( + result: PipelineResult, namespace: str, name: str) -> int: """ get the start time out of all times recorded by the specified distribution metric @@ -271,9 +268,8 @@ def get_start_time_metric(result, namespace, name): return min(min_list) if len(min_list) > 0 else -1 -def get_end_time_metric(result, namespace, name): - # type: (PipelineResult, str, str) -> int - +def get_end_time_metric( + result: PipelineResult, namespace: str, name: str) -> int: """ get the end time out of all times recorded by the specified distribution metric diff --git a/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py b/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py index d1da4667dcb8..caadbaca1e1e 100644 --- a/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py +++ b/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py @@ -199,7 +199,7 @@ def __init__( bq_table=None, bq_dataset=None, publish_to_bq=False, - influxdb_options=None, # type: Optional[InfluxDBMetricsPublisherOptions] + influxdb_options: Optional['InfluxDBMetricsPublisherOptions'] = None, namespace=None, filters=None): """Initializes :class:`MetricsReader` . @@ -524,37 +524,31 @@ def save(self, results): class InfluxDBMetricsPublisherOptions(object): def __init__( self, - measurement, # type: str - db_name, # type: str - hostname, # type: str - user=None, # type: Optional[str] - password=None # type: Optional[str] - ): + measurement: str, + db_name: str, + hostname: str, + user: Optional[str] = None, + password: Optional[str] = None): self.measurement = measurement self.db_name = db_name self.hostname = hostname self.user = user self.password = password - def validate(self): - # type: () -> bool + def validate(self) -> bool: return bool(self.measurement) and bool(self.db_name) - def http_auth_enabled(self): - # type: () -> bool + def http_auth_enabled(self) -> bool: return self.user is not None and self.password is not None class InfluxDBMetricsPublisher(MetricsPublisher): """Publishes collected metrics to InfluxDB database.""" - def __init__( - self, - options # type: InfluxDBMetricsPublisherOptions - ): + def __init__(self, options: InfluxDBMetricsPublisherOptions): self.options = options - def publish(self, results): - # type: (List[Mapping[str, Union[float, str, int]]]) -> None + def publish( + self, results: List[Mapping[str, Union[float, str, int]]]) -> None: url = '{}/write'.format(self.options.hostname) payload = self._build_payload(results) query_str = {'db': self.options.db_name, 'precision': 's'} @@ -575,8 +569,8 @@ def publish(self, results): 'with an error message: %s' % (response.status_code, content['error'])) - def _build_payload(self, results): - # type: (List[Mapping[str, Union[float, str, int]]]) -> str + def _build_payload( + self, results: List[Mapping[str, Union[float, str, int]]]) -> str: def build_kv(mapping, key): return '{}={}'.format(key, mapping[key]) diff --git a/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py b/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py index 51f66b3c1bb0..3cb5f32c3114 100644 --- a/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py +++ b/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py @@ -35,7 +35,7 @@ @with_input_types(int) @with_output_types(int) class CallSequenceEnforcingCombineFn(beam.CombineFn): - instances = set() # type: Set[CallSequenceEnforcingCombineFn] + instances: Set['CallSequenceEnforcingCombineFn'] = set() def __init__(self): super().__init__() diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py index 53d5190bf625..8b05e8da1df5 100644 --- a/sdks/python/apache_beam/transforms/combiners.py +++ b/sdks/python/apache_beam/transforms/combiners.py @@ -380,7 +380,7 @@ def push(hp, e): return False if self._compare or self._key: - heapc = [] # type: List[cy_combiners.ComparableValue] + heapc: List[cy_combiners.ComparableValue] = [] for bundle in bundles: if not heapc: heapc = [ diff --git a/sdks/python/apache_beam/transforms/external_java.py b/sdks/python/apache_beam/transforms/external_java.py index 534b2622c8a0..e3984fa8ef20 100644 --- a/sdks/python/apache_beam/transforms/external_java.py +++ b/sdks/python/apache_beam/transforms/external_java.py @@ -21,6 +21,7 @@ import logging import subprocess import sys +from typing import Optional import grpc from mock import patch @@ -46,8 +47,8 @@ class JavaExternalTransformTest(object): # This will be overwritten if set via a flag. - expansion_service_jar = None # type: str - expansion_service_port = None # type: int + expansion_service_jar: Optional[str] = None + expansion_service_port: Optional[int] = None class _RunWithExpansion(object): def __init__(self): diff --git a/sdks/python/apache_beam/transforms/resources.py b/sdks/python/apache_beam/transforms/resources.py index 7c4160df8edd..04f38d368122 100644 --- a/sdks/python/apache_beam/transforms/resources.py +++ b/sdks/python/apache_beam/transforms/resources.py @@ -26,18 +26,15 @@ """ import re -from typing import TYPE_CHECKING from typing import Any from typing import Dict +from typing import Mapping from typing import Optional +from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import StandardOptions from apache_beam.portability.common_urns import resource_hints -if TYPE_CHECKING: - from typing import Mapping - from apache_beam.options.pipeline_options import PipelineOptions - __all__ = [ 'ResourceHint', 'AcceleratorHint', @@ -52,13 +49,13 @@ class ResourceHint: """A superclass to define resource hints.""" # A unique URN, one per Resource Hint class. - urn = None # type: Optional[str] + urn: Optional[str] = None - _urn_to_known_hints = {} # type: Dict[str, type] - _name_to_known_hints = {} # type: Dict[str, type] + _urn_to_known_hints: Dict[str, type] = {} + _name_to_known_hints: Dict[str, type] = {} @classmethod - def parse(cls, value): # type: (str) -> Dict[str, bytes] + def parse(cls, value: str) -> Dict[str, bytes]: """Describes how to parse the hint. Override to specify a custom parsing logic.""" assert cls.urn is not None @@ -66,8 +63,7 @@ def parse(cls, value): # type: (str) -> Dict[str, bytes] return {cls.urn: ResourceHint._parse_str(value)} @classmethod - def get_merged_value( - cls, outer_value, inner_value): # type: (bytes, bytes) -> bytes + def get_merged_value(cls, outer_value: bytes, inner_value: bytes) -> bytes: """Reconciles values of a hint when the hint specified on a transform is also defined in an outer context, for example on a composite transform, or specified in the transform's execution environment. @@ -89,8 +85,7 @@ def is_registered(name): return name in ResourceHint._name_to_known_hints @staticmethod - def register_resource_hint( - hint_name, hint_class): # type: (str, type) -> None + def register_resource_hint(hint_name: str, hint_class: type) -> None: assert issubclass(hint_class, ResourceHint) assert hint_class.urn is not None ResourceHint._name_to_known_hints[hint_name] = hint_class @@ -164,12 +159,11 @@ class MinRamHint(ResourceHint): urn = resource_hints.MIN_RAM_BYTES.urn @classmethod - def parse(cls, value): # type: (str) -> Dict[str, bytes] + def parse(cls, value: str) -> Dict[str, bytes]: return {cls.urn: ResourceHint._parse_storage_size_str(value)} @classmethod - def get_merged_value( - cls, outer_value, inner_value): # type: (bytes, bytes) -> bytes + def get_merged_value(cls, outer_value: bytes, inner_value: bytes) -> bytes: return ResourceHint._use_max(outer_value, inner_value) @@ -183,8 +177,7 @@ class CpuCountHint(ResourceHint): urn = resource_hints.CPU_COUNT.urn @classmethod - def get_merged_value( - cls, outer_value, inner_value): # type: (bytes, bytes) -> bytes + def get_merged_value(cls, outer_value: bytes, inner_value: bytes) -> bytes: return ResourceHint._use_max(outer_value, inner_value) @@ -193,7 +186,7 @@ def get_merged_value( ResourceHint.register_resource_hint('cpuCount', CpuCountHint) -def parse_resource_hints(hints): # type: (Dict[Any, Any]) -> Dict[str, bytes] +def parse_resource_hints(hints: Dict[Any, Any]) -> Dict[str, bytes]: parsed_hints = {} for hint, value in hints.items(): try: @@ -208,8 +201,8 @@ def parse_resource_hints(hints): # type: (Dict[Any, Any]) -> Dict[str, bytes] return parsed_hints -def resource_hints_from_options(options): - # type: (Optional[PipelineOptions]) -> Dict[str, bytes] +def resource_hints_from_options( + options: Optional[PipelineOptions]) -> Dict[str, bytes]: if options is None: return {} hints = {} @@ -225,8 +218,8 @@ def resource_hints_from_options(options): def merge_resource_hints( - outer_hints, inner_hints -): # type: (Mapping[str, bytes], Mapping[str, bytes]) -> Dict[str, bytes] + outer_hints: Mapping[str, bytes], + inner_hints: Mapping[str, bytes]) -> Dict[str, bytes]: merged_hints = dict(inner_hints) for urn, outer_value in outer_hints.items(): if urn in inner_hints: diff --git a/sdks/python/apache_beam/transforms/sideinputs.py b/sdks/python/apache_beam/transforms/sideinputs.py index 5c92eafe5422..0ff2a388b9e1 100644 --- a/sdks/python/apache_beam/transforms/sideinputs.py +++ b/sdks/python/apache_beam/transforms/sideinputs.py @@ -45,21 +45,20 @@ # Top-level function so we can identify it later. -def _global_window_mapping_fn(w, global_window=window.GlobalWindow()): - # type: (...) -> window.GlobalWindow +def _global_window_mapping_fn( + w, global_window=window.GlobalWindow()) -> window.GlobalWindow: return global_window -def default_window_mapping_fn(target_window_fn): - # type: (window.WindowFn) -> WindowMappingFn +def default_window_mapping_fn( + target_window_fn: window.WindowFn) -> WindowMappingFn: if target_window_fn == window.GlobalWindows(): return _global_window_mapping_fn if isinstance(target_window_fn, window.Sessions): raise RuntimeError("Sessions is not allowed in side inputs") - def map_via_end(source_window): - # type: (window.BoundedWindow) -> window.BoundedWindow + def map_via_end(source_window: window.BoundedWindow) -> window.BoundedWindow: return list( target_window_fn.assign( window.WindowFn.AssignContext(source_window.max_timestamp())))[-1] @@ -67,8 +66,7 @@ def map_via_end(source_window): return map_via_end -def get_sideinput_index(tag): - # type: (str) -> int +def get_sideinput_index(tag: str) -> int: match = re.match(SIDE_INPUT_REGEX, tag, re.DOTALL) if match: return int(match.group(1)) @@ -78,28 +76,22 @@ def get_sideinput_index(tag): class SideInputMap(object): """Represents a mapping of windows to side input values.""" - def __init__( - self, - view_class, # type: pvalue.AsSideInput - view_options, - iterable): + def __init__(self, view_class: 'pvalue.AsSideInput', view_options, iterable): self._window_mapping_fn = view_options.get( 'window_mapping_fn', _global_window_mapping_fn) self._view_class = view_class self._view_options = view_options self._iterable = iterable - self._cache = {} # type: Dict[window.BoundedWindow, Any] + self._cache: Dict[window.BoundedWindow, Any] = {} - def __getitem__(self, window): - # type: (window.BoundedWindow) -> Any + def __getitem__(self, window: window.BoundedWindow) -> Any: if window not in self._cache: target_window = self._window_mapping_fn(window) self._cache[window] = self._view_class._from_runtime_iterable( _FilteringIterable(self._iterable, target_window), self._view_options) return self._cache[window] - def is_globally_windowed(self): - # type: () -> bool + def is_globally_windowed(self) -> bool: return self._window_mapping_fn == _global_window_mapping_fn diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py index 6483859adcf8..63895704727f 100644 --- a/sdks/python/apache_beam/transforms/trigger.py +++ b/sdks/python/apache_beam/transforms/trigger.py @@ -181,8 +181,7 @@ class DataLossReason(Flag): # to `reason & flag == flag` -def _IncludesMayFinish(reason): - # type: (DataLossReason) -> bool +def _IncludesMayFinish(reason: DataLossReason) -> bool: return reason & DataLossReason.MAY_FINISH == DataLossReason.MAY_FINISH @@ -267,9 +266,7 @@ def reset(self, window, context): """Clear any state and timers used by this TriggerFn.""" pass - def may_lose_data(self, unused_windowing): - # type: (core.Windowing) -> DataLossReason - + def may_lose_data(self, unused_windowing: core.Windowing) -> DataLossReason: """Returns whether or not this trigger could cause data loss. A trigger can cause data loss in the following scenarios: diff --git a/sdks/python/apache_beam/transforms/userstate.py b/sdks/python/apache_beam/transforms/userstate.py index c266d0685472..ada0b755bd6c 100644 --- a/sdks/python/apache_beam/transforms/userstate.py +++ b/sdks/python/apache_beam/transforms/userstate.py @@ -38,20 +38,19 @@ from apache_beam.portability import common_urns from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.transforms.timeutil import TimeDomain +from apache_beam.utils import windowed_value +from apache_beam.utils.timestamp import Timestamp if TYPE_CHECKING: from apache_beam.runners.pipeline_context import PipelineContext - from apache_beam.transforms.core import CombineFn, DoFn - from apache_beam.utils import windowed_value - from apache_beam.utils.timestamp import Timestamp + from apache_beam.transforms.core import DoFn CallableT = TypeVar('CallableT', bound=Callable) class StateSpec(object): """Specification for a user DoFn state cell.""" - def __init__(self, name, coder): - # type: (str, Coder) -> None + def __init__(self, name: str, coder: Coder) -> None: if not isinstance(name, str): raise TypeError("name is not a string") if not isinstance(coder, Coder): @@ -59,19 +58,18 @@ def __init__(self, name, coder): self.name = name self.coder = coder - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: return '%s(%s)' % (self.__class__.__name__, self.name) - def to_runner_api(self, context): - # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec + def to_runner_api( + self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: raise NotImplementedError class ReadModifyWriteStateSpec(StateSpec): """Specification for a user DoFn value state cell.""" - def to_runner_api(self, context): - # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec + def to_runner_api( + self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( read_modify_write_spec=beam_runner_api_pb2.ReadModifyWriteStateSpec( coder_id=context.coders.get_id(self.coder)), @@ -81,8 +79,8 @@ def to_runner_api(self, context): class BagStateSpec(StateSpec): """Specification for a user DoFn bag state cell.""" - def to_runner_api(self, context): - # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec + def to_runner_api( + self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( bag_spec=beam_runner_api_pb2.BagStateSpec( element_coder_id=context.coders.get_id(self.coder)), @@ -92,8 +90,8 @@ def to_runner_api(self, context): class SetStateSpec(StateSpec): """Specification for a user DoFn Set State cell""" - def to_runner_api(self, context): - # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec + def to_runner_api( + self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( set_spec=beam_runner_api_pb2.SetStateSpec( element_coder_id=context.coders.get_id(self.coder)), @@ -103,9 +101,11 @@ def to_runner_api(self, context): class CombiningValueStateSpec(StateSpec): """Specification for a user DoFn combining value state cell.""" - def __init__(self, name, coder=None, combine_fn=None): - # type: (str, Optional[Coder], Any) -> None - + def __init__( + self, + name: str, + coder: Optional[Coder] = None, + combine_fn: Any = None) -> None: """Initialize the specification for CombiningValue state. CombiningValueStateSpec(name, combine_fn) -> Coder-inferred combining value @@ -140,8 +140,8 @@ def __init__(self, name, coder=None, combine_fn=None): super().__init__(name, coder) - def to_runner_api(self, context): - # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec + def to_runner_api( + self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: return beam_runner_api_pb2.StateSpec( combining_spec=beam_runner_api_pb2.CombiningStateSpec( combine_fn=self.combine_fn.to_runner_api(context), @@ -169,29 +169,26 @@ class TimerSpec(object): """Specification for a user stateful DoFn timer.""" prefix = "ts-" - def __init__(self, name, time_domain): - # type: (str, str) -> None + def __init__(self, name: str, time_domain: str) -> None: self.name = self.prefix + name if time_domain not in (TimeDomain.WATERMARK, TimeDomain.REAL_TIME): raise ValueError('Unsupported TimeDomain: %r.' % (time_domain, )) self.time_domain = time_domain - self._attached_callback = None # type: Optional[Callable] + self._attached_callback: Optional[Callable] = None - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: return '%s(%s)' % (self.__class__.__name__, self.name) - def to_runner_api(self, context, key_coder, window_coder): - # type: (PipelineContext, Coder, Coder) -> beam_runner_api_pb2.TimerFamilySpec + def to_runner_api( + self, context: 'PipelineContext', key_coder: Coder, + window_coder: Coder) -> beam_runner_api_pb2.TimerFamilySpec: return beam_runner_api_pb2.TimerFamilySpec( time_domain=TimeDomain.to_runner_api(self.time_domain), timer_family_coder_id=context.coders.get_id( coders._TimerCoder(key_coder, window_coder))) -def on_timer(timer_spec): - # type: (TimerSpec) -> Callable[[CallableT], CallableT] - +def on_timer(timer_spec: TimerSpec) -> Callable[[CallableT], CallableT]: """Decorator for timer firing DoFn method. This decorator allows a user to specify an on_timer processing method @@ -208,8 +205,7 @@ def my_timer_expiry_callback(self): if not isinstance(timer_spec, TimerSpec): raise ValueError('@on_timer decorator expected TimerSpec.') - def _inner(method): - # type: (CallableT) -> CallableT + def _inner(method: CallableT) -> CallableT: if not callable(method): raise ValueError('@on_timer decorator expected callable.') if timer_spec._attached_callback: @@ -221,9 +217,7 @@ def _inner(method): return _inner -def get_dofn_specs(dofn): - # type: (DoFn) -> Tuple[Set[StateSpec], Set[TimerSpec]] - +def get_dofn_specs(dofn: 'DoFn') -> Tuple[Set[StateSpec], Set[TimerSpec]]: """Gets the state and timer specs for a DoFn, if any. Args: @@ -262,9 +256,7 @@ def get_dofn_specs(dofn): return all_state_specs, all_timer_specs -def is_stateful_dofn(dofn): - # type: (DoFn) -> bool - +def is_stateful_dofn(dofn: 'DoFn') -> bool: """Determines whether a given DoFn is a stateful DoFn.""" # A Stateful DoFn is a DoFn that uses user state or timers. @@ -272,9 +264,7 @@ def is_stateful_dofn(dofn): return bool(all_state_specs or all_timer_specs) -def validate_stateful_dofn(dofn): - # type: (DoFn) -> None - +def validate_stateful_dofn(dofn: 'DoFn') -> None: """Validates the proper specification of a stateful DoFn.""" # Get state and timer specs. @@ -306,12 +296,10 @@ def validate_stateful_dofn(dofn): class BaseTimer(object): - def clear(self, dynamic_timer_tag=''): - # type: (str) -> None + def clear(self, dynamic_timer_tag: str = '') -> None: raise NotImplementedError - def set(self, timestamp, dynamic_timer_tag=''): - # type: (Timestamp, str) -> None + def set(self, timestamp: Timestamp, dynamic_timer_tag: str = '') -> None: raise NotImplementedError @@ -321,66 +309,54 @@ def set(self, timestamp, dynamic_timer_tag=''): class RuntimeTimer(BaseTimer): """Timer interface object passed to user code.""" def __init__(self) -> None: - self._timer_recordings = {} # type: Dict[str, _TimerTuple] + self._timer_recordings: Dict[str, _TimerTuple] = {} self._cleared = False - self._new_timestamp = None # type: Optional[Timestamp] + self._new_timestamp: Optional[Timestamp] = None - def clear(self, dynamic_timer_tag=''): - # type: (str) -> None + def clear(self, dynamic_timer_tag: str = '') -> None: self._timer_recordings[dynamic_timer_tag] = _TimerTuple( cleared=True, timestamp=None) - def set(self, timestamp, dynamic_timer_tag=''): - # type: (Timestamp, str) -> None + def set(self, timestamp: Timestamp, dynamic_timer_tag: str = '') -> None: self._timer_recordings[dynamic_timer_tag] = _TimerTuple( cleared=False, timestamp=timestamp) class RuntimeState(object): """State interface object passed to user code.""" - def prefetch(self): - # type: () -> None + def prefetch(self) -> None: # The default implementation here does nothing. pass - def finalize(self): - # type: () -> None + def finalize(self) -> None: pass class ReadModifyWriteRuntimeState(RuntimeState): - def read(self): - # type: () -> Any + def read(self) -> Any: raise NotImplementedError(type(self)) - def write(self, value): - # type: (Any) -> None + def write(self, value: Any) -> None: raise NotImplementedError(type(self)) - def clear(self): - # type: () -> None + def clear(self) -> None: raise NotImplementedError(type(self)) - def commit(self): - # type: () -> None + def commit(self) -> None: raise NotImplementedError(type(self)) class AccumulatingRuntimeState(RuntimeState): - def read(self): - # type: () -> Iterable[Any] + def read(self) -> Iterable[Any]: raise NotImplementedError(type(self)) - def add(self, value): - # type: (Any) -> None + def add(self, value: Any) -> None: raise NotImplementedError(type(self)) - def clear(self): - # type: () -> None + def clear(self) -> None: raise NotImplementedError(type(self)) - def commit(self): - # type: () -> None + def commit(self) -> None: raise NotImplementedError(type(self)) @@ -398,24 +374,23 @@ class CombiningValueRuntimeState(AccumulatingRuntimeState): class UserStateContext(object): """Wrapper allowing user state and timers to be accessed by a DoFnInvoker.""" - def get_timer(self, - timer_spec, # type: TimerSpec - key, # type: Any - window, # type: windowed_value.BoundedWindow - timestamp, # type: Timestamp - pane, # type: windowed_value.PaneInfo - ): - # type: (...) -> BaseTimer + def get_timer( + self, + timer_spec: TimerSpec, + key: Any, + window: 'windowed_value.BoundedWindow', + timestamp: Timestamp, + pane: windowed_value.PaneInfo, + ) -> BaseTimer: raise NotImplementedError(type(self)) - def get_state(self, - state_spec, # type: StateSpec - key, # type: Any - window, # type: windowed_value.BoundedWindow - ): - # type: (...) -> RuntimeState + def get_state( + self, + state_spec: StateSpec, + key: Any, + window: 'windowed_value.BoundedWindow', + ) -> RuntimeState: raise NotImplementedError(type(self)) - def commit(self): - # type: () -> None + def commit(self) -> None: raise NotImplementedError(type(self)) diff --git a/sdks/python/apache_beam/transforms/userstate_test.py b/sdks/python/apache_beam/transforms/userstate_test.py index e17894ccb949..5dd6c61d6add 100644 --- a/sdks/python/apache_beam/transforms/userstate_test.py +++ b/sdks/python/apache_beam/transforms/userstate_test.py @@ -437,7 +437,7 @@ def __repr__(self): class StatefulDoFnOnDirectRunnerTest(unittest.TestCase): # pylint: disable=expression-not-assigned - all_records = None # type: List[Any] + all_records: List[Any] def setUp(self): # Use state on the TestCase class, since other references would be pickled diff --git a/sdks/python/apache_beam/transforms/window.py b/sdks/python/apache_beam/transforms/window.py index c76b30fb8ff7..592164a5ef49 100644 --- a/sdks/python/apache_beam/transforms/window.py +++ b/sdks/python/apache_beam/transforms/window.py @@ -104,8 +104,9 @@ class TimestampCombiner(object): OUTPUT_AT_EARLIEST_TRANSFORMED = 'OUTPUT_AT_EARLIEST_TRANSFORMED' @staticmethod - def get_impl(timestamp_combiner, window_fn): - # type: (beam_runner_api_pb2.OutputTime.Enum, WindowFn) -> timeutil.TimestampCombinerImpl + def get_impl( + timestamp_combiner: beam_runner_api_pb2.OutputTime.Enum, + window_fn: 'WindowFn') -> timeutil.TimestampCombinerImpl: if timestamp_combiner == TimestampCombiner.OUTPUT_AT_EOW: return timeutil.OutputAtEndOfWindowImpl() elif timestamp_combiner == TimestampCombiner.OUTPUT_AT_EARLIEST: @@ -124,18 +125,17 @@ class AssignContext(object): """Context passed to WindowFn.assign().""" def __init__( self, - timestamp, # type: TimestampTypes - element=None, # type: Optional[Any] - window=None # type: Optional[BoundedWindow] - ): - # type: (...) -> None + timestamp: TimestampTypes, + element: Optional[Any] = None, + window: Optional['BoundedWindow'] = None) -> None: self.timestamp = Timestamp.of(timestamp) self.element = element self.window = window @abc.abstractmethod - def assign(self, assign_context): - # type: (AssignContext) -> Iterable[BoundedWindow] # noqa: F821 + def assign(self, + assign_context: 'AssignContext') -> Iterable['BoundedWindow']: + # noqa: F821 """Associates windows to an element. @@ -149,35 +149,30 @@ def assign(self, assign_context): class MergeContext(object): """Context passed to WindowFn.merge() to perform merging, if any.""" - def __init__(self, windows): - # type: (Iterable[BoundedWindow]) -> None + def __init__(self, windows: Iterable['BoundedWindow']) -> None: self.windows = list(windows) - def merge(self, to_be_merged, merge_result): - # type: (Iterable[BoundedWindow], BoundedWindow) -> None + def merge( + self, + to_be_merged: Iterable['BoundedWindow'], + merge_result: 'BoundedWindow') -> None: raise NotImplementedError @abc.abstractmethod - def merge(self, merge_context): - # type: (WindowFn.MergeContext) -> None - + def merge(self, merge_context: 'WindowFn.MergeContext') -> None: """Returns a window that is the result of merging a set of windows.""" raise NotImplementedError - def is_merging(self): - # type: () -> bool - + def is_merging(self) -> bool: """Returns whether this WindowFn merges windows.""" return True @abc.abstractmethod - def get_window_coder(self): - # type: () -> coders.Coder + def get_window_coder(self) -> coders.Coder: raise NotImplementedError - def get_transformed_output_time(self, window, input_timestamp): # pylint: disable=unused-argument - # type: (BoundedWindow, Timestamp) -> Timestamp - + def get_transformed_output_time( + self, window: 'BoundedWindow', input_timestamp: Timestamp) -> Timestamp: # pylint: disable=unused-argument """Given input time and output window, returns output time for window. If TimestampCombiner.OUTPUT_AT_EARLIEST_TRANSFORMED is used in the @@ -205,22 +200,18 @@ class BoundedWindow(object): Attributes: end: End of window. """ - def __init__(self, end): - # type: (TimestampTypes) -> None + def __init__(self, end: TimestampTypes) -> None: self._end = Timestamp.of(end) @property - def start(self): - # type: () -> Timestamp + def start(self) -> Timestamp: raise NotImplementedError @property - def end(self): - # type: () -> Timestamp + def end(self) -> Timestamp: return self._end - def max_timestamp(self): - # type: () -> Timestamp + def max_timestamp(self) -> Timestamp: return self.end.predecessor() def __eq__(self, other): @@ -270,12 +261,10 @@ def __lt__(self, other): return self.end < other.end return hash(self) < hash(other) - def intersects(self, other): - # type: (IntervalWindow) -> bool + def intersects(self, other: 'IntervalWindow') -> bool: return other.start < self.end or self.start < other.end - def union(self, other): - # type: (IntervalWindow) -> IntervalWindow + def union(self, other: 'IntervalWindow') -> 'IntervalWindow': return IntervalWindow( min(self.start, other.start), max(self.end, other.end)) @@ -291,8 +280,7 @@ class TimestampedValue(Generic[V]): value: The underlying value. timestamp: Timestamp associated with the value as seconds since Unix epoch. """ - def __init__(self, value, timestamp): - # type: (V, TimestampTypes) -> None + def __init__(self, value: V, timestamp: TimestampTypes) -> None: self.value = value self.timestamp = Timestamp.of(timestamp) @@ -314,15 +302,14 @@ def __lt__(self, other): class GlobalWindow(BoundedWindow): """The default window into which all data is placed (via GlobalWindows).""" - _instance = None # type: GlobalWindow + _instance: Optional['GlobalWindow'] = None def __new__(cls): if cls._instance is None: cls._instance = super(GlobalWindow, cls).__new__(cls) return cls._instance - def __init__(self): - # type: () -> None + def __init__(self) -> None: super().__init__(GlobalWindow._getTimestampFromProto()) def __repr__(self): @@ -336,25 +323,21 @@ def __eq__(self, other): return self is other or type(self) is type(other) @property - def start(self): - # type: () -> Timestamp + def start(self) -> Timestamp: return MIN_TIMESTAMP @staticmethod - def _getTimestampFromProto(): - # type: () -> Timestamp + def _getTimestampFromProto() -> Timestamp: ts_millis = int( common_urns.constants.GLOBAL_WINDOW_MAX_TIMESTAMP_MILLIS.constant) return Timestamp(micros=ts_millis * 1000) class NonMergingWindowFn(WindowFn): - def is_merging(self): - # type: () -> bool + def is_merging(self) -> bool: return False - def merge(self, merge_context): - # type: (WindowFn.MergeContext) -> None + def merge(self, merge_context: WindowFn.MergeContext) -> None: pass # No merging. @@ -363,34 +346,31 @@ class GlobalWindows(NonMergingWindowFn): @classmethod def windowed_batch( cls, - batch, # type: Any - timestamp=MIN_TIMESTAMP, # type: Timestamp - pane_info=windowed_value.PANE_INFO_UNKNOWN # type: windowed_value.PaneInfo - ): - # type: (...) -> windowed_value.WindowedBatch + batch: Any, + timestamp: Timestamp = MIN_TIMESTAMP, + pane_info: windowed_value.PaneInfo = windowed_value.PANE_INFO_UNKNOWN + ) -> windowed_value.WindowedBatch: return windowed_value.HomogeneousWindowedBatch.of( batch, timestamp, (GlobalWindow(), ), pane_info) @classmethod def windowed_value( cls, - value, # type: Any - timestamp=MIN_TIMESTAMP, # type: Timestamp - pane_info=windowed_value.PANE_INFO_UNKNOWN # type: windowed_value.PaneInfo - ): - # type: (...) -> WindowedValue + value: Any, + timestamp: Timestamp = MIN_TIMESTAMP, + pane_info: windowed_value.PaneInfo = windowed_value.PANE_INFO_UNKNOWN + ) -> WindowedValue: return WindowedValue(value, timestamp, (GlobalWindow(), ), pane_info) @classmethod def windowed_value_at_end_of_window(cls, value): return cls.windowed_value(value, GlobalWindow().max_timestamp()) - def assign(self, assign_context): - # type: (WindowFn.AssignContext) -> List[GlobalWindow] + def assign(self, + assign_context: WindowFn.AssignContext) -> List[GlobalWindow]: return [GlobalWindow()] - def get_window_coder(self): - # type: () -> coders.GlobalWindowCoder + def get_window_coder(self) -> coders.GlobalWindowCoder: return coders.GlobalWindowCoder() def __hash__(self): @@ -405,8 +385,8 @@ def to_runner_api_parameter(self, context): @staticmethod @urns.RunnerApiFn.register_urn(common_urns.global_windows.urn, None) - def from_runner_api_parameter(unused_fn_parameter, unused_context): - # type: (...) -> GlobalWindows + def from_runner_api_parameter( + unused_fn_parameter, unused_context) -> 'GlobalWindows': return GlobalWindows() @@ -424,11 +404,7 @@ class FixedWindows(NonMergingWindowFn): value in range [0, size). If it is not it will be normalized to this range. """ - def __init__( - self, - size, # type: DurationTypes - offset=0 # type: TimestampTypes - ): + def __init__(self, size: DurationTypes, offset: TimestampTypes = 0): """Initialize a ``FixedWindows`` function for a given size and offset. Args: @@ -443,14 +419,12 @@ def __init__( self.size = Duration.of(size) self.offset = Timestamp.of(offset) % self.size - def assign(self, context): - # type: (WindowFn.AssignContext) -> List[IntervalWindow] + def assign(self, context: WindowFn.AssignContext) -> List[IntervalWindow]: timestamp = context.timestamp start = timestamp - (timestamp - self.offset) % self.size return [IntervalWindow(start, start + self.size)] - def get_window_coder(self): - # type: () -> coders.IntervalWindowCoder + def get_window_coder(self) -> coders.IntervalWindowCoder: return coders.IntervalWindowCoder() def __eq__(self, other): @@ -473,8 +447,7 @@ def to_runner_api_parameter(self, context): @urns.RunnerApiFn.register_urn( common_urns.fixed_windows.urn, standard_window_fns_pb2.FixedWindowsPayload) - def from_runner_api_parameter(fn_parameter, unused_context): - # type: (...) -> FixedWindows + def from_runner_api_parameter(fn_parameter, unused_context) -> 'FixedWindows': return FixedWindows( size=Duration(micros=fn_parameter.size.ToMicroseconds()), offset=Timestamp(micros=fn_parameter.offset.ToMicroseconds())) @@ -494,20 +467,19 @@ class SlidingWindows(NonMergingWindowFn): t=N * period + offset where t=0 is the epoch. The offset must be a value in range [0, period). If it is not it will be normalized to this range. """ - - def __init__(self, - size, # type: DurationTypes - period, # type: DurationTypes - offset=0, # type: TimestampTypes - ): + def __init__( + self, + size: DurationTypes, + period: DurationTypes, + offset: TimestampTypes = 0, + ): if size <= 0: raise ValueError('The size parameter must be strictly positive.') self.size = Duration.of(size) self.period = Duration.of(period) self.offset = Timestamp.of(offset) % period - def assign(self, context): - # type: (WindowFn.AssignContext) -> List[IntervalWindow] + def assign(self, context: WindowFn.AssignContext) -> List[IntervalWindow]: timestamp = context.timestamp start = timestamp - ((timestamp - self.offset) % self.period) return [ @@ -520,8 +492,7 @@ def assign(self, context): -self.period.micros) ] - def get_window_coder(self): - # type: () -> coders.IntervalWindowCoder + def get_window_coder(self) -> coders.IntervalWindowCoder: return coders.IntervalWindowCoder() def __eq__(self, other): @@ -548,8 +519,8 @@ def to_runner_api_parameter(self, context): @urns.RunnerApiFn.register_urn( common_urns.sliding_windows.urn, standard_window_fns_pb2.SlidingWindowsPayload) - def from_runner_api_parameter(fn_parameter, unused_context): - # type: (...) -> SlidingWindows + def from_runner_api_parameter( + fn_parameter, unused_context) -> 'SlidingWindows': return SlidingWindows( size=Duration(micros=fn_parameter.size.ToMicroseconds()), offset=Timestamp(micros=fn_parameter.offset.ToMicroseconds()), @@ -565,24 +536,20 @@ class Sessions(WindowFn): Attributes: gap_size: Size of the gap between windows as floating-point seconds. """ - def __init__(self, gap_size): - # type: (DurationTypes) -> None + def __init__(self, gap_size: DurationTypes) -> None: if gap_size <= 0: raise ValueError('The size parameter must be strictly positive.') self.gap_size = Duration.of(gap_size) - def assign(self, context): - # type: (WindowFn.AssignContext) -> List[IntervalWindow] + def assign(self, context: WindowFn.AssignContext) -> List[IntervalWindow]: timestamp = context.timestamp return [IntervalWindow(timestamp, timestamp + self.gap_size)] - def get_window_coder(self): - # type: () -> coders.IntervalWindowCoder + def get_window_coder(self) -> coders.IntervalWindowCoder: return coders.IntervalWindowCoder() - def merge(self, merge_context): - # type: (WindowFn.MergeContext) -> None - to_merge = [] # type: List[BoundedWindow] + def merge(self, merge_context: WindowFn.MergeContext) -> None: + to_merge: List[BoundedWindow] = [] end = MIN_TIMESTAMP for w in sorted(merge_context.windows, key=lambda w: w.start): if to_merge: @@ -620,7 +587,6 @@ def to_runner_api_parameter(self, context): @urns.RunnerApiFn.register_urn( common_urns.session_windows.urn, standard_window_fns_pb2.SessionWindowsPayload) - def from_runner_api_parameter(fn_parameter, unused_context): - # type: (...) -> Sessions + def from_runner_api_parameter(fn_parameter, unused_context) -> 'Sessions': return Sessions( gap_size=Duration(micros=fn_parameter.gap_size.ToMicroseconds())) diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index c24f2ed8f43c..9c0cc2b8af4e 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -202,8 +202,11 @@ class IOTypeHints(NamedTuple): origin: List[str] @classmethod - def _make_origin(cls, bases, tb=True, msg=()): - # type: (List[IOTypeHints], bool, Iterable[str]) -> List[str] + def _make_origin( + cls, + bases: List['IOTypeHints'], + tb: bool = True, + msg: Iterable[str] = ()) -> List[str]: if msg: res = list(msg) else: @@ -229,16 +232,12 @@ def _make_origin(cls, bases, tb=True, msg=()): return res @classmethod - def empty(cls): - # type: () -> IOTypeHints - + def empty(cls) -> 'IOTypeHints': """Construct a base IOTypeHints object with no hints.""" return IOTypeHints(None, None, []) @classmethod - def from_callable(cls, fn): - # type: (Callable) -> Optional[IOTypeHints] - + def from_callable(cls, fn: Callable) -> Optional['IOTypeHints']: """Construct an IOTypeHints object from a callable's signature. Supports Python 3 annotations. For partial annotations, sets unknown types @@ -292,23 +291,19 @@ def from_callable(cls, fn): output_types=(tuple(output_args), {}), origin=cls._make_origin([], tb=False, msg=msg)) - def with_input_types(self, *args, **kwargs): - # type: (...) -> IOTypeHints + def with_input_types(self, *args, **kwargs) -> 'IOTypeHints': return self._replace( input_types=(args, kwargs), origin=self._make_origin([self])) - def with_output_types(self, *args, **kwargs): - # type: (...) -> IOTypeHints + def with_output_types(self, *args, **kwargs) -> 'IOTypeHints': return self._replace( output_types=(args, kwargs), origin=self._make_origin([self])) - def with_input_types_from(self, other): - # type: (IOTypeHints) -> IOTypeHints + def with_input_types_from(self, other: 'IOTypeHints') -> 'IOTypeHints': return self._replace( input_types=other.input_types, origin=self._make_origin([self])) - def with_output_types_from(self, other): - # type: (IOTypeHints) -> IOTypeHints + def with_output_types_from(self, other: 'IOTypeHints') -> 'IOTypeHints': return self._replace( output_types=other.output_types, origin=self._make_origin([self])) @@ -355,14 +350,14 @@ def strip_pcoll(self): def strip_pcoll_helper( self, - my_type, # type: any - has_my_type, # type: Callable[[], bool] - my_key, # type: str - special_containers, # type: List[Union[PBegin, PDone, PCollection]] # noqa: F821 - error_str, # type: str - source_str # type: str - ): - # type: (...) -> IOTypeHints + my_type: any, + has_my_type: Callable[[], bool], + my_key: str, + special_containers: List[ + Union['PBegin', 'PDone', 'PCollection']], # noqa: F821 + error_str: str, + source_str: str + ) -> 'IOTypeHints': from apache_beam.pvalue import PCollection if not has_my_type() or not my_type or len(my_type[0]) != 1: @@ -396,9 +391,7 @@ def strip_pcoll_helper( origin=self._make_origin([self], tb=False, msg=[source_str]), **kwarg_dict) - def strip_iterable(self): - # type: () -> IOTypeHints - + def strip_iterable(self) -> 'IOTypeHints': """Removes outer Iterable (or equivalent) from output type. Only affects instances with simple output types, otherwise is a no-op. @@ -437,8 +430,7 @@ def strip_iterable(self): output_types=((yielded_type, ), {}), origin=self._make_origin([self], tb=False, msg=['strip_iterable()'])) - def with_defaults(self, hints): - # type: (Optional[IOTypeHints]) -> IOTypeHints + def with_defaults(self, hints: Optional['IOTypeHints']) -> 'IOTypeHints': if not hints: return self if not self: @@ -501,8 +493,7 @@ class WithTypeHints(object): def __init__(self, *unused_args, **unused_kwargs): self._type_hints = IOTypeHints.empty() - def _get_or_create_type_hints(self): - # type: () -> IOTypeHints + def _get_or_create_type_hints(self) -> IOTypeHints: # __init__ may have not been called try: # Only return an instance bound to self (see BEAM-8629). @@ -524,23 +515,24 @@ def get_type_hints(self): self.default_type_hints()).with_defaults( get_type_hints(self.__class__))) - def _set_type_hints(self, type_hints): - # type: (IOTypeHints) -> None + def _set_type_hints(self, type_hints: IOTypeHints) -> None: self._type_hints = type_hints def default_type_hints(self): return None - def with_input_types(self, *arg_hints, **kwarg_hints): - # type: (WithTypeHintsT, *Any, **Any) -> WithTypeHintsT + def with_input_types( + self: WithTypeHintsT, *arg_hints: Any, + **kwarg_hints: Any) -> WithTypeHintsT: arg_hints = native_type_compatibility.convert_to_beam_types(arg_hints) kwarg_hints = native_type_compatibility.convert_to_beam_types(kwarg_hints) self._type_hints = self._get_or_create_type_hints().with_input_types( *arg_hints, **kwarg_hints) return self - def with_output_types(self, *arg_hints, **kwarg_hints): - # type: (WithTypeHintsT, *Any, **Any) -> WithTypeHintsT + def with_output_types( + self: WithTypeHintsT, *arg_hints: Any, + **kwarg_hints: Any) -> WithTypeHintsT: arg_hints = native_type_compatibility.convert_to_beam_types(arg_hints) kwarg_hints = native_type_compatibility.convert_to_beam_types(kwarg_hints) self._type_hints = self._get_or_create_type_hints().with_output_types( @@ -681,9 +673,7 @@ def getcallargs_forhints(func, *type_args, **type_kwargs): return dict(bound_args) -def get_type_hints(fn): - # type: (Any) -> IOTypeHints - +def get_type_hints(fn: Any) -> IOTypeHints: """Gets the type hint associated with an arbitrary object fn. Always returns a valid IOTypeHints object, creating one if necessary. @@ -704,9 +694,8 @@ def get_type_hints(fn): # pylint: enable=protected-access -def with_input_types(*positional_hints, **keyword_hints): - # type: (*Any, **Any) -> Callable[[T], T] - +def with_input_types(*positional_hints: Any, + **keyword_hints: Any) -> Callable[[T], T]: """A decorator that type-checks defined type-hints with passed func arguments. All type-hinted arguments can be specified using positional arguments, @@ -790,9 +779,8 @@ def annotate_input_types(f): return annotate_input_types -def with_output_types(*return_type_hint, **kwargs): - # type: (*Any, **Any) -> Callable[[T], T] - +def with_output_types(*return_type_hint: Any, + **kwargs: Any) -> Callable[[T], T]: """A decorator that type-checks defined type-hints for return values(s). This decorator will type-check the return value(s) of the decorated function. diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index cd517cd6ac70..621adc44507e 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -184,7 +184,7 @@ def is_forward_ref(typ): # Mapping from typing.TypeVar/typehints.TypeVariable ids to an object of the # other type. Bidirectional mapping preserves typing.TypeVar instances. -_type_var_cache = {} # type: typing.Dict[int, typehints.TypeVariable] +_type_var_cache: typing.Dict[int, typehints.TypeVariable] = {} def convert_builtin_to_typing(typ): diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 70eb78b6ffc6..b368f0abdf3d 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -1257,7 +1257,7 @@ def __getitem__(self, type_params): # There is a circular dependency between defining this mapping # and using it in normalize(). Initialize it here and populate # it below. -_KNOWN_PRIMITIVE_TYPES = {} # type: typing.Dict[type, typing.Any] +_KNOWN_PRIMITIVE_TYPES: typing.Dict[type, typing.Any] = {} def normalize(x, none_as_type=False): diff --git a/sdks/python/apache_beam/utils/profiler.py b/sdks/python/apache_beam/utils/profiler.py index d10703c17289..c75fdcc5878d 100644 --- a/sdks/python/apache_beam/utils/profiler.py +++ b/sdks/python/apache_beam/utils/profiler.py @@ -45,18 +45,18 @@ class Profile(object): SORTBY = 'cumulative' - profile_output = None # type: str - stats = None # type: pstats.Stats + profile_output: str + stats: pstats.Stats def __init__( self, - profile_id, # type: str - profile_location=None, # type: Optional[str] - log_results=False, # type: bool - file_copy_fn=None, # type: Optional[Callable[[str, str], None]] - time_prefix='%Y-%m-%d_%H_%M_%S-', # type: str - enable_cpu_profiling=False, # type: bool - enable_memory_profiling=False, # type: bool + profile_id: str, + profile_location: Optional[str] = None, + log_results: bool = False, + file_copy_fn: Optional[Callable[[str, str], None]] = None, + time_prefix: str = '%Y-%m-%d_%H_%M_%S-', + enable_cpu_profiling: bool = False, + enable_memory_profiling: bool = False, ): """Creates a Profile object. @@ -139,8 +139,7 @@ def default_file_copy_fn(src, dest): filesystems.FileSystems.rename([dest + '.tmp'], [dest]) @staticmethod - def factory_from_options(options): - # type: (...) -> Optional[Callable[..., Profile]] + def factory_from_options(options) -> Optional[Callable[..., 'Profile']]: if options.profile_cpu or options.profile_memory: def create_profiler(profile_id, **kwargs): @@ -156,8 +155,7 @@ def create_profiler(profile_id, **kwargs): return None def _upload_profile_data( - self, profile_location, dir, data, write_binary=True): - # type: (...) -> str + self, profile_location, dir, data, write_binary=True) -> str: dump_location = os.path.join( profile_location, dir, diff --git a/sdks/python/apache_beam/utils/proto_utils.py b/sdks/python/apache_beam/utils/proto_utils.py index 3a5e020df167..cc637dead477 100644 --- a/sdks/python/apache_beam/utils/proto_utils.py +++ b/sdks/python/apache_beam/utils/proto_utils.py @@ -38,14 +38,12 @@ @overload -def pack_Any(msg): - # type: (message.Message) -> any_pb2.Any +def pack_Any(msg: message.Message) -> any_pb2.Any: pass @overload -def pack_Any(msg): - # type: (None) -> None +def pack_Any(msg: None) -> None: pass @@ -63,14 +61,12 @@ def pack_Any(msg): @overload -def unpack_Any(any_msg, msg_class): - # type: (any_pb2.Any, Type[MessageT]) -> MessageT +def unpack_Any(any_msg: any_pb2.Any, msg_class: Type[MessageT]) -> MessageT: pass @overload -def unpack_Any(any_msg, msg_class): - # type: (any_pb2.Any, None) -> None +def unpack_Any(any_msg: any_pb2.Any, msg_class: None) -> None: pass @@ -87,14 +83,13 @@ def unpack_Any(any_msg, msg_class): @overload -def parse_Bytes(serialized_bytes, msg_class): - # type: (bytes, Type[MessageT]) -> MessageT +def parse_Bytes(serialized_bytes: bytes, msg_class: Type[MessageT]) -> MessageT: pass @overload -def parse_Bytes(serialized_bytes, msg_class): - # type: (bytes, Union[Type[bytes], None]) -> bytes +def parse_Bytes( + serialized_bytes: bytes, msg_class: Union[Type[bytes], None]) -> bytes: pass @@ -109,9 +104,7 @@ def parse_Bytes(serialized_bytes, msg_class): return msg -def pack_Struct(**kwargs): - # type: (...) -> struct_pb2.Struct - +def pack_Struct(**kwargs) -> struct_pb2.Struct: """Returns a struct containing the values indicated by kwargs. """ msg = struct_pb2.Struct() @@ -120,16 +113,13 @@ def pack_Struct(**kwargs): return msg -def from_micros(cls, micros): - # type: (Type[TimeMessageT], int) -> TimeMessageT +def from_micros(cls: Type[TimeMessageT], micros: int) -> TimeMessageT: result = cls() result.FromMicroseconds(micros) return result -def to_Timestamp(time): - # type: (Union[int, float]) -> timestamp_pb2.Timestamp - +def to_Timestamp(time: Union[int, float]) -> timestamp_pb2.Timestamp: """Convert a float returned by time.time() to a Timestamp. """ seconds = int(time) @@ -137,9 +127,7 @@ def to_Timestamp(time): return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) -def from_Timestamp(timestamp): - # type: (timestamp_pb2.Timestamp) -> float - +def from_Timestamp(timestamp: timestamp_pb2.Timestamp) -> float: """Convert a Timestamp to a float expressed as seconds since the epoch. """ return timestamp.seconds + float(timestamp.nanos) / 10**9 diff --git a/sdks/python/apache_beam/utils/python_callable.py b/sdks/python/apache_beam/utils/python_callable.py index 70aa7cb39e5c..f6f507300ea8 100644 --- a/sdks/python/apache_beam/utils/python_callable.py +++ b/sdks/python/apache_beam/utils/python_callable.py @@ -43,8 +43,7 @@ class PythonCallableWithSource(object): is a valid chunk of source code. """ - def __init__(self, source): - # type: (str) -> None + def __init__(self, source: str) -> None: self._source = source self._callable = self.load_from_source(source) @@ -120,8 +119,7 @@ def default_label(self): def _argspec_fn(self): return self._callable - def get_source(self): - # type: () -> str + def get_source(self) -> str: return self._source def __call__(self, *args, **kwargs): diff --git a/sdks/python/apache_beam/utils/sharded_key.py b/sdks/python/apache_beam/utils/sharded_key.py index 9a03ab36bfd2..f6492779ef34 100644 --- a/sdks/python/apache_beam/utils/sharded_key.py +++ b/sdks/python/apache_beam/utils/sharded_key.py @@ -30,9 +30,8 @@ class ShardedKey(object): def __init__( self, key, - shard_id, # type: bytes - ): - # type: (...) -> None + shard_id: bytes, + ) -> None: assert shard_id is not None self._key = key self._shard_id = shard_id diff --git a/sdks/python/apache_beam/utils/shared.py b/sdks/python/apache_beam/utils/shared.py index d7eed350b0c1..bb04d1a19fb0 100644 --- a/sdks/python/apache_beam/utils/shared.py +++ b/sdks/python/apache_beam/utils/shared.py @@ -109,13 +109,7 @@ def __init__(self): self._ref = None self._tag = None - def acquire( - self, - constructor_fn, # type: Callable[[], Any] - tag=None # type: Any - ): - # type: (...) -> Any - + def acquire(self, constructor_fn: Callable[[], Any], tag: Any = None) -> Any: """Acquire a reference to the object this shared control block manages. Args: @@ -209,18 +203,14 @@ def __init__(self): # to keep it alive self._keepalive = (None, None) - def make_key(self): - # type: (...) -> Text + def make_key(self) -> Text: return str(uuid.uuid1()) def acquire( self, - key, # type: Text - constructor_fn, # type: Callable[[], Any] - tag=None # type: Any - ): - # type: (...) -> Any - + key: Text, + constructor_fn: Callable[[], Any], + tag: Any = None) -> Any: """Acquire a reference to a Shared object. Args: @@ -280,13 +270,7 @@ class Shared(object): def __init__(self): self._key = _shared_map.make_key() - def acquire( - self, - constructor_fn, # type: Callable[[], Any] - tag=None # type: Any - ): - # type: (...) -> Any - + def acquire(self, constructor_fn: Callable[[], Any], tag: Any = None) -> Any: """Acquire a reference to the object associated with this Shared handle. Args: diff --git a/sdks/python/apache_beam/utils/timestamp.py b/sdks/python/apache_beam/utils/timestamp.py index c54b5bf44e5c..3f585eecae08 100644 --- a/sdks/python/apache_beam/utils/timestamp.py +++ b/sdks/python/apache_beam/utils/timestamp.py @@ -52,8 +52,10 @@ class Timestamp(object): especially after arithmetic operations (for example, 10000000 % 0.1 evaluates to 0.0999999994448885). """ - def __init__(self, seconds=0, micros=0): - # type: (Union[int, float], Union[int, float]) -> None + def __init__( + self, + seconds: Union[int, float] = 0, + micros: Union[int, float] = 0) -> None: if not isinstance(seconds, (int, float)): raise TypeError( 'Cannot interpret %s %s as seconds.' % (seconds, type(seconds))) @@ -63,9 +65,7 @@ def __init__(self, seconds=0, micros=0): self.micros = int(seconds * 1000000) + int(micros) @staticmethod - def of(seconds): - # type: (TimestampTypes) -> Timestamp - + def of(seconds: TimestampTypes) -> 'Timestamp': """Return the Timestamp for the given number of seconds. If the input is already a Timestamp, the input itself will be returned. @@ -88,19 +88,15 @@ def of(seconds): 'Cannot interpret %s %s as Timestamp.' % (seconds, type(seconds))) @staticmethod - def now(): - # type: () -> Timestamp + def now() -> 'Timestamp': return Timestamp(seconds=time.time()) @staticmethod - def _epoch_datetime_utc(): - # type: () -> datetime.datetime + def _epoch_datetime_utc() -> datetime.datetime: return datetime.datetime.fromtimestamp(0, pytz.utc) @classmethod - def from_utc_datetime(cls, dt): - # type: (datetime.datetime) -> Timestamp - + def from_utc_datetime(cls, dt: datetime.datetime) -> 'Timestamp': """Create a ``Timestamp`` instance from a ``datetime.datetime`` object. Args: @@ -117,9 +113,7 @@ def from_utc_datetime(cls, dt): return Timestamp(duration.total_seconds()) @classmethod - def from_rfc3339(cls, rfc3339): - # type: (str) -> Timestamp - + def from_rfc3339(cls, rfc3339: str) -> 'Timestamp': """Create a ``Timestamp`` instance from an RFC 3339 compliant string. .. note:: @@ -140,20 +134,15 @@ def seconds(self) -> int: """Returns the timestamp in seconds.""" return self.micros // 1000000 - def predecessor(self): - # type: () -> Timestamp - + def predecessor(self) -> 'Timestamp': """Returns the largest timestamp smaller than self.""" return Timestamp(micros=self.micros - 1) - def successor(self): - # type: () -> Timestamp - + def successor(self) -> 'Timestamp': """Returns the smallest timestamp larger than self.""" return Timestamp(micros=self.micros + 1) - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: micros = self.micros sign = '' if micros < 0: @@ -165,9 +154,7 @@ def __repr__(self): return 'Timestamp(%s%d.%06d)' % (sign, int_part, frac_part) return 'Timestamp(%s%d)' % (sign, int_part) - def to_utc_datetime(self, has_tz=False): - # type: (bool) -> datetime.datetime - + def to_utc_datetime(self, has_tz: bool = False) -> datetime.datetime: """Returns a ``datetime.datetime`` object of UTC for this Timestamp. Note that this method returns a ``datetime.datetime`` object without a @@ -189,23 +176,18 @@ def to_utc_datetime(self, has_tz=False): epoch = epoch.replace(tzinfo=None) return epoch + datetime.timedelta(microseconds=self.micros) - def to_rfc3339(self): - # type: () -> str + def to_rfc3339(self) -> str: # Append 'Z' for UTC timezone. return self.to_utc_datetime().isoformat() + 'Z' - def to_proto(self): - # type: () -> timestamp_pb2.Timestamp - + def to_proto(self) -> timestamp_pb2.Timestamp: """Returns the `google.protobuf.timestamp_pb2` representation.""" secs = self.micros // 1000000 nanos = (self.micros % 1000000) * 1000 return timestamp_pb2.Timestamp(seconds=secs, nanos=nanos) @staticmethod - def from_proto(timestamp_proto): - # type: (timestamp_pb2.Timestamp) -> Timestamp - + def from_proto(timestamp_proto: timestamp_pb2.Timestamp) -> 'Timestamp': """Creates a Timestamp from a `google.protobuf.timestamp_pb2`. Note that the google has a sub-second resolution of nanoseconds whereas this @@ -227,18 +209,15 @@ class has a resolution of microsends. This class will truncate the return Timestamp( seconds=timestamp_proto.seconds, micros=timestamp_proto.nanos // 1000) - def __float__(self): - # type: () -> float + def __float__(self) -> float: # Note that the returned value may have lost precision. return self.micros / 1000000 - def __int__(self): - # type: () -> int + def __int__(self) -> int: # Note that the returned value may have lost precision. return self.micros // 1000000 - def __eq__(self, other): - # type: (object) -> bool + def __eq__(self, other: object) -> bool: # Allow comparisons between Duration and Timestamp values. if isinstance(other, (Duration, Timestamp)): return self.micros == other.micros @@ -248,57 +227,48 @@ def __eq__(self, other): # Support equality with other types return NotImplemented - def __lt__(self, other): - # type: (TimestampDurationTypes) -> bool + def __lt__(self, other: TimestampDurationTypes) -> bool: # Allow comparisons between Duration and Timestamp values. if not isinstance(other, Duration): other = Timestamp.of(other) return self.micros < other.micros - def __gt__(self, other): - # type: (TimestampDurationTypes) -> bool + def __gt__(self, other: TimestampDurationTypes) -> bool: return not (self < other or self == other) - def __le__(self, other): - # type: (TimestampDurationTypes) -> bool + def __le__(self, other: TimestampDurationTypes) -> bool: return self < other or self == other - def __ge__(self, other): - # type: (TimestampDurationTypes) -> bool + def __ge__(self, other: TimestampDurationTypes) -> bool: return not self < other - def __hash__(self): - # type: () -> int + def __hash__(self) -> int: return hash(self.micros) - def __add__(self, other): - # type: (DurationTypes) -> Timestamp + def __add__(self, other: DurationTypes) -> 'Timestamp': other = Duration.of(other) return Timestamp(micros=self.micros + other.micros) - def __radd__(self, other): - # type: (DurationTypes) -> Timestamp + def __radd__(self, other: DurationTypes) -> 'Timestamp': return self + other @overload - def __sub__(self, other): - # type: (DurationTypes) -> Timestamp + def __sub__(self, other: DurationTypes) -> 'Timestamp': pass @overload - def __sub__(self, other): - # type: (Timestamp) -> Duration + def __sub__(self, other: 'Timestamp') -> 'Duration': pass - def __sub__(self, other): - # type: (Union[DurationTypes, Timestamp]) -> Union[Timestamp, Duration] + def __sub__( + self, other: Union[DurationTypes, + 'Timestamp']) -> Union['Timestamp', 'Duration']: if isinstance(other, Timestamp): return Duration(micros=self.micros - other.micros) other = Duration.of(other) return Timestamp(micros=self.micros - other.micros) - def __mod__(self, other): - # type: (DurationTypes) -> Duration + def __mod__(self, other: DurationTypes) -> 'Duration': other = Duration.of(other) return Duration(micros=self.micros % other.micros) @@ -319,14 +289,14 @@ class Duration(object): especially after arithmetic operations (for example, 10000000 % 0.1 evaluates to 0.0999999994448885). """ - def __init__(self, seconds=0, micros=0): - # type: (Union[int, float], Union[int, float]) -> None + def __init__( + self, + seconds: Union[int, float] = 0, + micros: Union[int, float] = 0) -> None: self.micros = int(seconds * 1000000) + int(micros) @staticmethod - def of(seconds): - # type: (DurationTypes) -> Duration - + def of(seconds: DurationTypes) -> 'Duration': """Return the Duration for the given number of seconds since Unix epoch. If the input is already a Duration, the input itself will be returned. @@ -344,18 +314,14 @@ def of(seconds): return seconds return Duration(seconds) - def to_proto(self): - # type: () -> duration_pb2.Duration - + def to_proto(self) -> duration_pb2.Duration: """Returns the `google.protobuf.duration_pb2` representation.""" secs = self.micros // 1000000 nanos = (self.micros % 1000000) * 1000 return duration_pb2.Duration(seconds=secs, nanos=nanos) @staticmethod - def from_proto(duration_proto): - # type: (duration_pb2.Duration) -> Duration - + def from_proto(duration_proto: duration_pb2.Duration) -> 'Duration': """Creates a Duration from a `google.protobuf.duration_pb2`. Note that the google has a sub-second resolution of nanoseconds whereas this @@ -377,8 +343,7 @@ class has a resolution of microsends. This class will truncate the return Duration( seconds=duration_proto.seconds, micros=duration_proto.nanos // 1000) - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: micros = self.micros sign = '' if micros < 0: @@ -390,13 +355,11 @@ def __repr__(self): return 'Duration(%s%d.%06d)' % (sign, int_part, frac_part) return 'Duration(%s%d)' % (sign, int_part) - def __float__(self): - # type: () -> float + def __float__(self) -> float: # Note that the returned value may have lost precision. return self.micros / 1000000 - def __eq__(self, other): - # type: (object) -> bool + def __eq__(self, other: object) -> bool: # Allow comparisons between Duration and Timestamp values. if isinstance(other, (Duration, Timestamp)): return self.micros == other.micros @@ -406,65 +369,52 @@ def __eq__(self, other): # Support equality with other types return NotImplemented - def __lt__(self, other): - # type: (TimestampDurationTypes) -> bool + def __lt__(self, other: TimestampDurationTypes) -> bool: # Allow comparisons between Duration and Timestamp values. if not isinstance(other, Timestamp): other = Duration.of(other) return self.micros < other.micros - def __gt__(self, other): - # type: (TimestampDurationTypes) -> bool + def __gt__(self, other: TimestampDurationTypes) -> bool: return not (self < other or self == other) - def __le__(self, other): - # type: (TimestampDurationTypes) -> bool + def __le__(self, other: TimestampDurationTypes) -> bool: return self < other or self == other - def __ge__(self, other): - # type: (TimestampDurationTypes) -> bool + def __ge__(self, other: TimestampDurationTypes) -> bool: return not self < other - def __hash__(self): - # type: () -> int + def __hash__(self) -> int: return hash(self.micros) - def __neg__(self): - # type: () -> Duration + def __neg__(self) -> 'Duration': return Duration(micros=-self.micros) - def __add__(self, other): - # type: (DurationTypes) -> Duration + def __add__(self, other: DurationTypes) -> 'Duration': if isinstance(other, Timestamp): # defer to Timestamp.__add__ return NotImplemented other = Duration.of(other) return Duration(micros=self.micros + other.micros) - def __radd__(self, other): - # type: (DurationTypes) -> Duration + def __radd__(self, other: DurationTypes) -> 'Duration': return self + other - def __sub__(self, other): - # type: (DurationTypes) -> Duration + def __sub__(self, other: DurationTypes) -> 'Duration': other = Duration.of(other) return Duration(micros=self.micros - other.micros) - def __rsub__(self, other): - # type: (DurationTypes) -> Duration + def __rsub__(self, other: DurationTypes) -> 'Duration': return -(self - other) - def __mul__(self, other): - # type: (DurationTypes) -> Duration + def __mul__(self, other: DurationTypes) -> 'Duration': other = Duration.of(other) return Duration(micros=self.micros * other.micros // 1000000) - def __rmul__(self, other): - # type: (DurationTypes) -> Duration + def __rmul__(self, other: DurationTypes) -> 'Duration': return self * other - def __mod__(self, other): - # type: (DurationTypes) -> Duration + def __mod__(self, other: DurationTypes) -> 'Duration': other = Duration.of(other) return Duration(micros=self.micros % other.micros) diff --git a/sdks/python/apache_beam/utils/urns.py b/sdks/python/apache_beam/utils/urns.py index 3f2cb43e9753..2647a0200bde 100644 --- a/sdks/python/apache_beam/utils/urns.py +++ b/sdks/python/apache_beam/utils/urns.py @@ -38,10 +38,10 @@ from google.protobuf import wrappers_pb2 from apache_beam.internal import pickler +from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.utils import proto_utils if TYPE_CHECKING: - from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.runners.pipeline_context import PipelineContext T = TypeVar('T') @@ -65,7 +65,7 @@ class RunnerApiFn(object): # classes + abc metaclass # __metaclass__ = abc.ABCMeta - _known_urns = {} # type: Dict[str, Tuple[Optional[type], ConstructorFn]] + _known_urns: Dict[str, Tuple[Optional[type], ConstructorFn]] = {} # @abc.abstractmethod is disabled here to avoid an error with mypy. mypy # performs abc.abtractmethod/property checks even if a class does @@ -74,9 +74,8 @@ class RunnerApiFn(object): # mypy incorrectly infers that this method has not been overridden with a # concrete implementation. # @abc.abstractmethod - def to_runner_api_parameter(self, unused_context): - # type: (PipelineContext) -> Tuple[str, Any] - + def to_runner_api_parameter( + self, unused_context: 'PipelineContext') -> Tuple[str, Any]: """Returns the urn and payload for this Fn. The returned urn(s) should be registered with `register_urn`. @@ -87,40 +86,38 @@ def to_runner_api_parameter(self, unused_context): @overload def register_urn( cls, - urn, # type: str - parameter_type, # type: Type[T] - ): - # type: (...) -> Callable[[Callable[[T, PipelineContext], Any]], Callable[[T, PipelineContext], Any]] + urn: str, + parameter_type: Type[T], + ) -> Callable[[Callable[[T, 'PipelineContext'], Any]], + Callable[[T, 'PipelineContext'], Any]]: pass @classmethod @overload def register_urn( cls, - urn, # type: str - parameter_type, # type: None - ): - # type: (...) -> Callable[[Callable[[bytes, PipelineContext], Any]], Callable[[bytes, PipelineContext], Any]] + urn: str, + parameter_type: None, + ) -> Callable[[Callable[[bytes, 'PipelineContext'], Any]], + Callable[[bytes, 'PipelineContext'], Any]]: pass @classmethod @overload - def register_urn(cls, - urn, # type: str - parameter_type, # type: Type[T] - fn # type: Callable[[T, PipelineContext], Any] - ): - # type: (...) -> None + def register_urn( + cls, + urn: str, + parameter_type: Type[T], + fn: Callable[[T, 'PipelineContext'], Any]) -> None: pass @classmethod @overload - def register_urn(cls, - urn, # type: str - parameter_type, # type: None - fn # type: Callable[[bytes, PipelineContext], Any] - ): - # type: (...) -> None + def register_urn( + cls, + urn: str, + parameter_type: None, + fn: Callable[[bytes, 'PipelineContext'], Any]) -> None: pass @classmethod @@ -161,14 +158,12 @@ def register_pickle_urn(cls, pickle_urn): lambda proto, unused_context: pickler.loads(proto.value)) - def to_runner_api(self, context): - # type: (PipelineContext) -> beam_runner_api_pb2.FunctionSpec - + def to_runner_api( + self, context: 'PipelineContext') -> beam_runner_api_pb2.FunctionSpec: """Returns an FunctionSpec encoding this Fn. Prefer overriding self.to_runner_api_parameter. """ - from apache_beam.portability.api import beam_runner_api_pb2 urn, typed_param = self.to_runner_api_parameter(context) return beam_runner_api_pb2.FunctionSpec( urn=urn, @@ -176,9 +171,10 @@ def to_runner_api(self, context): typed_param, message.Message) else typed_param) @classmethod - def from_runner_api(cls, fn_proto, context): - # type: (Type[RunnerApiFnT], beam_runner_api_pb2.FunctionSpec, PipelineContext) -> RunnerApiFnT - + def from_runner_api( + cls: Type[RunnerApiFnT], + fn_proto: beam_runner_api_pb2.FunctionSpec, + context: 'PipelineContext') -> RunnerApiFnT: """Converts from an FunctionSpec to a Fn object. Prefer registering a urn with its parameter type and constructor. diff --git a/sdks/python/apache_beam/version.py b/sdks/python/apache_beam/version.py index 8a04408c17d0..7bebe118a63d 100644 --- a/sdks/python/apache_beam/version.py +++ b/sdks/python/apache_beam/version.py @@ -17,4 +17,4 @@ """Apache Beam SDK version information and utilities.""" -__version__ = '2.58.0.dev' +__version__ = '2.59.0.dev' diff --git a/sdks/python/apache_beam/yaml/pipeline.schema.yaml b/sdks/python/apache_beam/yaml/pipeline.schema.yaml index f68a7306d941..c3937e611317 100644 --- a/sdks/python/apache_beam/yaml/pipeline.schema.yaml +++ b/sdks/python/apache_beam/yaml/pipeline.schema.yaml @@ -168,8 +168,10 @@ $defs: providerOrProviderInclude: if: - properties: - include {} + allOf: [ + { properties: { include: { type: string }}}, + { required: [ "include" ] } + ] then: $ref: '#/$defs/providerInclude' else: diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index c1c509ebde2c..ffef9bbcd8f0 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -1064,5 +1064,5 @@ def expand_pipeline( return YamlTransform( pipeline_as_composite(pipeline_spec['pipeline']), yaml_provider.merge_providers( - pipeline_spec.get('providers', []), providers or - {})).expand(beam.pvalue.PBegin(pipeline)) + yaml_provider.parse_providers(pipeline_spec.get('providers', [])), + providers or {})).expand(beam.pvalue.PBegin(pipeline)) diff --git a/sdks/python/build.gradle b/sdks/python/build.gradle index 72c696d30bf5..832f567c66fd 100644 --- a/sdks/python/build.gradle +++ b/sdks/python/build.gradle @@ -95,7 +95,7 @@ tasks.register("generateYamlDocs") { dependsOn ":sdks:java:extensions:sql:expansion-service:shadowJar" dependsOn ":sdks:java:io:expansion-service:build" dependsOn ":sdks:java:io:google-cloud-platform:expansion-service:build" - def extraPackages = "pyyaml markdown docstring_parser pandas pygments" + def extraPackages = "pyyaml markdown docstring_parser pandas pygments Jinja2" doLast { exec { diff --git a/sdks/python/container/piputil.go b/sdks/python/container/piputil.go index 113bf4054167..d6250ad2fdcd 100644 --- a/sdks/python/container/piputil.go +++ b/sdks/python/container/piputil.go @@ -32,6 +32,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/util/execx" ) +const pipLogFlushInterval time.Duration = 15 * time.Second const unrecoverableURL string = "https://beam.apache.org/documentation/sdks/python-unrecoverable-errors/index.html#pip-dependency-resolution-failures" // pipInstallRequirements installs the given requirement, if present. @@ -40,7 +41,7 @@ func pipInstallRequirements(ctx context.Context, logger *tools.Logger, files []s if err != nil { return err } - bufLogger := tools.NewBufferedLogger(logger) + bufLogger := tools.NewBufferedLoggerWithFlushInterval(ctx, logger, pipLogFlushInterval) for _, file := range files { if file == name { // We run the install process in two rounds in order to avoid as much @@ -48,7 +49,7 @@ func pipInstallRequirements(ctx context.Context, logger *tools.Logger, files []s // option will make sure that only things staged in the worker will be // used without following their dependencies. args := []string{"-m", "pip", "install", "-r", filepath.Join(dir, name), "--no-cache-dir", "--disable-pip-version-check", "--no-index", "--no-deps", "--find-links", dir} - if err := execx.Execute(pythonVersion, args...); err != nil { + if err := execx.ExecuteEnvWithIO(nil, os.Stdin, bufLogger, bufLogger, pythonVersion, args...); err != nil { bufLogger.Printf(ctx, "Some packages could not be installed solely from the requirements cache. Installing packages from PyPI.") } // The second install round opens up the search for packages on PyPI and @@ -79,8 +80,6 @@ func isPackageInstalled(pkgName string) bool { return true } -const pipLogFlushInterval time.Duration = 15 * time.Second - // pipInstallPackage installs the given package, if present. func pipInstallPackage(ctx context.Context, logger *tools.Logger, files []string, dir, name string, force, optional bool, extras []string) error { pythonVersion, err := expansionx.GetPythonVersion() @@ -150,7 +149,7 @@ func pipInstallPackage(ctx context.Context, logger *tools.Logger, files []string // installExtraPackages installs all the packages declared in the extra // packages manifest file. func installExtraPackages(ctx context.Context, logger *tools.Logger, files []string, extraPackagesFile, dir string) error { - bufLogger := tools.NewBufferedLogger(logger) + bufLogger := tools.NewBufferedLoggerWithFlushInterval(ctx, logger, pipLogFlushInterval) // First check that extra packages manifest file is present. for _, file := range files { if file != extraPackagesFile { @@ -179,7 +178,7 @@ func installExtraPackages(ctx context.Context, logger *tools.Logger, files []str } func findBeamSdkWhl(ctx context.Context, logger *tools.Logger, files []string, acceptableWhlSpecs []string) string { - bufLogger := tools.NewBufferedLogger(logger) + bufLogger := tools.NewBufferedLoggerWithFlushInterval(ctx, logger, pipLogFlushInterval) for _, file := range files { if strings.HasPrefix(file, "apache_beam") { for _, s := range acceptableWhlSpecs { @@ -200,7 +199,7 @@ func findBeamSdkWhl(ctx context.Context, logger *tools.Logger, files []string, a // SDK from source tarball provided in sdkSrcFile. func installSdk(ctx context.Context, logger *tools.Logger, files []string, workDir string, sdkSrcFile string, acceptableWhlSpecs []string, required bool) error { sdkWhlFile := findBeamSdkWhl(ctx, logger, files, acceptableWhlSpecs) - bufLogger := tools.NewBufferedLogger(logger) + bufLogger := tools.NewBufferedLoggerWithFlushInterval(ctx, logger, pipLogFlushInterval) if sdkWhlFile != "" { // by default, pip rejects to install wheel if same version already installed isDev := strings.Contains(sdkWhlFile, ".dev") diff --git a/sdks/python/container/py38/base_image_requirements.txt b/sdks/python/container/py38/base_image_requirements.txt index 9cfd30ba6b3a..f88dba103469 100644 --- a/sdks/python/container/py38/base_image_requirements.txt +++ b/sdks/python/container/py38/base_image_requirements.txt @@ -29,7 +29,7 @@ beautifulsoup4==4.12.3 bs4==0.0.2 build==1.2.1 cachetools==5.3.3 -certifi==2024.2.2 +certifi==2024.7.4 cffi==1.16.0 charset-normalizer==3.3.2 click==8.1.7 diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index 8f3402035eb2..e5d301ecbe14 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -523,8 +523,8 @@ project.tasks.register("inferencePostCommitIT") { // Create cross-language tasks for running tests against Java expansion service(s) -def dataflowProject = project.findProperty('dataflowProject') ?: 'apache-beam-testing' -def dataflowRegion = project.findProperty('dataflowRegion') ?: 'us-central1' +def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing' +def gcpRegion = project.findProperty('gcpRegion') ?: 'us-central1' project(":sdks:python:test-suites:xlang").ext.xlangTasks.each { taskMetadata -> createCrossLanguageUsingJavaExpansionTask( @@ -533,8 +533,8 @@ project(":sdks:python:test-suites:xlang").ext.xlangTasks.each { taskMetadata -> collectMarker: taskMetadata.collectMarker, pythonPipelineOptions: [ "--runner=TestDataflowRunner", - "--project=${dataflowProject}", - "--region=${dataflowRegion}", + "--project=${gcpProject}", + "--region=${gcpRegion}", "--sdk_container_image=gcr.io/apache-beam-testing/beam-sdk/beam_python${project.ext.pythonVersion}_sdk:latest", "--sdk_harness_container_image_overrides=.*java.*,gcr.io/apache-beam-testing/beam-sdk/beam_java8_sdk:latest" ], diff --git a/sdks/python/test-suites/direct/common.gradle b/sdks/python/test-suites/direct/common.gradle index c79c5f66abbc..e290e8003b13 100644 --- a/sdks/python/test-suites/direct/common.gradle +++ b/sdks/python/test-suites/direct/common.gradle @@ -436,7 +436,7 @@ project.tasks.register("inferencePostCommitIT") { } // Create cross-language tasks for running tests against Java expansion service(s) -def gcpProject = project.findProperty('dataflowProject') ?: 'apache-beam-testing' +def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing' project(":sdks:python:test-suites:xlang").ext.xlangTasks.each { taskMetadata -> createCrossLanguageUsingJavaExpansionTask( diff --git a/sdks/python/test-suites/direct/xlang/build.gradle b/sdks/python/test-suites/direct/xlang/build.gradle index 289f5c8a0e07..3003329aef59 100644 --- a/sdks/python/test-suites/direct/xlang/build.gradle +++ b/sdks/python/test-suites/direct/xlang/build.gradle @@ -44,7 +44,7 @@ def cleanupTask = project.tasks.register("fnApiJobServerCleanup", Exec) { args '-c', ". ${envDir}/bin/activate && python -m apache_beam.runners.portability.local_job_service_main --pid_file ${pidFile} --stop" } -def gcpProject = project.findProperty('dataflowProject') ?: 'apache-beam-testing' +def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing' createCrossLanguageValidatesRunnerTask( startJobServer: setupTask, diff --git a/sdks/typescript/package.json b/sdks/typescript/package.json index f62c9c6de586..87a1fca269fc 100644 --- a/sdks/typescript/package.json +++ b/sdks/typescript/package.json @@ -1,6 +1,6 @@ { "name": "apache-beam", - "version": "2.58.0-SNAPSHOT", + "version": "2.59.0-SNAPSHOT", "devDependencies": { "@google-cloud/bigquery": "^5.12.0", "@types/mocha": "^9.0.0", diff --git a/settings.gradle.kts b/settings.gradle.kts index 1f9369c0779f..4d4b93908a02 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -24,7 +24,7 @@ pluginManagement { } plugins { - id("com.gradle.develocity") version "3.17.5" + id("com.gradle.develocity") version "3.17.6" id("com.gradle.common-custom-user-data-gradle-plugin") version "2.0.1" } @@ -153,6 +153,7 @@ include(":runners:jet") include(":runners:local-java") include(":runners:portability:java") include(":runners:prism") +include(":runners:prism:java") include(":runners:spark:3") include(":runners:spark:3:job-server") include(":runners:spark:3:job-server:container") diff --git a/website/www/site/content/en/documentation/io/connectors.md b/website/www/site/content/en/documentation/io/connectors.md index d390a9248cd7..313f72ce622a 100644 --- a/website/www/site/content/en/documentation/io/connectors.md +++ b/website/www/site/content/en/documentation/io/connectors.md @@ -1196,5 +1196,21 @@ This table provides a consolidated, at-a-glance overview of the available built- ✔ ✘ + + + Beam PyIO (Collection of Python IO connectors) + + ✔ + ✔ + Not available + + ✔ + native + + Not available + Not available + ✔ + ✔ +