!9100 add more tf parsers

From: @wangzhe128
Reviewed-by: @hangangqiang
Signed-off-by:
pull/9100/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 50a265b6f4

@ -46,6 +46,10 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
attr->type = schema::ActivationType_RELU; attr->type = schema::ActivationType_RELU;
} else if (tf_op.op() == "Relu6") { } else if (tf_op.op() == "Relu6") {
attr->type = schema::ActivationType_RELU6; attr->type = schema::ActivationType_RELU6;
} else if (tf_op.op() == "Sigmoid") {
attr->type = schema::ActivationType_SIGMOID;
} else if (tf_op.op() == "Tanh") {
attr->type = schema::ActivationType_TANH;
} else { } else {
MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op(); MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op();
} }
@ -64,5 +68,7 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
} }
TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser()); TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser());
TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser()); TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser());
TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser());
TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser());
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -69,6 +69,54 @@ STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op,
} }
primitive->value.type = schema::PrimitiveType_Div; primitive->value.type = schema::PrimitiveType_Div;
primitive->value.value = attr.release(); primitive->value.value = attr.release();
} else if (tf_op.op() == "Maximum") {
auto attr = std::make_unique<schema::MaximumT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Maximum;
primitive->value.value = attr.release();
} else if (tf_op.op() == "Minimum") {
auto attr = std::make_unique<schema::MinimumT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Minimum;
primitive->value.value = attr.release();
} else if (tf_op.op() == "Greater") {
auto attr = std::make_unique<schema::GreaterT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Greater;
primitive->value.value = attr.release();
} else if (tf_op.op() == "GreaterEqual") {
auto attr = std::make_unique<schema::GreaterEqualT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_GreaterEqual;
primitive->value.value = attr.release();
} else if (tf_op.op() == "Less") {
auto attr = std::make_unique<schema::LessT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Less;
primitive->value.value = attr.release();
} else if (tf_op.op() == "LessEqual") {
auto attr = std::make_unique<schema::LessEqualT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_LessEqual;
primitive->value.value = attr.release();
} }
*primitiveC = PrimitiveC::Create(primitive.release()); *primitiveC = PrimitiveC::Create(primitive.release());
@ -86,8 +134,15 @@ STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op,
return status; return status;
} }
TFNodeRegistrar g_tfAddParser("Add", new TFArithmeticParser()); TFNodeRegistrar g_tfAddParser("Add", new TFArithmeticParser());
TFNodeRegistrar g_tfAddV2Parser("AddV2", new TFArithmeticParser());
TFNodeRegistrar g_tfSubParser("Sub", new TFArithmeticParser()); TFNodeRegistrar g_tfSubParser("Sub", new TFArithmeticParser());
TFNodeRegistrar g_tfMulParser("Mul", new TFArithmeticParser()); TFNodeRegistrar g_tfMulParser("Mul", new TFArithmeticParser());
TFNodeRegistrar g_tfDivParser("Div", new TFArithmeticParser()); TFNodeRegistrar g_tfDivParser("Div", new TFArithmeticParser());
TFNodeRegistrar g_tfMaximumParser("Maximum", new TFArithmeticParser());
TFNodeRegistrar g_tfMinimumParser("Minimum", new TFArithmeticParser());
TFNodeRegistrar g_tfGreaterParser("Greater", new TFArithmeticParser());
TFNodeRegistrar g_tfGreaterEqualParser("GreaterEqual", new TFArithmeticParser());
TFNodeRegistrar g_tfLessParser("Less", new TFArithmeticParser());
TFNodeRegistrar g_tfLessEqualParser("LessEqual", new TFArithmeticParser());
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -44,7 +44,7 @@ STATUS TFBiasAddParser::Parse(const tensorflow::NodeDef &tf_op,
attr->axis = {1}; attr->axis = {1};
primitive->value.type = schema::PrimitiveType_Add; primitive->value.type = schema::PrimitiveType_BiasAdd;
primitive->value.value = attr.release(); primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release()); *primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) { if (*primitiveC == nullptr) {

@ -0,0 +1,72 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_cast_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFCastParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF CastParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::CastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
auto src_type = TensorFlowUtils::ParseAttrDataType(tf_op, "SrcT");
if (src_type == kTypeUnknown) {
MS_LOG(ERROR) << "Get attr SrcT failed";
return RET_ERROR;
}
auto dst_type = TensorFlowUtils::ParseAttrDataType(tf_op, "DstT");
if (dst_type == kTypeUnknown) {
MS_LOG(ERROR) << "Get attr DstT failed";
return RET_ERROR;
}
attr->srcT = src_type;
attr->dstT = dst_type;
primitive->value.type = schema::PrimitiveType_Cast;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
return status;
}
TFNodeRegistrar g_tfCastParser("Cast", new TFCastParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CAST_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CAST_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFCastParser : public TFNodeParser {
public:
TFCastParser() = default;
~TFCastParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CAST_PARSER_H_

@ -0,0 +1,83 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_concat_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFConcatParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF ConcatParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::ConcatT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
if (tf_node_map.find(tf_op.input(tf_op.input_size() - 1)) == tf_node_map.end()) {
MS_LOG(ERROR) << "Find Concat input axis failed";
return RET_ERROR;
}
auto axis_node = tf_node_map.at(tf_op.input(tf_op.input_size() - 1));
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";
return RET_ERROR;
}
auto tensor_proto = attr_value.tensor();
attr->axis = tensor_proto.int_val(0);
if (!TensorFlowUtils::FindAttrValue(tf_op, "N", &attr_value)) {
MS_LOG(ERROR) << "The N attr should be specified";
return RET_ERROR;
}
attr->n = (int32_t)attr_value.i();
primitive->value.type = schema::PrimitiveType_Concat;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size() - 1; ++i) {
auto status = AddOpInput(tf_op, i, inputs);
if (status != RET_OK) {
return status;
}
}
return RET_OK;
}
TFNodeRegistrar g_tfConcatV2Parser("ConcatV2", new TFConcatParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONCAT_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONCAT_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFConcatParser : public TFNodeParser {
public:
TFConcatParser() = default;
~TFConcatParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONCAT_PARSER_H_

@ -0,0 +1,102 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_conv_base_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace lite {
namespace {
const uint32_t STRIDE_DEFAULT_VALUE = 1;
const uint32_t DILATION_DEFAULT_VALUE = 1;
} // namespace
STATUS TFConvBaseParser::ParseStrides(const tensorflow::NodeDef &node_def, const schema::Format &format,
std::vector<int64_t> *strides) {
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(node_def, "strides", &attr_value)) {
strides->at(0) = STRIDE_DEFAULT_VALUE;
strides->at(1) = STRIDE_DEFAULT_VALUE;
} else {
auto stride_list = attr_value.list();
if (format == schema::Format_NHWC) {
strides->at(0) = stride_list.i(1);
strides->at(1) = stride_list.i(2);
} else {
strides->at(0) = stride_list.i(2);
strides->at(1) = stride_list.i(3);
}
}
return RET_OK;
}
STATUS TFConvBaseParser::ParseDilations(const tensorflow::NodeDef &node_def, const schema::Format &format,
std::vector<int64_t> *dilations) {
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(node_def, "dilations", &attr_value)) {
dilations->at(0) = DILATION_DEFAULT_VALUE;
dilations->at(1) = DILATION_DEFAULT_VALUE;
} else {
auto dilation_list = attr_value.list();
if (format == schema::Format_NHWC) {
dilations->at(0) = dilation_list.i(1);
dilations->at(1) = dilation_list.i(2);
} else {
dilations->at(0) = dilation_list.i(2);
dilations->at(1) = dilation_list.i(3);
}
}
return RET_OK;
}
STATUS TFConvBaseParser::ParseKernels(const tensorflow::NodeDef &node_def, const schema::Format &format,
std::vector<int64_t> *kernel) {
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(node_def, "value", &attr_value)) {
MS_LOG(ERROR) << "The kernels should be specified";
return RET_PARAM_INVALID;
}
auto shape = attr_value.tensor().tensor_shape();
if (shape.dim().size() != 4) {
MS_LOG(ERROR) << "Dims of Kernel should be 4.";
return RET_PARAM_INVALID;
}
kernel->at(0) = shape.dim(0).size();
kernel->at(1) = shape.dim(1).size();
kernel->at(2) = shape.dim(2).size();
kernel->at(3) = shape.dim(3).size();
return RET_OK;
}
STATUS TFConvBaseParser::ParsePadMode(const tensorflow::NodeDef &node_def, schema::PadMode *pad_mode) {
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(node_def, "padding", &attr_value)) {
MS_LOG(ERROR) << "The attr padding should be specified";
return RET_PARAM_INVALID;
}
if (attr_value.s() == "VALID") {
*pad_mode = schema::PadMode_VALID;
} else if (attr_value.s() == "SAME") {
*pad_mode = schema::PadMode_SAME_UPPER;
} else {
*pad_mode = schema::PadMode_NOTSET;
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_BASE_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_BASE_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFConvBaseParser : public TFNodeParser {
public:
TFConvBaseParser() = default;
~TFConvBaseParser() override = default;
STATUS ParseStrides(const tensorflow::NodeDef &node_def, const schema::Format &format, std::vector<int64_t> *strides);
STATUS ParseDilations(const tensorflow::NodeDef &node_def, const schema::Format &format,
std::vector<int64_t> *dilations);
STATUS ParseKernels(const tensorflow::NodeDef &node_def, const schema::Format &format, std::vector<int64_t> *kernel);
STATUS ParsePadMode(const tensorflow::NodeDef &node_def, schema::PadMode *pad_mode);
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_BASE_PARSER_H_

@ -0,0 +1,103 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_conv_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/parser/tf/tf_util.h"
namespace mindspore {
namespace lite {
STATUS TFConvParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF ConvParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::Conv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
attr->group = 1;
attr->format = TensorFlowUtils::ParseNodeFormat(tf_op);
std::vector<int64_t> dilations(2);
auto status = ParseDilations(tf_op, attr->format, &dilations);
if (status != RET_OK) {
return status;
}
attr->dilateH = dilations[0];
attr->dilateW = dilations[1];
std::vector<int64_t> strides(2);
status = ParseStrides(tf_op, attr->format, &strides);
if (status != RET_OK) {
return status;
}
attr->strideH = strides[0];
attr->strideW = strides[1];
if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) {
MS_LOG(ERROR) << "Find Conv2D input weights failed";
return RET_ERROR;
}
auto weight_node = tf_node_map.at(tf_op.input(1));
std::vector<int64_t> kernels(4);
status = ParseKernels(*weight_node, attr->format, &kernels);
if (status != RET_OK) {
return status;
}
attr->kernelH = kernels[0];
attr->kernelW = kernels[1];
attr->channelIn = kernels[2];
attr->channelOut = kernels[3];
status = ParsePadMode(tf_op, &attr->padMode);
if (status != RET_OK) {
return status;
}
primitive->value.type = schema::PrimitiveType_Conv2D;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
status = AddOpInput(tf_op, 0, inputs);
if (status != RET_OK) {
return status;
}
status = AddOpInput(tf_op, 1, inputs); // weights
return status;
}
TFNodeRegistrar g_tfConvParser("Conv2D", new TFConvParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_conv_base_parser.h"
namespace mindspore {
namespace lite {
class TFConvParser : public TFConvBaseParser {
public:
TFConvParser() = default;
~TFConvParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_PARSER_H_

@ -0,0 +1,76 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_expand_dims_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFExpandDimsParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF ExpandDimsParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::ExpandDimsT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) {
MS_LOG(ERROR) << "Find ExpandDims input axis failed";
return RET_ERROR;
}
auto axis_node = tf_node_map.at(tf_op.input(1));
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";
return RET_ERROR;
}
auto tensor_proto = attr_value.tensor();
if (tensor_proto.int_val_size() > 0) {
attr->dim = tensor_proto.int_val(0);
} else {
attr->dim = (reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data()))[0];
}
primitive->value.type = schema::PrimitiveType_ExpandDims;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
return status;
}
TFNodeRegistrar g_tfExpandDimsParser("ExpandDims", new TFExpandDimsParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_EXPAND_DIMS_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_EXPAND_DIMS_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFExpandDimsParser : public TFNodeParser {
public:
TFExpandDimsParser() = default;
~TFExpandDimsParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_EXPAND_DIMS_PARSER_H_

@ -0,0 +1,102 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_gather_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFGatherParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF GatherParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::GatherT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
tensorflow::AttrValue attr_value;
if (TensorFlowUtils::FindAttrValue(tf_op, "batch_dims", &attr_value)) {
attr->batchDims = attr_value.i();
}
bool axis_is_set = false;
if (tf_op.input_size() == 3) {
axis_is_set = true;
if (tf_node_map.find(tf_op.input(2)) == tf_node_map.end()) {
MS_LOG(ERROR) << "Find Gather input axis failed";
return RET_ERROR;
}
auto axis_node = tf_node_map.at(tf_op.input(2));
if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";
return RET_ERROR;
}
auto tensor_proto = attr_value.tensor();
if (tensor_proto.dtype() == tensorflow::DT_INT32) {
if (tensor_proto.int_val_size() > 0) {
attr->axis = tensor_proto.int_val(0);
} else {
attr->axis = (reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data()))[0];
}
} else if (tensor_proto.dtype() == tensorflow::DT_INT64) {
if (tensor_proto.int64_val_size() > 0) {
attr->axis = tensor_proto.int64_val(0);
} else {
attr->axis = (reinterpret_cast<const int64_t *>(tensor_proto.tensor_content().data()))[0];
}
} else {
MS_LOG(ERROR) << "axis must be int32 or int64";
return RET_ERROR;
}
}
if (attr->batchDims != 0 && !axis_is_set) {
attr->axis = attr->batchDims;
}
primitive->value.type = schema::PrimitiveType_Gather;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
if (status != RET_OK) {
return status;
}
status = AddOpInput(tf_op, 1, inputs);
return status;
}
TFNodeRegistrar g_tfGatherV2Parser("GatherV2", new TFGatherParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_GATHER_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_GATHER_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFGatherParser : public TFNodeParser {
public:
TFGatherParser() = default;
~TFGatherParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_GATHER_PARSER_H_

@ -27,22 +27,6 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
static const std::unordered_map<int, mindspore::TypeId> TF_TYPE_MAP = {
{tensorflow::DT_INT8, mindspore::kNumberTypeInt8}, {tensorflow::DT_UINT8, mindspore::kNumberTypeUInt8},
{tensorflow::DT_INT16, mindspore::kNumberTypeInt16}, {tensorflow::DT_UINT16, mindspore::kNumberTypeUInt16},
{tensorflow::DT_INT32, mindspore::kNumberTypeInt32}, {tensorflow::DT_INT64, mindspore::kNumberTypeInt64},
{tensorflow::DT_HALF, mindspore::kNumberTypeFloat16}, {tensorflow::DT_FLOAT, mindspore::kNumberTypeFloat32},
{tensorflow::DT_DOUBLE, mindspore::kNumberTypeFloat64}, {tensorflow::DT_COMPLEX64, mindspore::kNumberTypeComplex64},
{tensorflow::DT_BOOL, mindspore::kNumberTypeBool}};
TypeId GetTFDataType(const tensorflow::DataType &tf_data_type) {
auto iter = TF_TYPE_MAP.find(tf_data_type);
if (iter == TF_TYPE_MAP.end()) {
MS_LOG(ERROR) << "unsupported TF data type: " << tf_data_type;
return kTypeUnknown;
}
return iter->second;
}
AnfNodePtr TFModelParser::GetAnfNode(const std::string &name) { AnfNodePtr TFModelParser::GetAnfNode(const std::string &name) {
AnfNodePtr ret = nullptr; AnfNodePtr ret = nullptr;
@ -151,7 +135,7 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa
tensorflow::AttrValue attr_value; tensorflow::AttrValue attr_value;
TypeId type = kNumberTypeFloat32; TypeId type = kNumberTypeFloat32;
if (TensorFlowUtils::FindAttrValue(node, "dtype", &attr_value)) { if (TensorFlowUtils::FindAttrValue(node, "dtype", &attr_value)) {
type = GetTFDataType(attr_value.type()); type = TensorFlowUtils::GetTFDataType(attr_value.type());
} }
auto type_ptr = TypeIdToType(type); auto type_ptr = TypeIdToType(type);
@ -413,7 +397,7 @@ STATUS TFModelParser::ConvertGraphOutputs() {
auto cnode = funcGraphPtr->NewCNode(op_inputs); auto cnode = funcGraphPtr->NewCNode(op_inputs);
cnode->set_fullname_with_scope("return"); cnode->set_fullname_with_scope("return");
funcGraphPtr->set_return(cnode); funcGraphPtr->set_return(cnode);
} else { } else if (output_nodes.size() == 1) {
auto return_prim_ptr = GetReturnPrim(); auto return_prim_ptr = GetReturnPrim();
if (return_prim_ptr == nullptr) { if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetReturnPrim return nullptr"; MS_LOG(ERROR) << "GetReturnPrim return nullptr";

@ -15,9 +15,7 @@
*/ */
#include "tools/converter/parser/tf/tf_node_parser.h" #include "tools/converter/parser/tf/tf_node_parser.h"
#include <string> #include <string>
#include <memory>
#include <vector> #include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {

@ -0,0 +1,77 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_pack_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFPackParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF PackParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::StackT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "axis", &attr_value)) {
MS_LOG(ERROR) << "The axis attr should be specified";
return RET_ERROR;
}
attr->axis = static_cast<int32_t>(attr_value.i());
if (!TensorFlowUtils::FindAttrValue(tf_op, "N", &attr_value)) {
MS_LOG(ERROR) << "The axis attr should be specified";
return RET_ERROR;
}
attr->n = static_cast<int32_t>(attr_value.i());
primitive->value.type = schema::PrimitiveType_Stack;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); ++i) {
auto status = AddOpInput(tf_op, i, inputs);
if (status != RET_OK) {
return status;
}
}
return RET_OK;
}
TFNodeRegistrar g_tfPackParser("Pack", new TFPackParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_PACK_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_PACK_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFPackParser : public TFNodeParser {
public:
TFPackParser() = default;
~TFPackParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_PACK_PARSER_H_

@ -0,0 +1,110 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_reduce_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFReduceParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF ReduceParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::ReduceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
if (tf_op.op() == "Sum") {
attr->mode = schema::ReduceMode_ReduceSum;
} else if (tf_op.op() == "Max") {
attr->mode = schema::ReduceMode_ReduceMax;
} else if (tf_op.op() == "Min") {
attr->mode = schema::ReduceMode_ReduceMin;
} else if (tf_op.op() == "Mean") {
attr->mode = schema::ReduceMode_ReduceMean;
} else if (tf_op.op() == "Prod") {
attr->mode = schema::ReduceMode_ReduceProd;
} else {
MS_LOG(ERROR) << "unsupported reduce mode: " << tf_op.op();
return RET_ERROR;
}
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "keep_dims", &attr_value)) {
MS_LOG(ERROR) << "The keep_dims attr should be specified";
return RET_ERROR;
}
if (attr_value.value_case() != tensorflow::AttrValue::kB) {
MS_LOG(ERROR) << "the keep_dims attr of reduce should be bool type";
return RET_ERROR;
}
attr->keepDims = attr_value.b();
if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) {
MS_LOG(ERROR) << "Find Reduce input axis failed";
return RET_ERROR;
}
auto axis_node = tf_node_map.at(tf_op.input(1));
if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";
return RET_ERROR;
}
auto tensor_proto = attr_value.tensor();
if (tensor_proto.int_val_size() > 0) {
for (int i = 0; i < tensor_proto.int_val_size(); ++i) {
attr->axes.push_back(tensor_proto.int_val(i));
}
} else {
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t);
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data());
for (size_t i = 0; i < data_num; ++i) {
attr->axes.push_back(data[i]);
}
}
primitive->value.type = schema::PrimitiveType_Reduce;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
return status;
}
TFNodeRegistrar g_tfSumParser("Sum", new TFReduceParser());
TFNodeRegistrar g_tfMaxParser("Max", new TFReduceParser());
TFNodeRegistrar g_tfMinParser("Min", new TFReduceParser());
TFNodeRegistrar g_tfMeanParser("Mean", new TFReduceParser());
TFNodeRegistrar g_tfProdParser("Prod", new TFReduceParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REDUCE_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REDUCE_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFReduceParser : public TFNodeParser {
public:
TFReduceParser() = default;
~TFReduceParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REDUCE_PARSER_H_

@ -0,0 +1,66 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_reshape_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFReshapeParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF ReshapeParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::ReshapeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
attr->format = schema::Format_NHWC;
// attr->shape is omitted cause input[1] provide shape info
primitive->value.type = schema::PrimitiveType_Reshape;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
if (status != RET_OK) {
return status;
}
status = AddOpInput(tf_op, 1, inputs);
return status;
}
TFNodeRegistrar g_tfReshapeParser("Reshape", new TFReshapeParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RESHAPE_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RESHAPE_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFReshapeParser : public TFNodeParser {
public:
TFReshapeParser() = default;
~TFReshapeParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RESHAPE_PARSER_H_

@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_round_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFRoundParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF RoundParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::RoundT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Round;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
return status;
}
TFNodeRegistrar g_tfRoundParser("Round", new TFRoundParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFRoundParser : public TFNodeParser {
public:
TFRoundParser() = default;
~TFRoundParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save