|
|
|
@ -143,21 +143,21 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
|
|
|
|
|
} else {
|
|
|
|
|
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
|
|
|
|
}
|
|
|
|
|
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
|
|
|
|
|
auto pad_list = CastToInt(prim.GetAttr("pad_list"));
|
|
|
|
|
attr->padUp = pad_list[0];
|
|
|
|
|
attr->padDown = pad_list[1];
|
|
|
|
|
attr->padLeft = pad_list[2];
|
|
|
|
|
attr->padRight = pad_list[3];
|
|
|
|
|
|
|
|
|
|
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
|
|
|
|
|
auto dilation = CastToInt(prim.GetAttr("dilation"));
|
|
|
|
|
attr->dilateH = dilation[0];
|
|
|
|
|
attr->dilateW = dilation[1];
|
|
|
|
|
|
|
|
|
|
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
|
|
|
|
|
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
|
|
|
|
|
attr->kernelH = kernel_size[0];
|
|
|
|
|
attr->kernelW = kernel_size[1];
|
|
|
|
|
|
|
|
|
|
auto stride = CastToInt(prim.GetAttr("stride"), true);
|
|
|
|
|
auto stride = CastToInt(prim.GetAttr("stride"));
|
|
|
|
|
attr->strideH = stride[2];
|
|
|
|
|
attr->strideW = stride[3];
|
|
|
|
|
|
|
|
|
@ -179,7 +179,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
|
|
|
|
|
|
|
|
|
|
int channel_mutiplier = 1;
|
|
|
|
|
if (prim.GetAttr("channel_mutiplier") != nullptr) {
|
|
|
|
|
channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front();
|
|
|
|
|
channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier")).front();
|
|
|
|
|
}
|
|
|
|
|
attr->channelMultiplier = channel_mutiplier;
|
|
|
|
|
|
|
|
|
@ -220,25 +220,25 @@ void Conv2D::PopulaterConv2DSingleGroup(const Primitive &prim, schema::Primitive
|
|
|
|
|
} else {
|
|
|
|
|
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
|
|
|
|
}
|
|
|
|
|
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
|
|
|
|
|
auto pad_list = CastToInt(prim.GetAttr("pad_list"));
|
|
|
|
|
attr->padUp = pad_list[0];
|
|
|
|
|
attr->padDown = pad_list[1];
|
|
|
|
|
attr->padLeft = pad_list[2];
|
|
|
|
|
attr->padRight = pad_list[3];
|
|
|
|
|
|
|
|
|
|
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
|
|
|
|
|
auto dilation = CastToInt(prim.GetAttr("dilation"));
|
|
|
|
|
attr->dilateH = dilation[2];
|
|
|
|
|
attr->dilateW = dilation[3];
|
|
|
|
|
|
|
|
|
|
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
|
|
|
|
|
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
|
|
|
|
|
attr->kernelH = kernel_size[0];
|
|
|
|
|
attr->kernelW = kernel_size[1];
|
|
|
|
|
|
|
|
|
|
auto stride = CastToInt(prim.GetAttr("stride"), true);
|
|
|
|
|
auto stride = CastToInt(prim.GetAttr("stride"));
|
|
|
|
|
attr->strideH = stride[2];
|
|
|
|
|
attr->strideW = stride[3];
|
|
|
|
|
|
|
|
|
|
attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front();
|
|
|
|
|
attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front();
|
|
|
|
|
|
|
|
|
|
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
|
|
|
|
if (pad_mode == "valid") {
|
|
|
|
@ -278,7 +278,7 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
|
|
|
|
MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
int group = CastToInt(groupAttr, false).front();
|
|
|
|
|
int group = CastToInt(groupAttr).front();
|
|
|
|
|
if (group > 1) {
|
|
|
|
|
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
|
|
|
|
|
} else {
|
|
|
|
|