Update code and fix conflicts

updateWriteDocsCN
dangqingqing 7 years ago
commit d9a305cb93

2
.gitignore vendored

@ -21,7 +21,7 @@ third_party/
cmake-build-* cmake-build-*
# generated while compiling # generated while compiling
python/paddle/v2/framework/core.so python/paddle/v2/fluid/core.so
paddle/pybind/pybind.h paddle/pybind/pybind.h
CMakeFiles CMakeFiles
cmake_install.cmake cmake_install.cmake

@ -0,0 +1,58 @@
## Evaluator Design
### The Problem
During training or serving, we provide the evaluation function to measure the model performance, e.g., accuracy, precision. In the operator based framework design, the data go through the network pipeline batch by batch. As a result, inside the operator, we only can calculate one minibatch metrics. We need to provide a mechanism to calculate the metrics for each N pass/batch the user wanted.
### Evaluator Design
Currently, every operation is expressed in the graph. we divide the evaluator process into three steps.
1. Initialize the metric state and add it into the block.
2. Calculate the statistic of the metric state in every mini-batch. The single operator is only responsible for calculating necessary statistics for one mini-batch. For example, accuracy operator only calculate a minibatch data if run once.
3. Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. When it comes to distributed training/Multi-GPU training, aggregate the value from different devices.
### Implementation
This design is shown in python API.
Each metric operator need to caculate the metric statistic and return the batch aware states, Python side responsible for accumulate the states for each pass.
```python
class Evaluator(object):
"""
Evaluator Base class.
"""
def __init__(self, name, **kwargs):
"""
Different evaluator may has different metric states. E.g, Accuracy need two variables, total and right sample counts.
Auc need four variables, `true_positives`,
`true_negatives`, `false_positives` and `false_negatives`. So every evaluator should create its needed variables and append to main_program
The initialization of Evaluator should be responsible for:
create metric states and append to the main_program
"""
pass
def _update_ops(self, input, label, **kwargs)
"""
Add mini-batch evaluator caculate operators to the main_program.
Add increment operator to accumulate the metric states.
"""
def reset(self, executor, reset_program=None):
"""
Reset metric states at the begin of each pass/user specified batch number.
Execute the reset_program to reset the states.
"""
def eval(self, executor, eval_program=None):
"""
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
Execute the eval_program and return the result.
"""
return eval_result
```

@ -121,6 +121,7 @@ paddle_error paddle_matrix_get_shape(paddle_matrix mat,
paddle_matrix paddle_matrix_create_sparse( paddle_matrix paddle_matrix_create_sparse(
uint64_t height, uint64_t width, uint64_t nnz, bool isBinary, bool useGpu) { uint64_t height, uint64_t width, uint64_t nnz, bool isBinary, bool useGpu) {
#ifndef PADDLE_MOBILE_INFERENCE
auto ptr = new paddle::capi::CMatrix(); auto ptr = new paddle::capi::CMatrix();
ptr->mat = paddle::Matrix::createSparseMatrix( ptr->mat = paddle::Matrix::createSparseMatrix(
height, height,
@ -131,6 +132,9 @@ paddle_matrix paddle_matrix_create_sparse(
false, false,
useGpu); useGpu);
return ptr; return ptr;
#else
return nullptr;
#endif
} }
paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat, paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat,
@ -140,6 +144,7 @@ paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat,
uint64_t colSize, uint64_t colSize,
float* valueArray, float* valueArray,
uint64_t valueSize) { uint64_t valueSize) {
#ifndef PADDLE_MOBILE_INFERENCE
if (mat == nullptr) return kPD_NULLPTR; if (mat == nullptr) return kPD_NULLPTR;
auto ptr = cast(mat); auto ptr = cast(mat);
if (rowArray == nullptr || colArray == nullptr || if (rowArray == nullptr || colArray == nullptr ||
@ -160,4 +165,7 @@ paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat,
} else { } else {
return kPD_NOT_SUPPORTED; return kPD_NOT_SUPPORTED;
} }
#else
return kPD_NOT_SUPPORTED;
#endif
} }

