Skip to content

Commit

Permalink
Remove duplicating embedding tables when one operand isn't embedding …
Browse files Browse the repository at this point in the history
…op, because this is the spatial feature that was solving blocking issues.
  • Loading branch information
dgolubovicTT committed Sep 30, 2024
1 parent d2514d2 commit e5757cc
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 75 deletions.
3 changes: 0 additions & 3 deletions forge/csrc/forge_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,6 @@ graphlib::Graph* run_pre_lowering_passes(
// Bypass embedding input nops
bypass_embedding_input_nops(graph);

// If there are any non-embedding users of the emdebbing table, it needs to be duplicated
duplicate_embedding_table_if_needed(graph);

//
// Data formats
//
Expand Down
71 changes: 0 additions & 71 deletions forge/csrc/passes/pre_lowering_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,77 +200,6 @@ void bypass_embedding_input_nops(Graph *graph)
}
}

void duplicate_embedding_table_if_needed(Graph *graph)
{
for (Node *input_node : graph->nodes_by_type(NodeType::kInput))
{
// each embedding user needs its own embedding table, any non-embedding users can share a table
std::vector<graphlib::Edge> user_edges = graph->user_data_edges(input_node);
if (user_edges.size() == 1)
continue;

bool embedding_users = false;
for (graphlib::Edge edge : user_edges)
if (graph->node_by_id(edge.consumer_node_id)->as<graphlib::PyOpNode>()->op_type().op == "embedding")
embedding_users = true;

if (!embedding_users)
continue;

graphlib::Node *non_embedding_users_table = nullptr;

auto clone_param = [graph](Node *param) {
auto new_edge = clone_input_forking_edge(graph, graph->user_data_edges(param)[0]);
auto new_param = graph->node_by_id(new_edge.producer_node_id);
auto *consteval_graph = new_param->as<graphlib::InputNode>()->get_consteval_graph(graph, true);
consteval_graph->promote_node(graph, param);

return new_param;
};
std::vector<graphlib::Node *>params;
params.push_back(input_node);
for (graphlib::Edge edge : user_edges)
{
graphlib::Node *param = graph->node_by_id(edge.producer_node_id);
graphlib::Node *node = graph->node_by_id(edge.consumer_node_id);
if (node->as<graphlib::PyOpNode>()->op_type().op == "embedding")
{
// duplicate the embedding-table iff embedding weight (not index) has more than 1 user
if (graph->data_users(param).size() != 1 and edge.consumer_input_port_id == 0)
{
params.push_back(clone_param(param));
}
}
else
{
if (non_embedding_users_table)
{
graph->add_edge(Edge(non_embedding_users_table->id(), 0, edge.consumer_node_id, edge.consumer_input_port_id, edge.edge_type));
graph->remove_edge(edge);
}
else
{
// possibly the param is already cloned once and #users decrease, check the condition again
if (graph->data_users(param).size() != 1)
{
non_embedding_users_table = clone_param(param);
params.push_back(non_embedding_users_table);
}
}
}
}
for (graphlib::Node *param : params)
{
bool constevaled = true;
while(constevaled)
{
constevaled = try_consteval_input_no_operand_forks(graph, param->as<graphlib::InputNode>(), true);
}
}
}

}


bool safe_to_hoist_past(const Graph *graph, const Node *operand)
{
Expand Down
1 change: 0 additions & 1 deletion forge/csrc/passes/pre_lowering_passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ using Node = graphlib::Node;
void convert_broadcast_ops_to_tms(Graph *graph);

void bypass_embedding_input_nops(Graph *graph);
void duplicate_embedding_table_if_needed(Graph *graph);

bool safe_to_hoist_past(const Graph *graph, const Node *operand);

Expand Down

0 comments on commit e5757cc

Please sign in to comment.