[INT8] Add requant-op squash (#24143)

revert-24314-dev/fix_err_msg
joanna.wozna.intel 5 years ago committed by GitHub
parent e8869a907b
commit b43b46e619
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1508,6 +1508,27 @@ PDNode *patterns::OpRequant::operator()() {
return requant_out;
}
PDNode *patterns::RequantOp::operator()() {
auto requant_in = pattern->NewNode(requant_in_repr())
->assert_is_op_input("requantize", "Input");
auto requant_op =
pattern->NewNode(requant_op_repr())->assert_is_op("requantize");
auto requant_out = pattern->NewNode(requant_out_repr())
->AsOutput()
->assert_is_op_output("requantize", "Output");
auto any_op = pattern->NewNode(any_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
return (node->Op()->HasAttr("Scale_in") ||
node->Op()->HasAttr("Scale_x") ||
node->Op()->HasAttr("Scale_y"));
});
requant_op->LinksFrom({requant_in}).LinksTo({requant_out});
any_op->LinksFrom({requant_out});
return any_op;
}
PDNode *patterns::ConvDequant::operator()() {
// Create Operators
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");

@ -913,6 +913,22 @@ struct OpRequant : public PatternBase {
PATTERN_DECL_NODE(requant_out);
};
// Requant + Op
// named nodes:
// requant_in, requant_op,
// requant_out, any_op
struct RequantOp : public PatternBase {
RequantOp(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "requant_op") {}
PDNode* operator()();
PATTERN_DECL_NODE(any_op);
PATTERN_DECL_NODE(requant_in);
PATTERN_DECL_NODE(requant_op);
PATTERN_DECL_NODE(requant_out);
};
// Conv + Dequant
// named nodes:
// conv_op, conv_out

@ -152,7 +152,7 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const {
PADDLE_ENFORCE_NE(
any_op_output_name.empty(), true,
platform::errors::NotFound("Operator before requantize operator "
"should has requantize input as output"));
"should have requantize input as output"));
float requant_scale_out =
boost::get<float>(requant_op->Op()->GetAttr("Scale_out"));
@ -170,6 +170,59 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const {
found_requant_squash_count);
}
// requant-op squash if op has Scale_in, Scale_x, Scale_y attr
// conv2d, fc, matmul
void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::RequantOp requant_op_pattern{gpd.mutable_pattern(), "requant_op"};
requant_op_pattern();
int found_requant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash requantize-op ops pair";
GET_IR_NODE_FROM_SUBGRAPH(requant_in, requant_in, requant_op_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_op, requant_op, requant_op_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, requant_op_pattern);
GET_IR_NODE_FROM_SUBGRAPH(any_op, any_op, requant_op_pattern);
if (requant_out->outputs.size() == 1) {
std::string any_op_input_name;
for (auto name : any_op->Op()->InputNames())
for (auto input_name : any_op->Op()->Input(name))
if (input_name == requant_out->Name()) any_op_input_name = name;
PADDLE_ENFORCE_NE(
any_op_input_name.empty(), true,
platform::errors::NotFound("The operator after requantize operator "
"should have requantize output as input"));
float requant_scale_in =
boost::get<float>(requant_op->Op()->GetAttr("Scale_in"));
auto scale_name = "Scale_in";
if (any_op->Op()->Type() == "matmul")
scale_name = any_op_input_name == "X" ? "Scale_x" : "Scale_y";
PADDLE_ENFORCE_EQ(requant_op->Op()->GetAttrIfExists<float>("Scale_out"),
any_op->Op()->GetAttrIfExists<float>(scale_name),
platform::errors::InvalidArgument(
"The operator after requantize should have input "
"scale equal to requantize output scale"));
any_op->Op()->SetAttr(scale_name, requant_scale_in);
any_op->Op()->SetInput(any_op_input_name,
std::vector<std::string>({requant_in->Name()}));
IR_NODE_LINK_TO(requant_in, any_op);
GraphSafeRemoveNodes(graph, {requant_op, requant_out});
found_requant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_requant_squash_count);
PrettyLogDetail("--- squashed %d requantize ops",
found_requant_squash_count);
}
void CPUQuantizeSquashPass::ConvDequantSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::ConvDequant conv_dequant_pattern{gpd.mutable_pattern(),
@ -379,6 +432,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
FindNodesToKeep(graph, &nodes_keep_counter);
DequantQuantSquash(graph, &nodes_keep_counter);
OpRequantSquash(graph);
RequantOpSquash(graph);
ConvDequantSquash(graph);
FcDequantSquash(graph);
MultipleQuantizeSquash(graph);

@ -55,6 +55,11 @@ class CPUQuantizeSquashPass : public FusePassBase {
*/
void OpRequantSquash(Graph* graph) const;
/*
* Squash requantize op if the next operator's input scale can be updated
*/
void RequantOpSquash(Graph* graph) const;
/*
* Squash conv2d with dequant when dequant is the only op after conv2d
*/

Loading…
Cancel
Save