diff --git a/mindspore/lite/src/ops/resize.cc b/mindspore/lite/src/ops/resize.cc index 6070a28874..ac06eec07e 100644 --- a/mindspore/lite/src/ops/resize.cc +++ b/mindspore/lite/src/ops/resize.cc @@ -97,11 +97,29 @@ int Resize::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: namespace { constexpr int kInputRank = 4; } // namespace +template +void CalShape(const T *data, const std::vector &inputs, std::vector *out_shape, int shape_size) { + int input_count = inputs[0]->ElementsNum(); + int index = 0; + int size = 1; + for (int i = 0; i < shape_size; i++) { + if (static_cast(data[i]) == -1) { + index = i; + } else { + size *= data[i]; + } + out_shape->push_back(data[i]); + } + if (static_cast(data[index]) == -1) { + (*out_shape)[index] = input_count / size; + } +} + int Resize::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); if (input == nullptr) { - return 1; + return RET_ERROR; } if (input->shape().size() != kInputRank) { MS_LOG(ERROR) << "Size of input shape is wrong."; @@ -110,20 +128,58 @@ int Resize::InferShape(std::vector inputs_, std::vectorset_data_type(input->data_type()); output->SetFormat(input->GetFormat()); if (!GetInferFlag()) { return RET_OK; } - auto new_height = GetNewHeight(); - auto new_width = GetNewWidth(); std::vector output_shape; output_shape.push_back(input->Batch()); - output_shape.push_back(new_height); - output_shape.push_back(new_width); + if (inputs_.size() == kDoubleNum) { + auto shape_tensor = inputs_.at(1); + if (shape_tensor->data_c() == nullptr) { + MS_LOG(INFO) << "Do infer shape in runtime."; + return RET_INFER_INVALID; + } + size_t shape_size = shape_tensor->ElementsNum(); + switch (shape_tensor->data_type()) { + case kNumberTypeInt8: { + auto data = reinterpret_cast(shape_tensor->MutableData()); + CalShape(data, inputs_, &output_shape, shape_size); + } break; + case kNumberTypeInt32: { + auto data = reinterpret_cast(shape_tensor->MutableData()); + CalShape(data, inputs_, &output_shape, shape_size); + } break; + case kNumberTypeInt64: { + auto data = reinterpret_cast(shape_tensor->MutableData()); + CalShape(data, inputs_, &output_shape, shape_size); + } break; + case kNumberTypeFloat: { + auto data = reinterpret_cast(shape_tensor->MutableData()); + CalShape(data, inputs_, &output_shape, shape_size); + } break; + case kNumberTypeUInt32: { + auto data = reinterpret_cast(shape_tensor->MutableData()); + CalShape(data, inputs_, &output_shape, shape_size); + } break; + default: { + MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type(); + return RET_INFER_ERR; + } + } + } else if (inputs_.size() == kSingleNum) { + auto new_height = GetNewHeight(); + auto new_width = GetNewWidth(); + output_shape.push_back(new_height); + output_shape.push_back(new_width); + } else { + MS_LOG(ERROR) << "inputs tensor size invalid."; + return RET_INFER_ERR; + } output_shape.push_back(input->Channel()); output->set_shape(output_shape); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc index cbbb3d0c63..6ba82a7b05 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc @@ -85,16 +85,22 @@ STATUS TfliteResizeParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } auto buffData = reinterpret_cast(buff->data.data()); - auto height = buffData[0]; - auto width = buffData[1]; - attr->newWidth = width; - attr->newHeight = height; + if (buffData != nullptr) { + auto height = buffData[0]; + auto width = buffData[1]; + attr->newWidth = width; + attr->newHeight = height; + } op->primitive->value.type = schema::PrimitiveType_Resize; op->primitive->value.value = attr.release(); AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format::Format_NHWC); + if (buffData == nullptr) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), + tflite_tensors.size(), schema::Format::Format_NHWC); + } AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format::Format_NHWC); return RET_OK;