From ac1ba79565d281900a48b01decdeb0c5f6c0247c Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 12 Jul 2023 12:11:33 +0200 Subject: [PATCH] Reject invalid URIs in URI extract functions It's not by (any document) design to silently ignore invalid input in URI extracting functions (url_extract_fragment, url_extract_host, url_extract_parameter, url_extract_path, url_extract_port, url_extract_protocol, url_extract_query). This commit follows least surprise principle: invalid input is explicitly rejected. Users wanting to ignore invalid input should use `try(...)` expression around the invocations. This commit does not change docs, since now the functions behave as documented. --- .../trino/operator/scalar/UrlFunctions.java | 33 ++++++++------- .../operator/scalar/TestUrlFunctions.java | 42 ++++++++++++++++++- 2 files changed, 58 insertions(+), 17 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/UrlFunctions.java b/core/trino-main/src/main/java/io/trino/operator/scalar/UrlFunctions.java index 6c84b6c42425..823916b4a516 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/UrlFunctions.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/UrlFunctions.java @@ -34,6 +34,7 @@ import java.net.URLDecoder; import java.util.Iterator; +import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Strings.nullToEmpty; import static io.airlift.slice.Slices.utf8Slice; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; @@ -53,8 +54,8 @@ private UrlFunctions() {} @SqlType("varchar(x)") public static Slice urlExtractProtocol(@SqlType("varchar(x)") Slice url) { - URI uri = parseUrl(url); - return (uri == null) ? null : slice(uri.getScheme()); + URI uri = parseUriArgument(url); + return slice(uri.getScheme()); } @SqlNullable @@ -64,8 +65,8 @@ public static Slice urlExtractProtocol(@SqlType("varchar(x)") Slice url) @SqlType("varchar(x)") public static Slice urlExtractHost(@SqlType("varchar(x)") Slice url) { - URI uri = parseUrl(url); - return (uri == null) ? null : slice(uri.getHost()); + URI uri = parseUriArgument(url); + return slice(uri.getHost()); } @SqlNullable @@ -75,8 +76,8 @@ public static Slice urlExtractHost(@SqlType("varchar(x)") Slice url) @SqlType(StandardTypes.BIGINT) public static Long urlExtractPort(@SqlType("varchar(x)") Slice url) { - URI uri = parseUrl(url); - if ((uri == null) || (uri.getPort() < 0)) { + URI uri = parseUriArgument(url); + if (uri.getPort() < 0) { return null; } return (long) uri.getPort(); @@ -89,8 +90,8 @@ public static Long urlExtractPort(@SqlType("varchar(x)") Slice url) @SqlType("varchar(x)") public static Slice urlExtractPath(@SqlType("varchar(x)") Slice url) { - URI uri = parseUrl(url); - return (uri == null) ? null : slice(uri.getPath()); + URI uri = parseUriArgument(url); + return slice(uri.getPath()); } @SqlNullable @@ -100,8 +101,8 @@ public static Slice urlExtractPath(@SqlType("varchar(x)") Slice url) @SqlType("varchar(x)") public static Slice urlExtractQuery(@SqlType("varchar(x)") Slice url) { - URI uri = parseUrl(url); - return (uri == null) ? null : slice(uri.getQuery()); + URI uri = parseUriArgument(url); + return slice(uri.getQuery()); } @SqlNullable @@ -111,8 +112,8 @@ public static Slice urlExtractQuery(@SqlType("varchar(x)") Slice url) @SqlType("varchar(x)") public static Slice urlExtractFragment(@SqlType("varchar(x)") Slice url) { - URI uri = parseUrl(url); - return (uri == null) ? null : slice(uri.getFragment()); + URI uri = parseUriArgument(url); + return slice(uri.getFragment()); } @SqlNullable @@ -122,8 +123,8 @@ public static Slice urlExtractFragment(@SqlType("varchar(x)") Slice url) @SqlType("varchar(x)") public static Slice urlExtractParameter(@SqlType("varchar(x)") Slice url, @SqlType("varchar(y)") Slice parameterName) { - URI uri = parseUrl(url); - if ((uri == null) || (uri.getRawQuery() == null)) { + URI uri = parseUriArgument(url); + if (uri.getRawQuery() == null) { return null; } @@ -184,13 +185,13 @@ private static Slice slice(@Nullable String s) } @Nullable - private static URI parseUrl(Slice url) + private static URI parseUriArgument(Slice url) { try { return new URI(url.toStringUtf8()); } catch (URISyntaxException e) { - return null; + throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "Cannot parse as URI value '%s': %s".formatted(url.toStringUtf8(), firstNonNull(e.getMessage(), e)), e); } } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestUrlFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestUrlFunctions.java index f5aaa6af182a..88efd14a7b85 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestUrlFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestUrlFunctions.java @@ -19,8 +19,11 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; @@ -53,7 +56,9 @@ public void testUrlExtract() validateUrlExtract("https://username:password@example.com", "https", "example.com", null, "", "", ""); validateUrlExtract("mailto:test@example.com", "mailto", "", null, "", "", ""); validateUrlExtract("foo", "", "", null, "foo", "", ""); - validateUrlExtract("http://example.com/^", null, null, null, null, null, null); + + invalidUrlExtract("http://example.com/^"); + invalidUrlExtract("http://not uri/cannot contain whitespace"); } @Test @@ -96,6 +101,10 @@ public void testUrlExtractParameter() assertThat(assertions.expression("url_extract_parameter('foo', 'k1')")) .isNull(createVarcharType(3)); + + assertTrinoExceptionThrownBy(() -> assertions.expression("url_extract_parameter('http://not uri/cannot contain whitespace?foo=bar', 'foo')").evaluate()) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT) + .hasMessage("Cannot parse as URI value 'http://not uri/cannot contain whitespace?foo=bar': Illegal character in authority at index 7: http://not uri/cannot contain whitespace?foo=bar"); } @Test @@ -144,6 +153,8 @@ public void testUrlDecode() private void validateUrlExtract(String url, String protocol, String host, Long port, String path, String query, String fragment) { + checkArgument(!url.contains("'")); // Would require escaping in literals + assertThat(assertions.function("url_extract_protocol", "'" + url + "'")) .hasType(createVarcharType(url.length())) .isEqualTo(protocol); @@ -168,4 +179,33 @@ private void validateUrlExtract(String url, String protocol, String host, Long p .hasType(createVarcharType(url.length())) .isEqualTo(fragment); } + + private void invalidUrlExtract(String url) + { + checkArgument(!url.contains("'")); // Would require escaping in literals + + assertTrinoExceptionThrownBy(() -> assertions.function("url_extract_protocol", "'" + url + "'").evaluate()) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT) + .hasMessageStartingWith("Cannot parse as URI value '%s': ".formatted(url)); + + assertTrinoExceptionThrownBy(() -> assertions.function("url_extract_host", "'" + url + "'").evaluate()) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT) + .hasMessageStartingWith("Cannot parse as URI value '%s': ".formatted(url)); + + assertTrinoExceptionThrownBy(() -> assertions.function("url_extract_port", "'" + url + "'").evaluate()) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT) + .hasMessageStartingWith("Cannot parse as URI value '%s': ".formatted(url)); + + assertTrinoExceptionThrownBy(() -> assertions.function("url_extract_path", "'" + url + "'").evaluate()) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT) + .hasMessageStartingWith("Cannot parse as URI value '%s': ".formatted(url)); + + assertTrinoExceptionThrownBy(() -> assertions.function("url_extract_query", "'" + url + "'").evaluate()) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT) + .hasMessageStartingWith("Cannot parse as URI value '%s': ".formatted(url)); + + assertTrinoExceptionThrownBy(() -> assertions.function("url_extract_fragment", "'" + url + "'").evaluate()) + .hasErrorCode(INVALID_FUNCTION_ARGUMENT) + .hasMessageStartingWith("Cannot parse as URI value '%s': ".formatted(url)); + } }