Skip to content

Commit

Permalink
Support partial update in Phoenix connector
Browse files Browse the repository at this point in the history
  • Loading branch information
chenjian2664 authored and electrum committed Nov 23, 2024
1 parent ae2448a commit f17fdac
Show file tree
Hide file tree
Showing 5 changed files with 394 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
*/
package io.trino.plugin.phoenix5;

import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.Slice;
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcOutputTableHandle;
Expand All @@ -32,19 +34,27 @@
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import org.apache.phoenix.util.SchemaUtil;

import java.sql.Connection;
import java.sql.SQLException;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR;
import static io.trino.plugin.phoenix5.PhoenixClient.ROWKEY;
import static io.trino.plugin.phoenix5.PhoenixClient.ROWKEY_COLUMN_HANDLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.TinyintType.TINYINT;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static org.apache.phoenix.util.SchemaUtil.getEscapedArgument;
Expand All @@ -56,9 +66,11 @@ public class PhoenixMergeSink
private final int columnCount;

private final ConnectorPageSink insertSink;
private final ConnectorPageSink updateSink;
private final Map<Integer, Supplier<ConnectorPageSink>> updateSinkSuppliers;
private final ConnectorPageSink deleteSink;

private final Map<Integer, Set<Integer>> updateCaseChannels;

public PhoenixMergeSink(
ConnectorSession session,
ConnectorMergeTableHandle mergeHandle,
Expand All @@ -73,7 +85,6 @@ public PhoenixMergeSink(
this.columnCount = phoenixOutputTableHandle.getColumnNames().size();

this.insertSink = new JdbcPageSink(session, phoenixOutputTableHandle, phoenixClient, pageSinkId, remoteQueryModifier, JdbcClient::buildInsertSql);
this.updateSink = createUpdateSink(session, phoenixOutputTableHandle, phoenixClient, pageSinkId, remoteQueryModifier);

ImmutableList.Builder<String> mergeRowIdFieldNamesBuilder = ImmutableList.builder();
ImmutableList.Builder<Type> mergeRowIdFieldTypesBuilder = ImmutableList.builder();
Expand All @@ -84,6 +95,31 @@ public PhoenixMergeSink(
mergeRowIdFieldTypesBuilder.add(field.getType());
}
List<String> mergeRowIdFieldNames = mergeRowIdFieldNamesBuilder.build();
List<String> dataColumnNames = phoenixOutputTableHandle.getColumnNames().stream()
.map(SchemaUtil::getEscapedArgument)
.collect(toImmutableList());
Set<Integer> mergeRowIdChannels = mergeRowIdFieldNames.stream()
.map(dataColumnNames::indexOf)
.collect(toImmutableSet());

Map<Integer, Set<Integer>> updateCaseChannels = new HashMap<>();
for (Map.Entry<Integer, Set<Integer>> entry : phoenixMergeTableHandle.updateCaseColumns().entrySet()) {
updateCaseChannels.put(entry.getKey(), entry.getValue());
if (!hasRowKey) {
checkArgument(!mergeRowIdChannels.isEmpty() && !mergeRowIdChannels.contains(-1), "No primary keys found");
updateCaseChannels.get(entry.getKey()).addAll(mergeRowIdChannels);
}
}
this.updateCaseChannels = ImmutableMap.copyOf(updateCaseChannels);

ImmutableMap.Builder<Integer, Supplier<ConnectorPageSink>> updateSinksBuilder = ImmutableMap.builder();
for (Map.Entry<Integer, Set<Integer>> entry : this.updateCaseChannels.entrySet()) {
int caseNumber = entry.getKey();
Supplier<ConnectorPageSink> updateSupplier = Suppliers.memoize(() -> createUpdateSink(session, phoenixOutputTableHandle, phoenixClient, pageSinkId, remoteQueryModifier, entry.getValue()));
updateSinksBuilder.put(caseNumber, updateSupplier);
}
this.updateSinkSuppliers = updateSinksBuilder.buildOrThrow();

this.deleteSink = createDeleteSink(session, mergeRowIdFieldTypesBuilder.build(), phoenixClient, phoenixMergeTableHandle, mergeRowIdFieldNames, pageSinkId, remoteQueryModifier, queryBuilder);
}

Expand All @@ -92,12 +128,17 @@ private static ConnectorPageSink createUpdateSink(
PhoenixOutputTableHandle phoenixOutputTableHandle,
PhoenixClient phoenixClient,
ConnectorPageSinkId pageSinkId,
RemoteQueryModifier remoteQueryModifier)
RemoteQueryModifier remoteQueryModifier,
Set<Integer> updateChannels)
{
ImmutableList.Builder<String> columnNamesBuilder = ImmutableList.builder();
ImmutableList.Builder<Type> columnTypesBuilder = ImmutableList.builder();
columnNamesBuilder.addAll(phoenixOutputTableHandle.getColumnNames());
columnTypesBuilder.addAll(phoenixOutputTableHandle.getColumnTypes());
for (int channel = 0; channel < phoenixOutputTableHandle.getColumnNames().size(); channel++) {
if (updateChannels.contains(channel)) {
columnNamesBuilder.add(phoenixOutputTableHandle.getColumnNames().get(channel));
columnTypesBuilder.add(phoenixOutputTableHandle.getColumnTypes().get(channel));
}
}
if (phoenixOutputTableHandle.rowkeyColumn().isPresent()) {
columnNamesBuilder.add(ROWKEY);
columnTypesBuilder.add(ROWKEY_COLUMN_HANDLE.getColumnType());
Expand Down Expand Up @@ -168,8 +209,10 @@ public void storeMergedRows(Page page)
int insertPositionCount = 0;
int[] deletePositions = new int[positionCount];
int deletePositionCount = 0;
int[] updatePositions = new int[positionCount];
int updatePositionCount = 0;

Block updateCaseBlock = page.getBlock(columnCount + 1);
Map<Integer, int[]> updatePositions = new HashMap<>();
Map<Integer, Integer> updatePositionCounts = new HashMap<>();

for (int position = 0; position < positionCount; position++) {
int operation = TINYINT.getByte(operationBlock, position);
Expand All @@ -183,8 +226,10 @@ public void storeMergedRows(Page page)
deletePositionCount++;
}
case UPDATE_OPERATION_NUMBER -> {
updatePositions[updatePositionCount] = position;
updatePositionCount++;
int caseNumber = INTEGER.getInt(updateCaseBlock, position);
int updatePositionCount = updatePositionCounts.getOrDefault(caseNumber, 0);
updatePositions.computeIfAbsent(caseNumber, _ -> new int[positionCount])[updatePositionCount] = position;
updatePositionCounts.put(caseNumber, updatePositionCount + 1);
}
default -> throw new IllegalStateException("Unexpected value: " + operation);
}
Expand All @@ -203,13 +248,21 @@ public void storeMergedRows(Page page)
deleteSink.appendPage(new Page(deletePositionCount, deleteBlocks));
}

