From b84b395d5886740f7eadbe4cac7178160cadb921 Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 10 Aug 2023 11:35:50 +0800 Subject: [PATCH 1/2] pnnx fuse expression for scalar-like attribute and unbind chain --- .../pnnx/src/pass_level3/fuse_expression.cpp | 259 ++++++++++++------ 1 file changed, 182 insertions(+), 77 deletions(-) diff --git a/tools/pnnx/src/pass_level3/fuse_expression.cpp b/tools/pnnx/src/pass_level3/fuse_expression.cpp index d8fe46111a4..6b543f5ea52 100644 --- a/tools/pnnx/src/pass_level3/fuse_expression.cpp +++ b/tools/pnnx/src/pass_level3/fuse_expression.cpp @@ -47,6 +47,11 @@ static bool operand_maybe_tensor(const Operand* operand) return false; } + if (op->type == "torch.unbind" && op->inputs[0]->shape.size() == 1) + { + return false; + } + if (op->type == "aten::size") { return false; @@ -131,25 +136,7 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s { if (op->outputs.size() > 1 || op->outputs[0]->consumers.size() > 1) { - auto it = std::find(inputs.begin(), inputs.end(), operand); - if (it == inputs.end()) - { - // tensor - char tmp[32]; - sprintf(tmp, "@%d", (int)inputs.size()); - expr += tmp; - - inputs.push_back(operand); - } - else - { - // tensor - char tmp[32]; - sprintf(tmp, "@%d", (int)(it - inputs.begin())); - expr += tmp; - } - - return; + goto DEFAULT; } } @@ -189,25 +176,170 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s } else { - auto it = std::find(inputs.begin(), inputs.end(), operand); - if (it == inputs.end()) + goto DEFAULT; + } + } + else if (op->type == "pnnx.Attribute") + { + // fprintf(stderr, "operand pnnx.Attribute %s\n", operand->name.c_str()); + + const Attribute& data = op->attrs["data"]; + if (data.shape.size() == 1 && data.shape[0] == 1 && data.type != -1) + { + if (data.type == 0) + { + expr += "None"; + } + else if (data.type == 1) + { + char tmp[32]; + sprintf(tmp, "%e", ((const float*)data.data.data())[0]); + expr += tmp; + } + else if (data.type == 2) { - // tensor char tmp[32]; - sprintf(tmp, "@%d", (int)inputs.size()); + sprintf(tmp, "%e", ((const double*)data.data.data())[0]); expr += tmp; + } + else if (data.type == 4) + { + char tmp[32]; + sprintf(tmp, "%d", ((const int*)data.data.data())[0]); + expr += tmp; + } + else if (data.type == 5) + { + int64_t v = ((const int64_t*)data.data.data())[0]; + if (v == std::numeric_limits::max()) v = INT_MAX; + if (v == std::numeric_limits::min()) v = INT_MIN; - inputs.push_back(operand); + char tmp[32]; + sprintf(tmp, "%d", (int)v); + expr += tmp; } - else + else if (data.type == 6) { - // tensor char tmp[32]; - sprintf(tmp, "@%d", (int)(it - inputs.begin())); + sprintf(tmp, "%d", ((const short*)data.data.data())[0]); expr += tmp; } + else if (data.type == 7) + { + char tmp[32]; + sprintf(tmp, "%d", ((const signed char*)data.data.data())[0]); + expr += tmp; + } + else if (data.type == 8) + { + char tmp[32]; + sprintf(tmp, "%u", ((const unsigned char*)data.data.data())[0]); + expr += tmp; + } + else if (data.type == 9) + { + expr += ((const char*)data.data.data())[0] ? "True" : "False"; + } + else + { + // unsupported type + fprintf(stderr, "fuse expression got unsupported scalar type %d\n", data.type); + } + } + else + { + goto DEFAULT; } } + else if (op->type == "torch.unbind") + { + // track chain + // pnnx.Attribute/foldable with 1-rank + // torch.unbind to constant scalar + Operand* operand2 = op->inputs[0]; + if (operand2->producer->type == "pnnx.Attribute") + { + const Attribute& data = operand2->producer->attrs["data"]; + + if (data.shape.size() == 1 && data.type != -1) + { + // resolve scalar i + int si = 0; + for (size_t i = 0; i < op->outputs.size(); i++) + { + if (op->outputs[i] == operand) + { + si = (int)i; + break; + } + } + + if (data.type == 0) + { + expr += "None"; + } + else if (data.type == 1) + { + char tmp[32]; + sprintf(tmp, "%e", ((const float*)data.data.data())[si]); + expr += tmp; + } + else if (data.type == 2) + { + char tmp[32]; + sprintf(tmp, "%e", ((const double*)data.data.data())[si]); + expr += tmp; + } + else if (data.type == 4) + { + char tmp[32]; + sprintf(tmp, "%d", ((const int*)data.data.data())[si]); + expr += tmp; + } + else if (data.type == 5) + { + int64_t v = ((const int64_t*)data.data.data())[si]; + if (v == std::numeric_limits::max()) v = INT_MAX; + if (v == std::numeric_limits::min()) v = INT_MIN; + + char tmp[32]; + sprintf(tmp, "%d", (int)v); + expr += tmp; + } + else if (data.type == 6) + { + char tmp[32]; + sprintf(tmp, "%d", ((const short*)data.data.data())[si]); + expr += tmp; + } + else if (data.type == 7) + { + char tmp[32]; + sprintf(tmp, "%d", ((const signed char*)data.data.data())[si]); + expr += tmp; + } + else if (data.type == 8) + { + char tmp[32]; + sprintf(tmp, "%u", ((const unsigned char*)data.data.data())[si]); + expr += tmp; + } + else if (data.type == 9) + { + expr += ((const char*)data.data.data())[si] ? "True" : "False"; + } + else + { + // unsupported type + fprintf(stderr, "fuse expression got unsupported scalar type %d\n", data.type); + goto DEFAULT; + } + return; + } + } + + goto DEFAULT; + } else if (checksubgraph && operand_maybe_tensor(operand) && foldable_constants.find(operand->name) != foldable_constants.end()) { // fprintf(stderr, "operand_is_foldable %s\n", operand->name.c_str()); @@ -316,23 +448,7 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s } else { - auto it = std::find(inputs.begin(), inputs.end(), operand); - if (it == inputs.end()) - { - // tensor - char tmp[32]; - sprintf(tmp, "@%d", (int)inputs.size()); - expr += tmp; - - inputs.push_back(operand); - } - else - { - // tensor - char tmp[32]; - sprintf(tmp, "@%d", (int)(it - inputs.begin())); - expr += tmp; - } + goto DEFAULT; } } else if (op->type == "prim::NumToTensor") @@ -376,23 +492,7 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s } else { - auto it = std::find(inputs.begin(), inputs.end(), operand); - if (it == inputs.end()) - { - // tensor - char tmp[32]; - sprintf(tmp, "@%d", (int)inputs.size()); - expr += tmp; - - inputs.push_back(operand); - } - else - { - // tensor - char tmp[32]; - sprintf(tmp, "@%d", (int)(it - inputs.begin())); - expr += tmp; - } + goto DEFAULT; } } else if (op->type == "aten::detach" || op->type == "aten::ScalarImplicit") @@ -539,23 +639,28 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s } else { - auto it = std::find(inputs.begin(), inputs.end(), operand); - if (it == inputs.end()) - { - // tensor - char tmp[32]; - sprintf(tmp, "@%d", (int)inputs.size()); - expr += tmp; + goto DEFAULT; + } - inputs.push_back(operand); - } - else - { - // tensor - char tmp[32]; - sprintf(tmp, "@%d", (int)(it - inputs.begin())); - expr += tmp; - } + return; + +DEFAULT: + auto it = std::find(inputs.begin(), inputs.end(), operand); + if (it == inputs.end()) + { + // tensor + char tmp[32]; + sprintf(tmp, "@%d", (int)inputs.size()); + expr += tmp; + + inputs.push_back(operand); + } + else + { + // tensor + char tmp[32]; + sprintf(tmp, "@%d", (int)(it - inputs.begin())); + expr += tmp; } } From 109ebe3e6f54f26848e580d09ad069512150bd42 Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 10 Aug 2023 11:50:04 +0800 Subject: [PATCH 2/2] test++ --- tools/pnnx/tests/test_pnnx_expression.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tools/pnnx/tests/test_pnnx_expression.py b/tools/pnnx/tests/test_pnnx_expression.py index 7a7dc9967bd..4834b13cd31 100644 --- a/tools/pnnx/tests/test_pnnx_expression.py +++ b/tools/pnnx/tests/test_pnnx_expression.py @@ -27,9 +27,12 @@ def __init__(self): self.w3 = nn.Parameter(torch.rand(12, 15)) self.w4 = nn.Parameter(torch.rand(12, 15)) self.w5 = nn.Parameter(torch.rand(12, 15)) + self.c0 = nn.Parameter(torch.ones(1)) + self.c1 = nn.Parameter(torch.ones(3) + 0.2) def forward(self, x): - x0 = x * 10 + c10, c11, _ = torch.unbind(self.c1) + x0 = x * 10 + self.c0 - c11 x = x + self.w0 + x0 x = x - self.w1 + x0.float() x = x * self.w2 + x0