Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pnnx fuse expression for scalar-like attribute and unbind chain #4928

Merged
merged 2 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 182 additions & 77 deletions tools/pnnx/src/pass_level3/fuse_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ static bool operand_maybe_tensor(const Operand* operand)
return false;
}

if (op->type == "torch.unbind" && op->inputs[0]->shape.size() == 1)
{
return false;
}

if (op->type == "aten::size")
{
return false;
Expand Down Expand Up @@ -131,25 +136,7 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s
{
if (op->outputs.size() > 1 || op->outputs[0]->consumers.size() > 1)
{
auto it = std::find(inputs.begin(), inputs.end(), operand);
if (it == inputs.end())
{
// tensor
char tmp[32];
sprintf(tmp, "@%d", (int)inputs.size());
expr += tmp;

inputs.push_back(operand);
}
else
{
// tensor
char tmp[32];
sprintf(tmp, "@%d", (int)(it - inputs.begin()));
expr += tmp;
}

return;
goto DEFAULT;
}
}

Expand Down Expand Up @@ -189,25 +176,170 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s
}
else
{
auto it = std::find(inputs.begin(), inputs.end(), operand);
if (it == inputs.end())
goto DEFAULT;
}
}
else if (op->type == "pnnx.Attribute")
{
// fprintf(stderr, "operand pnnx.Attribute %s\n", operand->name.c_str());

const Attribute& data = op->attrs["data"];
if (data.shape.size() == 1 && data.shape[0] == 1 && data.type != -1)
{
if (data.type == 0)
{
expr += "None";
}
else if (data.type == 1)
{
char tmp[32];
sprintf(tmp, "%e", ((const float*)data.data.data())[0]);
expr += tmp;
}
else if (data.type == 2)
{
// tensor
char tmp[32];
sprintf(tmp, "@%d", (int)inputs.size());
sprintf(tmp, "%e", ((const double*)data.data.data())[0]);
expr += tmp;
}
else if (data.type == 4)
{
char tmp[32];
sprintf(tmp, "%d", ((const int*)data.data.data())[0]);
expr += tmp;
}
else if (data.type == 5)
{
int64_t v = ((const int64_t*)data.data.data())[0];
if (v == std::numeric_limits<int64_t>::max()) v = INT_MAX;
if (v == std::numeric_limits<int64_t>::min()) v = INT_MIN;

inputs.push_back(operand);
char tmp[32];
sprintf(tmp, "%d", (int)v);
expr += tmp;
}
else
else if (data.type == 6)
{
// tensor
char tmp[32];
sprintf(tmp, "@%d", (int)(it - inputs.begin()));
sprintf(tmp, "%d", ((const short*)data.data.data())[0]);
expr += tmp;
}
else if (data.type == 7)
{
char tmp[32];
sprintf(tmp, "%d", ((const signed char*)data.data.data())[0]);
expr += tmp;
}
else if (data.type == 8)
{
char tmp[32];
sprintf(tmp, "%u", ((const unsigned char*)data.data.data())[0]);
expr += tmp;
}
else if (data.type == 9)
{
expr += ((const char*)data.data.data())[0] ? "True" : "False";
}
else
{
// unsupported type
fprintf(stderr, "fuse expression got unsupported scalar type %d\n", data.type);
}
}
else
{
goto DEFAULT;
}
}
else if (op->type == "torch.unbind")
{
// track chain
// pnnx.Attribute/foldable with 1-rank
// torch.unbind to constant scalar
Operand* operand2 = op->inputs[0];
if (operand2->producer->type == "pnnx.Attribute")
{
const Attribute& data = operand2->producer->attrs["data"];

if (data.shape.size() == 1 && data.type != -1)
{
// resolve scalar i
int si = 0;
for (size_t i = 0; i < op->outputs.size(); i++)
{
if (op->outputs[i] == operand)
{
si = (int)i;
break;
}
}

if (data.type == 0)
{
expr += "None";
}
else if (data.type == 1)
{
char tmp[32];
sprintf(tmp, "%e", ((const float*)data.data.data())[si]);
expr += tmp;
}
else if (data.type == 2)
{
char tmp[32];
sprintf(tmp, "%e", ((const double*)data.data.data())[si]);
expr += tmp;
}
else if (data.type == 4)
{
char tmp[32];
sprintf(tmp, "%d", ((const int*)data.data.data())[si]);
expr += tmp;
}
else if (data.type == 5)
{
int64_t v = ((const int64_t*)data.data.data())[si];
if (v == std::numeric_limits<int64_t>::max()) v = INT_MAX;
if (v == std::numeric_limits<int64_t>::min()) v = INT_MIN;

char tmp[32];
sprintf(tmp, "%d", (int)v);
expr += tmp;
}
else if (data.type == 6)
{
char tmp[32];
sprintf(tmp, "%d", ((const short*)data.data.data())[si]);
expr += tmp;
}
else if (data.type == 7)
{
char tmp[32];
sprintf(tmp, "%d", ((const signed char*)data.data.data())[si]);
expr += tmp;
}
else if (data.type == 8)
{
char tmp[32];
sprintf(tmp, "%u", ((const unsigned char*)data.data.data())[si]);
expr += tmp;
}
else if (data.type == 9)
{
expr += ((const char*)data.data.data())[si] ? "True" : "False";
}
else
{
// unsupported type
fprintf(stderr, "fuse expression got unsupported scalar type %d\n", data.type);
goto DEFAULT;
}
return;
}
}

goto DEFAULT;
}
else if (checksubgraph && operand_maybe_tensor(operand) && foldable_constants.find(operand->name) != foldable_constants.end())
{
// fprintf(stderr, "operand_is_foldable %s\n", operand->name.c_str());
Expand Down Expand Up @@ -316,23 +448,7 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s
}
else
{
auto it = std::find(inputs.begin(), inputs.end(), operand);
if (it == inputs.end())
{
// tensor
char tmp[32];
sprintf(tmp, "@%d", (int)inputs.size());
expr += tmp;

inputs.push_back(operand);
}
else
{
// tensor
char tmp[32];
sprintf(tmp, "@%d", (int)(it - inputs.begin()));
expr += tmp;
}
goto DEFAULT;
}
}
else if (op->type == "prim::NumToTensor")
Expand Down Expand Up @@ -376,23 +492,7 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s
}
else
{
auto it = std::find(inputs.begin(), inputs.end(), operand);
if (it == inputs.end())
{
// tensor
char tmp[32];
sprintf(tmp, "@%d", (int)inputs.size());
expr += tmp;

inputs.push_back(operand);
}
else
{
// tensor
char tmp[32];
sprintf(tmp, "@%d", (int)(it - inputs.begin()));
expr += tmp;
}
goto DEFAULT;
}
}
else if (op->type == "aten::detach" || op->type == "aten::ScalarImplicit")
Expand Down Expand Up @@ -539,23 +639,28 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s
}
else
{
auto it = std::find(inputs.begin(), inputs.end(), operand);
if (it == inputs.end())
{
// tensor
char tmp[32];
sprintf(tmp, "@%d", (int)inputs.size());
expr += tmp;
goto DEFAULT;
}

inputs.push_back(operand);
}
else
{
// tensor
char tmp[32];
sprintf(tmp, "@%d", (int)(it - inputs.begin()));
expr += tmp;
}
return;

DEFAULT:
auto it = std::find(inputs.begin(), inputs.end(), operand);
if (it == inputs.end())
{
// tensor
char tmp[32];
sprintf(tmp, "@%d", (int)inputs.size());
expr += tmp;

inputs.push_back(operand);
}
else
{
// tensor
char tmp[32];
sprintf(tmp, "@%d", (int)(it - inputs.begin()));
expr += tmp;
}
}

Expand Down
5 changes: 4 additions & 1 deletion tools/pnnx/tests/test_pnnx_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ def __init__(self):
self.w3 = nn.Parameter(torch.rand(12, 15))
self.w4 = nn.Parameter(torch.rand(12, 15))
self.w5 = nn.Parameter(torch.rand(12, 15))
self.c0 = nn.Parameter(torch.ones(1))
self.c1 = nn.Parameter(torch.ones(3) + 0.2)

def forward(self, x):
x0 = x * 10
c10, c11, _ = torch.unbind(self.c1)
x0 = x * 10 + self.c0 - c11
x = x + self.w0 + x0
x = x - self.w1 + x0.float()
x = x * self.w2 + x0
Expand Down
Loading