Skip to content

Commit

Permalink
Add information from written files to Iceberg conflict detection
Browse files Browse the repository at this point in the history
  • Loading branch information
pajaks committed Dec 17, 2024
1 parent 76de6df commit 0bc5d55
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@
import static io.trino.plugin.iceberg.IcebergUtil.getFileFormat;
import static io.trino.plugin.iceberg.IcebergUtil.getIcebergTableProperties;
import static io.trino.plugin.iceberg.IcebergUtil.getPartitionKeys;
import static io.trino.plugin.iceberg.IcebergUtil.getPartitionValues;
import static io.trino.plugin.iceberg.IcebergUtil.getProjectedColumns;
import static io.trino.plugin.iceberg.IcebergUtil.getSnapshotIdAsOfTime;
import static io.trino.plugin.iceberg.IcebergUtil.getTableComment;
Expand Down Expand Up @@ -336,6 +337,7 @@
import static io.trino.spi.connector.MaterializedViewFreshness.Freshness.UNKNOWN;
import static io.trino.spi.connector.RetryMode.NO_RETRIES;
import static io.trino.spi.connector.RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW;
import static io.trino.spi.predicate.TupleDomain.withColumnDomains;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc;
import static io.trino.spi.type.DateType.DATE;
Expand Down Expand Up @@ -2868,7 +2870,9 @@ private void finishWrite(ConnectorSession session, IcebergTableHandle table, Col
table.getSnapshotId().map(icebergTable::snapshot).ifPresent(s -> rowDelta.validateFromSnapshot(s.snapshotId()));
TupleDomain<IcebergColumnHandle> dataColumnPredicate = table.getEnforcedPredicate().filter((column, domain) -> !isMetadataColumnId(column.getId()));
TupleDomain<IcebergColumnHandle> convertibleUnenforcedPredicate = table.getUnenforcedPredicate().filter((_, domain) -> isConvertibleToIcebergExpression(domain));
TupleDomain<IcebergColumnHandle> effectivePredicate = dataColumnPredicate.intersect(convertibleUnenforcedPredicate);
TupleDomain<IcebergColumnHandle> commitTasksDomains = extractTupleDomainsFromCommitTasks(table, icebergTable, commitTasks).filter((_, domain) -> isConvertibleToIcebergExpression(domain));
TupleDomain<IcebergColumnHandle> effectivePredicate = dataColumnPredicate.intersect(convertibleUnenforcedPredicate).intersect(commitTasksDomains);

if (!effectivePredicate.isAll()) {
rowDelta.conflictDetectionFilter(toIcebergExpression(effectivePredicate));
}
Expand Down Expand Up @@ -2933,6 +2937,39 @@ private void finishWrite(ConnectorSession session, IcebergTableHandle table, Col
commitUpdateAndTransaction(rowDelta, session, transaction, "write");
}

private TupleDomain<IcebergColumnHandle> extractTupleDomainsFromCommitTasks(IcebergTableHandle table, Table icebergTable, List<CommitTaskData> commitTasks)
{
List<IcebergColumnHandle> partitionColumns = getProjectedColumns(icebergTable.schema(), typeManager, identityPartitionColumnsInAllSpecs(icebergTable));
PartitionSpec partitionSpec = icebergTable.spec();
Type[] partitionColumnTypes = partitionSpec.fields().stream()
.map(field -> field.transform().getResultType(
icebergTable.schema().findType(field.sourceId())))
.toArray(Type[]::new);
Schema schema = SchemaParser.fromJson(table.getTableSchemaJson());
Map<IcebergColumnHandle, Domain> domainsFromTasks = new HashMap<>();
for (CommitTaskData commitTask : commitTasks) {
PartitionSpec taskPartitionSpec = PartitionSpecParser.fromJson(schema, commitTask.partitionSpecJson());
if (commitTask.partitionDataJson().isEmpty() || taskPartitionSpec.isUnpartitioned() || !taskPartitionSpec.equals(partitionSpec)) {
return TupleDomain.all(); // We should not produce any specific domains if there are no partitions or current partitions does not match task partitions for any of tasks
}

PartitionData partitionData = PartitionData.fromJson(commitTask.partitionDataJson().get(), partitionColumnTypes);
Map<Integer, Optional<String>> partitionKeys = getPartitionKeys(partitionData, partitionSpec);
Map<ColumnHandle, NullableValue> partitionValues = getPartitionValues(new HashSet<>(partitionColumns), partitionKeys);

for (Map.Entry<ColumnHandle, NullableValue> entry : partitionValues.entrySet()) {
IcebergColumnHandle columnHandle = (IcebergColumnHandle) entry.getKey();
NullableValue value = entry.getValue();
if (value.isNull()) {
return TupleDomain.all(); // We should not produce any specific domains if any of partition value is null
}
Domain newDomain = Domain.singleValue(columnHandle.getType(), value.getValue());
domainsFromTasks.merge(columnHandle, newDomain, Domain::union);
}
}
return withColumnDomains(domainsFromTasks);
}

@Override
public void createView(ConnectorSession session, SchemaTableName viewName, ConnectorViewDefinition definition, Map<String, Object> viewProperties, boolean replace)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,64 @@ void testConcurrentUpdateAndInserts()
}
}

