fix bug of quant export

pull/13760/head
chengxianbin 5 years ago
parent 9ae264f32a
commit 85b3389caa

@ -383,41 +383,23 @@ ExecutorPy::~ExecutorPy() {
ConfigManager::GetInstance().ResetConfig(); ConfigManager::GetInstance().ResetConfig();
} }
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchInfoForQuantExport( void ExecutorPy::GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weight_node,
const std::string &phase_s) { std::map<std::string, std::pair<PrimitivePyPtr, std::string>> *fake_quant_table) {
FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); std::string weight_name;
MS_EXCEPTION_IF_NULL(func_graph); auto x = root_node->input(1);
MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!"; if (IsPrimitiveCNode(weight_node, prim::kPrimLoad)) {
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> fake_quant_table; weight_name = weight_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name();
auto filter = [](const AnfNodePtr &node) { } else {
return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul) || weight_name = weight_node->cast<ParameterPtr>()->name();
IsPrimitiveCNode(node, prim::kPrimDepthwiseConv2dNative));
};
std::vector<AnfNodePtr> nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter);
auto is_quant_cnode = [](const AnfNodePtr &node) {
return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) ||
IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel);
};
for (const auto &node : nodes) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr || cnode->size() != 3) {
continue;
}
auto x = cnode->input(1);
auto weight = cnode->input(2);
if (!is_quant_cnode(weight)) {
continue;
}
// get parameter weight's name
cnode = weight->cast<CNodePtr>();
auto weight_node = cnode->input(2);
if (!weight_node->isa<Parameter>()) {
continue;
} }
auto weight_name = weight_node->cast<ParameterPtr>()->name();
// find the fakequant from input // find the fakequant from input
int64_t count = 0; int64_t count = 0;
const int64_t max_depth = 5; const int64_t max_depth = 5;
CNodePtr cnode = nullptr;
auto is_quant_cnode = [](const AnfNodePtr &node) {
return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) ||
IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel);
};
while (!is_quant_cnode(x)) { while (!is_quant_cnode(x)) {
if (count >= max_depth) { if (count >= max_depth) {
break; break;
@ -429,28 +411,66 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI
x = cnode->input(1); x = cnode->input(1);
count += 1; count += 1;
} }
if (x->isa<Parameter>()) { if (x->isa<Parameter>() || IsPrimitiveCNode(x, prim::kPrimLoad)) {
fake_quant_table[weight_name] = std::make_pair(nullptr, "input"); (*fake_quant_table)[weight_name] = std::make_pair(nullptr, "input");
} }
// get the fakequant parameter minq's name // get the fakequant parameter minq's name
if (!is_quant_cnode(x)) { if (!is_quant_cnode(x)) {
continue; return;
} }
cnode = x->cast<CNodePtr>(); cnode = x->cast<CNodePtr>();
if (cnode == nullptr || cnode->size() != 4) { if (cnode == nullptr || IsPrimitiveCNode(cnode, prim::kPrimLoad) || cnode->size() != 4) {
continue; return;
} }
auto fakequant_min_node = cnode->input(2); auto fakequant_min_node = cnode->input(2);
if (!fakequant_min_node->isa<Parameter>()) { if (!fakequant_min_node->isa<Parameter>() && !IsPrimitiveCNode(fakequant_min_node, prim::kPrimLoad)) {
continue; return;
}
std::string fakequant_min_node_name;
if (IsPrimitiveCNode(fakequant_min_node, prim::kPrimLoad)) {
fakequant_min_node_name = fakequant_min_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name();
} else {
fakequant_min_node_name = fakequant_min_node->cast<ParameterPtr>()->name();
} }
auto fakequant_min_node_name = fakequant_min_node->cast<ParameterPtr>()->name();
auto quant_op_value = cnode->input(0)->cast<ValueNodePtr>()->value(); auto quant_op_value = cnode->input(0)->cast<ValueNodePtr>()->value();
if (!quant_op_value->isa<PrimitivePy>()) { if (!quant_op_value->isa<PrimitivePy>()) {
continue; return;
} }
auto quant_op = quant_op_value->cast<PrimitivePyPtr>(); auto quant_op = quant_op_value->cast<PrimitivePyPtr>();
fake_quant_table[weight_name] = std::make_pair(quant_op, fakequant_min_node_name); (*fake_quant_table)[weight_name] = std::make_pair(quant_op, fakequant_min_node_name);
}
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchInfoForQuantExport(
const std::string &phase_s) {
FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!";
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> fake_quant_table;
auto filter = [](const AnfNodePtr &node) {
return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul) ||
IsPrimitiveCNode(node, prim::kPrimDepthwiseConv2dNative));
};
std::vector<AnfNodePtr> nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter);
auto is_quant_cnode = [](const AnfNodePtr &node) {
return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) ||
IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel);
};
for (const auto &node : nodes) {
auto root_node = node->cast<CNodePtr>();
if (root_node == nullptr || root_node->size() != 3) {
continue;
}
auto weight = root_node->input(2);
if (!is_quant_cnode(weight)) {
continue;
}
// get parameter weight's name
auto cnode = weight->cast<CNodePtr>();
auto weight_node = cnode->input(2);
if (!weight_node->isa<Parameter>() && !IsPrimitiveCNode(weight_node, prim::kPrimLoad)) {
continue;
}
GetWeightInfo(root_node, weight_node, &fake_quant_table);
} }
return fake_quant_table; return fake_quant_table;

