|
|
|
@ -509,10 +509,9 @@ public:
|
|
|
|
|
size_t filterMultiplier = outputChannels / groups_;
|
|
|
|
|
CHECK_EQ(inputChannels, groups_);
|
|
|
|
|
|
|
|
|
|
// only support
|
|
|
|
|
// only support strideH() == strideW() and filterHeight == filterWidth.
|
|
|
|
|
CHECK_EQ(strideH(), strideW());
|
|
|
|
|
CHECK_EQ(filterHeight, filterWidth);
|
|
|
|
|
CHECK_LT(strideH(), size_t(3));
|
|
|
|
|
|
|
|
|
|
float* inputData = inputs[0].data<float>();
|
|
|
|
|
float* filterData = inputs[1].data<float>();
|
|
|
|
@ -538,49 +537,32 @@ public:
|
|
|
|
|
inputWidth += 2 * paddingW();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < batchSize; i++) {
|
|
|
|
|
if (filterWidth == 3 && strideH() == 1) {
|
|
|
|
|
DepthwiseConvKernel<3, 1>::run(inputPadding,
|
|
|
|
|
filterData,
|
|
|
|
|
inputHeight,
|
|
|
|
|
inputWidth,
|
|
|
|
|
outputChannels,
|
|
|
|
|
outputHeight,
|
|
|
|
|
outputWidth,
|
|
|
|
|
filterMultiplier,
|
|
|
|
|
outputData);
|
|
|
|
|
} else if (filterWidth == 3 && strideH() == 2) {
|
|
|
|
|
DepthwiseConvKernel<3, 2>::run(inputPadding,
|
|
|
|
|
filterData,
|
|
|
|
|
inputHeight,
|
|
|
|
|
inputWidth,
|
|
|
|
|
outputChannels,
|
|
|
|
|
outputHeight,
|
|
|
|
|
outputWidth,
|
|
|
|
|
filterMultiplier,
|
|
|
|
|
outputData);
|
|
|
|
|
} else if (filterWidth == 4 && strideH() == 1) {
|
|
|
|
|
DepthwiseConvKernel<4, 1>::run(inputPadding,
|
|
|
|
|
filterData,
|
|
|
|
|
inputHeight,
|
|
|
|
|
inputWidth,
|
|
|
|
|
outputChannels,
|
|
|
|
|
outputHeight,
|
|
|
|
|
outputWidth,
|
|
|
|
|
filterMultiplier,
|
|
|
|
|
outputData);
|
|
|
|
|
} else if (filterWidth == 4 && strideH() == 2) {
|
|
|
|
|
DepthwiseConvKernel<4, 2>::run(inputPadding,
|
|
|
|
|
filterData,
|
|
|
|
|
inputHeight,
|
|
|
|
|
inputWidth,
|
|
|
|
|
outputChannels,
|
|
|
|
|
outputHeight,
|
|
|
|
|
outputWidth,
|
|
|
|
|
filterMultiplier,
|
|
|
|
|
outputData);
|
|
|
|
|
}
|
|
|
|
|
std::function<void(
|
|
|
|
|
const float*, const float*, int, int, int, int, int, int, float*)>
|
|
|
|
|
DepthWiseConv;
|
|
|
|
|
|
|
|
|
|
if (filterWidth == 3 && strideW() == 1) {
|
|
|
|
|
DepthWiseConv = DepthwiseConvKernel<3, 1>::run;
|
|
|
|
|
} else if (filterWidth == 3 && strideW() == 2) {
|
|
|
|
|
DepthWiseConv = DepthwiseConvKernel<3, 2>::run;
|
|
|
|
|
} else if (filterWidth == 4 && strideW() == 1) {
|
|
|
|
|
DepthWiseConv = DepthwiseConvKernel<4, 1>::run;
|
|
|
|
|
} else if (filterWidth == 4 && strideW() == 2) {
|
|
|
|
|
DepthWiseConv = DepthwiseConvKernel<4, 2>::run;
|
|
|
|
|
} else {
|
|
|
|
|
LOG(FATAL) << "Not supported";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < batchSize; i++) {
|
|
|
|
|
DepthWiseConv(inputPadding,
|
|
|
|
|
filterData,
|
|
|
|
|
inputHeight,
|
|
|
|
|
inputWidth,
|
|
|
|
|
outputChannels,
|
|
|
|
|
outputHeight,
|
|
|
|
|
outputWidth,
|
|
|
|
|
filterMultiplier,
|
|
|
|
|
outputData);
|
|
|
|
|
inputPadding += inputChannels * inputHeight * inputWidth;
|
|
|
|
|
outputData += outputChannels * outputHeight * outputWidth;
|
|
|
|
|
}
|
|
|
|
|