diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc index 2558a4c634..c5242cf534 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc @@ -307,6 +307,12 @@ kernel::LiteKernel *CpuConvInt8KernelSelect(const std::vector &i return kernel; } +void CopyTensorQuantParam(lite::Tensor *dst, lite::Tensor *src) { + for (size_t i = 0; i < src->quant_params().size(); i++) { + dst->AddQuantParam(src->quant_params().at(i)); + } +} + kernel::LiteKernel *CpuGroupConvInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *op_parameter, const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, @@ -359,6 +365,7 @@ kernel::LiteKernel *CpuGroupConvInt8KernelCreator(const std::vector & MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); auto conv_param = reinterpret_cast(opParameter); kernel::LiteKernel *kernel = nullptr; + if (primitive != nullptr && primitive->infer_flag()) { + conv_param->input_h_ = inputs.front()->Height(); + conv_param->input_w_ = inputs.front()->Width(); + conv_param->input_channel_ = inputs.front()->Channel(); + conv_param->output_h_ = outputs.front()->Height(); + conv_param->output_w_ = outputs.front()->Width(); + conv_param->output_channel_ = outputs.front()->Channel(); + conv_param->op_parameter_.thread_num_ = ctx->thread_num_; + } if (conv_param->group_ == 1) { kernel = CpuConvInt8KernelSelect(inputs, outputs, opParameter, ctx, primitive); } else { diff --git a/mindspore/lite/test/models_onnx.cfg b/mindspore/lite/test/models_onnx.cfg index ce7db58afa..50ea66a130 100644 --- a/mindspore/lite/test/models_onnx.cfg +++ b/mindspore/lite/test/models_onnx.cfg @@ -8,6 +8,8 @@ mobilenetv2-7.onnx shufflenet-v2-10.onnx squeezenet1.1-7.onnx densenet-9.onnx +ml_table_detection_fp32.onnx +ml_table_segment.onnx googlenet-9.onnx inception-v1-9.onnx inception-v2-9.onnx diff --git a/mindspore/lite/test/models_onnx_fp16.cfg b/mindspore/lite/test/models_onnx_fp16.cfg index 318037a9ae..01c32090f0 100644 --- a/mindspore/lite/test/models_onnx_fp16.cfg +++ b/mindspore/lite/test/models_onnx_fp16.cfg @@ -8,6 +8,8 @@ mobilenetv2-7.onnx 8 shufflenet-v2-10.onnx 5 squeezenet1.1-7.onnx 1 densenet-9.onnx 6 +ml_table_detection_fp32.onnx 2 +ml_table_segment.onnx 2 googlenet-9.onnx 3 inception-v1-9.onnx 3 inception-v2-9.onnx 4 diff --git a/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc b/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc index 9b63d74a1e..656124f018 100644 --- a/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc @@ -59,18 +59,27 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) { MS_LOG(DEBUG) << "the tensor's shape is dynamic."; return true; } - auto conv_attr = std::make_unique(); - if (conv_attr == nullptr) { - MS_LOG(ERROR) << "conv_attr is null"; + auto weight_data_node = depthwise_cnode->input(kConvWeightIndex)->abstract(); + if (weight_data_node == nullptr) { + MS_LOG(ERROR) << "the weight node input is invalid."; return false; } - - if (data_shape[3] == 1) { + auto weight_shape = utils::cast(weight_data_node->GetShapeTrack())->shape(); + if (weight_shape.empty()) { + MS_LOG(DEBUG) << "the weight's shape is dynamic."; + return true; + } + if ((data_shape[3] == 1) || (data_shape[3] != weight_shape[3])) { + auto conv_attr = std::make_unique(); + if (conv_attr == nullptr) { + MS_LOG(ERROR) << "conv_attr is null"; + return false; + } conv_attr->channelIn = data_shape[3]; - conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier; + conv_attr->channelOut = weight_shape[3]; // update attr - conv_attr->group = 1; + conv_attr->group = data_shape[3]; conv_attr->format = attr->format; conv_attr->kernelH = attr->kernelH; conv_attr->kernelW = attr->kernelW;