Fix bug of resize parser.

pull/6475/head
wsc 5 years ago
parent 669491237c
commit 4b40757e79

@ -97,11 +97,29 @@ int Resize::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:
namespace {
constexpr int kInputRank = 4;
} // namespace
template <typename T>
void CalShape(const T *data, const std::vector<Tensor *> &inputs, std::vector<int> *out_shape, int shape_size) {
int input_count = inputs[0]->ElementsNum();
int index = 0;
int size = 1;
for (int i = 0; i < shape_size; i++) {
if (static_cast<int>(data[i]) == -1) {
index = i;
} else {
size *= data[i];
}
out_shape->push_back(data[i]);
}
if (static_cast<int>(data[index]) == -1) {
(*out_shape)[index] = input_count / size;
}
}
int Resize::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
if (input == nullptr) {
return 1;
return RET_ERROR;
}
if (input->shape().size() != kInputRank) {
MS_LOG(ERROR) << "Size of input shape is wrong.";
@ -110,20 +128,58 @@ int Resize::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te
auto output = outputs_.front();
if (output == nullptr) {
return 1;
return RET_NULL_PTR;
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto new_height = GetNewHeight();
auto new_width = GetNewWidth();
std::vector<int> output_shape;
output_shape.push_back(input->Batch());
output_shape.push_back(new_height);
output_shape.push_back(new_width);
if (inputs_.size() == kDoubleNum) {
auto shape_tensor = inputs_.at(1);
if (shape_tensor->data_c() == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime.";
return RET_INFER_INVALID;
}
size_t shape_size = shape_tensor->ElementsNum();
switch (shape_tensor->data_type()) {
case kNumberTypeInt8: {
auto data = reinterpret_cast<int8_t *>(shape_tensor->MutableData());
CalShape<int8_t>(data, inputs_, &output_shape, shape_size);
} break;
case kNumberTypeInt32: {
auto data = reinterpret_cast<int32_t *>(shape_tensor->MutableData());
CalShape<int32_t>(data, inputs_, &output_shape, shape_size);
} break;
case kNumberTypeInt64: {
auto data = reinterpret_cast<int64_t *>(shape_tensor->MutableData());
CalShape<int64_t>(data, inputs_, &output_shape, shape_size);
} break;
case kNumberTypeFloat: {
auto data = reinterpret_cast<float *>(shape_tensor->MutableData());
CalShape<float>(data, inputs_, &output_shape, shape_size);
} break;
case kNumberTypeUInt32: {
auto data = reinterpret_cast<uint32_t *>(shape_tensor->MutableData());
CalShape<uint32_t>(data, inputs_, &output_shape, shape_size);
} break;
default: {
MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type();
return RET_INFER_ERR;
}
}
} else if (inputs_.size() == kSingleNum) {
auto new_height = GetNewHeight();
auto new_width = GetNewWidth();
output_shape.push_back(new_height);
output_shape.push_back(new_width);
} else {
MS_LOG(ERROR) << "inputs tensor size invalid.";
return RET_INFER_ERR;
}
output_shape.push_back(input->Channel());
output->set_shape(output_shape);

@ -85,16 +85,22 @@ STATUS TfliteResizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR;
}
auto buffData = reinterpret_cast<int32_t *>(buff->data.data());
auto height = buffData[0];
auto width = buffData[1];
attr->newWidth = width;
attr->newHeight = height;
if (buffData != nullptr) {
auto height = buffData[0];
auto width = buffData[1];
attr->newWidth = width;
attr->newHeight = height;
}
op->primitive->value.type = schema::PrimitiveType_Resize;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
if (buffData == nullptr) {
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
}
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
return RET_OK;

Loading…
Cancel
Save