diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index ba635eba09..ae22dbb2da 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -182,6 +182,16 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT if (input_node->isa()) { auto param_node = input_node->cast(); ConvertConvWeight(param_node); + auto abstractBase = param_node->abstract(); + MS_ASSERT(abstractBase != nullptr); + if (utils::isa(abstractBase)) { + auto abstractTensor = utils::cast(abstractBase); + MS_ASSERT(abstractTensor != nullptr); + if (utils::isa(abstractTensor->BuildShape())) { + auto dims = utils::cast(abstractTensor->BuildShape())->shape(); + attr->channelIn = dims[kAnfPopulaterOne]; + } + } } primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;