pull/11972/head
yefeng 4 years ago
parent f1009cb21b
commit 1a22007f1f

@ -61,6 +61,8 @@ STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op,
status = CreateOperator<schema::LogT>(primitive, schema::PrimitiveType_Log); status = CreateOperator<schema::LogT>(primitive, schema::PrimitiveType_Log);
} else if (tf_op.op() == "Sqrt") { } else if (tf_op.op() == "Sqrt") {
status = CreateOperator<schema::SqrtT>(primitive, schema::PrimitiveType_Sqrt); status = CreateOperator<schema::SqrtT>(primitive, schema::PrimitiveType_Sqrt);
} else if (tf_op.op() == "Pow") {
status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Power);
} }
if (status != RET_OK) { if (status != RET_OK) {
return status; return status;
@ -84,5 +86,6 @@ TFNodeRegistrar g_tfExpParser("Exp", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfFloorParser("Floor", new TFArithmeticSelfParser()); TFNodeRegistrar g_tfFloorParser("Floor", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfLogParser("Log", new TFArithmeticSelfParser()); TFNodeRegistrar g_tfLogParser("Log", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfSqrtParser("Sqrt", new TFArithmeticSelfParser()); TFNodeRegistrar g_tfSqrtParser("Sqrt", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfPowParser("Pow", new TFArithmeticSelfParser());
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -0,0 +1,67 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_one_hot_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFOneHotParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF OneHotParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::OneHotT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "axis", &attr_value)) {
MS_LOG(ERROR) << "The axis attr should be specified";
return RET_ERROR;
}
attr->axis = static_cast<int32_t>(attr_value.i());
primitive->value.type = schema::PrimitiveType_OneHot;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); ++i) {
auto status = AddOpInput(tf_op, i, inputs);
if (status != RET_OK) {
return status;
}
}
return RET_OK;
}
TFNodeRegistrar g_tfOneHotParser("OneHot", new TFOneHotParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ONE_HOT_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ONE_HOT_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFOneHotParser : public TFNodeParser {
public:
TFOneHotParser() = default;
~TFOneHotParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ONE_HOT_PARSER_H_

@ -56,6 +56,8 @@ STATUS TFResizeParser::Parse(const tensorflow::NodeDef &tf_op,
attr->method = schema::ResizeMethod_LINEAR; attr->method = schema::ResizeMethod_LINEAR;
} else if (tf_op.op() == "ResizeNearestNeighbor") { } else if (tf_op.op() == "ResizeNearestNeighbor") {
attr->method = schema::ResizeMethod_NEAREST; attr->method = schema::ResizeMethod_NEAREST;
} else if (tf_op.op() == "ResizeBicubic") {
attr->method = schema::ResizeMethod_CUBIC;
} else { } else {
attr->method = schema::ResizeMethod_UNKNOWN; attr->method = schema::ResizeMethod_UNKNOWN;
} }
@ -90,5 +92,6 @@ STATUS TFResizeParser::Parse(const tensorflow::NodeDef &tf_op,
} }
TFNodeRegistrar g_tfResizeBilinearParser("ResizeBilinear", new TFResizeParser()); TFNodeRegistrar g_tfResizeBilinearParser("ResizeBilinear", new TFResizeParser());
TFNodeRegistrar g_tfResizeNearestNeighborParser("ResizeNearestNeighbor", new TFResizeParser()); TFNodeRegistrar g_tfResizeNearestNeighborParser("ResizeNearestNeighbor", new TFResizeParser());
TFNodeRegistrar g_tfResizeBicubicParser("ResizeBicubic", new TFResizeParser());
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

Loading…
Cancel
Save