diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/RollbackToSnapshotProcedure.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/RollbackToSnapshotProcedure.java index db7b34d34b87..588be5299d78 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/RollbackToSnapshotProcedure.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/RollbackToSnapshotProcedure.java @@ -26,6 +26,7 @@ import java.lang.invoke.MethodHandle; +import static io.trino.plugin.base.util.Procedures.checkProcedureArgument; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarcharType.VARCHAR; import static java.lang.invoke.MethodHandles.lookup; @@ -68,6 +69,10 @@ public Procedure get() public void rollbackToSnapshot(ConnectorSession clientSession, String schema, String table, Long snapshotId) { + checkProcedureArgument(schema != null, "schema cannot be null"); + checkProcedureArgument(table != null, "table cannot be null"); + checkProcedureArgument(snapshotId != null, "snapshot_id cannot be null"); + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(getClass().getClassLoader())) { SchemaTableName schemaTableName = new SchemaTableName(schema, table); Table icebergTable = catalogFactory.create(clientSession.getIdentity()).loadTable(clientSession, schemaTableName); diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergProcedureCalls.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergProcedureCalls.java index ab22a503147f..c25bf836e89f 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergProcedureCalls.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergProcedureCalls.java @@ -17,6 +17,7 @@ import org.testng.annotations.Test; import static io.trino.tempto.assertions.QueryAssert.Row.row; +import static io.trino.tempto.assertions.QueryAssert.assertQueryFailure; import static io.trino.tempto.assertions.QueryAssert.assertThat; import static io.trino.testing.TestingNames.randomNameSuffix; import static io.trino.tests.product.TestGroups.ICEBERG; @@ -47,6 +48,18 @@ public void testRollbackToSnapshot() onTrino().executeQuery(format("DROP TABLE IF EXISTS %s", tableName)); } + @Test(groups = {ICEBERG, PROFILE_SPECIFIC_TESTS}) + public void testRollbackToSnapshotWithNullArgument() + { + onTrino().executeQuery("USE iceberg.default"); + assertQueryFailure(() -> onTrino().executeQuery("CALL system.rollback_to_snapshot(NULL, 'customer_orders', 8954597067493422955)")) + .hasMessageMatching(".*schema cannot be null.*"); + assertQueryFailure(() -> onTrino().executeQuery("CALL system.rollback_to_snapshot('testdb', NULL, 8954597067493422955)")) + .hasMessageMatching(".*table cannot be null.*"); + assertQueryFailure(() -> onTrino().executeQuery("CALL system.rollback_to_snapshot('testdb', 'customer_orders', NULL)")) + .hasMessageMatching(".*snapshot_id cannot be null.*"); + } + private long getSecondOldestTableSnapshot(String tableName) { return (Long) onTrino().executeQuery(