You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/function/DepthwiseConvOp.cpp

309 lines
10 KiB

/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "DepthwiseConvOp.h"
#include "GemmFunctor.h"
#include "paddle/math/MemoryHandle.h"
namespace paddle {
/*
* imData = [input_channels, input_height, input_width]
* colData = [input_channels, filter_height, filter_width,
* output_height, output_width]
*/
template <class T>
class DepthwiseConvFunctor<DEVICE_TYPE_CPU, T> {
public:
void operator()(int outputSize,
const T* inputData,
const T* filterData,
int batchSize,
int outputChannels,
int outputHeight,
int outputWidth,
int filterHeight,
int filterWidth,
int strideH,
int strideW,
int paddingH,
int paddingW,
T* outputData) {
// NO_IMPLEMENTATION
}
};
template <class T>
class DepthwiseConvGradInputFunctor<DEVICE_TYPE_CPU, T> {
public:
void operator()(int inputSize,
const T* outputGrad,
const T* filterData,
int batchSize,
int outputChannels,
int outputHeight,
int outputWidth,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideH,
int strideW,
int paddingH,
int paddingW,
T* inputGrad) {}
};
template <class T>
class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_CPU, T> {
public:
void operator()(int num_i,
int colDataSize,
const T* outputGrad,
const T* inputData,
int batchSize,
int outputChannels,
int outputHeight,
int outputWidth,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideH,
int strideW,
int paddingH,
int paddingW,
T* colData,
T* multiplierData,
T* filterGrad) {}
};
/*
* \brief Forward calculation of convolution.
*/
template <DeviceType Device>
class DepthwiseConvFunction : public ConvFunctionBase {
public:
void init(const FuncConfig& config) override {
ConvFunctionBase::init(config);
}
virtual void check(const BufferArgs& inputs,
const BufferArgs& outputs) override {
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape();
checkShape(input, filter, output);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
check(inputs, outputs);
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape();
size_t batchSize = input[0];
// size_t inputChannels = input[1];
// size_t inputHeight = input[2];
// size_t inputWidth = input[3];
size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1];
size_t outputHeight = output[2];
size_t outputWidth = output[3];
real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>();
size_t outputSize = batchSize * outputChannels * outputHeight * outputWidth;
DepthwiseConvFunctor<Device, real> depthwiseConv;
depthwiseConv(outputSize,
inputData,
filterData,
batchSize,
outputChannels,
outputHeight,
outputWidth,
filterHeight,
filterWidth,
strideH(),
strideW(),
paddingH(),
paddingW(),
outputData);
}
};
/*
* \brief Backward input calculation of convolution.
*/
template <DeviceType Device>
class DepthwiseConvGradInputFunction : public ConvFunctionBase {
public:
void init(const FuncConfig& config) override {
ConvFunctionBase::init(config);
}
virtual void check(const BufferArgs& inputs,
const BufferArgs& outputs) override {
const TensorShape& output = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& input = outputs[0].shape();
checkShape(input, filter, output);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
check(inputs, outputs);
// Since the implementation of Col2ImFunctor is ADD_TO,
// this function only supports ADD_TO mode.
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
const TensorShape& output = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& input = outputs[0].shape();
size_t batchSize = input[0];
size_t inputChannels = input[1];
size_t inputHeight = input[2];
size_t inputWidth = input[3];
size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1];
size_t outputHeight = output[2];
size_t outputWidth = output[3];
real* outputGrad = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
real* inputGrad = outputs[0].data<real>();
size_t inputSize = batchSize * inputChannels * inputHeight * inputWidth;
DepthwiseConvGradInputFunctor<Device, real> depthwiseConvGradInput;
depthwiseConvGradInput(inputSize,
outputGrad,
filterData,
batchSize,
outputChannels,
outputHeight,
outputWidth,
inputHeight,
inputWidth,
filterHeight,
filterWidth,
strideH(),
strideW(),
paddingH(),
paddingW(),
inputGrad);
}
};
/*
* \brief Backward filter calculation of convolution.
*/
template <DeviceType Device>
class DepthwiseConvGradFilterFunction : public ConvFunctionBase {
public:
void init(const FuncConfig& config) override {
ConvFunctionBase::init(config);
}
virtual void check(const BufferArgs& inputs,
const BufferArgs& outputs) override {
const TensorShape& output = inputs[0].shape();
const TensorShape& input = inputs[1].shape();
const TensorShape& filter = outputs[0].shape();
checkShape(input, filter, output);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
check(inputs, outputs);
const TensorShape& output = inputs[0].shape();
const TensorShape& input = inputs[1].shape();
// const TensorShape& multiplier = inputs[2].shape();
const TensorShape& filter = outputs[0].shape();
size_t batchSize = input[0];
size_t inputChannels = input[1];
size_t inputHeight = input[2];
size_t inputWidth = input[3];
size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1];
size_t outputHeight = output[2];
size_t outputWidth = output[3];
real* outputGrad = inputs[0].data<real>();
real* inputData = inputs[1].data<real>();
real* multiplierData = inputs[2].data<real>();
real* filterGrad = outputs[0].data<real>();
size_t size =
inputChannels * filterHeight * filterWidth * outputHeight * outputWidth;
resizeBuffer<Device>(size);
real* colData = reinterpret_cast<real*>(memory_->getBuf());
DepthwiseConvGradFilterFunctor<Device, real> depthwiseConvGradFilter;
for (size_t i = 0; i < batchSize; i++) {
depthwiseConvGradFilter(i,
size,
outputGrad,
inputData,
batchSize,
outputChannels,
outputHeight,
outputWidth,
inputHeight,
inputWidth,
filterHeight,
filterWidth,
strideH(),
strideW(),
paddingH(),
paddingW(),
colData,
multiplierData,
filterGrad);
}
}
};
REGISTER_TYPED_FUNC(DepthwiseConv, CPU, DepthwiseConvFunction);
REGISTER_TYPED_FUNC(DepthwiseConvGradInput,
CPU,
DepthwiseConvGradInputFunction);
REGISTER_TYPED_FUNC(DepthwiseConvGradFilter,
CPU,
DepthwiseConvGradFilterFunction);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(DepthwiseConv, GPU, DepthwiseConvFunction);
REGISTER_TYPED_FUNC(DepthwiseConvGradInput,
GPU,
DepthwiseConvGradInputFunction);
REGISTER_TYPED_FUNC(DepthwiseConvGradFilter,
GPU,
DepthwiseConvGradFilterFunction);
#endif
} // namespace paddle