Skip to content

Commit

Permalink
set backtick by DbType
Browse files Browse the repository at this point in the history
  • Loading branch information
whhe committed Dec 26, 2021
1 parent b54bea5 commit f24ab0a
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.slf4j.LoggerFactory;

import com.alibaba.druid.pool.DruidDataSource;
import com.alibaba.druid.util.JdbcUtils;
import com.alibaba.otter.canal.client.adapter.OuterAdapter;
import com.alibaba.otter.canal.client.adapter.rdb.config.ConfigLoader;
import com.alibaba.otter.canal.client.adapter.rdb.config.MappingConfig;
Expand Down Expand Up @@ -73,6 +74,11 @@ public Map<String, MirrorDbConfig> getMirrorDbConfigCache() {
@Override
public void init(OuterAdapterConfig configuration, Properties envProperties) {
this.envProperties = envProperties;

// 从jdbc url获取db类型
Map<String, String> properties = configuration.getProperties();
String dbType = JdbcUtils.getDbType(properties.get("jdbc.url"), null);

Map<String, MappingConfig> rdbMappingTmp = ConfigLoader.load(envProperties);
// 过滤不匹配的key的配置
rdbMappingTmp.forEach((key, mappingConfig) -> {
Expand Down Expand Up @@ -112,7 +118,6 @@ public void init(OuterAdapterConfig configuration, Properties envProperties) {
}

// 初始化连接池
Map<String, String> properties = configuration.getProperties();
dataSource = new DruidDataSource();
dataSource.setDriverClassName(properties.get("jdbc.driverClassName"));
dataSource.setUrl(properties.get("jdbc.url"));
Expand All @@ -125,6 +130,8 @@ public void init(OuterAdapterConfig configuration, Properties envProperties) {
dataSource.setTimeBetweenEvictionRunsMillis(60000);
dataSource.setMinEvictableIdleTimeMillis(300000);
dataSource.setUseUnfairLock(true);
dataSource.setDbType(dbType);

// List<String> array = new ArrayList<>();
// array.add("set names utf8mb4;");
// dataSource.setConnectionInitSqls(array);
Expand Down Expand Up @@ -226,7 +233,7 @@ public EtlResult etl(String task, List<String> params) {
public Map<String, Object> count(String task) {
MappingConfig config = rdbMapping.get(task);
MappingConfig.DbMapping dbMapping = config.getDbMapping();
String sql = "SELECT COUNT(1) AS cnt FROM " + SyncUtil.getDbTableName(dbMapping);
String sql = "SELECT COUNT(1) AS cnt FROM " + SyncUtil.getDbTableName(dbMapping, dataSource.getDbType());
Connection conn = null;
Map<String, Object> res = new LinkedHashMap<>();
try {
Expand All @@ -252,7 +259,7 @@ public Map<String, Object> count(String task) {
}
}
}
res.put("targetTable", SyncUtil.getDbTableName(dbMapping));
res.put("targetTable", SyncUtil.getDbTableName(dbMapping, dataSource.getDbType()));

return res;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import javax.sql.DataSource;

import com.alibaba.druid.pool.DruidDataSource;
import com.alibaba.otter.canal.client.adapter.rdb.config.MappingConfig;
import com.alibaba.otter.canal.client.adapter.rdb.config.MappingConfig.DbMapping;
import com.alibaba.otter.canal.client.adapter.rdb.support.SyncUtil;
Expand Down Expand Up @@ -56,8 +57,11 @@ protected boolean executeSqlImport(DataSource srcDS, String sql, List<Object> va
DbMapping dbMapping = (DbMapping) mapping;
Map<String, String> columnsMap = new LinkedHashMap<>();
Map<String, Integer> columnType = new LinkedHashMap<>();
DruidDataSource dataSource = (DruidDataSource) srcDS;

Util.sqlRS(targetDS, "SELECT * FROM " + SyncUtil.getDbTableName(dbMapping) + " LIMIT 1 ", rs -> {
Util.sqlRS(targetDS,
"SELECT * FROM " + SyncUtil.getDbTableName(dbMapping, dataSource.getDbType()) + " LIMIT 1 ",
rs -> {
try {

ResultSetMetaData rsd = rs.getMetaData();
Expand All @@ -83,7 +87,9 @@ protected boolean executeSqlImport(DataSource srcDS, String sql, List<Object> va
boolean completed = false;

StringBuilder insertSql = new StringBuilder();
insertSql.append("INSERT INTO ").append(SyncUtil.getDbTableName(dbMapping)).append(" (");
insertSql.append("INSERT INTO ")
.append(SyncUtil.getDbTableName(dbMapping, dataSource.getDbType()))
.append(" (");
columnsMap
.forEach((targetColumnName, srcColumnName) -> insertSql.append(targetColumnName).append(","));

Expand All @@ -107,7 +113,7 @@ protected boolean executeSqlImport(DataSource srcDS, String sql, List<Object> va
// 删除数据
Map<String, Object> pkVal = new LinkedHashMap<>();
StringBuilder deleteSql = new StringBuilder(
"DELETE FROM " + SyncUtil.getDbTableName(dbMapping) + " WHERE ");
"DELETE FROM " + SyncUtil.getDbTableName(dbMapping, dataSource.getDbType()) + " WHERE ");
appendCondition(dbMapping, deleteSql, pkVal, rs);
try (PreparedStatement pstmt2 = connTarget.prepareStatement(deleteSql.toString())) {
int k = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@
import java.util.List;
import java.util.Map;

import javax.sql.DataSource;

import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.alibaba.druid.pool.DruidDataSource;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.alibaba.otter.canal.client.adapter.rdb.config.MappingConfig;
import com.alibaba.otter.canal.client.adapter.rdb.config.MirrorDbConfig;
import com.alibaba.otter.canal.client.adapter.rdb.support.SingleDml;
import com.alibaba.otter.canal.client.adapter.rdb.support.SyncUtil;
import com.alibaba.otter.canal.client.adapter.support.Dml;

/**
Expand All @@ -31,10 +31,10 @@ public class RdbMirrorDbSyncService {
private static final Logger logger = LoggerFactory.getLogger(RdbMirrorDbSyncService.class);

private Map<String, MirrorDbConfig> mirrorDbConfigCache; // 镜像库配置
private DataSource dataSource;
private DruidDataSource dataSource;
private RdbSyncService rdbSyncService; // rdbSyncService代理

public RdbMirrorDbSyncService(Map<String, MirrorDbConfig> mirrorDbConfigCache, DataSource dataSource,
public RdbMirrorDbSyncService(Map<String, MirrorDbConfig> mirrorDbConfigCache, DruidDataSource dataSource,
Integer threads, Map<String, Map<String, Integer>> columnsTypeCache,
boolean skipDupException){
this.mirrorDbConfigCache = mirrorDbConfigCache;
Expand Down Expand Up @@ -153,7 +153,13 @@ private void initMappingConfig(String key, MappingConfig baseConfigMap, MirrorDb
*/
private void executeDdl(MirrorDbConfig mirrorDbConfig, Dml ddl) {
try (Connection conn = dataSource.getConnection(); Statement statement = conn.createStatement()) {
statement.execute(ddl.getSql());
// 替换反引号
String sql = ddl.getSql();
String backtick = SyncUtil.getBacktickByDbType(dataSource.getDbType());
if (!"`".equals(backtick)) {
sql = sql.replaceAll("`", backtick);
}
statement.execute(sql);
// 移除对应配置
mirrorDbConfig.getTableConfig().remove(ddl.getTable());
if (logger.isTraceEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
import java.util.concurrent.Future;
import java.util.function.Function;

import javax.sql.DataSource;

import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.alibaba.druid.pool.DruidDataSource;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.alibaba.otter.canal.client.adapter.rdb.config.MappingConfig;
Expand All @@ -41,6 +40,7 @@ public class RdbSyncService {

private static final Logger logger = LoggerFactory.getLogger(RdbSyncService.class);

private DruidDataSource dataSource;
// 源库表字段类型缓存: instance.schema.table -> <columnName, jdbcType>
private Map<String, Map<String, Integer>> columnsTypeCache;

Expand All @@ -59,13 +59,14 @@ public Map<String, Map<String, Integer>> getColumnsTypeCache() {
return columnsTypeCache;
}

public RdbSyncService(DataSource dataSource, Integer threads, boolean skipDupException){
public RdbSyncService(DruidDataSource dataSource, Integer threads, boolean skipDupException){
this(dataSource, threads, new ConcurrentHashMap<>(), skipDupException);
}

@SuppressWarnings("unchecked")
public RdbSyncService(DataSource dataSource, Integer threads, Map<String, Map<String, Integer>> columnsTypeCache,
public RdbSyncService(DruidDataSource dataSource, Integer threads, Map<String, Map<String, Integer>> columnsTypeCache,
boolean skipDupException){
this.dataSource = dataSource;
this.columnsTypeCache = columnsTypeCache;
this.skipDupException = skipDupException;
try {
Expand Down Expand Up @@ -241,15 +242,15 @@ private void insert(BatchExecutor batchExecutor, MappingConfig config, SingleDml
}

DbMapping dbMapping = config.getDbMapping();

String backtick = SyncUtil.getBacktickByDbType(dataSource.getDbType());
Map<String, String> columnsMap = SyncUtil.getColumnsMap(dbMapping, data);

StringBuilder insertSql = new StringBuilder();
insertSql.append("INSERT INTO ").append(SyncUtil.getDbTableName(dbMapping)).append(" (");
insertSql.append("INSERT INTO ").append(SyncUtil.getDbTableName(dbMapping, dataSource.getDbType())).append(" (");

columnsMap.forEach((targetColumnName, srcColumnName) -> insertSql.append("`")
columnsMap.forEach((targetColumnName, srcColumnName) -> insertSql.append(backtick)
.append(targetColumnName)
.append("`")
.append(backtick)
.append(","));
int len = insertSql.length();
insertSql.delete(len - 1, len).append(") VALUES (");
Expand Down Expand Up @@ -313,13 +314,13 @@ private void update(BatchExecutor batchExecutor, MappingConfig config, SingleDml
}

DbMapping dbMapping = config.getDbMapping();

String backtick = SyncUtil.getBacktickByDbType(dataSource.getDbType());
Map<String, String> columnsMap = SyncUtil.getColumnsMap(dbMapping, data);

Map<String, Integer> ctype = getTargetColumnType(batchExecutor.getConn(), config);

StringBuilder updateSql = new StringBuilder();
updateSql.append("UPDATE ").append(SyncUtil.getDbTableName(dbMapping)).append(" SET ");
updateSql.append("UPDATE ").append(SyncUtil.getDbTableName(dbMapping, dataSource.getDbType())).append(" SET ");
List<Map<String, ?>> values = new ArrayList<>();
boolean hasMatched = false;
for (String srcColumnName : old.keySet()) {
Expand All @@ -332,7 +333,7 @@ private void update(BatchExecutor batchExecutor, MappingConfig config, SingleDml
if (!targetColumnNames.isEmpty()) {
hasMatched = true;
for (String targetColumnName : targetColumnNames) {
updateSql.append("`").append(targetColumnName).append("`").append("=?, ");
updateSql.append(backtick).append(targetColumnName).append(backtick).append("=?, ");
Integer type = ctype.get(Util.cleanColumn(targetColumnName).toLowerCase());
if (type == null) {
throw new RuntimeException("Target column: " + targetColumnName + " not matched");
Expand Down Expand Up @@ -369,11 +370,10 @@ private void delete(BatchExecutor batchExecutor, MappingConfig config, SingleDml
}

DbMapping dbMapping = config.getDbMapping();

Map<String, Integer> ctype = getTargetColumnType(batchExecutor.getConn(), config);

StringBuilder sql = new StringBuilder();
sql.append("DELETE FROM ").append(SyncUtil.getDbTableName(dbMapping)).append(" WHERE ");
sql.append("DELETE FROM ").append(SyncUtil.getDbTableName(dbMapping, dataSource.getDbType())).append(" WHERE ");

List<Map<String, ?>> values = new ArrayList<>();
// 拼接主键
Expand All @@ -392,7 +392,7 @@ private void delete(BatchExecutor batchExecutor, MappingConfig config, SingleDml
private void truncate(BatchExecutor batchExecutor, MappingConfig config) throws SQLException {
DbMapping dbMapping = config.getDbMapping();
StringBuilder sql = new StringBuilder();
sql.append("TRUNCATE TABLE ").append(SyncUtil.getDbTableName(dbMapping));
sql.append("TRUNCATE TABLE ").append(SyncUtil.getDbTableName(dbMapping, dataSource.getDbType()));
batchExecutor.execute(sql.toString(), new ArrayList<>());
if (logger.isTraceEnabled()) {
logger.trace("Truncate target table, sql: {}", sql);
Expand All @@ -416,7 +416,7 @@ private Map<String, Integer> getTargetColumnType(Connection conn, MappingConfig
if (columnType == null) {
columnType = new LinkedHashMap<>();
final Map<String, Integer> columnTypeTmp = columnType;
String sql = "SELECT * FROM " + SyncUtil.getDbTableName(dbMapping) + " WHERE 1=2";
String sql = "SELECT * FROM " + SyncUtil.getDbTableName(dbMapping, dataSource.getDbType()) + " WHERE 1=2";
Util.sqlRS(conn, sql, rs -> {
try {
ResultSetMetaData rsd = rs.getMetaData();
Expand Down Expand Up @@ -445,14 +445,16 @@ private void appendCondition(MappingConfig.DbMapping dbMapping, StringBuilder sq

private void appendCondition(MappingConfig.DbMapping dbMapping, StringBuilder sql, Map<String, Integer> ctype,
List<Map<String, ?>> values, Map<String, Object> d, Map<String, Object> o) {
String backtick = SyncUtil.getBacktickByDbType(dataSource.getDbType());

// 拼接主键
for (Map.Entry<String, String> entry : dbMapping.getTargetPk().entrySet()) {
String targetColumnName = entry.getKey();
String srcColumnName = entry.getValue();
if (srcColumnName == null) {
srcColumnName = Util.cleanColumn(targetColumnName);
}
sql.append("`").append(targetColumnName).append("`").append("=? AND ");
sql.append(backtick).append(targetColumnName).append(backtick).append("=? AND ");
Integer type = ctype.get(Util.cleanColumn(targetColumnName).toLowerCase());
if (type == null) {
throw new RuntimeException("Target column: " + targetColumnName + " not matched");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.alibaba.otter.canal.client.adapter.rdb.support;

import com.alibaba.druid.DbType;
import com.alibaba.otter.canal.client.adapter.rdb.config.MappingConfig;
import com.alibaba.otter.canal.client.adapter.support.Util;
import org.apache.commons.lang.StringUtils;
Expand Down Expand Up @@ -255,12 +256,36 @@ public static void setPStmt(int type, PreparedStatement pstmt, Object value, int
}
}

public static String getDbTableName(MappingConfig.DbMapping dbMapping) {
public static String getDbTableName(MappingConfig.DbMapping dbMapping, String dbType) {
String result = "";
String backtick = getBacktickByDbType(dbType);
if (StringUtils.isNotEmpty(dbMapping.getTargetDb())) {
result += ("`" + dbMapping.getTargetDb() + "`.");
result += (backtick + dbMapping.getTargetDb() + backtick + ".");
}
result += ("`" + dbMapping.getTargetTable() + "`");
result += (backtick + dbMapping.getTargetTable() + backtick);
return result;
}

/**
* 根据DbType返回反引号或空字符串
*
* @param dbTypeName DbType名称
* @return 反引号或空字符串
*/
public static String getBacktickByDbType(String dbTypeName) {
DbType dbType = DbType.of(dbTypeName);
if (dbType == null) {
dbType = DbType.other;
}

// 只有当dbType为MySQL/MariaDB或OceanBase时返回反引号
switch (dbType) {
case mysql:
case mariadb:
case oceanbase:
return "`";
default:
return "";
}
}
}

0 comments on commit f24ab0a

Please sign in to comment.