!9223 [MS][LITE] add posttraining model && fix static check error

From: @jianghui58
Reviewed-by: @zhanghaibo5,@HilbertDavid
Signed-off-by: @HilbertDavid
pull/9223/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit a36c0b8b61

@ -1 +1,3 @@
mobilenet.tflite
mobilenet.tflite 0.5
transformer_20200831_encoder_fp32.tflite 68
transformer_20200831_decoder_fp32.tflite 35

File diff suppressed because one or more lines are too long

@ -867,6 +867,9 @@ STATUS PostTrainingQuantizer::QuantNode() {
}
}
if (input_node->isa<mindspore::CNode>()) {
if (op_type == PrimitiveType_Gather) {
continue;
}
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
auto input_cnode_primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
if (input_cnode_primitive_c == nullptr) {
@ -932,7 +935,8 @@ STATUS PostTrainingQuantizer::QuantNode() {
// do weight quant
auto weight = cnode->input(2);
bool perchannel = false;
if (op_type == PrimitiveType_Conv2D || op_type == PrimitiveType_DepthwiseConv2D) {
if (op_type == PrimitiveType_Conv2D || op_type == PrimitiveType_DepthwiseConv2D ||
op_type == PrimitiveType_FullConnection) {
perchannel = true;
}
DoWeightQuant(weight, primitive_c, perchannel);

@ -47,7 +47,10 @@ STATUS QuantCast::Run(const FuncGraphPtr &graph) {
} else {
curnode_quant_type = primitive_c->quant_type();
}
auto op_type = (schema::PrimitiveType)primitive_c->Type();
if (op_type == schema::PrimitiveType_Gather) {
continue;
}
for (size_t i = 1; i < cnode->inputs().size(); i++) {
auto input_node = cnode->input(i);
auto is_graph_input = false;

@ -89,25 +89,13 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
auto cnode = std::dynamic_pointer_cast<CNode>(node);
auto type = NodePrimitiveType(cnode);
static const std::vector<schema::PrimitiveType> int8OpList = {
schema::PrimitiveType_Nchw2Nhwc,
schema::PrimitiveType_Nhwc2Nchw,
schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_Add,
schema::PrimitiveType_Pooling,
schema::PrimitiveType_Concat,
schema::PrimitiveType_Split,
schema::PrimitiveType_TupleGetItem,
schema::PrimitiveType_Reshape,
schema::PrimitiveType_FullConnection,
schema::PrimitiveType_MatMul,
schema::PrimitiveType_Crop,
schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_Activation,
schema::PrimitiveType_Transpose,
schema::PrimitiveType_Eltwise,
schema::PrimitiveType_LayerNorm,
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Mul,
schema::PrimitiveType_Pooling, schema::PrimitiveType_Concat, schema::PrimitiveType_Split,
schema::PrimitiveType_TupleGetItem, schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection,
schema::PrimitiveType_MatMul, schema::PrimitiveType_Crop, schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_DeConv2D, schema::PrimitiveType_Activation, schema::PrimitiveType_Transpose,
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Gather, schema::PrimitiveType_LayerNorm,
};
bool contain = IsContain(int8OpList, type);
if (!contain) {

@ -34,8 +34,8 @@ static const char *DELIM_SLASH = "/";
void SaveFile(std::string path, void *buf, size_t size) {
std::ofstream ofs(path);
assert(true == ofs.good());
assert(true == ofs.is_open());
MS_ASSERT(ofs.good() == true);
MS_ASSERT(ofs.is_open() == true);
ofs.seekp(0, std::ios::beg);
ofs.write((const char *)buf, size);
@ -521,7 +521,7 @@ int NetTrain::RunNetTrain() {
}
void NetTrainFlags::InitInputDataList() {
char *saveptr1;
char *saveptr1 = nullptr;
char *input_list = new char[this->in_data_file_.length() + 1];
snprintf(input_list, this->in_data_file_.length() + 1, "%s", this->in_data_file_.c_str());
char *cur_input;

Loading…
Cancel
Save