From 97e9dd72375258ed69fbbab39f340d23878002f5 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 8 Nov 2017 14:15:58 +0800 Subject: [PATCH 01/25] add dilation for im2col --- paddle/operators/conv_cudnn_op.cc | 2 - paddle/operators/conv_op.cc | 13 +- paddle/operators/conv_op.h | 29 +- paddle/operators/conv_transpose_op.h | 16 +- paddle/operators/math/context_project.h | 10 +- paddle/operators/math/im2col.cc | 281 +++++++++--------- paddle/operators/math/im2col.cu | 366 +++++++++++++----------- paddle/operators/math/im2col.h | 11 +- paddle/operators/math/im2col_test.cc | 18 +- 9 files changed, 395 insertions(+), 351 deletions(-) diff --git a/paddle/operators/conv_cudnn_op.cc b/paddle/operators/conv_cudnn_op.cc index 97f31bf22d..4c65b60d23 100644 --- a/paddle/operators/conv_cudnn_op.cc +++ b/paddle/operators/conv_cudnn_op.cc @@ -22,8 +22,6 @@ class CudnnConvOpMaker : public Conv2DOpMaker { CudnnConvOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : Conv2DOpMaker(proto, op_checker) { - AddAttr>("dilations", "dilations of convolution operator.") - .SetDefault(std::vector{1, 1}); AddAttr("workspace_size_MB", "workspace size for cudnn, in MB, " "workspace is a section of GPU memory which will be " diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index a6f65f1016..852ac2ae37 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -30,6 +30,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); int groups = ctx->Attrs().Get("groups"); + std::vector dilations = ctx->Attrs().Get>("dilations"); int input_channels = in_dims[1]; int output_channels = filter_dims[0]; @@ -54,7 +55,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < paddings.size(); ++i) { output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2], - paddings[i], strides[i])); + dilations[i], paddings[i], paddings[i], + strides[i])); } ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); } @@ -90,6 +92,10 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, "first half of the input channels, while the second half of the filters " "is only connected to the second half of the input channels.") .SetDefault(1); + AddAttr>("dilations", + "(vector default:{1, 1}), the dilations of " + "convolution operator.") + .SetDefault(std::vector{1, 1}); AddComment(R"DOC( Convolution Operator. @@ -151,6 +157,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, "first half of the input channels, while the second half of the filters " "is only connected to the second half of the input channels.") .SetDefault(1); + AddAttr>("dilations", + "(vector default:{1, 1, 1}), the dilations of " + "convolution operator. Currently, conv3d doesn't " + "support dilation.") + .SetDefault(std::vector{1, 1, 1}); AddComment(R"DOC( Convolution3D Operator. diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index 7c1729213b..2459f03a1a 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -27,9 +27,12 @@ using Tensor = framework::Tensor; // Base convolution operator definations for other conv // like operators to reuse the implementation. -inline int OutputSize(int input_size, int filter_size, int padding, - int stride) { - int output_size = (input_size - filter_size + 2 * padding) / stride + 1; +inline int OutputSize(int input_size, int filter_size, int dilation, + int padding_up, int padding_down, int stride) { + int output_size = (input_size + padding_up + padding_down - + (dilation * (filter_size - 1) + 1)) / + stride + + 1; return output_size; } @@ -76,6 +79,7 @@ class GemmConvKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); int groups = context.Attr("groups"); + std::vector dilations = context.Attr>("dilations"); const int batch_size = static_cast(input->dims()[0]); @@ -139,9 +143,9 @@ class GemmConvKernel : public framework::OpKernel { if (filter_shape_vec.size() == 2) { // im2col math::Im2ColFunctor im2col; - im2col(context.device_context(), in_slice, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], - paddings[1]); + im2col(context.device_context(), in_slice, col, dilations[0], + dilations[1], strides[0], strides[1], paddings[0], paddings[0], + paddings[1], paddings[1]); } else if (filter_shape_vec.size() == 3) { // vol2col math::Vol2ColFunctor vol2col; @@ -181,6 +185,7 @@ class GemmConvGradKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); int groups = context.Attr("groups"); + std::vector dilations = context.Attr>("dilations"); const int batch_size = static_cast(input->dims()[0]); @@ -263,9 +268,9 @@ class GemmConvGradKernel : public framework::OpKernel { if (filter_shape_vec.size() == 2) { math::Col2ImFunctor col2im; - col2im(context.device_context(), in_grad_slice, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], - paddings[1]); + col2im(context.device_context(), in_grad_slice, col, dilations[0], + dilations[1], strides[0], strides[1], paddings[0], + paddings[0], paddings[1], paddings[1]); } else if (filter_shape_vec.size() == 3) { math::Col2VolFunctor col2vol; @@ -295,9 +300,9 @@ class GemmConvGradKernel : public framework::OpKernel { if (filter_shape_vec.size() == 2) { math::Im2ColFunctor im2col; - im2col(context.device_context(), in_slice, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], - paddings[1]); + im2col(context.device_context(), in_slice, col, dilations[0], + dilations[1], strides[0], strides[1], paddings[0], + paddings[0], paddings[1], paddings[1]); } else if (filter_shape_vec.size() == 3) { math::Vol2ColFunctor vol2col; vol2col(context.device_context(), in_slice, col, strides[0], diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index 6c1a6220d7..cbfad88b39 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -69,6 +69,9 @@ class GemmConvTransposeKernel : public framework::OpKernel { // TODO(Zhuoyuan): Paddings can be added in future. // groups will alway be disabled in conv2dtranspose. + int dilation_h = 1; + int dilation_w = 1; + const int batch_size = static_cast(input->dims()[0]); // input_shape_vec: {h, w} or {d, h, w} @@ -140,8 +143,8 @@ class GemmConvTransposeKernel : public framework::OpKernel { // from (c * k_h * k_w, h * w) to (c, o_h, o_w) math::Col2ImFunctor col2im; - col2im(context.device_context(), output_batch, col, strides[0], - strides[1], 0, 0, 0, 0); + col2im(context.device_context(), output_batch, col, dilation_h, + dilation_w, strides[0], strides[1], 0, 0, 0, 0); } else if (filter_shape_vec.size() == 3) { // col2vol: col_matrix -> dy // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) @@ -174,6 +177,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // Actually, no paddings and groups allowed in conv transpose. std::vector paddings = context.Attr>("paddings"); + int dilation_h = 1; + int dilation_w = 1; + const int batch_size = static_cast(input->dims()[0]); // input_shape_vec: {h, w} or {d, h, w} @@ -248,9 +254,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // im2col: dy -> col matrix // from (c, o_h, o_w) to (c * k_h * k_w, h * w) math::Im2ColFunctor im2col; - im2col(context.device_context(), output_grad_batch, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], - paddings[1]); + im2col(context.device_context(), output_grad_batch, col, dilation_h, + dilation_w, strides[0], strides[1], paddings[0], paddings[0], + paddings[1], paddings[1]); } else if (filter_shape_vec.size() == 3) { // vol2col: dy -> col_matrix // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) diff --git a/paddle/operators/math/context_project.h b/paddle/operators/math/context_project.h index e028336041..c67d84528f 100644 --- a/paddle/operators/math/context_project.h +++ b/paddle/operators/math/context_project.h @@ -95,6 +95,9 @@ class ContextProjectFunctor { math::Im2ColFunctor im2col_ocf; + int dilation_h = 1; + int dilation_w = 1; + int input_row_begin, input_row_end; int sequence_height, sequence_width; sequence_width = in.dims()[1]; @@ -124,7 +127,7 @@ class ContextProjectFunctor { sequence_width}); // input_channels, input_height, input_width in_t.Resize(framework::make_ddim(input_shape)); - im2col_ocf(context, in_t, out_t, + im2col_ocf(context, in_t, out_t, dilation_h, dilation_w, /*stride_height*/ context_stride, /*stride_width*/ 1, up_pad, down_pad, 0, 0); out_t.Resize({sequence_height, context_length * sequence_width}); @@ -204,6 +207,9 @@ class ContextProjectGradFunctor { math::Col2ImFunctor col2im_ocf; + int dilation_h = 1; + int dilation_w = 1; + int input_row_begin, input_row_end; int sequence_height, sequence_width; sequence_width = in.dims()[1]; @@ -234,7 +240,7 @@ class ContextProjectGradFunctor { sequence_width}); // input_channels, input_height, input_width in_t.Resize(framework::make_ddim(input_shape)); - col2im_ocf(context, in_t, out_t, + col2im_ocf(context, in_t, out_t, dilation_h, dilation_w, /*stride_height*/ context_stride, /*stride_width*/ 1, up_pad, down_pad, 0, 0); out_t.Resize({sequence_height, context_length * sequence_width}); diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index 3b1b0bd71d..b248863b4e 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -29,35 +29,36 @@ class Im2ColFunctor(); T* col_data = col.data(); @@ -66,19 +67,19 @@ class Im2ColFunctor= input_height || im_col_idx < 0 || - im_col_idx >= input_width) { - col_data[(c * output_height + h) * output_width + w] = T(0); - } else { - im_row_idx += c_im * input_height; - col_data[(c * output_height + h) * output_width + w] = - im_data[im_row_idx * input_width + im_col_idx]; - } + col_data[(c * col_height + h) * col_width + w] = + (im_row_idx < 0 || im_row_idx >= im_height || im_col_idx < 0 || + im_col_idx >= im_width) + ? static_cast(0) + : im_data[(im_row_idx + c_im * im_height) * im_width + + im_col_idx]; } } } @@ -95,35 +96,35 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& col, int dilation_h, int dilation_w, + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int input_channels = im.dims()[0]; - int input_height = im.dims()[1]; - int input_width = im.dims()[2]; + int im_channels = im.dims()[0]; + int im_height = im.dims()[1]; + int im_width = im.dims()[2]; int filter_height = col.dims()[1]; int filter_width = col.dims()[2]; - int output_height = col.dims()[3]; - int output_width = col.dims()[4]; + int col_height = col.dims()[3]; + int col_width = col.dims()[4]; - PADDLE_ENFORCE_EQ( - (input_height + padding_up + padding_down - filter_height) / - stride_height + - 1, - output_height, - "Output_height and padding(padding_up, padding_down) are " - "inconsistent."); - PADDLE_ENFORCE_EQ( - (input_width + padding_left + padding_right - filter_width) / - stride_width + - 1, - output_width, - "output_width and padding(padding_left, padding_right) are " - "inconsistent."); + PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - + ((dilation_h * (filter_height - 1) + 1))) / + stride_height + + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - + ((dilation_w * (filter_width - 1) + 1))) / + stride_width + + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); - int channels_col = input_channels * filter_height * filter_width; + int channels_col = im_channels * filter_height * filter_width; T* im_data = im.data(); const T* col_data = col.data(); @@ -132,16 +133,18 @@ class Col2ImFunctor= 0 && (im_row_idx) < input_height && - (im_col_idx) >= 0 && (im_col_idx) < input_width) { - im_row_idx += c_im * input_height; - im_data[im_row_idx * input_width + im_col_idx] += - col_data[(c * output_height + h) * output_width + w]; + if ((im_row_idx) >= 0 && (im_row_idx) < im_height && + (im_col_idx) >= 0 && (im_col_idx) < im_width) { + im_row_idx += c_im * im_height; + im_data[im_row_idx * im_width + im_col_idx] += + col_data[(c * col_height + h) * col_width + w]; } } } @@ -169,39 +172,38 @@ class Im2ColFunctor(); T* col_data = col.data(); - for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { - for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { - for (int channel = 0; channel < input_channels; ++channel) { + for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { + for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) { + for (int channel = 0; channel < im_channels; ++channel) { for (int filter_row_idx = 0; filter_row_idx < filter_height; ++filter_row_idx) { for (int filter_col_idx = 0; filter_col_idx < filter_width; @@ -210,22 +212,21 @@ class Im2ColFunctor= input_height || - im_col_offset < 0 || im_col_offset >= input_width) { - col_data[col_offset] = T(0); - } else { - int im_offset = - (channel * input_height + im_row_offset) * input_width + - im_col_offset; - col_data[col_offset] = im_data[im_offset]; - } + int col_offset = + ((((col_row_idx)*col_width + col_col_idx) * im_channels + + channel) * + filter_height + + filter_row_idx) * + filter_width + + filter_col_idx; + + int im_offset = (channel * im_height + im_row_offset) * im_width + + im_col_offset; + col_data[col_offset] = + (im_row_offset < 0 || im_row_offset >= im_height || + im_col_offset < 0 || im_col_offset >= im_width) + ? static_cast(0) + : im_data[im_offset]; } } } @@ -244,40 +245,38 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& col, int dilation_h, int dilation_w, + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int input_channels = im.dims()[0]; - int input_height = im.dims()[1]; - int input_width = im.dims()[2]; + int im_channels = im.dims()[0]; + int im_height = im.dims()[1]; + int im_width = im.dims()[2]; int filter_height = col.dims()[3]; int filter_width = col.dims()[4]; - int output_height = col.dims()[0]; - int output_width = col.dims()[1]; + int col_height = col.dims()[0]; + int col_width = col.dims()[1]; - PADDLE_ENFORCE_EQ( - (input_height + padding_up + padding_down - filter_height) / - stride_height + - 1, - output_height, - "Output_height and padding(padding_up, padding_down) are " - "inconsistent."); - PADDLE_ENFORCE_EQ( - (input_width + padding_left + padding_right - filter_width) / - stride_width + - 1, - output_width, - "output_width and padding(padding_left, padding_right) are " - "inconsistent."); + PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) / + stride_height + + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) / + stride_width + + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); T* im_data = im.data(); const T* col_data = col.data(); - for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { - for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { - for (int channel = 0; channel < input_channels; ++channel) { + for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { + for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) { + for (int channel = 0; channel < im_channels; ++channel) { for (int filter_row_idx = 0; filter_row_idx < filter_height; ++filter_row_idx) { for (int filter_col_idx = 0; filter_col_idx < filter_width; @@ -286,17 +285,17 @@ class Col2ImFunctor= 0 && im_row_offset < input_height && - im_col_offset >= 0 && im_col_offset < input_width) { + int col_offset = + (((col_row_idx * col_width + col_col_idx) * im_channels + + channel) * + filter_height + + filter_row_idx) * + filter_width + + filter_col_idx; + if (im_row_offset >= 0 && im_row_offset < im_height && + im_col_offset >= 0 && im_col_offset < im_width) { int im_offset = - (channel * input_height + im_row_offset) * input_width + + (channel * im_height + im_row_offset) * im_width + im_col_offset; im_data[im_offset] += col_data[col_offset]; } diff --git a/paddle/operators/math/im2col.cu b/paddle/operators/math/im2col.cu index 7b201fdbf3..69e2abee03 100644 --- a/paddle/operators/math/im2col.cu +++ b/paddle/operators/math/im2col.cu @@ -20,36 +20,32 @@ namespace operators { namespace math { template -__global__ void im2col(const T* data_im, int num_outs, int height, int width, +__global__ void im2col(const T* data_im, int num_outs, int im_height, + int im_width, int dilation_h, int dilation_w, int filter_height, int filter_width, int stride_height, int stride_width, int padding_height, int padding_width, - int output_height, int output_width, T* data_col) { - int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + int col_height, int col_width, T* data_col) { + const int index = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; if (index < num_outs) { - int w_out = index % output_width; - index /= output_width; - int h_out = index % output_height; - int channel_in = index / output_height; + int w_out = index % col_width; + int h_out = (index / col_width) % col_height; + int channel_in = index / col_width / col_height; int channel_out = channel_in * filter_height * filter_width; - int h_in = h_out * stride_height; - int w_in = w_out * stride_width; + int h_in = h_out * stride_height - padding_height; + int w_in = w_out * stride_width - padding_width; - data_col += (channel_out * output_height + h_out) * output_width + w_out; + data_col += (channel_out * col_height + h_out) * col_width + w_out; + data_im += (channel_in * im_height + h_in) * im_width + w_in; for (int i = 0; i < filter_height; ++i) { for (int j = 0; j < filter_width; ++j) { - int rIdx = int(h_in + i); - int cIdx = int(w_in + j); - if ((rIdx - (int)padding_height) >= (int)height || - (rIdx - (int)padding_height) < 0 || - (cIdx - (int)padding_width) >= (int)width || - (cIdx - (int)padding_width) < 0) { - *data_col = 0; - } else { - rIdx = rIdx + channel_in * height - padding_height; - cIdx = cIdx - padding_width; - *data_col = data_im[rIdx * width + cIdx]; - } - data_col += output_height * output_width; + int rIdx = h_in + i * dilation_h; + int cIdx = w_in + j * dilation_w; + *data_col = + (rIdx >= im_height || rIdx < 0 || cIdx >= im_width || cIdx < 0) + ? 0 + : data_im[i * dilation_h * im_width + j * dilation_w]; + data_col += col_height * col_width; } } } @@ -66,29 +62,36 @@ class Im2ColFunctor<<(context) .stream()>>>( - im.data(), num_outputs, input_height, input_width, filter_height, - filter_width, stride_height, stride_width, padding_up, padding_left, - output_height, output_width, col.data()); + im.data(), num_outputs, im_height, im_width, dilation_h, dilation_w, + filter_height, filter_width, stride_height, stride_width, padding_up, + padding_left, col_height, col_width, col.data()); } }; template -__global__ void col2im(size_t n, const T* data_col, size_t height, size_t width, - size_t channels, size_t filter_height, - size_t filter_width, size_t stride_height, - size_t stride_width, size_t padding_height, - size_t padding_width, size_t output_height, - size_t output_width, T* data_im) { - size_t index = +__global__ void col2im(int n, const T* data_col, int im_height, int im_width, + int dilation_h, int dilation_w, int filter_height, + int filter_width, int stride_height, int stride_width, + int padding_height, int padding_width, int col_height, + int col_width, T* data_im) { + const int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + + const int d_filter_height = dilation_h * (filter_height - 1) + 1; + const int d_filter_width = dilation_w * (filter_width - 1) + 1; + if (index < n) { T val = 0; - int w = int(index % width); - int h = int((index / width) % height); - int c = int(index / (width * height)); - if ((w - (int)padding_width) >= 0 && - (w - (int)padding_width) < (width - 2 * padding_width) && - (h - (int)padding_height) >= 0 && - (h - padding_height) < (height - 2 * padding_height)) { - // compute the start and end of the output - int w_col_start = (w < (int)filter_width) - ? 0 - : (w - int(filter_width)) / (int)stride_width + 1; - int w_col_end = - min((int)(w / (int)stride_width + 1), (int)(output_width)); - int h_col_start = (h < (int)filter_height) - ? 0 - : (h - (int)filter_height) / (int)stride_height + 1; - int h_col_end = min(int(h / stride_height + 1), int(output_height)); - for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { - for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - // the col location: [c * width * height + h_out, w_out] - int c_col = int(c * filter_height * filter_width) + - (h - h_col * (int)stride_height) * (int)filter_width + - (w - w_col * (int)stride_width); - val += - data_col[(c_col * output_height + h_col) * output_width + w_col]; + int w = index % im_width; + int h = (index / im_width) % im_height; + int c = index / (im_width * im_height); + + // compute the start and end of the output + int w_col_start = + (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1; + int w_col_end = min(w / stride_width + 1, col_width); + int h_col_start = + (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1; + int h_col_end = min(h / stride_height + 1, col_height); + + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + int h_off = (h - h_col * stride_height); + int w_off = (w - w_col * stride_width); + if (h_off % dilation_h == 0 && w_off % dilation_w == 0) { + h_off /= dilation_h; + w_off /= dilation_w; + int data_col_index = + (((c * filter_height + h_off) * filter_width + w_off) * + col_height + + h_col) * + col_width + + w_col; + val += data_col[data_col_index]; } } - h -= padding_height; - w -= padding_width; - data_im[c * ((width - 2 * padding_width) * - (height - 2 * padding_height)) + - h * (width - 2 * padding_width) + w] += val; } + data_im[index] = val; } } @@ -160,32 +163,36 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& col, int dilation_h, int dilation_w, + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int input_channels = im.dims()[0]; - int input_height = im.dims()[1]; - int input_width = im.dims()[2]; + int im_channels = im.dims()[0]; + int im_height = im.dims()[1]; + int im_width = im.dims()[2]; int filter_height = col.dims()[1]; int filter_width = col.dims()[2]; - int output_height = col.dims()[3]; - int output_width = col.dims()[4]; - - PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / - stride_height + - 1 == - output_height); - PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / - stride_width + - 1 == - output_width); - - size_t num_kernels = input_channels * - (input_height + padding_up + padding_down) * - (input_width + padding_left + padding_right); + int col_height = col.dims()[3]; + int col_width = col.dims()[4]; + + PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - + (dilation_h * (filter_height - 1) + 1)) / + stride_height + + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - + (dilation_w * (filter_width - 1) + 1)) / + stride_width + + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); + + size_t num_kernels = im_channels * im_height * im_width; size_t blocks = (num_kernels + 1024 - 1) / 1024; size_t block_x = 512; @@ -198,10 +205,9 @@ class Col2ImFunctor<<(context) .stream()>>>( - num_kernels, col.data(), input_height + padding_up + padding_down, - input_width + padding_left + padding_left, input_channels, + num_kernels, col.data(), im_height, im_width, dilation_h, dilation_w, filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, output_height, output_width, im.data()); + padding_left, col_height, col_width, im.data()); } }; @@ -215,33 +221,32 @@ template class Col2ImFunctor; template -__global__ void im2colOCF(const T* im_data, T* col_data, int input_channels, - int input_height, int input_width, int filter_height, +__global__ void im2colOCF(const T* im_data, T* col_data, int im_channels, + int im_height, int im_width, int filter_height, int filter_width, int stride_height, int stride_width, - int padding_height, int padding_width, - int output_height, int output_width) { + int padding_height, int padding_width, int col_height, + int col_width) { int swid = blockIdx.x; int shid = blockIdx.y; - for (int channelid = threadIdx.z; channelid < input_channels; + for (int channelid = threadIdx.z; channelid < im_channels; channelid += blockDim.z) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { int width_offset = idx + swid * stride_width - padding_width; int height_offset = idy + shid * stride_height - padding_height; - int im_offset = width_offset + height_offset * input_width + - channelid * input_height * input_width; + int im_offset = width_offset + height_offset * im_width + + channelid * im_height * im_width; int col_offset = idx + idy * filter_width + channelid * filter_height * filter_width + - (shid * output_width + swid) * - (input_channels * filter_height * filter_width); - - if (height_offset >= input_height || height_offset < 0 || - width_offset >= input_width || width_offset < 0) { - col_data[col_offset] = T(0); - } else { - col_data[col_offset] = im_data[im_offset]; - } + (shid * col_width + swid) * + (im_channels * filter_height * filter_width); + + col_data[col_offset] = + (height_offset >= im_height || height_offset < 0 || + width_offset >= im_width || width_offset < 0) + ? T(0) + : im_data[im_offset]; } } } @@ -258,26 +263,33 @@ class Im2ColFunctor<<(context) .stream()>>>( - im.data(), col.data(), input_channels, input_height, input_width, + im.data(), col.data(), im_channels, im_height, im_width, filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, output_height, output_width); + padding_left, col_height, col_width); } }; template -__global__ void col2imOCF(T* im_data, const T* col_data, int input_channels, - int input_height, int input_width, int filter_height, +__global__ void col2imOCF(T* im_data, const T* col_data, int im_channels, + int im_height, int im_width, int filter_height, int filter_width, int stride_height, int stride_width, - int padding_height, int padding_width, - int output_height, int output_width) { + int padding_height, int padding_width, int col_height, + int col_width) { int swid = blockIdx.x; int shid = blockIdx.y; - for (int channelid = threadIdx.z; channelid < input_channels; + for (int channelid = threadIdx.z; channelid < im_channels; channelid += blockDim.z) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { int width_offset = idx + swid * stride_width - padding_width; int height_offset = idy + shid * stride_height - padding_height; - int im_offset = width_offset + height_offset * input_width + - channelid * input_height * input_width; + int im_offset = width_offset + height_offset * im_width + + channelid * im_height * im_width; int col_offset = idx + idy * filter_width + channelid * filter_height * filter_width + - (shid * output_width + swid) * - (input_channels * filter_height * filter_width); + (shid * col_width + swid) * + (im_channels * filter_height * filter_width); - if (height_offset >= 0 && height_offset < input_height && - width_offset >= 0 && width_offset < input_width) { + if (height_offset >= 0 && height_offset < im_height && + width_offset >= 0 && width_offset < im_width) { paddle::platform::CudaAtomicAdd(im_data + im_offset, col_data[col_offset]); } @@ -350,27 +361,33 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& col, int dilation_h, int dilation_w, + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int input_channels = im.dims()[0]; - int input_height = im.dims()[1]; - int input_width = im.dims()[2]; + int im_channels = im.dims()[0]; + int im_height = im.dims()[1]; + int im_width = im.dims()[2]; int filter_height = col.dims()[3]; int filter_width = col.dims()[4]; - int output_height = col.dims()[0]; - int output_width = col.dims()[1]; - - PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / - stride_height + - 1 == - output_height); - PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / - stride_width + - 1 == - output_width); + int col_height = col.dims()[0]; + int col_width = col.dims()[1]; + + PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - + (dilation_h * (filter_height - 1) + 1)) / + stride_height + + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - + (dilation_w * (filter_width - 1) + 1)) / + stride_width + + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); int block_dim_x = 0; int block_dim_y = 0; @@ -389,15 +406,14 @@ class Col2ImFunctor<<(context) .stream()>>>( - im.data(), col.data(), input_channels, input_height, input_width, + im.data(), col.data(), im_channels, im_height, im_width, filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, output_height, output_width); + padding_left, col_height, col_width); } }; diff --git a/paddle/operators/math/im2col.h b/paddle/operators/math/im2col.h index c736d4fa52..d1c9595a32 100644 --- a/paddle/operators/math/im2col.h +++ b/paddle/operators/math/im2col.h @@ -74,17 +74,18 @@ class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& im, framework::Tensor& col, - int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left, int padding_right); + int dilation_h, int dilation_w, int stride_height, + int stride_width, int padding_up, int padding_down, + int padding_left, int padding_right); }; template class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right); + const framework::Tensor& col, int dilation_h, int dilation_w, + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right); }; } // namespace math diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index 5763782c4e..3385fe8721 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -47,6 +47,8 @@ void testIm2col() { int filter_size = 2; int stride = 1; int padding = 0; + int dilation_h = 1; + int dilation_w = 1; int output_height = (input_height - filter_size + 2 * padding) / stride + 1; int output_width = (input_width - filter_size + 2 * padding) / stride + 1; float* input_ptr = input_tmp.mutable_data( @@ -85,10 +87,10 @@ void testIm2col() { paddle::operators::math::ColFormat::kOCF, Place, float> im2col_ocf; - im2col(*context, input, output_cfo, stride, stride, padding, padding, padding, - padding); - im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding, - padding, padding); + im2col(*context, input, output_cfo, dilation_h, dilation_w, stride, stride, + padding, padding, padding, padding); + im2col_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride, + stride, padding, padding, padding, padding); float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5}; float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5}; @@ -131,8 +133,8 @@ void testIm2col() { input.CopyFrom(input_tmp, *place, *context); } - col2im(*context, input, output_cfo, stride, stride, padding, padding, padding, - padding); + col2im(*context, input, output_cfo, dilation_h, dilation_w, stride, stride, + padding, padding, padding, padding); float* in_ptr; if (paddle::platform::is_cpu_place(*place)) { @@ -153,8 +155,8 @@ void testIm2col() { input.CopyFrom(input_tmp, *place, *context); } - col2im_ocf(*context, input, output_ocf, stride, stride, padding, padding, - padding, padding); + col2im_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride, + stride, padding, padding, padding, padding); if (paddle::platform::is_cpu_place(*place)) { in_ptr = input.data(); From b6f9ba484ee285b75d40272f8a2f48267fb3284c Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 8 Nov 2017 18:19:41 +0800 Subject: [PATCH 02/25] fix conv2d doc --- paddle/operators/conv_op.cc | 14 ++++++++++---- python/paddle/v2/framework/tests/test_conv2d_op.py | 5 ++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 852ac2ae37..a848b9b49c 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -54,6 +54,12 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < paddings.size(); ++i) { + PADDLE_ENFORCE(in_dims[i + 2] + 2 * paddings[i] - + (dilations[i] * (filter_dims[i + 2] - 1) + 1) > + 0, + "Due to the settings of paddings, filter_dims and " + "dilations, the output size is less than 0, please check " + "again."); output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], paddings[i], paddings[i], strides[i])); @@ -100,11 +106,11 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, Convolution Operator. The convolution operation calculates the output based on the input, filter -and strides, paddings, groups parameters. The size of each dimension of the +and strides, paddings, groups, dilations parameters. The size of each dimension of the parameters is checked in the infer-shape. Input(Input, Filter) and output(Output) are in NCHW format. Where N is batch size, C is the number of channels, H is the height of the feature, and W is -the width of the feature. Parameters(ksize, strides, paddings) are two elements. +the width of the feature. Parameters(ksize, strides, paddings, dilations) are two elements. These two elements represent height and width, respectively. The input(X) size and output(Out) size may be different. @@ -115,8 +121,8 @@ Example: Output: Output shape: (N, C_out, H_out, W_out) where - H_out = (H_in - filter_size[0] + 2 * paddings[0]) / strides[0] + 1; - W_out = (W_in - filter_size[1] + 2 * paddings[1]) / strides[1] + 1; + H_out = (H_in + 2 * paddings[0] - (dilations[0]*(filter_size[0] - 1) + 1)) / strides[0] + 1; + W_out = (W_in + 2 * paddings[1] - (dilations[1]*(filter_size[1] - 1) + 1)) / strides[1] + 1; )DOC"); } diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index 04ae7f294c..f3f3930dab 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -39,6 +39,7 @@ class TestConv2dOp(OpTest): def setUp(self): self.init_op_type() self.init_group() + self.init_dilation() self.init_test_case() conv2d_param = {'stride': self.stride, 'pad': self.pad} @@ -80,12 +81,14 @@ class TestConv2dOp(OpTest): def init_test_case(self): self.pad = [0, 0] self.stride = [1, 1] - self.dilations = [1, 1] self.input_size = [2, 3, 5, 5] # NCHW assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] / self.groups self.filter_size = [6, f_c, 3, 3] + def init_dilation(self): + self.dilations = [1, 1] + def init_group(self): self.groups = 1 From 21ce704247b53e08cb092a7602f351464892f528 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 9 Nov 2017 11:02:04 +0800 Subject: [PATCH 03/25] refine conv2d for filter size:(1,1) --- paddle/operators/conv_op.h | 256 ++++++++++++------ .../v2/framework/tests/test_conv2d_op.py | 19 ++ 2 files changed, 192 insertions(+), 83 deletions(-) diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index 2459f03a1a..8e9f3b0b0e 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -35,6 +35,18 @@ inline int OutputSize(int input_size, int filter_size, int dilation, 1; return output_size; } +inline bool NotExpand(std::vector& filter_dim, + std::vector& strides, std::vector& paddings, + std::vector& dilations) { + bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; + for (size_t j = 0; j < strides.size(); ++j) { + filter_1 &= (static_cast(filter_dim[j]) == 1); + strides_1 &= (strides[j] == 1); + padding_0 &= (paddings[j] == 0); + dilation_1 &= (dilations[j] == 1); + } + return filter_1 && strides_1 && padding_0 && dilation_1; +} // Define Op classes in .h file so that other conv // operator implementations can reuse the code. @@ -110,14 +122,17 @@ class GemmConvKernel : public framework::OpKernel { framework::DDim col_matrix_shape = framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); + bool not_expand = NotExpand(filter_shape_vec, strides, paddings, dilations); Tensor col; - col.mutable_data(col_shape, context.GetPlace()); // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. Tensor col_matrix; - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); + if (!not_expand) { + col.mutable_data(col_shape, context.GetPlace()); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } framework::DDim input_shape = framework::slice_ddim( input->dims(), 1, static_cast(input->dims().size())); @@ -134,31 +149,51 @@ class GemmConvKernel : public framework::OpKernel { int in_step = static_cast(input->dims()[1]) / groups; int out_step = static_cast(output->dims()[1]) / groups; - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - for (int g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - - if (filter_shape_vec.size() == 2) { - // im2col - math::Im2ColFunctor im2col; - im2col(context.device_context(), in_slice, col, dilations[0], - dilations[1], strides[0], strides[1], paddings[0], paddings[0], - paddings[1], paddings[1]); - } else if (filter_shape_vec.size() == 3) { - // vol2col - math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), in_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + if (!not_expand) { + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + for (int g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + if (filter_shape_vec.size() == 2) { + // im2col + math::Im2ColFunctor im2col; + im2col(context.device_context(), in_slice, col, dilations[0], + dilations[1], strides[0], strides[1], paddings[0], + paddings[0], paddings[1], paddings[1]); + } else if (filter_shape_vec.size() == 3) { + // vol2col + math::Vol2ColFunctor vol2col; + vol2col(context.device_context(), in_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + } + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), filter_slice, false, + col_matrix, false, T(1.0), &out_slice, T(0.0)); } + } + } else { + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + for (int g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, false, - col_matrix, false, T(1.0), &out_slice, T(0.0)); + col.ShareDataWith(in_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), filter_slice, false, + col_matrix, false, T(1.0), &out_slice, T(0.0)); + } } } } @@ -235,14 +270,17 @@ class GemmConvGradKernel : public framework::OpKernel { int in_step = static_cast(input->dims()[1]) / groups; int out_step = static_cast(output_grad->dims()[1]) / groups; + bool not_expand = NotExpand(filter_shape_vec, strides, paddings, dilations); Tensor col; // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. Tensor col_matrix; - col.mutable_data(col_shape, context.GetPlace()); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); + if (!not_expand) { + col.mutable_data(col_shape, context.GetPlace()); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } math::SetConstant set_zero; @@ -250,33 +288,60 @@ class GemmConvGradKernel : public framework::OpKernel { input_grad->mutable_data(context.GetPlace()); set_zero(context.device_context(), input_grad, static_cast(0)); - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // gemm - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, true, - out_grad_slice, false, T(1.0), &col_matrix, - T(0.0)); - // col2im - Tensor in_grad_slice = - in_grad_batch.Slice(g * in_step, (g + 1) * in_step); - - if (filter_shape_vec.size() == 2) { - math::Col2ImFunctor col2im; - col2im(context.device_context(), in_grad_slice, col, dilations[0], - dilations[1], strides[0], strides[1], paddings[0], - paddings[0], paddings[1], paddings[1]); - - } else if (filter_shape_vec.size() == 3) { - math::Col2VolFunctor col2vol; - col2vol(context.device_context(), in_grad_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + if (!not_expand) { + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_grad_batch = + input_grad->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // gemm + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = + filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), filter_slice, true, + out_grad_slice, false, T(1.0), &col_matrix, + T(0.0)); + Tensor in_grad_slice = + in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + + if (filter_shape_vec.size() == 2) { + math::Col2ImFunctor col2im; + col2im(context.device_context(), in_grad_slice, col, dilations[0], + dilations[1], strides[0], strides[1], paddings[0], + paddings[0], paddings[1], paddings[1]); + + } else if (filter_shape_vec.size() == 3) { + math::Col2VolFunctor col2vol; + col2vol(context.device_context(), in_grad_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + } + } + } + } else { + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_grad_batch = + input_grad->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // gemm + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = + filter.Slice(g * out_step, (g + 1) * out_step); + + Tensor in_grad_slice = + in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + + col_matrix.ShareDataWith(in_grad_slice); + col_matrix.Resize(col_matrix_shape); + + math::matmul(context.device_context(), filter_slice, true, + out_grad_slice, false, T(1.0), &col_matrix, + T(0.0)); } } } @@ -288,34 +353,59 @@ class GemmConvGradKernel : public framework::OpKernel { filter_grad_.Resize(filter_matrix_shape); set_zero(context.device_context(), filter_grad, static_cast(0)); - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // im2col - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - - if (filter_shape_vec.size() == 2) { - math::Im2ColFunctor im2col; - im2col(context.device_context(), in_slice, col, dilations[0], - dilations[1], strides[0], strides[1], paddings[0], - paddings[0], paddings[1], paddings[1]); - } else if (filter_shape_vec.size() == 3) { - math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), in_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + if (!not_expand) { + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // im2col + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + if (filter_shape_vec.size() == 2) { + math::Im2ColFunctor im2col; + im2col(context.device_context(), in_slice, col, dilations[0], + dilations[1], strides[0], strides[1], paddings[0], + paddings[0], paddings[1], paddings[1]); + } else if (filter_shape_vec.size() == 3) { + math::Vol2ColFunctor vol2col; + vol2col(context.device_context(), in_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + } + + // gemm + Tensor filter_grad_slice = + filter_grad_.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), out_grad_slice, + false, col_matrix, true, T(1.0), + &filter_grad_slice, T(1.0)); + } + } + } else { + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // im2col + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + col.ShareDataWith(in_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + + // gemm + Tensor filter_grad_slice = + filter_grad_.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), out_grad_slice, + false, col_matrix, true, T(1.0), + &filter_grad_slice, T(1.0)); } - - // gemm - Tensor filter_grad_slice = - filter_grad_.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), out_grad_slice, - false, col_matrix, true, T(1.0), - &filter_grad_slice, T(1.0)); } } } diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index f3f3930dab..4ba67cf006 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -104,6 +104,25 @@ class TestWithGroup(TestConv2dOp): self.op_type = "conv2d" +class TestWith1x1(TestConv2dOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] / self.groups + self.filter_size = [6, f_c, 1, 1] + + def init_dilation(self): + self.dilations = [1, 1] + + def init_group(self): + self.groups = 3 + + def init_op_type(self): + self.op_type = "conv2d" + + #----------------Conv2dCudnn---------------- From 93551bd232dacdc4afccb392f507eb48747c2978 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 9 Nov 2017 15:00:48 +0800 Subject: [PATCH 04/25] refine unit test (Add dilation) --- paddle/operators/math/im2col.cc | 12 ++-- .../v2/framework/tests/test_conv2d_op.py | 63 +++++++++++++++---- 2 files changed, 56 insertions(+), 19 deletions(-) diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index b248863b4e..2af55fa71f 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -73,13 +73,13 @@ class Im2ColFunctor= im_height || im_col_idx < 0 || - im_col_idx >= im_width) - ? static_cast(0) - : im_data[(im_row_idx + c_im * im_height) * im_width + - im_col_idx]; + col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || + im_col_idx < 0 || im_col_idx >= im_width) + ? static_cast(0) + : im_data[im_idx]; } } } diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index 4ba67cf006..907b52c405 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -10,23 +10,33 @@ def conv2d_forward_naive(input, filter, group, conv_param): assert np.mod(out_c, group) == 0 sub_out_c = out_c / group - stride, pad = conv_param['stride'], conv_param['pad'] - out_h = 1 + (in_h + 2 * pad[0] - f_h) / stride[0] - out_w = 1 + (in_w + 2 * pad[1] - f_w) / stride[1] + stride, pad, dilation = conv_param['stride'], conv_param['pad'], conv_param[ + 'dilation'] + out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) / stride[0] + out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) / stride[1] out = np.zeros((in_n, out_c, out_h, out_w)) + d_bolck_w = (dilation[0] * (f_h - 1) + 1) + d_bolck_h = (dilation[1] * (f_w - 1) + 1) + input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], )), mode='constant', constant_values=0) + + filter_dilation = np.zeros((out_c, f_c, d_bolck_h, d_bolck_w)) + filter_dilation[:, :, 0:d_bolck_h:dilation[0], 0:d_bolck_w:dilation[ + 1]] = filter + for i in range(out_h): for j in range(out_w): for g in range(group): input_pad_masked = \ input_pad[:, g * f_c:(g + 1) * f_c, - i * stride[0]:i * stride[0] + f_h, - j * stride[1]:j * stride[1] + f_w] + i * stride[0]:i * stride[0] + d_bolck_h, + j * stride[1]:j * stride[1] + d_bolck_w] - f_sub = filter[g * sub_out_c:(g + 1) * sub_out_c, :, :, :] + f_sub = filter_dilation[g * sub_out_c:(g + 1) * + sub_out_c, :, :, :] for k in range(sub_out_c): out[:, g * sub_out_c + k, i, j] = \ np.sum(input_pad_masked * f_sub[k, :, :, :], @@ -42,7 +52,11 @@ class TestConv2dOp(OpTest): self.init_dilation() self.init_test_case() - conv2d_param = {'stride': self.stride, 'pad': self.pad} + conv2d_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilation': self.dilations + } input = np.random.random(self.input_size).astype("float32") filter = np.random.random(self.filter_size).astype("float32") output = conv2d_forward_naive(input, filter, self.groups, @@ -123,24 +137,47 @@ class TestWith1x1(TestConv2dOp): self.op_type = "conv2d" -#----------------Conv2dCudnn---------------- +class TestWithDilation(TestConv2dOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 10, 10] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] / self.groups + self.filter_size = [6, f_c, 3, 3] + def init_dilation(self): + self.dilations = [2, 2] -class TestCudnn(TestConv2dOp): def init_group(self): - self.groups = 1 + self.groups = 3 + def init_op_type(self): + self.op_type = "conv2d" + + +#----------------Conv2dCudnn---------------- + + +class TestCudnn(TestConv2dOp): def init_op_type(self): self.op_type = "conv_cudnn" -class TestCudnnWithGroup(TestConv2dOp): - def init_group(self): - self.groups = 3 +class TestCudnnWithGroup(TestWithGroup): + def init_op_type(self): + self.op_type = "conv_cudnn" + +class TestCudnnWith1x1(TestWith1x1): def init_op_type(self): self.op_type = "conv_cudnn" +# cudnn v5 does not support dilation conv. +# class TestCudnnWithDilation(TestWithDilation): +# def init_op_type(self): +# self.op_type = "conv_cudnn" + if __name__ == '__main__': unittest.main() From 271fc9c1198e90813fee647b7020ee752aae549a Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 10 Nov 2017 10:25:44 +0800 Subject: [PATCH 05/25] Add dilation for vol2col --- paddle/operators/conv_op.h | 15 +-- paddle/operators/conv_transpose_op.h | 13 ++- paddle/operators/math/im2col.cu | 1 + paddle/operators/math/vol2col.cc | 80 ++++++++++++--- paddle/operators/math/vol2col.cu | 139 +++++++++++++++++++------- paddle/operators/math/vol2col.h | 2 + paddle/operators/math/vol2col_test.cc | 9 +- 7 files changed, 189 insertions(+), 70 deletions(-) diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index 8e9f3b0b0e..af2c8fb163 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -165,9 +165,9 @@ class GemmConvKernel : public framework::OpKernel { } else if (filter_shape_vec.size() == 3) { // vol2col math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), in_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + vol2col(context.device_context(), in_slice, col, dilations[0], + dilations[1], dilations[2], strides[0], strides[1], + strides[2], paddings[0], paddings[1], paddings[2]); } // gemm @@ -314,7 +314,8 @@ class GemmConvGradKernel : public framework::OpKernel { } else if (filter_shape_vec.size() == 3) { math::Col2VolFunctor col2vol; - col2vol(context.device_context(), in_grad_slice, col, strides[0], + col2vol(context.device_context(), in_grad_slice, col, + dilations[0], dilations[1], dilations[2], strides[0], strides[1], strides[2], paddings[0], paddings[1], paddings[2]); } @@ -371,9 +372,9 @@ class GemmConvGradKernel : public framework::OpKernel { paddings[0], paddings[1], paddings[1]); } else if (filter_shape_vec.size() == 3) { math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), in_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + vol2col(context.device_context(), in_slice, col, dilations[0], + dilations[1], dilations[2], strides[0], strides[1], + strides[2], paddings[0], paddings[1], paddings[2]); } // gemm diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index cbfad88b39..18ca6b20e0 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -69,6 +69,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { // TODO(Zhuoyuan): Paddings can be added in future. // groups will alway be disabled in conv2dtranspose. + int dilaiton_d = 1; int dilation_h = 1; int dilation_w = 1; @@ -149,8 +150,9 @@ class GemmConvTransposeKernel : public framework::OpKernel { // col2vol: col_matrix -> dy // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) math::Col2VolFunctor col2vol; - col2vol(context.device_context(), output_batch, col, strides[0], - strides[1], strides[2], 0, 0, 0); + col2vol(context.device_context(), output_batch, col, dilaiton_d, + dilation_h, dilation_w, strides[0], strides[1], strides[2], 0, + 0, 0); } } } @@ -177,6 +179,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // Actually, no paddings and groups allowed in conv transpose. std::vector paddings = context.Attr>("paddings"); + int dilaiton_d = 1; int dilation_h = 1; int dilation_w = 1; @@ -261,9 +264,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // vol2col: dy -> col_matrix // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), output_grad_batch, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + vol2col(context.device_context(), output_grad_batch, col, dilaiton_d, + dilation_h, dilation_w, strides[0], strides[1], strides[2], + paddings[0], paddings[1], paddings[2]); } if (input_grad) { diff --git a/paddle/operators/math/im2col.cu b/paddle/operators/math/im2col.cu index 69e2abee03..9da427fdf1 100644 --- a/paddle/operators/math/im2col.cu +++ b/paddle/operators/math/im2col.cu @@ -145,6 +145,7 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width, h_col) * col_width + w_col; + val += data_col[data_col_index]; } } diff --git a/paddle/operators/math/vol2col.cc b/paddle/operators/math/vol2col.cc index e9718a0473..d383ee8152 100644 --- a/paddle/operators/math/vol2col.cc +++ b/paddle/operators/math/vol2col.cc @@ -29,6 +29,7 @@ class Vol2ColFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& vol, framework::Tensor& col, + int dilation_d, int dilation_h, int dilation_w, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width) const { @@ -48,6 +49,28 @@ class Vol2ColFunctor { int channels_col = input_channels * filter_depth * filter_height * filter_width; + PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - + ((dilation_d * (filter_depth - 1) + 1))) / + stride_depth + + 1, + output_depth, + "input_depth and output_depth are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - + ((dilation_h * (filter_height - 1) + 1))) / + stride_height + + 1, + output_height, + "input_height and output_height are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - + ((dilation_w * (filter_width - 1) + 1))) / + stride_width + + 1, + output_width, + "input_width and output_width are " + "Mismatching."); + const T* vol_data = vol.data(); T* col_data = col.data(); @@ -57,24 +80,25 @@ class Vol2ColFunctor { int d_offset = (c / filter_width / filter_height) % filter_depth; int c_in = c / filter_width / filter_height / filter_depth; for (int d = 0; d < output_depth; ++d) { - int d_pad = d * stride_depth - padding_depth + d_offset; + int d_pad = d * stride_depth - padding_depth + d_offset * dilation_d; for (int h = 0; h < output_height; ++h) { - int h_pad = h * stride_height - padding_height + h_offset; + int h_pad = + h * stride_height - padding_height + h_offset * dilation_h; for (int w = 0; w < output_width; ++w) { - int w_pad = w * stride_width - padding_width + w_offset; + int w_pad = + w * stride_width - padding_width + w_offset * dilation_w; int col_idx = ((c * output_depth + d) * output_height + h) * output_width + w; - if (h_pad < 0 || h_pad >= input_height || w_pad < 0 || - w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) { - col_data[col_idx] = static_cast(0); - } else { - int vol_idx = - ((c_in * input_depth + d_pad) * input_height + h_pad) * - input_width + - w_pad; - col_data[col_idx] = vol_data[vol_idx]; - } + int vol_idx = + ((c_in * input_depth + d_pad) * input_height + h_pad) * + input_width + + w_pad; + col_data[col_idx] = + (h_pad < 0 || h_pad >= input_height || w_pad < 0 || + w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) + ? static_cast(0) + : vol_data[vol_idx]; } } } @@ -93,6 +117,7 @@ class Col2VolFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& vol, const framework::Tensor& col, + int dilation_d, int dilation_h, int dilation_w, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width) const { @@ -112,6 +137,27 @@ class Col2VolFunctor { int channels_col = input_channels * filter_depth * filter_height * filter_width; + PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - + ((dilation_d * (filter_depth - 1) + 1))) / + stride_depth + + 1, + output_depth, + "input_depth and output_depth are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - + ((dilation_h * (filter_height - 1) + 1))) / + stride_height + + 1, + output_height, + "input_height and output_height are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - + ((dilation_w * (filter_width - 1) + 1))) / + stride_width + + 1, + output_width, + "input_width and output_width are " + "Mismatching."); T* vol_data = vol.data(); const T* col_data = col.data(); @@ -121,11 +167,13 @@ class Col2VolFunctor { int d_offset = (c / filter_width / filter_height) % filter_depth; int cIm = c / filter_width / filter_height / filter_depth; for (int d = 0; d < output_depth; ++d) { - int d_pad = d * stride_depth - padding_depth + d_offset; + int d_pad = d * stride_depth - padding_depth + d_offset * dilation_d; for (int h = 0; h < output_height; ++h) { - int h_pad = h * stride_height - padding_height + h_offset; + int h_pad = + h * stride_height - padding_height + h_offset * dilation_h; for (int w = 0; w < output_width; ++w) { - int w_pad = w * stride_width - padding_width + w_offset; + int w_pad = + w * stride_width - padding_width + w_offset * dilation_w; if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 && w_pad < input_width && d_pad >= 0 && d_pad < input_depth) { diff --git a/paddle/operators/math/vol2col.cu b/paddle/operators/math/vol2col.cu index 27b11fb237..080d3e5466 100644 --- a/paddle/operators/math/vol2col.cu +++ b/paddle/operators/math/vol2col.cu @@ -21,11 +21,12 @@ namespace math { template __global__ void vol2col(int num_kernels, const T* data_vol, int depth, - int height, int width, int filter_depth, - int filter_height, int filter_width, int stride_depth, - int stride_height, int stride_width, int padding_depth, - int padding_height, int padding_width, int output_detph, - int output_height, int output_width, T* data_col) { + int height, int width, int dilation_d, int dilation_h, + int dilation_w, int filter_depth, int filter_height, + int filter_width, int stride_depth, int stride_height, + int stride_width, int padding_depth, int padding_height, + int padding_width, int output_detph, int output_height, + int output_width, T* data_col) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) { int w_out = index % output_width; @@ -44,12 +45,14 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth, for (int k = 0; k < filter_depth; ++k) { for (int i = 0; i < filter_height; ++i) { for (int j = 0; j < filter_width; ++j) { - int d = d_in + k; - int h = h_in + i; - int w = w_in + j; + int d = d_in + k * dilation_d; + int h = h_in + i * dilation_h; + int w = w_in + j * dilation_w; + int col_idx = (k * dilation_d * height + i * dilation_h) * width + + j * dilation_w; *data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 && w < width) - ? data_vol[(k * height + i) * width + j] + ? data_vol[col_idx] : 0; data_col += output_detph * output_height * output_width; } @@ -69,6 +72,7 @@ class Vol2ColFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& vol, framework::Tensor& col, + int dilation_d, int dilation_h, int dilation_w, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width) const { @@ -86,6 +90,28 @@ class Vol2ColFunctor { int output_height = col.dims()[5]; int output_width = col.dims()[6]; + PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - + ((dilation_d * (filter_depth - 1) + 1))) / + stride_depth + + 1, + output_depth, + "input_depth and output_depth are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - + ((dilation_h * (filter_height - 1) + 1))) / + stride_height + + 1, + output_height, + "input_height and output_height are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - + ((dilation_w * (filter_width - 1) + 1))) / + stride_width + + 1, + output_width, + "input_width and output_width are " + "Mismatching."); + int num_outputs = input_channels * output_depth * output_height * output_width; @@ -95,19 +121,25 @@ class Vol2ColFunctor { reinterpret_cast(context) .stream()>>>( num_outputs, vol.data(), input_depth, input_height, input_width, - filter_depth, filter_height, filter_width, stride_depth, stride_height, - stride_width, padding_depth, padding_height, padding_width, - output_depth, output_height, output_width, col.data()); + dilation_d, dilation_h, dilation_w, filter_depth, filter_height, + filter_width, stride_depth, stride_height, stride_width, padding_depth, + padding_height, padding_width, output_depth, output_height, + output_width, col.data()); } }; template __global__ void col2vol(int num_kernels, const T* data_col, int depth, - int height, int width, int filter_depth, - int filter_height, int filter_width, int stride_depth, - int stride_height, int stride_width, int padding_depth, - int padding_height, int padding_width, int output_detph, - int output_height, int output_width, T* data_vol) { + int height, int width, int dilation_d, int dilation_h, + int dilation_w, int filter_depth, int filter_height, + int filter_width, int stride_depth, int stride_height, + int stride_width, int padding_depth, int padding_height, + int padding_width, int output_detph, int output_height, + int output_width, T* data_vol) { + const int d_filter_depth = dilation_d * (filter_depth - 1) + 1; + const int d_filter_height = dilation_h * (filter_height - 1) + 1; + const int d_filter_width = dilation_w * (filter_width - 1) + 1; + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) { T src_val = 0; @@ -115,35 +147,42 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth, int h = (index / width) % height + padding_height; int d = (index / width / height) % depth + padding_depth; int c = index / width / height / depth; + // compute the start and end of the output int w_col_start = - (w < filter_width) ? 0 : (w - filter_width) / stride_width + 1; + (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1; int w_col_end = min(w / stride_width + 1, output_width); int h_col_start = - (h < filter_height) ? 0 : (h - filter_height) / stride_height + 1; + (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1; int h_col_end = min(h / stride_height + 1, output_height); int d_col_start = - (d < filter_depth) ? 0 : (d - filter_depth) / stride_depth + 1; + (d < d_filter_depth) ? 0 : (d - d_filter_depth) / stride_depth + 1; int d_col_end = min(d / stride_depth + 1, output_detph); - int offset = (c * filter_depth * filter_height * filter_width + - d * filter_width * filter_height + h * filter_width + w) * - output_detph * output_height * output_width; - - int coeff_d_col = - (1 - stride_depth * filter_width * filter_height * output_detph) * - output_height * output_width; - int coeff_h_col = - (1 - stride_height * filter_width * output_detph * output_height) * - output_width; - int coeff_w_col = - (1 - stride_width * output_detph * output_height * output_width); - for (int d_col = d_col_start; d_col < d_col_end; ++d_col) { for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - src_val += data_col[offset + d_col * coeff_d_col + - h_col * coeff_h_col + w_col * coeff_w_col]; + int d_off = (d - d_col * stride_depth); + int h_off = (h - h_col * stride_height); + int w_off = (w - w_col * stride_width); + if (d_off % dilation_d == 0 && h_off % dilation_h == 0 && + w_off % dilation_w == 0) { + d_off /= dilation_d; + h_off /= dilation_h; + w_off /= dilation_w; + + int data_col_index = + (((((c * filter_depth + d_off) * filter_height + h_off) * + filter_width + + w_off) * + output_detph + + d_col) * + output_height + + h_col) * + output_width + + w_col; + src_val += data_col[data_col_index]; + } } } } @@ -162,6 +201,7 @@ class Col2VolFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& vol, const framework::Tensor& col, + int dilation_d, int dilation_h, int dilation_w, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width) const { @@ -179,6 +219,28 @@ class Col2VolFunctor { int output_height = col.dims()[5]; int output_width = col.dims()[6]; + PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - + ((dilation_d * (filter_depth - 1) + 1))) / + stride_depth + + 1, + output_depth, + "input_depth and output_depth are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - + ((dilation_h * (filter_height - 1) + 1))) / + stride_height + + 1, + output_height, + "input_height and output_height are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - + ((dilation_w * (filter_width - 1) + 1))) / + stride_width + + 1, + output_width, + "input_width and output_width are " + "Mismatching."); + int num_kernels = input_channels * input_depth * input_height * input_width; const int threads = 1024; @@ -188,9 +250,10 @@ class Col2VolFunctor { reinterpret_cast(context) .stream()>>>( num_kernels, col.data(), input_depth, input_height, input_width, - filter_depth, filter_height, filter_width, stride_depth, stride_height, - stride_width, padding_depth, padding_height, padding_width, - output_depth, output_height, output_width, vol.data()); + dilation_d, dilation_h, dilation_w, filter_depth, filter_height, + filter_width, stride_depth, stride_height, stride_width, padding_depth, + padding_height, padding_width, output_depth, output_height, + output_width, vol.data()); } }; diff --git a/paddle/operators/math/vol2col.h b/paddle/operators/math/vol2col.h index f022365a16..c2d8257c0b 100644 --- a/paddle/operators/math/vol2col.h +++ b/paddle/operators/math/vol2col.h @@ -58,6 +58,7 @@ class Vol2ColFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& vol, framework::Tensor& col, + int dilation_d, int dilation_h, int dilation_w, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width) const; @@ -68,6 +69,7 @@ class Col2VolFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& vol, const framework::Tensor& col, + int dilation_d, int dilation_h, int dilation_w, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width) const; diff --git a/paddle/operators/math/vol2col_test.cc b/paddle/operators/math/vol2col_test.cc index 74590d17cd..9d673ad36c 100644 --- a/paddle/operators/math/vol2col_test.cc +++ b/paddle/operators/math/vol2col_test.cc @@ -64,6 +64,7 @@ void testVol2col() { int filter_size = 2; int stride = 1; int padding = 0; + int dilation = 1; int output_depth = (input_depth - filter_size + 2 * padding) / stride + 1; int output_height = (input_height - filter_size + 2 * padding) / stride + 1; int output_width = (input_width - filter_size + 2 * padding) / stride + 1; @@ -85,8 +86,8 @@ void testVol2col() { *place); paddle::operators::math::Vol2ColFunctor vol2col; - vol2col(*context, input, output, stride, stride, stride, padding, padding, - padding); + vol2col(*context, input, output, dilation, dilation, dilation, stride, stride, + stride, padding, padding, padding); float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11}; float* out_cfo_ptr; @@ -111,8 +112,8 @@ void testVol2col() { } paddle::operators::math::Col2VolFunctor col2vol; - col2vol(*context, input, output, stride, stride, stride, padding, padding, - padding); + col2vol(*context, input, output, dilation, dilation, dilation, stride, stride, + stride, padding, padding, padding); float* in_ptr; if (paddle::platform::is_cpu_place(*place)) { From 7d73b8fc8e7080b02167808a1a71bd4219089b88 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 10 Nov 2017 11:33:12 +0800 Subject: [PATCH 06/25] fix unit test (conv3d) --- paddle/operators/math/vol2col.cc | 1 + .../v2/framework/tests/test_conv3d_op.py | 84 ++++++++++++++----- 2 files changed, 64 insertions(+), 21 deletions(-) diff --git a/paddle/operators/math/vol2col.cc b/paddle/operators/math/vol2col.cc index d383ee8152..bd509a94f3 100644 --- a/paddle/operators/math/vol2col.cc +++ b/paddle/operators/math/vol2col.cc @@ -181,6 +181,7 @@ class Col2VolFunctor { ((cIm * input_depth + d_pad) * input_height + h_pad) * input_width + w_pad; + int col_idx = ((c * output_depth + d) * output_height + h) * output_width + w; diff --git a/python/paddle/v2/framework/tests/test_conv3d_op.py b/python/paddle/v2/framework/tests/test_conv3d_op.py index 44c192f58d..934ea46437 100644 --- a/python/paddle/v2/framework/tests/test_conv3d_op.py +++ b/python/paddle/v2/framework/tests/test_conv3d_op.py @@ -10,27 +10,40 @@ def conv3d_forward_naive(input, filter, group, conv_param): assert np.mod(out_c, group) == 0 sub_out_c = out_c / group - stride, pad = conv_param['stride'], conv_param['pad'] - out_d = 1 + (in_d + 2 * pad[0] - f_h) / stride[0] - out_h = 1 + (in_h + 2 * pad[1] - f_h) / stride[1] - out_w = 1 + (in_w + 2 * pad[2] - f_w) / stride[2] + stride, pad, dilation = conv_param['stride'], conv_param['pad'], conv_param[ + 'dilations'] + + out_d = 1 + (in_d + 2 * pad[0] - (dilation[0] * (f_d - 1) + 1)) / stride[0] + out_h = 1 + (in_h + 2 * pad[1] - (dilation[1] * (f_h - 1) + 1)) / stride[1] + out_w = 1 + (in_w + 2 * pad[2] - (dilation[2] * (f_w - 1) + 1)) / stride[2] + out = np.zeros((in_n, out_c, out_d, out_h, out_w)) + d_bolck_d = (dilation[0] * (f_d - 1) + 1) + d_bolck_h = (dilation[1] * (f_h - 1) + 1) + d_bolck_w = (dilation[2] * (f_w - 1) + 1) + input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], ), (pad[2], )), mode='constant', constant_values=0) + + filter_dilation = np.zeros((out_c, f_c, d_bolck_d, d_bolck_h, d_bolck_w)) + filter_dilation[:, :, 0:d_bolck_d:dilation[0], 0:d_bolck_h:dilation[1], 0: + d_bolck_w:dilation[2]] = filter + for d in range(out_d): for i in range(out_h): for j in range(out_w): for g in range(group): input_pad_masked = \ input_pad[:, g * f_c:(g + 1) * f_c, - d * stride[0]:d * stride[0] + f_d, - i * stride[1]:i * stride[1] + f_h, - j * stride[2]:j * stride[2] + f_w] - f_sub = filter[g * sub_out_c:(g + 1) * - sub_out_c, :, :, :, :] + d * stride[0]:d * stride[0] + d_bolck_d, + i * stride[1]:i * stride[1] + d_bolck_h, + j * stride[2]:j * stride[2] + d_bolck_w] + + f_sub = filter_dilation[g * sub_out_c:(g + 1) * + sub_out_c, :, :, :, :] for k in range(sub_out_c): out[:, g * sub_out_c + k, d, i, j] = \ np.sum(input_pad_masked * f_sub[k, :, :, :, :], @@ -43,9 +56,14 @@ class TestConv3dOp(OpTest): def setUp(self): self.init_group() self.init_op_type() + self.init_dilation() self.init_test_case() - conv3d_param = {'stride': self.stride, 'pad': self.pad} + conv3d_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilations': self.dilations + } input = np.random.random(self.input_size).astype("float32") filter = np.random.random(self.filter_size).astype("float32") output = conv3d_forward_naive(input, filter, self.groups, @@ -55,7 +73,8 @@ class TestConv3dOp(OpTest): self.attrs = { 'strides': self.stride, 'paddings': self.pad, - 'groups': self.groups + 'groups': self.groups, + 'dilations': self.dilations } self.outputs = {'Output': output} @@ -88,6 +107,9 @@ class TestConv3dOp(OpTest): f_c = self.input_size[1] / self.groups self.filter_size = [6, f_c, 3, 3, 3] + def init_dilation(self): + self.dilations = [1, 1, 1] + def init_group(self): self.groups = 1 @@ -104,27 +126,47 @@ class TestCase1(TestConv3dOp): f_c = self.input_size[1] / self.groups self.filter_size = [6, f_c, 3, 3, 3] - def init_group(self): - self.groups = 1 - def init_op_type(self): - self.op_type = "conv3d" +class TestWithGroup1(TestConv3dOp): + def init_group(self): + self.groups = 3 -class TestWithGroup1(TestConv3dOp): +class TestWithGroup2(TestCase1): def init_group(self): self.groups = 3 - def init_op_type(self): - self.op_type = "conv3d" +class TestWith1x1(TestConv3dOp): + def init_test_case(self): + self.pad = [0, 0, 0] + self.stride = [1, 1, 1] + self.input_size = [2, 3, 4, 4, 4] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] / self.groups + self.filter_size = [6, f_c, 1, 1, 1] + + def init_dilation(self): + self.dilations = [1, 1, 1] -class TestWithGroup2(TestCase1): def init_group(self): self.groups = 3 - def init_op_type(self): - self.op_type = "conv3d" + +class TestWithDilation(TestConv3dOp): + def init_test_case(self): + self.pad = [0, 0, 0] + self.stride = [1, 1, 1] + self.input_size = [2, 3, 6, 6, 6] # NCDHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] / self.groups + self.filter_size = [6, f_c, 2, 2, 2] + + def init_dilation(self): + self.dilations = [2, 2, 2] + + def init_group(self): + self.groups = 3 if __name__ == '__main__': From f5e367655eadf224d1bfd3765564deeefb35ed6b Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Sat, 11 Nov 2017 19:38:35 +0800 Subject: [PATCH 07/25] Use G++ to compile some cu operators. --- paddle/operators/CMakeLists.txt | 14 ++- .../{batch_norm_op.cu => batch_norm_op.cu.cc} | 0 .../{concat_op.cu => concat_op.cu.cc} | 0 ..._op.cu => conv2d_transpose_cudnn_op.cu.cc} | 9 +- .../{conv_cudnn_op.cu => conv_cudnn_op.cu.cc} | 0 .../operators/{conv_op.cu => conv_op.cu.cc} | 0 ...ranspose_op.cu => conv_transpose_op.cu.cc} | 0 ...=> fill_constant_batch_size_like_op.cu.cc} | 2 +- ...os_like_op.cu => fill_zeros_like_op.cu.cc} | 2 +- paddle/operators/{gru_op.cu => gru_op.cu.cc} | 1 - paddle/operators/gru_op.h | 54 ++++---- .../operators/{lstm_op.cu => lstm_op.cu.cc} | 1 - paddle/operators/lstm_op.h | 19 +-- paddle/operators/math/context_project.h | 28 ++--- paddle/operators/math/math_function.cc | 28 +++++ paddle/operators/math/math_function.cu | 35 ++++++ paddle/operators/math/math_function.h | 17 ++- paddle/operators/math/math_function_impl.h | 48 +++++++ paddle/operators/math/sequence2batch.cc | 23 ++++ paddle/operators/math/sequence2batch.cu | 32 +++++ paddle/operators/math/sequence2batch.h | 12 ++ .../{matmul_op.cu => matmul_op.cu.cc} | 0 paddle/operators/matmul_op.h | 10 +- paddle/operators/{mul_op.cu => mul_op.cu.cc} | 0 .../operators/{nccl_op.cu => nccl_op.cu.cc} | 0 .../{nccl_op_test.cu => nccl_op_test.cu.cc} | 0 .../{pool_cudnn_op.cu => pool_cudnn_op.cu.cc} | 0 .../operators/{pool_op.cu => pool_op.cu.cc} | 0 ...h_index_op.cu => pool_with_index_op.cu.cc} | 0 paddle/operators/pool_with_index_op.h | 13 +- .../{reshape_op.cu => reshape_op.cu.cc} | 0 ..._concat_op.cu => sequence_concat_op.cu.cc} | 0 ...ence_conv_op.cu => sequence_conv_op.cu.cc} | 2 - paddle/operators/sequence_conv_op.h | 3 +- ...oftmax_op.cu => sequence_softmax_op.cu.cc} | 0 .../{softmax_op.cu => softmax_op.cu.cc} | 0 paddle/operators/softmax_op.h | 3 + .../operators/{split_op.cu => split_op.cu.cc} | 0 .../{transpose_op.cu => transpose_op.cu.cc} | 0 paddle/operators/transpose_op.h | 119 +++++++----------- paddle/platform/dynload/cublas.h | 2 + .../paddle/v2/framework/tests/test_lstm_op.py | 3 +- .../v2/framework/tests/test_seq_conv.py | 57 +++++---- 43 files changed, 338 insertions(+), 199 deletions(-) rename paddle/operators/{batch_norm_op.cu => batch_norm_op.cu.cc} (100%) rename paddle/operators/{concat_op.cu => concat_op.cu.cc} (100%) rename paddle/operators/{conv2d_transpose_cudnn_op.cu => conv2d_transpose_cudnn_op.cu.cc} (96%) rename paddle/operators/{conv_cudnn_op.cu => conv_cudnn_op.cu.cc} (100%) rename paddle/operators/{conv_op.cu => conv_op.cu.cc} (100%) rename paddle/operators/{conv_transpose_op.cu => conv_transpose_op.cu.cc} (100%) rename paddle/operators/{fill_constant_batch_size_like_op.cu => fill_constant_batch_size_like_op.cu.cc} (100%) rename paddle/operators/{fill_zeros_like_op.cu => fill_zeros_like_op.cu.cc} (100%) rename paddle/operators/{gru_op.cu => gru_op.cu.cc} (97%) rename paddle/operators/{lstm_op.cu => lstm_op.cu.cc} (97%) create mode 100644 paddle/operators/math/math_function_impl.h rename paddle/operators/{matmul_op.cu => matmul_op.cu.cc} (100%) rename paddle/operators/{mul_op.cu => mul_op.cu.cc} (100%) rename paddle/operators/{nccl_op.cu => nccl_op.cu.cc} (100%) rename paddle/operators/{nccl_op_test.cu => nccl_op_test.cu.cc} (100%) rename paddle/operators/{pool_cudnn_op.cu => pool_cudnn_op.cu.cc} (100%) rename paddle/operators/{pool_op.cu => pool_op.cu.cc} (100%) rename paddle/operators/{pool_with_index_op.cu => pool_with_index_op.cu.cc} (100%) rename paddle/operators/{reshape_op.cu => reshape_op.cu.cc} (100%) rename paddle/operators/{sequence_concat_op.cu => sequence_concat_op.cu.cc} (100%) rename paddle/operators/{sequence_conv_op.cu => sequence_conv_op.cu.cc} (97%) rename paddle/operators/{sequence_softmax_op.cu => sequence_softmax_op.cu.cc} (100%) rename paddle/operators/{softmax_op.cu => softmax_op.cu.cc} (100%) rename paddle/operators/{split_op.cu => split_op.cu.cc} (100%) rename paddle/operators/{transpose_op.cu => transpose_op.cu.cc} (100%) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 29ce44c233..7eb8b3539f 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -9,6 +9,7 @@ function(op_library TARGET) set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} PARENT_SCOPE) set(cc_srcs) set(cu_srcs) + set(cu_cc_srcs) set(op_common_deps operator op_registry math_function) set(options "") set(oneValueArgs "") @@ -22,6 +23,9 @@ function(op_library TARGET) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc) list(APPEND cc_srcs ${TARGET}.cc) endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc) + list(APPEND cu_cc_srcs ${TARGET}.cu.cc) + endif() if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) list(APPEND cu_srcs ${TARGET}.cu) endif() @@ -29,6 +33,8 @@ function(op_library TARGET) foreach(src ${op_library_SRCS}) if (${src} MATCHES ".*\\.cu$") list(APPEND cu_srcs ${src}) + elseif(${src} MATCHES ".*\\.cu.cc$") + list(APPEND cu_cc_srcs ${src}) elseif(${src} MATCHES ".*\\.cc$") list(APPEND cc_srcs ${src}) else() @@ -43,7 +49,7 @@ function(op_library TARGET) endif() if (WITH_GPU) - nv_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} + nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) else() cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS} @@ -140,7 +146,9 @@ function(op_library TARGET) # pybind USE_CPU_ONLY_OP list(LENGTH cu_srcs cu_srcs_len) - if (${pybind_flag} EQUAL 0 AND ${cu_srcs_len} EQUAL 0) + list(LENGTH cu_cc_srcs cu_cc_srcs_len) + + if (${pybind_flag} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0) file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") set(pybind_flag 1) endif() @@ -219,6 +227,6 @@ cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc rnn/recurrent_op_utils.cc DEPS dynamic_recurrent_op) if(WITH_GPU) - nv_test(nccl_op_test SRCS nccl_op_test.cu DEPS nccl_op gpu_info device_context) + cc_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context) endif() cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) diff --git a/paddle/operators/batch_norm_op.cu b/paddle/operators/batch_norm_op.cu.cc similarity index 100% rename from paddle/operators/batch_norm_op.cu rename to paddle/operators/batch_norm_op.cu.cc diff --git a/paddle/operators/concat_op.cu b/paddle/operators/concat_op.cu.cc similarity index 100% rename from paddle/operators/concat_op.cu rename to paddle/operators/concat_op.cu.cc diff --git a/paddle/operators/conv2d_transpose_cudnn_op.cu b/paddle/operators/conv2d_transpose_cudnn_op.cu.cc similarity index 96% rename from paddle/operators/conv2d_transpose_cudnn_op.cu rename to paddle/operators/conv2d_transpose_cudnn_op.cu.cc index 694526ec01..eff058afc6 100644 --- a/paddle/operators/conv2d_transpose_cudnn_op.cu +++ b/paddle/operators/conv2d_transpose_cudnn_op.cu.cc @@ -200,9 +200,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel { T alpha = 1.0f, beta = 0.0f; if (input_grad) { T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); - auto t = framework::EigenVector::Flatten(*input_grad); - t.device(ctx.GetEigenDevice()) = - t.constant(static_cast(0)); + math::set_constant(ctx.device_context(), input_grad, 0); PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward( handle, &alpha, cudnn_output_desc, output_grad_data, @@ -214,9 +212,8 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel { // ------------------- cudnn conv backward filter --------------------- if (filter_grad) { T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); - auto t = framework::EigenVector::Flatten(*filter_grad); - t.device(ctx.GetEigenDevice()) = - t.constant(static_cast(0)); + math::set_constant(ctx.device_context(), filter_grad, 0); + // Gradient with respect to the filter PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc, diff --git a/paddle/operators/conv_cudnn_op.cu b/paddle/operators/conv_cudnn_op.cu.cc similarity index 100% rename from paddle/operators/conv_cudnn_op.cu rename to paddle/operators/conv_cudnn_op.cu.cc diff --git a/paddle/operators/conv_op.cu b/paddle/operators/conv_op.cu.cc similarity index 100% rename from paddle/operators/conv_op.cu rename to paddle/operators/conv_op.cu.cc diff --git a/paddle/operators/conv_transpose_op.cu b/paddle/operators/conv_transpose_op.cu.cc similarity index 100% rename from paddle/operators/conv_transpose_op.cu rename to paddle/operators/conv_transpose_op.cu.cc diff --git a/paddle/operators/fill_constant_batch_size_like_op.cu b/paddle/operators/fill_constant_batch_size_like_op.cu.cc similarity index 100% rename from paddle/operators/fill_constant_batch_size_like_op.cu rename to paddle/operators/fill_constant_batch_size_like_op.cu.cc index 298c196f1d..87e3697e28 100644 --- a/paddle/operators/fill_constant_batch_size_like_op.cu +++ b/paddle/operators/fill_constant_batch_size_like_op.cu.cc @@ -12,8 +12,8 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/framework/op_registry.h" #include "paddle/operators/fill_constant_batch_size_like_op.h" +#include "paddle/framework/op_registry.h" namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( diff --git a/paddle/operators/fill_zeros_like_op.cu b/paddle/operators/fill_zeros_like_op.cu.cc similarity index 100% rename from paddle/operators/fill_zeros_like_op.cu rename to paddle/operators/fill_zeros_like_op.cu.cc index a6d4ba64bd..2adb40cf90 100644 --- a/paddle/operators/fill_zeros_like_op.cu +++ b/paddle/operators/fill_zeros_like_op.cu.cc @@ -12,8 +12,8 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/framework/op_registry.h" #include "paddle/operators/fill_zeros_like_op.h" +#include "paddle/framework/op_registry.h" namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( diff --git a/paddle/operators/gru_op.cu b/paddle/operators/gru_op.cu.cc similarity index 97% rename from paddle/operators/gru_op.cu rename to paddle/operators/gru_op.cu.cc index 35538c74b4..0ceff94ec3 100644 --- a/paddle/operators/gru_op.cu +++ b/paddle/operators/gru_op.cu.cc @@ -12,7 +12,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU #include "paddle/operators/gru_op.h" namespace ops = paddle::operators; diff --git a/paddle/operators/gru_op.h b/paddle/operators/gru_op.h index ba90ec9816..437496e0ac 100644 --- a/paddle/operators/gru_op.h +++ b/paddle/operators/gru_op.h @@ -27,10 +27,6 @@ namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -template -using EigenMatrix = framework::EigenMatrix; - template class GRUKernel : public framework::OpKernel { public: @@ -57,19 +53,15 @@ class GRUKernel : public framework::OpKernel { bool is_reverse = context.Attr("is_reverse"); math::LoDTensor2BatchFunctor to_batch; - to_batch(context.device_context(), *input, *batch_gate, true, is_reverse); + auto& dev_ctx = context.device_context(); + to_batch(dev_ctx, *input, *batch_gate, true, is_reverse); - int frame_size = hidden_dims[1]; - int batch_size = hidden_dims[0]; - auto g = EigenMatrix::From(*batch_gate); - auto place = context.GetEigenDevice(); if (bias) { - auto b = EigenMatrix::From(*bias); - g.device(place) = g + - b.reshape(Eigen::array({{1, frame_size * 3}})) - .broadcast(Eigen::array({{batch_size, 1}})); + math::RowwiseAdd add_bias; + add_bias(dev_ctx, *batch_gate, *bias, batch_gate); } + int frame_size = hidden_dims[1]; math::hl_gru_value gru_value; gru_value.gateWeight = const_cast(weight_data); gru_value.stateWeight = @@ -89,7 +81,7 @@ class GRUKernel : public framework::OpKernel { gru_value.gateValue = gate_t.data(); gru_value.resetOutputValue = reset_hidden_prev_t.data(); math::GRUUnitFunctor::compute( - context.device_context(), gru_value, frame_size, cur_batch_size, + dev_ctx, gru_value, frame_size, cur_batch_size, math::ActiveType(context.Attr("activation")), math::ActiveType(context.Attr("gate_activation"))); gru_value.prevOutValue = gru_value.outputValue; @@ -97,7 +89,7 @@ class GRUKernel : public framework::OpKernel { math::Batch2LoDTensorFunctor to_seq; batch_hidden->set_lod(batch_gate->lod()); - to_seq(context.device_context(), *batch_hidden, *hidden); + to_seq(dev_ctx, *batch_hidden, *hidden); } void Compute(const framework::ExecutionContext& context) const override { @@ -138,15 +130,14 @@ class GRUGradKernel : public framework::OpKernel { batch_reset_hidden_prev_grad.mutable_data(hidden_dims, context.GetPlace()); math::SetConstant zero; - zero(context.device_context(), &batch_hidden_grad, static_cast(0.0)); - zero(context.device_context(), &batch_gate_grad, static_cast(0.0)); - zero(context.device_context(), &batch_reset_hidden_prev_grad, - static_cast(0.0)); + auto& dev_ctx = context.device_context(); + zero(dev_ctx, &batch_hidden_grad, static_cast(0.0)); + zero(dev_ctx, &batch_gate_grad, static_cast(0.0)); + zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast(0.0)); bool is_reverse = context.Attr("is_reverse"); batch_hidden_grad.set_lod(batch_hidden->lod()); - to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, false, - is_reverse); + to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse); math::hl_gru_value gru_value; gru_value.gateWeight = const_cast(weight_data); @@ -157,7 +148,7 @@ class GRUGradKernel : public framework::OpKernel { if (weight_grad) { gru_grad.gateWeightGrad = weight_grad->mutable_data(context.GetPlace()); - zero(context.device_context(), weight_grad, static_cast(0.0)); + zero(dev_ctx, weight_grad, static_cast(0.0)); gru_grad.stateWeightGrad = weight_grad->data() + 2 * frame_size * frame_size; } else { @@ -188,7 +179,7 @@ class GRUGradKernel : public framework::OpKernel { gru_value.prevOutValue = const_cast(h0_data); if (h0_grad) { T* h0_grad_data = h0_grad->mutable_data(context.GetPlace()); - zero(context.device_context(), h0_grad, static_cast(0.0)); + zero(dev_ctx, h0_grad, static_cast(0.0)); gru_grad.prevOutGrad = h0_grad_data; } else { gru_grad.prevOutGrad = nullptr; @@ -202,8 +193,7 @@ class GRUGradKernel : public framework::OpKernel { } math::GRUUnitGradFunctor::compute( - context.device_context(), gru_value, gru_grad, frame_size, - cur_batch_size, + dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, math::ActiveType(context.Attr("activation")), math::ActiveType(context.Attr("gate_activation"))); } @@ -211,14 +201,18 @@ class GRUGradKernel : public framework::OpKernel { input_grad->mutable_data(context.GetPlace()); math::Batch2LoDTensorFunctor to_seq; batch_gate_grad.set_lod(batch_gate->lod()); - to_seq(context.device_context(), batch_gate_grad, *input_grad); + to_seq(dev_ctx, batch_gate_grad, *input_grad); } if (bias_grad) { bias_grad->mutable_data(context.GetPlace()); - auto d_b = EigenMatrix::From(*bias_grad); - auto d_g = EigenMatrix::From(batch_gate_grad); - auto place = context.GetEigenDevice(); - d_b.device(place) = d_g.sum(Eigen::array({{0}})); + int m = static_cast(batch_gate_grad.dims()[0]); + int n = static_cast(batch_gate_grad.dims()[1]); + Tensor ones; + ones.mutable_data({m}, context.GetPlace()); + math::SetConstant set; + set(dev_ctx, &ones, static_cast(1)); + math::gemv(dev_ctx, true, m, n, 1., batch_gate_grad.data(), + ones.data(), 0., bias_grad->data()); } } diff --git a/paddle/operators/lstm_op.cu b/paddle/operators/lstm_op.cu.cc similarity index 97% rename from paddle/operators/lstm_op.cu rename to paddle/operators/lstm_op.cu.cc index 9ad5694155..610cbb03e8 100644 --- a/paddle/operators/lstm_op.cu +++ b/paddle/operators/lstm_op.cu.cc @@ -12,7 +12,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU #include "paddle/operators/lstm_op.h" namespace ops = paddle::operators; diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index fca84e2d8f..58fedaee9a 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -24,10 +24,6 @@ namespace operators { using LoDTensor = framework::LoDTensor; using Tensor = framework::Tensor; -template -using EigenMatrix = framework::EigenMatrix; - template inline void ReorderInitState(const platform::DeviceContext& ctx, const framework::Tensor& src, const size_t* index, @@ -65,16 +61,11 @@ class LSTMKernel : public framework::OpKernel { framework::DDim dims({in_dims[0], frame_size}); if (bias) { - Eigen::array extents({{1, 4 * frame_size}}); - Eigen::array offsets({{0, 0}}); - auto b = EigenMatrix::From(*bias); - auto gate = EigenMatrix::From(*batch_gate); - gate.device(ctx.GetEigenDevice()) = - gate + - b.slice(offsets, extents) - .reshape(Eigen::array({{1, frame_size * 4}})) - .broadcast( - Eigen::array({{static_cast(in_dims[0]), 1}})); + Tensor b = *bias; + b.Resize({bias->numel(), 1}); + Tensor gate_bias = b.Slice(0, 4 * frame_size); + math::RowwiseAdd add_bias; + add_bias(device_ctx, *batch_gate, gate_bias, batch_gate); } math::LstmMetaValue lstm_value; diff --git a/paddle/operators/math/context_project.h b/paddle/operators/math/context_project.h index e028336041..7dc76d0c60 100644 --- a/paddle/operators/math/context_project.h +++ b/paddle/operators/math/context_project.h @@ -14,9 +14,9 @@ limitations under the License. */ #pragma once -#include "paddle/framework/eigen.h" #include "paddle/framework/lod_tensor.h" #include "paddle/operators/math/im2col.h" +#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { @@ -24,9 +24,6 @@ namespace math { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -template -using EigenMatrix = framework::EigenMatrix; /* * \brief Context projection concatenates features in adjacent time-steps in @@ -94,6 +91,9 @@ class ContextProjectFunctor { auto lod_level_0 = in.lod()[0]; math::Im2ColFunctor im2col_ocf; + if (platform::is_gpu_place(context.GetPlace())) { + LOG(INFO) << "========= gpu =========="; + } int input_row_begin, input_row_end; int sequence_height, sequence_width; @@ -150,9 +150,7 @@ class ContextProjectFunctor { Tensor out_t_sub = out_t.Slice(k * context_length, k * context_length + padding_size); Tensor w_sub = padding_data.Slice(k, k + padding_size); - auto out_t_sub_e = EigenMatrix::From(out_t_sub); - auto w_sub_e = EigenMatrix::From(w_sub); - out_t_sub_e.device(*context.GetEigenDevice()) = w_sub_e; + out_t_sub.CopyFrom(w_sub, context.GetPlace(), context); } } if (down_pad > 0) { // add down pad @@ -182,9 +180,7 @@ class ContextProjectFunctor { (down_pad_begin_row + t) * context_length); Tensor w_sub = padding_data.Slice( up_pad + padding_idx, up_pad + padding_idx + padding_size); - auto out_t_sub_e = EigenMatrix::From(out_t_sub); - auto w_sub_e = EigenMatrix::From(w_sub); - out_t_sub_e.device(*context.GetEigenDevice()) = w_sub_e; + out_t_sub.CopyFrom(w_sub, context.GetPlace(), context); } } out_t.Resize({sequence_height, context_length * sequence_width}); @@ -260,10 +256,8 @@ class ContextProjectGradFunctor { Tensor out_t_sub = out_t.Slice(k * context_length, k * context_length + padding_size); Tensor w_sub = padding_data.Slice(k, k + padding_size); - auto out_t_sub_e = EigenMatrix::From(out_t_sub); - auto w_sub_e = EigenMatrix::From(w_sub); - w_sub_e.device(*context.GetEigenDevice()) = - w_sub_e + out_t_sub_e; + axpy(context, w_sub.numel(), static_cast(1), + out_t_sub.data(), w_sub.data()); } } if (down_pad > 0) { @@ -294,10 +288,8 @@ class ContextProjectGradFunctor { (down_pad_begin_row + t) * context_length); Tensor w_sub = padding_data.Slice( up_pad + padding_idx, up_pad + padding_idx + padding_size); - auto out_t_sub_e = EigenMatrix::From(out_t_sub); - auto w_sub_e = EigenMatrix::From(w_sub); - w_sub_e.device(*context.GetEigenDevice()) = - w_sub_e + out_t_sub_e; + axpy(context, w_sub.numel(), static_cast(1), + out_t_sub.data(), w_sub.data()); } } out_t.Resize({sequence_height, context_length * sequence_width}); diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 09c3f0b1e6..034e5ca0f0 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/operators/math/math_function.h" #include "paddle/framework/data_type.h" +#include "paddle/operators/math/math_function_impl.h" namespace paddle { namespace operators { @@ -232,7 +233,34 @@ void gemv(const platform::DeviceContext& context, cblas_dgemv(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); } +template <> +void axpy(const platform::DeviceContext& context, + const int n, const float alpha, + const float* x, float* y) { + cblas_saxpy(n, alpha, x, 1, y, 1); +} + +template <> +void axpy(const platform::DeviceContext& context, + const int n, const double alpha, + const double* x, double* y) { + cblas_daxpy(n, alpha, x, 1, y, 1); +} + template struct SetConstant; +template struct SetConstant; +template struct SetConstant; + +#define DEFINE_CPU_TRANS(RANK) \ + template struct Transpose; \ + template struct Transpose; + +DEFINE_CPU_TRANS(1); +DEFINE_CPU_TRANS(2); +DEFINE_CPU_TRANS(3); +DEFINE_CPU_TRANS(4); +DEFINE_CPU_TRANS(5); +DEFINE_CPU_TRANS(6); struct TensorSetConstant { TensorSetConstant(framework::Tensor* tensor, float value) diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 255e480680..67cac93b8d 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#define EIGEN_USE_GPU #include "paddle/framework/data_type.h" #include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/math_function_impl.h" namespace paddle { namespace operators { @@ -231,7 +233,40 @@ void gemv(const platform::DeviceContext& context, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1)); } +template <> +void axpy(const platform::DeviceContext& context, + const int n, const float alpha, + const float* x, float* y) { + PADDLE_ENFORCE(platform::dynload::cublasSaxpy( + reinterpret_cast(context) + .cublas_handle(), + n, alpha, x, 1, y, 1)); +} + +template <> +void axpy(const platform::DeviceContext& context, + const int n, const double alpha, + const double* x, double* y) { + PADDLE_ENFORCE(platform::dynload::cublasDaxpy( + reinterpret_cast(context) + .cublas_handle(), + n, alpha, x, 1, y, 1)); +} + template struct SetConstant; +template struct SetConstant; +template struct SetConstant; + +#define DEFINE_GPU_TRANS(RANK) \ + template struct Transpose; \ + template struct Transpose; + +DEFINE_GPU_TRANS(1); +DEFINE_GPU_TRANS(2); +DEFINE_GPU_TRANS(3); +DEFINE_GPU_TRANS(4); +DEFINE_GPU_TRANS(5); +DEFINE_GPU_TRANS(6); struct TensorSetConstant { TensorSetConstant(const platform::DeviceContext& context, diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index c2aaa1d7b7..6b40a08375 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -93,14 +93,21 @@ void gemv(const platform::DeviceContext& context, const bool trans_a, const int M, const int N, const T alpha, const T* A, const T* B, const T beta, T* C); +template +void axpy(const platform::DeviceContext& context, const int n, const T alpha, + const T* x, T* y); + +template +struct Transpose { + void operator()(const platform::DeviceContext& context, + const framework::Tensor& in, framework::Tensor* out, + const std::vector& axis); +}; + template struct SetConstant { void operator()(const platform::DeviceContext& context, - framework::Tensor* tensor, T num) { - auto t = framework::EigenVector::Flatten(*tensor); - t.device(*context.GetEigenDevice()) = - t.constant(static_cast(num)); - } + framework::Tensor* tensor, T num); }; template diff --git a/paddle/operators/math/math_function_impl.h b/paddle/operators/math/math_function_impl.h new file mode 100644 index 0000000000..dd279cbbfd --- /dev/null +++ b/paddle/operators/math/math_function_impl.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/data_type.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math { + +template +void SetConstant::operator()(const platform::DeviceContext& context, + framework::Tensor* tensor, T num) { + auto t = framework::EigenVector::Flatten(*tensor); + t.device(*context.GetEigenDevice()) = + t.constant(static_cast(num)); +} + +template +void Transpose::operator()( + const platform::DeviceContext& context, const framework::Tensor& in, + framework::Tensor* out, const std::vector& axis) { + Eigen::array permute; + for (int i = 0; i < Rank; i++) { + permute[i] = axis[i]; + } + auto in_dim = in.dims(); + auto out_dim = out->dims(); + + auto eigen_in = framework::EigenTensor::From(in); + auto eigen_out = framework::EigenTensor::From(*out); + auto* dev = context.GetEigenDevice(); + eigen_out.device(*dev) = eigen_in.shuffle(permute); +} +} +} +} diff --git a/paddle/operators/math/sequence2batch.cc b/paddle/operators/math/sequence2batch.cc index 5b3bde02fb..5170b595e6 100644 --- a/paddle/operators/math/sequence2batch.cc +++ b/paddle/operators/math/sequence2batch.cc @@ -56,6 +56,29 @@ template class LoDTensor2BatchFunctor; template class Batch2LoDTensorFunctor; template class Batch2LoDTensorFunctor; +template +struct RowwiseAdd { + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, const framework::Tensor& bias, + framework::Tensor* output) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ(bias.numel(), size); + PADDLE_ENFORCE_EQ(output->dims(), in_dims); + + auto in = EigenMatrix::From(input); + auto b = EigenMatrix::From(bias); + auto out = EigenMatrix::From(*output); + Eigen::array bshape({{1, static_cast(size)}}); + Eigen::array bcast({{static_cast(in_dims[0]), 1}}); + out.device(*context.GetEigenDevice()) = + in + b.reshape(bshape).broadcast(bcast); + } +}; + +template struct RowwiseAdd; +template struct RowwiseAdd; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/sequence2batch.cu b/paddle/operators/math/sequence2batch.cu index 8d04653832..e386e63a9a 100644 --- a/paddle/operators/math/sequence2batch.cu +++ b/paddle/operators/math/sequence2batch.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#define EIGEN_USE_GPU #include "paddle/operators/math/sequence2batch.h" namespace paddle { @@ -73,6 +74,37 @@ template class LoDTensor2BatchFunctor; template class Batch2LoDTensorFunctor; template class Batch2LoDTensorFunctor; +template +__global__ void RowwiseAddKernel(const T* src, const T* b, T* dst, + int64_t height, int64_t width) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < height * width; + i += blockDim.x * gridDim.x) { + int64_t h = i / width; + int64_t w = i % width; + dst[h * width + w] = src[h * width + w] + b[w]; + } +} + +template +struct RowwiseAdd { + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, const framework::Tensor& bias, + framework::Tensor* output) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ(bias.numel(), size); + PADDLE_ENFORCE_EQ(output->dims(), in_dims); + int block = 512; + int grid = (input.numel() + block - 1) / block; + auto stream = + reinterpret_cast(context).stream(); + RowwiseAddKernel<<>>( + input.data(), bias.data(), output->data(), in_dims[0], size); + } +}; + +template struct RowwiseAdd; +template struct RowwiseAdd; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h index 794c7d4397..9e7d863081 100644 --- a/paddle/operators/math/sequence2batch.h +++ b/paddle/operators/math/sequence2batch.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" @@ -21,6 +22,10 @@ namespace paddle { namespace operators { namespace math { +template +using EigenMatrix = framework::EigenMatrix; + template class CopyMatrixRowsFunctor { public: @@ -159,6 +164,13 @@ class Batch2LoDTensorFunctor { } }; +template +struct RowwiseAdd { + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, const framework::Tensor& bias, + framework::Tensor* output); +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/matmul_op.cu b/paddle/operators/matmul_op.cu.cc similarity index 100% rename from paddle/operators/matmul_op.cu rename to paddle/operators/matmul_op.cu.cc diff --git a/paddle/operators/matmul_op.h b/paddle/operators/matmul_op.h index 5ce30740c9..1e4aa48b70 100644 --- a/paddle/operators/matmul_op.h +++ b/paddle/operators/matmul_op.h @@ -15,8 +15,8 @@ #pragma once #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/matmul.h" -#include "paddle/operators/transpose_op.h" namespace paddle { namespace operators { @@ -74,11 +74,13 @@ Tensor CombineBatchAndN(const framework::ExecutionContext& context, Tensor output; auto in_dims = input.dims(); if (in_dims.size() == 3) { - output.Resize(in_dims); + output.Resize({in_dims[1], in_dims[0], in_dims[2]}); output.mutable_data(context.GetPlace()); - EigenTranspose(context, input, output, {1, 0, 2}); + std::vector axis = {1, 0, 2}; + math::Transpose trans; + trans(context.device_context(), input, &output, axis); std::vector out_dims = {in_dims[1], in_dims[0] * in_dims[2]}; - output.Resize(make_ddim(out_dims)); + output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); } else { output.ShareDataWith(input); } diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu.cc similarity index 100% rename from paddle/operators/mul_op.cu rename to paddle/operators/mul_op.cu.cc diff --git a/paddle/operators/nccl_op.cu b/paddle/operators/nccl_op.cu.cc similarity index 100% rename from paddle/operators/nccl_op.cu rename to paddle/operators/nccl_op.cu.cc diff --git a/paddle/operators/nccl_op_test.cu b/paddle/operators/nccl_op_test.cu.cc similarity index 100% rename from paddle/operators/nccl_op_test.cu rename to paddle/operators/nccl_op_test.cu.cc diff --git a/paddle/operators/pool_cudnn_op.cu b/paddle/operators/pool_cudnn_op.cu.cc similarity index 100% rename from paddle/operators/pool_cudnn_op.cu rename to paddle/operators/pool_cudnn_op.cu.cc diff --git a/paddle/operators/pool_op.cu b/paddle/operators/pool_op.cu.cc similarity index 100% rename from paddle/operators/pool_op.cu rename to paddle/operators/pool_op.cu.cc diff --git a/paddle/operators/pool_with_index_op.cu b/paddle/operators/pool_with_index_op.cu.cc similarity index 100% rename from paddle/operators/pool_with_index_op.cu rename to paddle/operators/pool_with_index_op.cu.cc diff --git a/paddle/operators/pool_with_index_op.h b/paddle/operators/pool_with_index_op.h index ea37de84ab..fdab9dc20b 100644 --- a/paddle/operators/pool_with_index_op.h +++ b/paddle/operators/pool_with_index_op.h @@ -81,22 +81,21 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel { if (in_x_grad) { in_x_grad->mutable_data(context.GetPlace()); - auto temp = framework::EigenVector::Flatten(*in_x_grad); - temp.device(context.GetEigenDevice()) = - temp.constant(static_cast(0)); + auto& device_ctx = context.device_context(); + math::set_constant(device_ctx, in_x_grad, 0); switch (ksize.size()) { case 2: { paddle::operators::math::MaxPool2dWithIndexGradFunctor pool2d_backward; - pool2d_backward(context.device_context(), *in_x_grad, *out_grad, - *mask, ksize, strides, paddings); + pool2d_backward(device_ctx, *in_x_grad, *out_grad, *mask, ksize, + strides, paddings); } break; case 3: { paddle::operators::math::MaxPool3dWithIndexGradFunctor pool3d_backward; - pool3d_backward(context.device_context(), *in_x_grad, *out_grad, - *mask, ksize, strides, paddings); + pool3d_backward(device_ctx, *in_x_grad, *out_grad, *mask, ksize, + strides, paddings); } break; default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } } diff --git a/paddle/operators/reshape_op.cu b/paddle/operators/reshape_op.cu.cc similarity index 100% rename from paddle/operators/reshape_op.cu rename to paddle/operators/reshape_op.cu.cc diff --git a/paddle/operators/sequence_concat_op.cu b/paddle/operators/sequence_concat_op.cu.cc similarity index 100% rename from paddle/operators/sequence_concat_op.cu rename to paddle/operators/sequence_concat_op.cu.cc diff --git a/paddle/operators/sequence_conv_op.cu b/paddle/operators/sequence_conv_op.cu.cc similarity index 97% rename from paddle/operators/sequence_conv_op.cu rename to paddle/operators/sequence_conv_op.cu.cc index 4c0c673a51..6106b0e46c 100644 --- a/paddle/operators/sequence_conv_op.cu +++ b/paddle/operators/sequence_conv_op.cu.cc @@ -12,8 +12,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU - #include "paddle/operators/sequence_conv_op.h" namespace ops = paddle::operators; diff --git a/paddle/operators/sequence_conv_op.h b/paddle/operators/sequence_conv_op.h index a57e1752bb..5e7f4f7daf 100644 --- a/paddle/operators/sequence_conv_op.h +++ b/paddle/operators/sequence_conv_op.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/math/context_project.h" #include "paddle/operators/math/math_function.h" @@ -66,8 +65,10 @@ class SequenceConvKernel : public framework::OpKernel { padding_trainable, context_start, context_length, context_stride, up_pad, down_pad); + context.device_context().Finish(); math::matmul(context.device_context(), col, false, filter, false, static_cast(1.0), out, static_cast(0.0)); + context.device_context().Finish(); } }; diff --git a/paddle/operators/sequence_softmax_op.cu b/paddle/operators/sequence_softmax_op.cu.cc similarity index 100% rename from paddle/operators/sequence_softmax_op.cu rename to paddle/operators/sequence_softmax_op.cu.cc diff --git a/paddle/operators/softmax_op.cu b/paddle/operators/softmax_op.cu.cc similarity index 100% rename from paddle/operators/softmax_op.cu rename to paddle/operators/softmax_op.cu.cc diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 44d1e63f1b..8e33a70e04 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -27,6 +27,9 @@ class SoftmaxKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); auto* Y = context.Output("Y"); + if (platform::is_gpu_place(context.GetPlace())) { + LOG(INFO) << "==========gpu========="; + } // allocate memory on device. Y->mutable_data(context.GetPlace()); diff --git a/paddle/operators/split_op.cu b/paddle/operators/split_op.cu.cc similarity index 100% rename from paddle/operators/split_op.cu rename to paddle/operators/split_op.cu.cc diff --git a/paddle/operators/transpose_op.cu b/paddle/operators/transpose_op.cu.cc similarity index 100% rename from paddle/operators/transpose_op.cu rename to paddle/operators/transpose_op.cu.cc diff --git a/paddle/operators/transpose_op.h b/paddle/operators/transpose_op.h index aaa3f47ab5..e296032f41 100644 --- a/paddle/operators/transpose_op.h +++ b/paddle/operators/transpose_op.h @@ -14,27 +14,44 @@ #pragma once -#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { -template -void EigenTranspose(const framework::ExecutionContext& context, - const framework::Tensor& in, framework::Tensor& out, - std::vector axis) { - Eigen::array permute; - for (int i = 0; i < Rank; i++) { - permute[i] = axis[i]; +template +inline void TransCompute(const int dim, const platform::DeviceContext& dev_ctx, + const framework::Tensor& in, framework::Tensor* out, + const std::vector& axis) { + switch (dim) { + case 1: + math::Transpose trans1; + trans1(dev_ctx, in, out, axis); + break; + case 2: + math::Transpose trans2; + trans2(dev_ctx, in, out, axis); + break; + case 3: + math::Transpose trans3; + trans3(dev_ctx, in, out, axis); + break; + case 4: + math::Transpose trans4; + trans4(dev_ctx, in, out, axis); + break; + case 5: + math::Transpose trans5; + trans5(dev_ctx, in, out, axis); + break; + case 6: + math::Transpose trans6; + trans6(dev_ctx, in, out, axis); + break; + default: + PADDLE_THROW("Tensors with rank at most 6 are supported"); } - auto in_dim = in.dims(); - auto out_dim = out.dims(); - - auto eigen_in = framework::EigenTensor::From(in); - auto eigen_out = framework::EigenTensor::From(out); - auto& dev = context.GetEigenDevice(); - eigen_out.device(dev) = eigen_in.shuffle(permute); } template @@ -47,28 +64,8 @@ class TransposeKernel : public framework::OpKernel { std::vector axis = context.Attr>("axis"); int ndims = axis.size(); - switch (ndims) { - case 1: - EigenTranspose(context, *x, *out, axis); - break; - case 2: - EigenTranspose(context, *x, *out, axis); - break; - case 3: - EigenTranspose(context, *x, *out, axis); - break; - case 4: - EigenTranspose(context, *x, *out, axis); - break; - case 5: - EigenTranspose(context, *x, *out, axis); - break; - case 6: - EigenTranspose(context, *x, *out, axis); - break; - default: - PADDLE_THROW("Tensors with rank at most 6 are supported"); - } + auto& dev_ctx = context.device_context(); + TransCompute(ndims, dev_ctx, *x, out, axis); } }; @@ -80,47 +77,19 @@ class TransposeGradKernel : public framework::OpKernel { context.Input(framework::GradVarName("Out")); auto* x_grad = context.Output(framework::GradVarName("X")); - if (x_grad) { - x_grad->mutable_data(context.GetPlace()); - - std::vector axis = context.Attr>("axis"); - std::vector reversed_axis(axis); + if (!x_grad) return; - for (size_t i = 0; i < axis.size(); i++) { - reversed_axis[axis[i]] = i; - } - - int ndims = axis.size(); + x_grad->mutable_data(context.GetPlace()); + std::vector axis = context.Attr>("axis"); + std::vector reversed_axis(axis); - switch (ndims) { - case 1: - EigenTranspose(context, *out_grad, *x_grad, - reversed_axis); - break; - case 2: - EigenTranspose(context, *out_grad, *x_grad, - reversed_axis); - break; - case 3: - EigenTranspose(context, *out_grad, *x_grad, - reversed_axis); - break; - case 4: - EigenTranspose(context, *out_grad, *x_grad, - reversed_axis); - break; - case 5: - EigenTranspose(context, *out_grad, *x_grad, - reversed_axis); - break; - case 6: - EigenTranspose(context, *out_grad, *x_grad, - reversed_axis); - break; - default: - PADDLE_THROW("Tensors with rank at most 6 are supported"); - } + for (size_t i = 0; i < axis.size(); i++) { + reversed_axis[axis[i]] = i; } + + int ndims = axis.size(); + auto& dev_ctx = context.device_context(); + TransCompute(ndims, dev_ctx, *out_grad, x_grad, reversed_axis); } }; diff --git a/paddle/platform/dynload/cublas.h b/paddle/platform/dynload/cublas.h index 6b64539b0a..61a22d9db3 100644 --- a/paddle/platform/dynload/cublas.h +++ b/paddle/platform/dynload/cublas.h @@ -62,6 +62,8 @@ extern void *cublas_dso_handle; DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasSaxpy_v2); \ + __macro(cublasDaxpy_v2); \ __macro(cublasSgemv_v2); \ __macro(cublasDgemv_v2); \ __macro(cublasSgemm_v2); \ diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py index 77f062e8c8..5c817ba03c 100644 --- a/python/paddle/v2/framework/tests/test_lstm_op.py +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -180,6 +180,7 @@ class TestLstmOp(OpTest): ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4) +""" class TestLstmOpHasInitial(TestLstmOp): def set_argument(self): self.lod = [[0, 2, 5, 7]] @@ -280,7 +281,7 @@ class TestLstmOpNotUsePeepholes(TestLstmOp): self.has_initial_state = False self.is_reverse = True self.use_peepholes = False - +""" if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_seq_conv.py b/python/paddle/v2/framework/tests/test_seq_conv.py index 14edc5f953..65292a1a20 100644 --- a/python/paddle/v2/framework/tests/test_seq_conv.py +++ b/python/paddle/v2/framework/tests/test_seq_conv.py @@ -122,7 +122,7 @@ class TestSeqProject(OpTest): max_relative_error=0.05, no_grad_set=set(['X', 'Filter'])) - def test_check_grad_Filter(self): + def not_test_check_grad_Filter(self): self.check_grad( ['Filter'], 'Out', @@ -165,34 +165,33 @@ class TestSeqProject(OpTest): self.output_represention = 8 # output feature size -class TestSeqProjectCase1(TestSeqProject): - def init_test_case(self): - self.input_row = 11 - self.context_start = -1 - self.context_length = 3 - self.padding_trainable = True - self.context_stride = 1 - - self.input_size = [self.input_row, 23] - self.lod = [[0, 4, 5, 8, self.input_row]] - self.output_represention = 8 # output feature size - - -class TestSeqProjectCase2(TestSeqProject): - def init_test_case(self): - self.input_row = 25 - self.context_start = 2 - self.context_length = 3 - self.padding_trainable = True - self.context_stride = 1 - - self.input_size = [self.input_row, 23] - idx = range(self.input_size[0]) - del idx[0] - self.lod = [[0] + np.sort(random.sample(idx, 8)).tolist() + - [self.input_size[0]]] - self.output_represention = 8 # output feature size - +#class TestSeqProjectCase1(TestSeqProject): +# def init_test_case(self): +# self.input_row = 11 +# self.context_start = -1 +# self.context_length = 3 +# self.padding_trainable = True +# self.context_stride = 1 +# +# self.input_size = [self.input_row, 23] +# self.lod = [[0, 4, 5, 8, self.input_row]] +# self.output_represention = 8 # output feature size +# +# +#class TestSeqProjectCase2(TestSeqProject): +# def init_test_case(self): +# self.input_row = 25 +# self.context_start = 2 +# self.context_length = 3 +# self.padding_trainable = True +# self.context_stride = 1 +# +# self.input_size = [self.input_row, 23] +# idx = range(self.input_size[0]) +# del idx[0] +# self.lod = [[0] + np.sort(random.sample(idx, 8)).tolist() + +# [self.input_size[0]]] +# self.output_represention = 8 # output feature size if __name__ == '__main__': unittest.main() From 5f217099384be95e96a933200e6ed9f5bf1a38fe Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Sat, 11 Nov 2017 21:29:56 +0800 Subject: [PATCH 08/25] Fix bug. --- paddle/operators/math/CMakeLists.txt | 4 ++-- paddle/operators/math/context_project.h | 3 --- paddle/operators/math/math_function_impl.h | 3 +-- paddle/operators/softmax_op.h | 3 --- 4 files changed, 3 insertions(+), 10 deletions(-) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index ab7f23f570..c196724e4d 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -10,7 +10,7 @@ if(WITH_GPU) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) - nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context) + nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) @@ -22,7 +22,7 @@ else() cc_library(pooling SRCS pooling.cc DEPS device_context) cc_library(sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function) cc_library(vol2col SRCS vol2col.cc DEPS device_context) - cc_library(context_project SRCS context_project.cc DEPS device_context) + cc_library(context_project SRCS context_project.cc DEPS device_context math_function) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) diff --git a/paddle/operators/math/context_project.h b/paddle/operators/math/context_project.h index 7dc76d0c60..563024dac0 100644 --- a/paddle/operators/math/context_project.h +++ b/paddle/operators/math/context_project.h @@ -91,9 +91,6 @@ class ContextProjectFunctor { auto lod_level_0 = in.lod()[0]; math::Im2ColFunctor im2col_ocf; - if (platform::is_gpu_place(context.GetPlace())) { - LOG(INFO) << "========= gpu =========="; - } int input_row_begin, input_row_end; int sequence_height, sequence_width; diff --git a/paddle/operators/math/math_function_impl.h b/paddle/operators/math/math_function_impl.h index dd279cbbfd..daa28f26da 100644 --- a/paddle/operators/math/math_function_impl.h +++ b/paddle/operators/math/math_function_impl.h @@ -23,8 +23,7 @@ template void SetConstant::operator()(const platform::DeviceContext& context, framework::Tensor* tensor, T num) { auto t = framework::EigenVector::Flatten(*tensor); - t.device(*context.GetEigenDevice()) = - t.constant(static_cast(num)); + t.device(*context.GetEigenDevice()) = t.constant(static_cast(num)); } template diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 8e33a70e04..44d1e63f1b 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -27,9 +27,6 @@ class SoftmaxKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); auto* Y = context.Output("Y"); - if (platform::is_gpu_place(context.GetPlace())) { - LOG(INFO) << "==========gpu========="; - } // allocate memory on device. Y->mutable_data(context.GetPlace()); From 91d4fc694117a9c294a399c2a5b5e060749b2160 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Mon, 13 Nov 2017 14:21:33 +0800 Subject: [PATCH 09/25] Fix compling for softmax_with_cross_entropy_op. --- paddle/operators/CMakeLists.txt | 5 +- paddle/operators/math/CMakeLists.txt | 12 +-- paddle/operators/math/cross_entropy.h | 1 - paddle/operators/math/math_function_impl.h | 1 + paddle/operators/math/softmax.cc | 3 + paddle/operators/math/softmax.cu | 3 + paddle/operators/math/softmax.h | 69 +------------ paddle/operators/math/softmax_impl.h | 98 +++++++++++++++++++ .../softmax_with_cross_entropy_op.cc | 1 - 9 files changed, 117 insertions(+), 76 deletions(-) create mode 100644 paddle/operators/math/softmax_impl.h diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 7eb8b3539f..4b71c72551 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -168,11 +168,12 @@ set(DEPS_OPS recurrent_op dynamic_recurrent_op softmax_with_cross_entropy_op + softmax_op + sequence_softmax_op sum_op pool_op pool_with_index_op conv_op - lstm_op conv_transpose_op nccl_op sequence_conv_op @@ -187,6 +188,8 @@ set(DEPS_OPS op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) op_library(cross_entropy_op DEPS cross_entropy) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) +op_library(softmax_op DEPS softmax) +op_library(sequence_softmax_op DEPS softmax) op_library(conv_op DEPS vol2col) op_library(sum_op DEPS net_op selected_rows_functor) op_library(pool_op DEPS pooling) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index c196724e4d..b9417f1d7f 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,12 +1,12 @@ add_subdirectory(detail) if(WITH_GPU) - nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator) + nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context) nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor) nv_library(selected_rows_functor SRCS selected_rows_functor.cc selected_rows_functor.cu DEPS selected_rows math_function) nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor) - nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) - nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) + nv_library(softmax SRCS softmax.cc softmax.cu DEPS device_context) + nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS device_context) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) @@ -15,10 +15,10 @@ if(WITH_GPU) nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) else() - cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) + cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) - cc_library(softmax SRCS softmax.cc DEPS operator) - cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) + cc_library(softmax SRCS softmax.cc DEPS device_context) + cc_library(cross_entropy SRCS cross_entropy.cc DEPS device_context) cc_library(pooling SRCS pooling.cc DEPS device_context) cc_library(sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function) cc_library(vol2col SRCS vol2col.cc DEPS device_context) diff --git a/paddle/operators/math/cross_entropy.h b/paddle/operators/math/cross_entropy.h index 0ab6827ffa..70ed9ddd55 100644 --- a/paddle/operators/math/cross_entropy.h +++ b/paddle/operators/math/cross_entropy.h @@ -14,7 +14,6 @@ #pragma once #include "paddle/framework/eigen.h" -#include "paddle/framework/operator.h" #include "paddle/framework/tensor.h" #include "paddle/platform/hostdevice.h" diff --git a/paddle/operators/math/math_function_impl.h b/paddle/operators/math/math_function_impl.h index daa28f26da..dba2d02c27 100644 --- a/paddle/operators/math/math_function_impl.h +++ b/paddle/operators/math/math_function_impl.h @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#pragma once #include "paddle/framework/data_type.h" #include "paddle/operators/math/math_function.h" diff --git a/paddle/operators/math/softmax.cc b/paddle/operators/math/softmax.cc index 0ba8197ab8..3e2f15d6c2 100644 --- a/paddle/operators/math/softmax.cc +++ b/paddle/operators/math/softmax.cc @@ -13,13 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/math/softmax.h" +#include "paddle/operators/math/softmax_impl.h" namespace paddle { namespace operators { namespace math { template class SoftmaxFunctor; +template class SoftmaxFunctor; template class SoftmaxGradFunctor; +template class SoftmaxGradFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/softmax.cu b/paddle/operators/math/softmax.cu index 99f988d51e..4dbab51d46 100644 --- a/paddle/operators/math/softmax.cu +++ b/paddle/operators/math/softmax.cu @@ -15,13 +15,16 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/operators/math/softmax.h" +#include "paddle/operators/math/softmax_impl.h" namespace paddle { namespace operators { namespace math { template class SoftmaxFunctor; +template class SoftmaxFunctor; template class SoftmaxGradFunctor; +template class SoftmaxGradFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/softmax.h b/paddle/operators/math/softmax.h index b7f627eee7..fe10746502 100644 --- a/paddle/operators/math/softmax.h +++ b/paddle/operators/math/softmax.h @@ -13,60 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/framework/eigen.h" -#include "paddle/framework/operator.h" #include "paddle/framework/tensor.h" namespace paddle { namespace operators { namespace math { -template -using EigenMatrix = framework::EigenMatrix; - -template -struct ValueClip { - HOSTDEVICE T operator()(const T& x) const { - const T kThreshold = -64.; - return x < kThreshold ? kThreshold : x; - } -}; - template class SoftmaxFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor* X, framework::Tensor* Y) { - auto logits = EigenMatrix::From(*X); - auto softmax = EigenMatrix::From(*Y); - - const int kBatchDim = 0; - const int kClassDim = 1; - - const int batch_size = logits.dimension(kBatchDim); - const int num_classes = logits.dimension(kClassDim); - - Eigen::DSizes along_class(kClassDim); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, num_classes); - - auto shifted_logits = (logits - - logits.maximum(along_class) - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class)) - .unaryExpr(ValueClip()); - - softmax.device(*context.GetEigenDevice()) = shifted_logits.exp(); - softmax.device(*context.GetEigenDevice()) = - (softmax * - softmax.sum(along_class) - .inverse() - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class)); - } + const framework::Tensor* X, framework::Tensor* Y); }; template @@ -74,29 +31,7 @@ class SoftmaxGradFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor* y, const framework::Tensor* y_grad, - framework::Tensor* x_grad) { - auto softmax = EigenMatrix::From(*y); - auto softmax_grad = EigenMatrix::From(*y_grad); - auto logits_grad = EigenMatrix::From(*x_grad); - - const int kBatchDim = 0; - const int kClassDim = 1; - - const int batch_size = softmax.dimension(kBatchDim); - const int num_classes = softmax.dimension(kClassDim); - - Eigen::DSizes along_class(kClassDim); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, num_classes); - - auto dot = (softmax * softmax_grad) - .sum(along_class) - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class); - logits_grad.device(*context.GetEigenDevice()) = - (softmax_grad - dot) * softmax; - } + framework::Tensor* x_grad); }; } // namespace math diff --git a/paddle/operators/math/softmax_impl.h b/paddle/operators/math/softmax_impl.h new file mode 100644 index 0000000000..05793eeb3e --- /dev/null +++ b/paddle/operators/math/softmax_impl.h @@ -0,0 +1,98 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace operators { +namespace math { + +template +using EigenMatrix = framework::EigenMatrix; + +template +struct ValueClip { + HOSTDEVICE T operator()(const T& x) const { + const T kThreshold = -64.; + return x < kThreshold ? kThreshold : x; + } +}; + +template +void SoftmaxFunctor::operator()( + const platform::DeviceContext& context, const framework::Tensor* X, + framework::Tensor* Y) { + auto logits = EigenMatrix::From(*X); + auto softmax = EigenMatrix::From(*Y); + + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + + auto shifted_logits = (logits - + logits.maximum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)) + .unaryExpr(ValueClip()); + + softmax.device(*context.GetEigenDevice()) = shifted_logits.exp(); + softmax.device(*context.GetEigenDevice()) = + (softmax * + softmax.sum(along_class) + .inverse() + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); +} + +template +void SoftmaxGradFunctor::operator()( + const platform::DeviceContext& context, const framework::Tensor* y, + const framework::Tensor* y_grad, framework::Tensor* x_grad) { + auto softmax = EigenMatrix::From(*y); + auto softmax_grad = EigenMatrix::From(*y_grad); + auto logits_grad = EigenMatrix::From(*x_grad); + + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = softmax.dimension(kBatchDim); + const int num_classes = softmax.dimension(kClassDim); + + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + + auto dot = (softmax * softmax_grad) + .sum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class); + logits_grad.device(*context.GetEigenDevice()) = + (softmax_grad - dot) * softmax; +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index ed96e8cee5..3dbb62d2e5 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/operators/softmax_with_cross_entropy_op.h" #include -#include namespace paddle { namespace operators { From e9082bb78e098cd106dc1f667afc8bb0204791b5 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Mon, 13 Nov 2017 21:37:50 +0800 Subject: [PATCH 10/25] Resume unit testing. --- paddle/operators/cross_entropy_op.cu | 2 - paddle/operators/math/math_function.cu | 6 +- paddle/operators/sequence_conv_op.h | 2 - .../paddle/v2/framework/tests/test_lstm_op.py | 3 +- .../v2/framework/tests/test_seq_conv.py | 57 ++++++++++--------- 5 files changed, 33 insertions(+), 37 deletions(-) diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 530b319a44..6212e39dfd 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -23,8 +23,6 @@ template __global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X, const int64_t* label, const int N, const int D) { - // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. - // CUDA_1D_KERNEL_LOOP(i, N) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { int idx = i * D + label[i]; diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 57b995f36d..6daec3797e 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -240,7 +240,7 @@ void axpy(const platform::DeviceContext& context, PADDLE_ENFORCE(platform::dynload::cublasSaxpy( reinterpret_cast(context) .cublas_handle(), - n, alpha, x, 1, y, 1)); + n, &alpha, x, 1, y, 1)); } template <> @@ -250,7 +250,7 @@ void axpy(const platform::DeviceContext& context, PADDLE_ENFORCE(platform::dynload::cublasDaxpy( reinterpret_cast(context) .cublas_handle(), - n, alpha, x, 1, y, 1)); + n, &alpha, x, 1, y, 1)); } template struct SetConstant; @@ -270,7 +270,7 @@ DEFINE_GPU_TRANS(6); struct TensorSetConstantGPU { TensorSetConstantGPU(const platform::DeviceContext& context, - framework::Tensor* tensor, float value) + framework::Tensor* tensor, float value) : context_(context), tensor_(tensor), value_(value) {} template diff --git a/paddle/operators/sequence_conv_op.h b/paddle/operators/sequence_conv_op.h index 5e7f4f7daf..312c915394 100644 --- a/paddle/operators/sequence_conv_op.h +++ b/paddle/operators/sequence_conv_op.h @@ -65,10 +65,8 @@ class SequenceConvKernel : public framework::OpKernel { padding_trainable, context_start, context_length, context_stride, up_pad, down_pad); - context.device_context().Finish(); math::matmul(context.device_context(), col, false, filter, false, static_cast(1.0), out, static_cast(0.0)); - context.device_context().Finish(); } }; diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py index 5c817ba03c..77f062e8c8 100644 --- a/python/paddle/v2/framework/tests/test_lstm_op.py +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -180,7 +180,6 @@ class TestLstmOp(OpTest): ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4) -""" class TestLstmOpHasInitial(TestLstmOp): def set_argument(self): self.lod = [[0, 2, 5, 7]] @@ -281,7 +280,7 @@ class TestLstmOpNotUsePeepholes(TestLstmOp): self.has_initial_state = False self.is_reverse = True self.use_peepholes = False -""" + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_seq_conv.py b/python/paddle/v2/framework/tests/test_seq_conv.py index 65292a1a20..14edc5f953 100644 --- a/python/paddle/v2/framework/tests/test_seq_conv.py +++ b/python/paddle/v2/framework/tests/test_seq_conv.py @@ -122,7 +122,7 @@ class TestSeqProject(OpTest): max_relative_error=0.05, no_grad_set=set(['X', 'Filter'])) - def not_test_check_grad_Filter(self): + def test_check_grad_Filter(self): self.check_grad( ['Filter'], 'Out', @@ -165,33 +165,34 @@ class TestSeqProject(OpTest): self.output_represention = 8 # output feature size -#class TestSeqProjectCase1(TestSeqProject): -# def init_test_case(self): -# self.input_row = 11 -# self.context_start = -1 -# self.context_length = 3 -# self.padding_trainable = True -# self.context_stride = 1 -# -# self.input_size = [self.input_row, 23] -# self.lod = [[0, 4, 5, 8, self.input_row]] -# self.output_represention = 8 # output feature size -# -# -#class TestSeqProjectCase2(TestSeqProject): -# def init_test_case(self): -# self.input_row = 25 -# self.context_start = 2 -# self.context_length = 3 -# self.padding_trainable = True -# self.context_stride = 1 -# -# self.input_size = [self.input_row, 23] -# idx = range(self.input_size[0]) -# del idx[0] -# self.lod = [[0] + np.sort(random.sample(idx, 8)).tolist() + -# [self.input_size[0]]] -# self.output_represention = 8 # output feature size +class TestSeqProjectCase1(TestSeqProject): + def init_test_case(self): + self.input_row = 11 + self.context_start = -1 + self.context_length = 3 + self.padding_trainable = True + self.context_stride = 1 + + self.input_size = [self.input_row, 23] + self.lod = [[0, 4, 5, 8, self.input_row]] + self.output_represention = 8 # output feature size + + +class TestSeqProjectCase2(TestSeqProject): + def init_test_case(self): + self.input_row = 25 + self.context_start = 2 + self.context_length = 3 + self.padding_trainable = True + self.context_stride = 1 + + self.input_size = [self.input_row, 23] + idx = range(self.input_size[0]) + del idx[0] + self.lod = [[0] + np.sort(random.sample(idx, 8)).tolist() + + [self.input_size[0]]] + self.output_represention = 8 # output feature size + if __name__ == '__main__': unittest.main() From 2673657684ac12f2e086d8651c4462f39938c550 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 14 Nov 2017 17:20:21 +0800 Subject: [PATCH 11/25] Move RowwiseAdd functor to math_funcion and Add ColwiseSum functor. --- paddle/operators/gru_op.h | 10 ++---- paddle/operators/lstm_op.h | 15 +++------ paddle/operators/math/math_function.cc | 5 +++ paddle/operators/math/math_function.cu | 5 +++ paddle/operators/math/math_function.h | 13 ++++++++ paddle/operators/math/math_function_impl.h | 37 +++++++++++++++++++++- paddle/operators/math/sequence2batch.cc | 23 -------------- paddle/operators/math/sequence2batch.cu | 31 ------------------ paddle/operators/math/sequence2batch.h | 7 ---- 9 files changed, 66 insertions(+), 80 deletions(-) diff --git a/paddle/operators/gru_op.h b/paddle/operators/gru_op.h index 437496e0ac..55e9cc4a98 100644 --- a/paddle/operators/gru_op.h +++ b/paddle/operators/gru_op.h @@ -205,14 +205,8 @@ class GRUGradKernel : public framework::OpKernel { } if (bias_grad) { bias_grad->mutable_data(context.GetPlace()); - int m = static_cast(batch_gate_grad.dims()[0]); - int n = static_cast(batch_gate_grad.dims()[1]); - Tensor ones; - ones.mutable_data({m}, context.GetPlace()); - math::SetConstant set; - set(dev_ctx, &ones, static_cast(1)); - math::gemv(dev_ctx, true, m, n, 1., batch_gate_grad.data(), - ones.data(), 0., bias_grad->data()); + math::ColwiseSum col_sum; + col_sum(dev_ctx, batch_gate_grad, bias_grad); } } diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index 58fedaee9a..721aa42c92 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -341,16 +341,11 @@ class LSTMGradKernel : public framework::OpKernel { } if (bias && bias_g) { /* backward bias */ - int m = static_cast(batch_gate_g.dims()[0]); - int n = static_cast(batch_gate_g.dims()[1]); - - Tensor ones; - ones.mutable_data({m}, ctx.GetPlace()); - math::SetConstant set; - set(device_ctx, &ones, static_cast(1.0)); - - math::gemv(device_ctx, true, m, n, 1., batch_gate_g.data(), - ones.data(), 0., bias_g->data()); + Tensor b_g = *bias_g; + b_g.Resize({bias_g->numel(), 1}); + Tensor gate_bias_g = b_g.Slice(0, 4 * frame_size); + math::ColwiseSum col_sum; + col_sum(device_ctx, batch_gate_g, &gate_bias_g); } if (h0 && h0_g) { diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index a137ffe57f..5ee0917886 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -308,6 +308,11 @@ void set_constant(const platform::DeviceContext& context, #endif } +template struct RowwiseAdd; +template struct RowwiseAdd; +template struct ColwiseSum; +template struct ColwiseSum; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 6daec3797e..38c04b97f9 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -292,6 +292,11 @@ void set_constant_with_place( TensorSetConstantGPU(context, tensor, value)); } +template struct RowwiseAdd; +template struct RowwiseAdd; +template struct ColwiseSum; +template struct ColwiseSum; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 6b40a08375..ffb99f5380 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -117,6 +117,19 @@ void set_constant_with_place(const platform::DeviceContext& context, void set_constant(const platform::DeviceContext& context, framework::Tensor* tensor, float value); +template +struct RowwiseAdd { + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, const framework::Tensor& vec, + framework::Tensor* output); +}; + +template +struct ColwiseSum { + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor* vec); +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function_impl.h b/paddle/operators/math/math_function_impl.h index dba2d02c27..4dc17a4e52 100644 --- a/paddle/operators/math/math_function_impl.h +++ b/paddle/operators/math/math_function_impl.h @@ -43,6 +43,41 @@ void Transpose::operator()( auto* dev = context.GetEigenDevice(); eigen_out.device(*dev) = eigen_in.shuffle(permute); } + +template +void RowwiseAdd::operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& vector, + framework::Tensor* output) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ(vector.numel(), size); + PADDLE_ENFORCE_EQ(output->dims(), in_dims); + + auto in = framework::EigenMatrix::From(input); + auto vec = framework::EigenMatrix::From(vector); + auto out = framework::EigenMatrix::From(*output); + Eigen::array shape({{1, static_cast(size)}}); + Eigen::array bcast({{static_cast(in_dims[0]), 1}}); + out.device(*context.GetEigenDevice()) = + in + vec.reshape(shape).broadcast(bcast); } + +template +void ColwiseSum::operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + framework::Tensor* vector) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ(vector->numel(), size); + + auto vec = framework::EigenMatrix::From(*vector); + auto in = framework::EigenMatrix::From(input); + Eigen::array shape({{1, static_cast(size)}}); + vec.reshape(shape).device(*context.GetEigenDevice()) = + in.sum(Eigen::array({{0}})).reshape(shape); } -} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence2batch.cc b/paddle/operators/math/sequence2batch.cc index 5170b595e6..5b3bde02fb 100644 --- a/paddle/operators/math/sequence2batch.cc +++ b/paddle/operators/math/sequence2batch.cc @@ -56,29 +56,6 @@ template class LoDTensor2BatchFunctor; template class Batch2LoDTensorFunctor; template class Batch2LoDTensorFunctor; -template -struct RowwiseAdd { - void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, const framework::Tensor& bias, - framework::Tensor* output) { - auto in_dims = input.dims(); - auto size = input.numel() / in_dims[0]; - PADDLE_ENFORCE_EQ(bias.numel(), size); - PADDLE_ENFORCE_EQ(output->dims(), in_dims); - - auto in = EigenMatrix::From(input); - auto b = EigenMatrix::From(bias); - auto out = EigenMatrix::From(*output); - Eigen::array bshape({{1, static_cast(size)}}); - Eigen::array bcast({{static_cast(in_dims[0]), 1}}); - out.device(*context.GetEigenDevice()) = - in + b.reshape(bshape).broadcast(bcast); - } -}; - -template struct RowwiseAdd; -template struct RowwiseAdd; - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/sequence2batch.cu b/paddle/operators/math/sequence2batch.cu index e386e63a9a..c5d968aeb2 100644 --- a/paddle/operators/math/sequence2batch.cu +++ b/paddle/operators/math/sequence2batch.cu @@ -74,37 +74,6 @@ template class LoDTensor2BatchFunctor; template class Batch2LoDTensorFunctor; template class Batch2LoDTensorFunctor; -template -__global__ void RowwiseAddKernel(const T* src, const T* b, T* dst, - int64_t height, int64_t width) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < height * width; - i += blockDim.x * gridDim.x) { - int64_t h = i / width; - int64_t w = i % width; - dst[h * width + w] = src[h * width + w] + b[w]; - } -} - -template -struct RowwiseAdd { - void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, const framework::Tensor& bias, - framework::Tensor* output) { - auto in_dims = input.dims(); - auto size = input.numel() / in_dims[0]; - PADDLE_ENFORCE_EQ(bias.numel(), size); - PADDLE_ENFORCE_EQ(output->dims(), in_dims); - int block = 512; - int grid = (input.numel() + block - 1) / block; - auto stream = - reinterpret_cast(context).stream(); - RowwiseAddKernel<<>>( - input.data(), bias.data(), output->data(), in_dims[0], size); - } -}; - -template struct RowwiseAdd; -template struct RowwiseAdd; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h index 9e7d863081..73295ddbcb 100644 --- a/paddle/operators/math/sequence2batch.h +++ b/paddle/operators/math/sequence2batch.h @@ -164,13 +164,6 @@ class Batch2LoDTensorFunctor { } }; -template -struct RowwiseAdd { - void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, const framework::Tensor& bias, - framework::Tensor* output); -}; - } // namespace math } // namespace operators } // namespace paddle From a5c9e6ace25352fffb722a3ed413419c5877fcf2 Mon Sep 17 00:00:00 2001 From: xuwei06 Date: Tue, 14 Nov 2017 17:35:29 -0800 Subject: [PATCH 12/25] Fix conv2d bias The size of the bias parameter should be the number of filters. --- python/paddle/v2/fluid/io.py | 43 ++++++++++++++++--- python/paddle/v2/fluid/layer_helper.py | 26 +++++------ python/paddle/v2/fluid/layers.py | 5 ++- .../paddle/v2/fluid/tests/test_parameter.py | 22 ++++++---- 4 files changed, 65 insertions(+), 31 deletions(-) diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index 394a171c67..d1263c3e91 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -35,7 +35,7 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None): :param executor: executor that save variable :param dirname: directory path - :param main_program: program. If vars is None, then filter all variables in this + :param main_program: program. If vars is None, then filter all variables in this program which fit `predicate`. Default g_program. :param predicate: The Predicate describes a callable that returns a variable as a bool. If it returns true, the variables will be saved. @@ -96,11 +96,11 @@ def load_vars(executor, dirname, main_program=None, vars=None, predicate=None): :param executor: executor that save variable :param dirname: directory path - :param main_program: program. If vars is None, then filter all variables in this + :param main_program: program. If vars is None, then filter all variables in this program which fit `predicate`. Default g_program. :param predicate: The Predicate describes a callable that returns a variable as a bool. If it returns true, the variables will be loaded. - :param vars: variables need to be loaded. If specify vars, program & + :param vars: variables need to be loaded. If specify vars, program & predicate will be ignored :return: None """ @@ -157,15 +157,15 @@ def save_inference_model(dirname, executor, main_program=None): """ - Build a model especially for inference, + Build a model especially for inference, and save it to directory by the executor. :param dirname: directory path :param feeded_var_names: Names of variables that need to be feeded data during inference :param target_vars: Variables from which we can get inference results. :param executor: executor that save inference model - :param main_program: original program, which will be pruned to build the inference model. - Default g_program. + :param main_program: original program, which will be pruned to build the inference model. + Default g_main_program. :return: None """ @@ -234,3 +234,34 @@ def load_inference_model(dirname, executor): fetch_vars = [program.global_block().var(name) for name in fetch_var_names] return [program, feed_var_names, fetch_vars] + + +def get_parameter_value(para, executor): + """ + Get the LoDTensor for the parameter + + :param executor: executor for retrieving the value + :param para: the given parameter + :return: the LoDTensor for the parameter + """ + get_program = Program() + block = get_program.global_block() + new_var = _clone_var_in_block_(block, para) + return executor.run(get_program, feed={}, fetch_list=[new_var])[0] + + +def get_parameter_value_by_name(name, executor, program=None): + """ + Get the LoDTensor for paramter with the given name + + :param executor: executor for retrieving the value + :param name: the name of the parameter + :param program: the program where the variable is found + Default g_main_program. + :return: the LoDTensor for the variable + """ + if program is None: + program = g_main_program + var = program.global_block().var(name) + assert is_parameter(var) + return get_parameter_value(var, executor) diff --git a/python/paddle/v2/fluid/layer_helper.py b/python/paddle/v2/fluid/layer_helper.py index 9dc3c119ea..0a9ed81888 100644 --- a/python/paddle/v2/fluid/layer_helper.py +++ b/python/paddle/v2/fluid/layer_helper.py @@ -72,7 +72,7 @@ class LayerHelper(object): @property def bias_attr(self): - default = {'name': None, 'initializer': XavierInitializer()} + default = {'name': None, 'initializer': ConstantInitializer()} bias_attr = self.kwargs.get('bias_attr', None) if bias_attr is None: bias_attr = default @@ -149,24 +149,19 @@ class LayerHelper(object): persistable=True, initializer=initializer) - def append_bias_op(self, input_var, num_flatten_dims=None): + def append_bias_op(self, input_var, dim_start=1, dim_end=None): """ - Append bias operator and return its output. If the user does not set + Append bias operator and return its output. If the user does not set bias_attr, append_bias_op will return input_var - + :param input_var: the input variable. The len(input_var.shape) is larger or equal than 2. - :param num_flatten_dims: The input tensor will be flatten as a matrix - when adding bias. - `matrix.shape = product(input_var.shape[0:num_flatten_dims]), product( - input_var.shape[num_flatten_dims:])` + :param dim_start: + :param dim_end: the shape of the bias will be + input_var.shape(dim_start:dim_end). The bias is broadcast to other + dimensions and added to input_var to get the output """ - if num_flatten_dims is None: - num_flatten_dims = self.kwargs.get('num_flatten_dims', None) - if num_flatten_dims is None: - num_flatten_dims = 1 - - size = list(input_var.shape[num_flatten_dims:]) + size = list(input_var.shape[dim_start:dim_end]) bias_attr = self.bias_attr if not bias_attr: return input_var @@ -178,7 +173,8 @@ class LayerHelper(object): type='elementwise_add', inputs={'X': [input_var], 'Y': [b]}, - outputs={'Out': [tmp]}) + outputs={'Out': [tmp]}, + attrs={'axis': dim_start}) return tmp def append_activation(self, input_var): diff --git a/python/paddle/v2/fluid/layers.py b/python/paddle/v2/fluid/layers.py index b582f2ef6d..771a313598 100644 --- a/python/paddle/v2/fluid/layers.py +++ b/python/paddle/v2/fluid/layers.py @@ -250,7 +250,7 @@ def _convert_(name): def _generate_doc_string_(op_proto): """ Generate docstring by OpProto - + Args: op_proto (framework_pb2.OpProto): a protobuf message typed OpProto @@ -676,6 +676,7 @@ def conv2d(input, filter_shape = [num_filters, num_filter_channels] + filter_size std = (2.0 / (filter_size[0]**2 * num_channels))**0.5 + print 'name=', name, 'std=', std filter = helper.create_parameter( attr=helper.param_attr, shape=filter_shape, @@ -694,7 +695,7 @@ def conv2d(input, 'paddings': padding, 'groups': groups}) - pre_act = helper.append_bias_op(pre_bias, 1) + pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) return helper.append_activation(pre_act) diff --git a/python/paddle/v2/fluid/tests/test_parameter.py b/python/paddle/v2/fluid/tests/test_parameter.py index 71a1bd2aaf..a633d22c2b 100644 --- a/python/paddle/v2/fluid/tests/test_parameter.py +++ b/python/paddle/v2/fluid/tests/test_parameter.py @@ -1,26 +1,32 @@ import unittest from paddle.v2.fluid.framework import g_main_program import paddle.v2.fluid.core as core +from paddle.v2.fluid.executor import Executor +import paddle.v2.fluid.io as io +from paddle.v2.fluid.initializer import ConstantInitializer +import numpy as np class TestParameter(unittest.TestCase): def test_param(self): - b = g_main_program.create_block() + shape = [784, 100] + val = 1.0625 + b = g_main_program.global_block() param = b.create_parameter( name='fc.w', - shape=[784, 100], + shape=shape, dtype='float32', - initialize_attr={ - 'type': 'uniform_random', - 'seed': 13, - 'min': -5.0, - 'max': 5.0 - }) + initializer=ConstantInitializer(val)) self.assertIsNotNone(param) self.assertEqual('fc.w', param.name) self.assertEqual((784, 100), param.shape) self.assertEqual(core.DataType.FP32, param.data_type) self.assertEqual(0, param.block.idx) + exe = Executor(core.CPUPlace()) + p = exe.run(g_main_program, fetch_list=[param])[0] + self.assertTrue(np.allclose(np.array(p), np.ones(shape) * val)) + p = io.get_parameter_value_by_name('fc.w', exe, g_main_program) + self.assertTrue(np.allclose(np.array(p), np.ones(shape) * val)) if __name__ == '__main__': From 356d6954043923d30ef8b1b116b66cbfa1dca7e1 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 14 Nov 2017 19:19:57 +0800 Subject: [PATCH 13/25] follow comments --- paddle/operators/conv_op.cc | 40 ++-- paddle/operators/conv_op.h | 257 +++++++++--------------- paddle/operators/conv_transpose_op.cc | 23 ++- paddle/operators/conv_transpose_op.h | 52 +++-- paddle/operators/math/context_project.h | 19 +- paddle/operators/math/im2col.cc | 168 ++++++++-------- paddle/operators/math/im2col.cu | 160 +++++++-------- paddle/operators/math/im2col.h | 25 ++- paddle/operators/math/im2col_test.cc | 26 ++- paddle/operators/math/vol2col.cc | 112 +++++------ paddle/operators/math/vol2col.cu | 96 ++++----- paddle/operators/math/vol2col.h | 29 ++- paddle/operators/math/vol2col_test.cc | 21 +- 13 files changed, 487 insertions(+), 541 deletions(-) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index a848b9b49c..e1a11a38b3 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/operators/conv_op.h" +#include namespace paddle { namespace operators { @@ -53,7 +54,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { "The number of output channels should be divided by groups."); std::vector output_shape({in_dims[0], filter_dims[0]}); - for (size_t i = 0; i < paddings.size(); ++i) { + for (size_t i = 0; i < strides.size(); ++i) { PADDLE_ENFORCE(in_dims[i + 2] + 2 * paddings[i] - (dilations[i] * (filter_dims[i + 2] - 1) + 1) > 0, @@ -61,8 +62,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { "dilations, the output size is less than 0, please check " "again."); output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2], - dilations[i], paddings[i], paddings[i], - strides[i])); + dilations[i], paddings[i], strides[i])); } ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); } @@ -86,9 +86,15 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, AddOutput("Output", "(Tensor) The output tensor of convolution operator. " "The format of output tensor is also NCHW."); - AddAttr>("strides", "strides of convolution operator.") + AddAttr>("strides", + "(vector default:{1, 1}), the " + "strides(h_stride, w_stride) of " + "convolution operator.") .SetDefault({1, 1}); - AddAttr>("paddings", "paddings of convolution operator.") + AddAttr>("paddings", + "(vector default:{0, 0}), the " + "paddings(h_pad, w_pad) of " + "convolution operator.") .SetDefault({0, 0}); AddAttr( "groups", @@ -99,9 +105,10 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, "is only connected to the second half of the input channels.") .SetDefault(1); AddAttr>("dilations", - "(vector default:{1, 1}), the dilations of " + "(vector default:{1, 1}), the " + "dilations(h_dilation, w_dilation) of " "convolution operator.") - .SetDefault(std::vector{1, 1}); + .SetDefault({1, 1}); AddComment(R"DOC( Convolution Operator. @@ -147,13 +154,15 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, AddOutput("Output", "(Tensor) The output tensor of convolution operator." "The format of output tensor is also NCDHW."); - AddAttr>( - "strides", - "(vector, default:{0, 0, 0}), the strides of convolution operator.") + AddAttr>("strides", + "(vector, default:{1, 1, 1}), the " + "strides(d_stride, h_stride, w_stride) of " + "convolution operator.") .SetDefault({1, 1, 1}); - AddAttr>( - "paddings", - "(vector, default:{0, 0, 0}), the paddings of convolution operator.") + AddAttr>("paddings", + "(vector, default:{0, 0, 0}), the " + "paddings(d_pad, h_pad, w_pad) of convolution " + "operator.") .SetDefault({0, 0, 0}); AddAttr( "groups", @@ -164,10 +173,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, "is only connected to the second half of the input channels.") .SetDefault(1); AddAttr>("dilations", - "(vector default:{1, 1, 1}), the dilations of " + "(vector default:{1, 1, 1}), the " + "dilations(d_dilation, h_dilation, w_dilation) of " "convolution operator. Currently, conv3d doesn't " "support dilation.") - .SetDefault(std::vector{1, 1, 1}); + .SetDefault({1, 1, 1}); AddComment(R"DOC( Convolution3D Operator. diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index af2c8fb163..fac5f1d0e2 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -28,24 +28,22 @@ using Tensor = framework::Tensor; // Base convolution operator definations for other conv // like operators to reuse the implementation. inline int OutputSize(int input_size, int filter_size, int dilation, - int padding_up, int padding_down, int stride) { - int output_size = (input_size + padding_up + padding_down - - (dilation * (filter_size - 1) + 1)) / - stride + - 1; + int padding, int stride) { + const int dkernel = dilation * (filter_size - 1) + 1; + const int output_size = (input_size + 2 * padding - dkernel) / stride + 1; return output_size; } -inline bool NotExpand(std::vector& filter_dim, - std::vector& strides, std::vector& paddings, - std::vector& dilations) { +inline bool IsExpand(std::vector& filter_dim, + std::vector& strides, std::vector& paddings, + std::vector& dilations) { bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; for (size_t j = 0; j < strides.size(); ++j) { - filter_1 &= (static_cast(filter_dim[j]) == 1); - strides_1 &= (strides[j] == 1); - padding_0 &= (paddings[j] == 0); - dilation_1 &= (dilations[j] == 1); + filter_1 = filter_1 && (static_cast(filter_dim[j]) == 1); + strides_1 = strides_1 && (strides[j] == 1); + padding_0 = padding_0 && (paddings[j] == 0); + dilation_1 = dilation_1 && (dilations[j] == 1); } - return filter_1 && strides_1 && padding_0 && dilation_1; + return !(filter_1 && strides_1 && padding_0 && dilation_1); } // Define Op classes in .h file so that other conv @@ -65,14 +63,12 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker { class ConvOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override; }; class ConvOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override; }; @@ -88,9 +84,9 @@ class GemmConvKernel : public framework::OpKernel { Tensor* output = context.Output("Output"); output->mutable_data(context.GetPlace()); + int groups = context.Attr("groups"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - int groups = context.Attr("groups"); std::vector dilations = context.Attr>("dilations"); const int batch_size = static_cast(input->dims()[0]); @@ -122,13 +118,13 @@ class GemmConvKernel : public framework::OpKernel { framework::DDim col_matrix_shape = framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); - bool not_expand = NotExpand(filter_shape_vec, strides, paddings, dilations); + bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); Tensor col; // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. Tensor col_matrix; - if (!not_expand) { + if (is_expand) { col.mutable_data(col_shape, context.GetPlace()); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); @@ -149,51 +145,37 @@ class GemmConvKernel : public framework::OpKernel { int in_step = static_cast(input->dims()[1]) / groups; int out_step = static_cast(output->dims()[1]) / groups; - if (!not_expand) { - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - for (int g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; - if (filter_shape_vec.size() == 2) { - // im2col - math::Im2ColFunctor im2col; - im2col(context.device_context(), in_slice, col, dilations[0], - dilations[1], strides[0], strides[1], paddings[0], - paddings[0], paddings[1], paddings[1]); - } else if (filter_shape_vec.size() == 3) { - // vol2col - math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), in_slice, col, dilations[0], - dilations[1], dilations[2], strides[0], strides[1], - strides[2], paddings[0], paddings[1], paddings[2]); - } + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, false, - col_matrix, false, T(1.0), &out_slice, T(0.0)); - } - } - } else { - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - for (int g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + for (int g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + if (!is_expand) { col.ShareDataWith(in_slice); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); - - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, false, - col_matrix, false, T(1.0), &out_slice, T(0.0)); + } else if (filter_shape_vec.size() == 2) { + // im2col + im2col(context.device_context(), in_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (filter_shape_vec.size() == 3) { + // vol2col + vol2col(context.device_context(), in_slice, dilations, strides, + paddings, &col); } + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), filter_slice, false, + col_matrix, false, T(1.0), &out_slice, T(0.0)); } } } @@ -217,9 +199,9 @@ class GemmConvGradKernel : public framework::OpKernel { if (!input_grad && !filter_grad) return; + int groups = context.Attr("groups"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - int groups = context.Attr("groups"); std::vector dilations = context.Attr>("dilations"); const int batch_size = static_cast(input->dims()[0]); @@ -270,13 +252,13 @@ class GemmConvGradKernel : public framework::OpKernel { int in_step = static_cast(input->dims()[1]) / groups; int out_step = static_cast(output_grad->dims()[1]) / groups; - bool not_expand = NotExpand(filter_shape_vec, strides, paddings, dilations); + bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); Tensor col; // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. Tensor col_matrix; - if (!not_expand) { + if (is_expand) { col.mutable_data(col_shape, context.GetPlace()); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); @@ -288,61 +270,38 @@ class GemmConvGradKernel : public framework::OpKernel { input_grad->mutable_data(context.GetPlace()); set_zero(context.device_context(), input_grad, static_cast(0)); - if (!not_expand) { - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_grad_batch = - input_grad->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // gemm - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = - filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, true, - out_grad_slice, false, T(1.0), &col_matrix, - T(0.0)); - Tensor in_grad_slice = - in_grad_batch.Slice(g * in_step, (g + 1) * in_step); - - if (filter_shape_vec.size() == 2) { - math::Col2ImFunctor col2im; - col2im(context.device_context(), in_grad_slice, col, dilations[0], - dilations[1], strides[0], strides[1], paddings[0], - paddings[0], paddings[1], paddings[1]); - - } else if (filter_shape_vec.size() == 3) { - math::Col2VolFunctor col2vol; - col2vol(context.device_context(), in_grad_slice, col, - dilations[0], dilations[1], dilations[2], strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); - } - } - } - } else { - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_grad_batch = - input_grad->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // gemm - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = - filter.Slice(g * out_step, (g + 1) * out_step); - - Tensor in_grad_slice = - in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + math::Col2VolFunctor col2vol; + math::Col2ImFunctor col2im; + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // gemm + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + + Tensor in_grad_slice = + in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + + if (!is_expand) { col_matrix.ShareDataWith(in_grad_slice); col_matrix.Resize(col_matrix_shape); - - math::matmul(context.device_context(), filter_slice, true, - out_grad_slice, false, T(1.0), &col_matrix, - T(0.0)); + } + math::matmul(context.device_context(), filter_slice, true, + out_grad_slice, false, T(1.0), &col_matrix, + T(0.0)); + + if (is_expand && filter_shape_vec.size() == 2) { + col2im(context.device_context(), col, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &in_grad_slice); + } else if (is_expand && filter_shape_vec.size() == 3) { + col2vol(context.device_context(), col, dilations, strides, paddings, + &in_grad_slice); } } } @@ -353,60 +312,38 @@ class GemmConvGradKernel : public framework::OpKernel { Tensor filter_grad_ = *filter_grad; filter_grad_.Resize(filter_matrix_shape); set_zero(context.device_context(), filter_grad, static_cast(0)); + math::Im2ColFunctor im2col; + math::Vol2ColFunctor vol2col; + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // im2col + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - if (!not_expand) { - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // im2col - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - - if (filter_shape_vec.size() == 2) { - math::Im2ColFunctor im2col; - im2col(context.device_context(), in_slice, col, dilations[0], - dilations[1], strides[0], strides[1], paddings[0], - paddings[0], paddings[1], paddings[1]); - } else if (filter_shape_vec.size() == 3) { - math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), in_slice, col, dilations[0], - dilations[1], dilations[2], strides[0], strides[1], - strides[2], paddings[0], paddings[1], paddings[2]); - } - - // gemm - Tensor filter_grad_slice = - filter_grad_.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), out_grad_slice, - false, col_matrix, true, T(1.0), - &filter_grad_slice, T(1.0)); - } - } - } else { - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // im2col - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - + if (!is_expand) { col.ShareDataWith(in_slice); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); - - // gemm - Tensor filter_grad_slice = - filter_grad_.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), out_grad_slice, - false, col_matrix, true, T(1.0), - &filter_grad_slice, T(1.0)); + } else if (filter_shape_vec.size() == 2) { + im2col(context.device_context(), in_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (filter_shape_vec.size() == 3) { + vol2col(context.device_context(), in_slice, dilations, strides, + paddings, &col); } + + // gemm + Tensor filter_grad_slice = + filter_grad_.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), out_grad_slice, + false, col_matrix, true, T(1.0), + &filter_grad_slice, T(1.0)); } } } diff --git a/paddle/operators/conv_transpose_op.cc b/paddle/operators/conv_transpose_op.cc index 50081779a5..6f47a6d6a0 100644 --- a/paddle/operators/conv_transpose_op.cc +++ b/paddle/operators/conv_transpose_op.cc @@ -51,7 +51,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { "as the number of filters."); std::vector output_shape({in_dims[0], filter_dims[1]}); - for (size_t i = 0; i < paddings.size(); ++i) { + for (size_t i = 0; i < strides.size(); ++i) { output_shape.push_back((in_dims[i + 2] - 1) * strides[i] + filter_dims[i + 2]); } @@ -77,13 +77,14 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker( AddOutput("Output", "(Tensor) The output tensor of convolution transpose operator. " "The format of output tensor is also NCHW."); - AddAttr>( - "strides", - "(vector defalut:{1, 1}), strides of convolution transpose operator.") + AddAttr>("strides", + "(vector defalut:{1, 1}), strides of " + "convolution transpose operator.") .SetDefault({1, 1}); AddAttr>( "paddings", - "(vector defalut:{0, 0}), paddings of convolution transpose operator.") + "(vector defalut:{0, 0}), paddings(h_pad, w_pad) of convolution " + "transpose operator.") .SetDefault({0, 0}); AddComment(R"DOC( Convolution2D Transpose Operator. @@ -132,13 +133,13 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker( "Where N is batch size, C is " "the number of channels, D is the depth of the feature, H is the " "height of the feature, and W is the width of the feature."); - AddAttr>( - "strides", - "(vector defalut:{1, 1, 1}), strides of convolution transpose operator.") + AddAttr>("strides", + "(vector defalut:{1, 1, 1}), strides of " + "convolution transpose operator.") .SetDefault({1, 1, 1}); - AddAttr>( - "paddings", - "(vector defalut:{0, 0, 0}), paddings of convolution transpose operator.") + AddAttr>("paddings", + "(vector defalut:{0, 0, 0}), paddings(d_pad, " + "h_pad, w_pad) of convolution transpose operator.") .SetDefault({0, 0, 0}); AddComment(R"DOC( Convolution3D Transpose Operator. diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index 18ca6b20e0..4b2bd60437 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -43,16 +43,12 @@ class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { class ConvTransposeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - protected: void InferShape(framework::InferShapeContext* ctx) const override; }; class ConvTransposeOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - protected: void InferShape(framework::InferShapeContext* ctx) const override; }; @@ -66,13 +62,11 @@ class GemmConvTransposeKernel : public framework::OpKernel { Tensor* output = context.Output("Output"); std::vector strides = context.Attr>("strides"); + // Actually, no paddings and groups allowed in conv transpose. + std::vector paddings = context.Attr>("paddings"); // TODO(Zhuoyuan): Paddings can be added in future. // groups will alway be disabled in conv2dtranspose. - int dilaiton_d = 1; - int dilation_h = 1; - int dilation_w = 1; - const int batch_size = static_cast(input->dims()[0]); // input_shape_vec: {h, w} or {d, h, w} @@ -124,6 +118,10 @@ class GemmConvTransposeKernel : public framework::OpKernel { math::SetConstant set_zero; set_zero(context.device_context(), output, static_cast(0)); + math::Col2ImFunctor col2im; + math::Col2VolFunctor col2vol; + std::vector dilations({1, 1, 1}); + // convolution transpose: gemm + col2im or col2vol (similar to conv-backward // on input) for (int i = 0; i < batch_size; i++) { @@ -142,17 +140,16 @@ class GemmConvTransposeKernel : public framework::OpKernel { if (filter_shape_vec.size() == 2) { // col2im: col_matrix -> dy // from (c * k_h * k_w, h * w) to (c, o_h, o_w) - math::Col2ImFunctor col2im; - - col2im(context.device_context(), output_batch, col, dilation_h, - dilation_w, strides[0], strides[1], 0, 0, 0, 0); + col2im(context.device_context(), col, + std::vector{dilations[0], dilations[1]}, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &output_batch); } else if (filter_shape_vec.size() == 3) { // col2vol: col_matrix -> dy // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) - math::Col2VolFunctor col2vol; - col2vol(context.device_context(), output_batch, col, dilaiton_d, - dilation_h, dilation_w, strides[0], strides[1], strides[2], 0, - 0, 0); + col2vol(context.device_context(), col, dilations, strides, + std::vector{0, 0, 0}, &output_batch); } } } @@ -179,10 +176,6 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // Actually, no paddings and groups allowed in conv transpose. std::vector paddings = context.Attr>("paddings"); - int dilaiton_d = 1; - int dilation_h = 1; - int dilation_w = 1; - const int batch_size = static_cast(input->dims()[0]); // input_shape_vec: {h, w} or {d, h, w} @@ -237,6 +230,10 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { Tensor filter_grad_; math::SetConstant set_zero; + math::Im2ColFunctor im2col; + math::Vol2ColFunctor vol2col; + std::vector dilations({1, 1, 1}); + if (input_grad) { input_grad->mutable_data(context.GetPlace()); set_zero(context.device_context(), input_grad, static_cast(0)); @@ -256,17 +253,16 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { if (filter_shape_vec.size() == 2) { // im2col: dy -> col matrix // from (c, o_h, o_w) to (c * k_h * k_w, h * w) - math::Im2ColFunctor im2col; - im2col(context.device_context(), output_grad_batch, col, dilation_h, - dilation_w, strides[0], strides[1], paddings[0], paddings[0], - paddings[1], paddings[1]); + im2col(context.device_context(), output_grad_batch, + std::vector{dilations[0], dilations[1]}, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); } else if (filter_shape_vec.size() == 3) { // vol2col: dy -> col_matrix // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) - math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), output_grad_batch, col, dilaiton_d, - dilation_h, dilation_w, strides[0], strides[1], strides[2], - paddings[0], paddings[1], paddings[2]); + vol2col(context.device_context(), output_grad_batch, dilations, + strides, paddings, &col); } if (input_grad) { diff --git a/paddle/operators/math/context_project.h b/paddle/operators/math/context_project.h index c67d84528f..d9f952c387 100644 --- a/paddle/operators/math/context_project.h +++ b/paddle/operators/math/context_project.h @@ -95,8 +95,9 @@ class ContextProjectFunctor { math::Im2ColFunctor im2col_ocf; - int dilation_h = 1; - int dilation_w = 1; + std::vector dilation({1, 1}); + std::vector padding({up_pad, 0, down_pad, 0}); + std::vector stride({context_stride, 1}); int input_row_begin, input_row_end; int sequence_height, sequence_width; @@ -126,10 +127,7 @@ class ContextProjectFunctor { {1, input_row_end - input_row_begin, sequence_width}); // input_channels, input_height, input_width in_t.Resize(framework::make_ddim(input_shape)); - - im2col_ocf(context, in_t, out_t, dilation_h, dilation_w, - /*stride_height*/ context_stride, /*stride_width*/ 1, up_pad, - down_pad, 0, 0); + im2col_ocf(context, in_t, dilation, stride, padding, &out_t); out_t.Resize({sequence_height, context_length * sequence_width}); } } @@ -207,8 +205,9 @@ class ContextProjectGradFunctor { math::Col2ImFunctor col2im_ocf; - int dilation_h = 1; - int dilation_w = 1; + std::vector dilation({1, 1}); + std::vector padding({up_pad, 0, down_pad, 0}); + std::vector stride({context_stride, 1}); int input_row_begin, input_row_end; int sequence_height, sequence_width; @@ -240,9 +239,7 @@ class ContextProjectGradFunctor { sequence_width}); // input_channels, input_height, input_width in_t.Resize(framework::make_ddim(input_shape)); - col2im_ocf(context, in_t, out_t, dilation_h, dilation_w, - /*stride_height*/ context_stride, /*stride_width*/ 1, - up_pad, down_pad, 0, 0); + col2im_ocf(context, out_t, dilation, stride, padding, &in_t); out_t.Resize({sequence_height, context_length * sequence_width}); } } diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index 2af55fa71f..c10c44c520 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -28,40 +28,39 @@ class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& im, framework::Tensor& col, - int dilation_h, int dilation_w, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& im, const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* col) { PADDLE_ENFORCE(im.dims().size() == 3); - PADDLE_ENFORCE(col.dims().size() == 5); + PADDLE_ENFORCE(col->dims().size() == 5); int im_channels = im.dims()[0]; int im_height = im.dims()[1]; int im_width = im.dims()[2]; - int filter_height = col.dims()[1]; - int filter_width = col.dims()[2]; - int col_height = col.dims()[3]; - int col_width = col.dims()[4]; + int filter_height = col->dims()[1]; + int filter_width = col->dims()[2]; + int col_height = col->dims()[3]; + int col_width = col->dims()[4]; - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - - ((dilation_h * (filter_height - 1) + 1))) / - stride_height + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + ((dilation[0] * (filter_height - 1) + 1))) / + stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - - ((dilation_w * (filter_width - 1) + 1))) / - stride_width + + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + ((dilation[1] * (filter_width - 1) + 1))) / + stride[1] + 1, col_width, - "col_width and padding(padding_left, padding_right) are " + "Output_height and padding(padding_up, padding_down) are " "inconsistent."); int channels_col = im_channels * filter_height * filter_width; const T* im_data = im.data(); - T* col_data = col.data(); + T* col_data = col->data(); for (int c = 0; c < channels_col; ++c) { int w_offset = c % filter_width; @@ -69,10 +68,8 @@ class Im2ColFunctor class Col2ImFunctor { public: - void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int dilation_h, int dilation_w, - int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left, int padding_right) { - PADDLE_ENFORCE(im.dims().size() == 3); + void operator()(const platform::DeviceContext& context, + const framework::Tensor& col, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* im) { + PADDLE_ENFORCE(im->dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; + int im_channels = im->dims()[0]; + int im_height = im->dims()[1]; + int im_width = im->dims()[2]; int filter_height = col.dims()[1]; int filter_width = col.dims()[2]; int col_height = col.dims()[3]; int col_width = col.dims()[4]; - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - - ((dilation_h * (filter_height - 1) + 1))) / - stride_height + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + ((dilation[0] * (filter_height - 1) + 1))) / + stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - - ((dilation_w * (filter_width - 1) + 1))) / - stride_width + + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + ((dilation[1] * (filter_width - 1) + 1))) / + stride[1] + 1, col_width, - "col_width and padding(padding_left, padding_right) are " + "Output_height and padding(padding_up, padding_down) are " "inconsistent."); int channels_col = im_channels * filter_height * filter_width; - T* im_data = im.data(); + T* im_data = im->data(); const T* col_data = col.data(); for (int c = 0; c < channels_col; ++c) { @@ -135,10 +133,8 @@ class Col2ImFunctor= 0 && (im_row_idx) < im_height && (im_col_idx) >= 0 && (im_col_idx) < im_width) { @@ -171,35 +167,32 @@ class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& im, framework::Tensor& col, - int dilation_h, int dilation_w, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& im, const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* col) { PADDLE_ENFORCE(im.dims().size() == 3); - PADDLE_ENFORCE(col.dims().size() == 5); + PADDLE_ENFORCE(col->dims().size() == 5); int im_channels = im.dims()[0]; int im_height = im.dims()[1]; int im_width = im.dims()[2]; - int filter_height = col.dims()[3]; - int filter_width = col.dims()[4]; - int col_height = col.dims()[0]; - int col_width = col.dims()[1]; + int filter_height = col->dims()[3]; + int filter_width = col->dims()[4]; + int col_height = col->dims()[0]; + int col_width = col->dims()[1]; - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) / - stride_height + - 1, - col_height, - "Output_height and padding(padding_up, padding_down) are " - "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) / - stride_width + - 1, - col_width, - "col_width and padding(padding_left, padding_right) are " - "inconsistent."); + PADDLE_ENFORCE_EQ( + (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ( + (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); const T* im_data = im.data(); - T* col_data = col.data(); + T* col_data = col->data(); for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) { @@ -209,9 +202,9 @@ class Im2ColFunctor class Col2ImFunctor { public: - void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int dilation_h, int dilation_w, - int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left, int padding_right) { - PADDLE_ENFORCE(im.dims().size() == 3); + void operator()(const platform::DeviceContext& context, + const framework::Tensor& col, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* im) { + PADDLE_ENFORCE(im->dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; + int im_channels = im->dims()[0]; + int im_height = im->dims()[1]; + int im_width = im->dims()[2]; int filter_height = col.dims()[3]; int filter_width = col.dims()[4]; int col_height = col.dims()[0]; int col_width = col.dims()[1]; - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) / - stride_height + - 1, - col_height, - "Output_height and padding(padding_up, padding_down) are " - "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) / - stride_width + - 1, - col_width, - "col_width and padding(padding_left, padding_right) are " - "inconsistent."); + PADDLE_ENFORCE_EQ( + (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ( + (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); - T* im_data = im.data(); + T* im_data = im->data(); const T* col_data = col.data(); for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { @@ -282,9 +274,9 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& im, framework::Tensor& col, - int dilation_h, int dilation_w, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& im, const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* col) { PADDLE_ENFORCE(im.dims().size() == 3); - PADDLE_ENFORCE(col.dims().size() == 5); + PADDLE_ENFORCE(col->dims().size() == 5); int im_channels = im.dims()[0]; int im_height = im.dims()[1]; int im_width = im.dims()[2]; - int filter_height = col.dims()[1]; - int filter_width = col.dims()[2]; - int col_height = col.dims()[3]; - int col_width = col.dims()[4]; - - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - - (dilation_h * (filter_height - 1) + 1)) / - stride_height + + int filter_height = col->dims()[1]; + int filter_width = col->dims()[2]; + int col_height = col->dims()[3]; + int col_width = col->dims()[4]; + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + (dilation[0] * (filter_height - 1) + 1)) / + stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - - (dilation_w * (filter_width - 1) + 1)) / - stride_width + + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + (dilation[1] * (filter_width - 1) + 1)) / + stride[1] + 1, col_width, "col_width and padding(padding_left, padding_right) are " @@ -100,9 +99,9 @@ class Im2ColFunctor<<(context) .stream()>>>( - im.data(), num_outputs, im_height, im_width, dilation_h, dilation_w, - filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, col_height, col_width, col.data()); + im.data(), num_outputs, im_height, im_width, dilation[0], + dilation[1], filter_height, filter_width, stride[0], stride[1], + padding[0], padding[1], col_height, col_width, col->data()); } }; @@ -163,31 +162,32 @@ template class Col2ImFunctor { public: - void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int dilation_h, int dilation_w, - int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left, int padding_right) { - PADDLE_ENFORCE(im.dims().size() == 3); + void operator()(const platform::DeviceContext& context, + const framework::Tensor& col, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* im) { + PADDLE_ENFORCE(im->dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; + int im_channels = im->dims()[0]; + int im_height = im->dims()[1]; + int im_width = im->dims()[2]; int filter_height = col.dims()[1]; int filter_width = col.dims()[2]; int col_height = col.dims()[3]; int col_width = col.dims()[4]; - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - - (dilation_h * (filter_height - 1) + 1)) / - stride_height + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + (dilation[0] * (filter_height - 1) + 1)) / + stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - - (dilation_w * (filter_width - 1) + 1)) / - stride_width + + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + (dilation[1] * (filter_width - 1) + 1)) / + stride[1] + 1, col_width, "col_width and padding(padding_left, padding_right) are " @@ -206,9 +206,9 @@ class Col2ImFunctor<<(context) .stream()>>>( - num_kernels, col.data(), im_height, im_width, dilation_h, dilation_w, - filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, col_height, col_width, im.data()); + num_kernels, col.data(), im_height, im_width, dilation[0], + dilation[1], filter_height, filter_width, stride[0], stride[1], + padding[0], padding[2], col_height, col_width, im->data()); } }; @@ -222,11 +222,11 @@ template class Col2ImFunctor; template -__global__ void im2colOCF(const T* im_data, T* col_data, int im_channels, - int im_height, int im_width, int filter_height, - int filter_width, int stride_height, int stride_width, +__global__ void im2colOCF(const T* im_data, int im_channels, int im_height, + int im_width, int filter_height, int filter_width, + int stride_height, int stride_width, int padding_height, int padding_width, int col_height, - int col_width) { + int col_width, T* col_data) { int swid = blockIdx.x; int shid = blockIdx.y; for (int channelid = threadIdx.z; channelid < im_channels; @@ -263,30 +263,29 @@ class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& im, framework::Tensor& col, - int dilation_h, int dilation_w, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& im, const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* col) { PADDLE_ENFORCE(im.dims().size() == 3); - PADDLE_ENFORCE(col.dims().size() == 5); + PADDLE_ENFORCE(col->dims().size() == 5); int im_channels = im.dims()[0]; int im_height = im.dims()[1]; int im_width = im.dims()[2]; - int filter_height = col.dims()[3]; - int filter_width = col.dims()[4]; - int col_height = col.dims()[0]; - int col_width = col.dims()[1]; - - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - - (dilation_h * (filter_height - 1) + 1)) / - stride_height + + int filter_height = col->dims()[3]; + int filter_width = col->dims()[4]; + int col_height = col->dims()[0]; + int col_width = col->dims()[1]; + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + (dilation[0] * (filter_height - 1) + 1)) / + stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - - (dilation_w * (filter_width - 1) + 1)) / - stride_width + + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + (dilation[1] * (filter_width - 1) + 1)) / + stride[1] + 1, col_width, "col_width and padding(padding_left, padding_right) are " @@ -314,18 +313,18 @@ class Im2ColFunctor<<(context) .stream()>>>( - im.data(), col.data(), im_channels, im_height, im_width, - filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, col_height, col_width); + im.data(), im_channels, im_height, im_width, filter_height, + filter_width, stride[0], stride[1], padding[0], padding[1], col_height, + col_width, col->data()); } }; template -__global__ void col2imOCF(T* im_data, const T* col_data, int im_channels, - int im_height, int im_width, int filter_height, - int filter_width, int stride_height, int stride_width, +__global__ void col2imOCF(const T* col_data, int im_channels, int im_height, + int im_width, int filter_height, int filter_width, + int stride_height, int stride_width, int padding_height, int padding_width, int col_height, - int col_width) { + int col_width, T* im_data) { int swid = blockIdx.x; int shid = blockIdx.y; for (int channelid = threadIdx.z; channelid < im_channels; @@ -361,30 +360,31 @@ template class Col2ImFunctor { public: - void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int dilation_h, int dilation_w, - int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left, int padding_right) { - PADDLE_ENFORCE(im.dims().size() == 3); + void operator()(const platform::DeviceContext& context, + const framework::Tensor& col, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* im) { + PADDLE_ENFORCE(im->dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; + int im_channels = im->dims()[0]; + int im_height = im->dims()[1]; + int im_width = im->dims()[2]; int filter_height = col.dims()[3]; int filter_width = col.dims()[4]; int col_height = col.dims()[0]; int col_width = col.dims()[1]; - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - - (dilation_h * (filter_height - 1) + 1)) / - stride_height + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + (dilation[0] * (filter_height - 1) + 1)) / + stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - - (dilation_w * (filter_width - 1) + 1)) / - stride_width + + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + (dilation[1] * (filter_width - 1) + 1)) / + stride[1] + 1, col_width, "col_width and padding(padding_left, padding_right) are " @@ -412,9 +412,9 @@ class Col2ImFunctor<<(context) .stream()>>>( - im.data(), col.data(), im_channels, im_height, im_width, - filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, col_height, col_width); + col.data(), im_channels, im_height, im_width, filter_height, + filter_width, stride[0], stride[1], padding[0], padding[1], col_height, + col_width, im->data()); } }; diff --git a/paddle/operators/math/im2col.h b/paddle/operators/math/im2col.h index d1c9595a32..deb60051be 100644 --- a/paddle/operators/math/im2col.h +++ b/paddle/operators/math/im2col.h @@ -35,6 +35,15 @@ enum class ColFormat { kCFO = 0, kOCF = 1 }; * \param colData Column data. * \param colShape The shape of colData. * + * \param dilations dilation data. + * \param 2-dimension [dilation_height, dilation_width]. + * + * \param strides stride data. + * \param 2-dimension [stride_height, stride_width]. + * + * \param paddings padding data. + * \param 4-dimension [up_pad, left_pad, down_pad, right_pad]. + * * If the template argument Format is kCFO, the shape of colData is: * [input_channels, filter_height, filter_width, output_height, output_width] * So, it is easy to reshape into a convolution matrix for convolution @@ -73,19 +82,19 @@ template class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& im, framework::Tensor& col, - int dilation_h, int dilation_w, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right); + const framework::Tensor& im, const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* col); }; template class Col2ImFunctor { public: - void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int dilation_h, int dilation_w, - int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left, int padding_right); + void operator()(const platform::DeviceContext& context, + const framework::Tensor& col, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* im); }; } // namespace math diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index 3385fe8721..10c28da72b 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -45,12 +45,14 @@ void testIm2col() { int input_height = 2; int input_width = 3; int filter_size = 2; - int stride = 1; - int padding = 0; - int dilation_h = 1; - int dilation_w = 1; - int output_height = (input_height - filter_size + 2 * padding) / stride + 1; - int output_width = (input_width - filter_size + 2 * padding) / stride + 1; + std::vector stride({1, 1}); // stride_y, stride_x + std::vector padding( + {0, 0, 0, 0}); // up_pad, left_pad, down_pad, right_pad + std::vector dilation({1, 1}); // dilation_y, dilation_x + int output_height = + (input_height - filter_size + padding[0] + padding[1]) / stride[0] + 1; + int output_width = + (input_width - filter_size + padding[2] + padding[3]) / stride[1] + 1; float* input_ptr = input_tmp.mutable_data( {1, input_height, input_width}, paddle::platform::CPUPlace()); float arr[6] = {0, 1, 2, 3, 4, 5}; @@ -87,10 +89,8 @@ void testIm2col() { paddle::operators::math::ColFormat::kOCF, Place, float> im2col_ocf; - im2col(*context, input, output_cfo, dilation_h, dilation_w, stride, stride, - padding, padding, padding, padding); - im2col_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride, - stride, padding, padding, padding, padding); + im2col(*context, input, dilation, stride, padding, &output_cfo); + im2col_ocf(*context, input, dilation, stride, padding, &output_ocf); float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5}; float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5}; @@ -133,8 +133,7 @@ void testIm2col() { input.CopyFrom(input_tmp, *place, *context); } - col2im(*context, input, output_cfo, dilation_h, dilation_w, stride, stride, - padding, padding, padding, padding); + col2im(*context, output_cfo, dilation, stride, padding, &input); float* in_ptr; if (paddle::platform::is_cpu_place(*place)) { @@ -155,8 +154,7 @@ void testIm2col() { input.CopyFrom(input_tmp, *place, *context); } - col2im_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride, - stride, padding, padding, padding, padding); + col2im_ocf(*context, output_ocf, dilation, stride, padding, &input); if (paddle::platform::is_cpu_place(*place)) { in_ptr = input.data(); diff --git a/paddle/operators/math/vol2col.cc b/paddle/operators/math/vol2col.cc index bd509a94f3..99eb7fd46d 100644 --- a/paddle/operators/math/vol2col.cc +++ b/paddle/operators/math/vol2col.cc @@ -28,51 +28,51 @@ template class Vol2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& vol, framework::Tensor& col, - int dilation_d, int dilation_h, int dilation_w, - int stride_depth, int stride_height, int stride_width, - int padding_depth, int padding_height, - int padding_width) const { + const framework::Tensor& vol, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* col) const { PADDLE_ENFORCE(vol.dims().size() == 4); - PADDLE_ENFORCE(col.dims().size() == 7); + PADDLE_ENFORCE(col->dims().size() == 7); int input_channels = vol.dims()[0]; int input_depth = vol.dims()[1]; int input_height = vol.dims()[2]; int input_width = vol.dims()[3]; - int filter_depth = col.dims()[1]; - int filter_height = col.dims()[2]; - int filter_width = col.dims()[3]; - int output_depth = col.dims()[4]; - int output_height = col.dims()[5]; - int output_width = col.dims()[6]; + int filter_depth = col->dims()[1]; + int filter_height = col->dims()[2]; + int filter_width = col->dims()[3]; + int output_depth = col->dims()[4]; + int output_height = col->dims()[5]; + int output_width = col->dims()[6]; int channels_col = input_channels * filter_depth * filter_height * filter_width; - PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - - ((dilation_d * (filter_depth - 1) + 1))) / - stride_depth + + PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - + ((dilations[0] * (filter_depth - 1) + 1))) / + strides[0] + 1, output_depth, "input_depth and output_depth are " - "Mismatching."); - PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - - ((dilation_h * (filter_height - 1) + 1))) / - stride_height + + "mismatching."); + PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - + ((dilations[1] * (filter_height - 1) + 1))) / + strides[1] + 1, output_height, "input_height and output_height are " - "Mismatching."); - PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - - ((dilation_w * (filter_width - 1) + 1))) / - stride_width + + "mismatching."); + PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - + ((dilations[2] * (filter_width - 1) + 1))) / + strides[2] + 1, output_width, "input_width and output_width are " - "Mismatching."); + "mismatching."); const T* vol_data = vol.data(); - T* col_data = col.data(); + T* col_data = col->data(); for (int c = 0; c < channels_col; ++c) { int w_offset = c % filter_width; @@ -80,13 +80,11 @@ class Vol2ColFunctor { int d_offset = (c / filter_width / filter_height) % filter_depth; int c_in = c / filter_width / filter_height / filter_depth; for (int d = 0; d < output_depth; ++d) { - int d_pad = d * stride_depth - padding_depth + d_offset * dilation_d; + int d_pad = d * strides[0] - paddings[0] + d_offset * dilations[0]; for (int h = 0; h < output_height; ++h) { - int h_pad = - h * stride_height - padding_height + h_offset * dilation_h; + int h_pad = h * strides[1] - paddings[1] + h_offset * dilations[1]; for (int w = 0; w < output_width; ++w) { - int w_pad = - w * stride_width - padding_width + w_offset * dilation_w; + int w_pad = w * strides[2] - paddings[2] + w_offset * dilations[2]; int col_idx = ((c * output_depth + d) * output_height + h) * output_width + w; @@ -116,18 +114,18 @@ template class Col2VolFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& vol, const framework::Tensor& col, - int dilation_d, int dilation_h, int dilation_w, - int stride_depth, int stride_height, int stride_width, - int padding_depth, int padding_height, - int padding_width) const { - PADDLE_ENFORCE(vol.dims().size() == 4); + const framework::Tensor& col, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* vol) const { + PADDLE_ENFORCE(vol->dims().size() == 4); PADDLE_ENFORCE(col.dims().size() == 7); - int input_channels = vol.dims()[0]; - int input_depth = vol.dims()[1]; - int input_height = vol.dims()[2]; - int input_width = vol.dims()[3]; + int input_channels = vol->dims()[0]; + int input_depth = vol->dims()[1]; + int input_height = vol->dims()[2]; + int input_width = vol->dims()[3]; int filter_depth = col.dims()[1]; int filter_height = col.dims()[2]; int filter_width = col.dims()[3]; @@ -137,28 +135,28 @@ class Col2VolFunctor { int channels_col = input_channels * filter_depth * filter_height * filter_width; - PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - - ((dilation_d * (filter_depth - 1) + 1))) / - stride_depth + + PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - + ((dilations[0] * (filter_depth - 1) + 1))) / + strides[0] + 1, output_depth, "input_depth and output_depth are " - "Mismatching."); - PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - - ((dilation_h * (filter_height - 1) + 1))) / - stride_height + + "mismatching."); + PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - + ((dilations[1] * (filter_height - 1) + 1))) / + strides[1] + 1, output_height, "input_height and output_height are " - "Mismatching."); - PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - - ((dilation_w * (filter_width - 1) + 1))) / - stride_width + + "mismatching."); + PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - + ((dilations[2] * (filter_width - 1) + 1))) / + strides[2] + 1, output_width, "input_width and output_width are " - "Mismatching."); - T* vol_data = vol.data(); + "mismatching."); + T* vol_data = vol->data(); const T* col_data = col.data(); for (int c = 0; c < channels_col; ++c) { @@ -167,13 +165,11 @@ class Col2VolFunctor { int d_offset = (c / filter_width / filter_height) % filter_depth; int cIm = c / filter_width / filter_height / filter_depth; for (int d = 0; d < output_depth; ++d) { - int d_pad = d * stride_depth - padding_depth + d_offset * dilation_d; + int d_pad = d * strides[0] - paddings[0] + d_offset * dilations[0]; for (int h = 0; h < output_height; ++h) { - int h_pad = - h * stride_height - padding_height + h_offset * dilation_h; + int h_pad = h * strides[1] - paddings[1] + h_offset * dilations[1]; for (int w = 0; w < output_width; ++w) { - int w_pad = - w * stride_width - padding_width + w_offset * dilation_w; + int w_pad = w * strides[2] - paddings[2] + w_offset * dilations[2]; if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 && w_pad < input_width && d_pad >= 0 && d_pad < input_depth) { diff --git a/paddle/operators/math/vol2col.cu b/paddle/operators/math/vol2col.cu index 080d3e5466..addae3caf8 100644 --- a/paddle/operators/math/vol2col.cu +++ b/paddle/operators/math/vol2col.cu @@ -71,42 +71,42 @@ template class Vol2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& vol, framework::Tensor& col, - int dilation_d, int dilation_h, int dilation_w, - int stride_depth, int stride_height, int stride_width, - int padding_depth, int padding_height, - int padding_width) const { + const framework::Tensor& vol, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* col) const { PADDLE_ENFORCE(vol.dims().size() == 4); - PADDLE_ENFORCE(col.dims().size() == 7); + PADDLE_ENFORCE(col->dims().size() == 7); int input_channels = vol.dims()[0]; int input_depth = vol.dims()[1]; int input_height = vol.dims()[2]; int input_width = vol.dims()[3]; - int filter_depth = col.dims()[1]; - int filter_height = col.dims()[2]; - int filter_width = col.dims()[3]; - int output_depth = col.dims()[4]; - int output_height = col.dims()[5]; - int output_width = col.dims()[6]; + int filter_depth = col->dims()[1]; + int filter_height = col->dims()[2]; + int filter_width = col->dims()[3]; + int output_depth = col->dims()[4]; + int output_height = col->dims()[5]; + int output_width = col->dims()[6]; - PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - - ((dilation_d * (filter_depth - 1) + 1))) / - stride_depth + + PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - + ((dilations[0] * (filter_depth - 1) + 1))) / + strides[0] + 1, output_depth, "input_depth and output_depth are " "Mismatching."); - PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - - ((dilation_h * (filter_height - 1) + 1))) / - stride_height + + PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - + ((dilations[1] * (filter_height - 1) + 1))) / + strides[1] + 1, output_height, "input_height and output_height are " "Mismatching."); - PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - - ((dilation_w * (filter_width - 1) + 1))) / - stride_width + + PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - + ((dilations[2] * (filter_width - 1) + 1))) / + strides[2] + 1, output_width, "input_width and output_width are " @@ -121,10 +121,10 @@ class Vol2ColFunctor { reinterpret_cast(context) .stream()>>>( num_outputs, vol.data(), input_depth, input_height, input_width, - dilation_d, dilation_h, dilation_w, filter_depth, filter_height, - filter_width, stride_depth, stride_height, stride_width, padding_depth, - padding_height, padding_width, output_depth, output_height, - output_width, col.data()); + dilations[0], dilations[1], dilations[2], filter_depth, filter_height, + filter_width, strides[0], strides[1], strides[2], paddings[0], + paddings[1], paddings[2], output_depth, output_height, output_width, + col->data()); } }; @@ -200,18 +200,18 @@ template class Col2VolFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& vol, const framework::Tensor& col, - int dilation_d, int dilation_h, int dilation_w, - int stride_depth, int stride_height, int stride_width, - int padding_depth, int padding_height, - int padding_width) const { - PADDLE_ENFORCE(vol.dims().size() == 4); + const framework::Tensor& col, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* vol) const { + PADDLE_ENFORCE(vol->dims().size() == 4); PADDLE_ENFORCE(col.dims().size() == 7); - int input_channels = vol.dims()[0]; - int input_depth = vol.dims()[1]; - int input_height = vol.dims()[2]; - int input_width = vol.dims()[3]; + int input_channels = vol->dims()[0]; + int input_depth = vol->dims()[1]; + int input_height = vol->dims()[2]; + int input_width = vol->dims()[3]; int filter_depth = col.dims()[1]; int filter_height = col.dims()[2]; int filter_width = col.dims()[3]; @@ -219,23 +219,23 @@ class Col2VolFunctor { int output_height = col.dims()[5]; int output_width = col.dims()[6]; - PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - - ((dilation_d * (filter_depth - 1) + 1))) / - stride_depth + + PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - + ((dilations[0] * (filter_depth - 1) + 1))) / + strides[0] + 1, output_depth, "input_depth and output_depth are " "Mismatching."); - PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - - ((dilation_h * (filter_height - 1) + 1))) / - stride_height + + PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - + ((dilations[1] * (filter_height - 1) + 1))) / + strides[1] + 1, output_height, "input_height and output_height are " "Mismatching."); - PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - - ((dilation_w * (filter_width - 1) + 1))) / - stride_width + + PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - + ((dilations[2] * (filter_width - 1) + 1))) / + strides[2] + 1, output_width, "input_width and output_width are " @@ -250,10 +250,10 @@ class Col2VolFunctor { reinterpret_cast(context) .stream()>>>( num_kernels, col.data(), input_depth, input_height, input_width, - dilation_d, dilation_h, dilation_w, filter_depth, filter_height, - filter_width, stride_depth, stride_height, stride_width, padding_depth, - padding_height, padding_width, output_depth, output_height, - output_width, vol.data()); + dilations[0], dilations[1], dilations[2], filter_depth, filter_height, + filter_width, strides[0], strides[1], strides[2], paddings[0], + paddings[1], paddings[2], output_depth, output_height, output_width, + vol->data()); } }; diff --git a/paddle/operators/math/vol2col.h b/paddle/operators/math/vol2col.h index c2d8257c0b..cbc30bd754 100644 --- a/paddle/operators/math/vol2col.h +++ b/paddle/operators/math/vol2col.h @@ -31,6 +31,15 @@ namespace math { * \param colData Column data. * \param colShape The shape of colData. * + * \param dilations dilation data. + * \param 3-dimension [dilation_depth, dilation_height, dilation_width]. + * + * \param strides stride data. + * \param 3-dimension [stride_depth, stride_height, stride_width]. + * + * \param paddings padding data. + * \param 3-dimension [d_pad, h_pad, w_pad]. + * * The shape of colData is: * [input_channels, filter_depth, filter_height, filter_width, output_depth, * output_height, output_width] @@ -57,22 +66,22 @@ template class Vol2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& vol, framework::Tensor& col, - int dilation_d, int dilation_h, int dilation_w, - int stride_depth, int stride_height, int stride_width, - int padding_depth, int padding_height, - int padding_width) const; + const framework::Tensor& vol, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* col) const; }; template class Col2VolFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& vol, const framework::Tensor& col, - int dilation_d, int dilation_h, int dilation_w, - int stride_depth, int stride_height, int stride_width, - int padding_depth, int padding_height, - int padding_width) const; + const framework::Tensor& col, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* vol) const; }; } // namespace math diff --git a/paddle/operators/math/vol2col_test.cc b/paddle/operators/math/vol2col_test.cc index 9d673ad36c..c31c716842 100644 --- a/paddle/operators/math/vol2col_test.cc +++ b/paddle/operators/math/vol2col_test.cc @@ -62,12 +62,15 @@ void testVol2col() { int input_height = 2; int input_width = 3; int filter_size = 2; - int stride = 1; - int padding = 0; - int dilation = 1; - int output_depth = (input_depth - filter_size + 2 * padding) / stride + 1; - int output_height = (input_height - filter_size + 2 * padding) / stride + 1; - int output_width = (input_width - filter_size + 2 * padding) / stride + 1; + std::vector strides({1, 1, 1}); + std::vector paddings({0, 0, 0}); + std::vector dilations({1, 1, 1}); + int output_depth = + (input_depth - filter_size + 2 * paddings[0]) / strides[0] + 1; + int output_height = + (input_height - filter_size + 2 * paddings[1]) / strides[1] + 1; + int output_width = + (input_width - filter_size + 2 * paddings[2]) / strides[2] + 1; // Vol2Col test float* input_ptr = @@ -86,8 +89,7 @@ void testVol2col() { *place); paddle::operators::math::Vol2ColFunctor vol2col; - vol2col(*context, input, output, dilation, dilation, dilation, stride, stride, - stride, padding, padding, padding); + vol2col(*context, input, dilations, strides, paddings, &output); float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11}; float* out_cfo_ptr; @@ -112,8 +114,7 @@ void testVol2col() { } paddle::operators::math::Col2VolFunctor col2vol; - col2vol(*context, input, output, dilation, dilation, dilation, stride, stride, - stride, padding, padding, padding); + col2vol(*context, output, dilations, strides, paddings, &input); float* in_ptr; if (paddle::platform::is_cpu_place(*place)) { From 81bb26f3bdacd2786ffdb4a2960bbee28e7834dd Mon Sep 17 00:00:00 2001 From: xuwei06 Date: Tue, 14 Nov 2017 21:32:31 -0800 Subject: [PATCH 14/25] Removing debug print and moving assert --- python/paddle/v2/fluid/io.py | 3 ++- python/paddle/v2/fluid/layer_helper.py | 2 +- python/paddle/v2/fluid/layers.py | 1 - 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index d1263c3e91..2d070814ee 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -244,6 +244,8 @@ def get_parameter_value(para, executor): :param para: the given parameter :return: the LoDTensor for the parameter """ + assert is_parameter(para) + get_program = Program() block = get_program.global_block() new_var = _clone_var_in_block_(block, para) @@ -263,5 +265,4 @@ def get_parameter_value_by_name(name, executor, program=None): if program is None: program = g_main_program var = program.global_block().var(name) - assert is_parameter(var) return get_parameter_value(var, executor) diff --git a/python/paddle/v2/fluid/layer_helper.py b/python/paddle/v2/fluid/layer_helper.py index 0a9ed81888..a97e07982b 100644 --- a/python/paddle/v2/fluid/layer_helper.py +++ b/python/paddle/v2/fluid/layer_helper.py @@ -158,7 +158,7 @@ class LayerHelper(object): or equal than 2. :param dim_start: :param dim_end: the shape of the bias will be - input_var.shape(dim_start:dim_end). The bias is broadcast to other + input_var.shape[dim_start:dim_end]. The bias is broadcasted to other dimensions and added to input_var to get the output """ size = list(input_var.shape[dim_start:dim_end]) diff --git a/python/paddle/v2/fluid/layers.py b/python/paddle/v2/fluid/layers.py index 771a313598..1789d2f82a 100644 --- a/python/paddle/v2/fluid/layers.py +++ b/python/paddle/v2/fluid/layers.py @@ -676,7 +676,6 @@ def conv2d(input, filter_shape = [num_filters, num_filter_channels] + filter_size std = (2.0 / (filter_size[0]**2 * num_channels))**0.5 - print 'name=', name, 'std=', std filter = helper.create_parameter( attr=helper.param_attr, shape=filter_shape, From 09866fb75f8522e0cea56ccc40fee76cdf7d6be7 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Wed, 15 Nov 2017 17:29:34 +0800 Subject: [PATCH 15/25] feature/beam search op (#5052) --- paddle/operators/beam_search_op.cc | 185 ++++++++++++++ paddle/operators/beam_search_op.h | 226 ++++++++++++++++++ .../v2/framework/tests/test_beam_search_op.py | 65 +++++ 3 files changed, 476 insertions(+) create mode 100644 paddle/operators/beam_search_op.cc create mode 100644 paddle/operators/beam_search_op.h create mode 100644 python/paddle/v2/framework/tests/test_beam_search_op.py diff --git a/paddle/operators/beam_search_op.cc b/paddle/operators/beam_search_op.cc new file mode 100644 index 0000000000..17926a813d --- /dev/null +++ b/paddle/operators/beam_search_op.cc @@ -0,0 +1,185 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/beam_search_op.h" + +#include +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +void BeamSearch::operator()(const framework::LoDTensor &pre_ids, + framework::LoDTensor *selected_ids, + framework::LoDTensor *selected_scores) { + auto items = SelectTopBeamSizeItems(); + auto selected_items = ToMap(items); + PruneEndidCandidates(pre_ids, &selected_items); + // calculate the output tensor's height + size_t num_instances = std::accumulate( + std::begin(items), std::end(items), 0, + [](size_t a, std::vector &b) { return a + b.size(); }); + // the output tensor shape should be [num_instances, 1] + auto dims = framework::make_ddim( + std::vector({static_cast(num_instances), 1})); + selected_ids->Resize(dims); + selected_scores->Resize(dims); + + std::map> hash; + framework::LoD new_lod; + auto *ids_data = selected_ids->mutable_data(platform::CPUPlace()); + auto *scores_data = + selected_scores->mutable_data(platform::CPUPlace()); + + // fill in data + std::vector low_level; + size_t low_offset = 0; + for (auto &items : selected_items) { + low_level.push_back(low_offset); + for (auto &item : items) { + ids_data[low_offset] = item.id; + scores_data[low_offset] = item.score; + low_offset++; + } + } + // fill lod + auto abs_lod = framework::ToAbsOffset(ids_->lod()); + auto &high_level = abs_lod[lod_level_]; + framework::LoD lod(2); + lod[0].assign(high_level.begin(), high_level.end()); + lod[1].assign(low_level.begin(), low_level.end()); + selected_ids->set_lod(lod); + selected_scores->set_lod(lod); +} + +void BeamSearch::PruneEndidCandidates(const framework::LoDTensor &pre_ids, + std::vector> *items) { + auto *pre_ids_data = pre_ids.data(); + + for (size_t offset = 0; offset < items->size(); offset++) { + auto prefix_id = pre_ids_data[offset]; + if (prefix_id == end_id_) { + items->at(offset).clear(); + } + } +} + +std::vector> BeamSearch::ToMap( + const std::vector> &items) { + std::vector> result; + for (auto &entries : items) { + for (const auto &item : entries) { + if (item.offset >= result.size()) { + result.resize(item.offset + 1); + } + result[item.offset].push_back(item); + } + } + return result; +} + +std::vector> +BeamSearch::SelectTopBeamSizeItems() { + std::vector> result; + std::vector items; + // for each source sentence, select the top beam_size items across all + // candidate sets. + while (NextItemSet(&items)) { + std::nth_element(std::begin(items), std::begin(items) + beam_size_, + std::end(items), [](const Item &a, const Item &b) { + // TODO(superjom) make score's comparation customizable. + // partial sort in descending order + return a.score > b.score; + }); + // prune the top beam_size items. + if (items.size() > beam_size_) { + items.resize(beam_size_); + } + result.emplace_back(items); + } + return result; +} + +// the candidates of a source +bool BeamSearch::NextItemSet(std::vector *items) { + if (sent_offset_ >= ids_->NumElements(lod_level_)) { + return false; + } + // find the current candidates + auto ids = *ids_; + auto scores = *scores_; + + auto source_abs_two_level_lod = framework::SliceInLevel( + ids.lod(), lod_level_, sent_offset_, sent_offset_ + 1); + source_abs_two_level_lod = framework::ToAbsOffset(source_abs_two_level_lod); + auto abs_lod = framework::ToAbsOffset(ids.lod()); + PADDLE_ENFORCE_GE(source_abs_two_level_lod.size(), 2UL); + + auto *ids_data = ids.data(); + auto *scores_data = scores.data(); + + size_t instance_dim = 1; + for (int i = 1; i < ids.dims().size(); i++) { + instance_dim *= ids.dims()[i]; + } + + items->clear(); + items->reserve(framework::product(ids.dims())); + for (size_t offset = abs_lod[lod_level_][sent_offset_]; + offset < abs_lod[lod_level_][sent_offset_ + 1]; offset++) { + for (int d = 0; d < instance_dim; d++) { + const size_t dim_offset = offset * instance_dim + d; + items->emplace_back(offset, ids_data[dim_offset], + scores_data[dim_offset]); + } + } + + sent_offset_++; + return true; +} + +class BeamSearchProtoAndCheckerMaker + : public framework::OpProtoAndCheckerMaker { + public: + BeamSearchProtoAndCheckerMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + // inputs and outputs stored in proto + AddInput("pre_ids", "ids in previous step"); + AddInput("ids", "a LoDTensor of shape of [None,k]"); + AddInput("scores", + "a LoDTensor that has the same shape and LoD with `ids`"); + AddOutput("selected_ids", + "a LoDTensor that stores the IDs selected by beam search"); + AddOutput( + "selected_scores", + "a LoDTensor that has the same shape and LoD with `selected_ids`"); + + // Attributes stored in AttributeMap + AddAttr("level", "the level of LoDTensor"); + AddAttr("beam_size", "beam size for beam search"); + AddAttr("end_id", + "the token id which indicates the end of a sequence"); + + AddComment( + "This is a beam search operator that help to generate sequences."); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_WITHOUT_GRADIENT(beam_search, paddle::operators::BeamSearchOp, + paddle::operators::BeamSearchProtoAndCheckerMaker); diff --git a/paddle/operators/beam_search_op.h b/paddle/operators/beam_search_op.h new file mode 100644 index 0000000000..cc556bfe42 --- /dev/null +++ b/paddle/operators/beam_search_op.h @@ -0,0 +1,226 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef PADDLE_WITH_TESTING +#include "gtest/gtest.h" +#endif + +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +/* + * This is an implementation of beam search. + * + * To explain the details, lets take machine translation task for example, in + * this task, one source sentence is translated to multiple target sentences, + * during this period, one sentence will be translated to multiple translation + * prefixes(target sentence that have not ended), in each time step a prefix + * will have some candidates, input the candidate ids and their corresponding + * scores (probabilities), it will sort and select the top beam_size candidates + * for each source sentence, and store the selected candidates's score and their + * corresponding ids to LoDTensors. + * + * A detailed example: + * + * Input + * + * ids: + * LoD (should have 2 levels) + * first level: [0, 1, 4] + * second level: [0, 1, 2, 3, 4] + * + * tensor's data + * [ + * [4, 2, 5] + * [2, 1, 3] + * [3, 5, 2] + * [8, 2, 1] + * ] + * + * scores: + * LoD same as `ids` + * tensor's data + * [ + * [0.5, 0.3, 0.2] + * [0.6, 0.3, 0.1] + * [0.9, 0.5, 0.1] + * [0.7, 0.5, 0.1] + * ] + * + * the inputs means that there are 2 source sentences to translate, and the + * first source has 1 prefix, the second source has 2 prefix. + * + * lets assume beam size is 2, and the beam search's output should be + * LoD + * first level: + * [0, 1, 2] + * second level: + * [0, 2, 4] + * + * tensor's data + * [[ + * 0.5, + * 0.3, + * 0.9, + * 0.7 + * ]] + * + * TODO all the prune operations should be in the beam search, so it is better + * to split the beam search algorithm into a sequence of smaller operators, and + * the prune operators can be inserted in this sequence. + */ +class BeamSearch { + public: + // TODO(superjom) make type customizable + using id_t = size_t; + using score_t = float; + /* + * Input the arguments that needed by this class. + */ + BeamSearch(const framework::LoDTensor& ids, + const framework::LoDTensor& scores, size_t level, size_t beam_size, + int end_id) + : beam_size_(beam_size), + ids_(&ids), + scores_(&scores), + lod_level_(level), + end_id_(end_id) {} + + /* + * The main function of beam search. + * + * @selected_ids: a [None, 1]-shaped tensor with LoD. + * In a machine translation model, it might be the candidate term id sets, + * each set stored as a varience-length sequence. + * The format might be described with a two-level LoD + * - [[0 1] + * - [0 1 2]] + * - [[] + * - [0 1]] + * the first level of LoD tells that there are two source sentences. The + * second level describes the details of the candidate id set's offsets in + * the + * source sentences. + * + * @selected_scores: a LoD tensor with the same shape and LoD with + * selected_ids. + * It stores the corresponding scores of candidate ids in selected_ids. + * + * Return false if all the input tensor is empty, in machine translation task + * that means no candidates is provided, and the task will stop running. + */ + void operator()(const framework::LoDTensor& pre_ids, + framework::LoDTensor* selected_ids, + framework::LoDTensor* selected_scores); + + protected: + /* + * The basic items help to sort. + */ + struct Item { + Item() {} + Item(size_t offset, size_t id, float score) + : offset(offset), id(id), score(score) {} + // offset in the lod_level_+1 + size_t offset; + // the candidate id + id_t id; + // the corresponding score + score_t score; + }; + + void PruneEndidCandidates(const framework::LoDTensor& pre_ids, + std::vector>* items); + + /* + * Transform the items into a map whose key is offset, value is the items. + * NOTE low performance + */ + std::vector> ToMap( + const std::vector>& inputs); + + /* + * For each source, select top beam_size records. + */ + std::vector> SelectTopBeamSizeItems(); + + /* + * Get the items of next source sequence, return false if no remaining items. + */ + bool NextItemSet(std::vector* items); + + private: + size_t beam_size_; + const framework::LoDTensor* ids_; + const framework::LoDTensor* scores_; + size_t lod_level_{0}; + size_t sent_offset_{0}; + int end_id_{0}; +}; + +class BeamSearchOp : public framework::OperatorBase { + public: + BeamSearchOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + BeamSearchOp(const BeamSearchOp& o) + : framework::OperatorBase( + static_cast(o)) { + PADDLE_THROW("Not Implemented"); + } + + void Run(const framework::Scope& scope, + const platform::DeviceContext& dev_ctx) const override { + LOG(INFO) << "run beam search op"; + auto ids_var = scope.FindVar(Input("ids")); + auto scores_var = scope.FindVar(Input("scores")); + auto pre_ids_var = scope.FindVar(Input("pre_ids")); + PADDLE_ENFORCE_NOT_NULL(ids_var); + PADDLE_ENFORCE_NOT_NULL(scores_var); + PADDLE_ENFORCE_NOT_NULL(pre_ids_var); + + auto& ids = ids_var->Get(); + auto& scores = scores_var->Get(); + auto& pre_ids = pre_ids_var->Get(); + size_t level = Attr("level"); + size_t beam_size = Attr("beam_size"); + int end_id = Attr("end_id"); + LOG(INFO) << "init beam search"; + BeamSearch alg(ids, scores, level, beam_size, end_id); + + LOG(INFO) << "after beam search"; + auto selected_ids_var = scope.FindVar(Output("selected_ids")); + auto selected_scores_var = scope.FindVar(Output("selected_scores")); + PADDLE_ENFORCE_NOT_NULL(selected_ids_var); + PADDLE_ENFORCE_NOT_NULL(selected_scores_var); + auto& selected_ids_tensor = + *selected_ids_var->GetMutable(); + auto& selected_scores_tensor = + *selected_scores_var->GetMutable(); + LOG(INFO) << "run beam search"; + alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor); + LOG(INFO) << "finish beam search"; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_beam_search_op.py b/python/paddle/v2/framework/tests/test_beam_search_op.py new file mode 100644 index 0000000000..a5a0cc0c96 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_beam_search_op.py @@ -0,0 +1,65 @@ +import logging +from paddle.v2.framework.op import Operator, DynamicRecurrentOp +import paddle.v2.framework.core as core +import unittest +import numpy as np + + +def create_tensor(scope, name, np_data): + tensor = scope.var(name).get_tensor() + tensor.set(np_data, core.CPUPlace()) + return tensor + + +class BeamSearchOpTester(unittest.TestCase): + def setUp(self): + self.scope = core.Scope() + self.ctx = core.DeviceContext.create(core.CPUPlace()) + self._create_ids() + self._create_scores() + self._create_pre_ids() + self.scope.var('selected_ids') + self.scope.var('selected_scores') + + def test_run(self): + op = Operator( + 'beam_search', + pre_ids="pre_ids", + ids='ids', + scores='scores', + selected_ids='selected_ids', + selected_scores='selected_scores', + level=0, + beam_size=2, + end_id=0, ) + op.run(self.scope, self.ctx) + selected_ids = self.scope.find_var("selected_ids").get_tensor() + print 'selected_ids', np.array(selected_ids) + print 'lod', selected_ids.lod() + + def _create_pre_ids(self): + np_data = np.array([[1, 2, 3, 4]], dtype='int32') + tensor = create_tensor(self.scope, "pre_ids", np_data) + + def _create_ids(self): + self.lod = [[0, 1, 4], [0, 1, 2, 3, 4]] + np_data = np.array( + [[4, 2, 5], [2, 1, 3], [3, 5, 2], [8, 2, 1]], dtype='int32') + tensor = create_tensor(self.scope, "ids", np_data) + tensor.set_lod(self.lod) + + def _create_scores(self): + np_data = np.array( + [ + [0.5, 0.3, 0.2], + [0.6, 0.3, 0.1], + [0.9, 0.5, 0.1], + [0.7, 0.5, 0.1], + ], + dtype='float32') + tensor = create_tensor(self.scope, "scores", np_data) + tensor.set_lod(self.lod) + + +if __name__ == '__main__': + unittest.main() From 31dc0193c958e9ba723ee89fc602a01479d0bbf1 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 15 Nov 2017 13:23:23 +0800 Subject: [PATCH 16/25] fix ContextProjectFunctor parameter order --- paddle/operators/math/context_project.h | 36 +++++++++++++------------ paddle/operators/math/vol2col.cu | 7 +++-- paddle/operators/sequence_conv_op.h | 22 +++++++-------- 3 files changed, 33 insertions(+), 32 deletions(-) diff --git a/paddle/operators/math/context_project.h b/paddle/operators/math/context_project.h index d9f952c387..845de82bbc 100644 --- a/paddle/operators/math/context_project.h +++ b/paddle/operators/math/context_project.h @@ -88,9 +88,10 @@ template class ContextProjectFunctor { public: void operator()(const platform::DeviceContext& context, const LoDTensor& in, - const Tensor& padding_data, Tensor& col, - bool padding_trainable, int context_start, int context_length, - int context_stride, int up_pad, int down_pad) { + const Tensor& padding_data, bool padding_trainable, + const int context_start, const int context_length, + const int context_stride, const int up_pad, + const int down_pad, Tensor* col) { auto lod_level_0 = in.lod()[0]; math::Im2ColFunctor im2col_ocf; @@ -109,8 +110,8 @@ class ContextProjectFunctor { : static_cast(lod_level_0[i]); input_row_end = static_cast(lod_level_0[i + 1]); - Tensor out_t = col.Slice(static_cast(lod_level_0[i]), - static_cast(lod_level_0[i + 1])); + Tensor out_t = col->Slice(static_cast(lod_level_0[i]), + static_cast(lod_level_0[i + 1])); sequence_height = static_cast(out_t.dims()[0]); @@ -133,8 +134,8 @@ class ContextProjectFunctor { } if (padding_trainable) { for (int i = 0; i < static_cast(lod_level_0.size()) - 1; ++i) { - Tensor out_t = col.Slice(static_cast(lod_level_0[i]), - static_cast(lod_level_0[i + 1])); + Tensor out_t = col->Slice(static_cast(lod_level_0[i]), + static_cast(lod_level_0[i + 1])); sequence_height = static_cast(out_t.dims()[0]); @@ -197,10 +198,11 @@ class ContextProjectFunctor { template class ContextProjectGradFunctor { public: - void operator()(const platform::DeviceContext& context, LoDTensor& in, - Tensor& padding_data, Tensor& col, bool padding_trainable, - int context_start, int context_length, int context_stride, - int up_pad, int down_pad, bool input_grad, bool pad_grad) { + void operator()(const platform::DeviceContext& context, const LoDTensor& in, + bool padding_trainable, const int context_start, + const int context_length, const int context_stride, + const int up_pad, const int down_pad, bool pad_grad, + bool input_grad, Tensor* padding_data, Tensor* col) { auto lod_level_0 = in.lod()[0]; math::Col2ImFunctor col2im_ocf; @@ -220,8 +222,8 @@ class ContextProjectGradFunctor { : static_cast(lod_level_0[i]); input_row_end = static_cast(lod_level_0[i + 1]); - Tensor out_t = col.Slice(static_cast(lod_level_0[i]), - static_cast(lod_level_0[i + 1])); + Tensor out_t = col->Slice(static_cast(lod_level_0[i]), + static_cast(lod_level_0[i + 1])); sequence_height = static_cast(out_t.dims()[0]); @@ -247,8 +249,8 @@ class ContextProjectGradFunctor { if (pad_grad) { if (padding_trainable) { for (int i = 0; i < static_cast(lod_level_0.size()) - 1; ++i) { - Tensor out_t = col.Slice(static_cast(lod_level_0[i]), - static_cast(lod_level_0[i + 1])); + Tensor out_t = col->Slice(static_cast(lod_level_0[i]), + static_cast(lod_level_0[i + 1])); sequence_height = static_cast(out_t.dims()[0]); out_t.Resize({sequence_height * context_length, sequence_width}); @@ -262,7 +264,7 @@ class ContextProjectGradFunctor { k + context_length < up_pad ? context_length : up_pad - k; Tensor out_t_sub = out_t.Slice(k * context_length, k * context_length + padding_size); - Tensor w_sub = padding_data.Slice(k, k + padding_size); + Tensor w_sub = padding_data->Slice(k, k + padding_size); auto out_t_sub_e = EigenMatrix::From(out_t_sub); auto w_sub_e = EigenMatrix::From(w_sub); w_sub_e.device(*context.GetEigenDevice()) = @@ -295,7 +297,7 @@ class ContextProjectGradFunctor { Tensor out_t_sub = out_t.Slice( (down_pad_begin_row + t) * context_length - padding_size, (down_pad_begin_row + t) * context_length); - Tensor w_sub = padding_data.Slice( + Tensor w_sub = padding_data->Slice( up_pad + padding_idx, up_pad + padding_idx + padding_size); auto out_t_sub_e = EigenMatrix::From(out_t_sub); auto w_sub_e = EigenMatrix::From(w_sub); diff --git a/paddle/operators/math/vol2col.cu b/paddle/operators/math/vol2col.cu index addae3caf8..dae3be858e 100644 --- a/paddle/operators/math/vol2col.cu +++ b/paddle/operators/math/vol2col.cu @@ -174,10 +174,9 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth, int data_col_index = (((((c * filter_depth + d_off) * filter_height + h_off) * filter_width + - w_off) * - output_detph + - d_col) * - output_height + + w_off))); + data_col_index = + ((data_col_index * output_detph + d_col) * output_height + h_col) * output_width + w_col; diff --git a/paddle/operators/sequence_conv_op.h b/paddle/operators/sequence_conv_op.h index a57e1752bb..adee8d760e 100644 --- a/paddle/operators/sequence_conv_op.h +++ b/paddle/operators/sequence_conv_op.h @@ -62,9 +62,9 @@ class SequenceConvKernel : public framework::OpKernel { math::ContextProjectFunctor seq_project_functor; - seq_project_functor(context.device_context(), *in, *padding_data, col, + seq_project_functor(context.device_context(), *in, *padding_data, padding_trainable, context_start, context_length, - context_stride, up_pad, down_pad); + context_stride, up_pad, down_pad, &col); math::matmul(context.device_context(), col, false, filter, false, static_cast(1.0), out, static_cast(0.0)); @@ -117,10 +117,10 @@ class SequenceConvGradKernel : public framework::OpKernel { in_g->set_lod(in->lod()); set_zero(context.device_context(), in_g, static_cast(0)); - seq_project_grad_functor(context.device_context(), *in_g, *padding_data_g, - col, padding_trainable, context_start, - context_length, context_stride, up_pad, down_pad, - true, false); + seq_project_grad_functor(context.device_context(), *in_g, + padding_trainable, context_start, context_length, + context_stride, up_pad, down_pad, false, true, + padding_data_g, &col); } if (padding_trainable && padding_data_g) { @@ -129,9 +129,9 @@ class SequenceConvGradKernel : public framework::OpKernel { LoDTensor* input = const_cast(in); seq_project_grad_functor(context.device_context(), *input, - *padding_data_g, col, padding_trainable, - context_start, context_length, context_stride, - up_pad, down_pad, false, true); + padding_trainable, context_start, context_length, + context_stride, up_pad, down_pad, true, false, + padding_data_g, &col); } if (filter_g) { @@ -146,9 +146,9 @@ class SequenceConvGradKernel : public framework::OpKernel { padding_data = context.Input("PaddingData"); } - seq_project_functor(context.device_context(), *in, *padding_data, col, + seq_project_functor(context.device_context(), *in, *padding_data, padding_trainable, context_start, context_length, - context_stride, up_pad, down_pad); + context_stride, up_pad, down_pad, &col); math::matmul(context.device_context(), col, true, out_grad, false, T(1.0), &filter_grad, T(1.0)); From 00e0881bfb1fa3d633a360032ce85e80e966a0b3 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 15 Nov 2017 19:58:39 +0800 Subject: [PATCH 17/25] remove conflict --- python/paddle/v2/framework/proto/__init__.py | 0 .../v2/framework/proto/framework_pb2.py | 1076 ----------------- 2 files changed, 1076 deletions(-) delete mode 100644 python/paddle/v2/framework/proto/__init__.py delete mode 100644 python/paddle/v2/framework/proto/framework_pb2.py diff --git a/python/paddle/v2/framework/proto/__init__.py b/python/paddle/v2/framework/proto/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/python/paddle/v2/framework/proto/framework_pb2.py b/python/paddle/v2/framework/proto/framework_pb2.py deleted file mode 100644 index 950cd22907..0000000000 --- a/python/paddle/v2/framework/proto/framework_pb2.py +++ /dev/null @@ -1,1076 +0,0 @@ -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: framework.proto - -import sys -_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode('latin1')) -from google.protobuf.internal import enum_type_wrapper -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -from google.protobuf import descriptor_pb2 -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - -DESCRIPTOR = _descriptor.FileDescriptor( - name='framework.proto', - package='paddle.framework', - syntax='proto2', - serialized_pb=_b( - '\n\x0f\x66ramework.proto\x12\x10paddle.framework\"\x8c\x03\n\x06OpDesc\x12\x0c\n\x04type\x18\x03 \x02(\t\x12,\n\x06inputs\x18\x01 \x03(\x0b\x32\x1c.paddle.framework.OpDesc.Var\x12-\n\x07outputs\x18\x02 \x03(\x0b\x32\x1c.paddle.framework.OpDesc.Var\x12,\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32\x1d.paddle.framework.OpDesc.Attr\x1a\xbb\x01\n\x04\x41ttr\x12\x0c\n\x04name\x18\x01 \x02(\t\x12(\n\x04type\x18\x02 \x02(\x0e\x32\x1a.paddle.framework.AttrType\x12\t\n\x01i\x18\x03 \x01(\x05\x12\t\n\x01\x66\x18\x04 \x01(\x02\x12\t\n\x01s\x18\x05 \x01(\t\x12\x0c\n\x04ints\x18\x06 \x03(\x05\x12\x0e\n\x06\x66loats\x18\x07 \x03(\x02\x12\x0f\n\x07strings\x18\x08 \x03(\t\x12\t\n\x01\x62\x18\n \x01(\x08\x12\r\n\x05\x62ools\x18\x0b \x03(\x08\x12\x11\n\tblock_idx\x18\x0c \x01(\x05\x1a+\n\x03Var\x12\x11\n\tparameter\x18\x01 \x02(\t\x12\x11\n\targuments\x18\x02 \x03(\t\"\x9f\x03\n\x07OpProto\x12\x0c\n\x04type\x18\x01 \x02(\t\x12-\n\x06inputs\x18\x02 \x03(\x0b\x32\x1d.paddle.framework.OpProto.Var\x12.\n\x07outputs\x18\x03 \x03(\x0b\x32\x1d.paddle.framework.OpProto.Var\x12-\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32\x1e.paddle.framework.OpProto.Attr\x12\x0f\n\x07\x63omment\x18\x05 \x02(\t\x1a|\n\x03Var\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0f\n\x07\x63omment\x18\x02 \x02(\t\x12\x19\n\nduplicable\x18\x03 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0cintermediate\x18\x04 \x01(\x08:\x05\x66\x61lse\x12\x1e\n\x0fnot_in_gradient\x18\x05 \x01(\x08:\x05\x66\x61lse\x1ai\n\x04\x41ttr\x12\x0c\n\x04name\x18\x01 \x02(\t\x12(\n\x04type\x18\x02 \x02(\x0e\x32\x1a.paddle.framework.AttrType\x12\x0f\n\x07\x63omment\x18\x03 \x02(\t\x12\x18\n\tgenerated\x18\x04 \x01(\x08:\x05\x66\x61lse\"b\n\rLoDTensorDesc\x12-\n\tdata_type\x18\x01 \x02(\x0e\x32\x1a.paddle.framework.DataType\x12\x0c\n\x04\x64ims\x18\x02 \x03(\x03\x12\x14\n\tlod_level\x18\x03 \x01(\x05:\x01\x30\"L\n\x07VarDesc\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x33\n\nlod_tensor\x18\x02 \x01(\x0b\x32\x1f.paddle.framework.LoDTensorDesc\"|\n\tBlockDesc\x12\x0b\n\x03idx\x18\x01 \x02(\x05\x12\x12\n\nparent_idx\x18\x02 \x02(\x05\x12\'\n\x04vars\x18\x03 \x03(\x0b\x32\x19.paddle.framework.VarDesc\x12%\n\x03ops\x18\x04 \x03(\x0b\x32\x18.paddle.framework.OpDesc\":\n\x0bProgramDesc\x12+\n\x06\x62locks\x18\x01 \x03(\x0b\x32\x1b.paddle.framework.BlockDesc*s\n\x08\x41ttrType\x12\x07\n\x03INT\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\n\n\x06STRING\x10\x02\x12\x08\n\x04INTS\x10\x03\x12\n\n\x06\x46LOATS\x10\x04\x12\x0b\n\x07STRINGS\x10\x05\x12\x0b\n\x07\x42OOLEAN\x10\x06\x12\x0c\n\x08\x42OOLEANS\x10\x07\x12\t\n\x05\x42LOCK\x10\x08*S\n\x08\x44\x61taType\x12\x08\n\x04\x42OOL\x10\x00\x12\t\n\x05INT16\x10\x01\x12\t\n\x05INT32\x10\x02\x12\t\n\x05INT64\x10\x03\x12\x08\n\x04\x46P16\x10\x04\x12\x08\n\x04\x46P32\x10\x05\x12\x08\n\x04\x46P64\x10\x06' - )) -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -_ATTRTYPE = _descriptor.EnumDescriptor( - name='AttrType', - full_name='paddle.framework.AttrType', - filename=None, - file=DESCRIPTOR, - values=[ - _descriptor.EnumValueDescriptor( - name='INT', index=0, number=0, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='FLOAT', index=1, number=1, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='STRING', index=2, number=2, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='INTS', index=3, number=3, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='FLOATS', index=4, number=4, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='STRINGS', index=5, number=5, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='BOOLEAN', index=6, number=6, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='BOOLEANS', index=7, number=7, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='BLOCK', index=8, number=8, options=None, type=None), - ], - containing_type=None, - options=None, - serialized_start=1218, - serialized_end=1333, ) -_sym_db.RegisterEnumDescriptor(_ATTRTYPE) - -AttrType = enum_type_wrapper.EnumTypeWrapper(_ATTRTYPE) -_DATATYPE = _descriptor.EnumDescriptor( - name='DataType', - full_name='paddle.framework.DataType', - filename=None, - file=DESCRIPTOR, - values=[ - _descriptor.EnumValueDescriptor( - name='BOOL', index=0, number=0, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='INT16', index=1, number=1, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='INT32', index=2, number=2, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='INT64', index=3, number=3, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='FP16', index=4, number=4, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='FP32', index=5, number=5, options=None, type=None), - _descriptor.EnumValueDescriptor( - name='FP64', index=6, number=6, options=None, type=None), - ], - containing_type=None, - options=None, - serialized_start=1335, - serialized_end=1418, ) -_sym_db.RegisterEnumDescriptor(_DATATYPE) - -DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE) -INT = 0 -FLOAT = 1 -STRING = 2 -INTS = 3 -FLOATS = 4 -STRINGS = 5 -BOOLEAN = 6 -BOOLEANS = 7 -BLOCK = 8 -BOOL = 0 -INT16 = 1 -INT32 = 2 -INT64 = 3 -FP16 = 4 -FP32 = 5 -FP64 = 6 - -_OPDESC_ATTR = _descriptor.Descriptor( - name='Attr', - full_name='paddle.framework.OpDesc.Attr', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', - full_name='paddle.framework.OpDesc.Attr.name', - index=0, - number=1, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='type', - full_name='paddle.framework.OpDesc.Attr.type', - index=1, - number=2, - type=14, - cpp_type=8, - label=2, - has_default_value=False, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='i', - full_name='paddle.framework.OpDesc.Attr.i', - index=2, - number=3, - type=5, - cpp_type=1, - label=1, - has_default_value=False, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='f', - full_name='paddle.framework.OpDesc.Attr.f', - index=3, - number=4, - type=2, - cpp_type=6, - label=1, - has_default_value=False, - default_value=float(0), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='s', - full_name='paddle.framework.OpDesc.Attr.s', - index=4, - number=5, - type=9, - cpp_type=9, - label=1, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='ints', - full_name='paddle.framework.OpDesc.Attr.ints', - index=5, - number=6, - type=5, - cpp_type=1, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='floats', - full_name='paddle.framework.OpDesc.Attr.floats', - index=6, - number=7, - type=2, - cpp_type=6, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='strings', - full_name='paddle.framework.OpDesc.Attr.strings', - index=7, - number=8, - type=9, - cpp_type=9, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='b', - full_name='paddle.framework.OpDesc.Attr.b', - index=8, - number=10, - type=8, - cpp_type=7, - label=1, - has_default_value=False, - default_value=False, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='bools', - full_name='paddle.framework.OpDesc.Attr.bools', - index=9, - number=11, - type=8, - cpp_type=7, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='block_idx', - full_name='paddle.framework.OpDesc.Attr.block_idx', - index=10, - number=12, - type=5, - cpp_type=1, - label=1, - has_default_value=False, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=202, - serialized_end=389, ) - -_OPDESC_VAR = _descriptor.Descriptor( - name='Var', - full_name='paddle.framework.OpDesc.Var', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='parameter', - full_name='paddle.framework.OpDesc.Var.parameter', - index=0, - number=1, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='arguments', - full_name='paddle.framework.OpDesc.Var.arguments', - index=1, - number=2, - type=9, - cpp_type=9, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=391, - serialized_end=434, ) - -_OPDESC = _descriptor.Descriptor( - name='OpDesc', - full_name='paddle.framework.OpDesc', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='type', - full_name='paddle.framework.OpDesc.type', - index=0, - number=3, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='inputs', - full_name='paddle.framework.OpDesc.inputs', - index=1, - number=1, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='outputs', - full_name='paddle.framework.OpDesc.outputs', - index=2, - number=2, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='attrs', - full_name='paddle.framework.OpDesc.attrs', - index=3, - number=4, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[ - _OPDESC_ATTR, - _OPDESC_VAR, - ], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=38, - serialized_end=434, ) - -_OPPROTO_VAR = _descriptor.Descriptor( - name='Var', - full_name='paddle.framework.OpProto.Var', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', - full_name='paddle.framework.OpProto.Var.name', - index=0, - number=1, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='comment', - full_name='paddle.framework.OpProto.Var.comment', - index=1, - number=2, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='duplicable', - full_name='paddle.framework.OpProto.Var.duplicable', - index=2, - number=3, - type=8, - cpp_type=7, - label=1, - has_default_value=True, - default_value=False, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='intermediate', - full_name='paddle.framework.OpProto.Var.intermediate', - index=3, - number=4, - type=8, - cpp_type=7, - label=1, - has_default_value=True, - default_value=False, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='not_in_gradient', - full_name='paddle.framework.OpProto.Var.not_in_gradient', - index=4, - number=5, - type=8, - cpp_type=7, - label=1, - has_default_value=True, - default_value=False, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=621, - serialized_end=745, ) - -_OPPROTO_ATTR = _descriptor.Descriptor( - name='Attr', - full_name='paddle.framework.OpProto.Attr', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', - full_name='paddle.framework.OpProto.Attr.name', - index=0, - number=1, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='type', - full_name='paddle.framework.OpProto.Attr.type', - index=1, - number=2, - type=14, - cpp_type=8, - label=2, - has_default_value=False, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='comment', - full_name='paddle.framework.OpProto.Attr.comment', - index=2, - number=3, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='generated', - full_name='paddle.framework.OpProto.Attr.generated', - index=3, - number=4, - type=8, - cpp_type=7, - label=1, - has_default_value=True, - default_value=False, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=747, - serialized_end=852, ) - -_OPPROTO = _descriptor.Descriptor( - name='OpProto', - full_name='paddle.framework.OpProto', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='type', - full_name='paddle.framework.OpProto.type', - index=0, - number=1, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='inputs', - full_name='paddle.framework.OpProto.inputs', - index=1, - number=2, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='outputs', - full_name='paddle.framework.OpProto.outputs', - index=2, - number=3, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='attrs', - full_name='paddle.framework.OpProto.attrs', - index=3, - number=4, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='comment', - full_name='paddle.framework.OpProto.comment', - index=4, - number=5, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[ - _OPPROTO_VAR, - _OPPROTO_ATTR, - ], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=437, - serialized_end=852, ) - -_LODTENSORDESC = _descriptor.Descriptor( - name='LoDTensorDesc', - full_name='paddle.framework.LoDTensorDesc', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='data_type', - full_name='paddle.framework.LoDTensorDesc.data_type', - index=0, - number=1, - type=14, - cpp_type=8, - label=2, - has_default_value=False, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='dims', - full_name='paddle.framework.LoDTensorDesc.dims', - index=1, - number=2, - type=3, - cpp_type=2, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='lod_level', - full_name='paddle.framework.LoDTensorDesc.lod_level', - index=2, - number=3, - type=5, - cpp_type=1, - label=1, - has_default_value=True, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=854, - serialized_end=952, ) - -_VARDESC = _descriptor.Descriptor( - name='VarDesc', - full_name='paddle.framework.VarDesc', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', - full_name='paddle.framework.VarDesc.name', - index=0, - number=1, - type=9, - cpp_type=9, - label=2, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='lod_tensor', - full_name='paddle.framework.VarDesc.lod_tensor', - index=1, - number=2, - type=11, - cpp_type=10, - label=1, - has_default_value=False, - default_value=None, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=954, - serialized_end=1030, ) - -_BLOCKDESC = _descriptor.Descriptor( - name='BlockDesc', - full_name='paddle.framework.BlockDesc', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='idx', - full_name='paddle.framework.BlockDesc.idx', - index=0, - number=1, - type=5, - cpp_type=1, - label=2, - has_default_value=False, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='parent_idx', - full_name='paddle.framework.BlockDesc.parent_idx', - index=1, - number=2, - type=5, - cpp_type=1, - label=2, - has_default_value=False, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='vars', - full_name='paddle.framework.BlockDesc.vars', - index=2, - number=3, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='ops', - full_name='paddle.framework.BlockDesc.ops', - index=3, - number=4, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=1032, - serialized_end=1156, ) - -_PROGRAMDESC = _descriptor.Descriptor( - name='ProgramDesc', - full_name='paddle.framework.ProgramDesc', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='blocks', - full_name='paddle.framework.ProgramDesc.blocks', - index=0, - number=1, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[], - serialized_start=1158, - serialized_end=1216, ) - -_OPDESC_ATTR.fields_by_name['type'].enum_type = _ATTRTYPE -_OPDESC_ATTR.containing_type = _OPDESC -_OPDESC_VAR.containing_type = _OPDESC -_OPDESC.fields_by_name['inputs'].message_type = _OPDESC_VAR -_OPDESC.fields_by_name['outputs'].message_type = _OPDESC_VAR -_OPDESC.fields_by_name['attrs'].message_type = _OPDESC_ATTR -_OPPROTO_VAR.containing_type = _OPPROTO -_OPPROTO_ATTR.fields_by_name['type'].enum_type = _ATTRTYPE -_OPPROTO_ATTR.containing_type = _OPPROTO -_OPPROTO.fields_by_name['inputs'].message_type = _OPPROTO_VAR -_OPPROTO.fields_by_name['outputs'].message_type = _OPPROTO_VAR -_OPPROTO.fields_by_name['attrs'].message_type = _OPPROTO_ATTR -_LODTENSORDESC.fields_by_name['data_type'].enum_type = _DATATYPE -_VARDESC.fields_by_name['lod_tensor'].message_type = _LODTENSORDESC -_BLOCKDESC.fields_by_name['vars'].message_type = _VARDESC -_BLOCKDESC.fields_by_name['ops'].message_type = _OPDESC -_PROGRAMDESC.fields_by_name['blocks'].message_type = _BLOCKDESC -DESCRIPTOR.message_types_by_name['OpDesc'] = _OPDESC -DESCRIPTOR.message_types_by_name['OpProto'] = _OPPROTO -DESCRIPTOR.message_types_by_name['LoDTensorDesc'] = _LODTENSORDESC -DESCRIPTOR.message_types_by_name['VarDesc'] = _VARDESC -DESCRIPTOR.message_types_by_name['BlockDesc'] = _BLOCKDESC -DESCRIPTOR.message_types_by_name['ProgramDesc'] = _PROGRAMDESC -DESCRIPTOR.enum_types_by_name['AttrType'] = _ATTRTYPE -DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE - -OpDesc = _reflection.GeneratedProtocolMessageType( - 'OpDesc', - (_message.Message, ), - dict( - Attr=_reflection.GeneratedProtocolMessageType( - 'Attr', - (_message.Message, ), - dict( - DESCRIPTOR=_OPDESC_ATTR, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.OpDesc.Attr) - )), - Var=_reflection.GeneratedProtocolMessageType( - 'Var', - (_message.Message, ), - dict( - DESCRIPTOR=_OPDESC_VAR, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.OpDesc.Var) - )), - DESCRIPTOR=_OPDESC, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.OpDesc) - )) -_sym_db.RegisterMessage(OpDesc) -_sym_db.RegisterMessage(OpDesc.Attr) -_sym_db.RegisterMessage(OpDesc.Var) - -OpProto = _reflection.GeneratedProtocolMessageType( - 'OpProto', - (_message.Message, ), - dict( - Var=_reflection.GeneratedProtocolMessageType( - 'Var', - (_message.Message, ), - dict( - DESCRIPTOR=_OPPROTO_VAR, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.OpProto.Var) - )), - Attr=_reflection.GeneratedProtocolMessageType( - 'Attr', - (_message.Message, ), - dict( - DESCRIPTOR=_OPPROTO_ATTR, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.OpProto.Attr) - )), - DESCRIPTOR=_OPPROTO, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.OpProto) - )) -_sym_db.RegisterMessage(OpProto) -_sym_db.RegisterMessage(OpProto.Var) -_sym_db.RegisterMessage(OpProto.Attr) - -LoDTensorDesc = _reflection.GeneratedProtocolMessageType( - 'LoDTensorDesc', - (_message.Message, ), - dict( - DESCRIPTOR=_LODTENSORDESC, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.LoDTensorDesc) - )) -_sym_db.RegisterMessage(LoDTensorDesc) - -VarDesc = _reflection.GeneratedProtocolMessageType( - 'VarDesc', - (_message.Message, ), - dict( - DESCRIPTOR=_VARDESC, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.VarDesc) - )) -_sym_db.RegisterMessage(VarDesc) - -BlockDesc = _reflection.GeneratedProtocolMessageType( - 'BlockDesc', - (_message.Message, ), - dict( - DESCRIPTOR=_BLOCKDESC, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.BlockDesc) - )) -_sym_db.RegisterMessage(BlockDesc) - -ProgramDesc = _reflection.GeneratedProtocolMessageType( - 'ProgramDesc', - (_message.Message, ), - dict( - DESCRIPTOR=_PROGRAMDESC, - __module__='framework_pb2' - # @@protoc_insertion_point(class_scope:paddle.framework.ProgramDesc) - )) -_sym_db.RegisterMessage(ProgramDesc) - -# @@protoc_insertion_point(module_scope) From e97b89873a4ec2f57b54225b432eebbffad4fb2f Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Wed, 15 Nov 2017 12:25:23 -0800 Subject: [PATCH 18/25] "fix accuracy kernel bug" (#5673) * "fix accuracy kernel bug" * "relauch ci" --- paddle/operators/accuracy_op.cu | 23 +++++++++++++---------- paddle/platform/gpu_info.cc | 5 +++++ paddle/platform/gpu_info.h | 3 +++ 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu index b575c682f0..d2dcab4e54 100644 --- a/paddle/operators/accuracy_op.cu +++ b/paddle/operators/accuracy_op.cu @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/operators/accuracy_op.h" #include "paddle/platform/cuda_helper.h" +#include "paddle/platform/gpu_info.h" namespace paddle { namespace operators { @@ -73,26 +74,28 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { int num_samples = static_cast(inference->dims()[0]); size_t infer_width = inference->dims()[1]; - PADDLE_ENFORCE(cudaMemset(accuracy_data, 0, sizeof(float))); - // cudaMemset((void**)&correct_data, 0, sizeof(float)); + auto stream = ctx.cuda_device_context().stream(); + platform::GpuMemsetAsync(accuracy_data, 0, sizeof(float), stream); if (num_samples == 0) { return; } - cudaMemcpy(total_data, &num_samples, sizeof(int), cudaMemcpyHostToDevice); + platform::GpuMemcpyAsync(total_data, &num_samples, sizeof(int), + cudaMemcpyHostToDevice, stream); - AccuracyCudaKernel<<< - 1, PADDLE_CUDA_NUM_THREADS, 0, ctx.cuda_device_context().stream()>>>( + AccuracyCudaKernel< + PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( num_samples, infer_width, indices_data, label_data, correct_data, accuracy_data); int d_num_samples, d_num_correct; float d_accuracy; - cudaMemcpy(&d_num_correct, correct_data, sizeof(int), - cudaMemcpyDeviceToHost); - cudaMemcpy(&d_num_samples, total_data, sizeof(int), cudaMemcpyDeviceToHost); - cudaMemcpy(&d_accuracy, accuracy_data, sizeof(float), - cudaMemcpyDeviceToHost); + platform::GpuMemcpyAsync(&d_num_correct, correct_data, sizeof(int), + cudaMemcpyDeviceToHost, stream); + platform::GpuMemcpyAsync(&d_num_samples, total_data, sizeof(int), + cudaMemcpyDeviceToHost, stream); + platform::GpuMemcpyAsync(&d_accuracy, accuracy_data, sizeof(float), + cudaMemcpyDeviceToHost, stream); } }; diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc index f3455a8733..36b216d872 100644 --- a/paddle/platform/gpu_info.cc +++ b/paddle/platform/gpu_info.cc @@ -109,5 +109,10 @@ void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device, cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream), "cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeer"); } + +void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream) { + PADDLE_ENFORCE(cudaMemsetAsync(dst, value, count, stream), + "cudaMemsetAsync failed in paddle::platform::GpuMemsetAsync"); +} } // namespace platform } // namespace paddle diff --git a/paddle/platform/gpu_info.h b/paddle/platform/gpu_info.h index 37665b97d7..db961f3838 100644 --- a/paddle/platform/gpu_info.h +++ b/paddle/platform/gpu_info.h @@ -60,6 +60,9 @@ void GpuMemcpySync(void *dst, const void *src, size_t count, void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device, size_t count, cudaStream_t stream); +//! Set memory dst with value count size asynchronously +void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream); + } // namespace platform } // namespace paddle From 36f1d16f54f268c20d61628bdb358dd57eb8d32c Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Tue, 14 Nov 2017 14:13:14 -0800 Subject: [PATCH 19/25] Remove unused model.py --- python/paddle/v2/__init__.py | 2 - python/paddle/v2/model.py | 73 ------------------------------------ 2 files changed, 75 deletions(-) delete mode 100644 python/paddle/v2/model.py diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index 1c8d8f4b2f..f7ed42a397 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -33,7 +33,6 @@ import networks import minibatch import plot import image -import model import paddle.trainer.config_parser as cp __all__ = [ @@ -56,7 +55,6 @@ __all__ = [ 'evaluator', 'image', 'master', - 'model', ] cp.begin_parse() diff --git a/python/paddle/v2/model.py b/python/paddle/v2/model.py deleted file mode 100644 index 4634db55a9..0000000000 --- a/python/paddle/v2/model.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import errno -import uuid - -import paddle.v2.master - -__all__ = ["save_model", "load_model"] - -trainer_id = str(uuid.uuid4()) - - -def mkdir_p(path): - try: - os.makedirs(path) - except OSError as exc: - if exc.errno == errno.EEXIST and os.path.isdir(path): - pass - else: - raise - - -def save_model(parameters, path): - need_request = "KUBERNETES_SERVICE_HOST" in os.environ.keys() - - if need_request: - # TODO(helin): figure out how MPI trains, since MPI only save - # model when trainer_id == "0", we can consolidate the logic - # here. - - # TODO(helin): change this environment variable name from - # MASTER_IP to ETCD_IP - etcd_name = "MASTER_IP" - if etcd_name not in os.environ.keys(): - raise Exception('not find ' + etcd_name + - ' in environment variable.') - - etcd_ip = os.environ.get(etcd_name) - client = paddle.v2.master.client("http://" + etcd_ip + ":2379", 5, 0) - r = client.request_save_model(trainer_id, 5000) - if r == 0: - # do not need to save - return - elif r < 0: - # error - return - else: - # save model - path = os.path.join(path, trainer_id) - path = os.path.join(path, "model.tar") - - mkdir_p(path) - - with open(path, 'wb') as f: - parameters.to_tar(f) - - -def load_model(parameters, path): - with open(path, 'rb') as f: - parameters.from_tar(f) From 3dc8834209e03da8a53aa0a9a68872e980a7fd26 Mon Sep 17 00:00:00 2001 From: Markus Kliegl Date: Fri, 10 Nov 2017 23:04:33 +0000 Subject: [PATCH 20/25] conv shift op: change to CamelCase --- paddle/operators/conv_shift_op.cu | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/paddle/operators/conv_shift_op.cu b/paddle/operators/conv_shift_op.cu index 74ed1b0ed3..1db77657a0 100644 --- a/paddle/operators/conv_shift_op.cu +++ b/paddle/operators/conv_shift_op.cu @@ -22,7 +22,7 @@ using framework::Tensor; namespace { -inline int div_up(int x, int y) { return (x + y - 1) / y; } +inline int DivUp(int x, int y) { return (x + y - 1) / y; } // Some notes on the design: // @@ -33,9 +33,9 @@ inline int div_up(int x, int y) { return (x + y - 1) / y; } // y is fairly small. For large y, it would probably be more efficient // to also tile across y. template -__global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width, - int y_width, int y_half_width, - int batch_size) { +__global__ void ConvShiftForward(const T *x, const T *y, T *out, int x_width, + int y_width, int y_half_width, + int batch_size) { extern __shared__ T mem[]; int tx = threadIdx.x; @@ -79,8 +79,8 @@ __global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width, // Compute x gradient - initial naive implementation with atomic add. template -__global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width, - int y_width, int y_half_width, int batch_size) { +__global__ void ConvShiftGradX(const T *dout, const T *y, T *dx, int x_width, + int y_width, int y_half_width, int batch_size) { int i = blockIdx.x * blockDim.x + threadIdx.x; // x index int j = blockIdx.y; // y index int k = blockIdx.z; // batch index @@ -94,8 +94,8 @@ __global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width, // Compute y gradient - initial naive implementation with atomic add. template -__global__ void conv_shift_dy(const T *x, const T *dout, T *dy, int x_width, - int y_width, int y_half_width, int batch_size) { +__global__ void ConvShiftDy(const T *x, const T *dout, T *dy, int x_width, + int y_width, int y_half_width, int batch_size) { int i = blockIdx.x * blockDim.x + threadIdx.x; // x index int j = blockIdx.y; // y index int k = blockIdx.z; // batch index @@ -125,14 +125,14 @@ class ConvShiftKernel : public framework::OpKernel { int y_half_width = (y_width - 1) / 2; const int x_per_block = 256; - int num_x_blocks = div_up(x_width, x_per_block); + int num_x_blocks = DivUp(x_width, x_per_block); int mem_per_block = (x_per_block + 2 * y_width) * sizeof(T); dim3 grid_dim(num_x_blocks, batch_size); auto stream = context.cuda_device_context().stream(); - conv_shift_forward<<>>( + ConvShiftForward<<>>( x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size); } }; @@ -160,20 +160,20 @@ class ConvShiftGradKernel auto stream = context.cuda_device_context().stream(); const int x_per_block = 256; - int num_x_blocks = div_up(x_width, x_per_block); + int num_x_blocks = DivUp(x_width, x_per_block); dim3 grid_dim(num_x_blocks, y_width, batch_size); if (dX) { T *dx_data = dX->mutable_data(context.GetPlace()); cudaMemsetAsync(dx_data, 0, dX->numel() * sizeof(T), stream); - conv_shift_dx<<>>( + ConvShiftGradX<<>>( dout_data, y_data, dx_data, x_width, y_width, y_half_width, batch_size); } if (dY) { T *dy_data = dY->mutable_data(context.GetPlace()); cudaMemsetAsync(dy_data, 0, dY->numel() * sizeof(T), stream); - conv_shift_dy<<>>( + ConvShiftDy<<>>( x_data, dout_data, dy_data, x_width, y_width, y_half_width, batch_size); } From 42dd5da0fde79261af3c9bcf4f8fa716d515ef26 Mon Sep 17 00:00:00 2001 From: Markus Kliegl Date: Tue, 14 Nov 2017 04:23:52 +0000 Subject: [PATCH 21/25] conv shift: fix return before syncthreads --- paddle/operators/conv_shift_op.cu | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/paddle/operators/conv_shift_op.cu b/paddle/operators/conv_shift_op.cu index 1db77657a0..2a157f457a 100644 --- a/paddle/operators/conv_shift_op.cu +++ b/paddle/operators/conv_shift_op.cu @@ -62,19 +62,19 @@ __global__ void ConvShiftForward(const T *x, const T *y, T *out, int x_width, if (tx < num_x) { int load_i = (i - y_half_width + x_width) % x_width; sx[tx] = x[k * x_width + load_i]; - } else { - return; } __syncthreads(); - // Compute dot product of sx[tx:tx + y_width] and sy. - T sum = 0; - for (int j = 0; j < y_width; ++j) { - sum += sx[tx + j] * sy[j]; - } + if (tx < num_x) { + // Compute dot product of sx[tx:tx + y_width] and sy. + T sum = 0; + for (int j = 0; j < y_width; ++j) { + sum += sx[tx + j] * sy[j]; + } - // Save to out[k, i]. - out[k * x_width + i] = sum; + // Save to out[k, i]. + out[k * x_width + i] = sum; + } } // Compute x gradient - initial naive implementation with atomic add. From d0b601c4a8219eef669b7e530a047bf898cf4cdc Mon Sep 17 00:00:00 2001 From: Markus Kliegl Date: Wed, 15 Nov 2017 00:57:43 +0000 Subject: [PATCH 22/25] address PR feedback --- paddle/operators/conv_shift_op.cu | 37 +++++++++++++++++-------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/paddle/operators/conv_shift_op.cu b/paddle/operators/conv_shift_op.cu index 2a157f457a..95e13c38a8 100644 --- a/paddle/operators/conv_shift_op.cu +++ b/paddle/operators/conv_shift_op.cu @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/operators/conv_shift_op.h" +#include "paddle/operators/math/math_function.h" #include "paddle/platform/cuda_helper.h" namespace paddle { @@ -33,9 +34,9 @@ inline int DivUp(int x, int y) { return (x + y - 1) / y; } // y is fairly small. For large y, it would probably be more efficient // to also tile across y. template -__global__ void ConvShiftForward(const T *x, const T *y, T *out, int x_width, - int y_width, int y_half_width, - int batch_size) { +__global__ void ConvShiftForward(const T *x, const T *y, int x_width, + int y_width, int y_half_width, int batch_size, + T *out) { extern __shared__ T mem[]; int tx = threadIdx.x; @@ -79,8 +80,9 @@ __global__ void ConvShiftForward(const T *x, const T *y, T *out, int x_width, // Compute x gradient - initial naive implementation with atomic add. template -__global__ void ConvShiftGradX(const T *dout, const T *y, T *dx, int x_width, - int y_width, int y_half_width, int batch_size) { +__global__ void ConvShiftGradX(const T *dout, const T *y, int x_width, + int y_width, int y_half_width, int batch_size, + T *dx) { int i = blockIdx.x * blockDim.x + threadIdx.x; // x index int j = blockIdx.y; // y index int k = blockIdx.z; // batch index @@ -94,8 +96,8 @@ __global__ void ConvShiftGradX(const T *dout, const T *y, T *dx, int x_width, // Compute y gradient - initial naive implementation with atomic add. template -__global__ void ConvShiftDy(const T *x, const T *dout, T *dy, int x_width, - int y_width, int y_half_width, int batch_size) { +__global__ void ConvShiftDy(const T *x, const T *dout, int x_width, int y_width, + int y_half_width, int batch_size, T *dy) { int i = blockIdx.x * blockDim.x + threadIdx.x; // x index int j = blockIdx.y; // y index int k = blockIdx.z; // batch index @@ -133,7 +135,7 @@ class ConvShiftKernel : public framework::OpKernel { auto stream = context.cuda_device_context().stream(); ConvShiftForward<<>>( - x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size); + x_data, y_data, x_width, y_width, y_half_width, batch_size, out_data); } }; @@ -157,7 +159,8 @@ class ConvShiftGradKernel int y_width = Y->dims()[1]; int y_half_width = (y_width - 1) / 2; - auto stream = context.cuda_device_context().stream(); + auto &device_ctx = context.cuda_device_context(); + math::SetConstant zero; const int x_per_block = 256; int num_x_blocks = DivUp(x_width, x_per_block); @@ -165,17 +168,17 @@ class ConvShiftGradKernel if (dX) { T *dx_data = dX->mutable_data(context.GetPlace()); - cudaMemsetAsync(dx_data, 0, dX->numel() * sizeof(T), stream); - ConvShiftGradX<<>>( - dout_data, y_data, dx_data, x_width, y_width, y_half_width, - batch_size); + zero(device_ctx, dX, static_cast(0.0)); + ConvShiftGradX<<>>( + dout_data, y_data, x_width, y_width, y_half_width, batch_size, + dx_data); } if (dY) { T *dy_data = dY->mutable_data(context.GetPlace()); - cudaMemsetAsync(dy_data, 0, dY->numel() * sizeof(T), stream); - ConvShiftDy<<>>( - x_data, dout_data, dy_data, x_width, y_width, y_half_width, - batch_size); + zero(device_ctx, dY, static_cast(0.0)); + ConvShiftDy<<>>( + x_data, dout_data, x_width, y_width, y_half_width, batch_size, + dy_data); } } }; From e0e3a8a5bb2a33bf1953c7cebdedcf2ea5869b51 Mon Sep 17 00:00:00 2001 From: kavyasrinet Date: Wed, 15 Nov 2017 17:00:44 -0800 Subject: [PATCH 23/25] Updating the writeup of the RNN design doc and sequence_decoder (#5611) --- doc/design/ops/images/2_level_rnn.dot | 8 +- doc/design/ops/rnn.md | 66 +++++++-------- doc/design/ops/sequence_decoder.md | 114 +++++++++++--------------- 3 files changed, 86 insertions(+), 102 deletions(-) diff --git a/doc/design/ops/images/2_level_rnn.dot b/doc/design/ops/images/2_level_rnn.dot index a498e882a3..5d77865061 100644 --- a/doc/design/ops/images/2_level_rnn.dot +++ b/doc/design/ops/images/2_level_rnn.dot @@ -1,6 +1,6 @@ digraph G { - rnn [label="1-th level RNN" shape=box] + rnn [label="1st level RNN" shape=box] subgraph cluster0 { label = "time step 0" @@ -8,7 +8,7 @@ digraph G { sent0 [label="sentence"] sent1 [label="sentence"] - rnn1 [label="2-th level RNN" shape=box] + rnn1 [label="2nd level RNN" shape=box] sent0 -> rnn1 sent1 -> rnn1 @@ -20,7 +20,7 @@ digraph G { sent2 [label="sentence"] sent3 [label="sentence"] - rnn2 [label="2-th level RNN" shape=box] + rnn2 [label="2nd level RNN" shape=box] sent2 -> rnn2 sent3 -> rnn2 @@ -32,7 +32,7 @@ digraph G { sent4 [label="sentence"] sent5 [label="sentence"] - rnn3 [label="2-th level RNN" shape=box] + rnn3 [label="2nd level RNN" shape=box] sent4 -> rnn3 sent5 -> rnn3 diff --git a/doc/design/ops/rnn.md b/doc/design/ops/rnn.md index a78eea7d45..2f4854793f 100644 --- a/doc/design/ops/rnn.md +++ b/doc/design/ops/rnn.md @@ -1,62 +1,62 @@ # RNNOp design -This document is about an RNN operator which requires that instances in a mini-batch have the same length. We will have a more flexible RNN operator. +This document describes the RNN (Recurrent Neural Network) operator and how it is implemented in PaddlePaddle. The RNN op requires that all instances in a mini-batch have the same length. We will have a more flexible dynamic RNN operator in the future. ## RNN Algorithm Implementation -

+

The above diagram shows an RNN unrolled into a full network. -There are several important concepts: +There are several important concepts here: -- *step-net*: the sub-graph to run at each step, -- *memory*, $h_t$, the state of the current step, -- *ex-memory*, $h_{t-1}$, the state of the previous step, -- *initial memory value*, the ex-memory of the first step. +- *step-net*: the sub-graph that runs at each step. +- *memory*, $h_t$, the state of the current step. +- *ex-memory*, $h_{t-1}$, the state of the previous step. +- *initial memory value*, the memory of the first (initial) step. ### Step-scope -There could be local variables defined in step-nets. PaddlePaddle runtime realizes these variables in *step-scopes* -- scopes created for each step. +There could be local variables defined in each step-net. PaddlePaddle runtime realizes these variables in *step-scopes* which are created for each step. -

+


-Figure 2 the RNN's data flow +Figure 2 illustrates the RNN's data flow

-Please be aware that all steps run the same step-net. Each step +Please be aware that every step runs the same step-net. Each step does the following: -1. creates the step-scope, -2. realizes local variables, including step-outputs, in the step-scope, and -3. runs the step-net, which could use these variables. +1. Creates the step-scope. +2. Initializes the local variables including step-outputs, in the step-scope. +3. Runs the step-net, which uses the above mentioned variables. -The RNN operator will compose its output from step outputs in step scopes. +The RNN operator will compose its output from step outputs in each of the step scopes. ### Memory and Ex-memory -Let's give more details about memory and ex-memory via a simply example: +Let's give more details about memory and ex-memory using a simple example: $$ h_t = U h_{t-1} + W x_t $$, -where $h_t$ and $h_{t-1}$ are the memory and ex-memory of step $t$'s respectively. +where $h_t$ and $h_{t-1}$ are the memory and ex-memory (previous memory) of step $t$ respectively. -In the implementation, we can make an ex-memory variable either "refers to" the memory variable of the previous step, -or copy the value of the previous memory value to the current ex-memory variable. +In the implementation, we can make an ex-memory variable either "refer to" the memory variable of the previous step, +or copy the memory value of the previous step to the current ex-memory variable. ### Usage in Python For more information on Block, please refer to the [design doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.md). -We can define an RNN's step-net using Block: +We can define an RNN's step-net using a Block: ```python import paddle as pd -X = some_op() # x is some operator's output, and is a LoDTensor +X = some_op() # x is some operator's output and is a LoDTensor a = some_op() # declare parameters @@ -68,7 +68,7 @@ with rnn.stepnet(): x = rnn.add_input(X) # declare a memory (rnn's step) h = rnn.add_memory(init=a) - # h.pre_state() means previous memory of rnn + # h.pre_state(), the previous memory of rnn new_state = pd.add_two( pd.matmul(W, x) + pd.matmul(U, h.pre_state())) # update current memory h.update(new_state) @@ -80,19 +80,19 @@ out = rnn() Python API functions in above example: -- `rnn.add_input` indicates the parameter is a variable that will be segmented into step-inputs. -- `rnn.add_memory` creates a variable used as the memory. -- `rnn.add_outputs` mark the variables that will be concatenated across steps into the RNN output. +- `rnn.add_input`: indicates that the parameter is a variable that will be segmented into step-inputs. +- `rnn.add_memory`: creates a variable used as the memory. +- `rnn.add_outputs`: marks the variables that will be concatenated across steps into the RNN output. ### Nested RNN and LoDTensor An RNN whose step-net includes other RNN operators is known as an *nested RNN*. -For example, we could have a 2-level RNN, where the top level corresponds to paragraphs, and the lower level corresponds to sentences. +For example, we could have a 2-level RNN, where the top level corresponds to paragraphs, and the lower level corresponds to sentences. Each step of the higher level RNN also receives an input from the corresponding step of the lower level, and additionally the output from the previous time step at the same level. -The following figure illustrates the feeding of text into the lower level, one sentence each step, and the feeding of step outputs to the top level. The final top level output is about the whole text. +The following figure illustrates feeding in text into the lower level, one sentence at a step, and the feeding in step outputs to the top level. The final top level output is about the whole text. -

+

@@ -110,7 +110,7 @@ a = some_op() # chapter_data is a set of 128-dim word vectors # the first level of LoD is sentence -# the second level of LoD is chapter +# the second level of LoD is a chapter chapter_data = pd.Variable(shape=[None, 128], type=pd.lod_tensor, level=2) def lower_level_rnn(paragraph): @@ -138,14 +138,14 @@ with top_level_rnn.stepnet(): pd.matmul(W0, paragraph_data) + pd.matmul(U0, h.pre_state())) top_level_rnn.add_outputs(h) -# just output the last step +# output the last step chapter_out = top_level_rnn(output_all_steps=False) ``` -in above example, the construction of the `top_level_rnn` calls `lower_level_rnn`. The input is a LoD Tensor. The top level RNN segments input text data into paragraphs, and the lower level RNN segments each paragraph into sentences. +In the above example, the construction of the `top_level_rnn` calls `lower_level_rnn`. The input is an LoD Tensor. The top level RNN segments input text data into paragraphs, and the lower level RNN segments each paragraph into sentences. -By default, the `RNNOp` will concatenate the outputs from all the time steps, -if the `output_all_steps` set to False, it will only output the final time step. +By default, the `RNNOp` will concatenate the outputs from all the time steps. +If the `output_all_steps` is set to False, it will only output the final time step.

diff --git a/doc/design/ops/sequence_decoder.md b/doc/design/ops/sequence_decoder.md index 9007aae7a8..9db5fb8e9a 100644 --- a/doc/design/ops/sequence_decoder.md +++ b/doc/design/ops/sequence_decoder.md @@ -1,35 +1,28 @@ # Design: Sequence Decoder Generating LoDTensors -In tasks such as machine translation and image to text, -a [sequence decoder](https://github.com/PaddlePaddle/book/blob/develop/08.machine_translation/README.md) is necessary to generate sequences. +In tasks such as machine translation and visual captioning, +a [sequence decoder](https://github.com/PaddlePaddle/book/blob/develop/08.machine_translation/README.md) is necessary to generate sequences, one word at a time. This documentation describes how to implement the sequence decoder as an operator. ## Beam Search based Decoder -The [beam search algorithm](https://en.wikipedia.org/wiki/Beam_search) is necessary when generating sequences, -it is a heuristic search algorithm that explores the paths by expanding the most promising node in a limited set. +The [beam search algorithm](https://en.wikipedia.org/wiki/Beam_search) is necessary when generating sequences. It is a heuristic search algorithm that explores the paths by expanding the most promising node in a limited set. -In the old version of PaddlePaddle, a C++ class `RecurrentGradientMachine` implements the general sequence decoder based on beam search, -due to the complexity, the implementation relays on a lot of special data structures, -quite trivial and hard to be customized by users. +In the old version of PaddlePaddle, the C++ class `RecurrentGradientMachine` implements the general sequence decoder based on beam search, due to the complexity involved, the implementation relies on a lot of special data structures that are quite trivial and hard to be customized by users. -There are a lot of heuristic tricks in the sequence generation tasks, -so the flexibility of sequence decoder is very important to users. +There are a lot of heuristic tricks in the sequence generation tasks, so the flexibility of sequence decoder is very important to users. -During PaddlePaddle's refactoring work, -some new concept is proposed such as [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md) and [TensorArray](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/tensor_array.md) that can better support sequence usage, -and they can help to make the implementation of beam search based sequence decoder **more transparent and modular** . +During the refactoring of PaddlePaddle, some new concepts are proposed such as: [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md) and [TensorArray](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/tensor_array.md) that can better support the sequence usage, and they can also help make the implementation of beam search based sequence decoder **more transparent and modular** . -For example, the RNN sates, candidates IDs and probabilities of beam search can be represented as `LoDTensors`; +For example, the RNN states, candidates IDs and probabilities of beam search can be represented all as `LoDTensors`; the selected candidate's IDs in each time step can be stored in a `TensorArray`, and `Packed` to the sentences translated. ## Changing LoD's absolute offset to relative offsets -The current `LoDTensor` is designed to store levels of variable-length sequences, -it stores several arrays of integers each represents a level. +The current `LoDTensor` is designed to store levels of variable-length sequences. It stores several arrays of integers where each represents a level. -The integers in each level represents the begin and end (not inclusive) offset of a sequence **in the underlying tensor**, -let's call this format the **absolute-offset LoD** for clear. +The integers in each level represent the begin and end (not inclusive) offset of a sequence **in the underlying tensor**, +let's call this format the **absolute-offset LoD** for clarity. -The relative-offset LoD can fast retrieve any sequence but fails to represent empty sequences, for example, a two-level LoD is as follows +The relative-offset LoD can retrieve any sequence very quickly but fails to represent empty sequences, for example, a two-level LoD is as follows ```python [[0, 3, 9] [0, 2, 3, 3, 3, 9]] @@ -41,10 +34,9 @@ The first level tells that there are two sequences: while on the second level, there are several empty sequences that both begin and end at `3`. It is impossible to tell how many empty second-level sequences exist in the first-level sequences. -There are many scenarios that relay on empty sequence representation, -such as machine translation or image to text, one instance has no translations or the empty candidate set for a prefix. +There are many scenarios that rely on empty sequence representation, for example in machine translation or visual captioning, one instance has no translation or the empty candidate set for a prefix. -So let's introduce another format of LoD, +So let's introduce another format of LoD, it stores **the offsets of the lower level sequences** and is called **relative-offset** LoD. For example, to represent the same sequences of the above data @@ -54,19 +46,18 @@ For example, to represent the same sequences of the above data [0, 2, 3, 3, 3, 9]] ``` -the first level represents that there are two sequences, +the first level represents that there are two sequences, their offsets in the second-level LoD is `[0, 3)` and `[3, 5)`. The second level is the same with the relative offset example because the lower level is a tensor. It is easy to find out the second sequence in the first-level LoD has two empty sequences. -The following demos are based on relative-offset LoD. +The following examples are based on relative-offset LoD. ## Usage in a simple machine translation model -Let's start from a simple machine translation model that is simplified from [machine translation chapter](https://github.com/PaddlePaddle/book/tree/develop/08.machine_translation) to draw a simple blueprint of what a sequence decoder can do and how to use it. +Let's start from a simple machine translation model that is simplified from the [machine translation chapter](https://github.com/PaddlePaddle/book/tree/develop/08.machine_translation) to draw a blueprint of what a sequence decoder can do and how to use it. -The model has an encoder that learns the semantic vector from a sequence, -and a decoder which uses the sequence decoder to generate new sentences. +The model has an encoder that learns the semantic vector from a sequence, and a decoder which uses the sequence encoder to generate new sentences. **Encoder** ```python @@ -117,7 +108,7 @@ def generate(): # which means there are 2 sentences to translate # - the first sentence has 1 translation prefixes, the offsets are [0, 1) # - the second sentence has 2 translation prefixes, the offsets are [1, 3) and [3, 6) - # the target_word.lod is + # the target_word.lod is # [[0, 1, 6] # [0, 2, 4, 7, 9 12]] # which means 2 sentences to translate, each has 1 and 5 prefixes @@ -154,37 +145,36 @@ def generate(): translation_ids, translation_scores = decoder() ``` -The `decoder.beam_search` is a operator that given the candidates and the scores of translations including the candidates, -return the result of the beam search algorithm. +The `decoder.beam_search` is an operator that, given the candidates and the scores of translations including the candidates, +returns the result of the beam search algorithm. -In this way, users can customize anything on the inputs or outputs of beam search, for example, two ways to prune some translation prefixes +In this way, users can customize anything on the input or output of beam search, for example: -1. meke the correspondind elements in `topk_generated_scores` zero or some small values, beam_search will discard this candidate. -2. remove some specific candidate in `selected_ids` -3. get the final `translation_ids`, remove the translation sequence in it. +1. Make the corresponding elements in `topk_generated_scores` zero or some small values, beam_search will discard this candidate. +2. Remove some specific candidate in `selected_ids`. +3. Get the final `translation_ids`, remove the translation sequence in it. -The implementation of sequence decoder can reuse the C++ class [RNNAlgorithm](https://github.com/Superjom/Paddle/blob/68cac3c0f8451fe62a4cdf156747d6dc0ee000b3/paddle/operators/dynamic_recurrent_op.h#L30), -so the python syntax is quite similar to a [RNN](https://github.com/Superjom/Paddle/blob/68cac3c0f8451fe62a4cdf156747d6dc0ee000b3/doc/design/block.md#blocks-with-for-and-rnnop). +The implementation of sequence decoder can reuse the C++ class: [RNNAlgorithm](https://github.com/Superjom/Paddle/blob/68cac3c0f8451fe62a4cdf156747d6dc0ee000b3/paddle/operators/dynamic_recurrent_op.h#L30), +so the python syntax is quite similar to that of an [RNN](https://github.com/Superjom/Paddle/blob/68cac3c0f8451fe62a4cdf156747d6dc0ee000b3/doc/design/block.md#blocks-with-for-and-rnnop). -Both of them are two-level `LoDTensors` +Both of them are two-level `LoDTensors`: -- the first level represents `batch_size` of (source) sentences; -- the second level represents the candidate ID sets for translation prefix. +- The first level represents `batch_size` of (source) sentences. +- The second level represents the candidate ID sets for translation prefix. -for example, 3 source sentences to translate, and has 2, 3, 1 candidates. +For example, 3 source sentences to translate, and has 2, 3, 1 candidates. -Unlike an RNN, in sequence decoder, the previous state and the current state have different LoD and shape, -a `lod_expand` operator is used to expand the LoD of the previous state to fit the current state. +Unlike an RNN, in sequence decoder, the previous state and the current state have different LoD and shape, and an `lod_expand` operator is used to expand the LoD of the previous state to fit the current state. -For example, the previous state +For example, the previous state: * LoD is `[0, 1, 3][0, 2, 5, 6]` * content of tensor is `a1 a2 b1 b2 b3 c1` -the current state stored in `encoder_ctx_expanded` +the current state is stored in `encoder_ctx_expanded`: * LoD is `[0, 2, 7][0 3 5 8 9 11 11]` -* the content is +* the content is - a1 a1 a1 (a1 has 3 candidates, so the state should be copied 3 times for each candidates) - a2 a2 - b1 b1 b1 @@ -192,54 +182,48 @@ the current state stored in `encoder_ctx_expanded` - b3 b3 - None (c1 has 0 candidates, so c1 is dropped) -Benefit from the relative offset LoD, empty candidate set can be represented naturally. +The benefit from the relative offset LoD is that the empty candidate set can be represented naturally. -the status in each time step can be stored in `TensorArray`, and `Pack`ed to a final LoDTensor, the corresponding syntax is +The status in each time step can be stored in `TensorArray`, and `Pack`ed to a final LoDTensor. The corresponding syntax is: ```python decoder.output(selected_ids) decoder.output(selected_generation_scores) ``` -the `selected_ids` is the candidate ids for the prefixes, -it will be `Packed` by `TensorArray` to a two-level `LoDTensor`, -the first level represents the source sequences, -the second level represents generated sequences. +The `selected_ids` are the candidate ids for the prefixes, and will be `Packed` by `TensorArray` to a two-level `LoDTensor`, where the first level represents the source sequences and the second level represents generated sequences. -Pack the `selected_scores` will get a `LoDTensor` that stores scores of each candidate of translations. +Packing the `selected_scores` will get a `LoDTensor` that stores scores of each translation candidate. -Pack the `selected_generation_scores` will get a `LoDTensor`, and each tail is the probability of the translation. +Packing the `selected_generation_scores` will get a `LoDTensor`, and each tail is the probability of the translation. ## LoD and shape changes during decoding

-According the image above, the only phrase to change LoD is beam search. +According to the image above, the only phase that changes the LoD is beam search. ## Beam search design -The beam search algorthm will be implemented as one method of the sequence decoder, it has 3 inputs +The beam search algorithm will be implemented as one method of the sequence decoder and has 3 inputs: -1. `topk_ids`, top K candidate ids for each prefix. +1. `topk_ids`, the top K candidate ids for each prefix. 2. `topk_scores`, the corresponding scores for `topk_ids` 3. `generated_scores`, the score of the prefixes. -All of the are LoDTensors, so that the sequence affilication is clear. -Beam search will keep a beam for each prefix and select a smaller candidate set for each prefix. +All of these are LoDTensors, so that the sequence affiliation is clear. Beam search will keep a beam for each prefix and select a smaller candidate set for each prefix. -It will return three variables +It will return three variables: 1. `selected_ids`, the final candidate beam search function selected for the next step. 2. `selected_scores`, the scores for the candidates. -3. `generated_scores`, the updated scores for each prefixes (with the new candidates appended). +3. `generated_scores`, the updated scores for each prefix (with the new candidates appended). ## Introducing the LoD-based `Pack` and `Unpack` methods in `TensorArray` -The `selected_ids`, `selected_scores` and `generated_scores` are LoDTensors, -and they exist in each time step, +The `selected_ids`, `selected_scores` and `generated_scores` are LoDTensors that exist at each time step, so it is natural to store them in arrays. -Currently, PaddlePaddle has a module called `TensorArray` which can store an array of tensors, -the results of beam search are better to store in a `TensorArray`. +Currently, PaddlePaddle has a module called `TensorArray` which can store an array of tensors. It is better to store the results of beam search in a `TensorArray`. -The `Pack` and `UnPack` in `TensorArray` are used to package tensors in the array to a `LoDTensor` or split the `LoDTensor` to an array of tensors. -It needs some extensions to support pack or unpack an array of `LoDTensors`. +The `Pack` and `UnPack` in `TensorArray` are used to pack tensors in the array to an `LoDTensor` or split the `LoDTensor` to an array of tensors. +It needs some extensions to support the packing or unpacking an array of `LoDTensors`. From d7bf372d2682b4951308da47fcc444265ac80510 Mon Sep 17 00:00:00 2001 From: QI JUN Date: Wed, 15 Nov 2017 20:27:30 -0600 Subject: [PATCH 24/25] support adagrad sparse update (#5272) * adam sparse support * fix gpu build error * fix ci * fix ci * fix adagrad sparse update bug * fix gpu build error --- paddle/operators/CMakeLists.txt | 9 +- paddle/operators/adagrad_op.cc | 90 +++++++++++- paddle/operators/adagrad_op.cu | 135 +++++++++++++++++- paddle/operators/adagrad_op.h | 66 ++++++--- paddle/operators/sgd_op.cu | 15 +- paddle/operators/sum_op.cc | 1 - .../paddle/v2/fluid/tests/test_adagrad_op.py | 108 ++++++++++++++ 7 files changed, 386 insertions(+), 38 deletions(-) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 709f7de2e4..d7145798dd 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -174,13 +174,18 @@ set(DEPS_OPS array_to_lod_tensor_op lstm_op tensor_array_read_write_op - gru_op) + gru_op + adagrad_op + sgd_op) + op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) op_library(cross_entropy_op DEPS cross_entropy) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) +op_library(sum_op DEPS selected_rows_functor) +op_library(sgd_op DEPS selected_rows_functor) +op_library(adagrad_op DEPS selected_rows_functor) op_library(conv_op DEPS vol2col) -op_library(sum_op DEPS net_op selected_rows_functor) op_library(pool_op DEPS pooling) op_library(pool_with_index_op DEPS pooling) op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table) diff --git a/paddle/operators/adagrad_op.cc b/paddle/operators/adagrad_op.cc index 8d1a2b7938..d6686e3ef3 100644 --- a/paddle/operators/adagrad_op.cc +++ b/paddle/operators/adagrad_op.cc @@ -14,6 +14,11 @@ limitations under the License. */ #include "paddle/operators/adagrad_op.h" +#include + +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/selected_rows_functor.h" + namespace paddle { namespace operators { @@ -21,7 +26,7 @@ class AdagradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(Param) of AdagradOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), @@ -54,8 +59,8 @@ class AdagradOp : public framework::OperatorWithKernel { class AdagradOpMaker : public framework::OpProtoAndCheckerMaker { public: - AdagradOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + AdagradOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Param", "(Tensor) Input parameter"); AddInput("Grad", "(Tensor) Input gradient"); @@ -87,10 +92,85 @@ for numerical stability to avoid the division by zero error. )DOC"); } }; + +namespace { +size_t FindPos(const std::vector& rows, int64_t value) { + return std::find(rows.begin(), rows.end(), value) - rows.begin(); +} +} // namespace + +template +struct SparseAdagradFunctor { + void operator()(const platform::DeviceContext& context, + const framework::SelectedRows& grad, + const framework::Tensor& learning_rate, T epsilon, + framework::Tensor* moment, framework::Tensor* param) { + // 1. g_m.rows = set(g.rows) + auto grad_rows = grad.rows(); + std::set row_set(grad_rows.begin(), grad_rows.end()); + std::vector merge_rows(row_set.begin(), row_set.end()); + + auto grad_width = grad.value().dims()[1]; + std::unique_ptr grad_merge{ + new framework::SelectedRows()}; + grad_merge->set_rows(merge_rows); + grad_merge->set_height(grad.height()); + grad_merge->mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), grad_width}), + context.GetPlace()); + + math::SetConstant constant_functor; + constant_functor(context, grad_merge->mutable_value(), 0.0); + + auto* grad_merge_data = grad_merge->mutable_value()->data(); + auto* grad_data = grad.value().data(); + + for (size_t i = 0; i < grad_rows.size(); i++) { + size_t grad_merge_i = FindPos(merge_rows, grad_rows[i]); + for (int64_t j = 0; j < grad_width; j++) { + grad_merge_data[grad_merge_i * grad_width + j] += + grad_data[i * grad_width + j]; + } + } + + // 2. m += g_m * g_m + std::unique_ptr grad_square{ + new framework::SelectedRows()}; + grad_square->set_rows(grad_merge->rows()); + grad_square->set_height(grad_merge->height()); + grad_square->mutable_value()->mutable_data(grad_merge->value().dims(), + context.GetPlace()); + auto gs = + framework::EigenVector::Flatten(*(grad_square->mutable_value())); + auto gm = framework::EigenVector::Flatten(grad_merge->value()); + gs.device(*context.GetEigenDevice()) = gm * gm; + + math::SelectedRowsAddToTensor functor; + functor(context, *grad_square, moment); + + // 3. update parameter + auto* lr = learning_rate.data(); + auto* param_data = param->data(); + auto* moment_data = moment->data(); + + for (size_t i = 0; i < merge_rows.size(); i++) { + for (int64_t j = 0; j < grad_width; j++) { + param_data[merge_rows[i] * grad_width + j] -= + lr[0] * grad_merge_data[i * grad_width + j] / + (std::sqrt(moment_data[merge_rows[i] * grad_width + j]) + epsilon); + } + } + } +}; + +template struct SparseAdagradFunctor; +template struct SparseAdagradFunctor; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(adagrad, ops::AdagradOp, ops::AdagradOpMaker); -REGISTER_OP_CPU_KERNEL(adagrad, - ops::AdagradOpKernel); +REGISTER_OP_CPU_KERNEL( + adagrad, ops::AdagradOpKernel, + ops::AdagradOpKernel); diff --git a/paddle/operators/adagrad_op.cu b/paddle/operators/adagrad_op.cu index a5b7951121..5b869e6bc5 100644 --- a/paddle/operators/adagrad_op.cu +++ b/paddle/operators/adagrad_op.cu @@ -14,7 +14,138 @@ #define EIGEN_USE_GPU #include "paddle/operators/adagrad_op.h" +#include "paddle/operators/math/selected_rows_functor.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { + +namespace { + +template +__global__ void MergeGradKernel(const T* grad, const int64_t* grad_rows, + T* grad_merge, const int64_t* grad_merge_rows, + size_t grad_merge_rows_size, + int64_t row_numel) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + __shared__ size_t grad_merge_idx; + + if (tid == 0) { + for (size_t i = 0; i < grad_merge_rows_size; i++) { + if (grad_rows[ty] == grad_merge_rows[i]) { + grad_merge_idx = i; + } + } + } + + __syncthreads(); + + grad += ty * row_numel; + grad_merge += grad_merge_idx * row_numel; + for (int index = tid; index < row_numel; index += block_size) { + paddle::platform::CudaAtomicAdd(grad_merge + index, grad[index]); + } +} + +template +__global__ void SparseAdagradFunctorKernel(const T* grad, const int64_t* rows, + const T* learning_rate, T* param, + T* moment, int64_t row_numel, + T epsilon) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + + grad += ty * row_numel; + param += rows[ty] * row_numel; + moment += rows[ty] * row_numel; + + for (int index = tid; index < row_numel; index += block_size) { + // Since index in rows of SelectedRows can be duplicate, we have to use + // Atomic Operation to avoid concurrent write error. + paddle::platform::CudaAtomicAdd(param + index, + -1.0 * learning_rate[0] * grad[index] / + (sqrt(moment[index]) + epsilon)); + } +} +} // namespace + +template +struct SparseAdagradFunctor { + void operator()(const platform::DeviceContext& context, + const framework::SelectedRows& grad, + const framework::Tensor& learning_rate, T epsilon, + framework::Tensor* moment, framework::Tensor* param) { + // 1. g_m.rows = set(g.rows) + auto grad_rows = grad.rows(); + std::set row_set(grad_rows.begin(), grad_rows.end()); + std::vector merge_rows(row_set.begin(), row_set.end()); + + auto grad_width = grad.value().dims()[1]; + std::unique_ptr grad_merge{ + new framework::SelectedRows()}; + grad_merge->set_rows(merge_rows); + grad_merge->set_height(grad.height()); + grad_merge->mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), grad_width}), + context.GetPlace()); + + math::SetConstant constant_functor; + constant_functor(context, grad_merge->mutable_value(), 0.0); + + auto* grad_merge_data = grad_merge->mutable_value()->data(); + auto* grad_data = grad.value().data(); + + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid1(1, grad_rows.size()); + + MergeGradKernel< + T, 256><<(context) + .stream()>>>(grad_data, grad.rows().data(), + grad_merge_data, grad_merge->rows().data(), + grad_merge->rows().size(), grad_width); + + // 2. m += g_m * g_m + std::unique_ptr grad_square{ + new framework::SelectedRows()}; + grad_square->set_rows(grad_merge->rows()); + grad_square->set_height(grad_merge->height()); + grad_square->mutable_value()->mutable_data(grad_merge->value().dims(), + context.GetPlace()); + auto gs = + framework::EigenVector::Flatten(*(grad_square->mutable_value())); + auto gm = framework::EigenVector::Flatten(grad_merge->value()); + gs.device(*context.GetEigenDevice()) = gm * gm; + + math::SelectedRowsAddToTensor functor; + functor(context, *grad_square, moment); + + // 3. update parameter + auto* lr = learning_rate.data(); + auto* param_data = param->data(); + auto* moment_data = moment->data(); + + dim3 grid2(1, merge_rows.size()); + SparseAdagradFunctorKernel< + T, 256><<(context) + .stream()>>>(grad_merge_data, grad_merge->rows().data(), + lr, param_data, + moment_data, grad_width, epsilon); + } +}; + +template struct SparseAdagradFunctor; +template struct SparseAdagradFunctor; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(adagrad, - ops::AdagradOpKernel); +REGISTER_OP_GPU_KERNEL( + adagrad, ops::AdagradOpKernel, + ops::AdagradOpKernel); diff --git a/paddle/operators/adagrad_op.h b/paddle/operators/adagrad_op.h index c5d8f751d3..4d4a6434c7 100644 --- a/paddle/operators/adagrad_op.h +++ b/paddle/operators/adagrad_op.h @@ -19,35 +19,59 @@ limitations under the License. */ namespace paddle { namespace operators { +template +struct SparseAdagradFunctor { + void operator()(const platform::DeviceContext& context, + const framework::SelectedRows& grad, + const framework::Tensor& learning_rate, T epsilon, + framework::Tensor* moment, framework::Tensor* param); +}; + template class AdagradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto param_out_tensor = ctx.Output("ParamOut"); - auto moment_out_tensor = ctx.Output("MomentOut"); + auto* param_out_tensor = ctx.Output("ParamOut"); + auto* moment_out_tensor = ctx.Output("MomentOut"); param_out_tensor->mutable_data(ctx.GetPlace()); moment_out_tensor->mutable_data(ctx.GetPlace()); - float epsilon = ctx.Attr("epsilon"); - - auto param = framework::EigenVector::Flatten( - *ctx.Input("Param")); - auto grad = framework::EigenVector::Flatten( - *ctx.Input("Grad")); - auto moment = framework::EigenVector::Flatten( - *ctx.Input("Moment")); - auto lr = framework::EigenVector::Flatten( - *ctx.Input("LearningRate")); - - auto param_out = framework::EigenVector::Flatten(*param_out_tensor); - auto moment_out = framework::EigenVector::Flatten(*moment_out_tensor); - auto place = ctx.GetEigenDevice(); - - moment_out.device(place) = moment + grad * grad; - Eigen::DSizes m_dsize(moment_out_tensor->numel()); - param_out.device(place) = - param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon); + T epsilon = static_cast(ctx.Attr("epsilon")); + + auto* grad_var = ctx.InputVar("Grad"); + if (grad_var->IsType()) { + auto param = framework::EigenVector::Flatten( + *ctx.Input("Param")); + auto grad = framework::EigenVector::Flatten( + *ctx.Input("Grad")); + auto moment = framework::EigenVector::Flatten( + *ctx.Input("Moment")); + auto lr = framework::EigenVector::Flatten( + *ctx.Input("LearningRate")); + + auto param_out = framework::EigenVector::Flatten(*param_out_tensor); + auto moment_out = framework::EigenVector::Flatten(*moment_out_tensor); + auto place = ctx.GetEigenDevice(); + + moment_out.device(place) = moment + grad * grad; + Eigen::DSizes m_dsize(moment_out_tensor->numel()); + param_out.device(place) = + param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon); + } else if (grad_var->IsType()) { + auto* param_tensor = ctx.Input("Param"); + PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor); + + auto* moment_tensor = ctx.Input("Moment"); + PADDLE_ENFORCE_EQ(moment_tensor, moment_out_tensor); + + SparseAdagradFunctor functor; + functor(ctx.device_context(), *ctx.Input("Grad"), + *ctx.Input("LearningRate"), epsilon, + moment_out_tensor, param_out_tensor); + } else { + PADDLE_THROW("Unsupported Variable Type of Grad"); + } } }; diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu index 2f41c7fc12..7b6c5ec306 100644 --- a/paddle/operators/sgd_op.cu +++ b/paddle/operators/sgd_op.cu @@ -20,11 +20,11 @@ namespace paddle { namespace operators { namespace { -template +template __global__ void SparseSGDFunctorKernel(const T* selected_rows, const int64_t* rows, const T* learning_rate, T* tensor_out, - int64_t row_numel, int block_size) { + int64_t row_numel) { const int ty = blockIdx.y; int tid = threadIdx.x; @@ -59,14 +59,15 @@ struct SparseSGDFunctor { auto* in_data = in_value.data(); auto* out_data = output->data(); - int block_size = 256; + const int block_size = 256; dim3 threads(block_size, 1); dim3 grid(1, in_rows.size()); SparseSGDFunctorKernel< - T><<(context) - .stream()>>>(in_data, in_rows.data(), learning_rate.data(), - out_data, in_row_numel, block_size); + T, 256><<(context) + .stream()>>>(in_data, in_rows.data(), + learning_rate.data(), out_data, + in_row_numel); } }; diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc index 57b99bdb3a..9837f325e3 100644 --- a/paddle/operators/sum_op.cc +++ b/paddle/operators/sum_op.cc @@ -12,7 +12,6 @@ limitations under the License. */ #include "paddle/operators/sum_op.h" #include #include "paddle/framework/var_type_inference.h" -#include "paddle/operators/net_op.h" namespace paddle { namespace operators { diff --git a/python/paddle/v2/fluid/tests/test_adagrad_op.py b/python/paddle/v2/fluid/tests/test_adagrad_op.py index 66bad349e5..903e84c328 100644 --- a/python/paddle/v2/fluid/tests/test_adagrad_op.py +++ b/python/paddle/v2/fluid/tests/test_adagrad_op.py @@ -1,6 +1,9 @@ import unittest import numpy as np +import paddle.v2.fluid.core as core +from paddle.v2.fluid.op import Operator from op_test import OpTest +import math class TestAdagradOp1(OpTest): @@ -65,5 +68,110 @@ class TestAdagradOp2(OpTest): self.check_output() +class TestSparseAdagradOp(unittest.TestCase): + def check_with_place(self, place): + scope = core.Scope() + + # create and initialize Grad Variable + height = 10 + rows = [0, 4, 7, 4] + row_numel = 12 + + grad_selected_rows = scope.var('Grad').get_selected_rows() + grad_selected_rows.set_height(height) + grad_selected_rows.set_rows(rows) + np_array = np.ones((len(rows), row_numel)).astype("float32") + np_array[0, 0] = 2.0 + np_array[2, 8] = 4.0 + + grad_tensor = grad_selected_rows.get_tensor() + grad_tensor.set(np_array, place) + + # create and initialize Param Variable + param = scope.var('Param').get_tensor() + param_array = np.full((height, row_numel), 5.0).astype("float32") + param.set(param_array, place) + + # create and initialize LeraningRate Variable + lr = scope.var('LearningRate').get_tensor() + lr_array = np.full((1), 2.0).astype("float32") + lr.set(lr_array, place) + + # create and initialize moment Variable + moment = scope.var('Moment').get_tensor() + moment_np_array = np.full((height, row_numel), 2.0).astype("float32") + moment.set(moment_np_array, place) + + # create and run sgd operator + adagrad_op = Operator( + "adagrad", + Param='Param', + Grad='Grad', + ParamOut='Param', + Moment='Moment', + MomentOut='Moment', + LearningRate='LearningRate', + epsilon=2.0) + + ctx = core.DeviceContext.create(place) + adagrad_op.run(scope, ctx) + + # get and compare moment result + moment_result_array = np.array(moment) + + self.assertAlmostEqual(6.0, moment_result_array[rows[0], 0]) + self.assertAlmostEqual(3.0, moment_result_array[rows[0], 2]) + self.assertAlmostEqual(2.0, moment_result_array[1, 0]) + # 2.0 + (1.0 + 1.0)^2 + self.assertAlmostEqual(6.0, moment_result_array[rows[1], 10]) + self.assertAlmostEqual(6.0, moment_result_array[rows[3], 4]) + + self.assertAlmostEqual(2.0, moment_result_array[5, 8]) + self.assertAlmostEqual(3.0, moment_result_array[rows[2], 1]) + self.assertAlmostEqual(18.0, moment_result_array[rows[2], 8]) + + # get and compare param result + result_array = np.array(param) + + def get_out(param, lr, grad, m, epsilon): + return param - lr * grad / (math.sqrt(m) + epsilon) + + self.assertAlmostEqual( + get_out(5.0, 2.0, 2.0, 6.0, 2.0), + result_array[rows[0], 0], + places=5) + self.assertAlmostEqual( + get_out(5.0, 2.0, 1.0, 3.0, 2.0), + result_array[rows[0], 2], + places=5) + self.assertAlmostEqual( + get_out(5.0, 2.0, 0.0, 2.0, 2.0), result_array[1, 0], places=5) + + # grad_merge = 1.0 + 1.0 + # m = 6.0 + self.assertAlmostEqual( + get_out(5.0, 2.0, 2.0, 6.0, 2.0), + result_array[rows[1], 10], + places=5) + + self.assertAlmostEqual( + get_out(5.0, 2.0, 0.0, 2.0, 2.0), result_array[5, 8], places=5) + self.assertAlmostEqual( + get_out(5.0, 2.0, 1.0, 3.0, 2.0), + result_array[rows[2], 1], + places=5) + self.assertAlmostEqual( + get_out(5.0, 2.0, 4.0, 18.0, 2.0), + result_array[rows[2], 8], + places=5) + + def test_sparse_adagrad(self): + places = [core.CPUPlace()] + if core.is_compile_gpu(): + places.append(core.GPUPlace(0)) + for place in places: + self.check_with_place(place) + + if __name__ == "__main__": unittest.main() From 0ed5a5513f1be807e65aafcb0c7d61d1fe3bbb08 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Wed, 15 Nov 2017 19:59:07 -0800 Subject: [PATCH 25/25] "update doc" (#5682) --- paddle/memory/README.md | 141 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 139 insertions(+), 2 deletions(-) diff --git a/paddle/memory/README.md b/paddle/memory/README.md index 7f95e80f98..6cb003c50b 100644 --- a/paddle/memory/README.md +++ b/paddle/memory/README.md @@ -1,4 +1,141 @@ # Region-based Heterogeneous Memory Management +## Design -Please check out the [design documentation](http://gangliao.me) to find out more details about -buddy memory allocator for both CPU and GPU. +### Usage + +To allocate 4KB CPU memory: + +```cpp +p = memory::Alloc(platform::CPUPlace(), 4*1024); +``` + +To allocate 4KB memory on the 3rd GPU: + +```cpp +p = memory::Alloc(platform::GPUPlace(2), 4*1024); +``` + +To free memory and check the so-far used amount of memory on a place: + +```cpp +auto pl = platform::GPUPlace(0); +p = memory::Alloc(pl, 4*1024); +cout << memory::Used(pl); +memory::Free(pl, p); +``` + +### API + +In `paddle/memory/memory.h` we have: + +```cpp +namespace memory { +template void* Alloc(Place, size_t); +template void Free(Place, void*); +template size_t Used(Place); +} // namespace memory +``` + +These function templates have specializations on either `platform::CPUPlace` or `platform::GPUPlace`: + +```cpp +template<> +void* Alloc(CPUPlace p, size_t size) { + return GetCPUBuddyAllocator()->Alloc(size); +} +``` + +and + +```cpp +template<> +void Alloc(GPUPlace p, size_t size) { + return GetGPUBuddyAllocator(p.id)->Alloc(size); +} +``` + +Similar specializations exist for `Free` and `Used`. + +### Implementation + +`GetCPUBuddyAllocator` and `GetGPUBuddyAllocator` are singletions. + +```cpp +BuddyAllocator* GetCPUBuddyAllocator() { + static BuddyAllocator* a = NULL; + if (a == NULL) { + a = new BuddyAllocator(new CPUAllocator /*backup allocator*/, ...); + } + return a; +} + +BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { + static BuddyAllocator* as = NULL; + if (as == NULL) { + as = new BuddyAllocator*[platform::NumGPUs()]; + for (int gpu = 0; gpu < platform::NumGPUs(); gpu++) { + as[gpu] = new BuddyAllocator(new GPUAllocator(gpu) /* backup allocator */, ...); + } + } + return as[gpu_id); +``` + +#### `BuddyAllocator` + +`BuddyAllocator` implements the buddy allocation algorithm. Its constructor takes parameters only related with the algorithm: + +```cpp +BuddyAllocator::BuddyAllocator(initial_pool_size, max_pool_size) { + ... +} +``` + +Please be aware that **`BuddyAllocator` always allocate aligned memory**, aligned on 32-bytes, which can hold a `BuddyAllocator::Block` object: + +```cpp +class BuddyAllocator { + private: + struct Block { + size_t size; + Block* left, right; + size_t index; // allocator id + }; + ... +}; +``` + +Because BuddyAllocator has the meta-data of each block, it can trace the used memory -- record the amount returned by `Alloc` freed in `Free`. Instead, `CPUAllocator` and `GPUAllocator` doesn't know the size of freed memory block and cannot do the trace. + +#### System Allocators + +The `GPUAllocator` and `CPUAllocator` are calls *system allocators*. They work as the fallback allocators of `BuddyAllocator`. + +## Justification + +I got inspiration from Majel and Caffe2, though above design look different from both. + +### Caffe2 + +In Caffe2, `Tensor::mutable_data()` allocates the memroy. In particular, [`Tensor::mutable_data`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/tensor.h#L523) calls [`Tensor::raw_mutable_data`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/tensor.h#L459), which in turn calls [`Context::New`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/tensor.h#L479). + +There are two implementations of `Context`: + +1. [`CPUContext`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context.h#L105), whose [`New` method](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context.h#L131) calls [`g_cpu_allocator.get()->New(size_t)`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context.cc#L15) to allocate the memory. + +1. [`CUDAContext`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context_gpu.h#L99), which has a data member [`int gpu_id_`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context_gpu.h#L202). This looks very similar to class `majel::GPUPlace`, who also has an `int id_` data member. `CUDAContext::New(size_t)` calls [`g_cub_allocator->DeviceAllocate(&ptr, nbytes)`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context_gpu.cu#L355) to allocate the memory. + +### Majel + +In Majel, there are basically two allocator types: + +1. `cpu::SystemAllocator`, which has similar functionality to `caffe2::CPUContext::New/Delete`. +1. `gpu::SystemAllocator`, which has similar functionality to `caffe2::CUDAContext::New/Delete`. + +However, memory allocation is not via these two allocators. Instead, these two allocators are defined in hidden namespaces. + +In Majel there are hidden global variables like: + +1. `cpu::SystemAllocator g_cpu_allocator`, and +1. `vector g_gpu_allocators(NUM_GPUS)`. + +Programs allocate memory via a BuddyAllocator, which can take the `g_cpu_allocator` or a `g_gpu_allocators[gpu_id]` as its *fallback allocator*, so that if BuddyAllocator cannot find a block in its memory pool, it extends its memory pool by calling the fallback allocator's `New(size_t)`.