|
|
|
@ -32,22 +32,24 @@ using std::vector;
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace lite {
|
|
|
|
|
namespace quant {
|
|
|
|
|
const std::array<std::string, 4> QuantStrategy::mConvTypes = {
|
|
|
|
|
{"Conv2D", "DeConv2D", "DepthwiseConv2D", "DeDepthwiseConv2D"}};
|
|
|
|
|
const std::array<std::string, 4> QuantStrategy::mMulTypes = {{"Mul", "MatMul", "BatchMatMul", "FullConnection"}};
|
|
|
|
|
|
|
|
|
|
const std::vector<schema::PrimitiveType> QuantStrategy::conv_types = {
|
|
|
|
|
schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
|
|
|
|
|
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D};
|
|
|
|
|
const std::vector<schema::PrimitiveType> QuantStrategy::mul_types = {
|
|
|
|
|
schema::PrimitiveType_Mul, schema::PrimitiveType_MatMul, schema::PrimitiveType_FullConnection};
|
|
|
|
|
QuantStrategy::QuantStrategy(size_t weightSize, size_t convWeightQuantChannelThreshold)
|
|
|
|
|
: mWeightSize(weightSize), mConvWeightQuantChannelThreshold(convWeightQuantChannelThreshold) {}
|
|
|
|
|
|
|
|
|
|
bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const {
|
|
|
|
|
size_t i = 0;
|
|
|
|
|
for (i = 0; i < mConvTypes.size(); i++) {
|
|
|
|
|
if (node->fullname_with_scope().find(mConvTypes[i]) == 0) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(node->input(0));
|
|
|
|
|
if (primitive_c == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "primitive_c is nullptr";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if ((i == mConvTypes.size()) || (node->size() < 3)) {
|
|
|
|
|
if (!IsContain(conv_types, (schema::PrimitiveType)primitive_c->Type())) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (node->size() < 3) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -107,13 +109,13 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const {
|
|
|
|
|
size_t i = 0;
|
|
|
|
|
for (i = 0; i < mMulTypes.size(); i++) {
|
|
|
|
|
if (node->fullname_with_scope().find(mMulTypes[i]) == 0) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(node->input(0));
|
|
|
|
|
if (primitive_c == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "primitive_c is nullptr";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (i == mMulTypes.size()) {
|
|
|
|
|
|
|
|
|
|
if (!IsContain(mul_types, (schema::PrimitiveType)primitive_c->Type())) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|