|
|
|
@ -111,15 +111,15 @@ int Scheduler::InferShape(const lite::Model *model, std::vector<Tensor *> *tenso
|
|
|
|
|
MS_LOG(ERROR) << "Op " << node->name_ << " should exist in model!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
STATUS ret = RET_INFER_INVALID;
|
|
|
|
|
bool infer_valid = std::all_of(inputs.begin(), inputs.end(), [](Tensor *tensor) {
|
|
|
|
|
bool infer_valid = std::all_of(inputs.begin(), inputs.end(), [](const Tensor *tensor) {
|
|
|
|
|
auto shape = tensor->shape();
|
|
|
|
|
return std::all_of(shape.begin(), shape.end(), [](int dim) { return dim != -1; });
|
|
|
|
|
return std::all_of(shape.begin(), shape.end(), [](const int dim) { return dim != -1; });
|
|
|
|
|
});
|
|
|
|
|
if (infer_valid) {
|
|
|
|
|
primitive->set_infer_flag(!infer_shape_interrupt);
|
|
|
|
|
ret = primitive->InferShape(inputs, outputs);
|
|
|
|
|
if (!infer_valid) {
|
|
|
|
|
infer_shape_interrupt = true;
|
|
|
|
|
}
|
|
|
|
|
primitive->set_infer_flag(!infer_shape_interrupt);
|
|
|
|
|
auto ret = primitive->InferShape(inputs, outputs);
|
|
|
|
|
if (ret == RET_INFER_INVALID) {
|
|
|
|
|
MS_LOG(INFO) << "InferShape shouldn't be done before runtime, name: " << node->name_
|
|
|
|
|
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type()))
|
|
|
|
|