diff --git a/velox/core/PlanNode.cpp b/velox/core/PlanNode.cpp index 6d75abc3d3b5..4b751ad2231e 100644 --- a/velox/core/PlanNode.cpp +++ b/velox/core/PlanNode.cpp @@ -2319,6 +2319,18 @@ folly::dynamic PlanNode::serialize() const { return obj; } +const std::vector& QueryTraceScanNode::sources() const { + return kEmptySources; +} + +std::string QueryTraceScanNode::traceDir() const { + return traceDir_; +} + +void QueryTraceScanNode::addDetails(std::stringstream& stream) const { + stream << "Trace dir: " << traceDir_; +} + folly::dynamic FilterNode::serialize() const { auto obj = PlanNode::serialize(); obj["filter"] = filter_->serialize(); diff --git a/velox/core/PlanNode.h b/velox/core/PlanNode.h index 31b7b65bfdd2..1b168b56304f 100644 --- a/velox/core/PlanNode.h +++ b/velox/core/PlanNode.h @@ -17,6 +17,8 @@ #include +#include + #include "velox/connectors/Connector.h" #include "velox/core/Expressions.h" #include "velox/core/QueryConfig.h" @@ -312,6 +314,39 @@ class ArrowStreamNode : public PlanNode { std::shared_ptr arrowStream_; }; +class QueryTraceScanNode final : public PlanNode { + public: + QueryTraceScanNode( + const PlanNodeId& id, + const std::string& traceDir, + const RowTypePtr& outputType) + : PlanNode(id), traceDir_(traceDir), outputType_(outputType) {} + + const RowTypePtr& outputType() const override { + return outputType_; + } + + const std::vector& sources() const override; + + std::string_view name() const override { + return "QueryReplayScan"; + } + + folly::dynamic serialize() const override { + VELOX_UNSUPPORTED("QueryReplayScanNode is not serializable"); + return nullptr; + } + + std::string traceDir() const; + + private: + void addDetails(std::stringstream& stream) const override; + + // Directory of traced data, which is $traceRoot/$taskId/$nodeId. + const std::string traceDir_; + const RowTypePtr outputType_; +}; + class FilterNode : public PlanNode { public: FilterNode(const PlanNodeId& id, TypedExprPtr filter, PlanNodePtr source) diff --git a/velox/exec/LocalPlanner.cpp b/velox/exec/LocalPlanner.cpp index bf99d78e4cf0..283bb007ae54 100644 --- a/velox/exec/LocalPlanner.cpp +++ b/velox/exec/LocalPlanner.cpp @@ -45,6 +45,7 @@ #include "velox/exec/Unnest.h" #include "velox/exec/Values.h" #include "velox/exec/Window.h" +#include "velox/exec/trace/QueryTraceScan.h" namespace facebook::velox::exec { @@ -587,6 +588,12 @@ std::shared_ptr DriverFactory::createDriver( assignUniqueIdNode, assignUniqueIdNode->taskUniqueId(), assignUniqueIdNode->uniqueIdCounter())); + } else if ( + const auto queryReplayScanNode = + std::dynamic_pointer_cast( + planNode)) { + operators.push_back(std::make_unique( + id, ctx.get(), queryReplayScanNode)); } else { std::unique_ptr extended; if (planNode->requiresExchangeClient()) { diff --git a/velox/exec/tests/utils/AssertQueryBuilder.cpp b/velox/exec/tests/utils/AssertQueryBuilder.cpp index b7792d42a43e..554ee2bb3483 100644 --- a/velox/exec/tests/utils/AssertQueryBuilder.cpp +++ b/velox/exec/tests/utils/AssertQueryBuilder.cpp @@ -96,6 +96,18 @@ AssertQueryBuilder& AssertQueryBuilder::connectorSessionProperty( return *this; } +AssertQueryBuilder& AssertQueryBuilder::connectorSessionProperties( + const std::unordered_map< + std::string, + std::unordered_map>& properties) { + for (const auto& [connectorId, values] : properties) { + for (const auto& [key, value] : values) { + connectorSessionProperty(connectorId, key, value); + } + } + return *this; +} + AssertQueryBuilder& AssertQueryBuilder::split(Split split) { this->split(getOnlyLeafPlanNodeId(params_.planNode), std::move(split)); return *this; diff --git a/velox/exec/tests/utils/AssertQueryBuilder.h b/velox/exec/tests/utils/AssertQueryBuilder.h index 257fe624a7f8..4d5d4299d177 100644 --- a/velox/exec/tests/utils/AssertQueryBuilder.h +++ b/velox/exec/tests/utils/AssertQueryBuilder.h @@ -65,6 +65,11 @@ class AssertQueryBuilder { const std::string& key, const std::string& value); + AssertQueryBuilder& connectorSessionProperties( + const std::unordered_map< + std::string, + std::unordered_map>& properties); + // Methods to add splits. /// Add a single split for the specified plan node. diff --git a/velox/exec/tests/utils/PlanBuilder.cpp b/velox/exec/tests/utils/PlanBuilder.cpp index ae5cd8b16a22..31c3a5d39cba 100644 --- a/velox/exec/tests/utils/PlanBuilder.cpp +++ b/velox/exec/tests/utils/PlanBuilder.cpp @@ -220,6 +220,14 @@ PlanBuilder& PlanBuilder::values( return *this; } +PlanBuilder& PlanBuilder::traceScan( + const std::string& traceNodeDir, + const RowTypePtr& outputType) { + planNode_ = std::make_shared( + nextPlanNodeId(), traceNodeDir, outputType); + return *this; +} + PlanBuilder& PlanBuilder::exchange(const RowTypePtr& outputType) { VELOX_CHECK_NULL(planNode_, "Exchange must be the source node"); planNode_ = diff --git a/velox/exec/tests/utils/PlanBuilder.h b/velox/exec/tests/utils/PlanBuilder.h index 0b0f2eb180de..f344899e4639 100644 --- a/velox/exec/tests/utils/PlanBuilder.h +++ b/velox/exec/tests/utils/PlanBuilder.h @@ -308,6 +308,14 @@ class PlanBuilder { bool parallelizable = false, size_t repeatTimes = 1); + /// Adds a QueryReplayNode for query tracing. + /// + /// @param traceNodeDir The trace directory for a given plan node. + /// @param outputType The type of the tracing data. + PlanBuilder& traceScan( + const std::string& traceNodeDir, + const RowTypePtr& outputType); + /// Add an ExchangeNode. /// /// Use capturePlanNodeId method to capture the node ID needed for adding diff --git a/velox/exec/trace/CMakeLists.txt b/velox/exec/trace/CMakeLists.txt index 555c0e270dd1..13404d36f3c1 100644 --- a/velox/exec/trace/CMakeLists.txt +++ b/velox/exec/trace/CMakeLists.txt @@ -16,7 +16,9 @@ velox_add_library( velox_query_trace_exec QueryMetadataWriter.cpp QueryTraceConfig.cpp + QueryDataReader.cpp QueryDataWriter.cpp + QueryTraceScan.cpp QueryTraceUtil.cpp) velox_link_libraries( diff --git a/velox/exec/trace/QueryDataReader.cpp b/velox/exec/trace/QueryDataReader.cpp index f0175fe11b36..4585c0dd05b8 100644 --- a/velox/exec/trace/QueryDataReader.cpp +++ b/velox/exec/trace/QueryDataReader.cpp @@ -14,6 +14,8 @@ * limitations under the License. */ +#include + #include "velox/exec/trace/QueryDataReader.h" #include "velox/common/file/File.h" @@ -21,11 +23,14 @@ namespace facebook::velox::exec::trace { -QueryDataReader::QueryDataReader(std::string path, memory::MemoryPool* pool) - : path_(std::move(path)), - fs_(filesystems::getFileSystem(path_, nullptr)), +QueryDataReader::QueryDataReader( + std::string traceDir, + RowTypePtr dataType, + memory::MemoryPool* pool) + : traceDir_(std::move(traceDir)), + fs_(filesystems::getFileSystem(traceDir_, nullptr)), + dataType_(std::move(dataType)), pool_(pool), - dataType_(getTraceDataType()), dataStream_(getDataInputStream()) { VELOX_CHECK_NOT_NULL(dataType_); VELOX_CHECK_NOT_NULL(dataStream_); @@ -42,19 +47,10 @@ bool QueryDataReader::read(RowVectorPtr& batch) const { return true; } -RowTypePtr QueryDataReader::getTraceDataType() const { - const auto summaryFile = fs_->openFileForRead( - fmt::format("{}/{}", path_, QueryTraceTraits::kDataSummaryFileName)); - const auto summary = summaryFile->pread(0, summaryFile->size()); - VELOX_USER_CHECK(!summary.empty()); - folly::dynamic obj = folly::parseJson(summary); - return ISerializable::deserialize(obj["rowType"]); -} - std::unique_ptr QueryDataReader::getDataInputStream() const { auto dataFile = fs_->openFileForRead( - fmt::format("{}/{}", path_, QueryTraceTraits::kDataFileName)); + fmt::format("{}/{}", traceDir_, QueryTraceTraits::kDataFileName)); // TODO: Make the buffer size configurable. return std::make_unique( std::move(dataFile), 1 << 20, pool_); diff --git a/velox/exec/trace/QueryDataReader.h b/velox/exec/trace/QueryDataReader.h index ad61bde5886f..b5e6d24e011a 100644 --- a/velox/exec/trace/QueryDataReader.h +++ b/velox/exec/trace/QueryDataReader.h @@ -27,25 +27,26 @@ namespace facebook::velox::exec::trace { class QueryDataReader { public: - explicit QueryDataReader(std::string path, memory::MemoryPool* pool); + explicit QueryDataReader( + std::string traceDir, + RowTypePtr dataType, + memory::MemoryPool* pool); /// Reads from 'dataStream_' and deserializes to 'batch'. Returns false if /// reaches to end of the stream and 'batch' is set to nullptr. bool read(RowVectorPtr& batch) const; private: - RowTypePtr getTraceDataType() const; - std::unique_ptr getDataInputStream() const; - const std::string path_; + const std::string traceDir_; const serializer::presto::PrestoVectorSerde::PrestoOptions readOptions_{ true, common::CompressionKind_ZSTD, // TODO: Use trace config. /*nullsFirst=*/true}; const std::shared_ptr fs_; - memory::MemoryPool* const pool_; const RowTypePtr dataType_; + memory::MemoryPool* const pool_; const std::unique_ptr dataStream_; }; } // namespace facebook::velox::exec::trace diff --git a/velox/exec/trace/QueryTraceScan.cpp b/velox/exec/trace/QueryTraceScan.cpp new file mode 100644 index 000000000000..1718cb3f7fd2 --- /dev/null +++ b/velox/exec/trace/QueryTraceScan.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/trace/QueryTraceScan.h" + +#include "QueryTraceUtil.h" + +namespace facebook::velox::exec::trace { + +QueryTraceScan::QueryTraceScan( + int32_t operatorId, + DriverCtx* driverCtx, + const std::shared_ptr& queryTraceScanNode) + : SourceOperator( + driverCtx, + queryTraceScanNode->outputType(), + operatorId, + queryTraceScanNode->id(), + "QueryReplayScan") { + const auto dataDir = getDataDir( + queryTraceScanNode->traceDir(), + driverCtx->pipelineId, + driverCtx->driverId); + traceReader_ = std::make_unique( + dataDir, + queryTraceScanNode->outputType(), + memory::MemoryManager::getInstance()->tracePool()); +} + +RowVectorPtr QueryTraceScan::getOutput() { + RowVectorPtr batch; + if (traceReader_->read(batch)) { + return batch; + } + finished_ = true; + return nullptr; +} + +bool QueryTraceScan::isFinished() { + return finished_; +} + +} // namespace facebook::velox::exec::trace diff --git a/velox/exec/trace/QueryTraceScan.h b/velox/exec/trace/QueryTraceScan.h new file mode 100644 index 000000000000..6c0f25c8d314 --- /dev/null +++ b/velox/exec/trace/QueryTraceScan.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/core/PlanNode.h" +#include "velox/exec/Operator.h" +#include "velox/exec/trace/QueryDataReader.h" + +namespace facebook::velox::exec::trace { +/// This is a scan operator for query replay. It uses traced data from a +/// specific directory path, which is +/// $traceRoot/$taskId/$nodeId/$pipelineId/$driverId. +/// +/// A plan node can be split into multiple pipelines, and each pipeline can be +/// divided into multiple operators. Each operator corresponds to a driver, +/// which is a thread of execution. Pipeline IDs and driver IDs are sequential +/// numbers starting from zero. +/// +/// For a single plan node, there can be multiple traced data files. To find the +/// right input data file for replaying, we need to use both the pipeline ID and +/// driver ID. +/// +/// The trace data directory up to the $nodeId, which is $root/$taskId/$nodeId. +/// It can be found from the QueryReplayScanNode. However the pipeline ID and +/// driver ID are only known during operator creation, so we need to figure out +/// the input traced data file and the output type dynamically. +class QueryTraceScan final : public SourceOperator { + public: + QueryTraceScan( + int32_t operatorId, + DriverCtx* driverCtx, + const std::shared_ptr& + queryTraceScanNode); + + RowVectorPtr getOutput() override; + + BlockingReason isBlocked(ContinueFuture* /* unused */) override { + return BlockingReason::kNotBlocked; + } + + bool isFinished() override; + + private: + std::unique_ptr traceReader_; + bool finished_{false}; +}; + +} // namespace facebook::velox::exec::trace diff --git a/velox/exec/trace/QueryTraceUtil.cpp b/velox/exec/trace/QueryTraceUtil.cpp index a0ed04fe03c9..19bec09fd6e2 100644 --- a/velox/exec/trace/QueryTraceUtil.cpp +++ b/velox/exec/trace/QueryTraceUtil.cpp @@ -16,10 +16,13 @@ #include "velox/exec/trace/QueryTraceUtil.h" +#include + +#include "QueryTraceTraits.h" + #include #include "velox/common/base/Exceptions.h" -#include "velox/common/file/File.h" #include "velox/common/file/FileSystems.h" namespace facebook::velox::exec::trace { @@ -75,4 +78,36 @@ folly::dynamic getMetadata( } } +RowTypePtr getDataType( + const core::PlanNodePtr& tracedPlan, + const std::string& tracedNodeId, + size_t sourceIndex) { + const auto* traceNode = core::PlanNode::findFirstNode( + tracedPlan.get(), [&tracedNodeId](const core::PlanNode* node) { + return node->id() == tracedNodeId; + }); + VELOX_CHECK_NOT_NULL( + traceNode, + "traced node id {} not found in the traced plan", + tracedNodeId); + return traceNode->sources().at(sourceIndex)->outputType(); +} + +uint8_t getNumDrivers( + const std::string& rootDir, + const std::string& taskId, + const std::string& nodeId, + int32_t pipelineId, + const std::shared_ptr& fs) { + const auto traceDir = + fmt::format("{}/{}/{}/{}", rootDir, taskId, nodeId, pipelineId); + const auto driverDirs = fs->list(traceDir); + return driverDirs.size(); +} + +std::string +getDataDir(const std::string& traceDir, int pipelineId, int driverId) { + return fmt::format("{}/{}/{}/data", traceDir, pipelineId, driverId); +} + } // namespace facebook::velox::exec::trace diff --git a/velox/exec/trace/QueryTraceUtil.h b/velox/exec/trace/QueryTraceUtil.h index e006b2fdb18f..83ffc3d7f719 100644 --- a/velox/exec/trace/QueryTraceUtil.h +++ b/velox/exec/trace/QueryTraceUtil.h @@ -19,7 +19,9 @@ #include #include #include "velox/common/file/FileSystems.h" +#include "velox/type/Type.h" +#include #include namespace facebook::velox::exec::trace { @@ -27,14 +29,47 @@ namespace facebook::velox::exec::trace { /// Creates a directory to store the query trace metdata and data. void createTraceDirectory(const std::string& traceDir); +/// Extracts the input data type for the trace scan operator. The function first +/// uses the traced node id to find traced operator's plan node from the traced +/// plan fragment. Then it uses the specified source node index to find the +/// output data type from its source node plans as the input data type of the +/// traced plan node. +/// +/// For hash join plan node, there are two source nodes, the output data type +/// of the first node is the input data type of the 'HashProbe' operator, and +/// the output data type of the second one is the input data type of the +/// 'HashBuild' operator. +/// +/// @param tracedPlan The root node of the trace plan fragment. +/// @param tracedNodeId The node id of the trace node. +/// @param sourceIndex The source index of the specific traced operator. +RowTypePtr getDataType( + const core::PlanNodePtr& tracedPlan, + const std::string& tracedNodeId, + size_t sourceIndex = 0); + +/// Extracts the number of drivers by listing the number of sub-directors under +/// the trace directory for a given pipeline. +uint8_t getNumDrivers( + const std::string& rootDir, + const std::string& taskId, + const std::string& nodeId, + int32_t pipelineId, + const std::shared_ptr& fs); + /// Extracts task ids of the query tracing by listing the trace directory. std::vector getTaskIds( const std::string& traceDir, const std::shared_ptr& fs); -/// Gets the metadata from the given task directory which includes query plan, +/// Gets the metadata from a given task metadata file which includes query plan, /// configs and connector properties. folly::dynamic getMetadata( - const std::string& traceTaskDir, + const std::string& metadataFile, const std::shared_ptr& fs); + +/// Gets the traced data directory. 'traceaDir' is the trace directory for a +/// given plan node, which is $traceRoot/$taskId/$nodeId. +std::string +getDataDir(const std::string& traceDir, int pipelineId, int driverId); } // namespace facebook::velox::exec::trace diff --git a/velox/exec/trace/test/QueryTraceTest.cpp b/velox/exec/trace/test/QueryTraceTest.cpp index 5c3122aefc06..4b905ab3f21c 100644 --- a/velox/exec/trace/test/QueryTraceTest.cpp +++ b/velox/exec/trace/test/QueryTraceTest.cpp @@ -158,7 +158,7 @@ TEST_F(QueryTracerTest, traceData) { continue; } - const auto reader = QueryDataReader(outputDir->getPath(), pool()); + const auto reader = QueryDataReader(outputDir->getPath(), rowType, pool()); RowVectorPtr actual; size_t numOutputVectors{0}; while (reader.read(actual)) { @@ -502,7 +502,7 @@ TEST_F(QueryTracerTest, traceTableWriter) { obj[QueryTraceTraits::kTraceLimitExceededKey].asBool(), testData.limitExceeded); - const auto reader = trace::QueryDataReader(dataDir, pool()); + const auto reader = trace::QueryDataReader(dataDir, rowType, pool()); RowVectorPtr actual; size_t numOutputVectors{0}; while (reader.read(actual)) { diff --git a/velox/tool/trace/CMakeLists.txt b/velox/tool/trace/CMakeLists.txt index 643376c7c942..aad2d45bc9b8 100644 --- a/velox/tool/trace/CMakeLists.txt +++ b/velox/tool/trace/CMakeLists.txt @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(query_trace_replayer_base QueryTraceReplayer.cpp) +velox_add_library(velox_query_trace_replayer_base OperatorReplayerBase.cpp + TableWriterReplayer.cpp) velox_link_libraries( - query_trace_replayer_base + velox_query_trace_replayer_base velox_query_trace_retrieve velox_aggregates velox_type @@ -28,7 +29,13 @@ velox_link_libraries( glog::glog gflags::gflags) -add_executable(query_replayer QueryReplayer.cpp) +add_executable(velox_query_replayer QueryReplayer.cpp) target_link_libraries( - query_replayer query_trace_replayer_base) + velox_query_replayer + velox_query_trace_replayer_base + velox_exec + velox_exec_test_lib + velox_tpch_connector) + +add_subdirectory(test) diff --git a/velox/tool/trace/QueryTraceReplayer.cpp b/velox/tool/trace/OperatorReplayerBase.cpp similarity index 63% rename from velox/tool/trace/QueryTraceReplayer.cpp rename to velox/tool/trace/OperatorReplayerBase.cpp index 862f9a3505d0..a999110131db 100644 --- a/velox/tool/trace/QueryTraceReplayer.cpp +++ b/velox/tool/trace/OperatorReplayerBase.cpp @@ -18,32 +18,51 @@ #include "velox/exec/trace/QueryTraceTraits.h" #include "velox/exec/trace/QueryTraceUtil.h" -#include "velox/tool/trace/QueryTraceReplayer.h" +#include "velox/tool/trace/OperatorReplayerBase.h" + +#include #include "velox/common/serialization/Serializable.h" #include "velox/core/PlanNode.h" -DEFINE_bool(usage, false, "Show the usage"); -DEFINE_string(root, "", "Root dir of the query tracing"); -DEFINE_bool(summary, false, "Show the summary of the tracing"); -DEFINE_bool(short_summary, false, "Only show number of tasks and task ids"); -DEFINE_string( - task_id, - "", - "Specify the target task id, if empty, show the summary of all the traced query task."); - using namespace facebook::velox; namespace facebook::velox::tool::trace { -QueryTraceReplayer::QueryTraceReplayer() - : rootDir_(FLAGS_root), taskId_(FLAGS_task_id) {} +OperatorReplayerBase::OperatorReplayerBase( + std::string rootDir, + std::string taskId, + std::string nodeId, + int32_t pipelineId, + std::string operatorType) + : rootDir_(std::move(rootDir)), + taskId_(std::move(taskId)), + nodeId_(std::move(nodeId)), + pipelineId_(pipelineId), + operatorType_(std::move(operatorType)) { + VELOX_USER_CHECK(!rootDir_.empty()); + VELOX_USER_CHECK(!taskId_.empty()); + VELOX_USER_CHECK(!nodeId_.empty()); + VELOX_USER_CHECK_GE(pipelineId_, 0); + VELOX_USER_CHECK(!operatorType_.empty()); + const auto traceTaskDir = fmt::format("{}/{}", rootDir_, taskId_); + const auto metadataReader = exec::trace::QueryMetadataReader( + traceTaskDir, memory::MemoryManager::getInstance()->tracePool()); + metadataReader.read(queryConfigs_, connectorConfigs_, planFragment_); + queryConfigs_[core::QueryConfig::kQueryTraceEnabled] = "false"; + fs_ = filesystems::getFileSystem(rootDir_, nullptr); + maxDrivers_ = + exec::trace::getNumDrivers(rootDir_, taskId_, nodeId_, pipelineId_, fs_); +} -void QueryTraceReplayer::printSummary() const { - const auto fs = filesystems::getFileSystem(rootDir_, nullptr); - const auto taskIds = exec::trace::getTaskIds(rootDir_, fs); +void OperatorReplayerBase::printSummary( + const std::string& rootDir, + const std::string& taskId, + bool shortSummary) { + const auto fs = filesystems::getFileSystem(rootDir, nullptr); + const auto taskIds = exec::trace::getTaskIds(rootDir, fs); if (taskIds.empty()) { - LOG(ERROR) << "No traced query task under " << rootDir_; + LOG(ERROR) << "No traced query task under " << rootDir; return; } @@ -52,17 +71,17 @@ void QueryTraceReplayer::printSummary() const { summary << "Number of tasks: " << taskIds.size() << "\n"; summary << "Task ids: " << folly::join(",", taskIds); - if (FLAGS_short_summary) { + if (shortSummary) { LOG(INFO) << summary.str(); return; } const auto summaryTaskIds = - taskId_.empty() ? taskIds : std::vector{taskId_}; + taskId.empty() ? taskIds : std::vector{taskId}; for (const auto& taskId : summaryTaskIds) { summary << "\n++++++Query configs and plan of task " << taskId << ":++++++\n"; - const auto traceTaskDir = fmt::format("{}/{}", rootDir_, taskId); + const auto traceTaskDir = fmt::format("{}/{}", rootDir, taskId); const auto queryMetaFile = fmt::format( "{}/{}", traceTaskDir, @@ -81,7 +100,7 @@ void QueryTraceReplayer::printSummary() const { LOG(INFO) << summary.str(); } -std::string QueryTraceReplayer::usage() { +std::string OperatorReplayerBase::usage() { std::ostringstream usage; usage << "++++++Query Trace Tool Usage++++++\n" diff --git a/velox/tool/trace/OperatorReplayerBase.h b/velox/tool/trace/OperatorReplayerBase.h new file mode 100644 index 000000000000..f29beff6c047 --- /dev/null +++ b/velox/tool/trace/OperatorReplayerBase.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "velox/common/file/FileSystems.h" +#include "velox/core/PlanNode.h" + +namespace facebook::velox::tool::trace { +class OperatorReplayerBase { + public: + OperatorReplayerBase( + std::string rootDir, + std::string taskId, + std::string nodeId, + int32_t pipelineId, + std::string operatorType); + virtual ~OperatorReplayerBase() = default; + + OperatorReplayerBase(const OperatorReplayerBase& other) = delete; + OperatorReplayerBase& operator=(const OperatorReplayerBase& other) = delete; + OperatorReplayerBase(OperatorReplayerBase&& other) noexcept = delete; + OperatorReplayerBase& operator=(OperatorReplayerBase&& other) noexcept = + delete; + + static void printSummary( + const std::string& rootDir, + const std::string& taskId, + bool shortSummary); + + virtual RowVectorPtr run() const = 0; + + static std::string usage(); + + protected: + virtual core::PlanNodePtr createPlan() const = 0; + + const std::string rootDir_; + const std::string taskId_; + const std::string nodeId_; + const int32_t pipelineId_; + const std::string operatorType_; + + std::unordered_map queryConfigs_; + std::unordered_map> + connectorConfigs_; + core::PlanNodePtr planFragment_; + std::shared_ptr fs_; + int32_t maxDrivers_{1}; +}; + +} // namespace facebook::velox::tool::trace diff --git a/velox/tool/trace/QueryReplayer.cpp b/velox/tool/trace/QueryReplayer.cpp index 053ddd417ab5..9ffccb4ddef9 100644 --- a/velox/tool/trace/QueryReplayer.cpp +++ b/velox/tool/trace/QueryReplayer.cpp @@ -14,6 +14,9 @@ * limitations under the License. */ +#include +#include +#include #include #include "velox/common/memory/Memory.h" #include "velox/core/PlanNode.h" @@ -25,7 +28,28 @@ #include "velox/connectors/hive/storage_adapters/gcs/RegisterGCSFileSystem.h" #include "velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h" #include "velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.h" -#include "velox/tool/trace/QueryTraceReplayer.h" +#include "velox/tool/trace/OperatorReplayerBase.h" +#include "velox/tool/trace/TableWriterReplayer.h" + +DEFINE_bool(usage, false, "Show the usage"); +DEFINE_string(root, "", "Root dir of the query tracing"); +DEFINE_bool(summary, false, "Show the summary of the tracing"); +DEFINE_bool(short_summary, false, "Only show number of tasks and task ids"); +DEFINE_string( + task_id, + "", + "Specify the target task id, if empty, show the summary of all the traced query task."); +DEFINE_string(node_id, "", "Specify the target node id."); +DEFINE_int32(pipeline_id, 0, "Specify the target pipeline id."); +DEFINE_string(operator_type, "", "Specify the target operator type."); +DEFINE_string( + table_writer_output_dir, + "", + "Specify output directory of TableWriter."); +DEFINE_double( + hiveConnectorExecutorHwMultiplier, + 2.0, + "Hardware multipler for hive connector."); using namespace facebook::velox; @@ -41,38 +65,75 @@ void init() { core::PlanNode::registerSerDe(); core::ITypedExpr::registerSerDe(); exec::registerPartitionFunctionSerDe(); + connector::hive::HiveTableHandle::registerSerDe(); + connector::hive::LocationHandle::registerSerDe(); + connector::hive::HiveColumnHandle::registerSerDe(); + connector::hive::HiveInsertTableHandle::registerSerDe(); + if (!isRegisteredVectorSerde()) { + serializer::presto::PrestoVectorSerde::registerVectorSerde(); + } + // TODO: make it configurable. + const auto ioExecutor = std::make_unique( + std::thread::hardware_concurrency() * + FLAGS_hiveConnectorExecutorHwMultiplier); + const auto hiveConnector = + connector::getConnectorFactory("hive")->newConnector( + "test-hive", + std::make_shared( + std::unordered_map()), + ioExecutor.get()); + connector::registerConnector(hiveConnector); +} + +std::unique_ptr createReplayer( + const std::string& operatorType) { + std::unique_ptr replayer = nullptr; + if (operatorType == "TableWriter") { + replayer = std::make_unique( + FLAGS_root, + FLAGS_task_id, + FLAGS_node_id, + FLAGS_pipeline_id, + FLAGS_operator_type, + FLAGS_table_writer_output_dir); + } else { + VELOX_FAIL("Unsupported opeartor type: {}", FLAGS_operator_type); + } + VELOX_USER_CHECK_NOT_NULL(replayer); + return replayer; } } // namespace int main(int argc, char** argv) { if (argc == 1) { - LOG(ERROR) << "\n" << tool::trace::QueryTraceReplayer::usage(); + LOG(ERROR) << "\n" << tool::trace::OperatorReplayerBase::usage(); return 1; } gflags::ParseCommandLineFlags(&argc, &argv, true); if (FLAGS_usage) { - LOG(INFO) << "\n" << tool::trace::QueryTraceReplayer::usage(); + LOG(INFO) << "\n" << tool::trace::OperatorReplayerBase::usage(); return 0; } if (FLAGS_root.empty()) { LOG(ERROR) << "Root dir is not provided!\n" - << tool::trace::QueryTraceReplayer::usage(); - return 1; - } - - if (!FLAGS_summary && !FLAGS_short_summary) { - LOG(ERROR) << "Only support to print traced query metadata for now"; + << tool::trace::OperatorReplayerBase::usage(); return 1; } init(); - const auto tool = std::make_unique(); if (FLAGS_summary || FLAGS_short_summary) { - tool->printSummary(); + tool::trace::OperatorReplayerBase::printSummary( + FLAGS_root, FLAGS_task_id, FLAGS_short_summary); return 0; } - VELOX_UNREACHABLE(tool::trace::QueryTraceReplayer::usage()); + const auto replayer = createReplayer(FLAGS_operator_type); + VELOX_USER_CHECK_NOT_NULL( + replayer, "Unsupported opeartor type: {}", FLAGS_operator_type); + + replayer->run(); + + return 0; } diff --git a/velox/tool/trace/TableWriterReplayer.cpp b/velox/tool/trace/TableWriterReplayer.cpp new file mode 100644 index 000000000000..dc7ea81bf01d --- /dev/null +++ b/velox/tool/trace/TableWriterReplayer.cpp @@ -0,0 +1,129 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/tool/trace/TableWriterReplayer.h" + +#include +#include "velox/common/memory/Memory.h" +#include "velox/exec/TableWriter.h" +#include "velox/exec/Task.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/trace/QueryDataReader.h" +#include "velox/exec/trace/QueryTraceUtil.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; + +namespace facebook::velox::tool::trace { + +namespace { + +std::shared_ptr +makeHiveInsertTableHandle( + const core::TableWriteNode* node, + std::string targetDir) { + const auto tracedHandle = + std::dynamic_pointer_cast( + node->insertTableHandle()->connectorInsertTableHandle()); + const auto inputColumns = tracedHandle->inputColumns(); + const auto compressionKind = + tracedHandle->compressionKind().value_or(common::CompressionKind_NONE); + const auto storageFormat = tracedHandle->tableStorageFormat(); + const auto serdeParameters = tracedHandle->serdeParameters(); + const auto writerOptions = tracedHandle->writerOptions(); + return std::make_shared( + inputColumns, + std::make_shared( + targetDir, + targetDir, + connector::hive::LocationHandle::TableType::kNew), + storageFormat, + tracedHandle->bucketProperty() == nullptr + ? nullptr + : std::make_shared( + *tracedHandle->bucketProperty()), + compressionKind, + std::unordered_map{}, + writerOptions); +} + +std::shared_ptr createInsertTableHanlde( + const std::string& connectorId, + const core::TableWriteNode* node, + std::string targetDir) { + return std::make_shared( + connectorId, makeHiveInsertTableHandle(node, std::move(targetDir))); +} + +} // namespace + +RowVectorPtr TableWriterReplayer::run() const { + const auto restoredPlanNode = createPlan(); + + return AssertQueryBuilder(restoredPlanNode) + .maxDrivers(maxDrivers_) + .configs(queryConfigs_) + .connectorSessionProperties(connectorConfigs_) + .copyResults(memory::MemoryManager::getInstance()->tracePool()); +} + +core::PlanNodePtr TableWriterReplayer::createPlan() const { + const auto* tableWriterNode = core::PlanNode::findFirstNode( + planFragment_.get(), + [this](const core::PlanNode* node) { return node->id() == nodeId_; }); + const auto traceRoot = fmt::format("{}/{}", rootDir_, taskId_); + return PlanBuilder() + .traceScan( + fmt::format("{}/{}", traceRoot, nodeId_), + exec::trace::getDataType(planFragment_, nodeId_)) + .addNode(addTableWriter( + dynamic_cast(tableWriterNode), + replayOutputDir_)) + .planNode(); +} + +core::PlanNodePtr TableWriterReplayer::createTableWrtierNode( + const core::TableWriteNode* node, + const std::string& targetDir, + const core::PlanNodeId& nodeId, + const core::PlanNodePtr& source) { + const auto insertTableHandle = + createInsertTableHanlde("test-hive", node, targetDir); + return std::make_shared( + nodeId, + node->columns(), + node->columnNames(), + node->aggregationNode(), + insertTableHandle, + node->hasPartitioningScheme(), + TableWriteTraits::outputType(node->aggregationNode()), + node->commitStrategy(), + source); +} + +std::function +TableWriterReplayer::addTableWriter( + const core::TableWriteNode* node, + const std::string& targetDir) { + return [=](const core::PlanNodeId& nodeId, + const core::PlanNodePtr& source) -> core::PlanNodePtr { + return createTableWrtierNode(node, targetDir, nodeId, source); + }; +} + +} // namespace facebook::velox::tool::trace diff --git a/velox/tool/trace/TableWriterReplayer.h b/velox/tool/trace/TableWriterReplayer.h new file mode 100644 index 000000000000..499ec628d301 --- /dev/null +++ b/velox/tool/trace/TableWriterReplayer.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "velox/tool/trace/OperatorReplayerBase.h" + +#include "velox/core/PlanNode.h" + +namespace facebook::velox::tool::trace { +/// The replayer to replay the traced 'TableWriter' operator. +class TableWriterReplayer final : public OperatorReplayerBase { + public: + TableWriterReplayer( + const std::string& rootDir, + const std::string& taskId, + const std::string& nodeId, + const int32_t pipelineId, + const std::string& operatorType, + const std::string& replayOutputDir) + : OperatorReplayerBase(rootDir, taskId, nodeId, pipelineId, operatorType), + replayOutputDir_(replayOutputDir) { + VELOX_CHECK(!replayOutputDir_.empty()); + } + + RowVectorPtr run() const override; + + protected: + core::PlanNodePtr createPlan() const override; + + private: + static core::PlanNodePtr createTableWrtierNode( + const core::TableWriteNode* node, + const std::string& targetDir, + const core::PlanNodeId& nodeId, + const core::PlanNodePtr& source); + + static std::function + addTableWriter( + const core::TableWriteNode* node, + const std::string& targetDir); + + const std::string replayOutputDir_; +}; + +} // namespace facebook::velox::tool::trace diff --git a/velox/tool/trace/test/CMakeLists.txt b/velox/tool/trace/test/CMakeLists.txt new file mode 100644 index 000000000000..e24088fc0ae9 --- /dev/null +++ b/velox/tool/trace/test/CMakeLists.txt @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_executable(velox_tool_trace_test TableWriterReplayerTest.cpp Main.cpp) + +add_test( + NAME velox_tool_trace_test + COMMAND velox_tool_trace_test + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + +set_tests_properties(velox_tool_trace_test PROPERTIES TIMEOUT 3000) + +target_link_libraries( + velox_tool_trace_test + velox_exec + velox_exec_test_lib + velox_memory + velox_query_trace_exec + velox_query_trace_retrieve + velox_query_trace_replayer_base + velox_vector_fuzzer + GTest::gtest_main + GTest::gmock + Folly::folly + gflags::gflags + glog::glog + fmt::fmt + ${FILESYSTEM}) diff --git a/velox/tool/trace/QueryTraceReplayer.h b/velox/tool/trace/test/Main.cpp similarity index 53% rename from velox/tool/trace/QueryTraceReplayer.h rename to velox/tool/trace/test/Main.cpp index 29f99f91594f..814b389a1b1b 100644 --- a/velox/tool/trace/QueryTraceReplayer.h +++ b/velox/tool/trace/test/Main.cpp @@ -14,29 +14,16 @@ * limitations under the License. */ -#pragma once +#include "velox/common/process/ThreadDebugInfo.h" -#include +#include +#include +#include -DECLARE_bool(usage); -DECLARE_string(root); -DECLARE_bool(summary); -DECLARE_bool(short_summary); -DECLARE_bool(pretty); -DECLARE_string(task_id); - -namespace facebook::velox::tool::trace { -/// The tool used to print or replay the traced query metadata and operations. -class QueryTraceReplayer { - public: - QueryTraceReplayer(); - - void printSummary() const; - static std::string usage(); - - private: - const std::string rootDir_; - const std::string taskId_; -}; - -} // namespace facebook::velox::tool::trace +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + // Signal handler required for ThreadDebugInfoTest + facebook::velox::process::addDefaultFatalSignalHandler(); + folly::Init init(&argc, &argv, false); + return RUN_ALL_TESTS(); +} diff --git a/velox/tool/trace/test/TableWriterReplayerTest.cpp b/velox/tool/trace/test/TableWriterReplayerTest.cpp new file mode 100644 index 000000000000..56876bc9a311 --- /dev/null +++ b/velox/tool/trace/test/TableWriterReplayerTest.cpp @@ -0,0 +1,413 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "folly/dynamic.h" +#include "velox/common/base/Fs.h" +#include "velox/common/file/FileSystems.h" +#include "velox/common/hyperloglog/SparseHll.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/exec/PartitionFunction.h" +#include "velox/exec/TableWriter.h" +#include "velox/exec/tests/utils/ArbitratorTestUtil.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/exec/trace/QueryDataReader.h" +#include "velox/exec/trace/QueryTraceUtil.h" +#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" +#include "velox/serializers/PrestoSerializer.h" +#include "velox/vector/fuzzer/VectorFuzzer.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +#include "velox/tool/trace/TableWriterReplayer.h" + +#include +#include "folly/experimental/EventCount.h" +#include "velox/dwio/dwrf/writer/Writer.h" + +using namespace facebook::velox; +using namespace facebook::velox::core; +using namespace facebook::velox::common; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; +using namespace facebook::velox::connector; +using namespace facebook::velox::connector::hive; +using namespace facebook::velox::dwio::common; +using namespace facebook::velox::common::testutil; +using namespace facebook::velox::common::hll; + +namespace facebook::velox::tool::trace::test { +class TableWriterReplayerTest : public HiveConnectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance({}); + HiveConnectorTestBase::SetUpTestCase(); + filesystems::registerLocalFileSystem(); + if (!isRegisteredVectorSerde()) { + serializer::presto::PrestoVectorSerde::registerVectorSerde(); + } + Type::registerSerDe(); + common::Filter::registerSerDe(); + connector::hive::HiveTableHandle::registerSerDe(); + connector::hive::LocationHandle::registerSerDe(); + connector::hive::HiveColumnHandle::registerSerDe(); + connector::hive::HiveInsertTableHandle::registerSerDe(); + core::PlanNode::registerSerDe(); + core::ITypedExpr::registerSerDe(); + registerPartitionFunctionSerDe(); + } + + std::vector makeBatches( + vector_size_t numBatches, + std::function makeVector) { + std::vector batches; + batches.reserve(numBatches); + for (int32_t i = 0; i < numBatches; ++i) { + batches.push_back(makeVector(i)); + } + return batches; + } + + std::set getLeafSubdirectories( + const std::string& directoryPath) { + std::set subdirectories; + for (auto& path : fs::recursive_directory_iterator(directoryPath)) { + if (path.is_regular_file()) { + subdirectories.emplace(path.path().parent_path().string()); + } + } + return subdirectories; + } + + // Helper method to return InsertTableHandle. + std::shared_ptr createInsertTableHandle( + const RowTypePtr& outputRowType, + const connector::hive::LocationHandle::TableType& outputTableType, + const std::string& outputDirectoryPath, + const std::vector& partitionedBy, + const std::shared_ptr bucketProperty, + const std::optional compressionKind = {}) { + return std::make_shared( + kHiveConnectorId, + makeHiveInsertTableHandle( + outputRowType->names(), + outputRowType->children(), + partitionedBy, + bucketProperty, + makeLocationHandle( + outputDirectoryPath, std::nullopt, outputTableType), + fileFormat_, + compressionKind)); + } + + // Returns a table insert plan node. + PlanNodePtr createInsertPlan( + PlanBuilder& inputPlan, + const RowTypePtr& inputRowType, + const RowTypePtr& tableRowType, + const std::string& outputDirectoryPath, + const std::vector& partitionedBy = {}, + std::shared_ptr bucketProperty = nullptr, + const std::optional compressionKind = {}, + const connector::hive::LocationHandle::TableType& outputTableType = + connector::hive::LocationHandle::TableType::kNew, + const CommitStrategy& outputCommitStrategy = CommitStrategy::kNoCommit, + bool aggregateResult = true, + std::shared_ptr aggregationNode = nullptr) { + auto insertPlan = inputPlan + .addNode(addTableWriter( + inputRowType, + tableRowType->names(), + aggregationNode, + createInsertTableHandle( + tableRowType, + outputTableType, + outputDirectoryPath, + partitionedBy, + bucketProperty, + compressionKind), + !partitionedBy.empty(), + outputCommitStrategy)) + .capturePlanNodeId(tableWriteNodeId_); + if (aggregateResult) { + insertPlan.project({TableWriteTraits::rowCountColumnName()}) + .singleAggregation( + {}, + {fmt::format("sum({})", TableWriteTraits::rowCountColumnName())}); + } + return insertPlan.planNode(); + } + + std::function addTableWriter( + const RowTypePtr& inputColumns, + const std::vector& tableColumnNames, + const std::shared_ptr& aggregationNode, + const std::shared_ptr& insertHandle, + bool hasPartitioningScheme, + connector::CommitStrategy commitStrategy = + connector::CommitStrategy::kNoCommit) { + return [=](core::PlanNodeId nodeId, + core::PlanNodePtr source) -> core::PlanNodePtr { + return std::make_shared( + nodeId, + inputColumns, + tableColumnNames, + aggregationNode, + insertHandle, + hasPartitioningScheme, + TableWriteTraits::outputType(aggregationNode), + commitStrategy, + std::move(source)); + }; + } + + RowTypePtr getNonPartitionsColumns( + const std::vector& partitionedKeys, + const RowTypePtr& rowType) { + std::vector dataColumnNames; + std::vector dataColumnTypes; + for (auto i = 0; i < rowType->size(); i++) { + auto name = rowType->names()[i]; + if (std::find(partitionedKeys.cbegin(), partitionedKeys.cend(), name) == + partitionedKeys.cend()) { + dataColumnNames.emplace_back(name); + dataColumnTypes.emplace_back(rowType->findChild(name)); + } + } + + return ROW(std::move(dataColumnNames), std::move(dataColumnTypes)); + } + + std::vector> + makeHiveSplitsFromDirectory(const std::string& directoryPath) { + std::vector> splits; + + for (auto& path : fs::recursive_directory_iterator(directoryPath)) { + if (path.is_regular_file()) { + splits.push_back(HiveConnectorTestBase::makeHiveConnectorSplits( + path.path().string(), 1, fileFormat_)[0]); + } + } + + return splits; + } + + void checkWriteResults( + const std::set& actualDirs, + const std::set& expectedDirs, + const std::vector& partitionKeys, + const RowTypePtr& rowType) { + ASSERT_EQ(actualDirs.size(), expectedDirs.size()); + auto actualDirIt = actualDirs.begin(); + auto expectedDirIt = expectedDirs.begin(); + const auto newOutputType = getNonPartitionsColumns(partitionKeys, rowType); + while (actualDirIt != actualDirs.end()) { + const auto actualWrites = + AssertQueryBuilder(PlanBuilder().tableScan(newOutputType).planNode()) + .splits(makeHiveSplitsFromDirectory(*actualDirIt)) + .copyResults(pool()); + const auto expectedWrites = + AssertQueryBuilder(PlanBuilder().tableScan(newOutputType).planNode()) + .splits(makeHiveSplitsFromDirectory(*expectedDirIt)) + .copyResults(pool()); + assertEqualResults({actualWrites}, {expectedWrites}); + ++actualDirIt; + ++expectedDirIt; + } + } + + std::string tableWriteNodeId_; + FileFormat fileFormat_{FileFormat::DWRF}; +}; + +TEST_F(TableWriterReplayerTest, basic) { + vector_size_t size = 1'000; + auto data = makeRowVector({ + makeFlatVector(size, [](auto row) { return row; }), + makeFlatVector( + size, [](auto row) { return row * 2; }, nullEvery(7)), + }); + auto sourceFilePath = TempFilePath::create(); + writeToFile(sourceFilePath->getPath(), data); + + std::string planNodeId; + auto targetDirectoryPath = TempDirectoryPath::create(); + auto rowType = asRowType(data->type()); + auto plan = PlanBuilder() + .tableScan(rowType) + .tableWrite(targetDirectoryPath->getPath()) + .capturePlanNodeId(planNodeId) + .planNode(); + const auto testDir = TempDirectoryPath::create(); + const auto traceRoot = fmt::format("{}/{}", testDir->getPath(), "traceRoot"); + std::shared_ptr task; + auto results = + AssertQueryBuilder(plan) + .config(core::QueryConfig::kQueryTraceEnabled, true) + .config(core::QueryConfig::kQueryTraceDir, traceRoot) + .config(core::QueryConfig::kQueryTraceMaxBytes, 100UL << 30) + .config(core::QueryConfig::kQueryTraceTaskRegExp, ".*") + .config(core::QueryConfig::kQueryTraceNodeIds, planNodeId) + .split(makeHiveConnectorSplit(sourceFilePath->getPath())) + .copyResults(pool(), task); + const auto traceOutputDir = TempDirectoryPath::create(); + const auto tableWriterReplayer = TableWriterReplayer( + traceRoot, + task->taskId(), + "1", + 0, + "TableWriter", + traceOutputDir->getPath()); + const auto result = tableWriterReplayer.run(); + // Second column contains details about written files. + const auto details = results->childAt(TableWriteTraits::kFragmentChannel) + ->as>(); + const folly::dynamic obj = folly::parseJson(details->valueAt(1)); + const auto fileWriteInfos = obj["fileWriteInfos"]; + ASSERT_EQ(1, fileWriteInfos.size()); + + const auto writeFileName = fileWriteInfos[0]["writeFileName"].asString(); + // Read from 'writeFileName' and verify the data matches the original. + plan = PlanBuilder().tableScan(rowType).planNode(); + + const auto copy = + AssertQueryBuilder(plan) + .split(makeHiveConnectorSplit(fmt::format( + "{}/{}", targetDirectoryPath->getPath(), writeFileName))) + .copyResults(pool()); + assertEqualResults({data}, {copy}); +} + +TEST_F(TableWriterReplayerTest, partitionWrite) { + const int32_t numPartitions = 4; + const int32_t numBatches = 2; + const auto rowType = + ROW({"c0", "p0", "p1", "c1", "c3", "c5"}, + {INTEGER(), INTEGER(), VARCHAR(), BIGINT(), REAL(), VARCHAR()}); + const std::vector partitionKeys = {"p0", "p1"}; + const std::vector partitionTypes = {INTEGER(), VARCHAR()}; + const std::vector vectors = makeBatches(numBatches, [&](auto) { + return makeRowVector( + rowType->names(), + { + makeFlatVector( + numPartitions, [&](auto row) { return row + 100; }), + makeFlatVector( + numPartitions, [&](auto row) { return row; }), + makeFlatVector( + numPartitions, + [&](auto row) { + return StringView::makeInline(fmt::format("str_{}", row)); + }), + makeFlatVector( + numPartitions, [&](auto row) { return row + 1000; }), + makeFlatVector( + numPartitions, [&](auto row) { return row + 33.23; }), + makeFlatVector( + numPartitions, + [&](auto row) { + return StringView::makeInline( + fmt::format("bucket_{}", row * 3)); + }), + }); + }); + const auto inputFilePaths = makeFilePaths(numBatches); + for (int i = 0; i < numBatches; i++) { + writeToFile(inputFilePaths[i]->getPath(), vectors[i]); + } + + const auto outputDirectory = TempDirectoryPath::create(); + auto inputPlan = PlanBuilder().tableScan(rowType); + auto plan = createInsertPlan( + inputPlan, + inputPlan.planNode()->outputType(), + rowType, + outputDirectory->getPath(), + partitionKeys, + nullptr, + CompressionKind::CompressionKind_ZSTD); + AssertQueryBuilder(plan) + .splits(makeHiveConnectorSplits(inputFilePaths)) + .copyResults(pool()); + // Verify that there is one partition directory for each partition. + std::set actualPartitionDirectories = + getLeafSubdirectories(outputDirectory->getPath()); + std::set expectedPartitionDirectories; + std::set partitionNames; + for (auto i = 0; i < numPartitions; i++) { + auto partitionName = fmt::format("p0={}/p1=str_{}", i, i); + partitionNames.emplace(partitionName); + expectedPartitionDirectories.emplace( + fs::path(outputDirectory->getPath()) / partitionName); + } + EXPECT_EQ(actualPartitionDirectories, expectedPartitionDirectories); + + const auto outputDirWithTracing = TempDirectoryPath::create(); + auto inputPlanWithTracing = PlanBuilder().tableScan(rowType); + auto planWithTracing = createInsertPlan( + inputPlanWithTracing, + inputPlanWithTracing.planNode()->outputType(), + rowType, + outputDirWithTracing->getPath(), + partitionKeys, + nullptr, + CompressionKind::CompressionKind_ZSTD); + const auto testDir = TempDirectoryPath::create(); + const auto traceRoot = fmt::format("{}/{}", testDir->getPath(), "traceRoot"); + const auto tableWriteNodeId = std::move(tableWriteNodeId_); + std::shared_ptr task; + AssertQueryBuilder(planWithTracing) + .config(core::QueryConfig::kQueryTraceEnabled, true) + .config(core::QueryConfig::kQueryTraceDir, traceRoot) + .config(core::QueryConfig::kQueryTraceMaxBytes, 100UL << 30) + .config(core::QueryConfig::kQueryTraceTaskRegExp, ".*") + .config(core::QueryConfig::kQueryTraceNodeIds, tableWriteNodeId) + .splits(makeHiveConnectorSplits(inputFilePaths)) + .copyResults(pool(), task); + actualPartitionDirectories = + getLeafSubdirectories(outputDirWithTracing->getPath()); + ASSERT_EQ( + actualPartitionDirectories.size(), expectedPartitionDirectories.size()); + checkWriteResults( + actualPartitionDirectories, + expectedPartitionDirectories, + partitionKeys, + rowType); + + const auto traceOutputDir = TempDirectoryPath::create(); + const auto tableWriterReplayer = TableWriterReplayer( + traceRoot, + task->taskId(), + tableWriteNodeId, + 0, + "TableWriter", + traceOutputDir->getPath()); + tableWriterReplayer.run(); + actualPartitionDirectories = getLeafSubdirectories(traceOutputDir->getPath()); + checkWriteResults( + actualPartitionDirectories, + expectedPartitionDirectories, + partitionKeys, + rowType); +} + +} // namespace facebook::velox::tool::trace::test