|
|
|
@ -16,6 +16,10 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/utils/Logging.h"
|
|
|
|
|
#include "paddle/utils/Stat.h"
|
|
|
|
|
|
|
|
|
|
DEFINE_bool(use_nnpack,
|
|
|
|
|
false,
|
|
|
|
|
"Whether to use nnpack for convolution calculation.");
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
@ -37,26 +41,38 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
|
|
|
|
|
for (int i = 0; i < config_.inputs_size(); i++) {
|
|
|
|
|
std::vector<size_t> paddings = {(size_t)paddingY_[i], (size_t)padding_[i]};
|
|
|
|
|
std::vector<size_t> strides = {(size_t)strideY_[i], (size_t)stride_[i]};
|
|
|
|
|
createFunction(forward_,
|
|
|
|
|
!isDeconv_ ? "GemmConv" : "GemmConvGradInput",
|
|
|
|
|
FuncConfig()
|
|
|
|
|
.set("paddings", paddings)
|
|
|
|
|
.set("strides", strides)
|
|
|
|
|
.set("groups", (size_t)groups_[i]));
|
|
|
|
|
|
|
|
|
|
createFunction(backward_,
|
|
|
|
|
!isDeconv_ ? "GemmConvGradInput" : "GemmConv",
|
|
|
|
|
FuncConfig()
|
|
|
|
|
.set("paddings", paddings)
|
|
|
|
|
.set("strides", strides)
|
|
|
|
|
.set("groups", (size_t)groups_[i]));
|
|
|
|
|
|
|
|
|
|
createFunction(backward_,
|
|
|
|
|
"GemmConvGradFilter",
|
|
|
|
|
FuncConfig()
|
|
|
|
|
.set("paddings", paddings)
|
|
|
|
|
.set("strides", strides)
|
|
|
|
|
.set("groups", (size_t)groups_[i]));
|
|
|
|
|
|
|
|
|
|
if (FLAGS_use_nnpack) {
|
|
|
|
|
CHECK_EQ(isDeconv_, false);
|
|
|
|
|
createFunction(forward_,
|
|
|
|
|
"NNPACKConv",
|
|
|
|
|
FuncConfig()
|
|
|
|
|
.set("paddings", paddings)
|
|
|
|
|
.set("strides", strides)
|
|
|
|
|
.set("groups", (size_t)groups_[i])
|
|
|
|
|
.set("algo", "auto"));
|
|
|
|
|
} else {
|
|
|
|
|
createFunction(forward_,
|
|
|
|
|
!isDeconv_ ? "GemmConv" : "GemmConvGradInput",
|
|
|
|
|
FuncConfig()
|
|
|
|
|
.set("paddings", paddings)
|
|
|
|
|
.set("strides", strides)
|
|
|
|
|
.set("groups", (size_t)groups_[i]));
|
|
|
|
|
|
|
|
|
|
createFunction(backward_,
|
|
|
|
|
!isDeconv_ ? "GemmConvGradInput" : "GemmConv",
|
|
|
|
|
FuncConfig()
|
|
|
|
|
.set("paddings", paddings)
|
|
|
|
|
.set("strides", strides)
|
|
|
|
|
.set("groups", (size_t)groups_[i]));
|
|
|
|
|
|
|
|
|
|
createFunction(backward_,
|
|
|
|
|
"GemmConvGradFilter",
|
|
|
|
|
FuncConfig()
|
|
|
|
|
.set("paddings", paddings)
|
|
|
|
|
.set("strides", strides)
|
|
|
|
|
.set("groups", (size_t)groups_[i]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|