parent
5635e40b7c
commit
a6fde99750
@ -0,0 +1,68 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "tools/converter/parser/tf/tf_activation_parser.h"
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||||
|
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||||
|
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
|
||||||
|
MS_LOG(INFO) << "TF ActivationParser";
|
||||||
|
if (primitiveC == nullptr || output_size == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
if (primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitive is nullptr";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
auto attr = std::make_unique<schema::ActivationT>();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new op failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tf_op.op() == "Relu") {
|
||||||
|
attr->type = schema::ActivationType_RELU;
|
||||||
|
} else if (tf_op.op() == "Relu6") {
|
||||||
|
attr->type = schema::ActivationType_RELU6;
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op();
|
||||||
|
}
|
||||||
|
|
||||||
|
primitive->value.type = schema::PrimitiveType_Activation;
|
||||||
|
primitive->value.value = attr.release();
|
||||||
|
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||||
|
if (*primitiveC == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
*output_size = 1;
|
||||||
|
auto status = AddOpInput(tf_op, 0, inputs);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser());
|
||||||
|
TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser());
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,38 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ACTIVATION_PARSER_H_
|
||||||
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ACTIVATION_PARSER_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
class TFActivationParser : public TFNodeParser {
|
||||||
|
public:
|
||||||
|
TFActivationParser() = default;
|
||||||
|
~TFActivationParser() override = default;
|
||||||
|
|
||||||
|
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||||
|
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ACTIVATION_PARSER_H_
|
@ -0,0 +1,93 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "tools/converter/parser/tf/tf_arithmetic_parser.h"
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||||
|
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||||
|
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
|
||||||
|
MS_LOG(INFO) << "TF ArithmeticParser";
|
||||||
|
if (primitiveC == nullptr || output_size == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
if (primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tf_op.op() == "Add") {
|
||||||
|
auto attr = std::make_unique<schema::AddT>();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new attr failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
primitive->value.type = schema::PrimitiveType_Add;
|
||||||
|
primitive->value.value = attr.release();
|
||||||
|
} else if (tf_op.op() == "Sub") {
|
||||||
|
auto attr = std::make_unique<schema::SubT>();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new attr failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
primitive->value.type = schema::PrimitiveType_Sub;
|
||||||
|
primitive->value.value = attr.release();
|
||||||
|
} else if (tf_op.op() == "Mul") {
|
||||||
|
auto attr = std::make_unique<schema::MulT>();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new attr failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
primitive->value.type = schema::PrimitiveType_Mul;
|
||||||
|
primitive->value.value = attr.release();
|
||||||
|
} else if (tf_op.op() == "Div") {
|
||||||
|
auto attr = std::make_unique<schema::DivT>();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new attr failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
primitive->value.type = schema::PrimitiveType_Div;
|
||||||
|
primitive->value.value = attr.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||||
|
if (*primitiveC == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
*output_size = 1;
|
||||||
|
auto status = AddOpInput(tf_op, 0, inputs);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
status = AddOpInput(tf_op, 1, inputs);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
TFNodeRegistrar g_tfAddParser("Add", new TFArithmeticParser());
|
||||||
|
TFNodeRegistrar g_tfSubParser("Sub", new TFArithmeticParser());
|
||||||
|
TFNodeRegistrar g_tfMulParser("Mul", new TFArithmeticParser());
|
||||||
|
TFNodeRegistrar g_tfDivParser("Div", new TFArithmeticParser());
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,36 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_PARSER_H_
|
||||||
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_PARSER_H_
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
class TFArithmeticParser : public TFNodeParser {
|
||||||
|
public:
|
||||||
|
TFArithmeticParser() = default;
|
||||||
|
~TFArithmeticParser() override = default;
|
||||||
|
|
||||||
|
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||||
|
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_PARSER_H_
|
@ -0,0 +1,61 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "tools/converter/parser/tf/tf_biasadd_parser.h"
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
STATUS TFBiasAddParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||||
|
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
|
||||||
|
std::vector<std::string> *inputs, int *output_size) {
|
||||||
|
MS_LOG(INFO) << "TF BiasAddParser";
|
||||||
|
if (primitiveC == nullptr || output_size == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
if (primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
auto attr = std::make_unique<schema::BiasAddT>();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new attr failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
attr->axis = {1};
|
||||||
|
|
||||||
|
primitive->value.type = schema::PrimitiveType_Add;
|
||||||
|
primitive->value.value = attr.release();
|
||||||
|
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||||
|
if (*primitiveC == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
*output_size = 1;
|
||||||
|
auto status = AddOpInput(tf_op, 0, inputs);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
TFNodeRegistrar g_tfBiasAddParser("BiasAdd", new TFBiasAddParser());
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,37 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BIASSADD_PARSER_H_
|
||||||
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BIASSADD_PARSER_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
class TFBiasAddParser : public TFNodeParser {
|
||||||
|
public:
|
||||||
|
TFBiasAddParser() = default;
|
||||||
|
~TFBiasAddParser() override = default;
|
||||||
|
|
||||||
|
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||||
|
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BIASSADD_PARSER_H_
|
@ -0,0 +1,22 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "tools/converter/parser/tf/tf_converter.h"
|
||||||
|
#include "tools/converter/parser/tf/tf_model_parser.h"
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
TFConverter::TFConverter() { modelParser = new TFModelParser(); }
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,70 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "tools/converter/parser/tf/tf_matmul_parser.h"
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
STATUS TFMatMulParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||||
|
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
|
||||||
|
std::vector<std::string> *inputs, int *output_size) {
|
||||||
|
MS_LOG(INFO) << "TF MatMulParser";
|
||||||
|
if (primitiveC == nullptr || output_size == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||||
|
if (primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitive is nullptr";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
auto attr = std::make_unique<schema::MatMulT>();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new op failed";
|
||||||
|
return RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
tensorflow::AttrValue attr_value;
|
||||||
|
if (TensorFlowUtils::FindAttrValue(tf_op, "transpose_a", &attr_value)) {
|
||||||
|
attr->transposeA = attr_value.b();
|
||||||
|
}
|
||||||
|
if (TensorFlowUtils::FindAttrValue(tf_op, "transpose_b", &attr_value)) {
|
||||||
|
attr->transposeB = attr_value.b();
|
||||||
|
}
|
||||||
|
|
||||||
|
primitive->value.type = schema::PrimitiveType_MatMul;
|
||||||
|
primitive->value.value = attr.release();
|
||||||
|
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||||
|
if (*primitiveC == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
*output_size = 1;
|
||||||
|
auto status = AddOpInput(tf_op, 0, inputs);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
status = AddOpInput(tf_op, 1, inputs);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
TFNodeRegistrar g_tfMatMulParser("MatMul", new TFMatMulParser());
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue