|
|
|
|
@ -90,22 +90,32 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto cnode = std::dynamic_pointer_cast<CNode>(node);
|
|
|
|
|
|
|
|
|
|
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
|
|
|
|
if (primitive_c == nullptr) {
|
|
|
|
|
MS_LOG(WARNING) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto type = (schema::PrimitiveType)primitive_c->Type();
|
|
|
|
|
MS_LOG(INFO) << "Primitive type: " << type;
|
|
|
|
|
static const std::vector<schema::PrimitiveType> uint8OpList = {
|
|
|
|
|
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D,
|
|
|
|
|
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
|
|
|
|
|
/*schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax,*/
|
|
|
|
|
schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection, schema::PrimitiveType_MatMul,
|
|
|
|
|
schema::PrimitiveType_Activation};
|
|
|
|
|
return IsContain(uint8OpList, type);
|
|
|
|
|
auto type = NodePrimitiveType(cnode);
|
|
|
|
|
static const std::vector<schema::PrimitiveType> int8OpList = {
|
|
|
|
|
schema::PrimitiveType_Nchw2Nhwc,
|
|
|
|
|
schema::PrimitiveType_Nhwc2Nchw,
|
|
|
|
|
schema::PrimitiveType_Conv2D,
|
|
|
|
|
schema::PrimitiveType_DepthwiseConv2D,
|
|
|
|
|
schema::PrimitiveType_Add,
|
|
|
|
|
schema::PrimitiveType_Pooling,
|
|
|
|
|
schema::PrimitiveType_Concat,
|
|
|
|
|
schema::PrimitiveType_Split,
|
|
|
|
|
schema::PrimitiveType_TupleGetItem,
|
|
|
|
|
schema::PrimitiveType_Reshape,
|
|
|
|
|
schema::PrimitiveType_FullConnection,
|
|
|
|
|
schema::PrimitiveType_MatMul,
|
|
|
|
|
schema::PrimitiveType_Crop,
|
|
|
|
|
schema::PrimitiveType_DeDepthwiseConv2D,
|
|
|
|
|
schema::PrimitiveType_DeConv2D,
|
|
|
|
|
schema::PrimitiveType_Activation,
|
|
|
|
|
schema::PrimitiveType_TupleGetItem,
|
|
|
|
|
};
|
|
|
|
|
bool contain = IsContain(int8OpList, type);
|
|
|
|
|
if (!contain) {
|
|
|
|
|
MS_LOG(INFO) << "not quant, " << cnode->fullname_with_scope()
|
|
|
|
|
<< " of type: " << schema::EnumNamePrimitiveType(type);
|
|
|
|
|
}
|
|
|
|
|
return contain;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const {
|
|
|
|
|
@ -431,6 +441,19 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc
|
|
|
|
|
quantParam->clusters = clusters;
|
|
|
|
|
return clusters_index;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
schema::PrimitiveType NodePrimitiveType(CNodePtr cnode) {
|
|
|
|
|
if (cnode == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "cnode is null";
|
|
|
|
|
return schema::PrimitiveType_NONE;
|
|
|
|
|
}
|
|
|
|
|
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
|
|
|
|
if (primitive_c == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "primitive_c is null";
|
|
|
|
|
return schema::PrimitiveType_NONE;
|
|
|
|
|
}
|
|
|
|
|
return (schema::PrimitiveType)primitive_c->Type();
|
|
|
|
|
}
|
|
|
|
|
} // namespace quant
|
|
|
|
|
} // namespace lite
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|
|