Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Snippets] Registers' assignment optimization #27391

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/common/snippets/include/snippets/emitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,28 @@ namespace snippets {
* @interface RegType
* @brief Register type of input and output operations
*/
enum class RegType { gpr, vec, undefined };
enum class RegType { gpr, vec, mask, undefined };
/**
* @interface Reg
* @brief Register representation: type of register and index
*/
struct Reg {
enum {UNDEFINED_IDX = std::numeric_limits<size_t>::max()};
Reg() = default;
Reg(RegType type_, size_t idx_) : type(type_), idx(idx_) {}

RegType type = RegType::gpr;
size_t idx = 0;
bool is_defined() const { return type != RegType::undefined && idx != UNDEFINED_IDX; }
RegType type = RegType::undefined;
size_t idx = UNDEFINED_IDX;

friend bool operator==(const Reg& lhs, const Reg& rhs);
friend bool operator<(const Reg& lhs, const Reg& rhs);
friend bool operator>(const Reg& lhs, const Reg& rhs);
friend bool operator!=(const Reg& lhs, const Reg& rhs);
friend std::ostream& operator<<(std::ostream& s, const Reg& r);
};
using RegInfo = std::pair<std::vector<Reg>, std::vector<Reg>>;

std::string regTypeToStr(const RegType& type);

/**
* @interface Emitter
* @brief Base class for all target specific code emitters used by generator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "pass.hpp"
#include "snippets/generator.hpp"
#include "snippets/lowered/reg_manager.hpp"

namespace ov {
namespace snippets {
Expand All @@ -21,15 +22,14 @@ namespace pass {
class AssignRegisters : public Pass {
public:
OPENVINO_RTTI("AssignRegisters", "Pass")
explicit AssignRegisters(const std::function<RegType(const ov::Output<Node>& out)>& mapper, const size_t reg_cnt)
: m_reg_type_mapper(mapper), reg_count(reg_cnt) {}
explicit AssignRegisters(RegManager& reg_manager) : m_reg_manager(reg_manager) {}
bool run(LinearIR& linear_ir) override;

private:
void set_reg_types(LinearIR& linear_ir);
using RegMap = std::map<Reg, Reg>;
static RegMap assign_regs_manually(const LinearIR& linear_ir, std::set<Reg>& gpr_pool, std::set<Reg>& vec_pool);

std::function<RegType(const ov::Output<Node>& out)> m_reg_type_mapper;
size_t reg_count;
RegManager& m_reg_manager;
};

} // namespace pass
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "pass.hpp"
#include "snippets/generator.hpp"
#include "snippets/lowered/reg_manager.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

/**
* @interface InitLiveRanges
* @brief Calculates live ranges of registers. This information will be used to assign registers and optimize ABI reg spills.
* @ingroup snippets
*/
class InitLiveRanges : public Pass {
public:
OPENVINO_RTTI("InitLiveRanges", "Pass")
explicit InitLiveRanges(RegManager& reg_manager) : m_reg_manager(reg_manager) {}
bool run(LinearIR& linear_ir) override;
private:
RegManager& m_reg_manager;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "pass.hpp"
#include "snippets/lowered/reg_manager.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

/**
* @interface InsertRegSpills
* @brief Insert RegSpill and RegRestore operations for binary call emitters to comply with ABI conventions.
* @ingroup snippets
*/
class InsertRegSpills : public Pass {
public:
OPENVINO_RTTI("InsertRegSpills", "Pass")
explicit InsertRegSpills(RegManager& reg_manager) : m_reg_manager(reg_manager) {}
bool run(LinearIR& linear_ir) override;

RegManager& m_reg_manager;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
62 changes: 62 additions & 0 deletions src/common/snippets/include/snippets/lowered/reg_manager.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once
#include "openvino/core/node.hpp"
#include "snippets/emitter.hpp"
#include "snippets/lowered/expression.hpp"
#include "snippets/generator.hpp"

/**
* @interface RegManager
* @brief The class holds supplementary info about assigned registers and live ranges
* @ingroup snippets
*/
namespace ov {
namespace snippets {
namespace lowered {

using RegTypeMapper = std::function<RegType(const ov::Output<Node>& out)>;
using LiveInterval = std::pair<double, double>;
class RegManager {
public:
RegManager() = delete;
RegManager(const std::shared_ptr<Generator>& generator) : m_generator(generator) {}
inline RegType get_reg_type(const ov::Output<Node>& out) const { return m_generator->get_op_out_reg_type(out); }
inline size_t get_gp_reg_count() const { return m_generator->get_target_machine()->get_gp_reg_count(); }
inline size_t get_vec_reg_count() const { return m_generator->get_target_machine()->get_vec_reg_count(); }
// inline bool need_abi_reg_spill() const {m_generator->}
inline void set_live_regs(const ExpressionPtr& expr, std::set<Reg>&& live, bool force = false) {
OPENVINO_ASSERT(force || m_live_reg.count(expr) == 0, "Live regs for this expression already registered");
m_live_reg.insert({expr, live});
}
inline const std::set<Reg>& get_live_regs(const ExpressionPtr& expr) const {
OPENVINO_ASSERT(m_live_reg.count(expr), "Live regs for this expression were not registered");
return m_live_reg.at(expr);
}

inline void set_live_range(const Reg& reg, LiveInterval&& interval, bool force = false) {
OPENVINO_ASSERT(force || m_reg_live_range.count(reg) == 0, "Live range for this reg is already set");
m_reg_live_range[reg] = interval;
}

inline const LiveInterval& get_live_range(const Reg& reg) {
OPENVINO_ASSERT(m_reg_live_range.count(reg), "Live range for this reg was not set");
return m_reg_live_range[reg];
}
inline std::map<Reg, LiveInterval> get_live_range_map() const {
return m_reg_live_range;
}

private:
// Maps Register to {Start, Stop} pairs
std::map<Reg, LiveInterval> m_reg_live_range;
// Regs that are live on input of the key expression
std::unordered_map<ExpressionPtr , std::set<Reg>> m_live_reg;
const std::shared_ptr<const Generator> m_generator;
};

} // namespace lowered
} // namespace snippets
} // namespace ov
74 changes: 74 additions & 0 deletions src/common/snippets/include/snippets/op/reg_spill.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "snippets/emitter.hpp"

#include "openvino/op/op.hpp"

namespace ov {
namespace snippets {
namespace op {

/**
* @interface RegSpillBase
* @brief Base class for RegSpillBegin and RegSpillEnd ops
* @ingroup snippets
*/
class RegSpillBase : public ov::op::Op {
public:
OPENVINO_OP("RegSpillBaseBase", "SnippetsOpset");
RegSpillBase(const std::vector<Output<Node>>& args);
RegSpillBase() = default;
virtual std::set<Reg> get_regs_to_spill() const = 0;
protected:
};
class RegSpillBegin;
/**
* @interface RegSpillEnd
* @brief Marks the end of the register spill region.
* @ingroup snippets
*/
class RegSpillEnd : public RegSpillBase {
public:
OPENVINO_OP("RegSpillEnd", "SnippetsOpset", RegSpillBase);
RegSpillEnd() = default;
RegSpillEnd(const Output<Node>& reg_spill_begin, std::set<Reg> regs_to_spill);

void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override;
std::shared_ptr<RegSpillBegin> get_reg_spill_begin();
std::set<Reg> get_regs_to_spill() const override { return m_regs_to_spill; }

protected:
std::set<Reg> m_regs_to_spill = {};
};

/**
* @interface RegSpillBegin
* @brief Marks the start of the register spill region.
* @ingroup snippets
*/
class RegSpillBegin : public RegSpillBase {
public:
OPENVINO_OP("RegSpillBegin", "SnippetsOpset", RegSpillBase);
RegSpillBegin();

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override;
std::shared_ptr<RegSpillEnd> get_reg_spill_end() const;
std::set<Reg> get_regs_to_spill() const override {
return get_reg_spill_end()->get_regs_to_spill();
}

protected:
void validate_and_infer_types_except_RegSpillEnd();
};

} // namespace op
} // namespace snippets
} // namespace ov
4 changes: 3 additions & 1 deletion src/common/snippets/include/snippets/op/subgraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
namespace ov {
namespace snippets {
namespace op {

// Accessor class to test private interface
class SubgarphTestAccessor;
/**
* @interface Subgraph
* @brief An operation that is implemented by a model
* @ingroup snippets
*/
class Subgraph : public ov::op::util::SubGraphOp {
friend class SubgarphTestAccessor;
public:
OPENVINO_OP("Subgraph", "SnippetsOpset", ov::op::util::SubGraphOp);
// < 1, 42, 17, 15, 16> < 0, 1, 2, 3, 1>
Expand Down
1 change: 1 addition & 0 deletions src/common/snippets/include/snippets/snippets_isa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "op/rank_normalization.hpp"
#include "op/perf_count.hpp"
#include "op/reduce.hpp"
#include "op/reg_spill.hpp"

namespace ov {
namespace snippets {
Expand Down
9 changes: 7 additions & 2 deletions src/common/snippets/include/snippets/target_machine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,15 @@ class TargetMachine {
virtual size_t get_lanes() const = 0;

/**
* @brief gets number of registers for a target machine
* @brief returns the number of available general-purpose registers.
* @return number of registers
*/
virtual size_t get_reg_count() const = 0;
virtual size_t get_gp_reg_count() const = 0;
/**
* @brief returns the number of available vector registers.
* @return number of registers
*/
virtual size_t get_vec_reg_count() const = 0;

/**
* @brief called by generator to all the emitter for a target machine
Expand Down
36 changes: 27 additions & 9 deletions src/common/snippets/src/emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,34 @@ bool operator==(const Reg& lhs, const Reg& rhs) {
bool operator!=(const Reg& lhs, const Reg& rhs) {
return !(lhs == rhs);
}
bool operator<(const Reg& lhs, const Reg& rhs) {
return lhs.type < rhs.type ||
(lhs.type == rhs.type && lhs.idx < rhs.idx);
}
bool operator>(const Reg& lhs, const Reg& rhs) {
return lhs.type > rhs.type ||
(lhs.type == rhs.type && lhs.idx > rhs.idx);
}

std::string regTypeToStr(const RegType& type) {
switch (type) {
case RegType::vec:
return "vec";
case RegType::gpr:
return "gpr";
default:
OPENVINO_THROW("Unexpected RegType");
}
std::ostream& operator<<(std::ostream& s, const Reg& r) {
auto regTypeToStr = [](const RegType& type) {
switch (type) {
case RegType::vec:
return "vec";
case RegType::gpr:
return "gpr";
case RegType::mask:
return "mask";
case RegType::undefined:
return "undefined";
default:
OPENVINO_THROW("Unexpected RegType");
}
};
s << regTypeToStr(r.type) << "[" <<
(r.idx == Reg::UNDEFINED_IDX ? "undefined" : std::to_string(r.idx))
<< "]";
return s;
}

} // namespace snippets
Expand Down
Loading
Loading