if (updatePositionCount > 0) {
Page updatePage = dataPage.getPositions(updatePositions, 0, updatePositionCount);
if (hasRowKey) {
updatePage = updatePage.appendColumn(rowIdFields.get(0).getPositions(updatePositions, 0, updatePositionCount));
}
for (Map.Entry<Integer, Integer> entry : updatePositionCounts.entrySet()) {
int caseNumber = entry.getKey();
int updatePositionCount = entry.getValue();
if (updatePositionCount > 0) {
checkArgument(updatePositions.containsKey(caseNumber), "Unexpected case number %s", caseNumber);

updateSink.appendPage(updatePage);
Page updatePage = dataPage
.getColumns(updateCaseChannels.get(caseNumber).stream().mapToInt(Integer::intValue).sorted().toArray())
.getPositions(updatePositions.get(caseNumber), 0, updatePositionCount);
if (hasRowKey) {
updatePage = updatePage.appendColumn(rowIdFields.get(0).getPositions(updatePositions.get(caseNumber), 0, updatePositionCount));
}

updateSinkSuppliers.get(caseNumber).get().appendPage(updatePage);
}
}
}

Expand All @@ -218,7 +271,7 @@ public CompletableFuture<Collection<Slice>> finish()
{
insertSink.finish();
deleteSink.finish();
updateSink.finish();
updateSinkSuppliers.values().stream().map(Supplier::get).forEach(ConnectorPageSink::finish);
return completedFuture(ImmutableList.of());
}

