Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metal] fix Conv2d_transpose MPS #8439

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include "lite/backends/metal/metal_context.h"
#include "lite/backends/metal/metal_debug.h"
#include "lite/backends/metal/mps_conv_datasource.h"
#include "lite/kernels/metal/image_op/metal_params.h"

namespace paddle {
namespace lite {
Expand All @@ -39,6 +41,7 @@ class Conv2dTransposeImageCompute

public:
void PrepareForRun() override;
void ReInitWhenNeeded() override;
void Run() override;
void SaveOutput() override {
MetalDebug::SaveOutput(
Expand All @@ -48,6 +51,7 @@ class Conv2dTransposeImageCompute

private:
bool use_mps_{false};
void* mps_conv_trans_op_{nullptr};
void* mps_conv_op_{nullptr};
void* mps_input_image_{nullptr};
void* mps_output_image_{nullptr};
Expand Down
138 changes: 138 additions & 0 deletions lite/kernels/metal/image_op/conv2d_transpose_image_compute.mm
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,38 @@
init_for_run();
}

void Conv2dTransposeImageCompute::ReInitWhenNeeded() {
const auto& param = this->Param<param_t>();
auto input_dims = param.x->dims();

if (last_input_dims_ != input_dims) {
release_memory();
init_memory();

if (use_mps_) {
if (@available(iOS 10.0, macOS 10.13, macCatalyst 13.0, *)) {
if (mps_input_image_) {
CFRelease(mps_input_image_);
mps_input_image_ = nullptr;
}
if (mps_output_image_) {
CFRelease(mps_output_image_);
mps_output_image_ = nullptr;
}
auto input_c = static_cast<int>(input_buffer_->tensor_dim_[1]);
auto output_c = static_cast<int>(output_buffer_->tensor_dim_[1]);
// MPS input and output
mps_input_image_ = (__bridge_retained void*)[[MPSImage alloc]
initWithTexture:input_buffer_->image()
featureChannels:input_c];
mps_output_image_ = (__bridge_retained void*)[[MPSImage alloc]
initWithTexture:output_buffer_->image()
featureChannels:output_c];
}
}
}
}

// attention!!! filter: CNHW2NCHW
void Conv2dTransposeImageCompute::init_attention() {
const auto& param = this->Param<param_t>();
Expand Down Expand Up @@ -73,6 +105,13 @@
const auto& param = this->Param<param_t>();

function_name_ = KernelFunctionName(param);
bool should_use_mps = false;
if (@available(iOS 10.0, macOS 10.13, macCatalyst 13.0, *)) {
if (metal_context_->use_mps()) {
should_use_mps = true;
}
}
use_mps_ = should_use_mps;
if (use_mps_) {
setup_with_mps();
} else {
Expand Down Expand Up @@ -427,9 +466,108 @@
#pragma mark - MPS

void Conv2dTransposeImageCompute::run_with_mps() {
auto backend = (__bridge MetalContextImp*)metal_context_->backend();
auto cmdbuf = [backend commandBuffer];
if (mps_conv_trans_op_) {
if (@available(iOS 10.0, macOS 10.13, macCatalyst 13.0, *)) {
[((__bridge MPSCNNConvolutionTranspose*)mps_conv_trans_op_)
encodeToCommandBuffer:cmdbuf
sourceImage:(__bridge MPSImage*)mps_input_image_
destinationImage:(__bridge MPSImage*)mps_output_image_];
}
}
[backend commit:cmdbuf];
}

void Conv2dTransposeImageCompute::setup_with_mps() {
const auto& param = this->Param<param_t>();
auto backend = (__bridge MetalContextImp*)metal_context_->backend();
auto padding_top = (*param.paddings)[0];
auto padding_left = (*param.paddings)[2];

int offsetX =
static_cast<int>(param.filter->dims()[3] / 2 - param.filter->dims()[3] + 1 + padding_left);
int offsetY =
static_cast<int>(param.filter->dims()[2] / 2 - param.filter->dims()[2] + 1 + padding_top);

auto rawdata = param.filter->data<float>();
auto dims = filter_metal_dims_; //
auto tensorDim = DDimLite({dims[0], dims[1], dims[2], dims[3]}); //
auto count = tensorDim.production();

void* convertedPointer = TargetWrapperMetal::Malloc(count * sizeof(float));
TargetWrapperMetal::MemsetSync(convertedPointer, 0, count * sizeof(float));
auto weightsPointer = (float*)rawdata;
auto transposed = (float*)convertedPointer;

int length_nhw = dims[0] * dims[2] * dims[3];
int length_chw = dims[1] * dims[2] * dims[3];
int length_hw = dims[2] * dims[3];

for (int n = 0; n < dims[0]; n++) {
for (int c = 0; c < dims[1]; c++) {
for (int h = 0; h < dims[2]; h++) {
for (int w = 0; w < dims[3]; w++) {
int tIndex = h * dims[3] + w + length_nhw * c + length_hw * n;
int index = length_chw * n + (dims[2] - 1 - h) * dims[3] * dims[1] +
(dims[3] - 1 - w) * dims[1] + c;
transposed[index] = weightsPointer[tIndex];
}
}
}
}
// mps-Convolution
if (@available(iOS 10.0, macOS 10.13, macCatalyst 13.0, *)) {
output_buffer_->use_mps_ = true;
const_cast<MetalImage*>(input_buffer_)->use_mps_ = true;
auto filter_h = static_cast<int>(param.filter->dims()[2]);
auto filter_w = static_cast<int>(param.filter->dims()[3]);
auto input_c = MAX(4, static_cast<int>(input_buffer_->tensor_dim_[1]));
auto output_c = MAX(4, static_cast<int>(output_buffer_->tensor_dim_[1]));

MPSCNNConvolutionDescriptor* description =
[MPSCNNConvolutionDescriptor cnnConvolutionDescriptorWithKernelWidth:filter_w
kernelHeight:filter_h
inputFeatureChannels:input_c
outputFeatureChannels:output_c];

description.strideInPixelsX = param.strides[0];
description.strideInPixelsY = param.strides[1];
description.dilationRateX = (*param.dilations)[0];
description.dilationRateY = (*param.dilations)[1];
description.groups = 1;

MPSConvDataSource* scoure = [[MPSConvDataSource alloc] init];
scoure.descriptor = description;
filter_buffer_ = std::make_shared<MetalBuffer>(
metal_context_, filter_metal_dims_, METAL_PRECISION_TYPE::HALF);
filter_buffer_->convert_to_nhwc_ = false;
filter_buffer_->CopyFromNCHW<float>(transposed);
scoure.weights = filter_buffer_->rawdata();
if (param.bias && canMPSAddByChannel()) {
if (bias_buffer_->src_tensor_) {
lite::Tensor* y = (lite::Tensor*)(bias_buffer_->src_tensor_);
auto bias = y->data<float>();
scoure.biasTerms = const_cast<float*>(bias);
}
}
mps_conv_trans_op_ = (__bridge_retained void*)[[MPSCNNConvolutionTranspose alloc]
initWithDevice:backend.device
weights:scoure];
((__bridge MPSCNNConvolutionTranspose*)mps_conv_trans_op_).offset =
MPSOffset{.x = 0, .y = 0, .z = 0};
((__bridge MPSCNNConvolutionTranspose*)mps_conv_trans_op_).edgeMode = MPSImageEdgeModeZero;
((__bridge MPSCNNConvolutionTranspose*)mps_conv_trans_op_).kernelOffsetX = offsetX;
((__bridge MPSCNNConvolutionTranspose*)mps_conv_trans_op_).kernelOffsetY = offsetY;

// MPS input and output
mps_input_image_ =
(__bridge_retained void*)[[MPSImage alloc] initWithTexture:input_buffer_->image()
featureChannels:input_c];
mps_output_image_ =
(__bridge_retained void*)[[MPSImage alloc] initWithTexture:output_buffer_->image()
featureChannels:output_c];
}
}

#pragma mark - internal
Expand Down