@ -110,6 +110,8 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
private: private:
ExecutorPy(); ExecutorPy();
void ConvertObjectToTensors(const py::dict &dict, std::map<std::string, tensor::TensorPtr> *tensors); void ConvertObjectToTensors(const py::dict &dict, std::map<std::string, tensor::TensorPtr> *tensors);
void GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weight_node,
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> *fake_quant_table);
void GetGeBackendPolicy() const; void GetGeBackendPolicy() const;
// filter some pipeline actions according to phase, e.g. when exporting onnx, it is no need to execute actions after // filter some pipeline actions according to phase, e.g. when exporting onnx, it is no need to execute actions after
// 'validate' stage // 'validate' stage

@ -76,7 +76,7 @@ class ExportToQuantInferNetwork:
return network return network
def _get_quant_block(self, cell_core, activation, fake_quant_a_out): def _get_quant_block(self, cell_core, activation, fake_quant_a_out):
"""convet network's quant subcell to deploy subcell""" """convert network's quant subcell to deploy subcell"""
# Calculate the scale and zero point # Calculate the scale and zero point
w_minq_name = cell_core.fake_quant_weight.minq.name w_minq_name = cell_core.fake_quant_weight.minq.name
np_type = mstype.dtype_to_nptype(self.data_type) np_type = mstype.dtype_to_nptype(self.data_type)
@ -129,7 +129,7 @@ class ExportToQuantInferNetwork:
if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)): if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)):
if cell_core.has_bias: if cell_core.has_bias:
bias = cell_core.bias.data.asnumpy() bias = cell_core.bias.data.asnumpy()
elif isinstance(cell_core, quant.Conv2dBnFoldQuant): elif isinstance(cell_core, (quant.Conv2dBnFoldQuant, quant.Conv2dBnFoldQuantOneConv)):
weight, bias = quant_utils.fold_batchnorm(weight, cell_core) weight, bias = quant_utils.fold_batchnorm(weight, cell_core)
elif isinstance(cell_core, quant.Conv2dBnWithoutFoldQuant): elif isinstance(cell_core, quant.Conv2dBnWithoutFoldQuant):
weight, bias = quant_utils.without_fold_batchnorm(weight, cell_core) weight, bias = quant_utils.without_fold_batchnorm(weight, cell_core)

@ -1381,11 +1381,14 @@ class QuantBlock(Cell):
self.activation = activation self.activation = activation
self.has_act = activation is not None self.has_act = activation is not None
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
self.sub = P.Sub()
self.weight_offset = Parameter(np.zeros(shape=weight.shape, dtype=np.int8), name='weight_offset')
def construct(self, x): def construct(self, x):
x = self.quant(x) x = self.quant(x)
if self.has_bias: if self.has_bias:
x = self.core_op(x, self.weight) weight = self.sub(self.weight, self.weight_offset)
x = self.core_op(x, weight)
x = self.bias_add(x, self.bias) x = self.bias_add(x, self.bias)
else: else:
x = self.core_op(x, self.weight) x = self.core_op(x, self.weight)

@ -26,6 +26,7 @@ from src.config import config_quant
parser = argparse.ArgumentParser(description='Image classification') parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
parser.add_argument('--device_target', type=str, default=None, help='Run device target') parser.add_argument('--device_target', type=str, default=None, help='Run device target')
args_opt = parser.parse_args() args_opt = parser.parse_args()
@ -46,5 +47,9 @@ if __name__ == '__main__':
# export network # export network
print("============== Starting export ==============") print("============== Starting export ==============")
inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32) inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32)
if args_opt.file_format == 'MINDIR':
export(network, inputs, file_name="mobilenet_quant", file_format='MINDIR', quant_mode='AUTO') export(network, inputs, file_name="mobilenet_quant", file_format='MINDIR', quant_mode='AUTO')
else:
export(network, inputs, file_name="mobilenet_quant", file_format='AIR',
quant_mode='AUTO', mean=0., std_dev=48.106)
print("============== End export ==============") print("============== End export ==============")

Loading…
Cancel
Save