Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix_prelu

release/0.11.0
xzl 8 years ago
commit c1b23535c0

2
.gitignore vendored

@ -21,7 +21,7 @@ third_party/
cmake-build-* cmake-build-*
# generated while compiling # generated while compiling
python/paddle/v2/framework/core.so python/paddle/v2/fluid/core.so
paddle/pybind/pybind.h paddle/pybind/pybind.h
CMakeFiles CMakeFiles
cmake_install.cmake cmake_install.cmake

@ -18,7 +18,7 @@ limitations under the License. */
#include "hl_base.h" #include "hl_base.h"
/** /**
* @brief Maximum pool forward. * @brief Maximum pool forward with Mask output.
* *
* @param[in] frameCnt batch size of input image. * @param[in] frameCnt batch size of input image.
* @param[in] inputData input data. * @param[in] inputData input data.
@ -35,7 +35,7 @@ limitations under the License. */
* @param[in] paddingW padding width. * @param[in] paddingW padding width.
* @param[out] tgtData output data. * @param[out] tgtData output data.
* @param[in] tgtStride stride between output data samples. * @param[in] tgtStride stride between output data samples.
* * @param[out] maskData the location indices of select max data.
*/ */
extern void hl_maxpool_forward(const int frameCnt, extern void hl_maxpool_forward(const int frameCnt,
const real* inputData, const real* inputData,
@ -51,7 +51,8 @@ extern void hl_maxpool_forward(const int frameCnt,
const int paddingH, const int paddingH,
const int paddingW, const int paddingW,
real* tgtData, real* tgtData,
const int tgtStride); const int tgtStride,
real* maskData = NULL);
/** /**
* @brief Maximum pool backward. * @brief Maximum pool backward.

@ -31,7 +31,8 @@ inline void hl_maxpool_forward(const int frameCnt,
const int paddingH, const int paddingH,
const int paddingW, const int paddingW,
real* tgtData, real* tgtData,
const int tgtStride) {} const int tgtStride,
real* MaskData) {}
inline void hl_maxpool_backward(const int frameCnt, inline void hl_maxpool_backward(const int frameCnt,
const real* inputData, const real* inputData,

@ -31,7 +31,8 @@ __global__ void KeMaxPoolForward(const int nthreads,
const int offsetH, const int offsetH,
const int offsetW, const int offsetW,
real* tgtData, real* tgtData,
const int tgtStride) { const int tgtStride,
real* maskData) {
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) { if (index < nthreads) {
int pw = index % pooledW; int pw = index % pooledW;
@ -45,16 +46,22 @@ __global__ void KeMaxPoolForward(const int nthreads,
hstart = max(hstart, 0); hstart = max(hstart, 0);
wstart = max(wstart, 0); wstart = max(wstart, 0);
real maxval = -FLT_MAX; real maxval = -FLT_MAX;
int max_index = -1;
inputData += (frameNum * channels + c) * height * width; inputData += (frameNum * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
if (maxval < inputData[h * width + w]) if (maxval < inputData[h * width + w]) {
maxval = inputData[h * width + w]; max_index = h * width + w;
maxval = inputData[max_index];
}
} }
} }
int tgtIndex = int tgtIndex =
index % (pooledW * pooledH * channels) + frameNum * tgtStride; index % (pooledW * pooledH * channels) + frameNum * tgtStride;
tgtData[tgtIndex] = maxval; tgtData[tgtIndex] = maxval;
if (maskData != NULL) {
maskData[tgtIndex] = max_index;
}
} }
} }
@ -72,7 +79,8 @@ void hl_maxpool_forward(const int frameCnt,
const int paddingH, const int paddingH,
const int paddingW, const int paddingW,
real* tgtData, real* tgtData,
const int tgtStride) { const int tgtStride,
real* maskData) {
int num_kernels = pooledH * pooledW * channels * frameCnt; int num_kernels = pooledH * pooledW * channels * frameCnt;
int blocks = (num_kernels + 1024 - 1) / 1024; int blocks = (num_kernels + 1024 - 1) / 1024;
dim3 threads(1024, 1); dim3 threads(1024, 1);
@ -92,7 +100,8 @@ void hl_maxpool_forward(const int frameCnt,
paddingH, paddingH,
paddingW, paddingW,
tgtData, tgtData,
tgtStride); tgtStride,
maskData);
CHECK_SYNC("hl_maxpool_forward failed"); CHECK_SYNC("hl_maxpool_forward failed");
} }

@ -377,6 +377,12 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return grad_op_descs; return grad_op_descs;
} }
static BlockDescBind* CreateStepBlock(
ProgramDescBind& program_desc,
std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var,
int step_block_idx);
std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind& program_desc, int block_idx, ProgramDescBind& program_desc, int block_idx,
std::unordered_set<std::string>* no_grad_vars, std::unordered_set<std::string>* no_grad_vars,
@ -392,13 +398,13 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
if ((*it)->Type() == "recurrent") { if ((*it)->Type() == "recurrent") {
int step_block_idx = (*it)->GetBlockAttr("step_block"); int step_block_idx = (*it)->GetBlockAttr("step_block");
auto backward_block_op_descs = MakeBlockBackward( BlockDescBind* backward_block = CreateStepBlock(
program_desc, step_block_idx, no_grad_vars, grad_to_var); program_desc, no_grad_vars, grad_to_var, step_block_idx);
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
} else if ((*it)->Type() == "conditional_block") {
BlockDescBind* backward_block = BlockDescBind* backward_block =
program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx)); CreateStepBlock(program_desc, no_grad_vars, grad_to_var,
for (auto& ptr : backward_block_op_descs) { (*it)->GetBlockAttr("block"));
backward_block->AppendAllocatedOp(std::move(ptr));
}
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block}); op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
} else { } else {
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var); op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var);
@ -449,6 +455,21 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
return backward_descs; return backward_descs;
} }
static BlockDescBind* CreateStepBlock(
ProgramDescBind& program_desc,
std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var,
int step_block_idx) {
auto backward_block_op_descs = MakeBlockBackward(program_desc, step_block_idx,
no_grad_vars, grad_to_var);
BlockDescBind* backward_block =
program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx));
for (auto& ptr : backward_block_op_descs) {
backward_block->AppendAllocatedOp(move(ptr));
}
return backward_block;
}
ParamGradInfoMap AppendBackward( ParamGradInfoMap AppendBackward(
ProgramDescBind& program_desc, const VarDescBind& target, ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars) { const std::unordered_set<std::string>& no_grad_vars) {

@ -27,10 +27,32 @@ inline VarDesc::VarType ToVarType(std::type_index type) {
return VarDesc_VarType_LOD_RANK_TABLE; return VarDesc_VarType_LOD_RANK_TABLE;
} else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) { } else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) {
return VarDesc_VarType_LOD_TENSOR_ARRAY; return VarDesc_VarType_LOD_TENSOR_ARRAY;
} else if (type.hash_code() == typeid(SelectedRows).hash_code()) {
return VarDesc_VarType_SELECTED_ROWS;
} else { } else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name()); PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
} }
} }
template <typename Visitor>
inline void VisitVarType(const Variable& var, Visitor visitor) {
switch (ToVarType(var.Type())) {
case VarDesc_VarType_LOD_TENSOR:
visitor(var.Get<framework::LoDTensor>());
return;
case VarDesc_VarType_LOD_RANK_TABLE:
visitor(var.Get<LoDRankTable>());
return;
case VarDesc_VarType_LOD_TENSOR_ARRAY:
visitor(var.Get<LoDTensorArray>());
return;
case VarDesc_VarType_SELECTED_ROWS:
visitor(var.Get<SelectedRows>());
return;
default:
PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type()));
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -61,6 +61,7 @@ public:
// function arguments // function arguments
strides_ = config.get<std::vector<size_t>>("strides"); strides_ = config.get<std::vector<size_t>>("strides");
paddings_ = config.get<std::vector<size_t>>("paddings"); paddings_ = config.get<std::vector<size_t>>("paddings");
dilations_ = config.get<std::vector<size_t>>("dilations");
groups_ = config.get<size_t>("groups"); groups_ = config.get<size_t>("groups");
// number of inputs and outputs // number of inputs and outputs
@ -118,6 +119,7 @@ protected:
std::vector<size_t> strides_; std::vector<size_t> strides_;
std::vector<size_t> paddings_; std::vector<size_t> paddings_;
std::vector<size_t> dilations_;
/// Group size, refer to grouped convolution in /// Group size, refer to grouped convolution in
/// Alex Krizhevsky's paper: when group=2, the first half of the /// Alex Krizhevsky's paper: when group=2, the first half of the
@ -133,6 +135,10 @@ protected:
inline int paddingW() const { return paddings_[1]; } inline int paddingW() const { return paddings_[1]; }
inline int dilationH() const { return dilations_[0]; }
inline int dilationW() const { return dilations_[1]; }
// A temporary memory in convolution calculation. // A temporary memory in convolution calculation.
MemoryHandlePtr memory_; MemoryHandlePtr memory_;

@ -79,45 +79,59 @@ void Convolution(const std::string& conv1,
if (outputChannels < inputChannels) continue; if (outputChannels < inputChannels) continue;
for (size_t stride : {1, 2}) { for (size_t stride : {1, 2}) {
for (size_t padding : {0, 1}) { for (size_t padding : {0, 1}) {
if (padding >= filterSize) break; for (size_t dilation : {1, 3}) {
if (padding >= filterSize) break;
size_t filterS = (filterSize - 1) * dilation + 1;
// NNPACK only supports stride = 1 if batchSize > 1 if (inputSize + 2 * padding < filterS) break;
if ((conv1 == "NNPACKConv-CPU" || conv2 == "NNPACKConv-CPU") &&
batchSize > 1 && stride > 1)
break;
size_t outputSize = if ((conv1 == "NaiveConv-CPU" || conv2 == "NaiveConv-CPU" ||
(inputSize - filterSize + 2 * padding + stride) / stride; conv1 == "NNPACKConv-CPU" ||
VLOG(3) << " batchSize=" << batchSize conv2 == "NNPACKConv-CPU") &&
<< " inputChannels=" << inputChannels dilation > 1)
<< " inputHeight=" << inputSize break;
<< " inputWidth=" << inputSize
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterSize
<< " filterWidth=" << filterSize
<< " outputHeight=" << outputSize
<< " outputWidth=" << outputSize << " stride=" << stride
<< " padding=" << padding;
std::vector<size_t> paddings = {padding, padding}; // NNPACK only supports stride = 1 if batchSize > 1
std::vector<size_t> strides = {stride, stride}; if ((conv1 == "NNPACKConv-CPU" ||
Compare2Function<DType1, DType2> test( conv2 == "NNPACKConv-CPU") &&
conv1, batchSize > 1 && stride > 1)
conv2, break;
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)1)
.set("algo", (std::string) "auto"));
TensorShape input{ size_t outputSize =
batchSize, inputChannels, inputSize, inputSize}; (inputSize - filterS + 2 * padding + stride) / stride;
TensorShape filter{ VLOG(3) << " batchSize=" << batchSize
outputChannels, inputChannels, filterSize, filterSize}; << " inputChannels=" << inputChannels
TensorShape output{ << " inputHeight=" << inputSize
batchSize, outputChannels, outputSize, outputSize}; << " inputWidth=" << inputSize
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterSize
<< " filterWidth=" << filterSize
<< " outputHeight=" << outputSize
<< " outputWidth=" << outputSize
<< " stride=" << stride << " padding=" << padding;
function(test, input, filter, output); std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
std::vector<size_t> dilations = {dilation, dilation};
Compare2Function<DType1, DType2> test(
conv1,
conv2,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("dilations", dilations)
.set("groups", (size_t)1)
.set("algo", (std::string) "auto"));
TensorShape input{
batchSize, inputChannels, inputSize, inputSize};
TensorShape filter{
outputChannels, inputChannels, filterSize, filterSize};
TensorShape output{
batchSize, outputChannels, outputSize, outputSize};
function(test, input, filter, output);
}
} }
} }
} }
@ -144,6 +158,7 @@ void Convolution2(const std::string& conv1,
for (size_t outputChannels : {7}) { for (size_t outputChannels : {7}) {
size_t stride = 1; size_t stride = 1;
size_t padding = 0; size_t padding = 0;
size_t dilation = 1;
size_t outputHeight = size_t outputHeight =
(inputHeight - filterHeight + 2 * padding + stride) / (inputHeight - filterHeight + 2 * padding + stride) /
stride; stride;
@ -162,6 +177,7 @@ void Convolution2(const std::string& conv1,
std::vector<size_t> paddings = {padding, padding}; std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride}; std::vector<size_t> strides = {stride, stride};
std::vector<size_t> dilations = {dilation, dilation};
Compare2Function<DType1, DType2> test( Compare2Function<DType1, DType2> test(
conv1, conv1,
conv2, conv2,
@ -169,6 +185,7 @@ void Convolution2(const std::string& conv1,
.set("paddings", paddings) .set("paddings", paddings)
.set("strides", strides) .set("strides", strides)
.set("groups", (size_t)1) .set("groups", (size_t)1)
.set("dilations", dilations)
.set("algo", (std::string) "auto")); .set("algo", (std::string) "auto"));
TensorShape input{ TensorShape input{
@ -223,6 +240,7 @@ void DepthwiseConvolution(const std::string& conv1,
std::vector<size_t> paddings = {padding, padding}; std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride}; std::vector<size_t> strides = {stride, stride};
std::vector<size_t> dilations = {1, 1};
size_t groups = inputChannels; size_t groups = inputChannels;
Compare2Function<DType1, DType2> test( Compare2Function<DType1, DType2> test(
conv1, conv1,
@ -231,6 +249,7 @@ void DepthwiseConvolution(const std::string& conv1,
.set("paddings", paddings) .set("paddings", paddings)
.set("strides", strides) .set("strides", strides)
.set("groups", groups) .set("groups", groups)
.set("dilations", dilations)
.set("algo", (std::string) "auto")); .set("algo", (std::string) "auto"));
TensorShape input{ TensorShape input{

@ -100,7 +100,9 @@ public:
strideH(), strideH(),
strideW(), strideW(),
paddingH(), paddingH(),
paddingW()); paddingW(),
dilationH(),
dilationW());
} else { } else {
colData = inputData + g * inputOffset; colData = inputData + g * inputOffset;
} }
@ -223,7 +225,9 @@ public:
strideH(), strideH(),
strideW(), strideW(),
paddingH(), paddingH(),
paddingW()); paddingW(),
dilationH(),
dilationW());
} }
} }
inputGrad += inputChannels * inputHeight * inputWidth; inputGrad += inputChannels * inputHeight * inputWidth;
@ -310,7 +314,9 @@ public:
strideH(), strideH(),
strideW(), strideW(),
paddingH(), paddingH(),
paddingW()); paddingW(),
dilationH(),
dilationW());
} else { } else {
colData = inputData + g * inputOffset; colData = inputData + g * inputOffset;
} }

@ -78,7 +78,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth); int paddingWidth,
int dilationHeight = 1,
int dilationWidth = 1);
}; };
template <ColFormat Format, DeviceType Device, class T> template <ColFormat Format, DeviceType Device, class T>
@ -91,7 +93,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth); int paddingWidth,
int dilationHeight = 1,
int dilationWidth = 1);
}; };
} // namespace paddle } // namespace paddle

@ -31,7 +31,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
@ -47,8 +49,8 @@ public:
int c_im = c / filterWidth / filterHeight; int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) { for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) { for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset; int imRowIdx = h * strideHeight + hOffset * dilationHeight;
int imColIdx = w * strideWidth + wOffset; int imColIdx = w * strideWidth + wOffset * dilationWidth;
if ((imRowIdx - paddingHeight) < 0 || if ((imRowIdx - paddingHeight) < 0 ||
(imRowIdx - paddingHeight) >= inputHeight || (imRowIdx - paddingHeight) >= inputHeight ||
(imColIdx - paddingWidth) < 0 || (imColIdx - paddingWidth) < 0 ||
@ -81,7 +83,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
@ -97,8 +101,8 @@ public:
int c_im = c / filterWidth / filterHeight; int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) { for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) { for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset; int imRowIdx = h * strideHeight + hOffset * dilationHeight;
int imColIdx = w * strideWidth + wOffset; int imColIdx = w * strideWidth + wOffset * dilationWidth;
if ((imRowIdx - paddingHeight) >= 0 && if ((imRowIdx - paddingHeight) >= 0 &&
(imRowIdx - paddingHeight) < inputHeight && (imRowIdx - paddingHeight) < inputHeight &&
(imColIdx - paddingWidth) >= 0 && (imColIdx - paddingWidth) >= 0 &&
@ -134,7 +138,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight = 1,
int dilationWidth = 1) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
@ -147,9 +153,10 @@ public:
for (int channel = 0; channel < inputChannels; ++channel) { for (int channel = 0; channel < inputChannels; ++channel) {
for (int filterH = 0; filterH < filterHeight; ++filterH) { for (int filterH = 0; filterH < filterHeight; ++filterH) {
for (int filterW = 0; filterW < filterWidth; ++filterW) { for (int filterW = 0; filterW < filterWidth; ++filterW) {
int imRowOffset = int imRowOffset = outputH * strideHeight +
outputH * strideHeight + filterH - paddingHeight; filterH * dilationHeight - paddingHeight;
int imColOffset = outputW * strideWidth + filterW - paddingWidth; int imColOffset = outputW * strideWidth +
filterW * dilationWidth - paddingWidth;
int colDataOffset = int colDataOffset =
(((outputH * outputWidth + outputW) * inputChannels + (((outputH * outputWidth + outputW) * inputChannels +
channel) * channel) *
@ -189,7 +196,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight = 1,
int dilationWidth = 1) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
@ -202,9 +211,10 @@ public:
for (int channel = 0; channel < inputChannels; ++channel) { for (int channel = 0; channel < inputChannels; ++channel) {
for (int filterH = 0; filterH < filterHeight; ++filterH) { for (int filterH = 0; filterH < filterHeight; ++filterH) {
for (int filterW = 0; filterW < filterWidth; ++filterW) { for (int filterW = 0; filterW < filterWidth; ++filterW) {
int imRowOffset = int imRowOffset = outputH * strideHeight +
outputH * strideHeight + filterH - paddingHeight; filterH * dilationHeight - paddingHeight;
int imColOffset = outputW * strideWidth + filterW - paddingWidth; int imColOffset = outputW * strideWidth +
filterW * dilationWidth - paddingWidth;
int colDataOffset = int colDataOffset =
(((outputH * outputWidth + outputW) * inputChannels + (((outputH * outputWidth + outputW) * inputChannels +
channel) * channel) *

@ -28,6 +28,8 @@ __global__ void im2col(const T* data_im,
int strideW, int strideW,
int paddingH, int paddingH,
int paddingW, int paddingW,
int dilationH,
int dilationW,
int height_col, int height_col,
int width_col, int width_col,
T* data_col) { T* data_col) {
@ -44,8 +46,8 @@ __global__ void im2col(const T* data_im,
data_col += (channel_out * height_col + h_out) * width_col + w_out; data_col += (channel_out * height_col + h_out) * width_col + w_out;
for (int i = 0; i < blockH; ++i) { for (int i = 0; i < blockH; ++i) {
for (int j = 0; j < blockW; ++j) { for (int j = 0; j < blockW; ++j) {
int rIdx = int(h_in + i); int rIdx = int(h_in + i * dilationH);
int cIdx = int(w_in + j); int cIdx = int(w_in + j * dilationW);
if ((rIdx - (int)paddingH) >= (int)height || if ((rIdx - (int)paddingH) >= (int)height ||
(rIdx - (int)paddingH) < 0 || (rIdx - (int)paddingH) < 0 ||
(cIdx - (int)paddingW) >= (int)width || (cIdx - (int)paddingW) >= (int)width ||
@ -77,7 +79,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
@ -102,6 +106,8 @@ public:
strideWidth, strideWidth,
paddingHeight, paddingHeight,
paddingWidth, paddingWidth,
dilationHeight,
dilationWidth,
outputHeight, outputHeight,
outputWidth, outputWidth,
colData); colData);
@ -121,6 +127,8 @@ __global__ void col2im(size_t n,
size_t strideW, size_t strideW,
size_t paddingH, size_t paddingH,
size_t paddingW, size_t paddingW,
size_t dilationH,
size_t dilationW,
size_t height_col, size_t height_col,
size_t width_col, size_t width_col,
T* data_im) { T* data_im) {
@ -131,23 +139,34 @@ __global__ void col2im(size_t n,
int w = int(index % width); int w = int(index % width);
int h = int((index / width) % height); int h = int((index / width) % height);
int c = int(index / (width * height)); int c = int(index / (width * height));
int filterH = (blockH - 1) * dilationH + 1;
int filterW = (blockW - 1) * dilationW + 1;
if ((w - (int)paddingW) >= 0 && if ((w - (int)paddingW) >= 0 &&
(w - (int)paddingW) < (width - 2 * paddingW) && (w - (int)paddingW) < (width - 2 * paddingW) &&
(h - (int)paddingH) >= 0 && (h - paddingH) < (height - 2 * paddingH)) { (h - (int)paddingH) >= 0 && (h - paddingH) < (height - 2 * paddingH)) {
// compute the start and end of the output // compute the start and end of the output
int w_col_start = int w_col_start =
(w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1; (w < (int)filterW) ? 0 : (w - int(filterW)) / (int)strideW + 1;
int w_col_end = min((int)(w / (int)strideW + 1), (int)(width_col)); int w_col_end = min((int)(w / (int)strideW + 1), (int)(width_col));
int h_col_start = int h_col_start =
(h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1; (h < (int)filterH) ? 0 : (h - (int)filterH) / (int)strideH + 1;
int h_col_end = min(int(h / strideH + 1), int(height_col)); int h_col_end = min(int(h / strideH + 1), int(height_col));
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
// the col location: [c * width * height + h_out, w_out] // the col location: [c * width * height + h_out, w_out]
int c_col = int(c * blockH * blockW) + int h_k = (h - h_col * strideH);
(h - h_col * (int)strideH) * (int)blockW + int w_k = (w - w_col * strideW);
(w - w_col * (int)strideW); if (h_k % dilationH == 0 && w_k % dilationW == 0) {
val += data_col[(c_col * height_col + h_col) * width_col + w_col]; h_k /= dilationH;
w_k /= dilationW;
int c_col =
(((c * blockH + h_k) * blockW + w_k) * height_col + h_col) *
width_col +
w_col;
val += data_col[c_col];
}
} }
} }
h -= paddingH; h -= paddingH;
@ -173,7 +192,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
@ -205,6 +226,8 @@ public:
strideWidth, strideWidth,
paddingHeight, paddingHeight,
paddingWidth, paddingWidth,
dilationHeight,
dilationWidth,
outputHeight, outputHeight,
outputWidth, outputWidth,
imData); imData);
@ -229,6 +252,8 @@ __global__ void im2colOCF(const T* imData,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth, int paddingWidth,
int dilationHeight,
int dilationWidth,
int outputHeight, int outputHeight,
int outputWidth) { int outputWidth) {
int swId = blockIdx.x; int swId = blockIdx.x;
@ -237,8 +262,10 @@ __global__ void im2colOCF(const T* imData,
channelId += blockDim.z) { channelId += blockDim.z) {
for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) { for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) { for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) {
int widthOffset = idx + swId * strideWidth - paddingWidth; int widthOffset =
int heightOffset = idy + shId * strideHeight - paddingHeight; idx * dilationHeight + swId * strideWidth - paddingWidth;
int heightOffset =
idy * dilationWidth + shId * strideHeight - paddingHeight;
int imOffset = widthOffset + heightOffset * inputWidth + int imOffset = widthOffset + heightOffset * inputWidth +
channelId * inputHeight * inputWidth; channelId * inputHeight * inputWidth;
@ -273,7 +300,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
@ -312,6 +341,8 @@ public:
strideWidth, strideWidth,
paddingHeight, paddingHeight,
paddingWidth, paddingWidth,
dilationHeight,
dilationWidth,
outputHeight, outputHeight,
outputWidth); outputWidth);
CHECK_SYNC("Im2ColFunctor GPU failed"); CHECK_SYNC("Im2ColFunctor GPU failed");
@ -330,6 +361,8 @@ __global__ void col2imOCF(T* imData,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth, int paddingWidth,
int dilationHeight,
int dilationWidth,
int outputHeight, int outputHeight,
int outputWidth) { int outputWidth) {
int swId = blockIdx.x; int swId = blockIdx.x;
@ -338,8 +371,10 @@ __global__ void col2imOCF(T* imData,
channelId += blockDim.z) { channelId += blockDim.z) {
for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) { for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) { for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) {
int widthOffset = idx + swId * strideWidth - paddingWidth; int widthOffset =
int heightOffset = idy + shId * strideHeight - paddingHeight; idx * dilationWidth + swId * strideWidth - paddingWidth;
int heightOffset =
idy * dilationHeight + shId * strideHeight - paddingHeight;
int imOffset = widthOffset + heightOffset * inputWidth + int imOffset = widthOffset + heightOffset * inputWidth +
channelId * inputHeight * inputWidth; channelId * inputHeight * inputWidth;
@ -372,7 +407,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
@ -411,6 +448,8 @@ public:
strideWidth, strideWidth,
paddingHeight, paddingHeight,
paddingWidth, paddingWidth,
dilationHeight,
dilationWidth,
outputHeight, outputHeight,
outputWidth); outputWidth);
CHECK_SYNC("Col2ImFunctor GPU failed"); CHECK_SYNC("Col2ImFunctor GPU failed");

@ -29,82 +29,98 @@ void TestIm2ColFunctor() {
for (size_t filterWidth : {3, 7}) { for (size_t filterWidth : {3, 7}) {
for (size_t stride : {1, 2}) { for (size_t stride : {1, 2}) {
for (size_t padding : {0, 1}) { for (size_t padding : {0, 1}) {
if (inputHeight <= filterHeight || inputWidth <= filterWidth) for (size_t dilation : {1, 3}) {
break; size_t filterSizeH = (filterHeight - 1) * dilation + 1;
if (padding >= filterHeight || padding >= filterWidth) break; size_t filterSizeW = (filterWidth - 1) * dilation + 1;
size_t outputHeight = if (inputHeight + 2 * padding < filterSizeH ||
(inputHeight - filterHeight + 2 * padding + stride) / inputWidth + 2 * padding < filterSizeW)
stride; break;
size_t outputWidth = if (padding >= filterSizeH || padding >= filterSizeW) break;
(inputWidth - filterWidth + 2 * padding + stride) / stride; size_t outputHeight =
(inputHeight - filterSizeH + 2 * padding) / stride + 1;
TensorShape imShape = size_t outputWidth =
TensorShape({channels, inputHeight, inputWidth}); (inputWidth - filterSizeW + 2 * padding) / stride + 1;
TensorShape colShape1 = TensorShape({channels,
filterHeight, TensorShape imShape =
filterWidth, TensorShape({channels, inputHeight, inputWidth});
outputHeight, TensorShape colShape1 = TensorShape({channels,
outputWidth}); filterHeight,
TensorShape colShape2 = TensorShape({outputHeight, filterWidth,
outputWidth, outputHeight,
channels, outputWidth});
filterHeight, TensorShape colShape2 = TensorShape({outputHeight,
filterWidth}); outputWidth,
channels,
size_t height = channels * filterHeight * filterWidth; filterHeight,
size_t width = outputHeight * outputWidth; filterWidth});
VectorPtr input1 = Vector::create(imShape.getElements(), false);
VectorPtr input2 = Vector::create(imShape.getElements(), false); size_t height = channels * filterHeight * filterWidth;
MatrixPtr output1 = Matrix::create(height, width, false, false); size_t width = outputHeight * outputWidth;
MatrixPtr output2 = Matrix::create(width, height, false, false); VectorPtr input1 =
input1->uniform(0.001, 1); Vector::create(imShape.getElements(), false);
input2->copyFrom(*input1); VectorPtr input2 =
Vector::create(imShape.getElements(), false);
Im2ColFunctor<kCFO, Device, T> im2Col1; MatrixPtr output1 =
Im2ColFunctor<kOCF, Device, T> im2Col2; Matrix::create(height, width, false, false);
im2Col1(input1->getData(), MatrixPtr output2 =
imShape, Matrix::create(width, height, false, false);
output1->getData(), input1->uniform(0.001, 1);
colShape1, input2->copyFrom(*input1);
stride,
stride, Im2ColFunctor<kCFO, Device, T> im2Col1;
padding, Im2ColFunctor<kOCF, Device, T> im2Col2;
padding); im2Col1(input1->getData(),
im2Col2(input2->getData(), imShape,
imShape, output1->getData(),
output2->getData(), colShape1,
colShape2, stride,
stride, stride,
stride, padding,
padding, padding,
padding); dilation,
dilation);
// The transposition of the result of ColFormat == kCFO im2Col2(input2->getData(),
// is equal to the result of ColFormat == kOCF. imShape,
MatrixPtr test; output2->getData(),
output2->transpose(test, true); colShape2,
autotest::TensorCheckErr(*output1, *test); stride,
stride,
Col2ImFunctor<kCFO, Device, T> col2Im1; padding,
Col2ImFunctor<kOCF, Device, T> col2Im2; padding,
col2Im1(input1->getData(), dilation,
imShape, dilation);
output1->getData(),
colShape1, // The transposition of the result of ColFormat == kCFO
stride, // is equal to the result of ColFormat == kOCF.
stride, MatrixPtr test;
padding, output2->transpose(test, true);
padding); autotest::TensorCheckErr(*output1, *test);
col2Im2(input2->getData(),
imShape, Col2ImFunctor<kCFO, Device, T> col2Im1;
output2->getData(), Col2ImFunctor<kOCF, Device, T> col2Im2;
colShape2,
stride, col2Im1(input1->getData(),
stride, imShape,
padding, output1->getData(),
padding); colShape1,
stride,
autotest::TensorCheckErr(*input1, *input2); stride,
padding,
padding,
dilation,
dilation);
col2Im2(input2->getData(),
imShape,
output2->getData(),
colShape2,
stride,
stride,
padding,
padding,
dilation,
dilation);
autotest::TensorCheckErr(*input1, *input2);
}
} }
} }
} }

@ -79,6 +79,10 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
for (int i = 0; i < config_.inputs_size(); i++) { for (int i = 0; i < config_.inputs_size(); i++) {
std::vector<size_t> paddings = {(size_t)paddingY_[i], (size_t)padding_[i]}; std::vector<size_t> paddings = {(size_t)paddingY_[i], (size_t)padding_[i]};
std::vector<size_t> strides = {(size_t)strideY_[i], (size_t)stride_[i]}; std::vector<size_t> strides = {(size_t)strideY_[i], (size_t)stride_[i]};
std::vector<size_t> dilations = {(size_t)dilationY_[i],
(size_t)dilation_[i]};
bool useDilation = ((size_t)dilationY_[i] > 1 || (size_t)dilation_[i] > 1);
// Convolution Layer uses the GemmConv function by default. // Convolution Layer uses the GemmConv function by default.
convType = "GemmConv"; convType = "GemmConv";
@ -97,13 +101,14 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
if ((filterSize_[i] == filterSizeY_[i]) && if ((filterSize_[i] == filterSizeY_[i]) &&
(filterSize_[i] == 3 || filterSize_[i] == 4) && (filterSize_[i] == 3 || filterSize_[i] == 4) &&
(stride_[i] == strideY_[i]) && (stride_[i] == 1 || stride_[i] == 2)) { (stride_[i] == strideY_[i]) && (stride_[i] == 1 || stride_[i] == 2) &&
!useDilation) {
convType = "NeonDepthwiseConv"; convType = "NeonDepthwiseConv";
} }
#endif #endif
} }
if (FLAGS_use_nnpack && !isDeconv_) { if (FLAGS_use_nnpack && !isDeconv_ && !useDilation) {
createFunction(forward_, createFunction(forward_,
"NNPACKConv", "NNPACKConv",
FuncConfig() FuncConfig()
@ -117,6 +122,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
FuncConfig() FuncConfig()
.set("paddings", paddings) .set("paddings", paddings)
.set("strides", strides) .set("strides", strides)
.set("dilations", dilations)
.set("groups", (size_t)groups_[i])); .set("groups", (size_t)groups_[i]));
createFunction(backward_, createFunction(backward_,
@ -124,6 +130,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
FuncConfig() FuncConfig()
.set("paddings", paddings) .set("paddings", paddings)
.set("strides", strides) .set("strides", strides)
.set("dilations", dilations)
.set("groups", (size_t)groups_[i])); .set("groups", (size_t)groups_[i]));
createFunction(backward_, createFunction(backward_,
@ -131,6 +138,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
FuncConfig() FuncConfig()
.set("paddings", paddings) .set("paddings", paddings)
.set("strides", strides) .set("strides", strides)
.set("dilations", dilations)
.set("groups", (size_t)groups_[i])); .set("groups", (size_t)groups_[i]));
} }
} }

@ -0,0 +1,109 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "MaxPoolWithMaskLayer.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace paddle {
bool MaxPoolWithMaskLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
PoolLayer::init(layerMap, parameterMap);
setOutput("mask", &mask_);
return true;
}
size_t MaxPoolWithMaskLayer::getSize() {
CHECK_EQ(inputLayers_.size(), 1UL);
size_t layerSize = 0;
outputY_ = outputSize(imgSizeY_,
sizeY_,
confPaddingY_,
strideY_,
/* caffeMode */ false);
outputX_ = outputSize(imgSize_,
sizeX_,
confPadding_,
stride_,
/* caffeMode */ false);
layerSize = outputX_ * outputY_ * channels_;
getOutput().setFrameHeight(outputY_);
getOutput().setFrameWidth(outputX_);
return layerSize;
}
void MaxPoolWithMaskLayer::forward(PassType passType) {
size_t size = getSize();
MatrixPtr inputV = inputLayers_[0]->getOutputValue();
int batchSize = inputV->getHeight();
resetOutput(batchSize, size);
MatrixPtr outV = getOutputValue();
CHECK_EQ(size, outV->getWidth());
resetSpecifyOutput(mask_,
batchSize,
size,
/* isValueClean */ false,
/* isGradClean */ true);
MatrixPtr maskV = mask_.value;
outV->maxPoolForward(*inputV,
imgSizeY_,
imgSize_,
channels_,
sizeX_,
sizeY_,
strideY_,
stride_,
outputY_,
outputX_,
confPaddingY_,
confPadding_,
maskV);
}
void MaxPoolWithMaskLayer::backward(const UpdateCallback& callback) {
(void)callback;
if (NULL == getInputGrad(0)) {
return;
}
MatrixPtr outGrad = getOutputGrad();
MatrixPtr inputV = inputLayers_[0]->getOutputValue();
MatrixPtr outV = getOutputValue();
MatrixPtr inputGrad = inputLayers_[0]->getOutputGrad();
inputGrad->maxPoolBackward(*inputV,
imgSizeY_,
imgSize_,
*outGrad,
*outV,
sizeX_,
sizeY_,
strideY_,
stride_,
outputY_,
outputX_,
1,
1,
confPaddingY_,
confPadding_);
}
} // namespace paddle

