!4665 [MS][LITE][Develop]support infer data type and format when infer shape fail

Merge pull request !4665 from chenjianping/lite_dev3
pull/4665/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 030af09f60

@ -43,6 +43,11 @@ int AddN::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
MS_LOG(ERROR) << "input size" << inputs.size() << " is error!";
return RET_INPUT_TENSOR_ERROR;
}
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
for (int i = 1; i < inputs.size(); ++i) {
if (inputs.at(i)->shape() != inputs.at(0)->shape()) {
MS_LOG(ERROR) << "AddN inputs shape is not equal!";
@ -53,9 +58,8 @@ int AddN::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
return RET_INPUT_TENSOR_ERROR;
}
}
output->SetFormat(input->GetFormat());
output->set_shape(input->shape());
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace lite

@ -55,6 +55,12 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "tensor number is error.";
}
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
auto argmax_prim = this->primitive->value_as_ArgMax();
std::vector<int> output_shape(input->shape());
auto input_shape_size = input->shape().size();
@ -68,9 +74,8 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
} else {
output_shape[axis] = argmax_prim->topK();
}
output->SetFormat(input->GetFormat());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace lite

@ -55,6 +55,11 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "tensor number is error.";
}
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
auto argmin_prim = this->primitive->value_as_ArgMin();
auto input_shape_size = input->shape().size();
int axis = argmin_prim->axis() < 0 ? argmin_prim->axis() + input_shape_size : argmin_prim->axis();
@ -68,9 +73,8 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
} else {
output_shape[axis] = argmin_prim->topK();
}
output->SetFormat(input->GetFormat());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace lite

@ -46,6 +46,11 @@ int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vec
return 1;
}
auto input = inputs.at(0);
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
std::vector<int32_t> dst_shape(this->primitive->value_as_BroadcastTo()->dst_shape()->begin(),
this->primitive->value_as_BroadcastTo()->dst_shape()->end());
auto input_shape = input->shape();
@ -72,10 +77,8 @@ int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vec
shape[i] = dst_shape[i];
--input_shape_index;
}
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_shape(shape);
outputs[0]->set_data_type(input->data_type());
return 0;
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -44,8 +44,14 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_LOG(ERROR) << "tensor number is error.";
return RET_INPUT_TENSOR_ERROR;
}
output->SetFormat(input->GetFormat());
auto cast_prim = this->primitive->value_as_Cast();
MS_ASSERT(cast_prim != nullptr);
output->set_data_type(static_cast<TypeId>(cast_prim->dstT()));
if (!GetInferFlag()) {
return RET_OK;
}
if (input->data_type() != cast_prim->srcT()) {
MS_LOG(ERROR) << "input dataType is error";
return RET_INPUT_TENSOR_ERROR;
@ -54,13 +60,8 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_LOG(ERROR) << "Unsupported input data type " << input->data_type();
return RET_INPUT_TENSOR_ERROR;
}
if (cast_prim->dstT() != kNumberTypeFloat && cast_prim->dstT() != kNumberTypeFloat32) {
MS_LOG(ERROR) << "Invalid output datatype " << cast_prim->dstT();
return RET_INPUT_TENSOR_ERROR;
}
output->SetFormat(input->GetFormat());
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeFloat32);
return RET_OK;
}
} // namespace lite

@ -50,16 +50,19 @@ int ConstantOfShape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
return RET_ERROR;
}
auto in_tensor = inputs_.front();
auto in_data = reinterpret_cast<int *>(in_tensor->Data());
auto out_tensor = outputs_.front();
out_tensor->set_data_type(kNumberTypeFloat32);
out_tensor->SetFormat(in_tensor->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto in_data = reinterpret_cast<int *>(in_tensor->Data());
int size = in_tensor->ElementsNum();
std::vector<int> out_shape(size);
for (int i = 0; i < size; ++i) {
out_shape[i] = in_data[i];
}
out_tensor->set_shape(out_shape);
out_tensor->set_data_type(kNumberTypeFloat32);
out_tensor->SetFormat(in_tensor->GetFormat());
return RET_OK;
}

@ -46,9 +46,12 @@ int Crop::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
return RET_PARAM_INVALID;
}
outputs[0]->set_shape(inputs[1]->shape());
outputs[0]->SetFormat(inputs[0]->GetFormat());
outputs[0]->set_data_type(inputs[0]->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
outputs[0]->set_shape(inputs[1]->shape());
return RET_OK;
}
} // namespace lite

