|
|
|
@ -223,71 +223,44 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
|
|
|
|
|
found_requant_squash_count);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CPUQuantizeSquashPass::ConvDequantSquash(Graph* graph) const {
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
patterns::ConvDequant conv_dequant_pattern{gpd.mutable_pattern(),
|
|
|
|
|
"conv_dequant"};
|
|
|
|
|
conv_dequant_pattern();
|
|
|
|
|
|
|
|
|
|
int found_conv_dequant_squash_count = 0;
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
VLOG(4) << "squash conv-dequant ops pair";
|
|
|
|
|
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_dequant_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_dequant_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, conv_dequant_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, conv_dequant_pattern);
|
|
|
|
|
|
|
|
|
|
// if conv2d has one output
|
|
|
|
|
// and there is no fuse residual connection
|
|
|
|
|
// because residual fusion does not support force output with fp32
|
|
|
|
|
if (conv_out->outputs.size() == 1 &&
|
|
|
|
|
!(conv_op->Op()->GetAttrIfExists<bool>("fuse_residual_connection"))) {
|
|
|
|
|
conv_op->Op()->SetAttr("force_fp32_output", true);
|
|
|
|
|
conv_op->Op()->SetOutput("Output",
|
|
|
|
|
std::vector<std::string>({dequant_out->Name()}));
|
|
|
|
|
IR_NODE_LINK_TO(conv_op, dequant_out);
|
|
|
|
|
GraphSafeRemoveNodes(graph, {conv_out, dequant_op});
|
|
|
|
|
found_conv_dequant_squash_count++;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
gpd(graph, handler);
|
|
|
|
|
AddStatis(found_conv_dequant_squash_count);
|
|
|
|
|
PrettyLogDetail("--- squashed %d dequant with convs",
|
|
|
|
|
found_conv_dequant_squash_count);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// squash fc with dequant
|
|
|
|
|
void CPUQuantizeSquashPass::FcDequantSquash(Graph* graph) const {
|
|
|
|
|
// squash dequant with previous op if that op has force_fp32_output attr
|
|
|
|
|
// conv2d, fc, matmul
|
|
|
|
|
void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const {
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
patterns::FcDequant fc_dequant_pattern{gpd.mutable_pattern(), "fc_dequant"};
|
|
|
|
|
fc_dequant_pattern();
|
|
|
|
|
patterns::OpDequant op_dequant_pattern{gpd.mutable_pattern(), "op_dequant"};
|
|
|
|
|
op_dequant_pattern();
|
|
|
|
|
|
|
|
|
|
int found_fc_dequant_squash_count = 0;
|
|
|
|
|
int found_op_dequant_squash_count = 0;
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
VLOG(4) << "squash fc-dequant ops pair";
|
|
|
|
|
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc_op, fc_dequant_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fc_dequant_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, fc_dequant_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, fc_dequant_pattern);
|
|
|
|
|
|
|
|
|
|
// if fc has force_fp32_output attribute
|
|
|
|
|
if (fc_out->outputs.size() == 1) {
|
|
|
|
|
fc_op->Op()->SetAttr("force_fp32_output", true);
|
|
|
|
|
fc_op->Op()->SetOutput("Out",
|
|
|
|
|
std::vector<std::string>({dequant_out->Name()}));
|
|
|
|
|
IR_NODE_LINK_TO(fc_op, dequant_out);
|
|
|
|
|
GraphSafeRemoveNodes(graph, {fc_out, dequant_op});
|
|
|
|
|
found_fc_dequant_squash_count++;
|
|
|
|
|
VLOG(4) << "squash op-dequant ops pair";
|
|
|
|
|
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(any_op, any_op, op_dequant_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(dequant_in, dequant_in, op_dequant_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, op_dequant_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, op_dequant_pattern);
|
|
|
|
|
|
|
|
|
|
if (dequant_in->outputs.size() == 1) {
|
|
|
|
|
auto output_name = "Out";
|
|
|
|
|
if (any_op->Op()->Type() == "conv2d") {
|
|
|
|
|
// do not squash if fuse residual connection is true
|
|
|
|
|
// because residual fusion does not support force output with fp32
|
|
|
|
|
if (any_op->Op()->GetAttrIfExists<bool>("fuse_residual_connection"))
|
|
|
|
|
return;
|
|
|
|
|
output_name = "Output";
|
|
|
|
|
}
|
|
|
|
|
any_op->Op()->SetAttr("force_fp32_output", true);
|
|
|
|
|
any_op->Op()->SetOutput(output_name,
|
|
|
|
|
std::vector<std::string>({dequant_out->Name()}));
|
|
|
|
|
IR_NODE_LINK_TO(any_op, dequant_out);
|
|
|
|
|
GraphSafeRemoveNodes(graph, {dequant_in, dequant_op});
|
|
|
|
|
found_op_dequant_squash_count++;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
gpd(graph, handler);
|
|
|
|
|
AddStatis(found_fc_dequant_squash_count);
|
|
|
|
|
PrettyLogDetail("--- squashed %d dequant with fcs",
|
|
|
|
|
found_fc_dequant_squash_count);
|
|
|
|
|
AddStatis(found_op_dequant_squash_count);
|
|
|
|
|
PrettyLogDetail("--- squashed %d dequant with ops",
|
|
|
|
|
found_op_dequant_squash_count);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
|
|
|
|
@ -389,38 +362,6 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
|
|
|
|
|
found_dequant_scale_squash_count);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// squash dequant with dequant
|
|
|
|
|
void CPUQuantizeSquashPass::MatmulDequantSquash(Graph* graph) const {
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
patterns::MatmulDequant matmul_dequant_pattern{gpd.mutable_pattern(),
|
|
|
|
|
"matmul_dequant"};
|
|
|
|
|
matmul_dequant_pattern();
|
|
|
|
|
|
|
|
|
|
int found_matmul_dequant_squash_count = 0;
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
VLOG(4) << "squash matmul-dequant ops pair";
|
|
|
|
|
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_dequant_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_dequant_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, matmul_dequant_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, matmul_dequant_pattern);
|
|
|
|
|
|
|
|
|
|
if (matmul_out->outputs.size() == 1) {
|
|
|
|
|
matmul_op->Op()->SetAttr("force_fp32_output", true);
|
|
|
|
|
matmul_op->Op()->SetOutput(
|
|
|
|
|
"Out", std::vector<std::string>({dequant_out->Name()}));
|
|
|
|
|
IR_NODE_LINK_TO(matmul_op, dequant_out);
|
|
|
|
|
GraphSafeRemoveNodes(graph, {matmul_out, dequant_op});
|
|
|
|
|
found_matmul_dequant_squash_count++;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
gpd(graph, handler);
|
|
|
|
|
AddStatis(found_matmul_dequant_squash_count);
|
|
|
|
|
PrettyLogDetail("--- squashed %d dequant with matmul",
|
|
|
|
|
found_matmul_dequant_squash_count);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
graph,
|
|
|
|
@ -433,11 +374,9 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
DequantQuantSquash(graph, &nodes_keep_counter);
|
|
|
|
|
OpRequantSquash(graph);
|
|
|
|
|
RequantOpSquash(graph);
|
|
|
|
|
ConvDequantSquash(graph);
|
|
|
|
|
FcDequantSquash(graph);
|
|
|
|
|
OpDequantSquash(graph);
|
|
|
|
|
MultipleQuantizeSquash(graph);
|
|
|
|
|
DequantScaleSquash(graph);
|
|
|
|
|
MatmulDequantSquash(graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace ir
|
|
|
|
|