|
|
|
@ -18,23 +18,19 @@
|
|
|
|
|
#include "tools/optimizer/common/gllo_utils.h"
|
|
|
|
|
|
|
|
|
|
using mindspore::lite::converter::FmkType_CAFFE;
|
|
|
|
|
using mindspore::lite::converter::FmkType_TFLITE;
|
|
|
|
|
using mindspore::lite::converter::FmkType_ONNX;
|
|
|
|
|
using mindspore::lite::converter::FmkType_MS;
|
|
|
|
|
using mindspore::schema::QuantType_WeightQuant;
|
|
|
|
|
using mindspore::schema::QuantType_QUANT_NONE;
|
|
|
|
|
using mindspore::lite::converter::FmkType_ONNX;
|
|
|
|
|
using mindspore::lite::converter::FmkType_TFLITE;
|
|
|
|
|
using mindspore::schema::QuantType_AwareTraining;
|
|
|
|
|
using mindspore::schema::QuantType_PostTraining;
|
|
|
|
|
using mindspore::schema::QuantType_QUANT_NONE;
|
|
|
|
|
using mindspore::schema::QuantType_WeightQuant;
|
|
|
|
|
namespace mindspore::opt {
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr size_t kConvWeightIndex = 2;
|
|
|
|
|
} // namespace
|
|
|
|
|
void WeightFormatHardCodePass::SetQuantType(QuantType type) {
|
|
|
|
|
this->quant_type = type;
|
|
|
|
|
}
|
|
|
|
|
void WeightFormatHardCodePass::SetFmkType(FmkType type) {
|
|
|
|
|
this->fmk_type = type;
|
|
|
|
|
}
|
|
|
|
|
void WeightFormatHardCodePass::SetQuantType(QuantType type) { this->quant_type = type; }
|
|
|
|
|
void WeightFormatHardCodePass::SetFmkType(FmkType type) { this->fmk_type = type; }
|
|
|
|
|
lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node,
|
|
|
|
|
const ParamValueLitePtr ¶m_value) const {
|
|
|
|
|
MS_ASSERT(conv_cnode != nullptr);
|
|
|
|
@ -42,11 +38,12 @@ lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node
|
|
|
|
|
switch (quant_type) {
|
|
|
|
|
case schema::QuantType_PostTraining:
|
|
|
|
|
case QuantType_WeightQuant:
|
|
|
|
|
case QuantType_QUANT_NONE:param_value->set_format(schema::Format::Format_KCHW);
|
|
|
|
|
case QuantType_QUANT_NONE:
|
|
|
|
|
param_value->set_format(schema::Format::Format_KCHW);
|
|
|
|
|
break;
|
|
|
|
|
default: {
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: "
|
|
|
|
|
<< conv_node->fullname_with_scope();
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type)
|
|
|
|
|
<< ", node: " << conv_node->fullname_with_scope();
|
|
|
|
|
return lite::RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -68,12 +65,11 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node,
|
|
|
|
|
} else if (op_type == schema::PrimitiveType_DeConv2D) {
|
|
|
|
|
param_value->set_format(schema::Format::Format_KCHW);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
|
|
|
|
|
<< conv_node->fullname_with_scope();
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type)
|
|
|
|
|
<< ", node: " << conv_node->fullname_with_scope();
|
|
|
|
|
return lite::RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
} break;
|
|
|
|
|
case QuantType_PostTraining:
|
|
|
|
|
case QuantType_WeightQuant:
|
|
|
|
|
case QuantType_QUANT_NONE: {
|
|
|
|
@ -81,19 +77,18 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node,
|
|
|
|
|
// depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W)
|
|
|
|
|
// deconv (C x K/group x kH x kW) group = 1
|
|
|
|
|
// dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W)
|
|
|
|
|
if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D
|
|
|
|
|
|| op_type == schema::PrimitiveType_DeConv2D) {
|
|
|
|
|
if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D ||
|
|
|
|
|
op_type == schema::PrimitiveType_DeConv2D) {
|
|
|
|
|
param_value->set_format(schema::Format::Format_KCHW);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
|
|
|
|
|
<< conv_node->fullname_with_scope();
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type)
|
|
|
|
|
<< ", node: " << conv_node->fullname_with_scope();
|
|
|
|
|
return lite::RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
} break;
|
|
|
|
|
default: {
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: "
|
|
|
|
|
<< conv_node->fullname_with_scope();
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type)
|
|
|
|
|
<< ", node: " << conv_node->fullname_with_scope();
|
|
|
|
|
return lite::RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -114,8 +109,7 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node,
|
|
|
|
|
} else {
|
|
|
|
|
param_value->set_format(schema::Format::Format_KCHW);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
} break;
|
|
|
|
|
case QuantType_PostTraining:
|
|
|
|
|
case QuantType_WeightQuant:
|
|
|
|
|
case QuantType_QUANT_NONE: {
|
|
|
|
@ -124,18 +118,19 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node,
|
|
|
|
|
param_value->set_format(schema::Format::Format_KCHW);
|
|
|
|
|
} else if (op_type == schema::PrimitiveType_DepthwiseConv2D) {
|
|
|
|
|
param_value->set_format(schema::Format::Format_CKHW);
|
|
|
|
|
} else if (op_type == schema::PrimitiveType_DeDepthwiseConv2D) {
|
|
|
|
|
param_value->set_format(schema::Format::Format_CKHW);
|
|
|
|
|
} else if (op_type == schema::PrimitiveType_DeConv2D) {
|
|
|
|
|
param_value->set_format(schema::Format::Format_KCHW);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
|
|
|
|
|
<< conv_node->fullname_with_scope();
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type)
|
|
|
|
|
<< ", node: " << conv_node->fullname_with_scope();
|
|
|
|
|
return lite::RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
} break;
|
|
|
|
|
default: {
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: "
|
|
|
|
|
<< conv_node->fullname_with_scope();
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type)
|
|
|
|
|
<< ", node: " << conv_node->fullname_with_scope();
|
|
|
|
|
return lite::RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -159,15 +154,14 @@ lite::STATUS WeightFormatHardCodePass::HardCodeTFLITE(const AnfNodePtr &conv_nod
|
|
|
|
|
} else if (op_type == schema::PrimitiveType_DeConv2D) {
|
|
|
|
|
param_value->set_format(schema::Format::Format_CHWK);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
|
|
|
|
|
<< conv_node->fullname_with_scope();
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type)
|
|
|
|
|
<< ", node: " << conv_node->fullname_with_scope();
|
|
|
|
|
return lite::RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
} break;
|
|
|
|
|
default: {
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
|
|
|
|
|
<< conv_node->fullname_with_scope();
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type)
|
|
|
|
|
<< ", node: " << conv_node->fullname_with_scope();
|
|
|
|
|
return lite::RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -183,8 +177,8 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) {
|
|
|
|
|
}
|
|
|
|
|
auto conv_cnode = node->cast<CNodePtr>();
|
|
|
|
|
auto type = opt::GetCNodeType(node);
|
|
|
|
|
if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D
|
|
|
|
|
&& type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) {
|
|
|
|
|
if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D &&
|
|
|
|
|
type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
|
|
|
|
@ -197,15 +191,20 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) {
|
|
|
|
|
}
|
|
|
|
|
lite::STATUS status;
|
|
|
|
|
switch (fmk_type) {
|
|
|
|
|
case FmkType_CAFFE:status = HardCodeCAFFE(node, param_value);
|
|
|
|
|
case FmkType_CAFFE:
|
|
|
|
|
status = HardCodeCAFFE(node, param_value);
|
|
|
|
|
break;
|
|
|
|
|
case FmkType_TFLITE:status = HardCodeTFLITE(node, param_value);
|
|
|
|
|
case FmkType_TFLITE:
|
|
|
|
|
status = HardCodeTFLITE(node, param_value);
|
|
|
|
|
break;
|
|
|
|
|
case FmkType_ONNX:status = HardCodeONNX(node, param_value);
|
|
|
|
|
case FmkType_ONNX:
|
|
|
|
|
status = HardCodeONNX(node, param_value);
|
|
|
|
|
break;
|
|
|
|
|
case FmkType_MS:status = HardCodeMS(node, param_value);
|
|
|
|
|
case FmkType_MS:
|
|
|
|
|
status = HardCodeMS(node, param_value);
|
|
|
|
|
break;
|
|
|
|
|
default:MS_LOG(ERROR) << "Unsupported fmkType: " << fmk_type << ", node: " << node->fullname_with_scope();
|
|
|
|
|
default:
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported fmkType: " << fmk_type << ", node: " << node->fullname_with_scope();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (status != lite::RET_OK) {
|
|
|
|
|