|
|
|
@ -15,6 +15,8 @@
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include "src/ops/deconv2d.h"
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace lite {
|
|
|
|
@ -56,7 +58,86 @@ void DeConv2D::SetHasBias(bool has_bias) { this->primitive_->value.AsDeConv2D()-
|
|
|
|
|
void DeConv2D::SetActivationType(int activation_type) {
|
|
|
|
|
this->primitive_->value.AsDeConv2D()->activationType = (schema::ActivationType)activation_type;
|
|
|
|
|
}
|
|
|
|
|
void DeConv2D::PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group) {
|
|
|
|
|
auto attr = std::make_unique<schema::DeConv2DT>();
|
|
|
|
|
attr->group = group;
|
|
|
|
|
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
|
|
|
|
|
if (format == "NCHW") {
|
|
|
|
|
attr->format = schema::Format_NCHW;
|
|
|
|
|
} else if (format == "NHWC") {
|
|
|
|
|
attr->format = schema::Format_NHWC;
|
|
|
|
|
} else {
|
|
|
|
|
attr->format = schema::Format_NUM_OF_FORMAT;
|
|
|
|
|
}
|
|
|
|
|
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
|
|
|
|
attr->padUp = pad_list[0];
|
|
|
|
|
attr->padDown = pad_list[1];
|
|
|
|
|
attr->padLeft = pad_list[2];
|
|
|
|
|
attr->padRight = pad_list[3];
|
|
|
|
|
|
|
|
|
|
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
|
|
|
|
attr->dilateH = dilation[0];
|
|
|
|
|
attr->dilateW = dilation[1];
|
|
|
|
|
|
|
|
|
|
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
|
|
|
|
attr->kernelH = kernel_size[0];
|
|
|
|
|
attr->kernelW = kernel_size[1];
|
|
|
|
|
|
|
|
|
|
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
|
|
|
|
attr->strideH = stride[0];
|
|
|
|
|
attr->strideW = stride[1];
|
|
|
|
|
|
|
|
|
|
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));
|
|
|
|
|
|
|
|
|
|
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
|
|
|
|
if (pad_mode == "valid" || pad_mode == "VALID") {
|
|
|
|
|
attr->padMode = schema::PadMode_VALID;
|
|
|
|
|
} else if (pad_mode == "same" || pad_mode == "SAME") {
|
|
|
|
|
attr->padMode = schema::PadMode_SAME;
|
|
|
|
|
} else {
|
|
|
|
|
attr->padMode = schema::PadMode_NOTSET;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (prim.GetAttr("activation_name") != nullptr) {
|
|
|
|
|
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
|
|
|
|
|
attr->activationType = kActivationTypeMap[activate_name];
|
|
|
|
|
} else {
|
|
|
|
|
attr->activationType = schema::ActivationType_NO_ACTIVATION;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// attr->padMode = schema::PadMode_SAME;
|
|
|
|
|
// attr->activationType = schema::ActivationType_RELU;
|
|
|
|
|
primitive->value.type = schema::PrimitiveType_DeConv2D;
|
|
|
|
|
primitive->value.value = attr.release();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
|
|
|
|
if (this->primitive_ == nullptr) {
|
|
|
|
|
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
|
|
|
|
if (this->primitive_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new primitiveT failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
this->primitive_->value.type = schema::PrimitiveType_DeConv2D;
|
|
|
|
|
}
|
|
|
|
|
if (this->primitive_->value.type != schema::PrimitiveType_DeConv2D) {
|
|
|
|
|
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
int group = GetValue<int>(prim.GetAttr("group"));
|
|
|
|
|
if (group == 1) {
|
|
|
|
|
PopulaterDeConv2DSingleGroup(prim, this->primitive_, group);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (GetQuantType() == schema::QuantType_AwareTraining) {
|
|
|
|
|
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
|
|
|
|
|
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
|
|
|
|
|
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
|
|
|
|
|
SetInputQuantParam(vecInputQuantParam);
|
|
|
|
|
SetOutputQuantParam(vecOutputQuantParam);
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
int DeConv2D::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
|
|
|
|
MS_ASSERT(nullptr != primitive);
|
|
|
|
|