@ -48,6 +48,7 @@ PD_API paddle_matrix paddle_matrix_create(uint64_t height,
* @param isBinary is binary (either 1 or 0 in matrix) or not. * @param isBinary is binary (either 1 or 0 in matrix) or not.
* @param useGpu is using GPU or not. * @param useGpu is using GPU or not.
* @return paddle_matrix. * @return paddle_matrix.
* @note Mobile inference does not support this interface.
*/ */
PD_API paddle_matrix paddle_matrix_create_sparse( PD_API paddle_matrix paddle_matrix_create_sparse(
uint64_t height, uint64_t width, uint64_t nnz, bool isBinary, bool useGpu); uint64_t height, uint64_t width, uint64_t nnz, bool isBinary, bool useGpu);
@ -129,6 +130,7 @@ PD_API paddle_error paddle_matrix_get_shape(paddle_matrix mat,
* NULL if the matrix is binary. * NULL if the matrix is binary.
* @param [in] valueSize length of value array. Zero if the matrix is binary. * @param [in] valueSize length of value array. Zero if the matrix is binary.
* @return paddle_error * @return paddle_error
* @note Mobile inference does not support this interface.
*/ */
PD_API paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat, PD_API paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat,
int* rowArray, int* rowArray,

@ -27,7 +27,9 @@ if(WITH_GPU)
set_source_files_properties(${CUDA_CXX_SOURCES} set_source_files_properties(${CUDA_CXX_SOURCES}
PROPERTIES COMPILE_FLAGS "-D__NVCC__") PROPERTIES COMPILE_FLAGS "-D__NVCC__")
else() else()
if (NOT MOBILE_INFERENCE)
set(CUDA_CXX_SOURCES src/hl_warpctc_wrap.cc) set(CUDA_CXX_SOURCES src/hl_warpctc_wrap.cc)
endif()
endif() endif()
set(CUDA_CU_SOURCES set(CUDA_CU_SOURCES

@ -18,7 +18,7 @@ limitations under the License. */
#include "hl_base.h" #include "hl_base.h"
/** /**
* @brief Maximum pool forward. * @brief Maximum pool forward with Mask output.
* *
* @param[in] frameCnt batch size of input image. * @param[in] frameCnt batch size of input image.
* @param[in] inputData input data. * @param[in] inputData input data.
@ -35,7 +35,7 @@ limitations under the License. */
* @param[in] paddingW padding width. * @param[in] paddingW padding width.
* @param[out] tgtData output data. * @param[out] tgtData output data.
* @param[in] tgtStride stride between output data samples. * @param[in] tgtStride stride between output data samples.
* * @param[out] maskData the location indices of select max data.
*/ */
extern void hl_maxpool_forward(const int frameCnt, extern void hl_maxpool_forward(const int frameCnt,
const real* inputData, const real* inputData,
@ -51,7 +51,8 @@ extern void hl_maxpool_forward(const int frameCnt,
const int paddingH, const int paddingH,
const int paddingW, const int paddingW,
real* tgtData, real* tgtData,
const int tgtStride); const int tgtStride,
real* maskData = NULL);
/** /**
* @brief Maximum pool backward. * @brief Maximum pool backward.

@ -31,7 +31,8 @@ inline void hl_maxpool_forward(const int frameCnt,
const int paddingH, const int paddingH,
const int paddingW, const int paddingW,
real* tgtData, real* tgtData,
const int tgtStride) {} const int tgtStride,
real* MaskData) {}
inline void hl_maxpool_backward(const int frameCnt, inline void hl_maxpool_backward(const int frameCnt,
const real* inputData, const real* inputData,

@ -31,7 +31,8 @@ __global__ void KeMaxPoolForward(const int nthreads,
const int offsetH, const int offsetH,
const int offsetW, const int offsetW,
real* tgtData, real* tgtData,
const int tgtStride) { const int tgtStride,
real* maskData) {
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) { if (index < nthreads) {
int pw = index % pooledW; int pw = index % pooledW;
@ -45,16 +46,22 @@ __global__ void KeMaxPoolForward(const int nthreads,
hstart = max(hstart, 0); hstart = max(hstart, 0);
wstart = max(wstart, 0); wstart = max(wstart, 0);
real maxval = -FLT_MAX; real maxval = -FLT_MAX;
int max_index = -1;
inputData += (frameNum * channels + c) * height * width; inputData += (frameNum * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
if (maxval < inputData[h * width + w]) if (maxval < inputData[h * width + w]) {
maxval = inputData[h * width + w]; max_index = h * width + w;
maxval = inputData[max_index];
}
} }
} }
int tgtIndex = int tgtIndex =
index % (pooledW * pooledH * channels) + frameNum * tgtStride; index % (pooledW * pooledH * channels) + frameNum * tgtStride;
tgtData[tgtIndex] = maxval; tgtData[tgtIndex] = maxval;
if (maskData != NULL) {
maskData[tgtIndex] = max_index;
}
} }
} }
@ -72,7 +79,8 @@ void hl_maxpool_forward(const int frameCnt,
const int paddingH, const int paddingH,
const int paddingW, const int paddingW,
real* tgtData, real* tgtData,
const int tgtStride) { const int tgtStride,
real* maskData) {
int num_kernels = pooledH * pooledW * channels * frameCnt; int num_kernels = pooledH * pooledW * channels * frameCnt;
int blocks = (num_kernels + 1024 - 1) / 1024; int blocks = (num_kernels + 1024 - 1) / 1024;
dim3 threads(1024, 1); dim3 threads(1024, 1);
@ -92,7 +100,8 @@ void hl_maxpool_forward(const int frameCnt,
paddingH, paddingH,
paddingW, paddingW,
tgtData, tgtData,
tgtStride); tgtStride,
maskData);
CHECK_SYNC("hl_maxpool_forward failed"); CHECK_SYNC("hl_maxpool_forward failed");
} }

@ -377,6 +377,12 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return grad_op_descs; return grad_op_descs;
} }
static BlockDescBind* CreateStepBlock(
ProgramDescBind& program_desc,
std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var,
int step_block_idx);
std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind& program_desc, int block_idx, ProgramDescBind& program_desc, int block_idx,
std::unordered_set<std::string>* no_grad_vars, std::unordered_set<std::string>* no_grad_vars,
@ -392,13 +398,13 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
if ((*it)->Type() == "recurrent") { if ((*it)->Type() == "recurrent") {
int step_block_idx = (*it)->GetBlockAttr("step_block"); int step_block_idx = (*it)->GetBlockAttr("step_block");
auto backward_block_op_descs = MakeBlockBackward( BlockDescBind* backward_block = CreateStepBlock(
program_desc, step_block_idx, no_grad_vars, grad_to_var); program_desc, no_grad_vars, grad_to_var, step_block_idx);
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
} else if ((*it)->Type() == "conditional_block") {
BlockDescBind* backward_block = BlockDescBind* backward_block =
program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx)); CreateStepBlock(program_desc, no_grad_vars, grad_to_var,
for (auto& ptr : backward_block_op_descs) { (*it)->GetBlockAttr("block"));
backward_block->AppendAllocatedOp(std::move(ptr));
}
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block}); op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
} else { } else {
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var); op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var);
@ -449,6 +455,21 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
return backward_descs; return backward_descs;
} }
static BlockDescBind* CreateStepBlock(
ProgramDescBind& program_desc,
std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var,
int step_block_idx) {
auto backward_block_op_descs = MakeBlockBackward(program_desc, step_block_idx,
no_grad_vars, grad_to_var);
BlockDescBind* backward_block =
program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx));
for (auto& ptr : backward_block_op_descs) {
backward_block->AppendAllocatedOp(move(ptr));
}
return backward_block;
}
ParamGradInfoMap AppendBackward( ParamGradInfoMap AppendBackward(
ProgramDescBind& program_desc, const VarDescBind& target, ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars) { const std::unordered_set<std::string>& no_grad_vars) {

@ -27,10 +27,32 @@ inline VarDesc::VarType ToVarType(std::type_index type) {
return VarDesc_VarType_LOD_RANK_TABLE; return VarDesc_VarType_LOD_RANK_TABLE;
} else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) { } else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) {
return VarDesc_VarType_LOD_TENSOR_ARRAY; return VarDesc_VarType_LOD_TENSOR_ARRAY;
} else if (type.hash_code() == typeid(SelectedRows).hash_code()) {
return VarDesc_VarType_SELECTED_ROWS;
} else { } else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name()); PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
} }
} }
template <typename Visitor>
inline void VisitVarType(const Variable& var, Visitor visitor) {
switch (ToVarType(var.Type())) {
case VarDesc_VarType_LOD_TENSOR:
visitor(var.Get<framework::LoDTensor>());
return;
case VarDesc_VarType_LOD_RANK_TABLE:
visitor(var.Get<LoDRankTable>());
return;
case VarDesc_VarType_LOD_TENSOR_ARRAY:
visitor(var.Get<LoDTensorArray>());
return;
case VarDesc_VarType_SELECTED_ROWS:
visitor(var.Get<SelectedRows>());
return;
default:
PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type()));
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -61,6 +61,7 @@ public:
// function arguments // function arguments
strides_ = config.get<std::vector<size_t>>("strides"); strides_ = config.get<std::vector<size_t>>("strides");
paddings_ = config.get<std::vector<size_t>>("paddings"); paddings_ = config.get<std::vector<size_t>>("paddings");
dilations_ = config.get<std::vector<size_t>>("dilations");
groups_ = config.get<size_t>("groups"); groups_ = config.get<size_t>("groups");
// number of inputs and outputs // number of inputs and outputs
@ -118,6 +119,7 @@ protected:
std::vector<size_t> strides_; std::vector<size_t> strides_;
std::vector<size_t> paddings_; std::vector<size_t> paddings_;
std::vector<size_t> dilations_;
/// Group size, refer to grouped convolution in /// Group size, refer to grouped convolution in
/// Alex Krizhevsky's paper: when group=2, the first half of the /// Alex Krizhevsky's paper: when group=2, the first half of the
@ -133,6 +135,10 @@ protected:
inline int paddingW() const { return paddings_[1]; } inline int paddingW() const { return paddings_[1]; }
inline int dilationH() const { return dilations_[0]; }
inline int dilationW() const { return dilations_[1]; }
// A temporary memory in convolution calculation. // A temporary memory in convolution calculation.
MemoryHandlePtr memory_; MemoryHandlePtr memory_;

@ -79,45 +79,59 @@ void Convolution(const std::string& conv1,
if (outputChannels < inputChannels) continue; if (outputChannels < inputChannels) continue;
for (size_t stride : {1, 2}) { for (size_t stride : {1, 2}) {
for (size_t padding : {0, 1}) { for (size_t padding : {0, 1}) {
if (padding >= filterSize) break; for (size_t dilation : {1, 3}) {
if (padding >= filterSize) break;
size_t filterS = (filterSize - 1) * dilation + 1;
// NNPACK only supports stride = 1 if batchSize > 1 if (inputSize + 2 * padding < filterS) break;
if ((conv1 == "NNPACKConv-CPU" || conv2 == "NNPACKConv-CPU") &&
batchSize > 1 && stride > 1)
break;
size_t outputSize = if ((conv1 == "NaiveConv-CPU" || conv2 == "NaiveConv-CPU" ||
(inputSize - filterSize + 2 * padding + stride) / stride; conv1 == "NNPACKConv-CPU" ||
VLOG(3) << " batchSize=" << batchSize conv2 == "NNPACKConv-CPU") &&
<< " inputChannels=" << inputChannels dilation > 1)
<< " inputHeight=" << inputSize break;
<< " inputWidth=" << inputSize
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterSize
<< " filterWidth=" << filterSize
<< " outputHeight=" << outputSize
<< " outputWidth=" << outputSize << " stride=" << stride
<< " padding=" << padding;
std::vector<size_t> paddings = {padding, padding}; // NNPACK only supports stride = 1 if batchSize > 1
std::vector<size_t> strides = {stride, stride}; if ((conv1 == "NNPACKConv-CPU" ||
Compare2Function<DType1, DType2> test( conv2 == "NNPACKConv-CPU") &&
conv1, batchSize > 1 && stride > 1)
conv2, break;
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)1)
.set("algo", (std::string) "auto"));
TensorShape input{ size_t outputSize =
batchSize, inputChannels, inputSize, inputSize}; (inputSize - filterS + 2 * padding + stride) / stride;
TensorShape filter{ VLOG(3) << " batchSize=" << batchSize
outputChannels, inputChannels, filterSize, filterSize}; << " inputChannels=" << inputChannels
TensorShape output{ << " inputHeight=" << inputSize
batchSize, outputChannels, outputSize, outputSize}; << " inputWidth=" << inputSize
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterSize
<< " filterWidth=" << filterSize
<< " outputHeight=" << outputSize
<< " outputWidth=" << outputSize
<< " stride=" << stride << " padding=" << padding;
function(test, input, filter, output); std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
std::vector<size_t> dilations = {dilation, dilation};
Compare2Function<DType1, DType2> test(
conv1,
conv2,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("dilations", dilations)
.set("groups", (size_t)1)
.set("algo", (std::string) "auto"));
TensorShape input{
batchSize, inputChannels, inputSize, inputSize};
TensorShape filter{
outputChannels, inputChannels, filterSize, filterSize};
TensorShape output{
batchSize, outputChannels, outputSize, outputSize};
function(test, input, filter, output);
}
} }
} }
} }
@ -144,6 +158,7 @@ void Convolution2(const std::string& conv1,
for (size_t outputChannels : {7}) { for (size_t outputChannels : {7}) {
size_t stride = 1; size_t stride = 1;
size_t padding = 0; size_t padding = 0;
size_t dilation = 1;
size_t outputHeight = size_t outputHeight =
(inputHeight - filterHeight + 2 * padding + stride) / (inputHeight - filterHeight + 2 * padding + stride) /
stride; stride;
@ -162,6 +177,7 @@ void Convolution2(const std::string& conv1,
std::vector<size_t> paddings = {padding, padding}; std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride}; std::vector<size_t> strides = {stride, stride};
std::vector<size_t> dilations = {dilation, dilation};
Compare2Function<DType1, DType2> test( Compare2Function<DType1, DType2> test(
conv1, conv1,
conv2, conv2,
@ -169,6 +185,7 @@ void Convolution2(const std::string& conv1,
.set("paddings", paddings) .set("paddings", paddings)
.set("strides", strides) .set("strides", strides)
.set("groups", (size_t)1) .set("groups", (size_t)1)
.set("dilations", dilations)
.set("algo", (std::string) "auto")); .set("algo", (std::string) "auto"));
TensorShape input{ TensorShape input{
@ -223,6 +240,7 @@ void DepthwiseConvolution(const std::string& conv1,
std::vector<size_t> paddings = {padding, padding}; std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride}; std::vector<size_t> strides = {stride, stride};
std::vector<size_t> dilations = {1, 1};
size_t groups = inputChannels; size_t groups = inputChannels;
Compare2Function<DType1, DType2> test( Compare2Function<DType1, DType2> test(
conv1, conv1,
@ -231,6 +249,7 @@ void DepthwiseConvolution(const std::string& conv1,
.set("paddings", paddings) .set("paddings", paddings)
.set("strides", strides) .set("strides", strides)
.set("groups", groups) .set("groups", groups)
.set("dilations", dilations)
.set("algo", (std::string) "auto")); .set("algo", (std::string) "auto"));
TensorShape input{ TensorShape input{

@ -100,7 +100,9 @@ public:
strideH(), strideH(),
strideW(), strideW(),
paddingH(), paddingH(),
paddingW()); paddingW(),
dilationH(),
dilationW());
} else { } else {
colData = inputData + g * inputOffset; colData = inputData + g * inputOffset;
} }
@ -223,7 +225,9 @@ public:
strideH(), strideH(),
strideW(), strideW(),
paddingH(), paddingH(),
paddingW()); paddingW(),
dilationH(),
dilationW());
} }
} }
inputGrad += inputChannels * inputHeight * inputWidth; inputGrad += inputChannels * inputHeight * inputWidth;
@ -310,7 +314,9 @@ public:
strideH(), strideH(),
strideW(), strideW(),
paddingH(), paddingH(),
paddingW()); paddingW(),
dilationH(),
dilationW());
} else { } else {
colData = inputData + g * inputOffset; colData = inputData + g * inputOffset;
} }

@ -78,7 +78,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth); int paddingWidth,
int dilationHeight = 1,
int dilationWidth = 1);
}; };
template <ColFormat Format, DeviceType Device, class T> template <ColFormat Format, DeviceType Device, class T>
@ -91,7 +93,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth); int paddingWidth,
int dilationHeight = 1,
int dilationWidth = 1);
}; };
} // namespace paddle } // namespace paddle

@ -31,7 +31,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
@ -47,8 +49,8 @@ public:
int c_im = c / filterWidth / filterHeight; int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) { for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) { for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset; int imRowIdx = h * strideHeight + hOffset * dilationHeight;
int imColIdx = w * strideWidth + wOffset; int imColIdx = w * strideWidth + wOffset * dilationWidth;
if ((imRowIdx - paddingHeight) < 0 || if ((imRowIdx - paddingHeight) < 0 ||
(imRowIdx - paddingHeight) >= inputHeight || (imRowIdx - paddingHeight) >= inputHeight ||
(imColIdx - paddingWidth) < 0 || (imColIdx - paddingWidth) < 0 ||
@ -81,7 +83,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
@ -97,8 +101,8 @@ public:
int c_im = c / filterWidth / filterHeight; int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) { for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) { for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset; int imRowIdx = h * strideHeight + hOffset * dilationHeight;
int imColIdx = w * strideWidth + wOffset; int imColIdx = w * strideWidth + wOffset * dilationWidth;
if ((imRowIdx - paddingHeight) >= 0 && if ((imRowIdx - paddingHeight) >= 0 &&
(imRowIdx - paddingHeight) < inputHeight && (imRowIdx - paddingHeight) < inputHeight &&
(imColIdx - paddingWidth) >= 0 && (imColIdx - paddingWidth) >= 0 &&
@ -134,7 +138,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight = 1,
int dilationWidth = 1) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
@ -147,9 +153,10 @@ public:
for (int channel = 0; channel < inputChannels; ++channel) { for (int channel = 0; channel < inputChannels; ++channel) {
for (int filterH = 0; filterH < filterHeight; ++filterH) { for (int filterH = 0; filterH < filterHeight; ++filterH) {
for (int filterW = 0; filterW < filterWidth; ++filterW) { for (int filterW = 0; filterW < filterWidth; ++filterW) {
int imRowOffset = int imRowOffset = outputH * strideHeight +
outputH * strideHeight + filterH - paddingHeight; filterH * dilationHeight - paddingHeight;
int imColOffset = outputW * strideWidth + filterW - paddingWidth; int imColOffset = outputW * strideWidth +
filterW * dilationWidth - paddingWidth;
int colDataOffset = int colDataOffset =
(((outputH * outputWidth + outputW) * inputChannels + (((outputH * outputWidth + outputW) * inputChannels +
channel) * channel) *
@ -189,7 +196,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight = 1,
int dilationWidth = 1) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
@ -202,9 +211,10 @@ public:
for (int channel = 0; channel < inputChannels; ++channel) { for (int channel = 0; channel < inputChannels; ++channel) {
for (int filterH = 0; filterH < filterHeight; ++filterH) { for (int filterH = 0; filterH < filterHeight; ++filterH) {
for (int filterW = 0; filterW < filterWidth; ++filterW) { for (int filterW = 0; filterW < filterWidth; ++filterW) {
int imRowOffset = int imRowOffset = outputH * strideHeight +
outputH * strideHeight + filterH - paddingHeight; filterH * dilationHeight - paddingHeight;
int imColOffset = outputW * strideWidth + filterW - paddingWidth; int imColOffset = outputW * strideWidth +
filterW * dilationWidth - paddingWidth;
int colDataOffset = int colDataOffset =
(((outputH * outputWidth + outputW) * inputChannels + (((outputH * outputWidth + outputW) * inputChannels +
channel) * channel) *

@ -28,6 +28,8 @@ __global__ void im2col(const T* data_im,
int strideW, int strideW,
int paddingH, int paddingH,
int paddingW, int paddingW,
int dilationH,
int dilationW,
int height_col, int height_col,
int width_col, int width_col,
T* data_col) { T* data_col) {
@ -44,8 +46,8 @@ __global__ void im2col(const T* data_im,
data_col += (channel_out * height_col + h_out) * width_col + w_out; data_col += (channel_out * height_col + h_out) * width_col + w_out;
for (int i = 0; i < blockH; ++i) { for (int i = 0; i < blockH; ++i) {
for (int j = 0; j < blockW; ++j) { for (int j = 0; j < blockW; ++j) {
int rIdx = int(h_in + i); int rIdx = int(h_in + i * dilationH);
int cIdx = int(w_in + j); int cIdx = int(w_in + j * dilationW);
if ((rIdx - (int)paddingH) >= (int)height || if ((rIdx - (int)paddingH) >= (int)height ||
(rIdx - (int)paddingH) < 0 || (rIdx - (int)paddingH) < 0 ||
(cIdx - (int)paddingW) >= (int)width || (cIdx - (int)paddingW) >= (int)width ||
@ -77,7 +79,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
@ -102,6 +106,8 @@ public:
strideWidth, strideWidth,
paddingHeight, paddingHeight,
paddingWidth, paddingWidth,
dilationHeight,
dilationWidth,
outputHeight, outputHeight,
outputWidth, outputWidth,
colData); colData);
@ -121,6 +127,8 @@ __global__ void col2im(size_t n,
size_t strideW, size_t strideW,
size_t paddingH, size_t paddingH,
size_t paddingW, size_t paddingW,
size_t dilationH,
size_t dilationW,
size_t height_col, size_t height_col,
size_t width_col, size_t width_col,
T* data_im) { T* data_im) {
@ -131,23 +139,34 @@ __global__ void col2im(size_t n,
int w = int(index % width); int w = int(index % width);
int h = int((index / width) % height); int h = int((index / width) % height);
int c = int(index / (width * height)); int c = int(index / (width * height));
int filterH = (blockH - 1) * dilationH + 1;
int filterW = (blockW - 1) * dilationW + 1;
if ((w - (int)paddingW) >= 0 && if ((w - (int)paddingW) >= 0 &&
(w - (int)paddingW) < (width - 2 * paddingW) && (w - (int)paddingW) < (width - 2 * paddingW) &&
(h - (int)paddingH) >= 0 && (h - paddingH) < (height - 2 * paddingH)) { (h - (int)paddingH) >= 0 && (h - paddingH) < (height - 2 * paddingH)) {
// compute the start and end of the output // compute the start and end of the output
int w_col_start = int w_col_start =
(w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1; (w < (int)filterW) ? 0 : (w - int(filterW)) / (int)strideW + 1;
int w_col_end = min((int)(w / (int)strideW + 1), (int)(width_col)); int w_col_end = min((int)(w / (int)strideW + 1), (int)(width_col));
int h_col_start = int h_col_start =
(h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1; (h < (int)filterH) ? 0 : (h - (int)filterH) / (int)strideH + 1;
int h_col_end = min(int(h / strideH + 1), int(height_col)); int h_col_end = min(int(h / strideH + 1), int(height_col));
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
// the col location: [c * width * height + h_out, w_out] // the col location: [c * width * height + h_out, w_out]
int c_col = int(c * blockH * blockW) + int h_k = (h - h_col * strideH);
(h - h_col * (int)strideH) * (int)blockW + int w_k = (w - w_col * strideW);
(w - w_col * (int)strideW); if (h_k % dilationH == 0 && w_k % dilationW == 0) {
val += data_col[(c_col * height_col + h_col) * width_col + w_col]; h_k /= dilationH;
w_k /= dilationW;
int c_col =
(((c * blockH + h_k) * blockW + w_k) * height_col + h_col) *
width_col +
w_col;
val += data_col[c_col];
}
} }
} }
h -= paddingH; h -= paddingH;
@ -173,7 +192,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
@ -205,6 +226,8 @@ public:
strideWidth, strideWidth,
paddingHeight, paddingHeight,
paddingWidth, paddingWidth,
dilationHeight,
dilationWidth,
outputHeight, outputHeight,
outputWidth, outputWidth,
imData); imData);
@ -229,6 +252,8 @@ __global__ void im2colOCF(const T* imData,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth, int paddingWidth,
int dilationHeight,
int dilationWidth,
int outputHeight, int outputHeight,
int outputWidth) { int outputWidth) {
int swId = blockIdx.x; int swId = blockIdx.x;
@ -237,8 +262,10 @@ __global__ void im2colOCF(const T* imData,
channelId += blockDim.z) { channelId += blockDim.z) {
for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) { for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) { for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) {
int widthOffset = idx + swId * strideWidth - paddingWidth; int widthOffset =
int heightOffset = idy + shId * strideHeight - paddingHeight; idx * dilationHeight + swId * strideWidth - paddingWidth;
int heightOffset =
idy * dilationWidth + shId * strideHeight - paddingHeight;
int imOffset = widthOffset + heightOffset * inputWidth + int imOffset = widthOffset + heightOffset * inputWidth +
channelId * inputHeight * inputWidth; channelId * inputHeight * inputWidth;
@ -273,7 +300,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
@ -312,6 +341,8 @@ public:
strideWidth, strideWidth,
paddingHeight, paddingHeight,
paddingWidth, paddingWidth,
dilationHeight,
dilationWidth,
outputHeight, outputHeight,
outputWidth); outputWidth);
CHECK_SYNC("Im2ColFunctor GPU failed"); CHECK_SYNC("Im2ColFunctor GPU failed");
@ -330,6 +361,8 @@ __global__ void col2imOCF(T* imData,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth, int paddingWidth,
int dilationHeight,
int dilationWidth,
int outputHeight, int outputHeight,
int outputWidth) { int outputWidth) {
int swId = blockIdx.x; int swId = blockIdx.x;
@ -338,8 +371,10 @@ __global__ void col2imOCF(T* imData,
channelId += blockDim.z) { channelId += blockDim.z) {
for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) { for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) { for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) {
int widthOffset = idx + swId * strideWidth - paddingWidth; int widthOffset =
int heightOffset = idy + shId * strideHeight - paddingHeight; idx * dilationWidth + swId * strideWidth - paddingWidth;
int heightOffset =
idy * dilationHeight + shId * strideHeight - paddingHeight;
int imOffset = widthOffset + heightOffset * inputWidth + int imOffset = widthOffset + heightOffset * inputWidth +
channelId * inputHeight * inputWidth; channelId * inputHeight * inputWidth;
@ -372,7 +407,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
@ -411,6 +448,8 @@ public:
strideWidth, strideWidth,
paddingHeight, paddingHeight,
paddingWidth, paddingWidth,
dilationHeight,
dilationWidth,
outputHeight, outputHeight,
outputWidth); outputWidth);
CHECK_SYNC("Col2ImFunctor GPU failed"); CHECK_SYNC("Col2ImFunctor GPU failed");

@ -29,82 +29,98 @@ void TestIm2ColFunctor() {
for (size_t filterWidth : {3, 7}) { for (size_t filterWidth : {3, 7}) {
for (size_t stride : {1, 2}) { for (size_t stride : {1, 2}) {
for (size_t padding : {0, 1}) { for (size_t padding : {0, 1}) {
if (inputHeight <= filterHeight || inputWidth <= filterWidth) for (size_t dilation : {1, 3}) {
break; size_t filterSizeH = (filterHeight - 1) * dilation + 1;
if (padding >= filterHeight || padding >= filterWidth) break; size_t filterSizeW = (filterWidth - 1) * dilation + 1;
size_t outputHeight = if (inputHeight + 2 * padding < filterSizeH ||
(inputHeight - filterHeight + 2 * padding + stride) / inputWidth + 2 * padding < filterSizeW)
stride; break;
size_t outputWidth = if (padding >= filterSizeH || padding >= filterSizeW) break;
(inputWidth - filterWidth + 2 * padding + stride) / stride; size_t outputHeight =
(inputHeight - filterSizeH + 2 * padding) / stride + 1;
TensorShape imShape = size_t outputWidth =
TensorShape({channels, inputHeight, inputWidth}); (inputWidth - filterSizeW + 2 * padding) / stride + 1;
TensorShape colShape1 = TensorShape({channels,
filterHeight, TensorShape imShape =
filterWidth, TensorShape({channels, inputHeight, inputWidth});
outputHeight, TensorShape colShape1 = TensorShape({channels,
outputWidth}); filterHeight,
TensorShape colShape2 = TensorShape({outputHeight, filterWidth,
outputWidth, outputHeight,
channels, outputWidth});
filterHeight, TensorShape colShape2 = TensorShape({outputHeight,
filterWidth}); outputWidth,
channels,
size_t height = channels * filterHeight * filterWidth; filterHeight,
size_t width = outputHeight * outputWidth; filterWidth});
VectorPtr input1 = Vector::create(imShape.getElements(), false);
VectorPtr input2 = Vector::create(imShape.getElements(), false); size_t height = channels * filterHeight * filterWidth;
MatrixPtr output1 = Matrix::create(height, width, false, false); size_t width = outputHeight * outputWidth;
MatrixPtr output2 = Matrix::create(width, height, false, false); VectorPtr input1 =
input1->uniform(0.001, 1); Vector::create(imShape.getElements(), false);
input2->copyFrom(*input1); VectorPtr input2 =
Vector::create(imShape.getElements(), false);
Im2ColFunctor<kCFO, Device, T> im2Col1; MatrixPtr output1 =
Im2ColFunctor<kOCF, Device, T> im2Col2; Matrix::create(height, width, false, false);
im2Col1(input1->getData(), MatrixPtr output2 =
imShape, Matrix::create(width, height, false, false);
output1->getData(), input1->uniform(0.001, 1);
colShape1, input2->copyFrom(*input1);
stride,
stride, Im2ColFunctor<kCFO, Device, T> im2Col1;
padding, Im2ColFunctor<kOCF, Device, T> im2Col2;
padding); im2Col1(input1->getData(),
im2Col2(input2->getData(), imShape,
imShape, output1->getData(),
output2->getData(), colShape1,
colShape2, stride,
stride, stride,
stride, padding,
padding, padding,
padding); dilation,
dilation);
// The transposition of the result of ColFormat == kCFO im2Col2(input2->getData(),
// is equal to the result of ColFormat == kOCF. imShape,
MatrixPtr test; output2->getData(),
output2->transpose(test, true); colShape2,
autotest::TensorCheckErr(*output1, *test); stride,
stride,
Col2ImFunctor<kCFO, Device, T> col2Im1; padding,
Col2ImFunctor<kOCF, Device, T> col2Im2; padding,
col2Im1(input1->getData(), dilation,
imShape, dilation);
output1->getData(),
colShape1, // The transposition of the result of ColFormat == kCFO
stride, // is equal to the result of ColFormat == kOCF.
stride, MatrixPtr test;
padding, output2->transpose(test, true);
padding); autotest::TensorCheckErr(*output1, *test);
col2Im2(input2->getData(),
imShape, Col2ImFunctor<kCFO, Device, T> col2Im1;
output2->getData(), Col2ImFunctor<kOCF, Device, T> col2Im2;
colShape2,
stride, col2Im1(input1->getData(),
stride, imShape,
padding, output1->getData(),
padding); colShape1,
stride,
autotest::TensorCheckErr(*input1, *input2); stride,
padding,
padding,
dilation,
dilation);
col2Im2(input2->getData(),
imShape,
output2->getData(),
colShape2,
stride,
stride,
padding,
padding,
dilation,
dilation);
autotest::TensorCheckErr(*input1, *input2);
}
} }
} }
} }

@ -85,9 +85,49 @@ if(MOBILE_INFERENCE)
gradientmachines/GradientMachineMode.cpp gradientmachines/GradientMachineMode.cpp
gradientmachines/MultiGradientMachine.cpp) gradientmachines/MultiGradientMachine.cpp)
# Remove useless layers # Remove layers that used in training
list(REMOVE_ITEM GSERVER_SOURCES list(REMOVE_ITEM GSERVER_SOURCES
layers/RecurrentLayerGroup.cpp) layers/RecurrentLayerGroup.cpp
layers/CostLayer.cpp
layers/MultiBoxLossLayer.cpp
layers/WarpCTCLayer.cpp
layers/CTCLayer.cpp
layers/LinearChainCTC.cpp
layers/PrintLayer.cpp)
list(REMOVE_ITEM GSERVER_SOURCES
layers/OuterProdLayer.cpp
layers/SumToOneNormLayer.cpp
layers/ConvShiftLayer.cpp
layers/InterpolationLayer.cpp
layers/AgentLayer.cpp
layers/DotMulOperator.cpp
layers/GruStepLayer.cpp
layers/LstmStepLayer.cpp
layers/ConvexCombinationLayer.cpp
layers/Conv3DLayer.cpp
layers/DeConv3DLayer.cpp
layers/CropLayer.cpp
layers/CrossEntropyOverBeam.cpp
layers/DataNormLayer.cpp
layers/FeatureMapExpandLayer.cpp
layers/HierarchicalSigmoidLayer.cpp
layers/MultinomialSampler.cpp
layers/NCELayer.cpp
layers/KmaxSeqScoreLayer.cpp
layers/MDLstmLayer.cpp
layers/MultiplexLayer.cpp
layers/PadLayer.cpp
layers/Pool3DLayer.cpp
layers/ResizeLayer.cpp
layers/RotateLayer.cpp
layers/RowConvLayer.cpp
layers/RowL2NormLayer.cpp
layers/SamplingIdLayer.cpp
layers/ScaleShiftLayer.cpp
layers/SelectiveFullyConnectedLayer.cpp
layers/SpatialPyramidPoolLayer.cpp
layers/BilinearInterpLayer.cpp
layers/ClipLayer.cpp)
endif() endif()
if(WITH_GPU) if(WITH_GPU)

@ -16,7 +16,6 @@ limitations under the License. */
#include "NeuralNetwork.h" #include "NeuralNetwork.h"
#include "hl_gpu.h" #include "hl_gpu.h"
#include "paddle/gserver/layers/AgentLayer.h"
#include "paddle/utils/CustomStackTrace.h" #include "paddle/utils/CustomStackTrace.h"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h" #include "paddle/utils/Stat.h"
@ -28,6 +27,7 @@ limitations under the License. */
#ifndef PADDLE_MOBILE_INFERENCE #ifndef PADDLE_MOBILE_INFERENCE
#include "MultiNetwork.h" #include "MultiNetwork.h"
#include "RecurrentGradientMachine.h" #include "RecurrentGradientMachine.h"
#include "paddle/gserver/layers/AgentLayer.h"
#endif #endif
namespace paddle { namespace paddle {
@ -192,9 +192,11 @@ void NeuralNetwork::init(const ModelConfig& config,
void NeuralNetwork::connect(LayerPtr agentLayer, void NeuralNetwork::connect(LayerPtr agentLayer,
LayerPtr realLayer, LayerPtr realLayer,
int height) { int height) {
#ifndef PADDLE_MOBILE_INFERENCE
AgentLayer* agent = dynamic_cast<AgentLayer*>(agentLayer.get()); AgentLayer* agent = dynamic_cast<AgentLayer*>(agentLayer.get());
CHECK_NOTNULL(agent); CHECK_NOTNULL(agent);
agent->setRealLayer(realLayer, height); agent->setRealLayer(realLayer, height);
#endif
} }
void NeuralNetwork::connect(std::string agentLayerName, void NeuralNetwork::connect(std::string agentLayerName,

@ -79,6 +79,10 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
for (int i = 0; i < config_.inputs_size(); i++) { for (int i = 0; i < config_.inputs_size(); i++) {
std::vector<size_t> paddings = {(size_t)paddingY_[i], (size_t)padding_[i]}; std::vector<size_t> paddings = {(size_t)paddingY_[i], (size_t)padding_[i]};
std::vector<size_t> strides = {(size_t)strideY_[i], (size_t)stride_[i]}; std::vector<size_t> strides = {(size_t)strideY_[i], (size_t)stride_[i]};
std::vector<size_t> dilations = {(size_t)dilationY_[i],
(size_t)dilation_[i]};
bool useDilation = ((size_t)dilationY_[i] > 1 || (size_t)dilation_[i] > 1);
// Convolution Layer uses the GemmConv function by default. // Convolution Layer uses the GemmConv function by default.
convType = "GemmConv"; convType = "GemmConv";
@ -97,13 +101,14 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
if ((filterSize_[i] == filterSizeY_[i]) && if ((filterSize_[i] == filterSizeY_[i]) &&
(filterSize_[i] == 3 || filterSize_[i] == 4) && (filterSize_[i] == 3 || filterSize_[i] == 4) &&
(stride_[i] == strideY_[i]) && (stride_[i] == 1 || stride_[i] == 2)) { (stride_[i] == strideY_[i]) && (stride_[i] == 1 || stride_[i] == 2) &&
!useDilation) {
convType = "NeonDepthwiseConv"; convType = "NeonDepthwiseConv";
} }
#endif #endif
} }
if (FLAGS_use_nnpack && !isDeconv_) { if (FLAGS_use_nnpack && !isDeconv_ && !useDilation) {
createFunction(forward_, createFunction(forward_,
"NNPACKConv", "NNPACKConv",
FuncConfig() FuncConfig()
@ -117,6 +122,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
FuncConfig() FuncConfig()
.set("paddings", paddings) .set("paddings", paddings)
.set("strides", strides) .set("strides", strides)
.set("dilations", dilations)
.set("groups", (size_t)groups_[i])); .set("groups", (size_t)groups_[i]));
createFunction(backward_, createFunction(backward_,
@ -124,6 +130,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
FuncConfig() FuncConfig()
.set("paddings", paddings) .set("paddings", paddings)
.set("strides", strides) .set("strides", strides)
.set("dilations", dilations)
.set("groups", (size_t)groups_[i])); .set("groups", (size_t)groups_[i]));
createFunction(backward_, createFunction(backward_,
@ -131,6 +138,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
FuncConfig() FuncConfig()
.set("paddings", paddings) .set("paddings", paddings)
.set("strides", strides) .set("strides", strides)
.set("dilations", dilations)
.set("groups", (size_t)groups_[i])); .set("groups", (size_t)groups_[i]));
} }
} }

@ -98,6 +98,7 @@ ClassRegistrar<Layer, LayerConfig> Layer::registrar_;
LayerPtr Layer::create(const LayerConfig& config) { LayerPtr Layer::create(const LayerConfig& config) {
std::string type = config.type(); std::string type = config.type();
#ifndef PADDLE_MOBILE_INFERENCE
// NOTE: As following types have illegal character '-', // NOTE: As following types have illegal character '-',
// they can not use REGISTER_LAYER to registrar. // they can not use REGISTER_LAYER to registrar.
// Besides, to fit with old training models, // Besides, to fit with old training models,
@ -106,7 +107,6 @@ LayerPtr Layer::create(const LayerConfig& config) {
return LayerPtr(new MultiClassCrossEntropy(config)); return LayerPtr(new MultiClassCrossEntropy(config));
else if (type == "rank-cost") else if (type == "rank-cost")
return LayerPtr(new RankingCost(config)); return LayerPtr(new RankingCost(config));
#ifndef PADDLE_MOBILE_INFERENCE
else if (type == "auc-validation") else if (type == "auc-validation")
return LayerPtr(new AucValidation(config)); return LayerPtr(new AucValidation(config));
else if (type == "pnpair-validation") else if (type == "pnpair-validation")

@ -0,0 +1,109 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "MaxPoolWithMaskLayer.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace paddle {
bool MaxPoolWithMaskLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
PoolLayer::init(layerMap, parameterMap);
setOutput("mask", &mask_);
return true;
}
size_t MaxPoolWithMaskLayer::getSize() {
CHECK_EQ(inputLayers_.size(), 1UL);
size_t layerSize = 0;
outputY_ = outputSize(imgSizeY_,
sizeY_,
confPaddingY_,
strideY_,
/* caffeMode */ false);
outputX_ = outputSize(imgSize_,
sizeX_,
confPadding_,
stride_,
/* caffeMode */ false);
layerSize = outputX_ * outputY_ * channels_;
getOutput().setFrameHeight(outputY_);
getOutput().setFrameWidth(outputX_);
return layerSize;
}
void MaxPoolWithMaskLayer::forward(PassType passType) {
size_t size = getSize();
MatrixPtr inputV = inputLayers_[0]->getOutputValue();
int batchSize = inputV->getHeight();
resetOutput(batchSize, size);
MatrixPtr outV = getOutputValue();
CHECK_EQ(size, outV->getWidth());
resetSpecifyOutput(mask_,
batchSize,
size,
/* isValueClean */ false,
/* isGradClean */ true);
MatrixPtr maskV = mask_.value;
outV->maxPoolForward(*inputV,
imgSizeY_,
imgSize_,
channels_,
sizeX_,
sizeY_,
strideY_,
stride_,
outputY_,
outputX_,
confPaddingY_,
confPadding_,
maskV);
}
void MaxPoolWithMaskLayer::backward(const UpdateCallback& callback) {
(void)callback;
if (NULL == getInputGrad(0)) {
return;
}
MatrixPtr outGrad = getOutputGrad();
MatrixPtr inputV = inputLayers_[0]->getOutputValue();
MatrixPtr outV = getOutputValue();
MatrixPtr inputGrad = inputLayers_[0]->getOutputGrad();
inputGrad->maxPoolBackward(*inputV,
imgSizeY_,
imgSize_,
*outGrad,
*outV,
sizeX_,
sizeY_,
strideY_,
stride_,
outputY_,
outputX_,
1,
1,
confPaddingY_,
confPadding_);
}
} // namespace paddle

@ -0,0 +1,40 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "PoolLayer.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* @brief Basic parent layer of different kinds of pooling
*/
class MaxPoolWithMaskLayer : public PoolLayer {
protected:
Argument mask_;
public:
explicit MaxPoolWithMaskLayer(const LayerConfig& config)
: PoolLayer(config) {}
size_t getSize();
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
};
} // namespace paddle

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "PoolLayer.h" #include "PoolLayer.h"
#include "MaxPoolWithMaskLayer.h"
#include "PoolProjectionLayer.h" #include "PoolProjectionLayer.h"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
@ -44,7 +45,6 @@ bool PoolLayer::init(const LayerMap& layerMap,
strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride(); strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride();
confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding(); confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding();
outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x(); outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x();
return true; return true;
} }
@ -57,6 +57,8 @@ Layer* PoolLayer::create(const LayerConfig& config) {
} else if (CudnnPoolLayer::typeCheck(pool)) { } else if (CudnnPoolLayer::typeCheck(pool)) {
return new CudnnPoolLayer(config); return new CudnnPoolLayer(config);
#endif #endif
} else if (pool == "max-pool-with-mask") {
return new MaxPoolWithMaskLayer(config);
} else { } else {
LOG(FATAL) << "Unknown pool type: " << pool; LOG(FATAL) << "Unknown pool type: " << pool;
return nullptr; return nullptr;

@ -1,9 +1,12 @@
# gserver pacakge unittests # gserver pacakge unittests
add_simple_unittest(test_LinearChainCRF) add_simple_unittest(test_LinearChainCRF)
add_simple_unittest(test_MultinomialSampler)
add_simple_unittest(test_RecurrentLayer) add_simple_unittest(test_RecurrentLayer)
if(NOT MOBILE_INFERENCE)
add_simple_unittest(test_MultinomialSampler)
endif()
function(gserver_test TARGET) function(gserver_test TARGET)
add_unittest_without_exec(${TARGET} add_unittest_without_exec(${TARGET}
${TARGET}.cpp ${TARGET}.cpp
@ -24,6 +27,7 @@ gserver_test(test_ConvUnify)
gserver_test(test_BatchNorm) gserver_test(test_BatchNorm)
gserver_test(test_KmaxSeqScore) gserver_test(test_KmaxSeqScore)
gserver_test(test_Expand) gserver_test(test_Expand)
gserver_test(test_MaxPoolingWithMaskOutput)
########## test_Mkldnn layers and activations ########## ########## test_Mkldnn layers and activations ##########
if(WITH_MKLDNN) if(WITH_MKLDNN)
@ -48,7 +52,7 @@ if(WITH_PYTHON)
endif() endif()
############### test_WarpCTCLayer ####################### ############### test_WarpCTCLayer #######################
if(NOT WITH_DOUBLE) if(NOT WITH_DOUBLE AND NOT MOBILE_INFERENCE)
add_unittest_without_exec(test_WarpCTCLayer add_unittest_without_exec(test_WarpCTCLayer
test_WarpCTCLayer.cpp) test_WarpCTCLayer.cpp)

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save