Skip to content

Commit

Permalink
Fix semi mask on scan node table (#4050)
Browse files Browse the repository at this point in the history
* enable semi mask on node table

* add todo

* Add fine-grained-semi-mask optimizer

* Run clang-format

---------

Co-authored-by: xiyang <x74feng@uwaterloo.ca>
Co-authored-by: CI Bot <ray6080@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 9, 2024
1 parent a1b7279 commit 72944e9
Show file tree
Hide file tree
Showing 17 changed files with 86 additions and 78 deletions.
10 changes: 5 additions & 5 deletions src/function/gds/shortest_paths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class PathLengths : public GDSFrontier {
explicit PathLengths(
std::vector<std::tuple<common::table_id_t, uint64_t>> nodeTableIDAndNumNodes) {
for (const auto& [tableID, numNodes] : nodeTableIDAndNumNodes) {
masks.insert({tableID, std::make_unique<processor::MaskData>(numNodes, UNVISITED)});
masks.insert({tableID, std::make_unique<MaskData>(numNodes, UNVISITED)});
}
}

Expand Down Expand Up @@ -131,10 +131,10 @@ class PathLengths : public GDSFrontier {

private:
uint8_t curIter = 255;
common::table_id_map_t<std::unique_ptr<processor::MaskData>> masks;
common::table_id_map_t<std::unique_ptr<MaskData>> masks;
common::table_id_t curFrontierFixedTableID;
processor::MaskData* curFrontierFixedMask;
processor::MaskData* nextFrontierFixedMask;
MaskData* curFrontierFixedMask;
MaskData* nextFrontierFixedMask;
};

class PathLengthsFrontiers : public Frontiers {
Expand Down Expand Up @@ -419,7 +419,7 @@ class ShortestPathsAlgorithm final : public GDSAlgorithm {
}
auto mask = sharedState->inputNodeOffsetMasks.at(tableID).get();
for (auto offset = 0u; offset < sharedState->graph->getNumNodes(tableID); ++offset) {
if (!mask->isMasked(offset)) {
if (!mask->isMasked(offset, offset)) {
continue;
}
auto sourceNodeID = nodeID_t{offset, tableID};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "common/types/internal_id_t.h"

namespace kuzu {
namespace processor {
namespace common {

// Note: Classes in this file are NOT thread-safe.
struct MaskUtil {
Expand Down Expand Up @@ -52,7 +52,15 @@ class MaskCollection {
maskData = std::make_unique<MaskData>(maxOffset + 1);
}

bool isMasked(common::offset_t offset) { return maskData->isMasked(offset, numMasks); }
// Return true if any offset between [startOffset, endOffset] is masked. Otherwise return false.
bool isMasked(common::offset_t startOffset, common::offset_t endOffset) const {
auto offset = startOffset;
auto numMasked = 0u;
while (offset <= endOffset) {
numMasked += maskData->isMasked(offset++, numMasks);
}
return numMasked > 0;
}
// Increment mask value for the given nodeOffset if its current mask value is equal to
// the specified `currentMaskValue`.
// Note: blindly update mask does not parallelize well, so we minimize write by first checking
Expand Down Expand Up @@ -87,7 +95,7 @@ class NodeSemiMask {
virtual void init() = 0;

virtual void incrementMaskValue(common::offset_t nodeOffset, uint8_t currentMaskValue) = 0;
virtual bool isMasked(common::offset_t nodeOffset) = 0;
virtual bool isMasked(common::offset_t startNodeOffset, common::offset_t endNodeOffset) = 0;

bool isEnabled() const { return getNumMasks() > 0; }
uint8_t getNumMasks() const { return maskCollection.getNumMasks(); }
Expand Down Expand Up @@ -115,8 +123,8 @@ class NodeOffsetLevelSemiMask final : public NodeSemiMask {
maskCollection.incrementMaskValue(nodeOffset, currentMaskValue);
}

bool isMasked(common::offset_t nodeOffset) override {
return maskCollection.isMasked(nodeOffset);
bool isMasked(common::offset_t startNodeOffset, common::offset_t endNodeOffset) override {
return maskCollection.isMasked(startNodeOffset, endNodeOffset);
}
};

Expand All @@ -136,10 +144,11 @@ class NodeVectorLevelSemiMask final : public NodeSemiMask {
maskCollection.incrementMaskValue(MaskUtil::getVectorIdx(nodeOffset), currentMaskValue);
}

bool isMasked(common::offset_t nodeOffset) override {
return maskCollection.isMasked(MaskUtil::getVectorIdx(nodeOffset));
bool isMasked(common::offset_t startNodeOffset, common::offset_t endNodeOffset) override {
return maskCollection.isMasked(MaskUtil::getVectorIdx(startNodeOffset),
MaskUtil::getVectorIdx(endNodeOffset));
}
};

} // namespace processor
} // namespace common
} // namespace kuzu
3 changes: 2 additions & 1 deletion src/include/planner/operator/sip/side_way_info_passing.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ enum class SemiMaskPosition : uint8_t {
NONE = 0,
ON_BUILD = 1,
ON_PROBE = 2,
PROHIBIT = 3,
PROHIBIT_PROBE_TO_BUILD = 3,
PROHIBIT = 4,
};

/*
Expand Down
9 changes: 5 additions & 4 deletions src/include/processor/operator/gds_call.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#pragma once

#include "common/mask.h"
#include "function/gds/gds.h"
#include "function/gds/gds_utils.h"
#include "graph/graph.h"
#include "processor/operator/mask.h"
#include "processor/operator/sink.h"

namespace kuzu {
Expand All @@ -13,10 +13,11 @@ struct GDSCallSharedState {
std::mutex mtx;
std::shared_ptr<FactorizedTable> fTable;
std::unique_ptr<graph::Graph> graph;
common::table_id_map_t<std::unique_ptr<NodeOffsetLevelSemiMask>> inputNodeOffsetMasks;
common::table_id_map_t<std::unique_ptr<common::NodeOffsetLevelSemiMask>> inputNodeOffsetMasks;

GDSCallSharedState(std::shared_ptr<FactorizedTable> fTable, std::unique_ptr<graph::Graph> graph,
common::table_id_map_t<std::unique_ptr<NodeOffsetLevelSemiMask>> inputNodeOffsetMasks)
common::table_id_map_t<std::unique_ptr<common::NodeOffsetLevelSemiMask>>
inputNodeOffsetMasks)
: fTable{std::move(fTable)}, graph{std::move(graph)},
inputNodeOffsetMasks{std::move(inputNodeOffsetMasks)} {}
DELETE_COPY_AND_MOVE(GDSCallSharedState);
Expand Down Expand Up @@ -59,7 +60,7 @@ class GDSCall : public Sink {
info{std::move(info)}, sharedState{std::move(sharedState)} {}

bool hasSemiMask() const { return !sharedState->inputNodeOffsetMasks.empty(); }
std::vector<NodeSemiMask*> getSemiMasks() const;
std::vector<common::NodeSemiMask*> getSemiMasks() const;

bool isSource() const override { return true; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
#include "bfs_state.h"
#include "common/enums/extend_direction.h"
#include "common/enums/query_rel_type.h"
#include "common/mask.h"
#include "frontier_scanner.h"
#include "planner/operator/extend/recursive_join_type.h"
#include "processor/operator/mask.h"
#include "processor/operator/physical_operator.h"

namespace kuzu {
Expand All @@ -14,10 +14,10 @@ namespace processor {
class OffsetScanNodeTable;

struct RecursiveJoinSharedState {
std::vector<std::unique_ptr<NodeOffsetLevelSemiMask>> semiMasks;
std::vector<std::unique_ptr<common::NodeOffsetLevelSemiMask>> semiMasks;

explicit RecursiveJoinSharedState(
std::vector<std::unique_ptr<NodeOffsetLevelSemiMask>> semiMasks)
std::vector<std::unique_ptr<common::NodeOffsetLevelSemiMask>> semiMasks)
: semiMasks{std::move(semiMasks)} {}
};

Expand Down Expand Up @@ -116,7 +116,7 @@ class RecursiveJoin : public PhysicalOperator {
info{std::move(info)}, sharedState{std::move(sharedState)},
recursiveRoot{std::move(recursiveRoot)} {}

std::vector<NodeSemiMask*> getSemiMask() const;
std::vector<common::NodeSemiMask*> getSemiMask() const;

void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) final;

Expand Down
10 changes: 5 additions & 5 deletions src/include/processor/operator/scan/scan_node_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct ScanNodeTableProgressSharedState {

class ScanNodeTableSharedState {
public:
explicit ScanNodeTableSharedState(std::unique_ptr<NodeVectorLevelSemiMask> semiMask)
explicit ScanNodeTableSharedState(std::unique_ptr<common::NodeVectorLevelSemiMask> semiMask)
: table{nullptr}, currentCommittedGroupIdx{common::INVALID_NODE_GROUP_IDX},
currentUnCommittedGroupIdx{common::INVALID_NODE_GROUP_IDX}, numCommittedNodeGroups{0},
numUnCommittedNodeGroups{0}, semiMask{std::move(semiMask)} {};
Expand All @@ -27,7 +27,7 @@ class ScanNodeTableSharedState {
void nextMorsel(storage::NodeTableScanState& scanState,
std::shared_ptr<ScanNodeTableProgressSharedState> progressSharedState);

NodeSemiMask* getSemiMask() const { return semiMask.get(); }
common::NodeSemiMask* getSemiMask() const { return semiMask.get(); }

private:
std::mutex mtx;
Expand All @@ -36,7 +36,7 @@ class ScanNodeTableSharedState {
common::node_group_idx_t currentUnCommittedGroupIdx;
common::node_group_idx_t numCommittedNodeGroups;
common::node_group_idx_t numUnCommittedNodeGroups;
std::unique_ptr<NodeVectorLevelSemiMask> semiMask;
std::unique_ptr<common::NodeVectorLevelSemiMask> semiMask;
};

struct ScanNodeTableInfo {
Expand All @@ -52,7 +52,7 @@ struct ScanNodeTableInfo {
columnPredicates{std::move(columnPredicates)} {}
EXPLICIT_COPY_DEFAULT_MOVE(ScanNodeTableInfo);

void initScanState(NodeSemiMask* semiMask);
void initScanState(common::NodeSemiMask* semiMask);

private:
ScanNodeTableInfo(const ScanNodeTableInfo& other)
Expand Down Expand Up @@ -93,7 +93,7 @@ class ScanNodeTable final : public ScanTable {
KU_ASSERT(this->nodeInfos.size() == this->sharedStates.size());
}

std::vector<NodeSemiMask*> getSemiMasks() const;
std::vector<common::NodeSemiMask*> getSemiMasks() const;

bool isSource() const override { return true; }

Expand Down
4 changes: 2 additions & 2 deletions src/include/processor/operator/semi_masker.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include "processor/operator/mask.h"
#include "common/mask.h"
#include "processor/operator/physical_operator.h"

namespace kuzu {
Expand All @@ -12,7 +12,7 @@ class BaseSemiMasker;
// to indicate the execution sequence of its pipeline. Also, the maskerIdx is used as a flag to
// indicate if a value in the mask is masked or not, as each masker will increment the selected
// value in the mask by 1. More details are described in NodeTableSemiMask.
using mask_with_idx = std::pair<NodeSemiMask*, uint8_t>;
using mask_with_idx = std::pair<common::NodeSemiMask*, uint8_t>;

class SemiMaskerInfo {
friend class BaseSemiMasker;
Expand Down
13 changes: 4 additions & 9 deletions src/include/storage/store/node_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

#include <cstdint>

#include "common/exception/not_implemented.h"
#include "common/types/types.h"
#include "processor/operator/mask.h"
#include "storage/index/hash_index.h"
#include "storage/store/node_group_collection.h"
#include "storage/store/table.h"
#include <common/exception/not_implemented.h>

namespace kuzu {
namespace evaluator {
Expand All @@ -26,25 +25,21 @@ class Transaction;
namespace storage {

struct NodeTableScanState final : TableScanState {
processor::NodeSemiMask* semiMask;

// Scan state for un-committed data.
// Ideally we shouldn't need columns to scan un-checkpointed but committed data.
explicit NodeTableScanState(std::vector<common::column_id_t> columnIDs)
: TableScanState{std::move(columnIDs), {}}, semiMask{nullptr} {
: TableScanState{std::move(columnIDs), {}} {
nodeGroupScanState = std::make_unique<NodeGroupScanState>(this->columnIDs.size());
}

NodeTableScanState(std::vector<common::column_id_t> columnIDs, std::vector<Column*> columns)
: TableScanState{std::move(columnIDs), std::move(columns),
std::vector<ColumnPredicateSet>{}},
semiMask{nullptr} {
std::vector<ColumnPredicateSet>{}} {
nodeGroupScanState = std::make_unique<NodeGroupScanState>(this->columnIDs.size());
}
NodeTableScanState(std::vector<common::column_id_t> columnIDs, std::vector<Column*> columns,
std::vector<ColumnPredicateSet> columnPredicateSets)
: TableScanState{std::move(columnIDs), std::move(columns), std::move(columnPredicateSets)},
semiMask{nullptr} {
: TableScanState{std::move(columnIDs), std::move(columns), std::move(columnPredicateSets)} {
nodeGroupScanState = std::make_unique<NodeGroupScanState>(this->columnIDs.size());
}
};
Expand Down
13 changes: 8 additions & 5 deletions src/include/storage/store/table.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "catalog/catalog_entry/table_catalog_entry.h"
#include "common/enums/zone_map_check_result.h"
#include "common/mask.h"
#include "storage/predicate/column_predicate.h"
#include "storage/store/column.h"
#include "storage/store/node_group.h"
Expand All @@ -20,8 +21,9 @@ struct TableScanState {
common::ValueVector* IDVector;
std::vector<common::ValueVector*> outputVectors;
std::vector<common::column_id_t> columnIDs;
common::NodeSemiMask* semiMask;

// Only used when scan from checkpointed data.
// Only used when scan from persistent data.
std::vector<Column*> columns;

TableScanSource source = TableScanSource::NONE;
Expand All @@ -33,18 +35,19 @@ struct TableScanState {
common::ZoneMapCheckResult zoneMapResult = common::ZoneMapCheckResult::ALWAYS_SCAN;

explicit TableScanState(std::vector<common::column_id_t> columnIDs)
: IDVector(nullptr), columnIDs{std::move(columnIDs)} {
: IDVector(nullptr), columnIDs{std::move(columnIDs)}, semiMask{nullptr} {
rowIdxVector = std::make_unique<common::ValueVector>(common::LogicalType::INT64());
}
TableScanState(std::vector<common::column_id_t> columnIDs, std::vector<Column*> columns,
std::vector<ColumnPredicateSet> columnPredicateSets)
: IDVector(nullptr), columnIDs{std::move(columnIDs)}, columns{std::move(columns)},
columnPredicateSets{std::move(columnPredicateSets)} {
: IDVector(nullptr), columnIDs{std::move(columnIDs)}, semiMask{nullptr},
columns{std::move(columns)}, columnPredicateSets{std::move(columnPredicateSets)} {
rowIdxVector = std::make_unique<common::ValueVector>(common::LogicalType::INT64());
}
explicit TableScanState(std::vector<common::column_id_t> columnIDs,
std::vector<Column*> columns)
: IDVector(nullptr), columnIDs{std::move(columnIDs)}, columns{std::move(columns)} {
: IDVector(nullptr), columnIDs{std::move(columnIDs)}, semiMask{nullptr},
columns{std::move(columns)} {
rowIdxVector = std::make_unique<common::ValueVector>(common::LogicalType::INT64());
}
virtual ~TableScanState() = default;
Expand Down
15 changes: 13 additions & 2 deletions src/optimizer/acc_hash_join_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ void HashJoinSIPOptimizer::visitHashJoin(LogicalOperator* op) {
if (tryBuildToProbeHJSIP(op)) { // Try build to probe SIP first.
return;
}
if (hashJoin.getSIPInfo().position == SemiMaskPosition::PROHIBIT_PROBE_TO_BUILD) {
return;
}
tryProbeToBuildHJSIP(op);
}

Expand Down Expand Up @@ -213,8 +216,12 @@ bool HashJoinSIPOptimizer::tryBuildToProbeHJSIP(LogicalOperator* op) {
// TODO(Xiyang): we don't apply SIP from build to probe.
void HashJoinSIPOptimizer::visitIntersect(LogicalOperator* op) {
auto& intersect = op->cast<LogicalIntersect>();
if (intersect.getSIPInfo().position == SemiMaskPosition::PROHIBIT) {
switch (intersect.getSIPInfo().position) {
case SemiMaskPosition::PROHIBIT_PROBE_TO_BUILD:
case SemiMaskPosition::PROHIBIT:
return;
default:
break;
}
if (!isProbeSideQualified(op->getChild(0).get())) {
return;
Expand Down Expand Up @@ -246,8 +253,12 @@ void HashJoinSIPOptimizer::visitIntersect(LogicalOperator* op) {

void HashJoinSIPOptimizer::visitPathPropertyProbe(LogicalOperator* op) {
auto& pathPropertyProbe = op->cast<LogicalPathPropertyProbe>();
if (pathPropertyProbe.getSIPInfo().position == SemiMaskPosition::PROHIBIT) {
switch (pathPropertyProbe.getSIPInfo().position) {
case SemiMaskPosition::PROHIBIT_PROBE_TO_BUILD:
case SemiMaskPosition::PROHIBIT:
return;
default:
break;
}
if (pathPropertyProbe.getJoinType() == RecursiveJoinType::TRACK_NONE) {
return;
Expand Down
2 changes: 1 addition & 1 deletion src/planner/plan/append_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ void Planner::appendRecursiveExtend(const std::shared_ptr<NodeExpression>& bound
// Check for sip
auto ratio = plan.getCardinality() / relScanCardinality;
if (ratio > PlannerKnobs::SIP_RATIO) {
pathPropertyProbe->getSIPInfoUnsafe().position = SemiMaskPosition::PROHIBIT;
pathPropertyProbe->getSIPInfoUnsafe().position = SemiMaskPosition::PROHIBIT_PROBE_TO_BUILD;
}
plan.setLastOperator(std::move(pathPropertyProbe));
// Update cost
Expand Down
2 changes: 1 addition & 1 deletion src/planner/plan/append_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void Planner::appendHashJoin(const expression_vector& joinNodeIDs, JoinType join
// Check for sip
auto ratio = probePlan.getCardinality() / buildPlan.getCardinality();
if (ratio > PlannerKnobs::SIP_RATIO) {
hashJoin->getSIPInfoUnsafe().position = SemiMaskPosition::PROHIBIT;
hashJoin->getSIPInfoUnsafe().position = SemiMaskPosition::PROHIBIT_PROBE_TO_BUILD;
}
// Update cost
resultPlan.setCost(CostModel::computeHashJoinCost(joinNodeIDs, probePlan, buildPlan));
Expand Down
4 changes: 2 additions & 2 deletions src/processor/map/map_recursive_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ namespace processor {

static std::shared_ptr<RecursiveJoinSharedState> createSharedState(
const binder::NodeExpression& nbrNode, const main::ClientContext& context) {
std::vector<std::unique_ptr<NodeOffsetLevelSemiMask>> semiMasks;
std::vector<std::unique_ptr<common::NodeOffsetLevelSemiMask>> semiMasks;
for (auto tableID : nbrNode.getTableIDs()) {
auto table = context.getStorageManager()->getTable(tableID)->ptrCast<storage::NodeTable>();
semiMasks.push_back(
std::make_unique<NodeOffsetLevelSemiMask>(tableID, table->getNumRows()));
std::make_unique<common::NodeOffsetLevelSemiMask>(tableID, table->getNumRows()));
}
return std::make_shared<RecursiveJoinSharedState>(std::move(semiMasks));
}
Expand Down
4 changes: 2 additions & 2 deletions src/processor/operator/gds_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ std::string GDSCallPrintInfo::toString() const {
return "Algorithm: " + funcName;
}

std::vector<NodeSemiMask*> GDSCall::getSemiMasks() const {
std::vector<NodeSemiMask*> masks;
std::vector<common::NodeSemiMask*> GDSCall::getSemiMasks() const {
std::vector<common::NodeSemiMask*> masks;
for (auto& [_, mask] : sharedState->inputNodeOffsetMasks) {
masks.push_back(mask.get());
}
Expand Down
Loading

0 comments on commit 72944e9

Please sign in to comment.