diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index f51b0ef062080..9113fb350f637 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -181,6 +181,27 @@ def test_insert_into(self): df.write.mode("overwrite").insertInto("test_table", False) self.assertEqual(6, self.spark.sql("select * from test_table").count()) + def test_cached_table(self): + with self.table("test_cached_table_1"): + self.spark.range(10).withColumn( + "value_1", + lit(1), + ).write.saveAsTable("test_cached_table_1") + + with self.table("test_cached_table_2"): + self.spark.range(10).withColumnRenamed("id", "index").withColumn( + "value_2", lit(2) + ).write.saveAsTable("test_cached_table_2") + + df1 = self.spark.read.table("test_cached_table_1") + df2 = self.spark.read.table("test_cached_table_2") + df3 = self.spark.read.table("test_cached_table_1") + + join1 = df1.join(df2, on=df1.id == df2.index).select(df2.index, df2.value_2) + join2 = df3.join(join1, how="left", on=join1.index == df3.id) + + self.assertEqual(join2.columns, ["id", "value_1", "index", "value_2"]) + class ReadwriterV2TestsMixin: def test_api(self): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 74061f2b8f214..f3e82f6c71e34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1257,16 +1257,29 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor expandIdentifier(u.multipartIdentifier) match { case CatalogAndIdentifier(catalog, ident) => val key = ((catalog.name +: ident.namespace :+ ident.name).toSeq, timeTravelSpec) - AnalysisContext.get.relationCache.get(key).map(_.transform { - case multi: MultiInstanceRelation => - val newRelation = multi.newInstance() - newRelation.copyTagsFrom(multi) - newRelation - }).orElse { + AnalysisContext.get.relationCache.get(key).map { cache => + val cachedRelation = cache.transform { + case multi: MultiInstanceRelation => + val newRelation = multi.newInstance() + newRelation.copyTagsFrom(multi) + newRelation + } + u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => + val cachedConnectRelation = cachedRelation.clone() + cachedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) + cachedConnectRelation + }.getOrElse(cachedRelation) + }.orElse { val table = CatalogV2Util.loadTable(catalog, ident, timeTravelSpec) val loaded = createRelation(catalog, ident, table, u.options, u.isStreaming) loaded.foreach(AnalysisContext.get.relationCache.update(key, _)) - loaded + u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => + loaded.map { loadedRelation => + val loadedConnectRelation = loadedRelation.clone() + loadedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) + loadedConnectRelation + } + }.getOrElse(loaded) } case _ => None }