!9100 add more tf parsers
From: @wangzhe128 Reviewed-by: @hangangqiang Signed-off-by:pull/9100/MERGE
commit
50a265b6f4
@ -0,0 +1,72 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_cast_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFCastParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(INFO) << "TF CastParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::CastT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto src_type = TensorFlowUtils::ParseAttrDataType(tf_op, "SrcT");
|
||||
if (src_type == kTypeUnknown) {
|
||||
MS_LOG(ERROR) << "Get attr SrcT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto dst_type = TensorFlowUtils::ParseAttrDataType(tf_op, "DstT");
|
||||
if (dst_type == kTypeUnknown) {
|
||||
MS_LOG(ERROR) << "Get attr DstT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->srcT = src_type;
|
||||
attr->dstT = dst_type;
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Cast;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
if (*primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
auto status = AddOpInput(tf_op, 0, inputs);
|
||||
return status;
|
||||
}
|
||||
TFNodeRegistrar g_tfCastParser("Cast", new TFCastParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,36 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CAST_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CAST_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFCastParser : public TFNodeParser {
|
||||
public:
|
||||
TFCastParser() = default;
|
||||
~TFCastParser() override = default;
|
||||
|
||||
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CAST_PARSER_H_
|
@ -0,0 +1,83 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_concat_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFConcatParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(INFO) << "TF ConcatParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::ConcatT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (tf_node_map.find(tf_op.input(tf_op.input_size() - 1)) == tf_node_map.end()) {
|
||||
MS_LOG(ERROR) << "Find Concat input axis failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto axis_node = tf_node_map.at(tf_op.input(tf_op.input_size() - 1));
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The value attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto tensor_proto = attr_value.tensor();
|
||||
attr->axis = tensor_proto.int_val(0);
|
||||
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "N", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The N attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->n = (int32_t)attr_value.i();
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Concat;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
if (*primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
for (int i = 0; i < tf_op.input_size() - 1; ++i) {
|
||||
auto status = AddOpInput(tf_op, i, inputs);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
TFNodeRegistrar g_tfConcatV2Parser("ConcatV2", new TFConcatParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,37 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONCAT_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONCAT_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFConcatParser : public TFNodeParser {
|
||||
public:
|
||||
TFConcatParser() = default;
|
||||
~TFConcatParser() override = default;
|
||||
|
||||
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONCAT_PARSER_H_
|
@ -0,0 +1,102 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_conv_base_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
const uint32_t STRIDE_DEFAULT_VALUE = 1;
|
||||
const uint32_t DILATION_DEFAULT_VALUE = 1;
|
||||
} // namespace
|
||||
STATUS TFConvBaseParser::ParseStrides(const tensorflow::NodeDef &node_def, const schema::Format &format,
|
||||
std::vector<int64_t> *strides) {
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(node_def, "strides", &attr_value)) {
|
||||
strides->at(0) = STRIDE_DEFAULT_VALUE;
|
||||
strides->at(1) = STRIDE_DEFAULT_VALUE;
|
||||
} else {
|
||||
auto stride_list = attr_value.list();
|
||||
if (format == schema::Format_NHWC) {
|
||||
strides->at(0) = stride_list.i(1);
|
||||
strides->at(1) = stride_list.i(2);
|
||||
} else {
|
||||
strides->at(0) = stride_list.i(2);
|
||||
strides->at(1) = stride_list.i(3);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TFConvBaseParser::ParseDilations(const tensorflow::NodeDef &node_def, const schema::Format &format,
|
||||
std::vector<int64_t> *dilations) {
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(node_def, "dilations", &attr_value)) {
|
||||
dilations->at(0) = DILATION_DEFAULT_VALUE;
|
||||
dilations->at(1) = DILATION_DEFAULT_VALUE;
|
||||
} else {
|
||||
auto dilation_list = attr_value.list();
|
||||
if (format == schema::Format_NHWC) {
|
||||
dilations->at(0) = dilation_list.i(1);
|
||||
dilations->at(1) = dilation_list.i(2);
|
||||
} else {
|
||||
dilations->at(0) = dilation_list.i(2);
|
||||
dilations->at(1) = dilation_list.i(3);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TFConvBaseParser::ParseKernels(const tensorflow::NodeDef &node_def, const schema::Format &format,
|
||||
std::vector<int64_t> *kernel) {
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(node_def, "value", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The kernels should be specified";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto shape = attr_value.tensor().tensor_shape();
|
||||
if (shape.dim().size() != 4) {
|
||||
MS_LOG(ERROR) << "Dims of Kernel should be 4.";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
kernel->at(0) = shape.dim(0).size();
|
||||
kernel->at(1) = shape.dim(1).size();
|
||||
kernel->at(2) = shape.dim(2).size();
|
||||
kernel->at(3) = shape.dim(3).size();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TFConvBaseParser::ParsePadMode(const tensorflow::NodeDef &node_def, schema::PadMode *pad_mode) {
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(node_def, "padding", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The attr padding should be specified";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
if (attr_value.s() == "VALID") {
|
||||
*pad_mode = schema::PadMode_VALID;
|
||||
} else if (attr_value.s() == "SAME") {
|
||||
*pad_mode = schema::PadMode_SAME_UPPER;
|
||||
} else {
|
||||
*pad_mode = schema::PadMode_NOTSET;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,39 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_BASE_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_BASE_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFConvBaseParser : public TFNodeParser {
|
||||
public:
|
||||
TFConvBaseParser() = default;
|
||||
~TFConvBaseParser() override = default;
|
||||
STATUS ParseStrides(const tensorflow::NodeDef &node_def, const schema::Format &format, std::vector<int64_t> *strides);
|
||||
STATUS ParseDilations(const tensorflow::NodeDef &node_def, const schema::Format &format,
|
||||
std::vector<int64_t> *dilations);
|
||||
STATUS ParseKernels(const tensorflow::NodeDef &node_def, const schema::Format &format, std::vector<int64_t> *kernel);
|
||||
STATUS ParsePadMode(const tensorflow::NodeDef &node_def, schema::PadMode *pad_mode);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_BASE_PARSER_H_
|
@ -0,0 +1,103 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_conv_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include "tools/converter/parser/tf/tf_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFConvParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(INFO) << "TF ConvParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::Conv2DT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
attr->group = 1;
|
||||
attr->format = TensorFlowUtils::ParseNodeFormat(tf_op);
|
||||
|
||||
std::vector<int64_t> dilations(2);
|
||||
auto status = ParseDilations(tf_op, attr->format, &dilations);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
attr->dilateH = dilations[0];
|
||||
attr->dilateW = dilations[1];
|
||||
|
||||
std::vector<int64_t> strides(2);
|
||||
status = ParseStrides(tf_op, attr->format, &strides);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
attr->strideH = strides[0];
|
||||
attr->strideW = strides[1];
|
||||
|
||||
if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) {
|
||||
MS_LOG(ERROR) << "Find Conv2D input weights failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto weight_node = tf_node_map.at(tf_op.input(1));
|
||||
std::vector<int64_t> kernels(4);
|
||||
status = ParseKernels(*weight_node, attr->format, &kernels);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
attr->kernelH = kernels[0];
|
||||
attr->kernelW = kernels[1];
|
||||
attr->channelIn = kernels[2];
|
||||
attr->channelOut = kernels[3];
|
||||
|
||||
status = ParsePadMode(tf_op, &attr->padMode);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
if (*primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
status = AddOpInput(tf_op, 0, inputs);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
status = AddOpInput(tf_op, 1, inputs); // weights
|
||||
return status;
|
||||
}
|
||||
TFNodeRegistrar g_tfConvParser("Conv2D", new TFConvParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,36 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_PARSER_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_conv_base_parser.h"
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFConvParser : public TFConvBaseParser {
|
||||
public:
|
||||
TFConvParser() = default;
|
||||
~TFConvParser() override = default;
|
||||
|
||||
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONV_PARSER_H_
|
@ -0,0 +1,76 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_expand_dims_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFExpandDimsParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(INFO) << "TF ExpandDimsParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::ExpandDimsT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) {
|
||||
MS_LOG(ERROR) << "Find ExpandDims input axis failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto axis_node = tf_node_map.at(tf_op.input(1));
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The value attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto tensor_proto = attr_value.tensor();
|
||||
if (tensor_proto.int_val_size() > 0) {
|
||||
attr->dim = tensor_proto.int_val(0);
|
||||
} else {
|
||||
attr->dim = (reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data()))[0];
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_ExpandDims;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
if (*primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
auto status = AddOpInput(tf_op, 0, inputs);
|
||||
return status;
|
||||
}
|
||||
TFNodeRegistrar g_tfExpandDimsParser("ExpandDims", new TFExpandDimsParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,37 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_EXPAND_DIMS_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_EXPAND_DIMS_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFExpandDimsParser : public TFNodeParser {
|
||||
public:
|
||||
TFExpandDimsParser() = default;
|
||||
~TFExpandDimsParser() override = default;
|
||||
|
||||
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_EXPAND_DIMS_PARSER_H_
|
@ -0,0 +1,102 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_gather_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFGatherParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(INFO) << "TF GatherParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::GatherT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (TensorFlowUtils::FindAttrValue(tf_op, "batch_dims", &attr_value)) {
|
||||
attr->batchDims = attr_value.i();
|
||||
}
|
||||
|
||||
bool axis_is_set = false;
|
||||
if (tf_op.input_size() == 3) {
|
||||
axis_is_set = true;
|
||||
if (tf_node_map.find(tf_op.input(2)) == tf_node_map.end()) {
|
||||
MS_LOG(ERROR) << "Find Gather input axis failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto axis_node = tf_node_map.at(tf_op.input(2));
|
||||
if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The value attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto tensor_proto = attr_value.tensor();
|
||||
if (tensor_proto.dtype() == tensorflow::DT_INT32) {
|
||||
if (tensor_proto.int_val_size() > 0) {
|
||||
attr->axis = tensor_proto.int_val(0);
|
||||
} else {
|
||||
attr->axis = (reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data()))[0];
|
||||
}
|
||||
} else if (tensor_proto.dtype() == tensorflow::DT_INT64) {
|
||||
if (tensor_proto.int64_val_size() > 0) {
|
||||
attr->axis = tensor_proto.int64_val(0);
|
||||
} else {
|
||||
attr->axis = (reinterpret_cast<const int64_t *>(tensor_proto.tensor_content().data()))[0];
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "axis must be int32 or int64";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
if (attr->batchDims != 0 && !axis_is_set) {
|
||||
attr->axis = attr->batchDims;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Gather;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
if (*primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
auto status = AddOpInput(tf_op, 0, inputs);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
status = AddOpInput(tf_op, 1, inputs);
|
||||
return status;
|
||||
}
|
||||
TFNodeRegistrar g_tfGatherV2Parser("GatherV2", new TFGatherParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,37 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_GATHER_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_GATHER_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFGatherParser : public TFNodeParser {
|
||||
public:
|
||||
TFGatherParser() = default;
|
||||
~TFGatherParser() override = default;
|
||||
|
||||
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_GATHER_PARSER_H_
|
@ -0,0 +1,77 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_pack_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFPackParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(INFO) << "TF PackParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::StackT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "axis", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The axis attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->axis = static_cast<int32_t>(attr_value.i());
|
||||
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "N", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The axis attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->n = static_cast<int32_t>(attr_value.i());
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Stack;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
if (*primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
for (int i = 0; i < tf_op.input_size(); ++i) {
|
||||
auto status = AddOpInput(tf_op, i, inputs);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
TFNodeRegistrar g_tfPackParser("Pack", new TFPackParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,36 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_PACK_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_PACK_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFPackParser : public TFNodeParser {
|
||||
public:
|
||||
TFPackParser() = default;
|
||||
~TFPackParser() override = default;
|
||||
|
||||
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_PACK_PARSER_H_
|
@ -0,0 +1,110 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_reduce_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFReduceParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(INFO) << "TF ReduceParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::ReduceT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (tf_op.op() == "Sum") {
|
||||
attr->mode = schema::ReduceMode_ReduceSum;
|
||||
} else if (tf_op.op() == "Max") {
|
||||
attr->mode = schema::ReduceMode_ReduceMax;
|
||||
} else if (tf_op.op() == "Min") {
|
||||
attr->mode = schema::ReduceMode_ReduceMin;
|
||||
} else if (tf_op.op() == "Mean") {
|
||||
attr->mode = schema::ReduceMode_ReduceMean;
|
||||
} else if (tf_op.op() == "Prod") {
|
||||
attr->mode = schema::ReduceMode_ReduceProd;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported reduce mode: " << tf_op.op();
|
||||
return RET_ERROR;
|
||||
}
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "keep_dims", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The keep_dims attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (attr_value.value_case() != tensorflow::AttrValue::kB) {
|
||||
MS_LOG(ERROR) << "the keep_dims attr of reduce should be bool type";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->keepDims = attr_value.b();
|
||||
|
||||
if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) {
|
||||
MS_LOG(ERROR) << "Find Reduce input axis failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto axis_node = tf_node_map.at(tf_op.input(1));
|
||||
if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The value attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto tensor_proto = attr_value.tensor();
|
||||
if (tensor_proto.int_val_size() > 0) {
|
||||
for (int i = 0; i < tensor_proto.int_val_size(); ++i) {
|
||||
attr->axes.push_back(tensor_proto.int_val(i));
|
||||
}
|
||||
} else {
|
||||
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t);
|
||||
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data());
|
||||
for (size_t i = 0; i < data_num; ++i) {
|
||||
attr->axes.push_back(data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Reduce;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
if (*primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
auto status = AddOpInput(tf_op, 0, inputs);
|
||||
return status;
|
||||
}
|
||||
TFNodeRegistrar g_tfSumParser("Sum", new TFReduceParser());
|
||||
TFNodeRegistrar g_tfMaxParser("Max", new TFReduceParser());
|
||||
TFNodeRegistrar g_tfMinParser("Min", new TFReduceParser());
|
||||
TFNodeRegistrar g_tfMeanParser("Mean", new TFReduceParser());
|
||||
TFNodeRegistrar g_tfProdParser("Prod", new TFReduceParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,36 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REDUCE_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REDUCE_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFReduceParser : public TFNodeParser {
|
||||
public:
|
||||
TFReduceParser() = default;
|
||||
~TFReduceParser() override = default;
|
||||
|
||||
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REDUCE_PARSER_H_
|
@ -0,0 +1,66 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_reshape_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFReshapeParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(INFO) << "TF ReshapeParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::ReshapeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
attr->format = schema::Format_NHWC;
|
||||
// attr->shape is omitted cause input[1] provide shape info
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Reshape;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
if (*primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
auto status = AddOpInput(tf_op, 0, inputs);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
status = AddOpInput(tf_op, 1, inputs);
|
||||
return status;
|
||||
}
|
||||
TFNodeRegistrar g_tfReshapeParser("Reshape", new TFReshapeParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,36 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RESHAPE_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RESHAPE_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFReshapeParser : public TFNodeParser {
|
||||
public:
|
||||
TFReshapeParser() = default;
|
||||
~TFReshapeParser() override = default;
|
||||
|
||||
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RESHAPE_PARSER_H_
|
@ -0,0 +1,59 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_round_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFRoundParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(INFO) << "TF RoundParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::RoundT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Round;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
if (*primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
auto status = AddOpInput(tf_op, 0, inputs);
|
||||
return status;
|
||||
}
|
||||
TFNodeRegistrar g_tfRoundParser("Round", new TFRoundParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,36 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFRoundParser : public TFNodeParser {
|
||||
public:
|
||||
TFRoundParser() = default;
|
||||
~TFRoundParser() override = default;
|
||||
|
||||
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue