!11818 [MS][LITE]fix bug when run with zero shape

From: @mengyuanli
Reviewed-by: 
Signed-off-by:
pull/11818/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit cd1e341e2e

@ -406,8 +406,8 @@ int Conv2D::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
this->ConvInferShape(input_h, input_w, &output_h, &output_w); this->ConvInferShape(input_h, input_w, &output_h, &output_w);
std::vector<int> out_shape{input_tensor->shape()}; std::vector<int> out_shape{input_tensor->shape()};
out_shape.at(1) = output_h > 0 ? output_h : 1; out_shape.at(1) = output_h >= 0 ? output_h : 1;
out_shape.at(2) = output_w > 0 ? output_w : 1; out_shape.at(2) = output_w >= 0 ? output_w : 1;
out_shape.at(3) = weight_tensor->shape()[0]; out_shape.at(3) = weight_tensor->shape()[0];
out_tensor->set_shape(out_shape); out_tensor->set_shape(out_shape);

@ -66,14 +66,25 @@ PrimitiveC *MergeCreator(const schema::Primitive *primitive) { return PrimitiveC
Registry MergeRegistry(schema::PrimitiveType_Merge, MergeCreator); Registry MergeRegistry(schema::PrimitiveType_Merge, MergeCreator);
#endif #endif
int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { InferStatus Merge::AbleToInfer(const std::vector<lite::Tensor *> &inputs) {
MS_ASSERT(inputs_.size() == 2 * outputs_.size()); for (auto &input : inputs) {
if (!infer_flag()) { if (input->shape().empty()) {
return RET_INFER_INVALID; return HasZeroShape;
}
if (input->root_tensor() != nullptr && input->root_tensor()->data_c() != nullptr) {
continue;
}
if (input->data_c() == nullptr) {
return NotAble;
}
} }
for (size_t i = 0; i < inputs_.size() / 2; i++) { return Able;
auto *input = inputs_[i]; }
auto *output = outputs_[i];
int Merge::Infer(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs) {
for (size_t i = 0; i < inputs.size(); i++) {
auto *input = inputs[i];
auto *output = outputs[i];
if (input == nullptr) { if (input == nullptr) {
MS_LOG(ERROR) << "input tensor is nullptr"; MS_LOG(ERROR) << "input tensor is nullptr";
return RET_ERROR; return RET_ERROR;
@ -98,5 +109,35 @@ int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
} }
return RET_OK; return RET_OK;
} }
int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(inputs_.size() == 2 * outputs_.size());
for (size_t i = 0; i < outputs_.size(); ++i) {
outputs_[i]->set_data_type(inputs_[i]->data_type());
}
if (!infer_flag()) {
return RET_INFER_INVALID;
}
std::vector<Tensor *> left_part_inputs{};
left_part_inputs.assign(inputs_.begin(), inputs_.begin() + inputs_.size() / 2);
std::vector<Tensor *> right_part_inputs{};
right_part_inputs.assign(inputs_.begin() + inputs_.size() / 2, inputs_.end());
if (AbleToInfer(left_part_inputs) == Able) {
return Infer(left_part_inputs, outputs_);
}
if (AbleToInfer(right_part_inputs) == Able) {
return Infer(right_part_inputs, outputs_);
}
if (AbleToInfer(left_part_inputs) == HasZeroShape && AbleToInfer(right_part_inputs) == HasZeroShape) {
return Infer(left_part_inputs, outputs_);
}
return RET_INFER_INVALID;
}
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -24,6 +24,7 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
enum InferStatus { Able, NotAble, HasZeroShape };
class Merge : public PrimitiveC { class Merge : public PrimitiveC {
public: public:
@ -37,6 +38,10 @@ class Merge : public PrimitiveC {
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
private:
static InferStatus AbleToInfer(const std::vector<lite::Tensor *> &inputs);
static int Infer(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs);
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -116,12 +116,12 @@ int Reshape::CalNewShape(const Tensor *in_tensor, std::vector<int> *out_shape) c
for (size_t i = 0; i < in_tensor->shape().size(); i++) { for (size_t i = 0; i < in_tensor->shape().size(); i++) {
in_shape_size *= in_tensor->shape().at(i); in_shape_size *= in_tensor->shape().at(i);
} }
int64_t inferIndex = -1; int64_t infer_index = -1;
size_t out_shapeSize = 1; size_t out_shape_size = 1;
for (size_t i = 0; i < out_shape->size(); i++) { for (size_t i = 0; i < out_shape->size(); i++) {
if (out_shape->at(i) == -1) { if (out_shape->at(i) == -1) {
if (inferIndex == -1) { if (infer_index == -1) {
inferIndex = i; infer_index = i;
} else { } else {
MS_LOG(ERROR) << "output shape should has no more than one dim which need infer"; MS_LOG(ERROR) << "output shape should has no more than one dim which need infer";
return RET_INFER_ERR; return RET_INFER_ERR;
@ -130,18 +130,23 @@ int Reshape::CalNewShape(const Tensor *in_tensor, std::vector<int> *out_shape) c
MS_LOG(ERROR) << "output shape dim should be non-negative"; MS_LOG(ERROR) << "output shape dim should be non-negative";
return RET_INFER_ERR; return RET_INFER_ERR;
} else if (out_shape->at(i) == 0) { } else if (out_shape->at(i) == 0) {
out_shape->at(i) = in_tensor->shape().at(i); if (in_tensor->ElementsNum() != 0) {
out_shapeSize *= out_shape->at(i); out_shape->at(i) = in_tensor->shape().at(i);
out_shape_size *= out_shape->at(i);
} else {
out_shape_size = 0;
break;
}
} else { } else {
out_shapeSize *= out_shape->at(i); out_shape_size *= out_shape->at(i);
} }
} }
if (inferIndex == -1 && out_shapeSize != in_shape_size) { if (infer_index == -1 && out_shape_size != in_shape_size) {
MS_LOG(ERROR) << "output shapeSize: " << out_shapeSize << " should be equal to input shapeSize: " << in_shape_size; MS_LOG(ERROR) << "output shapeSize: " << out_shape_size << " should be equal to input shapeSize: " << in_shape_size;
return RET_INFER_ERR; return RET_INFER_ERR;
} }
if (inferIndex != -1) { if (infer_index != -1) {
out_shape->at(inferIndex) = in_shape_size / out_shapeSize; out_shape->at(infer_index) = in_shape_size / out_shape_size;
} }
return RET_OK; return RET_OK;
} }

@ -39,7 +39,7 @@ int CarryDataKernel::MoveData(std::vector<lite::Tensor *>::iterator dst_begin,
} }
lite::STATUS ret; lite::STATUS ret;
if (src_tensor->data_type() == kObjectTypeTensorType && dst_tensor->data_type() == kObjectTypeTensorType) { if (src_tensor->data_type() == kObjectTypeTensorType && dst_tensor->data_type() == kObjectTypeTensorType) {
ret = MoveTensorLiteData(reinterpret_cast<lite::TensorList *>(dst_tensor), ret = MoveTensorListData(reinterpret_cast<lite::TensorList *>(dst_tensor),
reinterpret_cast<lite::TensorList *>(src_tensor)); reinterpret_cast<lite::TensorList *>(src_tensor));
} else { } else {
ret = MoveTensorData(dst_tensor, src_tensor); ret = MoveTensorData(dst_tensor, src_tensor);
@ -55,7 +55,13 @@ int CarryDataKernel::MoveData(std::vector<lite::Tensor *>::iterator dst_begin,
int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor) { int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor) {
if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format() || if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format() ||
!(dst_tensor->shape() == src_tensor->shape() || (dst_tensor->shape().empty() && src_tensor->shape().empty()))) { !(dst_tensor->shape() == src_tensor->shape() || (dst_tensor->shape().empty() && src_tensor->shape().empty()))) {
MS_LOG(ERROR) << "input tensor and output tensor is incompatible"; MS_LOG(ERROR) << "input tensor and output tensor is incompatible.";
MS_LOG(ERROR) << "input tensor data_type: " << src_tensor->data_type() << " vs "
<< "output tensor data_type: " << dst_tensor->data_type()
<< "input tensor format: " << src_tensor->format() << " vs "
<< "output tensor format: " << dst_tensor->format() << "input tensor shape: " << src_tensor->shape()
<< " vs "
<< "output tensor shape: " << dst_tensor->shape();
return RET_ERROR; return RET_ERROR;
} }
if (src_tensor->root_tensor() == nullptr) { if (src_tensor->root_tensor() == nullptr) {
@ -83,18 +89,19 @@ int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_
return RET_OK; return RET_OK;
} }
int CarryDataKernel::MoveTensorLiteData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor) { int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor) {
// shape may change, because tensors.size() can be change in RunGraph // shape may change, because tensors.size() can be change in RunGraph
if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format()) { if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format()) {
MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible"; MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible";
MS_LOG(ERROR) << "input tensor data_type: " << src_tensor->data_type() << " vs "
<< "output tensor data_type: " << dst_tensor->data_type()
<< "input tensor format: " << src_tensor->format() << " vs "
<< "output tensor format: " << dst_tensor->format();
return RET_ERROR; return RET_ERROR;
} }
if (dst_tensor->element_shape().empty()) { // when tensorlist malloc is done. this need to check element_shape compatibility
dst_tensor->set_element_shape(src_tensor->element_shape()); dst_tensor->set_element_shape(src_tensor->element_shape());
} else if (dst_tensor->element_shape() != src_tensor->element_shape()) {
MS_LOG(ERROR) << "input tensorlist and output tensorlist element shape is incompatible";
return RET_ERROR;
}
auto update_data_type = kTypeUnknown; auto update_data_type = kTypeUnknown;
auto dst_tensor_data_type = dst_tensor->tensors_data_type(); auto dst_tensor_data_type = dst_tensor->tensors_data_type();
auto src_tensor_data_type = src_tensor->tensors_data_type(); auto src_tensor_data_type = src_tensor->tensors_data_type();

@ -34,7 +34,7 @@ class CarryDataKernel : public LiteKernel {
int MoveData(std::vector<lite::Tensor *>::iterator dst_begin, std::vector<lite::Tensor *>::iterator dst_end, int MoveData(std::vector<lite::Tensor *>::iterator dst_begin, std::vector<lite::Tensor *>::iterator dst_end,
std::vector<lite::Tensor *>::iterator src_begin, std::vector<lite::Tensor *>::iterator src_limit); std::vector<lite::Tensor *>::iterator src_begin, std::vector<lite::Tensor *>::iterator src_limit);
static int MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor); static int MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor);
static int MoveTensorLiteData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor); static int MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor);
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel

@ -146,6 +146,14 @@ int TensorListStackCPUKernel::MergeSubShape(const std::vector<int> &shape) {
} }
int TensorListStackCPUKernel::Run() { int TensorListStackCPUKernel::Run() {
if (dtype_ == kTypeUnknown) {
dtype_ = input0_->tensors_data_type();
#ifdef ENABLE_FP16
if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) {
dtype_ = kNumberTypeFloat16;
}
#endif
}
if (CheckParam() != RET_OK) { if (CheckParam() != RET_OK) {
MS_LOG(ERROR) << "CheckParam failed!"; MS_LOG(ERROR) << "CheckParam failed!";
return RET_ERROR; return RET_ERROR;
@ -169,7 +177,10 @@ int TensorListStackCPUKernel::Run() {
MS_ASSERT(out_data != nullptr); MS_ASSERT(out_data != nullptr);
for (int i = 0; i < num_element_; ++i) { for (int i = 0; i < num_element_; ++i) {
auto in_ptr = input0_->GetTensor(i); auto in_ptr = input0_->GetTensor(i);
MS_ASSERT(in_ptr != nullptr); if (in_ptr == nullptr) {
MS_LOG(DEBUG) << "no need to stack.";
continue;
}
if (in_ptr->data_type() != kTypeUnknown) { if (in_ptr->data_type() != kTypeUnknown) {
int data_size = in_ptr->ElementsNum() * lite::DataTypeSize(dtype_); int data_size = in_ptr->ElementsNum() * lite::DataTypeSize(dtype_);
auto in_data = in_ptr->data_c(); auto in_data = in_ptr->data_c();

@ -44,3 +44,4 @@ ml_video_edit_style_transfer_gongnongbing.onnx
ml_video_edit_style_transfer_starry.onnx ml_video_edit_style_transfer_starry.onnx
ml_video_edit_judge.onnx ml_video_edit_judge.onnx
ml_video_edit_vignet.onnx ml_video_edit_vignet.onnx
ssd_mobilenet_v1_10.onnx;1,383,640,3

@ -23,7 +23,7 @@ namespace lite {
lite::PrimitiveC *OnnxNonZeroParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, lite::PrimitiveC *OnnxNonZeroParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) { const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx NonZeroParser"; MS_LOG(DEBUG) << "onnx NonZeroParser";
auto attr = std::make_unique<schema::NonZeroT>(); auto attr = std::make_unique<schema::WhereT>();
if (attr == nullptr) { if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed"; MS_LOG(ERROR) << "new op failed";
return nullptr; return nullptr;
@ -33,7 +33,7 @@ lite::PrimitiveC *OnnxNonZeroParser::ParseLitePrimitive(const onnx::GraphProto &
MS_LOG(ERROR) << "new primitive failed"; MS_LOG(ERROR) << "new primitive failed";
return nullptr; return nullptr;
} }
primitive->value.type = schema::PrimitiveType_NonZero; primitive->value.type = schema::PrimitiveType_Where;
primitive->value.value = attr.release(); primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release()); return PrimitiveC::Create(primitive.release());
} }

Loading…
Cancel
Save