@ -103,7 +103,11 @@ int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vecto
MS_ASSERT(weight != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
int32_t input_h = input->Height();
int32_t input_w = input->Width();
@ -138,8 +142,6 @@ int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vecto
std::vector<int> out_shape = {output_n, output_h, output_w, output_c};
output->set_shape(out_shape);
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
return 0;
}
} // namespace lite

@ -126,7 +126,11 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
MS_ASSERT(weight != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
auto in_shape = input->shape();
int input_h = in_shape.at(1);
int input_w = in_shape.at(2);
@ -155,8 +159,6 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel
output->set_shape(out_shape);
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
return 0;
}
} // namespace lite

@ -50,6 +50,11 @@ int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
MS_LOG(ERROR) << "depth_to_space only support NHWC now!";
return 1;
}
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape();
if (input_shape.size() != kDimension_4d) {
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
@ -68,10 +73,7 @@ int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
output_shape[NHWC_W] = input_shape[NHWC_W] * block_size;
output_shape[NHWC_C] = input_shape[NHWC_C] / (block_size * block_size);
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
return 0;
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -120,7 +120,11 @@ int DepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
MS_ASSERT(weight != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
auto in_shape = input->shape();
int input_h = in_shape.at(1);
int input_w = in_shape.at(2);
@ -158,8 +162,6 @@ int DepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel
output->set_shape(out_shape);
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
return 0;
}
} // namespace lite

@ -46,6 +46,12 @@ int EmbeddingLookup::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
MS_ASSERT(ids != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(params_->GetFormat());
output->set_data_type(params_->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
auto embedding_shape = params_->shape();
embedding_shape.erase(embedding_shape.begin());
std::vector<int> output_shape(ids->shape());
@ -61,7 +67,6 @@ int EmbeddingLookup::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
}
}
output->set_shape(output_shape);
output->set_data_type(params_->data_type());
return RET_OK;
}
} // namespace lite

