Reconstruction of GemmConv Based on new im2col.

cblas_new
hedaoyuan 8 years ago
parent eb0c7e5ebc
commit 07cde439aa

@ -12,101 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "GemmConvOp.h" #include "ConvOp.h"
#include "GemmFunctor.h" #include "GemmFunctor.h"
#include "Im2Col.h"
#include "paddle/math/MemoryHandle.h" #include "paddle/math/MemoryHandle.h"
namespace paddle { namespace paddle {
/*
* imData = [input_channels, input_height, input_width]
* colData = [input_channels, filter_height, filter_width,
* output_height, output_width]
*/
template <class T>
class Im2ColFunctor<DEVICE_TYPE_CPU, T> {
public:
void operator()(const T* imData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* colData) {
int channelsCol = inputChannels * filterHeight * filterWidth;
for (int c = 0; c < channelsCol; ++c) {
int wOffset = c % filterWidth;
int hOffset = (c / filterWidth) % filterHeight;
int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset;
int imColIdx = w * strideWidth + wOffset;
if ((imRowIdx - paddingHeight) < 0 ||
(imRowIdx - paddingHeight) >= inputHeight ||
(imColIdx - paddingWidth) < 0 ||
(imColIdx - paddingWidth) >= inputWidth) {
colData[(c * outputHeight + h) * outputWidth + w] = T(0);
} else {
imRowIdx += c_im * inputHeight - paddingHeight;
imColIdx -= paddingWidth;
colData[(c * outputHeight + h) * outputWidth + w] =
imData[imRowIdx * inputWidth + imColIdx];
}
}
}
}
}
};
template <class T>
class Col2ImFunctor<DEVICE_TYPE_CPU, T> {
public:
void operator()(const T* colData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* imData) {
int channelsCol = inputChannels * filterHeight * filterWidth;
for (int c = 0; c < channelsCol; ++c) {
int wOffset = c % filterWidth;
int hOffset = (c / filterWidth) % filterHeight;
int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset;
int imColIdx = w * strideWidth + wOffset;
if ((imRowIdx - paddingHeight) >= 0 &&
(imRowIdx - paddingHeight) < inputHeight &&
(imColIdx - paddingWidth) >= 0 &&
(imColIdx - paddingWidth) < inputWidth) {
imRowIdx += c_im * inputHeight - paddingHeight;
imColIdx -= paddingWidth;
imData[imRowIdx * inputWidth + imColIdx] +=
colData[(c * outputHeight + h) * outputWidth + w];
}
}
}
}
}
};
/* /*
* \brief Forward calculation of convolution. * \brief Forward calculation of convolution.
*/ */
@ -155,15 +67,20 @@ public:
real* inputData = inputs[0].data<real>(); real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>(); real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>(); real* outputData = outputs[0].data<real>();
TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
size_t size = inputChannels / groups_ * filterHeight * filterWidth * resizeBuffer<Device>(colShape.getElements());
outputHeight * outputWidth;
resizeBuffer<Device>(size);
real* colData = reinterpret_cast<real*>(memory_->getBuf()); real* colData = reinterpret_cast<real*>(memory_->getBuf());
Im2ColFunctor<Device, real> im2col; Im2ColFunctor<kCFO, Device, real> im2col;
GemmFunctor<Device, real> gemm; GemmFunctor<Device, real> gemm;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; size_t inputOffset = imShape.getElements();
size_t outputOffset = size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth; (outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_; size_t filterOffset = filter.getElements() / groups_;
@ -171,18 +88,13 @@ public:
for (size_t i = 0; i < batchSize; i++) { for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) { for (size_t g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset, im2col(inputData + g * inputOffset,
inputChannels / groups_, imShape,
inputHeight, colData,
inputWidth, colShape,
filterHeight,
filterWidth,
strideH(), strideH(),
strideW(), strideW(),
paddingH(), paddingH(),
paddingW(), paddingW());
outputHeight,
outputWidth,
colData);
int M = outputChannels / groups_; int M = outputChannels / groups_;
int N = outputHeight * outputWidth; int N = outputHeight * outputWidth;
@ -249,15 +161,20 @@ public:
real* outputGrad = inputs[0].data<real>(); real* outputGrad = inputs[0].data<real>();
real* filterData = inputs[1].data<real>(); real* filterData = inputs[1].data<real>();
real* inputGrad = outputs[0].data<real>(); real* inputGrad = outputs[0].data<real>();
TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
size_t size = inputChannels / groups_ * filterHeight * filterWidth * resizeBuffer<Device>(colShape.getElements());
outputHeight * outputWidth;
resizeBuffer<Device>(size);
real* colData = reinterpret_cast<real*>(memory_->getBuf()); real* colData = reinterpret_cast<real*>(memory_->getBuf());
Col2ImFunctor<Device, real> col2im; Col2ImFunctor<kCFO, Device, real> col2im;
GemmFunctor<Device, real> gemm; GemmFunctor<Device, real> gemm;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; size_t inputOffset = imShape.getElements();
size_t outputOffset = size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth; (outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_; size_t filterOffset = filter.getElements() / groups_;
@ -280,20 +197,14 @@ public:
0.0f, 0.0f,
colData, colData,
N); N);
col2im(inputGrad + g * inputOffset,
col2im(colData, imShape,
inputChannels / groups_, colData,
inputHeight, colShape,
inputWidth,
filterHeight,
filterWidth,
strideH(), strideH(),
strideW(), strideW(),
paddingH(), paddingH(),
paddingW(), paddingW());
outputHeight,
outputWidth,
inputGrad + g * inputOffset);
} }
inputGrad += inputChannels * inputHeight * inputWidth; inputGrad += inputChannels * inputHeight * inputWidth;
outputGrad += outputChannels * outputHeight * outputWidth; outputGrad += outputChannels * outputHeight * outputWidth;
@ -347,33 +258,33 @@ public:
real* outputGrad = inputs[0].data<real>(); real* outputGrad = inputs[0].data<real>();
real* inputData = inputs[1].data<real>(); real* inputData = inputs[1].data<real>();
real* filterGrad = outputs[0].data<real>(); real* filterGrad = outputs[0].data<real>();
TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
size_t size = inputChannels / groups_ * filterHeight * filterWidth * resizeBuffer<Device>(colShape.getElements());
outputHeight * outputWidth;
resizeBuffer<Device>(size);
real* colData = reinterpret_cast<real*>(memory_->getBuf()); real* colData = reinterpret_cast<real*>(memory_->getBuf());
Im2ColFunctor<Device, real> im2col; Im2ColFunctor<kCFO, Device, real> im2col;
GemmFunctor<Device, real> gemm; GemmFunctor<Device, real> gemm;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; size_t inputOffset = imShape.getElements();
size_t outputOffset = size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth; (outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_; size_t filterOffset = filter.getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) { for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) { for (size_t g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset, im2col(inputData + g * inputOffset,
inputChannels / groups_, imShape,
inputHeight, colData,
inputWidth, colShape,
filterHeight,
filterWidth,
strideH(), strideH(),
strideW(), strideW(),
paddingH(), paddingH(),
paddingW(), paddingW());
outputHeight,
outputWidth,
colData);
int M = outputChannels / groups_; int M = outputChannels / groups_;
int K = outputHeight * outputWidth; int K = outputHeight * outputWidth;

Loading…
Cancel
Save