Skip to content

Commit

Permalink
Reject invalid URIs in URI extract functions
Browse files Browse the repository at this point in the history
It's not by (any document) design to silently ignore invalid input in
URI extracting functions (url_extract_fragment, url_extract_host,
url_extract_parameter, url_extract_path, url_extract_port,
url_extract_protocol, url_extract_query).

This commit follows least surprise principle: invalid input is
explicitly rejected. Users wanting to ignore invalid input should use
`try(...)` expression around the invocations.

This commit does not change docs, since now the functions behave as
documented.
  • Loading branch information
findepi committed Jul 14, 2023
1 parent d4fbc02 commit 3befdbb
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.net.URLDecoder;
import java.util.Iterator;

import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Strings.nullToEmpty;
import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
Expand All @@ -53,8 +54,8 @@ private UrlFunctions() {}
@SqlType("varchar(x)")
public static Slice urlExtractProtocol(@SqlType("varchar(x)") Slice url)
{
URI uri = parseUrl(url);
return (uri == null) ? null : slice(uri.getScheme());
URI uri = parseUriArgument(url);
return slice(uri.getScheme());
}

@SqlNullable
Expand All @@ -64,8 +65,8 @@ public static Slice urlExtractProtocol(@SqlType("varchar(x)") Slice url)
@SqlType("varchar(x)")
public static Slice urlExtractHost(@SqlType("varchar(x)") Slice url)
{
URI uri = parseUrl(url);
return (uri == null) ? null : slice(uri.getHost());
URI uri = parseUriArgument(url);
return slice(uri.getHost());
}

