!6835 MSLITE adjust tflite parser

Merge pull request !6835 from 徐安越/master
pull/6835/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 981f2ee82f

@ -20,7 +20,6 @@
#include <stdlib.h>
#include <string.h>
#include <stddef.h>
#include <initializer_list>
#define DEFAULT_CAPACITY 4
struct MSTensor;

@ -31,7 +31,7 @@ template <typename T>
Vector<T>::Vector(size_t size) {
size_ = size;
elem_size_ = sizeof(T);
capacity_ = size;
capacity_ = (size == 0 ? DEFAULT_CAPACITY : size);
data_ = reinterpret_cast<T *>(malloc(capacity_ * elem_size_));
if (data_ == nullptr) {
MS_C_EXCEPTION("malloc data failed");
@ -43,7 +43,7 @@ template <typename T>
Vector<T>::Vector(size_t size, const T &value) {
size_ = size;
elem_size_ = sizeof(T);
capacity_ = size;
capacity_ = (size == 0 ? DEFAULT_CAPACITY : size);
data_ = reinterpret_cast<T *>(malloc(capacity_ * elem_size_));
if (data_ == nullptr) {
MS_C_EXCEPTION("malloc data failed");
@ -115,7 +115,7 @@ void Vector<T>::push_back(const T &elem) {
template <typename T>
void Vector<T>::push_back(T &&elem) {
if (data_ == nullptr) {
data_ = reinterpret_cast<T *>(malloc(elem_size_));
data_ = reinterpret_cast<T *>(malloc(capacity_ * elem_size_));
if (data_ == nullptr) {
MS_C_EXCEPTION("malloc data failed");
}

@ -102,9 +102,13 @@ int WriteToBin(const std::string &file_path, void *data, size_t size) {
return 0;
}
int CompareOutputData(float *output_data, float *correct_data, int data_size) {
int CompareOutputData(float *output_data, size_t output_size, float *correct_data, size_t data_size) {
if (output_size != data_size) {
printf("compare failed, output_size %zu isn't equal to data_size %zu.\n", output_size, data_size);
return 0;
}
float error = 0;
for (int i = 0; i < data_size; i++) {
for (size_t i = 0; i < data_size; i++) {
float abs = fabs(output_data[i] - correct_data[i]);
if (abs > 0.00001) {
error += abs;
@ -120,12 +124,12 @@ int CompareOutputData(float *output_data, float *correct_data, int data_size) {
return 0;
}
int CompareOutput(float *output_data, std::string file_path) {
size_t output_size;
auto ground_truth = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size));
size_t output_num = output_size / sizeof(float);
printf("output num : %zu\n", output_num);
int res = CompareOutputData(output_data, ground_truth, output_num);
int CompareOutput(float *output_data, size_t output_num, std::string file_path) {
size_t ground_truth_size;
auto ground_truth = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &ground_truth_size));
size_t ground_truth_num = ground_truth_size / sizeof(float);
printf("ground truth num : %zu\n", ground_truth_num);
int res = CompareOutputData(output_data, output_num, ground_truth, ground_truth_num);
delete[] ground_truth;
return res;
}

@ -50,8 +50,8 @@ void WriteToTxt(const std::string &file_path, void *data, size_t element_size) {
int WriteToBin(const std::string &file_path, void *data, size_t size);
int CompareOutputData(float *output_data, float *correct_data, int data_size);
int CompareOutput(float *output_data, std::string file_path);
int CompareOutputData(float *output_data, size_t output_num, float *correct_data, size_t data_size);
int CompareOutput(float *output_data, size_t output_num, std::string file_path);
std::string GetAndroidPackageName();
std::string GetAndroidPackagePath();

@ -169,6 +169,7 @@ class PrimitiveC {
}
auto ret = primc->UnPackSchemaPrimitive(primitive);
if (ret != RET_OK) {
delete primc;
MS_LOG(ERROR) << "UnPackSchemaPrimitive failed";
return nullptr;
}

@ -144,6 +144,8 @@ void CalShape(const T *data, const std::vector<Tensor *> &inputs, std::vector<in
for (int i = 0; i < shape_size; i++) {
if (static_cast<int>(data[i]) == -1) {
index = i;
} else if (static_cast<int>(data[i]) == 0) {
size *= inputs[0]->shape()[i];
} else {
size *= data[i];
}

@ -64,6 +64,10 @@ int StridedSlice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::StridedSliceT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new StridedSlice failed";
return RET_ERROR;
}
attr->beginMask = GetValue<int>(prim.GetAttr("begin_mask"));
attr->endMask = GetValue<int>(prim.GetAttr("end_mask"));
attr->ellipsisMask = GetValue<int>(prim.GetAttr("ellipsis_mask"));

@ -43,6 +43,10 @@ int Transpose::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::TransposeT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new TransposeT failed";
return RET_ERROR;
}
MS_ASSERT(inputs.size() == kAnfPopulaterTwo);
auto inputNode = inputs[kAnfPopulaterOne];
if (inputNode->isa<ValueNode>()) {

@ -54,7 +54,7 @@ TEST_F(TestConv1x1Fp32, Input1x1PrePack1) {
float out[20] = {0};
Conv1x1InputPack(in, out, conv_param, sizeof(float));
EXPECT_EQ(0, lite::CompareOutputData(out, correct, 20));
EXPECT_EQ(0, lite::CompareOutputData(out, 20, correct, 20));
delete conv_param;
}
@ -114,7 +114,7 @@ TEST_F(TestConv1x1Fp32, Input1x1PrePack3) {
-5.052577, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
Conv1x1InputPack(in, out, conv_param, sizeof(float));
EXPECT_EQ(0, lite::CompareOutputData(out, correct, 18));
EXPECT_EQ(0, lite::CompareOutputData(out, 18, correct, 18));
delete conv_param;
}
@ -136,7 +136,7 @@ TEST_F(TestConv1x1Fp32, Input1x1PrePack4) {
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
float out[54] = {0};
Conv1x1InputPack(in, out, conv_param, sizeof(float));
EXPECT_EQ(0, lite::CompareOutputData(out, correct, 54));
EXPECT_EQ(0, lite::CompareOutputData(out, 54, correct, 54));
delete conv_param;
}
@ -166,7 +166,7 @@ TEST_F(TestConv1x1Fp32, Conv1x1WeightTest1) {
conv_param->output_channel_ = 7;
float out[96] = {0};
Pack1x1WeightFp32(in, out, conv_param);
EXPECT_EQ(0, lite::CompareOutputData(out, co, 96));
EXPECT_EQ(0, lite::CompareOutputData(out, 96, co, 96));
delete conv_param;
}

@ -75,7 +75,7 @@ TEST_F(TestDeConvolutionFp32, DeConvWeightC4x4Pack1) {
0.000, 0.000, 0.000, 0.00};
float dst[256] = {0};
PackDeConvWeightFp32(in, dst, 5, 6, 2 * 2);
EXPECT_EQ(0, lite::CompareOutputData(dst, co, 256));
EXPECT_EQ(0, lite::CompareOutputData(dst, 256, co, 256));
}
TEST_F(TestDeConvolutionFp32, DeConvWeightC4x4Pack2) {
@ -90,7 +90,7 @@ TEST_F(TestDeConvolutionFp32, DeConvWeightC4x4Pack2) {
-0.293, 18.686, 0.0873, 0, 0, 0, 0, 0, 0, 0, 0, 0};
float dst[64] = {0};
PackDeConvWeightFp32(in, dst, 6, 3, 2 * 1);
EXPECT_EQ(0, lite::CompareOutputData(dst, co, 64));
EXPECT_EQ(0, lite::CompareOutputData(dst, 64, co, 64));
}
TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test1) {

@ -212,7 +212,7 @@ TEST_F(TestActGradFp32, SigmoidGradFp32) {
int res = lite::CompareRelativeOutput(output_data, output_path);
EXPECT_EQ(res, 0);
// lite::CompareOutput(output_data, output_path);
// lite::CompareOutput(output_data, output_data_size, output_path);
delete[] input_data;
delete[] output_data;

@ -58,7 +58,7 @@ TEST_F(TestBiasGradFp32, BiasGradFp32) {
}
std::cout << std::endl;
std::string output_path = "./test_data/operators/biasgradfp32_1_db_7.bin";
lite::CompareOutput(output_data, output_path);
lite::CompareOutput(output_data, 7, output_path);
delete[] input_data;
delete[] output_data;

@ -96,7 +96,7 @@ TEST_F(TestPoolingGradFp32, AvgPoolingGradFp32) {
}
std::cout << std::endl;
std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_1_28_28_3.bin";
auto res = lite::CompareOutput(output_data, output_path);
auto res = lite::CompareOutput(output_data, output_data_size, output_path);
EXPECT_EQ(res, 0);
delete[] input_data;
@ -152,7 +152,7 @@ TEST_F(TestPoolingGradFp32, AvgPoolingKernelGradFp32) {
}
std::cout << std::endl;
std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_1_28_28_3.bin";
auto res = lite::CompareOutput(output_data, output_path);
auto res = lite::CompareOutput(output_data, output_data_size, output_path);
EXPECT_EQ(res, 0);
delete[] input_data;
@ -213,7 +213,8 @@ TEST_F(TestPoolingGradFp32, AvgPoolingBatchGradFp32) {
}
std::cout << std::endl;
std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_3_28_28_3.bin";
auto res = lite::CompareOutput(output_data, output_path);
size_t output_data_size = dx_tensor.ElementsNum();
auto res = lite::CompareOutput(output_data, output_data_size, output_path);
EXPECT_EQ(res, 0);
delete[] input_data;
@ -388,7 +389,7 @@ TEST_F(TestPoolingGradFp32, MaxPoolingGradFp32) {
}
std::cout << std::endl;
std::string output_path = "./test_data/pooling/maxpoolgradfp32_1_xgrad_1_28_28_3.bin";
auto res = lite::CompareOutput(output_data, output_path);
auto res = lite::CompareOutput(output_data, output_data_size, output_path);
EXPECT_EQ(res, 0);
free(pooling_param);

@ -70,7 +70,7 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) {
printf("==================Testing Grad===============\n");
std::string output_path = "./test_data/operators/sce_fp32_1_loss_1.bin";
lite::CompareOutput(loss, output_path);
lite::CompareOutput(loss, 1, output_path);
((mindspore::kernel::SparseSoftmaxCrossEntropyWithLogitsCPUKernel *)kernel_obj)->train();
kernel_obj->Run();
@ -81,7 +81,7 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) {
}
std::cout << std::endl;
std::string grad_path = "./test_data/operators/sce_fp32_1_dy_6_4.bin";
lite::CompareOutput(grad, grad_path);
lite::CompareOutput(grad, 24, grad_path);
delete[] ll_labels;
delete[] labels;

@ -20,10 +20,8 @@
namespace mindspore {
namespace lite {
STATUS CaffeReduceParser::Parse(const caffe::LayerParameter &proto,
const caffe::LayerParameter &weight,
schema::CNodeT *op,
std::vector<schema::TensorT *> *weightVec) {
STATUS CaffeReduceParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight,
schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) {
MS_LOG(DEBUG) << "parse CaffeReduceParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
@ -67,6 +65,11 @@ STATUS CaffeReduceParser::Parse(const caffe::LayerParameter &proto,
} else {
attr->axes = std::vector(1, 0);
}
if (reduce_param.has_coeff()) {
attr->coeff = reduce_param.coeff();
} else {
attr->coeff = 1.0;
}
attr->reduceToEnd = true;
attr->keepDims = false;
op->name = proto.name();
@ -78,4 +81,3 @@ STATUS CaffeReduceParser::Parse(const caffe::LayerParameter &proto,
CaffeNodeRegistrar g_caffeReduceParser("Reduction", new CaffeReduceParser());
} // namespace lite
} // namespace mindspore

@ -22,11 +22,9 @@
namespace mindspore {
namespace lite {
STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op, std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) {
STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@ -74,10 +72,10 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
return RET_OK;
}

@ -29,11 +29,8 @@ class TfliteActivationParser : public TfliteNodeParser {
public:
TfliteActivationParser() : TfliteNodeParser("node_name") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op,
std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
};
class TfliteReluParser : public TfliteActivationParser {

@ -22,11 +22,8 @@
namespace mindspore {
namespace lite {
STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op, std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) {
STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteAddNParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
@ -44,16 +41,16 @@ STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_
return RET_NULL_PTR;
}
attr->N = tflite_tensors.size() - 1;
attr->N = tflite_model->subgraphs[0]->tensors.size() - 1;
op->primitive->value.type = schema::PrimitiveType_AddN;
op->primitive->value.value = attr.release();
for (size_t i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
}
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
return RET_OK;
}

@ -29,11 +29,8 @@ class TfliteAddNParser : public TfliteNodeParser {
public:
TfliteAddNParser() : TfliteNodeParser("AddN") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op,
std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore

@ -21,11 +21,8 @@
namespace mindspore {
namespace lite {
STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op, std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) {
STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteArgmaxParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
@ -50,8 +47,8 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
// get axis attr
auto axis_idx = tflite_op->inputs[1];
std::for_each(tflite_tensors[axis_idx]->shape.begin(), tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha) {});
auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer];
auto buffer_idx = tflite_model->subgraphs[0]->tensors[axis_idx]->buffer;
auto &buf_data = tflite_model->buffers[buffer_idx];
if (buf_data == nullptr) {
MS_LOG(ERROR) << "the buf data is null";
return RET_NULL_PTR;
@ -66,10 +63,10 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
op->primitive->value.type = schema::PrimitiveType_ArgMax;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
return RET_OK;
}

@ -29,11 +29,8 @@ class TfliteArgmaxParser : public TfliteNodeParser {
public:
TfliteArgmaxParser() : TfliteNodeParser("Argmax") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op,
std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore

@ -21,11 +21,8 @@
namespace mindspore {
namespace lite {
STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op, std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) {
STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteArgminParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
@ -50,8 +47,8 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
// get axis attr
auto axis_idx = tflite_op->inputs[1];
std::for_each(tflite_tensors[axis_idx]->shape.begin(), tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha) {});
auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer];
auto buffer_idx = tflite_model->subgraphs[0]->tensors[axis_idx]->buffer;
auto &buf_data = tflite_model->buffers[buffer_idx];
if (buf_data == nullptr) {
MS_LOG(ERROR) << "the buf data is null";
return RET_NULL_PTR;
@ -66,10 +63,10 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
op->primitive->value.type = schema::PrimitiveType_ArgMin;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
return RET_OK;
}

@ -29,11 +29,8 @@ class TfliteArgminParser : public TfliteNodeParser {
public:
TfliteArgminParser() : TfliteNodeParser("Argmin") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op,
std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore

@ -22,12 +22,9 @@
namespace mindspore {
namespace lite {
STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op, std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@ -171,20 +168,17 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
// set input
for (size_t i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
}
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
return RET_OK;
}
STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op, std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@ -210,13 +204,13 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
} else if (std::strcmp(node_name, "Exp") == 0) {
MS_LOG(DEBUG) << "parse TfliteExpParser";
auto attr = std::make_unique<schema::ExpT>();
attr->base = -1; // -1 represent base = e
attr->scale = 1;
attr->shift = 0;
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->base = -1; // -1 represent base = e
attr->scale = 1;
attr->shift = 0;
op->primitive->value.type = schema::PrimitiveType_Exp;
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Sqrt") == 0) {
@ -300,7 +294,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
}
op->primitive->value.type = schema::PrimitiveType_Floor;
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "NEG") == 0) {
} else if (std::strcmp(node_name, "Neg") == 0) {
MS_LOG(DEBUG) << "parse TfliteNegParser";
auto attr = std::make_unique<schema::NegT>();
if (attr == nullptr) {
@ -311,18 +305,16 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
op->primitive->value.value = attr.release();
}
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
return RET_OK;
}
STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op, std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) {
STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
@ -393,11 +385,11 @@ STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf
}
for (size_t i = 0; i < tflite_op->inputs.size(); i++) {
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
}
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
return RET_OK;
}
@ -424,7 +416,7 @@ TfliteNodeRegister g_TfliteLogParser("Log", new TfliteLogParser());
TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteRoundParser());
TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser());
TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteFloorParser());
TfliteNodeRegister g_tfliteNegParser("NEG", new TfliteNegParser());
TfliteNodeRegister g_tfliteNegParser("Neg", new TfliteNegParser());
TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteEqualParser());
TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteNotEqualParser());

@ -29,11 +29,8 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser {
public:
TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op,
std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
};
class TfliteAddParser : public TfliteDoubleInputOpParser {
@ -95,11 +92,8 @@ class TfliteSingleInputOpParser : public TfliteNodeParser {
public:
TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op,
std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
};
class TfliteAbsParser : public TfliteSingleInputOpParser {
@ -166,11 +160,8 @@ class TfliteCompareOpParser : public TfliteNodeParser {
public:
TfliteCompareOpParser() : TfliteNodeParser("node_name") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op,
std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
};
class TfliteEqualParser : public TfliteCompareOpParser {

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save