Expand All @@ -227,6 +280,6 @@ public void abort()
{
insertSink.abort();
deleteSink.abort();
updateSink.abort();
updateSinkSuppliers.values().stream().map(Supplier::get).forEach(ConnectorPageSink::abort);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,32 @@
import io.trino.spi.connector.ConnectorMergeTableHandle;
import io.trino.spi.predicate.TupleDomain;

import java.util.Map;
import java.util.Set;

import static java.util.Objects.requireNonNull;

public record PhoenixMergeTableHandle(
JdbcTableHandle tableHandle,
PhoenixOutputTableHandle phoenixOutputTableHandle,
JdbcColumnHandle mergeRowIdColumnHandle,
TupleDomain<ColumnHandle> primaryKeysDomain)
TupleDomain<ColumnHandle> primaryKeysDomain,
Map<Integer, Set<Integer>> updateCaseColumns)
implements ConnectorMergeTableHandle
{
@JsonCreator
public PhoenixMergeTableHandle(
@JsonProperty("tableHandle") JdbcTableHandle tableHandle,
@JsonProperty("phoenixOutputTableHandle") PhoenixOutputTableHandle phoenixOutputTableHandle,
@JsonProperty("mergeRowIdColumnHandle") JdbcColumnHandle mergeRowIdColumnHandle,
@JsonProperty("primaryKeysDomain") TupleDomain<ColumnHandle> primaryKeysDomain)
@JsonProperty("primaryKeysDomain") TupleDomain<ColumnHandle> primaryKeysDomain,
@JsonProperty("updateCaseColumns") Map<Integer, Set<Integer>> updateCaseColumns)
{
this.tableHandle = requireNonNull(tableHandle, "tableHandle is null");
this.phoenixOutputTableHandle = requireNonNull(phoenixOutputTableHandle, "phoenixOutputTableHandle is null");
this.mergeRowIdColumnHandle = requireNonNull(mergeRowIdColumnHandle, "mergeRowIdColumnHandle is null");
this.primaryKeysDomain = requireNonNull(primaryKeysDomain, "primaryKeysDomain is null");
this.updateCaseColumns = requireNonNull(updateCaseColumns, "updateCaseColumns is null");
}

@JsonProperty
Expand Down Expand Up @@ -70,4 +76,11 @@ public TupleDomain<ColumnHandle> primaryKeysDomain()
{
return primaryKeysDomain;
}

@Override
@JsonProperty
public Map<Integer, Set<Integer>> updateCaseColumns()
{
return updateCaseColumns;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.plugin.jdbc.JdbcMetadata.getColumns;
import static io.trino.plugin.phoenix5.MetadataUtil.getEscapedTableName;
import static io.trino.plugin.phoenix5.MetadataUtil.toTrinoSchemaName;
Expand Down Expand Up @@ -350,11 +351,23 @@ public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorT
primaryKeysDomainBuilder.put(columnHandle, dummy);
}

ImmutableMap.Builder<Integer, Set<Integer>> updateColumnChannelsBuilder = ImmutableMap.builder();
for (Map.Entry<Integer, Collection<ColumnHandle>> entry : updateColumnHandles.entrySet()) {
int caseNumber = entry.getKey();
Set<Integer> updateColumnChannels = entry.getValue().stream()
.map(JdbcColumnHandle.class::cast)
.peek(column -> checkArgument(columns.contains(column), "update column %s not found in the target table", column))
.map(columns::indexOf)
.collect(toImmutableSet());
updateColumnChannelsBuilder.put(caseNumber, updateColumnChannels);
}

return new PhoenixMergeTableHandle(
phoenixClient.updatedScanColumnTable(session, handle, handle.getColumns(), mergeRowIdColumnHandle),
phoenixOutputTableHandle,
mergeRowIdColumnHandle,
TupleDomain.withColumnDomains(primaryKeysDomainBuilder.buildOrThrow()));
TupleDomain.withColumnDomains(primaryKeysDomainBuilder.buildOrThrow()),
updateColumnChannelsBuilder.buildOrThrow());
}

@Override
Expand Down
Loading

0 comments on commit f17fdac

Please sign in to comment.