!11972 [MS_LITE] tf_parser-onehot-pow

From: @YeFeng_24
Reviewed-by: @hangangqiang,@hangangqiang
Signed-off-by: @hangangqiang
pull/11972/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit a911936e73

@ -61,6 +61,8 @@ STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op,
status = CreateOperator<schema::LogT>(primitive, schema::PrimitiveType_Log);
} else if (tf_op.op() == "Sqrt") {
status = CreateOperator<schema::SqrtT>(primitive, schema::PrimitiveType_Sqrt);
} else if (tf_op.op() == "Pow") {
status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Power);
}
if (status != RET_OK) {
return status;
@ -84,5 +86,6 @@ TFNodeRegistrar g_tfExpParser("Exp", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfFloorParser("Floor", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfLogParser("Log", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfSqrtParser("Sqrt", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfPowParser("Pow", new TFArithmeticSelfParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,67 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_one_hot_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFOneHotParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF OneHotParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::OneHotT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "axis", &attr_value)) {
MS_LOG(ERROR) << "The axis attr should be specified";
return RET_ERROR;
}
attr->axis = static_cast<int32_t>(attr_value.i());
primitive->value.type = schema::PrimitiveType_OneHot;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); ++i) {
auto status = AddOpInput(tf_op, i, inputs);
if (status != RET_OK) {
return status;
}
}
return RET_OK;
}
TFNodeRegistrar g_tfOneHotParser("OneHot", new TFOneHotParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ONE_HOT_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ONE_HOT_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFOneHotParser : public TFNodeParser {
public:
TFOneHotParser() = default;
~TFOneHotParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ONE_HOT_PARSER_H_

@ -56,6 +56,8 @@ STATUS TFResizeParser::Parse(const tensorflow::NodeDef &tf_op,
attr->method = schema::ResizeMethod_LINEAR;
} else if (tf_op.op() == "ResizeNearestNeighbor") {
attr->method = schema::ResizeMethod_NEAREST;
} else if (tf_op.op() == "ResizeBicubic") {
attr->method = schema::ResizeMethod_CUBIC;
} else {
attr->method = schema::ResizeMethod_UNKNOWN;
}
@ -90,5 +92,6 @@ STATUS TFResizeParser::Parse(const tensorflow::NodeDef &tf_op,
}
TFNodeRegistrar g_tfResizeBilinearParser("ResizeBilinear", new TFResizeParser());
TFNodeRegistrar g_tfResizeNearestNeighborParser("ResizeNearestNeighbor", new TFResizeParser());
TFNodeRegistrar g_tfResizeBicubicParser("ResizeBicubic", new TFResizeParser());
} // namespace lite
} // namespace mindspore

Loading…
Cancel
Save