From a5d3deb6c16e51a1ff3c05ab182ad281bb547db5 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Mon, 21 Feb 2022 16:14:14 +0100 Subject: [PATCH] DSL for JDBC expression rewrites Introduce DSL for building function pushdown rules in JDBC. As we are building up a registry of allowed function pushdowns, the declaration verbosity matters. This commit adds support for defining rewrite rules tersely. As a usage example, a previously hand-written `RewriteLike` rule is now replaced with a mere one-liner: ``` .map("$like_pattern(value: varchar, pattern: varchar): boolean").to("value LIKE pattern") ``` --- plugin/trino-base-jdbc/pom.xml | 15 +++ .../expression/ConnectorExpressionPattern.g4 | 81 ++++++++++++ .../plugin/jdbc/expression/CallPattern.java | 107 ++++++++++++++++ .../jdbc/expression/ExpressionCapture.java | 83 ++++++++++++ .../expression/ExpressionMappingParser.java | 83 ++++++++++++ .../jdbc/expression/ExpressionPattern.java | 34 +++++ .../expression/ExpressionPatternBuilder.java | 121 ++++++++++++++++++ .../jdbc/expression/GenericRewrite.java | 95 ++++++++++++++ ...dbcConnectorExpressionRewriterBuilder.java | 18 +++ .../jdbc/expression/LongTypeParameter.java | 79 ++++++++++++ .../plugin/jdbc/expression/MatchContext.java | 68 ++++++++++ .../plugin/jdbc/expression/RewriteLike.java | 72 ----------- .../expression/RewriteLikeWithEscape.java | 77 ----------- .../jdbc/expression/TypeParameterCapture.java | 93 ++++++++++++++ .../jdbc/expression/TypeParameterPattern.java | 34 +++++ .../plugin/jdbc/expression/TypePattern.java | 102 +++++++++++++++ .../TestExpressionMappingParser.java | 109 ++++++++++++++++ .../expression/TestExpressionMatching.java | 100 +++++++++++++++ .../jdbc/expression/TestGenericRewrite.java | 82 ++++++++++++ .../plugin/postgresql/PostgreSqlClient.java | 6 +- .../postgresql/TestPostgreSqlClient.java | 4 +- 21 files changed, 1308 insertions(+), 155 deletions(-) create mode 100644 plugin/trino-base-jdbc/src/main/antlr4/io/trino/plugin/jdbc/expression/ConnectorExpressionPattern.g4 create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/CallPattern.java create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ExpressionCapture.java create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ExpressionMappingParser.java create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ExpressionPattern.java create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ExpressionPatternBuilder.java create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/GenericRewrite.java create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/LongTypeParameter.java create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/MatchContext.java delete mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteLike.java delete mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteLikeWithEscape.java create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/TypeParameterCapture.java create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/TypeParameterPattern.java create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/TypePattern.java create mode 100644 plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestExpressionMappingParser.java create mode 100644 plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestExpressionMatching.java create mode 100644 plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestGenericRewrite.java diff --git a/plugin/trino-base-jdbc/pom.xml b/plugin/trino-base-jdbc/pom.xml index bef0c6049fe4..a9918f3498ea 100644 --- a/plugin/trino-base-jdbc/pom.xml +++ b/plugin/trino-base-jdbc/pom.xml @@ -124,6 +124,12 @@ failsafe + + org.antlr + antlr4-runtime + ${dep.antlr.version} + + org.weakref jmxutils @@ -238,4 +244,13 @@ test + + + + + org.antlr + antlr4-maven-plugin + + + diff --git a/plugin/trino-base-jdbc/src/main/antlr4/io/trino/plugin/jdbc/expression/ConnectorExpressionPattern.g4 b/plugin/trino-base-jdbc/src/main/antlr4/io/trino/plugin/jdbc/expression/ConnectorExpressionPattern.g4 new file mode 100644 index 000000000000..ca8a960fb2fc --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/antlr4/io/trino/plugin/jdbc/expression/ConnectorExpressionPattern.g4 @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +grammar ConnectorExpressionPattern; + +tokens { + DELIMITER +} + +standaloneExpression + : expression EOF + ; + +standaloneType + : type EOF + ; + +expression + : call + | expressionCapture + ; + +call + : identifier '(' expression (',' expression)* ')' (':' type)? + ; + +expressionCapture + : identifier ':' type + ; + +type + : identifier + | identifier '(' typeParameter (',' typeParameter)* ')' + ; + +typeParameter + : number + | identifier + ; + +identifier + : IDENTIFIER + ; + +number + : INTEGER_VALUE + ; + +IDENTIFIER + : (LETTER | '_' | '$') (LETTER | DIGIT | '_' | '$')* + ; + +INTEGER_VALUE + : DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Za-z] + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +UNRECOGNIZED: .; diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/CallPattern.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/CallPattern.java new file mode 100644 index 000000000000..cabe97a42521 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/CallPattern.java @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import com.google.common.collect.ImmutableList; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.ConnectorExpression; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionUnqualifiedName; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; + +public class CallPattern + extends ExpressionPattern +{ + private final String functionName; + private final List parameters; + private final Optional resultType; + private final Pattern pattern; + + public CallPattern(String functionName, List parameters, Optional resultType) + { + this.functionName = requireNonNull(functionName, "functionName is null"); + this.parameters = ImmutableList.copyOf(requireNonNull(parameters, "parameters is null")); + this.resultType = requireNonNull(resultType, "resultType is null"); + + Pattern pattern = call().with(functionUnqualifiedName().equalTo(functionName)); + if (resultType.isPresent()) { + pattern = pattern.with(type().matching(resultType.get().getPattern())); + } + pattern = pattern.with(argumentCount().equalTo(parameters.size())); + for (int i = 0; i < parameters.size(); i++) { + pattern = pattern.with(argument(i).matching(parameters.get(i).getPattern())); + } + this.pattern = pattern; + } + + @Override + public Pattern getPattern() + { + return pattern; + } + + @Override + public void resolve(Captures captures, MatchContext matchContext) + { + for (ExpressionPattern parameter : parameters) { + parameter.resolve(captures, matchContext); + } + resultType.ifPresent(resultType -> resultType.resolve(captures, matchContext)); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + CallPattern that = (CallPattern) o; + return Objects.equals(functionName, that.functionName) && + Objects.equals(parameters, that.parameters) && + Objects.equals(resultType, that.resultType); + } + + @Override + public int hashCode() + { + return Objects.hash(functionName, parameters, resultType); + } + + @Override + public String toString() + { + return format( + "%s(%s)%s", + functionName, + parameters.stream() + .map(Object::toString) + .collect(joining(", ")), + resultType.map(resultType -> ": " + resultType).orElse("")); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ExpressionCapture.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ExpressionCapture.java new file mode 100644 index 000000000000..83bd1b2d1f7e --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ExpressionCapture.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.spi.expression.ConnectorExpression; + +import java.util.Objects; + +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class ExpressionCapture + extends ExpressionPattern +{ + private final String name; + private final TypePattern type; + + private final Capture capture = newCapture(); + private final Pattern pattern; + + public ExpressionCapture(String name, TypePattern type) + { + this.name = requireNonNull(name, "name is null"); + this.type = requireNonNull(type, "type is null"); + this.pattern = Pattern.typeOf(ConnectorExpression.class).capturedAs(capture) + .with(type().matching(type.getPattern())); + } + + @Override + public Pattern getPattern() + { + return pattern; + } + + @Override + public void resolve(Captures captures, MatchContext matchContext) + { + matchContext.record(name, captures.get(capture)); + type.resolve(captures, matchContext); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ExpressionCapture that = (ExpressionCapture) o; + return Objects.equals(name, that.name) && + Objects.equals(type, that.type); + } + + @Override + public int hashCode() + { + return Objects.hash(name, type); + } + + @Override + public String toString() + { + return format("%s: %s", name, type); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ExpressionMappingParser.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ExpressionMappingParser.java new file mode 100644 index 000000000000..094c6a513f46 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ExpressionMappingParser.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import org.antlr.v4.runtime.BaseErrorListener; +import org.antlr.v4.runtime.CharStreams; +import org.antlr.v4.runtime.CommonTokenStream; +import org.antlr.v4.runtime.ParserRuleContext; +import org.antlr.v4.runtime.RecognitionException; +import org.antlr.v4.runtime.Recognizer; +import org.antlr.v4.runtime.atn.PredictionMode; +import org.antlr.v4.runtime.misc.ParseCancellationException; + +import java.util.function.Function; + +import static java.lang.String.format; + +public class ExpressionMappingParser +{ + private static final BaseErrorListener ERROR_LISTENER = new BaseErrorListener() + { + @Override + public void syntaxError(Recognizer recognizer, Object offendingSymbol, int line, int charPositionInLine, String message, RecognitionException e) + { + throw new IllegalArgumentException(format("Error at %s:%s: %s", line, charPositionInLine, message), e); + } + }; + + public ExpressionPattern createExpressionPattern(String expressionPattern) + { + return (ExpressionPattern) invokeParser(expressionPattern, ConnectorExpressionPatternParser::standaloneExpression); + } + + public TypePattern createTypePattern(String typePattern) + { + return (TypePattern) invokeParser(typePattern, ConnectorExpressionPatternParser::standaloneType); + } + + public Object invokeParser(String input, Function parseFunction) + { + try { + ConnectorExpressionPatternLexer lexer = new ConnectorExpressionPatternLexer(CharStreams.fromString(input)); + CommonTokenStream tokenStream = new CommonTokenStream(lexer); + ConnectorExpressionPatternParser parser = new ConnectorExpressionPatternParser(tokenStream); + + lexer.removeErrorListeners(); + lexer.addErrorListener(ERROR_LISTENER); + + parser.removeErrorListeners(); + parser.addErrorListener(ERROR_LISTENER); + + ParserRuleContext tree; + try { + // first, try parsing with potentially faster SLL mode + parser.getInterpreter().setPredictionMode(PredictionMode.SLL); + tree = parseFunction.apply(parser); + } + catch (ParseCancellationException ex) { + // if we fail, parse with LL mode + tokenStream.seek(0); // rewind input stream + parser.reset(); + + parser.getInterpreter().setPredictionMode(PredictionMode.LL); + tree = parseFunction.apply(parser); + } + return new ExpressionPatternBuilder().visit(tree); + } + catch (StackOverflowError e) { + throw new IllegalArgumentException("expression pattern is too large (stack overflow while parsing)"); + } + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ExpressionPattern.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ExpressionPattern.java new file mode 100644 index 000000000000..5b9d4ae87ff5 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ExpressionPattern.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.spi.expression.ConnectorExpression; + +public abstract class ExpressionPattern +{ + public abstract Pattern getPattern(); + + public abstract void resolve(Captures captures, MatchContext matchContext); + + @Override + public abstract boolean equals(Object o); + + @Override + public abstract int hashCode(); + + @Override + public abstract String toString(); +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ExpressionPatternBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ExpressionPatternBuilder.java new file mode 100644 index 000000000000..a6d02989960a --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ExpressionPatternBuilder.java @@ -0,0 +1,121 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import org.antlr.v4.runtime.ParserRuleContext; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.String.format; + +public class ExpressionPatternBuilder + extends ConnectorExpressionPatternBaseVisitor +{ + @Override + public Object visitStandaloneExpression(ConnectorExpressionPatternParser.StandaloneExpressionContext context) + { + return visit(context.expression()); + } + + @Override + public Object visitStandaloneType(ConnectorExpressionPatternParser.StandaloneTypeContext context) + { + return visit(context.type()); + } + + @Override + public Object visitCall(ConnectorExpressionPatternParser.CallContext context) + { + return new CallPattern( + visit(context.identifier(), String.class), + visit(context.expression(), ExpressionPattern.class), + visitIfPresent(context.type(), TypePattern.class)); + } + + @Override + public ExpressionPattern visitExpressionCapture(ConnectorExpressionPatternParser.ExpressionCaptureContext context) + { + return new ExpressionCapture( + visit(context.identifier(), String.class), + visit(context.type(), TypePattern.class)); + } + + @Override + public Object visitType(ConnectorExpressionPatternParser.TypeContext context) + { + return new TypePattern( + visit(context.identifier(), String.class), + context.typeParameter().stream() + .map(parameter -> { + Object result = visit(parameter, Object.class); + if (result instanceof String) { + return new TypeParameterCapture((String) result); + } + if (result instanceof Long) { + return new LongTypeParameter((Long) result); + } + throw new UnsupportedOperationException(format("Unsupported parameter %s (%s) from %s", result, result.getClass(), parameter)); + }) + .collect(toImmutableList())); + } + + @Override + public Object visitNumber(ConnectorExpressionPatternParser.NumberContext context) + { + return Long.parseLong(context.INTEGER_VALUE().getText()); + } + + @Override + public Object visitIdentifier(ConnectorExpressionPatternParser.IdentifierContext context) + { + return context.getText(); + } + + private List visit(List contexts, Class expected) + { + return contexts.stream() + .map(context -> this.visit(context, expected)) + .collect(toImmutableList()); + } + + private Optional visitIfPresent(@Nullable ParserRuleContext context, Class expected) + { + if (context == null) { + return Optional.empty(); + } + return Optional.of(visit(context, expected)); + } + + private T visit(ParserRuleContext context, Class expected) + { + return expected.cast(super.visit(context)); + } + + // default implementation is error-prone + @Override + protected Object aggregateResult(Object aggregate, Object nextResult) + { + if (nextResult == null) { + throw new UnsupportedOperationException("not yet implemented"); + } + if (aggregate == null) { + return nextResult; + } + throw new UnsupportedOperationException(format("Cannot combine %s and %s", aggregate, nextResult)); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/GenericRewrite.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/GenericRewrite.java new file mode 100644 index 000000000000..843e455e448d --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/GenericRewrite.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import io.trino.matching.Captures; +import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.spi.expression.ConnectorExpression; + +import java.util.Optional; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.regex.Matcher.quoteReplacement; + +public class GenericRewrite + implements ConnectorExpressionRule +{ + // Matches words in the `rewritePattern` + private static final Pattern REWRITE_TOKENS = Pattern.compile("(? getPattern() + { + // TODO make ConnectorExpressionRule.getPattern result type flexible + //noinspection unchecked + return (io.trino.matching.Pattern) expressionPattern.getPattern(); + } + + @Override + public Optional rewrite(ConnectorExpression expression, Captures captures, RewriteContext context) + { + MatchContext matchContext = new MatchContext(); + expressionPattern.resolve(captures, matchContext); + + StringBuilder rewritten = new StringBuilder(); + Matcher matcher = REWRITE_TOKENS.matcher(rewritePattern); + while (matcher.find()) { + String identifier = matcher.group(0); + Optional capture = matchContext.getIfPresent(identifier); + String replacement; + if (capture.isPresent()) { + Object value = capture.get(); + if (value instanceof Long) { + replacement = Long.toString((Long) value); + } + else if (value instanceof ConnectorExpression) { + Optional rewrittenExpression = context.defaultRewrite((ConnectorExpression) value); + if (rewrittenExpression.isEmpty()) { + return Optional.empty(); + } + replacement = format("(%s)", rewrittenExpression.get()); + } + else { + throw new UnsupportedOperationException(format("Unsupported value: %s (%s)", value, value.getClass())); + } + } + else { + replacement = identifier; + } + matcher.appendReplacement(rewritten, quoteReplacement(replacement)); + } + matcher.appendTail(rewritten); + + return Optional.of(rewritten.toString()); + } + + @Override + public String toString() + { + return format("%s(%s -> %s)", GenericRewrite.class.getSimpleName(), expressionPattern, rewritePattern); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/JdbcConnectorExpressionRewriterBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/JdbcConnectorExpressionRewriterBuilder.java index 31853ae753a8..ea257365d152 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/JdbcConnectorExpressionRewriterBuilder.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/JdbcConnectorExpressionRewriterBuilder.java @@ -55,8 +55,26 @@ public JdbcConnectorExpressionRewriterBuilder add(ConnectorExpressionRule map(String expressionPattern) + { + return new ExpressionMapping<>() + { + @Override + public JdbcConnectorExpressionRewriterBuilder to(String rewritePattern) + { + rules.add(new GenericRewrite(expressionPattern, rewritePattern)); + return JdbcConnectorExpressionRewriterBuilder.this; + } + }; + } + public ConnectorExpressionRewriter build() { return new ConnectorExpressionRewriter<>(this.identifierQuote, rules.build()); } + + public interface ExpressionMapping + { + Continuation to(String rewritePattern); + } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/LongTypeParameter.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/LongTypeParameter.java new file mode 100644 index 000000000000..202a6319cfa5 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/LongTypeParameter.java @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.matching.Property; +import io.trino.spi.type.ParameterKind; +import io.trino.spi.type.TypeSignatureParameter; + +import java.util.Optional; + +public class LongTypeParameter + extends TypeParameterPattern +{ + private final long value; + private final Pattern pattern; + + public LongTypeParameter(long value) + { + this.value = value; + this.pattern = Pattern.typeOf(TypeSignatureParameter.class).with(value().equalTo(value)); + } + + @Override + public Pattern getPattern() + { + return pattern; + } + + @Override + public void resolve(Captures captures, MatchContext matchContext) {} + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + LongTypeParameter that = (LongTypeParameter) o; + return value == that.value; + } + + @Override + public int hashCode() + { + return Long.hashCode(value); + } + + @Override + public String toString() + { + return Long.toString(value); + } + + public static Property value() + { + return Property.optionalProperty("value", parameter -> { + if (parameter.getKind() != ParameterKind.LONG) { + return Optional.empty(); + } + return Optional.of(parameter.getLongLiteral()); + }); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/MatchContext.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/MatchContext.java new file mode 100644 index 000000000000..d67d210c4c7d --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/MatchContext.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableSet; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.BiFunction; + +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class MatchContext +{ + private final Map resolved = new HashMap<>(); + + public void record(String name, Object value) + { + requireNonNull(name, "name is null"); + requireNonNull(value, "value is null"); + resolved.merge(name, value, checkEqual(name)); + } + + public Object get(String name) + { + Object value = resolved.get(name); + if (value == null) { + throw new IllegalStateException("No value recorded for: " + name); + } + return value; + } + + public Optional getIfPresent(String name) + { + return Optional.ofNullable(resolved.get(name)); + } + + @VisibleForTesting + Set keys() + { + return ImmutableSet.copyOf(resolved.keySet()); + } + + private static BiFunction checkEqual(String name) + { + return (first, second) -> { + if (first.equals(second)) { + return first; + } + throw new IllegalStateException(format("%s is already mapped to %s, cannot remap to %s", name, first, second)); + }; + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteLike.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteLike.java deleted file mode 100644 index 5630df1922b2..000000000000 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteLike.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.jdbc.expression; - -import io.trino.matching.Capture; -import io.trino.matching.Captures; -import io.trino.matching.Pattern; -import io.trino.plugin.base.expression.ConnectorExpressionRule; -import io.trino.spi.expression.Call; -import io.trino.spi.expression.ConnectorExpression; -import io.trino.spi.type.VarcharType; - -import java.util.Optional; - -import static io.trino.matching.Capture.newCapture; -import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument; -import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount; -import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call; -import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.expression; -import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName; -import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; -import static io.trino.spi.expression.StandardFunctions.LIKE_PATTERN_FUNCTION_NAME; -import static io.trino.spi.type.BooleanType.BOOLEAN; -import static java.lang.String.format; - -public class RewriteLike - implements ConnectorExpressionRule -{ - private static final Capture LIKE_VALUE = newCapture(); - private static final Capture LIKE_PATTERN = newCapture(); - - private static final Pattern PATTERN = call() - .with(functionName().equalTo(LIKE_PATTERN_FUNCTION_NAME)) - .with(type().equalTo(BOOLEAN)) - // TODO support ESCAPE. Currently, LIKE with ESCAPE is not pushed down. - .with(argumentCount().equalTo(2)) - // TODO support LIKE on char(n) - .with(argument(0).matching(expression().capturedAs(LIKE_VALUE).with(type().matching(VarcharType.class::isInstance)))) - // Currently, LIKE's pattern must be a varchar. - .with(argument(1).matching(expression().capturedAs(LIKE_PATTERN).with(type().matching(VarcharType.class::isInstance)))); - - @Override - public Pattern getPattern() - { - return PATTERN; - } - - @Override - public Optional rewrite(Call call, Captures captures, RewriteContext context) - { - Optional value = context.defaultRewrite(captures.get(LIKE_VALUE)); - if (value.isEmpty()) { - return Optional.empty(); - } - Optional pattern = context.defaultRewrite(captures.get(LIKE_PATTERN)); - if (pattern.isEmpty()) { - return Optional.empty(); - } - return Optional.of(format("%s LIKE %s", value.get(), pattern.get())); - } -} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteLikeWithEscape.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteLikeWithEscape.java deleted file mode 100644 index 197bc14425fd..000000000000 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteLikeWithEscape.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.jdbc.expression; - -import io.trino.matching.Capture; -import io.trino.matching.Captures; -import io.trino.matching.Pattern; -import io.trino.plugin.base.expression.ConnectorExpressionRule; -import io.trino.spi.expression.Call; -import io.trino.spi.expression.ConnectorExpression; -import io.trino.spi.type.VarcharType; - -import java.util.Optional; - -import static io.trino.matching.Capture.newCapture; -import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument; -import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount; -import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call; -import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.expression; -import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName; -import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; -import static io.trino.spi.expression.StandardFunctions.LIKE_PATTERN_FUNCTION_NAME; -import static io.trino.spi.type.BooleanType.BOOLEAN; -import static java.lang.String.format; - -public class RewriteLikeWithEscape - implements ConnectorExpressionRule -{ - private static final Capture LIKE_VALUE = newCapture(); - private static final Capture LIKE_PATTERN = newCapture(); - private static final Capture ESCAPE_VALUE = newCapture(); - - private static final Pattern PATTERN = call() - .with(functionName().equalTo(LIKE_PATTERN_FUNCTION_NAME)) - .with(type().equalTo(BOOLEAN)) - .with(argumentCount().equalTo(3)) - // TODO support LIKE on char(n) - .with(argument(0).matching(expression().capturedAs(LIKE_VALUE).with(type().matching(VarcharType.class::isInstance)))) - // Currently, LIKE's pattern must be a varchar. - .with(argument(1).matching(expression().capturedAs(LIKE_PATTERN).with(type().matching(VarcharType.class::isInstance)))) - .with(argument(2).matching(expression().capturedAs(ESCAPE_VALUE).with(type().matching(VarcharType.class::isInstance)))); - - @Override - public Pattern getPattern() - { - return PATTERN; - } - - @Override - public Optional rewrite(Call call, Captures captures, RewriteContext context) - { - Optional value = context.defaultRewrite(captures.get(LIKE_VALUE)); - if (value.isEmpty()) { - return Optional.empty(); - } - Optional pattern = context.defaultRewrite(captures.get(LIKE_PATTERN)); - if (pattern.isEmpty()) { - return Optional.empty(); - } - Optional escape = context.defaultRewrite(captures.get(ESCAPE_VALUE)); - if (escape.isEmpty()) { - return Optional.empty(); - } - return Optional.of(format("%s LIKE %s ESCAPE %s", value.get(), pattern.get(), escape.get())); - } -} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/TypeParameterCapture.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/TypeParameterCapture.java new file mode 100644 index 000000000000..658c977853e1 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/TypeParameterCapture.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.matching.Property; +import io.trino.spi.type.TypeSignatureParameter; + +import java.util.Objects; + +import static io.trino.matching.Capture.newCapture; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +public class TypeParameterCapture + extends TypeParameterPattern +{ + private final String name; + + private final Capture capture = newCapture(); + private final Pattern pattern; + + public TypeParameterCapture(String name) + { + this.name = requireNonNull(name, "name is null"); + this.pattern = Pattern.typeOf(TypeSignatureParameter.class).with(self().capturedAs(capture)); + } + + @Override + public Pattern getPattern() + { + return pattern; + } + + @Override + public void resolve(Captures captures, MatchContext matchContext) + { + TypeSignatureParameter parameter = captures.get(capture); + switch (parameter.getKind()) { + case TYPE: + matchContext.record(name, parameter.getTypeSignature()); + break; + case LONG: + matchContext.record(name, parameter.getLongLiteral()); + break; + default: + throw new UnsupportedOperationException("Unsupported parameter: " + parameter); + } + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TypeParameterCapture that = (TypeParameterCapture) o; + return Objects.equals(name, that.name); + } + + @Override + public int hashCode() + { + return name.hashCode(); + } + + @Override + public String toString() + { + return name; + } + + public static Property self() + { + return Property.property("self", identity()); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/TypeParameterPattern.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/TypeParameterPattern.java new file mode 100644 index 000000000000..23a4fcedf174 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/TypeParameterPattern.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.spi.type.TypeSignatureParameter; + +public abstract class TypeParameterPattern +{ + public abstract Pattern getPattern(); + + public abstract void resolve(Captures captures, MatchContext matchContext); + + @Override + public abstract boolean equals(Object obj); + + @Override + public abstract int hashCode(); + + @Override + public abstract String toString(); +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/TypePattern.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/TypePattern.java new file mode 100644 index 000000000000..4a313730adc6 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/TypePattern.java @@ -0,0 +1,102 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import com.google.common.collect.ImmutableList; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.matching.Property; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeSignatureParameter; + +import java.util.List; +import java.util.Objects; + +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; + +public class TypePattern +{ + private final String baseName; + private final List parameters; + private final Pattern pattern; + + public TypePattern(String baseName, List parameters) + { + this.baseName = requireNonNull(baseName, "baseName is null"); + this.parameters = ImmutableList.copyOf(requireNonNull(parameters, "parameters is null")); + Pattern pattern = Pattern.typeOf(Type.class).with(baseName().equalTo(baseName)); + for (int i = 0; i < parameters.size(); i++) { + pattern = pattern.with(parameter(i).matching(parameters.get(i).getPattern())); + } + this.pattern = pattern; + } + + public Pattern getPattern() + { + return pattern; + } + + public void resolve(Captures captures, MatchContext matchContext) + { + for (TypeParameterPattern parameter : parameters) { + parameter.resolve(captures, matchContext); + } + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TypePattern that = (TypePattern) o; + return Objects.equals(baseName, that.baseName) && + Objects.equals(parameters, that.parameters); + } + + @Override + public int hashCode() + { + return Objects.hash(baseName, parameters); + } + + @Override + public String toString() + { + if (parameters.isEmpty()) { + return baseName; + } + return format( + "%s(%s)", + baseName, + parameters.stream() + .map(Object::toString) + .collect(joining(", "))); + } + + public static Property baseName() + { + return Property.property("baseName", Type::getBaseName); + } + + public static Property parameter(int i) + { + return Property.property(format("parameter(%s)", i), type -> type.getTypeSignature().getParameters().get(i)); + } +} diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestExpressionMappingParser.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestExpressionMappingParser.java new file mode 100644 index 000000000000..f5db6bf71347 --- /dev/null +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestExpressionMappingParser.java @@ -0,0 +1,109 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; + +import static org.testng.Assert.assertEquals; + +public class TestExpressionMappingParser +{ + @Test + public void testCapture() + { + assertExpressionPattern( + "b: bigint", + new ExpressionCapture( + "b", + type("bigint"))); + + assertExpressionPattern( + "bar: varchar(n)", + new ExpressionCapture( + "bar", + type("varchar", parameter("n")))); + } + + @Test + public void testParameterizedType() + { + assertExpressionPattern( + "bar: varchar(3)", + new ExpressionCapture( + "bar", + type("varchar", parameter(3L)))); + } + + @Test + public void testCallPattern() + { + assertExpressionPattern( + "$like_pattern(a: varchar(n), b: varchar(m))", + new CallPattern( + "$like_pattern", + List.of( + new ExpressionCapture("a", type("varchar", parameter("n"))), + new ExpressionCapture("b", type("varchar", parameter("m")))), + Optional.empty())); + + assertExpressionPattern( + "$like_pattern(a: varchar(n), b: varchar(m)): boolean", + new CallPattern( + "$like_pattern", + List.of( + new ExpressionCapture("a", type("varchar", parameter("n"))), + new ExpressionCapture("b", type("varchar", parameter("m")))), + Optional.of(type("boolean")))); + } + + private static void assertExpressionPattern(String expressionPattern, ExpressionPattern expected) + { + assertExpressionPattern(expressionPattern, expressionPattern, expected); + } + + private static void assertExpressionPattern(String expressionPattern, String canonical, ExpressionPattern expected) + { + assertEquals(expressionPattern(expressionPattern), expected); + assertEquals(expected.toString(), canonical); + } + + private static ExpressionPattern expressionPattern(String expressionPattern) + { + return new ExpressionMappingParser().createExpressionPattern(expressionPattern); + } + + private static TypePattern type(String baseName) + { + return new TypePattern(baseName, ImmutableList.of()); + } + + private static TypePattern type(String baseName, TypeParameterPattern... parameter) + { + return new TypePattern(baseName, ImmutableList.copyOf(parameter)); + } + + private static TypeParameterPattern parameter(long value) + { + return new LongTypeParameter(value); + } + + private static TypeParameterPattern parameter(String name) + { + return new TypeParameterCapture(name); + } +} diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestExpressionMatching.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestExpressionMatching.java new file mode 100644 index 000000000000..7398f458cff7 --- /dev/null +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestExpressionMatching.java @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import io.trino.matching.Match; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.expression.FunctionName; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.Type; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.google.common.collect.MoreCollectors.onlyElement; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.DecimalType.createDecimalType; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestExpressionMatching +{ + @Test + public void testMatchType() + { + Type type = createDecimalType(10, 2); + TypePattern pattern = typePattern("decimal(p, s)"); + + Match match = pattern.getPattern().match(type).collect(onlyElement()); + MatchContext matchContext = new MatchContext(); + pattern.resolve(match.captures(), matchContext); + + assertThat(matchContext.keys()).containsExactlyInAnyOrder("p", "s"); + assertThat(matchContext.get("p")).isEqualTo(10L); + assertThat(matchContext.get("s")).isEqualTo(2L); + } + + @Test + public void testExpressionCapture() + { + ConnectorExpression expression = new Call( + createDecimalType(21, 2), + new FunctionName("add"), + List.of( + new Variable("first", createDecimalType(10, 2)), + new Variable("second", BIGINT))); + ExpressionPattern pattern = expressionPattern("foo: decimal(p, s)"); + + Match match = pattern.getPattern().match(expression).collect(onlyElement()); + MatchContext matchContext = new MatchContext(); + pattern.resolve(match.captures(), matchContext); + + assertThat(matchContext.keys()).containsExactlyInAnyOrder("p", "s", "foo"); + assertThat(matchContext.get("p")).isEqualTo(21L); + assertThat(matchContext.get("s")).isEqualTo(2L); + assertThat(matchContext.get("foo")).isSameAs(expression); + } + + @Test + public void testMatchCall() + { + ConnectorExpression expression = new Call( + createDecimalType(21, 2), + new FunctionName("add"), + List.of( + new Variable("first", createDecimalType(10, 2)), + new Variable("second", BIGINT))); + ExpressionPattern pattern = expressionPattern("add(foo: decimal(p, s), bar: bigint)"); + + Match match = pattern.getPattern().match(expression).collect(onlyElement()); + MatchContext matchContext = new MatchContext(); + pattern.resolve(match.captures(), matchContext); + + assertThat(matchContext.keys()).containsExactlyInAnyOrder("p", "s", "foo", "bar"); + assertThat(matchContext.get("p")).isEqualTo(10L); + assertThat(matchContext.get("s")).isEqualTo(2L); + assertThat(matchContext.get("foo")).isEqualTo(new Variable("first", createDecimalType(10, 2))); + assertThat(matchContext.get("bar")).isEqualTo(new Variable("second", BIGINT)); + } + + private static ExpressionPattern expressionPattern(String expressionPattern) + { + return new ExpressionMappingParser().createExpressionPattern(expressionPattern); + } + + private static TypePattern typePattern(String typePattern) + { + return new ExpressionMappingParser().createTypePattern(typePattern); + } +} diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestGenericRewrite.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestGenericRewrite.java new file mode 100644 index 000000000000..a1ac5ddccecf --- /dev/null +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestGenericRewrite.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import io.trino.matching.Match; +import io.trino.plugin.base.expression.ConnectorExpressionRule.RewriteContext; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.expression.FunctionName; +import io.trino.spi.expression.Variable; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; + +import static com.google.common.collect.MoreCollectors.onlyElement; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.DecimalType.createDecimalType; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestGenericRewrite +{ + @Test + public void testRewriteCall() + { + GenericRewrite rewrite = new GenericRewrite("add(foo: decimal(p, s), bar: bigint): decimal(rp, rs)", "foo + bar::decimal(rp,rs)"); + ConnectorExpression expression = new Call( + createDecimalType(21, 2), + new FunctionName("add"), + List.of( + new Variable("first", createDecimalType(10, 2)), + new Variable("second", BIGINT))); + + Match match = rewrite.getPattern().match(expression).collect(onlyElement()); + Optional rewritten = rewrite.rewrite(expression, match.captures(), new RewriteContext<>() + { + @Override + public Map getAssignments() + { + throw new UnsupportedOperationException(); + } + + @Override + public Function getIdentifierQuote() + { + throw new UnsupportedOperationException(); + } + + @Override + public ConnectorSession getSession() + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional defaultRewrite(ConnectorExpression expression) + { + if (expression instanceof Variable) { + return Optional.of("\"" + ((Variable) expression).getName().replace("\"", "\"\"") + "\""); + } + return Optional.empty(); + } + }); + + assertThat(rewritten).hasValue("(\"first\") + (\"second\")::decimal(21,2)"); + } +} diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index 4da60c5743b7..8a9cef58c4df 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -64,8 +64,6 @@ import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.expression.RewriteComparison; -import io.trino.plugin.jdbc.expression.RewriteLike; -import io.trino.plugin.jdbc.expression.RewriteLikeWithEscape; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.plugin.postgresql.PostgreSqlConfig.ArrayMapping; import io.trino.spi.TrinoException; @@ -310,8 +308,8 @@ public PostgreSqlClient( .addStandardRules() // TODO allow all comparison operators for numeric types .add(new RewriteComparison(RewriteComparison.ComparisonOperator.EQUAL, RewriteComparison.ComparisonOperator.NOT_EQUAL)) - .add(new RewriteLike()) - .add(new RewriteLikeWithEscape()) + .map("$like_pattern(value: varchar, pattern: varchar): boolean").to("value LIKE pattern") + .map("$like_pattern(value: varchar, pattern: varchar, escape: varchar(1)): boolean").to("value LIKE pattern ESCAPE escape") .build(); } diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java index cedba2a4fdb7..ebfb2352decd 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java @@ -262,7 +262,7 @@ public void testConvertLike() Optional.empty()), Map.of("c_varchar_symbol", VARCHAR_COLUMN.getColumnType())), Map.of("c_varchar_symbol", VARCHAR_COLUMN))) - .hasValue("\"c_varchar\" LIKE '%pattern%'"); + .hasValue("(\"c_varchar\") LIKE ('%pattern%')"); // c_varchar LIKE '%pattern\%' ESCAPE '\' assertThat(JDBC_CLIENT.convertPredicate(SESSION, @@ -273,7 +273,7 @@ public void testConvertLike() new StringLiteral("\\")), Map.of("c_varchar", VARCHAR_COLUMN.getColumnType())), Map.of(VARCHAR_COLUMN.getColumnName(), VARCHAR_COLUMN))) - .hasValue("\"c_varchar\" LIKE '%pattern\\%' ESCAPE '\\'"); + .hasValue("(\"c_varchar\") LIKE ('%pattern\\%') ESCAPE ('\\')"); } private ConnectorExpression translateToConnectorExpression(Expression expression, Map symbolTypes)