Skip to content

Commit

Permalink
[XPU] Support int31 weight dynamic quantization for fc and conv2d (#5…
Browse files Browse the repository at this point in the history
…9981) (#67058)

Co-authored-by: Travis-Lee <lixiang.fr@hotmail.com>
  • Loading branch information
newway and Travis-Lee authored Aug 7, 2024
1 parent 736a253 commit 6267a2b
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 6 deletions.
13 changes: 13 additions & 0 deletions paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,19 @@ void Conv2dXPUFusePass::CreateFusionWeightsAndBias(
false,
weight_scale,
true);
} else if (quant_post_type.find("conv2d") != quant_post_type.end() &&
quant_post_type.find("conv2d")->second == 4) {
VLOG(5) << "Use int31 per-tensor weight";
PrepareWeight<float, float>(graph,
scope,
block,
conv_filter_replicated_node,
&filter_intx,
&filter_max,
&scale_max,
false,
weight_scale,
false);
} else if (quant_post_type.find("conv2d") != quant_post_type.end() &&
quant_post_type.find("conv2d")->second == 0 ||
quant_post_type.find("conv2d") != quant_post_type.end() &&
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,19 @@ void FcXPUFusePass::CreateFusionWeightsAndBias(
!transpose_w,
weight_scale,
true);
} else if (quant_post_type.find("fc") != quant_post_type.end() &&
quant_post_type.find("fc")->second == 4) {
VLOG(5) << "Use int31 per-tensor weight";
PrepareWeight<float, float>(graph,
scope,
block,
mul_w_replicated_node,
&filter_intx,
&filter_max,
&scale_max,
!transpose_w,
weight_scale,
false);
} else if (quant_post_type.find("fc") != quant_post_type.end() &&
quant_post_type.find("fc")->second == 0 ||
quant_post_type.find("fc") != quant_post_type.end() &&
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/framework/ir/xpu/pass_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,18 @@ void PrepareWeight(Graph* graph,
}
}

template void PrepareWeight<float, float>(
Graph* graph,
Scope* scope,
BlockDesc* block,
Node* weight,
Node** dst_weight,
Node** dst_weight_max,
Node** dst_scale_max,
bool transpose,
const std::vector<float>& weight_scales,
bool per_channel_quant = false);

template void PrepareWeight<float, int16_t>(
Graph* graph,
Scope* scope,
Expand Down
53 changes: 47 additions & 6 deletions paddle/fluid/framework/ir/xpu/quant_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,16 @@ static void QuantFP32ToIntX(const float* src_ptr,
LOG(FATAL) << "Not support.";
}

template <>
void QuantFP32ToIntX<float>(const float* src_ptr,
float* dst_ptr,
float max_val,
int numel) {
for (int i = 0; i < numel; i++) {
dst_ptr[i] = static_cast<float>(src_ptr[i]);
}
}

template <>
void QuantFP32ToIntX<int16_t>(const float* src_ptr,
int16_t* dst_ptr,
Expand Down Expand Up @@ -364,16 +374,16 @@ void ConvertWithoutQuant(phi::DenseTensor* weight,
phi::DenseTensor* scale_max,
bool transpose,
const std::vector<float>& weight_scales) {
PADDLE_ENFORCE_EQ(
weight_scales.empty(),
false,
platform::errors::InvalidArgument(
"ConvertWithoutQuant is not allowed weight scales is empty!"));
if (transpose) {
Transpose2D(weight);
}
bool per_tensor_quant = weight_scales.size() == 1;
if (std::is_same<T, int8_t>::value || std::is_same<T, int16_t>::value) {
PADDLE_ENFORCE_EQ(
weight_scales.empty(),
false,
platform::errors::InvalidArgument(
"ConvertWithoutQuant is not allowed weight scales is empty!"));
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
if (per_tensor_quant) {
Expand All @@ -400,8 +410,32 @@ void ConvertWithoutQuant(phi::DenseTensor* weight,
weight_scales.data(),
weight_scales.size() * sizeof(float));
}
} else if (std::is_same<T, float>::value) {
// Convert fp16 to fp32
phi::DenseTensor weight_fp32;
CastToFp32(weight, &weight_fp32);
// Find max
int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1);
int size = weight_fp32.numel();
auto* weight_data = weight_fp32.data<float>();
float max_val = FindMaxAbs(weight_data, size);
std::vector<float> max_vec(max_ptr_size, max_val);
weight_max->set_type(phi::DataType::FLOAT32);
weight_max->Resize({max_ptr_size});
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
memcpy(cpu_ctx->Alloc<float>(weight_max),
max_vec.data(),
max_ptr_size * sizeof(float));

// Quant
weight->set_type(phi::DataType::FLOAT32);
weight->Resize(weight_fp32.dims());
QuantFP32ToIntX<float>(
weight_data, cpu_ctx->Alloc<float>(weight), max_val, size);
} else {
LOG(FATAL) << "Only support int8<->int8 and int16<->int16 convert.";
LOG(FATAL)
<< "Only support float<->int31, int8<->int8 and int16<->int16 convert.";
}
}

Expand All @@ -424,6 +458,13 @@ template void ConvertWithoutQuant<int8_t>(
bool transpose,
const std::vector<float>& weight_scales);

template void ConvertWithoutQuant<float>(
phi::DenseTensor* weight,
phi::DenseTensor* weight_max,
phi::DenseTensor* scale_max,
bool transpose,
const std::vector<float>& weight_scales);

bool IsPerTensorQuant(const std::vector<float>& weight_max) {
bool per_tensor = true;
PADDLE_ENFORCE_GT(
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ void Conv2dXPUKernel(const Context& ctx,
DataTypeToString(filter.dtype()),
DataTypeToString(out_dtype)));
}
} else if (filter.dtype() == DataType::FLOAT32) {
CONV2D_XPU_KERNEL_IMPL(float, float, float, int32_t);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Not support x_dtype is %s, filter_dtype is %s and out_dtype is %s.",
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ void FcXPUKernel(const Context& ctx,
DataTypeToString(w.dtype()),
DataTypeToString(out_dtype)));
}
} else if (w.dtype() == DataType::FLOAT32) {
FC_XPU_KERNEL_IMPL(float, float, float, int32_t);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Not support x_dtype is %s, w_dtype is %s and out_dtype is %s.",
Expand Down

0 comments on commit 6267a2b

Please sign in to comment.