!7396 fix upsample opeartor of onnx to resize opeartor

Merge pull request !7396 from yankai10/merge_1016
pull/7396/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit b00b9d2450

@ -127,9 +127,60 @@ int Resize::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te
return RET_INFER_INVALID;
}
size_t shape_size = shape_tensor->ElementsNum();
auto data = reinterpret_cast<int32_t *>(shape_tensor->data_c());
for (size_t i = 0; i < shape_size; i++) {
output_shape.push_back(data[i]);
switch (shape_size) {
case kDimension_4d: {
if (shape_tensor->data_type() == kNumberTypeInt32) {
auto data = reinterpret_cast<int32_t *>(shape_tensor->data_c());
if (data == nullptr) {
MS_LOG(INFO) << "Resize op size can't cast int.";
return RET_INFER_INVALID;
}
switch (shape_tensor->GetFormat()) {
case schema::Format_NCHW:
output_shape.push_back(data[2] * input->Height());
output_shape.push_back(data[3] * input->Width());
break;
case schema::Format_NHWC:
output_shape.push_back(data[1] * input->Height());
output_shape.push_back(data[2] * input->Width());
break;
default:
MS_LOG(INFO) << "Resize don't support tensor format.";
return RET_INFER_INVALID;
}
} else if (shape_tensor->data_type() == kNumberTypeFloat32) {
auto data = reinterpret_cast<float *>(shape_tensor->data_c());
if (data == nullptr) {
MS_LOG(INFO) << "Resize op size can't cast float.";
return RET_INFER_INVALID;
}
switch (shape_tensor->GetFormat()) {
case schema::Format_NCHW:
output_shape.push_back(data[2] * input->Height());
output_shape.push_back(data[3] * input->Width());
break;
case schema::Format_NHWC:
output_shape.push_back(data[1] * input->Height());
output_shape.push_back(data[2] * input->Width());
break;
default:
MS_LOG(INFO) << "Resize don't support tensor format.";
return RET_INFER_INVALID;
}
}
break;
}
default: {
auto data = reinterpret_cast<int32_t *>(shape_tensor->data_c());
if (data == nullptr) {
MS_LOG(INFO) << "Resize op size can't cast float.";
return RET_INFER_INVALID;
}
for (size_t i = 0; i < shape_size; i++) {
output_shape.push_back(data[i]);
}
break;
}
}
} else if (inputs_.size() == kSingleNum) {
auto new_height = GetNewHeight();

@ -32,7 +32,7 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:
return RET_NULL_PTR;
}
std::unique_ptr<schema::UpsampleT> attr = std::make_unique<schema::UpsampleT>();
std::unique_ptr<schema::ResizeT> attr = std::make_unique<schema::ResizeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
@ -41,14 +41,19 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "mode") {
attr->mode = onnx_node_attr.s();
} else if (attribute_name == "scales") {
for (int i = 0; i < onnx_node_attr.floats_size(); ++i) {
attr->scales[i] = onnx_node_attr.floats(i);
if ("nearest" == onnx_node_attr.s()) {
attr->method = schema::ResizeMethod_NEAREST_NEIGHBOR;
} else if ("bilinear" == onnx_node_attr.s()) {
attr->method = schema::ResizeMethod_BILINEAR;
} else {
MS_LOG(ERROR) << "Resize do not support upsample mode";
return RET_ERROR;
}
}
}
attr->newWidth = 1;
attr->newHeight = 1;
attr->alignCorners = false;
op->primitive->value.type = schema::PrimitiveType_Upsample;
op->primitive->value.value = attr.release();
return RET_OK;

Loading…
Cancel
Save