diff --git a/03Compiler/03Frontend/03OPFusion.md b/03Compiler/03Frontend/03OPFusion.md index 37f78d33..3f528fe7 100644 --- a/03Compiler/03Frontend/03OPFusion.md +++ b/03Compiler/03Frontend/03OPFusion.md @@ -68,7 +68,8 @@ if __name__ == "__main__": # kernel fusion # 将 conv2 的卷积核权重由(1, 1)扩展到(3, 3) - weight_expanded = F.interpolate(conv2_weight, size=(3, 3), mode='bilinear', align_corners=False) + weight_expanded = torch.zeros(16, 3, 3, 3) + weight_expanded[:, :, 1, 1] = conv2_weight[:, :, 0, 0] # conv1 卷积核与 conv2 卷积核融合 weight_fusion = torch.concatenate([conv1_weight, weight_expanded], dim=0) # conv1 偏置与 conv2 偏置融合