|
|
@ -89,10 +89,10 @@ void ConvertConvWeight(const ParameterPtr ¶m_node) {
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
size_t filter_k = weight->tensor_shape()[0];
|
|
|
|
size_t filter_k = weight->tensor_shape().at(0);
|
|
|
|
size_t filter_c = weight->tensor_shape()[1];
|
|
|
|
size_t filter_c = weight->tensor_shape().at(1);
|
|
|
|
size_t filter_h = weight->tensor_shape()[2];
|
|
|
|
size_t filter_h = weight->tensor_shape().at(2);
|
|
|
|
size_t filter_w = weight->tensor_shape()[3];
|
|
|
|
size_t filter_w = weight->tensor_shape().at(3);
|
|
|
|
T *p1Buff = nullptr;
|
|
|
|
T *p1Buff = nullptr;
|
|
|
|
T *p2Buff = nullptr;
|
|
|
|
T *p2Buff = nullptr;
|
|
|
|
for (size_t k = 0; k < filter_k; ++k) {
|
|
|
|
for (size_t k = 0; k < filter_k; ++k) {
|
|
|
@ -145,26 +145,26 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
|
|
|
|
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
|
|
|
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto pad_list = CastToInt(prim.GetAttr("pad_list"));
|
|
|
|
auto pad_list = CastToInt(prim.GetAttr("pad_list"));
|
|
|
|
attr->padUp = pad_list[0];
|
|
|
|
attr->padUp = pad_list.at(0);
|
|
|
|
attr->padDown = pad_list[1];
|
|
|
|
attr->padDown = pad_list.at(1);
|
|
|
|
attr->padLeft = pad_list[2];
|
|
|
|
attr->padLeft = pad_list.at(2);
|
|
|
|
attr->padRight = pad_list[3];
|
|
|
|
attr->padRight = pad_list.at(3);
|
|
|
|
|
|
|
|
|
|
|
|
auto dilation = CastToInt(prim.GetAttr("dilation"));
|
|
|
|
auto dilation = CastToInt(prim.GetAttr("dilation"));
|
|
|
|
#ifdef SUPPORT_TRAIN
|
|
|
|
#ifdef SUPPORT_TRAIN
|
|
|
|
attr->dilateH = dilation[2];
|
|
|
|
attr->dilateH = dilation.at(2);
|
|
|
|
attr->dilateW = dilation[3];
|
|
|
|
attr->dilateW = dilation.at(3);
|
|
|
|
#else
|
|
|
|
#else
|
|
|
|
attr->dilateH = dilation[0];
|
|
|
|
attr->dilateH = dilation.at(0);
|
|
|
|
attr->dilateW = dilation[1];
|
|
|
|
attr->dilateW = dilation.at(1);
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
|
|
|
|
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
|
|
|
|
attr->kernelH = kernel_size[0];
|
|
|
|
attr->kernelH = kernel_size.at(0);
|
|
|
|
attr->kernelW = kernel_size[1];
|
|
|
|
attr->kernelW = kernel_size.at(1);
|
|
|
|
|
|
|
|
|
|
|
|
auto stride = CastToInt(prim.GetAttr("stride"));
|
|
|
|
auto stride = CastToInt(prim.GetAttr("stride"));
|
|
|
|
attr->strideH = stride[2];
|
|
|
|
attr->strideH = stride.at(2);
|
|
|
|
attr->strideW = stride[3];
|
|
|
|
attr->strideW = stride.at(3);
|
|
|
|
|
|
|
|
|
|
|
|
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
|
|
|
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
|
|
|
if (pad_mode == "valid") {
|
|
|
|
if (pad_mode == "valid") {
|
|
|
@ -229,22 +229,22 @@ void Conv2D::PopulaterConv2DSingleGroup(const Primitive &prim, schema::Primitive
|
|
|
|
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
|
|
|
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto pad_list = CastToInt(prim.GetAttr("pad_list"));
|
|
|
|
auto pad_list = CastToInt(prim.GetAttr("pad_list"));
|
|
|
|
attr->padUp = pad_list[0];
|
|
|
|
attr->padUp = pad_list.at(0);
|
|
|
|
attr->padDown = pad_list[1];
|
|
|
|
attr->padDown = pad_list.at(1);
|
|
|
|
attr->padLeft = pad_list[2];
|
|
|
|
attr->padLeft = pad_list.at(2);
|
|
|
|
attr->padRight = pad_list[3];
|
|
|
|
attr->padRight = pad_list.at(3);
|
|
|
|
|
|
|
|
|
|
|
|
auto dilation = CastToInt(prim.GetAttr("dilation"));
|
|
|
|
auto dilation = CastToInt(prim.GetAttr("dilation"));
|
|
|
|
attr->dilateH = dilation[2];
|
|
|
|
attr->dilateH = dilation.at(2);
|
|
|
|
attr->dilateW = dilation[3];
|
|
|
|
attr->dilateW = dilation.at(3);
|
|
|
|
|
|
|
|
|
|
|
|
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
|
|
|
|
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
|
|
|
|
attr->kernelH = kernel_size[0];
|
|
|
|
attr->kernelH = kernel_size.at(0);
|
|
|
|
attr->kernelW = kernel_size[1];
|
|
|
|
attr->kernelW = kernel_size.at(1);
|
|
|
|
|
|
|
|
|
|
|
|
auto stride = CastToInt(prim.GetAttr("stride"));
|
|
|
|
auto stride = CastToInt(prim.GetAttr("stride"));
|
|
|
|
attr->strideH = stride[2];
|
|
|
|
attr->strideH = stride.at(2);
|
|
|
|
attr->strideW = stride[3];
|
|
|
|
attr->strideW = stride.at(3);
|
|
|
|
|
|
|
|
|
|
|
|
attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front();
|
|
|
|
attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front();
|
|
|
|
|
|
|
|
|
|
|
|