|
|
|
@ -18,8 +18,6 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
namespace neon {
|
|
|
|
|
|
|
|
|
|
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
|
|
|
|
|
|
|
|
|
|
template <DeviceType Device>
|
|
|
|
@ -45,16 +43,16 @@ public:
|
|
|
|
|
const TensorShape& filter = inputs[1].shape();
|
|
|
|
|
const TensorShape& output = outputs[0].shape();
|
|
|
|
|
|
|
|
|
|
size_t batchSize = input[0];
|
|
|
|
|
size_t inputChannels = input[1];
|
|
|
|
|
size_t inputHeight = input[2];
|
|
|
|
|
size_t inputWidth = input[3];
|
|
|
|
|
size_t filterHeight = getFilterHeight(filter);
|
|
|
|
|
size_t filterWidth = getFilterWidth(filter);
|
|
|
|
|
size_t outputChannels = output[1];
|
|
|
|
|
size_t outputHeight = output[2];
|
|
|
|
|
size_t outputWidth = output[3];
|
|
|
|
|
size_t filterMultiplier = outputChannels / groups_;
|
|
|
|
|
int batchSize = input[0];
|
|
|
|
|
int inputChannels = input[1];
|
|
|
|
|
int inputHeight = input[2];
|
|
|
|
|
int inputWidth = input[3];
|
|
|
|
|
int filterHeight = getFilterHeight(filter);
|
|
|
|
|
int filterWidth = getFilterWidth(filter);
|
|
|
|
|
int outputChannels = output[1];
|
|
|
|
|
int outputHeight = output[2];
|
|
|
|
|
int outputWidth = output[3];
|
|
|
|
|
int filterMultiplier = outputChannels / groups_;
|
|
|
|
|
CHECK_EQ(inputChannels, groups_);
|
|
|
|
|
|
|
|
|
|
// only support strideH() == strideW() and filterHeight == filterWidth.
|
|
|
|
@ -90,18 +88,18 @@ public:
|
|
|
|
|
DepthWiseConv;
|
|
|
|
|
|
|
|
|
|
if (filterWidth == 3 && strideW() == 1) {
|
|
|
|
|
DepthWiseConv = DepthwiseConvKernel<3, 1>::run;
|
|
|
|
|
DepthWiseConv = neon::DepthwiseConvKernel<3, 1>::run;
|
|
|
|
|
} else if (filterWidth == 3 && strideW() == 2) {
|
|
|
|
|
DepthWiseConv = DepthwiseConvKernel<3, 2>::run;
|
|
|
|
|
DepthWiseConv = neon::DepthwiseConvKernel<3, 2>::run;
|
|
|
|
|
} else if (filterWidth == 4 && strideW() == 1) {
|
|
|
|
|
DepthWiseConv = DepthwiseConvKernel<4, 1>::run;
|
|
|
|
|
DepthWiseConv = neon::DepthwiseConvKernel<4, 1>::run;
|
|
|
|
|
} else if (filterWidth == 4 && strideW() == 2) {
|
|
|
|
|
DepthWiseConv = DepthwiseConvKernel<4, 2>::run;
|
|
|
|
|
DepthWiseConv = neon::DepthwiseConvKernel<4, 2>::run;
|
|
|
|
|
} else {
|
|
|
|
|
LOG(FATAL) << "Not supported";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < batchSize; i++) {
|
|
|
|
|
for (int i = 0; i < batchSize; i++) {
|
|
|
|
|
DepthWiseConv(inputPadding,
|
|
|
|
|
filterData,
|
|
|
|
|
inputHeight,
|
|
|
|
@ -117,9 +115,10 @@ public:
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#ifndef PADDLE_TYPE_DOUBLE
|
|
|
|
|
REGISTER_TYPED_FUNC(NeonDepthwiseConv, CPU, NeonDepthwiseConvFunction);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
} // namespace neon
|
|
|
|
|
} // namespace paddle
|
|
|
|
|