deconv2d add outputPadding

pull/11599/head
yeyunpeng 4 years ago
parent 6d2ed6cafc
commit a7bae1413d

@ -480,6 +480,8 @@ table DeConv2D {
dilateH: int;
hasBias: bool = false; // DEPRECATED
activationType: ActivationType = 0;
outputPaddingW: int;
outputPaddingH: int;
}
table DeConv2DGradFilter {

@ -47,6 +47,8 @@ int DeConv2D::GetPadRight() const { return this->primitive_->value.AsDeConv2D()-
int DeConv2D::GetDilateW() const { return this->primitive_->value.AsDeConv2D()->dilateW; }
int DeConv2D::GetDilateH() const { return this->primitive_->value.AsDeConv2D()->dilateH; }
int DeConv2D::GetActivationType() const { return this->primitive_->value.AsDeConv2D()->activationType; }
int DeConv2D::GetOutputPaddingW() const { return this->primitive_->value.AsDeConv2D()->outputPaddingW; }
int DeConv2D::GetOutputPaddingH() const { return this->primitive_->value.AsDeConv2D()->outputPaddingH; }
void DeConv2D::SetFormat(int format) { this->primitive_->value.AsDeConv2D()->format = (schema::Format)format; }
void DeConv2D::SetGroup(int group) { this->primitive_->value.AsDeConv2D()->group = group; }
@ -295,6 +297,8 @@ int DeConv2D::GetPadRight() const { return this->primitive_->value_as_DeConv2D()
int DeConv2D::GetDilateW() const { return this->primitive_->value_as_DeConv2D()->dilateW(); }
int DeConv2D::GetDilateH() const { return this->primitive_->value_as_DeConv2D()->dilateH(); }
int DeConv2D::GetActivationType() const { return this->primitive_->value_as_DeConv2D()->activationType(); }
int DeConv2D::GetOutputPaddingW() const { return this->primitive_->value_as_DeConv2D()->outputPaddingW(); }
int DeConv2D::GetOutputPaddingH() const { return this->primitive_->value_as_DeConv2D()->outputPaddingH(); }
PrimitiveC *DeConv2DCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<DeConv2D>(primitive);
@ -347,6 +351,8 @@ int DeConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::
MS_LOG(ERROR) << "unsupported pad mode for deconv";
return RET_ERROR;
}
output_h += GetOutputPaddingH();
output_w += GetOutputPaddingW();
std::vector<int> out_shape = {output_n, output_h, output_w, output_c};
output->set_shape(out_shape);

@ -71,7 +71,8 @@ class DeConv2D : public PrimitiveC {
int GetDilateW() const;
int GetDilateH() const;
int GetActivationType() const;
int GetOutputPaddingW() const;
int GetOutputPaddingH() const;
int PadUp() const { return this->pad_u_; }
int PadDown() const { return this->pad_d_; }
int PadLeft() const { return this->pad_l_; }

File diff suppressed because it is too large Load Diff

@ -19,6 +19,7 @@
#include <memory>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "schema/inner/model_generated.h"
#include "tools/common/storage.h"
#include "tools/converter/converter_flags.h"
@ -39,6 +40,24 @@ class AnfTransform {
std::vector<ValueNodePtr> *vnodes);
FuncGraphPtr TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);
std::unique_ptr<quant::Quantizer> mQuantizer = nullptr;
int AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config);
int AddGraphPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config);
int AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config);
int AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config);
int RunAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
int RunMindirAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
int RunOnnxAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
int RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph);
};
} // namespace lite
} // namespace mindspore

@ -53,42 +53,36 @@ bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeC
return true;
}
lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx DeConvParser";
auto attr = std::make_unique<schema::DeConv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
int OnnxDeConvParser::ParseParameters(const onnx::NodeProto &onnx_node,
const std::unique_ptr<schema::DeConv2DT> &attr) {
attr->padMode = schema::PadMode_NOTSET;
attr->group = 1;
attr->strideW = 1;
attr->strideH = 1;
attr->dilateW = 1;
attr->dilateH = 1;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "group") {
attr->group = static_cast<int32_t>(onnx_node_attr.i());
} else if (onnx_node_attr.name() == "dilations") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2";
return nullptr;
return RET_ERROR;
}
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernels") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return nullptr;
return RET_ERROR;
}
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernel_shape") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return nullptr;
return RET_ERROR;
}
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
@ -97,7 +91,7 @@ lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &o
} else if (onnx_node_attr.name() == "pads") {
if (onnx_node_attr.ints().size() != 4) {
MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4";
return nullptr;
return RET_ERROR;
}
attr->padUp = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->padLeft = static_cast<int32_t>(onnx_node_attr.ints(1));
@ -106,7 +100,7 @@ lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &o
} else if (onnx_node_attr.name() == "strides") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
return nullptr;
return RET_ERROR;
}
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1));
@ -115,13 +109,30 @@ lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &o
attr->format = schema::Format::Format_NHWC;
} else {
MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s().c_str();
return nullptr;
return RET_ERROR;
}
} else if (onnx_node_attr.name() == "output_padding") {
MS_LOG(ERROR) << "output_padding param hasn't been supported";
return nullptr;
attr->outputPaddingH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->outputPaddingW = static_cast<int32_t>(onnx_node_attr.ints(1));
}
}
return RET_OK;
}
lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx DeConvParser";
auto attr = std::make_unique<schema::DeConv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
auto status = ParseParameters(onnx_node, attr);
if (status != RET_OK) {
MS_LOG(ERROR) << "Parse parameters failed.";
return nullptr;
}
const auto &onnx_conv_weight = onnx_node.input(1);
auto node_iter =

@ -32,6 +32,8 @@ class OnnxDeConvParser : public OnnxNodeParser {
private:
bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::PrimitiveT *primitive);
int ParseParameters(const onnx::NodeProto &onnx_node, const std::unique_ptr<schema::DeConv2DT> &attr);
};
} // namespace lite
} // namespace mindspore

Loading…
Cancel
Save