fix bug for anf_exporter graph input tensor format and op output format

pull/3667/head
cjh9368 5 years ago
parent 1b69923472
commit 78c9122897

@ -150,6 +150,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
auto tensor = metaGraphT->allTensors[input].get();
if (tensor->data.empty()) {
tensor->nodeType = schema::NodeType_ValueNode;
tensor->format = schema::Format_NHWC;
// tensor->refCount = lite::MSCONST_WEIGHT_REFCOUNT;
metaGraphT->inputIndex.emplace_back(input);
}

@ -36,6 +36,7 @@ int AddN::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
return RET_INPUT_TENSOR_ERROR;
}
}
output->SetFormat(input->GetFormat());
output->set_shape(input->shape());
output->set_data_type(input->data_type());
return RET_OK;

@ -40,6 +40,7 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
}
output_shape.erase(output_shape.begin() + axis);
output->SetFormat(input->GetFormat());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
return RET_OK;

@ -39,9 +39,9 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
std::vector<int> output_shape(input->shape());
output_shape.erase(output_shape.begin() + axis);
output->SetFormat(input->GetFormat());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite

@ -39,7 +39,7 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
auto input_shape0 = input0->shape();
auto input_shape1 = input1->shape();
auto format = input0->GetFormat();
in_shape0_.resize(5);
in_shape1_.resize(5);
out_shape_.resize(5);
@ -57,6 +57,7 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
}
in_shape1_[i] = input_shape1[i];
}
format = input0->GetFormat();
} else if (input_shape0.size() > input_shape1.size()) {
ndim_ = input_shape0.size();
auto fill_dim_num = input_shape0.size() - input_shape1.size();
@ -93,7 +94,7 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
}
output_shape.push_back(out_shape_[i]);
}
output->SetFormat(format);
output->set_shape(output_shape);
output->set_data_type(input0->data_type());
return RET_OK;

@ -26,9 +26,11 @@ int ArithmeticSelf::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(input->GetFormat());
output->set_shape(input->shape());
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite

@ -85,9 +85,10 @@ int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
output_shape[kNHWC_h_index] = input_shape[kNHWC_h_index] * block_shape->Get(0) - crops->Get(0) - crops->Get(1);
output_shape[kNHWC_w_index] = input_shape[kNHWC_w_index] * block_shape->Get(1) - crops->Get(2) - crops->Get(3);
output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index];
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite

@ -58,9 +58,9 @@ int BroadcastTo::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<te
shape[i] = dst_shape[i];
--input_shape_index;
}
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_shape(shape);
outputs[0]->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite

@ -44,9 +44,9 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
MS_LOG(ERROR) << "Invalid output datatype " << cast_prim->dstT();
return RET_INPUT_TENSOR_ERROR;
}
output->SetFormat(input->GetFormat());
output->set_shape(input->shape());
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite

@ -70,7 +70,8 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
output_shape[axis] = output_axis_dim;
outputs_[0]->set_shape(output_shape);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

@ -32,7 +32,8 @@ int Crop::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
return RET_PARAM_INVALID;
}
outputs[0]->set_shape(inputs[1]->shape());
outputs[0]->SetFormat(inputs[1]->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

@ -23,7 +23,7 @@ namespace mindspore::lite {
namespace {
constexpr int kDepthToSpaceOutputNum = 1;
constexpr int kDepthToSpaceInputNum = 1;
}
} // namespace
int DepthToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
@ -56,7 +56,8 @@ int DepthToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index] / (block_size * block_size);
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

@ -45,7 +45,8 @@ int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
out_shape.insert(out_shape.begin() + dim, 1, 1);
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

@ -42,7 +42,8 @@ int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
(void)output_shape.insert(output_shape.begin(), fill_prim->dims()->begin(), fill_prim->dims()->end());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

@ -43,7 +43,8 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
}
output->set_shape(output_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

@ -56,7 +56,8 @@ int FullConnection::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
out_shape[fc_prim->axis()] = input1->shape()[0];
output->set_shape(out_shape);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

@ -71,7 +71,8 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

@ -59,7 +59,8 @@ int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

@ -57,7 +57,8 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
y_shape[y_shape_size - 1] = w_shape[w_shape.size() - 1];
output->set_shape(y_shape);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

@ -67,6 +67,8 @@ int OneHot::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:
return RET_NULL_PTR;
}
output->set_data_type(on_value->data_type());
output->SetFormat(on_value->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

@ -138,7 +138,8 @@ int Primitive::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

@ -55,9 +55,9 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te
if (output == nullptr) {
return RET_NULL_PTR;
}
output->SetFormat(input->GetFormat());
output->set_shape(output_shape);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite

@ -74,6 +74,7 @@ int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
input_shape.at(2) = output_w;
output->set_shape(input_shape);
output->set_data_type(input->data_type());
// todo: temp fix
output->SetFormat(schema::Format_NHWC);
return RET_OK;

@ -34,7 +34,8 @@ int Range::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
in_shape.push_back(shape_size);
output->set_shape(in_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

@ -29,7 +29,8 @@ int Rank::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
std::vector<int> in_shape(1, 1);
output->set_shape(in_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save