From ee0113af31a0ac678ae62190cf62fbc7c3c098d6 Mon Sep 17 00:00:00 2001 From: wanghaox Date: Fri, 1 Dec 2017 14:04:37 +0800 Subject: [PATCH 001/110] implement of prior box operator for ssd --- paddle/operators/prior_box_op.cc | 167 +++++++++++++++ paddle/operators/prior_box_op.cu | 20 ++ paddle/operators/prior_box_op.h | 199 ++++++++++++++++++ .../v2/fluid/tests/test_prior_box_op.py | 179 ++++++++++++++++ 4 files changed, 565 insertions(+) create mode 100644 paddle/operators/prior_box_op.cc create mode 100755 paddle/operators/prior_box_op.cu create mode 100644 paddle/operators/prior_box_op.h create mode 100644 python/paddle/v2/fluid/tests/test_prior_box_op.py diff --git a/paddle/operators/prior_box_op.cc b/paddle/operators/prior_box_op.cc new file mode 100644 index 0000000000..fe1ccceb06 --- /dev/null +++ b/paddle/operators/prior_box_op.cc @@ -0,0 +1,167 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/prior_box_op.h" + +namespace paddle { +namespace operators { + +class PriorBoxOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(X) of SequenceSliceOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Image"), + "Input(Offset) of SequenceSliceOp should not be null."); + + auto image_dims = ctx->GetInputDim("Image"); + auto input_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE(image_dims.size() == 4, + "The format of input tensor is NCHW."); + + auto min_sizes = ctx->Attrs().Get>("min_sizes"); + auto max_sizes = ctx->Attrs().Get>("max_sizes"); + auto variances = ctx->Attrs().Get>("variances"); + auto input_aspect_ratio = + ctx->Attrs().Get>("aspect_ratios"); + bool flip = ctx->Attrs().Get("flip"); + + PADDLE_ENFORCE_GT(min_sizes.size(), 0, "must provide min_size."); + for (size_t i = 0; i < min_sizes.size(); ++i) { + PADDLE_ENFORCE_GT(min_sizes[i], 0, "min_sizes[%d] must be positive.", i); + } + + std::vector aspect_ratios; + expand_aspect_ratios(input_aspect_ratio, flip, aspect_ratios); + + int num_priors = aspect_ratios.size() * min_sizes.size(); + if (max_sizes.size() > 0) { + PADDLE_ENFORCE_EQ(max_sizes.size(), min_sizes.size(), + "The length of min_size and max_size must be equal."); + for (size_t i = 0; i < min_sizes.size(); ++i) { + PADDLE_ENFORCE_GT(max_sizes[i], min_sizes[i], + "max_size[%d] must be greater than min_size[%d].", i, + i); + num_priors += 1; + } + } + + if (variances.size() > 1) { + PADDLE_ENFORCE_EQ(variances.size(), 4, + "Must and only provide 4 variance."); + for (size_t i = 0; i < variances.size(); ++i) { + PADDLE_ENFORCE_GT(variances[i], 0.0, + "variance[%d] must be greater than 0.", i); + } + } else if (variances.size() == 1) { + PADDLE_ENFORCE_GT(variances[0], 0.0, + "variance[0] must be greater than 0."); + } + + const int img_h = ctx->Attrs().Get("img_h"); + PADDLE_ENFORCE_GT(img_h, 0, "img_h should be larger than 0."); + const int img_w = ctx->Attrs().Get("img_w"); + PADDLE_ENFORCE_GT(img_w, 0, "img_w should be larger than 0."); + + const float step_h = ctx->Attrs().Get("step_h"); + PADDLE_ENFORCE_GT(step_h, 0.0, "step_h should be larger than 0."); + const float step_w = ctx->Attrs().Get("step_w"); + PADDLE_ENFORCE_GT(step_w, 0.0, "step_w should be larger than 0."); + + const int layer_height = input_dims[3]; + const int layer_width = input_dims[2]; + + std::vector dim_vec(3); + // Since all images in a batch has same height and width, we only need to + // generate one set of priors which can be shared across all images. + dim_vec[0] = 1; + // 2 channels. First channel stores the mean of each prior coordinate. + // Second channel stores the variance of each prior coordinate. + dim_vec[1] = 2; + dim_vec[2] = layer_width * layer_height * num_priors * 4; + PADDLE_ENFORCE_GT(dim_vec[2], 0, + "output_dim[2] must larger than 0." + "check your data dims"); + auto output_dim = framework::make_ddim(dim_vec); + ctx->SetOutputDim("Out", output_dim); + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Image")->type()), + ctx.device_context()); + } +}; + +class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { + public: + PriorBoxOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Input", + "(Tensor), " + "the input feature data of PriorBoxOp."); + AddInput("Image", + "(Tensor), " + "the input image data of PriorBoxOp."); + AddOutput("Out", "(Tensor), the output prior boxes of PriorBoxOp."); + AddAttr>("min_sizes", "(vector) ", + "List of min sizes of generated prior boxes."); + AddAttr>("max_sizes", "(vector) ", + "List of max sizes of generated prior boxes."); + AddAttr>( + "aspect_ratios", "(vector) ", + "List of aspect ratios of generated prior boxes.") + .SetDefault({}); + AddAttr>( + "variances", "(vector) ", + "List of variances to be encoded in prior boxes.") + .SetDefault({0.1}); + AddAttr("flip", "(bool) ", "Whether to flip aspect ratios.") + .SetDefault(true); + AddAttr("clip", "(bool) ", "Whether to clip out-of-boundary boxes.") + .SetDefault(true); + AddAttr("img_w", "").SetDefault(0); + AddAttr("img_h", "").SetDefault(0); + AddAttr("step_w", + "Prior boxes step across width, 0 for auto calculation.") + .SetDefault(0.0); + AddAttr("step_h", + "Prior boxes step across height, 0 for auto calculation.") + .SetDefault(0.0); + AddAttr("offset", + "(float) " + "Prior boxes center offset.") + .SetDefault(0.5); + AddComment(R"DOC( +Prior box operator +Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm. +Please get more information from the following papers: +https://arxiv.org/abs/1512.02325. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker); +REGISTER_OP_CPU_KERNEL( + prior_box, ops::PriorBoxOpKernel, + ops::PriorBoxOpKernel); diff --git a/paddle/operators/prior_box_op.cu b/paddle/operators/prior_box_op.cu new file mode 100755 index 0000000000..d1928462a2 --- /dev/null +++ b/paddle/operators/prior_box_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/prior_box_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + prior_box, ops::PriorBoxOpKernel, + ops::PriorBoxOpKernel); diff --git a/paddle/operators/prior_box_op.h b/paddle/operators/prior_box_op.h new file mode 100644 index 0000000000..6dabba5265 --- /dev/null +++ b/paddle/operators/prior_box_op.h @@ -0,0 +1,199 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +// #include "paddle/operators/strided_memcpy.h" + +namespace paddle { +namespace operators { + +inline void expand_aspect_ratios(const std::vector input_aspect_ratior, + bool flip, + std::vector& output_aspect_ratior) { + constexpr float eps = 1e-6; + output_aspect_ratior.clear(); + output_aspect_ratior.push_back(1.); + for (size_t i = 0; i < input_aspect_ratior.size(); ++i) { + float ar = input_aspect_ratior[i]; + bool already_exist = false; + for (size_t j = 0; j < output_aspect_ratior.size(); ++j) { + if (fabs(ar - output_aspect_ratior[j]) < eps) { + already_exist = true; + break; + } + } + if (!already_exist) { + output_aspect_ratior.push_back(ar); + if (flip) { + output_aspect_ratior.push_back(1. / ar); + } + } + } +} + +template +class PriorBoxOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* image = ctx.Input("Image"); + auto* out = ctx.Output("Out"); + + auto min_sizes = ctx.Attr>("min_sizes"); + auto max_sizes = ctx.Attr>("max_sizes"); + auto input_aspect_ratio = ctx.Attr>("aspect_ratios"); + auto variances = ctx.Attr>("variances"); + auto flip = ctx.Attr("flip"); + auto clip = ctx.Attr("clip"); + + std::vector aspect_ratios; + expand_aspect_ratios(input_aspect_ratio, flip, aspect_ratios); + + auto img_w = ctx.Attr("img_w"); + auto img_h = ctx.Attr("img_h"); + auto step_w = ctx.Attr("step_w"); + auto step_h = ctx.Attr("step_h"); + auto offset = ctx.Attr("offset"); + + int img_width, img_height; + if (img_h == 0 || img_w == 0) { + img_width = image->dims()[2]; + img_height = image->dims()[3]; + } else { + img_width = img_w; + img_height = img_h; + } + + const int layer_width = input->dims()[2]; + const int layer_height = input->dims()[3]; + + float step_width, step_height; + if (step_w == 0 || step_h == 0) { + step_width = static_cast(img_width) / layer_width; + step_height = static_cast(img_height) / layer_height; + } else { + step_width = step_w; + step_height = step_h; + } + + int num_priors = aspect_ratios.size() * min_sizes.size(); + if (max_sizes.size() > 0) { + num_priors += max_sizes.size(); + } + + int dim = layer_height * layer_width * num_priors * 4; + + T* output_data = nullptr; + framework::Tensor output_cpu; + out->mutable_data(ctx.GetPlace()); + if (platform::is_gpu_place(ctx.GetPlace())) { + output_data = + output_cpu.mutable_data(out->dims(), platform::CPUPlace()); + } else { + output_data = out->mutable_data(ctx.GetPlace()); + } + + int idx = 0; + for (int h = 0; h < layer_height; ++h) { + for (int w = 0; w < layer_width; ++w) { + float center_x = (w + offset) * step_width; + float center_y = (h + offset) * step_height; + float box_width, box_height; + for (size_t s = 0; s < min_sizes.size(); ++s) { + int min_size = min_sizes[s]; + // first prior: aspect_ratio = 1, size = min_size + box_width = box_height = min_size; + // xmin + output_data[idx++] = (center_x - box_width / 2.) / img_width; + // ymin + output_data[idx++] = (center_y - box_height / 2.) / img_height; + // xmax + output_data[idx++] = (center_x + box_width / 2.) / img_width; + // ymax + output_data[idx++] = (center_y + box_height / 2.) / img_height; + + if (max_sizes.size() > 0) { + int max_size = max_sizes[s]; + // second prior: aspect_ratio = 1, + // size = sqrt(min_size * max_size) + box_width = box_height = sqrt(min_size * max_size); + // xmin + output_data[idx++] = (center_x - box_width / 2.) / img_width; + // ymin + output_data[idx++] = (center_y - box_height / 2.) / img_height; + // xmax + output_data[idx++] = (center_x + box_width / 2.) / img_width; + // ymax + output_data[idx++] = (center_y + box_height / 2.) / img_height; + } + + // rest of priors + for (size_t r = 0; r < aspect_ratios.size(); ++r) { + float ar = aspect_ratios[r]; + if (fabs(ar - 1.) < 1e-6) { + continue; + } + box_width = min_size * sqrt(ar); + box_height = min_size / sqrt(ar); + // xmin + output_data[idx++] = (center_x - box_width / 2.) / img_width; + // ymin + output_data[idx++] = (center_y - box_height / 2.) / img_height; + // xmax + output_data[idx++] = (center_x + box_width / 2.) / img_width; + // ymax + output_data[idx++] = (center_y + box_height / 2.) / img_height; + } + } + } + } + + // clip the prior's coordidate such that it is within [0, 1] + if (clip) { + for (int d = 0; d < dim; ++d) { + output_data[d] = std::min(std::max(output_data[d], 0.), 1.); + } + } + + // set the variance. + auto output_stride = framework::stride(out->dims()); + output_data += output_stride[1]; + if (variances.size() == 1) { + for (int i = 0; i < dim; ++i) { + output_data[i] = variances[0]; + } + } else { + int count = 0; + for (int h = 0; h < layer_height; ++h) { + for (int w = 0; w < layer_width; ++w) { + for (int i = 0; i < num_priors; ++i) { + for (int j = 0; j < 4; ++j) { + output_data[count] = variances[j]; + ++count; + } + } + } + } + } + if (platform::is_gpu_place(ctx.GetPlace())) { + framework::CopyFrom(output_cpu, platform::CPUPlace(), + ctx.device_context(), out); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_prior_box_op.py b/python/paddle/v2/fluid/tests/test_prior_box_op.py new file mode 100644 index 0000000000..2f82188952 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_prior_box_op.py @@ -0,0 +1,179 @@ +import unittest +import numpy as np +import sys +import math +from op_test import OpTest + + +class TestPriorBoxOp(OpTest): + def set_data(self): + self.init_test_params() + self.init_test_input() + self.init_test_output() + self.inputs = {'Input': self.input, 'Image': self.image} + + self.attrs = { + 'min_sizes': self.min_sizes, + 'max_sizes': self.max_sizes, + 'aspect_ratios': self.aspect_ratios, + 'variances': self.variances, + 'flip': self.flip, + 'clip': self.clip, + 'step_w': self.step_w, + 'step_h': self.step_h, + 'img_w': self.image_w, + 'img_h': self.image_h, + 'offset': self.offset + } + + self.outputs = {'Out': self.output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + return + + def setUp(self): + self.op_type = "prior_box" + self.set_data() + + def init_test_params(self): + self.layer_w = 4 + self.layer_h = 4 + + self.image_w = 20 + self.image_h = 20 + + self.step_w = float(self.image_w) / float(self.layer_w) + self.step_h = float(self.image_h) / float(self.layer_h) + + self.input_channels = 2 + self.image_channels = 3 + self.batch_size = 10 + + self.min_sizes = [2, 4] + self.min_sizes = np.array(self.min_sizes).astype('int64') + self.max_sizes = [5, 10] + self.max_sizes = np.array(self.max_sizes).astype('int64') + self.aspect_ratios = [2.0, 3.0] + self.flip = True + self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0] + self.aspect_ratios = np.array( + self.aspect_ratios, dtype=np.float).flatten() + self.variances = [0.1, 0.1, 0.2, 0.2] + self.variances = np.array(self.variances, dtype=np.float).flatten() + + self.clip = True + + self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes) + if len(self.max_sizes) > 1: + self.num_priors += len(self.max_sizes) + self.offset = 0.5 + + def init_test_input(self): + self.image = np.random.random( + (self.batch_size, self.image_channels, self.image_w, + self.image_h)).astype('float32') + + self.input = np.random.random( + (self.batch_size, self.input_channels, self.layer_w, + self.layer_h)).astype('float32') + + def init_test_output(self): + dim = self.layer_w * self.layer_h * self.num_priors * 4 + out_dim = (1, 2, dim) + output = np.zeros(out_dim).astype('float32') + + idx = 0 + for h in range(self.layer_h): + for w in range(self.layer_w): + center_x = (w + self.offset) * self.step_w + center_y = (h + self.offset) * self.step_h + for s in range(len(self.min_sizes)): + min_size = self.min_sizes[s] + # first prior: aspect_ratio = 1, size = min_size + box_width = box_height = min_size + # xmin + output[0, 0, idx] = ( + center_x - box_width / 2.) / self.image_w + idx += 1 + # ymin + output[0, 0, idx] = ( + center_y - box_height / 2.) / self.image_h + idx += 1 + # xmax + output[0, 0, idx] = ( + center_x + box_width / 2.) / self.image_w + idx += 1 + # ymax + output[0, 0, idx] = ( + center_y + box_height / 2.) / self.image_h + idx += 1 + + if len(self.max_sizes) > 0: + max_size = self.max_sizes[s] + # second prior: aspect_ratio = 1, + # size = sqrt(min_size * max_size) + box_width = box_height = math.sqrt(min_size * max_size) + # xmin + output[0, 0, idx] = ( + center_x - box_width / 2.) / self.image_w + idx += 1 + # ymin + output[0, 0, idx] = ( + center_y - box_height / 2.) / self.image_h + idx += 1 + # xmax + output[0, 0, idx] = ( + center_x + box_width / 2.) / self.image_w + idx += 1 + # ymax + output[0, 0, idx] = ( + center_y + box_height / 2.) / self.image_h + idx += 1 + + # rest of priors + for r in range(len(self.real_aspect_ratios)): + ar = self.real_aspect_ratios[r] + if math.fabs(ar - 1.) < 1e-6: + continue + box_width = min_size * math.sqrt(ar) + box_height = min_size / math.sqrt(ar) + # xmin + output[0, 0, idx] = ( + center_x - box_width / 2.) / self.image_w + idx += 1 + # ymin + output[0, 0, idx] = ( + center_y - box_height / 2.) / self.image_h + idx += 1 + # xmax + output[0, 0, idx] = ( + center_x + box_width / 2.) / self.image_w + idx += 1 + # ymax + output[0, 0, idx] = ( + center_y + box_height / 2.) / self.image_h + idx += 1 + # clip the prior's coordidate such that it is within[0, 1] + if self.clip: + for d in range(dim): + output[0, 0, d] = min(max(output[0, 0, d], 0), 1) + # set the variance. + if len(self.variances) == 1: + for i in range(dim): + output[0, 1, i] = self.variances[0] + else: + count = 0 + for h in range(self.layer_h): + for w in range(self.layer_w): + for i in range(self.num_priors): + for j in range(4): + output[0, 1, count] = self.variances[j] + count += 1 + self.output = output.astype('float32') + + +if __name__ == '__main__': + unittest.main() From 99a6c5d40edd29c4fedc8f50fe4db75177fb255d Mon Sep 17 00:00:00 2001 From: wanghaox Date: Tue, 9 Jan 2018 20:22:17 +0800 Subject: [PATCH 002/110] change output shape to [2, layer_height, layer_width, num_priors, 4] --- paddle/operators/prior_box_op.cc | 20 +++--- paddle/operators/prior_box_op.h | 71 ++++++++++--------- .../v2/fluid/tests/test_prior_box_op.py | 58 +++++++-------- 3 files changed, 72 insertions(+), 77 deletions(-) diff --git a/paddle/operators/prior_box_op.cc b/paddle/operators/prior_box_op.cc index 04182cb1b7..2ffea67bdd 100644 --- a/paddle/operators/prior_box_op.cc +++ b/paddle/operators/prior_box_op.cc @@ -93,17 +93,12 @@ class PriorBoxOp : public framework::OperatorWithKernel { const int layer_height = input_dims[2]; const int layer_width = input_dims[3]; - std::vector dim_vec(3); - // Since all images in a batch has same height and width, we only need to - // generate one set of priors which can be shared across all images. - dim_vec[0] = 1; - // 2 channels. First channel stores the mean of each prior coordinate. - // Second channel stores the variance of each prior coordinate. - dim_vec[1] = 2; - dim_vec[2] = layer_width * layer_height * num_priors * 4; - PADDLE_ENFORCE_GT(dim_vec[2], 0, - "output_dim[2] must larger than 0." - "check your data dims"); + std::vector dim_vec(5); + dim_vec[0] = 2; + dim_vec[1] = layer_height; + dim_vec[2] = layer_width; + dim_vec[3] = num_priors; + dim_vec[4] = 4; auto output_dim = framework::make_ddim(dim_vec); ctx->SetOutputDim("Out", output_dim); } @@ -130,7 +125,8 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { "the input image data of PriorBoxOp, The format is NCHW."); AddOutput("Out", "(Tensor, default Tensor), the output prior boxes of " - "PriorBoxOp."); + "PriorBoxOp. The format is [2, layer_height, layer_width, " + "num_priors, 4]"); AddAttr>("min_sizes", "(vector) ", "List of min sizes of generated prior boxes."); AddAttr>("max_sizes", "(vector) ", diff --git a/paddle/operators/prior_box_op.h b/paddle/operators/prior_box_op.h index 142e738a93..86399b53c3 100644 --- a/paddle/operators/prior_box_op.h +++ b/paddle/operators/prior_box_op.h @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include "paddle/framework/op_registry.h" #include "paddle/operators/math/math_function.h" -// #include "paddle/operators/strided_memcpy.h" namespace paddle { namespace operators { @@ -94,50 +93,52 @@ class PriorBoxOpKernel : public framework::OpKernel { num_priors += max_sizes.size(); } - int dim = layer_height * layer_width * num_priors * 4; - T* output_data = nullptr; framework::Tensor output_cpu; + framework::Tensor* output_tensor; out->mutable_data(ctx.GetPlace()); if (platform::is_gpu_place(ctx.GetPlace())) { - output_data = - output_cpu.mutable_data(out->dims(), platform::CPUPlace()); + output_cpu.mutable_data(out->dims(), platform::CPUPlace()); + output_tensor = &output_cpu; } else { - output_data = out->mutable_data(ctx.GetPlace()); + output_tensor = out; } - int idx = 0; + auto e_out = framework::EigenTensor::From(*output_tensor); for (int h = 0; h < layer_height; ++h) { for (int w = 0; w < layer_width; ++w) { float center_x = (w + offset) * step_width; float center_y = (h + offset) * step_height; float box_width, box_height; + int idx = 0; for (size_t s = 0; s < min_sizes.size(); ++s) { int min_size = min_sizes[s]; // first prior: aspect_ratio = 1, size = min_size box_width = box_height = min_size; // xmin - output_data[idx++] = (center_x - box_width / 2.) / img_width; + e_out(0, h, w, idx, 0) = (center_x - box_width / 2.) / img_width; // ymin - output_data[idx++] = (center_y - box_height / 2.) / img_height; + e_out(0, h, w, idx, 1) = (center_y - box_height / 2.) / img_height; // xmax - output_data[idx++] = (center_x + box_width / 2.) / img_width; + e_out(0, h, w, idx, 2) = (center_x + box_width / 2.) / img_width; // ymax - output_data[idx++] = (center_y + box_height / 2.) / img_height; + e_out(0, h, w, idx, 3) = (center_y + box_height / 2.) / img_height; + idx++; if (max_sizes.size() > 0) { int max_size = max_sizes[s]; // second prior: aspect_ratio = 1, // size = sqrt(min_size * max_size) box_width = box_height = sqrt(min_size * max_size); // xmin - output_data[idx++] = (center_x - box_width / 2.) / img_width; + e_out(0, h, w, idx, 0) = (center_x - box_width / 2.) / img_width; // ymin - output_data[idx++] = (center_y - box_height / 2.) / img_height; + e_out(0, h, w, idx, 1) = (center_y - box_height / 2.) / img_height; // xmax - output_data[idx++] = (center_x + box_width / 2.) / img_width; + e_out(0, h, w, idx, 2) = (center_x + box_width / 2.) / img_width; // ymax - output_data[idx++] = (center_y + box_height / 2.) / img_height; + e_out(0, h, w, idx, 3) = (center_y + box_height / 2.) / img_height; + idx++; } // rest of priors @@ -149,13 +150,14 @@ class PriorBoxOpKernel : public framework::OpKernel { box_width = min_size * sqrt(ar); box_height = min_size / sqrt(ar); // xmin - output_data[idx++] = (center_x - box_width / 2.) / img_width; + e_out(0, h, w, idx, 0) = (center_x - box_width / 2.) / img_width; // ymin - output_data[idx++] = (center_y - box_height / 2.) / img_height; + e_out(0, h, w, idx, 1) = (center_y - box_height / 2.) / img_height; // xmax - output_data[idx++] = (center_x + box_width / 2.) / img_width; + e_out(0, h, w, idx, 2) = (center_x + box_width / 2.) / img_width; // ymax - output_data[idx++] = (center_y + box_height / 2.) / img_height; + e_out(0, h, w, idx, 3) = (center_y + box_height / 2.) / img_height; + idx++; } } } @@ -163,26 +165,31 @@ class PriorBoxOpKernel : public framework::OpKernel { // clip the prior's coordidate such that it is within [0, 1] if (clip) { - for (int d = 0; d < dim; ++d) { - output_data[d] = std::min(std::max(output_data[d], 0.), 1.); + for (int h = 0; h < layer_height; ++h) { + for (int w = 0; w < layer_width; ++w) { + for (int i = 0; i < num_priors; ++i) { + for (int j = 0; j < 4; ++j) { + e_out(0, h, w, i, j) = + std::min(std::max(e_out(0, h, w, i, j), 0.), 1.); + } + } + } } - } - // set the variance. - auto output_stride = framework::stride(out->dims()); - output_data += output_stride[1]; - if (variances.size() == 1) { - for (int i = 0; i < dim; ++i) { - output_data[i] = variances[0]; + // set the variance. + auto output_stride = framework::stride(out->dims()); + output_data += output_stride[1]; + if (variances.size() == 1) { + variances.resize(4); + variances[1] = variances[0]; + variances[2] = variances[0]; + variances[3] = variances[0]; } - } else { - int count = 0; for (int h = 0; h < layer_height; ++h) { for (int w = 0; w < layer_width; ++w) { for (int i = 0; i < num_priors; ++i) { for (int j = 0; j < 4; ++j) { - output_data[count] = variances[j]; - ++count; + e_out(1, h, w, i, j) = variances[j]; } } } diff --git a/python/paddle/v2/fluid/tests/test_prior_box_op.py b/python/paddle/v2/fluid/tests/test_prior_box_op.py index 2f82188952..e00bc4bae4 100644 --- a/python/paddle/v2/fluid/tests/test_prior_box_op.py +++ b/python/paddle/v2/fluid/tests/test_prior_box_op.py @@ -81,8 +81,7 @@ class TestPriorBoxOp(OpTest): self.layer_h)).astype('float32') def init_test_output(self): - dim = self.layer_w * self.layer_h * self.num_priors * 4 - out_dim = (1, 2, dim) + out_dim = (2, self.layer_h, self.layer_w, self.num_priors, 4) output = np.zeros(out_dim).astype('float32') idx = 0 @@ -90,24 +89,22 @@ class TestPriorBoxOp(OpTest): for w in range(self.layer_w): center_x = (w + self.offset) * self.step_w center_y = (h + self.offset) * self.step_h + idx = 0 for s in range(len(self.min_sizes)): min_size = self.min_sizes[s] # first prior: aspect_ratio = 1, size = min_size box_width = box_height = min_size # xmin - output[0, 0, idx] = ( + output[0, h, w, idx, 0] = ( center_x - box_width / 2.) / self.image_w - idx += 1 # ymin - output[0, 0, idx] = ( + output[0, h, w, idx, 1] = ( center_y - box_height / 2.) / self.image_h - idx += 1 # xmax - output[0, 0, idx] = ( + output[0, h, w, idx, 2] = ( center_x + box_width / 2.) / self.image_w - idx += 1 # ymax - output[0, 0, idx] = ( + output[0, h, w, idx, 3] = ( center_y + box_height / 2.) / self.image_h idx += 1 @@ -117,19 +114,16 @@ class TestPriorBoxOp(OpTest): # size = sqrt(min_size * max_size) box_width = box_height = math.sqrt(min_size * max_size) # xmin - output[0, 0, idx] = ( + output[0, h, w, idx, 0] = ( center_x - box_width / 2.) / self.image_w - idx += 1 # ymin - output[0, 0, idx] = ( + output[0, h, w, idx, 1] = ( center_y - box_height / 2.) / self.image_h - idx += 1 # xmax - output[0, 0, idx] = ( + output[0, h, w, idx, 2] = ( center_x + box_width / 2.) / self.image_w - idx += 1 # ymax - output[0, 0, idx] = ( + output[0, h, w, idx, 3] = ( center_y + box_height / 2.) / self.image_h idx += 1 @@ -141,37 +135,35 @@ class TestPriorBoxOp(OpTest): box_width = min_size * math.sqrt(ar) box_height = min_size / math.sqrt(ar) # xmin - output[0, 0, idx] = ( + output[0, h, w, idx, 0] = ( center_x - box_width / 2.) / self.image_w - idx += 1 # ymin - output[0, 0, idx] = ( + output[0, h, w, idx, 1] = ( center_y - box_height / 2.) / self.image_h - idx += 1 # xmax - output[0, 0, idx] = ( + output[0, h, w, idx, 2] = ( center_x + box_width / 2.) / self.image_w - idx += 1 # ymax - output[0, 0, idx] = ( + output[0, h, w, idx, 3] = ( center_y + box_height / 2.) / self.image_h idx += 1 # clip the prior's coordidate such that it is within[0, 1] if self.clip: - for d in range(dim): - output[0, 0, d] = min(max(output[0, 0, d], 0), 1) - # set the variance. - if len(self.variances) == 1: - for i in range(dim): - output[0, 1, i] = self.variances[0] - else: - count = 0 for h in range(self.layer_h): for w in range(self.layer_w): for i in range(self.num_priors): for j in range(4): - output[0, 1, count] = self.variances[j] - count += 1 + output[0, h, w, i, j] = min( + max(output[0, h, w, i, j], 0), 1) + # set the variance. + for h in range(self.layer_h): + for w in range(self.layer_w): + for i in range(self.num_priors): + for j in range(4): + if len(self.variances) == 1: + output[1, h, w, i, j] = self.variances[0] + else: + output[1, h, w, i, j] = self.variances[j] self.output = output.astype('float32') From 2ad5a6f0d155a3a7224819b70c86f91e78fde5f1 Mon Sep 17 00:00:00 2001 From: wanghaox Date: Tue, 16 Jan 2018 18:37:13 +0800 Subject: [PATCH 003/110] add iou similarity operator --- paddle/operators/iou_similarity_op.cc | 74 ++++++++++++++++ paddle/operators/iou_similarity_op.h | 87 +++++++++++++++++++ .../v2/fluid/tests/test_iou_similarity_op.py | 36 ++++++++ 3 files changed, 197 insertions(+) create mode 100755 paddle/operators/iou_similarity_op.cc create mode 100644 paddle/operators/iou_similarity_op.h create mode 100755 python/paddle/v2/fluid/tests/test_iou_similarity_op.py diff --git a/paddle/operators/iou_similarity_op.cc b/paddle/operators/iou_similarity_op.cc new file mode 100755 index 0000000000..247549a8ff --- /dev/null +++ b/paddle/operators/iou_similarity_op.cc @@ -0,0 +1,74 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/iou_similarity_op.h" + +namespace paddle { +namespace operators { + +class IOUSimilarityOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + + PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The shape of X is [N, 4]"); + PADDLE_ENFORCE_EQ(x_dims[1], 4UL, "The shape of X is [N, 4]"); + PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The shape of Y is [M, 4]"); + PADDLE_ENFORCE_EQ(y_dims[1], 4UL, "The shape of Y is [M, 4]"); + + ctx->SetOutputDim("Out", framework::make_ddim({x_dims[0], y_dims[0]})); + } +}; + +class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker { + public: + IOUSimilarityOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor, default Tensor) " + "BoxList X holding N boxes, each box is " + "represented as [xmin, ymin, xmax, ymax], the shape of X is [N, 4]."); + AddInput( + "Y", + "(Tensor, default Tensor) " + "BoxList Y holding M boxes, each box is " + "represented as [xmin, ymin, xmax, ymax], the shape of X is [N, 4]."); + + AddOutput( + "Out", + "(Tensor) The output of iou_similarity op, a tensor with shape [N, M] " + "representing pairwise iou scores."); + + AddComment(R"DOC( +IOU Similarity Operator. +Computes pairwise intersection-over-union between box collections. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(iou_similarity, ops::IOUSimilarityOp, + ops::IOUSimilarityOpMaker); + +REGISTER_OP_CPU_KERNEL( + iou_similarity, + ops::IOUSimilarityKernel, + ops::IOUSimilarityKernel); diff --git a/paddle/operators/iou_similarity_op.h b/paddle/operators/iou_similarity_op.h new file mode 100644 index 0000000000..b3b6219415 --- /dev/null +++ b/paddle/operators/iou_similarity_op.h @@ -0,0 +1,87 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/platform/for_range.h" + +template +inline T IOUSimilarity(T xmin1, T ymin1, T xmax1, T ymax1, T xmin2, T ymin2, + T xmax2, T ymax2) { + T area1 = (ymax1 - ymin1) * (xmax1 - xmin1); + T area2 = (ymax2 - ymin2) * (xmax2 - xmin2); + T inter_xmax = std::min(xmax1, xmax2); + T inter_ymax = std::min(ymax1, ymax2); + T inter_xmin = std::max(xmin1, xmin2); + T inter_ymin = std::max(ymin1, ymin2); + T inter_height = std::max(inter_ymax - inter_ymin, static_cast(0)); + T inter_width = std::max(inter_xmax - inter_xmin, static_cast(0)); + T inter_area = inter_width * inter_height; + T union_area = area1 + area2 - inter_area; + T sim_score = inter_area / union_area; + return sim_score; +} + +template +struct IOUSimilarityFunctor { + IOUSimilarityFunctor(const T* x, const T* y, T* z, int cols) + : x_(x), y_(y), z_(z), cols_(static_cast(cols)) {} + + inline HOSTDEVICE void operator()(size_t row_id) const { + T x_min1 = x_[row_id * 4]; + T y_min1 = x_[row_id * 4 + 1]; + T x_max1 = x_[row_id * 4 + 2]; + T y_max1 = x_[row_id * 4 + 3]; + for (int i = 0; i < cols_; ++i) { + T x_min2 = y_[i * 4]; + T y_min2 = y_[i * 4 + 1]; + T x_max2 = y_[i * 4 + 2]; + T y_max2 = y_[i * 4 + 3]; + + T sim = IOUSimilarity(x_min1, y_min1, x_max1, y_max1, x_min2, y_min2, + x_max2, y_max2); + + z_[row_id * cols_ + i] = sim; + } + } + const T* x_; + const T* y_; + T* z_; + const size_t cols_; +}; + +namespace paddle { +namespace operators { + +template +class IOUSimilarityKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const framework::Tensor* in_x = ctx.Input("X"); + const framework::Tensor* in_y = ctx.Input("Y"); + framework::Tensor* out = ctx.Output("Out"); + + int x_n = in_x->dims()[0]; + int y_n = in_y->dims()[0]; + IOUSimilarityFunctor functor(in_x->data(), in_y->data(), + out->mutable_data(ctx.GetPlace()), y_n); + + platform::ForRange for_range( + static_cast(ctx.device_context()), x_n); + for_range(functor); + } +}; // namespace operators + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_iou_similarity_op.py b/python/paddle/v2/fluid/tests/test_iou_similarity_op.py new file mode 100755 index 0000000000..a4eb500020 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_iou_similarity_op.py @@ -0,0 +1,36 @@ +import unittest +import numpy as np +import sys +import math +from op_test import OpTest + + +class TestIOUSimilarityOp(OpTest): + def set_data(self): + self.init_test_data() + self.inputs = {'X': self.boxes1, 'Y': self.boxes2} + + self.outputs = {'Out': self.output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + return + + def setUp(self): + self.op_type = "iou_similarity" + self.set_data() + + def init_test_data(self): + self.boxes1 = np.array( + [[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]]).astype('float32') + self.boxes2 = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0], + [0.0, 0.0, 20.0, 20.0]]).astype('float32') + self.output = np.array( + [[2.0 / 16.0, 0, 6.0 / 400.0], + [1.0 / 16.0, 0.0, 5.0 / 400.0]]).astype('float32') + + +if __name__ == '__main__': + unittest.main() From 8266fcc3be869cfef42ab0ec597b2b4ce08dd37d Mon Sep 17 00:00:00 2001 From: yangyaming Date: Tue, 16 Jan 2018 21:31:08 +0800 Subject: [PATCH 004/110] Add pyton wrapper for row conv operator. --- doc/api/v2/fluid/layers.rst | 5 +++ python/paddle/v2/fluid/layers/nn.py | 65 ++++++++++++++++++++++++++--- 2 files changed, 64 insertions(+), 6 deletions(-) diff --git a/doc/api/v2/fluid/layers.rst b/doc/api/v2/fluid/layers.rst index 62c154e65d..ad3c70a6f1 100644 --- a/doc/api/v2/fluid/layers.rst +++ b/doc/api/v2/fluid/layers.rst @@ -493,3 +493,8 @@ swish ------ .. autofunction:: paddle.v2.fluid.layers.swish :noindex: + +row_conv +-------- +.. autofunction:: paddle.v2.fluid.layers.row_conv + :noindex: diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 4e8fd407c9..7c694ed777 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -50,6 +50,7 @@ __all__ = [ 'sequence_last_step', 'dropout', 'split', + 'row_conv', ] @@ -1547,13 +1548,13 @@ def split(input, num_or_sections, dim=-1): Args: input (Variable): The input variable which is a Tensor or LoDTensor. - num_or_sections (int|list): If :attr:`num_or_sections` is an integer, - then the integer indicates the number of equal sized sub-tensors - that the tensor will be divided into. If :attr:`num_or_sections` - is a list of integers, the length of list indicates the number of - sub-tensors and the integers indicate the sizes of sub-tensors' + num_or_sections (int|list): If :attr:`num_or_sections` is an integer, + then the integer indicates the number of equal sized sub-tensors + that the tensor will be divided into. If :attr:`num_or_sections` + is a list of integers, the length of list indicates the number of + sub-tensors and the integers indicate the sizes of sub-tensors' :attr:`dim` dimension orderly. - dim (int): The dimension along which to split. If :math:`dim < 0`, the + dim (int): The dimension along which to split. If :math:`dim < 0`, the dimension to split along is :math:`rank(input) + dim`. Returns: @@ -1597,3 +1598,55 @@ def split(input, num_or_sections, dim=-1): 'axis': dim }) return outs + + +def row_conv(input, future_context_size, param_attr=None, act=None): + """Row Conv Operator. This layer will apply lookahead convolution to + **input**. The input variable should be a 2D LoDTensor with shape [T, D]. + Parameters with shape [future_context_size + 1, D] will be created. The math + equation of row convolution is as following: + + .. math:: + Out_{i} = \sum_{j = i} ^ {i + \\tau} X_{j} \odot W_{i - j} + + In the above equation: + + * :math:`Out_{i}`: The i-th row of output variable with shape [1, D]. + * :math:`\\tau`: Future context size. + * :math:`X_{j}`: The j-th row of input variable with shape [1, D]. + * :math:`W_{i-j}`: The (i-j)-th row of parameters with shape [1, D]. + + More details about row_conv please refer to the paper \ + (http://www.cs.cmu.edu/~dyogatam/papers/wang+etal.iclrworkshop2016.pdf) and + the design document \ + (https://github.com/PaddlePaddle/Paddle/issues/2228#issuecomment-303903645). + + Args: + input (Variable): Input variable, a 2D LoDTensor with shape [T, D]. + future_context_size (int): Future context size. + param_attr (ParamAttr): Attributes of parameters, including + name, initializer etc. + act (str): Non-linear activation to be applied to output variable. + + Returns: + Variable: The output tensor with same shape as input tensor. + + Examples: + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[16], + dtype='float32', lod_level=1) + out = fluid.layers.row_conv(input=x, future_context_size=2) + """ + helper = LayerHelper('row_conv', **locals()) + dtype = helper.input_dtype() + filter_shape = [future_context_size + 1, input.shape[1]] + filter_param = helper.create_parameter( + attr=helper.param_attr, shape=filter_shape, dtype=dtype) + out = helper.create_tmp_variable(dtype) + helper.append_op( + type='row_conv', + inputs={'X': [input], + 'Filter': [filter_param]}, + outputs={'Out': [out]}) + return out From 2a0a576130b3b04f3555479f1850fd91dbba4d9a Mon Sep 17 00:00:00 2001 From: yangyaming Date: Tue, 16 Jan 2018 21:40:34 +0800 Subject: [PATCH 005/110] Add non-linear activation. --- python/paddle/v2/fluid/layers/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 7c694ed777..4546616d1a 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -1649,4 +1649,4 @@ def row_conv(input, future_context_size, param_attr=None, act=None): inputs={'X': [input], 'Filter': [filter_param]}, outputs={'Out': [out]}) - return out + return helper.append_activation(out) From d2a70243f1179654fd7224a4114cff5d984d424e Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 16 Jan 2018 13:33:13 +0800 Subject: [PATCH 006/110] Refine profiler and expose to Python. --- cmake/external/pybind11.cmake | 2 +- paddle/framework/CMakeLists.txt | 3 +- paddle/framework/executor.cc | 6 ++ paddle/platform/profiler.cc | 37 +++++++--- paddle/platform/profiler.h | 22 ++++-- paddle/platform/profiler_test.cc | 10 ++- paddle/pybind/CMakeLists.txt | 2 +- paddle/pybind/protobuf.cc | 70 +++---------------- paddle/pybind/protobuf.h | 1 + paddle/pybind/pybind.cc | 27 ++++++- python/paddle/v2/fluid/profiler.py | 45 ++++++++++++ python/paddle/v2/fluid/tests/test_profiler.py | 37 +++++++++- 12 files changed, 171 insertions(+), 91 deletions(-) diff --git a/cmake/external/pybind11.cmake b/cmake/external/pybind11.cmake index 4e87dc49d8..ab23663695 100644 --- a/cmake/external/pybind11.cmake +++ b/cmake/external/pybind11.cmake @@ -26,7 +26,7 @@ ExternalProject_Add( extern_pybind ${EXTERNAL_PROJECT_LOG_ARGS} GIT_REPOSITORY "https://github.com/pybind/pybind11.git" - GIT_TAG "v2.1.1" + GIT_TAG "v2.2.1" PREFIX ${PYBIND_SOURCE_DIR} UPDATE_COMMAND "" CONFIGURE_COMMAND "" diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 597ea959f2..9bf712250d 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -68,7 +68,8 @@ cc_library(backward SRCS backward.cc DEPS net_op) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context fill_constant_op) cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) -cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog lod_rank_table) +cc_library(executor SRCS executor.cc DEPS op_registry device_context scope +framework_proto backward glog lod_rank_table profiler) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index c0418c9266..d7233882e7 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/op_registry.h" #include "paddle/platform/place.h" +#include "paddle/platform/profiler.h" DEFINE_bool(check_nan_inf, false, "Checking whether operator produce NAN/INF or not. It will be " @@ -116,6 +117,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, for (auto& op_desc : block.AllOps()) { auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); VLOG(3) << op->DebugStringEx(local_scope); + + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto dev_ctx = const_cast(pool.Get(place_)); + platform::RecordEvent record_event(op->Type(), dev_ctx); + op->Run(*local_scope, place_); if (FLAGS_check_nan_inf) { for (auto& vname : op->OutputVars(true)) { diff --git a/paddle/platform/profiler.cc b/paddle/platform/profiler.cc index 7e2e2d968e..8175b827c3 100644 --- a/paddle/platform/profiler.cc +++ b/paddle/platform/profiler.cc @@ -163,14 +163,17 @@ void EnableProfiler(ProfilerState state) { Mark("_start_profiler_", nullptr); } -std::vector> DisableProfiler() { - PADDLE_ENFORCE(g_state != ProfilerState::kDisabled, - "Can't disable profiling, since it's not starting."); - // Mark the profiling stop. - Mark("_stop_profiler_", nullptr); - g_state = ProfilerState::kDisabled; - std::vector> result; +void ResetProfiler() { std::lock_guard guard(g_all_event_lists_mutex); + for (auto it = g_all_event_lists.begin(); it != g_all_event_lists.end(); + ++it) { + (*it)->Clear(); + } +} + +std::vector> GetAllEvents() { + std::lock_guard guard(g_all_event_lists_mutex); + std::vector> result; for (auto it = g_all_event_lists.begin(); it != g_all_event_lists.end(); ++it) { result.emplace_back((*it)->Reduce()); @@ -178,6 +181,18 @@ std::vector> DisableProfiler() { return result; } +void DisableProfiler(EventSortingKey sorted_key) { + PADDLE_ENFORCE(g_state != ProfilerState::kDisabled, + "Can't disable profiling, since it's not starting."); + // Mark the profiling stop. + Mark("_stop_profiler_", nullptr); + g_state = ProfilerState::kDisabled; + + std::vector> all_events = GetAllEvents(); + ParseEvents(all_events, sorted_key); + ResetProfiler(); +} + void ParseEvents(std::vector>& events, EventSortingKey sorted_by) { if (g_profiler_place == "") return; @@ -291,12 +306,12 @@ void ParseEvents(std::vector>& events, } // Print report - PrintProfilingReport(events_table, sorted_domain, max_name_width + 4, 12); + PrintProfiler(events_table, sorted_domain, max_name_width + 4, 12); } -void PrintProfilingReport(std::vector>& events_table, - std::string& sorted_domain, const size_t name_width, - const size_t data_width) { +void PrintProfiler(std::vector>& events_table, + std::string& sorted_domain, const size_t name_width, + const size_t data_width) { // Output header information std::cout << "\n------------------------->" << " Profiling Report " diff --git a/paddle/platform/profiler.h b/paddle/platform/profiler.h index 6df48ef880..85823af1d7 100644 --- a/paddle/platform/profiler.h +++ b/paddle/platform/profiler.h @@ -84,6 +84,8 @@ struct EventList { return result; } + void Clear() { event_blocks.clear(); } + std::forward_list> event_blocks; }; @@ -110,12 +112,9 @@ struct RecordEvent { std::string name_; }; -// Enable the profiling function. -void EnableProfiler(ProfilerState state); - // Return the event list of all threads. Asummed the returned value calls // event_lists, event_lists[i][j] represents the j-th Event of i-th thread. -std::vector> DisableProfiler(); +std::vector> GetAllEvents(); // The information of each event given in the profiling report struct EventItem { @@ -130,13 +129,22 @@ struct EventItem { // Candidate keys to sort the profiling report enum EventSortingKey { kDefault, kCalls, kTotal, kMin, kMax, kAve }; +// Enable the profiling function. +void EnableProfiler(ProfilerState state); + +// Clear the g_all_event_lists, which is total event lists of all threads. +void ResetProfiler(); + +void DisableProfiler(EventSortingKey sorted_key); + // Parse the event list and output the profiling report void ParseEvents(std::vector>&, EventSortingKey sorted_by = EventSortingKey::kDefault); // Print results -void PrintProfilingReport(std::vector>& events_table, - std::string& sorted_domain, const size_t name_width, - const size_t data_width); +void PrintProfiler(std::vector>& events_table, + std::string& sorted_domain, const size_t name_width, + const size_t data_width); + } // namespace platform } // namespace paddle diff --git a/paddle/platform/profiler_test.cc b/paddle/platform/profiler_test.cc index 13dea713c7..81f10c9134 100644 --- a/paddle/platform/profiler_test.cc +++ b/paddle/platform/profiler_test.cc @@ -103,18 +103,14 @@ TEST(RecordEvent, RecordEvent) { // Bad Usage: PushEvent("event_without_pop", dev_ctx); PopEvent("event_without_push", dev_ctx); - std::vector> events = paddle::platform::DisableProfiler(); - // Will remove parsing-related code from test later - ParseEvents(events, EventSortingKey::kTotal); + std::vector> events = paddle::platform::GetAllEvents(); int cuda_startup_count = 0; int start_profiler_count = 0; - int stop_profiler_count = 0; for (size_t i = 0; i < events.size(); ++i) { for (size_t j = 0; j < events[i].size(); ++j) { if (events[i][j].name() == "_cuda_startup_") ++cuda_startup_count; if (events[i][j].name() == "_start_profiler_") ++start_profiler_count; - if (events[i][j].name() == "_stop_profiler_") ++stop_profiler_count; if (events[i][j].name() == "push") { EXPECT_EQ(events[i][j + 1].name(), "pop"); #ifdef PADDLE_WITH_CUDA @@ -127,5 +123,7 @@ TEST(RecordEvent, RecordEvent) { } EXPECT_EQ(cuda_startup_count % 5, 0); EXPECT_EQ(start_profiler_count, 1); - EXPECT_EQ(stop_profiler_count, 1); + + // Will remove parsing-related code from test later + DisableProfiler(EventSortingKey::kTotal); } diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 7b37430707..e78673e0ba 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,7 +1,7 @@ if(WITH_PYTHON) cc_library(paddle_pybind SHARED SRCS pybind.cc exception.cc protobuf.cc const_value.cc - DEPS pybind python backward proto_desc paddle_memory executor prune init + DEPS pybind python backward proto_desc paddle_memory executor prune init profiler ${GLOB_OP_LIB}) if(NOT APPLE AND NOT ANDROID) target_link_libraries(paddle_pybind rt) diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 4f95948153..d80f6b71e9 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -21,74 +21,24 @@ limitations under the License. */ #include "paddle/framework/program_desc.h" #include "paddle/framework/var_desc.h" -// Cast boost::variant for PyBind. -// Copy from -// https://github.com/pybind/pybind11/issues/576#issuecomment-269563199 +using boost::variant; + namespace pybind11 { namespace detail { -// Can be replaced by a generic lambda in C++14 -struct variant_caster_visitor : public boost::static_visitor { - return_value_policy policy; - handle parent; - - variant_caster_visitor(return_value_policy policy, handle parent) - : policy(policy), parent(parent) {} - - template - handle operator()(T const &src) const { - return make_caster::cast(src, policy, parent); - } -}; - -template -struct variant_caster; - -template