diff --git a/.gitignore b/.gitignore index 1512c1438e..020d3f0c30 100644 --- a/.gitignore +++ b/.gitignore @@ -21,11 +21,10 @@ third_party/ cmake-build-* # generated while compiling -python/paddle/v2/framework/core.so +python/paddle/v2/fluid/core.so paddle/pybind/pybind.h CMakeFiles cmake_install.cmake paddle/.timestamp python/paddlepaddle.egg-info/ paddle/pybind/pybind.h -python/paddle/v2/framework/tests/tmp/* diff --git a/doc/design/evaluator.md b/doc/design/evaluator.md new file mode 100644 index 0000000000..a62d75ffef --- /dev/null +++ b/doc/design/evaluator.md @@ -0,0 +1,58 @@ +## Evaluator Design + +### The Problem + +During training or serving, we provide the evaluation function to measure the model performance, e.g., accuracy, precision. In the operator based framework design, the data go through the network pipeline batch by batch. As a result, inside the operator, we only can calculate one minibatch metrics. We need to provide a mechanism to calculate the metrics for each N pass/batch the user wanted. + +### Evaluator Design +Currently, every operation is expressed in the graph. we divide the evaluator process into three steps. + +1. Initialize the metric state and add it into the block. + +2. Calculate the statistic of the metric state in every mini-batch. The single operator is only responsible for calculating necessary statistics for one mini-batch. For example, accuracy operator only calculate a minibatch data if run once. + + +3. Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. When it comes to distributed training/Multi-GPU training, aggregate the value from different devices. + +### Implementation +This design is shown in python API. +Each metric operator need to caculate the metric statistic and return the batch aware states, Python side responsible for accumulate the states for each pass. + + +```python +class Evaluator(object): + """ + Evaluator Base class. + """ + def __init__(self, name, **kwargs): + """ + Different evaluator may has different metric states. E.g, Accuracy need two variables, total and right sample counts. + Auc need four variables, `true_positives`, + `true_negatives`, `false_positives` and `false_negatives`. So every evaluator should create its needed variables and append to main_program + + The initialization of Evaluator should be responsible for: + create metric states and append to the main_program + """ + pass + + def _update_ops(self, input, label, **kwargs) + """ + Add mini-batch evaluator caculate operators to the main_program. + Add increment operator to accumulate the metric states. + """ + + + def reset(self, executor, reset_program=None): + """ + Reset metric states at the begin of each pass/user specified batch number. + Execute the reset_program to reset the states. + """ + + + def eval(self, executor, eval_program=None): + """ + Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. + Execute the eval_program and return the result. + """ + return eval_result +``` diff --git a/paddle/capi/Matrix.cpp b/paddle/capi/Matrix.cpp index 53a36f8f20..d5b55e1c95 100644 --- a/paddle/capi/Matrix.cpp +++ b/paddle/capi/Matrix.cpp @@ -121,6 +121,7 @@ paddle_error paddle_matrix_get_shape(paddle_matrix mat, paddle_matrix paddle_matrix_create_sparse( uint64_t height, uint64_t width, uint64_t nnz, bool isBinary, bool useGpu) { +#ifndef PADDLE_MOBILE_INFERENCE auto ptr = new paddle::capi::CMatrix(); ptr->mat = paddle::Matrix::createSparseMatrix( height, @@ -131,6 +132,9 @@ paddle_matrix paddle_matrix_create_sparse( false, useGpu); return ptr; +#else + return nullptr; +#endif } paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat, @@ -140,6 +144,7 @@ paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat, uint64_t colSize, float* valueArray, uint64_t valueSize) { +#ifndef PADDLE_MOBILE_INFERENCE if (mat == nullptr) return kPD_NULLPTR; auto ptr = cast(mat); if (rowArray == nullptr || colArray == nullptr || @@ -160,4 +165,7 @@ paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat, } else { return kPD_NOT_SUPPORTED; } +#else + return kPD_NOT_SUPPORTED; +#endif } diff --git a/paddle/capi/matrix.h b/paddle/capi/matrix.h index bb5223f8a2..01b8bad2ee 100644 --- a/paddle/capi/matrix.h +++ b/paddle/capi/matrix.h @@ -48,6 +48,7 @@ PD_API paddle_matrix paddle_matrix_create(uint64_t height, * @param isBinary is binary (either 1 or 0 in matrix) or not. * @param useGpu is using GPU or not. * @return paddle_matrix. + * @note Mobile inference does not support this interface. */ PD_API paddle_matrix paddle_matrix_create_sparse( uint64_t height, uint64_t width, uint64_t nnz, bool isBinary, bool useGpu); @@ -129,6 +130,7 @@ PD_API paddle_error paddle_matrix_get_shape(paddle_matrix mat, * NULL if the matrix is binary. * @param [in] valueSize length of value array. Zero if the matrix is binary. * @return paddle_error + * @note Mobile inference does not support this interface. */ PD_API paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat, int* rowArray, diff --git a/paddle/cuda/CMakeLists.txt b/paddle/cuda/CMakeLists.txt index 0865b02c4f..efd1b7a73e 100755 --- a/paddle/cuda/CMakeLists.txt +++ b/paddle/cuda/CMakeLists.txt @@ -27,7 +27,9 @@ if(WITH_GPU) set_source_files_properties(${CUDA_CXX_SOURCES} PROPERTIES COMPILE_FLAGS "-D__NVCC__") else() + if (NOT MOBILE_INFERENCE) set(CUDA_CXX_SOURCES src/hl_warpctc_wrap.cc) + endif() endif() set(CUDA_CU_SOURCES diff --git a/paddle/cuda/include/hl_cnn.h b/paddle/cuda/include/hl_cnn.h index 6b56d9ec8d..89c1f48eda 100644 --- a/paddle/cuda/include/hl_cnn.h +++ b/paddle/cuda/include/hl_cnn.h @@ -18,7 +18,7 @@ limitations under the License. */ #include "hl_base.h" /** - * @brief Maximum pool forward. + * @brief Maximum pool forward with Mask output. * * @param[in] frameCnt batch size of input image. * @param[in] inputData input data. @@ -35,7 +35,7 @@ limitations under the License. */ * @param[in] paddingW padding width. * @param[out] tgtData output data. * @param[in] tgtStride stride between output data samples. - * + * @param[out] maskData the location indices of select max data. */ extern void hl_maxpool_forward(const int frameCnt, const real* inputData, @@ -51,7 +51,8 @@ extern void hl_maxpool_forward(const int frameCnt, const int paddingH, const int paddingW, real* tgtData, - const int tgtStride); + const int tgtStride, + real* maskData = NULL); /** * @brief Maximum pool backward. diff --git a/paddle/cuda/include/stub/hl_cnn_stub.h b/paddle/cuda/include/stub/hl_cnn_stub.h index a76dbf0b65..968ed4840f 100644 --- a/paddle/cuda/include/stub/hl_cnn_stub.h +++ b/paddle/cuda/include/stub/hl_cnn_stub.h @@ -31,7 +31,8 @@ inline void hl_maxpool_forward(const int frameCnt, const int paddingH, const int paddingW, real* tgtData, - const int tgtStride) {} + const int tgtStride, + real* MaskData) {} inline void hl_maxpool_backward(const int frameCnt, const real* inputData, diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index 58674febdc..3699b1e8ae 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -31,7 +31,8 @@ __global__ void KeMaxPoolForward(const int nthreads, const int offsetH, const int offsetW, real* tgtData, - const int tgtStride) { + const int tgtStride, + real* maskData) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < nthreads) { int pw = index % pooledW; @@ -45,16 +46,22 @@ __global__ void KeMaxPoolForward(const int nthreads, hstart = max(hstart, 0); wstart = max(wstart, 0); real maxval = -FLT_MAX; + int max_index = -1; inputData += (frameNum * channels + c) * height * width; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - if (maxval < inputData[h * width + w]) - maxval = inputData[h * width + w]; + if (maxval < inputData[h * width + w]) { + max_index = h * width + w; + maxval = inputData[max_index]; + } } } int tgtIndex = index % (pooledW * pooledH * channels) + frameNum * tgtStride; tgtData[tgtIndex] = maxval; + if (maskData != NULL) { + maskData[tgtIndex] = max_index; + } } } @@ -72,7 +79,8 @@ void hl_maxpool_forward(const int frameCnt, const int paddingH, const int paddingW, real* tgtData, - const int tgtStride) { + const int tgtStride, + real* maskData) { int num_kernels = pooledH * pooledW * channels * frameCnt; int blocks = (num_kernels + 1024 - 1) / 1024; dim3 threads(1024, 1); @@ -92,7 +100,8 @@ void hl_maxpool_forward(const int frameCnt, paddingH, paddingW, tgtData, - tgtStride); + tgtStride, + maskData); CHECK_SYNC("hl_maxpool_forward failed"); } diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 1afc524208..c08e844847 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -38,9 +38,9 @@ py_proto_compile(framework_py_proto SRCS framework.proto) add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_dependencies(framework_py_proto framework_py_proto_init) add_custom_command(TARGET framework_py_proto POST_BUILD - COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SOURCE_DIR}/python/paddle/v2/framework/proto - COMMAND cp *.py ${PADDLE_SOURCE_DIR}/python/paddle/v2/framework/proto/ - COMMENT "Copy generated python proto into directory paddle/v2/framework/proto." + COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/proto + COMMAND cp *.py ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/proto/ + COMMENT "Copy generated python proto into directory paddle/v2/fluid/proto." WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) cc_library(backward SRCS backward.cc DEPS net_op) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 913cd0f81e..b3b9c45ded 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -377,6 +377,12 @@ std::vector> MakeOpGrad( return grad_op_descs; } +static BlockDescBind* CreateStepBlock( + ProgramDescBind& program_desc, + std::unordered_set* no_grad_vars, + std::unordered_map* grad_to_var, + int step_block_idx); + std::vector> MakeBlockBackward( ProgramDescBind& program_desc, int block_idx, std::unordered_set* no_grad_vars, @@ -392,13 +398,13 @@ std::vector> MakeBlockBackward( if ((*it)->Type() == "recurrent") { int step_block_idx = (*it)->GetBlockAttr("step_block"); - auto backward_block_op_descs = MakeBlockBackward( - program_desc, step_block_idx, no_grad_vars, grad_to_var); + BlockDescBind* backward_block = CreateStepBlock( + program_desc, no_grad_vars, grad_to_var, step_block_idx); + op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block}); + } else if ((*it)->Type() == "conditional_block") { BlockDescBind* backward_block = - program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx)); - for (auto& ptr : backward_block_op_descs) { - backward_block->AppendAllocatedOp(std::move(ptr)); - } + CreateStepBlock(program_desc, no_grad_vars, grad_to_var, + (*it)->GetBlockAttr("block")); op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block}); } else { op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var); @@ -449,6 +455,21 @@ std::vector> MakeBlockBackward( return backward_descs; } +static BlockDescBind* CreateStepBlock( + ProgramDescBind& program_desc, + std::unordered_set* no_grad_vars, + std::unordered_map* grad_to_var, + int step_block_idx) { + auto backward_block_op_descs = MakeBlockBackward(program_desc, step_block_idx, + no_grad_vars, grad_to_var); + BlockDescBind* backward_block = + program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx)); + for (auto& ptr : backward_block_op_descs) { + backward_block->AppendAllocatedOp(move(ptr)); + } + return backward_block; +} + ParamGradInfoMap AppendBackward( ProgramDescBind& program_desc, const VarDescBind& target, const std::unordered_set& no_grad_vars) { diff --git a/paddle/framework/var_type.h b/paddle/framework/var_type.h index d060196bb2..0f19870bec 100644 --- a/paddle/framework/var_type.h +++ b/paddle/framework/var_type.h @@ -27,10 +27,32 @@ inline VarDesc::VarType ToVarType(std::type_index type) { return VarDesc_VarType_LOD_RANK_TABLE; } else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) { return VarDesc_VarType_LOD_TENSOR_ARRAY; + } else if (type.hash_code() == typeid(SelectedRows).hash_code()) { + return VarDesc_VarType_SELECTED_ROWS; } else { PADDLE_THROW("ToVarType:Unsupported type %s", type.name()); } } +template +inline void VisitVarType(const Variable& var, Visitor visitor) { + switch (ToVarType(var.Type())) { + case VarDesc_VarType_LOD_TENSOR: + visitor(var.Get()); + return; + case VarDesc_VarType_LOD_RANK_TABLE: + visitor(var.Get()); + return; + case VarDesc_VarType_LOD_TENSOR_ARRAY: + visitor(var.Get()); + return; + case VarDesc_VarType_SELECTED_ROWS: + visitor(var.Get()); + return; + default: + PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type())); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h index baf78bc6c8..062ea25a11 100644 --- a/paddle/function/ConvOp.h +++ b/paddle/function/ConvOp.h @@ -61,6 +61,7 @@ public: // function arguments strides_ = config.get>("strides"); paddings_ = config.get>("paddings"); + dilations_ = config.get>("dilations"); groups_ = config.get("groups"); // number of inputs and outputs @@ -118,6 +119,7 @@ protected: std::vector strides_; std::vector paddings_; + std::vector dilations_; /// Group size, refer to grouped convolution in /// Alex Krizhevsky's paper: when group=2, the first half of the @@ -133,6 +135,10 @@ protected: inline int paddingW() const { return paddings_[1]; } + inline int dilationH() const { return dilations_[0]; } + + inline int dilationW() const { return dilations_[1]; } + // A temporary memory in convolution calculation. MemoryHandlePtr memory_; diff --git a/paddle/function/ConvOpTest.h b/paddle/function/ConvOpTest.h index cb02a96d0d..d8d3c792df 100644 --- a/paddle/function/ConvOpTest.h +++ b/paddle/function/ConvOpTest.h @@ -79,45 +79,59 @@ void Convolution(const std::string& conv1, if (outputChannels < inputChannels) continue; for (size_t stride : {1, 2}) { for (size_t padding : {0, 1}) { - if (padding >= filterSize) break; + for (size_t dilation : {1, 3}) { + if (padding >= filterSize) break; + size_t filterS = (filterSize - 1) * dilation + 1; - // NNPACK only supports stride = 1 if batchSize > 1 - if ((conv1 == "NNPACKConv-CPU" || conv2 == "NNPACKConv-CPU") && - batchSize > 1 && stride > 1) - break; + if (inputSize + 2 * padding < filterS) break; - size_t outputSize = - (inputSize - filterSize + 2 * padding + stride) / stride; - VLOG(3) << " batchSize=" << batchSize - << " inputChannels=" << inputChannels - << " inputHeight=" << inputSize - << " inputWidth=" << inputSize - << " outputChannels=" << outputChannels - << " filterHeight=" << filterSize - << " filterWidth=" << filterSize - << " outputHeight=" << outputSize - << " outputWidth=" << outputSize << " stride=" << stride - << " padding=" << padding; + if ((conv1 == "NaiveConv-CPU" || conv2 == "NaiveConv-CPU" || + conv1 == "NNPACKConv-CPU" || + conv2 == "NNPACKConv-CPU") && + dilation > 1) + break; - std::vector paddings = {padding, padding}; - std::vector strides = {stride, stride}; - Compare2Function test( - conv1, - conv2, - FuncConfig() - .set("paddings", paddings) - .set("strides", strides) - .set("groups", (size_t)1) - .set("algo", (std::string) "auto")); + // NNPACK only supports stride = 1 if batchSize > 1 + if ((conv1 == "NNPACKConv-CPU" || + conv2 == "NNPACKConv-CPU") && + batchSize > 1 && stride > 1) + break; - TensorShape input{ - batchSize, inputChannels, inputSize, inputSize}; - TensorShape filter{ - outputChannels, inputChannels, filterSize, filterSize}; - TensorShape output{ - batchSize, outputChannels, outputSize, outputSize}; + size_t outputSize = + (inputSize - filterS + 2 * padding + stride) / stride; + VLOG(3) << " batchSize=" << batchSize + << " inputChannels=" << inputChannels + << " inputHeight=" << inputSize + << " inputWidth=" << inputSize + << " outputChannels=" << outputChannels + << " filterHeight=" << filterSize + << " filterWidth=" << filterSize + << " outputHeight=" << outputSize + << " outputWidth=" << outputSize + << " stride=" << stride << " padding=" << padding; - function(test, input, filter, output); + std::vector paddings = {padding, padding}; + std::vector strides = {stride, stride}; + std::vector dilations = {dilation, dilation}; + Compare2Function test( + conv1, + conv2, + FuncConfig() + .set("paddings", paddings) + .set("strides", strides) + .set("dilations", dilations) + .set("groups", (size_t)1) + .set("algo", (std::string) "auto")); + + TensorShape input{ + batchSize, inputChannels, inputSize, inputSize}; + TensorShape filter{ + outputChannels, inputChannels, filterSize, filterSize}; + TensorShape output{ + batchSize, outputChannels, outputSize, outputSize}; + + function(test, input, filter, output); + } } } } @@ -144,6 +158,7 @@ void Convolution2(const std::string& conv1, for (size_t outputChannels : {7}) { size_t stride = 1; size_t padding = 0; + size_t dilation = 1; size_t outputHeight = (inputHeight - filterHeight + 2 * padding + stride) / stride; @@ -162,6 +177,7 @@ void Convolution2(const std::string& conv1, std::vector paddings = {padding, padding}; std::vector strides = {stride, stride}; + std::vector dilations = {dilation, dilation}; Compare2Function test( conv1, conv2, @@ -169,6 +185,7 @@ void Convolution2(const std::string& conv1, .set("paddings", paddings) .set("strides", strides) .set("groups", (size_t)1) + .set("dilations", dilations) .set("algo", (std::string) "auto")); TensorShape input{ @@ -223,6 +240,7 @@ void DepthwiseConvolution(const std::string& conv1, std::vector paddings = {padding, padding}; std::vector strides = {stride, stride}; + std::vector dilations = {1, 1}; size_t groups = inputChannels; Compare2Function test( conv1, @@ -231,6 +249,7 @@ void DepthwiseConvolution(const std::string& conv1, .set("paddings", paddings) .set("strides", strides) .set("groups", groups) + .set("dilations", dilations) .set("algo", (std::string) "auto")); TensorShape input{ diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index bdb56ddac3..8d34eee886 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -100,7 +100,9 @@ public: strideH(), strideW(), paddingH(), - paddingW()); + paddingW(), + dilationH(), + dilationW()); } else { colData = inputData + g * inputOffset; } @@ -223,7 +225,9 @@ public: strideH(), strideW(), paddingH(), - paddingW()); + paddingW(), + dilationH(), + dilationW()); } } inputGrad += inputChannels * inputHeight * inputWidth; @@ -310,7 +314,9 @@ public: strideH(), strideW(), paddingH(), - paddingW()); + paddingW(), + dilationH(), + dilationW()); } else { colData = inputData + g * inputOffset; } diff --git a/paddle/function/Im2Col.h b/paddle/function/Im2Col.h index 1e0cff436f..0c37fc9724 100644 --- a/paddle/function/Im2Col.h +++ b/paddle/function/Im2Col.h @@ -78,7 +78,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth); + int paddingWidth, + int dilationHeight = 1, + int dilationWidth = 1); }; template @@ -91,7 +93,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth); + int paddingWidth, + int dilationHeight = 1, + int dilationWidth = 1); }; } // namespace paddle diff --git a/paddle/function/Im2ColOp.cpp b/paddle/function/Im2ColOp.cpp index b7d1eb1ede..f864d42f80 100644 --- a/paddle/function/Im2ColOp.cpp +++ b/paddle/function/Im2ColOp.cpp @@ -31,7 +31,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth) { + int paddingWidth, + int dilationHeight, + int dilationWidth) { int inputChannels = imShape[0]; int inputHeight = imShape[1]; int inputWidth = imShape[2]; @@ -47,8 +49,8 @@ public: int c_im = c / filterWidth / filterHeight; for (int h = 0; h < outputHeight; ++h) { for (int w = 0; w < outputWidth; ++w) { - int imRowIdx = h * strideHeight + hOffset; - int imColIdx = w * strideWidth + wOffset; + int imRowIdx = h * strideHeight + hOffset * dilationHeight; + int imColIdx = w * strideWidth + wOffset * dilationWidth; if ((imRowIdx - paddingHeight) < 0 || (imRowIdx - paddingHeight) >= inputHeight || (imColIdx - paddingWidth) < 0 || @@ -81,7 +83,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth) { + int paddingWidth, + int dilationHeight, + int dilationWidth) { int inputChannels = imShape[0]; int inputHeight = imShape[1]; int inputWidth = imShape[2]; @@ -97,8 +101,8 @@ public: int c_im = c / filterWidth / filterHeight; for (int h = 0; h < outputHeight; ++h) { for (int w = 0; w < outputWidth; ++w) { - int imRowIdx = h * strideHeight + hOffset; - int imColIdx = w * strideWidth + wOffset; + int imRowIdx = h * strideHeight + hOffset * dilationHeight; + int imColIdx = w * strideWidth + wOffset * dilationWidth; if ((imRowIdx - paddingHeight) >= 0 && (imRowIdx - paddingHeight) < inputHeight && (imColIdx - paddingWidth) >= 0 && @@ -134,7 +138,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth) { + int paddingWidth, + int dilationHeight = 1, + int dilationWidth = 1) { int inputChannels = imShape[0]; int inputHeight = imShape[1]; int inputWidth = imShape[2]; @@ -147,9 +153,10 @@ public: for (int channel = 0; channel < inputChannels; ++channel) { for (int filterH = 0; filterH < filterHeight; ++filterH) { for (int filterW = 0; filterW < filterWidth; ++filterW) { - int imRowOffset = - outputH * strideHeight + filterH - paddingHeight; - int imColOffset = outputW * strideWidth + filterW - paddingWidth; + int imRowOffset = outputH * strideHeight + + filterH * dilationHeight - paddingHeight; + int imColOffset = outputW * strideWidth + + filterW * dilationWidth - paddingWidth; int colDataOffset = (((outputH * outputWidth + outputW) * inputChannels + channel) * @@ -189,7 +196,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth) { + int paddingWidth, + int dilationHeight = 1, + int dilationWidth = 1) { int inputChannels = imShape[0]; int inputHeight = imShape[1]; int inputWidth = imShape[2]; @@ -202,9 +211,10 @@ public: for (int channel = 0; channel < inputChannels; ++channel) { for (int filterH = 0; filterH < filterHeight; ++filterH) { for (int filterW = 0; filterW < filterWidth; ++filterW) { - int imRowOffset = - outputH * strideHeight + filterH - paddingHeight; - int imColOffset = outputW * strideWidth + filterW - paddingWidth; + int imRowOffset = outputH * strideHeight + + filterH * dilationHeight - paddingHeight; + int imColOffset = outputW * strideWidth + + filterW * dilationWidth - paddingWidth; int colDataOffset = (((outputH * outputWidth + outputW) * inputChannels + channel) * diff --git a/paddle/function/Im2ColOpGpu.cu b/paddle/function/Im2ColOpGpu.cu index bd98610498..71da11b955 100644 --- a/paddle/function/Im2ColOpGpu.cu +++ b/paddle/function/Im2ColOpGpu.cu @@ -28,6 +28,8 @@ __global__ void im2col(const T* data_im, int strideW, int paddingH, int paddingW, + int dilationH, + int dilationW, int height_col, int width_col, T* data_col) { @@ -44,8 +46,8 @@ __global__ void im2col(const T* data_im, data_col += (channel_out * height_col + h_out) * width_col + w_out; for (int i = 0; i < blockH; ++i) { for (int j = 0; j < blockW; ++j) { - int rIdx = int(h_in + i); - int cIdx = int(w_in + j); + int rIdx = int(h_in + i * dilationH); + int cIdx = int(w_in + j * dilationW); if ((rIdx - (int)paddingH) >= (int)height || (rIdx - (int)paddingH) < 0 || (cIdx - (int)paddingW) >= (int)width || @@ -77,7 +79,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth) { + int paddingWidth, + int dilationHeight, + int dilationWidth) { int inputChannels = imShape[0]; int inputHeight = imShape[1]; int inputWidth = imShape[2]; @@ -102,6 +106,8 @@ public: strideWidth, paddingHeight, paddingWidth, + dilationHeight, + dilationWidth, outputHeight, outputWidth, colData); @@ -121,6 +127,8 @@ __global__ void col2im(size_t n, size_t strideW, size_t paddingH, size_t paddingW, + size_t dilationH, + size_t dilationW, size_t height_col, size_t width_col, T* data_im) { @@ -131,23 +139,34 @@ __global__ void col2im(size_t n, int w = int(index % width); int h = int((index / width) % height); int c = int(index / (width * height)); + int filterH = (blockH - 1) * dilationH + 1; + int filterW = (blockW - 1) * dilationW + 1; + if ((w - (int)paddingW) >= 0 && (w - (int)paddingW) < (width - 2 * paddingW) && (h - (int)paddingH) >= 0 && (h - paddingH) < (height - 2 * paddingH)) { // compute the start and end of the output int w_col_start = - (w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1; + (w < (int)filterW) ? 0 : (w - int(filterW)) / (int)strideW + 1; int w_col_end = min((int)(w / (int)strideW + 1), (int)(width_col)); int h_col_start = - (h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1; + (h < (int)filterH) ? 0 : (h - (int)filterH) / (int)strideH + 1; int h_col_end = min(int(h / strideH + 1), int(height_col)); + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { // the col location: [c * width * height + h_out, w_out] - int c_col = int(c * blockH * blockW) + - (h - h_col * (int)strideH) * (int)blockW + - (w - w_col * (int)strideW); - val += data_col[(c_col * height_col + h_col) * width_col + w_col]; + int h_k = (h - h_col * strideH); + int w_k = (w - w_col * strideW); + if (h_k % dilationH == 0 && w_k % dilationW == 0) { + h_k /= dilationH; + w_k /= dilationW; + int c_col = + (((c * blockH + h_k) * blockW + w_k) * height_col + h_col) * + width_col + + w_col; + val += data_col[c_col]; + } } } h -= paddingH; @@ -173,7 +192,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth) { + int paddingWidth, + int dilationHeight, + int dilationWidth) { int inputChannels = imShape[0]; int inputHeight = imShape[1]; int inputWidth = imShape[2]; @@ -205,6 +226,8 @@ public: strideWidth, paddingHeight, paddingWidth, + dilationHeight, + dilationWidth, outputHeight, outputWidth, imData); @@ -229,6 +252,8 @@ __global__ void im2colOCF(const T* imData, int strideWidth, int paddingHeight, int paddingWidth, + int dilationHeight, + int dilationWidth, int outputHeight, int outputWidth) { int swId = blockIdx.x; @@ -237,8 +262,10 @@ __global__ void im2colOCF(const T* imData, channelId += blockDim.z) { for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) { for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) { - int widthOffset = idx + swId * strideWidth - paddingWidth; - int heightOffset = idy + shId * strideHeight - paddingHeight; + int widthOffset = + idx * dilationHeight + swId * strideWidth - paddingWidth; + int heightOffset = + idy * dilationWidth + shId * strideHeight - paddingHeight; int imOffset = widthOffset + heightOffset * inputWidth + channelId * inputHeight * inputWidth; @@ -273,7 +300,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth) { + int paddingWidth, + int dilationHeight, + int dilationWidth) { int inputChannels = imShape[0]; int inputHeight = imShape[1]; int inputWidth = imShape[2]; @@ -312,6 +341,8 @@ public: strideWidth, paddingHeight, paddingWidth, + dilationHeight, + dilationWidth, outputHeight, outputWidth); CHECK_SYNC("Im2ColFunctor GPU failed"); @@ -330,6 +361,8 @@ __global__ void col2imOCF(T* imData, int strideWidth, int paddingHeight, int paddingWidth, + int dilationHeight, + int dilationWidth, int outputHeight, int outputWidth) { int swId = blockIdx.x; @@ -338,8 +371,10 @@ __global__ void col2imOCF(T* imData, channelId += blockDim.z) { for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) { for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) { - int widthOffset = idx + swId * strideWidth - paddingWidth; - int heightOffset = idy + shId * strideHeight - paddingHeight; + int widthOffset = + idx * dilationWidth + swId * strideWidth - paddingWidth; + int heightOffset = + idy * dilationHeight + shId * strideHeight - paddingHeight; int imOffset = widthOffset + heightOffset * inputWidth + channelId * inputHeight * inputWidth; @@ -372,7 +407,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth) { + int paddingWidth, + int dilationHeight, + int dilationWidth) { int inputChannels = imShape[0]; int inputHeight = imShape[1]; int inputWidth = imShape[2]; @@ -411,6 +448,8 @@ public: strideWidth, paddingHeight, paddingWidth, + dilationHeight, + dilationWidth, outputHeight, outputWidth); CHECK_SYNC("Col2ImFunctor GPU failed"); diff --git a/paddle/function/Im2ColTest.cpp b/paddle/function/Im2ColTest.cpp index a0a01a5fc7..1f085538d8 100644 --- a/paddle/function/Im2ColTest.cpp +++ b/paddle/function/Im2ColTest.cpp @@ -29,82 +29,98 @@ void TestIm2ColFunctor() { for (size_t filterWidth : {3, 7}) { for (size_t stride : {1, 2}) { for (size_t padding : {0, 1}) { - if (inputHeight <= filterHeight || inputWidth <= filterWidth) - break; - if (padding >= filterHeight || padding >= filterWidth) break; - size_t outputHeight = - (inputHeight - filterHeight + 2 * padding + stride) / - stride; - size_t outputWidth = - (inputWidth - filterWidth + 2 * padding + stride) / stride; - - TensorShape imShape = - TensorShape({channels, inputHeight, inputWidth}); - TensorShape colShape1 = TensorShape({channels, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - TensorShape colShape2 = TensorShape({outputHeight, - outputWidth, - channels, - filterHeight, - filterWidth}); - - size_t height = channels * filterHeight * filterWidth; - size_t width = outputHeight * outputWidth; - VectorPtr input1 = Vector::create(imShape.getElements(), false); - VectorPtr input2 = Vector::create(imShape.getElements(), false); - MatrixPtr output1 = Matrix::create(height, width, false, false); - MatrixPtr output2 = Matrix::create(width, height, false, false); - input1->uniform(0.001, 1); - input2->copyFrom(*input1); - - Im2ColFunctor im2Col1; - Im2ColFunctor im2Col2; - im2Col1(input1->getData(), - imShape, - output1->getData(), - colShape1, - stride, - stride, - padding, - padding); - im2Col2(input2->getData(), - imShape, - output2->getData(), - colShape2, - stride, - stride, - padding, - padding); - - // The transposition of the result of ColFormat == kCFO - // is equal to the result of ColFormat == kOCF. - MatrixPtr test; - output2->transpose(test, true); - autotest::TensorCheckErr(*output1, *test); - - Col2ImFunctor col2Im1; - Col2ImFunctor col2Im2; - col2Im1(input1->getData(), - imShape, - output1->getData(), - colShape1, - stride, - stride, - padding, - padding); - col2Im2(input2->getData(), - imShape, - output2->getData(), - colShape2, - stride, - stride, - padding, - padding); - - autotest::TensorCheckErr(*input1, *input2); + for (size_t dilation : {1, 3}) { + size_t filterSizeH = (filterHeight - 1) * dilation + 1; + size_t filterSizeW = (filterWidth - 1) * dilation + 1; + if (inputHeight + 2 * padding < filterSizeH || + inputWidth + 2 * padding < filterSizeW) + break; + if (padding >= filterSizeH || padding >= filterSizeW) break; + size_t outputHeight = + (inputHeight - filterSizeH + 2 * padding) / stride + 1; + size_t outputWidth = + (inputWidth - filterSizeW + 2 * padding) / stride + 1; + + TensorShape imShape = + TensorShape({channels, inputHeight, inputWidth}); + TensorShape colShape1 = TensorShape({channels, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + TensorShape colShape2 = TensorShape({outputHeight, + outputWidth, + channels, + filterHeight, + filterWidth}); + + size_t height = channels * filterHeight * filterWidth; + size_t width = outputHeight * outputWidth; + VectorPtr input1 = + Vector::create(imShape.getElements(), false); + VectorPtr input2 = + Vector::create(imShape.getElements(), false); + MatrixPtr output1 = + Matrix::create(height, width, false, false); + MatrixPtr output2 = + Matrix::create(width, height, false, false); + input1->uniform(0.001, 1); + input2->copyFrom(*input1); + + Im2ColFunctor im2Col1; + Im2ColFunctor im2Col2; + im2Col1(input1->getData(), + imShape, + output1->getData(), + colShape1, + stride, + stride, + padding, + padding, + dilation, + dilation); + im2Col2(input2->getData(), + imShape, + output2->getData(), + colShape2, + stride, + stride, + padding, + padding, + dilation, + dilation); + + // The transposition of the result of ColFormat == kCFO + // is equal to the result of ColFormat == kOCF. + MatrixPtr test; + output2->transpose(test, true); + autotest::TensorCheckErr(*output1, *test); + + Col2ImFunctor col2Im1; + Col2ImFunctor col2Im2; + + col2Im1(input1->getData(), + imShape, + output1->getData(), + colShape1, + stride, + stride, + padding, + padding, + dilation, + dilation); + col2Im2(input2->getData(), + imShape, + output2->getData(), + colShape2, + stride, + stride, + padding, + padding, + dilation, + dilation); + autotest::TensorCheckErr(*input1, *input2); + } } } } diff --git a/paddle/gserver/CMakeLists.txt b/paddle/gserver/CMakeLists.txt index 5f39167afc..91d732641a 100644 --- a/paddle/gserver/CMakeLists.txt +++ b/paddle/gserver/CMakeLists.txt @@ -85,9 +85,49 @@ if(MOBILE_INFERENCE) gradientmachines/GradientMachineMode.cpp gradientmachines/MultiGradientMachine.cpp) - # Remove useless layers + # Remove layers that used in training list(REMOVE_ITEM GSERVER_SOURCES - layers/RecurrentLayerGroup.cpp) + layers/RecurrentLayerGroup.cpp + layers/CostLayer.cpp + layers/MultiBoxLossLayer.cpp + layers/WarpCTCLayer.cpp + layers/CTCLayer.cpp + layers/LinearChainCTC.cpp + layers/PrintLayer.cpp) + list(REMOVE_ITEM GSERVER_SOURCES + layers/OuterProdLayer.cpp + layers/SumToOneNormLayer.cpp + layers/ConvShiftLayer.cpp + layers/InterpolationLayer.cpp + layers/AgentLayer.cpp + layers/DotMulOperator.cpp + layers/GruStepLayer.cpp + layers/LstmStepLayer.cpp + layers/ConvexCombinationLayer.cpp + layers/Conv3DLayer.cpp + layers/DeConv3DLayer.cpp + layers/CropLayer.cpp + layers/CrossEntropyOverBeam.cpp + layers/DataNormLayer.cpp + layers/FeatureMapExpandLayer.cpp + layers/HierarchicalSigmoidLayer.cpp + layers/MultinomialSampler.cpp + layers/NCELayer.cpp + layers/KmaxSeqScoreLayer.cpp + layers/MDLstmLayer.cpp + layers/MultiplexLayer.cpp + layers/PadLayer.cpp + layers/Pool3DLayer.cpp + layers/ResizeLayer.cpp + layers/RotateLayer.cpp + layers/RowConvLayer.cpp + layers/RowL2NormLayer.cpp + layers/SamplingIdLayer.cpp + layers/ScaleShiftLayer.cpp + layers/SelectiveFullyConnectedLayer.cpp + layers/SpatialPyramidPoolLayer.cpp + layers/BilinearInterpLayer.cpp + layers/ClipLayer.cpp) endif() if(WITH_GPU) diff --git a/paddle/gserver/gradientmachines/NeuralNetwork.cpp b/paddle/gserver/gradientmachines/NeuralNetwork.cpp index dbadc352a4..be112b4123 100644 --- a/paddle/gserver/gradientmachines/NeuralNetwork.cpp +++ b/paddle/gserver/gradientmachines/NeuralNetwork.cpp @@ -16,7 +16,6 @@ limitations under the License. */ #include "NeuralNetwork.h" #include "hl_gpu.h" -#include "paddle/gserver/layers/AgentLayer.h" #include "paddle/utils/CustomStackTrace.h" #include "paddle/utils/Logging.h" #include "paddle/utils/Stat.h" @@ -28,6 +27,7 @@ limitations under the License. */ #ifndef PADDLE_MOBILE_INFERENCE #include "MultiNetwork.h" #include "RecurrentGradientMachine.h" +#include "paddle/gserver/layers/AgentLayer.h" #endif namespace paddle { @@ -192,9 +192,11 @@ void NeuralNetwork::init(const ModelConfig& config, void NeuralNetwork::connect(LayerPtr agentLayer, LayerPtr realLayer, int height) { +#ifndef PADDLE_MOBILE_INFERENCE AgentLayer* agent = dynamic_cast(agentLayer.get()); CHECK_NOTNULL(agent); agent->setRealLayer(realLayer, height); +#endif } void NeuralNetwork::connect(std::string agentLayerName, diff --git a/paddle/gserver/layers/ExpandConvLayer.cpp b/paddle/gserver/layers/ExpandConvLayer.cpp index 48dfcb49a4..7ff0c73721 100644 --- a/paddle/gserver/layers/ExpandConvLayer.cpp +++ b/paddle/gserver/layers/ExpandConvLayer.cpp @@ -79,6 +79,10 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, for (int i = 0; i < config_.inputs_size(); i++) { std::vector paddings = {(size_t)paddingY_[i], (size_t)padding_[i]}; std::vector strides = {(size_t)strideY_[i], (size_t)stride_[i]}; + std::vector dilations = {(size_t)dilationY_[i], + (size_t)dilation_[i]}; + + bool useDilation = ((size_t)dilationY_[i] > 1 || (size_t)dilation_[i] > 1); // Convolution Layer uses the GemmConv function by default. convType = "GemmConv"; @@ -97,13 +101,14 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, #if defined(__ARM_NEON__) || defined(__ARM_NEON) if ((filterSize_[i] == filterSizeY_[i]) && (filterSize_[i] == 3 || filterSize_[i] == 4) && - (stride_[i] == strideY_[i]) && (stride_[i] == 1 || stride_[i] == 2)) { + (stride_[i] == strideY_[i]) && (stride_[i] == 1 || stride_[i] == 2) && + !useDilation) { convType = "NeonDepthwiseConv"; } #endif } - if (FLAGS_use_nnpack && !isDeconv_) { + if (FLAGS_use_nnpack && !isDeconv_ && !useDilation) { createFunction(forward_, "NNPACKConv", FuncConfig() @@ -117,6 +122,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, FuncConfig() .set("paddings", paddings) .set("strides", strides) + .set("dilations", dilations) .set("groups", (size_t)groups_[i])); createFunction(backward_, @@ -124,6 +130,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, FuncConfig() .set("paddings", paddings) .set("strides", strides) + .set("dilations", dilations) .set("groups", (size_t)groups_[i])); createFunction(backward_, @@ -131,6 +138,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, FuncConfig() .set("paddings", paddings) .set("strides", strides) + .set("dilations", dilations) .set("groups", (size_t)groups_[i])); } } diff --git a/paddle/gserver/layers/Layer.cpp b/paddle/gserver/layers/Layer.cpp index 01f2aae6cf..b55b86221c 100644 --- a/paddle/gserver/layers/Layer.cpp +++ b/paddle/gserver/layers/Layer.cpp @@ -98,6 +98,7 @@ ClassRegistrar Layer::registrar_; LayerPtr Layer::create(const LayerConfig& config) { std::string type = config.type(); +#ifndef PADDLE_MOBILE_INFERENCE // NOTE: As following types have illegal character '-', // they can not use REGISTER_LAYER to registrar. // Besides, to fit with old training models, @@ -106,7 +107,6 @@ LayerPtr Layer::create(const LayerConfig& config) { return LayerPtr(new MultiClassCrossEntropy(config)); else if (type == "rank-cost") return LayerPtr(new RankingCost(config)); -#ifndef PADDLE_MOBILE_INFERENCE else if (type == "auc-validation") return LayerPtr(new AucValidation(config)); else if (type == "pnpair-validation") diff --git a/paddle/gserver/layers/MaxPoolWithMaskLayer.cpp b/paddle/gserver/layers/MaxPoolWithMaskLayer.cpp new file mode 100644 index 0000000000..d810a58d9a --- /dev/null +++ b/paddle/gserver/layers/MaxPoolWithMaskLayer.cpp @@ -0,0 +1,109 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "MaxPoolWithMaskLayer.h" +#include "paddle/utils/Logging.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +bool MaxPoolWithMaskLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + PoolLayer::init(layerMap, parameterMap); + setOutput("mask", &mask_); + return true; +} + +size_t MaxPoolWithMaskLayer::getSize() { + CHECK_EQ(inputLayers_.size(), 1UL); + size_t layerSize = 0; + + outputY_ = outputSize(imgSizeY_, + sizeY_, + confPaddingY_, + strideY_, + /* caffeMode */ false); + outputX_ = outputSize(imgSize_, + sizeX_, + confPadding_, + stride_, + /* caffeMode */ false); + + layerSize = outputX_ * outputY_ * channels_; + getOutput().setFrameHeight(outputY_); + getOutput().setFrameWidth(outputX_); + + return layerSize; +} + +void MaxPoolWithMaskLayer::forward(PassType passType) { + size_t size = getSize(); + MatrixPtr inputV = inputLayers_[0]->getOutputValue(); + int batchSize = inputV->getHeight(); + resetOutput(batchSize, size); + + MatrixPtr outV = getOutputValue(); + CHECK_EQ(size, outV->getWidth()); + + resetSpecifyOutput(mask_, + batchSize, + size, + /* isValueClean */ false, + /* isGradClean */ true); + + MatrixPtr maskV = mask_.value; + outV->maxPoolForward(*inputV, + imgSizeY_, + imgSize_, + channels_, + sizeX_, + sizeY_, + strideY_, + stride_, + outputY_, + outputX_, + confPaddingY_, + confPadding_, + maskV); +} + +void MaxPoolWithMaskLayer::backward(const UpdateCallback& callback) { + (void)callback; + if (NULL == getInputGrad(0)) { + return; + } + + MatrixPtr outGrad = getOutputGrad(); + MatrixPtr inputV = inputLayers_[0]->getOutputValue(); + MatrixPtr outV = getOutputValue(); + MatrixPtr inputGrad = inputLayers_[0]->getOutputGrad(); + + inputGrad->maxPoolBackward(*inputV, + imgSizeY_, + imgSize_, + *outGrad, + *outV, + sizeX_, + sizeY_, + strideY_, + stride_, + outputY_, + outputX_, + 1, + 1, + confPaddingY_, + confPadding_); +} + +} // namespace paddle diff --git a/paddle/gserver/layers/MaxPoolWithMaskLayer.h b/paddle/gserver/layers/MaxPoolWithMaskLayer.h new file mode 100644 index 0000000000..e0174add9d --- /dev/null +++ b/paddle/gserver/layers/MaxPoolWithMaskLayer.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "PoolLayer.h" +#include "paddle/math/Matrix.h" + +namespace paddle { +/** + * @brief Basic parent layer of different kinds of pooling + */ +class MaxPoolWithMaskLayer : public PoolLayer { +protected: + Argument mask_; + +public: + explicit MaxPoolWithMaskLayer(const LayerConfig& config) + : PoolLayer(config) {} + + size_t getSize(); + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; +}; +} // namespace paddle diff --git a/paddle/gserver/layers/PoolLayer.cpp b/paddle/gserver/layers/PoolLayer.cpp index 7b932d5a76..87613a96c5 100644 --- a/paddle/gserver/layers/PoolLayer.cpp +++ b/paddle/gserver/layers/PoolLayer.cpp @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "PoolLayer.h" +#include "MaxPoolWithMaskLayer.h" #include "PoolProjectionLayer.h" #include "paddle/utils/Logging.h" #ifdef PADDLE_WITH_CUDA @@ -44,7 +45,6 @@ bool PoolLayer::init(const LayerMap& layerMap, strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride(); confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding(); outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x(); - return true; } @@ -57,6 +57,8 @@ Layer* PoolLayer::create(const LayerConfig& config) { } else if (CudnnPoolLayer::typeCheck(pool)) { return new CudnnPoolLayer(config); #endif + } else if (pool == "max-pool-with-mask") { + return new MaxPoolWithMaskLayer(config); } else { LOG(FATAL) << "Unknown pool type: " << pool; return nullptr; diff --git a/paddle/gserver/layers/ROIPoolLayer.cpp b/paddle/gserver/layers/ROIPoolLayer.cpp index 99cfddb0cf..35d4b12d3d 100644 --- a/paddle/gserver/layers/ROIPoolLayer.cpp +++ b/paddle/gserver/layers/ROIPoolLayer.cpp @@ -98,7 +98,7 @@ void ROIPoolLayer::forward(PassType passType) { size_t roiStartH = round(bottomROIs[2] * spatialScale_); size_t roiEndW = round(bottomROIs[3] * spatialScale_); size_t roiEndH = round(bottomROIs[4] * spatialScale_); - CHECK_GE(roiBatchIdx, 0); + CHECK_GE(roiBatchIdx, 0UL); CHECK_LT(roiBatchIdx, batchSize); size_t roiHeight = std::max(roiEndH - roiStartH + 1, 1UL); size_t roiWidth = std::max(roiEndW - roiStartW + 1, 1UL); diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt index aa94ee406e..4bea348f63 100644 --- a/paddle/gserver/tests/CMakeLists.txt +++ b/paddle/gserver/tests/CMakeLists.txt @@ -1,9 +1,12 @@ # gserver pacakge unittests add_simple_unittest(test_LinearChainCRF) -add_simple_unittest(test_MultinomialSampler) add_simple_unittest(test_RecurrentLayer) +if(NOT MOBILE_INFERENCE) + add_simple_unittest(test_MultinomialSampler) +endif() + function(gserver_test TARGET) add_unittest_without_exec(${TARGET} ${TARGET}.cpp @@ -24,6 +27,7 @@ gserver_test(test_ConvUnify) gserver_test(test_BatchNorm) gserver_test(test_KmaxSeqScore) gserver_test(test_Expand) +gserver_test(test_MaxPoolingWithMaskOutput) ########## test_Mkldnn layers and activations ########## if(WITH_MKLDNN) @@ -48,7 +52,7 @@ if(WITH_PYTHON) endif() ############### test_WarpCTCLayer ####################### -if(NOT WITH_DOUBLE) +if(NOT WITH_DOUBLE AND NOT MOBILE_INFERENCE) add_unittest_without_exec(test_WarpCTCLayer test_WarpCTCLayer.cpp) diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index fcbcb5b0f1..3517d293e3 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -434,7 +434,7 @@ void testConvLayer(const string& type, bool trans, bool useGpu) { config.layerConfig.set_partial_sum(1); config.layerConfig.set_shared_biases(true); - int dilation = 1; + int dilation = 2; if (type == "cudnn_conv") { #if CUDNN_VERSION >= 6000 dilation = 2; @@ -1234,6 +1234,7 @@ void testPoolLayer2(const string& poolType, bool trans, bool useGpu) { TEST(Layer, PoolLayer) { testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ false); testPoolLayer("max-projection", /* trans= */ false, /* useGpu= */ false); + testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ false); #ifdef PADDLE_WITH_CUDA testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ true); @@ -1242,6 +1243,7 @@ TEST(Layer, PoolLayer) { testPoolLayer("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true); testPoolLayer2("cudnn-max-pool", /* trans= */ false, /* useGpu= */ true); testPoolLayer2("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true); + testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ true); #endif } diff --git a/paddle/gserver/tests/test_MKLDNN.cpp b/paddle/gserver/tests/test_MKLDNN.cpp index a0e039c2a3..a859e34c89 100644 --- a/paddle/gserver/tests/test_MKLDNN.cpp +++ b/paddle/gserver/tests/test_MKLDNN.cpp @@ -297,7 +297,7 @@ static void getAddtoConfig(TestConfig& cfg, } void testAddtoLayer(const testImageDesc& pm, const size_t nInputs) { - CHECK_GE(nInputs, 1); + CHECK_GE(nInputs, 1UL); TestConfig dnnConfig; getAddtoConfig(dnnConfig, pm, nInputs); dnnConfig.layerConfig.set_type("mkldnn_addto"); diff --git a/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp b/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp new file mode 100644 index 0000000000..16438886df --- /dev/null +++ b/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp @@ -0,0 +1,117 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "LayerGradUtil.h" +#include "paddle/math/MathUtils.h" +#include "paddle/testing/TestUtil.h" + +using namespace paddle; + +void setPoolConfig(TestConfig* config, + PoolConfig* pool, + const string& poolType) { + (*config).biasSize = 0; + (*config).layerConfig.set_type("pool"); + (*config).layerConfig.set_num_filters(1); + + int kw = 3, kh = 3; + int pw = 0, ph = 0; + int sw = 2, sh = 2; + pool->set_pool_type(poolType); + pool->set_channels(1); + pool->set_size_x(kw); + pool->set_size_y(kh); + pool->set_start(0); + pool->set_padding(pw); + pool->set_padding_y(ph); + pool->set_stride(sw); + pool->set_stride_y(sh); + + int ow = outputSize(pool->img_size(), kw, pw, sw, /* caffeMode */ false); + int oh = outputSize(pool->img_size_y(), kh, ph, sh, /* caffeMode */ false); + pool->set_output_x(ow); + pool->set_output_y(oh); +} + +void doOneMaxPoolingWithMaskOutputTest(MatrixPtr& inputMat, + const string& poolType, + bool use_gpu, + MatrixPtr& maskMat) { + TestConfig config; + config.inputDefs.push_back({INPUT_DATA, "layer_0", 25, 0}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + PoolConfig* pool = input->mutable_pool_conf(); + + pool->set_img_size(5); + pool->set_img_size_y(5); + setPoolConfig(&config, pool, poolType); + config.layerConfig.set_size(pool->output_x() * pool->output_y() * + pool->channels()); + + config.layerConfig.set_name("MaxPoolWithMask"); + + std::vector dataLayers; + LayerMap layerMap; + vector datas; + + initDataLayer(config, + &dataLayers, + &datas, + &layerMap, + "MaxPoolWithMask", + 1, + false, + use_gpu); + + dataLayers[0]->getOutputValue()->copyFrom(*inputMat); + + FLAGS_use_gpu = use_gpu; + std::vector parameters; + LayerPtr maxPoolingWithMaskOutputLayer; + initTestLayer(config, &layerMap, ¶meters, &maxPoolingWithMaskOutputLayer); + maxPoolingWithMaskOutputLayer->forward(PASS_GC); + + checkMatrixEqual(maxPoolingWithMaskOutputLayer->getOutput("mask").value, + maskMat); +} + +TEST(Layer, maxPoolingWithMaskOutputLayerFwd) { + bool useGpu = false; + MatrixPtr inputMat; + MatrixPtr maskMat; + real inputData[] = {0.1, 0.1, 0.5, 0.5, 1.1, 0.2, 0.2, 0.6, 0.1, + 0.1, 0.3, 0.3, 0.7, 0.1, 0.1, 0.4, 0.4, 0.8, + 0.8, 0.1, 1.0, 2.0, 3.0, 0.0, 9.0}; + real maskData[] = {12, 4, 22, 24}; + + inputMat = Matrix::create(1, 25, false, useGpu); + maskMat = Matrix::create(1, 4, false, useGpu); + inputMat->setData(inputData); + maskMat->setData(maskData); + doOneMaxPoolingWithMaskOutputTest( + inputMat, "max-pool-with-mask", useGpu, maskMat); +#ifdef PADDLE_WITH_CUDA + useGpu = true; + inputMat = Matrix::create(1, 25, false, useGpu); + maskMat = Matrix::create(1, 4, false, useGpu); + inputMat->copyFrom(inputData, 25); + maskMat->copyFrom(maskData, 4); + doOneMaxPoolingWithMaskOutputTest( + inputMat, "max-pool-with-mask", useGpu, maskMat); +#endif +} diff --git a/paddle/math/BaseMatrix.cu b/paddle/math/BaseMatrix.cu index 53dd538360..e3eff59dc5 100644 --- a/paddle/math/BaseMatrix.cu +++ b/paddle/math/BaseMatrix.cu @@ -1902,5 +1902,52 @@ void BaseMatrixT::sumOfProducts(BaseMatrixT& b, } template class BaseMatrixT; + +#ifndef PADDLE_MOBILE_INFERENCE + template class BaseMatrixT; + +#else + +template <> +void BaseMatrixT::zero() { + applyUnary(unary::Zero()); +} + +template <> +void BaseMatrixT::assign(int p) { + applyUnary(unary::Assign(p)); +} + +template <> +void BaseMatrixT::isEqualTo(BaseMatrixT& b, int value) { + applyBinary(binary::IsEqual(value), b); +} + +template <> +void BaseMatrixT::neg() { + applyUnary(unary::Neg()); +} + +template <> +void BaseMatrixT::abs2() { + applyUnary(unary::Abs()); +} + +template <> +void BaseMatrixT::add(int p) { + applyUnary(unary::Add(p)); +} + +template <> +void BaseMatrixT::add(int p1, int p2) { + applyUnary(unary::Add2(p1, p2)); +} + +template <> +void BaseMatrixT::applyL1(int learningRate, int decayRate) { + applyUnary(unary::ApplyL1(learningRate * decayRate)); +} + +#endif } // namespace paddle diff --git a/paddle/math/CMakeLists.txt b/paddle/math/CMakeLists.txt index 68b5296228..86bb270a43 100644 --- a/paddle/math/CMakeLists.txt +++ b/paddle/math/CMakeLists.txt @@ -25,6 +25,19 @@ else() message(STATUS "Compile with MKLDNNMatrix") endif() +if(MOBILE_INFERENCE) + list(REMOVE_ITEM MATH_SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/SIMDFunctions.cpp) + # Remove sparse + list(REMOVE_ITEM MATH_HEADERS + ${CMAKE_CURRENT_SOURCE_DIR}/CpuSparseMatrix.h + ${CMAKE_CURRENT_SOURCE_DIR}/SparseMatrix.h + ${CMAKE_CURRENT_SOURCE_DIR}/SparseRowMatrix.h) + list(REMOVE_ITEM MATH_SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/CpuSparseMatrix.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/SparseMatrix.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/SparseRowMatrix.cpp) +endif() set(MATH_SOURCES "${PADDLE_SOURCE_DIR}/paddle/math/BaseMatrix.cu" "${PADDLE_SOURCE_DIR}/paddle/math/TrainingAlgorithmOp.cu" diff --git a/paddle/math/CpuSparseMatrix.h b/paddle/math/CpuSparseMatrix.h index 36d57bbb65..aad1348353 100644 --- a/paddle/math/CpuSparseMatrix.h +++ b/paddle/math/CpuSparseMatrix.h @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + +#ifndef PADDLE_MOBILE_INFERENCE + #include #include "Matrix.h" @@ -309,3 +312,57 @@ private: using Matrix::subMatrix; }; } // namespace paddle + +#else + +#include "Matrix.h" + +namespace paddle { + +class CpuSparseMatrix : public Matrix { +public: + CpuSparseMatrix(size_t height, + size_t width, + size_t nnz, /* used to allocate space */ + SparseValueType valueType = FLOAT_VALUE, + SparseFormat format = SPARSE_CSR, + bool trans = false) + : Matrix(NULL, height, width, trans, false) {} + + CpuSparseMatrix(real* data, + int* rows, + int* cols, + size_t height, + size_t width, + size_t nnz, + SparseValueType valueType, + SparseFormat format, + bool trans) + : Matrix(NULL, height, width, trans, false) {} + + real* getValue() const { return nullptr; } + size_t getColStartIdx(size_t i) const { return 0; } + size_t getRowStartIdx(size_t i) const { return 0; } + size_t getColNum(size_t i) const { return 0; } + int* getRowCols(size_t i) const { return nullptr; } + + CpuSparseMatrixPtr getTmpSparseMatrix(size_t height, size_t width) { + return nullptr; + } + + void resize(size_t newHeight, + size_t newWidth, + size_t newNnz, /* used to allocate space */ + SparseValueType valueType, + SparseFormat format) {} + void resize(size_t newHeight, size_t newWidth) {} + MatrixPtr getTranspose() { return nullptr; } + void setRow(size_t row, + size_t colNum, + const unsigned int* cols, + const real* values) {} +}; + +} // namespace paddle + +#endif diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index c3e34d5309..88e9180690 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -451,6 +451,7 @@ void GpuMatrix::addSharedBias(Matrix& b, real scale) { } void GpuMatrix::collectBias(Matrix& a, real scale) { +#ifdef PADDLE_WITH_CUDA CHECK_EQ(getHeight(), (size_t)1); CHECK_EQ(width_, a.getWidth()); GpuSparseMatrix* sMatPtr = dynamic_cast(&a); @@ -461,6 +462,7 @@ void GpuMatrix::collectBias(Matrix& a, real scale) { hl_sparse_matrix_s A_d = sMatPtr->sMatrix_.get(); hl_sparse_matrix_column_sum(data, A_d, sMatPtr->getHeight(), width_, scale); } +#endif } void GpuMatrix::collectSharedBias(Matrix& a, real scale) { @@ -552,6 +554,7 @@ void GpuMatrix::mul(const GpuSparseMatrix& a, const GpuMatrix& b, real scaleAB, real scaleT) { +#ifdef PADDLE_WITH_CUDA CHECK(isContiguous()); CHECK(b.isContiguous()); CHECK(b.useGpu_ == true) << "Matrix type are not equal"; @@ -578,12 +581,14 @@ void GpuMatrix::mul(const GpuSparseMatrix& a, b.height_, scaleAB, scaleT); +#endif } void GpuMatrix::mul(const GpuMatrix& a, const GpuSparseMatrix& b, real scaleAB, real scaleT) { +#ifdef PADDLE_WITH_CUDA CHECK(isContiguous()); CHECK(a.isContiguous()); CHECK(a.useGpu_ == true) << "Matrix type are not equal"; @@ -622,6 +627,7 @@ void GpuMatrix::mul(const GpuMatrix& a, scaleAB, scaleT); } +#endif } /* this = a*b */ @@ -1028,15 +1034,23 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat, size_t outputH, size_t outputW, size_t paddingH, - size_t paddingW) { + size_t paddingW, + MatrixPtr maskMatP) { CHECK(inputMat.useGpu_ == true) << "Matrix type are not equal"; real* inputData = inputMat.getData(); + real* maskData = NULL; size_t frameNum = inputMat.getHeight(); CHECK(imgSizeH * imgSizeW * channels == inputMat.getWidth()); CHECK(height_ == inputMat.getHeight()); CHECK(width_ == outputH * outputW * channels); + if (maskMatP != NULL) { + CHECK(maskMatP->useGpu_ == true) << "Matrix type are not equal"; + CHECK(outputH * outputW * channels == maskMatP->getWidth()); + maskData = maskMatP->getData(); + } + hl_maxpool_forward(frameNum, inputData, channels, @@ -1051,7 +1065,8 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat, paddingH, paddingW, data_, - getStride()); + getStride(), + maskData); } void GpuMatrix::maxPoolBackward(Matrix& inputMat, @@ -1548,6 +1563,7 @@ void GpuMatrix::bilinearBackward(const Matrix& out, } void GpuMatrix::multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label) { +#ifdef PADDLE_WITH_CUDA GpuMatrix* outputPtr = dynamic_cast(&output); auto labelPtr = dynamic_cast(&label); @@ -1563,9 +1579,11 @@ void GpuMatrix::multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label) { hl_sparse_matrix_s mat_d = labelPtr->sMatrix_.get(); hl_matrix_multi_binary_cross_entropy( output_d, entropy_d, mat_d, height_, outputPtr->width_); +#endif } void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label) { +#ifdef PADDLE_WITH_CUDA GpuMatrix* outputPtr = dynamic_cast(&output); auto labelPtr = dynamic_cast(&label); @@ -1581,6 +1599,7 @@ void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label) { hl_sparse_matrix_s mat_d = labelPtr->sMatrix_.get(); hl_matrix_multi_binary_cross_entropy_bp( output_d, grad_d, mat_d, height_, width_); +#endif } void GpuMatrix::vol2Col(real* dataSrc, @@ -1973,9 +1992,11 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, size_t outputH, size_t outputW, size_t paddingH, - size_t paddingW) { + size_t paddingW, + MatrixPtr maskMatP) { real* inputData = inputMat.getData(); real* outData = data_; + real* maskData = NULL; size_t num = inputMat.getHeight(); size_t inLength = imgSizeH * imgSizeW; size_t outLength = outputH * outputW; @@ -1984,6 +2005,11 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, CHECK_EQ(channels * outLength, this->getWidth()); size_t outStride = getStride(); + if (maskMatP != NULL) { + maskData = maskMatP->getData(); + CHECK_EQ(channels * outLength, maskMatP->getWidth()); + } + /* initialize the data_ */ for (size_t i = 0; i < height_; i++) { for (size_t j = 0; j < width_; j++) { @@ -2005,10 +2031,21 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, int wstart = pw * strideW - paddingW; int wend = std::min(wstart + sizeX, imgSizeW); wstart = std::max(wstart, 0); - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - outData[ph * outputW + pw] = std::max( - outData[ph * outputW + pw], inputData[h * imgSizeW + w]); + if (maskData == NULL) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + outData[ph * outputW + pw] = std::max( + outData[ph * outputW + pw], inputData[h * imgSizeW + w]); + } + } + } else { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (outData[ph * outputW + pw] < inputData[h * imgSizeW + w]) { + outData[ph * outputW + pw] = inputData[h * imgSizeW + w]; + maskData[ph * outputW + pw] = h * imgSizeW + w; + } + } } } } @@ -2016,6 +2053,8 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, // compute offset inputData += inLength; outData += outLength; + + if (maskData != NULL) maskData += outLength; } } } @@ -3226,6 +3265,7 @@ template void CpuMatrix::mul(CpuSparseMatrix* a, real scaleAB, real scaleT); +#ifndef PADDLE_MOBILE_INFERENCE void SharedCpuMatrix::mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB, @@ -3354,6 +3394,7 @@ void SharedCpuMatrix::initBlock(int blockNum) { } } +#endif /* Add a (column) vector b to matrix a, column by column */ void CpuMatrix::addColumnVector(const Matrix& b) { BaseMatrix::addColVector(const_cast(b)); diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 44180bca8b..e273f11236 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -861,7 +861,8 @@ public: /** * Pooling forward operation, pick out the largest element - * in the sizeX of value + * in the sizeX of value, if the maskMatP is not NULL, it will + * also caculate the location indices. */ virtual void maxPoolForward(Matrix& inputMat, size_t imgSizeH, @@ -874,7 +875,8 @@ public: size_t outputH, size_t outputW, size_t paddingH, - size_t paddingW) { + size_t paddingW, + MatrixPtr maskMatP = NULL) { LOG(FATAL) << "Not implemeted"; } @@ -1426,7 +1428,8 @@ public: size_t outputH, size_t outputW, size_t paddingH, - size_t paddingW); + size_t paddingW, + MatrixPtr maskMatP); void maxPoolBackward(Matrix& image, size_t imgSizeH, @@ -1697,7 +1700,8 @@ public: size_t outputH, size_t outputW, size_t paddingH, - size_t paddingW); + size_t paddingW, + MatrixPtr maskMatP); void maxPoolBackward(Matrix& image, size_t imgSizeH, @@ -2066,6 +2070,7 @@ public: class SharedCpuMatrix : public CpuMatrix { public: +#ifndef PADDLE_MOBILE_INFERENCE /* blockNum is number of partitions of the matrix */ SharedCpuMatrix(int blockNum, size_t height, size_t width, bool trans = false) : CpuMatrix(height, width, trans) { @@ -2111,6 +2116,7 @@ private: ThreadLocal localBuf_; ThreadLocal> localBufRows_; ThreadLocal> blockSeq_; +#endif }; typedef struct { unsigned int col; } sparse_non_value_t; diff --git a/paddle/math/SparseMatrix.h b/paddle/math/SparseMatrix.h index 16300db081..e0a3c6d228 100644 --- a/paddle/math/SparseMatrix.h +++ b/paddle/math/SparseMatrix.h @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + +#ifndef PADDLE_MOBILE_INFERENCE + #include #include "CpuSparseMatrix.h" #include "Matrix.h" @@ -237,3 +240,47 @@ private: }; } // namespace paddle + +#else + +#include "CpuSparseMatrix.h" + +namespace paddle { + +class GpuSparseMatrix : public Matrix { +public: + GpuSparseMatrix(size_t height, + size_t width, + size_t nnz, /* used to allocate space */ + SparseValueType valueType = FLOAT_VALUE, + SparseFormat format_ = SPARSE_CSR, + bool trans = false) + : Matrix(NULL, height, width, trans, false) {} + + GpuSparseMatrix(real* value, + int* rows, + int* cols, + size_t height, + size_t width, + size_t nnz, + SparseValueType valueType, + SparseFormat format, + bool trans) + : Matrix(NULL, height, width, trans, true) {} + + void resize(size_t newHeight, + size_t newWidth, + size_t newNnz, /* used to allocate space */ + SparseValueType valueType, + SparseFormat format) {} + void resize(size_t newHeight, size_t newWidth) {} + MatrixPtr getTranspose() { return nullptr; } + void setRow(size_t row, + size_t colNum, + const unsigned int* cols, + const real* values) {} +}; + +} // namespace paddle + +#endif diff --git a/paddle/math/SparseRowMatrix.h b/paddle/math/SparseRowMatrix.h index 8704eb038d..ca7a6806da 100644 --- a/paddle/math/SparseRowMatrix.h +++ b/paddle/math/SparseRowMatrix.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once +#ifndef PADDLE_MOBILE_INFERENCE + #include #include #include @@ -313,3 +315,27 @@ private: }; } // namespace paddle + +#else +namespace paddle { + +class SparseRowCpuMatrix : public CpuMatrix { +public: + void reserveStore() {} + void clearIndices() {} +}; + +class SparsePrefetchRowCpuMatrix : public SparseRowCpuMatrix { +public: + void setupIndices() {} + void addRows(MatrixPtr input) {} + void addRows(IVectorPtr ids) {} +}; + +class SparseAutoGrowRowCpuMatrix : public SparseRowCpuMatrix {}; +class CacheRowCpuMatrix : public SparseAutoGrowRowCpuMatrix {}; +class SparseRowIdsCpuMatrix : public CpuMatrix {}; + +} // namespace paddle + +#endif diff --git a/paddle/math/tests/CMakeLists.txt b/paddle/math/tests/CMakeLists.txt index ceb96b2e25..d8b7f9e3fc 100644 --- a/paddle/math/tests/CMakeLists.txt +++ b/paddle/math/tests/CMakeLists.txt @@ -3,8 +3,10 @@ add_simple_unittest(test_ExecViaCpu) add_simple_unittest(test_SIMDFunctions) add_simple_unittest(test_TrainingAlgorithm) -add_simple_unittest(test_SparseMatrix) add_simple_unittest(test_RowBuffer) +if(NOT MOBILE_INFERENCE) + add_simple_unittest(test_SparseMatrix) +endif() # TODO(yuyang18): Refactor TestUtil.cpp. Remove this cross module reference. add_unittest(test_matrixCompare diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 29ce44c233..709f7de2e4 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -214,6 +214,7 @@ set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) +cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor) cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory) cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc rnn/recurrent_op_utils.cc diff --git a/paddle/operators/accuracy_op.cc b/paddle/operators/accuracy_op.cc index 03c2fa945d..2785a8c6fb 100644 --- a/paddle/operators/accuracy_op.cc +++ b/paddle/operators/accuracy_op.cc @@ -30,6 +30,10 @@ class AccuracyOp : public framework::OperatorWithKernel { "Input (Label) of accuracy op should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Accuracy"), "Output (Accuracy) of AccuracyOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Correct"), + "Output (Correct) of AccuracyOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Total"), + "Output (Total) of AccuracyOp should not be null."); auto inference_dim = ctx->GetInputDim("Out"); auto label_dim = ctx->GetInputDim("Label"); @@ -43,6 +47,8 @@ class AccuracyOp : public framework::OperatorWithKernel { " the same as label."); ctx->SetOutputDim("Accuracy", {1}); + ctx->SetOutputDim("Correct", {1}); + ctx->SetOutputDim("Total", {1}); ctx->ShareLoD("Out", /*->*/ "Accuracy"); } @@ -66,6 +72,8 @@ class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Label", "Label of the training data"); // TODO(typhoonzero): AddInput("Weight", ... AddOutput("Accuracy", "The accuracy of current batch"); + AddOutput("Correct", "The correct samples count of current batch"); + AddOutput("Total", "The samples count of current batch"); AddComment(R"DOC( Accuracy Operator. diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu index 1776f33105..b575c682f0 100644 --- a/paddle/operators/accuracy_op.cu +++ b/paddle/operators/accuracy_op.cu @@ -24,7 +24,8 @@ using platform::PADDLE_CUDA_NUM_THREADS; template __global__ void AccuracyCudaKernel(const int N, const int D, const int64_t* Xdata, - const int64_t* labeldata, float* accuracy) { + const int64_t* labeldata, int* correct_data, + float* accuracy) { int count = 0; __shared__ int total[BlockSize]; @@ -43,6 +44,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D, // reduce the count with init value 0, and output accuracy. int result = thrust::reduce(thrust::device, total, total + BlockSize, 0); if (threadIdx.x == 0) { + *correct_data = result; *accuracy = static_cast(result) / static_cast(N); } } @@ -56,31 +58,48 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { auto* inference = ctx.Input("Out"); auto* indices = ctx.Input("Indices"); auto* label = ctx.Input("Label"); + auto* accuracy = ctx.Output("Accuracy"); + auto* correct = ctx.Output("Correct"); + auto* total = ctx.Output("Total"); // FIXME(typhoonzero): only support indices currently // if add support for output values, how to detect the data type? const int64_t* indices_data = indices->data(); const int64_t* label_data = label->data(); + + int* correct_data = correct->mutable_data(ctx.GetPlace()); + int* total_data = total->mutable_data(ctx.GetPlace()); float* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); - size_t num_samples = inference->dims()[0]; + int num_samples = static_cast(inference->dims()[0]); size_t infer_width = inference->dims()[1]; PADDLE_ENFORCE(cudaMemset(accuracy_data, 0, sizeof(float))); + // cudaMemset((void**)&correct_data, 0, sizeof(float)); if (num_samples == 0) { return; } + cudaMemcpy(total_data, &num_samples, sizeof(int), cudaMemcpyHostToDevice); AccuracyCudaKernel<<< 1, PADDLE_CUDA_NUM_THREADS, 0, ctx.cuda_device_context().stream()>>>( - num_samples, infer_width, indices_data, label_data, accuracy_data); + num_samples, infer_width, indices_data, label_data, correct_data, + accuracy_data); + + int d_num_samples, d_num_correct; + float d_accuracy; + cudaMemcpy(&d_num_correct, correct_data, sizeof(int), + cudaMemcpyDeviceToHost); + cudaMemcpy(&d_num_samples, total_data, sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(&d_accuracy, accuracy_data, sizeof(float), + cudaMemcpyDeviceToHost); } }; } // namespace operators } // namespace paddle -// FIXME(typhoonzero): types of T is for infernece data. -// label data is always int +// FIXME(typhoonzero): types of T is for inference data. +// label data is always int64 REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel, paddle::operators::AccuracyOpCUDAKernel); diff --git a/paddle/operators/accuracy_op.h b/paddle/operators/accuracy_op.h index 28dbc77f64..d060e6eddd 100644 --- a/paddle/operators/accuracy_op.h +++ b/paddle/operators/accuracy_op.h @@ -29,7 +29,11 @@ class AccuracyKernel : public framework::OpKernel { auto* indices = ctx.Input("Indices"); auto* label = ctx.Input("Label"); auto* accuracy = ctx.Output("Accuracy"); + auto* correct = ctx.Output("Correct"); + auto* total = ctx.Output("Total"); + int* correct_data = correct->mutable_data(ctx.GetPlace()); + int* total_data = total->mutable_data(ctx.GetPlace()); float* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); const int64_t* indices_data = indices->data(); @@ -55,7 +59,8 @@ class AccuracyKernel : public framework::OpKernel { } } - // FIXME(typhoonzero): we don't accumulate the accuracy for now. + *correct_data = num_correct; + *total_data = num_samples; *accuracy_data = static_cast(num_correct) / static_cast(num_samples); } diff --git a/paddle/operators/assign_op.cc b/paddle/operators/assign_op.cc new file mode 100644 index 0000000000..609e915b93 --- /dev/null +++ b/paddle/operators/assign_op.cc @@ -0,0 +1,138 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/data_type.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/var_type.h" + +namespace paddle { +namespace operators { +class AssignFunctor { + public: + AssignFunctor(framework::Variable *out, + const platform::DeviceContext &dev_ctx) + : out_(out), dev_ctx_(dev_ctx) {} + + void operator()(const framework::LoDTensor &lod_tensor) const { + auto &out_tensor = *out_->GetMutable(); + copy_tensor(lod_tensor, &out_tensor); + } + + void operator()(const framework::LoDTensorArray &array) const { + auto &out_array = *out_->GetMutable(); + out_array.resize(array.size()); + for (size_t i = 0; i < array.size(); ++i) { + copy_tensor(array[i], &out_array[i]); + } + } + + void operator()(const framework::SelectedRows &rows) const { + framework::SelectedRows &out_rows = + *out_->GetMutable(); + out_rows.set_rows(rows.rows()); + out_rows.set_height(rows.height()); + auto &t = rows.value(); + out_rows.mutable_value()->CopyFrom(t, t.place(), dev_ctx_); + } + + template + void operator()(const T &v) const { + PADDLE_THROW("Not support type for assign op %s", typeid(T).name()); + } + + private: + void copy_tensor(const framework::LoDTensor &lod_tensor, + framework::LoDTensor *out) const { + auto &out_tensor = *out; + out_tensor.CopyFrom(lod_tensor, lod_tensor.place(), dev_ctx_); + out_tensor.set_lod(lod_tensor.lod()); + } + + framework::Variable *out_; + const platform::DeviceContext &dev_ctx_; +}; + +class AssignOp : public framework::OperatorBase { + public: + AssignOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto *x = scope.FindVar(Input("X")); + if (x == nullptr) { + return; + } + auto *out = scope.FindVar(Output("Out")); + PADDLE_ENFORCE( + out != nullptr, + "The Output(Out) should not be null if the Input(X) is set."); + framework::VisitVarType(*x, AssignFunctor(out, dev_ctx)); + } +}; + +class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + AssignOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(LoDTensor, SelectedRows or LoDTensorArray) The input variable " + "could be LoDTensor, SelectedRows or LoDTensorArray.") + .AsDispensable(); + AddOutput("Out", + "(LoDTensor, SelectedRows or LoDTensorArray) The type of output " + "is the same as input X."); + AddComment(R"DOC(Assign Operator + +Out = X, when type in [LoDTensor/SelectedRows/LoDTensorArray] +raise error if the type is not listed above. +)DOC"); + } +}; + +class AssignInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + if (context->HasInput("X")) { + auto type = context->GetInputsVarType("X")[0]; + if (type == framework::VarDesc_VarType_SELECTED_ROWS || + type == framework::VarDesc_VarType_LOD_TENSOR) { + context->SetOutputDim("Out", context->GetInputDim("X")); + } + } + } +}; + +class AssignGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *op = new framework::OpDescBind(); + op->SetType("assign"); + op->SetInput("X", OutputGrad("Out")); + op->SetOutput("Out", InputGrad("X")); + return std::unique_ptr(op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker, + ops::AssignInferShape, ops::AssignOpProtoMaker); diff --git a/paddle/operators/beam_search_decode_op.cc b/paddle/operators/beam_search_decode_op.cc new file mode 100644 index 0000000000..3904a97d58 --- /dev/null +++ b/paddle/operators/beam_search_decode_op.cc @@ -0,0 +1,111 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/beam_search_decode_op.h" + +namespace paddle { +namespace operators { + +class BeamSearchDecodeOp : public framework::OperatorBase { + public: + BeamSearchDecodeOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void Run(const framework::Scope& scope, + const platform::DeviceContext& dev_ctx) const override { + framework::ExecutionContext ctx(*this, scope, dev_ctx); + + const LoDTensorArray* ids = ctx.Input("Ids"); + const LoDTensorArray* scores = ctx.Input("Scores"); + const size_t step_num = ids->size(); + PADDLE_ENFORCE_GT(step_num, 0UL, + "beam search steps should be larger than 0"); + const size_t source_num = ids->at(0).lod().at(0).size() - 1; + PADDLE_ENFORCE_GT(source_num, 0UL, "source num should be larger than 0"); + + for (size_t i = 0; i < step_num; ++i) { + PADDLE_ENFORCE_EQ(ids->at(i).lod().size(), 2UL, + "Level of LodTensor should be 2"); + } + + // prepare output + LoDTensor* sentenceIds = ctx.Output("SentenceIds"); + LoDTensor* sentenceScores = ctx.Output("SentenceScores"); + + BeamSearchDecoder beam_search_decoder; + beam_search_decoder.PackAllSteps(*ids, *scores, sentenceIds, + sentenceScores); + } +}; + +class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + BeamSearchDecodeOpProtoMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Ids", + "(LodTensorArray)" + "score of the candidate words in each step"); + AddInput("Scores", + "(LodTensorArray)" + "score of the candidate words in each step"); + AddOutput("SentenceIds", + "(LodTensor)" + "All possible result sentences of word ids"); + AddOutput("SentenceScores", + "(LodTensor)" + "All possible result sentences of word scores"); + AddComment(R"DOC( +Pack the result of Beam search op into SentenceIds and SentenceScores. +)DOC"); + } +}; + +class BeamSearchDecodeInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* context) const override { + PADDLE_ENFORCE(context->HasInput("Ids"), + "BeamSearchDecodeOp must has input Ids"); + PADDLE_ENFORCE(context->HasInput("Scores"), + "BeamSearchDecodeOp must has input Scores"); + PADDLE_ENFORCE(context->HasOutput("SentenceIds"), + "BeamSearchDecodeOp must has output SentenceIds"); + PADDLE_ENFORCE(context->HasOutput("SentenceScores"), + "BeamSearchDecodeOp must has output SentenceScores"); + } +}; + +class BeamSearchDecodeInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDescBind& op_desc, + framework::BlockDescBind* block) const override { + for (auto& o : op_desc.Output("SentenceIds")) { + block->Var(o)->SetType(framework::VarDesc::LOD_TENSOR); + } + for (auto& o : op_desc.Output("SentenceScores")) { + block->Var(o)->SetType(framework::VarDesc::LOD_TENSOR); + } + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(beam_search_decode, paddle::operators::BeamSearchDecodeOp, + paddle::operators::BeamSearchDecodeOpProtoMaker, + paddle::operators::BeamSearchDecodeInferShape, + paddle::operators::BeamSearchDecodeInferVarType, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/operators/beam_search_decode_op.h b/paddle/operators/beam_search_decode_op.h new file mode 100644 index 0000000000..0f007ec22f --- /dev/null +++ b/paddle/operators/beam_search_decode_op.h @@ -0,0 +1,280 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/lod_tensor_array.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using LoDTensorArray = framework::LoDTensorArray; + +// all the lod have 2 levels. +// The First is source level, the second is sentence level. +// source level describe how many candidate words for this source. +// sentence level describe these candidates belong to which prefix +const size_t kSourceLevel = 0; +const size_t kSentenceLevel = 1; + +template +struct BeamNode { + BeamNode(int64_t word_id, T score) : word_id_(word_id), score_(score) {} + + ~BeamNode() { + if (parent_) { + parent_->DropKid(this); + if (parent_->kids_.size() == 0UL) { + delete parent_; + } + } + VLOG(3) << "Delete BeamNode root with word_id:" << this->word_id_; + } + + void AppendTo(BeamNode* parent) { + parent_ = parent; + parent->kids_.insert(this); + } + + void DropKid(BeamNode* kid) { kids_.erase(kid); } + + BeamNode* parent_ = nullptr; + std::unordered_set kids_; + int64_t word_id_; + T score_; +}; + +template +using BeamNodeVector = std::vector>>; + +template +struct Sentence { + std::vector word_ids; + std::vector scores; +}; + +template +using SentenceVector = std::vector>; + +template +struct BeamSearchDecoder { + /** + * make a BeamNode and all it's related prefix BeanNode into a Sentence. + */ + Sentence MakeSentence(const BeamNode* node) const; + + /** + * Param: + * cur_ids: LoDTensor of One step for word ID + * cur_scores: LoDTensor of One Step for word score + * prefixes_list: prefixes for each source sentence. + * sentence_vector_list: result sentence_vector for each source sentence. + * Return: + * a new prefixes list for each source of current step + */ + std::vector> PackTwoSteps( + const LoDTensor& cur_ids, const LoDTensor& cur_scores, + std::vector>& prefixes_list, + std::vector>* sentence_vector_list) const; + + /** + * convert the result sentence_vector for each source sentence into two + * LodTensor. + * One is all candidate sentences with word id, one is all candidate sentences + * with word score. + * Param: + * sentence_vector_list: sentence_vector for each source sentence. + * id_tensor: result LoDTensor for sentences of id. + * score_tensor: result LoDTensor for sentences of score. + */ + void ConvertSentenceVectorToLodTensor( + std::vector> sentence_vector_list, LoDTensor* id_tensor, + LoDTensor* score_tensor) const; + + /** + * Pack all steps of id/score LodTensor into sentence LoDTensor + * it's main logic is: + * ```python + * prefix + * result_sentence + * result_lod_tensor + * + * for (step in steps): + * prefix = PackTwoSteps(prefix, step, &result_sentence) + * ConvertSentenceVectorToLodTensor(result_sentence, &result_lod_tensor) + * ``` + */ + void PackAllSteps(const LoDTensorArray& step_ids, + const LoDTensorArray& step_scores, LoDTensor* id_tensor, + LoDTensor* score_tensor) const; +}; + +template +Sentence BeamSearchDecoder::MakeSentence(const BeamNode* node) const { + Sentence sentence; + while (node != nullptr) { + sentence.word_ids.emplace_back(node->word_id_); + sentence.scores.emplace_back(node->score_); + node = node->parent_; + } + + std::reverse(std::begin(sentence.word_ids), std::end(sentence.word_ids)); + std::reverse(std::begin(sentence.scores), std::end(sentence.scores)); + + return sentence; +} + +template +std::vector> BeamSearchDecoder::PackTwoSteps( + const LoDTensor& cur_ids, const LoDTensor& cur_scores, + std::vector>& prefixes_list, + std::vector>* sentence_vector_list) const { + std::vector> result; + + for (size_t src_idx = 0; src_idx < cur_ids.lod()[kSourceLevel].size() - 1; + ++src_idx) { + size_t src_start = cur_ids.lod().at(kSourceLevel)[src_idx]; + size_t src_end = cur_ids.lod().at(kSourceLevel)[src_idx + 1]; + + BeamNodeVector beam_nodes; + + // if prefixes size is 0, it means this is the first step. In this step, + // all candidate id is the start of candidate sentences. + if (prefixes_list.empty()) { + PADDLE_ENFORCE_EQ(cur_ids.lod().at(kSourceLevel).back(), + cur_ids.lod().at(kSentenceLevel).back(), + "in the first step"); + for (size_t id_idx = src_start; id_idx < src_end; ++id_idx) { + beam_nodes.push_back(std::unique_ptr>(new BeamNode( + cur_ids.data()[id_idx], cur_scores.data()[id_idx]))); + } + } else { + BeamNodeVector& prefixes = prefixes_list[src_idx]; + SentenceVector& sentence_vector = (*sentence_vector_list)[src_idx]; + + PADDLE_ENFORCE_EQ(src_end - src_start, prefixes.size(), + "prefix and candidate set number should be the same"); + + auto candidate_offset = cur_ids.lod()[kSentenceLevel]; + for (size_t prefix_idx = 0; prefix_idx < prefixes.size(); ++prefix_idx) { + std::unique_ptr>& prefix = prefixes[prefix_idx]; + size_t candidate_start = candidate_offset[src_start + prefix_idx]; + size_t candidate_end = candidate_offset[src_start + prefix_idx + 1]; + if (candidate_start == candidate_end) { + VLOG(3) << "this sentence has no more candidate, " + "add to result sentence and rm it from beam tree"; + sentence_vector.push_back(MakeSentence(prefix.get())); + prefix.reset(); + } else { + for (size_t candidate_idx = candidate_start; + candidate_idx < candidate_end; ++candidate_idx) { + auto* candidate = + new BeamNode(cur_ids.data()[candidate_idx], + cur_scores.data()[candidate_idx]); + candidate->AppendTo(prefix.get()); + beam_nodes.push_back(std::unique_ptr>(candidate)); + } + prefix.release(); + } + } + } + result.push_back(std::move(beam_nodes)); + } + return result; +} + +template +void BeamSearchDecoder::ConvertSentenceVectorToLodTensor( + std::vector> sentence_vector_list, LoDTensor* id_tensor, + LoDTensor* score_tensor) const { + size_t src_num = sentence_vector_list.size(); + + PADDLE_ENFORCE_NE(src_num, 0, "src_num should not be 0"); + + std::vector source_level_lod = {0}; + std::vector sentence_level_lod = {0}; + std::vector id_data; + std::vector score_data; + + for (size_t src_idx = 0; src_idx < src_num; ++src_idx) { + for (Sentence& sentence : sentence_vector_list[src_idx]) { + id_data.insert(id_data.end(), sentence.word_ids.begin(), + sentence.word_ids.end()); + score_data.insert(score_data.end(), sentence.scores.begin(), + sentence.scores.end()); + sentence_level_lod.push_back(sentence_level_lod.back() + + sentence.word_ids.size()); + } + source_level_lod.push_back(source_level_lod.back() + + sentence_vector_list[src_idx].size()); + } + + auto cpu_place = new paddle::platform::CPUPlace(); + paddle::platform::CPUDeviceContext cpu_ctx(*cpu_place); + + framework::LoD lod; + lod.push_back(source_level_lod); + lod.push_back(sentence_level_lod); + + id_tensor->set_lod(lod); + id_tensor->Resize({static_cast(id_data.size())}); + id_tensor->mutable_data(paddle::platform::CPUPlace()); + id_tensor->CopyFromVector(id_data, cpu_ctx); + + score_tensor->set_lod(lod); + score_tensor->Resize({static_cast(score_data.size())}); + score_tensor->mutable_data(paddle::platform::CPUPlace()); + score_tensor->CopyFromVector(score_data, cpu_ctx); +} + +template +void BeamSearchDecoder::PackAllSteps(const LoDTensorArray& step_ids, + const LoDTensorArray& step_scores, + LoDTensor* id_tensor, + LoDTensor* score_tensor) const { + PADDLE_ENFORCE(!step_ids.empty(), "step num should be larger than 0"); + PADDLE_ENFORCE_EQ(step_ids.size(), step_scores.size(), + "step_ids and step_scores should be the same"); + const size_t step_num = step_ids.size(); + const size_t src_num = step_ids.at(0).lod().at(kSourceLevel).size() - 1; + + PADDLE_ENFORCE_GT(src_num, 0UL, "source num should be larger than 0"); + + // previous prefixes for each step, + // the init length is 0, means this is the first step. + std::vector> beamnode_vector_list(0); + std::vector> sentence_vector_list(src_num); + + // pack all steps for one batch first, then another batch + for (size_t step_id = 0; step_id < step_num; ++step_id) { + beamnode_vector_list = + PackTwoSteps(step_ids.at(step_id), step_scores.at(step_id), + beamnode_vector_list, &sentence_vector_list); + } + // append last beam_node to result + for (size_t src_idx = 0; src_idx < src_num; ++src_idx) { + for (auto& beam_node : beamnode_vector_list.at(src_idx)) { + sentence_vector_list[src_idx].push_back(MakeSentence(beam_node.get())); + beam_node.reset(); + } + } + + ConvertSentenceVectorToLodTensor(sentence_vector_list, id_tensor, + score_tensor); +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/beam_search_decode_op_test.cc b/paddle/operators/beam_search_decode_op_test.cc new file mode 100644 index 0000000000..5ac23991f3 --- /dev/null +++ b/paddle/operators/beam_search_decode_op_test.cc @@ -0,0 +1,221 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/beam_search_decode_op.h" +#include "gtest/gtest.h" + +using CPUPlace = paddle::platform::CPUPlace; +using LoD = paddle::framework::LoD; +using LoDTensor = paddle::framework::LoDTensor; +using LoDTensorArray = paddle::framework::LoDTensorArray; + +template +using BeamNode = paddle::operators::BeamNode; +template +using BeamSearchDecoder = paddle::operators::BeamSearchDecoder; +template +using Sentence = paddle::operators::Sentence; +template +using BeamNodeVector = paddle::operators::BeamNodeVector; +template +using SentenceVector = paddle::operators::SentenceVector; + +namespace paddle { +namespace test { + +void GenerateExample(const std::vector& level_0, + const std::vector& level_1, + const std::vector& data, LoDTensorArray* ids, + LoDTensorArray* scores) { + PADDLE_ENFORCE_EQ(level_0.back(), level_1.size() - 1, + "source level is used to describe candidate set"); + PADDLE_ENFORCE_EQ(level_1.back(), data.size(), + "the lowest level is used to describe data" + ", so it's last element should be data length"); + + CPUPlace place; + + LoD lod; + lod.push_back(level_0); + lod.push_back(level_1); + + // Ids + LoDTensor tensor_id; + tensor_id.set_lod(lod); + tensor_id.Resize({static_cast(data.size())}); + // malloc memory + int64_t* id_ptr = tensor_id.mutable_data(place); + for (size_t i = 0; i < data.size(); ++i) { + id_ptr[i] = static_cast(data.at(i)); + } + + // Scores + LoDTensor tensor_score; + tensor_score.set_lod(lod); + tensor_score.Resize({static_cast(data.size())}); + // malloc memory + float* score_ptr = tensor_score.mutable_data(place); + for (size_t i = 0; i < data.size(); ++i) { + score_ptr[i] = static_cast(data.at(i)); + } + + ids->push_back(tensor_id); + scores->push_back(tensor_score); +} + +} // namespace test +} // namespace paddle + +TEST(BeamSearchDecodeOp, DeleteBeamNode) { + auto* root = new BeamNode(0, 0); + auto* b1 = new BeamNode(1, 1); + auto* b2 = new BeamNode(2, 2); + auto* b3 = new BeamNode(3, 3); + + b1->AppendTo(root); + b2->AppendTo(root); + b3->AppendTo(b1); + + delete b3; + delete b2; +} + +TEST(BeamSearchDecodeOp, MakeSentence) { + auto* root = new BeamNode(0, 0); + auto* b1 = new BeamNode(1, 1); + auto* end = new BeamNode(2, 2); + b1->AppendTo(root); + end->AppendTo(b1); + + BeamSearchDecoder helper; + Sentence sentence = helper.MakeSentence(end); + delete end; + + std::vector expect_ids = {0, 1, 2}; + ASSERT_EQ(sentence.word_ids, expect_ids); + + std::vector expect_scores = {0, 1, 2}; + ASSERT_EQ(sentence.scores, expect_scores); +} + +TEST(BeamSearchDecodeOp, PackTwoStepsFistStep) { + CPUPlace place; + + LoDTensorArray ids; + LoDTensorArray scores; + + paddle::test::GenerateExample( + std::vector{0, 2, 6}, std::vector{0, 1, 2, 3, 4, 5, 6}, + std::vector{1, 2, 3, 4, 5, 6}, &ids, &scores); + + std::vector> beamnode_vector_list; + std::vector> sentence_vector_list( + 2, SentenceVector()); + + BeamSearchDecoder helper; + beamnode_vector_list = helper.PackTwoSteps( + ids[0], scores[0], beamnode_vector_list, &sentence_vector_list); + ASSERT_EQ(beamnode_vector_list.size(), 2UL); + ASSERT_EQ(beamnode_vector_list[0].size(), 2UL); + ASSERT_EQ(beamnode_vector_list[1].size(), 4UL); +} + +TEST(BeamSearchDecodeOp, PackTwoSteps) { + CPUPlace place; + + // first source has three prefix + BeamNodeVector source0_prefixes; + source0_prefixes.push_back( + std::unique_ptr>(new BeamNode(1, 1))); + source0_prefixes.push_back( + std::unique_ptr>(new BeamNode(0, 0))); + source0_prefixes.push_back( + std::unique_ptr>(new BeamNode(3, 3))); + + // second source has two prefix + BeamNodeVector source1_prefixes; + source1_prefixes.push_back( + std::unique_ptr>(new BeamNode(4, 4))); + source1_prefixes.push_back( + std::unique_ptr>(new BeamNode(5, 5))); + + std::vector> beamnode_vector_list; + std::vector> sentence_vector_list( + 2, SentenceVector()); + + beamnode_vector_list.push_back(std::move(source0_prefixes)); + beamnode_vector_list.push_back(std::move(source1_prefixes)); + + // generate data for one step + LoDTensorArray ids; + LoDTensorArray scores; + + paddle::test::GenerateExample(std::vector{0, 3, 5}, + std::vector{0, 1, 1, 3, 4, 5}, + std::vector{0, 1, 2, 3, 4}, &ids, &scores); + + BeamSearchDecoder helper1; + beamnode_vector_list = helper1.PackTwoSteps( + ids[0], scores[0], beamnode_vector_list, &sentence_vector_list); + + ASSERT_EQ(sentence_vector_list[0].size(), 1UL); + ASSERT_EQ(sentence_vector_list[1].size(), 0UL); + ASSERT_EQ(beamnode_vector_list[0].size(), 3UL); + ASSERT_EQ(beamnode_vector_list[1].size(), 2UL); +} + +TEST(BeamSearchDecodeOp, PackAllSteps) { + CPUPlace place; + + // we will constuct a sample data with 3 steps and 2 source sentences + LoDTensorArray ids; + LoDTensorArray scores; + + paddle::test::GenerateExample( + std::vector{0, 3, 6}, std::vector{0, 1, 2, 3, 4, 5, 6}, + std::vector{1, 2, 3, 4, 5, 6}, &ids, &scores); + paddle::test::GenerateExample( + std::vector{0, 3, 6}, std::vector{0, 1, 1, 3, 5, 5, 6}, + std::vector{0, 1, 2, 3, 4, 5}, &ids, &scores); + paddle::test::GenerateExample(std::vector{0, 3, 6}, + std::vector{0, 0, 1, 2, 3, 4, 5}, + std::vector{0, 1, 2, 3, 4}, &ids, &scores); + + ASSERT_EQ(ids.size(), 3UL); + ASSERT_EQ(scores.size(), 3UL); + + BeamSearchDecoder helper; + + LoDTensor id_tensor; + LoDTensor score_tensor; + helper.PackAllSteps(ids, scores, &id_tensor, &score_tensor); + + LoD lod = id_tensor.lod(); + std::vector expect_source_lod = {0, 4, 8}; + EXPECT_EQ(lod[0], expect_source_lod); + std::vector expect_sentence_lod = {0, 1, 3, 6, 9, 10, 13, 16, 19}; + EXPECT_EQ(lod[1], expect_sentence_lod); + // 2| 1, 0| 3, 1, 0| 3, 2, 1| 5| 4, 3, 2| 4, 4, 3| 6, 5, 4 + std::vector expect_data = {2, 1, 0, 3, 1, 0, 3, 2, 1, 5, + 4, 3, 2, 4, 4, 3, 6, 5, 4}; + ASSERT_EQ(id_tensor.dims()[0], static_cast(expect_data.size())); + for (size_t i = 0; i < expect_data.size(); ++i) { + ASSERT_EQ(id_tensor.data()[i], + static_cast(expect_data[i])); + } + for (int64_t i = 0; i < id_tensor.dims()[0]; ++i) { + ASSERT_EQ(score_tensor.data()[i], + static_cast(id_tensor.data()[i])); + } +} diff --git a/paddle/operators/bilinear_tensor_product_op.cc b/paddle/operators/bilinear_tensor_product_op.cc new file mode 100644 index 0000000000..c65ba7eb26 --- /dev/null +++ b/paddle/operators/bilinear_tensor_product_op.cc @@ -0,0 +1,159 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/bilinear_tensor_product_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class BilinearTensorProductOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(Weight) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto weight_dims = ctx->GetInputDim("Weight"); + + PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The input(X) must be a 2D Tensor."); + PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The input(Y) must be a 2D Tensor."); + PADDLE_ENFORCE_EQ(weight_dims.size(), 3UL, + "The input(Weight) must be a 3D tensor."); + PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], + "The first dimension(batch_size) of input(X) must be " + "equal to the first dimension of the input(Y)."); + PADDLE_ENFORCE_EQ(x_dims[1], weight_dims[1], + "The second dimension of input(X) must be equal to " + "the second dimension of the input(Weight)."); + PADDLE_ENFORCE_EQ(y_dims[1], weight_dims[2], + "The second dimension of input(Y) must be equal to " + "the third dimension of the input(Weight)."); + + if (ctx->HasInput("Bias")) { + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE(bias_dims.size() == 2UL && bias_dims[0] == 1UL, + "The Input(Bias) must be a 2-D tensor with " + "the 2nd dimension fixed to 1 (a row vector)."); + PADDLE_ENFORCE_EQ(bias_dims[1], weight_dims[0], + "The second dimension of input(Bias) must be equal " + "to the first dimension of the input(Weight)."); + } + + ctx->SetOutputDim("Out", {x_dims[0], weight_dims[0]}); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker { + public: + BilinearTensorProductOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The first input of bilinear_tensor_product operator."); + AddInput("Y", "The second input of bilinear_tensor_product operator."); + AddInput("Weight", + "The learnable parameters of bilinear_tensor_product operator."); + AddInput("Bias", "The learnable bias of bilinear_tensor_product operator.") + .AsDispensable(); + AddOutput("Out", "The output of bilinear_tensor_product operator."); + AddComment(R"DOC( +Bilinear Tensor Product operator. +Given input X and Y, a 3D tensor weight, and bias. Each column of the +output is computed by one slice i = 1, . . . , k of the tensor: + + M = (X W_i) \cdot Y + Out_i = \sum_i {M_i} + Bias_i + +)DOC"); + } +}; + +class BilinearTensorProductOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(Weight) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null."); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto weight_dims = ctx->GetInputDim("Weight"); + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); + + PADDLE_ENFORCE_EQ(out_dims.size(), 2UL, + "The input(Out@GRAD) must be a 2D Tensor."); + PADDLE_ENFORCE_EQ( + x_dims[0], out_dims[0], + "The first dimension(batch_size) of input(Out@GRAD) must be " + "equal to the first dimension of the Input(X)."); + PADDLE_ENFORCE_EQ( + weight_dims[0], out_dims[1], + "The second dimension of input(Out@GRAD) must be equal to " + "the third dimension of the Input(Weight)."); + + if (ctx->HasInput("Bias")) { + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ( + bias_dims[1], out_dims[1], + "The second dimension of input(Out@GRAD) must be equal to " + "the second dimension of the Input(Bias)."); + auto bias_grad_name = framework::GradVarName("Bias"); + if (ctx->HasOutput(bias_grad_name)) + ctx->SetOutputDim(bias_grad_name, bias_dims); + } + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + auto weight_grad_name = framework::GradVarName("Weight"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } + if (ctx->HasOutput(weight_grad_name)) { + ctx->SetOutputDim(weight_grad_name, weight_dims); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(bilinear_tensor_product, ops::BilinearTensorProductOp, + ops::BilinearTensorProductOpMaker, bilinear_tensor_product_grad, + ops::BilinearTensorProductOpGrad); +REGISTER_OP_CPU_KERNEL( + bilinear_tensor_product, + ops::BilinearTensorProductKernel, + ops::BilinearTensorProductKernel); +REGISTER_OP_CPU_KERNEL( + bilinear_tensor_product_grad, + ops::BilinearTensorProductGradKernel, + ops::BilinearTensorProductGradKernel); diff --git a/paddle/operators/bilinear_tensor_product_op.cu b/paddle/operators/bilinear_tensor_product_op.cu new file mode 100644 index 0000000000..858d2668d0 --- /dev/null +++ b/paddle/operators/bilinear_tensor_product_op.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/bilinear_tensor_product_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + bilinear_tensor_product, + ops::BilinearTensorProductKernel, + ops::BilinearTensorProductKernel); +REGISTER_OP_GPU_KERNEL( + bilinear_tensor_product_grad, + ops::BilinearTensorProductGradKernel, + ops::BilinearTensorProductGradKernel); diff --git a/paddle/operators/bilinear_tensor_product_op.h b/paddle/operators/bilinear_tensor_product_op.h new file mode 100644 index 0000000000..ffa4f43a32 --- /dev/null +++ b/paddle/operators/bilinear_tensor_product_op.h @@ -0,0 +1,184 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +template +using EigenMatrix = framework::EigenMatrix; + +template +class BilinearTensorProductKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* weight = ctx.Input("Weight"); + auto* bias = ctx.Input("Bias"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + auto y_mat = EigenMatrix::From(*y); + auto output_mat = EigenMatrix::From(*out); + + auto batch_size = x->dims()[0]; + auto weight_dims = weight->dims(); + int out_dim = weight_dims[0]; + auto x_dim = weight_dims[1]; + auto y_dim = weight_dims[2]; + auto place = ctx.GetEigenDevice(); + + // Create the intermediate variable to caculate the result of + // Input(X) multiplied by Input(Weight_i), the formula is: + // left_mul = X Weight_i. + Tensor left_mul; + left_mul.mutable_data(framework::make_ddim({batch_size, y_dim}), + ctx.GetPlace()); + auto left_mul_mat = EigenMatrix::From(left_mul); + + for (int i = 0; i < out_dim; ++i) { + auto output_col_vec = output_mat.chip(i, 1); + Tensor weight_mat = + weight->Slice(i, i + 1).Resize(framework::make_ddim({x_dim, y_dim})); + math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, + batch_size, y_dim, x_dim, 1, x->data(), + weight_mat.data(), 0, left_mul.data()); + output_col_vec.device(place) = + (left_mul_mat * y_mat).sum(Eigen::DSizes(1)); + } + if (bias) { + auto bias_vec = EigenMatrix::From(*bias); + Eigen::DSizes bcast(batch_size, 1); + output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat; + } + } +}; + +template +class BilinearTensorProductGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* x = ctx.Input("X"); + const Tensor* y = ctx.Input("Y"); + const Tensor* weight = ctx.Input("Weight"); + Tensor* d_x = ctx.Output(framework::GradVarName("X")); + Tensor* d_y = ctx.Output(framework::GradVarName("Y")); + Tensor* d_weight = ctx.Output(framework::GradVarName("Weight")); + Tensor* d_bias = ctx.Output(framework::GradVarName("Bias")); + const Tensor* d_out = ctx.Input(framework::GradVarName("Out")); + + auto batch_size = x->dims()[0]; + auto weight_dims = weight->dims(); + int out_dim = weight_dims[0]; + auto x_dim = weight_dims[1]; + auto y_dim = weight_dims[2]; + + auto x_mat = EigenMatrix::From(*x); + auto y_mat = EigenMatrix::From(*y); + auto d_out_mat = EigenMatrix::From(*d_out); + auto place = ctx.GetEigenDevice(); + + // Create the intermediate variable to caculate the Output(Y@Grad). + Tensor x_scale; + x_scale.mutable_data(framework::make_ddim({batch_size, x_dim}), + ctx.GetPlace()); + auto x_scale_mat = EigenMatrix::From(x_scale); + + // Create the intermediate variable to caculate the Output(X@Grad). + Tensor y_scale; + y_scale.mutable_data(framework::make_ddim({batch_size, y_dim}), + ctx.GetPlace()); + auto y_scale_mat = EigenMatrix::From(y_scale); + + math::SetConstant set_zero; + + // Set Output(X@Grad) be zero. + if (d_x) { + d_x->mutable_data(ctx.GetPlace()); + set_zero(ctx.device_context(), d_x, static_cast(0)); + } + + // Set Output(Y@Grad) be zero. + if (d_y) { + d_y->mutable_data(ctx.GetPlace()); + set_zero(ctx.device_context(), d_y, static_cast(0)); + } + + // Caculate the Output(X@Grad) and Output(Y@Grad). + if (d_x || d_y) { + Eigen::DSizes bcast_for_x(1, y_dim); + Eigen::DSizes bcast_for_y(1, x_dim); + for (int i = 0; i < out_dim; ++i) { + Tensor weight_i = weight->Slice(i, i + 1).Resize( + framework::make_ddim({x_dim, y_dim})); + auto output_vec = d_out_mat.chip(i, 1); + if (d_x) { + y_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_x) * + y_mat; + math::gemm(ctx.device_context(), CblasNoTrans, CblasTrans, + batch_size, x_dim, y_dim, 1, y_scale.data(), + weight_i.data(), 1, d_x->data()); + } + if (d_y) { + x_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_y) * + x_mat; + math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, + batch_size, y_dim, x_dim, 1, x_scale.data(), + weight_i.data(), 1, d_y->data()); + } + } + } + + // Caculate the gradient of Input(Weight). + if (d_weight) { + d_weight->mutable_data(ctx.GetPlace()); + Eigen::DSizes bcast_for_weight(1, x_dim); + for (int i = 0; i < out_dim; ++i) { + Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize( + framework::make_ddim({x_dim, y_dim})); + auto output_vec = d_out_mat.chip(i, 1); + x_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_weight) * + x_mat; + math::gemm(ctx.device_context(), CblasTrans, CblasNoTrans, + x_dim, y_dim, batch_size, 1, x_scale.data(), + y->data(), 0, d_weight_i.data()); + } + } + + // Caculate the gradient of Input(Bias). + if (d_bias) { + d_bias->mutable_data(ctx.GetPlace()); + auto d_bias_mat = EigenMatrix::From(*d_bias); + d_bias_mat.device(place) = d_out_mat.sum(Eigen::DSizes(0)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/compare_op.cc b/paddle/operators/compare_op.cc index 716b5ee92d..bf7e883681 100644 --- a/paddle/operators/compare_op.cc +++ b/paddle/operators/compare_op.cc @@ -94,5 +94,13 @@ class CompareOp : public framework::OperatorWithKernel { REGISTER_LOGICAL_OP(less_than, "Out = X < Y"); REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); +REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y"); +REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); +REGISTER_LOGICAL_OP(greater_than, "Out = X > Y"); +REGISTER_LOGICAL_KERNEL(greater_than, CPU, + paddle::operators::GreaterThanFunctor); +REGISTER_LOGICAL_OP(greater_equal, "Out = X >= Y"); +REGISTER_LOGICAL_KERNEL(greater_equal, CPU, + paddle::operators::GreaterEqualFunctor); REGISTER_LOGICAL_OP(equal, "Out = X == Y"); REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor); diff --git a/paddle/operators/compare_op.cu b/paddle/operators/compare_op.cu index 42a5bb2f45..6ac8c124b9 100644 --- a/paddle/operators/compare_op.cu +++ b/paddle/operators/compare_op.cu @@ -15,4 +15,9 @@ #include "paddle/operators/compare_op.h" REGISTER_LOGICAL_KERNEL(less_than, GPU, paddle::operators::LessThanFunctor); +REGISTER_LOGICAL_KERNEL(less_equal, GPU, paddle::operators::LessEqualFunctor); +REGISTER_LOGICAL_KERNEL(greater_than, GPU, + paddle::operators::GreaterThanFunctor); +REGISTER_LOGICAL_KERNEL(greater_equal, GPU, + paddle::operators::GreaterEqualFunctor); REGISTER_LOGICAL_KERNEL(equal, GPU, paddle::operators::EqualFunctor); diff --git a/paddle/operators/compare_op.h b/paddle/operators/compare_op.h index 04e04e347b..afdf3ab3e0 100644 --- a/paddle/operators/compare_op.h +++ b/paddle/operators/compare_op.h @@ -27,6 +27,24 @@ struct LessThanFunctor { HOSTDEVICE bool operator()(const T& a, const T& b) const { return a < b; } }; +template +struct LessEqualFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; } +}; + +template +struct GreaterThanFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; } +}; + +template +struct GreaterEqualFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; } +}; + template struct EqualFunctor { using ELEM_TYPE = T; diff --git a/paddle/operators/conditional_block_op.cc b/paddle/operators/conditional_block_op.cc new file mode 100644 index 0000000000..d5b124682d --- /dev/null +++ b/paddle/operators/conditional_block_op.cc @@ -0,0 +1,197 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#include +#include "paddle/framework/executor.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class ConditionalOp : public framework::OperatorBase { + public: + ConditionalOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + protected: + std::vector InputTensors( + const framework::Scope &scope) const { + std::vector retv; + auto xs = Inputs("X"); + retv.resize(xs.size(), nullptr); + std::transform( + xs.begin(), xs.end(), retv.begin(), + [&scope](const std::string &var_name) -> const framework::LoDTensor * { + auto *var = scope.FindVar(var_name); + PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", var_name); + return &var->Get(); + }); + return retv; + } +}; + +class ConditionalBlockOp : public ConditionalOp { + public: + ConditionalBlockOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : ConditionalOp(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto xs = InputTensors(scope); + bool need_run = std::all_of( + xs.begin(), xs.end(), + [](const framework::LoDTensor *t) { return t->numel() != 0; }); + + if (need_run) { + auto *scope_var = scope.FindVar(Output("Scope")); + PADDLE_ENFORCE(scope_var != nullptr, "Must set scope"); + auto *scopes = scope_var->GetMutable>(); + scopes->resize(1); + scopes->front() = &scope.NewScope(); + auto &cur_scope = *scopes->front(); + + auto *block = Attr("block"); + framework::Executor exec(dev_ctx); + exec.Run(*block->Program(), &cur_scope, block->ID(), false); + } + } +}; + +class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + ConditionalBlockOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "The conditional variable of this operator. If X is empty, the " + "whole sub-block will not be executed.") + .AsDuplicable(); + AddInput("Params", "The input variables of the sub-block.").AsDuplicable(); + AddOutput("Out", "The output variables of the sub-block.").AsDuplicable(); + AddOutput("Scope", + "(std::vector) The step scope of conditional block. To " + "unify the conditional block, rnn and while op, the type of " + "scope is std::vector"); + AddAttr( + "block", "The step block of conditional block operator"); + AddComment(R"DOC(Conditional block operator + +Run the sub-block if X is not empty. Params is the other inputs and Out is the +outputs of the sub-block. +)DOC"); + } +}; + +class ConditionalBlockGradOp : public ConditionalOp { + public: + ConditionalBlockGradOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : ConditionalOp(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto xs = this->InputTensors(scope); + bool need_run = std::all_of( + xs.begin(), xs.end(), + [](const framework::LoDTensor *t) { return t->numel() != 0; }); + + if (need_run) { + auto *scope_var = scope.FindVar(Input("Scope")); + PADDLE_ENFORCE(scope_var != nullptr, "Must set scope"); + auto &scopes = scope_var->Get>(); + framework::Scope &cur_scope = *scopes[0]; + + auto *block = Attr("block"); + framework::Executor exec(dev_ctx); + exec.Run(*block->Program(), &cur_scope, block->ID(), false); + + AssignLocalGradientToGlobal(dev_ctx, cur_scope, Inputs("Params"), + Outputs(framework::GradVarName("Params"))); + + AssignLocalGradientToGlobal(dev_ctx, cur_scope, Inputs("X"), + Outputs(framework::GradVarName("X"))); + } + } + + private: + void AssignLocalGradientToGlobal( + const platform::DeviceContext &dev_ctx, const framework::Scope &cur_scope, + const std::vector &p_names, + const std::vector &pg_names) const { + for (size_t i = 0; i < p_names.size(); ++i) { + auto out_grad_name = pg_names[i]; + auto in_grad_name = framework::GradVarName(p_names[i]); + auto *in_var = cur_scope.FindVar(in_grad_name); + if (in_var == nullptr) { + continue; + } + auto new_in_grad_name = cur_scope.Rename(in_grad_name); + auto assign = + framework::OpRegistry::CreateOp("assign", {{"X", {new_in_grad_name}}}, + {{"Out", {out_grad_name}}}, {}); + assign->Run(cur_scope, dev_ctx); + cur_scope.Rename(new_in_grad_name, in_grad_name); + } + } +}; + +class ConditionalBlockGradInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE(context->HasInputs("X")); + if (context->HasInputs("Params")) { + PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("Params"))); + context->SetOutputsDim(framework::GradVarName("Params"), + context->GetInputsDim("Params")); + } + PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("X"))); + context->SetOutputsDim(framework::GradVarName("X"), + context->GetInputsDim("X")); + } +}; + +class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto grad_op = new framework::OpDescBind(); + grad_op->SetType("conditional_block_grad"); + grad_op->SetInput("X", Input("X")); + grad_op->SetInput("Params", Input("Params")); + grad_op->SetInput("Out", Output("Out")); + grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + grad_op->SetInput("Scope", Output("Scope")); + grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + grad_op->SetOutput(framework::GradVarName("Params"), InputGrad("Params")); + grad_op->SetBlockAttr("block", *this->grad_block_[0]); + return std::unique_ptr(grad_op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(conditional_block, ops::ConditionalBlockOp, + ops::ConditionalBlockOpProtoMaker, + ops::ConditionalBlockGradMaker); +REGISTER_OPERATOR(conditional_block_grad, ops::ConditionalBlockGradOp, + ops::ConditionalBlockGradInferShape); diff --git a/paddle/operators/elementwise_add_op.cc b/paddle/operators/elementwise_add_op.cc index ebe1de90c7..432b9ba6f7 100644 --- a/paddle/operators/elementwise_add_op.cc +++ b/paddle/operators/elementwise_add_op.cc @@ -34,7 +34,13 @@ REGISTER_OP(elementwise_add, ops::ElementwiseOp, ops::ElementwiseAddOpMaker, elementwise_add_grad, ops::ElementwiseOpGrad); REGISTER_OP_CPU_KERNEL( elementwise_add, - ops::ElementwiseAddKernel); + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel); REGISTER_OP_CPU_KERNEL( elementwise_add_grad, - ops::ElementwiseAddGradKernel); + ops::ElementwiseAddGradKernel, + ops::ElementwiseAddGradKernel, + ops::ElementwiseAddGradKernel, + ops::ElementwiseAddGradKernel); diff --git a/paddle/operators/elementwise_div_op.cc b/paddle/operators/elementwise_div_op.cc index de75816a24..7a325199bd 100644 --- a/paddle/operators/elementwise_div_op.cc +++ b/paddle/operators/elementwise_div_op.cc @@ -35,7 +35,13 @@ REGISTER_OP(elementwise_div, ops::ElementwiseOp, ops::ElementwiseDivOpMaker, elementwise_div_grad, ops::ElementwiseOpGrad); REGISTER_OP_CPU_KERNEL( elementwise_div, - ops::ElementwiseDivKernel); + ops::ElementwiseDivKernel, + ops::ElementwiseDivKernel, + ops::ElementwiseDivKernel, + ops::ElementwiseDivKernel); REGISTER_OP_CPU_KERNEL( elementwise_div_grad, - ops::ElementwiseDivGradKernel); + ops::ElementwiseDivGradKernel, + ops::ElementwiseDivGradKernel, + ops::ElementwiseDivGradKernel, + ops::ElementwiseDivGradKernel); diff --git a/paddle/operators/elementwise_mul_op.cc b/paddle/operators/elementwise_mul_op.cc index ffa10486f1..8851267a52 100644 --- a/paddle/operators/elementwise_mul_op.cc +++ b/paddle/operators/elementwise_mul_op.cc @@ -37,8 +37,12 @@ REGISTER_OP(elementwise_mul, ops::ElementwiseOp, ops::ElementwiseMulOpMaker, REGISTER_OP_CPU_KERNEL( elementwise_mul, ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel); + ops::ElementwiseMulKernel, + ops::ElementwiseMulKernel, + ops::ElementwiseMulKernel); REGISTER_OP_CPU_KERNEL( elementwise_mul_grad, ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel); + ops::ElementwiseMulGradKernel, + ops::ElementwiseMulGradKernel, + ops::ElementwiseMulGradKernel); diff --git a/paddle/operators/elementwise_sub_op.cc b/paddle/operators/elementwise_sub_op.cc index 39702dad0e..95d7979e39 100644 --- a/paddle/operators/elementwise_sub_op.cc +++ b/paddle/operators/elementwise_sub_op.cc @@ -34,7 +34,13 @@ REGISTER_OP(elementwise_sub, ops::ElementwiseOp, ops::ElementwiseSubOpMaker, elementwise_sub_grad, ops::ElementwiseOpGrad); REGISTER_OP_CPU_KERNEL( elementwise_sub, - ops::ElementwiseSubKernel); + ops::ElementwiseSubKernel, + ops::ElementwiseSubKernel, + ops::ElementwiseSubKernel, + ops::ElementwiseSubKernel); REGISTER_OP_CPU_KERNEL( elementwise_sub_grad, - ops::ElementwiseSubGradKernel); + ops::ElementwiseSubGradKernel, + ops::ElementwiseSubGradKernel, + ops::ElementwiseSubGradKernel, + ops::ElementwiseSubGradKernel); diff --git a/paddle/operators/l1_norm_op.h b/paddle/operators/l1_norm_op.h index de459818ad..3c60dc3dc7 100644 --- a/paddle/operators/l1_norm_op.h +++ b/paddle/operators/l1_norm_op.h @@ -29,7 +29,7 @@ class L1NormKernel : public framework::OpKernel { Out->mutable_data(context.GetPlace()); auto x = framework::EigenVector::Flatten(*X); - auto out = framework::EigenVector::Flatten(*Out); + auto out = framework::EigenScalar::From(*Out); auto place = context.GetEigenDevice(); out.device(place) = x.abs().sum(); diff --git a/paddle/operators/lod_reset_op.cc b/paddle/operators/lod_reset_op.cc new file mode 100644 index 0000000000..32831cb1e2 --- /dev/null +++ b/paddle/operators/lod_reset_op.cc @@ -0,0 +1,120 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/lod_reset_op.h" + +namespace paddle { +namespace operators { + +class LoDResetOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + // input check + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of LoDResetOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of LoDResetOp should not be null."); + // If target LoD is not set form Input(), then it must be set from Attr(). + if (!ctx->HasInput("TargetLoD")) { + auto level0 = ctx->Attrs().Get>("target_lod"); + PADDLE_ENFORCE(level0.size() > 1, + "Target LoD is not found, should be set to be a valid one " + "through Input() or Attr()."); + } + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LoDResetOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "(LoDTensor) The input tensor of lod_reset operator."); + AddInput("TargetLoD", + "(Tensor, optional) The target level 0 LoD from Input().") + .AsDispensable(); + AddOutput("Out", "(LoDTensor) The output tensor of lod_reset operator."); + AddAttr>("target_lod", + "The target level 0 LoD from Attr().") + .SetDefault(std::vector{}); + AddComment(R"DOC(LoDReset operator + +Reset LoD of Input(X) into a new one specified by Input(TargetLoD) or +Attr(target_lod), or set LoD for Input(X) if it doesn't have one. +Currently the lod_reset operator only supports the reset of level 0 LoD. +At least one of Input(TargetLoD) and Attr(target_lod) must be set, +and if both of them are set, Input(TargetLoD) will be chosen as the +target LoD. + +An example: +Given a float LoDTensor X with shape (6, 1), its transpose form represents + + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + +with LoD = [[0, 2, 5, 6]] and the three (transposed) sequences look like + + [1.0, 2.0], [3.0, 4.0, 5.0], [6.0]. + +If target LoD = [0, 4, 6], the lod_reset operator will reset the LoD and +the sequences that the LoDTensor Output(Out) contains becomes: + + [1.0, 2.0, 3.0, 4.0], [5.0, 6.0]. + +)DOC"); + } +}; + +class LoDResetGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker, lod_reset_grad, + ops::LoDResetGradOp); +REGISTER_OP_CPU_KERNEL(lod_reset, + ops::LoDResetKernel, + ops::LoDResetKernel); +REGISTER_OP_CPU_KERNEL( + lod_reset_grad, ops::LoDResetGradKernel, + ops::LoDResetGradKernel); diff --git a/paddle/operators/lod_reset_op.cu b/paddle/operators/lod_reset_op.cu new file mode 100644 index 0000000000..5244a17c3a --- /dev/null +++ b/paddle/operators/lod_reset_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/lod_reset_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL(lod_reset, + ops::LoDResetKernel, + ops::LoDResetKernel); +REGISTER_OP_GPU_KERNEL( + lod_reset_grad, ops::LoDResetGradKernel, + ops::LoDResetGradKernel); diff --git a/paddle/operators/lod_reset_op.h b/paddle/operators/lod_reset_op.h new file mode 100644 index 0000000000..2bb916ccee --- /dev/null +++ b/paddle/operators/lod_reset_op.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class LoDResetKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* out = ctx.Output("Out"); + auto* in = ctx.Input("X"); + auto* lod_t = ctx.Input("TargetLoD"); + + std::vector level0; + if (lod_t) { + auto* lod = lod_t->data(); + if (platform::is_gpu_place(ctx.GetPlace())) { + framework::Tensor lod_cpu; + lod_cpu.CopyFrom(*lod_t, platform::CPUPlace(), ctx.device_context()); + lod = lod_cpu.data(); + } + level0 = std::vector(lod, lod + lod_t->numel()); + } else { + level0 = ctx.Attr>("target_lod"); + } + + PADDLE_ENFORCE(level0.size() > 1UL, + "The size of target LoD should be greater than 1."); + PADDLE_ENFORCE(level0[0] == 0, + "Target LoD should be a vector starting from 0."); + PADDLE_ENFORCE(level0.back() == in->dims()[0], + "Target LoD should be a vector end with the " + "first dimension of Input(X)."); + for (size_t i = 0; i < level0.size() - 1; ++i) { + PADDLE_ENFORCE(level0[i + 1] > level0[i], + "Target LoD should be an ascending vector."); + } + + out->ShareDataWith(*in); + // cast level0 to size_t + std::vector ulevel0(level0.size(), 0); + std::transform(level0.begin(), level0.end(), ulevel0.begin(), + [](int a) { return static_cast(a); }); + framework::LoD target_lod; + target_lod.push_back(ulevel0); + out->set_lod(target_lod); + } +}; + +template +class LoDResetGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto* d_x = ctx.Output(framework::GradVarName("X")); + + d_x->ShareDataWith(*d_out); + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/pooling.cc b/paddle/operators/math/pooling.cc index 50cfb88bb5..ead89e146f 100644 --- a/paddle/operators/math/pooling.cc +++ b/paddle/operators/math/pooling.cc @@ -27,15 +27,15 @@ template class Pool2dFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_process) { + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_process, framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; - const int output_channels = output.dims()[1]; - const int output_height = output.dims()[2]; - const int output_width = output.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; const int ksize_height = ksize[0]; const int ksize_width = ksize[1]; const int stride_height = strides[0]; @@ -47,7 +47,7 @@ class Pool2dFunctor { const int output_stride = output_height * output_width; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); + T* output_data = output->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { @@ -87,11 +87,12 @@ template class Pool2dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_grad_process) { + PoolProcess pool_grad_process, + framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -110,7 +111,7 @@ class Pool2dGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { @@ -154,10 +155,11 @@ template class MaxPool2dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings) { + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -176,7 +178,7 @@ class MaxPool2dGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { @@ -240,17 +242,17 @@ template class Pool3dFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_process) { + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_process, framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; const int input_width = input.dims()[4]; - const int output_channels = output.dims()[1]; - const int output_depth = output.dims()[2]; - const int output_height = output.dims()[3]; - const int output_width = output.dims()[4]; + const int output_channels = output->dims()[1]; + const int output_depth = output->dims()[2]; + const int output_height = output->dims()[3]; + const int output_width = output->dims()[4]; const int ksize_depth = ksize[0]; const int ksize_height = ksize[1]; const int ksize_width = ksize[2]; @@ -265,7 +267,7 @@ class Pool3dFunctor { const int output_stride = output_depth * output_height * output_width; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); + T* output_data = output->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { @@ -315,11 +317,12 @@ template class Pool3dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_grad_process) { + PoolProcess pool_grad_process, + framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; @@ -343,7 +346,7 @@ class Pool3dGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { @@ -398,10 +401,11 @@ template class MaxPool3dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings) { + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; @@ -425,7 +429,7 @@ class MaxPool3dGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { @@ -498,15 +502,15 @@ template class MaxPool2dWithIndexFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings) { + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + framework::Tensor* output, framework::Tensor* mask) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; - const int output_channels = output.dims()[1]; - const int output_height = output.dims()[2]; - const int output_width = output.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; const int ksize_height = ksize[0]; const int ksize_width = ksize[1]; const int stride_height = strides[0]; @@ -517,8 +521,8 @@ class MaxPool2dWithIndexFunctor { const int output_stride = output_height * output_width; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); - T* mask_data = mask.mutable_data(context.GetPlace()); + T* output_data = output->mutable_data(context.GetPlace()); + T* mask_data = mask->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { @@ -563,13 +567,13 @@ template class MaxPool2dWithIndexGradFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& input_grad, const framework::Tensor& output_grad, const framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings) { - const int batch_size = input_grad.dims()[0]; - const int input_height = input_grad.dims()[2]; - const int input_width = input_grad.dims()[3]; + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad) { + const int batch_size = input_grad->dims()[0]; + const int input_height = input_grad->dims()[2]; + const int input_width = input_grad->dims()[3]; const int output_channels = output_grad.dims()[1]; const int output_height = output_grad.dims()[2]; const int output_width = output_grad.dims()[3]; @@ -578,7 +582,7 @@ class MaxPool2dWithIndexGradFunctor { const T* mask_data = mask.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); for (int n = 0; n < batch_size; ++n) { for (int c = 0; c < output_channels; ++c) { @@ -612,17 +616,17 @@ template class MaxPool3dWithIndexFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings) { + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + framework::Tensor* output, framework::Tensor* mask) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; const int input_width = input.dims()[4]; - const int output_channels = output.dims()[1]; - const int output_depth = output.dims()[2]; - const int output_height = output.dims()[3]; - const int output_width = output.dims()[4]; + const int output_channels = output->dims()[1]; + const int output_depth = output->dims()[2]; + const int output_height = output->dims()[3]; + const int output_width = output->dims()[4]; const int ksize_depth = ksize[0]; const int ksize_height = ksize[1]; const int ksize_width = ksize[2]; @@ -636,8 +640,8 @@ class MaxPool3dWithIndexFunctor { const int output_stride = output_depth * output_height * output_width; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); - T* mask_data = mask.mutable_data(context.GetPlace()); + T* output_data = output->mutable_data(context.GetPlace()); + T* mask_data = mask->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { @@ -691,14 +695,14 @@ template class MaxPool3dWithIndexGradFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& input_grad, const framework::Tensor& output_grad, const framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings) { - const int batch_size = input_grad.dims()[0]; - const int input_depth = input_grad.dims()[2]; - const int input_height = input_grad.dims()[3]; - const int input_width = input_grad.dims()[4]; + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad) { + const int batch_size = input_grad->dims()[0]; + const int input_depth = input_grad->dims()[2]; + const int input_height = input_grad->dims()[3]; + const int input_width = input_grad->dims()[4]; const int output_channels = output_grad.dims()[1]; const int output_depth = output_grad.dims()[2]; const int output_height = output_grad.dims()[3]; @@ -708,7 +712,7 @@ class MaxPool3dWithIndexGradFunctor { const T* mask_data = mask.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); for (int n = 0; n < batch_size; ++n) { for (int c = 0; c < output_channels; ++c) { diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu index 736327f4b7..6d1138ad50 100644 --- a/paddle/operators/math/pooling.cu +++ b/paddle/operators/math/pooling.cu @@ -21,13 +21,13 @@ namespace math { template __global__ void KernelPool2D(const int nthreads, const T* input_data, - T* output_data, const int channels, - const int input_height, const int input_width, - const int output_height, const int output_width, - const int ksize_height, const int ksize_width, - const int stride_height, const int stride_width, - const int padding_height, const int padding_width, - PoolProcess pool_process) { + const int channels, const int input_height, + const int input_width, const int output_height, + const int output_width, const int ksize_height, + const int ksize_width, const int stride_height, + const int stride_width, const int padding_height, + const int padding_width, PoolProcess pool_process, + T* output_data) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -59,11 +59,11 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data, template __global__ void KernelPool2DGrad( const int nthreads, const T* input_data, const T* output_data, - const T* output_grad, T* input_grad, const int channels, - const int input_height, const int input_width, const int output_height, - const int output_width, const int ksize_height, const int ksize_width, - const int stride_height, const int stride_width, const int padding_height, - const int padding_width, PoolProcess pool_process) { + const T* output_grad, const int channels, const int input_height, + const int input_width, const int output_height, const int output_width, + const int ksize_height, const int ksize_width, const int stride_height, + const int stride_width, const int padding_height, const int padding_width, + PoolProcess pool_process, T* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int offsetW = index % input_width + padding_width; @@ -107,11 +107,11 @@ __global__ void KernelPool2DGrad( template __global__ void KernelMaxPool2DGrad( const int nthreads, const T* input_data, const T* output_data, - const T* output_grad, T* input_grad, const int channels, - const int input_height, const int input_width, const int output_height, - const int output_width, const int ksize_height, const int ksize_width, - const int stride_height, const int stride_width, const int padding_height, - const int padding_width) { + const T* output_grad, const int channels, const int input_height, + const int input_width, const int output_height, const int output_width, + const int ksize_height, const int ksize_width, const int stride_height, + const int stride_width, const int padding_height, const int padding_width, + T* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -158,16 +158,16 @@ template class Pool2dFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_process) { + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_process, framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; - const int output_channels = output.dims()[1]; - const int output_height = output.dims()[2]; - const int output_width = output.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; const int ksize_height = ksize[0]; const int ksize_width = ksize[1]; const int stride_height = strides[0]; @@ -176,7 +176,7 @@ class Pool2dFunctor { const int padding_width = paddings[1]; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); + T* output_data = output->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_height * output_width; int blocks = (nthreads + 1024 - 1) / 1024; @@ -187,11 +187,10 @@ class Pool2dFunctor { PoolProcess, T><<(context) - .stream()>>>(nthreads, input_data, output_data, input_channels, - input_height, input_width, output_height, - output_width, ksize_height, ksize_width, - stride_height, stride_width, padding_height, - padding_width, pool_process); + .stream()>>>( + nthreads, input_data, input_channels, input_height, input_width, + output_height, output_width, ksize_height, ksize_width, stride_height, + stride_width, padding_height, padding_width, pool_process, output_data); } }; @@ -204,11 +203,11 @@ template class Pool2dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_process) { + PoolProcess pool_process, framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -225,7 +224,7 @@ class Pool2dGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * input_channels * input_height * input_width; int blocks = (nthreads + 1024 - 1) / 1024; @@ -237,10 +236,10 @@ class Pool2dGradFunctor { T><<(context) .stream()>>>( - nthreads, input_data, output_data, output_grad_data, input_grad_data, - input_channels, input_height, input_width, output_height, output_width, - ksize_height, ksize_width, stride_height, stride_width, padding_height, - padding_width, pool_process); + nthreads, input_data, output_data, output_grad_data, input_channels, + input_height, input_width, output_height, output_width, ksize_height, + ksize_width, stride_height, stride_width, padding_height, padding_width, + pool_process, input_grad_data); } }; @@ -253,10 +252,11 @@ template class MaxPool2dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings) { + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -274,7 +274,7 @@ class MaxPool2dGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_height * output_width; int blocks = (nthreads + 1024 - 1) / 1024; @@ -285,10 +285,10 @@ class MaxPool2dGradFunctor { T><<(context) .stream()>>>( - nthreads, input_data, output_data, output_grad_data, input_grad_data, - input_channels, input_height, input_width, output_height, output_width, - ksize_height, ksize_width, stride_height, stride_width, padding_height, - padding_width); + nthreads, input_data, output_data, output_grad_data, input_channels, + input_height, input_width, output_height, output_width, ksize_height, + ksize_width, stride_height, stride_width, padding_height, padding_width, + input_grad_data); } }; @@ -313,14 +313,16 @@ template class Pool2dGradFunctor< platform::GPUPlace, paddle::operators::math::AvgPoolGrad, double>; template -__global__ void KernelPool3D( - const int nthreads, const T* input_data, T* output_data, const int channels, - const int input_depth, const int input_height, const int input_width, - const int output_depth, const int output_height, const int output_width, - const int ksize_depth, const int ksize_height, const int ksize_width, - const int stride_depth, const int stride_height, const int stride_width, - const int padding_depth, const int padding_height, const int padding_width, - PoolProcess pool_process) { +__global__ void KernelPool3D(const int nthreads, const T* input_data, + const int channels, const int input_depth, + const int input_height, const int input_width, + const int output_depth, const int output_height, + const int output_width, const int ksize_depth, + const int ksize_height, const int ksize_width, + const int stride_depth, const int stride_height, + const int stride_width, const int padding_depth, + const int padding_height, const int padding_width, + PoolProcess pool_process, T* output_data) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -358,13 +360,13 @@ __global__ void KernelPool3D( template __global__ void KernelPool3DGrad( const int nthreads, const T* input_data, const T* output_data, - const T* output_grad, T* input_grad, const int channels, - const int input_depth, const int input_height, const int input_width, - const int output_depth, const int output_height, const int output_width, - const int ksize_depth, const int ksize_height, const int ksize_width, - const int stride_depth, const int stride_height, const int stride_width, - const int padding_depth, const int padding_height, const int padding_width, - PoolProcess pool_process) { + const T* output_grad, const int channels, const int input_depth, + const int input_height, const int input_width, const int output_depth, + const int output_height, const int output_width, const int ksize_depth, + const int ksize_height, const int ksize_width, const int stride_depth, + const int stride_height, const int stride_width, const int padding_depth, + const int padding_height, const int padding_width, PoolProcess pool_process, + T* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int offsetW = index % input_width + padding_width; @@ -422,13 +424,12 @@ __global__ void KernelPool3DGrad( template __global__ void KernelMaxPool3DGrad( const int nthreads, const T* input_data, const T* output_data, - const T* output_grad, T* input_grad, const int channels, - const int input_depth, const int input_height, const int input_width, - const int output_depth, const int output_height, const int output_width, - const int ksize_depth, const int ksize_height, const int ksize_width, - const int stride_depth, const int stride_height, const int stride_width, - const int padding_depth, const int padding_height, - const int padding_width) { + const T* output_grad, const int channels, const int input_depth, + const int input_height, const int input_width, const int output_depth, + const int output_height, const int output_width, const int ksize_depth, + const int ksize_height, const int ksize_width, const int stride_depth, + const int stride_height, const int stride_width, const int padding_depth, + const int padding_height, const int padding_width, T* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -480,18 +481,18 @@ template class Pool3dFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_process) { + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_process, framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; const int input_width = input.dims()[4]; - const int output_channels = output.dims()[1]; - const int output_depth = output.dims()[2]; - const int output_height = output.dims()[3]; - const int output_width = output.dims()[4]; + const int output_channels = output->dims()[1]; + const int output_depth = output->dims()[2]; + const int output_height = output->dims()[3]; + const int output_width = output->dims()[4]; const int ksize_depth = ksize[0]; const int ksize_height = ksize[1]; const int ksize_width = ksize[2]; @@ -503,7 +504,7 @@ class Pool3dFunctor { const int padding_width = paddings[2]; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); + T* output_data = output->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_depth * output_height * output_width; @@ -516,11 +517,11 @@ class Pool3dFunctor { T><<(context) .stream()>>>( - nthreads, input_data, output_data, input_channels, input_depth, - input_height, input_width, output_depth, output_height, output_width, - ksize_depth, ksize_height, ksize_width, stride_depth, stride_height, - stride_width, padding_depth, padding_height, padding_width, - pool_process); + nthreads, input_data, input_channels, input_depth, input_height, + input_width, output_depth, output_height, output_width, ksize_depth, + ksize_height, ksize_width, stride_depth, stride_height, stride_width, + padding_depth, padding_height, padding_width, pool_process, + output_data); } }; @@ -533,11 +534,11 @@ template class Pool3dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_process) { + PoolProcess pool_process, framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; @@ -560,7 +561,7 @@ class Pool3dGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * input_channels * input_depth * input_height * input_width; @@ -573,11 +574,11 @@ class Pool3dGradFunctor { T><<(context) .stream()>>>( - nthreads, input_data, output_data, output_grad_data, input_grad_data, - input_channels, input_depth, input_height, input_width, output_depth, - output_height, output_width, ksize_depth, ksize_height, ksize_width, - stride_depth, stride_height, stride_width, padding_depth, - padding_height, padding_width, pool_process); + nthreads, input_data, output_data, output_grad_data, input_channels, + input_depth, input_height, input_width, output_depth, output_height, + output_width, ksize_depth, ksize_height, ksize_width, stride_depth, + stride_height, stride_width, padding_depth, padding_height, + padding_width, pool_process, input_grad_data); } }; @@ -590,10 +591,11 @@ template class MaxPool3dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings) { + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; @@ -616,7 +618,7 @@ class MaxPool3dGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_depth * output_height * output_width; @@ -628,11 +630,11 @@ class MaxPool3dGradFunctor { T><<(context) .stream()>>>( - nthreads, input_data, output_data, output_grad_data, input_grad_data, - input_channels, input_depth, input_height, input_width, output_depth, - output_height, output_width, ksize_depth, ksize_height, ksize_width, - stride_depth, stride_height, stride_width, padding_depth, - padding_height, padding_width); + nthreads, input_data, output_data, output_grad_data, input_channels, + input_depth, input_height, input_width, output_depth, output_height, + output_width, ksize_depth, ksize_height, ksize_width, stride_depth, + stride_height, stride_width, padding_depth, padding_height, + padding_width, input_grad_data); } }; @@ -658,11 +660,11 @@ template class Pool3dGradFunctor< template __global__ void KernelMaxPool2dWithIdx( - const int nthreads, const T* input_data, T* output_data, T* mask_data, - const int channels, const int input_height, const int input_width, - const int output_height, const int output_width, const int ksize_height, - const int ksize_width, const int stride_height, const int stride_width, - const int padding_height, const int padding_width) { + const int nthreads, const T* input_data, const int channels, + const int input_height, const int input_width, const int output_height, + const int output_width, const int ksize_height, const int ksize_width, + const int stride_height, const int stride_width, const int padding_height, + const int padding_width, T* output_data, T* mask_data) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -697,11 +699,11 @@ __global__ void KernelMaxPool2dWithIdx( template __global__ void KernelMaxPool2DWithIdxGrad( - const int nthreads, T* input_grad, const T* output_grad, const T* mask_data, + const int nthreads, const T* output_grad, const T* mask_data, const int channels, const int input_height, const int input_width, const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, - const int padding_height, const int padding_width) { + const int padding_height, const int padding_width, T* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int w_offset = index % input_width; @@ -748,16 +750,16 @@ template class MaxPool2dWithIndexFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings) { + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + framework::Tensor* output, framework::Tensor* mask) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; - const int output_channels = output.dims()[1]; - const int output_height = output.dims()[2]; - const int output_width = output.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; const int ksize_height = ksize[0]; const int ksize_width = ksize[1]; const int stride_height = strides[0]; @@ -766,8 +768,8 @@ class MaxPool2dWithIndexFunctor { const int padding_width = paddings[1]; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); - T* mask_data = mask.mutable_data(context.GetPlace()); + T* output_data = output->mutable_data(context.GetPlace()); + T* mask_data = mask->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_height * output_width; int blocks = (nthreads + 1024 - 1) / 1024; @@ -777,11 +779,10 @@ class MaxPool2dWithIndexFunctor { KernelMaxPool2dWithIdx< T><<(context) - .stream()>>>(nthreads, input_data, output_data, mask_data, - input_channels, input_height, input_width, - output_height, output_width, ksize_height, - ksize_width, stride_height, stride_width, - padding_height, padding_width); + .stream()>>>( + nthreads, input_data, input_channels, input_height, input_width, + output_height, output_width, ksize_height, ksize_width, stride_height, + stride_width, padding_height, padding_width, output_data, mask_data); } }; @@ -794,14 +795,14 @@ template class MaxPool2dWithIndexGradFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& input_grad, const framework::Tensor& output_grad, const framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings) { - const int batch_size = input_grad.dims()[0]; - const int input_channels = input_grad.dims()[1]; - const int input_height = input_grad.dims()[2]; - const int input_width = input_grad.dims()[3]; + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad) { + const int batch_size = input_grad->dims()[0]; + const int input_channels = input_grad->dims()[1]; + const int input_height = input_grad->dims()[2]; + const int input_width = input_grad->dims()[3]; const int output_height = output_grad.dims()[2]; const int output_width = output_grad.dims()[3]; const int ksize_height = ksize[0]; @@ -813,7 +814,7 @@ class MaxPool2dWithIndexGradFunctor { const T* mask_data = mask.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * input_channels * input_height * input_width; int blocks = (nthreads + 1024 - 1) / 1024; @@ -823,11 +824,11 @@ class MaxPool2dWithIndexGradFunctor { KernelMaxPool2DWithIdxGrad< T><<(context) - .stream()>>>(nthreads, input_grad_data, output_grad_data, - mask_data, input_channels, input_height, - input_width, output_height, output_width, - ksize_height, ksize_width, stride_height, - stride_width, padding_height, padding_width); + .stream()>>>(nthreads, output_grad_data, mask_data, + input_channels, input_height, input_width, + output_height, output_width, ksize_height, + ksize_width, stride_height, stride_width, + padding_height, padding_width, input_grad_data); } }; @@ -838,13 +839,13 @@ template class MaxPool2dWithIndexGradFunctor; template __global__ void KernelMaxPool3DWithIdx( - const int nthreads, const T* input_data, T* output_data, T* mask_data, - const int channels, const int input_depth, const int input_height, - const int input_width, const int output_depth, const int output_height, - const int output_width, const int ksize_depth, const int ksize_height, - const int ksize_width, const int stride_depth, const int stride_height, - const int stride_width, const int padding_depth, const int padding_height, - const int padding_width) { + const int nthreads, const T* input_data, const int channels, + const int input_depth, const int input_height, const int input_width, + const int output_depth, const int output_height, const int output_width, + const int ksize_depth, const int ksize_height, const int ksize_width, + const int stride_depth, const int stride_height, const int stride_width, + const int padding_depth, const int padding_height, const int padding_width, + T* output_data, T* mask_data) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -886,13 +887,13 @@ __global__ void KernelMaxPool3DWithIdx( template __global__ void KernelMaxPool3DWithIdxGrad( - const int nthreads, T* input_grad, const T* output_grad, const T* mask, - const int channels, const int input_depth, const int input_height, - const int input_width, const int output_depth, const int output_height, - const int output_width, const int ksize_depth, const int ksize_height, - const int ksize_width, const int stride_depth, const int stride_height, - const int stride_width, const int padding_depth, const int padding_height, - const int padding_width) { + const int nthreads, const T* output_grad, const T* mask, const int channels, + const int input_depth, const int input_height, const int input_width, + const int output_depth, const int output_height, const int output_width, + const int ksize_depth, const int ksize_height, const int ksize_width, + const int stride_depth, const int stride_height, const int stride_width, + const int padding_depth, const int padding_height, const int padding_width, + T* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int w_offset = index % input_width; @@ -952,18 +953,18 @@ template class MaxPool3dWithIndexFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings) { + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + framework::Tensor* output, framework::Tensor* mask) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; const int input_width = input.dims()[4]; - const int output_channels = output.dims()[1]; - const int output_depth = output.dims()[2]; - const int output_height = output.dims()[3]; - const int output_width = output.dims()[4]; + const int output_channels = output->dims()[1]; + const int output_depth = output->dims()[2]; + const int output_height = output->dims()[3]; + const int output_width = output->dims()[4]; const int ksize_depth = ksize[0]; const int ksize_height = ksize[1]; const int ksize_width = ksize[2]; @@ -975,8 +976,8 @@ class MaxPool3dWithIndexFunctor { const int padding_width = paddings[2]; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); - T* mask_data = mask.mutable_data(context.GetPlace()); + T* output_data = output->mutable_data(context.GetPlace()); + T* mask_data = mask->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_depth * output_height * output_width; @@ -988,11 +989,10 @@ class MaxPool3dWithIndexFunctor { T><<(context) .stream()>>>( - nthreads, input_data, output_data, mask_data, input_channels, - input_depth, input_height, input_width, output_depth, output_height, - output_width, ksize_depth, ksize_height, ksize_width, stride_depth, - stride_height, stride_width, padding_depth, padding_height, - padding_width); + nthreads, input_data, input_channels, input_depth, input_height, + input_width, output_depth, output_height, output_width, ksize_depth, + ksize_height, ksize_width, stride_depth, stride_height, stride_width, + padding_depth, padding_height, padding_width, output_data, mask_data); } }; @@ -1005,15 +1005,15 @@ template class MaxPool3dWithIndexGradFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& input_grad, const framework::Tensor& output_grad, const framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings) { - const int batch_size = input_grad.dims()[0]; - const int input_channels = input_grad.dims()[1]; - const int input_depth = input_grad.dims()[2]; - const int input_height = input_grad.dims()[3]; - const int input_width = input_grad.dims()[4]; + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad) { + const int batch_size = input_grad->dims()[0]; + const int input_channels = input_grad->dims()[1]; + const int input_depth = input_grad->dims()[2]; + const int input_height = input_grad->dims()[3]; + const int input_width = input_grad->dims()[4]; const int output_depth = output_grad.dims()[2]; const int output_height = output_grad.dims()[3]; const int output_width = output_grad.dims()[4]; @@ -1029,7 +1029,7 @@ class MaxPool3dWithIndexGradFunctor { const T* output_grad_data = output_grad.data(); const T* mask_data = mask.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * input_channels * input_depth * input_height * input_width; @@ -1041,11 +1041,11 @@ class MaxPool3dWithIndexGradFunctor { T><<(context) .stream()>>>( - nthreads, input_grad_data, output_grad_data, mask_data, input_channels, - input_depth, input_height, input_width, output_depth, output_height, - output_width, ksize_depth, ksize_height, ksize_width, stride_depth, - stride_height, stride_width, padding_depth, padding_height, - padding_width); + nthreads, output_grad_data, mask_data, input_channels, input_depth, + input_height, input_width, output_depth, output_height, output_width, + ksize_depth, ksize_height, ksize_width, stride_depth, stride_height, + stride_width, padding_depth, padding_height, padding_width, + input_grad_data); } }; diff --git a/paddle/operators/math/pooling.h b/paddle/operators/math/pooling.h index c50c57b5c5..f6719e1e62 100644 --- a/paddle/operators/math/pooling.h +++ b/paddle/operators/math/pooling.h @@ -88,60 +88,62 @@ template class Pool2dFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_compute); + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_compute, framework::Tensor* output); }; template class Pool2dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_compute); + PoolProcess pool_compute, framework::Tensor* input_grad); }; template class MaxPool2dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings); + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad); }; template class Pool3dFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_compute); + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_compute, framework::Tensor* output); }; template class Pool3dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_compute); + PoolProcess pool_compute, framework::Tensor* input_grad); }; template class MaxPool3dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings); + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad); }; /* @@ -155,38 +157,38 @@ template class MaxPool2dWithIndexFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings); + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + framework::Tensor* output, framework::Tensor* mask); }; template class MaxPool2dWithIndexGradFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& input_grad, const framework::Tensor& output_grad, const framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings); + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad); }; template class MaxPool3dWithIndexFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings); + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + framework::Tensor* output, framework::Tensor* mask); }; template class MaxPool3dWithIndexGradFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& input_grad, const framework::Tensor& output_grad, const framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings); + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad); }; } // namespace math diff --git a/paddle/operators/matmul_op.h b/paddle/operators/matmul_op.h index 5ce30740c9..4f565946d5 100644 --- a/paddle/operators/matmul_op.h +++ b/paddle/operators/matmul_op.h @@ -74,11 +74,10 @@ Tensor CombineBatchAndN(const framework::ExecutionContext& context, Tensor output; auto in_dims = input.dims(); if (in_dims.size() == 3) { - output.Resize(in_dims); + output.Resize({in_dims[1], in_dims[0], in_dims[2]}); output.mutable_data(context.GetPlace()); EigenTranspose(context, input, output, {1, 0, 2}); - std::vector out_dims = {in_dims[1], in_dims[0] * in_dims[2]}; - output.Resize(make_ddim(out_dims)); + output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); } else { output.ShareDataWith(input); } diff --git a/paddle/operators/merge_lod_tensor_op.cc b/paddle/operators/merge_lod_tensor_op.cc new file mode 100644 index 0000000000..80460c4769 --- /dev/null +++ b/paddle/operators/merge_lod_tensor_op.cc @@ -0,0 +1,182 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/memory/memcpy.h" + +namespace paddle { +namespace operators { + +using LoD = framework::LoD; + +class MergeLoDTensorOp : public framework::OperatorBase { + public: + MergeLoDTensorOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto &x = scope.FindVar(Input("X"))->Get(); + auto &mask = scope.FindVar(Input("Mask"))->Get(); + auto &in_true = scope.FindVar(Input("InTrue"))->Get(); + auto &in_false = + scope.FindVar(Input("InFalse"))->Get(); + auto *out = + scope.FindVar(Output("Out"))->GetMutable(); + auto level = static_cast(Attr("level")); + + auto &mask_dim = mask.dims(); + + std::unique_ptr cpu_mask{new framework::LoDTensor()}; + if (platform::is_cpu_place(mask.place())) { + cpu_mask->ShareDataWith(mask); + } else if (platform::is_gpu_place(mask.place())) { +#ifdef PADDLE_WITH_CUDA + cpu_mask->CopyFrom(mask, platform::CPUPlace(), dev_ctx); +#else + PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option"); +#endif + } + auto *mask_data = cpu_mask->data(); + + int rank = in_true.dims().size(); + platform::Place place = in_true.place(); + std::type_index data_type = in_true.type(); + framework::DDim in_true_dims = + framework::slice_ddim(in_true.dims(), 1, rank); + + int64_t batch_size = in_true.dims()[0] + in_false.dims()[0]; + + auto in_true_dim_vec = framework::vectorize(in_true_dims); + in_true_dim_vec.insert(in_true_dim_vec.begin(), batch_size); + + framework::DDim out_dims = framework::make_ddim(in_true_dim_vec); + out->Resize(out_dims); + out->mutable_data(place, data_type); + + auto *out_lod = out->mutable_lod(); + out_lod->clear(); + size_t out_offset = 0; + + // Build LoDTensor `out` + + size_t in_true_idx = 0; + size_t in_false_idx = 0; + for (size_t i = 0; i < static_cast(mask_dim[0]); i++) { + const framework::LoDTensor *input = nullptr; + size_t *in_idx = nullptr; + if (static_cast(mask_data[i]) == 0) { + input = &in_false; + in_idx = &in_false_idx; + } else { + input = &in_true; + in_idx = &in_true_idx; + } + auto lod_and_offset = framework::GetSubLoDAndAbsoluteOffset( + input->lod(), *in_idx, (*in_idx) + 1, 0); + auto &lod_length = lod_and_offset.first; + + framework::AppendLoD(out_lod, lod_length); + + size_t start_offset = lod_and_offset.second.first; + size_t end_offset = lod_and_offset.second.second; + + PADDLE_ENFORCE_GE(end_offset, start_offset); + size_t len = end_offset - start_offset; + if (len == 0) { + continue; + } + out->Slice(out_offset, out_offset + len) + .CopyFrom(input->Slice(start_offset, end_offset), place, dev_ctx); + out_offset += len; + (*in_idx) += 1; + } + + for (size_t i = 0; i < level; i++) { + out_lod->insert(out_lod->begin(), x.lod()[i]); + } + } +}; + +class MergeLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + MergeLoDTensorOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "The input LoDTensor, contains complete lod information to " + "construct the output"); + AddInput("Mask", "A bool column vector which mask the input"); + AddInput("InTrue", "The True branch to be merged"); + AddInput("InFalse", "The False branch to be merged"); + AddOutput("Out", "The merged output LoDTensor"); + AddAttr("level", "(int) the specific lod level to rank.") + .SetDefault(0) + .EqualGreaterThan(0); + AddComment( + R"DOC( + Merge True and False branches of LoDTensor into a single Output, + with a mask at certain lod level. X is used to obtain complete + lod information. Please refer to SplitLoDTensorOp.)DOC"); + } +}; + +class MergeLoDTensorInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE(context->HasInput("X"), + "MergeLoDTensorOp must has input X."); + PADDLE_ENFORCE(context->HasInput("Mask"), + "MergeLoDTensorOp must has input Mask."); + PADDLE_ENFORCE(context->HasInput("InTrue"), + "MergeLoDTensorOp must has input InTrue."); + PADDLE_ENFORCE(context->HasInput("InFalse"), + "MergeLoDTensorOp must has input InFalse."); + PADDLE_ENFORCE(context->HasOutput("Out"), + "MergeLoDTensorOp must has output Out"); + + auto mask_dim = context->GetInputDim("Mask"); + PADDLE_ENFORCE_EQ(mask_dim.size(), 2); + PADDLE_ENFORCE_EQ(mask_dim[1], 1); + + context->SetOutputDim("Out", context->GetInputDim("InTrue")); + } +}; + +class MergeLoDTensorGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *grad_op = new framework::OpDescBind(); + grad_op->SetType("split_lod_tensor"); + grad_op->SetInput("X", OutputGrad("Out")); + grad_op->SetInput("Mask", Input("Mask")); + grad_op->SetOutput("OutTrue", InputGrad("InTrue")); + grad_op->SetOutput("OutFalse", InputGrad("InFalse")); + grad_op->SetAttrMap(Attrs()); + return std::unique_ptr(grad_op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(merge_lod_tensor, ops::MergeLoDTensorOp, + ops::MergeLoDTensorOpProtoMaker, + ops::MergeLoDTensorInferShape, ops::MergeLoDTensorGradMaker); diff --git a/paddle/operators/pool_op.h b/paddle/operators/pool_op.h index 4da1941ab5..63492a89e8 100644 --- a/paddle/operators/pool_op.h +++ b/paddle/operators/pool_op.h @@ -75,16 +75,16 @@ class PoolKernel : public framework::OpKernel { Place, paddle::operators::math::MaxPool, T> pool2d_forward; paddle::operators::math::MaxPool pool_process; - pool2d_forward(context.device_context(), *in_x, *out, ksize, strides, - paddings, pool_process); + pool2d_forward(context.device_context(), *in_x, ksize, strides, + paddings, pool_process, out); } else if (pooling_type == "avg") { paddle::operators::math::Pool2dFunctor< Place, paddle::operators::math::AvgPool, T> pool2d_forward; paddle::operators::math::AvgPool pool_process; - pool2d_forward(context.device_context(), *in_x, *out, ksize, strides, - paddings, pool_process); + pool2d_forward(context.device_context(), *in_x, ksize, strides, + paddings, pool_process, out); } } break; case 3: { @@ -93,15 +93,15 @@ class PoolKernel : public framework::OpKernel { Place, paddle::operators::math::MaxPool, T> pool3d_forward; paddle::operators::math::MaxPool pool_process; - pool3d_forward(context.device_context(), *in_x, *out, ksize, strides, - paddings, pool_process); + pool3d_forward(context.device_context(), *in_x, ksize, strides, + paddings, pool_process, out); } else if (pooling_type == "avg") { paddle::operators::math::Pool3dFunctor< Place, paddle::operators::math::AvgPool, T> pool3d_forward; paddle::operators::math::AvgPool pool_process; - pool3d_forward(context.device_context(), *in_x, *out, ksize, strides, - paddings, pool_process); + pool3d_forward(context.device_context(), *in_x, ksize, strides, + paddings, pool_process, out); } } break; default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } @@ -142,30 +142,30 @@ class PoolGradKernel : public framework::OpKernel { if (pooling_type == "max") { paddle::operators::math::MaxPool2dGradFunctor pool2d_backward; - pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out, - *out_grad, ksize, strides, paddings); + pool2d_backward(context.device_context(), *in_x, *out, *out_grad, + ksize, strides, paddings, in_x_grad); } else if (pooling_type == "avg") { paddle::operators::math::Pool2dGradFunctor< Place, paddle::operators::math::AvgPoolGrad, T> pool2d_backward; paddle::operators::math::AvgPoolGrad pool_process; - pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out, - *out_grad, ksize, strides, paddings, pool_process); + pool2d_backward(context.device_context(), *in_x, *out, *out_grad, + ksize, strides, paddings, pool_process, in_x_grad); } } break; case 3: { if (pooling_type == "max") { paddle::operators::math::MaxPool3dGradFunctor pool3d_backward; - pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out, - *out_grad, ksize, strides, paddings); + pool3d_backward(context.device_context(), *in_x, *out, *out_grad, + ksize, strides, paddings, in_x_grad); } else if (pooling_type == "avg") { paddle::operators::math::Pool3dGradFunctor< Place, paddle::operators::math::AvgPoolGrad, T> pool3d_backward; paddle::operators::math::AvgPoolGrad pool_process; - pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out, - *out_grad, ksize, strides, paddings, pool_process); + pool3d_backward(context.device_context(), *in_x, *out, *out_grad, + ksize, strides, paddings, pool_process, in_x_grad); } } break; default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } diff --git a/paddle/operators/pool_with_index_op.h b/paddle/operators/pool_with_index_op.h index ea37de84ab..c0e3b117dc 100644 --- a/paddle/operators/pool_with_index_op.h +++ b/paddle/operators/pool_with_index_op.h @@ -46,14 +46,14 @@ class MaxPoolWithIndexKernel : public framework::OpKernel { case 2: { paddle::operators::math::MaxPool2dWithIndexFunctor pool2d_forward; - pool2d_forward(context.device_context(), *in_x, *out, *mask, ksize, - strides, paddings); + pool2d_forward(context.device_context(), *in_x, ksize, strides, + paddings, out, mask); } break; case 3: { paddle::operators::math::MaxPool3dWithIndexFunctor pool3d_forward; - pool3d_forward(context.device_context(), *in_x, *out, *mask, ksize, - strides, paddings); + pool3d_forward(context.device_context(), *in_x, ksize, strides, + paddings, out, mask); } break; default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } } @@ -89,14 +89,14 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel { case 2: { paddle::operators::math::MaxPool2dWithIndexGradFunctor pool2d_backward; - pool2d_backward(context.device_context(), *in_x_grad, *out_grad, - *mask, ksize, strides, paddings); + pool2d_backward(context.device_context(), *out_grad, *mask, ksize, + strides, paddings, in_x_grad); } break; case 3: { paddle::operators::math::MaxPool3dWithIndexGradFunctor pool3d_backward; - pool3d_backward(context.device_context(), *in_x_grad, *out_grad, - *mask, ksize, strides, paddings); + pool3d_backward(context.device_context(), *out_grad, *mask, ksize, + strides, paddings, in_x_grad); } break; default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } } diff --git a/paddle/operators/reduce_op.h b/paddle/operators/reduce_op.h index 45043c440b..dd6547542d 100644 --- a/paddle/operators/reduce_op.h +++ b/paddle/operators/reduce_op.h @@ -14,6 +14,7 @@ #pragma once +#include "glog/logging.h" #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" @@ -26,6 +27,10 @@ template using EigenTensor = framework::EigenTensor; +template +using EigenScalar = framework::EigenScalar; + struct SumFunctor { template void operator()(const Place& place, X& x, Y& y, const Dim& dim) { @@ -133,10 +138,17 @@ class ReduceKernel : public framework::OpKernel { dims_vector.erase(dims_vector.begin() + dim); dims = framework::make_ddim(dims_vector); } - auto out = EigenTensor < T, D == 1 ? 1 : (D - 1) > ::From(*output, dims); + auto& place = context.GetEigenDevice(); Functor functor; - functor(place, x, out, reduce_dim); + + if (D == 1) { + auto out = EigenScalar::From(*output); + functor(place, x, out, reduce_dim); + } else { + auto out = EigenTensor::From(*output, dims); + functor(place, x, out, reduce_dim); + } } }; @@ -186,13 +198,13 @@ class ReduceGradKernel : public framework::OpKernel { auto x_reduce = EigenTensor::From(*input1, dims); auto x_reduce_grad = EigenTensor::From(*input2, dims); - Eigen::array braodcast_dim; - for (size_t i = 0; i < D; ++i) braodcast_dim[i] = 1; - braodcast_dim[dim] = input0->dims()[dim]; + Eigen::array broadcast_dim; + for (size_t i = 0; i < D; ++i) broadcast_dim[i] = 1; + broadcast_dim[dim] = input0->dims()[dim]; auto& place = context.GetEigenDevice(); Functor functor; - functor(place, x, x_reduce, x_grad, x_reduce_grad, braodcast_dim, - braodcast_dim[dim]); + functor(place, x, x_reduce, x_grad, x_reduce_grad, broadcast_dim, + broadcast_dim[dim]); } }; diff --git a/paddle/operators/sequence_concat_op.cc b/paddle/operators/sequence_concat_op.cc index db737bed7a..d1de0b4447 100644 --- a/paddle/operators/sequence_concat_op.cc +++ b/paddle/operators/sequence_concat_op.cc @@ -47,7 +47,7 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", - "(vector) Input is a vector of LoDTensor, " + "(LodTensorArray) Input is a vector of LoDTensor, " "each of which is a variable-length sequence or nested sequence.") .AsDuplicable(); AddOutput("Out", diff --git a/paddle/operators/sequence_pool_op.h b/paddle/operators/sequence_pool_op.h index 2b8a25c241..7f136d8cf0 100644 --- a/paddle/operators/sequence_pool_op.h +++ b/paddle/operators/sequence_pool_op.h @@ -126,6 +126,7 @@ class SequencePoolGradKernel : public framework::OpKernel { int64_t h = static_cast(lod[i + 1] - lod[i]); auto in_g_e = EigenMatrix::From(in_g_t, {h, w}); auto out_g_e = EigenMatrix::From(out_g_t, {1, w}); + auto out_g_e_v = EigenVector::Flatten(out_g_t); Eigen::DSizes bcast(h, 1); if (pooltype == "AVERAGE") { @@ -136,9 +137,9 @@ class SequencePoolGradKernel : public framework::OpKernel { in_g_e.device(place) = (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast); } else if (pooltype == "LAST") { - in_g_e.chip(h - 1, 0).device(place) = out_g_e; + in_g_e.chip(h - 1, 0).device(place) = out_g_e_v; } else if (pooltype == "FIRST") { - in_g_e.chip(0, 0).device(place) = out_g_e; + in_g_e.chip(0, 0).device(place) = out_g_e_v; } else { PADDLE_THROW("unsupported pooling pooltype"); } diff --git a/paddle/operators/split_lod_tensor_op.cc b/paddle/operators/split_lod_tensor_op.cc new file mode 100644 index 0000000000..db635f2ba0 --- /dev/null +++ b/paddle/operators/split_lod_tensor_op.cc @@ -0,0 +1,186 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/memory/memcpy.h" + +namespace paddle { +namespace operators { + +struct CopyRange { + size_t begin; + size_t end; +}; + +using LoD = framework::LoD; + +class SplitLoDTensorOp : public framework::OperatorBase { + public: + SplitLoDTensorOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto &x = scope.FindVar(Input("X"))->Get(); + auto &mask = scope.FindVar(Input("Mask"))->Get(); + auto *out_true = + scope.FindVar(Output("OutTrue"))->GetMutable(); + auto *out_false = + scope.FindVar(Output("OutFalse"))->GetMutable(); + auto level = static_cast(Attr("level")); + auto &x_lod = x.lod(); + auto &mask_dim = mask.dims(); + + std::unique_ptr cpu_mask{new framework::LoDTensor()}; + if (platform::is_cpu_place(mask.place())) { + cpu_mask->ShareDataWith(mask); + } else if (platform::is_gpu_place(mask.place())) { +#ifdef PADDLE_WITH_CUDA + cpu_mask->CopyFrom(mask, platform::CPUPlace(), dev_ctx); +#else + PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option"); +#endif + } + auto *mask_data = cpu_mask->data(); + + std::vector> copy_ranges(mask_dim[0]); + + // set out_true/out_false lod + for (size_t t = 0; t < 2; t++) { + LoD *lod = nullptr; + if (t == 0) { + lod = out_false->mutable_lod(); + } else { + lod = out_true->mutable_lod(); + } + lod->clear(); + for (size_t i = 0; i < static_cast(mask_dim[0]); i++) { + if (static_cast(mask_data[i]) == t) { + size_t start_idx = i; + auto lod_and_offset = framework::GetSubLoDAndAbsoluteOffset( + x_lod, start_idx, start_idx + 1, level); + + auto &lod_length = lod_and_offset.first; + framework::AppendLoD(lod, lod_length); + + size_t start_offset = lod_and_offset.second.first; + size_t end_offset = lod_and_offset.second.second; + copy_ranges[t].emplace_back(CopyRange{start_offset, end_offset}); + } + } + } + + for (size_t t = 0; t < 2; ++t) { + framework::LoDTensor *out; + if (t == 0) { + out = out_false; + } else { + out = out_true; + } + auto &ranges = copy_ranges[t]; + size_t height = std::accumulate( + ranges.begin(), ranges.end(), 0UL, + [](size_t a, const CopyRange &b) { return a + b.end - b.begin; }); + auto x_dim = x.dims(); + x_dim[0] = static_cast(height); + out->Resize(x_dim); + out->mutable_data(x.place(), x.type()); + size_t offset = 0; + for (auto &each_range : ranges) { + size_t len = each_range.end - each_range.begin; + if (len == 0) { + continue; + } + // out[offset: offset+len] = x[each_range.begin: each_range.end] + out->Slice(static_cast(offset), static_cast(offset + len)) + .CopyFrom(x.Slice(static_cast(each_range.begin), + static_cast(each_range.end)), + x.place(), dev_ctx); + offset += len; + } + } + } +}; + +class SplitLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + SplitLoDTensorOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input LoDTensor"); + AddInput("Mask", "A bool column vector which mask the input"); + AddOutput("OutTrue", "True branch of input LoDTensor"); + AddOutput("OutFalse", "False branch of input LoDTensor"); + AddAttr("level", "(int) the specific lod level to split.") + .SetDefault(0) + .EqualGreaterThan(0); + AddComment( + R"DOC( + Split a LoDTensor with a Mask at certain level. The input LoDTensor + has 3 sequence at certain lod level. The Mask is a bool column vector, + such as [0, 1, 0] at the same level. The first and third sequence will + be send to False Output LoDTensor; whereas the second sequence will + be send to True Output LoDTensor. Please refer to MergeLoDTensorOp.)DOC"); + } +}; + +class SplitLoDTensorInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE(context->HasInput("X"), + "SplitLoDTensorOp must has input X."); + PADDLE_ENFORCE(context->HasInput("Mask"), + "SplitLoDTensorOp must has input Mask."); + PADDLE_ENFORCE(context->HasOutput("OutTrue"), + "SplitLoDTensorOp must has output OutTrue."); + PADDLE_ENFORCE(context->HasOutput("OutFalse"), + "SplitLoDTensorOp must has output OutFalse."); + + auto mask_dim = context->GetInputDim("Mask"); + PADDLE_ENFORCE_EQ(mask_dim.size(), 2); + PADDLE_ENFORCE_EQ(mask_dim[1], 1); + + context->SetOutputDim("OutTrue", context->GetInputDim("X")); + context->SetOutputDim("OutFalse", context->GetInputDim("X")); + } +}; + +class SplitLoDTensorArrayGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *grad_op = new framework::OpDescBind(); + grad_op->SetType("merge_lod_tensor"); + grad_op->SetInput("InTrue", OutputGrad("OutTrue")); + grad_op->SetInput("InFalse", OutputGrad("OutFalse")); + grad_op->SetInput("Mask", Input("Mask")); + grad_op->SetInput("X", Input("X")); + grad_op->SetOutput("Out", InputGrad("X")); + grad_op->SetAttrMap(Attrs()); + return std::unique_ptr(grad_op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(split_lod_tensor, ops::SplitLoDTensorOp, + ops::SplitLoDTensorOpProtoMaker, + ops::SplitLoDTensorInferShape, + ops::SplitLoDTensorArrayGradMaker); diff --git a/paddle/operators/squared_l2_norm_op.h b/paddle/operators/squared_l2_norm_op.h index c8d37ac40c..48d7b1c2d5 100644 --- a/paddle/operators/squared_l2_norm_op.h +++ b/paddle/operators/squared_l2_norm_op.h @@ -29,7 +29,7 @@ class SquaredL2NormKernel : public framework::OpKernel { Out->mutable_data(context.GetPlace()); auto x = framework::EigenVector::Flatten(*X); - auto out = framework::EigenVector::Flatten(*Out); + auto out = framework::EigenScalar::From(*Out); auto place = context.GetEigenDevice(); out.device(place) = x.square().sum(); diff --git a/paddle/parameter/Parameter.cpp b/paddle/parameter/Parameter.cpp index f031109501..3b0f09cea6 100644 --- a/paddle/parameter/Parameter.cpp +++ b/paddle/parameter/Parameter.cpp @@ -200,7 +200,10 @@ void Parameter::setMat(ParameterType pType, int matType) { false, useGpu_); } - } else if (matType == MAT_NORMAL_SHARED) { + } +#ifndef PADDLE_MOBILE_INFERENCE + // NOLINTNEXTLINE + else if (matType == MAT_NORMAL_SHARED) { CHECK_EQ(height * width, bufs_[pType]->getSize()); size_t blockNum = 0; CHECK(isGradShared(&blockNum)); @@ -259,7 +262,10 @@ void Parameter::setMat(ParameterType pType, int matType) { } else if (matType == MAT_SPARSE_ROW_AUTO_GROW) { CHECK(isGradSparseUpdate()); mats_[pType] = std::make_shared(height, width); - } else { + } +#endif + // NOLINTNEXTLINE + else { LOG(FATAL) << "Unsupported mat type" << matType; } } diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 0f906e0e47..3d8d3f1d2f 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -42,6 +42,9 @@ limitations under the License. */ #include "paddle/platform/gpu_info.h" #endif +// disable auto conversion to list in Python +PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray); + namespace paddle { namespace pybind { static size_t UniqueIntegerGenerator(const std::string &prefix) { diff --git a/paddle/scripts/deb/postinst b/paddle/scripts/deb/postinst deleted file mode 100644 index 91620b1ee7..0000000000 --- a/paddle/scripts/deb/postinst +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -set -e -echo "Post install paddle debian package." -echo "Install some python package used for paddle. You can run " -echo " pip install /usr/opt/paddle/share/wheels/*.whl to install them." -find /usr/ -name '*paddle*.whl' | xargs pip install diff --git a/paddle/scripts/docker/README.md b/paddle/scripts/docker/README.md index 76bc30e59b..b5fd68839d 100644 --- a/paddle/scripts/docker/README.md +++ b/paddle/scripts/docker/README.md @@ -2,178 +2,198 @@ ## Goals -We want the building procedure generates Docker images so that we can run PaddlePaddle applications on Kubernetes clusters. +We want to make the building procedures: -We want to build .deb packages so that enterprise users can run PaddlePaddle applications without Docker. +1. Static, can reproduce easily. +1. Generate python `whl` packages that can be widely use cross many distributions. +1. Build different binaries per release to satisfy different environments: + - Binaries for different CUDA and CUDNN versions, like CUDA 7.5, 8.0, 9.0 + - Binaries containing only capi + - Binaries for python with wide unicode support or not. +1. Build docker images with PaddlePaddle pre-installed, so that we can run +PaddlePaddle applications directly in docker or on Kubernetes clusters. -We want to minimize the size of generated Docker images and .deb packages so to reduce the download time. +To achieve this, we created a repo: https://github.com/PaddlePaddle/buildtools +which gives several docker images that are `manylinux1` sufficient. Then we +can build PaddlePaddle using these images to generate corresponding `whl` +binaries. -We want to encapsulate building tools and dependencies in a *development* Docker image so to ease the tools installation for developers. +## Run The Build -Developers use various editors (emacs, vim, Eclipse, Jupyter Notebook), so the development Docker image contains only building tools, not editing tools, and developers are supposed to git clone source code into their development computers and map the code into the development container. +### Build Evironments -We want the procedure and tools also work with testing, continuous integration, and releasing. +The pre-built build environment images are: +| Image | Tag | +| ----- | --- | +| paddlepaddle/paddle_manylinux_devel | cuda7.5_cudnn5 | +| paddlepaddle/paddle_manylinux_devel | cuda8.0_cudnn5 | +| paddlepaddle/paddle_manylinux_devel | cuda7.5_cudnn7 | +| paddlepaddle/paddle_manylinux_devel | cuda9.0_cudnn7 | -## Docker Images - -So we need two Docker images for each version of PaddlePaddle: - -1. `paddle:-dev` - - This a development image contains only the development tools and standardizes the building procedure. Users include: +### Start Build - - developers -- no longer need to install development tools on the host, and can build their current work on the host (development computer). - - release engineers -- use this to build the official release from certain branch/tag on Github.com. - - document writers / Website developers -- Our documents are in the source repo in the form of .md/.rst files and comments in source code. We need tools to extract the information, typeset, and generate Web pages. +Choose one docker image that suit your environment and run the following +command to start a build: - Of course, developers can install building tools on their development computers. But different versions of PaddlePaddle might require different set or version of building tools. Also, it makes collaborative debugging easier if all developers use a unified development environment. - - The development image should include the following tools: - - - gcc/clang - - nvcc - - Python - - sphinx - - woboq - - sshd +```bash +git clone https://github.com/PaddlePaddle/Paddle.git +cd Paddle +docker run --rm -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_AVX=ON" -e "WITH_TESTING=OFF" -e "RUN_TEST=OFF" -e "PYTHON_ABI=cp27-cp27mu" paddlepaddle/paddle_manylinux_devel /paddle/paddle/scripts/docker/build.sh +``` - Many developers work on a remote computer with GPU; they could ssh into the computer and `docker exec` into the development container. However, running `sshd` in the container allows developers to ssh into the container directly. +After the build finishes, you can get output `whl` package under +`build/python/dist`. -1. `paddle:` +This command mounts the source directory on the host into `/paddle` in the container, then run the build script `/paddle/paddle/scripts/docker/build.sh` +in the container. When it writes to `/paddle/build` in the container, it writes to `$PWD/build` on the host indeed. - This is the production image, generated using the development image. This image might have multiple variants: +### Build Options - - GPU/AVX `paddle:-gpu` - - GPU/no-AVX `paddle:-gpu-noavx` - - no-GPU/AVX `paddle:` - - no-GPU/no-AVX `paddle:-noavx` +Users can specify the following Docker build arguments with either "ON" or "OFF" value: - We allow users to choose between GPU and no-GPU because the GPU version image is much larger than then the no-GPU version. +| Option | Default | Description | +| ------ | -------- | ----------- | +| `WITH_GPU` | OFF | Generates NVIDIA CUDA GPU code and relies on CUDA libraries. | +| `WITH_AVX` | OFF | Set to "ON" to enable AVX support. | +| `WITH_TESTING` | ON | Build unit tests binaries. | +| `WITH_MKLDNN` | ON | Build with [Intel® MKL DNN](https://github.com/01org/mkl-dnn) support. | +| `WITH_MKLML` | ON | Build with [Intel® MKL](https://software.intel.com/en-us/mkl) support. | +| `WITH_GOLANG` | ON | Build fault-tolerant parameter server written in go. | +| `WITH_SWIG_PY` | ON | Build with SWIG python API support. | +| `WITH_C_API` | OFF | Build capi libraries for inference. | +| `WITH_PYTHON` | ON | Build with python support. Turn this off if build is only for capi. | +| `WITH_STYLE_CHECK` | ON | Check the code style when building. | +| `PYTHON_ABI` | "" | Build for different python ABI support, can be cp27-cp27m or cp27-cp27mu | +| `RUN_TEST` | OFF | Run unit test immediently after the build. | +| `WITH_DOC` | OFF | Build docs after build binaries. | +| `WOBOQ` | OFF | Generate WOBOQ code viewer under `build/woboq_out` | - We allow users the choice between AVX and no-AVX, because some cloud providers don't provide AVX-enabled VMs. +## Docker Images -## Development Environment +You can get the latest PaddlePaddle docker images by +`docker pull paddlepaddle/paddle:` or build one by yourself. -Here we describe how to use above two images. We start from considering our daily development environment. +### Official Docker Releases -Developers work on a computer, which is usually a laptop or desktop: +Official docker images at +[here](https://hub.docker.com/r/paddlepaddle/paddle/tags/), +you can choose either latest or images with a release tag like `0.10.0`, +Currently available tags are: - +| Tag | Description | +| ------ | --------------------- | +| latest | latest CPU only image | +| latest-gpu | latest binary with GPU support | +| 0.10.0 | release 0.10.0 CPU only binary image | +| 0.10.0-gpu | release 0.10.0 with GPU support | -or, they might rely on a more sophisticated box (like with GPUs): +### Build Your Own Image - +Build PaddlePaddle docker images are quite simple since PaddlePaddle can +be installed by just running `pip install`. A sample `Dockerfile` is: -A principle here is that source code lies on the development computer (host) so that editors like Eclipse can parse the source code to support auto-completion. +```dockerfile +FROM nvidia/cuda:7.5-cudnn5-runtime-centos6 +RUN yum install -y centos-release-SCL +RUN yum install -y python27 +# This whl package is generated by previous build steps. +ADD python/dist/paddlepaddle-0.10.0-cp27-cp27mu-linux_x86_64.whl / +RUN pip install /paddlepaddle-0.10.0-cp27-cp27mu-linux_x86_64.whl && rm -f /*.whl +``` +Then build the image by running `docker build -t [REPO]/paddle:[TAG] .` under +the directory containing your own `Dockerfile`. -## Usages +- NOTE: note that you can choose different base images for your environment, you can find all the versions [here](https://hub.docker.com/r/nvidia/cuda/). -### Build the Development Docker Image +### Use Docker Images -The following commands check out the source code to the host and build the development image `paddle:dev`: +Suppose that you have written an application program `train.py` using +PaddlePaddle, we can test and run it using docker: ```bash -git clone https://github.com/PaddlePaddle/Paddle paddle -cd paddle -docker build -t paddle:dev . +docker run --rm -it -v $PWD:/work paddlepaddle/paddle /work/a.py ``` -The `docker build` command assumes that `Dockerfile` is in the root source tree. Note that in this design, this `Dockerfile` is this only one in our repo. - -Users can specify a Ubuntu mirror server for faster downloading: - -```bash -docker build -t paddle:dev --build-arg UBUNTU_MIRROR=mirror://mirrors.ubuntu.com/mirrors.txt . -``` +But this works only if all dependencies of `train.py` are in the production image. If this is not the case, we need to build a new Docker image from the production image and with more dependencies installs. -### Build PaddlePaddle from Source Code +### Run PaddlePaddle Book In Docker -Given the development image `paddle:dev`, the following command builds PaddlePaddle from the source tree on the development computer (host): +Our [book repo](https://github.com/paddlepaddle/book) also provide a docker +image to start a jupiter notebook inside docker so that you can run this book +using docker: ```bash -docker run --rm -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_AVX=ON" -e "WITH_TESTING=OFF" -e "RUN_TEST=OFF" paddle:dev +docker run -d -p 8888:8888 paddlepaddle/book ``` -This command mounts the source directory on the host into `/paddle` in the container, so the default entry point of `paddle:dev`, `build.sh`, could build the source code with possible local changes. When it writes to `/paddle/build` in the container, it writes to `$PWD/build` on the host indeed. - -`build.sh` builds the following: - -- PaddlePaddle binaries, -- `$PWD/build/paddle-.deb` for production installation, and -- `$PWD/build/Dockerfile`, which builds the production Docker image. +Please refer to https://github.com/paddlepaddle/book if you want to build this +docker image by your self. -Users can specify the following Docker build arguments with either "ON" or "OFF" value: -- `WITH_GPU`: ***Required***. Generates NVIDIA CUDA GPU code and relies on CUDA libraries. -- `WITH_AVX`: ***Required***. Set to "OFF" prevents from generating AVX instructions. If you don't know what is AVX, you might want to set "ON". -- `WITH_TEST`: ***Optional, default OFF***. Build unit tests binaries. Once you've built the unit tests, you can run these test manually by the following command: - ```bash - docker run --rm -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_AVX=ON" paddle:dev sh -c "cd /paddle/build; make coverall" - ``` -- `RUN_TEST`: ***Optional, default OFF***. Run unit tests after building. You can't run unit tests without building it. +### Run Distributed Applications -### Build the Production Docker Image +In our [API design doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/api.md#distributed-training), we proposed an API that starts a distributed training job on a cluster. This API need to build a PaddlePaddle application into a Docker image as above and calls kubectl to run it on the cluster. This API might need to generate a Dockerfile look like above and call `docker build`. -The following command builds the production image: +Of course, we can manually build an application image and launch the job using the kubectl tool: ```bash -docker build -t paddle -f build/Dockerfile ./build +docker build -f some/Dockerfile -t myapp . +docker tag myapp me/myapp +docker push +kubectl ... ``` -This production image is minimal -- it includes binary `paddle`, the shared library `libpaddle.so`, and Python runtime. +## Docker Images for Developers -### Run PaddlePaddle Applications +We have a special docker image for developers: +`paddlepaddle/paddle:-dev`. This image is also generated from +https://github.com/PaddlePaddle/buildtools -Again the development happens on the host. Suppose that we have a simple application program in `a.py`, we can test and run it using the production image: +This a development image contains only the +development tools and standardizes the building procedure. Users include: -```bash -docker run --rm -it -v $PWD:/work paddle /work/a.py -``` +- developers -- no longer need to install development tools on the host, and can build their current work on the host (development computer). +- release engineers -- use this to build the official release from certain branch/tag on Github.com. +- document writers / Website developers -- Our documents are in the source repo in the form of .md/.rst files and comments in source code. We need tools to extract the information, typeset, and generate Web pages. -But this works only if all dependencies of `a.py` are in the production image. If this is not the case, we need to build a new Docker image from the production image and with more dependencies installs. +Of course, developers can install building tools on their development computers. But different versions of PaddlePaddle might require different set or version of building tools. Also, it makes collaborative debugging easier if all developers use a unified development environment. -### Build and Run PaddlePaddle Applications +The development image contains the following tools: -We need a Dockerfile in https://github.com/paddlepaddle/book that builds Docker image `paddlepaddle/book:`, basing on the PaddlePaddle production image: + - gcc/clang + - nvcc + - Python + - sphinx + - woboq + - sshd -``` -FROM paddlepaddle/paddle: -RUN pip install -U matplotlib jupyter ... -COPY . /book -EXPOSE 8080 -CMD ["jupyter"] -``` +Many developers work on a remote computer with GPU; they could ssh into the computer and `docker exec` into the development container. However, running `sshd` in the container allows developers to ssh into the container directly. -The book image is an example of PaddlePaddle application image. We can build it -```bash -git clone https://github.com/paddlepaddle/book -cd book -docker build -t book . -``` +### Development Workflow -### Build and Run Distributed Applications +Here we describe how the workflow goes on. We start from considering our daily development environment. -In our [API design doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/api.md#distributed-training), we proposed an API that starts a distributed training job on a cluster. This API need to build a PaddlePaddle application into a Docker image as above and calls kubectl to run it on the cluster. This API might need to generate a Dockerfile look like above and call `docker build`. +Developers work on a computer, which is usually a laptop or desktop: -Of course, we can manually build an application image and launch the job using the kubectl tool: + -```bash -docker build -f some/Dockerfile -t myapp . -docker tag myapp me/myapp -docker push -kubectl ... -``` +or, they might rely on a more sophisticated box (like with GPUs): + + + +A principle here is that source code lies on the development computer (host) so that editors like Eclipse can parse the source code to support auto-completion. ### Reading source code with woboq codebrowser + For developers who are interested in the C++ source code, please use -e "WOBOQ=ON" to enable the building of C++ source code into HTML pages using [Woboq codebrowser](https://github.com/woboq/woboq_codebrowser). - The following command builds PaddlePaddle, generates HTML pages from C++ source code, and writes HTML pages into `$HOME/woboq_out` on the host: ```bash -docker run -v $PWD:/paddle -v $HOME/woboq_out:/woboq_out -e "WITH_GPU=OFF" -e "WITH_AVX=ON" -e "WITH_TEST=ON" -e "WOBOQ=ON" paddle:dev +docker run -v $PWD:/paddle -v $HOME/woboq_out:/woboq_out -e "WITH_GPU=OFF" -e "WITH_AVX=ON" -e "WITH_TEST=ON" -e "WOBOQ=ON" paddlepaddle/paddle:latest-dev ``` - You can open the generated HTML files in your Web browser. Or, if you want to run a Nginx container to serve them for a wider audience, you can run: diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 256500c56a..e9c89eee1a 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -1,23 +1,6 @@ #!/bin/bash -set -xe - - function cmake_gen() { - # Set BASE_IMAGE according to env variables - if [[ ${WITH_GPU} == "ON" ]]; then - BASE_IMAGE="nvidia/cuda:8.0-cudnn5-runtime-ubuntu16.04" - else - BASE_IMAGE="ubuntu:16.04" - fi - - DOCKERFILE_GPU_ENV="" - DOCKERFILE_CUDNN_DSO="" - if [[ ${WITH_GPU:-OFF} == 'ON' ]]; then - DOCKERFILE_GPU_ENV="ENV LD_LIBRARY_PATH /usr/lib/x86_64-linux-gnu:${LD_LIBRARY_PATH}" - DOCKERFILE_CUDNN_DSO="RUN ln -s /usr/lib/x86_64-linux-gnu/libcudnn.so.5 /usr/lib/x86_64-linux-gnu/libcudnn.so" - fi - mkdir -p /paddle/build cd /paddle/build @@ -26,10 +9,29 @@ function cmake_gen() { # delete previous built whl packages rm -rf /paddle/paddle/dist 2>/dev/null || true + # Support build for all python versions, currently + # including cp27-cp27m and cp27-cp27mu. + PYTHON_FLAGS="" + if [ "$1" != "" ]; then + echo "using python abi: $1" + if [ "$1" == "cp27-cp27m" ]; then + export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs2/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs4/lib:} + PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27m/bin/python + -DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27m/include/python2.7 + -DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs2/lib/libpython2.7.so" + elif [ "$1" == "cp27-cp27mu" ]; then + export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs4/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs2/lib:} + PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27mu/bin/python + -DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27mu/include/python2.7 + -DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs4/lib/libpython2.7.so" + fi + fi + cat < ids(height); std::vector indices(height + 1); indices[0] = 0; @@ -84,6 +85,8 @@ MatrixPtr makeRandomSparseMatrix(size_t height, } return mat; } +#endif + return nullptr; } void generateSequenceStartPositions(size_t batchSize, diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 32578ad779..c8632295a2 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -37,10 +37,10 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in ${CMAKE_CURRENT_BINARY_DIR}/setup.py) -add_custom_command(OUTPUT ${PADDLE_SOURCE_DIR}/python/paddle/v2/framework/core.so - COMMAND cmake -E copy $ ${PADDLE_SOURCE_DIR}/python/paddle/v2/framework/core.so +add_custom_command(OUTPUT ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/core.so + COMMAND cmake -E copy $ ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/core.so DEPENDS paddle_pybind) -add_custom_target(copy_paddle_pybind ALL DEPENDS ${PADDLE_SOURCE_DIR}/python/paddle/v2/framework/core.so) +add_custom_target(copy_paddle_pybind ALL DEPENDS ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/core.so) add_custom_command(OUTPUT ${PADDLE_PYTHON_BUILD_DIR}/.timestamp @@ -66,7 +66,7 @@ if (WITH_TESTING) add_subdirectory(paddle/v2/tests) add_subdirectory(paddle/v2/reader/tests) add_subdirectory(paddle/v2/plot/tests) - add_subdirectory(paddle/v2/framework/tests) + add_subdirectory(paddle/v2/fluid/tests) endif() endif() install(DIRECTORY ${PADDLE_PYTHON_PACKAGE_DIR} diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 43d02bf70e..5bd68e211a 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1200,8 +1200,14 @@ def TestData(data_config, async_load_data=None): #caffe_mode: compute the output size using floor instead of ceil, # which is consistent of caffe and CuDNN's convention. -def cnn_output_size(img_size, filter_size, padding, stride, caffe_mode): - output = (2 * padding + img_size - filter_size) / float(stride) +def cnn_output_size(img_size, + filter_size, + padding, + stride, + caffe_mode, + dilation=1): + filter_s = (filter_size - 1) * dilation + 1 + output = (2 * padding + img_size - filter_s) / float(stride) if caffe_mode: return 1 + int(math.floor(output)) else: @@ -1210,8 +1216,14 @@ def cnn_output_size(img_size, filter_size, padding, stride, caffe_mode): #calcualte image_size based on output_size for de-convolution (ConvTransLayer). #It is the reverse function of cnn_output_size -def cnn_image_size(output_size, filter_size, padding, stride, caffe_mode): - img_size = (output_size - 1) * stride + filter_size - 2 * padding +def cnn_image_size(output_size, + filter_size, + padding, + stride, + caffe_mode, + dilation=1): + filter_s = (filter_size - 1) * dilation + 1 + img_size = (output_size - 1) * stride + filter_s - 2 * padding if not caffe_mode: img_size = img_size + 1 return img_size @@ -1253,9 +1265,9 @@ def parse_bilinear(bilinear, input_layer_name, bilinear_conf): def parse_pool(pool, input_layer_name, pool_conf, ceil_mode): pool_conf.pool_type = pool.pool_type config_assert(pool.pool_type in [ - 'max-projection', 'avg-projection', 'cudnn-max-pool', 'cudnn-avg-pool' - ], "pool-type %s is not in " - "['max-projection', 'avg-projection', " + 'max-projection', 'avg-projection', 'max-pool-with-mask', 'cudnn-max-pool', 'cudnn-avg-pool' + ], "pool-type %s is not in " \ + "['max-projection', 'avg-projection', 'max-pool-with-mask'," \ "'cudnn-max-pool', 'cudnn-avg-pool']" % pool.pool_type) pool_conf.channels = pool.channels @@ -1376,6 +1388,12 @@ def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False): conv_conf.stride_y = conv.stride_y conv_conf.groups = conv.groups conv_conf.caffe_mode = conv.caffe_mode + if not conv.dilation: + conv.dilation = 1 + conv.dilation_y = 1 + else: + conv_conf.dilation = conv.dilation + conv_conf.dilation_y = conv.dilation_y if not trans: conv_conf.filter_channels = conv.channels / conv.groups @@ -1383,20 +1401,20 @@ def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False): get_img_size(input_layer_name, conv.channels) conv_conf.output_x = cnn_output_size( conv_conf.img_size, conv_conf.filter_size, conv_conf.padding, - conv_conf.stride, conv_conf.caffe_mode) + conv_conf.stride, conv_conf.caffe_mode, conv.dilation) conv_conf.output_y = cnn_output_size( conv_conf.img_size_y, conv_conf.filter_size_y, conv_conf.padding_y, - conv_conf.stride_y, conv_conf.caffe_mode) + conv_conf.stride_y, conv_conf.caffe_mode, conv.dilation_y) else: conv_conf.filter_channels = num_filters / conv.groups conv_conf.output_x, conv_conf.output_y = \ get_img_size(input_layer_name, conv.channels) conv_conf.img_size = cnn_image_size( conv_conf.output_x, conv_conf.filter_size, conv_conf.padding, - conv_conf.stride, conv_conf.caffe_mode) + conv_conf.stride, conv_conf.caffe_mode, conv.dilation) conv_conf.img_size_y = cnn_image_size( conv_conf.output_y, conv_conf.filter_size_y, conv_conf.padding_y, - conv_conf.stride_y, conv_conf.caffe_mode) + conv_conf.stride_y, conv_conf.caffe_mode, conv.dilation_y) #caffe_mode: compute the output size using floor instead of ceil, diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 617fbff948..5de1c18950 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -20,7 +20,7 @@ from paddle.trainer.config_parser import * from .activations import LinearActivation, SigmoidActivation, TanhActivation, \ ReluActivation, IdentityActivation, SoftmaxActivation, BaseActivation from .evaluators import * -from .poolings import MaxPooling, AvgPooling, BasePoolingType, \ +from .poolings import MaxPooling, AvgPooling, MaxWithMaskPooling, BasePoolingType, \ CudnnAvgPooling, CudnnMaxPooling from .attrs import * from .default_decorators import * @@ -888,7 +888,7 @@ def mixed_layer(size=0, :type size: int :param input: The input of this layer. It is an optional parameter. If set, then this function will just return layer's name. - :param act: Activation Type. LinearActivation is the default. + :param act: Activation Type. LinearActivation is the default activation. :type act: BaseActivation :param bias_attr: The bias attribute. If the parameter is set to False or an object whose type is not ParameterAttribute, no bias is defined. If the @@ -1030,7 +1030,7 @@ def fc_layer(input, :type input: LayerOutput | list | tuple :param size: The layer dimension. :type size: int - :param act: Activation Type. TanhActivation is the default. + :param act: Activation Type. TanhActivation is the default activation. :type act: BaseActivation :param param_attr: The Parameter Attribute|list. :type param_attr: ParameterAttribute @@ -1527,7 +1527,7 @@ def lstmemory(input, :type input: LayerOutput :param reverse: is sequence process reversed or not. :type reverse: bool - :param act: Activation type. TanhActivation is the default. :math:`h_t` + :param act: Activation type. TanhActivation is the default activation. :type act: BaseActivation :param gate_act: gate activation type, SigmoidActivation by default. :type gate_act: BaseActivation @@ -1920,7 +1920,7 @@ def repeat_layer(input, False for treating input as column vector and repeating in the row direction. :type as_row_vector: bool - :param act: Activation type. IdentityActivation is the default. + :param act: Activation type. IdentityActivation is the default activation. :type act: BaseActivation :type name: basestring :param layer_attr: extra layer attributes. @@ -1974,7 +1974,7 @@ def seq_reshape_layer(input, :type reshape_size: int :param name: The name of this layer. It is optional. :type name: basestring - :param act: Activation type. IdentityActivation is the default. + :param act: Activation type. IdentityActivation is the default activation. :type act: BaseActivation :param layer_attr: extra layer attributes. :type layer_attr: ExtraLayerAttribute. @@ -2487,7 +2487,7 @@ def img_conv_layer(input, shape will be (filter_size, filter_size_y). :type filter_size_y: int | None :param num_filters: Each filter group's number of filter - :param act: Activation type. ReluActivation is the default. + :param act: Activation type. ReluActivation is the default activation. :type act: BaseActivation :param groups: Group size of filters. :type groups: int @@ -2571,7 +2571,9 @@ def img_conv_layer(input, if layer_type: if dilation > 1 or dilation_y > 1: - assert layer_type in ["cudnn_conv", "cudnn_convt"] + assert layer_type in [ + "cudnn_conv", "cudnn_convt", "exconv", "exconvt" + ] if trans: assert layer_type in ["exconvt", "cudnn_convt"] else: @@ -2699,9 +2701,9 @@ def img_pool_layer(input, elif isinstance(pool_type, AvgPooling): pool_type.name = 'avg' - assert type(pool_type) in [AvgPooling, MaxPooling, CudnnAvgPooling, + assert type(pool_type) in [AvgPooling, MaxPooling, MaxWithMaskPooling, CudnnAvgPooling, CudnnMaxPooling], \ - "only (Cudnn)AvgPooling, (Cudnn)MaxPooling are supported" + "only (Cudnn)AvgPooling, (Cudnn)MaxPooling, MaxWithMaskPooling are supported" type_name = pool_type.name + '-projection' \ if ( @@ -3253,7 +3255,7 @@ def addto_layer(input, act=None, name=None, bias_attr=None, layer_attr=None): :param input: Input layers. It could be a LayerOutput or list/tuple of LayerOutput. :type input: LayerOutput | list | tuple - :param act: Activation Type. LinearActivation is the default. + :param act: Activation Type. LinearActivation is the default activation. :type act: BaseActivation :param bias_attr: The bias attribute. If the parameter is set to False or an object whose type is not ParameterAttribute, no bias is defined. If the @@ -3311,7 +3313,7 @@ def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None): :type name: basestring :param input: input layers or projections :type input: list | tuple | collections.Sequence - :param act: Activation type. IdentityActivation is the default. + :param act: Activation type. IdentityActivation is the default activation. :type act: BaseActivation :param layer_attr: Extra Layer Attribute. :type layer_attr: ExtraLayerAttribute @@ -3406,7 +3408,7 @@ def seq_concat_layer(a, b, act=None, name=None, layer_attr=None, :type a: LayerOutput :param b: input sequence layer :type b: LayerOutput - :param act: Activation type. IdentityActivation is the default. + :param act: Activation type. IdentityActivation is the default activation. :type act: BaseActivation :param layer_attr: Extra Layer Attribute. :type layer_attr: ExtraLayerAttribute @@ -3572,31 +3574,32 @@ def lstm_step_layer(input, ... - This layer has two outputs. Default output is :math:`h_t`. The other - output is :math:`o_t`, whose name is 'state' and can use + This layer has two outputs. The default output is :math:`h_t`. The other + output is :math:`o_t`, whose name is 'state' and users can use :code:`get_output_layer` to extract this output. :param name: The name of this layer. It is optional. :type name: basestring - :param size: Layer's size. NOTE: lstm layer's size, should be equal to - :code:`input.size/4`, and should be equal to - :code:`state.size`. + :param size: The dimension of this layer's output, which must be + equal to the dimension of the state. :type size: int - :param input: input layer. :math:`Wx_t + Wh_{t-1}` + :param input: The input of this layer. :type input: LayerOutput - :param state: State Layer. :math:`c_{t-1}` + :param state: The state of the LSTM unit. :type state: LayerOutput - :param act: Activation type. TanhActivation is the default. + :param act: Activation type. TanhActivation is the default activation. :type act: BaseActivation - :param gate_act: Gate Activation Type. SigmoidActivation is the default. + :param gate_act: Activation type of the gate. SigmoidActivation is the + default activation. :type gate_act: BaseActivation - :param state_act: State Activation Type. TanhActivation is the default. + :param state_act: Activation type of the state. TanhActivation is the + default activation. :type state_act: BaseActivation :param bias_attr: The bias attribute. If the parameter is set to False or an object whose type is not ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. :type bias_attr: ParameterAttribute | None | bool | Any - :param layer_attr: layer's extra attribute. + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for details. :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. :rtype: LayerOutput @@ -3641,22 +3644,31 @@ def gru_step_layer(input, layer_attr=None): """ - :param input: + :param input: The input of this layer, whose dimension can be divided by 3. :type input: LayerOutput - :param output_mem: - :param size: - :param act: + :param output_mem: A memory which memorizes the output of this layer at previous + time step. + :type output_mem: LayerOutput + :param size: The dimension of this layer's output. If it is not set or set to None, + it will be set to one-third of the dimension of the input automatically. + :type size: int + :param act: Activation type of this layer's output. TanhActivation + is the default activation. :type act: BaseActivation :param name: The name of this layer. It is optional. - :param gate_act: Activation type of this layer's two gates. Default is Sigmoid. + :type name: basestring + :param gate_act: Activation type of this layer's two gates. SigmoidActivation is + the default activation. :type gate_act: BaseActivation - :param bias_attr: The bias attribute. If the parameter is set to False or an object - whose type is not ParameterAttribute, no bias is defined. If the - parameter is set to True, the bias is initialized to zero. + :param bias_attr: The parameter attribute for bias. If this parameter is set to + False or an object whose type is not ParameterAttribute, no bias + is defined. If this parameter is set to True, + the bias is initialized to zero. :type bias_attr: ParameterAttribute | None | bool | Any - :param param_attr: the parameter_attribute for transforming the output_mem - from previous step. - :param layer_attr: + :param param_attr: The parameter attribute. See ParameterAttribute for details. + :type param_attr: ParameterAttribute + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for details. + :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. :rtype: LayerOutput """ @@ -3701,24 +3713,34 @@ def gru_step_naive_layer(input, param_attr=None, layer_attr=None): """ - GRU Step Layer, but using MixedLayer to generate. It support ERROR_CLIPPING + GRU Step Layer, which is realized using PaddlePaddle API. It supports ERROR_CLIPPING and DROPOUT. - :param input: - :param output_mem: - :param size: + :param input: The input of this layer, whose dimensionality can be divided by 3. + :param output_mem: A memory which memorizes the output of this layer at previous + time step. + :type output_mem: LayerOutput + :param size: The dimension of this layer's output. If it is not set or set to None, + it will be set to one-third of the dimension of the input automatically. + :type size: int :param name: The name of this layer. It is optional. - :param act: + :type name: basestring + :param act: Activation type of this layer's output. TanhActivation + is the default activation. :type act: BaseActivation - :param gate_act: Activation type of this layer's two gates. Default is Sigmoid. + :param gate_act: Activation type of this layer's two gates. SigmoidActivation + is the default activation. :type gate_act: BaseActivation - :param bias_attr: The bias attribute. If the parameter is set to False or an object - whose type is not ParameterAttribute, no bias is defined. If the - parameter is set to True, the bias is initialized to zero. + :param bias_attr: The parameter attribute for bias. If this parameter is set to + False or an object whose type is not ParameterAttribute, no bias + is defined. If this parameter is set to True, + the bias is initialized to zero. :type bias_attr: ParameterAttribute | None | bool | Any - :param param_attr: - :param layer_attr: - :return: + :param param_attr: The parameter attribute. See ParameterAttribute for details. + :type param_attr: ParameterAttribute + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for details. + :type layer_attr: ExtraLayerAttribute + :return: LayerOutput object. :rtype: LayerOutput """ if input.size % 3 != 0: @@ -3780,12 +3802,13 @@ def get_output_layer(input, arg_name, name=None, layer_attr=None): :param name: The name of this layer. It is optional. :type name: basestring - :param input: get output layer's input. And this layer should contains + :param input: The input layer. And this layer should contain multiple outputs. :type input: LayerOutput - :param arg_name: Output name from input. + :param arg_name: The name of the output to be extracted from the input layer. :type arg_name: basestring - :param layer_attr: Layer's extra attribute. + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. :return: LayerOutput object. :rtype: LayerOutput """ @@ -3842,17 +3865,20 @@ def recurrent_layer(input, :param input: The input of this layer. :type input: LayerOutput - :param act: Activation type. TanhActivation is the default. + :param act: Activation type. TanhActivation is the default activation. :type act: BaseActivation - :param bias_attr: The bias attribute. If the parameter is set to False or an object - whose type is not ParameterAttribute, no bias is defined. If the - parameter is set to True, the bias is initialized to zero. + :param bias_attr: The parameter attribute for bias. If this parameter is set to + False or an object whose type is not ParameterAttribute, + no bias is defined. If the parameter is set to True, + the bias is initialized to zero. :type bias_attr: ParameterAttribute | None | bool | Any - :param param_attr: parameter attribute. + :param param_attr: The parameter attribute. See ParameterAttribute for + details. :type param_attr: ParameterAttribute :param name: The name of this layer. It is optional. :type name: basestring - :param layer_attr: Layer Attribute. + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. :rtype: LayerOutput @@ -3877,7 +3903,7 @@ def recurrent_layer(input, class StaticInput(object): """ StaticInput is only used in recurrent_group which defines a read-only memory - that can be a sequence or non-sequence. + and can be a sequence or non-sequence. :param size: DEPRECATED :param is_seq: DEPRECATED """ @@ -3910,8 +3936,8 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): Recurrent layer group is an extremely flexible recurrent unit in PaddlePaddle. As long as the user defines the calculation done within a time step, PaddlePaddle will iterate such a recurrent calculation over - sequence input. This is extremely usefull for attention based model, or - Neural Turning Machine like models. + sequence input. This is useful for attention-based models, or Neural + Turning Machine like models. The basic usage (time steps) is: @@ -3933,18 +3959,17 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): demo/seqToseq/seqToseq_net.py - sequence steps: paddle/gserver/tests/sequence_nest_layer_group.conf - :param step: recurrent one time step function.The input of this function is - input of the group. The return of this function will be - recurrent group's return value. + :param step: A step function which takes the input of recurrent_group as its own + input and returns values as recurrent_group's output every time step. - The recurrent group scatter a sequence into time steps. And - for each time step, will invoke step function, and return - a time step result. Then gather each time step of output into + The recurrent group scatters a sequence into time steps. And + for each time step, it will invoke step function, and return + a time step result. Then gather outputs of each time step into layer group's output. :type step: callable - :param name: recurrent_group's name. + :param name: The recurrent_group's name. It is optional. :type name: basestring :param input: Input links array. @@ -3952,11 +3977,11 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): LayerOutput will be scattered into time steps. SubsequenceInput will be scattered into sequence steps. StaticInput will be imported to each time step, and doesn't change - through time. It's a mechanism to access layer outside step function. + over time. It's a mechanism to access layer outside step function. :type input: LayerOutput | StaticInput | SubsequenceInput | list | tuple - :param reverse: If reverse is set true, the recurrent unit will process the + :param reverse: If reverse is set to True, the recurrent unit will process the input sequence in a reverse order. :type reverse: bool @@ -4091,7 +4116,8 @@ def maxid_layer(input, name=None, layer_attr=None): :type input: LayerOutput :param name: The name of this layer. It is optional. :type name: basestring - :param layer_attr: extra layer attributes. + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. :type layer_attr: ExtraLayerAttribute. :return: LayerOutput object. :rtype: LayerOutput @@ -4124,11 +4150,12 @@ def out_prod_layer(input1, input2, name=None, layer_attr=None): :param name: The name of this layer. It is optional. :type name: basestring - :param input1: The first input layer name. + :param input1: The first input layer. :type input: LayerOutput - :param input2: The second input layer name. + :param input2: The second input layer. :type input2: LayerOutput - :param layer_attr: extra layer attributes. + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. :type layer_attr: ExtraLayerAttribute. :return: LayerOutput object. :rtype: LayerOutput @@ -4167,9 +4194,10 @@ def eos_layer(input, eos_id, name=None, layer_attr=None): :type name: basestring :param input: The input of this layer. :type input: LayerOutput - :param eos_id: end id of sequence + :param eos_id: End id of sequence :type eos_id: int - :param layer_attr: extra layer attributes. + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. :type layer_attr: ExtraLayerAttribute. :return: LayerOutput object. :rtype: LayerOutput @@ -4230,8 +4258,9 @@ def beam_search(step, - machine translation : demo/seqToseq/translation/gen.conf \ demo/seqToseq/seqToseq_net.py - :param name: Name of the recurrent unit that generates sequences. - :type name: base string + :param name: The name of the recurrent unit that is responsible for + generating sequences. It is optional. + :type name: basestring :param step: A callable function that defines the calculation in a time step, and it is applied to sequences with arbitrary length by sharing a same set of weights. @@ -4356,16 +4385,18 @@ def square_error_cost(input, :param name: The name of this layer. It is optional. :type name: basestring - :param input: Network prediction. + :param input: The first input layer. :type input: LayerOutput - :param label: Data label. + :param label: The input label. :type label: LayerOutput - :param weight: The weight affects the cost, namely the scale of cost. - It is an optional argument. + :param weight: The weight layer defines a weight for each sample in the + mini-batch. It is optional. :type weight: LayerOutput - :param coeff: The coefficient affects the gradient in the backward. + :param coeff: The weight of the gradient in the back propagation. + 1.0 is the default value. :type coeff: float - :param layer_attr: layer's extra attribute. + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. :rtype: LayerOutput @@ -4398,17 +4429,20 @@ def classification_cost(input, :param name: The name of this layer. It is optional. :type name: basestring - :param input: input layer name. network output. + :param input: The first input layer. :type input: LayerOutput - :param label: label layer name. data_layer often. + :param label: The input label. :type label: LayerOutput - :param weight: The weight affects the cost, namely the scale of cost. - It is an optional argument. + :param weight: The weight layer defines a weight for each sample in the + mini-batch. It is optional. :type weight: LayerOutput - :param evaluator: Evaluator method. - :param layer_attr: layer's extra attribute. + :param evaluator: Evaluator method. classification_error_evaluator is the default. + :type evaluator: Evaluator method + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. :type layer_attr: ExtraLayerAttribute - :param coeff: The coefficient affects the gradient in the backward. + :param coeff: The weight of the gradient in the back propagation. + 1.0 is the default value. :type coeff: float :return: LayerOutput object. :rtype: LayerOutput @@ -4461,7 +4495,7 @@ def conv_operator(img, Different from img_conv_layer, conv_op is an Operator, which can be used in mixed_layer. And conv_op takes two inputs to perform convolution. The first input is the image and the second is filter kernel. It only - support GPU mode. + supports GPU mode. The example usage is: @@ -4473,27 +4507,31 @@ def conv_operator(img, num_filters=64, num_channels=64) - :param img: input image + :param img: The input image. :type img: LayerOutput - :param filter: input filter + :param filter: The input filter. :type filter: LayerOutput - :param filter_size: The x dimension of a filter kernel. + :param filter_size: The dimension of the filter kernel on the x axis. :type filter_size: int - :param filter_size_y: The y dimension of a filter kernel. Since - PaddlePaddle now supports rectangular filters, - the filter's shape can be (filter_size, filter_size_y). + :param filter_size_y: The dimension of the filter kernel on the y axis. + If the parameter is not set or set to None, it will + set to 'filter_size' automatically. :type filter_size_y: int - :param num_filters: channel of output data. + :param num_filters: The number of the output channels. :type num_filters: int - :param num_channels: channel of input data. + :param num_channels: The number of the input channels. If the parameter is not set + or set to None, it will be automatically set to the channel + number of the 'img'. :type num_channels: int - :param stride: The x dimension of the stride. + :param stride: The stride on the x axis. :type stride: int - :param stride_y: The y dimension of the stride. + :param stride_y: The stride on the y axis. If the parameter is not set or + set to None, it will be set to 'stride' automatically. :type stride_y: int - :param padding: The x dimension of padding. + :param padding: The padding size on the x axis. :type padding: int - :param padding_y: The y dimension of padding. + :param padding_y: The padding size on the y axis. If the parameter is not set + or set to None, it will be set to 'padding' automatically. :type padding_y: int :return: A ConvOperator Object. :rtype: ConvOperator @@ -4544,9 +4582,9 @@ def conv_projection(input, param_attr=None, trans=False): """ - Different from img_conv_layer and conv_op, conv_projection is an Projection, - which can be used in mixed_layer and conat_layer. It use cudnn to implement - conv and only support GPU mode. + Different from img_conv_layer and conv_op, conv_projection is a Projection, + which can be used in mixed_layer and concat_layer. It uses cudnn to implement + convolution and only supports GPU mode. The example usage is: @@ -4559,32 +4597,45 @@ def conv_projection(input, :param input: The input of this layer. :type input: LayerOutput - :param filter_size: The x dimension of a filter kernel. - :type filter_size: int - :param filter_size_y: The y dimension of a filter kernel. Since - PaddlePaddle now supports rectangular filters, - the filter's shape can be (filter_size, filter_size_y). + :param filter_size: The dimensions of the filter kernel. If the parameter is + set to one integer, the two dimensions on x and y axises + will be same when filter_size_y is not set. If it is set + to a list, the first element indicates the dimension on + the x axis, and the second is used to specify the dimension + on the y axis when filter_size is not provided. + :type filter_size: int | tuple | list + :param filter_size_y: The dimension of the filter kernel on the y axis. If the parameter + is not set, it will be set automatically according to filter_size. :type filter_size_y: int - :param num_filters: channel of output data. + :param num_filters: The number of filters. :type num_filters: int - :param num_channels: channel of input data. + :param num_channels: The number of the input channels. :type num_channels: int - :param stride: The x dimension of the stride. - :type stride: int - :param stride_y: The y dimension of the stride. + :param stride: The strides. If the parameter is set to one integer, the strides + on x and y axises will be same when stride_y is not set. If it is + set to a list, the first element indicates the stride on the x axis, + and the second is used to specify the stride on the y axis when + stride_y is not provided. + :type stride: int | tuple | list + :param stride_y: The stride on the y axis. :type stride_y: int - :param padding: The x dimension of padding. - :type padding: int - :param padding_y: The y dimension of padding. + :param padding: The padding sizes. If the parameter is set to one integer, the padding + sizes on x and y axises will be same when padding_y is not set. If it + is set to a list, the first element indicates the padding size on the + x axis, and the second is used to specify the padding size on the y axis + when padding_y is not provided. + :type padding: int | tuple | list + :param padding_y: The padding size on the y axis. :type padding_y: int :param groups: The group number. :type groups: int - :param param_attr: Convolution param attribute. None means default attribute + :param param_attr: The parameter attribute of the convolution. See ParameterAttribute for + details. :type param_attr: ParameterAttribute - :param trans: whether it is convTrans or conv + :param trans: Whether it is ConvTransProjection or ConvProjection :type trans: bool - :return: A DotMulProjection Object. - :rtype: DotMulProjection + :return: A Projection Object. + :rtype: ConvTransProjection | ConvProjection """ if num_channels is None: assert input.num_filters is not None @@ -4649,13 +4700,13 @@ def pad_layer(input, layer_attr=None): """ This operation pads zeros to the input data according to pad_c,pad_h - and pad_w. pad_c, pad_h, pad_w specifies the which dimension and size - of padding. And the input data shape is NCHW. + and pad_w. pad_c, pad_h, pad_w specify the size in the corresponding + dimension. And the input data shape is NCHW. - For example, pad_c=[2,3] means padding 2 zeros before the - input data and 3 zeros after the input data in channel dimension. - pad_h means padding zeros in height dimension. pad_w means padding zeros - in width dimension. + For example, pad_c=[2,3] means padding 2 zeros before the input data + and 3 zeros after the input data in the channel dimension. pad_h means + padding zeros in the height dimension. pad_w means padding zeros in the + width dimension. For example, @@ -4692,13 +4743,14 @@ def pad_layer(input, :param input: The input of this layer. :type input: LayerOutput - :param pad_c: padding size in channel dimension. + :param pad_c: The padding size in the channel dimension. :type pad_c: list | None - :param pad_h: padding size in height dimension. + :param pad_h: The padding size in the height dimension. :type pad_h: list | None - :param pad_w: padding size in width dimension. + :param pad_w: The padding size in the width dimension. :type pad_w: list | None - :param layer_attr: Extra Layer Attribute. + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. :type layer_attr: ExtraLayerAttribute :param name: The name of this layer. It is optional. :type name: basestring @@ -4747,7 +4799,7 @@ def pad_layer(input, @layer_support() def conv_shift_layer(a, b, name=None, layer_attr=None): """ - This layer performs cyclic convolution for two input. For example: + This layer performs cyclic convolution on two inputs. For example: - a[in]: contains M elements. - b[in]: contains N elements (N should be odd). - c[out]: contains M elements. @@ -4756,7 +4808,7 @@ def conv_shift_layer(a, b, name=None, layer_attr=None): c[i] = \sum_{j=-(N-1)/2}^{(N-1)/2}a_{i+j} * b_{j} - In this formular: + In this formula: - a's index is computed modulo M. When it is negative, then get item from the right side (which is the end of array) to the left. - b's index is computed modulo N. When it is negative, then get item from @@ -4770,11 +4822,12 @@ def conv_shift_layer(a, b, name=None, layer_attr=None): :param name: The name of this layer. It is optional. :type name: basestring - :param a: Input layer a. + :param a: The first input of this layer. :type a: LayerOutput - :param b: input layer b. + :param b: The second input of this layer. :type b: LayerOutput - :param layer_attr: layer's extra attribute. + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. :rtype: LayerOutput @@ -4805,8 +4858,8 @@ def tensor_layer(a, bias_attr=None, layer_attr=None): """ - This layer performs tensor operation for two input. - For example, each sample: + This layer performs tensor operation on two inputs. + For example: .. math:: y_{i} = a * W_{i} * {b^\mathrm{T}}, i=0,1,...,K-1 @@ -4826,21 +4879,24 @@ def tensor_layer(a, :param name: The name of this layer. It is optional. :type name: basestring - :param a: Input layer a. + :param a: The first input of this layer. :type a: LayerOutput - :param b: input layer b. + :param b: The second input of this layer. :type b: LayerOutput - :param size: the layer dimension. - :type size: int. - :param act: Activation type. LinearActivation is the default. + :param size: The dimension of this layer. + :type size: int + :param act: Activation type. LinearActivation is the default activation. :type act: BaseActivation - :param param_attr: The Parameter Attribute. + :param param_attr: The parameter attribute. See ParameterAttribute for + details. :type param_attr: ParameterAttribute - :param bias_attr: The bias attribute. If the parameter is set to False or an object - whose type is not ParameterAttribute, no bias is defined. If the - parameter is set to True, the bias is initialized to zero. + :param bias_attr: The parameter attribute for bias. If this parameter is set to + False or an object whose type is not ParameterAttribute, + no bias is defined. If this parameter is set to True, + the bias is initialized to zero. :type bias_attr: ParameterAttribute | None | bool | Any - :param layer_attr: Extra Layer config. + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput @@ -4876,7 +4932,7 @@ def selective_fc_layer(input, layer_attr=None): """ Selectived fully connected layer. Different from fc_layer, the output - of this layer maybe sparse. It requires an additional input to indicate + of this layer can be sparse. It requires an additional input to indicate several selected columns for output. If the selected columns is not specified, selective_fc_layer acts exactly like fc_layer. @@ -4890,21 +4946,34 @@ def selective_fc_layer(input, :type name: basestring :param input: The input of this layer. :type input: LayerOutput | list | tuple - :param select: The select layer. The output of select layer should be a - sparse binary matrix, and treat as the mask of selective fc. - If is None, acts exactly like fc_layer. + :param select: The layer to select columns to output. It should be a sparse + binary matrix, and is treated as the mask of selective fc. If + it is not set or set to None, selective_fc_layer acts exactly + like fc_layer. :type select: LayerOutput - :param size: The layer dimension. + :param size: The dimension of this layer, which should be equal to that of + the layer 'select'. :type size: int - :param act: Activation type. TanhActivation is the default. + :param act: Activation type. TanhActivation is the default activation. :type act: BaseActivation - :param param_attr: The Parameter Attribute. + :param pass_generation: The flag which indicates whether it is during generation. + :type pass_generation: bool + :param has_selected_colums: The flag which indicates whether the parameter 'select' + has been set. True is the default. + :type has_selected_colums: bool + :param mul_ratio: A ratio helps to judge how sparse the output is and determine + the computation method for speed consideration. + :type mul_ratio: float + :param param_attr: The parameter attribute. See ParameterAttribute for + details. :type param_attr: ParameterAttribute - :param bias_attr: The bias attribute. If the parameter is set to False or an object - whose type is not ParameterAttribute, no bias is defined. If the - parameter is set to True, the bias is initialized to zero. + :param bias_attr: The parameter attribute for bias. If this parameter is set to + False or an object whose type is not ParameterAttribute, + no bias is defined. If this parameter is set to True, + the bias is initialized to zero. :type bias_attr: ParameterAttribute | None | bool | Any - :param layer_attr: Extra Layer config. + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. :type layer_attr: ExtraLayerAttribute | None :return: LayerOutput object. :rtype: LayerOutput @@ -4955,7 +5024,7 @@ def selective_fc_layer(input, @layer_support() def sampling_id_layer(input, name=None, layer_attr=None): """ - A layer for sampling id from multinomial distribution from the input layer. + A layer for sampling id from a multinomial distribution from the input layer. Sampling one id for one sample. The simple usage is: @@ -4968,8 +5037,9 @@ def sampling_id_layer(input, name=None, layer_attr=None): :type input: LayerOutput :param name: The name of this layer. It is optional. :type name: basestring - :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute | None + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. + :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. :rtype: LayerOutput """ @@ -4990,8 +5060,7 @@ def slope_intercept_layer(input, intercept=0.0, layer_attr=None): """ - This layer for applying a slope and an intercept to the input - element-wise. There is no activation and weight. + This layer for applying a slope and an intercept to the input. .. math:: y = slope * x + intercept @@ -5006,12 +5075,13 @@ def slope_intercept_layer(input, :type input: LayerOutput :param name: The name of this layer. It is optional. :type name: basestring - :param slope: the scale factor. - :type slope: float. - :param intercept: the offset. - :type intercept: float. - :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute | None + :param slope: The scale factor. + :type slope: float + :param intercept: The offset. + :type intercept: float + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. + :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. :rtype: LayerOutput """ @@ -5066,12 +5136,13 @@ def linear_comb_layer(weights, vectors, size=None, name=None, layer_attr=None): :type weights: LayerOutput :param vectors: The vector layer. :type vectors: LayerOutput - :param size: the dimension of this layer. + :param size: The dimension of this layer. :type size: int :param name: The name of this layer. It is optional. :type name: basestring - :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute | None + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. + :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. :rtype: LayerOutput """ @@ -5118,11 +5189,11 @@ def block_expand_layer(input, outputW = 1 + (2 * padding_x + imgSizeW - block_x + stride_x - 1) / stride_x - The expand method is the same with ExpandConvLayer, but saved the transposed + The expanding method is the same with ExpandConvLayer, but saved the transposed value. After expanding, output.sequenceStartPositions will store timeline. - The number of time steps are outputH * outputW and the dimension of each + The number of time steps is outputH * outputW and the dimension of each time step is block_y * block_x * num_channels. This layer can be used after - convolution neural network, and before recurrent neural network. + convolutional neural network, and before recurrent neural network. The simple usage is: @@ -5137,8 +5208,10 @@ def block_expand_layer(input, :param input: The input of this layer. :type input: LayerOutput - :param num_channels: The channel number of input layer. - :type num_channels: int | None + :param num_channels: The number of input channels. If the parameter is not set or + set to None, its actual value will be automatically set to + the channels number of the input. + :type num_channels: int :param block_x: The width of sub block. :type block_x: int :param block_y: The width of sub block. @@ -5152,9 +5225,10 @@ def block_expand_layer(input, :param padding_y: The padding size in vertical direction. :type padding_y: int :param name: The name of this layer. It is optional. - :type name: None | basestring. - :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute | None + :type name: basestring. + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. + :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. :rtype: LayerOutput """ @@ -5184,12 +5258,19 @@ def block_expand_layer(input, @layer_support() def maxout_layer(input, groups, num_channels=None, name=None, layer_attr=None): """ - A layer to do max out on conv layer output. - - Input: output of a conv layer. - - Output: feature map size same as input. Channel is (input channel) / groups. + A layer to do max out on convolutional layer output. + - Input: the output of a convolutional layer. + - Output: feature map size same as the input's, and its channel number is + (input channel) / groups. So groups should be larger than 1, and the num of channels should be able - to devided by groups. + to be devided by groups. + + Reference: + Maxout Networks + http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf + Multi-digit Number Recognition from Street View Imagery using Deep Convolutional Neural Networks + https://arxiv.org/pdf/1312.6082v4.pdf .. math:: y_{si+j} = \max_k x_{gsi + sk + j} @@ -5199,12 +5280,6 @@ def maxout_layer(input, groups, num_channels=None, name=None, layer_attr=None): 0 \le j < s 0 \le k < groups - Please refer to Paper: - - Maxout Networks: http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf - - Multi-digit Number Recognition from Street View \ - Imagery using Deep Convolutional Neural Networks: \ - https://arxiv.org/pdf/1312.6082v4.pdf - The simple usage is: .. code-block:: python @@ -5215,14 +5290,16 @@ def maxout_layer(input, groups, num_channels=None, name=None, layer_attr=None): :param input: The input of this layer. :type input: LayerOutput - :param num_channels: The channel number of input layer. If None will be set - automatically from previous output. - :type num_channels: int | None + :param num_channels: The number of input channels. If the parameter is not set or + set to None, its actual value will be automatically set to + the channels number of the input. + :type num_channels: int :param groups: The group number of input layer. :type groups: int :param name: The name of this layer. It is optional. - :type name: None | basestring. - :param layer_attr: Extra Layer attribute. + :type name: basestring + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. :rtype: LayerOutput @@ -5254,20 +5331,20 @@ def ctc_layer(input, layer_attr=None): """ Connectionist Temporal Classification (CTC) is designed for temporal - classication task. That is, for sequence labeling problems where the + classication task. e.g. sequence labeling problems where the alignment between the inputs and the target labels is unknown. - More details can be found by referring to `Connectionist Temporal - Classification: Labelling Unsegmented Sequence Data with Recurrent - Neural Networks `_ + Reference: + Connectionist Temporal Classification: Labelling Unsegmented Sequence Data + with Recurrent Neural Networks + http://machinelearning.wustl.edu/mlpapers/paper_files/icml2006_GravesFGS06.pdf Note: - Considering the 'blank' label needed by CTC, you need to use - (num_classes + 1) as the input size. num_classes is the category number. - And the 'blank' is the last category index. So the size of 'input' layer, such as - fc_layer with softmax activation, should be num_classes + 1. The size of ctc_layer - should also be num_classes + 1. + Considering the 'blank' label needed by CTC, you need to use (num_classes + 1) + as the size of the input, where num_classes is the category number. + And the 'blank' is the last category index. So the size of 'input' layer (e.g. + fc_layer with softmax activation) should be (num_classes + 1). The size of + ctc_layer should also be (num_classes + 1). The example usage is: @@ -5280,16 +5357,17 @@ def ctc_layer(input, :param input: The input of this layer. :type input: LayerOutput - :param label: The data layer of label with variable length. + :param label: The input label. :type label: LayerOutput - :param size: category numbers + 1. + :param size: The dimension of this layer, which must be equal to (category number + 1). :type size: int :param name: The name of this layer. It is optional. - :type name: basestring | None - :param norm_by_times: Whether to normalization by times. False by default. + :type name: basestring + :param norm_by_times: Whether to do normalization by times. False is the default. :type norm_by_times: bool - :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute | None + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. + :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. :rtype: LayerOutput """ @@ -5330,20 +5408,19 @@ def warp_ctc_layer(input, building process, PaddlePaddle will clone the source codes, build and install it to :code:`third_party/install/warpctc` directory. - More details of CTC can be found by referring to `Connectionist Temporal - Classification: Labelling Unsegmented Sequence Data with Recurrent - Neural Networks `_. + Reference: + Connectionist Temporal Classification: Labelling Unsegmented Sequence Data + with Recurrent Neural Networks + http://machinelearning.wustl.edu/mlpapers/paper_files/icml2006_GravesFGS06.pdf Note: - - Let num_classes represent the category number. Considering the 'blank' - label needed by CTC, you need to use (num_classes + 1) as the input size. - Thus, the size of both warp_ctc layer and 'input' layer should be set to - num_classes + 1. + - Let num_classes represents the category number. Considering the 'blank' + label needed by CTC, you need to use (num_classes + 1) as the size of + warp_ctc layer. - You can set 'blank' to any value ranged in [0, num_classes], which - should be consistent as that used in your labels. + should be consistent with those used in your labels. - As a native 'softmax' activation is interated to the warp-ctc library, - 'linear' activation is expected instead in the 'input' layer. + 'linear' activation is expected to be used instead in the 'input' layer. The example usage is: @@ -5357,18 +5434,19 @@ def warp_ctc_layer(input, :param input: The input of this layer. :type input: LayerOutput - :param label: The data layer of label with variable length. + :param label: The input label. :type label: LayerOutput - :param size: category numbers + 1. + :param size: The dimension of this layer, which must be equal to (category number + 1). :type size: int :param name: The name of this layer. It is optional. - :type name: basestring | None - :param blank: the 'blank' label used in ctc + :type name: basestring + :param blank: The 'blank' label used in ctc. :type blank: int - :param norm_by_times: Whether to normalization by times. False by default. + :param norm_by_times: Whether to do normalization by times. False is the default. :type norm_by_times: bool - :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute | None + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. + :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. :rtype: LayerOutput """ @@ -5414,23 +5492,26 @@ def crf_layer(input, label=label, size=label_dim) - :param input: The first input layer is the feature. + :param input: The first input layer. :type input: LayerOutput - :param label: The second input layer is label. + :param label: The input label. :type label: LayerOutput :param size: The category number. :type size: int - :param weight: The third layer is "weight" of each sample, which is an - optional argument. + :param weight: The weight layer defines a weight for each sample in the + mini-batch. It is optional. :type weight: LayerOutput - :param param_attr: Parameter attribute. None means default attribute + :param param_attr: The parameter attribute. See ParameterAttribute for + details. :type param_attr: ParameterAttribute :param name: The name of this layer. It is optional. - :type name: None | basestring - :param coeff: The coefficient affects the gradient in the backward. + :type name: basestring + :param coeff: The weight of the gradient in the back propagation. + 1.0 is the default value. :type coeff: float - :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute | None + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. + :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. :rtype: LayerOutput """ @@ -5476,9 +5557,9 @@ def crf_decoding_layer(input, """ A layer for calculating the decoding sequence of sequential conditional random field model. The decoding sequence is stored in output.ids. - If a second input is provided, it is treated as the ground-truth label, and - this layer will also calculate error. output.value[i] is 1 for incorrect - decoding or 0 for correct decoding. + If the input 'label' is provided, it is treated as the ground-truth label, and + this layer will also calculate error. output.value[i] is 1 for an incorrect + decoding and 0 for the correct. The example usage is: @@ -5489,16 +5570,18 @@ def crf_decoding_layer(input, :param input: The first input layer. :type input: LayerOutput - :param size: size of this layer. + :param size: The dimension of this layer. :type size: int - :param label: None or ground-truth label. - :type label: LayerOutput or None - :param param_attr: Parameter attribute. None means default attribute + :param label: The input label. + :type label: LayerOutput | None + :param param_attr: The parameter attribute. See ParameterAttribute for + details. :type param_attr: ParameterAttribute :param name: The name of this layer. It is optional. - :type name: None | basestring - :param layer_attr: Extra Layer config. - :type layer_attr: ExtraLayerAttribute | None + :type name: basestring + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. + :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. :rtype: LayerOutput """ @@ -5545,8 +5628,7 @@ def nce_layer(input, bias_attr=None, layer_attr=None): """ - Noise-contrastive estimation. This layer implements the method in the - following paper: + Noise-contrastive estimation. Reference: A fast and simple algorithm for training neural probabilistic language @@ -5562,37 +5644,40 @@ def nce_layer(input, :param name: The name of this layer. It is optional. :type name: basestring - :param input: The input layers. It should be a LayerOutput or a list/tuple - of LayerOutput. + :param input: The first input of this layer. :type input: LayerOutput | list | tuple | collections.Sequence - :param label: The ground truth. + :param label: The input label. :type label: LayerOutput :param weight: The weight layer defines a weight for each sample in the - mini-batch. The default value is None. + mini-batch. It is optional. :type weight: LayerOutput - :param num_classes: The class number. + :param num_classes: The number of classes. :type num_classes: int - :param param_attr: The parameter attributes. - :type param_attr: ParameterAttribute|list - :param num_neg_samples: The number of sampled negative labels. The default - value is 10. + :param act: Activation type. SigmoidActivation is the default activation. + :type act: BaseActivation + :param param_attr: The parameter attribute. See ParameterAttribute for + details. + :type param_attr: ParameterAttribute + :param num_neg_samples: The number of sampled negative labels. 10 is the + default value. :type num_neg_samples: int :param neg_distribution: The discrete noisy distribution over the output space from which num_neg_samples negative labels are sampled. If this parameter is not set, a - uniform distribution will be used. A user defined + uniform distribution will be used. A user-defined distribution is a list whose length must be equal to the num_classes. Each member of the list defines the probability of a class given input x. :type neg_distribution: list | tuple | collections.Sequence | None - :param bias_attr: The attribute for bias. If this parameter is set False or - any object whose type is not ParameterAttribute, no bias - is added. If this parameter is set True, the bias is - initialized to zero. + :param bias_attr: The parameter attribute for bias. If this parameter is set to + False or an object whose type is not ParameterAttribute, + no bias is defined. If this parameter is set to True, + the bias is initialized to zero. :type bias_attr: ParameterAttribute | None | bool | Any - :param layer_attr: Extra Layer Attribute. + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. :type layer_attr: ExtraLayerAttribute - :return: The LayerOutput object. + :return: LayerOutput object. :rtype: LayerOutput """ if isinstance(input, LayerOutput): @@ -5659,11 +5744,11 @@ def rank_cost(left, coeff=1.0, layer_attr=None): """ - A cost Layer for learning to rank using gradient descent. Details can refer - to `papers `_. - This layer contains at least three inputs. The weight is an optional - argument, which affects the cost. + A cost Layer for learning to rank using gradient descent. + + Reference: + Learning to Rank using Gradient Descent + http://research.microsoft.com/en-us/um/people/cburges/papers/ICML_ranking.pdf .. math:: @@ -5694,14 +5779,16 @@ def rank_cost(left, :type right: LayerOutput :param label: Label is 1 or 0, means positive order and reverse order. :type label: LayerOutput - :param weight: The weight affects the cost, namely the scale of cost. - It is an optional argument. + :param weight: The weight layer defines a weight for each sample in the + mini-batch. It is optional. :type weight: LayerOutput :param name: The name of this layer. It is optional. - :type name: None | basestring - :param coeff: The coefficient affects the gradient in the backward. + :type name: basestring + :param coeff: The weight of the gradient in the back propagation. + 1.0 is the default value. :type coeff: float - :param layer_attr: Extra Layer Attribute. + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. :rtype: LayerOutput @@ -5746,25 +5833,25 @@ def lambda_cost(input, NDCG_num=8, max_sort_size=-1) - :param input: Samples of the same query should be loaded as sequence. + :param input: The first input of this layer, which is often a document + samples list of the same query and whose type must be sequence. :type input: LayerOutput - :param score: The 2nd input. Score of each sample. + :param score: The scores of the samples. :type input: LayerOutput :param NDCG_num: The size of NDCG (Normalized Discounted Cumulative Gain), e.g., 5 for NDCG@5. It must be less than or equal to the - minimum size of lists. + minimum size of the list. :type NDCG_num: int - :param max_sort_size: The size of partial sorting in calculating gradient. - If max_sort_size = -1, then for each list, the - algorithm will sort the entire list to get gradient. - In other cases, max_sort_size must be greater than or - equal to NDCG_num. And if max_sort_size is greater - than the size of a list, the algorithm will sort the - entire list of get gradient. + :param max_sort_size: The size of partial sorting in calculating gradient. If + max_sort_size is equal to -1 or greater than the number + of the samples in the list, then the algorithm will sort + the entire list to compute the gradient. In other cases, + max_sort_size must be greater than or equal to NDCG_num. :type max_sort_size: int :param name: The name of this layer. It is optional. - :type name: None | basestring - :param layer_attr: Extra Layer Attribute. + :type name: basestring + :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for + details. :type layer_attr: ExtraLayerAttribute :return: LayerOutput object. :rtype: LayerOutput @@ -5809,11 +5896,10 @@ def cross_entropy(input, :param name: The name of this layer. It is optional. :type name: basestring :param coeff: The weight of the gradient in the back propagation. - 1.0 is the default. + 1.0 is the default value. :type coeff: float - :param weight: The cost of each sample is multiplied with each weight. - The weight should be a layer with size=1. Note that gradient - will not be calculated for weight. + :param weight: The weight layer defines a weight for each sample in the + mini-batch. It is optional. :type weight: LayerOutout :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for details. @@ -5858,7 +5944,7 @@ def cross_entropy_with_selfnorm(input, :param name: The name of this layer. It is optional. :type name: basestring :param coeff: The weight of the gradient in the back propagation. - 1.0 is the default. + 1.0 is the default value. :type coeff: float :param softmax_selfnorm_alpha: The scale factor affects the cost. :type softmax_selfnorm_alpha: float @@ -5948,7 +6034,7 @@ def huber_regression_cost(input, :param delta: The difference between the observed and predicted values. :type delta: float :param coeff: The weight of the gradient in the back propagation. - 1.0 is the default. + 1.0 is the default value. :type coeff: float :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for details. @@ -5998,7 +6084,7 @@ def huber_classification_cost(input, :param name: The name of this layer. It is optional. :type name: basestring :param coeff: The weight of the gradient in the back propagation. - 1.0 is the default. + 1.0 is the default value. :type coeff: float :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for details. @@ -6043,7 +6129,7 @@ def multi_binary_label_cross_entropy(input, :param name: The name of this layer. It is optional. :type name: basestring :param coeff: The weight of the gradient in the back propagation. - 1.0 is the default. + 1.0 is the default value. :type coeff: float :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for details. @@ -6214,7 +6300,7 @@ def smooth_l1_cost(input, label, name=None, coeff=1.0, layer_attr=None): :param name: The name of this layer. It is optional. :type name: basestring :param coeff: The weight of the gradient in the back propagation. - 1.0 is the default. + 1.0 is the default value. :type coeff: float :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for details. @@ -6366,7 +6452,7 @@ def row_conv_layer(input, :param context_len: The context length equals the lookahead step number plus one. :type context_len: int - :param act: Activation Type. LinearActivation is the default. + :param act: Activation Type. LinearActivation is the default activation. :type act: BaseActivation :param param_attr: The parameter attribute. See ParameterAttribute for details. @@ -6488,7 +6574,8 @@ def gated_unit_layer(input, :type input: LayerOutput :param size: The dimension of this layer's output. :type size: int - :param act: Activation type of the projection. LinearActivation is the default. + :param act: Activation type of the projection. LinearActivation is the default + activation. :type act: BaseActivation :param name: The name of this layer. It is optional. :type name: basestring @@ -6498,9 +6585,9 @@ def gated_unit_layer(input, :param gate_param_attr: The parameter attribute of the gate. See ParameterAttribute for details. :type gate_param_attr: ParameterAttribute - :param gate_bias_attr: The bias attribute of the gate. If the parameter is set to False or + :param gate_bias_attr: The bias attribute of the gate. If this parameter is set to False or an object whose type is not ParameterAttribute, no bias is defined. - If the parameter is set to True, the bias is initialized to zero. + If this parameter is set to True, the bias is initialized to zero. :type gate_bias_attr: ParameterAttribute | bool | None | Any :param inproj_attr: Extra layer attributes of the projection. See ExtraLayerAttribute for details. @@ -6508,9 +6595,9 @@ def gated_unit_layer(input, :param inproj_param_attr: The parameter attribute of the projection. See ParameterAttribute for details. :type inproj_param_attr: ParameterAttribute - :param inproj_bias_attr: The bias attribute of the projection. If the parameter is set to False + :param inproj_bias_attr: The bias attribute of the projection. If this parameter is set to False or an object whose type is not ParameterAttribute, no bias is defined. - If the parameter is set to True, the bias is initialized to zero. + If this parameter is set to True, the bias is initialized to zero. :type inproj_bias_attr: ParameterAttribute | bool | None | Any :param layer_attr: Extra layer attribute of the product. See ExtraLayerAttribute for details. @@ -6869,7 +6956,7 @@ def img_conv3d_layer(input, :type filter_size: int | tuple | list :param num_filters: The number of filters in each group. :type num_filters: int - :param act: Activation type. ReluActivation is the default. + :param act: Activation type. ReluActivation is the default activation. :type act: BaseActivation :param groups: The number of the filter groups. :type groups: int @@ -6884,8 +6971,8 @@ def img_conv3d_layer(input, parameter is set to True, the bias is initialized to zero. :type bias_attr: ParameterAttribute | None | bool | Any :param num_channels: The number of input channels. If the parameter is not set or - set to None, its actual value will be automatically set to - the channels number of the input . + set to None, its actual value will be automatically set to + the channels number of the input. :type num_channels: int :param param_attr: The parameter attribute of the convolution. See ParameterAttribute for details. @@ -7061,7 +7148,7 @@ def sub_seq_layer(input, offsets, sizes, act=None, bias_attr=None, name=None): :type offsets: LayerOutput :param sizes: The sizes of the sub-sequences, which should be sequence type. :type sizes: LayerOutput - :param act: Activation type, LinearActivation is the default. + :param act: Activation type, LinearActivation is the default activation. :type act: BaseActivation. :param bias_attr: The bias attribute. If the parameter is set to False or an object whose type is not ParameterAttribute, no bias is defined. If the diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index 3821d075cb..d323d34c3f 100644 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -681,34 +681,42 @@ def lstmemory_unit(input, state_act=TanhActivation()) - :param input: input layer. + :param input: Input layer. :type input: LayerOutput - :param out_memory: output of previous time step + :param out_memory: The output of previous time step. :type out_memory: LayerOutput | None - :param name: lstmemory unit name. + :param name: The lstmemory unit name. :type name: basestring - :param size: lstmemory unit size. + :param size: The lstmemory unit size. :type size: int - :param param_attr: parameter attribute, None means default attribute. + :param param_attr: The parameter attribute for the weights in + input to hidden projection. + None means default attribute. :type param_attr: ParameterAttribute - :param act: last activiation type of lstm. + :param act: The last activiation type of lstm. :type act: BaseActivation - :param gate_act: gate activiation type of lstm. + :param gate_act: The gate activiation type of lstm. :type gate_act: BaseActivation - :param state_act: state activiation type of lstm. + :param state_act: The state activiation type of lstm. :type state_act: BaseActivation - :param input_proj_bias_attr: bias attribute for input to hidden projection. - False means no bias, None means default bias. - :type input_proj_bias_attr: ParameterAttribute|False|None - :param input_proj_layer_attr: extra layer attribute for input to hidden - projection of the LSTM unit, such as dropout, error clipping. + :param input_proj_bias_attr: The parameter attribute for the bias in + input to hidden projection. + False or None means no bias. + If this parameter is set to True, + the bias is initialized to zero. + :type input_proj_bias_attr: ParameterAttribute|bool|None + :param input_proj_layer_attr: The extra layer attribute for + input to hidden projection of the LSTM unit, + such as dropout, error clipping. :type input_proj_layer_attr: ExtraLayerAttribute - :param lstm_bias_attr: bias parameter attribute of lstm layer. - False means no bias, None means default bias. - :type lstm_bias_attr: ParameterAttribute|False|None - :param lstm_layer_attr: extra attribute of lstm layer. + :param lstm_bias_attr: The parameter attribute for the bias in lstm layer. + False or None means no bias. + If this parameter is set to True, + the bias is initialized to zero. + :type lstm_bias_attr: ParameterAttribute|True|None + :param lstm_layer_attr: The extra attribute of lstm layer. :type lstm_layer_attr: ExtraLayerAttribute - :return: lstmemory unit name. + :return: The lstmemory unit name. :rtype: LayerOutput """ if size is None: @@ -786,34 +794,42 @@ def lstmemory_group(input, gate_act=SigmoidActivation(), state_act=TanhActivation()) - :param input: input layer. + :param input: Input layer. :type input: LayerOutput - :param size: lstmemory group size. + :param size: The lstmemory group size. :type size: int - :param name: name of lstmemory group. + :param name: The name of lstmemory group. :type name: basestring - :param out_memory: output of previous time step. + :param out_memory: The output of previous time step. :type out_memory: LayerOutput | None - :param reverse: process the input in a reverse order or not. + :param reverse: Process the input in a reverse order or not. :type reverse: bool - :param param_attr: parameter attribute, None means default attribute. + :param param_attr: The parameter attribute for the weights in + input to hidden projection. + None means default attribute. :type param_attr: ParameterAttribute - :param act: last activiation type of lstm. + :param act: The last activiation type of lstm. :type act: BaseActivation - :param gate_act: gate activiation type of lstm. + :param gate_act: The gate activiation type of lstm. :type gate_act: BaseActivation - :param state_act: state activiation type of lstm. + :param state_act: The state activiation type of lstm. :type state_act: BaseActivation - :param lstm_bias_attr: bias parameter attribute of lstm layer. - False means no bias, None means default bias. - :type lstm_bias_attr: ParameterAttribute|False|None - :param input_proj_bias_attr: bias attribute for input to hidden projection. - False means no bias, None means default bias. - :type input_proj_bias_attr: ParameterAttribute|False|None - :param input_proj_layer_attr: extra layer attribute for input to hidden - projection of the LSTM unit, such as dropout, error clipping. + :param input_proj_bias_attr: The parameter attribute for the bias in + input to hidden projection. + False or None means no bias. + If this parameter is set to True, + the bias is initialized to zero. + :type input_proj_bias_attr: ParameterAttribute|bool|None + :param input_proj_layer_attr: The extra layer attribute for + input to hidden projection of the LSTM unit, + such as dropout, error clipping. :type input_proj_layer_attr: ExtraLayerAttribute - :param lstm_layer_attr: lstm layer's extra attribute. + :param lstm_bias_attr: The parameter attribute for the bias in lstm layer. + False or None means no bias. + If this parameter is set to True, + the bias is initialized to zero. + :type lstm_bias_attr: ParameterAttribute|True|None + :param lstm_layer_attr: The extra attribute of lstm layer. :type lstm_layer_attr: ExtraLayerAttribute :return: the lstmemory group. :rtype: LayerOutput diff --git a/python/paddle/trainer_config_helpers/poolings.py b/python/paddle/trainer_config_helpers/poolings.py index 0c38a8dce5..f45616551b 100644 --- a/python/paddle/trainer_config_helpers/poolings.py +++ b/python/paddle/trainer_config_helpers/poolings.py @@ -15,8 +15,8 @@ """ __all__ = [ - "BasePoolingType", "MaxPooling", "AvgPooling", "CudnnMaxPooling", - "CudnnAvgPooling", "SumPooling", "SquareRootNPooling" + "BasePoolingType", "MaxPooling", "AvgPooling", "MaxWithMaskPooling", + "CudnnMaxPooling", "CudnnAvgPooling", "SumPooling", "SquareRootNPooling" ] @@ -55,6 +55,19 @@ class MaxPooling(BasePoolingType): self.output_max_index = output_max_index +class MaxWithMaskPooling(BasePoolingType): + """ + MaxWithMask pooling. + + Not only return the very large values for each dimension in sequence or time steps, + but also the location indices of found maxinum values. + + """ + + def __init__(self): + BasePoolingType.__init__(self, "max-pool-with-mask") + + class CudnnMaxPooling(BasePoolingType): """ Cudnn max pooling only support GPU. Return the maxinum value in the diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr index 5ddf6052df..b14121e82c 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr @@ -28,6 +28,8 @@ layers { stride_y: 1 output_y: 227 img_size_y: 256 + dilation: 1 + dilation_y: 1 } } bias_parameter_name: "___conv_0__.wbias" diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr index c0252b945b..c7a487a112 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr @@ -28,6 +28,8 @@ layers { stride_y: 1 output_y: 227 img_size_y: 256 + dilation: 1 + dilation_y: 1 } } bias_parameter_name: "___conv_0__.wbias" diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bilinear_interp.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bilinear_interp.protostr index fd5224ca55..25ec632375 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bilinear_interp.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bilinear_interp.protostr @@ -28,6 +28,8 @@ layers { stride_y: 1 output_y: 48 img_size_y: 48 + dilation: 1 + dilation_y: 1 } } bias_parameter_name: "___conv_0__.wbias" diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_maxout.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_maxout.protostr index 03f4f3a31d..39dc487146 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_maxout.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_maxout.protostr @@ -30,6 +30,8 @@ layers { stride_y: 1 output_y: 48 img_size_y: 48 + dilation: 1 + dilation_y: 1 } } bias_parameter_name: "___conv_0__.wbias" @@ -105,6 +107,8 @@ layers { stride_y: 1 output_y: 24 img_size_y: 24 + dilation: 1 + dilation_y: 1 } } bias_parameter_name: "___conv_1__.wbias" diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_pad.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_pad.protostr index 15c6ab4dc8..d5d6d31a17 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_pad.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_pad.protostr @@ -30,6 +30,8 @@ layers { stride_y: 1 output_y: 48 img_size_y: 48 + dilation: 1 + dilation_y: 1 } } bias_parameter_name: "___conv_0__.wbias" diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_roi_pool_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_roi_pool_layer.protostr index f1bc65b3ae..0ec88aa998 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_roi_pool_layer.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_roi_pool_layer.protostr @@ -36,6 +36,8 @@ layers { stride_y: 1 output_y: 14 img_size_y: 14 + dilation: 1 + dilation_y: 1 } } bias_parameter_name: "___conv_0__.wbias" diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index 1c8d8f4b2f..3d70513843 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -37,6 +37,8 @@ import model import paddle.trainer.config_parser as cp __all__ = [ + 'default_startup_program', + 'default_main_program', 'optimizer', 'layer', 'activation', diff --git a/python/paddle/v2/framework/.gitignore b/python/paddle/v2/fluid/.gitignore similarity index 100% rename from python/paddle/v2/framework/.gitignore rename to python/paddle/v2/fluid/.gitignore diff --git a/python/paddle/v2/framework/__init__.py b/python/paddle/v2/fluid/__init__.py similarity index 100% rename from python/paddle/v2/framework/__init__.py rename to python/paddle/v2/fluid/__init__.py diff --git a/python/paddle/v2/framework/backward.py b/python/paddle/v2/fluid/backward.py similarity index 97% rename from python/paddle/v2/framework/backward.py rename to python/paddle/v2/fluid/backward.py index 678efd5d20..f188582178 100644 --- a/python/paddle/v2/framework/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -1,4 +1,4 @@ -from paddle.v2.framework import framework as framework +from paddle.v2.fluid import framework as framework __all__ = ['append_backward_ops'] diff --git a/python/paddle/v2/framework/default_scope_funcs.py b/python/paddle/v2/fluid/default_scope_funcs.py similarity index 92% rename from python/paddle/v2/framework/default_scope_funcs.py rename to python/paddle/v2/fluid/default_scope_funcs.py index c07f9a6ab9..60c6165b6b 100644 --- a/python/paddle/v2/framework/default_scope_funcs.py +++ b/python/paddle/v2/fluid/default_scope_funcs.py @@ -13,7 +13,7 @@ A `scoped_function` will take a `function` as input. That function will be invoked in a new local scope. """ -import paddle.v2.framework.core +import paddle.v2.fluid.core import threading __tl_scope__ = threading.local() @@ -27,13 +27,13 @@ __all__ = [ def get_cur_scope(): """ Get current scope. - :rtype: paddle.v2.framework.core.Scope + :rtype: paddle.v2.fluid.core.Scope """ cur_scope_stack = getattr(__tl_scope__, 'cur_scope', None) if cur_scope_stack is None: __tl_scope__.cur_scope = list() if len(__tl_scope__.cur_scope) == 0: - __tl_scope__.cur_scope.append(paddle.v2.framework.core.Scope()) + __tl_scope__.cur_scope.append(paddle.v2.fluid.core.Scope()) return __tl_scope__.cur_scope[-1] diff --git a/python/paddle/v2/fluid/evaluator.py b/python/paddle/v2/fluid/evaluator.py new file mode 100644 index 0000000000..3a8f1831cf --- /dev/null +++ b/python/paddle/v2/fluid/evaluator.py @@ -0,0 +1,187 @@ +import numpy as np +from paddle.v2.fluid.framework import Program, g_main_program, unique_name, Variable +import paddle.v2.fluid.core as core + + +def _clone_var_in_block_(block, var): + assert isinstance(var, Variable) + return block.create_var( + name=var.name, + shape=var.shape, + dtype=var.data_type, + type=var.type, + lod_level=var.lod_level, + persistable=True) + + +class Evaluator(object): + """ + Evalutor Base class. + + create metric states + add mini-batch evaluator caculate operator + add increment operator to accumulate the metric states + """ + + def __init__(self, name, **kwargs): + """ + init the global states + """ + self._states = {} + if kwargs.has_key("main_program"): + self._main_program = kwargs.get("main_program") + else: + self._main_program = g_main_program + + def _update_ops(self, *args, **kwargs): + """ + append update ops to the global states + """ + raise NotImplementedError() + + def reset(self, executor, reset_program=None): + """ + Clear metric states at the begin of each pass/user specified batch + """ + if reset_program == None: + reset_program = Program() + else: + reset_program = program + block = reset_program.global_block() + for k, var in self._states.iteritems(): + g_var = _clone_var_in_block_(block, var) + zeros = block.create_var(dtype="float32", persistable=True) + block.append_op( + type="fill_constant", + outputs={"Out": [zeros]}, + attrs={ + "shape": g_var.shape, + "value": .0, + "data_type": 5, + }) + block.append_op( + type="scale", inputs={"X": zeros}, outputs={"Out": g_var}) + executor.run(reset_program, fetch_list=self._states.values()) + + def eval(self, executor, eval_program=None): + """ + Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. + """ + raise NotImplementedError() + + +class Accuracy(Evaluator): + """ + Accuracy need two state variable Total, Correct + """ + + def __init__(self, *args, **kwargs): + super(Accuracy, self).__init__("accuracy", **kwargs) + block = self._main_program.global_block() + g_total = block.create_var( + name=unique_name("Total"), + persistable=True, + dtype="int64", + shape=[1]) + g_correct = block.create_var( + name=unique_name("Correct"), + persistable=True, + dtype="int64", + shape=[1]) + self._states["Total"] = g_total + self._states["Correct"] = g_correct + + def _update_ops(self, input, label, k=1, **kwargs): + block = self._main_program.global_block() + topk_out = block.create_var(dtype=input.data_type) + topk_indices = block.create_var(dtype="int64") + block.append_op( + type="top_k", + inputs={"X": [input]}, + outputs={"Out": [topk_out], + "Indices": [topk_indices]}, + attrs={"k": k}) + acc_out = block.create_var(dtype=kwargs.get("out_dtype", "float32")) + correct = block.create_var(dtype="int64", persistable=True) + total = block.create_var(dtype="int64", persistable=True) + block.append_op( + type="accuracy", + inputs={ + "Out": [topk_out], + "Indices": [topk_indices], + "Label": [label] + }, + outputs={ + "Accuracy": [acc_out], + "Correct": [correct], + "Total": [total], + }) + + block.append_op( + type="cast", + inputs={"X": [self._states["Total"]]}, + outputs={"Out": [self._states["Total"]]}, + attrs={ + "in_data_type": 5, # float32 + "out_data_type": 2, #int32 + }) + block.append_op( + type="cast", + inputs={"X": [self._states["Correct"]]}, + outputs={"Out": [self._states["Correct"]]}, + attrs={ + "in_data_type": 5, + "out_data_type": 2, + }) + + block.append_op( + type="elementwise_add", + inputs={"X": [self._states["Total"]], + "Y": [total]}, + outputs={"Out": [self._states["Total"]]}) + block.append_op( + type="elementwise_add", + inputs={"X": [self._states["Correct"]], + "Y": [correct]}, + outputs={"Out": [self._states["Correct"]]}) + + return acc_out + + def eval(self, executor, eval_program=None): + if eval_program != None: + eval_program = eval_program + else: + eval_program = Program() + block = eval_program.global_block() + eval_out = block.create_var(dtype=self._states["Total"].data_type) + e_total = _clone_var_in_block_(block, self._states["Total"]) + e_correct = _clone_var_in_block_(block, self._states["Correct"]) + block.append_op( + type="cast", + inputs={"X": [e_total]}, + outputs={"Out": [e_total]}, + attrs={ + "in_data_type": 2, #int32 + "out_data_type": 5, #float32 + }) + block.append_op( + type="cast", + inputs={"X": [e_correct]}, + outputs={"Out": [e_correct]}, + attrs={ + "in_data_type": 2, + "out_data_type": 5, + }) + block.append_op( + type="elementwise_div", + inputs={"X": e_correct, + "Y": e_total}, + outputs={"Out": eval_out}) + out = executor.run(eval_program, fetch_list=[eval_out]) + return np.array(out[0]) + + +def accuracy(*args, **kwargs): + cls = Accuracy(*args, **kwargs) + out = cls._update_ops(*args, **kwargs) + return cls, out diff --git a/python/paddle/v2/framework/executor.py b/python/paddle/v2/fluid/executor.py similarity index 94% rename from python/paddle/v2/framework/executor.py rename to python/paddle/v2/fluid/executor.py index f5c833190e..ed1c2c06da 100644 --- a/python/paddle/v2/framework/executor.py +++ b/python/paddle/v2/fluid/executor.py @@ -1,5 +1,5 @@ -import paddle.v2.framework.core as core -from paddle.v2.framework.framework import Block, Program, g_main_program +import paddle.v2.fluid.core as core +from paddle.v2.fluid.framework import Block, Program, g_main_program g_scope = core.Scope() diff --git a/python/paddle/v2/framework/framework.py b/python/paddle/v2/fluid/framework.py similarity index 97% rename from python/paddle/v2/framework/framework.py rename to python/paddle/v2/fluid/framework.py index b9db2707c0..f20567243a 100644 --- a/python/paddle/v2/framework/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -1,10 +1,10 @@ -import paddle.v2.framework.core as core -import paddle.v2.framework.proto.framework_pb2 as framework_pb2 +import paddle.v2.fluid.core as core +import paddle.v2.fluid.proto.framework_pb2 as framework_pb2 import collections import numpy as np import copy -__all__ = ['Block', 'Variable', 'Program', 'Operator'] +__all__ = ['Block', 'Variable', 'Program', 'Operator', 'default_startup_program', 'default_main_program'] def unique_name(prefix): @@ -285,7 +285,7 @@ class Operator(object): self.desc.check_attrs() no_kernel_op_set = { 'feed', 'fetch', 'save', 'load', 'recurrent', - 'rnn_memory_helper_grad', 'while' + 'rnn_memory_helper_grad', 'conditional_block', 'while' } if type not in no_kernel_op_set: self.desc.infer_var_type(self.block.desc) @@ -562,3 +562,9 @@ class Parameter(Variable): # program is a global instance. g_main_program = Program() g_startup_program = Program() + +def default_startup_program(): + return g_startup_program + +def default_main_program(): + return g_main_program diff --git a/python/paddle/v2/framework/initializer.py b/python/paddle/v2/fluid/initializer.py similarity index 99% rename from python/paddle/v2/framework/initializer.py rename to python/paddle/v2/fluid/initializer.py index 98a87bfa86..ded144ecd5 100644 --- a/python/paddle/v2/framework/initializer.py +++ b/python/paddle/v2/fluid/initializer.py @@ -1,4 +1,4 @@ -import paddle.v2.framework.framework as framework +import paddle.v2.fluid.framework as framework import numpy as np __all__ = [ diff --git a/python/paddle/v2/framework/io.py b/python/paddle/v2/fluid/io.py similarity index 98% rename from python/paddle/v2/framework/io.py rename to python/paddle/v2/fluid/io.py index 5c247904a3..394a171c67 100644 --- a/python/paddle/v2/framework/io.py +++ b/python/paddle/v2/fluid/io.py @@ -1,7 +1,7 @@ import os import cPickle as pickle -from paddle.v2.framework.framework import Program, Parameter, g_main_program, \ +from paddle.v2.fluid.framework import Program, Parameter, g_main_program, \ Variable __all__ = [ diff --git a/python/paddle/v2/framework/layer_helper.py b/python/paddle/v2/fluid/layer_helper.py similarity index 98% rename from python/paddle/v2/framework/layer_helper.py rename to python/paddle/v2/fluid/layer_helper.py index 552976185d..9dc3c119ea 100644 --- a/python/paddle/v2/framework/layer_helper.py +++ b/python/paddle/v2/fluid/layer_helper.py @@ -1,9 +1,9 @@ import copy import itertools -from paddle.v2.framework.framework import Variable, g_main_program, \ +from paddle.v2.fluid.framework import Variable, g_main_program, \ g_startup_program, unique_name, Program -from paddle.v2.framework.initializer import ConstantInitializer, \ +from paddle.v2.fluid.initializer import ConstantInitializer, \ UniformInitializer, XavierInitializer diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/fluid/layers.py similarity index 89% rename from python/paddle/v2/framework/layers.py rename to python/paddle/v2/fluid/layers.py index fe3c86febb..b582f2ef6d 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/fluid/layers.py @@ -1,17 +1,17 @@ -import paddle.v2.framework.core as core -import paddle.v2.framework.proto.framework_pb2 as framework_pb2 -from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, \ +import paddle.v2.fluid.core as core +import paddle.v2.fluid.proto.framework_pb2 as framework_pb2 +from paddle.v2.fluid.framework import OpProtoHolder, Variable, Program, \ Operator -from paddle.v2.framework.initializer import ConstantInitializer, \ +from paddle.v2.fluid.initializer import ConstantInitializer, \ NormalInitializer -from paddle.v2.framework.layer_helper import LayerHelper, unique_name +from paddle.v2.fluid.layer_helper import LayerHelper, unique_name import re import cStringIO __all__ = [ 'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat', 'StaticRNN', 'cast', 'sequence_conv', 'sequence_pool', 'sums', 'cos_sim', - 'batch_norm', 'accuracy' + 'batch_norm', 'accuracy', 'split_lod_tensor' ] @@ -226,6 +226,11 @@ def data(name, stop_gradient=stop_gradient) +def create_tensor(dtype, name=None, main_program=None): + helper = LayerHelper("create_tensor", **locals()) + return helper.create_variable(name=helper.name, dtype=dtype) + + def _convert_(name): """ Formatting. @@ -451,6 +456,56 @@ def sums(input, main_program=None, startup_program=None): return out +def assign(input, output, main_program=None): + helper = LayerHelper('assign', **locals()) + helper.append_op( + type='scale', + inputs={'X': [input]}, + outputs={'Out': [output]}, + attrs={'scale': 1.0}) + return output + + +def split_lod_tensor(input, + mask, + level, + main_program=None, + startup_program=None): + helper = LayerHelper('split_lod_tensor', **locals()) + out_true = helper.create_tmp_variable(dtype=input.data_type) + out_false = helper.create_tmp_variable(dtype=input.data_type) + helper.append_op( + type='split_lod_tensor', + inputs={ + 'X': input, + 'Mask': mask, + }, + outputs={'OutTrue': out_true, + 'OutFalse': out_false}, + attrs={'level': level}) + return out_true, out_false + + +def merge_lod_tensor(in_true, + in_false, + x, + mask, + level, + main_program=None, + startup_program=None): + helper = LayerHelper('merge_lod_tensor', **locals()) + out = helper.create_tmp_variable(dtype=x.data_type) + helper.append_op( + type='merge_lod_tensor', + inputs={'X': x, + 'Mask': mask, + 'InTrue': in_true, + 'InFalse': in_false}, + outputs={'Out': out}, + attrs={'level': level}) + return out + + def cos_sim(X, Y, **kwargs): """ This function performs the cosine similarity between two tensors @@ -519,7 +574,9 @@ def accuracy(input, label, k=1, **kwargs): "Indices": [topk_indices]}, attrs={"k": k}) acc_out_dtype = kwargs.get("out_dtype", "float32") - acc_out = helper.create_tmp_variable(dtype=acc_out_dtype) + acc_out = helper.create_tmp_variable(dtype="float32") + correct = helper.create_tmp_variable(dtype="int64") + total = helper.create_tmp_variable(dtype="int64") helper.append_op( type="accuracy", inputs={ @@ -527,7 +584,11 @@ def accuracy(input, label, k=1, **kwargs): "Indices": [topk_indices], "Label": [label] }, - outputs={"Accuracy": [acc_out]}) + outputs={ + "Accuracy": [acc_out], + "Correct": [correct], + "Total": [total], + }) return acc_out @@ -784,6 +845,23 @@ def batch_norm(input, return helper.append_activation(batch_norm_out) +def beam_search_decode(ids, scores, main_program=None, startup_program=None): + helper = LayerHelper('beam_search_decode', **locals()) + sentence_ids = helper.create_tmp_variable(dtype=ids.data_type) + sentence_scores = helper.create_tmp_variable(dtype=ids.data_type) + + helper.append_op( + type="beam_search_decode", + inputs={"Ids": ids, + "Scores": scores}, + outputs={ + "SentenceIds": sentence_ids, + "SentenceScores": sentence_scores + }) + + return sentence_ids, sentence_scores + + class BlockGuard(object): """ BlockGuard class. @@ -1375,3 +1453,73 @@ def array_length(array, main_program=None): helper.append_op( type='lod_array_length', inputs={'X': [array]}, outputs={'Out': [tmp]}) return tmp + + +class ConditionalBlockGuard(BlockGuard): + def __init__(self, block): + if not isinstance(block, ConditionalBlock): + raise TypeError("block should be conditional block") + super(ConditionalBlockGuard, self).__init__(block.helper.main_program) + self.block = block + + def __enter__(self): + return super(ConditionalBlockGuard, self).__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.block.complete() + return super(ConditionalBlockGuard, self).__exit__(exc_type, exc_val, + exc_tb) + + +class ConditionalBlock(object): + def __init__(self, inputs, name=None, main_program=None): + for each_input in inputs: + if not isinstance(each_input, Variable): + raise TypeError("Each input should be variable") + self.inputs = inputs + self.helper = LayerHelper( + 'conditional_block', name=name, main_program=main_program) + + def block(self): + return ConditionalBlockGuard(self) + + def complete(self): + inside_block = self.helper.main_program.current_block() + parent_block = self.helper.main_program.block(inside_block.parent_idx) + + intermediate = set() + params = set() + + for each_op in inside_block.ops: + assert isinstance(each_op, Operator) + for iname in each_op.input_names: + for in_var_name in each_op.input(iname): + if in_var_name not in intermediate: + params.add(in_var_name) + + for oname in each_op.output_names: + for out_var_name in each_op.output(oname): + intermediate.add(out_var_name) + input_set = set([ipt.name for ipt in self.inputs]) + + param_list = [ + parent_block.var(each_name) for each_name in params + if each_name not in input_set + ] + + out_list = [ + parent_block.var(var_name) for var_name in parent_block.vars + if var_name not in intermediate + ] + + step_scope = parent_block.create_var( + type=core.VarDesc.VarType.STEP_SCOPES) + parent_block.append_op( + type='conditional_block', + inputs={ + 'X': self.inputs, + 'Params': param_list, + }, + outputs={'Out': out_list, + 'Scope': [step_scope]}, + attrs={'block': inside_block}) diff --git a/python/paddle/v2/framework/net_drawer.py b/python/paddle/v2/fluid/net_drawer.py similarity index 96% rename from python/paddle/v2/framework/net_drawer.py rename to python/paddle/v2/fluid/net_drawer.py index 045e267c25..17ad547c2b 100644 --- a/python/paddle/v2/framework/net_drawer.py +++ b/python/paddle/v2/fluid/net_drawer.py @@ -3,8 +3,8 @@ import json import logging from collections import defaultdict -import paddle.v2.framework.core as core -import paddle.v2.framework.proto.framework_pb2 as framework_pb2 +import paddle.v2.fluid.core as core +import paddle.v2.fluid.proto.framework_pb2 as framework_pb2 logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) diff --git a/python/paddle/v2/framework/nets.py b/python/paddle/v2/fluid/nets.py similarity index 98% rename from python/paddle/v2/framework/nets.py rename to python/paddle/v2/fluid/nets.py index 725d2fa7f5..5e14ca594b 100644 --- a/python/paddle/v2/framework/nets.py +++ b/python/paddle/v2/fluid/nets.py @@ -1,4 +1,4 @@ -import paddle.v2.framework.layers as layers +import paddle.v2.fluid.layers as layers __all__ = ["simple_img_conv_pool", "sequence_conv_pool"] diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/fluid/op.py similarity index 98% rename from python/paddle/v2/framework/op.py rename to python/paddle/v2/fluid/op.py index bc771a964a..5828803497 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/fluid/op.py @@ -1,5 +1,5 @@ -import paddle.v2.framework.core as core -import paddle.v2.framework.proto.framework_pb2 as framework_pb2 +import paddle.v2.fluid.core as core +import paddle.v2.fluid.proto.framework_pb2 as framework_pb2 def get_all_op_protos(): diff --git a/python/paddle/v2/framework/optimizer.py b/python/paddle/v2/fluid/optimizer.py similarity index 89% rename from python/paddle/v2/framework/optimizer.py rename to python/paddle/v2/fluid/optimizer.py index f06c0fb98d..d2841df6af 100644 --- a/python/paddle/v2/framework/optimizer.py +++ b/python/paddle/v2/fluid/optimizer.py @@ -1,15 +1,15 @@ from collections import defaultdict -import paddle.v2.framework.framework as framework -from paddle.v2.framework.framework import unique_name, Program -from paddle.v2.framework.backward import append_backward_ops -from paddle.v2.framework.initializer import ConstantInitializer -from paddle.v2.framework.regularizer import append_regularization_ops -from paddle.v2.framework.layer_helper import LayerHelper +import paddle.v2.fluid.framework as framework +from paddle.v2.fluid.framework import unique_name, Program +from paddle.v2.fluid.backward import append_backward_ops +from paddle.v2.fluid.initializer import ConstantInitializer +from paddle.v2.fluid.regularizer import append_regularization_ops +from paddle.v2.fluid.layer_helper import LayerHelper __all__ = [ 'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer', - 'AdamaxOptimizer' + 'AdamaxOptimizer', 'DecayedAdagradOptimizer' ] @@ -85,7 +85,7 @@ class Optimizer(object): """ if (name in self._accumulators and param.name in self._accumulators[name]): - raise Exception("Accumulator {} already exists for parmeter {}". + raise Exception("Accumulator {} already exists for parameter {}". format(name, param.name)) assert isinstance(self.helper, LayerHelper) @@ -307,7 +307,7 @@ class AdagradOptimizer(Optimizer): moment_acc = self._get_accumulator(self._moment_acc_str, param_and_grad[0]) - # create the adagrad optimizer op + # Create the adagrad optimizer op adagrad_op = block.append_op( type=self.type, inputs={ @@ -510,3 +510,51 @@ class AdamaxOptimizer(Optimizer): attrs={"scale": self._beta1}) return [scale_beta1] + + +class DecayedAdagradOptimizer(Optimizer): + """Simple Decayed Adagrad optimizer with moment state + """ + _moment_acc_str = "moment" + + def __init__(self, + learning_rate, + decay=0.95, + epsilon=1.0e-6, + global_step=None): + assert learning_rate is not None + assert decay is not None + assert epsilon is not None + + super(DecayedAdagradOptimizer, self).__init__(global_step) + self.type = "decayed_adagrad" + self._learning_rate = learning_rate + self._decay = decay + self._epsilon = epsilon + + def _create_accumulators(self, block, parameters): + assert isinstance(block, framework.Block) + + for p in parameters: + self._add_accumulator(self._moment_acc_str, p) + + def _append_optimize_op(self, block, param_and_grad): + assert isinstance(block, framework.Block) + + moment_acc = self._get_accumulator(self._moment_acc_str, + param_and_grad[0]) + + # Create the decayed adagrad optimizer op + decayed_adagrad_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "Moment": moment_acc, + "LearningRate": self._create_param_lr(param_and_grad) + }, + outputs={"ParamOut": param_and_grad[0], + "MomentOut": moment_acc}, + attrs={"epsilon": self._epsilon}) + + return decayed_adagrad_op diff --git a/python/paddle/v2/framework/regularizer.py b/python/paddle/v2/fluid/regularizer.py similarity index 98% rename from python/paddle/v2/framework/regularizer.py rename to python/paddle/v2/fluid/regularizer.py index 5111ac5566..098cd0dd64 100644 --- a/python/paddle/v2/framework/regularizer.py +++ b/python/paddle/v2/fluid/regularizer.py @@ -1,4 +1,4 @@ -import paddle.v2.framework.framework as framework +import paddle.v2.fluid.framework as framework __all__ = [ 'append_regularization_ops', 'L2DecayRegularizer', 'L1DecayRegularizer' diff --git a/python/paddle/v2/framework/tests/.gitignore b/python/paddle/v2/fluid/tests/.gitignore similarity index 100% rename from python/paddle/v2/framework/tests/.gitignore rename to python/paddle/v2/fluid/tests/.gitignore diff --git a/python/paddle/v2/fluid/tests/CMakeLists.txt b/python/paddle/v2/fluid/tests/CMakeLists.txt new file mode 100644 index 0000000000..e795627bfe --- /dev/null +++ b/python/paddle/v2/fluid/tests/CMakeLists.txt @@ -0,0 +1,7 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") +foreach(src ${TEST_OPS}) + py_test(${src} SRCS ${src}.py) +endforeach() + +add_subdirectory(book) diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/fluid/tests/book/CMakeLists.txt similarity index 100% rename from python/paddle/v2/framework/tests/CMakeLists.txt rename to python/paddle/v2/fluid/tests/book/CMakeLists.txt diff --git a/python/paddle/v2/framework/tests/test_fit_a_line.py b/python/paddle/v2/fluid/tests/book/test_fit_a_line.py similarity index 51% rename from python/paddle/v2/framework/tests/test_fit_a_line.py rename to python/paddle/v2/fluid/tests/book/test_fit_a_line.py index 6e09b88dca..ee677a2c56 100644 --- a/python/paddle/v2/framework/tests/test_fit_a_line.py +++ b/python/paddle/v2/fluid/tests/book/test_fit_a_line.py @@ -1,46 +1,34 @@ import paddle.v2 as paddle -import paddle.v2.framework.layers as layers -import paddle.v2.framework.core as core -import paddle.v2.framework.optimizer as optimizer - -from paddle.v2.framework.framework import Program -from paddle.v2.framework.io import save_persistables, load_persistables -from paddle.v2.framework.executor import Executor +import paddle.v2.fluid.layers as layers +import paddle.v2.fluid.core as core +import paddle.v2.fluid.optimizer as optimizer +import paddle.v2.fluid.framework as framework +from paddle.v2.fluid.io import save_persistables, load_persistables +from paddle.v2.fluid.executor import Executor import numpy as np -startup_program = Program() -main_program = Program() x = layers.data( name='x', shape=[13], - data_type='float32', - main_program=main_program, - startup_program=startup_program) + data_type='float32') y_predict = layers.fc(input=x, size=1, - act=None, - main_program=main_program, - startup_program=startup_program) + act=None) y = layers.data( name='y', shape=[1], - data_type='float32', - main_program=main_program, - startup_program=startup_program) + data_type='float32') cost = layers.square_error_cost( input=y_predict, - label=y, - main_program=main_program, - startup_program=startup_program) -avg_cost = layers.mean( - x=cost, main_program=main_program, startup_program=startup_program) + label=y) +avg_cost = layers.mean(x=cost) sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001) -opts = sgd_optimizer.minimize(avg_cost, startup_program) +opts = sgd_optimizer.minimize(avg_cost) BATCH_SIZE = 20 @@ -52,12 +40,12 @@ train_reader = paddle.batch( place = core.CPUPlace() exe = Executor(place) -exe.run(startup_program, feed={}, fetch_list=[]) +exe.run(framework.default_startup_program()) PASS_NUM = 100 for pass_id in range(PASS_NUM): - save_persistables(exe, "./fit_a_line.model/", main_program=main_program) - load_persistables(exe, "./fit_a_line.model/", main_program=main_program) + save_persistables(exe, "./fit_a_line.model/") + load_persistables(exe, "./fit_a_line.model/") for data in train_reader(): x_data = np.array(map(lambda x: x[0], data)).astype("float32") y_data = np.array(map(lambda x: x[1], data)).astype("float32") @@ -69,7 +57,7 @@ for pass_id in range(PASS_NUM): tensor_y = core.LoDTensor() tensor_y.set(y_data, place) # print tensor_y.get_dims() - outs = exe.run(main_program, + outs = exe.run(framework.default_main_program(), feed={'x': tensor_x, 'y': tensor_y}, fetch_list=[avg_cost]) diff --git a/python/paddle/v2/framework/tests/test_image_classification_train.py b/python/paddle/v2/fluid/tests/book/test_image_classification_train.py similarity index 57% rename from python/paddle/v2/framework/tests/test_image_classification_train.py rename to python/paddle/v2/fluid/tests/book/test_image_classification_train.py index a4165da970..f4be835b3a 100644 --- a/python/paddle/v2/framework/tests/test_image_classification_train.py +++ b/python/paddle/v2/fluid/tests/book/test_image_classification_train.py @@ -1,23 +1,21 @@ import numpy as np import paddle.v2 as paddle -import paddle.v2.framework.core as core -import paddle.v2.framework.layers as layers -import paddle.v2.framework.nets as nets -import paddle.v2.framework.optimizer as optimizer -from paddle.v2.framework.executor import Executor -from paddle.v2.framework.framework import g_startup_program, g_main_program -from paddle.v2.framework.initializer import XavierInitializer +import paddle.v2.fluid.core as core +import paddle.v2.fluid.layers as layers +import paddle.v2.fluid.nets as nets +import paddle.v2.fluid.optimizer as optimizer +from paddle.v2.fluid.executor import Executor +import paddle.v2.fluid.framework as framework +from paddle.v2.fluid.initializer import XavierInitializer -def resnet_cifar10(input, depth=32, main_program=None, startup_program=None): +def resnet_cifar10(input, depth=32): def conv_bn_layer(input, ch_out, filter_size, stride, padding, - act='relu', - main_program=None, - startup_program=None): + act='relu'): tmp = layers.conv2d( input=input, filter_size=filter_size, @@ -25,14 +23,10 @@ def resnet_cifar10(input, depth=32, main_program=None, startup_program=None): stride=stride, padding=padding, act=None, - bias_attr=False, - main_program=main_program, - startup_program=startup_program) + bias_attr=False) return layers.batch_norm( input=tmp, - act=act, - main_program=main_program, - startup_program=startup_program) + act=act) def shortcut(input, ch_in, ch_out, stride, program, init_program): if ch_in != ch_out: @@ -44,40 +38,30 @@ def resnet_cifar10(input, depth=32, main_program=None, startup_program=None): def basicblock(input, ch_in, ch_out, - stride, - main_program=main_program, - startup_program=startup_program): + stride): tmp = conv_bn_layer( input, ch_out, 3, stride, - 1, - main_program=main_program, - startup_program=startup_program) + 1) tmp = conv_bn_layer( tmp, ch_out, 3, 1, 1, - act=None, - main_program=main_program, - startup_program=startup_program) - short = shortcut(input, ch_in, ch_out, stride, main_program, - startup_program) + act=None) + short = shortcut(input, ch_in, ch_out, stride) return layers.elementwise_add( x=tmp, y=short, - act='relu', - main_program=main_program, - startup_program=startup_program) + act='relu') - def layer_warp(block_func, input, ch_in, ch_out, count, stride, program, - startup_program): - tmp = block_func(input, ch_in, ch_out, stride, program, startup_program) + def layer_warp(block_func, input, ch_in, ch_out, count, stride): + tmp = block_func(input, ch_in, ch_out, stride) for i in range(1, count): - tmp = block_func(tmp, ch_out, ch_out, 1, program, startup_program) + tmp = block_func(tmp, ch_out, ch_out, 1) return tmp assert (depth - 2) % 6 == 0 @@ -87,53 +71,41 @@ def resnet_cifar10(input, depth=32, main_program=None, startup_program=None): ch_out=16, filter_size=3, stride=1, - padding=1, - main_program=main_program, - startup_program=startup_program) + padding=1) res1 = layer_warp( basicblock, conv1, 16, 16, n, - 1, - main_program=main_program, - startup_program=startup_program) + 1) res2 = layer_warp( basicblock, res1, 16, 32, n, - 2, - main_program=main_program, - startup_program=startup_program) + 2) res3 = layer_warp( basicblock, res2, 32, 64, n, - 2, - main_program=main_program, - startup_program=startup_program) + 2) pool = layers.pool2d( input=res3, pool_size=8, pool_type='avg', - pool_stride=1, - main_program=main_program, - startup_program=startup_program) + pool_stride=1) return pool -def vgg16_bn_drop(input, main_program=None, startup_program=None): +def vgg16_bn_drop(input): def conv_block(input, num_filter, groups, - dropouts, - main_program=None, - startup_program=None): + dropouts): return nets.img_conv_group( input=input, pool_size=2, @@ -143,51 +115,34 @@ def vgg16_bn_drop(input, main_program=None, startup_program=None): conv_act='relu', conv_with_batchnorm=True, conv_batchnorm_drop_rate=dropouts, - pool_type='max', - main_program=main_program, - startup_program=startup_program) + pool_type='max') - conv1 = conv_block(input, 64, 2, [0.3, 0], main_program, startup_program) - conv2 = conv_block(conv1, 128, 2, [0.4, 0], main_program, startup_program) - conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0], main_program, - startup_program) - conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0], main_program, - startup_program) - conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0], main_program, - startup_program) + conv1 = conv_block(input, 64, 2, [0.3, 0]) + conv2 = conv_block(conv1, 128, 2, [0.4, 0]) + conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0]) + conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0]) + conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0]) drop = layers.dropout( x=conv5, - dropout_prob=0.5, - main_program=main_program, - startup_program=startup_program) + dropout_prob=0.5) fc1 = layers.fc(input=drop, size=512, act=None, - param_attr={"initializer": XavierInitializer()}, - main_program=main_program, - startup_program=startup_program) + param_attr={"initializer": XavierInitializer()}) reshape1 = layers.reshape( x=fc1, - shape=list(fc1.shape + (1, 1)), - main_program=main_program, - startup_program=startup_program) + shape=list(fc1.shape + (1, 1))) bn = layers.batch_norm( input=reshape1, - act='relu', - main_program=main_program, - startup_program=startup_program) + act='relu') drop2 = layers.dropout( x=bn, - dropout_prob=0.5, - main_program=main_program, - startup_program=startup_program) + dropout_prob=0.5) fc2 = layers.fc(input=drop2, size=512, act=None, - param_attr={"initializer": XavierInitializer()}, - main_program=main_program, - startup_program=startup_program) + param_attr={"initializer": XavierInitializer()}) return fc2 @@ -225,7 +180,7 @@ train_reader = paddle.batch( place = core.CPUPlace() exe = Executor(place) -exe.run(g_startup_program, feed={}, fetch_list=[]) +exe.run(framework.default_startup_program()) for pass_id in range(PASS_NUM): batch_id = 0 @@ -243,7 +198,7 @@ for pass_id in range(PASS_NUM): tensor_img.set(img_data, place) tensor_y.set(y_data, place) - outs = exe.run(g_main_program, + outs = exe.run(framework.default_main_program(), feed={"pixel": tensor_img, "label": tensor_y}, fetch_list=[avg_cost, accuracy]) diff --git a/python/paddle/v2/framework/tests/test_recognize_digits_conv.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py similarity index 53% rename from python/paddle/v2/framework/tests/test_recognize_digits_conv.py rename to python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py index 66c629eb42..f330ff5813 100644 --- a/python/paddle/v2/framework/tests/test_recognize_digits_conv.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py @@ -1,69 +1,48 @@ import paddle.v2 as paddle -import paddle.v2.framework.layers as layers -import paddle.v2.framework.nets as nets -import paddle.v2.framework.core as core -import paddle.v2.framework.optimizer as optimizer - -from paddle.v2.framework.framework import Program -from paddle.v2.framework.executor import Executor +import paddle.v2.fluid.layers as layers +import paddle.v2.fluid.nets as nets +import paddle.v2.fluid.core as core +import paddle.v2.fluid.optimizer as optimizer +import paddle.v2.fluid.evaluator as evaluator +import paddle.v2.fluid.framework as framework +from paddle.v2.fluid.executor import Executor import numpy as np -startup_program = Program() -main_program = Program() - images = layers.data( name='pixel', shape=[1, 28, 28], - data_type='float32', - main_program=main_program, - startup_program=startup_program) + data_type='float32') label = layers.data( name='label', shape=[1], - data_type='int64', - main_program=main_program, - startup_program=startup_program) + data_type='int64') conv_pool_1 = nets.simple_img_conv_pool( input=images, filter_size=5, num_filters=20, pool_size=2, pool_stride=2, - act="relu", - main_program=main_program, - startup_program=startup_program) + act="relu") conv_pool_2 = nets.simple_img_conv_pool( input=conv_pool_1, filter_size=5, num_filters=50, pool_size=2, pool_stride=2, - act="relu", - main_program=main_program, - startup_program=startup_program) + act="relu") predict = layers.fc(input=conv_pool_2, size=10, - act="softmax", - main_program=main_program, - startup_program=startup_program) -cost = layers.cross_entropy( - input=predict, - label=label, - main_program=main_program, - startup_program=startup_program) -avg_cost = layers.mean(x=cost, main_program=main_program) -accuracy = layers.accuracy( - input=predict, - label=label, - main_program=main_program, - startup_program=startup_program) - -# optimizer = optimizer.MomentumOptimizer(learning_rate=0.1 / 128.0, -# momentum=0.9) + act="softmax") +cost = layers.cross_entropy(input=predict, label=label) +avg_cost = layers.mean(x=cost) optimizer = optimizer.AdamOptimizer(learning_rate=0.01, beta1=0.9, beta2=0.999) -opts = optimizer.minimize(avg_cost, startup_program) +opts = optimizer.minimize(avg_cost) + +accuracy, acc_out = evaluator.accuracy( + input=predict, + label=label) BATCH_SIZE = 50 PASS_NUM = 3 @@ -75,10 +54,11 @@ train_reader = paddle.batch( place = core.CPUPlace() exe = Executor(place) -exe.run(startup_program, feed={}, fetch_list=[]) +exe.run(framework.default_startup_program()) for pass_id in range(PASS_NUM): count = 0 + accuracy.reset(exe) for data in train_reader(): img_data = np.array(map(lambda x: x[0].reshape([1, 28, 28]), data)).astype("float32") @@ -90,14 +70,20 @@ for pass_id in range(PASS_NUM): tensor_img.set(img_data, place) tensor_y.set(y_data, place) - outs = exe.run(main_program, + outs = exe.run(framework.default_main_program(), feed={"pixel": tensor_img, "label": tensor_y}, - fetch_list=[avg_cost, accuracy]) + fetch_list=[avg_cost, acc_out]) loss = np.array(outs[0]) acc = np.array(outs[1]) - + pass_acc = accuracy.eval(exe) + print "pass id : ", pass_id, pass_acc + # print loss, acc if loss < 10.0 and acc > 0.9: # if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good. exit(0) + + pass_acc = accuracy.eval(exe) + print "pass id : ", pass_id, pass_acc + exit(1) diff --git a/python/paddle/v2/framework/tests/test_recognize_digits_mlp.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py similarity index 57% rename from python/paddle/v2/framework/tests/test_recognize_digits_mlp.py rename to python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py index 076cf88216..b0164e3e36 100644 --- a/python/paddle/v2/framework/tests/test_recognize_digits_mlp.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py @@ -1,24 +1,19 @@ import paddle.v2 as paddle -import paddle.v2.framework.layers as layers -import paddle.v2.framework.core as core -import paddle.v2.framework.optimizer as optimizer - -from paddle.v2.framework.framework import Program -from paddle.v2.framework.executor import Executor -from paddle.v2.framework.regularizer import L2DecayRegularizer -from paddle.v2.framework.initializer import UniformInitializer +import paddle.v2.fluid.layers as layers +import paddle.v2.fluid.core as core +import paddle.v2.fluid.optimizer as optimizer +import paddle.v2.fluid.framework as framework +from paddle.v2.fluid.executor import Executor +from paddle.v2.fluid.regularizer import L2DecayRegularizer +from paddle.v2.fluid.initializer import UniformInitializer import numpy as np BATCH_SIZE = 128 -startup_program = Program() -main_program = Program() image = layers.data( name='x', shape=[784], - data_type='float32', - main_program=main_program, - startup_program=startup_program) + data_type='float32') param_attr = { 'name': None, @@ -30,45 +25,30 @@ param_attr = { hidden1 = layers.fc(input=image, size=128, act='relu', - main_program=main_program, - startup_program=startup_program, param_attr=param_attr) hidden2 = layers.fc(input=hidden1, size=64, act='relu', - main_program=main_program, - startup_program=startup_program, param_attr=param_attr) predict = layers.fc(input=hidden2, size=10, act='softmax', - main_program=main_program, - startup_program=startup_program, param_attr=param_attr) label = layers.data( name='y', shape=[1], - data_type='int64', - main_program=main_program, - startup_program=startup_program) + data_type='int64') -cost = layers.cross_entropy( - input=predict, - label=label, - main_program=main_program, - startup_program=startup_program) -avg_cost = layers.mean( - x=cost, main_program=main_program, startup_program=startup_program) +cost = layers.cross_entropy(input=predict, label=label) +avg_cost = layers.mean(x=cost) accuracy = layers.accuracy( input=predict, - label=label, - main_program=main_program, - startup_program=startup_program) + label=label) optimizer = optimizer.MomentumOptimizer(learning_rate=0.001, momentum=0.9) -opts = optimizer.minimize(avg_cost, startup_program) +opts = optimizer.minimize(avg_cost) train_reader = paddle.batch( paddle.reader.shuffle( @@ -78,7 +58,7 @@ train_reader = paddle.batch( place = core.CPUPlace() exe = Executor(place) -exe.run(startup_program, feed={}, fetch_list=[]) +exe.run(framework.default_startup_program()) PASS_NUM = 100 for pass_id in range(PASS_NUM): @@ -93,7 +73,7 @@ for pass_id in range(PASS_NUM): tensor_y = core.LoDTensor() tensor_y.set(y_data, place) - outs = exe.run(main_program, + outs = exe.run(framework.default_main_program(), feed={'x': tensor_x, 'y': tensor_y}, fetch_list=[avg_cost, accuracy]) diff --git a/python/paddle/v2/framework/tests/test_recommender_system.py b/python/paddle/v2/fluid/tests/book/test_recommender_system.py similarity index 58% rename from python/paddle/v2/framework/tests/test_recommender_system.py rename to python/paddle/v2/fluid/tests/book/test_recommender_system.py index 31562b4391..eefcb55beb 100644 --- a/python/paddle/v2/framework/tests/test_recommender_system.py +++ b/python/paddle/v2/fluid/tests/book/test_recommender_system.py @@ -1,18 +1,15 @@ import paddle.v2 as paddle -import paddle.v2.framework.layers as layers -import paddle.v2.framework.nets as nets -import paddle.v2.framework.core as core -import paddle.v2.framework.optimizer as optimizer - -from paddle.v2.framework.framework import Program -from paddle.v2.framework.executor import Executor +import paddle.v2.fluid.layers as layers +import paddle.v2.fluid.nets as nets +import paddle.v2.fluid.core as core +import paddle.v2.fluid.optimizer as optimizer +import paddle.v2.fluid.framework as framework +from paddle.v2.fluid.executor import Executor import numpy as np -startup_program = Program() -main_program = Program() -is_sparse = True -use_gpu = False +IS_SPARSE = True +USE_GPU = False BATCH_SIZE = 256 @@ -25,99 +22,71 @@ def get_usr_combined_features(): uid = layers.data( name='user_id', shape=[1], - data_type='int64', - main_program=main_program, - startup_program=startup_program) + data_type='int64') usr_emb = layers.embedding( input=uid, data_type='float32', size=[USR_DICT_SIZE, 32], param_attr={'name': 'user_table'}, - is_sparse=is_sparse, - main_program=main_program, - startup_program=startup_program) + is_sparse=IS_SPARSE) usr_fc = layers.fc(input=usr_emb, - size=32, - main_program=main_program, - startup_program=startup_program) + size=32) USR_GENDER_DICT_SIZE = 2 usr_gender_id = layers.data( name='gender_id', shape=[1], - data_type='int64', - main_program=main_program, - startup_program=startup_program) + data_type='int64') usr_gender_emb = layers.embedding( input=usr_gender_id, size=[USR_GENDER_DICT_SIZE, 16], param_attr={'name': 'gender_table'}, - is_sparse=is_sparse, - main_program=main_program, - startup_program=startup_program) + is_sparse=IS_SPARSE) usr_gender_fc = layers.fc(input=usr_gender_emb, - size=16, - main_program=main_program, - startup_program=startup_program) + size=16) USR_AGE_DICT_SIZE = len(paddle.dataset.movielens.age_table) usr_age_id = layers.data( name='age_id', shape=[1], - data_type="int64", - main_program=main_program, - startup_program=startup_program) + data_type="int64") usr_age_emb = layers.embedding( input=usr_age_id, size=[USR_AGE_DICT_SIZE, 16], - is_sparse=is_sparse, - param_attr={'name': 'age_table'}, - main_program=main_program, - startup_program=startup_program) + is_sparse=IS_SPARSE, + param_attr={'name': 'age_table'}) usr_age_fc = layers.fc(input=usr_age_emb, - size=16, - main_program=main_program, - startup_program=startup_program) + size=16) USR_JOB_DICT_SIZE = paddle.dataset.movielens.max_job_id() + 1 usr_job_id = layers.data( name='job_id', shape=[1], - data_type="int64", - main_program=main_program, - startup_program=startup_program) + data_type="int64") usr_job_emb = layers.embedding( input=usr_job_id, size=[USR_JOB_DICT_SIZE, 16], param_attr={'name': 'job_table'}, - is_sparse=is_sparse, - main_program=main_program, - startup_program=startup_program) + is_sparse=IS_SPARSE) usr_job_fc = layers.fc(input=usr_job_emb, - size=16, - main_program=main_program, - startup_program=startup_program) + size=16) concat_embed = layers.concat( input=[usr_fc, usr_gender_fc, usr_age_fc, usr_job_fc], - axis=1, - main_program=main_program, - startup_program=startup_program) + axis=1) usr_combined_features = layers.fc(input=concat_embed, size=200, - act="tanh", - main_program=main_program, - startup_program=startup_program) + act="tanh") return usr_combined_features @@ -129,83 +98,61 @@ def get_mov_combined_features(): mov_id = layers.data( name='movie_id', shape=[1], - data_type='int64', - main_program=main_program, - startup_program=startup_program) + data_type='int64') mov_emb = layers.embedding( input=mov_id, data_type='float32', size=[MOV_DICT_SIZE, 32], param_attr={'name': 'movie_table'}, - is_sparse=is_sparse, - main_program=main_program, - startup_program=startup_program) + is_sparse=IS_SPARSE) mov_fc = layers.fc(input=mov_emb, - size=32, - main_program=main_program, - startup_program=startup_program) + size=32) CATEGORY_DICT_SIZE = len(paddle.dataset.movielens.movie_categories()) category_id = layers.data( name='category_id', shape=[1], - data_type='int64', - main_program=main_program, - startup_program=startup_program) + data_type='int64') mov_categories_emb = layers.embedding( input=category_id, size=[CATEGORY_DICT_SIZE, 32], - is_sparse=is_sparse, - main_program=main_program, - startup_program=startup_program) + is_sparse=IS_SPARSE) mov_categories_hidden = layers.sequence_pool( input=mov_categories_emb, - pool_type="sum", - main_program=main_program, - startup_program=startup_program) + pool_type="sum") MOV_TITLE_DICT_SIZE = len(paddle.dataset.movielens.get_movie_title_dict()) mov_title_id = layers.data( name='movie_title', shape=[1], - data_type='int64', - main_program=main_program, - startup_program=startup_program) + data_type='int64') mov_title_emb = layers.embedding( input=mov_title_id, size=[MOV_TITLE_DICT_SIZE, 32], - is_sparse=is_sparse, - main_program=main_program, - startup_program=startup_program) + is_sparse=IS_SPARSE) mov_title_conv = nets.sequence_conv_pool( input=mov_title_emb, num_filters=32, filter_size=3, act="tanh", - pool_type="sum", - main_program=main_program, - startup_program=startup_program) + pool_type="sum") concat_embed = layers.concat( input=[mov_fc, mov_categories_hidden, mov_title_conv], - axis=1, - main_program=main_program, - startup_program=startup_program) + axis=1) # FIXME(dzh) : need tanh operator mov_combined_features = layers.fc(input=concat_embed, size=200, - act="tanh", - main_program=main_program, - startup_program=startup_program) + act="tanh") return mov_combined_features @@ -217,27 +164,18 @@ def model(): # need cos sim inference = layers.cos_sim( X=usr_combined_features, - Y=mov_combined_features, - main_program=main_program, - startup_program=startup_program) + Y=mov_combined_features) label = layers.data( name='score', shape=[1], - data_type='float32', - main_program=main_program, - startup_program=startup_program) + data_type='float32') square_cost = layers.square_error_cost( input=inference, - label=label, - main_program=main_program, - startup_program=startup_program) + label=label) - avg_cost = layers.mean( - x=square_cost, - main_program=main_program, - startup_program=startup_program) + avg_cost = layers.mean(x=square_cost) return avg_cost @@ -245,16 +183,15 @@ def model(): def main(): cost = model() sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.2) - opts = sgd_optimizer.minimize(cost, startup_program=startup_program) - block = main_program.block(0) + opts = sgd_optimizer.minimize(cost) - if use_gpu: + if USE_GPU: place = core.GPUPlace(0) else: place = core.CPUPlace() exe = Executor(place) - exe.run(startup_program, feed={}, fetch_list=[]) + exe.run(framework.default_startup_program()) train_reader = paddle.batch( paddle.reader.shuffle( @@ -303,7 +240,7 @@ def main(): PASS_NUM = 100 for pass_id in range(PASS_NUM): for data in train_reader(): - outs = exe.run(main_program, + outs = exe.run(framework.default_main_program(), feed=func_feed(feeding, data), fetch_list=[cost]) out = np.array(outs[0]) diff --git a/python/paddle/v2/framework/tests/test_understand_sentiment_conv.py b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py similarity index 87% rename from python/paddle/v2/framework/tests/test_understand_sentiment_conv.py rename to python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py index eb377e9264..91fc79a987 100644 --- a/python/paddle/v2/framework/tests/test_understand_sentiment_conv.py +++ b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py @@ -1,11 +1,10 @@ import paddle.v2 as paddle -import paddle.v2.framework.layers as layers -import paddle.v2.framework.nets as nets -import paddle.v2.framework.core as core -import paddle.v2.framework.optimizer as optimizer - -from paddle.v2.framework.framework import Program, g_main_program, g_startup_program -from paddle.v2.framework.executor import Executor +import paddle.v2.fluid.layers as layers +import paddle.v2.fluid.nets as nets +import paddle.v2.fluid.core as core +import paddle.v2.fluid.optimizer as optimizer +import paddle.v2.fluid.framework as framework +from paddle.v2.fluid.executor import Executor import numpy as np @@ -70,7 +69,7 @@ def main(): place = core.CPUPlace() exe = Executor(place) - exe.run(g_startup_program) + exe.run(framework.default_startup_program()) for pass_id in xrange(PASS_NUM): for data in train_data(): @@ -82,7 +81,7 @@ def main(): tensor_label = core.LoDTensor() tensor_label.set(label, place) - outs = exe.run(g_main_program, + outs = exe.run(framework.default_main_program(), feed={"words": tensor_words, "label": tensor_label}, fetch_list=[cost, acc]) diff --git a/python/paddle/v2/framework/tests/test_understand_sentiment_dynamic_lstm.py b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py similarity index 89% rename from python/paddle/v2/framework/tests/test_understand_sentiment_dynamic_lstm.py rename to python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py index 2457c71e1a..8c3d448835 100644 --- a/python/paddle/v2/framework/tests/test_understand_sentiment_dynamic_lstm.py +++ b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py @@ -1,11 +1,10 @@ import paddle.v2 as paddle -import paddle.v2.framework.layers as layers -import paddle.v2.framework.nets as nets -import paddle.v2.framework.core as core -import paddle.v2.framework.optimizer as optimizer - -from paddle.v2.framework.framework import Program, g_main_program, g_startup_program -from paddle.v2.framework.executor import Executor +import paddle.v2.fluid.layers as layers +import paddle.v2.fluid.nets as nets +import paddle.v2.fluid.core as core +import paddle.v2.fluid.optimizer as optimizer +import paddle.v2.fluid.framework as framework +from paddle.v2.fluid.executor import Executor import numpy as np @@ -81,7 +80,7 @@ def main(): place = core.CPUPlace() exe = Executor(place) - exe.run(g_startup_program) + exe.run(framework.default_startup_program()) for pass_id in xrange(PASS_NUM): for data in train_data(): @@ -93,7 +92,7 @@ def main(): tensor_label = core.LoDTensor() tensor_label.set(label, place) - outs = exe.run(g_main_program, + outs = exe.run(framework.default_main_program(), feed={"words": tensor_words, "label": tensor_label}, fetch_list=[cost, acc]) diff --git a/python/paddle/v2/framework/tests/test_understand_sentiment_lstm.py b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py similarity index 89% rename from python/paddle/v2/framework/tests/test_understand_sentiment_lstm.py rename to python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py index 26cbd01bc0..a7d791c1f3 100644 --- a/python/paddle/v2/framework/tests/test_understand_sentiment_lstm.py +++ b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py @@ -1,10 +1,9 @@ import paddle.v2 as paddle -import paddle.v2.framework.layers as layers -import paddle.v2.framework.core as core -import paddle.v2.framework.optimizer as optimizer - -from paddle.v2.framework.framework import g_main_program, g_startup_program -from paddle.v2.framework.executor import Executor +import paddle.v2.fluid.layers as layers +import paddle.v2.fluid.core as core +import paddle.v2.fluid.optimizer as optimizer +import paddle.v2.fluid.framework as framework +from paddle.v2.fluid.executor import Executor import numpy as np @@ -88,10 +87,10 @@ def main(): place = core.CPUPlace() tensor_words, tensor_label = prepare_feed_data(data, place) exe = Executor(place) - exe.run(g_startup_program) + exe.run(framework.default_startup_program()) while True: - outs = exe.run(g_main_program, + outs = exe.run(framework.default_main_program(), feed={"words": tensor_words, "label": tensor_label}, fetch_list=[cost, acc]) diff --git a/python/paddle/v2/framework/tests/test_word2vec.py b/python/paddle/v2/fluid/tests/book/test_word2vec.py similarity index 54% rename from python/paddle/v2/framework/tests/test_word2vec.py rename to python/paddle/v2/fluid/tests/book/test_word2vec.py index cb9fc2ab62..9dcb6f2fea 100644 --- a/python/paddle/v2/framework/tests/test_word2vec.py +++ b/python/paddle/v2/fluid/tests/book/test_word2vec.py @@ -1,21 +1,18 @@ import paddle.v2 as paddle -import paddle.v2.framework.layers as layers -import paddle.v2.framework.core as core -import paddle.v2.framework.optimizer as optimizer - -from paddle.v2.framework.framework import Program -from paddle.v2.framework.executor import Executor +import paddle.v2.fluid.layers as layers +import paddle.v2.fluid.core as core +import paddle.v2.fluid.optimizer as optimizer +import paddle.v2.fluid.framework as framework +from paddle.v2.fluid.executor import Executor import numpy as np -startup_program = Program() -main_program = Program() - -embed_size = 32 -hidden_size = 256 +PASS_NUM = 100 +EMBED_SIZE = 32 +HIDDEN_SIZE = 256 N = 5 -batch_size = 32 -is_sparse = True +BATCH_SIZE = 32 +IS_SPARSE = True word_dict = paddle.dataset.imikolov.build_dict() dict_size = len(word_dict) @@ -23,97 +20,67 @@ dict_size = len(word_dict) first_word = layers.data( name='firstw', shape=[1], - data_type='int64', - main_program=main_program, - startup_program=startup_program) + data_type='int64') second_word = layers.data( name='secondw', shape=[1], - data_type='int64', - main_program=main_program, - startup_program=startup_program) + data_type='int64') third_word = layers.data( name='thirdw', shape=[1], - data_type='int64', - main_program=main_program, - startup_program=startup_program) + data_type='int64') forth_word = layers.data( name='forthw', shape=[1], - data_type='int64', - main_program=main_program, - startup_program=startup_program) + data_type='int64') next_word = layers.data( name='nextw', shape=[1], - data_type='int64', - main_program=main_program, - startup_program=startup_program) + data_type='int64') embed_first = layers.embedding( input=first_word, - size=[dict_size, embed_size], + size=[dict_size, EMBED_SIZE], data_type='float32', - is_sparse=is_sparse, - param_attr={'name': 'shared_w'}, - main_program=main_program, - startup_program=startup_program) + is_sparse=IS_SPARSE, + param_attr={'name': 'shared_w'}) embed_second = layers.embedding( input=second_word, - size=[dict_size, embed_size], + size=[dict_size, EMBED_SIZE], data_type='float32', - is_sparse=is_sparse, - param_attr={'name': 'shared_w'}, - main_program=main_program, - startup_program=startup_program) - + is_sparse=IS_SPARSE, + param_attr={'name': 'shared_w'}) embed_third = layers.embedding( input=third_word, - size=[dict_size, embed_size], + size=[dict_size, EMBED_SIZE], data_type='float32', - is_sparse=is_sparse, - param_attr={'name': 'shared_w'}, - main_program=main_program, - startup_program=startup_program) + is_sparse=IS_SPARSE, + param_attr={'name': 'shared_w'}) embed_forth = layers.embedding( input=forth_word, - size=[dict_size, embed_size], + size=[dict_size, EMBED_SIZE], data_type='float32', - is_sparse=is_sparse, - param_attr={'name': 'shared_w'}, - main_program=main_program, - startup_program=startup_program) + is_sparse=IS_SPARSE, + param_attr={'name': 'shared_w'}) concat_embed = layers.concat( input=[embed_first, embed_second, embed_third, embed_forth], - axis=1, - main_program=main_program, - startup_program=startup_program) - + axis=1) hidden1 = layers.fc(input=concat_embed, - size=hidden_size, - act='sigmoid', - main_program=main_program, - startup_program=startup_program) + size=HIDDEN_SIZE, + act='sigmoid') predict_word = layers.fc(input=hidden1, size=dict_size, - act='softmax', - main_program=main_program, - startup_program=startup_program) + act='softmax') cost = layers.cross_entropy( input=predict_word, - label=next_word, - main_program=main_program, - startup_program=startup_program) -avg_cost = layers.mean( - x=cost, main_program=main_program, startup_program=startup_program) - + label=next_word) +avg_cost = layers.mean(x=cost) sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001) -opts = sgd_optimizer.minimize(avg_cost, startup_program) +opts = sgd_optimizer.minimize(avg_cost) train_reader = paddle.batch( - paddle.dataset.imikolov.train(word_dict, N), batch_size) + paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE) place = core.CPUPlace() exe = Executor(place) @@ -122,8 +89,8 @@ exe = Executor(place) # below exit line. exit(0) -exe.run(startup_program, feed={}, fetch_list=[]) -PASS_NUM = 100 +exe.run(framework.default_startup_program()) + for pass_id in range(PASS_NUM): for data in train_reader(): input_data = [[data_idx[idx] for data_idx in data] for idx in xrange(5)] @@ -150,7 +117,7 @@ for pass_id in range(PASS_NUM): next_tensor = core.LoDTensor() next_tensor.set(next_data, place) - outs = exe.run(main_program, + outs = exe.run(framework.default_main_program(), feed={ 'firstw': first_tensor, 'secondw': second_tensor, diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/fluid/tests/op_test.py similarity index 98% rename from python/paddle/v2/framework/tests/op_test.py rename to python/paddle/v2/fluid/tests/op_test.py index 4a269341a4..90269e308a 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/fluid/tests/op_test.py @@ -2,12 +2,12 @@ import unittest import numpy as np import random import itertools -import paddle.v2.framework.core as core +import paddle.v2.fluid.core as core import collections -from paddle.v2.framework.backward import append_backward_ops -from paddle.v2.framework.op import Operator -from paddle.v2.framework.executor import Executor -from paddle.v2.framework.framework import Program, OpProtoHolder +from paddle.v2.fluid.backward import append_backward_ops +from paddle.v2.fluid.op import Operator +from paddle.v2.fluid.executor import Executor +from paddle.v2.fluid.framework import Program, OpProtoHolder def randomize_probability(batch_size, class_num, dtype='float32'): diff --git a/python/paddle/v2/framework/tests/test_accuracy_op.py b/python/paddle/v2/fluid/tests/test_accuracy_op.py similarity index 86% rename from python/paddle/v2/framework/tests/test_accuracy_op.py rename to python/paddle/v2/fluid/tests/test_accuracy_op.py index 6536c297e8..6f72918b71 100644 --- a/python/paddle/v2/framework/tests/test_accuracy_op.py +++ b/python/paddle/v2/fluid/tests/test_accuracy_op.py @@ -18,7 +18,9 @@ class TestAccuracyOp(OpTest): num_correct += 1 break self.outputs = { - 'Accuracy': np.array([num_correct / float(n)]).astype("float32") + 'Accuracy': np.array([num_correct / float(n)]).astype("float32"), + 'Correct': np.array([num_correct]).astype("int32"), + 'Total': np.array([n]).astype("int32") } def test_check_output(self): diff --git a/python/paddle/v2/framework/tests/test_activation_op.py b/python/paddle/v2/fluid/tests/test_activation_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_activation_op.py rename to python/paddle/v2/fluid/tests/test_activation_op.py diff --git a/python/paddle/v2/framework/tests/test_adadelta_op.py b/python/paddle/v2/fluid/tests/test_adadelta_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_adadelta_op.py rename to python/paddle/v2/fluid/tests/test_adadelta_op.py diff --git a/python/paddle/v2/framework/tests/test_adagrad_op.py b/python/paddle/v2/fluid/tests/test_adagrad_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_adagrad_op.py rename to python/paddle/v2/fluid/tests/test_adagrad_op.py diff --git a/python/paddle/v2/framework/tests/test_adam_op.py b/python/paddle/v2/fluid/tests/test_adam_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_adam_op.py rename to python/paddle/v2/fluid/tests/test_adam_op.py diff --git a/python/paddle/v2/framework/tests/test_adamax_op.py b/python/paddle/v2/fluid/tests/test_adamax_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_adamax_op.py rename to python/paddle/v2/fluid/tests/test_adamax_op.py diff --git a/python/paddle/v2/framework/tests/test_array_read_write_op.py b/python/paddle/v2/fluid/tests/test_array_read_write_op.py similarity index 91% rename from python/paddle/v2/framework/tests/test_array_read_write_op.py rename to python/paddle/v2/fluid/tests/test_array_read_write_op.py index 79e9938216..e019a4e15f 100644 --- a/python/paddle/v2/framework/tests/test_array_read_write_op.py +++ b/python/paddle/v2/fluid/tests/test_array_read_write_op.py @@ -1,9 +1,9 @@ import unittest -import paddle.v2.framework.core as core -import paddle.v2.framework.layers as layers -from paddle.v2.framework.executor import Executor -from paddle.v2.framework.backward import append_backward_ops -from paddle.v2.framework.framework import g_main_program +import paddle.v2.fluid.core as core +import paddle.v2.fluid.layers as layers +from paddle.v2.fluid.executor import Executor +from paddle.v2.fluid.backward import append_backward_ops +from paddle.v2.fluid.framework import g_main_program import numpy diff --git a/python/paddle/v2/fluid/tests/test_assign_op.py b/python/paddle/v2/fluid/tests/test_assign_op.py new file mode 100644 index 0000000000..1b0c145f1a --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_assign_op.py @@ -0,0 +1,21 @@ +import op_test +import numpy +import unittest + + +class TestAssignOp(op_test.OpTest): + def setUp(self): + self.op_type = "assign" + x = numpy.random.random(size=(100, 10)) + self.inputs = {'X': x} + self.outputs = {'Out': x} + + def test_forward(self): + self.check_output() + + def test_backward(self): + self.check_grad(['X'], 'Out') + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_auc_op.py b/python/paddle/v2/fluid/tests/test_auc_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_auc_op.py rename to python/paddle/v2/fluid/tests/test_auc_op.py diff --git a/python/paddle/v2/framework/tests/test_batch_norm_op.py b/python/paddle/v2/fluid/tests/test_batch_norm_op.py similarity index 99% rename from python/paddle/v2/framework/tests/test_batch_norm_op.py rename to python/paddle/v2/fluid/tests/test_batch_norm_op.py index dee339f43c..71f9599e0d 100644 --- a/python/paddle/v2/framework/tests/test_batch_norm_op.py +++ b/python/paddle/v2/fluid/tests/test_batch_norm_op.py @@ -1,8 +1,8 @@ import unittest import numpy as np from op_test import OpTest -import paddle.v2.framework.core as core -from paddle.v2.framework.op import Operator +import paddle.v2.fluid.core as core +from paddle.v2.fluid.op import Operator def grad_var_name(var_name): diff --git a/python/paddle/v2/fluid/tests/test_beam_search_decode_op.py b/python/paddle/v2/fluid/tests/test_beam_search_decode_op.py new file mode 100644 index 0000000000..8a11820d2a --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_beam_search_decode_op.py @@ -0,0 +1,75 @@ +import unittest + +import numpy as np +import paddle.v2.fluid.core as core +from paddle.v2.fluid.op import Operator + + +class TestBeamSearchDecodeOp(unittest.TestCase): + def setUp(self): + self.scope = core.Scope() + self.cpu_place = core.CPUPlace() + + def append_lod_tensor(self, tensor_array, lod, data): + lod_tensor = core.LoDTensor() + lod_tensor.set_lod(lod) + lod_tensor.set(data, self.cpu_place) + tensor_array.append(lod_tensor) + + def test_get_set(self): + ids = self.scope.var("ids").get_lod_tensor_array() + self.append_lod_tensor( + ids, [[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]], + np.array( + [1, 2, 3, 4, 5, 6], dtype="int64")) + self.append_lod_tensor( + ids, [[0, 3, 6], [0, 1, 1, 3, 5, 5, 6]], + np.array( + [0, 1, 2, 3, 4, 5], dtype="int64")) + self.append_lod_tensor( + ids, [[0, 3, 6], [0, 0, 1, 2, 3, 4, 5]], + np.array( + [0, 1, 2, 3, 4], dtype="int64")) + + scores = self.scope.var("scores").get_lod_tensor_array() + self.append_lod_tensor( + scores, [[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]], + np.array( + [1, 2, 3, 4, 5, 6], dtype="float32")) + self.append_lod_tensor( + scores, [[0, 3, 6], [0, 1, 1, 3, 5, 5, 6]], + np.array( + [0, 1, 2, 3, 4, 5], dtype="float32")) + self.append_lod_tensor( + scores, [[0, 3, 6], [0, 0, 1, 2, 3, 4, 5]], + np.array( + [0, 1, 2, 3, 4], dtype="float32")) + + sentence_ids = self.scope.var("sentence_ids").get_tensor() + sentence_scores = self.scope.var("sentence_scores").get_tensor() + + beam_search_decode_op = Operator( + "beam_search_decode", + # inputs + Ids="ids", + Scores="scores", + # outputs + SentenceIds="sentence_ids", + SentenceScores="sentence_scores") + + ctx = core.DeviceContext.create(self.cpu_place) + beam_search_decode_op.run(self.scope, ctx) + + expected_lod = [[0, 4, 8], [0, 1, 3, 6, 9, 10, 13, 16, 19]] + self.assertEqual(sentence_ids.lod(), expected_lod) + self.assertEqual(sentence_scores.lod(), expected_lod) + + expected_data = np.array( + [2, 1, 0, 3, 1, 0, 3, 2, 1, 5, 4, 3, 2, 4, 4, 3, 6, 5, 4], "int64") + self.assertTrue(np.array_equal(np.array(sentence_ids), expected_data)) + self.assertTrue( + np.array_equal(np.array(sentence_scores), expected_data)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_bilinear_tensor_product_op.py b/python/paddle/v2/fluid/tests/test_bilinear_tensor_product_op.py new file mode 100644 index 0000000000..080ca43b82 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_bilinear_tensor_product_op.py @@ -0,0 +1,37 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestBilinearTensorProductOp(OpTest): + def setUp(self): + self.op_type = "bilinear_tensor_product" + batch_size = 6 + size0 = 3 + size1 = 4 + size2 = 5 + a = np.random.random((batch_size, size0)).astype("float32") + b = np.random.random((batch_size, size1)).astype("float32") + w = np.random.random((size2, size0, size1)).astype("float32") + bias = np.random.random((1, size2)).astype("float32") + output = np.zeros((batch_size, size2)).astype("float32") + for i in range(size2): + w_i = w[i, :, :] + output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1) + self.inputs = { + 'X': a, + 'Y': b, + 'Weight': w, + 'Bias': bias, + } + self.outputs = {'Out': output + bias} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y', 'Weight', 'Bias'], 'Out') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_cast_op.py b/python/paddle/v2/fluid/tests/test_cast_op.py similarity index 93% rename from python/paddle/v2/framework/tests/test_cast_op.py rename to python/paddle/v2/fluid/tests/test_cast_op.py index 52ee71a8a4..0c4b631065 100644 --- a/python/paddle/v2/framework/tests/test_cast_op.py +++ b/python/paddle/v2/fluid/tests/test_cast_op.py @@ -1,7 +1,7 @@ import op_test import unittest import numpy as np -import paddle.v2.framework.core as core +import paddle.v2.fluid.core as core class TestCastOp(op_test.OpTest): diff --git a/python/paddle/v2/framework/tests/test_chunk_eval_op.py b/python/paddle/v2/fluid/tests/test_chunk_eval_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_chunk_eval_op.py rename to python/paddle/v2/fluid/tests/test_chunk_eval_op.py diff --git a/python/paddle/v2/framework/tests/test_clip_by_norm_op.py b/python/paddle/v2/fluid/tests/test_clip_by_norm_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_clip_by_norm_op.py rename to python/paddle/v2/fluid/tests/test_clip_by_norm_op.py diff --git a/python/paddle/v2/framework/tests/test_clip_op.py b/python/paddle/v2/fluid/tests/test_clip_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_clip_op.py rename to python/paddle/v2/fluid/tests/test_clip_op.py diff --git a/python/paddle/v2/framework/tests/test_compare_op.py b/python/paddle/v2/fluid/tests/test_compare_op.py similarity index 79% rename from python/paddle/v2/framework/tests/test_compare_op.py rename to python/paddle/v2/fluid/tests/test_compare_op.py index bb0256694d..5d0dfab6ff 100644 --- a/python/paddle/v2/framework/tests/test_compare_op.py +++ b/python/paddle/v2/fluid/tests/test_compare_op.py @@ -23,6 +23,9 @@ def create_test_class(op_type, typename, callback): for _type_name in {'float32', 'float64', 'int32', 'int64'}: create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) + create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) + create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b) + create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b) create_test_class('equal', _type_name, lambda _a, _b: _a == _b) if __name__ == '__main__': diff --git a/python/paddle/v2/framework/tests/test_concat_op.py b/python/paddle/v2/fluid/tests/test_concat_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_concat_op.py rename to python/paddle/v2/fluid/tests/test_concat_op.py diff --git a/python/paddle/v2/framework/tests/test_cond_op.py b/python/paddle/v2/fluid/tests/test_cond_op.py similarity index 97% rename from python/paddle/v2/framework/tests/test_cond_op.py rename to python/paddle/v2/fluid/tests/test_cond_op.py index 09a3f5dc97..9d1df44b90 100644 --- a/python/paddle/v2/framework/tests/test_cond_op.py +++ b/python/paddle/v2/fluid/tests/test_cond_op.py @@ -1,8 +1,8 @@ import logging -import paddle.v2.framework.core as core +import paddle.v2.fluid.core as core import unittest import numpy as np -from paddle.v2.framework.op import Operator, CondOp +from paddle.v2.fluid.op import Operator, CondOp class PySimpleCond(object): diff --git a/python/paddle/v2/fluid/tests/test_conditional_block.py b/python/paddle/v2/fluid/tests/test_conditional_block.py new file mode 100644 index 0000000000..293803f004 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_conditional_block.py @@ -0,0 +1,40 @@ +import unittest +import paddle.v2.fluid.layers as layers +import paddle.v2.fluid.core as core +from paddle.v2.fluid.framework import g_startup_program, g_main_program +from paddle.v2.fluid.executor import Executor +from paddle.v2.fluid.backward import append_backward_ops +import numpy + + +class ConditionalBlock(unittest.TestCase): + def test_forward(self): + data = layers.data(name='X', shape=[1], data_type='float32') + data.stop_gradient = False + cond = layers.ConditionalBlock(inputs=[data]) + out = layers.create_tensor(dtype='float32') + with cond.block(): + hidden = layers.fc(input=data, size=10) + layers.assign(hidden, out) + + cpu = core.CPUPlace() + exe = Executor(cpu) + exe.run(g_startup_program) + + x = core.LoDTensor() + x.set(numpy.random.random(size=(10, 1)).astype('float32'), cpu) + + outs = map(numpy.array, exe.run(feed={'X': x}, fetch_list=[out]))[0] + print outs + loss = layers.mean(x=out) + append_backward_ops(loss=loss) + outs = map(numpy.array, + exe.run(feed={'X': x}, + fetch_list=[ + g_main_program.block(0).var(data.name + "@GRAD") + ]))[0] + print outs + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/fluid/tests/test_conv2d_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_conv2d_op.py rename to python/paddle/v2/fluid/tests/test_conv2d_op.py diff --git a/python/paddle/v2/framework/tests/test_conv2d_transpose_op.py b/python/paddle/v2/fluid/tests/test_conv2d_transpose_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_conv2d_transpose_op.py rename to python/paddle/v2/fluid/tests/test_conv2d_transpose_op.py diff --git a/python/paddle/v2/framework/tests/test_conv3d_op.py b/python/paddle/v2/fluid/tests/test_conv3d_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_conv3d_op.py rename to python/paddle/v2/fluid/tests/test_conv3d_op.py diff --git a/python/paddle/v2/framework/tests/test_conv3d_transpose_op.py b/python/paddle/v2/fluid/tests/test_conv3d_transpose_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_conv3d_transpose_op.py rename to python/paddle/v2/fluid/tests/test_conv3d_transpose_op.py diff --git a/python/paddle/v2/framework/tests/test_conv_shift_op.py b/python/paddle/v2/fluid/tests/test_conv_shift_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_conv_shift_op.py rename to python/paddle/v2/fluid/tests/test_conv_shift_op.py diff --git a/python/paddle/v2/framework/tests/test_cos_sim_op.py b/python/paddle/v2/fluid/tests/test_cos_sim_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_cos_sim_op.py rename to python/paddle/v2/fluid/tests/test_cos_sim_op.py diff --git a/python/paddle/v2/framework/tests/test_create_op_doc_string.py b/python/paddle/v2/fluid/tests/test_create_op_doc_string.py similarity index 80% rename from python/paddle/v2/framework/tests/test_create_op_doc_string.py rename to python/paddle/v2/fluid/tests/test_create_op_doc_string.py index d21e96df2a..42b6f7a361 100644 --- a/python/paddle/v2/framework/tests/test_create_op_doc_string.py +++ b/python/paddle/v2/fluid/tests/test_create_op_doc_string.py @@ -1,5 +1,5 @@ import unittest -import paddle.v2.framework.layers as layers +import paddle.v2.fluid.layers as layers class TestDocString(unittest.TestCase): diff --git a/python/paddle/v2/framework/tests/test_crf_decoding_op.py b/python/paddle/v2/fluid/tests/test_crf_decoding_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_crf_decoding_op.py rename to python/paddle/v2/fluid/tests/test_crf_decoding_op.py diff --git a/python/paddle/v2/framework/tests/test_crop_op.py b/python/paddle/v2/fluid/tests/test_crop_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_crop_op.py rename to python/paddle/v2/fluid/tests/test_crop_op.py diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/fluid/tests/test_cross_entropy_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_cross_entropy_op.py rename to python/paddle/v2/fluid/tests/test_cross_entropy_op.py diff --git a/python/paddle/v2/framework/tests/test_decayed_adagrad_op.py b/python/paddle/v2/fluid/tests/test_decayed_adagrad_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_decayed_adagrad_op.py rename to python/paddle/v2/fluid/tests/test_decayed_adagrad_op.py diff --git a/python/paddle/v2/framework/tests/test_default_scope_funcs.py b/python/paddle/v2/fluid/tests/test_default_scope_funcs.py similarity index 94% rename from python/paddle/v2/framework/tests/test_default_scope_funcs.py rename to python/paddle/v2/fluid/tests/test_default_scope_funcs.py index 09a9850d05..738e69529e 100644 --- a/python/paddle/v2/framework/tests/test_default_scope_funcs.py +++ b/python/paddle/v2/fluid/tests/test_default_scope_funcs.py @@ -1,4 +1,4 @@ -from paddle.v2.framework.default_scope_funcs import * +from paddle.v2.fluid.default_scope_funcs import * import unittest diff --git a/python/paddle/v2/framework/tests/test_dropout_op.py b/python/paddle/v2/fluid/tests/test_dropout_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_dropout_op.py rename to python/paddle/v2/fluid/tests/test_dropout_op.py diff --git a/python/paddle/v2/framework/tests/test_dynamic_recurrent_op.py b/python/paddle/v2/fluid/tests/test_dynamic_recurrent_op.py similarity index 98% rename from python/paddle/v2/framework/tests/test_dynamic_recurrent_op.py rename to python/paddle/v2/fluid/tests/test_dynamic_recurrent_op.py index 70af9dbc49..c2d8b48ea9 100644 --- a/python/paddle/v2/framework/tests/test_dynamic_recurrent_op.py +++ b/python/paddle/v2/fluid/tests/test_dynamic_recurrent_op.py @@ -1,7 +1,7 @@ import logging -import paddle.v2.framework.core as core +import paddle.v2.fluid.core as core import unittest -from paddle.v2.framework.op import Operator, DynamicRecurrentOp +from paddle.v2.fluid.op import Operator, DynamicRecurrentOp import numpy as np # for siplicity, just one level LoD diff --git a/python/paddle/v2/framework/tests/test_elementwise_add_op.py b/python/paddle/v2/fluid/tests/test_elementwise_add_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_elementwise_add_op.py rename to python/paddle/v2/fluid/tests/test_elementwise_add_op.py diff --git a/python/paddle/v2/framework/tests/test_elementwise_div_op.py b/python/paddle/v2/fluid/tests/test_elementwise_div_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_elementwise_div_op.py rename to python/paddle/v2/fluid/tests/test_elementwise_div_op.py diff --git a/python/paddle/v2/framework/tests/test_elementwise_mul_op.py b/python/paddle/v2/fluid/tests/test_elementwise_mul_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_elementwise_mul_op.py rename to python/paddle/v2/fluid/tests/test_elementwise_mul_op.py diff --git a/python/paddle/v2/framework/tests/test_elementwise_sub_op.py b/python/paddle/v2/fluid/tests/test_elementwise_sub_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_elementwise_sub_op.py rename to python/paddle/v2/fluid/tests/test_elementwise_sub_op.py diff --git a/python/paddle/v2/framework/tests/test_exception.py b/python/paddle/v2/fluid/tests/test_exception.py similarity index 89% rename from python/paddle/v2/framework/tests/test_exception.py rename to python/paddle/v2/fluid/tests/test_exception.py index 5ae048817c..b871f40c4a 100644 --- a/python/paddle/v2/framework/tests/test_exception.py +++ b/python/paddle/v2/fluid/tests/test_exception.py @@ -1,4 +1,4 @@ -import paddle.v2.framework.core as core +import paddle.v2.fluid.core as core import unittest diff --git a/python/paddle/v2/framework/tests/test_executor_and_mul.py b/python/paddle/v2/fluid/tests/test_executor_and_mul.py similarity index 83% rename from python/paddle/v2/framework/tests/test_executor_and_mul.py rename to python/paddle/v2/fluid/tests/test_executor_and_mul.py index c885cfbebd..709250d0c8 100644 --- a/python/paddle/v2/framework/tests/test_executor_and_mul.py +++ b/python/paddle/v2/fluid/tests/test_executor_and_mul.py @@ -1,8 +1,8 @@ import unittest -from paddle.v2.framework.layers import mul, data -import paddle.v2.framework.core as core -from paddle.v2.framework.executor import Executor -from paddle.v2.framework.framework import g_main_program +from paddle.v2.fluid.layers import mul, data +import paddle.v2.fluid.core as core +from paddle.v2.fluid.executor import Executor +from paddle.v2.fluid.framework import g_main_program import numpy diff --git a/python/paddle/v2/framework/tests/test_expand_op.py b/python/paddle/v2/fluid/tests/test_expand_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_expand_op.py rename to python/paddle/v2/fluid/tests/test_expand_op.py diff --git a/python/paddle/v2/framework/tests/test_feed_fetch_method.py b/python/paddle/v2/fluid/tests/test_feed_fetch_method.py similarity index 95% rename from python/paddle/v2/framework/tests/test_feed_fetch_method.py rename to python/paddle/v2/fluid/tests/test_feed_fetch_method.py index fbd659ece0..178c85b0dd 100644 --- a/python/paddle/v2/framework/tests/test_feed_fetch_method.py +++ b/python/paddle/v2/fluid/tests/test_feed_fetch_method.py @@ -1,4 +1,4 @@ -import paddle.v2.framework.core as core +import paddle.v2.fluid.core as core import unittest import numpy as np diff --git a/python/paddle/v2/framework/tests/test_fill_constant_batch_size_like_op.py b/python/paddle/v2/fluid/tests/test_fill_constant_batch_size_like_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_fill_constant_batch_size_like_op.py rename to python/paddle/v2/fluid/tests/test_fill_constant_batch_size_like_op.py diff --git a/python/paddle/v2/framework/tests/test_fill_constant_op.py b/python/paddle/v2/fluid/tests/test_fill_constant_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_fill_constant_op.py rename to python/paddle/v2/fluid/tests/test_fill_constant_op.py diff --git a/python/paddle/v2/framework/tests/test_fill_zeros_like_op.py b/python/paddle/v2/fluid/tests/test_fill_zeros_like_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_fill_zeros_like_op.py rename to python/paddle/v2/fluid/tests/test_fill_zeros_like_op.py diff --git a/python/paddle/v2/framework/tests/test_framework_debug_str.py b/python/paddle/v2/fluid/tests/test_framework_debug_str.py similarity index 85% rename from python/paddle/v2/framework/tests/test_framework_debug_str.py rename to python/paddle/v2/fluid/tests/test_framework_debug_str.py index 8fdf8f9117..a4cbabdb36 100644 --- a/python/paddle/v2/framework/tests/test_framework_debug_str.py +++ b/python/paddle/v2/fluid/tests/test_framework_debug_str.py @@ -1,5 +1,5 @@ import unittest -from paddle.v2.framework.framework import Program +from paddle.v2.fluid.framework import Program class TestDebugStringFramework(unittest.TestCase): diff --git a/python/paddle/v2/framework/tests/test_gather_op.py b/python/paddle/v2/fluid/tests/test_gather_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_gather_op.py rename to python/paddle/v2/fluid/tests/test_gather_op.py diff --git a/python/paddle/v2/framework/tests/test_gaussian_random_op.py b/python/paddle/v2/fluid/tests/test_gaussian_random_op.py similarity index 91% rename from python/paddle/v2/framework/tests/test_gaussian_random_op.py rename to python/paddle/v2/fluid/tests/test_gaussian_random_op.py index 0dc7e091a5..627ab4e235 100644 --- a/python/paddle/v2/framework/tests/test_gaussian_random_op.py +++ b/python/paddle/v2/fluid/tests/test_gaussian_random_op.py @@ -1,6 +1,6 @@ import unittest -import paddle.v2.framework.core as core -from paddle.v2.framework.op import Operator +import paddle.v2.fluid.core as core +from paddle.v2.fluid.op import Operator import numpy diff --git a/python/paddle/v2/framework/tests/test_gru_op.py b/python/paddle/v2/fluid/tests/test_gru_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_gru_op.py rename to python/paddle/v2/fluid/tests/test_gru_op.py diff --git a/python/paddle/v2/framework/tests/test_gru_unit_op.py b/python/paddle/v2/fluid/tests/test_gru_unit_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_gru_unit_op.py rename to python/paddle/v2/fluid/tests/test_gru_unit_op.py diff --git a/python/paddle/v2/framework/tests/test_huber_loss_op.py b/python/paddle/v2/fluid/tests/test_huber_loss_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_huber_loss_op.py rename to python/paddle/v2/fluid/tests/test_huber_loss_op.py diff --git a/python/paddle/v2/framework/tests/test_image_classification_layer.py b/python/paddle/v2/fluid/tests/test_image_classification_layer.py similarity index 95% rename from python/paddle/v2/framework/tests/test_image_classification_layer.py rename to python/paddle/v2/fluid/tests/test_image_classification_layer.py index b1a267ec32..bf5444107f 100644 --- a/python/paddle/v2/framework/tests/test_image_classification_layer.py +++ b/python/paddle/v2/fluid/tests/test_image_classification_layer.py @@ -1,8 +1,8 @@ import unittest -import paddle.v2.framework.layers as layers -import paddle.v2.framework.nets as nets -from paddle.v2.framework.framework import Program +import paddle.v2.fluid.layers as layers +import paddle.v2.fluid.nets as nets +from paddle.v2.fluid.framework import Program def conv_block(input, diff --git a/python/paddle/v2/framework/tests/test_infer_shape.py b/python/paddle/v2/fluid/tests/test_infer_shape.py similarity index 98% rename from python/paddle/v2/framework/tests/test_infer_shape.py rename to python/paddle/v2/fluid/tests/test_infer_shape.py index 2b2995f5e2..9f6695ce02 100644 --- a/python/paddle/v2/framework/tests/test_infer_shape.py +++ b/python/paddle/v2/fluid/tests/test_infer_shape.py @@ -1,6 +1,6 @@ import unittest -import paddle.v2.framework.core as core +import paddle.v2.fluid.core as core class TestInferShape(unittest.TestCase): diff --git a/python/paddle/v2/framework/tests/test_inference_model_io.py b/python/paddle/v2/fluid/tests/test_inference_model_io.py similarity index 90% rename from python/paddle/v2/framework/tests/test_inference_model_io.py rename to python/paddle/v2/fluid/tests/test_inference_model_io.py index 48984f86a1..98b95713b7 100644 --- a/python/paddle/v2/framework/tests/test_inference_model_io.py +++ b/python/paddle/v2/fluid/tests/test_inference_model_io.py @@ -1,11 +1,11 @@ import paddle.v2 as paddle -import paddle.v2.framework.layers as layers -import paddle.v2.framework.core as core -import paddle.v2.framework.optimizer as optimizer +import paddle.v2.fluid.layers as layers +import paddle.v2.fluid.core as core +import paddle.v2.fluid.optimizer as optimizer -from paddle.v2.framework.framework import Program -from paddle.v2.framework.io import save_inference_model, load_inference_model -import paddle.v2.framework.executor as executor +from paddle.v2.fluid.framework import Program +from paddle.v2.fluid.io import save_inference_model, load_inference_model +import paddle.v2.fluid.executor as executor import unittest import numpy as np diff --git a/python/paddle/v2/framework/tests/test_initializer.py b/python/paddle/v2/fluid/tests/test_initializer.py similarity index 98% rename from python/paddle/v2/framework/tests/test_initializer.py rename to python/paddle/v2/fluid/tests/test_initializer.py index bd4d2e39d7..f2eb79b209 100644 --- a/python/paddle/v2/framework/tests/test_initializer.py +++ b/python/paddle/v2/fluid/tests/test_initializer.py @@ -1,8 +1,8 @@ import numpy as np import unittest -import paddle.v2.framework.framework as framework -import paddle.v2.framework.initializer as initializer +import paddle.v2.fluid.framework as framework +import paddle.v2.fluid.initializer as initializer DELTA = 0.00001 diff --git a/python/paddle/v2/framework/tests/test_l1_norm_op.py b/python/paddle/v2/fluid/tests/test_l1_norm_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_l1_norm_op.py rename to python/paddle/v2/fluid/tests/test_l1_norm_op.py diff --git a/python/paddle/v2/framework/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py similarity index 97% rename from python/paddle/v2/framework/tests/test_layers.py rename to python/paddle/v2/fluid/tests/test_layers.py index b42af5ea45..3d18e7ce3a 100644 --- a/python/paddle/v2/framework/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -1,7 +1,7 @@ -import paddle.v2.framework.layers as layers -import paddle.v2.framework.nets as nets -from paddle.v2.framework.framework import Program -import paddle.v2.framework.core as core +import paddle.v2.fluid.layers as layers +import paddle.v2.fluid.nets as nets +from paddle.v2.fluid.framework import Program +import paddle.v2.fluid.core as core import unittest diff --git a/python/paddle/v2/framework/tests/test_linear_chain_crf_op.py b/python/paddle/v2/fluid/tests/test_linear_chain_crf_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_linear_chain_crf_op.py rename to python/paddle/v2/fluid/tests/test_linear_chain_crf_op.py diff --git a/python/paddle/v2/framework/tests/test_lod_array_length_op.py b/python/paddle/v2/fluid/tests/test_lod_array_length_op.py similarity index 79% rename from python/paddle/v2/framework/tests/test_lod_array_length_op.py rename to python/paddle/v2/fluid/tests/test_lod_array_length_op.py index af2b4d705e..a01ae83772 100644 --- a/python/paddle/v2/framework/tests/test_lod_array_length_op.py +++ b/python/paddle/v2/fluid/tests/test_lod_array_length_op.py @@ -1,7 +1,7 @@ import unittest -import paddle.v2.framework.layers as layers -from paddle.v2.framework.executor import Executor -import paddle.v2.framework.core as core +import paddle.v2.fluid.layers as layers +from paddle.v2.fluid.executor import Executor +import paddle.v2.fluid.core as core import numpy diff --git a/python/paddle/v2/framework/tests/test_lod_rank_table.py b/python/paddle/v2/fluid/tests/test_lod_rank_table.py similarity index 78% rename from python/paddle/v2/framework/tests/test_lod_rank_table.py rename to python/paddle/v2/fluid/tests/test_lod_rank_table.py index 408145c10f..bbc11930b9 100644 --- a/python/paddle/v2/framework/tests/test_lod_rank_table.py +++ b/python/paddle/v2/fluid/tests/test_lod_rank_table.py @@ -1,7 +1,7 @@ -from paddle.v2.framework.layers import lod_rank_table, data -from paddle.v2.framework.executor import Executor -from paddle.v2.framework.framework import g_main_program -import paddle.v2.framework.core as core +from paddle.v2.fluid.layers import lod_rank_table, data +from paddle.v2.fluid.executor import Executor +from paddle.v2.fluid.framework import g_main_program +import paddle.v2.fluid.core as core import numpy import unittest diff --git a/python/paddle/v2/fluid/tests/test_lod_reset_op.py b/python/paddle/v2/fluid/tests/test_lod_reset_op.py new file mode 100644 index 0000000000..652ccecfa4 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_lod_reset_op.py @@ -0,0 +1,64 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestLodResetOpByAttr(OpTest): + def setUp(self): + self.op_type = "lod_reset" + x = np.random.random((10, 20)).astype("float32") + lod = [[0, 3, 5, 10]] + target_lod_0 = [0, 7, 10] + self.inputs = {'X': (x, lod)} + self.attrs = {'target_lod': target_lod_0} + self.outputs = {'Out': (x, [target_lod_0])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestLodResetOpByInput(OpTest): + def setUp(self): + self.op_type = "lod_reset" + x = np.random.random((10, 20)).astype("float32") + lod = [[0, 3, 5, 10]] + target_lod_0 = [0, 4, 7, 10] + self.inputs = { + 'X': (x, lod), + 'TargetLoD': np.array([target_lod_0]).astype('int32') + } + self.outputs = {'Out': (x, [target_lod_0])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out", no_grad_set=set("TargetLoD")) + + +class TestLodResetOpBoth(OpTest): + def setUp(self): + self.op_type = "lod_reset" + x = np.random.random((10, 20)).astype("float32") + lod = [[0, 3, 5, 10]] + target_lod_0_attr = [0, 7, 10] + target_lod_0_in = [0, 4, 7, 10] + self.inputs = { + 'X': (x, lod), + 'TargetLoD': np.array(target_lod_0_in).astype('int32') + } + self.attrs = {'target_lod': target_lod_0_attr} + self.outputs = {'Out': (x, [target_lod_0_in])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out", no_grad_set=set("TargetLoD")) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_lod_tensor_array.py b/python/paddle/v2/fluid/tests/test_lod_tensor_array.py similarity index 96% rename from python/paddle/v2/framework/tests/test_lod_tensor_array.py rename to python/paddle/v2/fluid/tests/test_lod_tensor_array.py index a433bcf622..d6d3e23fd8 100644 --- a/python/paddle/v2/framework/tests/test_lod_tensor_array.py +++ b/python/paddle/v2/fluid/tests/test_lod_tensor_array.py @@ -1,5 +1,5 @@ import unittest -import paddle.v2.framework.core as core +import paddle.v2.fluid.core as core import numpy diff --git a/python/paddle/v2/framework/tests/test_lod_tensor_array_ops.py b/python/paddle/v2/fluid/tests/test_lod_tensor_array_ops.py similarity index 96% rename from python/paddle/v2/framework/tests/test_lod_tensor_array_ops.py rename to python/paddle/v2/fluid/tests/test_lod_tensor_array_ops.py index e9713666b3..b18cb6b49f 100644 --- a/python/paddle/v2/framework/tests/test_lod_tensor_array_ops.py +++ b/python/paddle/v2/fluid/tests/test_lod_tensor_array_ops.py @@ -1,10 +1,10 @@ import unittest -import paddle.v2.framework.core as core +import paddle.v2.fluid.core as core import numpy -import paddle.v2.framework.layers as layers -from paddle.v2.framework.framework import Program -from paddle.v2.framework.executor import Executor -from paddle.v2.framework.backward import append_backward_ops +import paddle.v2.fluid.layers as layers +from paddle.v2.fluid.framework import Program +from paddle.v2.fluid.executor import Executor +from paddle.v2.fluid.backward import append_backward_ops class TestCPULoDTensorArrayOps(unittest.TestCase): diff --git a/python/paddle/v2/framework/tests/test_lookup_table_op.py b/python/paddle/v2/fluid/tests/test_lookup_table_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_lookup_table_op.py rename to python/paddle/v2/fluid/tests/test_lookup_table_op.py diff --git a/python/paddle/v2/framework/tests/test_lrn_op.py b/python/paddle/v2/fluid/tests/test_lrn_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_lrn_op.py rename to python/paddle/v2/fluid/tests/test_lrn_op.py diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/fluid/tests/test_lstm_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_lstm_op.py rename to python/paddle/v2/fluid/tests/test_lstm_op.py diff --git a/python/paddle/v2/framework/tests/test_lstm_unit_op.py b/python/paddle/v2/fluid/tests/test_lstm_unit_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_lstm_unit_op.py rename to python/paddle/v2/fluid/tests/test_lstm_unit_op.py diff --git a/python/paddle/v2/framework/tests/test_margin_rank_loss_op.py b/python/paddle/v2/fluid/tests/test_margin_rank_loss_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_margin_rank_loss_op.py rename to python/paddle/v2/fluid/tests/test_margin_rank_loss_op.py diff --git a/python/paddle/v2/framework/tests/test_matmul_op.py b/python/paddle/v2/fluid/tests/test_matmul_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_matmul_op.py rename to python/paddle/v2/fluid/tests/test_matmul_op.py diff --git a/python/paddle/v2/framework/tests/test_mean_op.py b/python/paddle/v2/fluid/tests/test_mean_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_mean_op.py rename to python/paddle/v2/fluid/tests/test_mean_op.py diff --git a/python/paddle/v2/framework/tests/test_minus_op.py b/python/paddle/v2/fluid/tests/test_minus_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_minus_op.py rename to python/paddle/v2/fluid/tests/test_minus_op.py diff --git a/python/paddle/v2/framework/tests/test_modified_huber_loss_op.py b/python/paddle/v2/fluid/tests/test_modified_huber_loss_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_modified_huber_loss_op.py rename to python/paddle/v2/fluid/tests/test_modified_huber_loss_op.py diff --git a/python/paddle/v2/framework/tests/test_momentum_op.py b/python/paddle/v2/fluid/tests/test_momentum_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_momentum_op.py rename to python/paddle/v2/fluid/tests/test_momentum_op.py diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/fluid/tests/test_mul_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_mul_op.py rename to python/paddle/v2/fluid/tests/test_mul_op.py diff --git a/python/paddle/v2/framework/tests/test_multiplex_op.py b/python/paddle/v2/fluid/tests/test_multiplex_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_multiplex_op.py rename to python/paddle/v2/fluid/tests/test_multiplex_op.py diff --git a/python/paddle/v2/framework/tests/test_nccl_init_op.py b/python/paddle/v2/fluid/tests/test_nccl_init_op.py similarity index 91% rename from python/paddle/v2/framework/tests/test_nccl_init_op.py rename to python/paddle/v2/fluid/tests/test_nccl_init_op.py index 054909fdf5..a536800ccd 100644 --- a/python/paddle/v2/framework/tests/test_nccl_init_op.py +++ b/python/paddle/v2/fluid/tests/test_nccl_init_op.py @@ -1,8 +1,8 @@ import unittest, os import numpy as np import paddle.v2 as paddle -from paddle.v2.framework.op import Operator -import paddle.v2.framework.core as core +from paddle.v2.fluid.op import Operator +import paddle.v2.fluid.core as core from op_test import OpTest, create_op, set_input if not core.is_compile_gpu(): diff --git a/python/paddle/v2/framework/tests/test_net.py b/python/paddle/v2/fluid/tests/test_net.py similarity index 93% rename from python/paddle/v2/framework/tests/test_net.py rename to python/paddle/v2/fluid/tests/test_net.py index 8503257feb..318df08a9e 100644 --- a/python/paddle/v2/framework/tests/test_net.py +++ b/python/paddle/v2/fluid/tests/test_net.py @@ -1,5 +1,5 @@ -import paddle.v2.framework.core as core -from paddle.v2.framework.op import Operator +import paddle.v2.fluid.core as core +from paddle.v2.fluid.op import Operator import unittest diff --git a/python/paddle/v2/framework/tests/test_op_support_gpu.py b/python/paddle/v2/fluid/tests/test_op_support_gpu.py similarity index 84% rename from python/paddle/v2/framework/tests/test_op_support_gpu.py rename to python/paddle/v2/fluid/tests/test_op_support_gpu.py index dd36c666c4..a0eb4bd5fd 100644 --- a/python/paddle/v2/framework/tests/test_op_support_gpu.py +++ b/python/paddle/v2/fluid/tests/test_op_support_gpu.py @@ -1,5 +1,5 @@ import unittest -import paddle.v2.framework.core as core +import paddle.v2.fluid.core as core class TestOpSupportGPU(unittest.TestCase): diff --git a/python/paddle/v2/framework/tests/test_operator.py b/python/paddle/v2/fluid/tests/test_operator.py similarity index 97% rename from python/paddle/v2/framework/tests/test_operator.py rename to python/paddle/v2/fluid/tests/test_operator.py index 98f6b2f5ee..4aa022ef90 100644 --- a/python/paddle/v2/framework/tests/test_operator.py +++ b/python/paddle/v2/fluid/tests/test_operator.py @@ -1,7 +1,7 @@ import unittest -import paddle.v2.framework.op as op -import paddle.v2.framework.core as core -import paddle.v2.framework.proto.framework_pb2 as framework_pb2 +import paddle.v2.fluid.op as op +import paddle.v2.fluid.core as core +import paddle.v2.fluid.proto.framework_pb2 as framework_pb2 class TestGetAllProtos(unittest.TestCase): diff --git a/python/paddle/v2/framework/tests/test_operator_desc.py b/python/paddle/v2/fluid/tests/test_operator_desc.py similarity index 96% rename from python/paddle/v2/framework/tests/test_operator_desc.py rename to python/paddle/v2/fluid/tests/test_operator_desc.py index a0bc4e0b91..e8362d2e9c 100644 --- a/python/paddle/v2/framework/tests/test_operator_desc.py +++ b/python/paddle/v2/fluid/tests/test_operator_desc.py @@ -1,6 +1,6 @@ import unittest -from paddle.v2.framework.framework import Variable, Program, g_main_program -import paddle.v2.framework.core as core +from paddle.v2.fluid.framework import Variable, Program, g_main_program +import paddle.v2.fluid.core as core class TestOperator(unittest.TestCase): diff --git a/python/paddle/v2/framework/tests/test_optimizer.py b/python/paddle/v2/fluid/tests/test_optimizer.py similarity index 84% rename from python/paddle/v2/framework/tests/test_optimizer.py rename to python/paddle/v2/fluid/tests/test_optimizer.py index a39e740260..7b4237e7fd 100644 --- a/python/paddle/v2/framework/tests/test_optimizer.py +++ b/python/paddle/v2/fluid/tests/test_optimizer.py @@ -1,8 +1,8 @@ import unittest -import paddle.v2.framework.framework as framework -import paddle.v2.framework.optimizer as optimizer -from paddle.v2.framework.backward import append_backward_ops +import paddle.v2.fluid.framework as framework +import paddle.v2.fluid.optimizer as optimizer +from paddle.v2.fluid.backward import append_backward_ops class TestOptimizer(unittest.TestCase): @@ -198,7 +198,7 @@ class TestAdagradOptimizer(unittest.TestCase): adagrad_op = opts[0] self.assertEqual(adagrad_op.type, "adagrad") - # check accumulators + # Check accumulators accumulators = adagrad_optimizer.get_accumulators() self.assertEqual(len(accumulators), 1) self.assertTrue(adagrad_optimizer.get_moment_str() in accumulators) @@ -331,5 +331,59 @@ class TestAdamaxOptimizer(unittest.TestCase): self.assertAlmostEqual(init_ops[0].attr('value'), learning_rate) +class TestDecayedAdagradOptimizer(unittest.TestCase): + class MockDecayedAdagrad(optimizer.DecayedAdagradOptimizer): + def get_accumulators(self): + return self._accumulators + + def get_moment_str(self): + return self._moment_acc_str + + def test_decayed_adagrad_optimizer(self): + init_program = framework.Program() + program = framework.Program() + block = program.global_block() + mul_x = block.create_parameter( + dtype="float32", shape=[5, 10], lod_level=0, name="mul.x") + mul_y = block.create_var( + dtype="float32", shape=[10, 8], lod_level=0, name="mul.y") + mul_out = block.create_var( + dtype="float32", shape=[5, 8], lod_level=0, name="mul.out") + block.append_op( + type="mul", + inputs={"X": mul_x, + "Y": mul_y}, + outputs={"Out": mul_out}, + attrs={"x_num_col_dims": 1}) + learning_rate = 0.01 + decayed_adagrad_optimizer = self.MockDecayedAdagrad( + learning_rate=learning_rate, decay=0.95, epsilon=1.0e-6) + params_grads = append_backward_ops(mul_out) + self.assertEqual(len(params_grads), 1) + self.assertEqual(len(decayed_adagrad_optimizer.get_accumulators()), 0) + opts = decayed_adagrad_optimizer.create_optimization_pass( + params_grads, mul_out, init_program) + self.assertEqual(len(opts), 1) + decayed_adagrad_op = opts[0] + self.assertEqual(decayed_adagrad_op.type, "decayed_adagrad") + + # Check accumulators + accumulators = decayed_adagrad_optimizer.get_accumulators() + self.assertEqual(len(accumulators), 1) + self.assertTrue( + decayed_adagrad_optimizer.get_moment_str() in accumulators) + moment_acc = accumulators[decayed_adagrad_optimizer.get_moment_str()] + self.assertEqual(len(moment_acc), 1) + self.assertTrue(mul_x.name in moment_acc) + + # Check init_program + init_ops = init_program.global_block().ops + self.assertEqual(len(init_ops), 2) + self.assertEqual(init_ops[0].type, "fill_constant") + self.assertAlmostEqual(init_ops[0].attr('value'), learning_rate) + self.assertEqual(init_ops[1].type, "fill_constant") + self.assertAlmostEqual(init_ops[1].attr('value'), 0.0) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_pad_op.py b/python/paddle/v2/fluid/tests/test_pad_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_pad_op.py rename to python/paddle/v2/fluid/tests/test_pad_op.py diff --git a/python/paddle/v2/framework/tests/test_parameter.py b/python/paddle/v2/fluid/tests/test_parameter.py similarity index 87% rename from python/paddle/v2/framework/tests/test_parameter.py rename to python/paddle/v2/fluid/tests/test_parameter.py index f04eb4cf27..71a1bd2aaf 100644 --- a/python/paddle/v2/framework/tests/test_parameter.py +++ b/python/paddle/v2/fluid/tests/test_parameter.py @@ -1,6 +1,6 @@ import unittest -from paddle.v2.framework.framework import g_main_program -import paddle.v2.framework.core as core +from paddle.v2.fluid.framework import g_main_program +import paddle.v2.fluid.core as core class TestParameter(unittest.TestCase): diff --git a/python/paddle/v2/framework/tests/test_pool2d_op.py b/python/paddle/v2/fluid/tests/test_pool2d_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_pool2d_op.py rename to python/paddle/v2/fluid/tests/test_pool2d_op.py diff --git a/python/paddle/v2/framework/tests/test_pool3d_op.py b/python/paddle/v2/fluid/tests/test_pool3d_op.py similarity index 99% rename from python/paddle/v2/framework/tests/test_pool3d_op.py rename to python/paddle/v2/fluid/tests/test_pool3d_op.py index a3aedf8d28..2ba86665a7 100644 --- a/python/paddle/v2/framework/tests/test_pool3d_op.py +++ b/python/paddle/v2/fluid/tests/test_pool3d_op.py @@ -4,7 +4,6 @@ from op_test import OpTest def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0): - N, C, D, H, W = x.shape if global_pool == 1: ksize = [D, H, W] @@ -28,7 +27,6 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0): def avg_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0): - N, C, D, H, W = x.shape if global_pool == 1: ksize = [D, H, W] diff --git a/python/paddle/v2/framework/tests/test_pool_max_op.py b/python/paddle/v2/fluid/tests/test_pool_max_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_pool_max_op.py rename to python/paddle/v2/fluid/tests/test_pool_max_op.py diff --git a/python/paddle/v2/framework/tests/test_positive_negative_pair_op.py b/python/paddle/v2/fluid/tests/test_positive_negative_pair_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_positive_negative_pair_op.py rename to python/paddle/v2/fluid/tests/test_positive_negative_pair_op.py diff --git a/python/paddle/v2/framework/tests/test_precision_recall_op.py b/python/paddle/v2/fluid/tests/test_precision_recall_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_precision_recall_op.py rename to python/paddle/v2/fluid/tests/test_precision_recall_op.py diff --git a/python/paddle/v2/framework/tests/test_prelu_op.py b/python/paddle/v2/fluid/tests/test_prelu_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_prelu_op.py rename to python/paddle/v2/fluid/tests/test_prelu_op.py diff --git a/python/paddle/v2/framework/tests/test_program.py b/python/paddle/v2/fluid/tests/test_program.py similarity index 96% rename from python/paddle/v2/framework/tests/test_program.py rename to python/paddle/v2/fluid/tests/test_program.py index 7be67b6614..ef2daf6916 100644 --- a/python/paddle/v2/framework/tests/test_program.py +++ b/python/paddle/v2/fluid/tests/test_program.py @@ -1,8 +1,8 @@ import unittest -import paddle.v2.framework.core as core -from paddle.v2.framework.framework import Program -from paddle.v2.framework.framework import g_main_program +import paddle.v2.fluid.core as core +from paddle.v2.fluid.framework import Program +from paddle.v2.fluid.framework import g_main_program class TestProgram(unittest.TestCase): diff --git a/python/paddle/v2/framework/tests/test_protobuf.py b/python/paddle/v2/fluid/tests/test_protobuf.py similarity index 92% rename from python/paddle/v2/framework/tests/test_protobuf.py rename to python/paddle/v2/fluid/tests/test_protobuf.py index 848a396b3b..e064374176 100644 --- a/python/paddle/v2/framework/tests/test_protobuf.py +++ b/python/paddle/v2/fluid/tests/test_protobuf.py @@ -1,4 +1,4 @@ -import paddle.v2.framework.proto.framework_pb2 as framework_pb2 +import paddle.v2.fluid.proto.framework_pb2 as framework_pb2 import unittest diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/fluid/tests/test_protobuf_descs.py similarity index 99% rename from python/paddle/v2/framework/tests/test_protobuf_descs.py rename to python/paddle/v2/fluid/tests/test_protobuf_descs.py index 2fd3d5d165..098a9802df 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/fluid/tests/test_protobuf_descs.py @@ -1,5 +1,5 @@ import unittest -import paddle.v2.framework.core as core +import paddle.v2.fluid.core as core class TestOpDesc(unittest.TestCase): diff --git a/python/paddle/v2/framework/tests/test_proximal_adagrad_op.py b/python/paddle/v2/fluid/tests/test_proximal_adagrad_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_proximal_adagrad_op.py rename to python/paddle/v2/fluid/tests/test_proximal_adagrad_op.py diff --git a/python/paddle/v2/framework/tests/test_proximal_gd_op.py b/python/paddle/v2/fluid/tests/test_proximal_gd_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_proximal_gd_op.py rename to python/paddle/v2/fluid/tests/test_proximal_gd_op.py diff --git a/python/paddle/v2/framework/tests/test_rank_loss_op.py b/python/paddle/v2/fluid/tests/test_rank_loss_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_rank_loss_op.py rename to python/paddle/v2/fluid/tests/test_rank_loss_op.py diff --git a/python/paddle/v2/framework/tests/test_recurrent_op.py b/python/paddle/v2/fluid/tests/test_recurrent_op.py similarity index 98% rename from python/paddle/v2/framework/tests/test_recurrent_op.py rename to python/paddle/v2/fluid/tests/test_recurrent_op.py index 16100429dd..b623d12318 100644 --- a/python/paddle/v2/framework/tests/test_recurrent_op.py +++ b/python/paddle/v2/fluid/tests/test_recurrent_op.py @@ -1,11 +1,11 @@ import unittest -import paddle.v2.framework.layers as layers -from paddle.v2.framework.framework import Program -from paddle.v2.framework.executor import Executor -from paddle.v2.framework.backward import append_backward_ops +import paddle.v2.fluid.layers as layers +from paddle.v2.fluid.framework import Program +from paddle.v2.fluid.executor import Executor +from paddle.v2.fluid.backward import append_backward_ops import numpy as np -import paddle.v2.framework.core as core +import paddle.v2.fluid.core as core class PyRNNBase(object): diff --git a/python/paddle/v2/framework/tests/test_reduce_op.py b/python/paddle/v2/fluid/tests/test_reduce_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_reduce_op.py rename to python/paddle/v2/fluid/tests/test_reduce_op.py diff --git a/python/paddle/v2/framework/tests/test_regularizer.py b/python/paddle/v2/fluid/tests/test_regularizer.py similarity index 92% rename from python/paddle/v2/framework/tests/test_regularizer.py rename to python/paddle/v2/fluid/tests/test_regularizer.py index b21dceb584..f5d1eb3b96 100644 --- a/python/paddle/v2/framework/tests/test_regularizer.py +++ b/python/paddle/v2/fluid/tests/test_regularizer.py @@ -1,9 +1,9 @@ import unittest -import paddle.v2.framework.framework as framework -import paddle.v2.framework.optimizer as optimizer -import paddle.v2.framework.regularizer as regularizer -from paddle.v2.framework.backward import append_backward_ops +import paddle.v2.fluid.framework as framework +import paddle.v2.fluid.optimizer as optimizer +import paddle.v2.fluid.regularizer as regularizer +from paddle.v2.fluid.backward import append_backward_ops class TestL2DecayRegularizer(unittest.TestCase): diff --git a/python/paddle/v2/framework/tests/test_reshape_op.py b/python/paddle/v2/fluid/tests/test_reshape_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_reshape_op.py rename to python/paddle/v2/fluid/tests/test_reshape_op.py diff --git a/python/paddle/v2/framework/tests/test_rmsprop_op.py b/python/paddle/v2/fluid/tests/test_rmsprop_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_rmsprop_op.py rename to python/paddle/v2/fluid/tests/test_rmsprop_op.py diff --git a/python/paddle/v2/framework/tests/test_rnn_memory_helper_op.py b/python/paddle/v2/fluid/tests/test_rnn_memory_helper_op.py similarity index 95% rename from python/paddle/v2/framework/tests/test_rnn_memory_helper_op.py rename to python/paddle/v2/fluid/tests/test_rnn_memory_helper_op.py index 731beff17c..a3cba92504 100644 --- a/python/paddle/v2/framework/tests/test_rnn_memory_helper_op.py +++ b/python/paddle/v2/fluid/tests/test_rnn_memory_helper_op.py @@ -1,10 +1,10 @@ import unittest -from paddle.v2.framework.framework import Program -from paddle.v2.framework.executor import Executor -from paddle.v2.framework.backward import append_backward_ops +from paddle.v2.fluid.framework import Program +from paddle.v2.fluid.executor import Executor +from paddle.v2.fluid.backward import append_backward_ops import numpy as np -import paddle.v2.framework.core as core +import paddle.v2.fluid.core as core def create_tensor(np_data, place): diff --git a/python/paddle/v2/framework/tests/test_scale_op.py b/python/paddle/v2/fluid/tests/test_scale_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_scale_op.py rename to python/paddle/v2/fluid/tests/test_scale_op.py diff --git a/python/paddle/v2/framework/tests/test_scatter_op.py b/python/paddle/v2/fluid/tests/test_scatter_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_scatter_op.py rename to python/paddle/v2/fluid/tests/test_scatter_op.py diff --git a/python/paddle/v2/framework/tests/test_scope.py b/python/paddle/v2/fluid/tests/test_scope.py similarity index 81% rename from python/paddle/v2/framework/tests/test_scope.py rename to python/paddle/v2/fluid/tests/test_scope.py index 1474365479..e4857b590a 100644 --- a/python/paddle/v2/framework/tests/test_scope.py +++ b/python/paddle/v2/fluid/tests/test_scope.py @@ -1,22 +1,22 @@ -import paddle.v2.framework.core +import paddle.v2.fluid.core import unittest class TestScope(unittest.TestCase): def test_create_destroy(self): - paddle_c = paddle.v2.framework.core + paddle_c = paddle.v2.fluid.core scope = paddle_c.Scope() self.assertIsNotNone(scope) scope_with_parent = scope.new_scope() self.assertIsNotNone(scope_with_parent) def test_none_variable(self): - paddle_c = paddle.v2.framework.core + paddle_c = paddle.v2.fluid.core scope = paddle_c.Scope() self.assertIsNone(scope.find_var("test")) def test_create_var_get_var(self): - paddle_c = paddle.v2.framework.core + paddle_c = paddle.v2.fluid.core scope = paddle_c.Scope() var_a = scope.var("var_a") self.assertIsNotNone(var_a) @@ -25,7 +25,7 @@ class TestScope(unittest.TestCase): self.assertIsNotNone(scope2.find_var('var_a')) def test_var_get_int(self): - paddle_c = paddle.v2.framework.core + paddle_c = paddle.v2.fluid.core scope = paddle_c.Scope() var = scope.var("test_int") var.set_int(10) diff --git a/python/paddle/v2/framework/tests/test_selected_rows.py b/python/paddle/v2/fluid/tests/test_selected_rows.py similarity index 96% rename from python/paddle/v2/framework/tests/test_selected_rows.py rename to python/paddle/v2/fluid/tests/test_selected_rows.py index e8a930cb08..93daf37aa2 100644 --- a/python/paddle/v2/framework/tests/test_selected_rows.py +++ b/python/paddle/v2/fluid/tests/test_selected_rows.py @@ -1,4 +1,4 @@ -import paddle.v2.framework.core as core +import paddle.v2.fluid.core as core import unittest import numpy as np diff --git a/python/paddle/v2/framework/tests/test_seq_concat_op.py b/python/paddle/v2/fluid/tests/test_seq_concat_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_seq_concat_op.py rename to python/paddle/v2/fluid/tests/test_seq_concat_op.py diff --git a/python/paddle/v2/framework/tests/test_seq_conv.py b/python/paddle/v2/fluid/tests/test_seq_conv.py similarity index 100% rename from python/paddle/v2/framework/tests/test_seq_conv.py rename to python/paddle/v2/fluid/tests/test_seq_conv.py diff --git a/python/paddle/v2/framework/tests/test_seq_expand.py b/python/paddle/v2/fluid/tests/test_seq_expand.py similarity index 100% rename from python/paddle/v2/framework/tests/test_seq_expand.py rename to python/paddle/v2/fluid/tests/test_seq_expand.py diff --git a/python/paddle/v2/framework/tests/test_seq_pool.py b/python/paddle/v2/fluid/tests/test_seq_pool.py similarity index 100% rename from python/paddle/v2/framework/tests/test_seq_pool.py rename to python/paddle/v2/fluid/tests/test_seq_pool.py diff --git a/python/paddle/v2/framework/tests/test_sequence_softmax_op.py b/python/paddle/v2/fluid/tests/test_sequence_softmax_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_sequence_softmax_op.py rename to python/paddle/v2/fluid/tests/test_sequence_softmax_op.py diff --git a/python/paddle/v2/framework/tests/test_sgd_op.py b/python/paddle/v2/fluid/tests/test_sgd_op.py similarity index 97% rename from python/paddle/v2/framework/tests/test_sgd_op.py rename to python/paddle/v2/fluid/tests/test_sgd_op.py index 01262bba4d..ca05a381f0 100644 --- a/python/paddle/v2/framework/tests/test_sgd_op.py +++ b/python/paddle/v2/fluid/tests/test_sgd_op.py @@ -1,7 +1,7 @@ import unittest import numpy as np -import paddle.v2.framework.core as core -from paddle.v2.framework.op import Operator +import paddle.v2.fluid.core as core +from paddle.v2.fluid.op import Operator from op_test import OpTest diff --git a/python/paddle/v2/framework/tests/test_shrink_rnn_memory.py b/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py similarity index 86% rename from python/paddle/v2/framework/tests/test_shrink_rnn_memory.py rename to python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py index 2090455b96..1a3b88e18e 100644 --- a/python/paddle/v2/framework/tests/test_shrink_rnn_memory.py +++ b/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py @@ -1,9 +1,9 @@ import unittest -import paddle.v2.framework.core as core -from paddle.v2.framework.executor import Executor -import paddle.v2.framework.layers as layers -from paddle.v2.framework.backward import append_backward_ops -from paddle.v2.framework.framework import g_main_program +import paddle.v2.fluid.core as core +from paddle.v2.fluid.executor import Executor +import paddle.v2.fluid.layers as layers +from paddle.v2.fluid.backward import append_backward_ops +from paddle.v2.fluid.framework import g_main_program import numpy diff --git a/python/paddle/v2/framework/tests/test_sigmoid_cross_entropy_with_logits_op.py b/python/paddle/v2/fluid/tests/test_sigmoid_cross_entropy_with_logits_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_sigmoid_cross_entropy_with_logits_op.py rename to python/paddle/v2/fluid/tests/test_sigmoid_cross_entropy_with_logits_op.py diff --git a/python/paddle/v2/framework/tests/test_sign_op.py b/python/paddle/v2/fluid/tests/test_sign_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_sign_op.py rename to python/paddle/v2/fluid/tests/test_sign_op.py diff --git a/python/paddle/v2/framework/tests/test_smooth_l1_loss_op.py b/python/paddle/v2/fluid/tests/test_smooth_l1_loss_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_smooth_l1_loss_op.py rename to python/paddle/v2/fluid/tests/test_smooth_l1_loss_op.py diff --git a/python/paddle/v2/framework/tests/test_softmax_op.py b/python/paddle/v2/fluid/tests/test_softmax_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_softmax_op.py rename to python/paddle/v2/fluid/tests/test_softmax_op.py diff --git a/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py b/python/paddle/v2/fluid/tests/test_softmax_with_cross_entropy_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py rename to python/paddle/v2/fluid/tests/test_softmax_with_cross_entropy_op.py diff --git a/python/paddle/v2/fluid/tests/test_split_and_merge_lod_tensor_op.py b/python/paddle/v2/fluid/tests/test_split_and_merge_lod_tensor_op.py new file mode 100644 index 0000000000..3aed83b2ea --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_split_and_merge_lod_tensor_op.py @@ -0,0 +1,181 @@ +import unittest +import paddle.v2.fluid.core as core +import numpy as np +import paddle.v2.fluid.layers as layers +from paddle.v2.fluid.framework import Program +from paddle.v2.fluid.executor import Executor +from paddle.v2.fluid.backward import append_backward_ops + + +class TestCPULoDTensorArrayOps(unittest.TestCase): + def place(self): + return core.CPUPlace() + + def test_split_and_merge_lod_tensor_no_lod(self): + tensor = core.LoDTensor() + tensor.set(np.arange(10).reshape(10, 1).astype('int32'), self.place()) + + mask_np = np.array([0, 0, 1, 1, 1, 1, 0, 0, 0, 0]).astype('bool') + mask_np = np.expand_dims(mask_np, axis=1) + + mask = core.LoDTensor() + mask.set(mask_np, self.place()) + + expect_true_tensor = np.array([2, 3, 4, 5]).astype('int32') + expect_true_tensor = np.expand_dims(expect_true_tensor, axis=1) + expect_true = core.LoDTensor() + expect_true.set(expect_true_tensor, self.place()) + + expect_false_tensor = np.array([0, 1, 6, 7, 8, 9]).astype('int32') + expect_false_tensor = np.expand_dims(expect_false_tensor, axis=1) + + expect_false = core.LoDTensor() + expect_false.set(expect_false_tensor, self.place()) + + self.main( + tensor=tensor, + mask=mask, + expect_true=expect_true, + expect_false=expect_false, + expect_out=tensor) + + def test_split_and_merge_lod_tensor_level_0(self): + tensor = core.LoDTensor() + tensor.set(np.arange(10).reshape(10, 1).astype('int32'), self.place()) + tensor.set_lod([[0, 3, 9, 10]]) + + mask_np = np.array([0, 1, 0]).astype('bool') + mask_np = np.expand_dims(mask_np, axis=1) + + mask = core.LoDTensor() + mask.set(mask_np, self.place()) + + expect_true_tensor = np.array([3, 4, 5, 6, 7, 8]).astype('int32') + expect_true_tensor = np.expand_dims(expect_true_tensor, axis=1) + expect_true = core.LoDTensor() + expect_true.set(expect_true_tensor, self.place()) + expect_true.set_lod([[0, 6]]) + + expect_false_tensor = np.array([0, 1, 2, 9]).astype('int32') + expect_false_tensor = np.expand_dims(expect_false_tensor, axis=1) + expect_false_lod = [[0, 3, 4]] + + expect_false = core.LoDTensor() + expect_false.set(expect_false_tensor, self.place()) + expect_false.set_lod(expect_false_lod) + + self.main( + tensor=tensor, + mask=mask, + expect_true=expect_true, + expect_false=expect_false, + expect_out=tensor) + + def main(self, tensor, mask, expect_true, expect_false, expect_out, + level=0): + place = self.place() + program = Program() + x = layers.data(name='x', shape=[1], main_program=program) + x.persistable = True + + y = layers.data(name='y', shape=[1], main_program=program) + y.persistable = True + + out_true, out_false = layers.split_lod_tensor( + input=x, mask=y, level=level, main_program=program) + out_true.persistable = True + out_false.persistable = True + + out = layers.merge_lod_tensor( + in_true=out_true, + in_false=out_false, + mask=y, + x=x, + level=level, + main_program=program) + + out.persistable = True + + exe = Executor(place) + scope = core.Scope() + exe.run(program, feed={'x': tensor, 'y': mask}, scope=scope) + + var_true = scope.find_var(out_true.name).get_tensor() + + var_false = scope.find_var(out_false.name).get_tensor() + + var_out = scope.find_var(out.name).get_tensor() + + self.check_tensor_same(var_true, expect_true) + self.check_tensor_same(var_false, expect_false) + self.check_tensor_same(var_out, expect_out) + + def check_tensor_same(self, actual, expect): + self.assertTrue(np.allclose(np.array(actual), np.array(expect))) + self.assertEqual(actual.lod(), expect.lod()) + + +class TestCPUSplitMergeLoDTensorGrad(unittest.TestCase): + def test_grad(self): + place = core.CPUPlace() + program = Program() + + x = layers.data( + name='x', + shape=[1], + data_type='float32', + main_program=program, + stop_gradient=False) + y = layers.data( + name='y', + shape=[1], + data_type='bool', + main_program=program, + stop_gradient=False) + + level = 0 + + out_true, out_false = layers.split_lod_tensor( + input=x, mask=y, level=level, main_program=program) + out = layers.merge_lod_tensor( + in_true=out_true, + in_false=out_false, + mask=y, + x=x, + level=level, + main_program=program) + mean = layers.mean(x=out, main_program=program) + + append_backward_ops(mean) + + tensor = core.LoDTensor() + tensor.set(np.arange(10).reshape(10, 1).astype('float32'), place) + tensor.set_lod([[0, 3, 9, 10]]) + + mask_np = np.array([0, 1, 0]).astype('bool') + mask_np = np.expand_dims(mask_np, axis=1) + + mask = core.LoDTensor() + mask.set(mask_np, place) + + exe = Executor(place) + scope = core.Scope() + + g_vars = program.global_block().var(x.name + "@GRAD") + g_out = [ + item.sum() + for item in map(np.array, + exe.run(program, + feed={'x': tensor, + 'y': mask}, + fetch_list=[g_vars], + scope=scope)) + ] + + g_out_sum = np.array(g_out).sum() + + self.assertAlmostEqual(1.0, g_out_sum, delta=0.1) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_split_op.py b/python/paddle/v2/fluid/tests/test_split_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_split_op.py rename to python/paddle/v2/fluid/tests/test_split_op.py diff --git a/python/paddle/v2/framework/tests/test_squared_l2_distance_op.py b/python/paddle/v2/fluid/tests/test_squared_l2_distance_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_squared_l2_distance_op.py rename to python/paddle/v2/fluid/tests/test_squared_l2_distance_op.py diff --git a/python/paddle/v2/framework/tests/test_squared_l2_norm_op.py b/python/paddle/v2/fluid/tests/test_squared_l2_norm_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_squared_l2_norm_op.py rename to python/paddle/v2/fluid/tests/test_squared_l2_norm_op.py diff --git a/python/paddle/v2/framework/tests/test_sum_op.py b/python/paddle/v2/fluid/tests/test_sum_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_sum_op.py rename to python/paddle/v2/fluid/tests/test_sum_op.py diff --git a/python/paddle/v2/framework/tests/test_tensor.py b/python/paddle/v2/fluid/tests/test_tensor.py similarity index 98% rename from python/paddle/v2/framework/tests/test_tensor.py rename to python/paddle/v2/fluid/tests/test_tensor.py index e0cd2fa8aa..9f870d9eb3 100644 --- a/python/paddle/v2/framework/tests/test_tensor.py +++ b/python/paddle/v2/fluid/tests/test_tensor.py @@ -1,4 +1,4 @@ -import paddle.v2.framework.core as core +import paddle.v2.fluid.core as core import unittest import numpy diff --git a/python/paddle/v2/framework/tests/test_tensor_array.py b/python/paddle/v2/fluid/tests/test_tensor_array.py similarity index 98% rename from python/paddle/v2/framework/tests/test_tensor_array.py rename to python/paddle/v2/fluid/tests/test_tensor_array.py index 50b3e09162..d6929ba16e 100644 --- a/python/paddle/v2/framework/tests/test_tensor_array.py +++ b/python/paddle/v2/fluid/tests/test_tensor_array.py @@ -1,5 +1,5 @@ import logging -import paddle.v2.framework.core as core +import paddle.v2.fluid.core as core import unittest import numpy as np diff --git a/python/paddle/v2/framework/tests/test_top_k_op.py b/python/paddle/v2/fluid/tests/test_top_k_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_top_k_op.py rename to python/paddle/v2/fluid/tests/test_top_k_op.py diff --git a/python/paddle/v2/framework/tests/test_transpose_op.py b/python/paddle/v2/fluid/tests/test_transpose_op.py similarity index 100% rename from python/paddle/v2/framework/tests/test_transpose_op.py rename to python/paddle/v2/fluid/tests/test_transpose_op.py diff --git a/python/paddle/v2/framework/tests/test_uniform_random_op.py b/python/paddle/v2/fluid/tests/test_uniform_random_op.py similarity index 90% rename from python/paddle/v2/framework/tests/test_uniform_random_op.py rename to python/paddle/v2/fluid/tests/test_uniform_random_op.py index ded777105e..f736dfb2e8 100644 --- a/python/paddle/v2/framework/tests/test_uniform_random_op.py +++ b/python/paddle/v2/fluid/tests/test_uniform_random_op.py @@ -1,6 +1,6 @@ import unittest -from paddle.v2.framework.op import Operator -import paddle.v2.framework.core as core +from paddle.v2.fluid.op import Operator +import paddle.v2.fluid.core as core import numpy diff --git a/python/paddle/v2/framework/tests/test_variable.py b/python/paddle/v2/fluid/tests/test_variable.py similarity index 93% rename from python/paddle/v2/framework/tests/test_variable.py rename to python/paddle/v2/fluid/tests/test_variable.py index 03115f10a5..a3e60a7517 100644 --- a/python/paddle/v2/framework/tests/test_variable.py +++ b/python/paddle/v2/fluid/tests/test_variable.py @@ -1,6 +1,6 @@ import unittest -from paddle.v2.framework.framework import Variable, g_main_program, Program -import paddle.v2.framework.core as core +from paddle.v2.fluid.framework import Variable, g_main_program, Program +import paddle.v2.fluid.core as core import numpy as np diff --git a/python/paddle/v2/framework/tests/test_while_op.py b/python/paddle/v2/fluid/tests/test_while_op.py similarity index 94% rename from python/paddle/v2/framework/tests/test_while_op.py rename to python/paddle/v2/fluid/tests/test_while_op.py index 1c344eae49..0f01acb3b9 100644 --- a/python/paddle/v2/framework/tests/test_while_op.py +++ b/python/paddle/v2/fluid/tests/test_while_op.py @@ -1,7 +1,7 @@ import unittest -import paddle.v2.framework.layers as layers -from paddle.v2.framework.executor import Executor -import paddle.v2.framework.core as core +import paddle.v2.fluid.layers as layers +from paddle.v2.fluid.executor import Executor +import paddle.v2.fluid.core as core import numpy diff --git a/python/paddle/v2/framework/evaluator.py b/python/paddle/v2/framework/evaluator.py deleted file mode 100644 index 254dd5f1a3..0000000000 --- a/python/paddle/v2/framework/evaluator.py +++ /dev/null @@ -1,59 +0,0 @@ -import paddle.v2.framework.op as op -import numpy as np -import paddle.v2.framework.core as core - - -def avg_accumulate(accumulated_var, per_eval, num_batches, place): - t = np.array(accumulated_var.get_tensor()) - t[0] += per_eval[0] - accumulated_var.get_tensor().set([t[0] / float(num_batches)], place) - - -class Evaluator(object): - def __init__(self, - scope, - operator='accuracy', - input='Inference', - label='Label', - output='Output', - place=core.CPUPlace()): - """ - create an evaluator for evaluating the inference. - NOTE: default run on CPUPlace(), running on GPUPlace doesn't improve performance much. - - :param scope: the scope instance contains the input. - :type scope: paddle.v2.framework.core.scope - :param operator: operator name for caculating the evaluation for each mini-batch. - :type operator: string - :param input: output variable name of forward network. - :type input: string - :param label: variable name of label - :type label: string - """ - self.scope = scope - self.place = place - self.output_name = output - self.num_batches = 0 - # create variable to store accumulated evaluator output - eval_name = ''.join([operator, "@Eval"]) - if scope.find_var(eval_name): - raise Exception("evaluator already exist in scope: %s" % eval_name) - self.accumulated_var = scope.var(eval_name) - t = self.accumulated_var.get_tensor() - t.set_dims((1, )) - t.set([0.0], place) - # self.accumulated_var = block.create_var(block, name=eval_name, shape=(1,)) - # self.accumulated_var.get_tensor().set([0.0]) - # create operator of evaluation - var_map = dict() # var name -> variable - var_map[input] = [input] - var_map[label] = [label] - var_map[output] = [output] - self.op = op.Operator(operator, **var_map) - - def evaluate(self, ctx, accumulator=avg_accumulate): - self.op.run(self.scope, ctx) - per_eval = np.array(self.scope.find_var(self.output_name).get_tensor()) - self.num_batches += 1 - accumulator(self.accumulated_var, per_eval, self.num_batches, - self.place) diff --git a/python/paddle/v2/framework/tests/test_evaluator.py b/python/paddle/v2/framework/tests/test_evaluator.py deleted file mode 100644 index 37dbfbc06b..0000000000 --- a/python/paddle/v2/framework/tests/test_evaluator.py +++ /dev/null @@ -1,64 +0,0 @@ -from paddle.v2.framework.evaluator import Evaluator -from paddle.v2.framework.op import Operator -import paddle.v2.framework.core as core -import unittest -import op_test -import numpy as np - - -class TestEvaluator(unittest.TestCase): - def setup(self, scope, inputs, outputs): - def __create_var__(var_name, arr): - np_arr = np.array(arr) - scope.var(var_name) - # tensor = var.get_tensor() - # tensor.set_dims(np_arr.shape) - - for var_name, arr in inputs.iteritems(): - __create_var__(var_name, arr) - - for var_name, arr in outputs.iteritems(): - __create_var__(var_name, arr) - - def test_evaluator(self): - - inputs = { - 'Inference': np.array([[1, 1, 1, 1, 1, 0, 0, 0, 0, 1]]).T, - 'Label': np.array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0]) - } - outputs = {'Accuracy': np.array([0.9])} - out_name = 'Accuracy' - - places = [core.CPUPlace()] - if core.is_compile_gpu(): - places.append(core.GPUPlace(0)) - - for place in places: - scope = core.Scope() - self.setup(scope, inputs, outputs) - - evaluator = Evaluator( - scope, - operator='accuracy', - input='Inference', - label='Label', - output=out_name, - place=place) - op_test.set_input(scope, evaluator.op, inputs, place) - ctx = core.DeviceContext.create(place) - - for i in range(10): # simulate 10 mini-batches - evaluator.evaluate(ctx) - - actual = np.array(scope.find_var(out_name).get_tensor()) - print actual - - self.assertTrue( - np.allclose( - actual, outputs[out_name], atol=1e-5), - "output name: " + out_name + " has diff.") - - -if __name__ == '__main__': - exit(0) - unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index 5348c2d8d7..fe91df10da 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -13,8 +13,8 @@ packages=['paddle', 'paddle.v2.reader', 'paddle.v2.master', 'paddle.v2.plot', - 'paddle.v2.framework', - 'paddle.v2.framework.proto', + 'paddle.v2.fluid', + 'paddle.v2.fluid.proto', 'py_paddle'] with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f: @@ -44,14 +44,14 @@ setup(name='paddlepaddle', ext_modules=[Extension('_foo', ['stub.cc'])], package_data={ 'paddle.v2.master': ['libpaddle_master.so'], - 'paddle.v2.framework': ['core.so'], + 'paddle.v2.fluid': ['core.so'], 'py_paddle':['*.py','_swig_paddle.so'] }, package_dir={ '': '${CMAKE_CURRENT_SOURCE_DIR}', - # The paddle.v2.framework.proto will be generated while compiling. + # The paddle.v2.fluid.proto will be generated while compiling. # So that package points to other directory. - 'paddle.v2.framework.proto': '${PADDLE_BINARY_DIR}/paddle/framework', + 'paddle.v2.fluid.proto': '${PADDLE_BINARY_DIR}/paddle/framework', 'py_paddle': '${PADDLE_SOURCE_DIR}/paddle/py_paddle' }, scripts=paddle_bins,