!7757 [MS][LITE][DEVELOP] add while op parser

Merge pull request !7757 from mengyuanli/while
pull/7757/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit d378044f37

@ -226,6 +226,7 @@ union PrimitiveType {
InstanceNorm,
Identity,
LayerNorm,
While,
}
enum QuantType: int {

@ -1103,3 +1103,8 @@ table LayerNorm {
elementwiseAffine : bool;
}
table While {
condSubgraphIndex : int;
bodySubgraphIndex : int;
}

@ -0,0 +1,45 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/while.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
namespace mindspore {
namespace lite {
typedef struct WhileParemeter {
OpParameter op_parameter_;
int body_subgraph_index;
int cond_subgraph_index;
} WhileParemeter;
OpParameter *PopulateWhileParemeter(const mindspore::lite::PrimitiveC *primitive) {
WhileParemeter *while_paremeter = reinterpret_cast<WhileParemeter *>(malloc(sizeof(WhileParemeter)));
if (while_paremeter == nullptr) {
MS_LOG(ERROR) << "malloc WhileParemeter failed.";
return nullptr;
}
memset(while_paremeter, 0, sizeof(WhileParemeter));
auto param = reinterpret_cast<mindspore::lite::While *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
while_paremeter->op_parameter_.type_ = primitive->Type();
while_paremeter->body_subgraph_index = param->GetBodySubgraphIndex();
while_paremeter->cond_subgraph_index = param->GetCondSubgraphIndex();
return reinterpret_cast<OpParameter *>(while_paremeter);
}
Registry WhileParemeterRegistry(schema::PrimitiveType_While, PopulateWhileParemeter);
} // namespace lite
} // namespace mindspore

@ -144,6 +144,7 @@
#include "src/ops/mfcc.h"
#include "src/ops/identity.h"
#include "src/ops/instance_norm.h"
#include "src/ops/while.h"
#ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h"
@ -499,6 +500,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<Maximum>(prim, inputs, quantType);
} else if (op_type == "Split") {
return NewPrimitiveC<Split>(prim, inputs, quantType);
} else if (op_type == "While") {
return NewPrimitiveC<While>(prim, inputs, quantType);
} else if (op_type == "OneHot") {
return NewPrimitiveC<OneHot>(prim, inputs, quantType);
@ -793,6 +796,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new Mfcc(primitive);
case schema::PrimitiveType_InstanceNorm:
return new InstanceNorm(primitive);
case schema::PrimitiveType_While:
return new While(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:

@ -0,0 +1,107 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/while.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
void While::SetCondSubgraphIndex(const int cond_subgraph_index) {
this->primitive_->value.AsWhile()->condSubgraphIndex = cond_subgraph_index;
}
void While::SetBodySubgraphIndex(const int body_subgraph_index) {
this->primitive_->value.AsWhile()->bodySubgraphIndex = body_subgraph_index;
}
int While::GetCondSubgraphIndex() const { return this->primitive_->value.AsWhile()->condSubgraphIndex; }
int While::GetBodySubgraphIndex() const { return this->primitive_->value.AsWhile()->bodySubgraphIndex; }
int While::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_While;
}
if (this->primitive_->value.type != schema::PrimitiveType_While) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::WhileT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->bodySubgraphIndex = GetValue<bool>(prim.GetAttr("body_subgraph_index"));
attr->condSubgraphIndex = GetValue<bool>(prim.GetAttr("cond_subgraph_index"));
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int While::GetCondSubgraphIndex() const { return this->primitive_->value_as_While()->condSubgraphIndex(); }
int While::GetBodySubgraphIndex() const { return this->primitive_->value_as_While()->bodySubgraphIndex(); }
int While::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_While();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_While return nullptr";
return RET_ERROR;
}
auto cond_subgraph_index = attr->condSubgraphIndex();
auto body_subgraph_index = attr->bodySubgraphIndex();
auto val_offset = schema::CreateWhile(*fbb, body_subgraph_index, cond_subgraph_index);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_While, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *WhileCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<While>(primitive); }
Registry WhileRegistry(schema::PrimitiveType_While, WhileCreator);
#endif
int While::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
if (inputs_.size() != outputs_.size()) {
MS_LOG(ERROR) << "The number of inputs and outputs varies";
return RET_ERROR;
}
for (size_t i = 0; i < inputs_.size(); i++) {
outputs_[i]->set_data_type(inputs_[i]->data_type());
outputs_[i]->SetFormat(inputs_[i]->GetFormat());
outputs_[i]->set_shape(inputs_[i]->shape());
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,51 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_WHILE_H_
#define LITE_MINDSPORE_LITE_C_OPS_WHILE_H_
#include <vector>
#include <set>
#include <cmath>
#include <memory>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class While : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(While, PrimitiveC);
While() = default;
explicit While(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetCondSubgraphIndex(const int cond_subgraph_index);
void SetBodySubgraphIndex(const int body_subgraph_index);
#else
While() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
int GetCondSubgraphIndex() const;
int GetBodySubgraphIndex() const;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_WHERE_H_

@ -24,7 +24,8 @@ namespace mindspore {
namespace lite {
STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@ -75,10 +76,8 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}

@ -30,7 +30,8 @@ class TfliteActivationParser : public TfliteNodeParser {
TfliteActivationParser() : TfliteNodeParser("node_name") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
};
class TfliteReluParser : public TfliteActivationParser {

@ -23,7 +23,8 @@
namespace mindspore {
namespace lite {
STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteAddNParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
@ -41,16 +42,14 @@ STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
return RET_NULL_PTR;
}
attr->N = tflite_model->subgraphs[0]->tensors.size() - 1;
attr->N = tflite_subgraph->tensors.size() - 1;
op->primitive->value.type = schema::PrimitiveType_AddN;
op->primitive->value.value = attr.release();
for (size_t i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
}
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}

@ -30,7 +30,8 @@ class TfliteAddNParser : public TfliteNodeParser {
TfliteAddNParser() : TfliteNodeParser("AddN") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore

@ -22,7 +22,8 @@
namespace mindspore {
namespace lite {
STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteArgmaxParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
@ -47,7 +48,7 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
// get axis attr
auto axis_idx = tflite_op->inputs[1];
auto buffer_idx = tflite_model->subgraphs[0]->tensors[axis_idx]->buffer;
auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer;
auto &buf_data = tflite_model->buffers[buffer_idx];
if (buf_data == nullptr) {
MS_LOG(ERROR) << "the buf data is null";
@ -63,10 +64,8 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
op->primitive->value.type = schema::PrimitiveType_ArgMax;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}

@ -30,7 +30,8 @@ class TfliteArgmaxParser : public TfliteNodeParser {
TfliteArgmaxParser() : TfliteNodeParser("Argmax") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore

@ -22,7 +22,8 @@
namespace mindspore {
namespace lite {
STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteArgminParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
@ -47,7 +48,7 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
// get axis attr
auto axis_idx = tflite_op->inputs[1];
auto buffer_idx = tflite_model->subgraphs[0]->tensors[axis_idx]->buffer;
auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer;
auto &buf_data = tflite_model->buffers[buffer_idx];
if (buf_data == nullptr) {
MS_LOG(ERROR) << "the buf data is null";
@ -63,10 +64,8 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
op->primitive->value.type = schema::PrimitiveType_ArgMin;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}

@ -30,7 +30,8 @@ class TfliteArgminParser : public TfliteNodeParser {
TfliteArgminParser() : TfliteNodeParser("Argmin") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore

@ -24,7 +24,8 @@ namespace mindspore {
namespace lite {
STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@ -168,17 +169,16 @@ STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info,
// set input
for (size_t i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
}
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@ -305,16 +305,15 @@ STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info,
op->primitive->value.value = attr.release();
}
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@ -385,11 +384,9 @@ STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info,
}
for (size_t i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
}
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}

@ -30,7 +30,8 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser {
TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
};
class TfliteAddParser : public TfliteDoubleInputOpParser {
@ -93,7 +94,8 @@ class TfliteSingleInputOpParser : public TfliteNodeParser {
TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
};
class TfliteAbsParser : public TfliteSingleInputOpParser {
@ -161,7 +163,8 @@ class TfliteCompareOpParser : public TfliteNodeParser {
TfliteCompareOpParser() : TfliteNodeParser("node_name") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
};
class TfliteEqualParser : public TfliteCompareOpParser {

@ -25,7 +25,8 @@ namespace mindspore {
namespace lite {
STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@ -51,12 +52,11 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info,
return RET_NULL_PTR;
}
if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers,
attr->blockShape)) {
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->blockShape)) {
MS_LOG(ERROR) << "get batchToSpace -> blockShape failed";
return RET_ERROR;
}
if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->crops)) {
if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->crops)) {
MS_LOG(ERROR) << "get batchToSpace -> crops failed";
return RET_ERROR;
}
@ -64,10 +64,8 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info,
op->primitive->value.type = schema::PrimitiveType_BatchToSpace;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}

@ -30,7 +30,8 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser {
TfliteBatchToSpaceParser() : TfliteNodeParser("BatchToSpace") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
};
class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser {

@ -24,7 +24,8 @@ namespace mindspore {
namespace lite {
STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteBroadcastToParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
@ -42,8 +43,7 @@ STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info,
return RET_NULL_PTR;
}
if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers,
attr->dst_shape)) {
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dst_shape)) {
MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed";
return RET_ERROR;
}
@ -51,10 +51,8 @@ STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info,
op->primitive->value.type = schema::PrimitiveType_BroadcastTo;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}

@ -30,7 +30,8 @@ class TfliteBroadcastToParser : public TfliteNodeParser {
TfliteBroadcastToParser() : TfliteNodeParser("BroadcastTo") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore

@ -22,7 +22,8 @@
namespace mindspore {
namespace lite {
STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteCastParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
@ -40,13 +41,13 @@ STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
return RET_NULL_PTR;
}
const auto &in_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]];
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]];
if (in_tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null";
return RET_NULL_PTR;
}
attr->srcT = GetTfliteDataType(in_tensor->type);
const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]];
const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]];
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null";
return RET_NULL_PTR;
@ -56,10 +57,8 @@ STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
op->primitive->value.type = schema::PrimitiveType_Cast;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}