@ -42,6 +42,11 @@ int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
if (outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "output size is invalid";
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto expand_dims_prim = this->primitive->value_as_ExpandDims();
int dim = expand_dims_prim->dim();
if (dim < 0) {
@ -54,8 +59,6 @@ int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
auto out_shape = input->shape();
out_shape.insert(out_shape.begin() + dim, 1, 1);
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite

@ -45,6 +45,11 @@ int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
return RET_INPUT_TENSOR_ERROR;
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto fill_prim = this->primitive->value_as_Fill();
if (fill_prim == nullptr) {
MS_LOG(ERROR) << "Fill primitive is null!";
@ -53,8 +58,6 @@ int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
std::vector<int> output_shape;
(void)output_shape.insert(output_shape.begin(), fill_prim->dims()->begin(), fill_prim->dims()->end());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite

@ -31,6 +31,13 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
return RET_INPUT_TENSOR_ERROR;
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape();
std::vector<int> output_shape(2);
output_shape[0] = input_shape[0];
@ -39,8 +46,6 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
output_shape[1] *= input_shape[i];
}
output->set_shape(output_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite

@ -51,7 +51,11 @@ int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
MS_ASSERT(input1 != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
if ((GetHasBias() && inputs_.size() != kMultiNum) || (!GetHasBias() && inputs_.size() != kDoubleNum)) {
MS_LOG(ERROR) << "Input tensors num error";
return 1;
@ -78,8 +82,6 @@ int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
out_shape.resize(GetAxis() + 1);
out_shape[GetAxis()] = input1->shape()[0];
output->set_shape(out_shape);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
return 0;
}

@ -46,6 +46,12 @@ int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
MS_ASSERT(indices != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto in_shape = input->shape();
int in_rank = in_shape.size();
auto indices_shape = indices->shape();
@ -63,8 +69,6 @@ int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
out_shape.emplace_back(in_shape[i]);
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite

@ -44,6 +44,14 @@ int Lstm::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_ASSERT(input0 != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
for (int i = 0; i < kLstmOutputNum; i++) {
outputs_[i]->set_data_type(input->data_type());
outputs_[i]->SetFormat(input->GetFormat());
}
if (!GetInferFlag()) {
return RET_OK;
}
std::vector<int> in_shape = input->shape();
std::vector<int> w_shape = weight_i->shape(); // layer, hidden_size * 4, input_size
if (in_shape.size() != 3 || w_shape.size() != 3) {
@ -65,10 +73,7 @@ int Lstm::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
state_shape[2] = hidden_size;
outputs_[1]->set_shape(state_shape);
outputs_[2]->set_shape(state_shape);
for (int i = 0; i < kLstmOutputNum; i++) {
outputs_[i]->set_data_type(input->data_type());
outputs_[i]->SetFormat(input->GetFormat());
}
return RET_OK;
}
} // namespace lite

@ -43,6 +43,13 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_ASSERT(input1 != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
std::vector<int> a_shape = input0->shape();
std::vector<int> b_shape = input1->shape();
if (a_shape.size() < 2 || b_shape.size() < 2) {
@ -65,8 +72,6 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
std::vector<int> c_shape(a_shape);
c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1];
output->set_shape(c_shape);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
return RET_OK;
}
} // namespace lite

@ -50,6 +50,11 @@ int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
if (input == nullptr || output == nullptr) {
return RET_NULL_PTR;
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
if (this->primitive == nullptr) {
return RET_NULL_PTR;
}
@ -88,8 +93,6 @@ int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
}
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite

@ -25,6 +25,11 @@ int Nchw2Nhwc::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(schema::Format_NHWC);
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
std::vector<int> nchw_shape = input->shape();
if (nchw_shape.size() != 4) {
output->set_shape(nchw_shape);
@ -36,8 +41,6 @@ int Nchw2Nhwc::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
nhwc_shape[NHWC_C] = nchw_shape[NCHW_C];
output->set_shape(nhwc_shape);
}
output->SetFormat(schema::Format_NHWC);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace lite

@ -25,6 +25,11 @@ int Nhwc2Nchw::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(schema::Format_NCHW);
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
std::vector<int> nhwc_shape = input->shape();
if (nhwc_shape.size() != 4) {
output->set_shape(nhwc_shape);
@ -36,8 +41,6 @@ int Nhwc2Nchw::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
nchw_shape[NCHW_W] = nhwc_shape[NHWC_W];
output->set_shape(nchw_shape);
}
output->SetFormat(schema::Format_NCHW);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace lite

@ -56,6 +56,19 @@ int OneHot::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:
if (input == nullptr) {
return RET_NULL_PTR;
}
auto on_value = inputs.at(2);
if (on_value == nullptr) {
return RET_NULL_PTR;
}
auto output = outputs.front();
if (output == nullptr) {
return RET_NULL_PTR;
}
output->set_data_type(on_value->data_type());
output->SetFormat(on_value->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
const auto input_shape = input->shape();
int input_rank = static_cast<int>(input_shape.size());
if (axis < 0) {
@ -63,17 +76,7 @@ int OneHot::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:
}
std::vector<int> output_shape(input_shape);
output_shape.insert(output_shape.cbegin() + axis, *depth);
auto output = outputs.front();
if (output == nullptr) {
return RET_NULL_PTR;
}
output->set_shape(output_shape);
auto on_value = inputs.at(2);
if (on_value == nullptr) {
return RET_NULL_PTR;
}
output->set_data_type(on_value->data_type());
output->SetFormat(on_value->GetFormat());
return RET_OK;
}
} // namespace lite

@ -61,6 +61,15 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te
if (input == nullptr) {
return RET_NULL_PTR;
}
auto output = outputs.front();
if (output == nullptr) {
return RET_NULL_PTR;
}
output->SetFormat(input->GetFormat());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape();
std::vector<int> output_shape;
MS_ASSERT(input->shape().size() <= kInputRank);
@ -69,13 +78,8 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te
auto shape = input_shape[i] + (*paddings)[2 * paddings_index] + (*paddings)[2 * paddings_index + 1];
output_shape.push_back(shape);
}
auto output = outputs.front();
if (output == nullptr) {
return RET_NULL_PTR;
}
output->SetFormat(input->GetFormat());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace lite

@ -95,6 +95,11 @@ int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(schema::Format_NHWC);
if (!GetInferFlag()) {
return RET_OK;
}
int input_h = input->shape().at(1);
int input_w = input->shape().at(2);
auto pooling_prim = this->primitive->value_as_Pooling();
@ -137,9 +142,6 @@ int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
input_shape.at(1) = output_h;
input_shape.at(2) = output_w;
output->set_shape(input_shape);
output->set_data_type(input->data_type());
// todo: temp fix
output->SetFormat(schema::Format_NHWC);
return RET_OK;
}
} // namespace lite

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save