diff --git a/mindspore/lite/internal/include/vector.h b/mindspore/lite/internal/include/vector.h index 17fd40cb48..62856ce2a0 100644 --- a/mindspore/lite/internal/include/vector.h +++ b/mindspore/lite/internal/include/vector.h @@ -20,7 +20,6 @@ #include #include #include -#include #define DEFAULT_CAPACITY 4 struct MSTensor; diff --git a/mindspore/lite/internal/src/common/vector.cc b/mindspore/lite/internal/src/common/vector.cc index 6d9201c6aa..b7758ad494 100644 --- a/mindspore/lite/internal/src/common/vector.cc +++ b/mindspore/lite/internal/src/common/vector.cc @@ -31,7 +31,7 @@ template Vector::Vector(size_t size) { size_ = size; elem_size_ = sizeof(T); - capacity_ = size; + capacity_ = (size == 0 ? DEFAULT_CAPACITY : size); data_ = reinterpret_cast(malloc(capacity_ * elem_size_)); if (data_ == nullptr) { MS_C_EXCEPTION("malloc data failed"); @@ -43,7 +43,7 @@ template Vector::Vector(size_t size, const T &value) { size_ = size; elem_size_ = sizeof(T); - capacity_ = size; + capacity_ = (size == 0 ? DEFAULT_CAPACITY : size); data_ = reinterpret_cast(malloc(capacity_ * elem_size_)); if (data_ == nullptr) { MS_C_EXCEPTION("malloc data failed"); @@ -115,7 +115,7 @@ void Vector::push_back(const T &elem) { template void Vector::push_back(T &&elem) { if (data_ == nullptr) { - data_ = reinterpret_cast(malloc(elem_size_)); + data_ = reinterpret_cast(malloc(capacity_ * elem_size_)); if (data_ == nullptr) { MS_C_EXCEPTION("malloc data failed"); } diff --git a/mindspore/lite/src/common/file_utils.cc b/mindspore/lite/src/common/file_utils.cc index 1ef54dc49a..3f8c565e3b 100644 --- a/mindspore/lite/src/common/file_utils.cc +++ b/mindspore/lite/src/common/file_utils.cc @@ -102,9 +102,13 @@ int WriteToBin(const std::string &file_path, void *data, size_t size) { return 0; } -int CompareOutputData(float *output_data, float *correct_data, int data_size) { +int CompareOutputData(float *output_data, size_t output_size, float *correct_data, size_t data_size) { + if (output_size != data_size) { + printf("compare failed, output_size %zu isn't equal to data_size %zu.\n", output_size, data_size); + return 0; + } float error = 0; - for (int i = 0; i < data_size; i++) { + for (size_t i = 0; i < data_size; i++) { float abs = fabs(output_data[i] - correct_data[i]); if (abs > 0.00001) { error += abs; @@ -120,12 +124,12 @@ int CompareOutputData(float *output_data, float *correct_data, int data_size) { return 0; } -int CompareOutput(float *output_data, std::string file_path) { - size_t output_size; - auto ground_truth = reinterpret_cast(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); - size_t output_num = output_size / sizeof(float); - printf("output num : %zu\n", output_num); - int res = CompareOutputData(output_data, ground_truth, output_num); +int CompareOutput(float *output_data, size_t output_num, std::string file_path) { + size_t ground_truth_size; + auto ground_truth = reinterpret_cast(mindspore::lite::ReadFile(file_path.c_str(), &ground_truth_size)); + size_t ground_truth_num = ground_truth_size / sizeof(float); + printf("ground truth num : %zu\n", ground_truth_num); + int res = CompareOutputData(output_data, output_num, ground_truth, ground_truth_num); delete[] ground_truth; return res; } diff --git a/mindspore/lite/src/common/file_utils.h b/mindspore/lite/src/common/file_utils.h index c1751d9a28..3adacdb6a9 100644 --- a/mindspore/lite/src/common/file_utils.h +++ b/mindspore/lite/src/common/file_utils.h @@ -50,8 +50,8 @@ void WriteToTxt(const std::string &file_path, void *data, size_t element_size) { int WriteToBin(const std::string &file_path, void *data, size_t size); -int CompareOutputData(float *output_data, float *correct_data, int data_size); -int CompareOutput(float *output_data, std::string file_path); +int CompareOutputData(float *output_data, size_t output_num, float *correct_data, size_t data_size); +int CompareOutput(float *output_data, size_t output_num, std::string file_path); std::string GetAndroidPackageName(); std::string GetAndroidPackagePath(); diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index b4209b8282..2f22aaad30 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -169,6 +169,7 @@ class PrimitiveC { } auto ret = primc->UnPackSchemaPrimitive(primitive); if (ret != RET_OK) { + delete primc; MS_LOG(ERROR) << "UnPackSchemaPrimitive failed"; return nullptr; } diff --git a/mindspore/lite/src/ops/reshape.cc b/mindspore/lite/src/ops/reshape.cc index c005647ae5..78c4254f10 100644 --- a/mindspore/lite/src/ops/reshape.cc +++ b/mindspore/lite/src/ops/reshape.cc @@ -144,6 +144,8 @@ void CalShape(const T *data, const std::vector &inputs, std::vector(data[i]) == -1) { index = i; + } else if (static_cast(data[i]) == 0) { + size *= inputs[0]->shape()[i]; } else { size *= data[i]; } diff --git a/mindspore/lite/src/ops/strided_slice.cc b/mindspore/lite/src/ops/strided_slice.cc index 6cb0356c9b..b7905e0bef 100644 --- a/mindspore/lite/src/ops/strided_slice.cc +++ b/mindspore/lite/src/ops/strided_slice.cc @@ -64,6 +64,10 @@ int StridedSlice::UnPackAttr(const Primitive &prim, const std::vectorprimitive_->value.value == nullptr) { auto attr = new (std::nothrow) schema::StridedSliceT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new StridedSlice failed"; + return RET_ERROR; + } attr->beginMask = GetValue(prim.GetAttr("begin_mask")); attr->endMask = GetValue(prim.GetAttr("end_mask")); attr->ellipsisMask = GetValue(prim.GetAttr("ellipsis_mask")); diff --git a/mindspore/lite/src/ops/transpose.cc b/mindspore/lite/src/ops/transpose.cc index 95c0b02d3d..279a6953c7 100644 --- a/mindspore/lite/src/ops/transpose.cc +++ b/mindspore/lite/src/ops/transpose.cc @@ -43,6 +43,10 @@ int Transpose::UnPackAttr(const Primitive &prim, const std::vector & } if (this->primitive_->value.value == nullptr) { auto attr = new (std::nothrow) schema::TransposeT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new TransposeT failed"; + return RET_ERROR; + } MS_ASSERT(inputs.size() == kAnfPopulaterTwo); auto inputNode = inputs[kAnfPopulaterOne]; if (inputNode->isa()) { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc index caca06fe7a..d44a67eeb2 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc @@ -54,7 +54,7 @@ TEST_F(TestConv1x1Fp32, Input1x1PrePack1) { float out[20] = {0}; Conv1x1InputPack(in, out, conv_param, sizeof(float)); - EXPECT_EQ(0, lite::CompareOutputData(out, correct, 20)); + EXPECT_EQ(0, lite::CompareOutputData(out, 20, correct, 20)); delete conv_param; } @@ -114,7 +114,7 @@ TEST_F(TestConv1x1Fp32, Input1x1PrePack3) { -5.052577, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; Conv1x1InputPack(in, out, conv_param, sizeof(float)); - EXPECT_EQ(0, lite::CompareOutputData(out, correct, 18)); + EXPECT_EQ(0, lite::CompareOutputData(out, 18, correct, 18)); delete conv_param; } @@ -136,7 +136,7 @@ TEST_F(TestConv1x1Fp32, Input1x1PrePack4) { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; float out[54] = {0}; Conv1x1InputPack(in, out, conv_param, sizeof(float)); - EXPECT_EQ(0, lite::CompareOutputData(out, correct, 54)); + EXPECT_EQ(0, lite::CompareOutputData(out, 54, correct, 54)); delete conv_param; } @@ -166,7 +166,7 @@ TEST_F(TestConv1x1Fp32, Conv1x1WeightTest1) { conv_param->output_channel_ = 7; float out[96] = {0}; Pack1x1WeightFp32(in, out, conv_param); - EXPECT_EQ(0, lite::CompareOutputData(out, co, 96)); + EXPECT_EQ(0, lite::CompareOutputData(out, 96, co, 96)); delete conv_param; } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc index 4681e26fa0..289d8fd01e 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc @@ -75,7 +75,7 @@ TEST_F(TestDeConvolutionFp32, DeConvWeightC4x4Pack1) { 0.000, 0.000, 0.000, 0.00}; float dst[256] = {0}; PackDeConvWeightFp32(in, dst, 5, 6, 2 * 2); - EXPECT_EQ(0, lite::CompareOutputData(dst, co, 256)); + EXPECT_EQ(0, lite::CompareOutputData(dst, 256, co, 256)); } TEST_F(TestDeConvolutionFp32, DeConvWeightC4x4Pack2) { @@ -90,7 +90,7 @@ TEST_F(TestDeConvolutionFp32, DeConvWeightC4x4Pack2) { -0.293, 18.686, 0.0873, 0, 0, 0, 0, 0, 0, 0, 0, 0}; float dst[64] = {0}; PackDeConvWeightFp32(in, dst, 6, 3, 2 * 1); - EXPECT_EQ(0, lite::CompareOutputData(dst, co, 64)); + EXPECT_EQ(0, lite::CompareOutputData(dst, 64, co, 64)); } TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test1) { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc index 3563d2d76b..c7b7e0a212 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc @@ -212,7 +212,7 @@ TEST_F(TestActGradFp32, SigmoidGradFp32) { int res = lite::CompareRelativeOutput(output_data, output_path); EXPECT_EQ(res, 0); - // lite::CompareOutput(output_data, output_path); + // lite::CompareOutput(output_data, output_data_size, output_path); delete[] input_data; delete[] output_data; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bias_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bias_grad_fp32_tests.cc index 3efc68212b..71c01b7dc4 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bias_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bias_grad_fp32_tests.cc @@ -58,7 +58,7 @@ TEST_F(TestBiasGradFp32, BiasGradFp32) { } std::cout << std::endl; std::string output_path = "./test_data/operators/biasgradfp32_1_db_7.bin"; - lite::CompareOutput(output_data, output_path); + lite::CompareOutput(output_data, 7, output_path); delete[] input_data; delete[] output_data; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc index f0646e343f..d895dc8122 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc @@ -96,7 +96,7 @@ TEST_F(TestPoolingGradFp32, AvgPoolingGradFp32) { } std::cout << std::endl; std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_1_28_28_3.bin"; - auto res = lite::CompareOutput(output_data, output_path); + auto res = lite::CompareOutput(output_data, output_data_size, output_path); EXPECT_EQ(res, 0); delete[] input_data; @@ -152,7 +152,7 @@ TEST_F(TestPoolingGradFp32, AvgPoolingKernelGradFp32) { } std::cout << std::endl; std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_1_28_28_3.bin"; - auto res = lite::CompareOutput(output_data, output_path); + auto res = lite::CompareOutput(output_data, output_data_size, output_path); EXPECT_EQ(res, 0); delete[] input_data; @@ -213,7 +213,8 @@ TEST_F(TestPoolingGradFp32, AvgPoolingBatchGradFp32) { } std::cout << std::endl; std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_3_28_28_3.bin"; - auto res = lite::CompareOutput(output_data, output_path); + size_t output_data_size = dx_tensor.ElementsNum(); + auto res = lite::CompareOutput(output_data, output_data_size, output_path); EXPECT_EQ(res, 0); delete[] input_data; @@ -388,7 +389,7 @@ TEST_F(TestPoolingGradFp32, MaxPoolingGradFp32) { } std::cout << std::endl; std::string output_path = "./test_data/pooling/maxpoolgradfp32_1_xgrad_1_28_28_3.bin"; - auto res = lite::CompareOutput(output_data, output_path); + auto res = lite::CompareOutput(output_data, output_data_size, output_path); EXPECT_EQ(res, 0); free(pooling_param); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc index 0599fff9b5..26b2abf277 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc @@ -70,7 +70,7 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) { printf("==================Testing Grad===============\n"); std::string output_path = "./test_data/operators/sce_fp32_1_loss_1.bin"; - lite::CompareOutput(loss, output_path); + lite::CompareOutput(loss, 1, output_path); ((mindspore::kernel::SparseSoftmaxCrossEntropyWithLogitsCPUKernel *)kernel_obj)->train(); kernel_obj->Run(); @@ -81,7 +81,7 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) { } std::cout << std::endl; std::string grad_path = "./test_data/operators/sce_fp32_1_dy_6_4.bin"; - lite::CompareOutput(grad, grad_path); + lite::CompareOutput(grad, 24, grad_path); delete[] ll_labels; delete[] labels; diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_reduce_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_reduce_parser.cc index 4c41fd193b..3d2854975f 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_reduce_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_reduce_parser.cc @@ -20,10 +20,8 @@ namespace mindspore { namespace lite { -STATUS CaffeReduceParser::Parse(const caffe::LayerParameter &proto, - const caffe::LayerParameter &weight, - schema::CNodeT *op, - std::vector *weightVec) { +STATUS CaffeReduceParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, + schema::CNodeT *op, std::vector *weightVec) { MS_LOG(DEBUG) << "parse CaffeReduceParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -67,6 +65,11 @@ STATUS CaffeReduceParser::Parse(const caffe::LayerParameter &proto, } else { attr->axes = std::vector(1, 0); } + if (reduce_param.has_coeff()) { + attr->coeff = reduce_param.coeff(); + } else { + attr->coeff = 1.0; + } attr->reduceToEnd = true; attr->keepDims = false; op->name = proto.name(); @@ -78,4 +81,3 @@ STATUS CaffeReduceParser::Parse(const caffe::LayerParameter &proto, CaffeNodeRegistrar g_caffeReduceParser("Reduction", new CaffeReduceParser()); } // namespace lite } // namespace mindspore - diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc index c8671ceb14..b4c43fc635 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc @@ -22,11 +22,9 @@ namespace mindspore { namespace lite { -STATUS TfliteActivationParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -74,10 +72,10 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr &t op->primitive->value.type = schema::PrimitiveType_Activation; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h index c3b9db23b9..d6d8a13fb6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h @@ -29,11 +29,8 @@ class TfliteActivationParser : public TfliteNodeParser { public: TfliteActivationParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; class TfliteReluParser : public TfliteActivationParser { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc index 72bccef988..1865260194 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc @@ -22,11 +22,8 @@ namespace mindspore { namespace lite { -STATUS TfliteAddNParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteAddNParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -44,16 +41,16 @@ STATUS TfliteAddNParser::Parse(const std::unique_ptr &tflite_ return RET_NULL_PTR; } - attr->N = tflite_tensors.size() - 1; + attr->N = tflite_model->subgraphs[0]->tensors.size() - 1; op->primitive->value.type = schema::PrimitiveType_AddN; op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h index fdc2fe0553..edaf56e305 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h @@ -29,11 +29,8 @@ class TfliteAddNParser : public TfliteNodeParser { public: TfliteAddNParser() : TfliteNodeParser("AddN") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc index 56d3efea6c..e0638addb9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteArgmaxParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -50,8 +47,8 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr &tflit // get axis attr auto axis_idx = tflite_op->inputs[1]; - std::for_each(tflite_tensors[axis_idx]->shape.begin(), tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha) {}); - auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer]; + auto buffer_idx = tflite_model->subgraphs[0]->tensors[axis_idx]->buffer; + auto &buf_data = tflite_model->buffers[buffer_idx]; if (buf_data == nullptr) { MS_LOG(ERROR) << "the buf data is null"; return RET_NULL_PTR; @@ -66,10 +63,10 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_ArgMax; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h index 61d1ac385d..013cc5ad2b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h @@ -29,11 +29,8 @@ class TfliteArgmaxParser : public TfliteNodeParser { public: TfliteArgmaxParser() : TfliteNodeParser("Argmax") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc index e1b97dac8b..d7017517eb 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteArgminParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteArgminParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -50,8 +47,8 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr &tflit // get axis attr auto axis_idx = tflite_op->inputs[1]; - std::for_each(tflite_tensors[axis_idx]->shape.begin(), tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha) {}); - auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer]; + auto buffer_idx = tflite_model->subgraphs[0]->tensors[axis_idx]->buffer; + auto &buf_data = tflite_model->buffers[buffer_idx]; if (buf_data == nullptr) { MS_LOG(ERROR) << "the buf data is null"; return RET_NULL_PTR; @@ -66,10 +63,10 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_ArgMin; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h index 58e90a6775..ad4ed1c3c8 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h @@ -29,11 +29,8 @@ class TfliteArgminParser : public TfliteNodeParser { public: TfliteArgminParser() : TfliteNodeParser("Argmin") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc index 8fe2a794ab..62cf917175 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc @@ -22,12 +22,9 @@ namespace mindspore { namespace lite { -STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, - std::map *tensors_id_map) { +STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -171,20 +168,17 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr // set input for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } -STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, - std::map *tensors_id_map) { +STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -210,13 +204,13 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr } else if (std::strcmp(node_name, "Exp") == 0) { MS_LOG(DEBUG) << "parse TfliteExpParser"; auto attr = std::make_unique(); - attr->base = -1; // -1 represent base = e - attr->scale = 1; - attr->shift = 0; if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; } + attr->base = -1; // -1 represent base = e + attr->scale = 1; + attr->shift = 0; op->primitive->value.type = schema::PrimitiveType_Exp; op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "Sqrt") == 0) { @@ -300,7 +294,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr } op->primitive->value.type = schema::PrimitiveType_Floor; op->primitive->value.value = attr.release(); - } else if (std::strcmp(node_name, "NEG") == 0) { + } else if (std::strcmp(node_name, "Neg") == 0) { MS_LOG(DEBUG) << "parse TfliteNegParser"; auto attr = std::make_unique(); if (attr == nullptr) { @@ -311,18 +305,16 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); } - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } -STATUS TfliteCompareOpParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -393,11 +385,11 @@ STATUS TfliteCompareOpParser::Parse(const std::unique_ptr &tf } for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } @@ -424,7 +416,7 @@ TfliteNodeRegister g_TfliteLogParser("Log", new TfliteLogParser()); TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteRoundParser()); TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser()); TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteFloorParser()); -TfliteNodeRegister g_tfliteNegParser("NEG", new TfliteNegParser()); +TfliteNodeRegister g_tfliteNegParser("Neg", new TfliteNegParser()); TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteEqualParser()); TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteNotEqualParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h index 53676f121f..c52b0b98f1 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h @@ -29,11 +29,8 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser { public: TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; class TfliteAddParser : public TfliteDoubleInputOpParser { @@ -95,11 +92,8 @@ class TfliteSingleInputOpParser : public TfliteNodeParser { public: TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; class TfliteAbsParser : public TfliteSingleInputOpParser { @@ -166,11 +160,8 @@ class TfliteCompareOpParser : public TfliteNodeParser { public: TfliteCompareOpParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; class TfliteEqualParser : public TfliteCompareOpParser { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc index 2ae3bcbe29..a1c6797d3d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc @@ -23,12 +23,9 @@ namespace mindspore { namespace lite { -STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, - std::map *tensors_id_map) { +STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -54,11 +51,12 @@ STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr return RET_NULL_PTR; } - if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->blockShape)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, + attr->blockShape)) { MS_LOG(ERROR) << "get batchToSpace -> blockShape failed"; return RET_ERROR; } - if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->crops)) { + if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->crops)) { MS_LOG(ERROR) << "get batchToSpace -> crops failed"; return RET_ERROR; } @@ -66,10 +64,10 @@ STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr op->primitive->value.type = schema::PrimitiveType_BatchToSpace; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h index 398707bcd1..50fd8edb29 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h @@ -29,11 +29,8 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser { public: TfliteBatchToSpaceParser() : TfliteNodeParser("BatchToSpace") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc index afbf73c9ed..7c2b9c8fc9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc @@ -22,11 +22,9 @@ namespace mindspore { namespace lite { -STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteBroadcastToParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -44,7 +42,8 @@ STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr & return RET_NULL_PTR; } - if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dst_shape)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, + attr->dst_shape)) { MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed"; return RET_ERROR; } @@ -52,10 +51,10 @@ STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr & op->primitive->value.type = schema::PrimitiveType_BroadcastTo; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h index 50363b19ae..364709ed35 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h @@ -29,11 +29,8 @@ class TfliteBroadcastToParser : public TfliteNodeParser { public: TfliteBroadcastToParser() : TfliteNodeParser("BroadcastTo") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc index ab631f64cb..383474316c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteCastParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteCastParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -43,7 +40,7 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr &tflite_ return RET_NULL_PTR; } - const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]]; + const auto &in_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; if (in_tensor == nullptr) { MS_LOG(ERROR) << "tensor is null"; return RET_NULL_PTR; @@ -52,7 +49,7 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr &tflite_ if (attr->srcT == TypeId::kNumberTypeBool) { attr->srcT = TypeId::kNumberTypeUInt8; } - const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]]; + const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]]; if (out_tensor == nullptr) { MS_LOG(ERROR) << "tensor is null"; return RET_NULL_PTR; @@ -62,10 +59,10 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr &tflite_ op->primitive->value.type = schema::PrimitiveType_Cast; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h index 2570f43e94..17cf60ef05 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h @@ -29,11 +29,8 @@ class TfliteCastParser : public TfliteNodeParser { public: TfliteCastParser() : TfliteNodeParser("Cast") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc index d749dfc58f..a2741c8365 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteConcatParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteConcatParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteConcatParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -55,11 +52,11 @@ STATUS TfliteConcatParser::Parse(const std::unique_ptr &tflit op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h index b50b6eb03e..4074d41f3a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h @@ -29,11 +29,8 @@ class TfliteConcatParser : public TfliteNodeParser { public: TfliteConcatParser() : TfliteNodeParser("Concat") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc index 9c4a90e6df..e05f1cd75a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteConvParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteConvParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -60,7 +57,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr &tflite_ // get the conv op weight tensor auto weight_index = tflite_op->inputs[1]; - const auto &weight_tensor = tflite_tensors[weight_index]; + const auto &weight_tensor = tflite_model->subgraphs[0]->tensors[weight_index]; if (weight_tensor == nullptr) { MS_LOG(ERROR) << "the weight tensor is null"; return RET_NULL_PTR; @@ -73,7 +70,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr &tflite_ // calculate pad params auto data_index = tflite_op->inputs[0]; - const auto &data_tensor = tflite_tensors[data_index]; + const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; std::vector params; if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms) != RET_OK) { @@ -89,14 +86,14 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr &tflite_ op->primitive->value.type = schema::PrimitiveType_Conv2D; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_KHWC); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[2], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_KHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h index 6a21bce5c6..4e226c0b98 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h @@ -29,11 +29,8 @@ class TfliteConvParser : public TfliteNodeParser { public: TfliteConvParser() : TfliteNodeParser("Conv2D") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc index 0c2ec85d94..de61798b34 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc @@ -23,11 +23,8 @@ namespace mindspore { namespace lite { -STATUS TfliteCustomParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteCustomParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -78,12 +75,12 @@ STATUS TfliteCustomParser::Parse(const std::unique_ptr &tflit op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } for (size_t i = 0; i < tflite_op->outputs.size(); ++i) { - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h index 88e7bdd3a1..b39188ae4f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h @@ -29,11 +29,8 @@ class TfliteCustomParser : public TfliteNodeParser { public: TfliteCustomParser() : TfliteNodeParser("Custom") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc index 2ea3c3be69..70e5145253 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteDeConvParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -61,7 +58,7 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr &tflit // get the conv op weight tensor auto weight_index = tflite_op->inputs[1]; - const auto &weight_tensor = tflite_tensors[weight_index]; + const auto &weight_tensor = tflite_model->subgraphs[0]->tensors[weight_index]; if (weight_tensor == nullptr) { MS_LOG(ERROR) << "the weight tensor is null"; return RET_NULL_PTR; @@ -74,7 +71,7 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr &tflit // calculate pad params auto data_index = tflite_op->inputs[2]; - const auto &data_tensor = tflite_tensors[data_index]; + const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; std::vector params; if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms) != RET_OK) { @@ -90,12 +87,12 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_DeConv2D; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[2], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_KHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_KHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h index 5d754318b9..58c1a47b5d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h @@ -29,11 +29,8 @@ class TfliteDeConvParser : public TfliteNodeParser { public: TfliteDeConvParser() : TfliteNodeParser("DeConv2D") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc index 8d781f3787..611c40e288 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc @@ -22,12 +22,9 @@ namespace mindspore { namespace lite { -STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, - std::map *tensors_id_map) { +STATUS TfliteDepthToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser"; if (op == nullptr) { @@ -57,10 +54,10 @@ STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr op->primitive->value.type = schema::PrimitiveType_DepthToSpace; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h index 72c212f310..880502cb8b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h @@ -29,11 +29,8 @@ class TfliteDepthToSpaceParser : public TfliteNodeParser { public: TfliteDepthToSpaceParser() : TfliteNodeParser("DepthToSpace") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc index 6904acff0e..0d3ab72fbe 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc @@ -21,12 +21,9 @@ namespace mindspore { namespace lite { -STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, - std::map *tensors_id_map) { +STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -61,7 +58,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptrinputs[1]; - const auto &data_tensor = tflite_tensors[data_index]; + const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; if (data_tensor == nullptr) { MS_LOG(ERROR) << "the data tensor is null"; return RET_NULL_PTR; @@ -71,7 +68,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptrinputs[1]; - const auto &weight_tensor = tflite_tensors[weight_index]; + const auto &weight_tensor = tflite_model->subgraphs[0]->tensors[weight_index]; if (weight_tensor == nullptr) { MS_LOG(ERROR) << "the weight tensor is null"; return RET_NULL_PTR; @@ -96,14 +93,14 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptrprimitive->value.type = schema::PrimitiveType_DepthwiseConv2D; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_KHWC); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[2], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_KHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h index 20451e7319..73b0b25ea4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h @@ -29,11 +29,8 @@ class TfliteDepthwiseConv2DParser : public TfliteNodeParser { public: TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc index 1ce4fa1b28..7fbc882b0b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc @@ -20,11 +20,9 @@ namespace mindspore { namespace lite { -STATUS TfliteDequantizeParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteDequantizeNParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -36,12 +34,12 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr &t return RET_NULL_PTR; } - const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]]; + const auto &in_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; if (in_tensor == nullptr) { MS_LOG(ERROR) << "input tensor is null"; return RET_NULL_PTR; } - const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]]; + const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]]; if (out_tensor == nullptr) { MS_LOG(ERROR) << "output tensor is null"; return RET_NULL_PTR; @@ -70,10 +68,10 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr &t op->primitive->value.type = schema::PrimitiveType_Cast; } - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h index dc1a3c1545..2897b1857b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h @@ -28,11 +28,8 @@ class TfliteDequantizeParser : public TfliteNodeParser { public: TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc index 772bdcff9d..890b700f8d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc @@ -21,11 +21,9 @@ namespace mindspore { namespace lite { -STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteExpandDimsParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteExpandDimsParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -43,17 +41,17 @@ STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr &t return RET_NULL_PTR; } std::vector dims; - if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, dims)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, dims)) { MS_LOG(ERROR) << "get expand_dims -> dim failed"; return RET_ERROR; } attr->dim = dims[0]; op->primitive->value.type = schema::PrimitiveType_ExpandDims; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h index 09029fe591..4832f1117b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h @@ -29,11 +29,8 @@ class TfliteExpandDimsParser : public TfliteNodeParser { public: TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc index b264bc809d..3805f7f109 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteFillParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteFillParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteFillParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -44,7 +41,7 @@ STATUS TfliteFillParser::Parse(const std::unique_ptr &tflite_ } if (tflite_op->inputs.size() > 1) { - if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dims)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->dims)) { MS_LOG(ERROR) << "get fill -> dims failed"; return RET_ERROR; } @@ -53,10 +50,10 @@ STATUS TfliteFillParser::Parse(const std::unique_ptr &tflite_ op->primitive->value.type = schema::PrimitiveType_Fill; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h index 1e454e9fb4..8af709f3a4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h @@ -29,11 +29,8 @@ class TfliteFillParser : public TfliteNodeParser { public: TfliteFillParser() : TfliteNodeParser("Fill") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc index d82c6d57ae..8674e52567 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc @@ -21,12 +21,9 @@ namespace mindspore { namespace lite { -STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, - std::map *tensors_id_map) { +STATUS TfliteFullyConnectedParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -60,16 +57,16 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptrprimitive->value.type = schema::PrimitiveType_FullConnection; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_KHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_KHWC); if (hasBias) { - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[2], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h index 8178d2b09a..3da6407a60 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h @@ -29,11 +29,8 @@ class TfliteFullyConnectedParser : public TfliteNodeParser { public: TfliteFullyConnectedParser() : TfliteNodeParser("FullyConnected") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; class TfliteFakeQuantParser : public TfliteFullyConnectedParser { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc index 5681b5a12a..b1073f8f27 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteGatherNdParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteGatherNdParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteGatherNdParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -49,11 +46,11 @@ STATUS TfliteGatherNdParser::Parse(const std::unique_ptr &tfl op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h index f8d11b6ccc..9f93547a0b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h @@ -29,11 +29,8 @@ class TfliteGatherNdParser : public TfliteNodeParser { public: TfliteGatherNdParser() : TfliteNodeParser("GatherND") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc index bb215aa0e9..23f2f07611 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteGatherParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteGatherParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteGatherParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -55,11 +52,11 @@ STATUS TfliteGatherParser::Parse(const std::unique_ptr &tflit op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h index 4558e02307..6ead6b01d2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h @@ -29,11 +29,8 @@ class TfliteGatherParser : public TfliteNodeParser { public: TfliteGatherParser() : TfliteNodeParser("Gather") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc index 8e989ae164..c4dc02eda9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc @@ -21,12 +21,9 @@ namespace mindspore { namespace lite { -STATUS TfliteHashtableLookupParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, - std::map *tensors_id_map) { +STATUS TfliteHashtableLookupParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteHashtableLookupParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -47,12 +44,12 @@ STATUS TfliteHashtableLookupParser::Parse(const std::unique_ptrprimitive->value.type = schema::PrimitiveType_HashtableLookup; op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } for (size_t i = 0; i < tflite_op->outputs.size(); ++i) { - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } return RET_OK; } @@ -60,4 +57,3 @@ STATUS TfliteHashtableLookupParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc index 1a158fbccb..13b0122689 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc @@ -22,11 +22,8 @@ namespace mindspore { namespace lite { -STATUS TfliteL2NormParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteL2NormParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteL2NormParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -51,11 +48,11 @@ STATUS TfliteL2NormParser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_L2Norm; op->primitive->value.value = attr.release(); - // set input - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + // set input and output + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.h index 7da98c9cfc..3ddb116967 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.h @@ -29,11 +29,8 @@ class TfliteL2NormParser : public TfliteNodeParser { public: TfliteL2NormParser() : TfliteNodeParser("L2_NORMALIZATION") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc index c8b249f0b9..da14bb9338 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc @@ -22,11 +22,8 @@ namespace mindspore { namespace lite { -STATUS TfliteLogicalParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteLogicalParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -70,11 +67,11 @@ STATUS TfliteLogicalParser::Parse(const std::unique_ptr &tfli } for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h index d95145f658..b6a21aeeb4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h @@ -29,11 +29,8 @@ class TfliteLogicalParser : public TfliteNodeParser { public: TfliteLogicalParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; class TfliteLogicalAndParser : public TfliteLogicalParser { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc index 8e1371c95e..e33b2d5dd4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteLRNParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteLRNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteLRNParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -56,10 +53,10 @@ STATUS TfliteLRNParser::Parse(const std::unique_ptr &tflite_o op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h index 1566650c8c..492d677b63 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h @@ -29,11 +29,8 @@ class TfliteLRNParser : public TfliteNodeParser { public: TfliteLRNParser() : TfliteNodeParser("LocalResponseNorm") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc index 93eed2fd3e..489a2f7f0e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc @@ -21,13 +21,9 @@ namespace mindspore { namespace lite { -STATUS TfliteLshProjectionParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, - std::vector *tensors_id, - std::vector *tensors_format, - std::map *tensors_id_map) { +STATUS TfliteLshProjectionParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteLshProjectionParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -60,15 +56,14 @@ STATUS TfliteLshProjectionParser::Parse(const std::unique_ptr op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, - tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, - tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } TfliteNodeRegister g_tfliteLshProjectionParser("LshProjection", new TfliteLshProjectionParser()); } // namespace lite } // namespace mindspore - diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.h index 145bf3da42..448ceb7fff 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.h @@ -29,16 +29,10 @@ class TfliteLshProjectionParser : public TfliteNodeParser { public: TfliteLshProjectionParser() : TfliteNodeParser("LshProjection") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, - std::vector *tensors_id, - std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LSH_PROJECTION_PARSER_H - diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index d9228a3ef0..1837a8d39f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -120,8 +120,7 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr &tflit continue; } if (status == RET_OK) { - status = node_parser->Parse(tflite_op, tflite_subgraph->tensors, tflite_model->buffers, op.get(), &tensorsId, - &tensorsFormat, &tensorsIdMap); + status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, op.get()); if (status != RET_OK) { MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; continue; @@ -138,15 +137,15 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr &tflit STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr &tflite_subgraph, const std::vector> &tflite_model_buffer, schema::MetaGraphT *sub_graph) { - for (size_t i = 0; i < tensorsId.size(); i++) { - auto idx = tensorsId[i]; + for (size_t i = 0; i < tensorsInfo.tensorsId.size(); i++) { + auto idx = tensorsInfo.tensorsId[i]; if (idx < 0) { idx += tflite_subgraph->tensors.size(); } const auto &tflite_tensor = tflite_subgraph->tensors[idx]; std::unique_ptr tensor = std::make_unique(); - tensor->format = tensorsFormat[i]; + tensor->format = tensorsInfo.tensorsFormat[i]; tensor->dataType = GetTfliteDataType(tflite_tensor->type); tensor->dims = tflite_tensor->shape; @@ -207,8 +206,8 @@ STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr } else { id = idx; } - auto iter = tensorsIdMap.find(id); - if (iter != tensorsIdMap.end()) { + auto iter = tensorsInfo.tensorsIdMap.find(id); + if (iter != tensorsInfo.tensorsIdMap.end()) { graph_inputs.push_back(iter->second); } else { MS_LOG(ERROR) << "get graph input failed"; @@ -226,8 +225,8 @@ STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr } else { id = idx; } - auto iter = tensorsIdMap.find(id); - if (iter != tensorsIdMap.end()) { + auto iter = tensorsInfo.tensorsIdMap.find(id); + if (iter != tensorsInfo.tensorsIdMap.end()) { graph_outputs.push_back(iter->second); } else { MS_LOG(ERROR) << "get graph output failed"; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index 3cae315af1..b799f02cf6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -65,9 +65,7 @@ class TfliteModelParser : public ModelParser { STATUS ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph); private: - std::vector tensorsId; - std::vector tensorsFormat; - std::map tensorsIdMap; + TfliteTensorsInfo tensorsInfo; std::vector tensors; std::map opMap; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h index 8b8bd0d40e..9085b1d7f9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h @@ -38,40 +38,37 @@ class TfliteNodeParser { virtual ~TfliteNodeParser() = default; - virtual STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) = 0; + virtual STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) = 0; - void AddOpInput(schema::CNodeT *op, std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map, int idx, int new_idx, int total, schema::Format format) { - auto iter = tensors_id_map->find(idx); - if (iter != tensors_id_map->end()) { + void AddOpInput(schema::CNodeT *op, TfliteTensorsInfo *tensors_info, int idx, int total, schema::Format format) { + int new_idx = tensors_info->tensorsId.size(); + auto iter = tensors_info->tensorsIdMap.find(idx); + if (iter != tensors_info->tensorsIdMap.end()) { op->inputIndex.emplace_back(iter->second); } else { if (idx < 0) { idx += total; } - tensors_id->emplace_back(idx); - tensors_format->emplace_back(format); - tensors_id_map->insert(std::make_pair(idx, new_idx)); + tensors_info->tensorsId.emplace_back(idx); + tensors_info->tensorsFormat.emplace_back(format); + tensors_info->tensorsIdMap.insert(std::make_pair(idx, new_idx)); op->inputIndex.emplace_back(new_idx); } } - void AddOpOutput(schema::CNodeT *op, std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map, int idx, int new_idx, int total, schema::Format format) { - auto iter = tensors_id_map->find(idx); - if (iter != tensors_id_map->end()) { + void AddOpOutput(schema::CNodeT *op, TfliteTensorsInfo *tensors_info, int idx, int total, schema::Format format) { + int new_idx = tensors_info->tensorsId.size(); + auto iter = tensors_info->tensorsIdMap.find(idx); + if (iter != tensors_info->tensorsIdMap.end()) { op->outputIndex.emplace_back(iter->second); } else { if (idx < 0) { idx += total; } - tensors_id->emplace_back(idx); - tensors_format->emplace_back(format); - tensors_id_map->insert(std::make_pair(idx, new_idx)); + tensors_info->tensorsId.emplace_back(idx); + tensors_info->tensorsFormat.emplace_back(format); + tensors_info->tensorsIdMap.insert(std::make_pair(idx, new_idx)); op->outputIndex.emplace_back(new_idx); } } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc index dcb7eb3dcc..35984b1375 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteOneHotParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteOneHotParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteOneHotParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -49,7 +46,7 @@ STATUS TfliteOneHotParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } auto axis = tflite_attr->axis; - const auto &tensor = tflite_tensors[tflite_op->inputs[0]]; + const auto &tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; if (tensor == nullptr) { MS_LOG(ERROR) << "tensor is null"; return RET_NULL_PTR; @@ -60,11 +57,11 @@ STATUS TfliteOneHotParser::Parse(const std::unique_ptr &tflit op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h index bac4ce944e..ea3ebe9fb4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h @@ -29,11 +29,8 @@ class TfliteOneHotParser : public TfliteNodeParser { public: TfliteOneHotParser() : TfliteNodeParser("OneHot") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc index e131cb27c5..0c7ab33cc1 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc @@ -22,11 +22,8 @@ namespace mindspore { namespace lite { -STATUS TflitePadParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TflitePadParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -54,7 +51,8 @@ STATUS TflitePadParser::Parse(const std::unique_ptr &tflite_o } attr->paddingMode = schema::PaddingMode_CONSTANT; attr->constantValue = 0.0f; - if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->paddings)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, + attr->paddings)) { MS_LOG(ERROR) << "get pad -> paddings failed"; return RET_ERROR; } @@ -74,7 +72,7 @@ STATUS TflitePadParser::Parse(const std::unique_ptr &tflite_o default: MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support"; return RET_INVALID_OP_ATTR; - } + } } else { MS_LOG(ERROR) << "this pad:" << node_name << " hasn't been supported"; return RET_NOT_SUPPORT; @@ -83,14 +81,14 @@ STATUS TflitePadParser::Parse(const std::unique_ptr &tflite_o op->primitive->value.type = schema::PrimitiveType_Pad; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); if (std::strcmp(node_name, "MirrorPad") == 0) { - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h index 86e91913fc..d040aebe6b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h @@ -29,11 +29,8 @@ class TflitePadParser : public TfliteNodeParser { public: TflitePadParser() : TfliteNodeParser("Pad") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc index 51e096e7f1..3d47e9bbed 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc @@ -22,11 +22,8 @@ namespace mindspore { namespace lite { -STATUS TflitePoolingParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -72,7 +69,7 @@ STATUS TflitePoolingParser::Parse(const std::unique_ptr &tfli // calculate pad params auto data_index = tflite_op->inputs[0]; - const auto &data_tensor = tflite_tensors[data_index]; + const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; std::vector params; if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, ¶ms) != RET_OK) { @@ -88,10 +85,10 @@ STATUS TflitePoolingParser::Parse(const std::unique_ptr &tfli op->primitive->value.type = schema::PrimitiveType_Pooling; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h index 914c8db8e1..b6b8b25e51 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h @@ -29,11 +29,8 @@ class TflitePoolingParser : public TfliteNodeParser { public: TflitePoolingParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; class TfliteMeanPoolingParser : public TflitePoolingParser { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.cc index 0fd3d89e3b..0c96375c14 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.cc @@ -22,11 +22,8 @@ namespace mindspore { namespace lite { -STATUS TflitePReLUParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TflitePReLUParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TflitePReLUParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -47,12 +44,12 @@ STATUS TflitePReLUParser::Parse(const std::unique_ptr &tflite op->primitive->value.type = schema::PrimitiveType_PReLU; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.h index be03764d25..35bd7936e6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.h @@ -29,11 +29,8 @@ class TflitePReLUParser : public TfliteNodeParser { public: TflitePReLUParser() : TfliteNodeParser("PRELU") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc index af8abcf176..28bc97e8cf 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc @@ -20,11 +20,8 @@ namespace mindspore { namespace lite { -STATUS TfliteQuantizeParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteQuantizeNParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -36,12 +33,12 @@ STATUS TfliteQuantizeParser::Parse(const std::unique_ptr &tfl return RET_NULL_PTR; } - const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]]; + const auto &in_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; if (in_tensor == nullptr) { MS_LOG(ERROR) << "input tensor is null"; return RET_NULL_PTR; } - const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]]; + const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]]; if (out_tensor == nullptr) { MS_LOG(ERROR) << "output tensor is null"; return RET_NULL_PTR; @@ -70,10 +67,10 @@ STATUS TfliteQuantizeParser::Parse(const std::unique_ptr &tfl op->primitive->value.value = attr.release(); } - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.h index 0ee44fc567..f0d29a7653 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.h @@ -28,11 +28,8 @@ class TfliteQuantizeParser : public TfliteNodeParser { public: TfliteQuantizeParser() : TfliteNodeParser("Quantize") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc index 8e9d1a1309..ff3244c260 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteRangeParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteRangeParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -51,10 +48,10 @@ STATUS TfliteRangeParser::Parse(const std::unique_ptr &tflite op->primitive->value.type = schema::PrimitiveType_Range; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h index eedace5def..6b44f7d0f8 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h @@ -29,11 +29,8 @@ class TfliteRangeParser : public TfliteNodeParser { public: TfliteRangeParser() : TfliteNodeParser("Range") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc index cd01d674bf..3bdf2f156e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteRankParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteRankParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteRankParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -46,10 +43,10 @@ STATUS TfliteRankParser::Parse(const std::unique_ptr &tflite_ op->primitive->value.type = schema::PrimitiveType_Rank; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h index 9a910c5774..c732dcf9e7 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h @@ -29,11 +29,8 @@ class TfliteRankParser : public TfliteNodeParser { public: TfliteRankParser() : TfliteNodeParser("Rank") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc index bb36668fd4..b764a9d3b9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc @@ -22,11 +22,8 @@ namespace mindspore { namespace lite { -STATUS TfliteReduceParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -75,7 +72,7 @@ STATUS TfliteReduceParser::Parse(const std::unique_ptr &tflit return RET_NOT_FIND_OP; } - if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->axes)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->axes)) { MS_LOG(ERROR) << "get reduce -> axes failed"; return RET_ERROR; } @@ -83,10 +80,10 @@ STATUS TfliteReduceParser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_Reduce; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h index eb2d422f8d..0179c0f290 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h @@ -29,11 +29,8 @@ class TfliteReduceParser : public TfliteNodeParser { public: TfliteReduceParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; class TfliteReduceMaxParser : public TfliteReduceParser { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc index 3aea529d4d..98b1813c9e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteReshapeParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteReshapeParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -50,18 +47,19 @@ STATUS TfliteReshapeParser::Parse(const std::unique_ptr &tfli return RET_ERROR; } auto shape_tensor_index = tflite_op->inputs[1]; - const auto &shape_tensor = tflite_tensors[shape_tensor_index]; + const auto &shape_tensor = tflite_model->subgraphs[0]->tensors[shape_tensor_index]; if (shape_tensor == nullptr) { MS_LOG(ERROR) << "shape_tensor is null"; return RET_NULL_PTR; } - auto &buf_data = tflite_model_buffer[shape_tensor->buffer]; + auto &buf_data = tflite_model->buffers[shape_tensor->buffer]; if (buf_data == nullptr) { MS_LOG(ERROR) << "buf_data is null"; return RET_NULL_PTR; } if (!buf_data->data.empty()) { - if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->shape)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, + attr->shape)) { MS_LOG(ERROR) << "get reshape -> shape failed"; return RET_ERROR; } @@ -78,11 +76,11 @@ STATUS TfliteReshapeParser::Parse(const std::unique_ptr &tfli op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h index 6f1d6fcd28..582ba911ce 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h @@ -29,11 +29,8 @@ class TfliteReshapeParser : public TfliteNodeParser { public: TfliteReshapeParser() : TfliteNodeParser("Reshape") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc index 6ba82a7b05..ad16c51833 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc @@ -22,11 +22,8 @@ namespace mindspore { namespace lite { -STATUS TfliteResizeParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -73,13 +70,13 @@ STATUS TfliteResizeParser::Parse(const std::unique_ptr &tflit attr->preserveAspectRatio = false; auto tfliteResizeTensorIndex = tflite_op->inputs[1]; - const auto &shape_tensor = tflite_tensors[tfliteResizeTensorIndex]; + const auto &shape_tensor = tflite_model->subgraphs[0]->tensors[tfliteResizeTensorIndex]; if (shape_tensor == nullptr) { MS_LOG(ERROR) << "shape_tensor is null"; return RET_NULL_PTR; } auto resizeTensorBufferIndex = shape_tensor->buffer; - const auto &buff = tflite_model_buffer.at(resizeTensorBufferIndex); + const auto &buff = tflite_model->buffers.at(resizeTensorBufferIndex); if (buff == nullptr) { MS_LOG(ERROR) << "buff_data is null"; return RET_NULL_PTR; @@ -95,14 +92,14 @@ STATUS TfliteResizeParser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_Resize; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); if (buffData == nullptr) { - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h index 526864e6f7..3fae174d6c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h @@ -29,11 +29,8 @@ class TfliteResizeParser : public TfliteNodeParser { public: TfliteResizeParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; class TfliteResizeBilinearParser : public TfliteResizeParser { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc index 2662f33e68..4301b81b7e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteReverseParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteReverseParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteReverseParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -43,7 +40,7 @@ STATUS TfliteReverseParser::Parse(const std::unique_ptr &tfli return RET_NULL_PTR; } - if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->axis)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->axis)) { MS_LOG(ERROR) << "get reverse -> axis failed"; return RET_ERROR; } @@ -51,10 +48,10 @@ STATUS TfliteReverseParser::Parse(const std::unique_ptr &tfli op->primitive->value.type = schema::PrimitiveType_Reverse; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h index 0e3771eb1e..34d59ae501 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h @@ -29,11 +29,8 @@ class TfliteReverseParser : public TfliteNodeParser { public: TfliteReverseParser() : TfliteNodeParser("reverse") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc index 0bb7e1770a..cc98fbd540 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc @@ -22,12 +22,9 @@ namespace mindspore { namespace lite { -STATUS TfliteReverseSequenceParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, - std::map *tensors_id_map) { +STATUS TfliteReverseSequenceParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteReverseSequenceParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -56,12 +53,12 @@ STATUS TfliteReverseSequenceParser::Parse(const std::unique_ptrprimitive->value.type = schema::PrimitiveType_ReverseSequence; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format_NHWC); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), - tflite_tensors.size(), schema::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h index 538f859bf5..afd0621940 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h @@ -29,11 +29,8 @@ class TfliteReverseSequenceParser : public TfliteNodeParser { public: TfliteReverseSequenceParser() : TfliteNodeParser("ReverseSequence") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc index ccd5e1cbe6..df395a4f7a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc @@ -22,11 +22,9 @@ namespace mindspore { namespace lite { -STATUS TfliteScatterNdParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteScatterNdParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteScatterNdParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -54,14 +52,14 @@ STATUS TfliteScatterNdParser::Parse(const std::unique_ptr &tf // in tflite, kIndices = 0, kUpdates = 1, kShape = 2 // in mslite, kScatterShapeIndex = 0, kScatterIndicesIndex = 1, kScatterUpdateIndex = 2; - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[2], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h index 788c6c522e..cab92dd1f4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h @@ -29,11 +29,8 @@ class TfliteScatterNdParser : public TfliteNodeParser { public: TfliteScatterNdParser() : TfliteNodeParser("ScatterNd") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc index ef98d0888c..a0f8c7e00c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteShapeParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteShapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteShapeParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -46,10 +43,10 @@ STATUS TfliteShapeParser::Parse(const std::unique_ptr &tflite op->primitive->value.type = schema::PrimitiveType_Shape; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h index c413750764..5020b44b5e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h @@ -29,11 +29,8 @@ class TfliteShapeParser : public TfliteNodeParser { public: TfliteShapeParser() : TfliteNodeParser("Shape") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc index 0329520140..21abe5b5be 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteSkipGramParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteSkipGramParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteSkipGramParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -55,10 +52,10 @@ STATUS TfliteSkipGramParser::Parse(const std::unique_ptr &tfl op->primitive->value.type = schema::PrimitiveType_SkipGram; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.h index 5ebe6f5846..29ece28820 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.h @@ -29,11 +29,8 @@ class TfliteSkipGramParser : public TfliteNodeParser { public: TfliteSkipGramParser() : TfliteNodeParser("SkipGram") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc index 9cca9d521c..1e48e19999 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteSliceParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteSliceParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteSliceParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -45,11 +42,11 @@ STATUS TfliteSliceParser::Parse(const std::unique_ptr &tflite attr->format = schema::Format::Format_NHWC; - if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->begin)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->begin)) { MS_LOG(ERROR) << "get slice -> begin failed"; return RET_ERROR; } - if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->size)) { + if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->size)) { MS_LOG(ERROR) << "get slice -> size failed"; return RET_ERROR; } @@ -62,10 +59,10 @@ STATUS TfliteSliceParser::Parse(const std::unique_ptr &tflite op->primitive->value.type = schema::PrimitiveType_Slice; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h index 7a4850dca8..d363c453c8 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h @@ -29,11 +29,8 @@ class TfliteSliceParser : public TfliteNodeParser { public: TfliteSliceParser() : TfliteNodeParser("Slice") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc index 814c5000f9..0ef5b4d62b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteSoftmaxParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteSoftmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteSoftmaxParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -48,10 +45,10 @@ STATUS TfliteSoftmaxParser::Parse(const std::unique_ptr &tfli op->primitive->value.type = schema::PrimitiveType_SoftMax; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h index 5a7387ac58..30585bc0cd 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h @@ -29,11 +29,8 @@ class TfliteSoftmaxParser : public TfliteNodeParser { public: TfliteSoftmaxParser() : TfliteNodeParser("Softmax") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc index ed5467a878..58e4816dc3 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc @@ -22,12 +22,9 @@ namespace mindspore { namespace lite { -STATUS TfliteSpaceToBatchNDParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, - std::map *tensors_id_map) { +STATUS TfliteSpaceToBatchNDParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteSpaceToBatchNDParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -45,11 +42,12 @@ STATUS TfliteSpaceToBatchNDParser::Parse(const std::unique_ptrinputs[1], tflite_tensors, tflite_model_buffer, attr->blockShape)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, + attr->blockShape)) { MS_LOG(ERROR) << "get spaceToBatchND -> blockShape failed"; return RET_ERROR; } - if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->paddings)) { + if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->paddings)) { MS_LOG(ERROR) << "get spaceToBatchND -> paddings failed"; return RET_ERROR; } @@ -57,10 +55,10 @@ STATUS TfliteSpaceToBatchNDParser::Parse(const std::unique_ptrprimitive->value.type = schema::PrimitiveType_SpaceToBatchND; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h index 284396bcbb..e63956ccfb 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h @@ -29,11 +29,8 @@ class TfliteSpaceToBatchNDParser : public TfliteNodeParser { public: TfliteSpaceToBatchNDParser() : TfliteNodeParser("SpaceToBatchND") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc index 08207e02cd..64750d62bb 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc @@ -22,12 +22,9 @@ namespace mindspore { namespace lite { -STATUS TfliteSpaceToDepthParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, - std::map *tensors_id_map) { +STATUS TfliteSpaceToDepthParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteSpaceToDepthParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -56,10 +53,10 @@ STATUS TfliteSpaceToDepthParser::Parse(const std::unique_ptr op->primitive->value.type = schema::PrimitiveType_SpaceToDepth; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h index 8ab6e77dbd..4e6e9fd540 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h @@ -29,11 +29,8 @@ class TfliteSpaceToDepthParser : public TfliteNodeParser { public: TfliteSpaceToDepthParser() : TfliteNodeParser("SpaceToDepth") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc index 8e76dbb3c8..ed3f931de2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc @@ -22,12 +22,9 @@ namespace mindspore { namespace lite { -STATUS TfliteSparseToDenseParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, - std::map *tensors_id_map) { +STATUS TfliteSparseToDenseParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteSparseToDenseParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -49,16 +46,16 @@ STATUS TfliteSparseToDenseParser::Parse(const std::unique_ptr op->primitive->value.type = schema::PrimitiveType_SparseToDense; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[2], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[3], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[3], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h index 6e7f6923dc..32361d19c9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h @@ -29,11 +29,8 @@ class TfliteSparseToDenseParser : public TfliteNodeParser { public: TfliteSparseToDenseParser() : TfliteNodeParser("SparseToDense") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc index a0f989a15c..66da0beff5 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteSplitParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteSplitParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteSplitParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -50,18 +47,18 @@ STATUS TfliteSplitParser::Parse(const std::unique_ptr &tflite } auto num_splits = tflite_attr->num_splits; - const auto &shape_tensor = tflite_tensors[tflite_op->inputs[1]]; + const auto &shape_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[1]]; if (shape_tensor == nullptr) { MS_LOG(ERROR) << "shape_tensor is null"; return RET_NULL_PTR; } const auto tensor_shape = shape_tensor->shape; - const auto &axis_tensor = tflite_tensors[tflite_op->inputs[0]]; + const auto &axis_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; if (axis_tensor == nullptr) { MS_LOG(ERROR) << "axis_tensor is null"; return RET_NULL_PTR; } - auto axis = *(reinterpret_cast(tflite_model_buffer[axis_tensor->buffer]->data.data())); + auto axis = *(reinterpret_cast(tflite_model->buffers[axis_tensor->buffer]->data.data())); if (axis < 0) { axis += tensor_shape.size(); } @@ -83,11 +80,11 @@ STATUS TfliteSplitParser::Parse(const std::unique_ptr &tflite op->primitive->value.type = schema::PrimitiveType_Split; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); for (size_t i = 0; i < tflite_op->outputs.size(); i++) { - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h index e182f323ed..d2c85bbdfd 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h @@ -29,11 +29,8 @@ class TfliteSplitParser : public TfliteNodeParser { public: TfliteSplitParser() : TfliteNodeParser("Split") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc index 8631e35449..dea7b0e9df 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteSplitVParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteSplitVParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteSplitVParser"; if (op == nullptr) { @@ -51,23 +48,24 @@ STATUS TfliteSplitVParser::Parse(const std::unique_ptr &tflit } attr->numberSplit = tflite_attr->num_splits; - if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->sizeSplits)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, + attr->sizeSplits)) { MS_LOG(ERROR) << "get spliteV -> sizeSplits failed"; return RET_ERROR; } - const auto &tensor = tflite_tensors[tflite_op->inputs[0]]; + const auto &tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; if (tensor == nullptr) { MS_LOG(ERROR) << "tensor_shape is null"; return RET_NULL_PTR; } auto tensor_shape = tensor->shape; - const auto &axis_tensor = tflite_tensors[tflite_op->inputs[2]]; + const auto &axis_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[2]]; if (axis_tensor == nullptr) { MS_LOG(ERROR) << "axis_tensor is null"; return RET_NULL_PTR; } - auto axis = *(reinterpret_cast(tflite_model_buffer[axis_tensor->buffer]->data.data())); + auto axis = *(reinterpret_cast(tflite_model->buffers[axis_tensor->buffer]->data.data())); if (axis < 0) { axis += tensor_shape.size(); } @@ -80,11 +78,11 @@ STATUS TfliteSplitVParser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_Split; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); for (size_t i = 0; i < tflite_op->outputs.size(); i++) { - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h index 8b3f37fa81..85427ceab6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h @@ -29,11 +29,8 @@ class TfliteSplitVParser : public TfliteNodeParser { public: TfliteSplitVParser() : TfliteNodeParser("SplitV") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc index aa2b8deaea..882f0fd5e4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteSqueezeParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteSqueezeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteSqueezeParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -53,10 +50,10 @@ STATUS TfliteSqueezeParser::Parse(const std::unique_ptr &tfli op->primitive->value.type = schema::PrimitiveType_Squeeze; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h index 685866d409..b486cf1e00 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h @@ -29,11 +29,8 @@ class TfliteSqueezeParser : public TfliteNodeParser { public: TfliteSqueezeParser() : TfliteNodeParser("Squeeze") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc index fb6917950c..ad3bd244e8 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc @@ -21,11 +21,8 @@ namespace mindspore { namespace lite { -STATUS TfliteStackParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteStackParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteStackParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -50,18 +47,18 @@ STATUS TfliteStackParser::Parse(const std::unique_ptr &tflite } attr->axis = tflite_attr->axis; attr->n = tflite_attr->values_count; - attr->isScale.assign(tflite_tensors[tflite_op->inputs[0]]->shape.begin(), - tflite_tensors[tflite_op->inputs[0]]->shape.end()); + attr->isScale.assign(tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]->shape.begin(), + tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]->shape.end()); op->primitive->value.type = schema::PrimitiveType_Stack; op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h index d4801ec1ed..b30103b0e7 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h @@ -29,11 +29,8 @@ class TfliteStackParser : public TfliteNodeParser { public: TfliteStackParser() : TfliteNodeParser("Stack") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc index fdfe493027..a66f8936d8 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc @@ -21,12 +21,9 @@ namespace mindspore { namespace lite { -STATUS TfliteStridedSliceParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, - std::map *tensors_id_map) { +STATUS TfliteStridedSliceParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteStridedSliceParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -55,28 +52,28 @@ STATUS TfliteStridedSliceParser::Parse(const std::unique_ptr attr->newAxisMask = tflite_attr->new_axis_mask; attr->shrinkAxisMask = tflite_attr->shrink_axis_mask; - if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->begin)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->begin)) { MS_LOG(ERROR) << "stridedSlice -> begin get failed"; return RET_ERROR; } - if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->end)) { + if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->end)) { MS_LOG(ERROR) << "stridedSlice -> end get failed"; return RET_ERROR; } - if (GetTfliteData(tflite_op->inputs[3], tflite_tensors, tflite_model_buffer, attr->stride)) { + if (GetTfliteData(tflite_op->inputs[3], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->stride)) { MS_LOG(ERROR) << "stridedSlice -> stride get failed"; return RET_ERROR; } - attr->isScale.assign(tflite_tensors[tflite_op->inputs[0]]->shape.begin(), - tflite_tensors[tflite_op->inputs[0]]->shape.end()); + attr->isScale.assign(tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]->shape.begin(), + tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]->shape.end()); op->primitive->value.type = schema::PrimitiveType_StridedSlice; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h index 1a55022eb1..2fb2e6e378 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h @@ -29,11 +29,8 @@ class TfliteStridedSliceParser : public TfliteNodeParser { public: TfliteStridedSliceParser() : TfliteNodeParser("StridedSlice") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc index c523ca833d..0e39f06b41 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc @@ -22,11 +22,8 @@ namespace mindspore { namespace lite { -STATUS TfliteTileParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteTileParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteTileParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -44,7 +41,8 @@ STATUS TfliteTileParser::Parse(const std::unique_ptr &tflite_ return RET_NULL_PTR; } - if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->multiples)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, + attr->multiples)) { MS_LOG(ERROR) << "get tile -> multiples failed"; return RET_ERROR; } @@ -56,10 +54,10 @@ STATUS TfliteTileParser::Parse(const std::unique_ptr &tflite_ op->primitive->value.type = schema::PrimitiveType_Tile; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h index 602c577648..bb3a801d15 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h @@ -29,11 +29,8 @@ class TfliteTileParser : public TfliteNodeParser { public: TfliteTileParser() : TfliteNodeParser("Tile") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc index 5e864f2d52..248745d24c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc @@ -22,11 +22,8 @@ namespace mindspore { namespace lite { -STATUS TfliteTopKV2Parser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteTopKV2Parser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteTopKV2Parser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -46,7 +43,7 @@ STATUS TfliteTopKV2Parser::Parse(const std::unique_ptr &tflit attr->sorted = true; std::vector k; - if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, k)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, k)) { MS_LOG(ERROR) << "get topKV2 -> k failed"; return RET_ERROR; } @@ -55,11 +52,11 @@ STATUS TfliteTopKV2Parser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_TopK; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); for (size_t i = 0; i < tflite_op->outputs.size(); i++) { - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h index 15f1a4812f..1ab9a43e18 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h @@ -29,11 +29,8 @@ class TfliteTopKV2Parser : public TfliteNodeParser { public: TfliteTopKV2Parser() : TfliteNodeParser("TopKV2") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc index 5cb6ee3dfa..22a5b8975a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc @@ -20,11 +20,9 @@ namespace mindspore { namespace lite { -STATUS TfliteTransposeParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteTransposeParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteTransposeParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -42,7 +40,7 @@ STATUS TfliteTransposeParser::Parse(const std::unique_ptr &tf return RET_NULL_PTR; } - if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->perm)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->perm)) { MS_LOG(ERROR) << "get transpose -> perm failed"; return RET_ERROR; } @@ -51,12 +49,12 @@ STATUS TfliteTransposeParser::Parse(const std::unique_ptr &tf op->primitive->value.type = schema::PrimitiveType_Transpose; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_KHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_KHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h index 2f023946cc..4fe20528d1 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h @@ -29,11 +29,8 @@ class TfliteTransposeParser : public TfliteNodeParser { public: TfliteTransposeParser() : TfliteNodeParser("Transpose") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc index d725d0b025..4c8cf480ee 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc @@ -22,11 +22,8 @@ namespace mindspore { namespace lite { -STATUS TfliteUniqueParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteUniqueParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteUniqueParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -54,11 +51,11 @@ STATUS TfliteUniqueParser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_Unique; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); for (size_t i = 0; i < tflite_op->outputs.size(); i++) { - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h index 6bf2af8973..ba86414d22 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h @@ -29,11 +29,8 @@ class TfliteUniqueParser : public TfliteNodeParser { public: TfliteUniqueParser() : TfliteNodeParser("Unique") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc index 48dba85394..11247e4720 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc @@ -22,11 +22,8 @@ namespace mindspore { namespace lite { -STATUS TfliteUnstackParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteUnstackParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "paser TfliteUnstackParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -55,11 +52,11 @@ STATUS TfliteUnstackParser::Parse(const std::unique_ptr &tfli op->primitive->value.type = schema::PrimitiveType_Unstack; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); for (size_t i = 0; i < tflite_op->outputs.size(); i++) { - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h index 5ffc3bb17c..7d82dcdb94 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h @@ -29,11 +29,8 @@ class TfliteUnstackParser : public TfliteNodeParser { public: TfliteUnstackParser() : TfliteNodeParser("Unstack") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc index e9ec14fc12..25a166d093 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc @@ -119,6 +119,7 @@ std::map tfMsOpTypeMap{ {tflite::BuiltinOperator_CUSTOM, "Custom"}, {tflite::BuiltinOperator_MIRROR_PAD, "MirrorPad"}, {tflite::BuiltinOperator_NEG, "Neg"}, + {tflite::BuiltinOperator_PRELU, "PRELU"}, }; std::map tfMsActivationFunctionMap{ diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.h b/mindspore/lite/tools/converter/parser/tflite/tflite_util.h index e9a050c4bf..ad0e2e4f71 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_util.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.h @@ -44,6 +44,12 @@ STATUS getPaddingParam(const std::unique_ptr &tensor, schema::P int strideW, int windowH, int windowW, std::vector *params); void Split(const std::string &src_str, std::vector *dst_str, const std::string &chr); + +struct TfliteTensorsInfo { + std::vector tensorsId; + std::vector tensorsFormat; + std::map tensorsIdMap; +}; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc index 0e0bbd0f9e..7afea69876 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc @@ -22,11 +22,8 @@ namespace mindspore { namespace lite { -STATUS TfliteWhereParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteWhereParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteWhereParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -44,7 +41,8 @@ STATUS TfliteWhereParser::Parse(const std::unique_ptr &tflite return RET_NULL_PTR; } - if (GetTfliteData(tflite_op->inputs[0], tflite_tensors, tflite_model_buffer, attr->condition)) { + if (GetTfliteData(tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, + attr->condition)) { MS_LOG(ERROR) << "get where -> condition failed"; return RET_ERROR; } @@ -53,11 +51,11 @@ STATUS TfliteWhereParser::Parse(const std::unique_ptr &tflite op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h index a98f3401ae..6bdfbbe9f9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h @@ -29,11 +29,8 @@ class TfliteWhereParser : public TfliteNodeParser { public: TfliteWhereParser() : TfliteNodeParser("Where") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc index 5f17f68cc7..f0865edfa7 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc @@ -22,11 +22,9 @@ namespace mindspore { namespace lite { -STATUS TfliteZerosLikeParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, - schema::CNodeT *op, std::vector *tensors_id, - std::vector *tensors_format, std::map *tensors_id_map) { +STATUS TfliteZerosLikeParser::Parse(TfliteTensorsInfo *tensors_info, + const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteZerosLikeParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -47,10 +45,10 @@ STATUS TfliteZerosLikeParser::Parse(const std::unique_ptr &tf op->primitive->value.type = schema::PrimitiveType_ZerosLike; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); - AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), - tflite_tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h index 74ec9b920a..045a66ad85 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h @@ -29,11 +29,8 @@ class TfliteZerosLikeParser : public TfliteNodeParser { public: TfliteZerosLikeParser() : TfliteNodeParser("ZerosLike") {} - STATUS Parse(const std::unique_ptr &tflite_op, - const std::vector> &tflite_tensors, - const std::vector> &tflite_model_buffer, schema::CNodeT *op, - std::vector *tensors_id, std::vector *tensors_format, - std::map *tensors_id_map) override; + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc index 196fd86e1b..5c32e579e1 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc @@ -104,8 +104,8 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co } auto bias_add_weight = bias_node->input(kAddWEIGHTINDEX); if (CheckIfNodeIsParam(bias_add_weight) != lite::RET_OK) { - delete[] add_bias_data; - return lite::RET_INVALID_OP_ATTR; + delete[] add_bias_data; + return lite::RET_INVALID_OP_ATTR; } auto add_weight_param = bias_add_weight->cast()->default_param(); auto add_weight_tensor = std::dynamic_pointer_cast(add_weight_param); @@ -124,6 +124,7 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co } if (conv_bias_node != nullptr) { if (CheckIfNodeIsParam(conv_bias_node) != lite::RET_OK) { + delete[] add_bias_data; return lite::RET_INVALID_OP_ATTR; } auto conv_bias_param = conv_bias_node->cast()->default_param();