|
|
|
@ -152,7 +152,7 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const {
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
any_op_output_name.empty(), true,
|
|
|
|
|
platform::errors::NotFound("Operator before requantize operator "
|
|
|
|
|
"should has requantize input as output"));
|
|
|
|
|
"should have requantize input as output"));
|
|
|
|
|
|
|
|
|
|
float requant_scale_out =
|
|
|
|
|
boost::get<float>(requant_op->Op()->GetAttr("Scale_out"));
|
|
|
|
@ -170,6 +170,59 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const {
|
|
|
|
|
found_requant_squash_count);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// requant-op squash if op has Scale_in, Scale_x, Scale_y attr
|
|
|
|
|
// conv2d, fc, matmul
|
|
|
|
|
void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
patterns::RequantOp requant_op_pattern{gpd.mutable_pattern(), "requant_op"};
|
|
|
|
|
requant_op_pattern();
|
|
|
|
|
|
|
|
|
|
int found_requant_squash_count = 0;
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
VLOG(4) << "squash requantize-op ops pair";
|
|
|
|
|
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(requant_in, requant_in, requant_op_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(requant_op, requant_op, requant_op_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, requant_op_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(any_op, any_op, requant_op_pattern);
|
|
|
|
|
|
|
|
|
|
if (requant_out->outputs.size() == 1) {
|
|
|
|
|
std::string any_op_input_name;
|
|
|
|
|
for (auto name : any_op->Op()->InputNames())
|
|
|
|
|
for (auto input_name : any_op->Op()->Input(name))
|
|
|
|
|
if (input_name == requant_out->Name()) any_op_input_name = name;
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
any_op_input_name.empty(), true,
|
|
|
|
|
platform::errors::NotFound("The operator after requantize operator "
|
|
|
|
|
"should have requantize output as input"));
|
|
|
|
|
float requant_scale_in =
|
|
|
|
|
boost::get<float>(requant_op->Op()->GetAttr("Scale_in"));
|
|
|
|
|
|
|
|
|
|
auto scale_name = "Scale_in";
|
|
|
|
|
if (any_op->Op()->Type() == "matmul")
|
|
|
|
|
scale_name = any_op_input_name == "X" ? "Scale_x" : "Scale_y";
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(requant_op->Op()->GetAttrIfExists<float>("Scale_out"),
|
|
|
|
|
any_op->Op()->GetAttrIfExists<float>(scale_name),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The operator after requantize should have input "
|
|
|
|
|
"scale equal to requantize output scale"));
|
|
|
|
|
any_op->Op()->SetAttr(scale_name, requant_scale_in);
|
|
|
|
|
any_op->Op()->SetInput(any_op_input_name,
|
|
|
|
|
std::vector<std::string>({requant_in->Name()}));
|
|
|
|
|
IR_NODE_LINK_TO(requant_in, any_op);
|
|
|
|
|
GraphSafeRemoveNodes(graph, {requant_op, requant_out});
|
|
|
|
|
found_requant_squash_count++;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
gpd(graph, handler);
|
|
|
|
|
AddStatis(found_requant_squash_count);
|
|
|
|
|
PrettyLogDetail("--- squashed %d requantize ops",
|
|
|
|
|
found_requant_squash_count);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CPUQuantizeSquashPass::ConvDequantSquash(Graph* graph) const {
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
patterns::ConvDequant conv_dequant_pattern{gpd.mutable_pattern(),
|
|
|
|
@ -379,6 +432,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
FindNodesToKeep(graph, &nodes_keep_counter);
|
|
|
|
|
DequantQuantSquash(graph, &nodes_keep_counter);
|
|
|
|
|
OpRequantSquash(graph);
|
|
|
|
|
RequantOpSquash(graph);
|
|
|
|
|
ConvDequantSquash(graph);
|
|
|
|
|
FcDequantSquash(graph);
|
|
|
|
|
MultipleQuantizeSquash(graph);
|
|
|
|
|