@Test
public void testConcurrentMerge()
throws Exception
{
int threads = 3;
CyclicBarrier barrier = new CyclicBarrier(threads);
ExecutorService executor = newFixedThreadPool(threads);
String tableName = "test_concurrent_merges_table_" + randomNameSuffix();

assertUpdate("CREATE TABLE " + tableName + " (a, part) WITH (partitioning = ARRAY['part']) AS VALUES (1, 10), (11, 20), (21, 30), (31, 40)", 4);
// Add more files in the partition 30
assertUpdate("INSERT INTO " + tableName + " VALUES (22, 30)", 1);

try {
// merge data concurrently by using non-overlapping partition predicate
executor.invokeAll(ImmutableList.<Callable<Void>>builder()
.add(() -> {
barrier.await(10, SECONDS);
getQueryRunner().execute(
"""
MERGE INTO %s t USING (VALUES (12, 20)) AS s(a, part)
ON (FALSE)
WHEN NOT MATCHED THEN INSERT (a, part) VALUES(s.a, s.part)
""".formatted(tableName));
return null;
})
.add(() -> {
barrier.await(10, SECONDS);
getQueryRunner().execute(
"""
MERGE INTO %s t USING (VALUES (21, 30)) AS s(a, part)
ON (t.part = s.part)
WHEN MATCHED THEN DELETE
""".formatted(tableName));
return null;
})
.add(() -> {
barrier.await(10, SECONDS);
getQueryRunner().execute(
"""
MERGE INTO %s t USING (VALUES (32, 40)) AS s(a, part)
ON (t.part = s.part)
WHEN MATCHED THEN UPDATE SET a = s.a
""".formatted(tableName));
return null;
})
.build())
.forEach(MoreFutures::getDone);

assertThat(query("SELECT * FROM " + tableName)).matches("VALUES (1, 10), (11, 20), (12, 20), (32, 40)");
}
finally {
assertUpdate("DROP TABLE " + tableName);
executor.shutdownNow();
assertThat(executor.awaitTermination(10, SECONDS)).isTrue();
}
}

@Test
void testConcurrentMergeAndInserts()
throws Exception
Expand Down Expand Up @@ -781,6 +839,82 @@ void testConcurrentDeleteAndDeletePushdownAndInsert()
}
}

