Add conv reqantize squash (#18754)

* Add requantize squash

test=develop

* Add more precise tests
test=develop

* REname and REfactor tester

test=develop
padding_in_crf
joanna.wozna.intel 6 years ago committed by Tao Luo
parent c548e370f1
commit 492a00f53e

@ -1275,6 +1275,23 @@ PDNode *patterns::ConvConcatReLU::operator()() {
return relu_out;
}
PDNode *patterns::ConvRequant::operator()() {
// Create Operators
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto requant_op =
pattern->NewNode(requant_op_repr())->assert_is_op("requantize");
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d", "Output");
auto requant_out = pattern->NewNode(requant_out_repr())
->AsOutput()
->assert_is_op_output("requantize", "Output");
conv_op->LinksTo({conv_out});
requant_op->LinksFrom({conv_out}).LinksTo({requant_out});
return requant_out;
}
PDNode *patterns::PriorBox::operator()() {
auto prior_box_op =
pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box");

@ -796,6 +796,23 @@ struct ConvConcatReLU : public PatternBase {
PATTERN_DECL_NODE(relu_out);
};
// Conv + Requant
// named nodes:
// conv_op, conv_out
// requant_op, requant_out
struct ConvRequant : public PatternBase {
ConvRequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_requant") {}
PDNode* operator()();
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(requant_op);
PATTERN_DECL_NODE(requant_out);
};
// PriorBox operator
// operator: prior_box_op
// inputs: prior_box_input, prior_box_image

@ -49,14 +49,14 @@ void CPUQuantizeSquashPass::FindNodesToKeep(
AddStatis(found_count);
}
void CPUQuantizeSquashPass::Squash(
void CPUQuantizeSquashPass::DequantQuantSquash(
Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const {
GraphPatternDetector gpd;
patterns::DequantQuantAny squash_pattern{gpd.mutable_pattern(), "squash"};
squash_pattern();
int found_squash_count = 0;
int found_dequant_quant_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash requantize-quantize ops pair";
@ -96,7 +96,7 @@ void CPUQuantizeSquashPass::Squash(
IR_NODE_LINK_TO(dequant_in, next_op);
found_squash_count++;
found_dequant_quant_count++;
} else {
// squash dequantize-quantize to requantize op
OpDesc desc;
@ -116,13 +116,48 @@ void CPUQuantizeSquashPass::Squash(
IR_NODE_LINK_TO(dequant_in, requant_op);
IR_NODE_LINK_TO(requant_op, quant_out);
found_squash_count++;
found_dequant_quant_count++;
}
};
gpd(graph, handler);
AddStatis(found_squash_count);
AddStatis(found_dequant_quant_count);
PrettyLogDetail("--- squashed %d dequantize-quantize pairs",
found_squash_count);
found_dequant_quant_count);
}
void CPUQuantizeSquashPass::ConvRequantSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::ConvRequant conv_requant_pattern{gpd.mutable_pattern(),
"conv_requant"};
conv_requant_pattern();
int found_requant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash conv-requantize ops pair";
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_op, requant_op, conv_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, conv_requant_pattern);
// if conv2d has one output squash
if (conv_out->outputs.size() == 1) {
float requant_scale_out =
boost::get<float>(requant_op->Op()->GetAttr("Scale_out"));
conv_op->Op()->SetAttr("Scale_out", requant_scale_out);
conv_op->Op()->SetOutput("Output",
std::vector<std::string>({requant_out->Name()}));
IR_NODE_LINK_TO(conv_op, requant_out);
GraphSafeRemoveNodes(graph, {conv_out, requant_op});
found_requant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_requant_squash_count);
PrettyLogDetail("--- squashed %d requantize with convs",
found_requant_squash_count);
}
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
@ -131,7 +166,8 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
std::unordered_map<const Node*, int> nodes_keep_counter;
FindNodesToKeep(graph, &nodes_keep_counter);
Squash(graph, &nodes_keep_counter);
DequantQuantSquash(graph, &nodes_keep_counter);
ConvRequantSquash(graph);
}
} // namespace ir

@ -46,8 +46,14 @@ class CPUQuantizeSquashPass : public FusePassBase {
/*
* Squash dequantize-quantize ops pairs into requantize or nothing
*/
void Squash(Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const;
void DequantQuantSquash(
Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const;
/*
* Squash requantize op into conv with scale_out like requantize scale_out
*/
void ConvRequantSquash(Graph* graph) const;
const std::string name_scope_{"squash"};
};

Loading…
Cancel
Save