|
|
|
@ -79,6 +79,10 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
|
|
|
|
|
for (int i = 0; i < config_.inputs_size(); i++) {
|
|
|
|
|
std::vector<size_t> paddings = {(size_t)paddingY_[i], (size_t)padding_[i]};
|
|
|
|
|
std::vector<size_t> strides = {(size_t)strideY_[i], (size_t)stride_[i]};
|
|
|
|
|
std::vector<size_t> dilations = {(size_t)dilationY_[i],
|
|
|
|
|
(size_t)dilation_[i]};
|
|
|
|
|
|
|
|
|
|
bool useDilation = ((size_t)dilationY_[i] > 1 || (size_t)dilation_[i] > 1);
|
|
|
|
|
|
|
|
|
|
// Convolution Layer uses the GemmConv function by default.
|
|
|
|
|
convType = "GemmConv";
|
|
|
|
@ -97,13 +101,14 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
|
|
|
|
|
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
|
|
|
|
|
if ((filterSize_[i] == filterSizeY_[i]) &&
|
|
|
|
|
(filterSize_[i] == 3 || filterSize_[i] == 4) &&
|
|
|
|
|
(stride_[i] == strideY_[i]) && (stride_[i] == 1 || stride_[i] == 2)) {
|
|
|
|
|
(stride_[i] == strideY_[i]) && (stride_[i] == 1 || stride_[i] == 2) &&
|
|
|
|
|
!useDilation) {
|
|
|
|
|
convType = "NeonDepthwiseConv";
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (FLAGS_use_nnpack && !isDeconv_) {
|
|
|
|
|
if (FLAGS_use_nnpack && !isDeconv_ && !useDilation) {
|
|
|
|
|
createFunction(forward_,
|
|
|
|
|
"NNPACKConv",
|
|
|
|
|
FuncConfig()
|
|
|
|
@ -117,6 +122,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
|
|
|
|
|
FuncConfig()
|
|
|
|
|
.set("paddings", paddings)
|
|
|
|
|
.set("strides", strides)
|
|
|
|
|
.set("dilations", dilations)
|
|
|
|
|
.set("groups", (size_t)groups_[i]));
|
|
|
|
|
|
|
|
|
|
createFunction(backward_,
|
|
|
|
@ -124,6 +130,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
|
|
|
|
|
FuncConfig()
|
|
|
|
|
.set("paddings", paddings)
|
|
|
|
|
.set("strides", strides)
|
|
|
|
|
.set("dilations", dilations)
|
|
|
|
|
.set("groups", (size_t)groups_[i]));
|
|
|
|
|
|
|
|
|
|
createFunction(backward_,
|
|
|
|
@ -131,6 +138,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
|
|
|
|
|
FuncConfig()
|
|
|
|
|
.set("paddings", paddings)
|
|
|
|
|
.set("strides", strides)
|
|
|
|
|
.set("dilations", dilations)
|
|
|
|
|
.set("groups", (size_t)groups_[i]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|