@Test
public void testConcurrentMergeWithTwoConditions()
throws Exception
{
int threads = 3;
CyclicBarrier barrier = new CyclicBarrier(threads);
ExecutorService executor = newFixedThreadPool(threads);
String targetTableName = "test_concurrent_merges_target_table_" + randomNameSuffix();
String sourceTableName = "test_concurrent_merges_source_table_" + randomNameSuffix();

assertUpdate("CREATE TABLE " + sourceTableName + " (bigint_source BIGINT, timestamp_source TIMESTAMP)");
assertUpdate("INSERT INTO " + sourceTableName + " VALUES (10, TIMESTAMP '2024-01-01 00:00:01'), (10, TIMESTAMP '2024-01-02 00:00:01'), (10, TIMESTAMP '2024-01-03 00:00:01')", 3);
assertUpdate("CREATE TABLE " + targetTableName + " (a BIGINT, bigint_target BIGINT, date_target DATE) WITH (partitioning = ARRAY['bigint_target', 'date_target'])");
assertUpdate("INSERT INTO " + targetTableName + " VALUES (1, 10, DATE '2024-01-01'), (2, 10, DATE '2024-01-02'), (3, 10, DATE '2024-01-03')", 3);

try {
// merge data concurrently by using non-overlapping partition predicate (different day in timestamps)
executor.invokeAll(ImmutableList.<Callable<Void>>builder()
.add(() -> {
barrier.await(10, SECONDS);
getQueryRunner().execute(
"""
MERGE INTO %s t USING (SELECT bigint_source, DATE(timestamp_source) AS date_source from %s
WHERE timestamp_source BETWEEN TIMESTAMP '2024-01-01 00:00:00' AND TIMESTAMP '2024-01-01 23:59:59.999999'
AND bigint_source IN (10)) AS s
ON (t.bigint_target = s.bigint_source AND t.date_target = s.date_source)
WHEN MATCHED THEN UPDATE SET
a = a + 1
""".formatted(targetTableName, sourceTableName));
return null;
})
.add(() -> {
barrier.await(10, SECONDS);
getQueryRunner().execute(
"""
MERGE INTO %s t USING (SELECT bigint_source, DATE(timestamp_source) AS date_source from %s
WHERE timestamp_source BETWEEN TIMESTAMP '2024-01-02 00:00:00' AND TIMESTAMP '2024-01-02 23:59:59.999999'
AND bigint_source IN (10)) AS s
ON (t.bigint_target = s.bigint_source AND t.date_target = s.date_source)
WHEN MATCHED THEN UPDATE SET
a = a + 10
""".formatted(targetTableName, sourceTableName));
return null;
})
.add(() -> {
barrier.await(10, SECONDS);
getQueryRunner().execute(
"""
MERGE INTO %s t USING (SELECT bigint_source, DATE(timestamp_source) AS date_source from %s
WHERE timestamp_source BETWEEN TIMESTAMP '2024-01-03 00:00:00' AND TIMESTAMP '2024-01-03 23:59:59.999999'
AND bigint_source IN (10)) AS s
ON (t.bigint_target = s.bigint_source AND t.date_target = s.date_source)
WHEN MATCHED THEN UPDATE SET
a = a + 100
""".formatted(targetTableName, sourceTableName));
return null;
})
.build())
.forEach(MoreFutures::getDone);

assertThat(query("SELECT * FROM " + targetTableName))
.matches("""
VALUES
(CAST(2 AS BIGINT), CAST(10 AS BIGINT), DATE '2024-01-01'),
(CAST(12 AS BIGINT), CAST(10 AS BIGINT), DATE '2024-01-02'),
(CAST(103 AS BIGINT), CAST(10 AS BIGINT), DATE '2024-01-03')
""");
}
finally {
assertUpdate("DROP TABLE " + sourceTableName);
assertUpdate("DROP TABLE " + targetTableName);
executor.shutdownNow();
assertThat(executor.awaitTermination(10, SECONDS)).isTrue();
}
}

@Test
void testConcurrentUpdateWithPartitionTransformation()
throws Exception
Expand Down

0 comments on commit 0bc5d55

Please sign in to comment.