[MSLITE][Develop] fix bug of lite infershape

pull/9464/head
yangruoqi713 4 years ago
parent 415539655d
commit c9b6f6d0d8

@ -111,15 +111,15 @@ int Scheduler::InferShape(const lite::Model *model, std::vector<Tensor *> *tenso
MS_LOG(ERROR) << "Op " << node->name_ << " should exist in model!";
return RET_ERROR;
}
STATUS ret = RET_INFER_INVALID;
bool infer_valid = std::all_of(inputs.begin(), inputs.end(), [](Tensor *tensor) {
bool infer_valid = std::all_of(inputs.begin(), inputs.end(), [](const Tensor *tensor) {
auto shape = tensor->shape();
return std::all_of(shape.begin(), shape.end(), [](int dim) { return dim != -1; });
return std::all_of(shape.begin(), shape.end(), [](const int dim) { return dim != -1; });
});
if (infer_valid) {
primitive->set_infer_flag(!infer_shape_interrupt);
ret = primitive->InferShape(inputs, outputs);
if (!infer_valid) {
infer_shape_interrupt = true;
}
primitive->set_infer_flag(!infer_shape_interrupt);
auto ret = primitive->InferShape(inputs, outputs);
if (ret == RET_INFER_INVALID) {
MS_LOG(INFO) << "InferShape shouldn't be done before runtime, name: " << node->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type()))

Loading…
Cancel
Save