diff --git a/mindspore/lite/src/ops/deconvolution.cc b/mindspore/lite/src/ops/deconvolution.cc index 12aa1dcab1..ba840a49fa 100644 --- a/mindspore/lite/src/ops/deconvolution.cc +++ b/mindspore/lite/src/ops/deconvolution.cc @@ -35,7 +35,7 @@ int DeConv2D::InferShape(std::vector inputs_, std::vectorBatch(); int32_t output_h = 0; int32_t output_w = 0; - int32_t output_c = weight->Batch(); + int32_t output_c = weight->Channel(); auto deconv = GetAttribute(); int kernel_w = deconv->kernelW(); diff --git a/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc index 273eb63c76..359341aa13 100644 --- a/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc +++ b/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc @@ -154,7 +154,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { weightTensor->format = schema::Format_CHWK; } else if (opType == schema::PrimitiveType_DeConv2D) { - weightTensor->format = schema::Format_KHWC; + weightTensor->format = schema::Format_CHWK; } else { MS_LOG(ERROR) << "unsupport format"; return -1; @@ -367,8 +367,8 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); - } else if (weightTensor->format == schema::Format_KHWC) { // from tf - status = RET_OK; + } else if (weightTensor->format == schema::Format_CHWK) { // from tf + status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); } else { MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; return -1; @@ -390,7 +390,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { return -1; } if (status == 0) { - node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW; + node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC; weightTensor->format = schema::Format_CKHW; } else { MS_LOG(WARNING) << "TransFilter HWKCToCKHW failed, node : " << node->name.c_str(); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc index c8e970ea2a..50ca889d87 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc @@ -49,10 +49,10 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr &tflit return RET_ERROR; } auto weight_shape = weight_tensor->shape; - attr->channelIn = weight_shape[KHWC_C]; - attr->channelOut = weight_shape[KHWC_K]; - attr->kernelW = weight_shape[KHWC_W]; - attr->kernelH = weight_shape[KHWC_H]; + attr->channelIn = weight_shape[CHWK_K]; + attr->channelOut = weight_shape[CHWK_C]; + attr->kernelW = weight_shape[CHWK_W]; + attr->kernelH = weight_shape[CHWK_H]; if (op != nullptr) { op->primitive = std::make_unique();