Skip to content

Commit

Permalink
Support custom JdbcReadWithPartitionsHelper (#31733)
Browse files Browse the repository at this point in the history
  • Loading branch information
Amar3tto authored Jul 2, 2024
1 parent 8e11d0a commit 209d50e
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,10 @@
* <h4>Parallel reading from a JDBC datasource</h4>
*
* <p>Beam supports partitioned reading of all data from a table. Automatic partitioning is
* supported for a few data types: {@link Long}, {@link org.joda.time.DateTime}, {@link String}. To
* enable this, use {@link JdbcIO#readWithPartitions(TypeDescriptor)}.
* supported for a few data types: {@link Long}, {@link org.joda.time.DateTime}. To enable this, use
* {@link JdbcIO#readWithPartitions(TypeDescriptor)}. For other types, use {@link
* ReadWithPartitions#readWithPartitions(JdbcReadWithPartitionsHelper)} with custom {@link
* JdbcReadWithPartitionsHelper}.
*
* <p>The partitioning scheme depends on these parameters, which can be user-provided, or
* automatically inferred by Beam (for the supported types):
Expand Down Expand Up @@ -361,6 +363,7 @@ public static <ParameterT, OutputT> ReadAll<ParameterT, OutputT> readAll() {
* Like {@link #readAll}, but executes multiple instances of the query on the same table
* (subquery) using ranges.
*
* @param partitioningColumnType Type descriptor for the partition column.
* @param <T> Type of the data to be read.
*/
public static <T, PartitionColumnT> ReadWithPartitions<T, PartitionColumnT> readWithPartitions(
Expand All @@ -373,6 +376,23 @@ public static <T, PartitionColumnT> ReadWithPartitions<T, PartitionColumnT> read
.build();
}

/**
* Like {@link #readAll}, but executes multiple instances of the query on the same table
* (subquery) using ranges.
*
* @param partitionsHelper Custom helper for defining partitions.
* @param <T> Type of the data to be read.
*/
public static <T, PartitionColumnT> ReadWithPartitions<T, PartitionColumnT> readWithPartitions(
JdbcReadWithPartitionsHelper<PartitionColumnT> partitionsHelper) {
return new AutoValue_JdbcIO_ReadWithPartitions.Builder<T, PartitionColumnT>()
.setPartitionsHelper(partitionsHelper)
.setNumPartitions(DEFAULT_NUM_PARTITIONS)
.setFetchSize(DEFAULT_FETCH_SIZE)
.setUseBeamSchema(false)
.build();
}

public static <T> ReadWithPartitions<T, Long> readWithPartitions() {
return JdbcIO.<T, Long>readWithPartitions(TypeDescriptors.longs());
}
Expand Down Expand Up @@ -1229,7 +1249,10 @@ public abstract static class ReadWithPartitions<T, PartitionColumnT>
abstract @Nullable String getTable();

@Pure
abstract TypeDescriptor<PartitionColumnT> getPartitionColumnType();
abstract @Nullable TypeDescriptor<PartitionColumnT> getPartitionColumnType();

@Pure
abstract @Nullable JdbcReadWithPartitionsHelper<PartitionColumnT> getPartitionsHelper();

@Pure
abstract Builder<T, PartitionColumnT> toBuilder();
Expand Down Expand Up @@ -1261,6 +1284,9 @@ abstract Builder<T, PartitionColumnT> setDataSourceProviderFn(
abstract Builder<T, PartitionColumnT> setPartitionColumnType(
TypeDescriptor<PartitionColumnT> partitionColumnType);

abstract Builder<T, PartitionColumnT> setPartitionsHelper(
JdbcReadWithPartitionsHelper<PartitionColumnT> partitionsHelper);

abstract ReadWithPartitions<T, PartitionColumnT> build();
}

Expand Down Expand Up @@ -1360,10 +1386,19 @@ && getLowerBound() instanceof Comparable<?>) {
((Comparable<PartitionColumnT>) getLowerBound()).compareTo(getUpperBound()) < EQUAL,
"The lower bound of partitioning column is larger or equal than the upper bound");
}
checkNotNull(
JdbcUtil.JdbcReadWithPartitionsHelper.getPartitionsHelper(getPartitionColumnType()),
"readWithPartitions only supports the following types: %s",
JdbcUtil.PRESET_HELPERS.keySet());

JdbcReadWithPartitionsHelper<PartitionColumnT> partitionsHelper = getPartitionsHelper();
if (partitionsHelper == null) {
partitionsHelper =
JdbcUtil.getPartitionsHelper(
checkStateNotNull(
getPartitionColumnType(),
"Provide partitionColumnType or partitionsHelper for JdbcIO.readWithPartitions()"));
checkNotNull(
partitionsHelper,
"readWithPartitions only supports the following types: %s",
JdbcUtil.PRESET_HELPERS.keySet());
}

PCollection<KV<Long, KV<PartitionColumnT, PartitionColumnT>>> params;

Expand All @@ -1383,10 +1418,7 @@ && getLowerBound() instanceof Comparable<?>) {
JdbcIO.<KV<Long, KV<PartitionColumnT, PartitionColumnT>>>read()
.withQuery(query)
.withDataSourceProviderFn(dataSourceProviderFn)
.withRowMapper(
checkStateNotNull(
JdbcUtil.JdbcReadWithPartitionsHelper.getPartitionsHelper(
getPartitionColumnType())))
.withRowMapper(checkStateNotNull(partitionsHelper))
.withFetchSize(getFetchSize()))
.apply(
MapElements.via(
Expand Down Expand Up @@ -1441,7 +1473,9 @@ public KV<Long, KV<PartitionColumnT, PartitionColumnT>> apply(

PCollection<KV<PartitionColumnT, PartitionColumnT>> ranges =
params
.apply("Partitioning", ParDo.of(new PartitioningFn<>(getPartitionColumnType())))
.apply(
"Partitioning",
ParDo.of(new PartitioningFn<>(checkStateNotNull(partitionsHelper))))
.apply("Reshuffle partitions", Reshuffle.viaRandomKey());

JdbcIO.ReadAll<KV<PartitionColumnT, PartitionColumnT>, T> readAll =
Expand All @@ -1452,11 +1486,7 @@ public KV<Long, KV<PartitionColumnT, PartitionColumnT>> apply(
"select * from %1$s where %2$s >= ? and %2$s < ?", table, partitionColumn))
.withRowMapper(rowMapper)
.withFetchSize(getFetchSize())
.withParameterSetter(
checkStateNotNull(
JdbcUtil.JdbcReadWithPartitionsHelper.getPartitionsHelper(
getPartitionColumnType()))
::setParameters)
.withParameterSetter(checkStateNotNull(partitionsHelper))
.withOutputParallelization(false);

if (getUseBeamSchema()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.jdbc;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import org.apache.beam.sdk.io.jdbc.JdbcIO.PreparedStatementSetter;
import org.apache.beam.sdk.io.jdbc.JdbcIO.RowMapper;
import org.apache.beam.sdk.values.KV;

/**
* A helper for {@link JdbcIO.ReadWithPartitions} that handles range calculations.
*
* @param <PartitionT> Element type of the column used for partition.
*/
public interface JdbcReadWithPartitionsHelper<PartitionT>
extends PreparedStatementSetter<KV<PartitionT, PartitionT>>,
RowMapper<KV<Long, KV<PartitionT, PartitionT>>> {

/**
* Calculate the range of each partition from the lower and upper bound, and number of partitions.
*
* <p>Return a list of pairs for each lower and upper bound within each partition.
*/
Iterable<KV<PartitionT, PartitionT>> calculateRanges(
PartitionT lowerBound, PartitionT upperBound, Long partitions);

@Override
void setParameters(KV<PartitionT, PartitionT> element, PreparedStatement preparedStatement);

@Override
KV<Long, KV<PartitionT, PartitionT>> mapRow(ResultSet resultSet) throws Exception;
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.beam.sdk.io.jdbc;

import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;
import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;

import java.io.File;
import java.io.IOException;
Expand Down Expand Up @@ -47,9 +46,6 @@
import java.util.stream.IntStream;
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.io.jdbc.JdbcIO.PreparedStatementSetter;
import org.apache.beam.sdk.io.jdbc.JdbcIO.ReadWithPartitions;
import org.apache.beam.sdk.io.jdbc.JdbcIO.RowMapper;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.logicaltypes.FixedPrecisionNumeric;
import org.apache.beam.sdk.schemas.logicaltypes.MicrosInstant;
Expand Down Expand Up @@ -438,50 +434,30 @@ private static Calendar withTimestampAndTimezone(DateTime dateTime) {
return calendar;
}

/**
* A helper for {@link ReadWithPartitions} that handles range calculations.
*
* @param <PartitionT>
*/
interface JdbcReadWithPartitionsHelper<PartitionT>
extends PreparedStatementSetter<KV<PartitionT, PartitionT>>,
RowMapper<KV<Long, KV<PartitionT, PartitionT>>> {
static <T> @Nullable JdbcReadWithPartitionsHelper<T> getPartitionsHelper(
TypeDescriptor<T> type) {
// This cast is unchecked, thus this is a small type-checking risk. We just need
// to make sure that all preset helpers in `JdbcUtil.PRESET_HELPERS` are matched
// in type from their Key and their Value.
return (JdbcReadWithPartitionsHelper<T>) PRESET_HELPERS.get(type.getRawType());
}

Iterable<KV<PartitionT, PartitionT>> calculateRanges(
PartitionT lowerBound, PartitionT upperBound, Long partitions);

@Override
void setParameters(KV<PartitionT, PartitionT> element, PreparedStatement preparedStatement);

@Override
KV<Long, KV<PartitionT, PartitionT>> mapRow(ResultSet resultSet) throws Exception;
/** @return a {@code JdbcReadPartitionsHelper} instance associated with the given {@param type} */
static <T> @Nullable JdbcReadWithPartitionsHelper<T> getPartitionsHelper(TypeDescriptor<T> type) {
// This cast is unchecked, thus this is a small type-checking risk. We just need
// to make sure that all preset helpers in `JdbcUtil.PRESET_HELPERS` are matched
// in type from their Key and their Value.
return (JdbcReadWithPartitionsHelper<T>) PRESET_HELPERS.get(type.getRawType());
}

/** Create partitions on a table. */
static class PartitioningFn<T> extends DoFn<KV<Long, KV<T, T>>, KV<T, T>> {
private static final Logger LOG = LoggerFactory.getLogger(PartitioningFn.class);
final TypeDescriptor<T> partitioningColumnType;
final JdbcReadWithPartitionsHelper<T> partitionsHelper;

PartitioningFn(TypeDescriptor<T> partitioningColumnType) {
this.partitioningColumnType = partitioningColumnType;
PartitioningFn(JdbcReadWithPartitionsHelper<T> partitionsHelper) {
this.partitionsHelper = partitionsHelper;
}

@ProcessElement
public void processElement(ProcessContext c) {
T lowerBound = c.element().getValue().getKey();
T upperBound = c.element().getValue().getValue();
JdbcReadWithPartitionsHelper<T> helper =
checkStateNotNull(
JdbcReadWithPartitionsHelper.getPartitionsHelper(partitioningColumnType));
List<KV<T, T>> ranges =
Lists.newArrayList(helper.calculateRanges(lowerBound, upperBound, c.element().getKey()));
Lists.newArrayList(
partitionsHelper.calculateRanges(lowerBound, upperBound, c.element().getKey()));
LOG.warn("Total of {} ranges: {}", ranges.size(), ranges);
for (KV<T, T> e : ranges) {
c.output(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import java.sql.Date;
import java.sql.JDBCType;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.Time;
Expand Down Expand Up @@ -91,6 +92,7 @@
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.commons.dbcp2.PoolingDataSource;
import org.apache.commons.lang3.StringUtils;
import org.hamcrest.Description;
import org.hamcrest.TypeSafeMatcher;
import org.joda.time.DateTime;
Expand All @@ -104,11 +106,14 @@
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Test on the JdbcIO. */
@RunWith(JUnit4.class)
public class JdbcIOTest implements Serializable {

private static final Logger LOG = LoggerFactory.getLogger(JdbcIOTest.class);
private static final DataSourceConfiguration DATA_SOURCE_CONFIGURATION =
DataSourceConfiguration.create(
"org.apache.derby.jdbc.EmbeddedDriver", "jdbc:derby:memory:testDB;create=true");
Expand Down Expand Up @@ -1326,7 +1331,10 @@ public void testPartitioningDateTime() {
PCollection<KV<DateTime, DateTime>> ranges =
pipeline
.apply(Create.of(KV.of(10L, KV.of(new DateTime(0), DateTime.now()))))
.apply(ParDo.of(new PartitioningFn<>(TypeDescriptor.of(DateTime.class))));
.apply(
ParDo.of(
new PartitioningFn<>(
JdbcUtil.getPartitionsHelper(TypeDescriptor.of(DateTime.class)))));

PAssert.that(ranges.apply(Count.globally()))
.satisfies(
Expand Down Expand Up @@ -1407,9 +1415,78 @@ public void testPartitioningLongs() {
PCollection<KV<Long, Long>> ranges =
pipeline
.apply(Create.of(KV.of(10L, KV.of(0L, 12346789L))))
.apply(ParDo.of(new PartitioningFn<>(TypeDescriptors.longs())));
.apply(
ParDo.of(
new PartitioningFn<>(JdbcUtil.getPartitionsHelper(TypeDescriptors.longs()))));

PAssert.that(ranges.apply(Count.globally())).containsInAnyOrder(10L);
pipeline.run().waitUntilFinish();
}

@Test
public void testPartitioningStringsWithCustomPartitionsHelper() {
JdbcReadWithPartitionsHelper<String> helper =
new JdbcReadWithPartitionsHelper<String>() {
@Override
public Iterable<KV<String, String>> calculateRanges(
String lowerBound, String upperBound, Long partitions) {
// we expect the elements in the test case follow the format <common prefix>idx
String prefix = StringUtils.getCommonPrefix(lowerBound, upperBound);
int minChar = lowerBound.charAt(prefix.length());
int maxChar = upperBound.charAt(prefix.length());
int numPartition;
if (maxChar - minChar < partitions) {
LOG.warn(
"Partition large than possible! Adjust to {} partition instead",
maxChar - minChar);
numPartition = maxChar - minChar;
} else {
numPartition = Math.toIntExact(partitions);
}
List<KV<String, String>> ranges = new ArrayList<>();
int stride = (maxChar - minChar) / numPartition + 1;
int highest = minChar;
for (int i = minChar; i < maxChar - stride; i += stride) {
ranges.add(KV.of(prefix + (char) i, prefix + (char) (i + stride)));
highest = i + stride;
}
if (highest <= maxChar) {
ranges.add(KV.of(prefix + (char) highest, prefix + (char) (highest + stride)));
}
return ranges;
}

@Override
public void setParameters(
KV<String, String> element, PreparedStatement preparedStatement) {
try {
preparedStatement.setString(1, element.getKey());
preparedStatement.setString(2, element.getValue());
} catch (SQLException e) {
throw new RuntimeException(e);
}
}

@Override
public KV<Long, KV<String, String>> mapRow(ResultSet resultSet) throws Exception {
if (resultSet.getMetaData().getColumnCount() == 3) {
return KV.of(
resultSet.getLong(3), KV.of(resultSet.getString(1), resultSet.getString(2)));
} else {
return KV.of(0L, KV.of(resultSet.getString(1), resultSet.getString(2)));
}
}
};

PCollection<TestRow> rows =
pipeline.apply(
JdbcIO.<TestRow, String>readWithPartitions(helper)
.withDataSourceConfiguration(DATA_SOURCE_CONFIGURATION)
.withRowMapper(new JdbcTestHelper.CreateTestRowOfNameAndId())
.withTable(READ_TABLE_NAME)
.withNumPartitions(5)
.withPartitionColumn("name"));
PAssert.thatSingleton(rows.apply("Count All", Count.globally())).isEqualTo(1000L);
pipeline.run();
}
}
Loading

0 comments on commit 209d50e

Please sign in to comment.