!9159 [lite] fix converter bug

From: @xu_anyue
Reviewed-by: @hangangqiang,@HilbertDavid
Signed-off-by: @hangangqiang,@HilbertDavid
pull/9159/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit d0d5a8b878

@ -179,6 +179,13 @@ int Reshape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
std::vector<int> out_shape;
if (inputs_.size() == kDoubleNum) {
auto shape_tensor = inputs_.at(1);
if (input->ElementsNum() == 1) {
if (shape_tensor->shape().empty()) {
MS_LOG(DEBUG) << "reshape to a scalar.";
output->set_shape(out_shape);
return RET_OK;
}
}
if (shape_tensor->data_c() == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime.";
return RET_INFER_INVALID;

@ -52,8 +52,8 @@ int Split::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inpu
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->splitDim = GetValue<int32_t>(prim.GetAttr("axis"));
attr->numberSplit = GetValue<int32_t>(prim.GetAttr("output_num"));
attr->splitDim = CastToInt(prim.GetAttr("axis")).front();
attr->numberSplit = CastToInt(prim.GetAttr("output_num")).front();
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";

@ -177,6 +177,15 @@ constexpr size_t kStridedSliceInputNum = 1;
constexpr size_t kStridedSliceMultiInputNumMin = 3;
constexpr size_t kStridedSliceMultiInputNumMax = 5;
} // namespace
bool StridedSlice::CheckInputs(std::vector<lite::Tensor *> inputs_) {
for (size_t i = 1; i < inputs_.size(); ++i) {
if (inputs_[i]->data_c() == nullptr) {
MS_LOG(DEBUG) << "strided_slice has input from other node, which only can be obtained when running.";
return false;
}
}
return true;
}
void StridedSlice::ApplyNewAxisMask() {
for (size_t i = 0; i < new_axis_mask_.size(); i++) {
@ -365,6 +374,10 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
strides_.emplace_back((GetStride())[i]);
}
}
if (!CheckInputs(inputs)) {
MS_LOG(DEBUG) << "Do infer shape in runtime.";
return RET_INFER_INVALID;
}
if (inputs.size() == 4) {
// input order: input, begins, ends, strides.
auto begin_tensor = inputs.at(1);

@ -47,6 +47,7 @@ class StridedSlice : public PrimitiveC {
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
bool CheckInputs(std::vector<lite::Tensor *> inputs_);
int GetBeginMask() const;
int GetEndMask() const;
int GetEllipsisMask() const;

@ -81,7 +81,6 @@ FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::
STATUS TfliteModelParser::ConvertOps() {
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
const auto &tflite_model_buffers = tflite_model_->buffers;
NoSupportOp::GetInstance()->SetFmkType("TFLITE");
STATUS status = RET_OK;
int op_idx = 0;
@ -117,6 +116,9 @@ STATUS TfliteModelParser::ConvertOps() {
std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<lite::PrimitiveC>(primitiveC))};
// parse inputs
for (auto input_idx : op->inputs) {
if (tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED && input_idx == -1) {
continue;
}
if (input_idx < 0) {
input_idx += tflite_subgraph->tensors.size();
}
@ -126,18 +128,14 @@ STATUS TfliteModelParser::ConvertOps() {
continue;
}
// const tensor
if (!tflite_model_buffers.at(input_tensor->buffer)->data.empty()) {
auto parameter = func_graph_->add_parameter();
status = ConvertConstTensor(input_tensor.get(), parameter.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed.";
return status;
}
op_inputs.emplace_back(parameter);
nodes_.insert(std::pair(input_idx, parameter));
continue;
auto parameter = func_graph_->add_parameter();
status = ConvertConstTensor(input_tensor.get(), parameter.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed.";
return status;
}
MS_LOG(WARNING) << "tensor " << input_idx << " is neither a node output nor a weight tensor.";
op_inputs.emplace_back(parameter);
nodes_.insert(std::pair(input_idx, parameter));
}
auto new_cnode = func_graph_->NewCNode(op_inputs);
new_cnode->set_fullname_with_scope(op_name);
@ -268,6 +266,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr);
make_tuple_inputs.emplace_back(make_tuple_prim);
for (auto outputNode : tflite_subgraph->outputs) {
outputNode = outputNode < 0 ? outputNode + tflite_subgraph->tensors.size() : outputNode;
auto cnode = nodes_.at(outputNode);
if (nullptr == cnode) {
MS_LOG(ERROR) << "Can't find input node.";
@ -296,9 +295,12 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
return RET_NULL_PTR;
}
int outputNode = tflite_subgraph->outputs.front() < 0
? static_cast<int>(tflite_subgraph->outputs.front() + tflite_subgraph->tensors.size())
: static_cast<int>(tflite_subgraph->outputs.front());
auto valueNode = NewValueNode(returnPrim);
std::vector<AnfNodePtr> op_inputs{valueNode};
auto cnode = nodes_.at(tflite_subgraph->outputs.front());
auto cnode = nodes_.at(outputNode);
if (nullptr == cnode) {
MS_LOG(ERROR) << "Can't find input node.";
return RET_NOT_FIND_OP;
@ -345,8 +347,8 @@ STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Para
}
std::memcpy(tensor_data, data.data(), size);
param_value->SetTensorData(tensor_data, size);
parameter->set_default_param(param_value);
}
parameter->set_default_param(param_value);
return RET_OK;
}

@ -50,8 +50,15 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) {
}
auto data_node = depthwise_cnode->input(kConvInputIndex)->abstract();
if (data_node == nullptr) {
MS_LOG(ERROR) << "the node input is invalid.";
return false;
}
auto data_shape = utils::cast<abstract::ShapePtr>(data_node->GetShapeTrack())->shape();
if (data_shape.empty()) {
MS_LOG(DEBUG) << "the tensor's shape is dynamic.";
return true;
}
auto conv_attr = std::make_unique<schema::Conv2DT>();
if (conv_attr == nullptr) {
MS_LOG(ERROR) << "conv_attr is null";

@ -89,7 +89,7 @@ STATUS InferShapePass::SetParameterAbstract(const ParameterPtr &parameter) {
return RET_ERROR;
}
auto ret = memcpy_s(tensor_data, new_value->tensor_size(), param_value->tensor_addr(), param_value->tensor_size());
if (ret != EOK) {
if (new_value->tensor_size() != 0 && ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
delete[] tensor_data;
return RET_ERROR;
@ -163,7 +163,7 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l
return RET_ERROR;
}
ret = memcpy_s(tensor->MutableData(), tensor->Size(), param_value->tensor_addr(), param_value->tensor_size());
if (ret != EOK) {
if (tensor->Size() != 0 && ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;
}

Loading…
Cancel
Save