|
|
|
@ -195,32 +195,73 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
|
|
|
|
|
auto* weight_tensor =
|
|
|
|
|
scope->Var(quantized_op_weight_node->Name())->GetMutable<LoDTensor>();
|
|
|
|
|
auto w_dims = weight_tensor->dims();
|
|
|
|
|
float* quantized_weight_data =
|
|
|
|
|
weight_tensor->mutable_data<float>(platform::CPUPlace());
|
|
|
|
|
// If quantized op is fc, weight scale size = 1;
|
|
|
|
|
// If quantized op is conv2d, weight scale size = weight dims[0]
|
|
|
|
|
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
|
|
|
|
|
bool valid_scale_size =
|
|
|
|
|
(weight_scale.size() == 1 ||
|
|
|
|
|
weight_scale.size() == static_cast<size_t>(w_dims[0]) ||
|
|
|
|
|
weight_scale.size() == static_cast<size_t>(w_dims[1]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
valid_scale_size, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"TRT int8 quant: invalid scale size(%d).", weight_scale.size()));
|
|
|
|
|
float* quantized_weight_data =
|
|
|
|
|
weight_tensor->mutable_data<float>(platform::CPUPlace());
|
|
|
|
|
for (int j = 0; j < weight_tensor->numel(); j++) {
|
|
|
|
|
if (weight_scale.size() == 1) {
|
|
|
|
|
quantized_weight_data[j] *= weight_scale[0];
|
|
|
|
|
} else {
|
|
|
|
|
if (quantized_op_type == "conv2d_transpose") {
|
|
|
|
|
int inner_size = w_dims[2] * w_dims[3];
|
|
|
|
|
quantized_weight_data[j] *=
|
|
|
|
|
weight_scale[(j / inner_size) % w_dims[1]];
|
|
|
|
|
} else {
|
|
|
|
|
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
|
|
|
|
|
quantized_weight_data[j] *= weight_scale[j / inner_size];
|
|
|
|
|
if (quantized_op_type == "mul" || quantized_op_type == "fc") {
|
|
|
|
|
if (dequant_type == "fake_dequantize_max_abs") {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
weight_scale.size(), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"mul op weight dequantized by [fake_dequantize_max_abs] "
|
|
|
|
|
"requires weight scale size = 1, but got %d.",
|
|
|
|
|
weight_scale.size()));
|
|
|
|
|
for (int j = 0; j < weight_tensor->numel(); j++) {
|
|
|
|
|
quantized_weight_data[j] *= weight_scale[0];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
weight_scale.size(), static_cast<size_t>(w_dims[1]),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"mul op weight dequantized by "
|
|
|
|
|
"[fake_channel_wise_dequantize_max_abs] requires weight scale "
|
|
|
|
|
"size = 2nd dim of mul's weight, which is %d, but got %d.",
|
|
|
|
|
static_cast<size_t>(w_dims[1]), weight_scale.size()));
|
|
|
|
|
for (int j = 0; j < weight_tensor->numel(); j++) {
|
|
|
|
|
quantized_weight_data[j] *= weight_scale[j % w_dims[1]];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (quantized_op_type == "conv2d" ||
|
|
|
|
|
quantized_op_type == "depthwise_conv2d") {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dequant_type, "fake_channel_wise_dequantize_max_abs",
|
|
|
|
|
platform::errors::InvalidArgument("conv2d op must be dequantized by "
|
|
|
|
|
"[fake_channel_wise_dequantize_max_"
|
|
|
|
|
"abs], but got %s",
|
|
|
|
|
dequant_type));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
weight_scale.size(), static_cast<size_t>(w_dims[0]),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"conv2d op requires weight scale size = channel size of the "
|
|
|
|
|
"weight, which is %d, but got %d.",
|
|
|
|
|
static_cast<size_t>(w_dims[0]), weight_scale.size()));
|
|
|
|
|
for (int j = 0; j < weight_tensor->numel(); j++) {
|
|
|
|
|
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
|
|
|
|
|
quantized_weight_data[j] *= weight_scale[j / inner_size];
|
|
|
|
|
}
|
|
|
|
|
} else if (quantized_op_type == "conv2d_transpose") {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dequant_type, "fake_channel_wise_dequantize_max_abs",
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"conv2d_transpose must be dequantized by "
|
|
|
|
|
"[fake_channel_wise_dequantize_max_abs], but got %s",
|
|
|
|
|
dequant_type));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
weight_scale.size(), static_cast<size_t>(w_dims[1]),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"conv2d_transpose op requires weight scale size = channel size "
|
|
|
|
|
"of the weight, which is %d, but got %d.",
|
|
|
|
|
static_cast<size_t>(w_dims[1]), weight_scale.size()));
|
|
|
|
|
for (int j = 0; j < weight_tensor->numel(); j++) {
|
|
|
|
|
int inner_size = w_dims[2] * w_dims[3];
|
|
|
|
|
quantized_weight_data[j] *= weight_scale[(j / inner_size) % w_dims[1]];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"Unsupported quantized op type: %s", quantized_op_type));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// create new op_desc
|
|
|
|
|