|
|
|
@ -97,11 +97,29 @@ int Resize::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr int kInputRank = 4;
|
|
|
|
|
} // namespace
|
|
|
|
|
template <typename T>
|
|
|
|
|
void CalShape(const T *data, const std::vector<Tensor *> &inputs, std::vector<int> *out_shape, int shape_size) {
|
|
|
|
|
int input_count = inputs[0]->ElementsNum();
|
|
|
|
|
int index = 0;
|
|
|
|
|
int size = 1;
|
|
|
|
|
for (int i = 0; i < shape_size; i++) {
|
|
|
|
|
if (static_cast<int>(data[i]) == -1) {
|
|
|
|
|
index = i;
|
|
|
|
|
} else {
|
|
|
|
|
size *= data[i];
|
|
|
|
|
}
|
|
|
|
|
out_shape->push_back(data[i]);
|
|
|
|
|
}
|
|
|
|
|
if (static_cast<int>(data[index]) == -1) {
|
|
|
|
|
(*out_shape)[index] = input_count / size;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int Resize::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
|
|
|
|
|
MS_ASSERT(this->primitive_ != nullptr);
|
|
|
|
|
auto input = inputs_.front();
|
|
|
|
|
if (input == nullptr) {
|
|
|
|
|
return 1;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (input->shape().size() != kInputRank) {
|
|
|
|
|
MS_LOG(ERROR) << "Size of input shape is wrong.";
|
|
|
|
@ -110,20 +128,58 @@ int Resize::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te
|
|
|
|
|
|
|
|
|
|
auto output = outputs_.front();
|
|
|
|
|
if (output == nullptr) {
|
|
|
|
|
return 1;
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
output->set_data_type(input->data_type());
|
|
|
|
|
output->SetFormat(input->GetFormat());
|
|
|
|
|
if (!GetInferFlag()) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
auto new_height = GetNewHeight();
|
|
|
|
|
auto new_width = GetNewWidth();
|
|
|
|
|
|
|
|
|
|
std::vector<int> output_shape;
|
|
|
|
|
output_shape.push_back(input->Batch());
|
|
|
|
|
output_shape.push_back(new_height);
|
|
|
|
|
output_shape.push_back(new_width);
|
|
|
|
|
if (inputs_.size() == kDoubleNum) {
|
|
|
|
|
auto shape_tensor = inputs_.at(1);
|
|
|
|
|
if (shape_tensor->data_c() == nullptr) {
|
|
|
|
|
MS_LOG(INFO) << "Do infer shape in runtime.";
|
|
|
|
|
return RET_INFER_INVALID;
|
|
|
|
|
}
|
|
|
|
|
size_t shape_size = shape_tensor->ElementsNum();
|
|
|
|
|
switch (shape_tensor->data_type()) {
|
|
|
|
|
case kNumberTypeInt8: {
|
|
|
|
|
auto data = reinterpret_cast<int8_t *>(shape_tensor->MutableData());
|
|
|
|
|
CalShape<int8_t>(data, inputs_, &output_shape, shape_size);
|
|
|
|
|
} break;
|
|
|
|
|
case kNumberTypeInt32: {
|
|
|
|
|
auto data = reinterpret_cast<int32_t *>(shape_tensor->MutableData());
|
|
|
|
|
CalShape<int32_t>(data, inputs_, &output_shape, shape_size);
|
|
|
|
|
} break;
|
|
|
|
|
case kNumberTypeInt64: {
|
|
|
|
|
auto data = reinterpret_cast<int64_t *>(shape_tensor->MutableData());
|
|
|
|
|
CalShape<int64_t>(data, inputs_, &output_shape, shape_size);
|
|
|
|
|
} break;
|
|
|
|
|
case kNumberTypeFloat: {
|
|
|
|
|
auto data = reinterpret_cast<float *>(shape_tensor->MutableData());
|
|
|
|
|
CalShape<float>(data, inputs_, &output_shape, shape_size);
|
|
|
|
|
} break;
|
|
|
|
|
case kNumberTypeUInt32: {
|
|
|
|
|
auto data = reinterpret_cast<uint32_t *>(shape_tensor->MutableData());
|
|
|
|
|
CalShape<uint32_t>(data, inputs_, &output_shape, shape_size);
|
|
|
|
|
} break;
|
|
|
|
|
default: {
|
|
|
|
|
MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type();
|
|
|
|
|
return RET_INFER_ERR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (inputs_.size() == kSingleNum) {
|
|
|
|
|
auto new_height = GetNewHeight();
|
|
|
|
|
auto new_width = GetNewWidth();
|
|
|
|
|
output_shape.push_back(new_height);
|
|
|
|
|
output_shape.push_back(new_width);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "inputs tensor size invalid.";
|
|
|
|
|
return RET_INFER_ERR;
|
|
|
|
|
}
|
|
|
|
|
output_shape.push_back(input->Channel());
|
|
|
|
|
output->set_shape(output_shape);
|
|
|
|
|
|
|
|
|
|