Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize: refactor ColumnUtils and EscapeHandler #5456

Merged
merged 11 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ protected TableRecords buildTableRecords(Map<String, List<Object>> pkValuesMap)
List<String> insertColumns = recognizer.getInsertColumns();
if (ONLY_CARE_UPDATE_COLUMNS && CollectionUtils.isNotEmpty(insertColumns)) {
Set<String> columns = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
columns.addAll(recognizer.getInsertColumnsIsSimplified());
columns.addAll(recognizer.getInsertColumnsUnEscape());
columns.addAll(pkColumnNameList);
for (String columnName : columns) {
selectSQLJoin.add(columnName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ protected TableRecords beforeImage() throws SQLException {
throw new NotSupportYetException("Multi update SQL with orderBy condition is not support yet !");
}

List<String> updateColumns = sqlUpdateRecognizer.getUpdateColumnsIsSimplified();
List<String> updateColumns = sqlUpdateRecognizer.getUpdateColumnsUnEscape();
updateColumnsSet.addAll(updateColumns);
if (noWhereCondition) {
continue;
Expand Down Expand Up @@ -155,7 +155,7 @@ private String buildAfterImageSQL(TableMeta tableMeta, TableRecords beforeImage)
for (SQLRecognizer recognizer : sqlRecognizers) {
sqlRecognizer = recognizer;
SQLUpdateRecognizer sqlUpdateRecognizer = (SQLUpdateRecognizer) sqlRecognizer;
updateColumnsSet.addAll(sqlUpdateRecognizer.getUpdateColumnsIsSimplified());
updateColumnsSet.addAll(sqlUpdateRecognizer.getUpdateColumnsUnEscape());
}
StringBuilder prefix = new StringBuilder("SELECT ");
String suffix = " FROM " + getFromTableInSQL() + " WHERE " + SqlGenerateUtils.buildWhereConditionByPKs(tableMeta.getPrimaryKeyOnlyName(), beforeImage.pkRows().size(), getDbType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ private String buildBeforeImageSQL(TableMeta tableMeta, ArrayList<List<Object>>
}
suffix.append(" FOR UPDATE");
StringJoiner selectSQLJoin = new StringJoiner(", ", prefix.toString(), suffix.toString());
List<String> needUpdateColumns = getNeedUpdateColumns(tableMeta.getTableName(), sqlRecognizer.getTableAlias(), recognizer.getUpdateColumnsIsSimplified());
List<String> needUpdateColumns = getNeedUpdateColumns(tableMeta.getTableName(), sqlRecognizer.getTableAlias(), recognizer.getUpdateColumnsUnEscape());
needUpdateColumns.forEach(selectSQLJoin::add);
return selectSQLJoin.toString();
}
Expand Down Expand Up @@ -119,7 +119,7 @@ private String buildAfterImageSQL(TableMeta tableMeta, TableRecords beforeImage)
String suffix = " FROM " + getFromTableInSQL() + " WHERE " + whereSql;
StringJoiner selectSQLJoiner = new StringJoiner(", ", prefix.toString(), suffix);
SQLUpdateRecognizer recognizer = (SQLUpdateRecognizer) sqlRecognizer;
List<String> needUpdateColumns = getNeedUpdateColumns(tableMeta.getTableName(), sqlRecognizer.getTableAlias(), recognizer.getUpdateColumnsIsSimplified());
List<String> needUpdateColumns = getNeedUpdateColumns(tableMeta.getTableName(), sqlRecognizer.getTableAlias(), recognizer.getUpdateColumnsUnEscape());
needUpdateColumns.forEach(selectSQLJoiner::add);
return selectSQLJoiner.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.seata.sqlparser;
package io.seata.rm.datasource.sql.handler.mariadb;

import io.seata.common.loader.LoadLevel;
import io.seata.rm.datasource.sql.handler.mysql.MySQLEscapeHandler;
import io.seata.sqlparser.util.JdbcConstants;

/**
* The interface Keyword checker.
* The type Mariadb escape handler.
*
* @author Wu
* @author slievrly
*/
public interface KeywordChecker {
/**
* check whether given field name and table name use keywords
*
* @param fieldOrTableName the field or table name
* @return boolean
*/
boolean check(String fieldOrTableName);


/**
* check whether given field or table name use keywords. the method has database special logic.
* @param fieldOrTableName the field or table name
* @return true: need to escape. false: no need to escape.
*/
boolean checkEscape(String fieldOrTableName);

@LoadLevel(name = JdbcConstants.MARIADB)
public class MariadbEscapeHandler extends MySQLEscapeHandler {
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.seata.rm.datasource.undo.mysql.keyword;
package io.seata.rm.datasource.sql.handler.mysql;

import java.util.Arrays;
import java.util.Set;
import java.util.stream.Collectors;

import io.seata.common.loader.LoadLevel;
import io.seata.sqlparser.KeywordChecker;
import io.seata.common.util.StringUtils;
import io.seata.sqlparser.EscapeHandler;
import io.seata.sqlparser.util.JdbcConstants;

/**
Expand All @@ -29,7 +30,7 @@
* @author xingfudeshi@gmail.com
*/
@LoadLevel(name = JdbcConstants.MYSQL)
public class MySQLKeywordChecker implements KeywordChecker {
public class MySQLEscapeHandler implements EscapeHandler {

private Set<String> keywordSet = Arrays.stream(MySQLKeyword.values()).map(MySQLKeyword::name).collect(Collectors.toSet());

Expand Down Expand Up @@ -1101,7 +1102,7 @@ private enum MySQLKeyword {


@Override
public boolean check(String fieldOrTableName) {
public boolean checkIfKeyWords(String fieldOrTableName) {
if (keywordSet.contains(fieldOrTableName)) {
return true;
}
Expand All @@ -1113,8 +1114,19 @@ public boolean check(String fieldOrTableName) {
}

@Override
public boolean checkEscape(String fieldOrTableName) {
return check(fieldOrTableName);
public boolean checkIfNeedEscape(String fieldOrTableName) {
if (StringUtils.isBlank(fieldOrTableName)) {
return false;
}
fieldOrTableName = fieldOrTableName.trim();
if (containsEscape(fieldOrTableName)) {
return false;
}
return checkIfKeyWords(fieldOrTableName);
}

@Override
public char getEscapeSymbol() {
return '`';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.seata.rm.datasource.undo.oracle.keyword;
package io.seata.rm.datasource.sql.handler.oracle;

import java.util.Arrays;
import java.util.Set;
import java.util.stream.Collectors;

import io.seata.common.loader.LoadLevel;
import io.seata.sqlparser.KeywordChecker;
import io.seata.common.util.StringUtils;
import io.seata.sqlparser.EscapeHandler;
import io.seata.sqlparser.util.JdbcConstants;

/**
Expand All @@ -29,7 +30,7 @@
* @author ccg
*/
@LoadLevel(name = JdbcConstants.ORACLE)
public class OracleKeywordChecker implements KeywordChecker {
public class OracleEscapeHandler implements EscapeHandler {

private Set<String> keywordSet = Arrays.stream(OracleKeyword.values()).map(OracleKeyword::name).collect(Collectors.toSet());

Expand Down Expand Up @@ -488,7 +489,7 @@ private enum OracleKeyword {
}

@Override
public boolean check(String fieldOrTableName) {
public boolean checkIfKeyWords(String fieldOrTableName) {
if (keywordSet.contains(fieldOrTableName)) {
return true;
}
Expand All @@ -499,13 +500,33 @@ public boolean check(String fieldOrTableName) {

}


@Override
public boolean checkEscape(String fieldOrTableName) {
boolean check = check(fieldOrTableName);
public boolean checkIfNeedEscape(String fieldOrTableName) {
if (StringUtils.isBlank(fieldOrTableName)) {
return false;
}
fieldOrTableName = fieldOrTableName.trim();
if (containsEscape(fieldOrTableName)) {
return false;
}
boolean isKeyWord = checkIfKeyWords(fieldOrTableName);
if (isKeyWord) {
return true;
}
// oracle
// we are recommend table name and column name must uppercase.
// if exists full uppercase, the table name or column name does't bundle escape symbol.
if (!check && isUppercase(fieldOrTableName)) {
//create\read table TABLE "table" "TABLE"
slievrly marked this conversation as resolved.
Show resolved Hide resolved
//
//table √ √ × √
//
//TABLE √ √ × √
//
//"table" × × √ ×
//
//"TABLE" √ √ × √
if (isUppercase(fieldOrTableName)) {
return false;
}
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.seata.rm.datasource.undo.postgresql.keyword;
package io.seata.rm.datasource.sql.handler.postgresql;

import java.util.Arrays;
import java.util.Set;
import java.util.stream.Collectors;

import io.seata.common.loader.LoadLevel;
import io.seata.sqlparser.KeywordChecker;
import io.seata.common.util.StringUtils;
import io.seata.sqlparser.EscapeHandler;
import io.seata.sqlparser.util.JdbcConstants;

/**
Expand All @@ -29,10 +30,10 @@
* @author japsercloud
*/
@LoadLevel(name = JdbcConstants.POSTGRESQL)
public class PostgresqlKeywordChecker implements KeywordChecker {
public class PostgresqlEscapeHandler implements EscapeHandler {

private Set<String> keywordSet = Arrays.stream(PostgresqlKeywordChecker.PostgresqlKeyword.values())
.map(PostgresqlKeywordChecker.PostgresqlKeyword::name).collect(Collectors.toSet());
private Set<String> keywordSet = Arrays.stream(PostgresqlEscapeHandler.PostgresqlKeyword.values())
.map(PostgresqlEscapeHandler.PostgresqlKeyword::name).collect(Collectors.toSet());

/**
* postgresql keyword
Expand Down Expand Up @@ -357,7 +358,7 @@ private enum PostgresqlKeyword {
}

@Override
public boolean check(String fieldOrTableName) {
public boolean checkIfKeyWords(String fieldOrTableName) {
if (keywordSet.contains(fieldOrTableName)) {
return true;
}
Expand All @@ -369,8 +370,15 @@ public boolean check(String fieldOrTableName) {
}

@Override
public boolean checkEscape(String fieldOrTableName) {
boolean check = check(fieldOrTableName);
public boolean checkIfNeedEscape(String fieldOrTableName) {
if (StringUtils.isBlank(fieldOrTableName)) {
return false;
}
fieldOrTableName = fieldOrTableName.trim();
if (containsEscape(fieldOrTableName)) {
return false;
}
boolean check = checkIfKeyWords(fieldOrTableName);
if (!check && !containsUppercase(fieldOrTableName)) {
// postgresql
// we are recommend table name and column name must lowercase.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
io.seata.rm.datasource.sql.handler.oracle.OracleEscapeHandler
io.seata.rm.datasource.sql.handler.mysql.MySQLEscapeHandler
io.seata.rm.datasource.sql.handler.postgresql.PostgresqlEscapeHandler
io.seata.rm.datasource.sql.handler.mariadb.MariadbEscapeHandler

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -33,51 +33,51 @@ public void test_delEscape_byEscape() throws Exception {
List<String> cols = new ArrayList<>();
cols.add("`id`");
cols.add("name");
cols = ColumnUtils.delEscape(cols, ColumnUtils.Escape.MYSQL);
cols = ColumnUtils.delEscape(cols, JdbcConstants.MYSQL);
Assertions.assertEquals("id", cols.get(0));
Assertions.assertEquals("name", cols.get(1));

List<String> cols2 = new ArrayList<>();
cols2.add("\"id\"");
cols2 = ColumnUtils.delEscape(cols2, ColumnUtils.Escape.STANDARD);
cols2 = ColumnUtils.delEscape(cols2, JdbcConstants.ORACLE);
Assertions.assertEquals("id", cols2.get(0));

List<String> cols3 = new ArrayList<>();
cols3.add("\"scheme\".\"id\"");
cols3 = ColumnUtils.delEscape(cols3, ColumnUtils.Escape.STANDARD);
cols3 = ColumnUtils.delEscape(cols3, JdbcConstants.ORACLE);
Assertions.assertEquals("scheme.id", cols3.get(0));

List<String> cols4 = new ArrayList<>();
cols4.add("`scheme`.`id`");
cols4 = ColumnUtils.delEscape(cols4, ColumnUtils.Escape.MYSQL);
cols4 = ColumnUtils.delEscape(cols4, JdbcConstants.MYSQL);
Assertions.assertEquals("scheme.id", cols4.get(0));

List<String> cols5 = new ArrayList<>();
cols5.add("\"scheme\".id");
cols5 = ColumnUtils.delEscape(cols5, ColumnUtils.Escape.STANDARD);
cols5 = ColumnUtils.delEscape(cols5, JdbcConstants.ORACLE);
Assertions.assertEquals("scheme.id", cols5.get(0));

List<String> cols6 = new ArrayList<>();
cols6.add("\"tab\"\"le\"");
cols6 = ColumnUtils.delEscape(cols6, ColumnUtils.Escape.STANDARD);
cols6 = ColumnUtils.delEscape(cols6, JdbcConstants.ORACLE);
Assertions.assertEquals("tab\"\"le", cols6.get(0));

List<String> cols7 = new ArrayList<>();
cols7.add("scheme.\"id\"");
cols7 = ColumnUtils.delEscape(cols7, ColumnUtils.Escape.STANDARD);
cols7 = ColumnUtils.delEscape(cols7, JdbcConstants.ORACLE);
Assertions.assertEquals("scheme.id", cols7.get(0));

List<String> cols8 = new ArrayList<>();
cols8.add("`scheme`.id");
cols8 = ColumnUtils.delEscape(cols8, ColumnUtils.Escape.MYSQL);
Assertions.assertEquals("scheme.id", cols8.get(0));
cols8 = ColumnUtils.delEscape(cols8, JdbcConstants.ORACLE);
Assertions.assertEquals("`scheme`.id", cols8.get(0));

List<String> cols9 = new ArrayList<>();
cols9.add("scheme.`id`");
cols9 = ColumnUtils.delEscape(cols9, ColumnUtils.Escape.MYSQL);
cols9 = ColumnUtils.delEscape(cols9, JdbcConstants.MYSQL);
Assertions.assertEquals("scheme.id", cols9.get(0));

Assertions.assertNull(ColumnUtils.delEscape((String) null, ColumnUtils.Escape.MYSQL));
Assertions.assertNull(ColumnUtils.delEscape((String) null, JdbcConstants.MYSQL));
}

@Test
Expand Down
Loading