From 3772d27dfbf83b22333b7cc0eacfb3acd805c036 Mon Sep 17 00:00:00 2001 From: zlx Date: Mon, 22 Jan 2018 21:09:41 +0800 Subject: [PATCH 01/29] add depthwise conv forward --- paddle/operators/conv_op.cc | 7 + paddle/operators/conv_op.cu.cc | 5 + paddle/operators/conv_op.h | 30 ++ paddle/operators/math/depthwise_conv.cu | 347 ++++++++++++++++++++++++ paddle/operators/math/depthwise_conv.h | 57 ++++ 5 files changed, 446 insertions(+) create mode 100644 paddle/operators/math/depthwise_conv.cu create mode 100644 paddle/operators/math/depthwise_conv.h diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index d6882b275b..55a78efea1 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -318,9 +318,16 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( namespace ops = paddle::operators; REGISTER_OP(conv2d, ops::ConvOp, ops::Conv2DOpMaker, conv2d_grad, ops::ConvOpGrad); +REGISTER_OP(depthwiseConv, ops::ConvOp, ops::Conv2DOpMaker, conv2d_grad, + ops::ConvOpGrad); REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad, ops::ConvOpGrad); +REGISTER_OP_CPU_KERNEL( + depthwiseConv, + ops::DepthwiseConvKernel, + ops::DepthwiseConvKernel); + REGISTER_OP_CPU_KERNEL( conv2d, ops::GemmConvKernel, ops::GemmConvKernel); diff --git a/paddle/operators/conv_op.cu.cc b/paddle/operators/conv_op.cu.cc index 4f942444f3..4c7a345784 100644 --- a/paddle/operators/conv_op.cu.cc +++ b/paddle/operators/conv_op.cu.cc @@ -16,6 +16,11 @@ limitations under the License. */ namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + depthwiseConv, + ops::DepthwiseConvKernel, + ops::DepthwiseConvKernel); + REGISTER_OP_CUDA_KERNEL( conv2d, ops::GemmConvKernel, ops::GemmConvKernel); diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index 5a8933e791..ca61f1c6e6 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/depthwise_conv.h" #include "paddle/operators/math/im2col.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/vol2col.h" @@ -350,5 +351,34 @@ class GemmConvGradKernel : public framework::OpKernel { } } }; + +template +class DepthwiseConvKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + // The filter will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); + Tensor* output = context.Output("Output"); + output->mutable_data(context.GetPlace()); + + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + std::vector dilations = context.Attr>("dilations"); + + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + + math::DepthwiseConvFunctor depthwiseConv; + + auto& dev_ctx = context.template device_context(); + depthwiseConv(dev_ctx, input, filter, filter_shape_vec, strides, paddings, + output); + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/depthwise_conv.cu b/paddle/operators/math/depthwise_conv.cu new file mode 100644 index 0000000000..16a0037ab1 --- /dev/null +++ b/paddle/operators/math/depthwise_conv.cu @@ -0,0 +1,347 @@ +/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/pooling.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +// CUDA kernel to compute the depthwise convolution forward pass +template +__global__ void KernelDepthwiseConv( + const int nthreads, const T* const input_data, const T* const filter_data, + const int batch_size, const int output_channels, const int output_height, + const int output_width, const int input_channels, const int input_height, + const int input_width, const int filter_multiplier, const int filter_height, + const int filter_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width, T* const output_data) { + int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + + if (index < nthreads) { + const int batch = index / output_channels / output_height / output_width; + const int c_out = (index / output_height / output_width) % output_channels; + const int h_out = (index / output_width) % output_height; + const int w_out = index % output_width; + + const int c_in = c_out / filter_multiplier; + const T* weight = filter_data + c_out * filter_height * filter_width; + T value = 0; + const int h_in_start = -padding_height + h_out * stride_height; + const int w_in_start = -padding_width + w_out * stride_width; + const int h_in_end = + -padding_height + h_out * stride_height + filter_height - 1; + const int w_in_end = + -padding_width + w_out * stride_width + filter_width - 1; + if ((h_in_start >= 0) && (h_in_end < input_height) && (w_in_start >= 0) && + (w_in_end < input_width)) { + for (int kh = 0; kh < filter_height; ++kh) { + for (int kw = 0; kw < filter_width; ++kw) { + const int h_in = -padding_height + h_out * stride_height + kh; + const int w_in = -padding_width + w_out * stride_width + kw; + const int offset = + ((batch * input_channels + c_in) * input_height + h_in) * + input_width + + w_in; + value += (*weight) * input_data[offset]; + ++weight; + } + } + } else { + for (int kh = 0; kh < filter_height; ++kh) { + for (int kw = 0; kw < filter_width; ++kw) { + const int h_in = -padding_height + h_out * stride_height + kh; + const int w_in = -padding_width + w_out * stride_width + kw; + if ((h_in >= 0) && (h_in < input_height) && (w_in >= 0) && + (w_in < input_width)) { + const int offset = + ((batch * input_channels + c_in) * input_height + h_in) * + input_width + + w_in; + value += (*weight) * input_data[offset]; + } + ++weight; + } + } + } + output_data[index] = value; + } +} +/* +// CUDA kernel to compute the depthwise convolution backprop w.r.t input. +template +__global__ void KernelDepthwiseConvInputGrad(const int nthreads, + const T* const top_diff, + const T* const weight_data, + const int num, + const int outputChannels, + const int outputHeight, + const int outputWidth, + const int inputChannels, + const int inputHeight, + const int inputWidth, + const int filterMultiplier, + const int filterHeight, + const int filterWidth, + const int strideH, + const int strideW, + const int paddingH, + const int paddingW, + T* const bottom_diff) { + int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + if (index < nthreads) { + const int batch = index / inputChannels / inputHeight / inputWidth; + const int c_in = (index / inputHeight / inputWidth) % inputChannels; + const int h_in = (index / inputWidth) % inputHeight; + const int w_in = index % inputWidth; + + const int c_out_start = c_in * filterMultiplier; + + int h_out_start = (h_in - filterHeight + paddingH + strideH) / strideH; + h_out_start = 0 > h_out_start ? 0 : h_out_start; + int h_out_end = (h_in + paddingH) / strideH; + h_out_end = outputHeight - 1 < h_out_end ? outputHeight - 1 : h_out_end; + int w_out_start = (w_in - filterWidth + paddingW + strideW) / strideW; + w_out_start = 0 > w_out_start ? 0 : w_out_start; + int w_out_end = (w_in + paddingW) / strideW; + w_out_end = outputWidth - 1 < w_out_end ? outputWidth - 1 : w_out_end; + + T value = 0; + + for (int c_out = c_out_start; c_out < c_out_start + filterMultiplier; + c_out++) { + for (int h_out = h_out_start; h_out <= h_out_end; ++h_out) { + const int filter_h = h_in + paddingH - h_out * strideH; + for (int w_out = w_out_start; w_out <= w_out_end; ++w_out) { + const int filter_w = w_in + paddingW - w_out * strideW; + const int filter_offset = c_out * filterHeight * filterWidth + + filter_h * filterWidth + filter_w; + const int top_diff_offset = + ((batch * outputChannels + c_out) * outputHeight + h_out) * + outputWidth + + w_out; + value += top_diff[top_diff_offset] * weight_data[filter_offset]; + } + } + } + bottom_diff[index] += value; + } +} + +// CUDA kernel to compute the depthwise convolution backprop w.r.t filter. +template +__global__ void KernelDepthwiseConvFilterGrad(const int num_i, + const int nthreads, + const T* const top_diff, + const T* const inputData, + const int num, + const int outputChannels, + const int outputHeight, + const int outputWidth, + const int inputChannels, + const int inputHeight, + const int inputWidth, + const int filterMultiplier, + const int filterHeight, + const int filterWidth, + const int strideH, + const int strideW, + const int paddingH, + const int paddingW, + T* const buffer_data) { + int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + if (index < nthreads) { + const int h_out = (index / outputWidth) % outputHeight; + const int w_out = index % outputWidth; + const int kh = + (index / filterWidth / outputHeight / outputWidth) % filterHeight; + const int kw = (index / outputHeight / outputWidth) % filterWidth; + const int h_in = -paddingH + h_out * strideH + kh; + const int w_in = -paddingW + w_out * strideW + kw; + if ((h_in >= 0) && (h_in < inputHeight) && (w_in >= 0) && + (w_in < inputWidth)) { + const int c_out = + index / (filterHeight * filterWidth * outputHeight * outputWidth); + const int c_in = c_out / filterMultiplier; + const int batch = num_i; + const int top_offset = + ((batch * outputChannels + c_out) * outputHeight + h_out) * + outputWidth + w_out; + const int bottom_offset = + ((batch * inputChannels + c_in) * inputHeight + h_in) * inputWidth + + w_in; + buffer_data[index] = top_diff[top_offset] * inputData[bottom_offset]; + } else { + buffer_data[index] = 0; + } + } +} +*/ + +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class DepthwiseConvFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& filter, std::vector& ksize, + std::vector& strides, std::vector& paddings, + framework::Tensor* output) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* input_data = input.data(); + const T* filter_data = filter.data(); + T* output_data = output->mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_height * output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelDepthwiseConv<<>>( + nthreads, input_data, filter_data, batch_size, output_channels, + output_height, output_width, input_channels, input_height, input_width, + output_channels / input_channels, ksize_height, ksize_width, + stride_height, stride_width, padding_height, padding_width, + output_data); + } +}; + +/* + +template +class DepthwiseConvInputGradFunctor +{ + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_process, framework::Tensor* input_grad) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + + int nthreads = batch_size * input_channels * input_height * input_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelPool2DGrad<<>>( + nthreads, input_data, output_data, output_grad_data, input_channels, + input_height, input_width, output_height, output_width, ksize_height, + ksize_width, stride_height, stride_width, padding_height, padding_width, + pool_process, input_grad_data); + } +}; + +template +class DepthwiseConvdFilterGradFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_height * output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool2DGrad<<>>( + nthreads, input_data, output_data, output_grad_data, input_channels, + input_height, input_width, output_height, output_width, ksize_height, + ksize_width, stride_height, stride_width, padding_height, padding_width, + input_grad_data); + } +}; +*/ + +template class DepthwiseConvFunctor, + float>; + +/* +template class DepthwiseConvInputGradFunctor, + float>; +template class DepthwiseConvFilterGradFunctor, + float>; + +template class DepthwiseConvFunctor, double>; +template class DepthwiseConvInputGradFunctor, + double>; +template class DepthwiseConvFilterGradFunctor, + double>; +*/ + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/depthwise_conv.h b/paddle/operators/math/depthwise_conv.h new file mode 100644 index 0000000000..2e48fe5912 --- /dev/null +++ b/paddle/operators/math/depthwise_conv.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { + +template +class DepthwiseConvFunctor { + public: + void operator()(const DeviceContext& context, const framework::Tensor& input, + const framework::Tensor& filter, std::vector& ksize, + std::vector& strides, std::vector& paddings, + framework::Tensor* output); +}; + +/* +template +class DepthwiseConvInputGradFunctor { +public: + void operator()(const DeviceContext& context, + const framework::Tensor& filter, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad); +}; + +template +class DepthwiseConvFilterGradFunctor { +public: + void operator()(const DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings, + framework::Tensor* filter_grad); +}; +*/ + +} // namespace math +} // namespace operators +} // namespace paddle From 06db70384397a4d5b61cd7493ebab9b06faf3244 Mon Sep 17 00:00:00 2001 From: xzl Date: Tue, 23 Jan 2018 14:22:01 +0800 Subject: [PATCH 02/29] ../../../../../paddle/api --- paddle/operators/CMakeLists.txt | 3 ++- paddle/operators/conv_op.cc | 11 ++++++++--- paddle/operators/conv_op.h | 7 ++----- paddle/operators/math/CMakeLists.txt | 1 + paddle/operators/math/depthwise_conv.cu | 18 ++++++------------ 5 files changed, 19 insertions(+), 21 deletions(-) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 6745a8da17..fa2f8caacf 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -155,7 +155,8 @@ op_library(parallel_do_op DEPS executor) # Regist multiple Kernel to pybind if (WITH_GPU) -op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS vol2col) +op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS + vol2col depthwise_conv) op_library(pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling) op_library(conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc conv_transpose_cudnn_op.cu.cc DEPS vol2col) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 55a78efea1..a53b11615c 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -318,15 +318,20 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( namespace ops = paddle::operators; REGISTER_OP(conv2d, ops::ConvOp, ops::Conv2DOpMaker, conv2d_grad, ops::ConvOpGrad); -REGISTER_OP(depthwiseConv, ops::ConvOp, ops::Conv2DOpMaker, conv2d_grad, +REGISTER_OP(depthwiseConv, ops::ConvOp, ops::Conv2DOpMaker, depthwiseConv_grad, ops::ConvOpGrad); REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad, ops::ConvOpGrad); REGISTER_OP_CPU_KERNEL( depthwiseConv, - ops::DepthwiseConvKernel, - ops::DepthwiseConvKernel); + ops::GemmConvKernel, + ops::GemmConvKernel); + +REGISTER_OP_CPU_KERNEL( + depthwiseConv_grad, + ops::GemmConvGradKernel, + ops::GemmConvGradKernel); REGISTER_OP_CPU_KERNEL( conv2d, ops::GemmConvKernel, diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index ca61f1c6e6..a9138dbf93 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -364,18 +364,15 @@ class DepthwiseConvKernel : public framework::OpKernel { Tensor* output = context.Output("Output"); output->mutable_data(context.GetPlace()); + std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); std::vector dilations = context.Attr>("dilations"); - framework::DDim filter_matrix_shape = {filter.dims()[0], - filter.numel() / filter.dims()[0]}; - filter.Resize(filter_matrix_shape); - math::DepthwiseConvFunctor depthwiseConv; auto& dev_ctx = context.template device_context(); - depthwiseConv(dev_ctx, input, filter, filter_shape_vec, strides, paddings, + depthwiseConv(dev_ctx, *input, filter, ksize, strides, paddings, output); } }; diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index c607704efa..6fb1531236 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -8,6 +8,7 @@ if(WITH_GPU) nv_library(softmax SRCS softmax.cc softmax.cu DEPS device_context) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS device_context) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) + nv_library(depthwise_conv SRCS depthwise_conv.cu DEPS device_context) nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context tensor) nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function) diff --git a/paddle/operators/math/depthwise_conv.cu b/paddle/operators/math/depthwise_conv.cu index 16a0037ab1..aee052d379 100644 --- a/paddle/operators/math/depthwise_conv.cu +++ b/paddle/operators/math/depthwise_conv.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/math/pooling.h" +#include "paddle/operators/math/depthwise_conv.h" #include "paddle/platform/cuda_helper.h" namespace paddle { @@ -195,7 +195,7 @@ __global__ void KernelDepthwiseConvFilterGrad(const int num_i, * Ksize, strides, paddings are two elements. These two elements represent * height and width, respectively. */ -template +template class DepthwiseConvFunctor { public: void operator()(const platform::CUDADeviceContext& context, @@ -226,7 +226,7 @@ class DepthwiseConvFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); - KernelDepthwiseConv<<>>( + KernelDepthwiseConv<<>>( nthreads, input_data, filter_data, batch_size, output_channels, output_height, output_width, input_channels, input_height, input_width, output_channels / input_channels, ksize_height, ksize_width, @@ -236,7 +236,6 @@ class DepthwiseConvFunctor { }; /* - template class DepthwiseConvInputGradFunctor { @@ -254,8 +253,7 @@ class DepthwiseConvInputGradFunctor const int output_height = output.dims()[2]; const int output_width = output.dims()[3]; const int ksize_height = ksize[0]; - const int ksize_width = ksize[1]; - const int stride_height = strides[0]; + const int ksize_width = ksize[1]; const int stride_height = strides[0]; const int stride_width = strides[1]; const int padding_height = paddings[0]; const int padding_width = paddings[1]; @@ -321,24 +319,20 @@ class DepthwiseConvdFilterGradFunctor { */ template class DepthwiseConvFunctor, float>; +template class DepthwiseConvFunctor; /* template class DepthwiseConvInputGradFunctor, float>; template class DepthwiseConvFilterGradFunctor, float>; template class DepthwiseConvFunctor, double>; template class DepthwiseConvInputGradFunctor, double>; template class DepthwiseConvFilterGradFunctor, double>; */ From 6e17babe49a7fdeb4f345c83d347f217d05e7e77 Mon Sep 17 00:00:00 2001 From: xzl Date: Tue, 30 Jan 2018 19:05:53 +0800 Subject: [PATCH 03/29] More efficient, add check on python side --- paddle/operators/CMakeLists.txt | 1 - paddle/operators/math/depthwise_conv.cu | 52 ++++++++++++------------- python/paddle/v2/fluid/layers/nn.py | 3 +- 3 files changed, 26 insertions(+), 30 deletions(-) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 8b442af45b..f7d600414f 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -159,7 +159,6 @@ if (WITH_GPU) op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS vol2col depthwise_conv) -# op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS vol2col) op_library(edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function) op_library(pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling) op_library(conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc diff --git a/paddle/operators/math/depthwise_conv.cu b/paddle/operators/math/depthwise_conv.cu index 23e26e8827..4aa38151e6 100644 --- a/paddle/operators/math/depthwise_conv.cu +++ b/paddle/operators/math/depthwise_conv.cu @@ -46,16 +46,18 @@ __global__ void KernelDepthwiseConv( -padding_height + h_out * stride_height + filter_height - 1; const int w_in_end = -padding_width + w_out * stride_width + filter_width - 1; + + const int in_offset = + ((batch * input_channels + c_in) * input_height) * input_width; + if ((h_in_start >= 0) && (h_in_end < input_height) && (w_in_start >= 0) && (w_in_end < input_width)) { for (int kh = 0; kh < filter_height; ++kh) { for (int kw = 0; kw < filter_width; ++kw) { - const int h_in = -padding_height + h_out * stride_height + kh; - const int w_in = -padding_width + w_out * stride_width + kw; - const int offset = - ((batch * input_channels + c_in) * input_height + h_in) * - input_width + - w_in; + const int h_in = h_in_start + kh; + const int w_in = w_in_start + kw; + const int offset = in_offset + h_in * input_width + w_in; + value += (*weight) * input_data[offset]; ++weight; } @@ -63,14 +65,11 @@ __global__ void KernelDepthwiseConv( } else { for (int kh = 0; kh < filter_height; ++kh) { for (int kw = 0; kw < filter_width; ++kw) { - const int h_in = -padding_height + h_out * stride_height + kh; - const int w_in = -padding_width + w_out * stride_width + kw; + const int h_in = h_in_start + kh; + const int w_in = w_in_start + kw; if ((h_in >= 0) && (h_in < input_height) && (w_in >= 0) && (w_in < input_width)) { - const int offset = - ((batch * input_channels + c_in) * input_height + h_in) * - input_width + - w_in; + const int offset = in_offset + h_in * input_width + w_in; value += (*weight) * input_data[offset]; } ++weight; @@ -159,36 +158,33 @@ __global__ void KernelDepthwiseConvFilterGrad( const int h_in_end = -padding_height + h_out * stride_height + filter_height; const int w_in_end = -padding_width + w_out * stride_width + filter_width; + const int in_offset = + (batch * input_channels + c_in) * input_height * input_width; + + T* addr_offset = filter_grad_data + c_out * filter_height * filter_width; + if ((h_in_start >= 0) && (h_in_end < input_height) && (w_in_start >= 0) && (w_in_end < input_width)) { for (int kw = 0; kw < filter_width; kw++) { for (int kh = 0; kh < filter_height; kh++) { - const int h_in = -padding_height + h_out * stride_height + kh; - const int w_in = -padding_width + w_out * stride_width + kw; - const int offset = - ((batch * input_channels + c_in) * input_height + h_in) * - input_width + - w_in; + const int h_in = h_in_start + kh; + const int w_in = w_in_start + kw; + const int offset = in_offset + h_in * input_width + w_in; const T diff_temp = output_grad_data[index] * input_data[offset]; - T* addr = filter_grad_data + c_out * filter_height * filter_width + - kh * filter_width + kw; + T* addr = addr_offset + kh * filter_width + kw; paddle::platform::CudaAtomicAdd(addr, diff_temp); } } } else { for (int kw = 0; kw < filter_width; kw++) { for (int kh = 0; kh < filter_height; kh++) { - const int h_in = -padding_height + h_out * stride_height + kh; - const int w_in = -padding_width + w_out * stride_width + kw; + const int h_in = h_in_start + kh; + const int w_in = w_in_start + kw; if ((h_in >= 0) && (h_in < input_height) && (w_in >= 0) && (w_in < input_width)) { - const int offset = - ((batch * input_channels + c_in) * input_height + h_in) * - input_width + - w_in; + const int offset = in_offset + h_in * input_width + w_in; const T diff_temp = output_grad_data[index] * input_data[offset]; - T* addr = filter_grad_data + c_out * filter_height * filter_width + - kh * filter_width + kw; + T* addr = addr_offset + kh * filter_width + kw; paddle::platform::CudaAtomicAdd(addr, diff_temp); } } diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 40c7ec5866..a047cc4eec 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -1013,7 +1013,8 @@ def conv2d(input, num_channels = input.shape[1] l_type = 'conv2d' - if num_channels == groups and not use_cudnn: + if (num_channels == groups and num_filters % num_channels == 0 and + not use_cudnn): l_type = 'depthwise_conv' helper = LayerHelper(l_type, **locals()) From 84ded49d6632aec9733bbbcd242c539029711cd8 Mon Sep 17 00:00:00 2001 From: xzl Date: Thu, 1 Feb 2018 23:46:43 +0800 Subject: [PATCH 04/29] fix comments --- paddle/operators/conv_op.h | 3 +++ paddle/operators/math/depthwise_conv.cu | 11 +++++++---- paddle/operators/math/depthwise_conv.h | 11 +++++++---- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index 5b47eefb83..3c1d0e9c1c 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -361,6 +361,9 @@ class DepthwiseConvKernel : public framework::OpKernel { Tensor* output = context.Output("Output"); output->mutable_data(context.GetPlace()); + PADDLE_ENFORCE_EQ( + output->dims()[1] % input->dims()[1], 0, + "The output channels must be a multiple of the input channels"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); std::vector dilations = context.Attr>("dilations"); diff --git a/paddle/operators/math/depthwise_conv.cu b/paddle/operators/math/depthwise_conv.cu index 4aa38151e6..b9b958c92b 100644 --- a/paddle/operators/math/depthwise_conv.cu +++ b/paddle/operators/math/depthwise_conv.cu @@ -203,8 +203,9 @@ class DepthwiseConvFunctor { public: void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, - const framework::Tensor& filter, std::vector& strides, - std::vector& paddings, framework::Tensor* output) { + const framework::Tensor& filter, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -244,7 +245,8 @@ class DepthwiseConvInputGradFunctor { const framework::Tensor& input, const framework::Tensor& filter, const framework::Tensor& output_grad, - std::vector& strides, std::vector& paddings, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; @@ -284,7 +286,8 @@ class DepthwiseConvFilterGradFunctor { void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, const framework::Tensor& output_grad, - std::vector& strides, std::vector& paddings, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* filter_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; diff --git a/paddle/operators/math/depthwise_conv.h b/paddle/operators/math/depthwise_conv.h index 34eecca7b6..4708920bb4 100644 --- a/paddle/operators/math/depthwise_conv.h +++ b/paddle/operators/math/depthwise_conv.h @@ -29,8 +29,9 @@ template class DepthwiseConvFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor& input, - const framework::Tensor& filter, std::vector& strides, - std::vector& paddings, framework::Tensor* output); + const framework::Tensor& filter, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* output); }; template @@ -39,7 +40,8 @@ class DepthwiseConvInputGradFunctor { void operator()(const DeviceContext& context, const framework::Tensor& input, const framework::Tensor& filter, const framework::Tensor& output_grad, - std::vector& strides, std::vector& paddings, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* input_grad); }; @@ -48,7 +50,8 @@ class DepthwiseConvFilterGradFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor& input, const framework::Tensor& output_grad, - std::vector& strides, std::vector& paddings, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* filter_grad); }; From 6695a204cd739a000ea1d647143d5145c0e6974f Mon Sep 17 00:00:00 2001 From: xuwei06 Date: Wed, 10 Jan 2018 14:38:15 -0800 Subject: [PATCH 05/29] helper functions fetch_var and get_var fetch_var for getting the values of a variable with given name get_var for getting the Variable with given name --- python/paddle/v2/fluid/executor.py | 48 ++++++++++++++----- python/paddle/v2/fluid/framework.py | 20 ++++++++ python/paddle/v2/fluid/layers/tensor.py | 8 ++-- .../paddle/v2/fluid/tests/test_fetch_var.py | 23 +++++++++ 4 files changed, 85 insertions(+), 14 deletions(-) create mode 100644 python/paddle/v2/fluid/tests/test_fetch_var.py diff --git a/python/paddle/v2/fluid/executor.py b/python/paddle/v2/fluid/executor.py index 9f48815b8b..af69ce2abc 100644 --- a/python/paddle/v2/fluid/executor.py +++ b/python/paddle/v2/fluid/executor.py @@ -17,7 +17,9 @@ import contextlib from framework import Program, default_main_program from . import core -__all__ = ['Executor', 'global_scope', 'scope_guard', 'switch_scope'] +__all__ = [ + 'Executor', 'global_scope', 'scope_guard', 'switch_scope', 'fetch_var' +] g_scope = core.Scope() @@ -80,12 +82,12 @@ def has_feed_operators(block, feed_targets, feed_holder_name): Args: block: a block instance (typically global block of a program) feed_targets: a dictionary of {feed_target_name: feed_target_data} - feed_holder_name: the name of the variable that holds the data of - all feed targets. The type of this feed_holder variable is + feed_holder_name: the name of the variable that holds the data of + all feed targets. The type of this feed_holder variable is FEED_MINIBATCH, which is essentially vector. Returns: - A boolean value that indicates whether a block has feed operators + A boolean value that indicates whether a block has feed operators that match the info contained in feed_targets and feed_holder_name. """ @@ -108,7 +110,7 @@ def has_feed_operators(block, feed_targets, feed_holder_name): def has_fetch_operators(block, fetch_targets, fetch_holder_name): """ Check whether the block already has fetch operators. - + Return false if the block does not have any fetch operators. If some fetch operators have been appended to the block, check that the info contained in these fetch operators matches the fetch_targets @@ -118,13 +120,13 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name): Args: block: a block instance (typically global block of a program) fetch_targets: a dictionary of {fetch_target_name: fetch_target_data} - fetch_holder_name: the name of the variable that holds the data of - all fetch targets. The type of this fetch_holder variable is - FETCH_LIST, which is essentially vector. + fetch_holder_name: the name of the variable that holds the data of + all fetch targets. The type of this fetch_holder variable is + FETCH_LIST, which is essentially vector. - Return: - A boolean value that indicates whether a block has fetch operators - that match the info contained in fetch_targets and fetch_holder_name. + Return: + A boolean value that indicates whether a block has fetch operators + that match the info contained in fetch_targets and fetch_holder_name. """ fetch_count = 0 @@ -146,6 +148,30 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name): return fetch_count > 0 +def fetch_var(name, scope=None, return_numpy=True): + """ + Fetch the value of the variable with the given name from the given scope + Args: + name(str): name of the variable + scope(core.Scope|None): scope object. + If None, global_scope() will be used. + return_numpy(bool): whether convert the tensor to numpy.ndarray + Returns: + LodTensor|numpy.ndarray + """ + assert isinstance(name, str) + if scope is None: + scope = global_scope() + assert isinstance(scope, core.Scope) + + var = global_scope().find_var(name) + assert var is not None, "Cannot find '%s' in scope." % name + tensor = var.get_tensor() + if return_numpy: + tensor = as_numpy(tensor) + return tensor + + class Executor(object): def __init__(self, places): if not isinstance(places, list) and not isinstance(places, tuple): diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index 7f5187d299..7fcd19b215 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -31,6 +31,7 @@ __all__ = [ 'program_guard', 'switch_startup_program', 'switch_main_program', + 'get_var', ] EMPTY_VAR_NAME = core.kEmptyVarName() @@ -1124,3 +1125,22 @@ def program_guard(main_program, startup_program=None): switch_main_program(main_program) if startup_program is not None: switch_startup_program(startup_program) + + +def get_var(name, program=None): + """ + Get a variable by name from the global block of a program + Args: + name(str): name of the variable + program(Program|None): program object. + If None, default_global_program() will be used. + + Returns: + Variable + """ + if program is None: + program = default_main_program() + assert isinstance(name, str) + assert isinstance(name, Program) + + return program.global_block().var(name) diff --git a/python/paddle/v2/fluid/layers/tensor.py b/python/paddle/v2/fluid/layers/tensor.py index c435c5206d..27067d458d 100644 --- a/python/paddle/v2/fluid/layers/tensor.py +++ b/python/paddle/v2/fluid/layers/tensor.py @@ -35,13 +35,15 @@ __all__ = [ ] -def create_tensor(dtype, name=None): +def create_tensor(dtype, name=None, persistable=False): helper = LayerHelper("create_tensor", **locals()) - return helper.create_variable(name=helper.name, dtype=dtype) + return helper.create_variable( + name=helper.name, dtype=dtype, persistable=persistable) def create_parameter(shape, dtype, + name=None, attr=None, is_bias=False, default_initializer=None): @@ -62,7 +64,7 @@ def create_parameter(shape, """ helper = LayerHelper("create_parameter", **locals()) if attr is None: - attr = ParamAttr() + attr = ParamAttr(name=name) return helper.create_parameter(attr, shape, dtype, is_bias, default_initializer) diff --git a/python/paddle/v2/fluid/tests/test_fetch_var.py b/python/paddle/v2/fluid/tests/test_fetch_var.py new file mode 100644 index 0000000000..670ab54f51 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_fetch_var.py @@ -0,0 +1,23 @@ +import paddle.v2.fluid as fluid +import paddle.v2.fluid.layers as layers +import op_test +import numpy +import unittest + + +class TestFetchVar(op_test.OpTest): + def test_fetch_var(self): + val = numpy.array([1, 3, 5]).astype(numpy.int32) + x = layers.create_tensor(dtype="int32", persistable=True, name="x") + layers.assign(input=val, output=x) + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_main_program(), feed={}, fetch_list=[]) + fetched_x = fluid.fetch_var("x") + self.assertTrue( + numpy.array_equal(fetched_x, val), + "fetch_x=%s val=%s" % (fetched_x, val)) + self.assertEqual(fetched_x.dtype, val.dtype) + + +if __name__ == '__main__': + unittest.main() From 7208190701d9a3c6d1e4dc507940f5d89d12024f Mon Sep 17 00:00:00 2001 From: xuwei06 Date: Fri, 12 Jan 2018 09:27:38 -0800 Subject: [PATCH 06/29] More informative comment and error message for fetch_var() --- python/paddle/v2/fluid/executor.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/paddle/v2/fluid/executor.py b/python/paddle/v2/fluid/executor.py index af69ce2abc..0eddcc3a5a 100644 --- a/python/paddle/v2/fluid/executor.py +++ b/python/paddle/v2/fluid/executor.py @@ -152,8 +152,10 @@ def fetch_var(name, scope=None, return_numpy=True): """ Fetch the value of the variable with the given name from the given scope Args: - name(str): name of the variable - scope(core.Scope|None): scope object. + name(str): name of the variable. Typically, only persistable variables + can be found in the scope used for running the program. + scope(core.Scope|None): scope object. It should be the scope where + you pass to Executor.run() when running your program. If None, global_scope() will be used. return_numpy(bool): whether convert the tensor to numpy.ndarray Returns: @@ -165,7 +167,10 @@ def fetch_var(name, scope=None, return_numpy=True): assert isinstance(scope, core.Scope) var = global_scope().find_var(name) - assert var is not None, "Cannot find '%s' in scope." % name + assert var is not None, ( + "Cannot find " + name + " in scope. Perhaps you need to make the" + " variable persistable by using var.persistable = True in your" + " program.") tensor = var.get_tensor() if return_numpy: tensor = as_numpy(tensor) From 901cab9ed3e0838954f0015221093fc1d64b5795 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Fri, 2 Feb 2018 13:52:41 +0800 Subject: [PATCH 07/29] Add `make clean` in docker/build.sh --- paddle/scripts/docker/build.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index df7310d6b7..59f3af0398 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -79,6 +79,7 @@ function run_build() { Building in /paddle/build ... ============================================ EOF + make clean make -j `nproc` } From 2ffa3a8bf6a7cb0e3d5e1ac211417c234ab04f04 Mon Sep 17 00:00:00 2001 From: xzl Date: Fri, 2 Feb 2018 18:28:23 +0800 Subject: [PATCH 08/29] rename op to depthwise_conv2d, more efficient --- paddle/operators/conv_op.cc | 8 +- paddle/operators/conv_op.cu.cc | 4 +- paddle/operators/math/depthwise_conv.cu | 79 ++++++------------- python/paddle/v2/fluid/layers/nn.py | 2 +- .../paddle/v2/fluid/tests/test_conv2d_op.py | 4 +- 5 files changed, 34 insertions(+), 63 deletions(-) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index d25f3fd1a0..cef7ddd5fe 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -320,20 +320,20 @@ REGISTER_OP(conv2d, ops::ConvOp, ops::Conv2DOpMaker, conv2d_grad, ops::ConvOpGrad); // depthwise convolution op -REGISTER_OP(depthwise_conv, ops::ConvOp, ops::Conv2DOpMaker, - depthwise_conv_grad, ops::ConvOpGrad); +REGISTER_OP(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker, + depthwise_conv2d_grad, ops::ConvOpGrad); REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad, ops::ConvOpGrad); // depthwise conv kernel // TODO(xingzhaolong): neon kernel for mobile REGISTER_OP_CPU_KERNEL( - depthwise_conv, + depthwise_conv2d, ops::GemmConvKernel, ops::GemmConvKernel); REGISTER_OP_CPU_KERNEL( - depthwise_conv_grad, + depthwise_conv2d_grad, ops::GemmConvGradKernel, ops::GemmConvGradKernel); diff --git a/paddle/operators/conv_op.cu.cc b/paddle/operators/conv_op.cu.cc index 02a4e52466..d0bd40ee95 100644 --- a/paddle/operators/conv_op.cu.cc +++ b/paddle/operators/conv_op.cu.cc @@ -17,12 +17,12 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( - depthwise_conv, + depthwise_conv2d, ops::DepthwiseConvKernel, ops::DepthwiseConvKernel); REGISTER_OP_CUDA_KERNEL( - depthwise_conv_grad, + depthwise_conv2d_grad, ops::DepthwiseConvGradKernel, ops::DepthwiseConvGradKernel); diff --git a/paddle/operators/math/depthwise_conv.cu b/paddle/operators/math/depthwise_conv.cu index b9b958c92b..b212e78208 100644 --- a/paddle/operators/math/depthwise_conv.cu +++ b/paddle/operators/math/depthwise_conv.cu @@ -42,38 +42,23 @@ __global__ void KernelDepthwiseConv( T value = 0; const int h_in_start = -padding_height + h_out * stride_height; const int w_in_start = -padding_width + w_out * stride_width; - const int h_in_end = - -padding_height + h_out * stride_height + filter_height - 1; - const int w_in_end = - -padding_width + w_out * stride_width + filter_width - 1; + const int h_in_end = h_in_start + filter_height; + const int w_in_end = w_in_start + filter_width; const int in_offset = ((batch * input_channels + c_in) * input_height) * input_width; - if ((h_in_start >= 0) && (h_in_end < input_height) && (w_in_start >= 0) && - (w_in_end < input_width)) { - for (int kh = 0; kh < filter_height; ++kh) { - for (int kw = 0; kw < filter_width; ++kw) { - const int h_in = h_in_start + kh; - const int w_in = w_in_start + kw; - const int offset = in_offset + h_in * input_width + w_in; - - value += (*weight) * input_data[offset]; - ++weight; - } - } - } else { - for (int kh = 0; kh < filter_height; ++kh) { - for (int kw = 0; kw < filter_width; ++kw) { - const int h_in = h_in_start + kh; - const int w_in = w_in_start + kw; - if ((h_in >= 0) && (h_in < input_height) && (w_in >= 0) && - (w_in < input_width)) { - const int offset = in_offset + h_in * input_width + w_in; - value += (*weight) * input_data[offset]; - } - ++weight; - } + const int h_end = h_in_end < input_height ? h_in_end : input_height; + const int w_end = w_in_end < input_width ? w_in_end : input_width; + const int h_start = h_in_start > 0 ? h_in_start : 0; + const int w_start = w_in_start > 0 ? w_in_start : 0; + + for (int h_in = h_start; h_in < h_end; h_in++) { + for (int w_in = w_start; w_in < w_end; w_in++) { + const int offset = in_offset + h_in * input_width + w_in; + value += + weight[(h_in - h_in_start) * filter_width + (w_in - w_in_start)] * + input_data[offset]; } } output_data[index] = value; @@ -162,32 +147,18 @@ __global__ void KernelDepthwiseConvFilterGrad( (batch * input_channels + c_in) * input_height * input_width; T* addr_offset = filter_grad_data + c_out * filter_height * filter_width; - - if ((h_in_start >= 0) && (h_in_end < input_height) && (w_in_start >= 0) && - (w_in_end < input_width)) { - for (int kw = 0; kw < filter_width; kw++) { - for (int kh = 0; kh < filter_height; kh++) { - const int h_in = h_in_start + kh; - const int w_in = w_in_start + kw; - const int offset = in_offset + h_in * input_width + w_in; - const T diff_temp = output_grad_data[index] * input_data[offset]; - T* addr = addr_offset + kh * filter_width + kw; - paddle::platform::CudaAtomicAdd(addr, diff_temp); - } - } - } else { - for (int kw = 0; kw < filter_width; kw++) { - for (int kh = 0; kh < filter_height; kh++) { - const int h_in = h_in_start + kh; - const int w_in = w_in_start + kw; - if ((h_in >= 0) && (h_in < input_height) && (w_in >= 0) && - (w_in < input_width)) { - const int offset = in_offset + h_in * input_width + w_in; - const T diff_temp = output_grad_data[index] * input_data[offset]; - T* addr = addr_offset + kh * filter_width + kw; - paddle::platform::CudaAtomicAdd(addr, diff_temp); - } - } + const int h_end = h_in_end < input_height ? h_in_end : input_height; + const int w_end = w_in_end < input_width ? w_in_end : input_width; + const int h_start = h_in_start > 0 ? h_in_start : 0; + const int w_start = w_in_start > 0 ? w_in_start : 0; + + for (int h_in = h_start; h_in < h_end; h_in++) { + for (int w_in = w_start; w_in < w_end; w_in++) { + const int offset = in_offset + h_in * input_width + w_in; + const T diff_temp = output_grad_data[index] * input_data[offset]; + T* addr = addr_offset + (h_in - h_in_start) * filter_width + + (w_in - w_in_start); + paddle::platform::CudaAtomicAdd(addr, diff_temp); } } } diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 4be6ae8ed6..aaf096f0dd 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -1237,7 +1237,7 @@ def conv2d(input, l_type = 'conv2d' if (num_channels == groups and num_filters % num_channels == 0 and not use_cudnn): - l_type = 'depthwise_conv' + l_type = 'depthwise_conv2d' helper = LayerHelper(l_type, **locals()) dtype = helper.input_dtype() diff --git a/python/paddle/v2/fluid/tests/test_conv2d_op.py b/python/paddle/v2/fluid/tests/test_conv2d_op.py index a034d0ab91..7512ea333e 100644 --- a/python/paddle/v2/fluid/tests/test_conv2d_op.py +++ b/python/paddle/v2/fluid/tests/test_conv2d_op.py @@ -250,7 +250,7 @@ class TestDepthwiseConv(TestConv2dOp): assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] / self.groups self.filter_size = [6, f_c, 3, 3] - self.op_type = "depthwise_conv" + self.op_type = "depthwise_conv2d" class TestDepthwiseConv2(TestConv2dOp): @@ -262,7 +262,7 @@ class TestDepthwiseConv2(TestConv2dOp): assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] / self.groups self.filter_size = [6, f_c, 3, 3] - self.op_type = "depthwise_conv" + self.op_type = "depthwise_conv2d" # cudnn v5 does not support dilation conv. From 37a251ebafce61776b2fea7a2fb2ee16defd14ea Mon Sep 17 00:00:00 2001 From: xuwei06 Date: Fri, 2 Feb 2018 15:46:51 -0800 Subject: [PATCH 09/29] Fix copyright for test_fetch_var.py --- python/paddle/v2/fluid/tests/test_fetch_var.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/paddle/v2/fluid/tests/test_fetch_var.py b/python/paddle/v2/fluid/tests/test_fetch_var.py index 670ab54f51..ed75a350b0 100644 --- a/python/paddle/v2/fluid/tests/test_fetch_var.py +++ b/python/paddle/v2/fluid/tests/test_fetch_var.py @@ -1,3 +1,17 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle.v2.fluid as fluid import paddle.v2.fluid.layers as layers import op_test From dbe06551b86460d5ebf18ee33218cd6d11cd07e4 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Fri, 2 Feb 2018 16:12:53 -0800 Subject: [PATCH 10/29] Channel should notify both condition variables on close --- paddle/framework/details/buffered_channel.h | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/paddle/framework/details/buffered_channel.h b/paddle/framework/details/buffered_channel.h index b093e15892..9c806461aa 100644 --- a/paddle/framework/details/buffered_channel.h +++ b/paddle/framework/details/buffered_channel.h @@ -49,6 +49,7 @@ class Buffered : public paddle::framework::Channel { } void NotifyAllSenders(std::unique_lock*); + void NotifyAllParticipants(std::unique_lock*); }; template @@ -80,7 +81,7 @@ template void Buffered::Close() { std::unique_lock lock(mu_); closed_ = true; - NotifyAllSenders(&lock); + NotifyAllParticipants(&lock); } template @@ -88,7 +89,7 @@ Buffered::~Buffered() { std::unique_lock lock(mu_); closed_ = true; channel_.clear(); - NotifyAllSenders(&lock); + NotifyAllParticipants(&lock); } template @@ -97,6 +98,13 @@ void Buffered::NotifyAllSenders(std::unique_lock* lock) { full_cond_var_.notify_all(); } +template +void Buffered::NotifyAllParticipants(std::unique_lock* lock) { + lock->unlock(); + full_cond_var_.notify_all(); + empty_cond_var_.notify_all(); +} + } // namespace details } // namespace framework } // namespace paddle From 022e5dee8e685134e6c0199d7d0ee8762a03eb80 Mon Sep 17 00:00:00 2001 From: kavyasrinet Date: Fri, 2 Feb 2018 20:08:39 -0800 Subject: [PATCH 11/29] Added more receivers less senders. Receivers should block. (#8061) * Adding more receivers less senders * Added more receivers less senders * Added more send * Updated comment * Fixed code style * Fixing review comments --- paddle/framework/channel_test.cc | 36 +++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/paddle/framework/channel_test.cc b/paddle/framework/channel_test.cc index 31ac72eda9..c3533bbb1a 100644 --- a/paddle/framework/channel_test.cc +++ b/paddle/framework/channel_test.cc @@ -67,7 +67,7 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) { std::thread t([&]() { // Try to write more than buffer size. for (size_t i = 0; i < 2 * buffer_size; ++i) { - ch->Send(&i); // should not block + ch->Send(&i); // should block after 10 iterations sum += i; } }); @@ -207,3 +207,37 @@ TEST(Channel, UnbufferedLessReceiveMoreSendTest) { t.join(); delete ch; } + +TEST(Channel, UnbufferedMoreReceiveLessSendTest) { + auto ch = MakeChannel(0); + unsigned sum_send = 0; + unsigned sum_receive = 0; + // The receiver should block after 5 + // iterations, since there are only 5 senders. + std::thread t([&]() { + for (int i = 0; i < 8; i++) { + int recv; + ch->Receive(&recv); // should block after the fifth iteration. + EXPECT_EQ(recv, i); + sum_receive += i; + } + }); + for (int i = 0; i < 5; i++) { + ch->Send(&i); + sum_send += i; + } + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec + EXPECT_EQ(sum_send, 10U); + EXPECT_EQ(sum_receive, 10U); + // send three more elements + for (int i = 5; i < 8; i++) { + ch->Send(&i); + sum_send += i; + } + + CloseChannel(ch); + t.join(); + EXPECT_EQ(sum_send, 28U); + EXPECT_EQ(sum_receive, 28U); + delete ch; +} From b60da6729fa2484506869bc29271761de91676b7 Mon Sep 17 00:00:00 2001 From: chengduo Date: Sat, 3 Feb 2018 23:32:56 +0800 Subject: [PATCH 12/29] Refine buffer channel (#8098) * refine buffer channel * refine Receive and Send * follow comments --- paddle/framework/channel.h | 4 +-- paddle/framework/details/buffered_channel.h | 25 ++++++++----------- paddle/framework/details/unbuffered_channel.h | 14 ++++++++--- 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/paddle/framework/channel.h b/paddle/framework/channel.h index 0570980c5a..b679387b11 100644 --- a/paddle/framework/channel.h +++ b/paddle/framework/channel.h @@ -23,8 +23,8 @@ namespace framework { template class Channel { public: - virtual void Send(T*) = 0; - virtual void Receive(T*) = 0; + virtual bool Send(T*) = 0; + virtual bool Receive(T*) = 0; virtual size_t Cap() = 0; virtual void Close() = 0; virtual ~Channel() {} diff --git a/paddle/framework/details/buffered_channel.h b/paddle/framework/details/buffered_channel.h index 9c806461aa..7ac234b8d4 100644 --- a/paddle/framework/details/buffered_channel.h +++ b/paddle/framework/details/buffered_channel.h @@ -30,8 +30,8 @@ class Buffered : public paddle::framework::Channel { friend void paddle::framework::CloseChannel(Channel*); public: - virtual void Send(T*); - virtual void Receive(T*); + virtual bool Send(T*); + virtual bool Receive(T*); virtual size_t Cap() { return cap_; } virtual void Close(); virtual ~Buffered(); @@ -48,33 +48,36 @@ class Buffered : public paddle::framework::Channel { PADDLE_ENFORCE_GT(cap, 0); } - void NotifyAllSenders(std::unique_lock*); void NotifyAllParticipants(std::unique_lock*); }; template -void Buffered::Send(T* item) { +bool Buffered::Send(T* item) { std::unique_lock lock(mu_); full_cond_var_.wait(lock, [this]() { return channel_.size() < cap_ || closed_; }); + bool ret = false; if (!closed_) { channel_.push_back(std::move(*item)); lock.unlock(); empty_cond_var_.notify_one(); + ret = true; } + return ret; } template -void Buffered::Receive(T* item) { +bool Buffered::Receive(T* item) { std::unique_lock lock(mu_); empty_cond_var_.wait(lock, [this]() { return !channel_.empty() || closed_; }); + bool ret = false; if (!closed_) { *item = std::move(channel_.front()); channel_.pop_front(); - NotifyAllSenders(&lock); - } else { - item = nullptr; + full_cond_var_.notify_one(); + ret = true; } + return ret; } template @@ -92,12 +95,6 @@ Buffered::~Buffered() { NotifyAllParticipants(&lock); } -template -void Buffered::NotifyAllSenders(std::unique_lock* lock) { - lock->unlock(); - full_cond_var_.notify_all(); -} - template void Buffered::NotifyAllParticipants(std::unique_lock* lock) { lock->unlock(); diff --git a/paddle/framework/details/unbuffered_channel.h b/paddle/framework/details/unbuffered_channel.h index 0dc5afd7e5..f86a894bb4 100644 --- a/paddle/framework/details/unbuffered_channel.h +++ b/paddle/framework/details/unbuffered_channel.h @@ -29,8 +29,8 @@ class UnBuffered : public paddle::framework::Channel { friend void paddle::framework::CloseChannel(Channel*); public: - virtual void Send(T*); - virtual void Receive(T*); + virtual bool Send(T*); + virtual bool Receive(T*); virtual size_t Cap() { return 0; } virtual void Close(); virtual ~UnBuffered(); @@ -57,7 +57,7 @@ class UnBuffered : public paddle::framework::Channel { // This function implements the concept of how data should // be sent from a writer to a reader. template -void UnBuffered::Send(T* data) { +bool UnBuffered::Send(T* data) { // Prevent other writers from entering std::unique_lock writer_lock(mu_write_); writer_found_ = true; @@ -66,6 +66,7 @@ void UnBuffered::Send(T* data) { cv_writer_.wait(cv_lock, [this]() { return reader_found_ == true || closed_; }); cv_reader_.notify_one(); + bool ret = false; if (!closed_) { std::unique_lock channel_lock(mu_ch_); item = data; @@ -74,14 +75,16 @@ void UnBuffered::Send(T* data) { channel_lock.lock(); cv_channel_.wait(channel_lock, [this]() { return item == nullptr || closed_; }); + ret = true; } writer_found_ = false; + return ret; } // This function implements the concept of how // data that was sent by a writer is read from a reader. template -void UnBuffered::Receive(T* data) { +bool UnBuffered::Receive(T* data) { // Prevent other readers from entering std::unique_lock read_lock{mu_read_}; reader_found_ = true; @@ -90,6 +93,7 @@ void UnBuffered::Receive(T* data) { cv_reader_.wait(cv_lock, [this]() { return writer_found_ == true || closed_; }); cv_writer_.notify_one(); + bool ret = false; if (!closed_) { std::unique_lock lock_ch{mu_ch_}; // Reader should wait for the writer to first write its data @@ -98,10 +102,12 @@ void UnBuffered::Receive(T* data) { *data = std::move(*item); item = nullptr; lock_ch.unlock(); + ret = true; } cv_channel_.notify_one(); } reader_found_ = false; + return ret; } // This function implements the sequence of events From be65516876ae32fe2f8cfde1aaa2d22926ccc583 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Sun, 4 Feb 2018 16:37:02 +0000 Subject: [PATCH 13/29] Fix the error when sorted_key is none in profiler --- paddle/platform/profiler.cc | 2 +- python/paddle/v2/fluid/profiler.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/platform/profiler.cc b/paddle/platform/profiler.cc index 2a8afc9403..6df087d154 100644 --- a/paddle/platform/profiler.cc +++ b/paddle/platform/profiler.cc @@ -233,7 +233,7 @@ void ParseEvents(std::vector>& events, }; break; default: - sorted_domain = "event end time"; + sorted_domain = "event first end time"; } std::vector> events_table; diff --git a/python/paddle/v2/fluid/profiler.py b/python/paddle/v2/fluid/profiler.py index d4a2cd7eea..d33a4c52a8 100644 --- a/python/paddle/v2/fluid/profiler.py +++ b/python/paddle/v2/fluid/profiler.py @@ -103,10 +103,10 @@ def profiler(state, sorted_key=None): core.enable_profiler(prof_state) yield - if sorted_key not in ['calls', 'total', 'max', 'min', 'ave']: - raise ValueError("The state must be in 'calls', 'total', " - "'max', 'min', 'ave'") sorted_key = 'default' if sorted_key is None else sorted_key + if sorted_key not in ['default', 'calls', 'total', 'max', 'min', 'ave']: + raise ValueError("The sorted_key must be None or in 'calls', 'total', " + "'max', 'min' and 'ave'") key_map = { 'default': core.EventSortingKey.kDefault, 'calls': core.EventSortingKey.kCalls, From 1d2dd9c4a5b99074cec3cb642f64bfd2124e6412 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Sun, 4 Feb 2018 10:04:53 -0800 Subject: [PATCH 14/29] Close buffered channel should unblock the blocked senders and receivers (#8109) --- paddle/framework/channel_test.cc | 113 +++++++++++++++++++++++++++++-- 1 file changed, 106 insertions(+), 7 deletions(-) diff --git a/paddle/framework/channel_test.cc b/paddle/framework/channel_test.cc index c3533bbb1a..444d68498c 100644 --- a/paddle/framework/channel_test.cc +++ b/paddle/framework/channel_test.cc @@ -48,12 +48,12 @@ TEST(Channel, SufficientBufferSizeDoesntBlock) { const size_t buffer_size = 10; auto ch = MakeChannel(buffer_size); for (size_t i = 0; i < buffer_size; ++i) { - ch->Send(&i); // should not block + EXPECT_EQ(ch->Send(&i), true); // should not block } size_t out; for (size_t i = 0; i < buffer_size; ++i) { - ch->Receive(&out); // should not block + EXPECT_EQ(ch->Receive(&out), true); // should not block EXPECT_EQ(out, i); } CloseChannel(ch); @@ -67,7 +67,10 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) { std::thread t([&]() { // Try to write more than buffer size. for (size_t i = 0; i < 2 * buffer_size; ++i) { - ch->Send(&i); // should block after 10 iterations + if (i < buffer_size) + EXPECT_EQ(ch->Send(&i), true); // should block after 10 iterations + else + EXPECT_EQ(ch->Send(&i), false); sum += i; } }); @@ -84,13 +87,13 @@ TEST(Channel, SimpleUnbufferedChannelTest) { unsigned sum_send = 0; std::thread t([&]() { for (int i = 0; i < 5; i++) { - ch->Send(&i); + EXPECT_EQ(ch->Send(&i), true); sum_send += i; } }); for (int i = 0; i < 5; i++) { int recv; - ch->Receive(&recv); + EXPECT_EQ(ch->Receive(&recv), true); EXPECT_EQ(recv, i); } @@ -100,6 +103,102 @@ TEST(Channel, SimpleUnbufferedChannelTest) { delete ch; } +// This tests that closing a buffered channel also unblocks +// any receivers waiting on the channel +TEST(Channel, BufferedChannelCloseUnblocksReceiversTest) { + auto ch = MakeChannel(1); + size_t num_threads = 5; + std::thread t[num_threads]; + bool thread_ended[num_threads]; + + // Launches threads that try to read and are blocked because of no writers + for (size_t i = 0; i < num_threads; i++) { + thread_ended[i] = false; + t[i] = std::thread( + [&](bool *p) { + int data; + // All reads should return false + EXPECT_EQ(ch->Receive(&data), false); + *p = true; + }, + &thread_ended[i]); + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait + + // Verify that all threads are blocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], false); + } + + // Explicitly close the channel + // This should unblock all receivers + CloseChannel(ch); + + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait + + // Verify that all threads got unblocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], true); + } + + for (size_t i = 0; i < num_threads; i++) t[i].join(); + delete ch; +} + +// This tests that closing a buffered channel also unblocks +// any senders waiting for channel to have write space +TEST(Channel, BufferedChannelCloseUnblocksSendersTest) { + auto ch = MakeChannel(1); + size_t num_threads = 5; + std::thread t[num_threads]; + bool thread_ended[num_threads]; + bool send_success[num_threads]; + + // Launches threads that try to write and are blocked because of no readers + for (size_t i = 0; i < num_threads; i++) { + thread_ended[i] = false; + send_success[i] = false; + t[i] = std::thread( + [&](bool *ended, bool *success) { + int data = 10; + *success = ch->Send(&data); + *ended = true; + }, + &thread_ended[i], &send_success[i]); + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait + + // Verify that atleast 4 threads are blocked + int ct = 0; + for (size_t i = 0; i < num_threads; i++) { + if (thread_ended[i] == false) ct++; + } + // Atleast 4 threads must be blocked + EXPECT_GE(ct, 4); + + // Explicitly close the thread + // This should unblock all senders + CloseChannel(ch); + + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait + + // Verify that all threads got unblocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], true); + } + + // Verify that only 1 send was successful + ct = 0; + for (size_t i = 0; i < num_threads; i++) { + if (send_success[i]) ct++; + } + // Only 1 send must be successful + EXPECT_EQ(ct, 1); + + for (size_t i = 0; i < num_threads; i++) t[i].join(); + delete ch; +} + // This tests that closing an unbuffered channel also unblocks // unblocks any receivers waiting for senders TEST(Channel, UnbufferedChannelCloseUnblocksReceiversTest) { @@ -114,7 +213,7 @@ TEST(Channel, UnbufferedChannelCloseUnblocksReceiversTest) { t[i] = std::thread( [&](bool *p) { int data; - ch->Receive(&data); + EXPECT_EQ(ch->Receive(&data), false); *p = true; }, &thread_ended[i]); @@ -155,7 +254,7 @@ TEST(Channel, UnbufferedChannelCloseUnblocksSendersTest) { t[i] = std::thread( [&](bool *p) { int data = 10; - ch->Send(&data); + EXPECT_EQ(ch->Send(&data), false); *p = true; }, &thread_ended[i]); From 6f28084b4d062100336fd3889012b91c6e278bcc Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Mon, 5 Feb 2018 10:53:26 +0800 Subject: [PATCH 15/29] debug/format protobuf to human-readable codes (#8086) --- python/paddle/v2/fluid/debuger.py | 192 ++++++++++++++++++++++++++++++ 1 file changed, 192 insertions(+) diff --git a/python/paddle/v2/fluid/debuger.py b/python/paddle/v2/fluid/debuger.py index d379352442..db1808c647 100644 --- a/python/paddle/v2/fluid/debuger.py +++ b/python/paddle/v2/fluid/debuger.py @@ -12,10 +12,202 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import re from graphviz import GraphPreviewGenerator import proto.framework_pb2 as framework_pb2 +_vartype2str_ = [ + "UNK", + "LoDTensor", + "SelectedRows", + "FeedMinibatch", + "FetchList", + "StepScopes", + "LodRankTable", + "LoDTensorArray", + "PlaceList", +] +_dtype2str_ = [ + "bool", + "int16", + "int32", + "int64", + "float16", + "float32", + "float64", +] + + +def repr_data_type(type): + return _dtype2str_[type] + + +def repr_tensor(proto): + return "tensor(type={}, shape={})".format(_dtype2str_[int(proto.data_type)], + str(proto.dims)) + + +reprtpl = "{ttype} {name} ({reprs})" + + +def repr_lodtensor(proto): + if not proto.lod_tensor: return + level = proto.lod_tensor.lod_level + reprs = repr_tensor(proto.lod_tensor.tensor) + return reprtpl.format( + ttype="LoDTensor" if level > 0 else "Tensor", + name=proto.name, + reprs="level=%d, %s" % (level, reprs) if level > 0 else reprs) + + +def repr_selected_rows(proto): + if not proto.selected_rows: return + return reprtpl.format( + ttype="SelectedRows", + name=proto.name, + reprs=repr_tensor(proto.selected_rows)) + + +def repr_tensor_array(proto): + if not proto.tensor_array: return + return reprtpl.format( + ttype="TensorArray", + name=proto.name, + reprs="level=%d, %s" % (proto.tensor_array.lod_level, + repr_tensor(proto.lod_tensor))) + + +type_handlers = [ + repr_lodtensor, + repr_selected_rows, + repr_tensor_array, +] + + +def repr_var(vardesc): + for handler in type_handlers: + res = handler(vardesc) + if res: + return res + + +def pprint_program_codes(program_desc): + reprs = [] + for block_idx in range(program_desc.num_blocks()): + block_desc = program_desc.block(block_idx) + block_repr = pprint_block_codes(block_desc) + reprs.append(block_repr) + return '\n'.join(reprs) + + +def pprint_block_codes(block_desc, show_backward=False): + def is_op_backward(op_desc): + if op_desc.type.endswith('_grad'): return True + + def is_var_backward(var): + if "@GRAD" in var.parameter: return True + for arg in var.arguments: + if "@GRAD" in arg: return True + + for var in op_desc.inputs: + if is_var_backward(var): return True + for var in op_desc.outputs: + if is_var_backward(var): return True + return False + + def is_var_backward(var_desc): + return "@GRAD" in var_desc.name + + if type(block_desc) is not framework_pb2.BlockDesc: + block_desc = framework_pb2.BlockDesc.FromString( + block_desc.serialize_to_string()) + var_reprs = [] + op_reprs = [] + for var in block_desc.vars: + if not show_backward and is_var_backward(var): + continue + var_reprs.append(repr_var(var)) + + for op in block_desc.ops: + if not show_backward and is_op_backward(op): continue + op_reprs.append(repr_op(op)) + + tpl = "// block-{idx} parent-{pidx}\n// variables\n{vars}\n\n// operators\n{ops}\n" + return tpl.format( + idx=block_desc.idx, + pidx=block_desc.parent_idx, + vars='\n'.join(var_reprs), + ops='\n'.join(op_reprs), ) + + +def repr_attr(desc): + tpl = "{key}={value}" + valgetter = [ + lambda attr: attr.i, + lambda attr: attr.f, + lambda attr: attr.s, + lambda attr: attr.ints, + lambda attr: attr.floats, + lambda attr: attr.strings, + lambda attr: attr.b, + lambda attr: attr.bools, + lambda attr: attr.block_idx, + lambda attr: attr.l, + ] + key = desc.name + value = valgetter[desc.type](desc) + if key == "dtype": + value = repr_data_type(value) + return tpl.format(key=key, value=str(value)), (key, value) + + +def _repr_op_fill_constant(optype, inputs, outputs, attrs): + if optype == "fill_constant": + return "{output} = {data} [shape={shape}]".format( + output=','.join(outputs), + data=attrs['value'], + shape=str(attrs['shape'])) + + +op_repr_handlers = [_repr_op_fill_constant, ] + + +def repr_op(opdesc): + optype = None + attrs = [] + attr_dict = {} + is_target = None + inputs = [] + outputs = [] + + tpl = "{outputs} = {optype}({inputs}{is_target}) [{attrs}]" + args2value = lambda args: args[0] if len(args) == 1 else str(list(args)) + for var in opdesc.inputs: + key = var.parameter + value = args2value(var.arguments) + inputs.append("%s=%s" % (key, value)) + for var in opdesc.outputs: + value = args2value(var.arguments) + outputs.append(value) + for attr in opdesc.attrs: + attr_repr, attr_pair = repr_attr(attr) + attrs.append(attr_repr) + attr_dict[attr_pair[0]] = attr_pair[1] + + is_target = opdesc.is_target + + for handler in op_repr_handlers: + res = handler(opdesc.type, inputs, outputs, attr_dict) + if res: return res + + return tpl.format( + outputs=', '.join(outputs), + optype=opdesc.type, + inputs=', '.join(inputs), + attrs="{%s}" % ','.join(attrs), + is_target=", is_target" if is_target else "") + def draw_block_graphviz(block, highlights=None, path="./temp.dot"): ''' From 96d4bf5337c985feff01a549c26133e3ed1c3bde Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Mon, 5 Feb 2018 12:38:37 +0800 Subject: [PATCH 16/29] prevent make clean from cleaning ExternalProject boost --- CMakeLists.txt | 2 +- cmake/external/boost.cmake | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e8ea828dd2..49334279f6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -137,7 +137,7 @@ include(external/openblas) # download, build, install openblas include(external/mkldnn) # download, build, install mkldnn include(external/swig) # download, build, install swig include(external/warpctc) # download, build, install warpctc -include(external/boost) # download, build, install boost +include(external/boost) # download boost include(external/any) # download libn::any include(external/eigen) # download eigen3 include(external/pybind11) # download pybind11 diff --git a/cmake/external/boost.cmake b/cmake/external/boost.cmake index c70d83b3f4..dbc676bdac 100644 --- a/cmake/external/boost.cmake +++ b/cmake/external/boost.cmake @@ -21,6 +21,7 @@ set(BOOST_URL "http://sourceforge.net/projects/boost/files/boost/${BOO set(BOOST_SOURCES_DIR ${THIRD_PARTY_PATH}/boost) set(BOOST_DOWNLOAD_DIR "${BOOST_SOURCES_DIR}/src/${BOOST_PROJECT}") set(BOOST_INCLUDE_DIR "${BOOST_DOWNLOAD_DIR}/${BOOST_TAR}" CACHE PATH "boost include directory." FORCE) +set_directory_properties(PROPERTIES CLEAN_NO_CUSTOM 1) include_directories(${BOOST_INCLUDE_DIR}) From eef381d07482f845a875269f1b963f1d135e2cdc Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Mon, 5 Feb 2018 12:47:25 +0800 Subject: [PATCH 17/29] remove duplicated mobile index --- doc/index_cn.rst | 1 - doc/index_en.rst | 1 - doc/mobile/index_cn.rst | 9 --------- doc/mobile/index_en.rst | 9 --------- 4 files changed, 20 deletions(-) delete mode 100644 doc/mobile/index_cn.rst delete mode 100644 doc/mobile/index_en.rst diff --git a/doc/index_cn.rst b/doc/index_cn.rst index ada51c2d73..9279bac7f4 100644 --- a/doc/index_cn.rst +++ b/doc/index_cn.rst @@ -8,4 +8,3 @@ PaddlePaddle 文档 howto/index_cn.rst api/index_cn.rst faq/index_cn.rst - mobile/index_cn.rst diff --git a/doc/index_en.rst b/doc/index_en.rst index 23b64b6cad..64684b8b9b 100644 --- a/doc/index_en.rst +++ b/doc/index_en.rst @@ -7,4 +7,3 @@ PaddlePaddle Documentation getstarted/index_en.rst howto/index_en.rst api/index_en.rst - mobile/index_en.rst diff --git a/doc/mobile/index_cn.rst b/doc/mobile/index_cn.rst deleted file mode 100644 index 1d99666e58..0000000000 --- a/doc/mobile/index_cn.rst +++ /dev/null @@ -1,9 +0,0 @@ -MOBILE -====== - -.. toctree:: - :maxdepth: 1 - - cross_compiling_for_android_cn.md - cross_compiling_for_ios_cn.md - cross_compiling_for_raspberry_cn.md diff --git a/doc/mobile/index_en.rst b/doc/mobile/index_en.rst deleted file mode 100644 index ef421dacad..0000000000 --- a/doc/mobile/index_en.rst +++ /dev/null @@ -1,9 +0,0 @@ -MOBILE -====== - -.. toctree:: - :maxdepth: 1 - - cross_compiling_for_android_en.md - cross_compiling_for_ios_en.md - cross_compiling_for_raspberry_en.md From 7dabee27960b5e043b85aca3ee51568443b326f4 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 5 Feb 2018 15:00:03 +0800 Subject: [PATCH 18/29] Add type Reader for VarDesc Add a new type `Reader` for `VarDesc`, which can holds more than one LoDTensor. --- paddle/framework/backward.cc | 4 +- paddle/framework/framework.proto | 10 +- paddle/framework/op_desc.cc | 4 +- paddle/framework/program_desc_test.cc | 4 +- paddle/framework/var_desc.cc | 174 ++++++++++++++++-- paddle/framework/var_desc.h | 20 +- paddle/inference/io.cc | 2 +- paddle/pybind/protobuf.cc | 14 +- .../v2/fluid/tests/test_protobuf_descs.py | 38 ++++ 9 files changed, 246 insertions(+), 24 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 85e693434a..f52a51519f 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -534,7 +534,7 @@ ParamGradInfoMap AppendBackward( auto root_block = program_desc.MutableBlock(root_block_idx); std::string fill_one_op_out = GradVarName(target.Name()); - bool is_scalar = target.Shape() == std::vector{1}; + bool is_scalar = target.GetShape() == std::vector{1}; PADDLE_ENFORCE(is_scalar, "target should be scalar"); VLOG(3) << "backward from loss=" << target.Name() << " data_type=" << target.GetDataType(); @@ -565,7 +565,7 @@ ParamGradInfoMap AppendBackward( auto var = root_block->Var(fill_one_op_out); var->SetDataType(target.GetDataType()); - var->SetShape(target.Shape()); + var->SetShape(target.GetShape()); auto& target_grad = retv[target.Name()]; target_grad.name_ = fill_one_op_out; target_grad.block_idx_ = root_block_idx; diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index 5b6ef03f61..f65ccae6e6 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -116,6 +116,8 @@ message LoDTensorArrayDesc { optional int32 lod_level = 2 [ default = 0 ]; } +message Reader { repeated LoDTensorDesc lod_tensor = 1; } + message VarDesc { enum VarType { LOD_TENSOR = 1; @@ -126,13 +128,15 @@ message VarDesc { LOD_RANK_TABLE = 6; LOD_TENSOR_ARRAY = 7; PLACE_LIST = 8; + READER = 9; } required string name = 1; required VarType type = 2; - optional LoDTensorDesc lod_tensor = 3; - optional TensorDesc selected_rows = 4; + optional bool persistable = 3 [ default = false ]; + optional LoDTensorDesc lod_tensor = 4; + optional TensorDesc selected_rows = 5; optional LoDTensorArrayDesc tensor_array = 6; - optional bool persistable = 5 [ default = false ]; + optional Reader reader = 7; } message BlockDesc { diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index f554c77845..ad361852ec 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -458,11 +458,11 @@ DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const { auto var = block_.FindVarRecursive(name); PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); try { - auto shape = var->Shape(); + auto shape = var->GetShape(); if (shape.empty()) { return framework::make_ddim({0UL}); } else { - return framework::make_ddim(var->Shape()); + return framework::make_ddim(var->GetShape()); } } catch (...) { VLOG(5) << "GetDim of variable " << name << " error"; diff --git a/paddle/framework/program_desc_test.cc b/paddle/framework/program_desc_test.cc index 59947c9f21..9945aee31b 100644 --- a/paddle/framework/program_desc_test.cc +++ b/paddle/framework/program_desc_test.cc @@ -53,7 +53,7 @@ TEST(ProgramDesc, copy_ctor) { ASSERT_NE(copy, var_before); ASSERT_EQ(copy->Name(), var_before->Name()); ASSERT_EQ(copy->GetType(), var_before->GetType()); - ASSERT_EQ(copy->Shape(), var_before->Shape()); + ASSERT_EQ(copy->GetShape(), var_before->GetShape()); ASSERT_EQ(copy->Proto()->SerializeAsString(), var_before->Proto()->SerializeAsString()); }; @@ -117,7 +117,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) { ASSERT_NE(restored, var_before); ASSERT_EQ(restored->Name(), var_before->Name()); ASSERT_EQ(restored->GetType(), var_before->GetType()); - ASSERT_EQ(restored->Shape(), var_before->Shape()); + ASSERT_EQ(restored->GetShape(), var_before->GetShape()); ASSERT_EQ(restored->Proto()->SerializeAsString(), var_before->Proto()->SerializeAsString()); }; diff --git a/paddle/framework/var_desc.cc b/paddle/framework/var_desc.cc index 62ab6593ef..44bd2363c8 100644 --- a/paddle/framework/var_desc.cc +++ b/paddle/framework/var_desc.cc @@ -26,18 +26,91 @@ void VarDesc::SetShape(const std::vector &dims) { VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims()); } +void VarDesc::SetTensorDescNum(size_t num) { + switch (desc_.type()) { + case proto::VarDesc::READER: { + auto *lod_tensors_ptr = desc_.mutable_reader()->mutable_lod_tensor(); + lod_tensors_ptr->Clear(); + for (size_t i = 0; i < num; ++i) { + lod_tensors_ptr->Add(); + } + return; + } break; + default: + PADDLE_THROW( + "Setting 'sub_tensor_number' is not supported by the type of var %s.", + this->Name()); + } +} + +size_t VarDesc::GetTensorDescNum() const { + switch (desc_.type()) { + case proto::VarDesc::READER: + return desc_.reader().lod_tensor_size(); + break; + default: + PADDLE_THROW( + "Getting 'sub_tensor_number' is not supported by the type of var %s.", + this->Name()); + } +} + +void VarDesc::SetShapes( + const std::vector> &multiple_dims) { + PADDLE_ENFORCE_EQ(multiple_dims.size(), GetTensorDescNum(), + "The number of given shapes(%d) doesn't equal to the " + "number of sub tensor.", + multiple_dims.size(), GetTensorDescNum()); + std::vector tensors = mutable_tensor_descs(); + for (size_t i = 0; i < multiple_dims.size(); ++i) { + VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims()); + } +} + +std::vector VarDesc::GetShape() const { + return RepeatedToVector(tensor_desc().dims()); +} + +std::vector> VarDesc::GetShapes() const { + std::vector descs = tensor_descs(); + std::vector> res; + res.reserve(descs.size()); + for (const auto &tensor_desc : descs) { + res.push_back(RepeatedToVector(tensor_desc.dims())); + } + return res; +} + void VarDesc::SetDataType(proto::DataType data_type) { mutable_tensor_desc()->set_data_type(data_type); } -std::vector VarDesc::Shape() const { - return RepeatedToVector(tensor_desc().dims()); +void VarDesc::SetDataTypes( + const std::vector &multiple_data_type) { + PADDLE_ENFORCE_EQ(multiple_data_type.size(), GetTensorDescNum(), + "The number of given data types(%d) doesn't equal to the " + "number of sub tensor.", + multiple_data_type.size(), GetTensorDescNum()); + std::vector tensor_descs = mutable_tensor_descs(); + for (size_t i = 0; i < multiple_data_type.size(); ++i) { + tensor_descs[i]->set_data_type(multiple_data_type[i]); + } } proto::DataType VarDesc::GetDataType() const { return tensor_desc().data_type(); } +std::vector VarDesc::GetDataTypes() const { + std::vector descs = tensor_descs(); + std::vector res; + res.reserve(descs.size()); + for (const auto &tensor_desc : descs) { + res.push_back(tensor_desc.data_type()); + } + return res; +} + void VarDesc::SetLoDLevel(int32_t lod_level) { switch (desc_.type()) { case proto::VarDesc::LOD_TENSOR: @@ -47,8 +120,28 @@ void VarDesc::SetLoDLevel(int32_t lod_level) { desc_.mutable_tensor_array()->set_lod_level(lod_level); break; default: - PADDLE_THROW("Tensor type=%d does not support LoDLevel", - desc_.tensor_array().lod_level()); + PADDLE_THROW( + "Setting 'lod_level' is not supported by the type of var %s.", + this->Name()); + } +} + +void VarDesc::SetLoDLevels(const std::vector &multiple_lod_level) { + PADDLE_ENFORCE_EQ(multiple_lod_level.size(), GetTensorDescNum(), + "The number of given data types(%d) doesn't equal to the " + "number of sub tensor.", + multiple_lod_level.size(), GetTensorDescNum()); + switch (desc_.type()) { + case proto::VarDesc::READER: { + size_t i = 0; + for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) { + lod_tensor.set_lod_level(multiple_lod_level[i++]); + } + } break; + default: + PADDLE_THROW( + "Setting 'lod_levels' is not supported by the type of var %s.", + this->Name()); } } @@ -59,13 +152,31 @@ int32_t VarDesc::GetLoDLevel() const { case proto::VarDesc::LOD_TENSOR_ARRAY: return desc_.tensor_array().lod_level(); default: - PADDLE_THROW("Tensor type=%d does not support LoDLevel", - desc_.tensor_array().lod_level()); + PADDLE_THROW( + "Getting 'lod_level' is not supported by the type of var %s.", + this->Name()); + } +} + +std::vector VarDesc::GetLoDLevels() const { + std::vector res; + switch (desc_.type()) { + case proto::VarDesc::READER: + res.reserve(desc_.reader().lod_tensor_size()); + for (auto &lod_tensor : desc_.reader().lod_tensor()) { + res.push_back(lod_tensor.lod_level()); + } + return res; + break; + default: + PADDLE_THROW( + "Getting 'lod_levels' is not supported by the type of var %s.", + this->Name()); } } const proto::TensorDesc &VarDesc::tensor_desc() const { - PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type"); + PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set."); switch (desc_.type()) { case proto::VarDesc::SELECTED_ROWS: return desc_.selected_rows(); @@ -74,13 +185,32 @@ const proto::TensorDesc &VarDesc::tensor_desc() const { case proto::VarDesc::LOD_TENSOR_ARRAY: return desc_.tensor_array().tensor(); default: - PADDLE_THROW("The type of var %s is unsupported.", this->Name()); + PADDLE_THROW( + "Getting 'tensor_desc' is not supported by the type of var %s.", + this->Name()); + } +} + +std::vector VarDesc::tensor_descs() const { + PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); + std::vector res; + res.reserve(GetTensorDescNum()); + switch (desc_.type()) { + case proto::VarDesc::READER: + for (const auto &lod_tensor : desc_.reader().lod_tensor()) { + res.push_back(lod_tensor.tensor()); + } + return res; + default: + PADDLE_THROW( + "Getting 'tensor_descs' is not supported by the type of var " + "%s.", + this->Name()); } } proto::TensorDesc *VarDesc::mutable_tensor_desc() { - PADDLE_ENFORCE(desc_.has_type(), - "invoke MutableTensorDesc must after set type"); + PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); switch (desc_.type()) { case proto::VarDesc::SELECTED_ROWS: return desc_.mutable_selected_rows(); @@ -89,8 +219,30 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() { case proto::VarDesc::LOD_TENSOR_ARRAY: return desc_.mutable_tensor_array()->mutable_tensor(); default: - PADDLE_THROW("Unexpected branch."); + PADDLE_THROW( + "Getting 'mutable_tensor_desc' is not supported by the type of var " + "%s.", + this->Name()); } } + +std::vector VarDesc::mutable_tensor_descs() { + PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); + std::vector res; + res.reserve(GetTensorDescNum()); + switch (desc_.type()) { + case proto::VarDesc::READER: + for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) { + res.push_back(lod_tensor.mutable_tensor()); + } + return res; + default: + PADDLE_THROW( + "Getting 'tensor_descs' is not supported by the type of var " + "%s.", + this->Name()); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/var_desc.h b/paddle/framework/var_desc.h index 9316b14bb6..862b9a5d80 100644 --- a/paddle/framework/var_desc.h +++ b/paddle/framework/var_desc.h @@ -68,18 +68,34 @@ class VarDesc { void SetName(std::string name) { desc_.set_name(name); } + void SetTensorDescNum(size_t num); + + size_t GetTensorDescNum() const; + void SetShape(const std::vector &dims); + void SetShapes(const std::vector> &multiple_dims); + + std::vector GetShape() const; + + std::vector> GetShapes() const; + void SetDataType(proto::DataType data_type); - std::vector Shape() const; + void SetDataTypes(const std::vector &multiple_data_type); proto::DataType GetDataType() const; + std::vector GetDataTypes() const; + void SetLoDLevel(int32_t lod_level); + void SetLoDLevels(const std::vector &multiple_lod_level); + int32_t GetLoDLevel() const; + std::vector GetLoDLevels() const; + proto::VarDesc::VarType GetType() const; void SetType(proto::VarDesc::VarType type); @@ -90,7 +106,9 @@ class VarDesc { private: const proto::TensorDesc &tensor_desc() const; + std::vector tensor_descs() const; proto::TensorDesc *mutable_tensor_desc(); + std::vector mutable_tensor_descs(); proto::VarDesc desc_; }; diff --git a/paddle/inference/io.cc b/paddle/inference/io.cc index 60ad7af1c0..1ed14b69c8 100644 --- a/paddle/inference/io.cc +++ b/paddle/inference/io.cc @@ -55,7 +55,7 @@ void LoadPersistables(framework::Executor& executor, VLOG(3) << "parameter's name: " << var->Name(); framework::VarDesc* new_var = load_block->Var(var->Name()); - new_var->SetShape(var->Shape()); + new_var->SetShape(var->GetShape()); new_var->SetDataType(var->GetDataType()); new_var->SetType(var->GetType()); new_var->SetLoDLevel(var->GetLoDLevel()); diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 371d6119d4..0f1953abe0 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -214,11 +214,20 @@ void BindVarDsec(py::module &m) { py::return_value_policy::reference) .def("set_name", &VarDesc::SetName) .def("set_shape", &VarDesc::SetShape) + .def("set_shapes", &VarDesc::SetShapes) .def("set_dtype", &VarDesc::SetDataType) - .def("shape", &VarDesc::Shape, py::return_value_policy::reference) + .def("set_dtypes", &VarDesc::SetDataTypes) + .def("set_tensor_num", &VarDesc::SetTensorDescNum) + .def("tensor_num", &VarDesc::GetTensorDescNum) + .def("shape", &VarDesc::GetShape, py::return_value_policy::reference) + .def("shapes", &VarDesc::GetShapes, py::return_value_policy::reference) .def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference) + .def("dtypes", &VarDesc::GetDataTypes, py::return_value_policy::reference) .def("lod_level", &VarDesc::GetLoDLevel) + .def("lod_levels", &VarDesc::GetLoDLevels, + py::return_value_policy::reference) .def("set_lod_level", &VarDesc::SetLoDLevel) + .def("set_lod_levels", &VarDesc::SetLoDLevels) .def("type", &VarDesc::GetType) .def("set_type", &VarDesc::SetType) .def("serialize_to_string", SerializeMessage) @@ -233,7 +242,8 @@ void BindVarDsec(py::module &m) { .value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES) .value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE) .value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY) - .value("PLACE_LIST", proto::VarDesc::PLACE_LIST); + .value("PLACE_LIST", proto::VarDesc::PLACE_LIST) + .value("READER", proto::VarDesc::READER); } void BindOpDesc(py::module &m) { diff --git a/python/paddle/v2/fluid/tests/test_protobuf_descs.py b/python/paddle/v2/fluid/tests/test_protobuf_descs.py index 9034b2f4ef..ac6de68b5f 100644 --- a/python/paddle/v2/fluid/tests/test_protobuf_descs.py +++ b/python/paddle/v2/fluid/tests/test_protobuf_descs.py @@ -115,6 +115,20 @@ class TestVarDesc(unittest.TestCase): self.assertEqual(src_shape, res_shape) self.assertEqual(core.VarDesc.VarType.SELECTED_ROWS, var.type()) + def test_multiple_shape(self): + program_desc = core.ProgramDesc() + block = program_desc.block(0) + var = block.var('my_reader') + var.set_type(core.VarDesc.VarType.READER) + var.set_tensor_num(3) + src_shapes = [[2, 3, 3], [4, 5], [6, 7, 8, 9]] + var.set_shapes(src_shapes) + #import pdb + # pdb.set_trace() + res_shapes = var.shapes() + self.assertEqual(src_shapes, res_shapes) + self.assertEqual(core.VarDesc.VarType.READER, var.type()) + def test_dtype(self): program_desc = core.ProgramDesc() block = program_desc.block(0) @@ -124,6 +138,30 @@ class TestVarDesc(unittest.TestCase): self.assertEqual(core.DataType.INT32, var.dtype()) self.assertEqual(core.VarDesc.VarType.LOD_TENSOR, var.type()) + def test_multiple_dtype(self): + program_desc = core.ProgramDesc() + block = program_desc.block(0) + var = block.var('my_reader') + var.set_type(core.VarDesc.VarType.READER) + var.set_tensor_num(3) + src_types = [ + core.DataType.INT32, core.DataType.FP64, core.DataType.FP32 + ] + var.set_dtypes(src_types) + self.assertEqual(src_types, var.dtypes()) + self.assertEqual(core.VarDesc.VarType.READER, var.type()) + + def test_multiple_lod_level(self): + program_desc = core.ProgramDesc() + block = program_desc.block(0) + var = block.var('my_reader') + var.set_type(core.VarDesc.VarType.READER) + var.set_tensor_num(3) + src_types = [3, 1, 2] + var.set_lod_levels(src_types) + self.assertEqual(src_types, var.lod_levels()) + self.assertEqual(core.VarDesc.VarType.READER, var.type()) + class TestBlockDesc(unittest.TestCase): def test_add_var(self): From 0d03cab5e9b16dba434ed4a25b5dff887d60a897 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 5 Feb 2018 15:18:10 +0800 Subject: [PATCH 19/29] fix a compile error --- paddle/framework/var_desc.cc | 2 +- paddle/framework/var_desc.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/framework/var_desc.cc b/paddle/framework/var_desc.cc index 44bd2363c8..6d83e2e411 100644 --- a/paddle/framework/var_desc.cc +++ b/paddle/framework/var_desc.cc @@ -56,7 +56,7 @@ size_t VarDesc::GetTensorDescNum() const { } void VarDesc::SetShapes( - const std::vector> &multiple_dims) { + const std::vector> &multiple_dims) { PADDLE_ENFORCE_EQ(multiple_dims.size(), GetTensorDescNum(), "The number of given shapes(%d) doesn't equal to the " "number of sub tensor.", diff --git a/paddle/framework/var_desc.h b/paddle/framework/var_desc.h index 862b9a5d80..72da2fbb0a 100644 --- a/paddle/framework/var_desc.h +++ b/paddle/framework/var_desc.h @@ -74,7 +74,7 @@ class VarDesc { void SetShape(const std::vector &dims); - void SetShapes(const std::vector> &multiple_dims); + void SetShapes(const std::vector> &multiple_dims); std::vector GetShape() const; From 4e5202647684f4ff6525775ce62a6dd674257917 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Mon, 5 Feb 2018 16:55:53 +0800 Subject: [PATCH 20/29] add independent sphinx tree for api --- doc/CMakeLists.txt | 2 ++ doc/api/CMakeLists.txt | 20 ++++++++++++++++++++ paddle/scripts/docker/build.sh | 2 +- paddle/scripts/travis/build_doc.sh | 6 ++++-- 4 files changed, 27 insertions(+), 3 deletions(-) create mode 100644 doc/api/CMakeLists.txt diff --git a/doc/CMakeLists.txt b/doc/CMakeLists.txt index 94dd3457fb..58ce5d61c9 100644 --- a/doc/CMakeLists.txt +++ b/doc/CMakeLists.txt @@ -47,3 +47,5 @@ sphinx_add_target(paddle_docs_cn ${SPHINX_CACHE_DIR_CN} ${CMAKE_CURRENT_SOURCE_DIR} ${SPHINX_HTML_DIR_CN}) + +add_subdirectory(api) diff --git a/doc/api/CMakeLists.txt b/doc/api/CMakeLists.txt new file mode 100644 index 0000000000..4e0bc1d5b8 --- /dev/null +++ b/doc/api/CMakeLists.txt @@ -0,0 +1,20 @@ +# configured documentation tools and intermediate build results +set(BINARY_BUILD_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/_build") + +# Sphinx cache with pickled ReST documents +set(SPHINX_CACHE_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/_doctrees") + +# HTML output director +set(SPHINX_HTML_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/html") + +configure_file( + "${CMAKE_CURRENT_SOURCE_DIR}/../templates/conf.py.en.in" + "${BINARY_BUILD_DIR_EN}/conf.py" + @ONLY) + +sphinx_add_target(paddle_api_docs + html + ${BINARY_BUILD_DIR_EN} + ${SPHINX_CACHE_DIR_EN} + ${CMAKE_CURRENT_SOURCE_DIR} + ${SPHINX_HTML_DIR_EN}) diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 59f3af0398..ba496db5f8 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -117,7 +117,7 @@ EOF -DWITH_STYLE_CHECK=OFF make -j `nproc` gen_proto_py make -j `nproc` paddle_python - make -j `nproc` paddle_docs paddle_docs_cn + make -j `nproc` paddle_docs paddle_docs_cn paddle_api_docs make -j `nproc` print_operators_doc paddle/pybind/print_operators_doc > doc/en/html/operators.json popd diff --git a/paddle/scripts/travis/build_doc.sh b/paddle/scripts/travis/build_doc.sh index 0db8d33bbc..4af4ac4f5e 100755 --- a/paddle/scripts/travis/build_doc.sh +++ b/paddle/scripts/travis/build_doc.sh @@ -9,13 +9,14 @@ cd $TRAVIS_BUILD_DIR/build cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_MKL=OFF -DWITH_DOC=ON make -j `nproc` gen_proto_py make -j `nproc` paddle_python -make -j `nproc` paddle_docs paddle_docs_cn +make -j `nproc` paddle_docs paddle_docs_cn paddle_api_docs make -j `nproc` print_operators_doc paddle/pybind/print_operators_doc > doc/en/html/operators.json # check websites for broken links linkchecker doc/en/html/index.html linkchecker doc/cn/html/index.html +linkchecker doc/api/en/html/index.html # Parse Github URL REPO=`git config remote.origin.url` @@ -54,10 +55,11 @@ function deploy_docs() { mkdir -p ${DIR} # remove old docs. mv new docs. set +e - rm -rf ${DIR}/doc ${DIR}/doc_cn + rm -rf ${DIR}/doc ${DIR}/doc_cn ${DIR}/api_doc set -e cp -r ../doc/cn/html ${DIR}/doc_cn cp -r ../doc/en/html ${DIR}/doc + cp -r ../doc/api/en/html ${DIR}/api_doc git add . } From 93734a79138945e6a603b1c9b28ea8cb1b32569e Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 5 Feb 2018 19:01:26 +0800 Subject: [PATCH 21/29] fix bug --- paddle/operators/prior_box_op.cc | 69 ++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 29 deletions(-) diff --git a/paddle/operators/prior_box_op.cc b/paddle/operators/prior_box_op.cc index 105ff4ac3e..f35273bf41 100644 --- a/paddle/operators/prior_box_op.cc +++ b/paddle/operators/prior_box_op.cc @@ -44,12 +44,6 @@ class PriorBoxOp : public framework::OperatorWithKernel { auto aspect_ratios = ctx->Attrs().Get>("aspect_ratios"); bool flip = ctx->Attrs().Get("flip"); - PADDLE_ENFORCE_GT(min_sizes.size(), 0, - "Size of min_sizes must be at least 1."); - for (size_t i = 0; i < min_sizes.size(); ++i) { - PADDLE_ENFORCE_GT(min_sizes[i], 0, "min_sizes[%d] must be positive.", i); - } - std::vector aspect_ratios_vec; ExpandAspectRatios(aspect_ratios, flip, aspect_ratios_vec); @@ -65,17 +59,6 @@ class PriorBoxOp : public framework::OperatorWithKernel { } } - PADDLE_ENFORCE_EQ(variances.size(), 4, "Must and only provide 4 variance."); - for (size_t i = 0; i < variances.size(); ++i) { - PADDLE_ENFORCE_GT(variances[i], 0.0, - "variance[%d] must be greater than 0.", i); - } - - const float step_h = ctx->Attrs().Get("step_h"); - PADDLE_ENFORCE_GT(step_h, 0.0, "step_h should be larger than 0."); - const float step_w = ctx->Attrs().Get("step_w"); - PADDLE_ENFORCE_GT(step_w, 0.0, "step_w should be larger than 0."); - std::vector dim_vec(4); dim_vec[0] = input_dims[2]; dim_vec[1] = input_dims[3]; @@ -106,26 +89,54 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { "PriorBoxOp. The layout is [H, W, num_priors, 4]. " "H is the height of input, W is the width of input, num_priors " "is the box count of each position."); - AddAttr>("min_sizes", "(vector) ", - "List of min sizes of generated prior boxes."); - AddAttr>("max_sizes", "(vector) ", - "List of max sizes of generated prior boxes."); + + AddAttr>("min_sizes", + "(vector) List of min sizes " + "of generated prior boxes.") + .AddCustomChecker([](const std::vector& min_sizes) { + PADDLE_ENFORCE_GT(min_sizes.size(), 0, + "Size of min_sizes must be at least 1."); + for (size_t i = 0; i < min_sizes.size(); ++i) { + PADDLE_ENFORCE_GT(min_sizes[i], 0, + "min_sizes[%d] must be positive.", i); + } + }); + AddAttr>( + "max_sizes", + "(vector) List of max sizes of generated prior boxes."); AddAttr>( - "aspect_ratios", "(vector) ", - "List of aspect ratios of generated prior boxes."); + "aspect_ratios", + "(vector) List of aspect ratios of generated prior boxes."); + AddAttr>( - "variances", "(vector) ", - "List of variances to be encoded in prior boxes."); - AddAttr("flip", "(bool) ", "Whether to flip aspect ratios.") + "variances", + "(vector) List of variances to be encoded in prior boxes.") + .AddCustomChecker([](const std::vector& variances) { + PADDLE_ENFORCE_EQ(variances.size(), 4, + "Must and only provide 4 variance."); + for (size_t i = 0; i < variances.size(); ++i) { + PADDLE_ENFORCE_GT(variances[i], 0.0, + "variance[%d] must be greater than 0.", i); + } + }); + AddAttr("flip", "(bool) Whether to flip aspect ratios.") .SetDefault(true); - AddAttr("clip", "(bool) ", "Whether to clip out-of-boundary boxes.") + AddAttr("clip", "(bool) Whether to clip out-of-boundary boxes.") .SetDefault(true); + AddAttr("step_w", "Prior boxes step across width, 0 for auto calculation.") - .SetDefault(0.0); + .SetDefault(0.0) + .AddCustomChecker([](const float& step_w) { + PADDLE_ENFORCE_GT(step_w, 0.0, "step_h should be larger than 0."); + }); AddAttr("step_h", "Prior boxes step across height, 0 for auto calculation.") - .SetDefault(0.0); + .SetDefault(0.0) + .AddCustomChecker([](const float& step_h) { + PADDLE_ENFORCE_GT(step_h, 0.0, "step_h should be larger than 0."); + }); + AddAttr("offset", "(float) " "Prior boxes center offset.") From d7a371cbf25f4dcc5dcbfbf0a043e6dc98ae322a Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 5 Feb 2018 19:51:42 +0800 Subject: [PATCH 22/29] follow comments --- paddle/operators/prior_box_op.cc | 2 +- paddle/operators/prior_box_op.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/operators/prior_box_op.cc b/paddle/operators/prior_box_op.cc index f35273bf41..1dc4b28855 100644 --- a/paddle/operators/prior_box_op.cc +++ b/paddle/operators/prior_box_op.cc @@ -128,7 +128,7 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { "Prior boxes step across width, 0 for auto calculation.") .SetDefault(0.0) .AddCustomChecker([](const float& step_w) { - PADDLE_ENFORCE_GT(step_w, 0.0, "step_h should be larger than 0."); + PADDLE_ENFORCE_GT(step_w, 0.0, "step_w should be larger than 0."); }); AddAttr("step_h", "Prior boxes step across height, 0 for auto calculation.") diff --git a/paddle/operators/prior_box_op.h b/paddle/operators/prior_box_op.h index e0a663ace8..12ff162356 100644 --- a/paddle/operators/prior_box_op.h +++ b/paddle/operators/prior_box_op.h @@ -25,7 +25,7 @@ inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, std::vector& output_aspect_ratior) { constexpr float epsilon = 1e-6; output_aspect_ratior.clear(); - output_aspect_ratior.push_back(1.); + output_aspect_ratior.push_back(1.0f); for (size_t i = 0; i < input_aspect_ratior.size(); ++i) { float ar = input_aspect_ratior[i]; bool already_exist = false; @@ -38,7 +38,7 @@ inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, if (!already_exist) { output_aspect_ratior.push_back(ar); if (flip) { - output_aspect_ratior.push_back(1. / ar); + output_aspect_ratior.push_back(1.0f / ar); } } } From f367ad6c6cae825c46b7262c77fa0cf6f8394796 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 5 Feb 2018 20:03:50 +0800 Subject: [PATCH 23/29] add "inline" for ClipFunctor and refine code --- paddle/operators/prior_box_op.h | 39 ++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/paddle/operators/prior_box_op.h b/paddle/operators/prior_box_op.h index 12ff162356..6b221cb74e 100644 --- a/paddle/operators/prior_box_op.h +++ b/paddle/operators/prior_box_op.h @@ -46,7 +46,7 @@ inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, template struct ClipFunctor { - HOSTDEVICE T operator()(T in) const { + HOSTDEVICE inline T operator()(T in) const { return std::min(std::max(in, 0.), 1.); } }; @@ -97,6 +97,9 @@ class PriorBoxOpKernel : public framework::OpKernel { boxes->mutable_data(ctx.GetPlace()); vars->mutable_data(ctx.GetPlace()); + T inv_img_width = 1.0 / img_width; + T inv_img_height = 1.0 / img_height; + auto e_boxes = framework::EigenTensor::From(*boxes); for (int h = 0; h < feature_height; ++h) { for (int w = 0; w < feature_width; ++w) { @@ -109,13 +112,15 @@ class PriorBoxOpKernel : public framework::OpKernel { // first prior: aspect_ratio = 1, size = min_size box_width = box_height = min_size; // xmin - e_boxes(h, w, idx, 0) = (center_x - box_width / 2.) / img_width; + e_boxes(h, w, idx, 0) = (center_x - box_width * 0.5) * inv_img_width; // ymin - e_boxes(h, w, idx, 1) = (center_y - box_height / 2.) / img_height; + e_boxes(h, w, idx, 1) = + (center_y - box_height * 0.5) * inv_img_height; // xmax - e_boxes(h, w, idx, 2) = (center_x + box_width / 2.) / img_width; + e_boxes(h, w, idx, 2) = (center_x + box_width * 0.5) * inv_img_width; // ymax - e_boxes(h, w, idx, 3) = (center_y + box_height / 2.) / img_height; + e_boxes(h, w, idx, 3) = + (center_y + box_height * 0.5) * inv_img_height; idx++; if (max_sizes.size() > 0) { @@ -124,13 +129,17 @@ class PriorBoxOpKernel : public framework::OpKernel { // size = sqrt(min_size * max_size) box_width = box_height = sqrt(min_size * max_size); // xmin - e_boxes(h, w, idx, 0) = (center_x - box_width / 2.) / img_width; + e_boxes(h, w, idx, 0) = + (center_x - box_width * 0.5) * inv_img_width; // ymin - e_boxes(h, w, idx, 1) = (center_y - box_height / 2.) / img_height; + e_boxes(h, w, idx, 1) = + (center_y - box_height * 0.5) * inv_img_height; // xmax - e_boxes(h, w, idx, 2) = (center_x + box_width / 2.) / img_width; + e_boxes(h, w, idx, 2) = + (center_x + box_width * 0.5) * inv_img_width; // ymax - e_boxes(h, w, idx, 3) = (center_y + box_height / 2.) / img_height; + e_boxes(h, w, idx, 3) = + (center_y + box_height * 0.5) * inv_img_height; idx++; } @@ -143,13 +152,17 @@ class PriorBoxOpKernel : public framework::OpKernel { box_width = min_size * sqrt(ar); box_height = min_size / sqrt(ar); // xmin - e_boxes(h, w, idx, 0) = (center_x - box_width / 2.) / img_width; + e_boxes(h, w, idx, 0) = + (center_x - box_width * 0.5) * inv_img_width; // ymin - e_boxes(h, w, idx, 1) = (center_y - box_height / 2.) / img_height; + e_boxes(h, w, idx, 1) = + (center_y - box_height * 0.5) * inv_img_height; // xmax - e_boxes(h, w, idx, 2) = (center_x + box_width / 2.) / img_width; + e_boxes(h, w, idx, 2) = + (center_x + box_width * 0.5) * inv_img_width; // ymax - e_boxes(h, w, idx, 3) = (center_y + box_height / 2.) / img_height; + e_boxes(h, w, idx, 3) = + (center_y + box_height * 0.5) * inv_img_height; idx++; } } From e9e24249217c1b234a9ce8f8d0d9c1e6e18fd2d3 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Mon, 5 Feb 2018 21:38:53 +0800 Subject: [PATCH 24/29] Fix warnings in multiclass_nms_op.cc. --- paddle/operators/multiclass_nms_op.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/operators/multiclass_nms_op.cc b/paddle/operators/multiclass_nms_op.cc index 8a65fe69f1..41b9335fb8 100644 --- a/paddle/operators/multiclass_nms_op.cc +++ b/paddle/operators/multiclass_nms_op.cc @@ -85,7 +85,7 @@ static inline void GetMaxScoreIndex( std::stable_sort(sorted_indices->begin(), sorted_indices->end(), SortScorePairDescend); // Keep top_k scores if needed. - if (top_k > -1 && top_k < sorted_indices->size()) { + if (top_k > -1 && top_k < static_cast(sorted_indices->size())) { sorted_indices->resize(top_k); } } @@ -151,7 +151,7 @@ class MultiClassNMSKernel : public framework::OpKernel { while (sorted_indices.size() != 0) { const int idx = sorted_indices.front().second; bool keep = true; - for (int k = 0; k < selected_indices->size(); ++k) { + for (size_t k = 0; k < selected_indices->size(); ++k) { if (keep) { const int kept_idx = (*selected_indices)[k]; T overlap = JaccardOverlap(bbox_data + idx * box_size, @@ -201,7 +201,7 @@ class MultiClassNMSKernel : public framework::OpKernel { int label = it.first; const T* sdata = scores_data + label * predict_dim; const std::vector& label_indices = it.second; - for (int j = 0; j < label_indices.size(); ++j) { + for (size_t j = 0; j < label_indices.size(); ++j) { int idx = label_indices[j]; PADDLE_ENFORCE_LT(idx, predict_dim); score_index_pairs.push_back( @@ -215,7 +215,7 @@ class MultiClassNMSKernel : public framework::OpKernel { // Store the new indices. std::map> new_indices; - for (int j = 0; j < score_index_pairs.size(); ++j) { + for (size_t j = 0; j < score_index_pairs.size(); ++j) { int label = score_index_pairs[j].second.first; int idx = score_index_pairs[j].second.second; new_indices[label].push_back(idx); @@ -238,7 +238,7 @@ class MultiClassNMSKernel : public framework::OpKernel { int label = it.first; const T* sdata = scores_data + label * predict_dim; const std::vector& indices = it.second; - for (int j = 0; j < indices.size(); ++j) { + for (size_t j = 0; j < indices.size(); ++j) { int idx = indices[j]; const T* bdata = bboxes_data + idx * kBBoxSize; odata[count * kOutputDim] = label; // label From 497a131e53316fc3d81cf92e68845d2fd33243e3 Mon Sep 17 00:00:00 2001 From: kavyasrinet Date: Mon, 5 Feb 2018 10:45:43 -0800 Subject: [PATCH 25/29] Proposing Python syntax for send and recv in design doc (#8093) * Adding send and recv in design doc * fix typo * fixed code * Adding threading --- doc/design/csp.md | 76 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 73 insertions(+), 3 deletions(-) diff --git a/doc/design/csp.md b/doc/design/csp.md index ba9cacfdea..2f6ce8d6fa 100644 --- a/doc/design/csp.md +++ b/doc/design/csp.md @@ -71,14 +71,14 @@ ch1 := make(chan int, 100) // a channel that can buffer 100 ints. In Fluid, we should be able to do the same: ```python -ch = fluid.make_chan(dtype=INT) -ch1 = fluid.make_chan(dtype=INT, 100) +ch = fluid.make_channel(dtype=INT) +ch1 = fluid.make_channel(dtype=INT, 100) ``` In addition to that, we want channels that can hold more complex element types, e.g., Tensors of float16: ```python -ch = fluid.make_chan(dtype=Tensor, etype=float16) +ch = fluid.make_channel(dtype=Tensor, etype=float16) ``` or Tensors of Tensors of float16 etc. @@ -87,6 +87,76 @@ The point here is that we need a consistent way to compose types, like in C++ we ### Send and Recv +In Go, we first create a channel as explained in the section above and then perform read and write operations on top of the channels. + +```go +ch1 := make(chan int) +ch2 := make(chan int, 100) +``` + +To write (or perform a `Send` operation) the value of a variable `x`, to channel `ch1` above, we perform the following: + +```go +ch1 <- x +fmt.Println("Written to the channel") +``` +Now to read (or perform a `Recv` operation) the value stored in `ch2` into a variable `y`, we perform the following: + +```go +y <- ch2 +fmt.Println("Received from channel") +``` + +In Fluid, we should be able to perform the above operations on the channel objects as well. As of now, we support two different kinds of channels : [Buffered Channel](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/details/buffered_channel.h) and [UnBuffered Channel](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/details/unbuffered_channel.h) + +Send and Receive can be performed as following on a buffered channel: + +```python +import threading + +def send_to_channel(channel, num_time=1): + for i in xrange(num_time): + channel.send(i) + +# Create a buffered channel of capacity 10 +buffer_size = 10; +ch = fluid.make_channel(dtype=INT, buffer_size) + +# Now write three elements to the channel +thread = threading.Thread(target=send_to_channel, args=(ch, 3, )) +thread.daemon = True +thread.start() + +# Read all the data from the channel +for i in xrange(3): + y = ch.recv() + +# Done receiving , now close the channel +ch.close() +``` + +The send and receive operations will be similar for unbuffered channel as well, except for the fact that there is no buffer in an unbuffered channel, so the operations are completely synchronized. For example: + +```python +import threading + +def send_to_channel(channel, data): + channel.send(data) + +# Create an unbuffered channel +ch = fluid.make_channel(dtype=INT) + +# Writes and Reads are synchronous otherwise the calls will block. +thread = threading.Thread(target=send_to_channel, args=(ch, 10, )) +thread.daemon = True +thread.start() + +y = ch.recv() + +# Done receiving , now close the channel +ch.close() +``` + ### Select ## Example Programs From 1ead6c2691be09f34303c06d119c17ba4e4aeab7 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Mon, 5 Feb 2018 11:06:02 -0800 Subject: [PATCH 26/29] Add proposed fluid syntax for select statement in Fluid's implementation of CSP (#7908) * Add proposed fluid syntax for select statement in Fluid's implementation of CSP * Fix Typo --- doc/design/csp.md | 49 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/doc/design/csp.md b/doc/design/csp.md index 2f6ce8d6fa..36422d8236 100644 --- a/doc/design/csp.md +++ b/doc/design/csp.md @@ -159,6 +159,55 @@ ch.close() ### Select +In Go, the `select` statement lets a goroutine wait on multiple communication operations. A `select` blocks untill one of its cases can run, then it executes that case. It chooses one at random if multiple are ready. + +```go + +ch1 := make(chan int) +ch2 := make(chan int, 100) + +x := 0 + +for { + select { + case ch1 <- x: + x := x + 1 + case y <- ch2: + fmt.Println("Received on channel") + default: + fmt.Println("Default") + } + } + +``` + +In Fluid, we should be able to do the same: + +```python +ch1 = fluid.make_chan(dtype=INT) +ch2 = fluid.make_chan(dtype=INT, 100) + +sel = fluid.select() + +with sel.case(ch1, 'w', X): + fluid.layers.increment(X) + +with sel.case(ch2, 'r', Y): + fluid.print("Received on Channel") + +with sel.default(): + fluid.print("Default") + +``` + +In the above code snippet, `X` and `Y` are variables. Now let us look at each of these statements one by one. + +- `sel.case(ch1, 'w', X)` : This specifies that we are writing to `ch1` and we want to write the integer in variable `X` to the channel. The character `w` is used here to make the syntax familar to write syntax in Python I/O. + +- `sel.case(ch2, 'r', Y)` : This specifies that we would like to read the result from `ch2` into variable `Y`. The character `r` is used here to make the syntax familar to read syntax in Python I/O. + +- `sel.default()` : This is equivalent to the default in Go `select`. If none of the channels are ready for read or write, then the fluid code in the default block will be executed. + ## Example Programs ### 1. RPC between Trainers and Parameter Servers From b0ecb36583ed97737bd5c43cbafbdc8fa29cbd68 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Mon, 5 Feb 2018 17:11:11 -0800 Subject: [PATCH 27/29] Rewrite the Send/Recv part of csp.md (#8164) * Update csp.md * Update csp.md * Update csp.md --- doc/design/csp.md | 110 +++++++++++++++++++++++++--------------------- 1 file changed, 59 insertions(+), 51 deletions(-) diff --git a/doc/design/csp.md b/doc/design/csp.md index 36422d8236..ae2e3e1b99 100644 --- a/doc/design/csp.md +++ b/doc/design/csp.md @@ -42,7 +42,7 @@ The type *channel* is conceptually the blocking queue. In Go, its implemented i The `select` operation has been in OS kernels long before Go language. All Unix kernels implement system calls *poll* and *select*. They monitor multiple file descriptors to see if I/O is possible on any of them. This takes O(N) time. Since Linux 2.6, a new system call, *epoll*, can do the same in O(1) time. In BSD systems, there is a similar system call *kqueue*. Go's Linux implementation uses epoll. -It might be a good idea to implement Fluid's select using epoll too. In this design doc, we start from the O(N) way, so we could focus on Python binding and the syntax. +It might be a good idea to implement Fluid's select using epoll too. In this design doc, we start from the O(N) way so that we could focus on Python binding and the syntax. ### Type Channel @@ -87,79 +87,87 @@ The point here is that we need a consistent way to compose types, like in C++ we ### Send and Recv -In Go, we first create a channel as explained in the section above and then perform read and write operations on top of the channels. +Go's CSP implementation depends on data type *channel*. There are two types of channels: -```go -ch1 := make(chan int) -ch2 := make(chan int, 100) -``` +1. The unblocked channel, or buffered channel, is a blocking queue with a non-zero sized buffer. The sending to buffered channel blocks if the buffer is full, and the receive operation blocks if the buffer is empty. +1. blocked channel, or unbuffered channel, is a blocking queue with no buffer. Both sending and receiving block with unbuffered channels. -To write (or perform a `Send` operation) the value of a variable `x`, to channel `ch1` above, we perform the following: +There are four types of actions with a channel: -```go -ch1 <- x -fmt.Println("Written to the channel") -``` -Now to read (or perform a `Recv` operation) the value stored in `ch2` into a variable `y`, we perform the following: +1. Create a channel -```go -y <- ch2 -fmt.Println("Received from channel") -``` + ```go + ch := make(chan int) // this is an unbuffered channel + ch := make(chan int, 100) // this is a buffered channel of 100 ints. + ``` -In Fluid, we should be able to perform the above operations on the channel objects as well. As of now, we support two different kinds of channels : [Buffered Channel](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/details/buffered_channel.h) and [UnBuffered Channel](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/details/unbuffered_channel.h) +1. Send -Send and Receive can be performed as following on a buffered channel: + ```go + ch <- 111 + ``` -```python -import threading +1. Recv -def send_to_channel(channel, num_time=1): - for i in xrange(num_time): - channel.send(i) + ```go + y, ok <- ch + ``` -# Create a buffered channel of capacity 10 -buffer_size = 10; -ch = fluid.make_channel(dtype=INT, buffer_size) +1. Close -# Now write three elements to the channel -thread = threading.Thread(target=send_to_channel, args=(ch, 3, )) -thread.daemon = True -thread.start() + ```go + close(ch) + ``` + + Please be aware that a closed channel is not a nil channel, which is `var ch chan int`. + +There are some [axioms with channels](https://dave.cheney.net/2014/03/19/channel-axioms): -# Read all the data from the channel -for i in xrange(3): - y = ch.recv() +1. A send to a nil channel blocks forever -# Done receiving , now close the channel -ch.close() -``` +1. A receive from a nil channel blocks forever + +1. A send to a closed channel panics + +1. A receive from a closed channel returns the residual values and then zeros. -The send and receive operations will be similar for unbuffered channel as well, except for the fact that there is no buffer in an unbuffered channel, so the operations are completely synchronized. For example: +In Fluid, we have [buffered channels](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/details/buffered_channel.h) and [unbuffered channels](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/details/unbuffered_channel.h) + +The following program illustrates the Python syntax for accessing Fluid buffers. ```python -import threading +import fluid + +buffer_size = 10 +ch = fluid.make_channel(dtype=INT, buffer_size) -def send_to_channel(channel, data): - channel.send(data) +# Now write three elements to the channel +with fluid.while(steps=buffer_size): + fluid.send(ch, step) + fluid.close_channel(ch) + +with fluid.while(steps=buffer_size): + fluid.print(fluid.recv(ch)) +``` + +The following example shows that to avoid the always-blocking behavior of unbuffered channels, we need to use Fluid's goroutines. + +```python +import fluid -# Create an unbuffered channel ch = fluid.make_channel(dtype=INT) -# Writes and Reads are synchronous otherwise the calls will block. -thread = threading.Thread(target=send_to_channel, args=(ch, 10, )) -thread.daemon = True -thread.start() +with fluid.go(): + fluid.send(ch) -y = ch.recv() +y = fluid.recv(ch) -# Done receiving , now close the channel -ch.close() +fluid.close_channel(ch) ``` ### Select -In Go, the `select` statement lets a goroutine wait on multiple communication operations. A `select` blocks untill one of its cases can run, then it executes that case. It chooses one at random if multiple are ready. +In Go, the `select` statement lets a goroutine wait on multiple communication operations. A `select` blocks until one of its cases can run, then it executes that case. It chooses one at random if multiple are ready. ```go @@ -202,9 +210,9 @@ with sel.default(): In the above code snippet, `X` and `Y` are variables. Now let us look at each of these statements one by one. -- `sel.case(ch1, 'w', X)` : This specifies that we are writing to `ch1` and we want to write the integer in variable `X` to the channel. The character `w` is used here to make the syntax familar to write syntax in Python I/O. +- `sel.case(ch1, 'w', X)` : This specifies that we are writing to `ch1` and we want to write the integer in variable `X` to the channel. The character `w` is used here to make the syntax familiar to write syntax in Python I/O. -- `sel.case(ch2, 'r', Y)` : This specifies that we would like to read the result from `ch2` into variable `Y`. The character `r` is used here to make the syntax familar to read syntax in Python I/O. +- `sel.case(ch2, 'r', Y)` : This specifies that we would like to read the result from `ch2` into variable `Y`. The character `r` is used here to make the syntax familiar to read syntax in Python I/O. - `sel.default()` : This is equivalent to the default in Go `select`. If none of the channels are ready for read or write, then the fluid code in the default block will be executed. From 165450ff6ca5bc0f02ffe63ec11f50ed4c240f09 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Tue, 6 Feb 2018 09:52:18 +0800 Subject: [PATCH 28/29] Refine the inference unittest recognize_digits. (#8147) --- .../book/test_inference_recognize_digits.cc | 63 ++++++++++++++----- .../fluid/tests/book/test_recognize_digits.py | 4 +- 2 files changed, 49 insertions(+), 18 deletions(-) diff --git a/paddle/inference/tests/book/test_inference_recognize_digits.cc b/paddle/inference/tests/book/test_inference_recognize_digits.cc index 26dc2aee04..ce8772587f 100644 --- a/paddle/inference/tests/book/test_inference_recognize_digits.cc +++ b/paddle/inference/tests/book/test_inference_recognize_digits.cc @@ -58,6 +58,47 @@ void TestInference(const std::string& dirname, delete scope; } +template +void SetupTensor(paddle::framework::LoDTensor& input, + paddle::framework::DDim dims, + T lower, + T upper) { + srand(time(0)); + float* input_ptr = input.mutable_data(dims, paddle::platform::CPUPlace()); + for (int i = 0; i < input.numel(); ++i) { + input_ptr[i] = + (static_cast(rand()) / static_cast(RAND_MAX)) * (upper - lower) + + lower; + } +} + +template +void CheckError(paddle::framework::LoDTensor& output1, + paddle::framework::LoDTensor& output2) { + // Check lod information + EXPECT_EQ(output1.lod(), output2.lod()); + + EXPECT_EQ(output1.dims(), output2.dims()); + EXPECT_EQ(output1.numel(), output2.numel()); + + T err = static_cast(0); + if (typeid(T) == typeid(float)) { + err = 1E-3; + } else if (typeid(T) == typeid(double)) { + err = 1E-6; + } else { + err = 0; + } + + size_t count = 0; + for (int64_t i = 0; i < output1.numel(); ++i) { + if (fabs(output1.data()[i] - output2.data()[i]) > err) { + count++; + } + } + EXPECT_EQ(count, 0) << "There are " << count << " different elements."; +} + TEST(inference, recognize_digits) { if (FLAGS_dirname.empty()) { LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model"; @@ -70,12 +111,10 @@ TEST(inference, recognize_digits) { // In unittests, this is done in paddle/testing/paddle_gtest_main.cc paddle::framework::LoDTensor input; - srand(time(0)); - float* input_ptr = - input.mutable_data({1, 28, 28}, paddle::platform::CPUPlace()); - for (int i = 0; i < 784; ++i) { - input_ptr[i] = rand() / (static_cast(RAND_MAX)); - } + // Use normilized image pixels as input data, + // which should be in the range [-1.0, 1.0]. + SetupTensor( + input, {1, 28, 28}, static_cast(-1), static_cast(1)); std::vector cpu_feeds; cpu_feeds.push_back(&input); @@ -98,16 +137,6 @@ TEST(inference, recognize_digits) { dirname, cpu_feeds, cpu_fetchs2); LOG(INFO) << output2.dims(); - EXPECT_EQ(output1.dims(), output2.dims()); - EXPECT_EQ(output1.numel(), output2.numel()); - - float err = 1E-3; - int count = 0; - for (int64_t i = 0; i < output1.numel(); ++i) { - if (fabs(output1.data()[i] - output2.data()[i]) > err) { - count++; - } - } - EXPECT_EQ(count, 0) << "There are " << count << " different elements."; + CheckError(output1, output2); #endif } diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits.py index b8f55c813b..fb6b1f7192 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits.py @@ -166,7 +166,9 @@ def infer(use_cuda, save_dirname=None): fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) # The input's dimension of conv should be 4-D or 5-D. - tensor_img = numpy.random.rand(1, 1, 28, 28).astype("float32") + # Use normilized image pixels as input data, which should be in the range [-1.0, 1.0]. + tensor_img = numpy.random.uniform(-1.0, 1.0, + [1, 1, 28, 28]).astype("float32") # Construct feed as a dictionary of {feed_target_name: feed_target_data} # and results will contain a list of data corresponding to fetch_targets. From 9a1fa890a0c510ca1863eea358423bc89fd4fdef Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 6 Feb 2018 11:10:34 +0800 Subject: [PATCH 29/29] remove unnecessary comments --- python/paddle/v2/fluid/tests/test_protobuf_descs.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/paddle/v2/fluid/tests/test_protobuf_descs.py b/python/paddle/v2/fluid/tests/test_protobuf_descs.py index ac6de68b5f..8f335d13db 100644 --- a/python/paddle/v2/fluid/tests/test_protobuf_descs.py +++ b/python/paddle/v2/fluid/tests/test_protobuf_descs.py @@ -123,8 +123,6 @@ class TestVarDesc(unittest.TestCase): var.set_tensor_num(3) src_shapes = [[2, 3, 3], [4, 5], [6, 7, 8, 9]] var.set_shapes(src_shapes) - #import pdb - # pdb.set_trace() res_shapes = var.shapes() self.assertEqual(src_shapes, res_shapes) self.assertEqual(core.VarDesc.VarType.READER, var.type())