@ -0,0 +1,40 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "PoolLayer.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* @brief Basic parent layer of different kinds of pooling
*/
class MaxPoolWithMaskLayer : public PoolLayer {
protected:
Argument mask_;
public:
explicit MaxPoolWithMaskLayer(const LayerConfig& config)
: PoolLayer(config) {}
size_t getSize();
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
};
} // namespace paddle

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "PoolLayer.h" #include "PoolLayer.h"
#include "MaxPoolWithMaskLayer.h"
#include "PoolProjectionLayer.h" #include "PoolProjectionLayer.h"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
@ -44,7 +45,6 @@ bool PoolLayer::init(const LayerMap& layerMap,
strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride(); strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride();
confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding(); confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding();
outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x(); outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x();
return true; return true;
} }
@ -57,6 +57,8 @@ Layer* PoolLayer::create(const LayerConfig& config) {
} else if (CudnnPoolLayer::typeCheck(pool)) { } else if (CudnnPoolLayer::typeCheck(pool)) {
return new CudnnPoolLayer(config); return new CudnnPoolLayer(config);
#endif #endif
} else if (pool == "max-pool-with-mask") {
return new MaxPoolWithMaskLayer(config);
} else { } else {
LOG(FATAL) << "Unknown pool type: " << pool; LOG(FATAL) << "Unknown pool type: " << pool;
return nullptr; return nullptr;

@ -98,7 +98,7 @@ void ROIPoolLayer::forward(PassType passType) {
size_t roiStartH = round(bottomROIs[2] * spatialScale_); size_t roiStartH = round(bottomROIs[2] * spatialScale_);
size_t roiEndW = round(bottomROIs[3] * spatialScale_); size_t roiEndW = round(bottomROIs[3] * spatialScale_);
size_t roiEndH = round(bottomROIs[4] * spatialScale_); size_t roiEndH = round(bottomROIs[4] * spatialScale_);
CHECK_GE(roiBatchIdx, 0); CHECK_GE(roiBatchIdx, 0UL);
CHECK_LT(roiBatchIdx, batchSize); CHECK_LT(roiBatchIdx, batchSize);
size_t roiHeight = std::max(roiEndH - roiStartH + 1, 1UL); size_t roiHeight = std::max(roiEndH - roiStartH + 1, 1UL);
size_t roiWidth = std::max(roiEndW - roiStartW + 1, 1UL); size_t roiWidth = std::max(roiEndW - roiStartW + 1, 1UL);

@ -24,6 +24,7 @@ gserver_test(test_ConvUnify)
gserver_test(test_BatchNorm) gserver_test(test_BatchNorm)
gserver_test(test_KmaxSeqScore) gserver_test(test_KmaxSeqScore)
gserver_test(test_Expand) gserver_test(test_Expand)
gserver_test(test_MaxPoolingWithMaskOutput)
########## test_Mkldnn layers and activations ########## ########## test_Mkldnn layers and activations ##########
if(WITH_MKLDNN) if(WITH_MKLDNN)

@ -434,7 +434,7 @@ void testConvLayer(const string& type, bool trans, bool useGpu) {
config.layerConfig.set_partial_sum(1); config.layerConfig.set_partial_sum(1);
config.layerConfig.set_shared_biases(true); config.layerConfig.set_shared_biases(true);
int dilation = 1; int dilation = 2;
if (type == "cudnn_conv") { if (type == "cudnn_conv") {
#if CUDNN_VERSION >= 6000 #if CUDNN_VERSION >= 6000
dilation = 2; dilation = 2;
@ -1234,6 +1234,7 @@ void testPoolLayer2(const string& poolType, bool trans, bool useGpu) {
TEST(Layer, PoolLayer) { TEST(Layer, PoolLayer) {
testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ false); testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ false);
testPoolLayer("max-projection", /* trans= */ false, /* useGpu= */ false); testPoolLayer("max-projection", /* trans= */ false, /* useGpu= */ false);
testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ false);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ true); testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ true);
@ -1242,6 +1243,7 @@ TEST(Layer, PoolLayer) {
testPoolLayer("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true); testPoolLayer("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true);
testPoolLayer2("cudnn-max-pool", /* trans= */ false, /* useGpu= */ true); testPoolLayer2("cudnn-max-pool", /* trans= */ false, /* useGpu= */ true);
testPoolLayer2("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true); testPoolLayer2("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true);
testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ true);
#endif #endif
} }

@ -297,7 +297,7 @@ static void getAddtoConfig(TestConfig& cfg,
} }
void testAddtoLayer(const testImageDesc& pm, const size_t nInputs) { void testAddtoLayer(const testImageDesc& pm, const size_t nInputs) {
CHECK_GE(nInputs, 1); CHECK_GE(nInputs, 1UL);
TestConfig dnnConfig; TestConfig dnnConfig;
getAddtoConfig(dnnConfig, pm, nInputs); getAddtoConfig(dnnConfig, pm, nInputs);
dnnConfig.layerConfig.set_type("mkldnn_addto"); dnnConfig.layerConfig.set_type("mkldnn_addto");

@ -0,0 +1,117 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "LayerGradUtil.h"
#include "paddle/math/MathUtils.h"
#include "paddle/testing/TestUtil.h"
using namespace paddle;
void setPoolConfig(TestConfig* config,
PoolConfig* pool,
const string& poolType) {
(*config).biasSize = 0;
(*config).layerConfig.set_type("pool");
(*config).layerConfig.set_num_filters(1);
int kw = 3, kh = 3;
int pw = 0, ph = 0;
int sw = 2, sh = 2;
pool->set_pool_type(poolType);
pool->set_channels(1);
pool->set_size_x(kw);
pool->set_size_y(kh);
pool->set_start(0);
pool->set_padding(pw);
pool->set_padding_y(ph);
pool->set_stride(sw);
pool->set_stride_y(sh);
int ow = outputSize(pool->img_size(), kw, pw, sw, /* caffeMode */ false);
int oh = outputSize(pool->img_size_y(), kh, ph, sh, /* caffeMode */ false);
pool->set_output_x(ow);
pool->set_output_y(oh);
}
void doOneMaxPoolingWithMaskOutputTest(MatrixPtr& inputMat,
const string& poolType,
bool use_gpu,
MatrixPtr& maskMat) {
TestConfig config;
config.inputDefs.push_back({INPUT_DATA, "layer_0", 25, 0});
LayerInputConfig* input = config.layerConfig.add_inputs();
PoolConfig* pool = input->mutable_pool_conf();
pool->set_img_size(5);
pool->set_img_size_y(5);
setPoolConfig(&config, pool, poolType);
config.layerConfig.set_size(pool->output_x() * pool->output_y() *
pool->channels());
config.layerConfig.set_name("MaxPoolWithMask");
std::vector<DataLayerPtr> dataLayers;
LayerMap layerMap;
vector<Argument> datas;
initDataLayer(config,
&dataLayers,
&datas,
&layerMap,
"MaxPoolWithMask",
1,
false,
use_gpu);
dataLayers[0]->getOutputValue()->copyFrom(*inputMat);
FLAGS_use_gpu = use_gpu;
std::vector<ParameterPtr> parameters;
LayerPtr maxPoolingWithMaskOutputLayer;
initTestLayer(config, &layerMap, &parameters, &maxPoolingWithMaskOutputLayer);
maxPoolingWithMaskOutputLayer->forward(PASS_GC);
checkMatrixEqual(maxPoolingWithMaskOutputLayer->getOutput("mask").value,
maskMat);
}
TEST(Layer, maxPoolingWithMaskOutputLayerFwd) {
bool useGpu = false;
MatrixPtr inputMat;
MatrixPtr maskMat;
real inputData[] = {0.1, 0.1, 0.5, 0.5, 1.1, 0.2, 0.2, 0.6, 0.1,
0.1, 0.3, 0.3, 0.7, 0.1, 0.1, 0.4, 0.4, 0.8,
0.8, 0.1, 1.0, 2.0, 3.0, 0.0, 9.0};
real maskData[] = {12, 4, 22, 24};
inputMat = Matrix::create(1, 25, false, useGpu);
maskMat = Matrix::create(1, 4, false, useGpu);
inputMat->setData(inputData);
maskMat->setData(maskData);
doOneMaxPoolingWithMaskOutputTest(
inputMat, "max-pool-with-mask", useGpu, maskMat);
#ifdef PADDLE_WITH_CUDA
useGpu = true;
inputMat = Matrix::create(1, 25, false, useGpu);
maskMat = Matrix::create(1, 4, false, useGpu);
inputMat->copyFrom(inputData, 25);
maskMat->copyFrom(maskData, 4);
doOneMaxPoolingWithMaskOutputTest(
inputMat, "max-pool-with-mask", useGpu, maskMat);
#endif
}

@ -1028,15 +1028,23 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat,
size_t outputH, size_t outputH,
size_t outputW, size_t outputW,
size_t paddingH, size_t paddingH,
size_t paddingW) { size_t paddingW,
MatrixPtr maskMatP) {
CHECK(inputMat.useGpu_ == true) << "Matrix type are not equal"; CHECK(inputMat.useGpu_ == true) << "Matrix type are not equal";
real* inputData = inputMat.getData(); real* inputData = inputMat.getData();
real* maskData = NULL;
size_t frameNum = inputMat.getHeight(); size_t frameNum = inputMat.getHeight();
CHECK(imgSizeH * imgSizeW * channels == inputMat.getWidth()); CHECK(imgSizeH * imgSizeW * channels == inputMat.getWidth());
CHECK(height_ == inputMat.getHeight()); CHECK(height_ == inputMat.getHeight());
CHECK(width_ == outputH * outputW * channels); CHECK(width_ == outputH * outputW * channels);
if (maskMatP != NULL) {
CHECK(maskMatP->useGpu_ == true) << "Matrix type are not equal";
CHECK(outputH * outputW * channels == maskMatP->getWidth());
maskData = maskMatP->getData();
}
hl_maxpool_forward(frameNum, hl_maxpool_forward(frameNum,
inputData, inputData,
channels, channels,
@ -1051,7 +1059,8 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat,
paddingH, paddingH,
paddingW, paddingW,
data_, data_,
getStride()); getStride(),
maskData);
} }
void GpuMatrix::maxPoolBackward(Matrix& inputMat, void GpuMatrix::maxPoolBackward(Matrix& inputMat,
@ -1973,9 +1982,11 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
size_t outputH, size_t outputH,
size_t outputW, size_t outputW,
size_t paddingH, size_t paddingH,
size_t paddingW) { size_t paddingW,
MatrixPtr maskMatP) {
real* inputData = inputMat.getData(); real* inputData = inputMat.getData();
real* outData = data_; real* outData = data_;
real* maskData = NULL;
size_t num = inputMat.getHeight(); size_t num = inputMat.getHeight();
size_t inLength = imgSizeH * imgSizeW; size_t inLength = imgSizeH * imgSizeW;
size_t outLength = outputH * outputW; size_t outLength = outputH * outputW;
@ -1984,6 +1995,11 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
CHECK_EQ(channels * outLength, this->getWidth()); CHECK_EQ(channels * outLength, this->getWidth());
size_t outStride = getStride(); size_t outStride = getStride();
if (maskMatP != NULL) {
maskData = maskMatP->getData();
CHECK_EQ(channels * outLength, maskMatP->getWidth());
}
/* initialize the data_ */ /* initialize the data_ */
for (size_t i = 0; i < height_; i++) { for (size_t i = 0; i < height_; i++) {
for (size_t j = 0; j < width_; j++) { for (size_t j = 0; j < width_; j++) {
@ -2005,10 +2021,21 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
int wstart = pw * strideW - paddingW; int wstart = pw * strideW - paddingW;
int wend = std::min(wstart + sizeX, imgSizeW); int wend = std::min(wstart + sizeX, imgSizeW);
wstart = std::max(wstart, 0); wstart = std::max(wstart, 0);
for (int h = hstart; h < hend; ++h) { if (maskData == NULL) {
for (int w = wstart; w < wend; ++w) { for (int h = hstart; h < hend; ++h) {
outData[ph * outputW + pw] = std::max( for (int w = wstart; w < wend; ++w) {
outData[ph * outputW + pw], inputData[h * imgSizeW + w]); outData[ph * outputW + pw] = std::max(
outData[ph * outputW + pw], inputData[h * imgSizeW + w]);
}
}
} else {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (outData[ph * outputW + pw] < inputData[h * imgSizeW + w]) {
outData[ph * outputW + pw] = inputData[h * imgSizeW + w];
maskData[ph * outputW + pw] = h * imgSizeW + w;
}
}
} }
} }
} }
@ -2016,6 +2043,8 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
// compute offset // compute offset
inputData += inLength; inputData += inLength;
outData += outLength; outData += outLength;
if (maskData != NULL) maskData += outLength;
} }
} }
} }

