diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index f05ca65a3c5..e4974f6b7c8 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1450,7 +1450,6 @@ void Graph::flops_memops_sum() { for (auto op : ops) { - fprintf(stderr, "op->type: %s\n", op->type.c_str()); if (op->type[0] == 'F') { std::string sub_type = op->type.substr(2); @@ -1880,6 +1879,50 @@ void Graph::flops_memops_sum() memops += memops_qkv + memops_attention + memops_output; } } + + else if (op->type.substr(0, 5) == "torch") + { + std::string sub_type = op->type.substr(6); + if(sub_type == "matmul" + || sub_type == "mm" + || sub_type == "bmm") + { + std::vector input_shape_1 = op->inputs[0]->shape; + std::vector input_shape_2 = op->inputs[1]->shape; + int input_size_1 = std::accumulate(input_shape_1.begin(), input_shape_1.end(), 1, std::multiplies()); + int input_size_2 = std::accumulate(input_shape_2.begin(), input_shape_2.end(), 1, std::multiplies()); + std::vector output_shape = op->outputs[0]->shape; + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + flops += input_size_1 * input_shape_2.back(); + memops += input_size_1 + input_size_2 + output_size; + } + else if (sub_type == "addmm" + || sub_type == "baddbmm") + { + std::vector input_shape = op->inputs[0]->shape; + std::vector mat_shape_1 = op->inputs[1]->shape; + std::vector mat_shape_2 = op->inputs[2]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int mat_size_1 = std::accumulate(mat_shape_1.begin(), mat_shape_1.end(), 1, std::multiplies()); + int mat_size_2 = std::accumulate(mat_shape_2.begin(), mat_shape_2.end(), 1, std::multiplies()); + std::vector output_shape = op->outputs[0]->shape; + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + flops += input_size + mat_size_1 * mat_shape_2.back(); + memops += input_size + mat_size_1 + mat_size_2 + output_size; + } + else if (sub_type == "mul" + || sub_type == "add") + { + std::vector input_shape_1 = op->inputs[0]->shape; + std::vector input_shape_2 = op->inputs[1]->shape; + int input_size_1 = std::accumulate(input_shape_1.begin(), input_shape_1.end(), 1, std::multiplies()); + int input_size_2 = std::accumulate(input_shape_2.begin(), input_shape_2.end(), 1, std::multiplies()); + std::vector output_shape = op->outputs[0]->shape; + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + flops += output_size; + memops += input_size_1 + input_size_2 + output_size; + } + } } }