1. compute threshold only once

2. fix anf_exporter bug: pool, concat op may not set into metagraph
3. fix weight trans pass will return error when post quantization
4. make anf_exporter reentrant: do not set PrimitiveT * to nullptr
pull/3815/head
xutianchun 5 years ago
parent 6c4ee3f3d1
commit fae78e11a7

@ -98,7 +98,6 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
} }
node->primitive = std::unique_ptr<schema::PrimitiveT>(primitiveT_value->GetPrimitiveT()); node->primitive = std::unique_ptr<schema::PrimitiveT>(primitiveT_value->GetPrimitiveT());
primitiveT_value->SetPrimitiveT(nullptr);
std::vector<schema::TensorT *> outputs; std::vector<schema::TensorT *> outputs;
SetOpInputNode(cnode, metaGraphT.get(), node.get()); SetOpInputNode(cnode, metaGraphT.get(), node.get());
SetOpOutputNode(outputs, metaGraphT.get(), node.get()); SetOpOutputNode(outputs, metaGraphT.get(), node.get());
@ -113,24 +112,22 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
auto input_quant_params = primitiveT_value->GetInputQuantParams(); auto input_quant_params = primitiveT_value->GetInputQuantParams();
if (input_quant_params.empty()) { if (input_quant_params.empty()) {
MS_LOG(WARNING) << "node: " << node->name << " input quant params is empty"; MS_LOG(WARNING) << "node: " << node->name << " input quant params is empty";
continue; } else {
std::unique_ptr<schema::QuantParamT> input_quant_param =
std::make_unique<schema::QuantParamT>(input_quant_params[0]);
tensor_input->quantParams.emplace_back(std::move(input_quant_param));
} }
std::unique_ptr<schema::QuantParamT> input_quant_param =
std::make_unique<schema::QuantParamT>(input_quant_params[0]);
tensor_input->quantParams.emplace_back(std::move(input_quant_param));
// output // output
auto output_index = node->outputIndex[0]; auto output_index = node->outputIndex[0];
auto tensor_output = metaGraphT->allTensors[output_index].get(); auto tensor_output = metaGraphT->allTensors[output_index].get();
auto output_quant_params = primitiveT_value->GetOutputQuantParams(); auto output_quant_params = primitiveT_value->GetOutputQuantParams();
if (output_quant_params.empty()) { if (output_quant_params.empty()) {
MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty"; MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty";
continue; } else {
std::unique_ptr<schema::QuantParamT> output_quant_param =
std::make_unique<schema::QuantParamT>(output_quant_params[0]);
tensor_output->quantParams.emplace_back(std::move(output_quant_param));
} }
std::unique_ptr<schema::QuantParamT> output_quant_param =
std::make_unique<schema::QuantParamT>(output_quant_params[0]);
tensor_output->quantParams.emplace_back(std::move(output_quant_param));
// // TensorType // // TensorType
// valuePtr = primitive->GetAttr(kInputTensorDataType); // valuePtr = primitive->GetAttr(kInputTensorDataType);
// if (valuePtr != nullptr) { // if (valuePtr != nullptr) {

@ -26,8 +26,8 @@ namespace mindspore::lite {
class PrimitiveTValue : public Value { class PrimitiveTValue : public Value {
public: public:
explicit PrimitiveTValue(schema::PrimitiveT *primt) : primitive(primt) {} explicit PrimitiveTValue(schema::PrimitiveT *primt) : primitive(primt) {}
// not responsible to free primitive, the one created the dynamic memory is responsible to free it.
~PrimitiveTValue() override { delete this->primitive; } ~PrimitiveTValue() override = default;
MS_DECLARE_PARENT(PrimitiveTValue, Value) MS_DECLARE_PARENT(PrimitiveTValue, Value)

@ -27,7 +27,7 @@ int WeightFormatPass::Run(GraphNode *graphNode) {
MS_LOG(ERROR) << "ShapeFormatTrans failed: " << status; MS_LOG(ERROR) << "ShapeFormatTrans failed: " << status;
return status; return status;
} }
if (this->quantType == QuantType_AwareTrainning) { if (this->quantType == QuantType_AwareTrainning || this->quantType == QuantType_PostTraining) {
status = QuantDataFormatTrans(graphNode); status = QuantDataFormatTrans(graphNode);
if (status != 0) { if (status != 0) {
MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status;
@ -147,7 +147,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
} else if (fmkType == converter::FmkType_TFLITE) { } else if (fmkType == converter::FmkType_TFLITE) {
switch (node->quantType) { switch (node->quantType) {
case QuantType_QUANT_NONE: case QuantType_QUANT_NONE:
case QuantType_AwareTrainning: { case QuantType_AwareTrainning:
case QuantType_PostTraining: {
if (opType == schema::PrimitiveType_Conv2D) { if (opType == schema::PrimitiveType_Conv2D) {
weightTensor->format = schema::Format_KHWC; weightTensor->format = schema::Format_KHWC;
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { } else if (opType == schema::PrimitiveType_DepthwiseConv2D) {

@ -292,13 +292,32 @@ STATUS Calibrator::RecordMaxValue(std::string opName, vector<float> data,
} }
STATUS Calibrator::ComputeThreshold() { STATUS Calibrator::ComputeThreshold() {
for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) { for (auto iter = this->output_diverg_info_.begin(); iter != this->output_diverg_info_.end(); iter++) {
DivergInfo *info = iter->second.get(); DivergInfo *info = iter->second.get();
info->ComputeThreshold(); info->ComputeThreshold();
} }
for (auto iter = this->output_diverg_info_.begin(); iter != this->output_diverg_info_.end(); iter++) { // node A's input may be node B's output, no need to re-compute the node A's input quant param which is the same as
for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) {
DivergInfo *info = iter->second.get(); DivergInfo *info = iter->second.get();
info->ComputeThreshold(); auto cnode = info->cnode;
bool already_computed = false;
auto input = cnode->input(1);
if (input->isa<mindspore::CNode>()) {
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input);
for (const auto &output_diverg_info : output_diverg_info_) {
auto output_diverg_cnode = output_diverg_info.second->cnode;
if (output_diverg_cnode == input_cnode) {
*info = *(output_diverg_info.second);
info->cnode = cnode;
already_computed = true;
break;
}
}
}
if (!already_computed) {
info->ComputeThreshold();
}
} }
return RET_OK; return RET_OK;
} }

Loading…
Cancel
Save