|
|
|
@ -18,11 +18,6 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
* imData = [input_channels, input_height, input_width]
|
|
|
|
|
* colData = [input_channels, filter_height, filter_width,
|
|
|
|
|
* output_height, output_width]
|
|
|
|
|
*/
|
|
|
|
|
template <class T>
|
|
|
|
|
class DepthwiseConvFunctor<DEVICE_TYPE_CPU, T> {
|
|
|
|
|
public:
|
|
|
|
@ -33,6 +28,8 @@ public:
|
|
|
|
|
int outputChannels,
|
|
|
|
|
int outputHeight,
|
|
|
|
|
int outputWidth,
|
|
|
|
|
int inputHeight,
|
|
|
|
|
int inputWidth,
|
|
|
|
|
int filterHeight,
|
|
|
|
|
int filterWidth,
|
|
|
|
|
int strideH,
|
|
|
|
@ -40,7 +37,7 @@ public:
|
|
|
|
|
int paddingH,
|
|
|
|
|
int paddingW,
|
|
|
|
|
T* outputData) {
|
|
|
|
|
// NO_IMPLEMENTATION
|
|
|
|
|
// TODO(zhaolong) : cpu implementation of depthwise convolution
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -118,8 +115,8 @@ public:
|
|
|
|
|
|
|
|
|
|
size_t batchSize = input[0];
|
|
|
|
|
// size_t inputChannels = input[1];
|
|
|
|
|
// size_t inputHeight = input[2];
|
|
|
|
|
// size_t inputWidth = input[3];
|
|
|
|
|
size_t inputHeight = input[2];
|
|
|
|
|
size_t inputWidth = input[3];
|
|
|
|
|
size_t filterHeight = getFilterHeight(filter);
|
|
|
|
|
size_t filterWidth = getFilterWidth(filter);
|
|
|
|
|
size_t outputChannels = output[1];
|
|
|
|
@ -139,6 +136,8 @@ public:
|
|
|
|
|
outputChannels,
|
|
|
|
|
outputHeight,
|
|
|
|
|
outputWidth,
|
|
|
|
|
inputHeight,
|
|
|
|
|
inputWidth,
|
|
|
|
|
filterHeight,
|
|
|
|
|
filterWidth,
|
|
|
|
|
strideH(),
|
|
|
|
@ -233,8 +232,8 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
|
|
|
|
CHECK_EQ(numInputs_, inputs.size());
|
|
|
|
|
CHECK_EQ(numOutputs_, outputs.size());
|
|
|
|
|
// CHECK_EQ(numInputs_, inputs.size());
|
|
|
|
|
// CHECK_EQ(numOutputs_, outputs.size());
|
|
|
|
|
check(inputs, outputs);
|
|
|
|
|
const TensorShape& output = inputs[0].shape();
|
|
|
|
|
const TensorShape& input = inputs[1].shape();
|
|
|
|
|