!7559 add onnx parser of pow

Merge pull request !7559 from yankai10/1021merge
pull/7559/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit de93d9bff1

@ -19,6 +19,7 @@
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/ops/conv2d.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
@ -80,6 +81,11 @@ void ConvolutionBaseCPUKernel::FreeQuantParam() {
}
int ConvolutionBaseCPUKernel::Init() {
auto conv2d_lite_primitive = (lite::Conv2D *)primitive_;
conv_param_->pad_u_ = conv2d_lite_primitive->PadUp();
conv_param_->pad_d_ = conv2d_lite_primitive->PadDown();
conv_param_->pad_l_ = conv2d_lite_primitive->PadLeft();
conv_param_->pad_r_ = conv2d_lite_primitive->PadRight();
auto input = this->in_tensors_.front();
auto output = this->out_tensors_.front();
conv_param_->input_batch_ = input->Batch();

@ -129,6 +129,24 @@ STATUS OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
return RET_NULL_PTR;
}
const auto &onnx_pow_power = onnx_node.input(1);
auto nodeIter =
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(),
[onnx_pow_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_pow_power; });
if (nodeIter == onnx_graph.node().end()) {
MS_LOG(ERROR) << "can not find node: " << onnx_pow_power;
return RET_ERROR;
}
const float *pW = nullptr;
for (const auto &attrPower : nodeIter->attribute()) {
if (attrPower.name() == "value") {
const auto &t = attrPower.t();
pW = reinterpret_cast<const float *>(t.raw_data().data());
}
}
attr->power = *pW;
attr->scale = 1.0f;
attr->shift = 0.0f;
op->primitive->value.type = schema::PrimitiveType_Power;
op->primitive->value.value = attr.release();
return RET_OK;
@ -675,7 +693,7 @@ OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser());
OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser());
OnnxNodeRegistrar g_onnxMulParser("Mul", new OnnxMulParser());
OnnxNodeRegistrar g_onnxDivParser("Div", new OnnxDivParser());
OnnxNodeRegistrar g_onnxPowParser("Power", new OnnxPowParser());
OnnxNodeRegistrar g_onnxPowParser("Pow", new OnnxPowParser());
OnnxNodeRegistrar g_onnxEqualParser("Equal", new OnnxEqualParser());
OnnxNodeRegistrar g_onnxLessParser("Less", new OnnxLessParser());
OnnxNodeRegistrar g_onnxGreaterParser("Greater", new OnnxGreaterParser());

@ -0,0 +1,55 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_onehot_parser.h"
#include <memory>
namespace mindspore {
namespace lite {
STATUS OnnxOneHotParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx OneHotParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::OneHotT> attr = std::make_unique<schema::OneHotT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "axis") {
attr->axis = static_cast<int32_t>(onnx_node_attr.i());
}
}
op->primitive->value.type = schema::PrimitiveType_OneHot;
op->primitive->value.value = attr.release();
return RET_OK;
}
OnnxNodeRegistrar g_onnxOneHotParser("OneHot", new OnnxOneHotParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONEHOT_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONEHOT_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore {
namespace lite {
class OnnxOneHotParser : public OnnxNodeParser {
public:
OnnxOneHotParser() : OnnxNodeParser("OneHot") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONEHOT_PARSER_H
Loading…
Cancel
Save