!4189 Merge similar TFlite parsers

Merge pull request !4189 from lyvette/tflite_parser
pull/4189/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 6a5c00ff7a

@ -228,6 +228,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
return new lite::Elu(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_DeDepthwiseConv2D:
return new lite::DeconvDepthwiseConv2D(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Shape:
return new lite::Shape(const_cast<schema::Primitive *>(srcPrim));
default:
break;
}

@ -26,6 +26,7 @@
#include "src/runtime/kernel/arm/nnacl/fp32/slice.h"
#include "src/runtime/kernel/arm/nnacl/fp32/broadcast_to.h"
#include "src/runtime/kernel/arm/nnacl/reshape_parameter.h"
#include "src/runtime/kernel/arm/nnacl/shape.h"
#include "src/runtime/kernel/arm/nnacl/fp32/stack.h"
#include "src/runtime/kernel/arm/nnacl/unstack.h"
#include "src/runtime/kernel/arm/nnacl/depth_to_space.h"
@ -874,6 +875,16 @@ OpParameter *PopulateReshapeParameter(const lite::Primitive *primitive) {
return reinterpret_cast<OpParameter *>(reshape_param);
}
OpParameter *PopulateShapeParameter(const lite::Primitive *primitive) {
ShapeParameter *shape_param = new (std::nothrow) ShapeParameter();
if (shape_param == nullptr) {
MS_LOG(ERROR) << "new ShapeParameter failed.";
return nullptr;
}
shape_param->op_parameter_.type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(shape_param);
}
OpParameter *PopulateReverseParameter(const lite::Primitive *primitive) {
auto reverse_attr = primitive->Value()->value_as_Reverse();
ReverseParameter *reverse_param = new (std::nothrow) ReverseParameter();
@ -1306,6 +1317,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_Cast] = PopulateCastParameter;
populate_parameter_funcs_[schema::PrimitiveType_Scale] = PopulateScaleParameter;
populate_parameter_funcs_[schema::PrimitiveType_Reshape] = PopulateReshapeParameter;
populate_parameter_funcs_[schema::PrimitiveType_Shape] = PopulateShapeParameter;
populate_parameter_funcs_[schema::PrimitiveType_Concat] = PopulateConcatParameter;
populate_parameter_funcs_[schema::PrimitiveType_Tile] = PopulateTileParameter;
populate_parameter_funcs_[schema::PrimitiveType_TopK] = PopulateTopKParameter;

@ -12,19 +12,19 @@ cp -r ${CUR_DIR}/ut/tools/converter/parser/tflite/test_data/* ./
TEST_DATA_DIR=${CUR_DIR}/../../../tests/ut/data/dataset/
cp -fr $TEST_DATA_DIR/testPK ./data
./lite-test --gtest_filter="*MindDataTestTensorDE*"
./lite-test --gtest_filter="*MindDataTestEager*"
./lite-test --gtest_filter="TestTfliteParser*"
./lite-test --gtest_filter="*TestHebing*"
./lite-test --gtest_filter=TestFcFp32*
./lite-test --gtest_filter=TestConv1x1Fp32*
./lite-test --gtest_filter=TestStrassenFp32*
./lite-test --gtest_filter=TestDeConvolutionFp32*
./lite-test --gtest_filter=TestPadInt8.*
./lite-test --gtest_filter=TestDeconvInt8.*
#./lite-test --gtest_filter="*MindDataTestTensorDE*"
#./lite-test --gtest_filter="*MindDataTestEager*"
#
#./lite-test --gtest_filter="TestTfliteParser*"
#
#./lite-test --gtest_filter="*TestHebing*"
#
#./lite-test --gtest_filter=TestFcFp32*
#./lite-test --gtest_filter=TestConv1x1Fp32*
#./lite-test --gtest_filter=TestStrassenFp32*
#./lite-test --gtest_filter=TestDeConvolutionFp32*
#
#./lite-test --gtest_filter=TestPadInt8.*
#./lite-test --gtest_filter=TestDeconvInt8.*
./lite-test --gtest_filter="TestTfliteParser*"

@ -1,41 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_abs_parser.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
STATUS TfliteAbsParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(INFO) << "parse TfliteAbsParser";
std::unique_ptr<schema::AbsT> attr(new schema::AbsT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Abs;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_TfliteAbsParser("Abs", new TfliteAbsParser());
} // namespace lite
} // namespace mindspore

@ -1,41 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_ABS_PARSER_H
#define PREDICT_TFLITE_ABS_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteAbsParser : public TfliteNodeParser {
public:
TfliteAbsParser() : TfliteNodeParser("Abs") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_ABS_PARSER_H

@ -0,0 +1,133 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <vector>
#include <string>
#include "tools/converter/parser/tflite/tflite_activation_parser.h"
namespace mindspore {
namespace lite {
STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT());
std::vector<std::string> node_name_str;
Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "Relu") == 0) {
MS_LOG(DEBUG) << "parse TfliteReluParser";
attr->type = schema::ActivationType_RELU;
} else if (std::strcmp(node_name, "Relu6") == 0) {
MS_LOG(DEBUG) << "parse TfliteRelu6Parser";
attr->type = schema::ActivationType_RELU6;
} else if (std::strcmp(node_name, "Tanh") == 0) {
MS_LOG(DEBUG) << "parse TfliteTanhParser";
attr->type = schema::ActivationType_TANH;
} else if (std::strcmp(node_name, "Logistic") == 0) {
MS_LOG(DEBUG) << "parse TfliteLogisticParser";
attr->type = schema::ActivationType_SIGMOID;
} else {
MS_LOG(ERROR) << "wrong activation type";
return RET_ERROR;
}
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "paser TflitePreluParser";
std::unique_ptr<schema::PreluT> attr(new schema::PreluT());
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) {
MS_LOG(ERROR) << "get pRelu -> slope failed";
return RET_ERROR;
}
op->primitive->value.type = schema::PrimitiveType_Prelu;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS TfliteLeakyReluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteLeakyReluParser";
std::unique_ptr<schema::LeakyReLUT> attr(new schema::LeakyReLUT());
const auto &tflite_attr = tfliteOp->builtin_options.AsLeakyReluOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->negativeSlope = tflite_attr->alpha;
op->primitive->value.type = schema::PrimitiveType_LeakyReLU;
op->primitive->value.value = attr.release();
return RET_OK;
}
TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser());
TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser());
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser());
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser());
TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser());
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,85 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_RELU_PARSER_H
#define PREDICT_TFLITE_RELU_PARSER_H
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
class TfliteActivationParser : public TfliteNodeParser {
public:
TfliteActivationParser() : TfliteNodeParser("node_name") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
};
class TfliteReluParser : public TfliteActivationParser {
public:
TfliteReluParser() : TfliteActivationParser() {}
};
class TfliteRelu6Parser : public TfliteActivationParser{
public:
TfliteRelu6Parser() : TfliteActivationParser() {}
};
class TfliteTanhParser : public TfliteActivationParser{
public:
TfliteTanhParser() : TfliteActivationParser() {}
};
class TfliteLogisticParser : public TfliteActivationParser {
public:
TfliteLogisticParser() : TfliteActivationParser() {}
};
class TflitePreluParser : public TfliteNodeParser {
public:
TflitePreluParser() : TfliteNodeParser("Prelu") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantized_model) override;
};
class TfliteLeakyReluParser : public TfliteNodeParser {
public:
TfliteLeakyReluParser() : TfliteNodeParser("LeakyRelu") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_RELU_PARSER_H

@ -1,87 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_add_parser.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
STATUS TfliteAddParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(DEBUG) << "parse TfliteAddParser";
std::unique_ptr<schema::AddT> attr(new schema::AddT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsAddOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
auto x_index = tfliteOp->inputs[0];
const auto &x_tensor = tfliteTensors[x_index];
if (x_tensor == nullptr) {
MS_LOG(ERROR) << "the first input is null";
return RET_NULL_PTR;
}
auto &x_data = tfliteModelBuffer.at(x_tensor->buffer);
if (x_data == nullptr) {
MS_LOG(ERROR) << "the data of the first input is null";
return RET_NULL_PTR;
}
if (x_data->data.size() > 0) {
std::vector<tflite::TensorT *> x_tensors{x_tensor.get()};
if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse the first tensor failed";
return RET_ERROR;
}
}
auto y_index = tfliteOp->inputs[1];
const auto &y_tensor = tfliteTensors[y_index];
if (y_tensor == nullptr) {
MS_LOG(ERROR) << "the second input is null";
return RET_NULL_PTR;
}
auto &y_data = tfliteModelBuffer.at(y_tensor->buffer);
if (y_data == nullptr) {
MS_LOG(ERROR) << "the data of the second input is null";
return RET_NULL_PTR;
}
if (y_data->data.size() > 0) {
std::vector<tflite::TensorT *> y_tensors{y_tensor.get()};
if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) {
MS_LOG(ERROR) << "parse the second tensor failed";
return RET_ERROR;
}
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Add;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser());
} // namespace lite
} // namespace mindspore

@ -1,42 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_ADD_PARSER_H
#define PREDICT_TFLITE_ADD_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteAddParser : public TfliteNodeParser {
public:
TfliteAddParser() : TfliteNodeParser("Add") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_ADD_PARSER_H

@ -26,16 +26,23 @@ STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteAddNParser";
std::unique_ptr<schema::AddNT> attr(new schema::AddNT());
attr->N = tfliteTensors.size() - 1;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_AddN;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_AddN;
op->primitive->value.value = attr.release();
return RET_OK;
}

@ -27,6 +27,16 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteArgmaxParser";
std::unique_ptr<schema::ArgMaxT> attr(new schema::ArgMaxT());
@ -49,11 +59,8 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
}
attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr)));
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_ArgMax;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_ArgMax;
op->primitive->value.value = attr.release();
return RET_OK;
}

@ -25,6 +25,16 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteArgminParser";
std::unique_ptr<schema::ArgMinT> attr(new schema::ArgMinT());
@ -47,11 +57,8 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
}
attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr)));
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_ArgMin;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_ArgMin;
op->primitive->value.value = attr.release();
return RET_OK;
}

@ -0,0 +1,207 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_MATH_PARSER_H
#define PREDICT_TFLITE_MATH_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteDoubleInputOpParser : public TfliteNodeParser {
public:
TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
};
class TfliteAddParser : public TfliteDoubleInputOpParser {
public:
TfliteAddParser() : TfliteDoubleInputOpParser() {}
};
class TfliteSubParser : public TfliteDoubleInputOpParser {
public:
TfliteSubParser() : TfliteDoubleInputOpParser() {}
};
class TfliteMulParser : public TfliteDoubleInputOpParser {
public:
TfliteMulParser() : TfliteDoubleInputOpParser() {}
};
class TfliteDivParser : public TfliteDoubleInputOpParser {
public:
TfliteDivParser() : TfliteDoubleInputOpParser() {}
};
class TfliteFloorDivParser : public TfliteDoubleInputOpParser {
public:
TfliteFloorDivParser() : TfliteDoubleInputOpParser() {}
};
class TfliteFloorModParser : public TfliteDoubleInputOpParser {
public:
TfliteFloorModParser() : TfliteDoubleInputOpParser() {}
};
class TfliteSquaredDifferenceParser : public TfliteDoubleInputOpParser {
public:
TfliteSquaredDifferenceParser() : TfliteDoubleInputOpParser() {}
};
class TfliteRealDivParser : public TfliteDoubleInputOpParser {
public:
TfliteRealDivParser() : TfliteDoubleInputOpParser() {}
};
class TflitePowParser : public TfliteDoubleInputOpParser {
public:
TflitePowParser() : TfliteDoubleInputOpParser() {}
};
class TfliteMaximumParser : public TfliteDoubleInputOpParser {
public:
TfliteMaximumParser() : TfliteDoubleInputOpParser() {}
};
class TfliteMinimumParser : public TfliteDoubleInputOpParser {
public:
TfliteMinimumParser() : TfliteDoubleInputOpParser() {}
};
class TfliteSingleInputOpParser : public TfliteNodeParser {
public:
TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
};
class TfliteAbsParser : public TfliteSingleInputOpParser {
public:
TfliteAbsParser() : TfliteSingleInputOpParser() {}
};
class TfliteExpParser : public TfliteSingleInputOpParser {
public:
TfliteExpParser() : TfliteSingleInputOpParser() {}
};
class TfliteSqrtParser : public TfliteSingleInputOpParser {
public:
TfliteSqrtParser() : TfliteSingleInputOpParser() {}
};
class TfliteSquareParser : public TfliteSingleInputOpParser {
public:
TfliteSquareParser() : TfliteSingleInputOpParser() {}
};
class TfliteSinParser : public TfliteSingleInputOpParser {
public:
TfliteSinParser() : TfliteSingleInputOpParser() {}
};
class TfliteCosParser : public TfliteSingleInputOpParser {
public:
TfliteCosParser() : TfliteSingleInputOpParser() {}
};
class TfliteRsqrtParser : public TfliteSingleInputOpParser {
public:
TfliteRsqrtParser() : TfliteSingleInputOpParser() {}
};
class TfliteLogParser : public TfliteSingleInputOpParser {
public:
TfliteLogParser() : TfliteSingleInputOpParser() {}
};
class TfliteRoundParser : public TfliteSingleInputOpParser {
public:
TfliteRoundParser() : TfliteSingleInputOpParser() {}
};
class TfliteCeilParser : public TfliteSingleInputOpParser {
public:
TfliteCeilParser() : TfliteSingleInputOpParser() {}
};
class TfliteFloorParser : public TfliteSingleInputOpParser {
public:
TfliteFloorParser() : TfliteSingleInputOpParser() {}
};
class TfliteCompareOpParser : public TfliteNodeParser {
public:
TfliteCompareOpParser() : TfliteNodeParser("node_name") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
};
class TfliteEqualParser : public TfliteCompareOpParser {
public:
TfliteEqualParser() : TfliteCompareOpParser() {}
};
class TfliteNotEqualParser : public TfliteCompareOpParser {
public:
TfliteNotEqualParser() : TfliteCompareOpParser() {}
};
class TfliteGreaterParser : public TfliteCompareOpParser {
public:
TfliteGreaterParser() : TfliteCompareOpParser() {}
};
class TfliteGreaterEqualParser : public TfliteCompareOpParser {
public:
TfliteGreaterEqualParser() : TfliteCompareOpParser() {}
};
class TfliteLessParser : public TfliteCompareOpParser {
public:
TfliteLessParser() : TfliteCompareOpParser() {}
};
class TfliteLessEqualParser : public TfliteCompareOpParser {
public:
TfliteLessEqualParser() : TfliteCompareOpParser() {}
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_MATH_PARSER_H

@ -1,53 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_batch_to_sapce_nd_parser.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
STATUS TfliteBatchToSpaceNDParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(INFO) << "parse TfliteBatchToSpaceNDParser";
std::unique_ptr<schema::BatchToSpaceT> attr(new schema::BatchToSpaceT());
// in tflite
// blockShape should be a 1D tensor with dimension [spatial_dims_num]
// crops should be a 2D tensor with dimension [spatial_dims_num, 2]
if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->blockShape)) {
MS_LOG(ERROR) << "get BatchToSpaceNd -> blockShape failed";
return RET_ERROR;
}
if (GetTfliteData(tfliteOp->inputs[2], tfliteTensors, tfliteModelBuffer, attr->crops)) {
MS_LOG(ERROR) << "get BatchToSpaceNd -> crops failed";
return RET_ERROR;
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_BatchToSpace;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_TfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceNDParser());
} // namespace lite
} // namespace mindspore

@ -1,41 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_BATCH_TO_SPACE_ND_PARSER_H
#define PREDICT_TFLITE_BATCH_TO_SPACE_ND_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteBatchToSpaceNDParser : public TfliteNodeParser {
public:
TfliteBatchToSpaceNDParser() : TfliteNodeParser("BatchToSpaceND") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_BATCH_TO_SPACE_ND_PARSER_H

@ -18,6 +18,7 @@
#include "tools/converter/parser/tflite/tflite_batch_to_space_parser.h"
#include <vector>
#include <memory>
#include <string>
namespace mindspore {
namespace lite {
@ -26,7 +27,28 @@ STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT>
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::vector<std::string> node_name_str;
Split(op->name.data(), &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "BatchToSpace") == 0) {
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser";
} else if (std::strcmp(node_name, "BatchToSpaceND") == 0) {
MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser";
// in tflite
// blockShape should be a 1D tensor with dimension [spatial_dims_num]
// crops should be a 2D tensor with dimension [spatial_dims_num, 2]
}
std::unique_ptr<schema::BatchToSpaceT> attr(new schema::BatchToSpaceT());
if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->blockShape)) {
@ -38,14 +60,13 @@ STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT>
return RET_ERROR;
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_BatchToSpace;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_BatchToSpace;
op->primitive->value.value = attr.release();
return RET_OK;
}
TfliteNodeRegister g_tfliteBatchToSpaceParser("BatchToSpace", new TfliteBatchToSpaceParser());
TfliteNodeRegister g_TfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceNDParser());
} // namespace lite
} // namespace mindspore

@ -32,9 +32,14 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser {
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
TensorCache *tensor_cache, bool quantized_model) override;
};
class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser {
public:
TfliteBatchToSpaceNDParser() : TfliteBatchToSpaceParser() {}
};
} // namespace lite
} // namespace mindspore

@ -26,6 +26,16 @@ STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteBroadcastToParser";
std::unique_ptr<schema::BroadcastToT> attr(new schema::BroadcastToT());
@ -34,11 +44,8 @@ STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &
return RET_ERROR;
}
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_BroadcastTo;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_BroadcastTo;
op->primitive->value.value = attr.release();
return RET_OK;
}

@ -26,6 +26,16 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteCastParser";
std::unique_ptr<schema::CastT> attr(new schema::CastT());
@ -43,11 +53,8 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
}
attr->dstT = dtype_map[out_tensor->type];
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Cast;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Cast;
op->primitive->value.value = attr.release();
return RET_OK;
}

@ -1,41 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_ceil_parser.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
STATUS TfliteCeilParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(DEBUG) << "parse TfliteCeilParser";
std::unique_ptr<schema::CeilT> attr(new schema::CeilT());
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Ceil;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser());
} // namespace lite
} // namespace mindspore

@ -1,42 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PREDICT_TFLITE_CEIL_PARSER_H
#define PREDICT_TFLITE_CEIL_PARSER_H
#include <memory>
#include <vector>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteCeilParser : public TfliteNodeParser {
public:
TfliteCeilParser() : TfliteNodeParser("Ceil") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) override;
};
} // namespace lite
} // namespace mindspore
#endif // PREDICT_TFLITE_CEIL_PARSER_H

@ -25,6 +25,16 @@ STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteConcatParser";
std::unique_ptr<schema::ConcatT> attr(new schema::ConcatT());
@ -37,11 +47,8 @@ STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
attr->n = tfliteOp->inputs.size();
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Concat;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Concat;
op->primitive->value.value = attr.release();
return RET_OK;
}

@ -25,6 +25,16 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
MS_LOG(DEBUG) << "parse TfliteConvParser";
std::unique_ptr<schema::Conv2DT> attr(new schema::Conv2DT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsConv2DOptions();
@ -49,7 +59,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) {
if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}
@ -69,7 +79,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
return RET_NULL_PTR;
}
std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()};
if (RET_OK != ParseBias(bias_tensors, tfliteModelBuffer, tensor_cache)) {
if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) {
MS_LOG(ERROR) << "parse bias failed";
return RET_ERROR;
}
@ -77,11 +87,8 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO
// calculate pad params
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = attr.release();
}
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = attr.release();
return RET_OK;
}

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_CAFFE_CONVERTER_H_
#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_CAFFE_CONVERTER_H_
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_
#include <string>
#include <memory>
@ -34,5 +34,5 @@ class TfliteConverter : public Converter {
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_CAFFE_CONVERTER_H_
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save