Skip to content

Commit

Permalink
[training] splitting forward and backward graphs (#334)
Browse files Browse the repository at this point in the history
A pass is added, which takes the forge graph and creates a
`ForgeGraphModule`, by splitting the forward and backward parts of the
graph.

Since we are currently executing forward and backward pass as separate
programs (functions) within a binary, splitting the graphs seems to be
the natural approach.

Forward graph is created by cloning the forward nodes of the original
graph, and adding output nodes for all the intermediaate results the
backward pass will need.

The backward graph is constructed from backward nodes of the original
graph and utilizing the intermediate results from the forward graph.
For each gradient an output node is created.

The `CompiledModel` is updated to have the state from both graphs.
Additionally, `CompiledGraphState` is cleaned up a little - removing the
unused stuff.

In `lower_to_mlir()` removed the previous hacks - now we simply emit
mlir function for each graph in `ForgeGraphModule` and no special hacks
are needed for getting inputs/params/outputs of each graph.

This change enables MNIST Linear training to work.

Closes #100 #163
  • Loading branch information
pilkicTT authored Oct 2, 2024
1 parent e5757cc commit 7e2fe7f
Show file tree
Hide file tree
Showing 25 changed files with 572 additions and 293 deletions.
5 changes: 4 additions & 1 deletion forge/csrc/forge_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace py = pybind11;
#include "passes/passes_utils.hpp"
#include "passes/python_bindings.hpp"
#include "passes/mlir_compiler.hpp"
#include "passes/split_graph.hpp"
#include "python_bindings_common.hpp"
#include "reportify/reportify.hpp"
#include "runtime/python_bindings.hpp"
Expand Down Expand Up @@ -120,7 +121,8 @@ PYBIND11_MODULE(_C, m) {

py::class_<tt::ForgeGraphModule>(m, "ForgeGraphModule")
.def(py::init<std::string, tt::graphlib::Graph *>(), py::arg("name"), py::arg("forward_graph"))
.def("set_graph", &tt::ForgeGraphModule::set_graph);
.def("set_graph", &tt::ForgeGraphModule::set_graph)
.def("get_graph", &tt::ForgeGraphModule::get_graph);

py::module_ m_autograd = m.def_submodule("autograd", "Submodule defining autograd_engine.");
AutogradModule(m_autograd);
Expand Down Expand Up @@ -202,6 +204,7 @@ PYBIND11_MODULE(_C, m) {
py::arg("graph"),
py::arg("default_df_override") = std::optional<DataFormat>{});
m.def("run_mlir_compiler", &passes::run_mlir_compiler);
m.def("split_graph", &passes::split_graph);

m.def(
"dump_graph",
Expand Down
36 changes: 21 additions & 15 deletions forge/csrc/graph_lib/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,21 +674,13 @@ void Graph::copy_module_targets(Graph *old_graph, const std::unordered_map<Node
}
}



void Graph::register_module_outputs(const std::vector<NodeId>& module_outputs, std::vector<bool> requires_grad, bool append) {
TT_ASSERT(module_outputs.size() == requires_grad.size());
void Graph::register_module_outputs(const std::vector<NodeId>& module_outputs, bool append) {
if (!append) {
this->ordered_module_output_node_ids_.clear();
}
for (NodeId module_output : module_outputs) {
this->ordered_module_output_node_ids_.push_back(module_output);
}
for (std::size_t i=0; i < module_outputs.size(); i++)
{
OutputNode *out = node_by_id(module_outputs[i])->as<OutputNode>();
out->set_requires_grad(requires_grad[i]);
}
}

void Graph::copy_module_outputs(Graph *old_graph, const std::unordered_map<Node *, Node *> &old_to_new) {
Expand Down Expand Up @@ -797,19 +789,33 @@ std::vector<std::string> Graph::get_ordered_input_gradient_names() const {

}

std::vector<Node*> Graph::ordered_intermediates() const
{
std::vector<Node*> ordered_intermediates;
for (auto node : this->nodes())
{
if (node->node_type() == NodeType::kOutput
&& node->as<OutputNode>()->is_intermediate())
{
ordered_intermediates.push_back(node);
}
}

return ordered_intermediates;
}

std::vector<std::string> Graph::get_ordered_intermediate_names() const
{
std::vector<std::string> ordered_intermediate_names;
for (Node* node : this->nodes())
for (auto node : this->nodes())
{
if (node->node_type() == NodeType::kOutput)
if (node->node_type() == NodeType::kOutput
&& node->as<OutputNode>()->is_intermediate())
{
if (node->as<OutputNode>()->is_saved_intermediate())
{
ordered_intermediate_names.push_back(node->name());
}
ordered_intermediate_names.push_back(node->name());
}
}

return ordered_intermediate_names;
}

Expand Down
6 changes: 4 additions & 2 deletions forge/csrc/graph_lib/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class Graph
void update_node_name(Node *node, const std::string &new_name);

void register_module_inputs(const std::vector<NodeId> &module_inputs, bool append = false);
void register_module_outputs(const std::vector<NodeId> &module_outputs, std::vector<bool> requires_grad, bool append = false);
void register_module_outputs(const std::vector<NodeId> &module_outputs, bool append = false);
void register_module_targets(const std::vector<NodeId> &module_targets);
void copy_module_inputs(Graph *old_graph, const std::unordered_map<Node *, Node *> &old_to_new);
void copy_module_outputs(Graph *old_graph, const std::unordered_map<Node *, Node *> &old_to_new);
Expand All @@ -248,6 +248,7 @@ class Graph
std::vector<Node *> ordered_module_inputs() const;
std::vector<Node *> ordered_module_outputs() const;
std::vector<Node *> ordered_partial_datacopy_outputs() const;
std::vector<Node *> ordered_intermediates() const;
std::vector<Node *> get_constant_nodes(bool recurse = false) const;
std::vector<Node *> get_parameter_nodes() const;
std::vector<std::string> get_constant_names() const;
Expand Down Expand Up @@ -314,6 +315,7 @@ class Graph
std::vector<NodeId> ordered_module_input_node_ids_;
std::vector<NodeId> ordered_module_output_node_ids_;
std::vector<NodeId> ordered_module_target_node_ids_;
std::vector<NodeId> ordered_module_intermediate_node_ids_;

// ordered by insertion order
std::vector<Node *> nodes_;
Expand Down Expand Up @@ -343,7 +345,7 @@ NodeClassType *Graph::add_node(std::unique_ptr<NodeClassType> node, unsigned int
if (this->has_node_with_name(node->name()))
{
throw std::runtime_error(
"In graph " + std::to_string(this->id()) +
"In graph " + this->name() +
": trying to add a node with a name that already exists: " + node->name() + "\n");
}

Expand Down
8 changes: 4 additions & 4 deletions forge/csrc/graph_lib/node_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ class OutputNode : public QueueNode {
protected:
bool requires_grad_;
bool is_loss_output_;
bool is_saved_intermediate_;
bool is_intermediate_;
bool untilize_;
RuntimeTensorTransform runtime_tensor_transform;
// The golden info is needed if we fractured the output and need to reconstruct it for golden comparison
Expand All @@ -272,17 +272,17 @@ class OutputNode : public QueueNode {
QueueNode(name, QueueNodeType::Output, NodeType::kOutput),
requires_grad_(false),
is_loss_output_(false),
is_saved_intermediate_(false),
is_intermediate_(false),
untilize_(true)
{
}
bool requires_grad() const { return requires_grad_; }
bool is_loss_output() const { return is_loss_output_; }
bool is_saved_intermediate() const { return is_saved_intermediate_; }
bool is_intermediate() const { return is_intermediate_; }
bool untilize() const { return untilize_; }
void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
void set_loss_output() { is_loss_output_ = true; }
void set_saved_intermediate(bool saved_intermediate) { is_saved_intermediate_ = saved_intermediate; }
void set_intermediate(bool intermediate) { is_intermediate_ = intermediate; }
void set_untilize(bool should_untilize) { untilize_ = should_untilize; }
virtual std::unique_ptr<Node> clone(std::string const& name = "") const override;

Expand Down
2 changes: 1 addition & 1 deletion forge/csrc/graph_lib/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ void GraphModule(py::module &m_graph)
.def("get_ordered_intermediate_names", &Graph::get_ordered_intermediate_names)
.def("get_ordered_output_names", &Graph::get_ordered_output_names)
.def("get_ordered_target_names", &Graph::get_ordered_target_names)
.def("get_ordered_intermediate_names", &Graph::get_ordered_intermediate_names)
.def("get_ordered_input_gradient_names", &Graph::get_ordered_input_gradient_names)
.def("get_ordered_output_gradient_names", &Graph::get_ordered_output_gradient_names)
.def("get_ordered_input_requires_grad", &Graph::get_ordered_input_requires_grad)
Expand All @@ -110,7 +111,6 @@ void GraphModule(py::module &m_graph)
"register_module_outputs",
&Graph::register_module_outputs,
py::arg("module_outputs"),
py::arg("requires_grad"),
py::arg("append") = false)
.def("register_module_targets", &Graph::register_module_targets)
.def("get_ordered_input_shapes", &Graph::get_ordered_input_shapes)
Expand Down
30 changes: 0 additions & 30 deletions forge/csrc/graph_lib/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,36 +446,6 @@ bool is_input_host_queue(bool input_queues_on_host, const Graph *graph, const No

bool is_output_host_queue(bool output_queues_on_host, const Graph *graph, const Node *node);

// This method is used as a workaround when generating mlir for models in training.
// Remove once we modify autograd to create separate graphs for backward and forward.
template<SubgraphType subgraph_type>
GraphTraversalContext get_subgraph_traversal_context(Graph *graph)
{
static_assert(subgraph_type == SubgraphType::Forward || subgraph_type == SubgraphType::Backward,
"Subgraph traversal context is implemented only for forward and backward subgraphs.");

if constexpr (subgraph_type == SubgraphType::Forward)
{
auto fwd_nodes = graph->nodes([](graphlib::Node *node){
return node->get_epoch_type() == graphlib::NodeEpochType::Forward;
});

auto fwd_nodes_set = std::make_unique<const std::unordered_set<const graphlib::Node *>>(fwd_nodes.begin(), fwd_nodes.end());
return GraphTraversalContext(graph, std::move(fwd_nodes_set));
}
else if constexpr (subgraph_type == SubgraphType::Backward)
{
auto bwd_nodes = graph->nodes([](graphlib::Node *node){
return node->get_epoch_type() == graphlib::NodeEpochType::Backward
|| (node->node_type() == graphlib::NodeType::kInput);
});

auto bwd_nodes_set = std::make_unique<const std::unordered_set<const graphlib::Node *>>(bwd_nodes.begin(), bwd_nodes.end());

return GraphTraversalContext(graph, std::move(bwd_nodes_set));
}
}

// Wrapper graph management utility class for Node.
// If remove_from_graph is set to true on destruction of NodeGraphContainer
// graph->remove_node(node) will be invoked.
Expand Down
37 changes: 4 additions & 33 deletions forge/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,7 @@ class MLIRGenerator
// Emit MLIR functions for each graph in the module.
for (auto graph : module.graphs())
{
// Currently there is only one graph in the ForgeGraphModule. This will change after completion of issue #100.
// For now, we keep the hack for splitting the single graph into forward and backward subgraphs.
TT_ASSERT(module.graphs().size() == 1, "Expected only one graph in ForgeGraphModule");

{
auto traversal_context = graphlib::get_subgraph_traversal_context<graphlib::SubgraphType::Forward>(graph);
emit_mlir_function(graph);
}

if (graph->training())
{
auto traversal_context = graphlib::get_subgraph_traversal_context<graphlib::SubgraphType::Backward>(graph);
emit_mlir_function(graph, "backward");
}
emit_mlir_function(graph, graph->name());
}

/// Verify the module after we have finished constructing it, this will check
Expand Down Expand Up @@ -205,14 +192,6 @@ class MLIRGenerator
// Add the graph parameters to the argument list.
for(auto *parameter: graph->get_parameter_nodes())
{
// Check whether the parameter is actually used in the current graph context,
// for example when compiling model for training we will emit separate mlirs
// for forward and backward subgraphs (via GraphTraversalContext).
if (graph->data_users(parameter).empty())
{
log_trace(LogMLIRCompiler, "Skipping parameter {} as it is not used in the current graph context.", parameter->name());
continue;
}
log_trace(LogMLIRCompiler, "Adding parameter {} to the argument list.", parameter->name());

argument_nodes.push_back(parameter);
Expand All @@ -221,11 +200,7 @@ class MLIRGenerator

// Assemble the function return values (outputs)
llvm::SmallVector<mlir::Type> returns;
auto output_nodes = graph->nodes([](const graphlib::Node *node) {
return node->node_type() == tt::graphlib::NodeType::kOutput
|| (node->node_type() == tt::graphlib::NodeType::kQueue && node->as<graphlib::QueueNode>()->is_grad_accumulator());
});

auto output_nodes = graph->ordered_module_outputs();
for (auto *output : output_nodes)
{
log_trace(LogMLIRCompiler, "Adding output {} to the return list.", output->name());
Expand Down Expand Up @@ -457,14 +432,10 @@ class MLIRGenerator
// Assemble the function return values (outputs)
llvm::SmallVector<mlir::Value> returnValues;

auto output_nodes = graph->nodes([](const graphlib::Node *node) {
return node->node_type() == tt::graphlib::NodeType::kOutput
|| (node->node_type() == tt::graphlib::NodeType::kQueue && node->as<graphlib::QueueNode>()->is_grad_accumulator());
});

auto output_nodes = graph->ordered_module_outputs();
for (auto *output : output_nodes)
{
TT_ASSERT(graph->data_operands(output).size() == 1, "Output node must have exactly one operand.");
TT_ASSERT(graph->data_operands(output).size() == 1, "Output node " + output->name() + " must have exactly one operand.");
auto output_operand = graph->data_operands(output)[0];
auto outputValue = symbolTable_[output_operand->name()].first;
returnValues.push_back(outputValue);
Expand Down
2 changes: 2 additions & 0 deletions forge/csrc/passes/mlir_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ namespace tt::passes
run_mlir_passes(mlir_module);
tt::log_info(LogMLIRCompiler, "MLIR passes run successfully.");

mlir_module->dump();

// Generate binary from the MLIR module.
auto binary = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get());
tt::log_info(LogMLIRCompiler, "Flatbuffer binary generated successfully.");
Expand Down
2 changes: 1 addition & 1 deletion forge/csrc/passes/pre_placer_forge_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ void insert_queues_for_op_intermediates(graphlib::Graph *graph, const std::vecto
graph->add_edge(node, intermediate_output);
intermediate_output->set_shape(Shape::create(node->shape().as_vector()));
intermediate_output->set_output_df(node->output_df());
intermediate_output->set_saved_intermediate(true);
intermediate_output->set_intermediate(true);
intermediate_output->set_epoch_type(node->get_epoch_type());
}
}
Expand Down
3 changes: 1 addition & 2 deletions forge/csrc/passes/remove_nops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ void remove_nops(graphlib::Graph *graph) {
continue;

if (op->op_name() == "nop") {
log_warning("Removing nop: {}", op->name());
graphlib::bypass_node(graph, node, true);
}
}
}

}
}
Loading

0 comments on commit 7e2fe7f

Please sign in to comment.