|
|
|
@ -81,7 +81,8 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
|
|
|
|
|
if (quantized_op_type == "conv2d" ||
|
|
|
|
|
quantized_op_type == "conv2d_fusion" ||
|
|
|
|
|
quantized_op_type == "depthwise_conv2d" ||
|
|
|
|
|
quantized_op_type == "fc") {
|
|
|
|
|
quantized_op_type == "fc" ||
|
|
|
|
|
quantized_op_type == "conv2d_transpose") {
|
|
|
|
|
op_desc->SetAttr("Input_scale", scale_value);
|
|
|
|
|
} else if (quantized_op_type == "mul") {
|
|
|
|
|
op_desc->SetAttr("X_scale", scale_value);
|
|
|
|
@ -111,7 +112,8 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
|
|
|
|
|
std::string input_name = "";
|
|
|
|
|
if (quantized_op_type == "conv2d" ||
|
|
|
|
|
quantized_op_type == "depthwise_conv2d" ||
|
|
|
|
|
quantized_op_type == "conv2d_fusion") {
|
|
|
|
|
quantized_op_type == "conv2d_fusion" ||
|
|
|
|
|
quantized_op_type == "conv2d_transpose") {
|
|
|
|
|
weight_name = "Filter";
|
|
|
|
|
input_name = "Input";
|
|
|
|
|
} else if (quantized_op_type == "mul") {
|
|
|
|
@ -122,7 +124,8 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
|
|
|
|
|
input_name = "Input";
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|
"QuantDequantFuse: We only support conv2d, conv2d_fusion, fc, mul for "
|
|
|
|
|
"QuantDequantFuse: We only support conv2d, conv2d_fusion, "
|
|
|
|
|
"conv2d_transpose, fc, mul for "
|
|
|
|
|
"now."));
|
|
|
|
|
}
|
|
|
|
|
const std::string pattern_name = "dequant_fuse";
|
|
|
|
@ -192,10 +195,12 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
|
|
|
|
|
scope->Var(quantized_op_weight_node->Name())->GetMutable<LoDTensor>();
|
|
|
|
|
auto w_dims = weight_tensor->dims();
|
|
|
|
|
// If quantized op is fc, weight scale size = 1;
|
|
|
|
|
// If quantized op is conv, weight scale size = weight dims[0]
|
|
|
|
|
// If quantized op is conv2d, weight scale size = weight dims[0]
|
|
|
|
|
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
|
|
|
|
|
bool valid_scale_size =
|
|
|
|
|
(weight_scale.size() == 1 ||
|
|
|
|
|
weight_scale.size() == static_cast<size_t>(w_dims[0]));
|
|
|
|
|
weight_scale.size() == static_cast<size_t>(w_dims[0]) ||
|
|
|
|
|
weight_scale.size() == static_cast<size_t>(w_dims[1]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
valid_scale_size, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
@ -206,8 +211,14 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
|
|
|
|
|
if (weight_scale.size() == 1) {
|
|
|
|
|
quantized_weight_data[j] *= weight_scale[0];
|
|
|
|
|
} else {
|
|
|
|
|
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
|
|
|
|
|
quantized_weight_data[j] *= weight_scale[j / inner_size];
|
|
|
|
|
if (quantized_op_type == "conv2d_transpose") {
|
|
|
|
|
int inner_size = w_dims[2] * w_dims[3];
|
|
|
|
|
quantized_weight_data[j] *=
|
|
|
|
|
weight_scale[(j / inner_size) % w_dims[1]];
|
|
|
|
|
} else {
|
|
|
|
|
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
|
|
|
|
|
quantized_weight_data[j] *= weight_scale[j / inner_size];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -220,7 +231,8 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
|
|
|
|
|
new_op_desc.SetType(quantized_op_type);
|
|
|
|
|
new_op_desc.SetAttr("enable_int8", true);
|
|
|
|
|
if (quantized_op_type == "conv2d" || quantized_op_type == "conv2d_fusion" ||
|
|
|
|
|
quantized_op_type == "depthwise_conv2d") {
|
|
|
|
|
quantized_op_type == "depthwise_conv2d" ||
|
|
|
|
|
quantized_op_type == "conv2d_transpose") {
|
|
|
|
|
new_op_desc.SetInput("Input", {new_input});
|
|
|
|
|
new_op_desc.SetOutput("Output", {new_output});
|
|
|
|
|
} else if (quantized_op_type == "fc") {
|
|
|
|
@ -253,7 +265,7 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
std::unordered_set<std::string> quant_types = {
|
|
|
|
|
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
|
|
|
|
|
std::unordered_set<std::string> quantized_op_types = {
|
|
|
|
|
"conv2d", "mul", "depthwise_conv2d", "fc"};
|
|
|
|
|
"conv2d", "mul", "depthwise_conv2d", "fc", "conv2d_transpose"};
|
|
|
|
|
auto* scope = param_scope();
|
|
|
|
|
|
|
|
|
|
for (auto& quant_type : quant_types) {
|
|
|
|
|