From c9b6f6d0d87b072105227f102c207d4f620692e1 Mon Sep 17 00:00:00 2001 From: yangruoqi713 Date: Thu, 3 Dec 2020 21:51:19 +0800 Subject: [PATCH] [MSLITE][Develop] fix bug of lite infershape --- mindspore/lite/src/scheduler.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 48f8a4602f..324378faee 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -111,15 +111,15 @@ int Scheduler::InferShape(const lite::Model *model, std::vector *tenso MS_LOG(ERROR) << "Op " << node->name_ << " should exist in model!"; return RET_ERROR; } - STATUS ret = RET_INFER_INVALID; - bool infer_valid = std::all_of(inputs.begin(), inputs.end(), [](Tensor *tensor) { + bool infer_valid = std::all_of(inputs.begin(), inputs.end(), [](const Tensor *tensor) { auto shape = tensor->shape(); - return std::all_of(shape.begin(), shape.end(), [](int dim) { return dim != -1; }); + return std::all_of(shape.begin(), shape.end(), [](const int dim) { return dim != -1; }); }); - if (infer_valid) { - primitive->set_infer_flag(!infer_shape_interrupt); - ret = primitive->InferShape(inputs, outputs); + if (!infer_valid) { + infer_shape_interrupt = true; } + primitive->set_infer_flag(!infer_shape_interrupt); + auto ret = primitive->InferShape(inputs, outputs); if (ret == RET_INFER_INVALID) { MS_LOG(INFO) << "InferShape shouldn't be done before runtime, name: " << node->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(primitive->Type()))