|
|
|
@ -26,27 +26,33 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
Get<std::unordered_set<int>>("quantize_excluded_op_ids");
|
|
|
|
|
const auto& op_types_list =
|
|
|
|
|
Get<std::unordered_set<std::string>>("quantize_enabled_op_types");
|
|
|
|
|
for (const Node* n : graph->Nodes()) {
|
|
|
|
|
if (n->IsOp()) {
|
|
|
|
|
if (std::find(excluded_ids_list.begin(), excluded_ids_list.end(),
|
|
|
|
|
n->id()) != excluded_ids_list.end())
|
|
|
|
|
continue;
|
|
|
|
|
auto* op = n->Op();
|
|
|
|
|
if (op->HasAttr("mkldnn_data_type") ||
|
|
|
|
|
op->HasProtoAttr("mkldnn_data_type")) {
|
|
|
|
|
// use_quantizer is no longer used
|
|
|
|
|
// assign value for compatibility
|
|
|
|
|
if (op->GetAttrIfExists<bool>("use_quantizer")) {
|
|
|
|
|
op->SetAttr("mkldnn_data_type", std::string("int8"));
|
|
|
|
|
}
|
|
|
|
|
if (std::find(op_types_list.begin(), op_types_list.end(), op->Type()) !=
|
|
|
|
|
op_types_list.end()) {
|
|
|
|
|
op->SetAttr("mkldnn_data_type", std::string("int8"));
|
|
|
|
|
op->SetAttr("use_quantizer", true);
|
|
|
|
|
}
|
|
|
|
|
Init(name_scope_, graph);
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
patterns::QuantizePlacement quantize_placement_pattern{gpd.mutable_pattern(),
|
|
|
|
|
"quantize_placement"};
|
|
|
|
|
quantize_placement_pattern(op_types_list);
|
|
|
|
|
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(op, op, quantize_placement_pattern);
|
|
|
|
|
|
|
|
|
|
if (std::find(excluded_ids_list.begin(), excluded_ids_list.end(),
|
|
|
|
|
op->id()) != excluded_ids_list.end()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (op->Op()->HasAttr("mkldnn_data_type") ||
|
|
|
|
|
op->Op()->HasProtoAttr("mkldnn_data_type")) {
|
|
|
|
|
// use_quantizer is no longer used
|
|
|
|
|
// assign value for compatibility
|
|
|
|
|
if (op->Op()->GetAttrIfExists<bool>("use_quantizer")) {
|
|
|
|
|
op->Op()->SetAttr("mkldnn_data_type", std::string("int8"));
|
|
|
|
|
}
|
|
|
|
|
op->Op()->SetAttr("mkldnn_data_type", std::string("int8"));
|
|
|
|
|
op->Op()->SetAttr("use_quantizer", true);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
gpd(graph, handler);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace ir
|
|
|
|
@ -58,10 +64,7 @@ REGISTER_PASS(cpu_quantize_placement_pass,
|
|
|
|
|
// a vector of operator type names to be quantized ("conv2d" etc.)
|
|
|
|
|
// the second param is the default value for this vector
|
|
|
|
|
.DefaultPassAttr("quantize_enabled_op_types",
|
|
|
|
|
new std::unordered_set<std::string>(
|
|
|
|
|
{"concat", "conv2d", "elementwise_add", "fc", "matmul",
|
|
|
|
|
"pool2d", "prior_box", "relu", "reshape2",
|
|
|
|
|
"transpose2"}))
|
|
|
|
|
new std::unordered_set<std::string>())
|
|
|
|
|
// a vector of operator ids that are to be excluded from quantization
|
|
|
|
|
// the second param is the default value for this vector
|
|
|
|
|
.DefaultPassAttr("quantize_excluded_op_ids", new std::unordered_set<int>());
|
|
|
|
|