|
|
|
@ -228,6 +228,62 @@ void CPUQuantizeSquashPass::FcDequantSquash(Graph* graph) const {
|
|
|
|
|
found_fc_dequant_squash_count);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
patterns::MultipleQuantize multiple_quantize_pattern{gpd.mutable_pattern(),
|
|
|
|
|
"multiple_quantize"};
|
|
|
|
|
multiple_quantize_pattern();
|
|
|
|
|
|
|
|
|
|
int found_multiple_quantize_squash_count = 0;
|
|
|
|
|
int removed_quantize = 0;
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
VLOG(4) << "fuse multiple quantize ops";
|
|
|
|
|
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(prev_out, prev_out, multiple_quantize_pattern);
|
|
|
|
|
|
|
|
|
|
auto* first_quant_op = *(std::find_if(
|
|
|
|
|
prev_out->outputs.begin(), prev_out->outputs.end(), [&](Node* node) {
|
|
|
|
|
return (node->IsOp() && node->Op()->Type() == "quantize");
|
|
|
|
|
}));
|
|
|
|
|
auto* first_quant_out = first_quant_op->outputs[0];
|
|
|
|
|
float scale = first_quant_op->Op()->GetAttrIfExists<float>("Scale");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NE(scale, 0, platform::errors::InvalidArgument(
|
|
|
|
|
"Quantize scale should not be equal 0"));
|
|
|
|
|
|
|
|
|
|
for (int iter = prev_out->outputs.size() - 1; iter >= 0; iter--) {
|
|
|
|
|
auto quant_op = prev_out->outputs[iter];
|
|
|
|
|
if (quant_op->IsOp() && quant_op->Op()->Type() == "quantize" &&
|
|
|
|
|
quant_op->id() != first_quant_op->id() &&
|
|
|
|
|
quant_op->Op()->GetAttrIfExists<float>("Scale") == scale) {
|
|
|
|
|
auto quant_out = quant_op->outputs[0];
|
|
|
|
|
auto last_op = quant_out->outputs[0];
|
|
|
|
|
|
|
|
|
|
std::string last_op_input_name;
|
|
|
|
|
for (auto name : last_op->Op()->InputNames())
|
|
|
|
|
for (auto input_name : last_op->Op()->Input(name))
|
|
|
|
|
if (input_name == quant_out->Name()) last_op_input_name = name;
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
last_op_input_name.empty(), true,
|
|
|
|
|
platform::errors::NotFound("Operator after quantize operator "
|
|
|
|
|
"should has quantize output as input"));
|
|
|
|
|
last_op->Op()->SetInput(
|
|
|
|
|
last_op_input_name,
|
|
|
|
|
std::vector<std::string>({first_quant_out->Name()}));
|
|
|
|
|
IR_NODE_LINK_TO(first_quant_out, last_op);
|
|
|
|
|
GraphSafeRemoveNodes(graph, {quant_op, quant_out});
|
|
|
|
|
removed_quantize++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
found_multiple_quantize_squash_count++;
|
|
|
|
|
};
|
|
|
|
|
gpd(graph, handler);
|
|
|
|
|
AddStatis(found_multiple_quantize_squash_count);
|
|
|
|
|
PrettyLogDetail("--- squashed %d quantize op", removed_quantize);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
graph,
|
|
|
|
@ -240,6 +296,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
DequantQuantSquash(graph, &nodes_keep_counter);
|
|
|
|
|
ConvDequantSquash(graph);
|
|
|
|
|
FcDequantSquash(graph);
|
|
|
|
|
MultipleQuantizeSquash(graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace ir
|
|
|
|
|