From 2d44a2ec5a55699252bb64aa4a57186705c73d5f Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Mon, 30 Oct 2017 19:37:45 -0700 Subject: [PATCH 01/16] deconv cudnn --- paddle/operators/conv2dtranspose_cudnn_op.cc | 50 ++++ paddle/operators/conv2dtranspose_cudnn_op.cu | 276 +++++++++++++++++++ 2 files changed, 326 insertions(+) create mode 100644 paddle/operators/conv2dtranspose_cudnn_op.cc create mode 100644 paddle/operators/conv2dtranspose_cudnn_op.cu diff --git a/paddle/operators/conv2dtranspose_cudnn_op.cc b/paddle/operators/conv2dtranspose_cudnn_op.cc new file mode 100644 index 0000000000..72c470389c --- /dev/null +++ b/paddle/operators/conv2dtranspose_cudnn_op.cc @@ -0,0 +1,50 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv2dtranspose_op.h" + +namespace paddle { +namespace operators { + +class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker { + public: + CudnnConv2DTransposeOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : Conv2DTransposeOpMaker(proto, op_checker) { + AddAttr>("dilations", "dilations of convolution operator.") + .SetDefault(std::vector{1, 1}); + AddAttr("workspace_size_MB", + "workspace size for cudnn, in MB, " + "workspace is a section of GPU memory which will be " + "allocated/freed each time the operator runs, larger " + "workspace size can increase performance but also requires " + "better hardward. This size should be carefully setted.") + .SetDefault(4096); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(conv2dtranspose_cudnn, ops::Conv2DTransposeOp, + ops::CudnnConv2DTransposeOpMaker, conv2dtranspose_cudnn_grad, + ops::Conv2DTransposeOpGrad); + +REGISTER_OP_CPU_KERNEL( + conv2dtranspose_cudnn, + ops::GemmConv2DTransposeKernel); +REGISTER_OP_CPU_KERNEL( + conv2dtranspose_cudnn_grad, + ops::GemmConv2DTransposeGradKernel); diff --git a/paddle/operators/conv2dtranspose_cudnn_op.cu b/paddle/operators/conv2dtranspose_cudnn_op.cu new file mode 100644 index 0000000000..e9bad8c517 --- /dev/null +++ b/paddle/operators/conv2dtranspose_cudnn_op.cu @@ -0,0 +1,276 @@ +/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "glog/logging.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/memory/memory.h" +#include "paddle/operators/conv2d_op.h" +#include "paddle/platform/assert.h" +#include "paddle/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; +using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; +using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; +using DataLayout = platform::DataLayout; +using CUDADeviceContext = platform::CUDADeviceContext; + +static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024; + +template +class CudnnConvTransposeOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + auto* input = ctx.Input("Input"); + auto* filter = ctx.Input("Filter"); + auto* output = ctx.Output("Output"); + + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + int user_workspace_size = ctx.Attr("workspace_size_MB"); + + const T* input_data = input->data(); + const T* filter_data = filter->data(); + T* output_data = output->mutable_data(ctx.GetPlace()); + // ------------------- cudnn descriptors --------------------- + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor output_desc; + ScopedFilterDescriptor filter_desc; + ScopedConvolutionDescriptor conv_desc; + DataLayout layout = DataLayout::kNCHW; + + // N, M, H, W + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + layout, framework::vectorize2int(input->dims())); + // N, C, O_h, O_w + cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( + layout, framework::vectorize2int(output->dims())); + // M, C, K_h, K_w + cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( + layout, framework::vectorize2int(filter->dims())); + cudnnConvolutionDescriptor_t cudnn_conv_desc = + conv_desc.descriptor(paddings, strides, dilations); + + int input_channels = input->dims()[1]; // M + int input_height = input->dims()[2]; // H + int input_width = input->dims()[3]; // W + int output_channels = output->dims()[1]; // C + int output_height = output->dims()[2]; // O_H + int output_width = output->dims()[3]; // O_W + + // ------------------- cudnn conv workspace --------------------- + void* cudnn_workspace = nullptr; + size_t workspace_size_in_bytes; // final workspace to allocate. + size_t tmp_size; + size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; + if (user_workspace_size > 0) { + workspace_size_limit = user_workspace_size * 1024 * 1024; + } + // ------------------- cudnn conv algorithm --------------------- + cudnnConvolutionBwdAlgo_t algo; + auto handle = ctx.cuda_device_context().cudnn_handle(); + // Get the algorithm + PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( + handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, + // dxDesc: Handle to the previously initialized output tensor + // descriptor. + cudnn_output_desc, CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); + + // get workspace size able to allocate + PADDLE_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( + handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, + cudnn_output_desc, algo, &tmp_size)); + workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); + + // Allocate on GPU memory + platform::GPUPlace gpu = boost::get(ctx.GetPlace()); + cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); + + // ------------------- cudnn conv transpose forward --------------------- + T alpha = 1.0f, beta = 0.0f; + PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, cudnn_filter_desc, filter_data, cudnn_input_desc, + input_data, cudnn_conv_desc, algo, cudnn_workspace, + workspace_size_in_bytes, &beta, cudnn_output_desc, output_data)); + + // Release the cudnn workspace + paddle::memory::Free(gpu, cudnn_workspace); + } +}; + +/* +template +class CudnnConvTransposeGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + auto input = ctx.Input("Input"); + auto filter = ctx.Input("Filter"); + auto output_grad = ctx.Input(framework::GradVarName("Output")); + auto input_grad = ctx.Output(framework::GradVarName("Input")); + auto filter_grad = ctx.Output(framework::GradVarName("Filter")); + + const T* input_data = input->data(); + const T* output_grad_data = output_grad->data(); + const T* filter_data = filter->data(); + + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); + int user_workspace_size = ctx.Attr("workspace_size_MB"); + + // ------------------- cudnn descriptors --------------------- + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor output_grad_desc; + ScopedTensorDescriptor input_grad_desc; + + ScopedFilterDescriptor filter_desc; + ScopedFilterDescriptor filter_grad_desc; + ScopedConvolutionDescriptor conv_desc; + DataLayout layout = DataLayout::kNCHW; + + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + layout, framework::vectorize2int(input->dims()), groups); + cudnnTensorDescriptor_t cudnn_output_grad_desc = + output_grad_desc.descriptor( + layout, framework::vectorize2int(output_grad->dims()), groups); + cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( + layout, framework::vectorize2int(filter->dims()), groups); + cudnnTensorDescriptor_t cudnn_input_grad_desc = nullptr; + cudnnFilterDescriptor_t cudnn_filter_grad_desc = nullptr; + + cudnnConvolutionDescriptor_t cudnn_conv_desc = + conv_desc.descriptor(paddings, strides, dilations); + + int input_channels = input->dims()[1]; + int input_height = input->dims()[2]; + int input_width = input->dims()[3]; + int output_grad_channels = filter->dims()[0]; + int output_grad_height = output_grad->dims()[2]; + int output_grad_width = output_grad->dims()[3]; + + int group_offset_in = input_channels / groups * input_height * input_width; + int group_offset_out = + output_grad_channels / groups * output_grad_height * output_grad_width; + int group_offset_filter = filter->numel() / groups; + // ------------------- cudnn backward algorithm --------------------- + cudnnConvolutionBwdDataAlgo_t data_algo; + cudnnConvolutionBwdFilterAlgo_t filter_algo; + size_t workspace_size_in_bytes = 0, tmp_size = 0; + size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; + if (user_workspace_size > 0) { + workspace_size_limit = user_workspace_size * 1024 * 1024; + } + + auto handle = ctx.cuda_device_context().cudnn_handle(); + if (input_grad) { + cudnn_input_grad_desc = input_grad_desc.descriptor( + layout, framework::vectorize2int(input_grad->dims()), groups); + PADDLE_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( + handle, cudnn_filter_desc, + // dyDesc: Handle to the previously initialized input differential + // tensor descriptor. + cudnn_output_grad_desc, cudnn_conv_desc, + // dxDesc: Handle to the previously initialized output tensor + // descriptor. + cudnn_input_grad_desc, + CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &data_algo)); + PADDLE_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( + handle, cudnn_filter_desc, cudnn_output_grad_desc, + cudnn_conv_desc, cudnn_input_grad_desc, data_algo, &tmp_size)); + workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); + } + + if (filter_grad) { + cudnn_filter_grad_desc = filter_grad_desc.descriptor( + layout, framework::vectorize2int(filter_grad->dims()), groups); + PADDLE_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( + handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, + cudnn_filter_desc, + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &filter_algo)); + + PADDLE_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( + handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, + cudnn_filter_desc, filter_algo, &tmp_size)); + workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); + } + // ------------------- cudnn conv workspace --------------------- + // Already on GPU + void* cudnn_workspace = nullptr; + platform::GPUPlace gpu = boost::get(ctx.GetPlace()); + cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); + // ------------------- cudnn conv backward data --------------------- + // FIXME(typhoonzero): template type T may not be the same as cudnn call. + T alpha = 1.0f, beta = 0.0f; + if (input_grad) { + T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + auto t = framework::EigenVector::Flatten(*input_grad); + t.device(ctx.GetEigenDevice()) = + t.constant(static_cast(0)); + for (int i = 0; i < groups; i++) { + PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, cudnn_filter_desc, + filter_data + i * group_offset_filter, cudnn_output_grad_desc, + output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo, + cudnn_workspace, workspace_size_in_bytes, &beta, + cudnn_input_grad_desc, input_grad_data + i * group_offset_in)); + } + } + // ------------------- cudnn conv backward filter --------------------- + if (filter_grad) { + T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); + auto t = framework::EigenVector::Flatten(*filter_grad); + t.device(ctx.GetEigenDevice()) = + t.constant(static_cast(0)); + for (int i = 0; i < groups; i++) { + PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( + handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, + cudnn_output_grad_desc, output_grad_data + i * group_offset_out, + cudnn_conv_desc, filter_algo, cudnn_workspace, + workspace_size_in_bytes, &beta, cudnn_filter_grad_desc, + filter_grad_data + i * group_offset_filter)); + } + } + // Release the cudnn workspace + paddle::memory::Free(gpu, cudnn_workspace); + } +}; +*/ + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL(conv2dtranspose_cudnn, + ops::CudnnConvTransposeOpKernel); +// REGISTER_OP_GPU_KERNEL(conv2dtranspose_cudnn_grad, +// ops::CudnnConvTransposeGradOpKernel); From a349bee6ad4a454187edb5f47c8b7968bbcaa842 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Mon, 30 Oct 2017 19:53:51 -0700 Subject: [PATCH 02/16] deconv2d cudnn --- paddle/operators/conv2dtranspose_cudnn_op.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/operators/conv2dtranspose_cudnn_op.cu b/paddle/operators/conv2dtranspose_cudnn_op.cu index e9bad8c517..257c1fc62e 100644 --- a/paddle/operators/conv2dtranspose_cudnn_op.cu +++ b/paddle/operators/conv2dtranspose_cudnn_op.cu @@ -79,13 +79,13 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel { // ------------------- cudnn conv workspace --------------------- void* cudnn_workspace = nullptr; size_t workspace_size_in_bytes; // final workspace to allocate. - size_t tmp_size; size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; if (user_workspace_size > 0) { workspace_size_limit = user_workspace_size * 1024 * 1024; } // ------------------- cudnn conv algorithm --------------------- - cudnnConvolutionBwdAlgo_t algo; + // cudnnConvolutionBwdAlgo_t algo; + cudnnConvolutionBwdDataAlgo_t algo; auto handle = ctx.cuda_device_context().cudnn_handle(); // Get the algorithm PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( @@ -99,8 +99,8 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel { PADDLE_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, - cudnn_output_desc, algo, &tmp_size)); - workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); + cudnn_output_desc, algo, &workspace_size_in_bytes)); + // workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); // Allocate on GPU memory platform::GPUPlace gpu = boost::get(ctx.GetPlace()); From b77f9fbf041a458ef25e48139884b425f489579b Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Tue, 31 Oct 2017 11:58:04 -0700 Subject: [PATCH 03/16] deconv2d cudnn --- paddle/operators/conv2dtranspose_cudnn_op.cu | 120 ++++++------------ .../tests/test_conv2dtranspose_op.py | 46 +++---- 2 files changed, 63 insertions(+), 103 deletions(-) diff --git a/paddle/operators/conv2dtranspose_cudnn_op.cu b/paddle/operators/conv2dtranspose_cudnn_op.cu index 257c1fc62e..8485bc65eb 100644 --- a/paddle/operators/conv2dtranspose_cudnn_op.cu +++ b/paddle/operators/conv2dtranspose_cudnn_op.cu @@ -12,7 +12,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "glog/logging.h" #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/memory/memory.h" @@ -69,13 +68,6 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel { cudnnConvolutionDescriptor_t cudnn_conv_desc = conv_desc.descriptor(paddings, strides, dilations); - int input_channels = input->dims()[1]; // M - int input_height = input->dims()[2]; // H - int input_width = input->dims()[3]; // W - int output_channels = output->dims()[1]; // C - int output_height = output->dims()[2]; // O_H - int output_width = output->dims()[3]; // O_W - // ------------------- cudnn conv workspace --------------------- void* cudnn_workspace = nullptr; size_t workspace_size_in_bytes; // final workspace to allocate. @@ -118,7 +110,6 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel { } }; -/* template class CudnnConvTransposeGradOpKernel : public framework::OpKernel { public: @@ -130,7 +121,6 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel { auto output_grad = ctx.Input(framework::GradVarName("Output")); auto input_grad = ctx.Output(framework::GradVarName("Input")); auto filter_grad = ctx.Output(framework::GradVarName("Filter")); - const T* input_data = input->data(); const T* output_grad_data = output_grad->data(); const T* filter_data = filter->data(); @@ -138,47 +128,33 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel { std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); std::vector dilations = ctx.Attr>("dilations"); - int groups = ctx.Attr("groups"); int user_workspace_size = ctx.Attr("workspace_size_MB"); // ------------------- cudnn descriptors --------------------- ScopedTensorDescriptor input_desc; - ScopedTensorDescriptor output_grad_desc; - ScopedTensorDescriptor input_grad_desc; - + ScopedTensorDescriptor output_desc; ScopedFilterDescriptor filter_desc; - ScopedFilterDescriptor filter_grad_desc; ScopedConvolutionDescriptor conv_desc; DataLayout layout = DataLayout::kNCHW; + // Input: (N, M, H, W) cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( - layout, framework::vectorize2int(input->dims()), groups); - cudnnTensorDescriptor_t cudnn_output_grad_desc = - output_grad_desc.descriptor( - layout, framework::vectorize2int(output_grad->dims()), groups); + layout, framework::vectorize2int(input->dims())); + // Output: (N, C, O_H, O_W) + cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( + layout, framework::vectorize2int(output_grad->dims())); + // Filter (M, C, K_H, K_W) cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( - layout, framework::vectorize2int(filter->dims()), groups); - cudnnTensorDescriptor_t cudnn_input_grad_desc = nullptr; - cudnnFilterDescriptor_t cudnn_filter_grad_desc = nullptr; + layout, framework::vectorize2int(filter->dims())); cudnnConvolutionDescriptor_t cudnn_conv_desc = conv_desc.descriptor(paddings, strides, dilations); - int input_channels = input->dims()[1]; - int input_height = input->dims()[2]; - int input_width = input->dims()[3]; - int output_grad_channels = filter->dims()[0]; - int output_grad_height = output_grad->dims()[2]; - int output_grad_width = output_grad->dims()[3]; - - int group_offset_in = input_channels / groups * input_height * input_width; - int group_offset_out = - output_grad_channels / groups * output_grad_height * output_grad_width; - int group_offset_filter = filter->numel() / groups; // ------------------- cudnn backward algorithm --------------------- - cudnnConvolutionBwdDataAlgo_t data_algo; + cudnnConvolutionFwdAlgo_t data_algo; cudnnConvolutionBwdFilterAlgo_t filter_algo; - size_t workspace_size_in_bytes = 0, tmp_size = 0; + size_t bwd_filter_ws_size, fwd_ws_size; + size_t workspace_size_in_bytes = 0; size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; if (user_workspace_size > 0) { workspace_size_limit = user_workspace_size * 1024 * 1024; @@ -186,42 +162,35 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel { auto handle = ctx.cuda_device_context().cudnn_handle(); if (input_grad) { - cudnn_input_grad_desc = input_grad_desc.descriptor( - layout, framework::vectorize2int(input_grad->dims()), groups); - PADDLE_ENFORCE( - platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( - handle, cudnn_filter_desc, - // dyDesc: Handle to the previously initialized input differential - // tensor descriptor. - cudnn_output_grad_desc, cudnn_conv_desc, - // dxDesc: Handle to the previously initialized output tensor - // descriptor. - cudnn_input_grad_desc, - CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &data_algo)); - PADDLE_ENFORCE( - platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( - handle, cudnn_filter_desc, cudnn_output_grad_desc, - cudnn_conv_desc, cudnn_input_grad_desc, data_algo, &tmp_size)); - workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); + // choose backward algorithm for data + PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( + handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_input_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &data_algo)); + PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( + handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc, + cudnn_input_desc, data_algo, &fwd_ws_size)); + workspace_size_in_bytes = std::max(workspace_size_in_bytes, fwd_ws_size); } if (filter_grad) { - cudnn_filter_grad_desc = filter_grad_desc.descriptor( - layout, framework::vectorize2int(filter_grad->dims()), groups); + // choose backward algorithm for filter PADDLE_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( - handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, + handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc, cudnn_filter_desc, CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit, &filter_algo)); + // get workspace for backwards filter algorithm PADDLE_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( - handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, - cudnn_filter_desc, filter_algo, &tmp_size)); - workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); + handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc, + cudnn_filter_desc, filter_algo, &bwd_filter_ws_size)); + workspace_size_in_bytes = + std::max(workspace_size_in_bytes, bwd_filter_ws_size); } + // ------------------- cudnn conv workspace --------------------- // Already on GPU void* cudnn_workspace = nullptr; @@ -235,35 +204,30 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel { auto t = framework::EigenVector::Flatten(*input_grad); t.device(ctx.GetEigenDevice()) = t.constant(static_cast(0)); - for (int i = 0; i < groups; i++) { - PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( - handle, &alpha, cudnn_filter_desc, - filter_data + i * group_offset_filter, cudnn_output_grad_desc, - output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo, - cudnn_workspace, workspace_size_in_bytes, &beta, - cudnn_input_grad_desc, input_grad_data + i * group_offset_in)); - } + + PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward( + handle, &alpha, cudnn_output_desc, output_grad_data, + cudnn_filter_desc, filter_data, cudnn_conv_desc, data_algo, + cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, + input_grad_data)); } + // ------------------- cudnn conv backward filter --------------------- if (filter_grad) { T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); auto t = framework::EigenVector::Flatten(*filter_grad); t.device(ctx.GetEigenDevice()) = t.constant(static_cast(0)); - for (int i = 0; i < groups; i++) { - PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( - handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, - cudnn_output_grad_desc, output_grad_data + i * group_offset_out, - cudnn_conv_desc, filter_algo, cudnn_workspace, - workspace_size_in_bytes, &beta, cudnn_filter_grad_desc, - filter_grad_data + i * group_offset_filter)); - } + // Gradient with respect to the filter + PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( + handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc, + input_data, cudnn_conv_desc, filter_algo, cudnn_workspace, + workspace_size_in_bytes, &beta, cudnn_filter_desc, filter_grad_data)); } // Release the cudnn workspace paddle::memory::Free(gpu, cudnn_workspace); } }; -*/ } // namespace operators } // namespace paddle @@ -272,5 +236,5 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(conv2dtranspose_cudnn, ops::CudnnConvTransposeOpKernel); -// REGISTER_OP_GPU_KERNEL(conv2dtranspose_cudnn_grad, -// ops::CudnnConvTransposeGradOpKernel); +REGISTER_OP_GPU_KERNEL(conv2dtranspose_cudnn_grad, + ops::CudnnConvTransposeGradOpKernel); diff --git a/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py b/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py index 53604c58b7..4ed6e0bcc4 100644 --- a/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py +++ b/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py @@ -56,27 +56,9 @@ class TestConv2dTransposeOp(OpTest): self.outputs = {'Output': output} def test_check_output(self): - print 'check output here' + print 'check output here for', self.op_type self.check_output() - def test_check_grad(self): - self.check_grad( - set(['Input', 'Filter']), 'Output', max_relative_error=0.05) - - def test_check_grad_no_filter(self): - self.check_grad( - ['Input'], - 'Output', - max_relative_error=0.05, - no_grad_set=set(['Filter'])) - - def test_check_grad_no_input(self): - self.check_grad( - ['Filter'], - 'Output', - max_relative_error=0.05, - no_grad_set=set(['Input'])) - def init_test_case(self): self.pad = [0, 0] self.stride = [1, 1] @@ -88,15 +70,29 @@ class TestConv2dTransposeOp(OpTest): def init_op_type(self): self.op_type = "conv2dtranspose" + def test_check_grad_no_input(self): + self.check_grad( + ['Filter'], + 'Output', + max_relative_error=0.05, + no_grad_set=set(['Input'])) + + def test_check_grad_no_filter(self): + self.check_grad( + ['Input'], + 'Output', + max_relative_error=0.05, + no_grad_set=set(['Filter'])) -""" -class TestCudnn(TestConv2dOp): - def init_group(self): - self.groups = 1 + def test_check_grad(self): + self.check_grad( + set(['Input', 'Filter']), 'Output', max_relative_error=0.05) + +class TestCudnn(TestConv2dTransposeOp): def init_op_type(self): - self.op_type = "conv_cudnn" -""" + self.op_type = "conv2dtranspose_cudnn" + if __name__ == '__main__': unittest.main() From 8013328ed840ab65afbb2bff4eb1e27bc264eea6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A6=E6=AF=85?= Date: Tue, 31 Oct 2017 15:37:23 +0800 Subject: [PATCH 04/16] Refine evaluator op types (#5208) * refine evaluator op types * update * follow comments * update * fix v2 mnist case * fix v2 mnist case * update * update --- paddle/operators/accuracy_op.cc | 39 +++++++++++++------ paddle/operators/accuracy_op.cu | 24 +++++++----- paddle/operators/accuracy_op.h | 9 +++-- paddle/operators/auc_op.cc | 38 ++++++++++++------ paddle/operators/auc_op.h | 37 ++++++++---------- python/paddle/v2/framework/layers.py | 7 +++- .../v2/framework/tests/test_accuracy_op.py | 11 +++--- .../paddle/v2/framework/tests/test_auc_op.py | 16 ++++---- 8 files changed, 108 insertions(+), 73 deletions(-) diff --git a/paddle/operators/accuracy_op.cc b/paddle/operators/accuracy_op.cc index 88958e1634..2a2a1e9cfd 100644 --- a/paddle/operators/accuracy_op.cc +++ b/paddle/operators/accuracy_op.cc @@ -22,23 +22,35 @@ class AccuracyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Inference"), - "Input(Inference) of AccuracyOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Out"), + "Input (Out) of accuracy op should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Indices"), + "Input (Indices) of accuracy op should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), - "Input(Label) of AccuracyOp should not be null."); + "Input (Label) of accuracy op should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Accuracy"), - "Output(Accuracy) of AccuracyOp should not be null."); + "Output (Accuracy) of AccuracyOp should not be null."); - auto inference_dim = ctx->GetInputDim("Inference"); + auto inference_dim = ctx->GetInputDim("Out"); auto label_dim = ctx->GetInputDim("Label"); + // Assume indices has same shape with infernece, because + // it's the output of topk. PADDLE_ENFORCE_EQ(label_dim.size(), 2, "label's rank must be 2."); PADDLE_ENFORCE_EQ(label_dim[1], 1, "label's second dimension must be 1"); PADDLE_ENFORCE_EQ(inference_dim[0], label_dim[0], - "inference size must be the same as label size"); + "the inference tensor's num_rows must be" + " the same as label."); ctx->SetOutputDim("Accuracy", {1}); - ctx->ShareLoD("Inference", /*->*/ "Accuracy"); + ctx->ShareLoD("Out", /*->*/ "Accuracy"); + } + + protected: + // IndicateDataType + framework::DataType IndicateDataType( + const framework::ExecutionContext &ctx) const override { + return framework::ToDataType(ctx.Input("Out")->type()); } }; @@ -48,7 +60,8 @@ class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker { framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { // TODO(typhoonzero): support both inference value and indices. - AddInput("Inference", "topk(indices) the network output"); + AddInput("Out", "topk (inferences) the network output"); + AddInput("Indices", "topk (indices) the network output"); AddInput("Label", "Label of the training data"); // TODO(typhoonzero): AddInput("Weight", ... AddOutput("Accuracy", "The accuracy of current batch"); @@ -59,7 +72,7 @@ The accuracy is: .. math:: accuracy = \\frac{NumOfCorrectPredicts}{NumOfAllSamples}) -Both the input `Inference` and `Label` can carry the LoD (Level of Details) +Both the input `Out` and `Label` can carry the LoD (Level of Details) information, or not. But the output only shares the LoD with input `Inference`. )DOC"); } @@ -71,6 +84,8 @@ information, or not. But the output only shares the LoD with input `Inference`. namespace ops = paddle::operators; REGISTER_OPERATOR(accuracy, ops::AccuracyOp, ops::AccuracyOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL( - accuracy, ops::AccuracyKernel, - ops::AccuracyKernel); +// FIXME(typhoonzero): types of T is for infernece data. +// label data is always int. +REGISTER_OP_CPU_KERNEL(accuracy, + ops::AccuracyKernel, + ops::AccuracyKernel); diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu index be58dfbd03..a0483f367e 100644 --- a/paddle/operators/accuracy_op.cu +++ b/paddle/operators/accuracy_op.cu @@ -21,9 +21,10 @@ namespace paddle { namespace operators { using platform::PADDLE_CUDA_NUM_THREADS; -template -__global__ void AccuracyCudaKernel(const int N, const int D, const T* Xdata, - const T* labeldata, float* accuracy) { +template +__global__ void AccuracyCudaKernel(const int N, const int D, + const int64_t* Xdata, + const int64_t* labeldata, float* accuracy) { int count = 0; __shared__ int total[BlockSize]; @@ -52,13 +53,14 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "It must use GPUPlace."); - auto* inference = ctx.Input("Inference"); + auto* inference = ctx.Input("Out"); + auto* indices = ctx.Input("Indices"); auto* label = ctx.Input("Label"); auto* accuracy = ctx.Output("Accuracy"); // FIXME(typhoonzero): only support indices currently // if add support for output values, how to detect the data type? - const T* inference_data = inference->data(); - const T* label_data = label->data(); + const int64_t* indices_data = indices->data(); + const int64_t* label_data = label->data(); float* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); size_t num_samples = inference->dims()[0]; @@ -69,11 +71,11 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { return; } - AccuracyCudaKernel<<< + AccuracyCudaKernel<<< 1, PADDLE_CUDA_NUM_THREADS, 0, reinterpret_cast( ctx.device_context()) - .stream()>>>(num_samples, infer_width, inference_data, label_data, + .stream()>>>(num_samples, infer_width, indices_data, label_data, accuracy_data); } }; @@ -81,5 +83,7 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel, - paddle::operators::AccuracyOpCUDAKernel); +// FIXME(typhoonzero): types of T is for infernece data. +// label data is always int +REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel, + paddle::operators::AccuracyOpCUDAKernel); diff --git a/paddle/operators/accuracy_op.h b/paddle/operators/accuracy_op.h index 12c6b9aac8..1968b53d19 100644 --- a/paddle/operators/accuracy_op.h +++ b/paddle/operators/accuracy_op.h @@ -38,14 +38,15 @@ template class AccuracyKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* inference = ctx.Input("Inference"); + auto* inference = ctx.Input("Out"); + auto* indices = ctx.Input("Indices"); auto* label = ctx.Input("Label"); auto* accuracy = ctx.Output("Accuracy"); float* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); - const T* inference_data = inference->data(); - const T* label_data = label->data(); + const int64_t* indices_data = indices->data(); + const int64_t* label_data = label->data(); size_t num_samples = inference->dims()[0]; size_t class_dim = inference->dims()[1]; @@ -60,7 +61,7 @@ class AccuracyKernel : public framework::OpKernel { for (size_t i = 0; i < num_samples; ++i) { PADDLE_ENFORCE_GE(label_data[i], 0, "label must >= 0"); for (size_t j = 0; j < class_dim; ++j) { - if (inference_data[i * class_dim + j] == label_data[i]) { + if (indices_data[i * class_dim + j] == label_data[i]) { ++num_correct; break; } diff --git a/paddle/operators/auc_op.cc b/paddle/operators/auc_op.cc index cf3dbc5d10..f5784922af 100644 --- a/paddle/operators/auc_op.cc +++ b/paddle/operators/auc_op.cc @@ -23,18 +23,26 @@ class AucOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Inference"), - "Input of Inference must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Out"), "Input of Out must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Indices"), + "Input of Indices must be initialized."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input of Label must be initialized."); - auto inference_dim = ctx->GetInputDim("Inference"); - auto label_dim = ctx->GetInputDim("Label"); + auto inference_height = ctx->GetInputDim("Out")[0]; + auto label_height = ctx->GetInputDim("Label")[0]; - PADDLE_ENFORCE_EQ(inference_dim, label_dim, - "inference and label should have same shape"); + PADDLE_ENFORCE_EQ(inference_height, label_height, + "Out and Label should have same height."); ctx->SetOutputDim("AUC", {1}); - ctx->ShareLoD("Inference", /*->*/ "AUC"); + ctx->ShareLoD("Out", /*->*/ "AUC"); + } + + protected: + // IndicateDataType + framework::DataType IndicateDataType( + const framework::ExecutionContext &ctx) const override { + return framework::ToDataType(ctx.Input("Out")->type()); } }; @@ -42,12 +50,18 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { public: AucOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("Inference", - "A floating point tensor of arbitrary shape and whose values" - "are in the range [0, 1]."); + AddInput("Out", + "A floating point 2D tensor, values are in the range [0, 1]." + "Each row is descend sorted. This input should be the" + "output of topk." + "Typically, this tensor indicates the probability of each label"); + AddInput("Indices", + "An int 2D tensor, indicating the indices of original" + "tensor before sort. Typically, this tensor indicates which label" + "the probability stands for."); AddInput("Label", - "A tensor whose shape matches " - "Inference. Will be cast to bool."); + "A 2D int tensor indicating the label of the training data." + "The height is batch size and width is always 1."); // TODO(typhoonzero): support weight input AddOutput("AUC", "A scalar representing the " diff --git a/paddle/operators/auc_op.h b/paddle/operators/auc_op.h index be6ef29d5f..e5ac57b038 100644 --- a/paddle/operators/auc_op.h +++ b/paddle/operators/auc_op.h @@ -29,7 +29,7 @@ template class AucKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* inference = ctx.Input("Inference"); + auto* inference = ctx.Input("Out"); auto* label = ctx.Input("Label"); auto* auc = ctx.Output("AUC"); @@ -46,18 +46,11 @@ class AucKernel : public framework::OpKernel { thresholds_list[0] = 0.0f - kEpsilon; thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon; - size_t num_samples = inference->numel(); + size_t batch_size = inference->dims()[0]; + size_t inference_width = inference->dims()[1]; const T* inference_data = inference->data(); - Tensor label_casted; - label_casted.Resize(label->dims()); - bool* label_casted_data = label_casted.mutable_data(ctx.GetPlace()); - - const int* label_data = label->data(); - // cast label_data to bool - for (size_t i = 0; i < num_samples; i++) { - label_casted_data[i] = static_cast(label_data[i]); - } + const int64_t* label_data = label->data(); // Create local tensor for storing the curve: TP, FN, TN, FP // TODO(typhoonzero): use eigen op to caculate these values. @@ -68,23 +61,27 @@ class AucKernel : public framework::OpKernel { true_negative.Resize({num_thresholds}); false_positive.Resize({num_thresholds}); - int* tp_data = true_positive.mutable_data(ctx.GetPlace()); - int* fn_data = false_negative.mutable_data(ctx.GetPlace()); - int* tn_data = true_negative.mutable_data(ctx.GetPlace()); - int* fp_data = false_positive.mutable_data(ctx.GetPlace()); + int64_t* tp_data = true_positive.mutable_data(ctx.GetPlace()); + int64_t* fn_data = false_negative.mutable_data(ctx.GetPlace()); + int64_t* tn_data = true_negative.mutable_data(ctx.GetPlace()); + int64_t* fp_data = false_positive.mutable_data(ctx.GetPlace()); for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) { // caculate TP, FN, TN, FP for current thresh - int tp = 0, fn = 0, tn = 0, fp = 0; - for (size_t i = 0; i < num_samples; i++) { - if (label_casted_data[i]) { - if (inference_data[i] >= (thresholds_list[idx_thresh])) { + int64_t tp = 0, fn = 0, tn = 0, fp = 0; + for (size_t i = 0; i < batch_size; i++) { + // NOTE: label_data used as bool, labels >0 will be treated as true. + if (label_data[i]) { + // use first(max) data in each row + if (inference_data[i * inference_width] >= + (thresholds_list[idx_thresh])) { tp++; } else { fn++; } } else { - if (inference_data[i] >= (thresholds_list[idx_thresh])) { + if (inference_data[i * inference_width] >= + (thresholds_list[idx_thresh])) { fp++; } else { tn++; diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py index 4727d139a2..6451d11e2b 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/framework/layers.py @@ -243,8 +243,11 @@ def accuracy(input, label, k=1, **kwargs): acc_out = helper.create_tmp_variable(dtype=acc_out_dtype) helper.append_op( type="accuracy", - inputs={"Inference": [topk_indices], - "Label": [label]}, + inputs={ + "Out": [topk_out], + "Indices": [topk_indices], + "Label": [label] + }, outputs={"Accuracy": [acc_out]}) return acc_out diff --git a/python/paddle/v2/framework/tests/test_accuracy_op.py b/python/paddle/v2/framework/tests/test_accuracy_op.py index f17edd44ae..6536c297e8 100644 --- a/python/paddle/v2/framework/tests/test_accuracy_op.py +++ b/python/paddle/v2/framework/tests/test_accuracy_op.py @@ -7,13 +7,14 @@ class TestAccuracyOp(OpTest): def setUp(self): self.op_type = "accuracy" n = 8192 - infer = np.random.randint(0, 2, (n, 1)).astype("int") - label = np.random.randint(0, 2, (n, 1)).astype("int") - self.inputs = {'Inference': infer, "Label": label} + infer = np.random.random((n, 1)).astype("float32") + indices = np.random.randint(0, 2, (n, 1)) + label = np.random.randint(0, 2, (n, 1)) + self.inputs = {'Out': infer, 'Indices': indices, "Label": label} num_correct = 0 for rowid in xrange(n): - for ele in infer[rowid]: - if ele == label[rowid][0]: + for ele in indices[rowid]: + if ele == label[rowid]: num_correct += 1 break self.outputs = { diff --git a/python/paddle/v2/framework/tests/test_auc_op.py b/python/paddle/v2/framework/tests/test_auc_op.py index 65f679cfcc..26ea905d88 100644 --- a/python/paddle/v2/framework/tests/test_auc_op.py +++ b/python/paddle/v2/framework/tests/test_auc_op.py @@ -6,10 +6,11 @@ from op_test import OpTest class TestAucOp(OpTest): def setUp(self): self.op_type = "auc" - pred = np.random.random((128)).astype("float32") - labels = np.random.randint(0, 2, (128, )) + pred = np.random.random((128, 2)).astype("float32") + indices = np.random.randint(0, 2, (128, 2)) + labels = np.random.randint(0, 2, (128, 1)) num_thresholds = 200 - self.inputs = {'Inference': pred, 'Label': labels} + self.inputs = {'Out': pred, 'Indices': indices, 'Label': labels} self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds} # NOTE: sklearn use a different way to generate thresholds # which will cause the result differs slightly: @@ -31,12 +32,12 @@ class TestAucOp(OpTest): tp, fn, tn, fp = 0, 0, 0, 0 for i, lbl in enumerate(labels): if lbl: - if pred[i] >= thresh: + if pred[i, 0] >= thresh: tp += 1 else: fn += 1 else: - if pred[i] >= thresh: + if pred[i, 0] >= thresh: fp += 1 else: tn += 1 @@ -62,6 +63,5 @@ class TestAucOp(OpTest): self.check_output() -# TODO(typhoonzero): add this back till we fix it -#if __name__ == "__main__": -# unittest.main() +if __name__ == "__main__": + unittest.main() From 873ee9ab7e878a1b939183a0dccb946c0467e1d3 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Wed, 25 Oct 2017 15:30:24 +0800 Subject: [PATCH 05/16] add test_Expand and simply the gserver/tests/CMakeLists --- paddle/gserver/tests/CMakeLists.txt | 165 ++++++++------------------- paddle/gserver/tests/test_Expand.cpp | 125 ++++++++++++++++++++ 2 files changed, 174 insertions(+), 116 deletions(-) create mode 100644 paddle/gserver/tests/test_Expand.cpp diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt index 329536afaf..aa94ee406e 100644 --- a/paddle/gserver/tests/CMakeLists.txt +++ b/paddle/gserver/tests/CMakeLists.txt @@ -1,24 +1,29 @@ # gserver pacakge unittests -if(NOT MOBILE_INFERENCE) -################### test_ProtoDataProvider ############ - add_unittest_without_exec(test_ProtoDataProvider - test_ProtoDataProvider.cpp) - - # test_ProtoDataProvider will mkdir as same name, - # so if WORKING_DIRECTORY is default directory, then - # mkdir will get error. - add_test(NAME test_ProtoDataProvider - COMMAND ${CMAKE_CURRENT_BINARY_DIR}/test_ProtoDataProvider - WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) -endif() +add_simple_unittest(test_LinearChainCRF) +add_simple_unittest(test_MultinomialSampler) +add_simple_unittest(test_RecurrentLayer) -################# test_LayerGrad ####################### -add_unittest_without_exec(test_LayerGrad - test_LayerGrad.cpp - LayerGradUtil.cpp) -add_test(NAME test_LayerGrad - COMMAND test_LayerGrad) +function(gserver_test TARGET) + add_unittest_without_exec(${TARGET} + ${TARGET}.cpp + LayerGradUtil.cpp) + add_test(NAME ${TARGET} + COMMAND ${TARGET}) +endfunction() + +gserver_test(test_LayerGrad) +gserver_test(test_CRFLayerGrad) +gserver_test(test_CrossEntropyOverBeamGrad) +gserver_test(test_SeqSliceLayerGrad) +gserver_test(test_ActivationGrad) +gserver_test(test_ConvTrans) +gserver_test(test_PriorBox) +gserver_test(test_DetectionOutput) +gserver_test(test_ConvUnify) +gserver_test(test_BatchNorm) +gserver_test(test_KmaxSeqScore) +gserver_test(test_Expand) ########## test_Mkldnn layers and activations ########## if(WITH_MKLDNN) @@ -32,89 +37,6 @@ if(WITH_MKLDNN) WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) endif() -################ test_CRFLayerGrad #################### -add_unittest_without_exec(test_CRFLayerGrad - test_CRFLayerGrad.cpp - LayerGradUtil.cpp) -add_test(NAME test_CRFLayerGrad - COMMAND test_CRFLayerGrad) - -################ test_CrossEntropyOverBeam #################### -add_unittest_without_exec(test_CrossEntropyOverBeam - test_CrossEntropyOverBeamGrad.cpp - LayerGradUtil.cpp) -add_test(NAME test_CrossEntropyOverBeam - COMMAND test_CrossEntropyOverBeam) - -################ test_SeqSliceLayerGrad #################### -add_unittest_without_exec(test_SeqSliceLayerGrad - test_SeqSliceLayerGrad.cpp - LayerGradUtil.cpp) -add_test(NAME test_SeqSliceLayerGrad - COMMAND test_SeqSliceLayerGrad) - -add_unittest_without_exec(test_ActivationGrad - test_ActivationGrad.cpp - LayerGradUtil.cpp) -add_test(NAME test_ActivationGrad - COMMAND test_ActivationGrad) -################# test_ConvTrans ####################### -add_unittest_without_exec(test_ConvTrans - test_ConvTrans.cpp - LayerGradUtil.cpp) - -add_test(NAME test_ConvTrans - COMMAND test_ConvTrans) -################# test_PriorBox ####################### -add_unittest_without_exec(test_PriorBox - test_PriorBox.cpp - LayerGradUtil.cpp) - -add_test(NAME test_PriorBox - COMMAND test_PriorBox) -################# test_DetectionOutput ####################### -add_unittest_without_exec(test_DetectionOutput - test_DetectionOutput.cpp - LayerGradUtil.cpp) - -add_test(NAME test_DetectionOutput - COMMAND test_DetectionOutput) -################# test_ConvUnify ####################### -add_unittest_without_exec(test_ConvUnify - test_ConvUnify.cpp - LayerGradUtil.cpp) - -add_test(NAME test_ConvUnify - COMMAND test_ConvUnify) -################# test_BatchNorm ####################### -add_unittest_without_exec(test_BatchNorm - test_BatchNorm.cpp - LayerGradUtil.cpp) - -add_test(NAME test_BatchNorm - COMMAND test_BatchNorm) - - -################# test_KmaxSeqScore ####################### -add_unittest_without_exec(test_KmaxSeqScore - test_KmaxSeqScore.cpp - LayerGradUtil.cpp) - -add_test(NAME test_KmaxSeqScore - COMMAND test_KmaxSeqScore) - -if(NOT MOBILE_INFERENCE) -################## test_Evaluator ####################### - add_unittest(test_Evaluator - test_Evaluator.cpp) -endif() - -################ test_LinearChainCRF #################### -add_simple_unittest(test_LinearChainCRF) - -############## test_MultinomialSampler ################### -add_simple_unittest(test_MultinomialSampler) - ############## test_PyDataProvider ######################## if(WITH_PYTHON) add_unittest_without_exec(test_PyDataProvider @@ -125,9 +47,6 @@ if(WITH_PYTHON) WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) endif() -############### test_RecurrentLayer ####################### -add_simple_unittest(test_RecurrentLayer) - ############### test_WarpCTCLayer ####################### if(NOT WITH_DOUBLE) add_unittest_without_exec(test_WarpCTCLayer @@ -139,19 +58,33 @@ if(NOT WITH_DOUBLE) endif() if(NOT MOBILE_INFERENCE) -############### test_RecurrentGradientMachine ############### - # TODO(yuyang18): There is some bug in test_RecurrentGradientMachine - # I will fix it. - add_unittest_without_exec(test_RecurrentGradientMachine - test_RecurrentGradientMachine.cpp) - add_test(NAME test_RecurrentGradientMachine - COMMAND .set_python_path.sh -d - ${PADDLE_SOURCE_DIR}/python:${PADDLE_SOURCE_DIR}/paddle/gserver/tests - ${CMAKE_CURRENT_BINARY_DIR}/test_RecurrentGradientMachine - WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) -endif() +################### test_ProtoDataProvider ############ + add_unittest_without_exec(test_ProtoDataProvider + test_ProtoDataProvider.cpp) -if(NOT MOBILE_INFERENCE) + # test_ProtoDataProvider will mkdir as same name, + # so if WORKING_DIRECTORY is default directory, then + # mkdir will get error. + add_test(NAME test_ProtoDataProvider + COMMAND ${CMAKE_CURRENT_BINARY_DIR}/test_ProtoDataProvider + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) + +################## test_Evaluator ####################### + add_unittest(test_Evaluator + test_Evaluator.cpp) + +############### test_RecurrentGradientMachine ############### + # TODO(yuyang18): There is some bug in test_RecurrentGradientMachine + # I will fix it. + add_unittest_without_exec(test_RecurrentGradientMachine + test_RecurrentGradientMachine.cpp) + add_test(NAME test_RecurrentGradientMachine + COMMAND .set_python_path.sh -d + ${PADDLE_SOURCE_DIR}/python:${PADDLE_SOURCE_DIR}/paddle/gserver/tests + ${CMAKE_CURRENT_BINARY_DIR}/test_RecurrentGradientMachine + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) + +############### test_NetworkCompare ############### add_unittest_without_exec(test_NetworkCompare test_NetworkCompare.cpp) if(WITH_GPU) diff --git a/paddle/gserver/tests/test_Expand.cpp b/paddle/gserver/tests/test_Expand.cpp new file mode 100644 index 0000000000..a84a518a01 --- /dev/null +++ b/paddle/gserver/tests/test_Expand.cpp @@ -0,0 +1,125 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "LayerGradUtil.h" +#include "paddle/testing/TestUtil.h" + +using namespace paddle; // NOLINT +using namespace std; // NOLINT + +// Do one forward pass of expand layer and check to see if its output +// matches the given result.(Test onlyCPU currently.) +void doOneExpandTest(string trans_type, + bool hasSubseq, + bool useGpu, + Argument& input1, + Argument& input2, + Argument& result) { + FLAGS_use_gpu = false; + // Setting up the expand layer + TestConfig config; + config.layerConfig.set_type("expand"); + + auto inputType1 = + trans_type == "non-seq" ? INPUT_DENSE_DIM_DATA : INPUT_SEQUENCE_DATA; + config.inputDefs.push_back({inputType1, "layer0", 1, 0}); + auto inputType2 = + hasSubseq ? INPUT_HASSUB_SEQUENCE_DATA : INPUT_SEQUENCE_DATA; + + config.inputDefs.push_back({inputType2, "layer1", 1, 0}); + config.layerConfig.add_inputs(); + config.layerConfig.add_inputs(); + config.layerConfig.set_trans_type(trans_type); + + // data layer initialize + std::vector dataLayers; + LayerMap layerMap; + vector datas; + initDataLayer( + config, &dataLayers, &datas, &layerMap, "expand", 1, false, useGpu); + dataLayers[0]->getOutput() = input1; + dataLayers[1]->getOutput() = input2; + + // test layer initialize + std::vector parameters; + LayerPtr expandLayer; + initTestLayer(config, &layerMap, ¶meters, &expandLayer); + expandLayer->forward(PASS_GC); + checkMatrixEqual(expandLayer->getOutputValue(), result.value); +} + +TEST(Layer, ExpandLayerFwd) { + bool useGpu = false; + + // Assume batch_size =3 in all cases. + + // CPU case 1. non-seq expand to seq + // input1 = 1,2,3 + // input2 = [4,5],[6],[7,8,9] + // result = [1,1],[2],[3,3,3] + Argument input1, input2, result; + input1.value = Matrix::create(3, 1, false, useGpu); + real input1Data[] = {1, 2, 3}; + input1.value->setData(input1Data); + + input2.value = Matrix::create(6, 1, false, useGpu); + real input2Data[] = {4, 5, 6, 7, 8, 9}; + input2.value->setData(input2Data); + input2.sequenceStartPositions = ICpuGpuVector::create(4, useGpu); + int input2Seq[] = {0, 2, 3, 6}; + input2.sequenceStartPositions->copyFrom(input2Seq, 4, useGpu); + + result.value = Matrix::create(6, 1, false, useGpu); + real resultData[] = {1, 1, 2, 3, 3, 3}; + result.value->setData(resultData); + + doOneExpandTest("non-seq", false, useGpu, input1, input2, result); + + // CPU case 2. non-seq expand to sub-seq + // input1 = 1,2,3 + // input2 = [[4,5]],[[6]],[[7],[8,9]] + // result = [[1,1]],[[2]],[[3],[3,3]] + input2.subSequenceStartPositions = ICpuGpuVector::create(5, useGpu); + int input2SubSeq[] = {0, 2, 3, 4, 6}; + input2.subSequenceStartPositions->copyFrom(input2SubSeq, 5, useGpu); + + doOneExpandTest("non-seq", true, useGpu, input1, input2, result); + + // CPU case 3. seq expand to sub-seq + // input1 = [1,2],[3],[4] + // input2 = [[4,5]],[[6]],[[7],[8,9]] + // result = [[1,1]],[[2]],[[3],[4,4]] + Matrix::resizeOrCreate(input1.value, 4, 1, false, useGpu); + real input1Data_case3[] = {1, 2, 3, 4}; + input1.value->setData(input1Data_case3); + + input1.sequenceStartPositions = ICpuGpuVector::create(4, useGpu); + int input1Seq[] = {0, 2, 3, 4}; + input1.sequenceStartPositions->copyFrom(input1Seq, 4, useGpu); + + real resultData_case3[] = {1, 1, 2, 3, 4, 4}; + result.value->setData(resultData_case3); + + doOneExpandTest("seq", true, useGpu, input1, input2, result); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + initMain(argc, argv); + return RUN_ALL_TESTS(); +} From c2f6aa9b4ae4ed18cac09c87c3959f16f9f445d7 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Tue, 31 Oct 2017 14:36:38 +0800 Subject: [PATCH 06/16] add comments in test_Expand.cpp --- paddle/gserver/tests/test_Expand.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/gserver/tests/test_Expand.cpp b/paddle/gserver/tests/test_Expand.cpp index a84a518a01..d32bf0152f 100644 --- a/paddle/gserver/tests/test_Expand.cpp +++ b/paddle/gserver/tests/test_Expand.cpp @@ -91,6 +91,8 @@ TEST(Layer, ExpandLayerFwd) { doOneExpandTest("non-seq", false, useGpu, input1, input2, result); // CPU case 2. non-seq expand to sub-seq + // NOTE: input1.batch_size == input2.sequencelength in this case. + // i.e, input1 expands by input2.sequence // input1 = 1,2,3 // input2 = [[4,5]],[[6]],[[7],[8,9]] // result = [[1,1]],[[2]],[[3],[3,3]] From 1e127960cb706d5a77a2566a5d9398b8790553f1 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Tue, 31 Oct 2017 18:26:26 +0800 Subject: [PATCH 07/16] correct the index of cluster_train_cn/en.md --- doc/howto/usage/cluster/cluster_train_cn.md | 36 ++++++++++----------- doc/howto/usage/cluster/cluster_train_en.md | 36 ++++++++++----------- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/doc/howto/usage/cluster/cluster_train_cn.md b/doc/howto/usage/cluster/cluster_train_cn.md index 93c5544bcf..2e98b3de3f 100644 --- a/doc/howto/usage/cluster/cluster_train_cn.md +++ b/doc/howto/usage/cluster/cluster_train_cn.md @@ -19,7 +19,7 @@ * [启动集群作业](#启动集群作业-1) * [在Kubernetes集群中提交训练作业](#在kubernetes集群中提交训练作业) -# 概述 +## 概述 本文将介绍如何使用PaddlePaddle在不同的集群框架下完成分布式训练。分布式训练架构如下图所示: @@ -32,7 +32,7 @@ 在使用同步SGD训练神经网络时,PaddlePaddle使用同步屏障(barrier),使梯度的提交和参数的更新按照顺序方式执行。在异步SGD中,则并不会等待所有trainer提交梯度才更新参数,这样极大地提高了计算的并行性:参数服务器之间不相互依赖,并行地接收梯度和更新参数,参数服务器也不会等待计算节点全部都提交梯度之后才开始下一步,计算节点之间也不会相互依赖,并行地执行模型的训练。可以看出,虽然异步SGD方式会提高参数更新并行度, 但是并不能保证参数同步更新,在任意时间某一台参数服务器上保存的参数可能比另一台要更新,与同步SGD相比,梯度会有噪声。 -# 环境准备 +## 环境准备 1. 准备您的计算集群。计算集群通常由一组(几台到几千台规模)的Linux服务器组成。服务器之间可以通过局域网(LAN)联通,每台服务器具有集群中唯一的IP地址(或者可被DNS解析的主机名)。集群中的每台计算机通常被成为一个“节点”。 1. 我们需要在集群的所有节点上安装 PaddlePaddle。 如果要启用GPU,还需要在节点上安装对应的GPU驱动以及CUDA。PaddlePaddle的安装可以参考[build_and_install](https://github.com/PaddlePaddle/Paddle/tree/develop/doc/getstarted/build_and_install)的多种安装方式。我们推荐使用[Docker](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/getstarted/build_and_install/docker_install_cn.rst)安装方式来快速安装PaddlePaddle。 @@ -51,8 +51,8 @@ PaddlePaddle 0.10.0, compiled with 下面以`doc/howto/usage/cluster/src/word2vec`中的代码作为实例,介绍使用PaddlePaddle v2 API完成分布式训练。 -# 启动参数说明 -## 启动参数服务器 +## 启动参数说明 +### 启动参数服务器 执行以下的命令启动一个参数服务器并等待和计算节点的数据交互 ```bash $ paddle pserver --port=7164 --ports_num=1 --ports_num_for_sparse=1 --num_gradient_servers=1 @@ -70,7 +70,7 @@ $ stdbuf -oL /usr/bin/nohup paddle pserver --port=7164 --ports_num=1 --ports_num | ports_num_for_sparse | 必选 | 1 | 用于稀疏类型参数通信的端口个数 | | num_gradient_servers | 必选 | 1 | 当前训练任务pserver总数 | -## 启动计算节点 +### 启动计算节点 执行以下命令启动使用python编写的trainer程序(文件名为任意文件名,如train.py) ```bash $ python train.py @@ -117,7 +117,7 @@ paddle.init( | pservers | 必选 | 127.0.0.1 | 当前训练任务启动的pserver的IP列表,多个IP使用“,”隔开 | -## 准备数据集 +### 准备数据集 参考样例数据准备脚本[prepare.py](https://github.com/PaddlePaddle/Paddle/tree/develop/doc/howto/usage/cluster/src/word2vec/prepare.py),准备训练数据和验证数据集,我们使用paddle.dataset.imikolov数据集,并根据分布式训练并发数(trainer节点个数),在`prepare.py`开头部分指定`SPLIT_COUNT`将数据切分成多份。 @@ -149,7 +149,7 @@ test.txt-00002 对于不同的训练任务,训练数据格式和训练程序的`reader()`会大不相同,所以开发者需要根据自己训练任务的实际场景完成训练数据的分割和`reader()`的编写。 -## 准备训练程序 +### 准备训练程序 我们会对每个训练任务都会在每个节点上创建一个工作空间(workspace),其中包含了用户的训练程序、程序依赖、挂载或下载的训练数据分片。 @@ -184,7 +184,7 @@ test.txt-00002 - `train_data_dir`:包含训练数据的目录,可以是从分布式存储挂载过来的,也可以是在任务启动前下载到本地的。 - `test_data_dir`:包含测试数据集的目录。 -# 使用分布式计算平台或工具 +## 使用分布式计算平台或工具 PaddlePaddle可以使用多种分布式计算平台构建分布式计算任务,包括: - [Kubernetes](http://kubernetes.io) Google开源的容器集群的调度框架,支持大规模集群生产环境的完整集群方案。 @@ -195,12 +195,12 @@ PaddlePaddle可以使用多种分布式计算平台构建分布式计算任务 在使用分布式计算平台进行训练时,任务被调度在集群中时,分布式计算平台通常会通过API或者环境变量提供任务运行需要的参数,比如节点的ID、IP和任务节点个数等。 -## 使用Fabric启动集群作业 +### 使用Fabric启动集群作业 -### 准备一个Linux集群 +#### 准备一个Linux集群 可以在`paddle/scripts/cluster_train_v2/fabric/docker_cluster`目录下,执行`kubectl -f ssh_servers.yaml`启动一个测试集群,并使用`kubectl get po -o wide`获得这些节点的IP地址。 -### 启动集群作业 +#### 启动集群作业 `paddle.py` 提供了自动化脚本来启动不同节点中的所有 PaddlePaddle 集群进程。默认情况下,所有命令行选项可以设置为 `paddle.py` 命令选项并且 `paddle.py` 将透明、自动地将这些选项应用到 PaddlePaddle 底层进程。 @@ -216,10 +216,10 @@ sh run.sh 集群作业将会在几秒后启动。 -### 终止集群作业 +#### 终止集群作业 `paddle.py`能获取`Ctrl + C` SIGINT 信号来自动终止它启动的所有进程。只需中断 `paddle.py` 任务来终止集群作业。如果程序崩溃你也可以手动终止。 -### 检查集群训练结果 +#### 检查集群训练结果 详细信息请检查 $workspace/log 里的日志,每一个节点都有相同的日志结构。 `paddle_trainer.INFO` @@ -234,13 +234,13 @@ sh run.sh `train.log` 提供训练过程的 stderr 和 stdout。训练失败时可以检查错误日志。 -### 检查模型输出 +#### 检查模型输出 运行完成后,模型文件将被写入节点 0 的 `output` 目录中。 工作空间中的 `nodefile` 表示当前集群作业的节点 ID。 -## 在OpenMPI集群中提交训练作业 +### 在OpenMPI集群中提交训练作业 -### 准备OpenMPI集群 +#### 准备OpenMPI集群 执行下面的命令以启动3个节点的OpenMPI集群和一个"head"节点: @@ -252,7 +252,7 @@ kubectl create -f mpi-nodes.yaml 然后可以从head节点ssh无密码登录到OpenMPI的每个节点上。 -### 启动集群作业 +#### 启动集群作业 您可以按照下面的步骤在OpenMPI集群中提交paddle训练任务: @@ -280,6 +280,6 @@ scp train.txt-00002 test.txt-00002 [node3IP]:/home/tutorial mpirun -hostfile machines -n 3 /home/tutorial/start_mpi_train.sh ``` -## 在Kubernetes集群中提交训练作业 +### 在Kubernetes集群中提交训练作业 此部分的使用方法可以参考[here](../k8s/k8s_distributed_cn.md)。 diff --git a/doc/howto/usage/cluster/cluster_train_en.md b/doc/howto/usage/cluster/cluster_train_en.md index 1e8b4d54b9..baa97c0c02 100644 --- a/doc/howto/usage/cluster/cluster_train_en.md +++ b/doc/howto/usage/cluster/cluster_train_en.md @@ -19,7 +19,7 @@ * [Launching Cluster Job](#launching-cluster-job-1) * [Cluster Training Using Kubernetes](#cluster-training-using-kubernetes) -# Introduction +## Introduction In this article, we'll explain how to run distributed training jobs with PaddlePaddle on different types of clusters. The diagram below shows the main architecture of a distributed trainning job: @@ -33,7 +33,7 @@ PaddlePaddle can support both synchronize stochastic gradient descent (SGD) and When training with synchronize SGD, PaddlePaddle uses an internal "synchronize barrier" which makes gradients update and parameter download in strict order. On the other hand, asynchronous SGD won't wait for all trainers to finish upload at a single step, this will increase the parallelism of distributed training: parameter servers do not depend on each other, they'll do parameter optimization concurrently. Parameter servers will not wait for trainers, so trainers will also do their work concurrently. But asynchronous SGD will introduce more randomness and noises in the gradient. -# Preparations +## Preparations 1. Prepare your computer cluster. It's normally a bunch of Linux servers connected by LAN. Each server will be assigned a unique IP address. The computers in the cluster can be called "nodes". 2. Install PaddlePaddle on every node. If you are going to take advantage of GPU cards, you'll also need to install proper driver and CUDA libraries. To install PaddlePaddle please read [this build and install](https://github.com/PaddlePaddle/Paddle/tree/develop/doc/getstarted/build_and_install) document. We strongly recommend using [Docker installation](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/getstarted/build_and_install/docker_install_en.rst). @@ -52,9 +52,9 @@ PaddlePaddle 0.10.0rc, compiled with We'll take `doc/howto/usage/cluster/src/word2vec` as an example to introduce distributed training using PaddlePaddle v2 API. -# Command-line arguments +## Command-line arguments -## Starting parameter server +### Starting parameter server Type the below command to start a parameter server which will wait for trainers to connect: @@ -74,7 +74,7 @@ $ stdbuf -oL /usr/bin/nohup paddle pserver --port=7164 --ports_num=1 --ports_num | ports_num_for_sparse | required | 1 | number of ports which serves sparse parameter update | | num_gradient_servers | required | 1 | total number of gradient servers | -## Starting trainer +### Starting trainer Type the command below to start the trainer(name the file whatever you want, like "train.py") ```bash @@ -122,7 +122,7 @@ paddle.init( | trainer_id | required | 0 | ID for every trainer, start from 0 | | pservers | required | 127.0.0.1 | list of IPs of parameter servers, separated by "," | -## Prepare Training Dataset +### Prepare Training Dataset Here's some example code [prepare.py](https://github.com/PaddlePaddle/Paddle/tree/develop/doc/howto/usage/cluster/src/word2vec/prepare.py), it will download public `imikolov` dataset and split it into multiple files according to job parallelism(trainers count). Modify `SPLIT_COUNT` at the begining of `prepare.py` to change the count of output files. @@ -155,7 +155,7 @@ When job started, every trainer needs to get it's own part of data. In some dist Different training jobs may have different data format and `reader()` function, developers may need to write different data prepare scripts and `reader()` functions for their job. -## Prepare Training program +### Prepare Training program We'll create a *workspace* directory on each node, storing your training program, dependencies, mounted or downloaded dataset directory. @@ -191,7 +191,7 @@ Your workspace may looks like: - `train_data_dir`: containing training data. Mount from storage service or copy trainning data to here. - `test_data_dir`: containing testing data. -# Use cluster platforms or cluster management tools +## Use cluster platforms or cluster management tools PaddlePaddle supports running jobs on several platforms including: - [Kubernetes](http://kubernetes.io) open-source system for automating deployment, scaling, and management of containerized applications from Google. @@ -202,13 +202,13 @@ We'll introduce cluster job management on these platforms. The examples can be f These cluster platforms provide API or environment variables for training processes, when the job is dispatched to different nodes. Like node ID, IP or total number of nodes etc. -## Cluster Training Using Fabric +### Cluster Training Using Fabric -### Prepare a Linux cluster +#### Prepare a Linux cluster Run `kubectl -f ssh_servers.yaml` under the directory: `paddle/scripts/cluster_train_v2/fabric/docker_cluster` will launch a demo cluster. Run `kubectl get po -o wide` to get IP addresses of these nodes. -### Launching Cluster Job +#### Launching Cluster Job `paddle.py` provides automatical scripts to start all PaddlePaddle cluster processes in different nodes. By default, all command line options can be set as `paddle.py` command options and `paddle.py` will transparently and automatically set these options to PaddlePaddle lower level processes. `paddle.py`provides two distinguished command option for easy job launching. @@ -224,10 +224,10 @@ sh run.sh The cluster Job will start in several seconds. -### Kill Cluster Job +#### Kill Cluster Job `paddle.py` can capture `Ctrl + C` SIGINT signal to automatically kill all processes launched by it. So just stop `paddle.py` to kill cluster job. You should manually kill the job if the program crashed. -### Check Cluster Training Result +#### Check Cluster Training Result Check log in $workspace/log for details, each node owns same log structure. `paddle_trainer.INFO` @@ -242,13 +242,13 @@ It provides stderr and stdout of parameter server process. Check error log if tr `train.log` It provides stderr and stdout of trainer process. Check error log if training crashes. -### Check Model Output +#### Check Model Output After one pass finished, model files will be written in `output` directory in node 0. `nodefile` in workspace indicates the node id of current cluster job. -## Cluster Training Using OpenMPI +### Cluster Training Using OpenMPI -### Prepare an OpenMPI cluster +#### Prepare an OpenMPI cluster Run the following command to start a 3-node MPI cluster and one "head" node. @@ -260,7 +260,7 @@ kubectl create -f mpi-nodes.yaml Then you can log in to every OpenMPI node using ssh without input any passwords. -### Launching Cluster Job +#### Launching Cluster Job Follow the steps to launch a PaddlePaddle training job in OpenMPI cluster:\ @@ -288,6 +288,6 @@ scp train.txt-00002 test.txt-00002 [node3IP]:/home/tutorial mpirun -hostfile machines -n 3 /home/tutorial/start_mpi_train.sh ``` -## Cluster Training Using Kubernetes +### Cluster Training Using Kubernetes The details can be found [here](../k8s/k8s_cn.md) From 2113d6ed728e0e20ff529a64424f5a05637698b9 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 31 Oct 2017 10:06:44 -0700 Subject: [PATCH 08/16] fix bug (#5233) --- python/paddle/v2/dataset/imdb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/v2/dataset/imdb.py b/python/paddle/v2/dataset/imdb.py index 93dd3e8f7d..cfc1c886e1 100644 --- a/python/paddle/v2/dataset/imdb.py +++ b/python/paddle/v2/dataset/imdb.py @@ -116,7 +116,7 @@ def reader_creator(pos_pattern, neg_pattern, word_idx, buffer_size): yield [word_idx.get(w, UNK) for w in doc], i % 2 doc = qs[i % 2].get() - return reader() + return reader def train(word_idx): From ddde829a1ccf99cecd194fc27e008d49945e921a Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Tue, 31 Oct 2017 10:11:35 -0700 Subject: [PATCH 09/16] Fix/sequence pool (#5229) * "modify layers.py" * "fix pool interface" * "add export type to layers" * "fix based on comment" --- python/paddle/v2/framework/layers.py | 75 +++++++++++++++------------- python/paddle/v2/framework/nets.py | 9 +--- 2 files changed, 43 insertions(+), 41 deletions(-) diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py index 6451d11e2b..5fdad52f21 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/framework/layers.py @@ -5,7 +5,8 @@ import re __all__ = [ 'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat', - 'StaticRNN', 'cast', 'sequence_conv', 'sequence_pool', 'accuracy' + 'StaticRNN', 'cast', 'sequence_conv', 'sequence_pool', 'sums', 'cos_sim', + 'batch_norm', 'accuracy' ] @@ -165,18 +166,6 @@ _create_op_func_('dropout') _create_op_func_('reshape') -def cast(x, data_type, program=None): - helper = LayerHelper('cast', **locals()) - out = helper.create_tmp_variable(dtype=data_type) - helper.append_op( - type='cast', - inputs={'X': [x]}, - outputs={'Out': [out]}, - attrs={'in_data_type': x.data_type, - 'out_data_type': out.data_type}) - return out - - def cast(x, data_type, program=None): helper = LayerHelper('cast', **locals()) out = helper.create_tmp_variable(dtype=data_type) @@ -191,9 +180,7 @@ def cast(x, data_type, program=None): def concat(input, axis, program=None, init_program=None): helper = LayerHelper('concat', **locals()) - if not isinstance(input, list) and not isinstance(input, tuple): - input = [input] - out = helper.create_tmp_variable(dtype=input[0].data_type) + out = helper.create_tmp_variable(dtype=helper.input_dtype()) helper.append_op( type='concat', inputs={'X': input}, @@ -202,6 +189,28 @@ def concat(input, axis, program=None, init_program=None): return out +def sums(input, program=None, init_program=None): + helper = LayerHelper('sum', **locals()) + out = helper.create_tmp_variable(dtype=helper.input_dtype()) + helper.append_op(type='sum', inputs={'X': [input]}, outputs={'Out': out}) + return out + + +def cos_sim(X, Y, program=None, init_program=None): + helper = LayerHelper('cos_sim', **locals()) + out = helper.create_tmp_variable(dtype=helper.input_dtype("X")) + xnorm = helper.create_tmp_variable(dtype=helper.input_dtype("X")) + ynorm = helper.create_tmp_variable(dtype=helper.input_dtype("X")) + helper.append_op( + type='cos_sim', + inputs={'X': [X], + 'Y': [Y]}, + outputs={'Out': [out], + 'XNorm': [xnorm], + 'YNorm': [ynorm]}) + return out, xnorm, ynorm + + def cross_entropy(input, label, **kwargs): helper = LayerHelper('cross_entropy', **kwargs) out = helper.create_tmp_variable(dtype=input.data_type) @@ -254,9 +263,7 @@ def accuracy(input, label, k=1, **kwargs): def sequence_conv(input, num_filters, - name=None, filter_size=3, - act=None, stride=1, padding=None, bias_attr=None, @@ -270,7 +277,7 @@ def sequence_conv(input, helper = LayerHelper('sequence_conv', **locals()) dtype = helper.input_dtype() - filter_shape = [num_filters, filter_size] + filter_shape = [filter_size * input.shape[1], num_filters] filter = helper.create_parameter( attr=helper.param_attr, shape=filter_shape, dtype=dtype) pre_bias = helper.create_tmp_variable(dtype) @@ -279,7 +286,7 @@ def sequence_conv(input, type='sequence_conv', inputs={ 'X': [input], - 'Filter': filter, + 'Filter': [filter], }, outputs={"Out": pre_bias}, attrs={ @@ -287,7 +294,6 @@ def sequence_conv(input, 'context_start': 0, 'context_length': filter_size }) - pre_act = helper.append_bias_op(pre_bias) return helper.append_activation(pre_act) @@ -344,31 +350,32 @@ def conv2d(input, return helper.append_activation(pre_act) -def sequence_pool(input, - pool_size, - pool_type, - pool_stride=1, - pool_padding=0, - global_pooling=False, - program=None, - init_program=None): +def sequence_pool(input, pool_type, program=None, init_program=None): # FIXME(dzh) : want to unify the argument of python layer # function. So we ignore some unecessary attributes - ENUM_POOL_TYPE = set(["max", "avg", "sqrt", "last", "first"]) - if pool_type not in ENUM_POOL_TYPE: + ENUM_POOL_TYPE = dict({ + "AVERAGE": 0, + "SUM": 1, + "SQRT": 2, + "MAX": 3, + "LAST": 4, + "FIRST": 5 + }) + if pool_type.upper() not in ENUM_POOL_TYPE: raise ValueError("Unknown pool_type: '%s'. It can only be %s.", - str(pool_type), " ".join(ENUM_POOL_TYPE)) + str(pool_type), " ".join(ENUM_POOL_TYPE.keys())) helper = LayerHelper('sequence_pool', **locals()) dtype = helper.input_dtype() pool_out = helper.create_tmp_variable(dtype) + # FIXME(dzh): strategy helper.append_op( type="sequence_pool", inputs={"X": [input]}, - outputs={"Out": pool_out}, - attrs={"strategy": pool_type}) + outputs={"Out": [pool_out]}, + attrs={"strategy": ENUM_POOL_TYPE[pool_type.upper()]}) return pool_out diff --git a/python/paddle/v2/framework/nets.py b/python/paddle/v2/framework/nets.py index a9998073e1..8191b5ef44 100644 --- a/python/paddle/v2/framework/nets.py +++ b/python/paddle/v2/framework/nets.py @@ -101,24 +101,19 @@ def img_conv_group(input, def sequence_conv_pool(input, num_filters, filter_size, - pool_size, - pool_stride, - act, + pool_type="max", program=None, init_program=None): conv_out = layers.sequence_conv( input=input, num_filters=num_filters, filter_size=filter_size, - act=act, program=program, init_program=init_program) pool_out = layers.sequence_pool( input=conv_out, - pool_size=pool_size, - pool_type='max', - pool_stride=pool_stride, + pool_type=pool_type, program=program, init_program=init_program) return pool_out From e41f28cbcd4c9ab04213a8548470e7c5d040c244 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Tue, 31 Oct 2017 10:40:57 -0700 Subject: [PATCH 10/16] Adding a framework for variable initializers (#5232) --- python/paddle/v2/framework/framework.py | 19 +-- python/paddle/v2/framework/initializer.py | 109 ++++++++++++++++++ python/paddle/v2/framework/layer_helper.py | 19 +-- python/paddle/v2/framework/layers.py | 26 ++--- .../tests/test_recognize_digits_mlp.py | 10 +- 5 files changed, 128 insertions(+), 55 deletions(-) create mode 100644 python/paddle/v2/framework/initializer.py diff --git a/python/paddle/v2/framework/framework.py b/python/paddle/v2/framework/framework.py index f8d2f67410..b3493fc378 100644 --- a/python/paddle/v2/framework/framework.py +++ b/python/paddle/v2/framework/framework.py @@ -354,8 +354,8 @@ class Block(object): def create_var(self, *args, **kwargs): var = Variable(self, *args, **kwargs) - if 'init_attr' in kwargs: - self._prepend_initialize_ops_(var, kwargs['init_attr']) + if 'initializer' in kwargs: + kwargs['initializer'](var, self) return var def has_var(self, name): @@ -364,8 +364,8 @@ class Block(object): def create_parameter(self, *args, **kwargs): global_block = self.program.global_block() param = Parameter(global_block, *args, **kwargs) - if 'init_attr' in kwargs: - self._prepend_initialize_ops_(param, kwargs['init_attr']) + if 'initializer' in kwargs: + kwargs['initializer'](param, self) return param def append_op(self, *args, **kwargs): @@ -424,17 +424,6 @@ class Block(object): for index in range(len(self.ops)): assert self.ops[index].desc == ops_in_cpp[index] - def _prepend_initialize_ops_(self, param, init_attr): - op_type = init_attr['type'] - init_attr['shape'] = param.shape - init_attr['data_type'] = int(param.data_type) - op = self.prepend_op( - type=op_type, - inputs=None, - outputs={'Out': [param]}, - attrs=init_attr) - param.op = op - class Program(object): def __init__(self): diff --git a/python/paddle/v2/framework/initializer.py b/python/paddle/v2/framework/initializer.py new file mode 100644 index 0000000000..377d332713 --- /dev/null +++ b/python/paddle/v2/framework/initializer.py @@ -0,0 +1,109 @@ +import paddle.v2.framework.framework as framework + +__all__ = ['ConstantInitializer', 'UniformInitializer'] + + +class Initializer(object): + """Base class for variable initializers + + Defines the common interface of variable initializers. + They add operations to the init program that are used + to initialize variables. Users should not use this class + directly, but need to use one of its implementations. + """ + + def __init_(self): + pass + + def __call__(self, param, block): + """Add corresponding initialization operations to the network + """ + raise NotImplementedError() + + +class ConstantInitializer(Initializer): + """Implements the constant initializer + """ + + def __init__(self, value=0.0): + """Constructor for ConstantInitializer + + Args: + value: constant value to initialize the variable + """ + assert value is not None + super(ConstantInitializer, self).__init__() + self._value = value + + def __call__(self, var, block): + """Add constant initialization ops for a variable + + Args: + var: Variable that needs to be initialized + block: The block in which initialization ops + should be added + + Returns: + the initialization op + """ + assert isinstance(var, framework.Variable) + assert isinstance(block, framework.Block) + # Initialization Ops should be prepended and not appended + op = block.prepend_op( + type="fill_constant", + outputs={"Out": var}, + attrs={ + "shape": var.shape, + "data_type": int(var.data_type), + "value": self._value + }) + var.op = op + return op + + +class UniformInitializer(Initializer): + """Implements for random uniform distribution initializer + """ + + def __init__(self, low=-1.0, high=1.0, seed=0): + """Constructor for UniformInitializer + + Args: + low: lower boundary of the uniform distribution + high: upper boundary of the uniform distribution + seed: random seed + """ + assert low is not None + assert high is not None + assert seed is not None + super(UniformInitializer, self).__init__() + self._low = low + self._high = high + self._seed = seed + + def __call__(self, var, block): + """Add uniform distribution initialization ops for a variable + + Args: + var: Variable that needs to be initialized + block: The block in which initialization ops + should be added + + Returns: + the initialization op + """ + assert isinstance(var, framework.Variable) + assert isinstance(block, framework.Block) + # Initialization Ops should be prepended and not appended + op = block.prepend_op( + type="uniform_random", + outputs={"Out": var}, + attrs={ + "shape": var.shape, + "data_type": int(var.data_type), + "min": self._low, + "max": self._high, + "seed": self._seed + }) + var.op = op + return op diff --git a/python/paddle/v2/framework/layer_helper.py b/python/paddle/v2/framework/layer_helper.py index d96dbe172c..c57776441c 100644 --- a/python/paddle/v2/framework/layer_helper.py +++ b/python/paddle/v2/framework/layer_helper.py @@ -5,6 +5,8 @@ import paddle.v2.framework.core as core from paddle.v2.framework.framework import Variable, g_program, \ g_init_program +from paddle.v2.framework.initializer import ConstantInitializer, \ + UniformInitializer def unique_name(prefix): @@ -66,14 +68,7 @@ class LayerHelper(object): @property def param_attr(self): - default = { - 'name': None, - 'init_attr': { - 'type': 'uniform_random', - 'min': -1.0, - 'max': 1.0 - } - } + default = {'name': None, 'initializer': UniformInitializer()} actual = self.kwargs.get('param_attr', None) if actual is None: actual = default @@ -83,13 +78,7 @@ class LayerHelper(object): return actual def bias_attr(self): - default = { - 'name': None, - 'init_attr': { - 'type': 'fill_constant', - 'value': 0.0 - } - } + default = {'name': None, 'initializer': ConstantInitializer()} bias_attr = self.kwargs.get('bias_attr', None) if bias_attr is True: bias_attr = default diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py index 5fdad52f21..dab72f0195 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/framework/layers.py @@ -1,6 +1,7 @@ from paddle.v2.framework.layer_helper import LayerHelper, unique_name import paddle.v2.framework.core as core from paddle.v2.framework.framework import OpProtoHolder, Variable, Program +from paddle.v2.framework.initializer import ConstantInitializer import re __all__ = [ @@ -440,26 +441,12 @@ def batch_norm(input, else: raise ValueError("unsupported data layout:" + data_layout) - def get_init_attr(value): - if not isinstance(value, float): - raise ValueError("attr value should be a float") - return {'type': 'fill_constant', 'value': value} - - def prepend_init_op(var, init_attr): - assert isinstance(var, Variable) - op_type = init_attr['type'] - init_attr['shape'] = var.shape - init_attr['data_type'] = int(var.data_type) - op = var.block.prepend_op( - type=op_type, inputs=None, outputs={'Out': [var]}, attrs=init_attr) - return op - - def create_persistable_var(dtype, shape, init_attr=None): + def create_persistable_var(dtype, shape, initializer=None): name = unique_name(".".join([helper.name, "xxxx"])) var = init_program.global_block().create_var( dtype=dtype, shape=shape, name=name, persistable=True) - if 'init_attr' is not None: - prepend_init_op(var, init_attr) + if initializer is not None: + initializer(var, var.block) return program.global_block().create_var( name=name, dtype=dtype, shape=shape, persistable=True) @@ -472,8 +459,9 @@ def batch_norm(input, attr=helper.param_attr, shape=param_shape, dtype=dtype) # create input - mean = create_persistable_var(dtype, param_shape, get_init_attr(0.0)) - variance = create_persistable_var(dtype, param_shape, get_init_attr(1.0)) + mean = create_persistable_var(dtype, param_shape, ConstantInitializer(0.0)) + variance = create_persistable_var(dtype, param_shape, + ConstantInitializer(1.0)) # create output # mean and mean_out share the same memory diff --git a/python/paddle/v2/framework/tests/test_recognize_digits_mlp.py b/python/paddle/v2/framework/tests/test_recognize_digits_mlp.py index a8a34b2a95..9916569d04 100644 --- a/python/paddle/v2/framework/tests/test_recognize_digits_mlp.py +++ b/python/paddle/v2/framework/tests/test_recognize_digits_mlp.py @@ -3,9 +3,10 @@ import paddle.v2.framework.layers as layers import paddle.v2.framework.core as core import paddle.v2.framework.optimizer as optimizer -from paddle.v2.framework.framework import Program, g_program +from paddle.v2.framework.framework import Program from paddle.v2.framework.executor import Executor from paddle.v2.framework.regularizer import L2DecayRegularizer +from paddle.v2.framework.initializer import UniformInitializer import numpy as np @@ -21,11 +22,8 @@ image = layers.data( param_attr = { 'name': None, - 'init_attr': { - 'type': 'uniform_random', - 'min': -1.0, - 'max': 1.0 - }, + 'initializer': UniformInitializer( + low=-1.0, high=1.0), 'regularization': L2DecayRegularizer(0.0005 * BATCH_SIZE) } From 9b65acd586f0c0cc246ca7a763912cb2ea502536 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Wed, 1 Nov 2017 02:48:45 +0800 Subject: [PATCH 11/16] memory log level change from 3 to 10 (#5231) --- paddle/memory/detail/buddy_allocator.cc | 55 +++++++++++++------------ paddle/memory/detail/meta_cache.cc | 2 +- paddle/memory/memory.cc | 17 ++++---- 3 files changed, 38 insertions(+), 36 deletions(-) diff --git a/paddle/memory/detail/buddy_allocator.cc b/paddle/memory/detail/buddy_allocator.cc index e212f7737a..64ee538038 100644 --- a/paddle/memory/detail/buddy_allocator.cc +++ b/paddle/memory/detail/buddy_allocator.cc @@ -27,11 +27,11 @@ BuddyAllocator::BuddyAllocator(SystemAllocator* system_allocator, system_allocator_(std::move(system_allocator)) {} BuddyAllocator::~BuddyAllocator() { - VLOG(3) << "BuddyAllocator Disconstructor makes sure that all of these " - "have actually been freed"; + VLOG(10) << "BuddyAllocator Disconstructor makes sure that all of these " + "have actually been freed"; while (!pool_.empty()) { auto block = static_cast(std::get<2>(*pool_.begin())); - VLOG(3) << "Free from block (" << block << ", " << max_chunk_size_ << ")"; + VLOG(10) << "Free from block (" << block << ", " << max_chunk_size_ << ")"; system_allocator_->Free(block, max_chunk_size_, block->index(cache_)); cache_.invalidate(block); @@ -51,11 +51,12 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) { // acquire the allocator lock std::lock_guard lock(mutex_); - VLOG(3) << "Allocate " << unaligned_size << " bytes from chunk size " << size; + VLOG(10) << "Allocate " << unaligned_size << " bytes from chunk size " + << size; // if the allocation is huge, send directly to the system allocator if (size > max_chunk_size_) { - VLOG(3) << "Allocate from system allocator."; + VLOG(10) << "Allocate from system allocator."; return SystemAlloc(size); } @@ -70,9 +71,9 @@ void* BuddyAllocator::Alloc(size_t unaligned_size) { return nullptr; } } else { - VLOG(3) << "Allocation from existing memory block " << std::get<2>(*it) - << " at address " - << reinterpret_cast(std::get<2>(*it))->data(); + VLOG(10) << "Allocation from existing memory block " << std::get<2>(*it) + << " at address " + << reinterpret_cast(std::get<2>(*it))->data(); } total_used_ += size; @@ -89,10 +90,10 @@ void BuddyAllocator::Free(void* p) { // Acquire the allocator lock std::lock_guard lock(mutex_); - VLOG(3) << "Free from address " << block; + VLOG(10) << "Free from address " << block; if (block->type(cache_) == MemoryBlock::HUGE_CHUNK) { - VLOG(3) << "Free directly from system allocator"; + VLOG(10) << "Free directly from system allocator"; system_allocator_->Free(block, block->total_size(cache_), block->index(cache_)); @@ -109,8 +110,8 @@ void BuddyAllocator::Free(void* p) { // Trying to merge the right buddy if (block->has_right_buddy(cache_)) { - VLOG(3) << "Merging this block " << block << " with its right buddy " - << block->right_buddy(cache_); + VLOG(10) << "Merging this block " << block << " with its right buddy " + << block->right_buddy(cache_); auto right_buddy = block->right_buddy(cache_); @@ -127,8 +128,8 @@ void BuddyAllocator::Free(void* p) { // Trying to merge the left buddy if (block->has_left_buddy(cache_)) { - VLOG(3) << "Merging this block " << block << " with its left buddy " - << block->left_buddy(cache_); + VLOG(10) << "Merging this block " << block << " with its left buddy " + << block->left_buddy(cache_); auto left_buddy = block->left_buddy(cache_); @@ -144,8 +145,8 @@ void BuddyAllocator::Free(void* p) { } // Dumping this block into pool - VLOG(3) << "Inserting free block (" << block << ", " - << block->total_size(cache_) << ")"; + VLOG(10) << "Inserting free block (" << block << ", " + << block->total_size(cache_) << ")"; pool_.insert( IndexSizeAddress(block->index(cache_), block->total_size(cache_), block)); @@ -164,7 +165,7 @@ void* BuddyAllocator::SystemAlloc(size_t size) { size_t index = 0; void* p = system_allocator_->Alloc(index, size); - VLOG(3) << "Allocated " << p << " from system allocator."; + VLOG(10) << "Allocated " << p << " from system allocator."; if (p == nullptr) return nullptr; @@ -190,8 +191,8 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() { if (p == nullptr) return pool_.end(); - VLOG(3) << "Creating and inserting new block " << p - << " from system allocator"; + VLOG(10) << "Creating and inserting new block " << p + << " from system allocator"; static_cast(p)->init(cache_, MemoryBlock::FREE_CHUNK, index, max_chunk_size_, nullptr, nullptr); @@ -235,19 +236,19 @@ void* BuddyAllocator::SplitToAlloc(BuddyAllocator::PoolSet::iterator it, auto block = static_cast(std::get<2>(*it)); pool_.erase(it); - VLOG(3) << "Split block (" << block << ", " << block->total_size(cache_) - << ") into"; + VLOG(10) << "Split block (" << block << ", " << block->total_size(cache_) + << ") into"; block->split(cache_, size); - VLOG(3) << "Left block (" << block << ", " << block->total_size(cache_) - << ")"; + VLOG(10) << "Left block (" << block << ", " << block->total_size(cache_) + << ")"; block->set_type(cache_, MemoryBlock::ARENA_CHUNK); // the rest of memory if exist if (block->has_right_buddy(cache_)) { if (block->right_buddy(cache_)->type(cache_) == MemoryBlock::FREE_CHUNK) { - VLOG(3) << "Insert right block (" << block->right_buddy(cache_) << ", " - << block->right_buddy(cache_)->total_size(cache_) << ")"; + VLOG(10) << "Insert right block (" << block->right_buddy(cache_) << ", " + << block->right_buddy(cache_)->total_size(cache_) << ")"; pool_.insert( IndexSizeAddress(block->right_buddy(cache_)->index(cache_), @@ -274,7 +275,7 @@ void BuddyAllocator::CleanIdleFallBackAlloc() { return; } - VLOG(3) << "Return block " << block << " to fallback allocator."; + VLOG(10) << "Return block " << block << " to fallback allocator."; system_allocator_->Free(block, max_chunk_size_, block->index(cache_)); cache_.invalidate(block); @@ -310,7 +311,7 @@ void BuddyAllocator::CleanIdleNormalAlloc() { MemoryBlock* block = static_cast(std::get<2>(*pool)); - VLOG(3) << "Return block " << block << " to base allocator."; + VLOG(10) << "Return block " << block << " to base allocator."; system_allocator_->Free(block, max_chunk_size_, block->index(cache_)); cache_.invalidate(block); diff --git a/paddle/memory/detail/meta_cache.cc b/paddle/memory/detail/meta_cache.cc index f0721c3b94..7e2f92b00c 100644 --- a/paddle/memory/detail/meta_cache.cc +++ b/paddle/memory/detail/meta_cache.cc @@ -30,7 +30,7 @@ Metadata MetadataCache::load(const MemoryBlock* block) { return existing_metadata->second; } else { auto* meta = reinterpret_cast(block); - VLOG(3) << "Load MetaData type=" << meta->type; + VLOG(10) << "Load MetaData type=" << meta->type; PADDLE_ASSERT(meta->check_guards()); return *reinterpret_cast(block); } diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc index 0b648642f9..5eb1c44eb6 100644 --- a/paddle/memory/memory.cc +++ b/paddle/memory/memory.cc @@ -39,15 +39,15 @@ BuddyAllocator* GetCPUBuddyAllocator() { template <> void* Alloc(platform::CPUPlace place, size_t size) { - VLOG(3) << "Allocate " << size << " bytes on " << platform::Place(place); + VLOG(10) << "Allocate " << size << " bytes on " << platform::Place(place); void* p = GetCPUBuddyAllocator()->Alloc(size); - VLOG(3) << " pointer=" << p; + VLOG(10) << " pointer=" << p; return p; } template <> void Free(platform::CPUPlace place, void* p) { - VLOG(3) << "Free pointer=" << p << " on " << platform::Place(place); + VLOG(10) << "Free pointer=" << p << " on " << platform::Place(place); GetCPUBuddyAllocator()->Free(p); } @@ -69,11 +69,12 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { platform::GpuMinChunkSize(), platform::GpuMaxChunkSize()); } - VLOG(3) << "\n\nNOTE: each GPU device use " - << FLAGS_fraction_of_gpu_memory_to_use * 100 << "% of GPU memory.\n" - << "You can set environment variable '" - << platform::kEnvFractionGpuMemoryToUse - << "' to change the fraction of GPU usage.\n\n"; + VLOG(10) << "\n\nNOTE: each GPU device use " + << FLAGS_fraction_of_gpu_memory_to_use * 100 + << "% of GPU memory.\n" + << "You can set environment variable '" + << platform::kEnvFractionGpuMemoryToUse + << "' to change the fraction of GPU usage.\n\n"; } platform::SetDeviceId(gpu_id); return as[gpu_id]; From f354bd98610f184a11f22235d434ceb7bef3811e Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 31 Oct 2017 12:03:07 -0700 Subject: [PATCH 12/16] AddBiasOp does not care num_flatten_dims (#5200) * AddBiasOp does not care num_flatten_dims * Add comments --- python/paddle/v2/framework/layer_helper.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/python/paddle/v2/framework/layer_helper.py b/python/paddle/v2/framework/layer_helper.py index c57776441c..45d9cf3f48 100644 --- a/python/paddle/v2/framework/layer_helper.py +++ b/python/paddle/v2/framework/layer_helper.py @@ -142,8 +142,24 @@ class LayerHelper(object): return self.program.global_block().create_var( *args, persistable=False, **kwargs) - def append_bias_op(self, input_var): - size = list(input_var.shape[1:]) + def append_bias_op(self, input_var, num_flatten_dims=None): + """ + Append bias operator and return its output. If the user does not set + bias_attr, append_bias_op will return input_var + + :param input_var: the input variable. The len(input_var.shape) is larger + or equal than 2. + :param num_flatten_dims: The input tensor will be flatten as a matrix + when adding bias. + `matrix.shape = product(input_var.shape[0:num_flatten_dims]), product( + input_var.shape[num_flatten_dims:])` + """ + if num_flatten_dims is None: + num_flatten_dims = self.kwargs.get('num_flatten_dims', None) + if num_flatten_dims is None: + num_flatten_dims = 1 + + size = list(input_var.shape[num_flatten_dims:]) bias_attr = self.bias_attr() if not bias_attr: return input_var From b720f282b10fbb0baec226b841374c377eaba7f5 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Wed, 1 Nov 2017 00:05:49 -0700 Subject: [PATCH 13/16] deconv modify --- paddle/operators/conv2dtranspose_cudnn_op.cc | 8 ++++---- paddle/operators/conv2dtranspose_cudnn_op.cu | 8 +++----- .../paddle/v2/framework/tests/test_conv2dtranspose_op.py | 5 ++--- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/paddle/operators/conv2dtranspose_cudnn_op.cc b/paddle/operators/conv2dtranspose_cudnn_op.cc index 72c470389c..4f05364550 100644 --- a/paddle/operators/conv2dtranspose_cudnn_op.cc +++ b/paddle/operators/conv2dtranspose_cudnn_op.cc @@ -38,13 +38,13 @@ class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(conv2dtranspose_cudnn, ops::Conv2DTransposeOp, - ops::CudnnConv2DTransposeOpMaker, conv2dtranspose_cudnn_grad, +REGISTER_OP(conv2d_transpose_cudnn, ops::Conv2DTransposeOp, + ops::CudnnConv2DTransposeOpMaker, conv2d_transpose_cudnn_grad, ops::Conv2DTransposeOpGrad); REGISTER_OP_CPU_KERNEL( - conv2dtranspose_cudnn, + conv2d_transpose_cudnn, ops::GemmConv2DTransposeKernel); REGISTER_OP_CPU_KERNEL( - conv2dtranspose_cudnn_grad, + conv2d_transpose_cudnn_grad, ops::GemmConv2DTransposeGradKernel); diff --git a/paddle/operators/conv2dtranspose_cudnn_op.cu b/paddle/operators/conv2dtranspose_cudnn_op.cu index 8485bc65eb..1ec370a556 100644 --- a/paddle/operators/conv2dtranspose_cudnn_op.cu +++ b/paddle/operators/conv2dtranspose_cudnn_op.cu @@ -15,7 +15,7 @@ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/memory/memory.h" -#include "paddle/operators/conv2d_op.h" +#include "paddle/operators/conv2dtranspose_op.h" #include "paddle/platform/assert.h" #include "paddle/platform/cudnn_helper.h" @@ -76,7 +76,6 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel { workspace_size_limit = user_workspace_size * 1024 * 1024; } // ------------------- cudnn conv algorithm --------------------- - // cudnnConvolutionBwdAlgo_t algo; cudnnConvolutionBwdDataAlgo_t algo; auto handle = ctx.cuda_device_context().cudnn_handle(); // Get the algorithm @@ -92,7 +91,6 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel { platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, cudnn_output_desc, algo, &workspace_size_in_bytes)); - // workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); // Allocate on GPU memory platform::GPUPlace gpu = boost::get(ctx.GetPlace()); @@ -234,7 +232,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel { namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(conv2dtranspose_cudnn, +REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn, ops::CudnnConvTransposeOpKernel); -REGISTER_OP_GPU_KERNEL(conv2dtranspose_cudnn_grad, +REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad, ops::CudnnConvTransposeGradOpKernel); diff --git a/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py b/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py index 4ed6e0bcc4..0744370813 100644 --- a/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py +++ b/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py @@ -45,13 +45,12 @@ class TestConv2dTransposeOp(OpTest): filter_ = np.random.random(self.filter_size).astype("float32") output = conv2dtranspose_forward_naive( input_, filter_, conv2dtranspose_param).astype('float32') - # print 'deconv output py', output, output.shape self.inputs = {'Input': input_, 'Filter': filter_} self.attrs = { 'strides': self.stride, 'paddings': self.pad, - # 'dilations': self.dilations + 'dilations': self.dilations } self.outputs = {'Output': output} @@ -91,7 +90,7 @@ class TestConv2dTransposeOp(OpTest): class TestCudnn(TestConv2dTransposeOp): def init_op_type(self): - self.op_type = "conv2dtranspose_cudnn" + self.op_type = "conv2d_transpose_cudnn" if __name__ == '__main__': From 31187e7e7265f67e3b2ca67900b07242ad443b68 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Wed, 1 Nov 2017 11:47:09 -0700 Subject: [PATCH 14/16] deconv fix --- ...nspose_cudnn_op.cc => conv2d_transpose_cudnn_op.cc} | 2 +- ...nspose_cudnn_op.cu => conv2d_transpose_cudnn_op.cu} | 2 +- .../{conv2dtranspose_op.cc => conv2d_transpose_op.cc} | 10 +++++----- .../{conv2dtranspose_op.cu => conv2d_transpose_op.cu} | 6 +++--- .../{conv2dtranspose_op.h => conv2d_transpose_op.h} | 2 +- ...nv2dtranspose_op.py => test_conv2d_transpose_op.py} | 2 +- 6 files changed, 12 insertions(+), 12 deletions(-) rename paddle/operators/{conv2dtranspose_cudnn_op.cc => conv2d_transpose_cudnn_op.cc} (97%) rename paddle/operators/{conv2dtranspose_cudnn_op.cu => conv2d_transpose_cudnn_op.cu} (99%) rename paddle/operators/{conv2dtranspose_op.cc => conv2d_transpose_op.cc} (95%) rename paddle/operators/{conv2dtranspose_op.cu => conv2d_transpose_op.cu} (89%) rename paddle/operators/{conv2dtranspose_op.h => conv2d_transpose_op.h} (99%) rename python/paddle/v2/framework/tests/{test_conv2dtranspose_op.py => test_conv2d_transpose_op.py} (98%) diff --git a/paddle/operators/conv2dtranspose_cudnn_op.cc b/paddle/operators/conv2d_transpose_cudnn_op.cc similarity index 97% rename from paddle/operators/conv2dtranspose_cudnn_op.cc rename to paddle/operators/conv2d_transpose_cudnn_op.cc index 4f05364550..8ce94e0f04 100644 --- a/paddle/operators/conv2dtranspose_cudnn_op.cc +++ b/paddle/operators/conv2d_transpose_cudnn_op.cc @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/conv2dtranspose_op.h" +#include "paddle/operators/conv2d_transpose_op.h" namespace paddle { namespace operators { diff --git a/paddle/operators/conv2dtranspose_cudnn_op.cu b/paddle/operators/conv2d_transpose_cudnn_op.cu similarity index 99% rename from paddle/operators/conv2dtranspose_cudnn_op.cu rename to paddle/operators/conv2d_transpose_cudnn_op.cu index 1ec370a556..3844d9ad25 100644 --- a/paddle/operators/conv2dtranspose_cudnn_op.cu +++ b/paddle/operators/conv2d_transpose_cudnn_op.cu @@ -15,7 +15,7 @@ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/memory/memory.h" -#include "paddle/operators/conv2dtranspose_op.h" +#include "paddle/operators/conv2d_transpose_op.h" #include "paddle/platform/assert.h" #include "paddle/platform/cudnn_helper.h" diff --git a/paddle/operators/conv2dtranspose_op.cc b/paddle/operators/conv2d_transpose_op.cc similarity index 95% rename from paddle/operators/conv2dtranspose_op.cc rename to paddle/operators/conv2d_transpose_op.cc index c1b231906e..348527728b 100644 --- a/paddle/operators/conv2dtranspose_op.cc +++ b/paddle/operators/conv2d_transpose_op.cc @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/conv2dtranspose_op.h" +#include "paddle/operators/conv2d_transpose_op.h" namespace paddle { namespace operators { @@ -95,13 +95,13 @@ void Conv2DTransposeOpGrad::InferShape( } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(conv2dtranspose, ops::Conv2DTransposeOp, - ops::Conv2DTransposeOpMaker, conv2dtranspose_grad, +REGISTER_OP(conv2d_transpose, ops::Conv2DTransposeOp, + ops::Conv2DTransposeOpMaker, conv2d_transpose_grad, ops::Conv2DTransposeOpGrad); REGISTER_OP_CPU_KERNEL( - conv2dtranspose, + conv2d_transpose, ops::GemmConv2DTransposeKernel); REGISTER_OP_CPU_KERNEL( - conv2dtranspose_grad, + conv2d_transpose_grad, ops::GemmConv2DTransposeGradKernel); diff --git a/paddle/operators/conv2dtranspose_op.cu b/paddle/operators/conv2d_transpose_op.cu similarity index 89% rename from paddle/operators/conv2dtranspose_op.cu rename to paddle/operators/conv2d_transpose_op.cu index 761bc1959e..931ac9eed2 100644 --- a/paddle/operators/conv2dtranspose_op.cu +++ b/paddle/operators/conv2d_transpose_op.cu @@ -12,13 +12,13 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/conv2dtranspose_op.h" +#include "paddle/operators/conv2d_transpose_op.h" namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( - conv2dtranspose, + conv2d_transpose, ops::GemmConv2DTransposeKernel); REGISTER_OP_GPU_KERNEL( - conv2dtranspose_grad, + conv2d_transpose_grad, ops::GemmConv2DTransposeGradKernel); diff --git a/paddle/operators/conv2dtranspose_op.h b/paddle/operators/conv2d_transpose_op.h similarity index 99% rename from paddle/operators/conv2dtranspose_op.h rename to paddle/operators/conv2d_transpose_op.h index 8c70b3dcec..cab7788227 100644 --- a/paddle/operators/conv2dtranspose_op.h +++ b/paddle/operators/conv2d_transpose_op.h @@ -62,7 +62,7 @@ class GemmConv2DTransposeKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); // TODO(Zhuoyuan): Paddings can be added in future. - // groups will alway be disabled in conv2dtranspose. + // groups will alway be disabled in conv2d_transpose. const int batch_size = input->dims()[0]; const int m = input->dims()[1]; diff --git a/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py b/python/paddle/v2/framework/tests/test_conv2d_transpose_op.py similarity index 98% rename from python/paddle/v2/framework/tests/test_conv2dtranspose_op.py rename to python/paddle/v2/framework/tests/test_conv2d_transpose_op.py index 0744370813..999a0bdc62 100644 --- a/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_transpose_op.py @@ -67,7 +67,7 @@ class TestConv2dTransposeOp(OpTest): self.filter_size = [f_c, 6, 3, 3] def init_op_type(self): - self.op_type = "conv2dtranspose" + self.op_type = "conv2d_transpose" def test_check_grad_no_input(self): self.check_grad( From 2d956b82cd1d067c3b185423e6d13b0aab0dffb0 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Wed, 1 Nov 2017 17:15:07 -0700 Subject: [PATCH 15/16] deconv cudnn --- paddle/operators/conv2d_transpose_cudnn_op.cu | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/operators/conv2d_transpose_cudnn_op.cu b/paddle/operators/conv2d_transpose_cudnn_op.cu index 3844d9ad25..5a286897e0 100644 --- a/paddle/operators/conv2d_transpose_cudnn_op.cu +++ b/paddle/operators/conv2d_transpose_cudnn_op.cu @@ -29,7 +29,7 @@ using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; using DataLayout = platform::DataLayout; using CUDADeviceContext = platform::CUDADeviceContext; -static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024; +static constexpr size_t kConvCudnnWorkspaceLimitBytes = 1024 * 1024 * 1024; template class CudnnConvTransposeOpKernel : public framework::OpKernel { @@ -71,7 +71,7 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel { // ------------------- cudnn conv workspace --------------------- void* cudnn_workspace = nullptr; size_t workspace_size_in_bytes; // final workspace to allocate. - size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; + size_t workspace_size_limit = kConvCudnnWorkspaceLimitBytes; if (user_workspace_size > 0) { workspace_size_limit = user_workspace_size * 1024 * 1024; } @@ -125,6 +125,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel { std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); + // cudnn v5 does not support dilations std::vector dilations = ctx.Attr>("dilations"); int user_workspace_size = ctx.Attr("workspace_size_MB"); @@ -153,7 +154,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel { cudnnConvolutionBwdFilterAlgo_t filter_algo; size_t bwd_filter_ws_size, fwd_ws_size; size_t workspace_size_in_bytes = 0; - size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; + size_t workspace_size_limit = kConvCudnnWorkspaceLimitBytes; if (user_workspace_size > 0) { workspace_size_limit = user_workspace_size * 1024 * 1024; } From 0efac253d340b22999407d387a4c2098cb5581c2 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Wed, 1 Nov 2017 17:16:53 -0700 Subject: [PATCH 16/16] deconv small fix --- paddle/operators/conv2d_transpose_cudnn_op.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/operators/conv2d_transpose_cudnn_op.cu b/paddle/operators/conv2d_transpose_cudnn_op.cu index 5a286897e0..61fcfb3bd8 100644 --- a/paddle/operators/conv2d_transpose_cudnn_op.cu +++ b/paddle/operators/conv2d_transpose_cudnn_op.cu @@ -43,6 +43,7 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel { std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); + // cudnn v5 does not support dilations std::vector dilations = ctx.Attr>("dilations"); int user_workspace_size = ctx.Attr("workspace_size_MB");