Skip to content

Commit

Permalink
all finished
Browse files Browse the repository at this point in the history
  • Loading branch information
SZUwishion committed Sep 20, 2024
1 parent 4adf254 commit 296954d
Showing 1 changed file with 44 additions and 1 deletion.
45 changes: 44 additions & 1 deletion tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1450,7 +1450,6 @@ void Graph::flops_memops_sum()
{
for (auto op : ops)
{
fprintf(stderr, "op->type: %s\n", op->type.c_str());
if (op->type[0] == 'F')
{
std::string sub_type = op->type.substr(2);
Expand Down Expand Up @@ -1880,6 +1879,50 @@ void Graph::flops_memops_sum()
memops += memops_qkv + memops_attention + memops_output;
}
}

else if (op->type.substr(0, 5) == "torch")
{
std::string sub_type = op->type.substr(6);
if(sub_type == "matmul"
|| sub_type == "mm"
|| sub_type == "bmm")
{
std::vector<int> input_shape_1 = op->inputs[0]->shape;
std::vector<int> input_shape_2 = op->inputs[1]->shape;
int input_size_1 = std::accumulate(input_shape_1.begin(), input_shape_1.end(), 1, std::multiplies<int>());
int input_size_2 = std::accumulate(input_shape_2.begin(), input_shape_2.end(), 1, std::multiplies<int>());
std::vector<int> output_shape = op->outputs[0]->shape;
int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
flops += input_size_1 * input_shape_2.back();
memops += input_size_1 + input_size_2 + output_size;
}
else if (sub_type == "addmm"
|| sub_type == "baddbmm")
{
std::vector<int> input_shape = op->inputs[0]->shape;
std::vector<int> mat_shape_1 = op->inputs[1]->shape;
std::vector<int> mat_shape_2 = op->inputs[2]->shape;
int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int>());
int mat_size_1 = std::accumulate(mat_shape_1.begin(), mat_shape_1.end(), 1, std::multiplies<int>());
int mat_size_2 = std::accumulate(mat_shape_2.begin(), mat_shape_2.end(), 1, std::multiplies<int>());
std::vector<int> output_shape = op->outputs[0]->shape;
int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
flops += input_size + mat_size_1 * mat_shape_2.back();
memops += input_size + mat_size_1 + mat_size_2 + output_size;
}
else if (sub_type == "mul"
|| sub_type == "add")
{
std::vector<int> input_shape_1 = op->inputs[0]->shape;
std::vector<int> input_shape_2 = op->inputs[1]->shape;
int input_size_1 = std::accumulate(input_shape_1.begin(), input_shape_1.end(), 1, std::multiplies<int>());
int input_size_2 = std::accumulate(input_shape_2.begin(), input_shape_2.end(), 1, std::multiplies<int>());
std::vector<int> output_shape = op->outputs[0]->shape;
int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
flops += output_size;
memops += input_size_1 + input_size_2 + output_size;
}
}
}
}

Expand Down

0 comments on commit 296954d

Please sign in to comment.