!7631 fix post training quant with multi-output op

Merge pull request !7631 from xutianchun/mulit
pull/7631/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 5532de75ef

@ -132,7 +132,7 @@ struct DivergInfo {
std::vector<float> max_datas;
std::pair<float, float> percent_result{0.0, 0.0};
float scale_tmp = 0;
DivergInfo() = default;
DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min, const std::string &method_x) {
this->method_x = method_x;
this->cnode = cnode;
@ -187,13 +187,14 @@ class Calibrator {
STATUS AddQuantizedOp(CNodePtr node);
STATUS RecordMaxValue(const std::string &op_name, const std::vector<float> &data,
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
STATUS RecordMaxValue(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info);
STATUS UpdateDivergInverval(std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
STATUS UpdateDataFrequency(const std::string &op_name, const std::vector<float> &data,
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
STATUS UpdateOutputDivergInverval(
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *diverg_info);
STATUS UpdateDataFrequency(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info);
void Dump();
STATUS ComputeThreshold();
@ -208,7 +209,7 @@ class Calibrator {
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *GetInputDivergInfo();
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *GetOutputDivergInfo();
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetOutputDivergInfo();
private:
std::vector<std::vector<std::string>> images_; // multi_input, echo input has multi input data
@ -219,7 +220,7 @@ class Calibrator {
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> input_diverg_info_;
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> output_diverg_info_;
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> outputs_diverg_info_;
size_t bit_num_;
int quant_max_;

@ -85,7 +85,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
if (curnode_quant_type == schema::QuantType_PostTraining &&
input_cnode_quant_type == schema::QuantType_QUANT_NONE) {
value_node =
NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitive_c->GetInputQuantParams().front());
NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitive_c->GetInputQuantParams()[i - 1]);
} else if (curnode_quant_type == schema::QuantType_QUANT_NONE &&
input_cnode_quant_type == schema::QuantType_PostTraining) {
value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32,

@ -90,22 +90,32 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
return false;
}
auto cnode = std::dynamic_pointer_cast<CNode>(node);
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(WARNING) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
return false;
}
auto type = (schema::PrimitiveType)primitive_c->Type();
MS_LOG(INFO) << "Primitive type: " << type;
static const std::vector<schema::PrimitiveType> uint8OpList = {
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
/*schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax,*/
schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection, schema::PrimitiveType_MatMul,
schema::PrimitiveType_Activation};
return IsContain(uint8OpList, type);
auto type = NodePrimitiveType(cnode);
static const std::vector<schema::PrimitiveType> int8OpList = {
schema::PrimitiveType_Nchw2Nhwc,
schema::PrimitiveType_Nhwc2Nchw,
schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_Add,
schema::PrimitiveType_Pooling,
schema::PrimitiveType_Concat,
schema::PrimitiveType_Split,
schema::PrimitiveType_TupleGetItem,
schema::PrimitiveType_Reshape,
schema::PrimitiveType_FullConnection,
schema::PrimitiveType_MatMul,
schema::PrimitiveType_Crop,
schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_Activation,
schema::PrimitiveType_TupleGetItem,
};
bool contain = IsContain(int8OpList, type);
if (!contain) {
MS_LOG(INFO) << "not quant, " << cnode->fullname_with_scope()
<< " of type: " << schema::EnumNamePrimitiveType(type);
}
return contain;
}
bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const {
@ -431,6 +441,19 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc
quantParam->clusters = clusters;
return clusters_index;
}
schema::PrimitiveType NodePrimitiveType(CNodePtr cnode) {
if (cnode == nullptr) {
MS_LOG(ERROR) << "cnode is null";
return schema::PrimitiveType_NONE;
}
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is null";
return schema::PrimitiveType_NONE;
}
return (schema::PrimitiveType)primitive_c->Type();
}
} // namespace quant
} // namespace lite
} // namespace mindspore

@ -287,6 +287,8 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
}
STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION);
schema::PrimitiveType NodePrimitiveType(CNodePtr cnode);
} // namespace quant
} // namespace lite
} // namespace mindspore

Loading…
Cancel
Save