diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java index 7e51d89d68f..a63fae7339d 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java @@ -294,6 +294,7 @@ import static io.trino.plugin.iceberg.IcebergUtil.getFileFormat; import static io.trino.plugin.iceberg.IcebergUtil.getIcebergTableProperties; import static io.trino.plugin.iceberg.IcebergUtil.getPartitionKeys; +import static io.trino.plugin.iceberg.IcebergUtil.getPartitionValues; import static io.trino.plugin.iceberg.IcebergUtil.getProjectedColumns; import static io.trino.plugin.iceberg.IcebergUtil.getSnapshotIdAsOfTime; import static io.trino.plugin.iceberg.IcebergUtil.getTableComment; @@ -336,6 +337,7 @@ import static io.trino.spi.connector.MaterializedViewFreshness.Freshness.UNKNOWN; import static io.trino.spi.connector.RetryMode.NO_RETRIES; import static io.trino.spi.connector.RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW; +import static io.trino.spi.predicate.TupleDomain.withColumnDomains; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc; import static io.trino.spi.type.DateType.DATE; @@ -2868,7 +2870,9 @@ private void finishWrite(ConnectorSession session, IcebergTableHandle table, Col table.getSnapshotId().map(icebergTable::snapshot).ifPresent(s -> rowDelta.validateFromSnapshot(s.snapshotId())); TupleDomain dataColumnPredicate = table.getEnforcedPredicate().filter((column, domain) -> !isMetadataColumnId(column.getId())); TupleDomain convertibleUnenforcedPredicate = table.getUnenforcedPredicate().filter((_, domain) -> isConvertibleToIcebergExpression(domain)); - TupleDomain effectivePredicate = dataColumnPredicate.intersect(convertibleUnenforcedPredicate); + TupleDomain commitTasksDomains = extractTupleDomainsFromCommitTasks(table, icebergTable, commitTasks).filter((_, domain) -> isConvertibleToIcebergExpression(domain)); + TupleDomain effectivePredicate = dataColumnPredicate.intersect(convertibleUnenforcedPredicate).intersect(commitTasksDomains); + if (!effectivePredicate.isAll()) { rowDelta.conflictDetectionFilter(toIcebergExpression(effectivePredicate)); } @@ -2933,6 +2937,39 @@ private void finishWrite(ConnectorSession session, IcebergTableHandle table, Col commitUpdateAndTransaction(rowDelta, session, transaction, "write"); } + private TupleDomain extractTupleDomainsFromCommitTasks(IcebergTableHandle table, Table icebergTable, List commitTasks) + { + List partitionColumns = getProjectedColumns(icebergTable.schema(), typeManager, identityPartitionColumnsInAllSpecs(icebergTable)); + PartitionSpec partitionSpec = icebergTable.spec(); + Type[] partitionColumnTypes = partitionSpec.fields().stream() + .map(field -> field.transform().getResultType( + icebergTable.schema().findType(field.sourceId()))) + .toArray(Type[]::new); + Schema schema = SchemaParser.fromJson(table.getTableSchemaJson()); + Map domainsFromTasks = new HashMap<>(); + for (CommitTaskData commitTask : commitTasks) { + PartitionSpec taskPartitionSpec = PartitionSpecParser.fromJson(schema, commitTask.partitionSpecJson()); + if (commitTask.partitionDataJson().isEmpty() || taskPartitionSpec.isUnpartitioned() || !taskPartitionSpec.equals(partitionSpec)) { + return TupleDomain.all(); // We should not produce any specific domains if there are no partitions or current partitions does not match task partitions for any of tasks + } + + PartitionData partitionData = PartitionData.fromJson(commitTask.partitionDataJson().get(), partitionColumnTypes); + Map> partitionKeys = getPartitionKeys(partitionData, partitionSpec); + Map partitionValues = getPartitionValues(new HashSet<>(partitionColumns), partitionKeys); + + for (Map.Entry entry : partitionValues.entrySet()) { + IcebergColumnHandle columnHandle = (IcebergColumnHandle) entry.getKey(); + NullableValue value = entry.getValue(); + if (value.isNull()) { + return TupleDomain.all(); // We should not produce any specific domains if any of partition value is null + } + Domain newDomain = Domain.singleValue(columnHandle.getType(), value.getValue()); + domainsFromTasks.merge(columnHandle, newDomain, Domain::union); + } + } + return withColumnDomains(domainsFromTasks); + } + @Override public void createView(ConnectorSession session, SchemaTableName viewName, ConnectorViewDefinition definition, Map viewProperties, boolean replace) { diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergLocalConcurrentWrites.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergLocalConcurrentWrites.java index 903fff33da5..38f6a48ebe3 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergLocalConcurrentWrites.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergLocalConcurrentWrites.java @@ -657,6 +657,64 @@ void testConcurrentUpdateAndInserts() } } + @Test + public void testConcurrentMerge() + throws Exception + { + int threads = 3; + CyclicBarrier barrier = new CyclicBarrier(threads); + ExecutorService executor = newFixedThreadPool(threads); + String tableName = "test_concurrent_merges_table_" + randomNameSuffix(); + + assertUpdate("CREATE TABLE " + tableName + " (a, part) WITH (partitioning = ARRAY['part']) AS VALUES (1, 10), (11, 20), (21, 30), (31, 40)", 4); + // Add more files in the partition 30 + assertUpdate("INSERT INTO " + tableName + " VALUES (22, 30)", 1); + + try { + // merge data concurrently by using non-overlapping partition predicate + executor.invokeAll(ImmutableList.>builder() + .add(() -> { + barrier.await(10, SECONDS); + getQueryRunner().execute( + """ + MERGE INTO %s t USING (VALUES (12, 20)) AS s(a, part) + ON (FALSE) + WHEN NOT MATCHED THEN INSERT (a, part) VALUES(s.a, s.part) + """.formatted(tableName)); + return null; + }) + .add(() -> { + barrier.await(10, SECONDS); + getQueryRunner().execute( + """ + MERGE INTO %s t USING (VALUES (21, 30)) AS s(a, part) + ON (t.part = s.part) + WHEN MATCHED THEN DELETE + """.formatted(tableName)); + return null; + }) + .add(() -> { + barrier.await(10, SECONDS); + getQueryRunner().execute( + """ + MERGE INTO %s t USING (VALUES (32, 40)) AS s(a, part) + ON (t.part = s.part) + WHEN MATCHED THEN UPDATE SET a = s.a + """.formatted(tableName)); + return null; + }) + .build()) + .forEach(MoreFutures::getDone); + + assertThat(query("SELECT * FROM " + tableName)).matches("VALUES (1, 10), (11, 20), (12, 20), (32, 40)"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + executor.shutdownNow(); + assertThat(executor.awaitTermination(10, SECONDS)).isTrue(); + } + } + @Test void testConcurrentMergeAndInserts() throws Exception @@ -781,6 +839,82 @@ void testConcurrentDeleteAndDeletePushdownAndInsert() } } + @Test + public void testConcurrentMergeWithTwoConditions() + throws Exception + { + int threads = 3; + CyclicBarrier barrier = new CyclicBarrier(threads); + ExecutorService executor = newFixedThreadPool(threads); + String targetTableName = "test_concurrent_merges_target_table_" + randomNameSuffix(); + String sourceTableName = "test_concurrent_merges_source_table_" + randomNameSuffix(); + + assertUpdate("CREATE TABLE " + sourceTableName + " (bigint_source BIGINT, timestamp_source TIMESTAMP)"); + assertUpdate("INSERT INTO " + sourceTableName + " VALUES (10, TIMESTAMP '2024-01-01 00:00:01'), (10, TIMESTAMP '2024-01-02 00:00:01'), (10, TIMESTAMP '2024-01-03 00:00:01')", 3); + assertUpdate("CREATE TABLE " + targetTableName + " (a BIGINT, bigint_target BIGINT, date_target DATE) WITH (partitioning = ARRAY['bigint_target', 'date_target'])"); + assertUpdate("INSERT INTO " + targetTableName + " VALUES (1, 10, DATE '2024-01-01'), (2, 10, DATE '2024-01-02'), (3, 10, DATE '2024-01-03')", 3); + + try { + // merge data concurrently by using non-overlapping partition predicate (different day in timestamps) + executor.invokeAll(ImmutableList.>builder() + .add(() -> { + barrier.await(10, SECONDS); + getQueryRunner().execute( + """ + MERGE INTO %s t USING (SELECT bigint_source, DATE(timestamp_source) AS date_source from %s + WHERE timestamp_source BETWEEN TIMESTAMP '2024-01-01 00:00:00' AND TIMESTAMP '2024-01-01 23:59:59.999999' + AND bigint_source IN (10)) AS s + ON (t.bigint_target = s.bigint_source AND t.date_target = s.date_source) + WHEN MATCHED THEN UPDATE SET + a = a + 1 + """.formatted(targetTableName, sourceTableName)); + return null; + }) + .add(() -> { + barrier.await(10, SECONDS); + getQueryRunner().execute( + """ + MERGE INTO %s t USING (SELECT bigint_source, DATE(timestamp_source) AS date_source from %s + WHERE timestamp_source BETWEEN TIMESTAMP '2024-01-02 00:00:00' AND TIMESTAMP '2024-01-02 23:59:59.999999' + AND bigint_source IN (10)) AS s + ON (t.bigint_target = s.bigint_source AND t.date_target = s.date_source) + WHEN MATCHED THEN UPDATE SET + a = a + 10 + """.formatted(targetTableName, sourceTableName)); + return null; + }) + .add(() -> { + barrier.await(10, SECONDS); + getQueryRunner().execute( + """ + MERGE INTO %s t USING (SELECT bigint_source, DATE(timestamp_source) AS date_source from %s + WHERE timestamp_source BETWEEN TIMESTAMP '2024-01-03 00:00:00' AND TIMESTAMP '2024-01-03 23:59:59.999999' + AND bigint_source IN (10)) AS s + ON (t.bigint_target = s.bigint_source AND t.date_target = s.date_source) + WHEN MATCHED THEN UPDATE SET + a = a + 100 + """.formatted(targetTableName, sourceTableName)); + return null; + }) + .build()) + .forEach(MoreFutures::getDone); + + assertThat(query("SELECT * FROM " + targetTableName)) + .matches(""" + VALUES + (CAST(2 AS BIGINT), CAST(10 AS BIGINT), DATE '2024-01-01'), + (CAST(12 AS BIGINT), CAST(10 AS BIGINT), DATE '2024-01-02'), + (CAST(103 AS BIGINT), CAST(10 AS BIGINT), DATE '2024-01-03') + """); + } + finally { + assertUpdate("DROP TABLE " + sourceTableName); + assertUpdate("DROP TABLE " + targetTableName); + executor.shutdownNow(); + assertThat(executor.awaitTermination(10, SECONDS)).isTrue(); + } + } + @Test void testConcurrentUpdateWithPartitionTransformation() throws Exception