!9780 [MS][LITE]Add DepthwiseConv multiplier support, fix group conv int8

From: @gongdaguo
Reviewed-by: @zhang_xue_tong
Signed-off-by: @zhang_xue_tong
pull/9780/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 98d178cd76

@ -307,6 +307,12 @@ kernel::LiteKernel *CpuConvInt8KernelSelect(const std::vector<lite::Tensor *> &i
return kernel;
}
void CopyTensorQuantParam(lite::Tensor *dst, lite::Tensor *src) {
for (size_t i = 0; i < src->quant_params().size(); i++) {
dst->AddQuantParam(src->quant_params().at(i));
}
}
kernel::LiteKernel *CpuGroupConvInt8KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter,
const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive,
@ -359,6 +365,7 @@ kernel::LiteKernel *CpuGroupConvInt8KernelCreator(const std::vector<lite::Tensor
MS_LOG(ERROR) << "create input tensor failed.";
return nullptr;
}
CopyTensorQuantParam(in_tensor, inputs[kInputIndex]);
new_inputs.emplace_back(in_tensor);
// create new weight
@ -371,6 +378,7 @@ kernel::LiteKernel *CpuGroupConvInt8KernelCreator(const std::vector<lite::Tensor
MS_LOG(ERROR) << "create filter tensor failed.";
return nullptr;
}
CopyTensorQuantParam(filter_tensor, inputs[kWeightIndex]);
new_inputs.emplace_back(filter_tensor);
// if has bias, create new bias
@ -383,6 +391,7 @@ kernel::LiteKernel *CpuGroupConvInt8KernelCreator(const std::vector<lite::Tensor
MS_LOG(ERROR) << "create bias_tensor failed.";
return nullptr;
}
CopyTensorQuantParam(bias_tensor, inputs[kBiasIndex]);
new_inputs.emplace_back(bias_tensor);
}
@ -395,6 +404,7 @@ kernel::LiteKernel *CpuGroupConvInt8KernelCreator(const std::vector<lite::Tensor
MS_LOG(ERROR) << "new out_tensor failed.";
return nullptr;
}
CopyTensorQuantParam(out_tensor, outputs[j]);
new_outputs.emplace_back(out_tensor);
}
group_convs.emplace_back(CpuConvInt8KernelSelect(
@ -412,6 +422,15 @@ kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector<lite::Tensor *> &
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
kernel::LiteKernel *kernel = nullptr;
if (primitive != nullptr && primitive->infer_flag()) {
conv_param->input_h_ = inputs.front()->Height();
conv_param->input_w_ = inputs.front()->Width();
conv_param->input_channel_ = inputs.front()->Channel();
conv_param->output_h_ = outputs.front()->Height();
conv_param->output_w_ = outputs.front()->Width();
conv_param->output_channel_ = outputs.front()->Channel();
conv_param->op_parameter_.thread_num_ = ctx->thread_num_;
}
if (conv_param->group_ == 1) {
kernel = CpuConvInt8KernelSelect(inputs, outputs, opParameter, ctx, primitive);
} else {

@ -8,6 +8,8 @@ mobilenetv2-7.onnx
shufflenet-v2-10.onnx
squeezenet1.1-7.onnx
densenet-9.onnx
ml_table_detection_fp32.onnx
ml_table_segment.onnx
googlenet-9.onnx
inception-v1-9.onnx
inception-v2-9.onnx

@ -8,6 +8,8 @@ mobilenetv2-7.onnx 8
shufflenet-v2-10.onnx 5
squeezenet1.1-7.onnx 1
densenet-9.onnx 6
ml_table_detection_fp32.onnx 2
ml_table_segment.onnx 2
googlenet-9.onnx 3
inception-v1-9.onnx 3
inception-v2-9.onnx 4

@ -59,18 +59,27 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) {
MS_LOG(DEBUG) << "the tensor's shape is dynamic.";
return true;
}
auto conv_attr = std::make_unique<schema::Conv2DT>();
if (conv_attr == nullptr) {
MS_LOG(ERROR) << "conv_attr is null";
auto weight_data_node = depthwise_cnode->input(kConvWeightIndex)->abstract();
if (weight_data_node == nullptr) {
MS_LOG(ERROR) << "the weight node input is invalid.";
return false;
}
if (data_shape[3] == 1) {
auto weight_shape = utils::cast<abstract::ShapePtr>(weight_data_node->GetShapeTrack())->shape();
if (weight_shape.empty()) {
MS_LOG(DEBUG) << "the weight's shape is dynamic.";
return true;
}
if ((data_shape[3] == 1) || (data_shape[3] != weight_shape[3])) {
auto conv_attr = std::make_unique<schema::Conv2DT>();
if (conv_attr == nullptr) {
MS_LOG(ERROR) << "conv_attr is null";
return false;
}
conv_attr->channelIn = data_shape[3];
conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier;
conv_attr->channelOut = weight_shape[3];
// update attr
conv_attr->group = 1;
conv_attr->group = data_shape[3];
conv_attr->format = attr->format;
conv_attr->kernelH = attr->kernelH;
conv_attr->kernelW = attr->kernelW;

Loading…
Cancel
Save