diff --git a/core/src/test/scala/com/pingcap/tispark/accumulator/AccumulatorSuite.scala b/core/src/test/scala/com/pingcap/tispark/accumulator/AccumulatorSuite.scala new file mode 100644 index 0000000000..83cc63441f --- /dev/null +++ b/core/src/test/scala/com/pingcap/tispark/accumulator/AccumulatorSuite.scala @@ -0,0 +1,72 @@ +/* + * + * Copyright 2022 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package com.pingcap.tispark.accumulator; + +import org.apache.log4j.spi.LoggingEvent +import org.apache.log4j.{AppenderSkeleton, Logger} +import org.apache.spark.sql.BaseTiSparkTest + +import java.util +import java.util.stream.Collectors; + +class AccumulatorSuite extends BaseTiSparkTest { + test("cacheInvalidateCallback does not work") { + val listLogAppender = new ListLogAppender + val logger = Logger.getRootLogger + logger.addAppender(listLogAppender) + try { + tidbStmt.execute("DROP TABLE IF EXISTS `t1`") + tidbStmt.execute(""" + |CREATE TABLE `t1` ( + |`a` BIGINT(20) NOT NULL, + |`b` varchar(255) NOT NULL, + |`c` varchar(255) DEFAULT NULL + |) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin + """.stripMargin) + spark.sql("SELECT * FROM t1").show() + tidbStmt.execute( + "SPLIT TABLE t1 BETWEEN (-9223372036854775808) AND (-8223372036854775808) REGIONS 300") + spark.sql("SELECT * FROM t1").show() + } finally { + logger.removeAppender(listLogAppender) + } + val cacheInvalidateListenerLog = listLogAppender.getLog + .stream() + .filter(e => + e.getMessage.toString.contains( + "Failed to send notification back to driver since CacheInvalidateCallBack is null in executor node")) + .collect(Collectors.toList[LoggingEvent]) + assert(cacheInvalidateListenerLog.size() == 0) + } + + class ListLogAppender extends AppenderSkeleton { + + final private val log = new util.ArrayList[LoggingEvent]() + + override def requiresLayout = false + + override protected def append(loggingEvent: LoggingEvent): Unit = { + log.add(loggingEvent) + } + + override def close(): Unit = {} + + def getLog = new util.ArrayList[LoggingEvent](log) + } +} diff --git a/core/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanTestSuite.scala b/core/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanTestSuite.scala index fc2c0aa75c..a6b61aee87 100644 --- a/core/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanTestSuite.scala +++ b/core/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanTestSuite.scala @@ -61,37 +61,37 @@ class LogicalPlanTestSuite extends BasePlanTest { refreshConnections() val df = spark.sql(""" - |select t1.*, ( - | select count(*) - | from test2 - | where id > 1 - |), t1.c1, t2.c1, t3.*, t4.c3 - |from ( - | select id, c1, c2 - | from test1) t1 - |left join ( - | select id, c1, c2, c1 + coalesce(c2 % 2) as c3 - | from test2 where c1 + c2 > 3) t2 - |on t1.id = t2.id - |left join ( - | select max(id) as id, min(c1) + c2 as c1, c2, count(*) as c3 - | from test3 - | where c2 <= 3 and exists ( - | select * from ( - | select id as c1 from test3) - | where ( - | select max(id) from test1) = 4) - | group by c2) t3 - |on t1.id = t3.id - |left join ( - | select max(id) as id, min(c1) as c1, max(c1) as c1, count(*) as c2, c2 as c3 - | from test3 - | where id not in ( - | select id - | from test1 - | where c2 > 2) - | group by c2) t4 - |on t1.id = t4.id + |select t1.*, ( + | select count(*) + | from test2 + | where id > 1 + |), t1.c1, t2.c1, t3.*, t4.c3 + |from ( + | select id, c1, c2 + | from test1) t1 + |left join ( + | select id, c1, c2, c1 + coalesce(c2 % 2) as c3 + | from test2 where c1 + c2 > 3) t2 + |on t1.id = t2.id + |left join ( + | select max(id) as id, min(c1) + c2 as c1, c2, count(*) as c3 + | from test3 + | where c2 <= 3 and exists ( + | select * from ( + | select id as c1 from test3) + | where ( + | select max(id) from test1) = 4) + | group by c2) t3 + |on t1.id = t3.id + |left join ( + | select max(id) as id, min(c1) as c1, max(c1) as c1, count(*) as c2, c2 as c3 + | from test3 + | where id not in ( + | select id + | from test1 + | where c2 > 2) + | group by c2) t4 + |on t1.id = t4.id """.stripMargin) var v: TiTimestamp = null diff --git a/tikv-client/src/main/java/com/pingcap/tikv/TiSession.java b/tikv-client/src/main/java/com/pingcap/tikv/TiSession.java index 6582151469..d9b550608a 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/TiSession.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/TiSession.java @@ -296,6 +296,8 @@ public ExecutorService getThreadPoolForDeleteRange() { */ public void injectCallBackFunc(Function callBackFunc) { this.cacheInvalidateCallback = callBackFunc; + RegionManager manager = this.getRegionManager(); + manager.setCacheInvalidateCallback(callBackFunc); } /** diff --git a/tikv-client/src/main/java/com/pingcap/tikv/region/RegionManager.java b/tikv-client/src/main/java/com/pingcap/tikv/region/RegionManager.java index e24f5e06a7..086e9e448f 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/region/RegionManager.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/region/RegionManager.java @@ -49,7 +49,7 @@ public class RegionManager { // https://github.com/pingcap/tispark/issues/1170 private final RegionCache cache; - private final Function cacheInvalidateCallback; + private Function cacheInvalidateCallback; private AtomicInteger tiflashStoreIndex = new AtomicInteger(0); @@ -66,6 +66,11 @@ public RegionManager(ReadOnlyPDClient pdClient) { this.cacheInvalidateCallback = null; } + public synchronized void setCacheInvalidateCallback( + Function cacheInvalidateCallback) { + this.cacheInvalidateCallback = cacheInvalidateCallback; + } + public Function getCacheInvalidateCallback() { return cacheInvalidateCallback; }