@SqlNullable
Expand All @@ -75,8 +76,8 @@ public static Slice urlExtractHost(@SqlType("varchar(x)") Slice url)
@SqlType(StandardTypes.BIGINT)
public static Long urlExtractPort(@SqlType("varchar(x)") Slice url)
{
URI uri = parseUrl(url);
if ((uri == null) || (uri.getPort() < 0)) {
URI uri = parseUriArgument(url);
if (uri.getPort() < 0) {
return null;
}
return (long) uri.getPort();
Expand All @@ -89,8 +90,8 @@ public static Long urlExtractPort(@SqlType("varchar(x)") Slice url)
@SqlType("varchar(x)")
public static Slice urlExtractPath(@SqlType("varchar(x)") Slice url)
{
URI uri = parseUrl(url);
return (uri == null) ? null : slice(uri.getPath());
URI uri = parseUriArgument(url);
return slice(uri.getPath());
}

@SqlNullable
Expand All @@ -100,8 +101,8 @@ public static Slice urlExtractPath(@SqlType("varchar(x)") Slice url)
@SqlType("varchar(x)")
public static Slice urlExtractQuery(@SqlType("varchar(x)") Slice url)
{
URI uri = parseUrl(url);
return (uri == null) ? null : slice(uri.getQuery());
URI uri = parseUriArgument(url);
return slice(uri.getQuery());
}

@SqlNullable
Expand All @@ -111,8 +112,8 @@ public static Slice urlExtractQuery(@SqlType("varchar(x)") Slice url)
@SqlType("varchar(x)")
public static Slice urlExtractFragment(@SqlType("varchar(x)") Slice url)
{
URI uri = parseUrl(url);
return (uri == null) ? null : slice(uri.getFragment());
URI uri = parseUriArgument(url);
return slice(uri.getFragment());
}

@SqlNullable
Expand All @@ -122,8 +123,8 @@ public static Slice urlExtractFragment(@SqlType("varchar(x)") Slice url)
@SqlType("varchar(x)")
public static Slice urlExtractParameter(@SqlType("varchar(x)") Slice url, @SqlType("varchar(y)") Slice parameterName)
{
URI uri = parseUrl(url);
if ((uri == null) || (uri.getRawQuery() == null)) {
URI uri = parseUriArgument(url);
if (uri.getRawQuery() == null) {
return null;
}

Expand Down Expand Up @@ -184,13 +185,13 @@ private static Slice slice(@Nullable String s)
}

@Nullable
private static URI parseUrl(Slice url)
private static URI parseUriArgument(Slice url)
{
try {
return new URI(url.toStringUtf8());
}
catch (URISyntaxException e) {
return null;
throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "Cannot parse as URI value '%s': %s".formatted(url.toStringUtf8(), firstNonNull(e.getMessage(), e)), e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.VarcharType.createVarcharType;
import static io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;

Expand Down Expand Up @@ -53,7 +56,9 @@ public void testUrlExtract()
validateUrlExtract("https://username:password@example.com", "https", "example.com", null, "", "", "");
validateUrlExtract("mailto:test@example.com", "mailto", "", null, "", "", "");
validateUrlExtract("foo", "", "", null, "foo", "", "");
validateUrlExtract("http://example.com/^", null, null, null, null, null, null);

invalidUrlExtract("http://example.com/^");
invalidUrlExtract("http://not uri/cannot contain whitespace");
}

@Test
Expand Down Expand Up @@ -96,6 +101,10 @@ public void testUrlExtractParameter()

assertThat(assertions.expression("url_extract_parameter('foo', 'k1')"))
.isNull(createVarcharType(3));

assertTrinoExceptionThrownBy(() -> assertions.expression("url_extract_parameter('http://not uri/cannot contain whitespace?foo=bar', 'foo')").evaluate())
.hasErrorCode(INVALID_FUNCTION_ARGUMENT)
.hasMessage("Cannot parse as URI value 'http://not uri/cannot contain whitespace?foo=bar': Illegal character in authority at index 7: http://not uri/cannot contain whitespace?foo=bar");
}

@Test
Expand Down Expand Up @@ -144,6 +153,8 @@ public void testUrlDecode()

private void validateUrlExtract(String url, String protocol, String host, Long port, String path, String query, String fragment)
{
checkArgument(!url.contains("'")); // Would require escaping in literals

assertThat(assertions.function("url_extract_protocol", "'" + url + "'"))
.hasType(createVarcharType(url.length()))
.isEqualTo(protocol);
Expand All @@ -168,4 +179,33 @@ private void validateUrlExtract(String url, String protocol, String host, Long p
.hasType(createVarcharType(url.length()))
.isEqualTo(fragment);
}

private void invalidUrlExtract(String url)
{
checkArgument(!url.contains("'")); // Would require escaping in literals

assertTrinoExceptionThrownBy(() -> assertions.function("url_extract_protocol", "'" + url + "'").evaluate())
.hasErrorCode(INVALID_FUNCTION_ARGUMENT)
.hasMessageStartingWith("Cannot parse as URI value '%s': ".formatted(url));

assertTrinoExceptionThrownBy(() -> assertions.function("url_extract_host", "'" + url + "'").evaluate())
.hasErrorCode(INVALID_FUNCTION_ARGUMENT)
.hasMessageStartingWith("Cannot parse as URI value '%s': ".formatted(url));

assertTrinoExceptionThrownBy(() -> assertions.function("url_extract_port", "'" + url + "'").evaluate())
.hasErrorCode(INVALID_FUNCTION_ARGUMENT)
.hasMessageStartingWith("Cannot parse as URI value '%s': ".formatted(url));

assertTrinoExceptionThrownBy(() -> assertions.function("url_extract_path", "'" + url + "'").evaluate())
.hasErrorCode(INVALID_FUNCTION_ARGUMENT)
.hasMessageStartingWith("Cannot parse as URI value '%s': ".formatted(url));

assertTrinoExceptionThrownBy(() -> assertions.function("url_extract_query", "'" + url + "'").evaluate())
.hasErrorCode(INVALID_FUNCTION_ARGUMENT)
.hasMessageStartingWith("Cannot parse as URI value '%s': ".formatted(url));

assertTrinoExceptionThrownBy(() -> assertions.function("url_extract_fragment", "'" + url + "'").evaluate())
.hasErrorCode(INVALID_FUNCTION_ARGUMENT)
.hasMessageStartingWith("Cannot parse as URI value '%s': ".formatted(url));
}
}

0 comments on commit 3befdbb

Please sign in to comment.