From 65488f4c10daacf360c1b5f3a740e00fb5acb03b Mon Sep 17 00:00:00 2001 From: mengyuanli Date: Thu, 28 Jan 2021 16:28:26 +0800 Subject: [PATCH] fix bug when zero shape --- mindspore/lite/src/ops/conv2d.cc | 4 +- mindspore/lite/src/ops/merge.cc | 55 ++++++++++++++++--- mindspore/lite/src/ops/merge.h | 5 ++ mindspore/lite/src/ops/reshape.cc | 27 +++++---- .../src/runtime/kernel/arm/base/carry_data.cc | 25 ++++++--- .../src/runtime/kernel/arm/base/carry_data.h | 2 +- .../kernel/arm/fp32/tensorlist_stack_fp32.cc | 13 ++++- mindspore/lite/test/models_onnx.cfg | 1 + .../parser/onnx/onnx_nonzero_parser.cc | 4 +- 9 files changed, 103 insertions(+), 33 deletions(-) diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index 66656db10f..5135891cea 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -406,8 +406,8 @@ int Conv2D::InferShape(std::vector inputs_, std::vector outp this->ConvInferShape(input_h, input_w, &output_h, &output_w); std::vector out_shape{input_tensor->shape()}; - out_shape.at(1) = output_h > 0 ? output_h : 1; - out_shape.at(2) = output_w > 0 ? output_w : 1; + out_shape.at(1) = output_h >= 0 ? output_h : 1; + out_shape.at(2) = output_w >= 0 ? output_w : 1; out_shape.at(3) = weight_tensor->shape()[0]; out_tensor->set_shape(out_shape); diff --git a/mindspore/lite/src/ops/merge.cc b/mindspore/lite/src/ops/merge.cc index c93a9e0442..fcf8b505b7 100644 --- a/mindspore/lite/src/ops/merge.cc +++ b/mindspore/lite/src/ops/merge.cc @@ -66,14 +66,25 @@ PrimitiveC *MergeCreator(const schema::Primitive *primitive) { return PrimitiveC Registry MergeRegistry(schema::PrimitiveType_Merge, MergeCreator); #endif -int Merge::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(inputs_.size() == 2 * outputs_.size()); - if (!infer_flag()) { - return RET_INFER_INVALID; +InferStatus Merge::AbleToInfer(const std::vector &inputs) { + for (auto &input : inputs) { + if (input->shape().empty()) { + return HasZeroShape; + } + if (input->root_tensor() != nullptr && input->root_tensor()->data_c() != nullptr) { + continue; + } + if (input->data_c() == nullptr) { + return NotAble; + } } - for (size_t i = 0; i < inputs_.size() / 2; i++) { - auto *input = inputs_[i]; - auto *output = outputs_[i]; + return Able; +} + +int Merge::Infer(const std::vector &inputs, const std::vector &outputs) { + for (size_t i = 0; i < inputs.size(); i++) { + auto *input = inputs[i]; + auto *output = outputs[i]; if (input == nullptr) { MS_LOG(ERROR) << "input tensor is nullptr"; return RET_ERROR; @@ -98,5 +109,35 @@ int Merge::InferShape(std::vector inputs_, std::vector outpu } return RET_OK; } + +int Merge::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(inputs_.size() == 2 * outputs_.size()); + for (size_t i = 0; i < outputs_.size(); ++i) { + outputs_[i]->set_data_type(inputs_[i]->data_type()); + } + if (!infer_flag()) { + return RET_INFER_INVALID; + } + + std::vector left_part_inputs{}; + left_part_inputs.assign(inputs_.begin(), inputs_.begin() + inputs_.size() / 2); + + std::vector right_part_inputs{}; + right_part_inputs.assign(inputs_.begin() + inputs_.size() / 2, inputs_.end()); + + if (AbleToInfer(left_part_inputs) == Able) { + return Infer(left_part_inputs, outputs_); + } + + if (AbleToInfer(right_part_inputs) == Able) { + return Infer(right_part_inputs, outputs_); + } + + if (AbleToInfer(left_part_inputs) == HasZeroShape && AbleToInfer(right_part_inputs) == HasZeroShape) { + return Infer(left_part_inputs, outputs_); + } + + return RET_INFER_INVALID; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/merge.h b/mindspore/lite/src/ops/merge.h index 446fc76e09..fa177913ec 100644 --- a/mindspore/lite/src/ops/merge.h +++ b/mindspore/lite/src/ops/merge.h @@ -24,6 +24,7 @@ namespace mindspore { namespace lite { +enum InferStatus { Able, NotAble, HasZeroShape }; class Merge : public PrimitiveC { public: @@ -37,6 +38,10 @@ class Merge : public PrimitiveC { int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; + + private: + static InferStatus AbleToInfer(const std::vector &inputs); + static int Infer(const std::vector &inputs, const std::vector &outputs); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/reshape.cc b/mindspore/lite/src/ops/reshape.cc index 67956011bd..6e502a2fb6 100644 --- a/mindspore/lite/src/ops/reshape.cc +++ b/mindspore/lite/src/ops/reshape.cc @@ -116,12 +116,12 @@ int Reshape::CalNewShape(const Tensor *in_tensor, std::vector *out_shape) c for (size_t i = 0; i < in_tensor->shape().size(); i++) { in_shape_size *= in_tensor->shape().at(i); } - int64_t inferIndex = -1; - size_t out_shapeSize = 1; + int64_t infer_index = -1; + size_t out_shape_size = 1; for (size_t i = 0; i < out_shape->size(); i++) { if (out_shape->at(i) == -1) { - if (inferIndex == -1) { - inferIndex = i; + if (infer_index == -1) { + infer_index = i; } else { MS_LOG(ERROR) << "output shape should has no more than one dim which need infer"; return RET_INFER_ERR; @@ -130,18 +130,23 @@ int Reshape::CalNewShape(const Tensor *in_tensor, std::vector *out_shape) c MS_LOG(ERROR) << "output shape dim should be non-negative"; return RET_INFER_ERR; } else if (out_shape->at(i) == 0) { - out_shape->at(i) = in_tensor->shape().at(i); - out_shapeSize *= out_shape->at(i); + if (in_tensor->ElementsNum() != 0) { + out_shape->at(i) = in_tensor->shape().at(i); + out_shape_size *= out_shape->at(i); + } else { + out_shape_size = 0; + break; + } } else { - out_shapeSize *= out_shape->at(i); + out_shape_size *= out_shape->at(i); } } - if (inferIndex == -1 && out_shapeSize != in_shape_size) { - MS_LOG(ERROR) << "output shapeSize: " << out_shapeSize << " should be equal to input shapeSize: " << in_shape_size; + if (infer_index == -1 && out_shape_size != in_shape_size) { + MS_LOG(ERROR) << "output shapeSize: " << out_shape_size << " should be equal to input shapeSize: " << in_shape_size; return RET_INFER_ERR; } - if (inferIndex != -1) { - out_shape->at(inferIndex) = in_shape_size / out_shapeSize; + if (infer_index != -1) { + out_shape->at(infer_index) = in_shape_size / out_shape_size; } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/carry_data.cc b/mindspore/lite/src/runtime/kernel/arm/base/carry_data.cc index 9adc022b97..4d5b4f89a3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/carry_data.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/carry_data.cc @@ -39,7 +39,7 @@ int CarryDataKernel::MoveData(std::vector::iterator dst_begin, } lite::STATUS ret; if (src_tensor->data_type() == kObjectTypeTensorType && dst_tensor->data_type() == kObjectTypeTensorType) { - ret = MoveTensorLiteData(reinterpret_cast(dst_tensor), + ret = MoveTensorListData(reinterpret_cast(dst_tensor), reinterpret_cast(src_tensor)); } else { ret = MoveTensorData(dst_tensor, src_tensor); @@ -55,7 +55,13 @@ int CarryDataKernel::MoveData(std::vector::iterator dst_begin, int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor) { if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format() || !(dst_tensor->shape() == src_tensor->shape() || (dst_tensor->shape().empty() && src_tensor->shape().empty()))) { - MS_LOG(ERROR) << "input tensor and output tensor is incompatible"; + MS_LOG(ERROR) << "input tensor and output tensor is incompatible."; + MS_LOG(ERROR) << "input tensor data_type: " << src_tensor->data_type() << " vs " + << "output tensor data_type: " << dst_tensor->data_type() + << "input tensor format: " << src_tensor->format() << " vs " + << "output tensor format: " << dst_tensor->format() << "input tensor shape: " << src_tensor->shape() + << " vs " + << "output tensor shape: " << dst_tensor->shape(); return RET_ERROR; } if (src_tensor->root_tensor() == nullptr) { @@ -83,18 +89,19 @@ int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_ return RET_OK; } -int CarryDataKernel::MoveTensorLiteData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor) { +int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor) { // shape may change, because tensors.size() can be change in RunGraph if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format()) { MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible"; + MS_LOG(ERROR) << "input tensor data_type: " << src_tensor->data_type() << " vs " + << "output tensor data_type: " << dst_tensor->data_type() + << "input tensor format: " << src_tensor->format() << " vs " + << "output tensor format: " << dst_tensor->format(); return RET_ERROR; } - if (dst_tensor->element_shape().empty()) { - dst_tensor->set_element_shape(src_tensor->element_shape()); - } else if (dst_tensor->element_shape() != src_tensor->element_shape()) { - MS_LOG(ERROR) << "input tensorlist and output tensorlist element shape is incompatible"; - return RET_ERROR; - } + // when tensorlist malloc is done. this need to check element_shape compatibility + dst_tensor->set_element_shape(src_tensor->element_shape()); + auto update_data_type = kTypeUnknown; auto dst_tensor_data_type = dst_tensor->tensors_data_type(); auto src_tensor_data_type = src_tensor->tensors_data_type(); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/carry_data.h b/mindspore/lite/src/runtime/kernel/arm/base/carry_data.h index ba960b772b..d122d2562d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/carry_data.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/carry_data.h @@ -34,7 +34,7 @@ class CarryDataKernel : public LiteKernel { int MoveData(std::vector::iterator dst_begin, std::vector::iterator dst_end, std::vector::iterator src_begin, std::vector::iterator src_limit); static int MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor); - static int MoveTensorLiteData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor); + static int MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor); }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_stack_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_stack_fp32.cc index fe673e3dae..eb8a662789 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_stack_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_stack_fp32.cc @@ -146,6 +146,14 @@ int TensorListStackCPUKernel::MergeSubShape(const std::vector &shape) { } int TensorListStackCPUKernel::Run() { + if (dtype_ == kTypeUnknown) { + dtype_ = input0_->tensors_data_type(); +#ifdef ENABLE_FP16 + if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) { + dtype_ = kNumberTypeFloat16; + } +#endif + } if (CheckParam() != RET_OK) { MS_LOG(ERROR) << "CheckParam failed!"; return RET_ERROR; @@ -169,7 +177,10 @@ int TensorListStackCPUKernel::Run() { MS_ASSERT(out_data != nullptr); for (int i = 0; i < num_element_; ++i) { auto in_ptr = input0_->GetTensor(i); - MS_ASSERT(in_ptr != nullptr); + if (in_ptr == nullptr) { + MS_LOG(DEBUG) << "no need to stack."; + continue; + } if (in_ptr->data_type() != kTypeUnknown) { int data_size = in_ptr->ElementsNum() * lite::DataTypeSize(dtype_); auto in_data = in_ptr->data_c(); diff --git a/mindspore/lite/test/models_onnx.cfg b/mindspore/lite/test/models_onnx.cfg index c153f5e54e..9439d0a423 100644 --- a/mindspore/lite/test/models_onnx.cfg +++ b/mindspore/lite/test/models_onnx.cfg @@ -44,3 +44,4 @@ ml_video_edit_style_transfer_gongnongbing.onnx ml_video_edit_style_transfer_starry.onnx ml_video_edit_judge.onnx ml_video_edit_vignet.onnx +ssd_mobilenet_v1_10.onnx;1,383,640,3 diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_nonzero_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_nonzero_parser.cc index 556a798ab1..ce9fc1194d 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_nonzero_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_nonzero_parser.cc @@ -23,7 +23,7 @@ namespace lite { lite::PrimitiveC *OnnxNonZeroParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { MS_LOG(DEBUG) << "onnx NonZeroParser"; - auto attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return nullptr; @@ -33,7 +33,7 @@ lite::PrimitiveC *OnnxNonZeroParser::ParseLitePrimitive(const onnx::GraphProto & MS_LOG(ERROR) << "new primitive failed"; return nullptr; } - primitive->value.type = schema::PrimitiveType_NonZero; + primitive->value.type = schema::PrimitiveType_Where; primitive->value.value = attr.release(); return PrimitiveC::Create(primitive.release()); }