@ -861,7 +861,8 @@ public:
/** /**
* Pooling forward operation, pick out the largest element * Pooling forward operation, pick out the largest element
* in the sizeX of value * in the sizeX of value, if the maskMatP is not NULL, it will
* also caculate the location indices.
*/ */
virtual void maxPoolForward(Matrix& inputMat, virtual void maxPoolForward(Matrix& inputMat,
size_t imgSizeH, size_t imgSizeH,
@ -874,7 +875,8 @@ public:
size_t outputH, size_t outputH,
size_t outputW, size_t outputW,
size_t paddingH, size_t paddingH,
size_t paddingW) { size_t paddingW,
MatrixPtr maskMatP = NULL) {
LOG(FATAL) << "Not implemeted"; LOG(FATAL) << "Not implemeted";
} }
@ -1426,7 +1428,8 @@ public:
size_t outputH, size_t outputH,
size_t outputW, size_t outputW,
size_t paddingH, size_t paddingH,
size_t paddingW); size_t paddingW,
MatrixPtr maskMatP);
void maxPoolBackward(Matrix& image, void maxPoolBackward(Matrix& image,
size_t imgSizeH, size_t imgSizeH,
@ -1697,7 +1700,8 @@ public:
size_t outputH, size_t outputH,
size_t outputW, size_t outputW,
size_t paddingH, size_t paddingH,
size_t paddingW); size_t paddingW,
MatrixPtr maskMatP);
void maxPoolBackward(Matrix& image, void maxPoolBackward(Matrix& image,
size_t imgSizeH, size_t imgSizeH,

@ -214,6 +214,7 @@ set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory) cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc
rnn/recurrent_op_utils.cc rnn/recurrent_op_utils.cc

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save