|
|
|
@ -49,14 +49,14 @@ void CPUQuantizeSquashPass::FindNodesToKeep(
|
|
|
|
|
AddStatis(found_count);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CPUQuantizeSquashPass::Squash(
|
|
|
|
|
void CPUQuantizeSquashPass::DequantQuantSquash(
|
|
|
|
|
Graph* graph,
|
|
|
|
|
std::unordered_map<const Node*, int>* nodes_keep_counter) const {
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
patterns::DequantQuantAny squash_pattern{gpd.mutable_pattern(), "squash"};
|
|
|
|
|
squash_pattern();
|
|
|
|
|
|
|
|
|
|
int found_squash_count = 0;
|
|
|
|
|
int found_dequant_quant_count = 0;
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
VLOG(4) << "squash requantize-quantize ops pair";
|
|
|
|
@ -96,7 +96,7 @@ void CPUQuantizeSquashPass::Squash(
|
|
|
|
|
|
|
|
|
|
IR_NODE_LINK_TO(dequant_in, next_op);
|
|
|
|
|
|
|
|
|
|
found_squash_count++;
|
|
|
|
|
found_dequant_quant_count++;
|
|
|
|
|
} else {
|
|
|
|
|
// squash dequantize-quantize to requantize op
|
|
|
|
|
OpDesc desc;
|
|
|
|
@ -116,13 +116,48 @@ void CPUQuantizeSquashPass::Squash(
|
|
|
|
|
IR_NODE_LINK_TO(dequant_in, requant_op);
|
|
|
|
|
IR_NODE_LINK_TO(requant_op, quant_out);
|
|
|
|
|
|
|
|
|
|
found_squash_count++;
|
|
|
|
|
found_dequant_quant_count++;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
gpd(graph, handler);
|
|
|
|
|
AddStatis(found_squash_count);
|
|
|
|
|
AddStatis(found_dequant_quant_count);
|
|
|
|
|
PrettyLogDetail("--- squashed %d dequantize-quantize pairs",
|
|
|
|
|
found_squash_count);
|
|
|
|
|
found_dequant_quant_count);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CPUQuantizeSquashPass::ConvRequantSquash(Graph* graph) const {
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
patterns::ConvRequant conv_requant_pattern{gpd.mutable_pattern(),
|
|
|
|
|
"conv_requant"};
|
|
|
|
|
conv_requant_pattern();
|
|
|
|
|
|
|
|
|
|
int found_requant_squash_count = 0;
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
VLOG(4) << "squash conv-requantize ops pair";
|
|
|
|
|
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_requant_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_requant_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(requant_op, requant_op, conv_requant_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, conv_requant_pattern);
|
|
|
|
|
|
|
|
|
|
// if conv2d has one output squash
|
|
|
|
|
if (conv_out->outputs.size() == 1) {
|
|
|
|
|
float requant_scale_out =
|
|
|
|
|
boost::get<float>(requant_op->Op()->GetAttr("Scale_out"));
|
|
|
|
|
conv_op->Op()->SetAttr("Scale_out", requant_scale_out);
|
|
|
|
|
conv_op->Op()->SetOutput("Output",
|
|
|
|
|
std::vector<std::string>({requant_out->Name()}));
|
|
|
|
|
IR_NODE_LINK_TO(conv_op, requant_out);
|
|
|
|
|
GraphSafeRemoveNodes(graph, {conv_out, requant_op});
|
|
|
|
|
|
|
|
|
|
found_requant_squash_count++;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
gpd(graph, handler);
|
|
|
|
|
AddStatis(found_requant_squash_count);
|
|
|
|
|
PrettyLogDetail("--- squashed %d requantize with convs",
|
|
|
|
|
found_requant_squash_count);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
@ -131,7 +166,8 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
|
|
|
|
|
std::unordered_map<const Node*, int> nodes_keep_counter;
|
|
|
|
|
FindNodesToKeep(graph, &nodes_keep_counter);
|
|
|
|
|
Squash(graph, &nodes_keep_counter);
|
|
|
|
|
DequantQuantSquash(graph, &nodes_keep_counter);
|
|
|
|
|
ConvRequantSquash(graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace ir
|
|
|
|
|