@ -30,7 +30,8 @@ class TfliteCastParser : public TfliteNodeParser {
TfliteCastParser() : TfliteNodeParser("Cast") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore

@ -22,7 +22,8 @@
namespace mindspore {
namespace lite {
STATUS TfliteConcatParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteConcatParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
@ -52,11 +53,9 @@ STATUS TfliteConcatParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
op->primitive->value.value = attr.release();
for (size_t i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
}
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}

@ -30,7 +30,8 @@ class TfliteConcatParser : public TfliteNodeParser {
TfliteConcatParser() : TfliteNodeParser("Concat") {}
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore

@ -22,7 +22,8 @@
namespace mindspore {
namespace lite {
STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteConvParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
@ -57,7 +58,7 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
// get the conv op weight tensor
auto weight_index = tflite_op->inputs[1];
const auto &weight_tensor = tflite_model->subgraphs[0]->tensors[weight_index];
const auto &weight_tensor = tflite_subgraph->tensors[weight_index];
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "the weight tensor is null";
return RET_NULL_PTR;
@ -70,7 +71,7 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
// calculate pad params
auto data_index = tflite_op->inputs[0];
const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index];
const auto &data_tensor = tflite_subgraph->tensors[data_index];
std::vector<int> params;
int status =
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, &params);
@ -87,14 +88,10 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_KHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save