[MSLITE][Develop] return RET_INFET_INVALID when infer_flag is false for op infershape

pull/9506/head
yangruoqi713 4 years ago
parent c9b6f6d0d8
commit 4fc003bfae

@ -86,7 +86,7 @@ int AddN::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs
output->set_format(input->format());
output->set_data_type(input->data_type());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
output->set_shape(input->shape());

@ -74,7 +74,7 @@ int ArgMax::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
output->set_format(input->format());
output->set_data_type(input->data_type());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
std::vector<int> output_shape(input->shape());
auto input_shape_size = input->shape().size();

@ -72,7 +72,7 @@ int ArgMin::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te
output->set_format(input->format());
output->set_data_type(input->data_type());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
auto input_shape_size = input->shape().size();
auto axis = GetAxis() < 0 ? GetAxis() + input_shape_size : GetAxis();

@ -45,7 +45,7 @@ int Arithmetic::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite
output->set_format(format);
output->set_data_type(input0->data_type());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
if (input_shape0.size() > 10 || input_shape1.size() > 10) {
int wrong_dim = input_shape0.size() > input_shape1.size() ? input_shape0.size() : input_shape1.size();

@ -33,7 +33,7 @@ int ArithmeticSelf::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor
output->set_format(input->format());
output->set_data_type(input->data_type());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
output->set_shape(input->shape());
return RET_OK;

@ -77,7 +77,7 @@ int AudioSpectrogram::InferShape(std::vector<Tensor *> inputs_, std::vector<Tens
output->set_data_type(input->data_type());
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
auto input_shape = input->shape();
if (input_shape.size() != 2) {

@ -98,7 +98,7 @@ int BatchToSpace::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
outputs[0]->set_format(input->format());
outputs[0]->set_data_type(input->data_type());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
auto input_shape = input->shape();
if (input_shape.size() != kDimension_4d) {

@ -80,7 +80,7 @@ int BroadcastTo::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *>
outputs[0]->set_format(input->format());
outputs[0]->set_data_type(input->data_type());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
std::vector<int32_t> dst_shape(GetDstShape());
for (size_t i = 0; i < dst_shape.size(); ++i) {

@ -93,7 +93,7 @@ int Cast::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
output->set_data_type(static_cast<TypeId>(GetDstT()));
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
if (GetSrcT() != 0 && input->data_type() != GetSrcT()) {

@ -98,7 +98,7 @@ int Concat::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
output->set_data_type(input0->data_type());
output->set_format(input0->format());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
auto input0_shape = inputs_.at(0)->shape();

@ -83,7 +83,7 @@ int ConstantOfShape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tenso
out_tensor->set_data_type(static_cast<TypeId>(GetDataType()));
out_tensor->set_format(in_tensor->format());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
auto in_data = reinterpret_cast<int *>(in_tensor->data_c());
if (in_data == nullptr) {

@ -391,7 +391,7 @@ int Conv2D::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
pad_r_ = GetPadRight();
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
auto in_shape = input_tensor->shape();
int input_h = in_shape.at(1);

@ -71,7 +71,7 @@ int Crop::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs
outputs[0]->set_format(inputs[0]->format());
outputs[0]->set_data_type(inputs[0]->data_type());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
outputs[0]->set_shape(inputs[1]->shape());
return RET_OK;

@ -317,7 +317,7 @@ int DeConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::
output->set_format(input->format());
output->set_data_type(input->data_type());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
int32_t input_h = input->Height();
int32_t input_w = input->Width();

@ -138,7 +138,7 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
output->set_format(input->format());
output->set_data_type(input->data_type());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
auto in_shape = input->shape();
int input_h = in_shape.at(1);

@ -73,7 +73,7 @@ int DepthToSpace::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
outputs[0]->set_data_type(input->data_type());
outputs[0]->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
auto input_shape = input->shape();
if (input_shape.size() != kDimension_4d) {

@ -219,7 +219,7 @@ int DepthwiseConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vector
pad_r_ = GetPadRight();
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
auto in_shape = input->shape();
int input_h = in_shape.at(1);

@ -190,7 +190,7 @@ int DetectionPostProcess::InferShape(std::vector<lite::Tensor *> inputs_, std::v
num_det->set_format(boxes->format());
num_det->set_data_type(kNumberTypeFloat32);
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
const auto max_detections = GetMaxDetections();
const auto max_classes_per_detection = GetMaxClassesPerDetection();

@ -84,7 +84,7 @@ int Dropout::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
auto output0 = outputs_.front();
MS_ASSERT(output0 != nullptr);
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
output0->set_shape(input->shape());
output0->set_data_type(input->data_type());

@ -87,7 +87,7 @@ int DropoutGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *>
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
output->set_shape(input->shape());
output->set_data_type(input->data_type());

@ -70,7 +70,7 @@ int EmbeddingLookup::InferShape(std::vector<Tensor *> inputs_, std::vector<Tenso
output->set_format(params_->format());
output->set_data_type(params_->data_type());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
auto embedding_shape = params_->shape();

@ -103,7 +103,7 @@ int ExpandDims::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *>
output->set_data_type(input->data_type());
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
int dim = GetDim();
if (dim < 0) {

@ -43,7 +43,7 @@ int FftImag::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
output->set_data_type(TypeId::kNumberTypeFloat32);
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
auto input_shape = input->shape();
input_shape.pop_back();

@ -43,7 +43,7 @@ int FftReal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
output->set_data_type(TypeId::kNumberTypeFloat32);
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
auto input_shape = input->shape();
input_shape.pop_back();

@ -71,7 +71,7 @@ int Fill::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
output->set_data_type(input->data_type());
output->set_format(input->format());
if (!infer_flag()) {
return RET_OK;
return RET_INFER_INVALID;
}
std::vector<int> output_shape;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save