From ee0113af31a0ac678ae62190cf62fbc7c3c098d6 Mon Sep 17 00:00:00 2001 From: wanghaox Date: Fri, 1 Dec 2017 14:04:37 +0800 Subject: [PATCH 1/5] implement of prior box operator for ssd --- paddle/operators/prior_box_op.cc | 167 +++++++++++++++ paddle/operators/prior_box_op.cu | 20 ++ paddle/operators/prior_box_op.h | 199 ++++++++++++++++++ .../v2/fluid/tests/test_prior_box_op.py | 179 ++++++++++++++++ 4 files changed, 565 insertions(+) create mode 100644 paddle/operators/prior_box_op.cc create mode 100755 paddle/operators/prior_box_op.cu create mode 100644 paddle/operators/prior_box_op.h create mode 100644 python/paddle/v2/fluid/tests/test_prior_box_op.py diff --git a/paddle/operators/prior_box_op.cc b/paddle/operators/prior_box_op.cc new file mode 100644 index 0000000000..fe1ccceb06 --- /dev/null +++ b/paddle/operators/prior_box_op.cc @@ -0,0 +1,167 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/prior_box_op.h" + +namespace paddle { +namespace operators { + +class PriorBoxOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(X) of SequenceSliceOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Image"), + "Input(Offset) of SequenceSliceOp should not be null."); + + auto image_dims = ctx->GetInputDim("Image"); + auto input_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE(image_dims.size() == 4, + "The format of input tensor is NCHW."); + + auto min_sizes = ctx->Attrs().Get>("min_sizes"); + auto max_sizes = ctx->Attrs().Get>("max_sizes"); + auto variances = ctx->Attrs().Get>("variances"); + auto input_aspect_ratio = + ctx->Attrs().Get>("aspect_ratios"); + bool flip = ctx->Attrs().Get("flip"); + + PADDLE_ENFORCE_GT(min_sizes.size(), 0, "must provide min_size."); + for (size_t i = 0; i < min_sizes.size(); ++i) { + PADDLE_ENFORCE_GT(min_sizes[i], 0, "min_sizes[%d] must be positive.", i); + } + + std::vector aspect_ratios; + expand_aspect_ratios(input_aspect_ratio, flip, aspect_ratios); + + int num_priors = aspect_ratios.size() * min_sizes.size(); + if (max_sizes.size() > 0) { + PADDLE_ENFORCE_EQ(max_sizes.size(), min_sizes.size(), + "The length of min_size and max_size must be equal."); + for (size_t i = 0; i < min_sizes.size(); ++i) { + PADDLE_ENFORCE_GT(max_sizes[i], min_sizes[i], + "max_size[%d] must be greater than min_size[%d].", i, + i); + num_priors += 1; + } + } + + if (variances.size() > 1) { + PADDLE_ENFORCE_EQ(variances.size(), 4, + "Must and only provide 4 variance."); + for (size_t i = 0; i < variances.size(); ++i) { + PADDLE_ENFORCE_GT(variances[i], 0.0, + "variance[%d] must be greater than 0.", i); + } + } else if (variances.size() == 1) { + PADDLE_ENFORCE_GT(variances[0], 0.0, + "variance[0] must be greater than 0."); + } + + const int img_h = ctx->Attrs().Get("img_h"); + PADDLE_ENFORCE_GT(img_h, 0, "img_h should be larger than 0."); + const int img_w = ctx->Attrs().Get("img_w"); + PADDLE_ENFORCE_GT(img_w, 0, "img_w should be larger than 0."); + + const float step_h = ctx->Attrs().Get("step_h"); + PADDLE_ENFORCE_GT(step_h, 0.0, "step_h should be larger than 0."); + const float step_w = ctx->Attrs().Get("step_w"); + PADDLE_ENFORCE_GT(step_w, 0.0, "step_w should be larger than 0."); + + const int layer_height = input_dims[3]; + const int layer_width = input_dims[2]; + + std::vector dim_vec(3); + // Since all images in a batch has same height and width, we only need to + // generate one set of priors which can be shared across all images. + dim_vec[0] = 1; + // 2 channels. First channel stores the mean of each prior coordinate. + // Second channel stores the variance of each prior coordinate. + dim_vec[1] = 2; + dim_vec[2] = layer_width * layer_height * num_priors * 4; + PADDLE_ENFORCE_GT(dim_vec[2], 0, + "output_dim[2] must larger than 0." + "check your data dims"); + auto output_dim = framework::make_ddim(dim_vec); + ctx->SetOutputDim("Out", output_dim); + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Image")->type()), + ctx.device_context()); + } +}; + +class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { + public: + PriorBoxOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Input", + "(Tensor), " + "the input feature data of PriorBoxOp."); + AddInput("Image", + "(Tensor), " + "the input image data of PriorBoxOp."); + AddOutput("Out", "(Tensor), the output prior boxes of PriorBoxOp."); + AddAttr>("min_sizes", "(vector) ", + "List of min sizes of generated prior boxes."); + AddAttr>("max_sizes", "(vector) ", + "List of max sizes of generated prior boxes."); + AddAttr>( + "aspect_ratios", "(vector) ", + "List of aspect ratios of generated prior boxes.") + .SetDefault({}); + AddAttr>( + "variances", "(vector) ", + "List of variances to be encoded in prior boxes.") + .SetDefault({0.1}); + AddAttr("flip", "(bool) ", "Whether to flip aspect ratios.") + .SetDefault(true); + AddAttr("clip", "(bool) ", "Whether to clip out-of-boundary boxes.") + .SetDefault(true); + AddAttr("img_w", "").SetDefault(0); + AddAttr("img_h", "").SetDefault(0); + AddAttr("step_w", + "Prior boxes step across width, 0 for auto calculation.") + .SetDefault(0.0); + AddAttr("step_h", + "Prior boxes step across height, 0 for auto calculation.") + .SetDefault(0.0); + AddAttr("offset", + "(float) " + "Prior boxes center offset.") + .SetDefault(0.5); + AddComment(R"DOC( +Prior box operator +Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm. +Please get more information from the following papers: +https://arxiv.org/abs/1512.02325. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker); +REGISTER_OP_CPU_KERNEL( + prior_box, ops::PriorBoxOpKernel, + ops::PriorBoxOpKernel); diff --git a/paddle/operators/prior_box_op.cu b/paddle/operators/prior_box_op.cu new file mode 100755 index 0000000000..d1928462a2 --- /dev/null +++ b/paddle/operators/prior_box_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/prior_box_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + prior_box, ops::PriorBoxOpKernel, + ops::PriorBoxOpKernel); diff --git a/paddle/operators/prior_box_op.h b/paddle/operators/prior_box_op.h new file mode 100644 index 0000000000..6dabba5265 --- /dev/null +++ b/paddle/operators/prior_box_op.h @@ -0,0 +1,199 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +// #include "paddle/operators/strided_memcpy.h" + +namespace paddle { +namespace operators { + +inline void expand_aspect_ratios(const std::vector input_aspect_ratior, + bool flip, + std::vector& output_aspect_ratior) { + constexpr float eps = 1e-6; + output_aspect_ratior.clear(); + output_aspect_ratior.push_back(1.); + for (size_t i = 0; i < input_aspect_ratior.size(); ++i) { + float ar = input_aspect_ratior[i]; + bool already_exist = false; + for (size_t j = 0; j < output_aspect_ratior.size(); ++j) { + if (fabs(ar - output_aspect_ratior[j]) < eps) { + already_exist = true; + break; + } + } + if (!already_exist) { + output_aspect_ratior.push_back(ar); + if (flip) { + output_aspect_ratior.push_back(1. / ar); + } + } + } +} + +template +class PriorBoxOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* image = ctx.Input("Image"); + auto* out = ctx.Output("Out"); + + auto min_sizes = ctx.Attr>("min_sizes"); + auto max_sizes = ctx.Attr>("max_sizes"); + auto input_aspect_ratio = ctx.Attr>("aspect_ratios"); + auto variances = ctx.Attr>("variances"); + auto flip = ctx.Attr("flip"); + auto clip = ctx.Attr("clip"); + + std::vector aspect_ratios; + expand_aspect_ratios(input_aspect_ratio, flip, aspect_ratios); + + auto img_w = ctx.Attr("img_w"); + auto img_h = ctx.Attr("img_h"); + auto step_w = ctx.Attr("step_w"); + auto step_h = ctx.Attr("step_h"); + auto offset = ctx.Attr("offset"); + + int img_width, img_height; + if (img_h == 0 || img_w == 0) { + img_width = image->dims()[2]; + img_height = image->dims()[3]; + } else { + img_width = img_w; + img_height = img_h; + } + + const int layer_width = input->dims()[2]; + const int layer_height = input->dims()[3]; + + float step_width, step_height; + if (step_w == 0 || step_h == 0) { + step_width = static_cast(img_width) / layer_width; + step_height = static_cast(img_height) / layer_height; + } else { + step_width = step_w; + step_height = step_h; + } + + int num_priors = aspect_ratios.size() * min_sizes.size(); + if (max_sizes.size() > 0) { + num_priors += max_sizes.size(); + } + + int dim = layer_height * layer_width * num_priors * 4; + + T* output_data = nullptr; + framework::Tensor output_cpu; + out->mutable_data(ctx.GetPlace()); + if (platform::is_gpu_place(ctx.GetPlace())) { + output_data = + output_cpu.mutable_data(out->dims(), platform::CPUPlace()); + } else { + output_data = out->mutable_data(ctx.GetPlace()); + } + + int idx = 0; + for (int h = 0; h < layer_height; ++h) { + for (int w = 0; w < layer_width; ++w) { + float center_x = (w + offset) * step_width; + float center_y = (h + offset) * step_height; + float box_width, box_height; + for (size_t s = 0; s < min_sizes.size(); ++s) { + int min_size = min_sizes[s]; + // first prior: aspect_ratio = 1, size = min_size + box_width = box_height = min_size; + // xmin + output_data[idx++] = (center_x - box_width / 2.) / img_width; + // ymin + output_data[idx++] = (center_y - box_height / 2.) / img_height; + // xmax + output_data[idx++] = (center_x + box_width / 2.) / img_width; + // ymax + output_data[idx++] = (center_y + box_height / 2.) / img_height; + + if (max_sizes.size() > 0) { + int max_size = max_sizes[s]; + // second prior: aspect_ratio = 1, + // size = sqrt(min_size * max_size) + box_width = box_height = sqrt(min_size * max_size); + // xmin + output_data[idx++] = (center_x - box_width / 2.) / img_width; + // ymin + output_data[idx++] = (center_y - box_height / 2.) / img_height; + // xmax + output_data[idx++] = (center_x + box_width / 2.) / img_width; + // ymax + output_data[idx++] = (center_y + box_height / 2.) / img_height; + } + + // rest of priors + for (size_t r = 0; r < aspect_ratios.size(); ++r) { + float ar = aspect_ratios[r]; + if (fabs(ar - 1.) < 1e-6) { + continue; + } + box_width = min_size * sqrt(ar); + box_height = min_size / sqrt(ar); + // xmin + output_data[idx++] = (center_x - box_width / 2.) / img_width; + // ymin + output_data[idx++] = (center_y - box_height / 2.) / img_height; + // xmax + output_data[idx++] = (center_x + box_width / 2.) / img_width; + // ymax + output_data[idx++] = (center_y + box_height / 2.) / img_height; + } + } + } + } + + // clip the prior's coordidate such that it is within [0, 1] + if (clip) { + for (int d = 0; d < dim; ++d) { + output_data[d] = std::min(std::max(output_data[d], 0.), 1.); + } + } + + // set the variance. + auto output_stride = framework::stride(out->dims()); + output_data += output_stride[1]; + if (variances.size() == 1) { + for (int i = 0; i < dim; ++i) { + output_data[i] = variances[0]; + } + } else { + int count = 0; + for (int h = 0; h < layer_height; ++h) { + for (int w = 0; w < layer_width; ++w) { + for (int i = 0; i < num_priors; ++i) { + for (int j = 0; j < 4; ++j) { + output_data[count] = variances[j]; + ++count; + } + } + } + } + } + if (platform::is_gpu_place(ctx.GetPlace())) { + framework::CopyFrom(output_cpu, platform::CPUPlace(), + ctx.device_context(), out); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_prior_box_op.py b/python/paddle/v2/fluid/tests/test_prior_box_op.py new file mode 100644 index 0000000000..2f82188952 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_prior_box_op.py @@ -0,0 +1,179 @@ +import unittest +import numpy as np +import sys +import math +from op_test import OpTest + + +class TestPriorBoxOp(OpTest): + def set_data(self): + self.init_test_params() + self.init_test_input() + self.init_test_output() + self.inputs = {'Input': self.input, 'Image': self.image} + + self.attrs = { + 'min_sizes': self.min_sizes, + 'max_sizes': self.max_sizes, + 'aspect_ratios': self.aspect_ratios, + 'variances': self.variances, + 'flip': self.flip, + 'clip': self.clip, + 'step_w': self.step_w, + 'step_h': self.step_h, + 'img_w': self.image_w, + 'img_h': self.image_h, + 'offset': self.offset + } + + self.outputs = {'Out': self.output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + return + + def setUp(self): + self.op_type = "prior_box" + self.set_data() + + def init_test_params(self): + self.layer_w = 4 + self.layer_h = 4 + + self.image_w = 20 + self.image_h = 20 + + self.step_w = float(self.image_w) / float(self.layer_w) + self.step_h = float(self.image_h) / float(self.layer_h) + + self.input_channels = 2 + self.image_channels = 3 + self.batch_size = 10 + + self.min_sizes = [2, 4] + self.min_sizes = np.array(self.min_sizes).astype('int64') + self.max_sizes = [5, 10] + self.max_sizes = np.array(self.max_sizes).astype('int64') + self.aspect_ratios = [2.0, 3.0] + self.flip = True + self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0] + self.aspect_ratios = np.array( + self.aspect_ratios, dtype=np.float).flatten() + self.variances = [0.1, 0.1, 0.2, 0.2] + self.variances = np.array(self.variances, dtype=np.float).flatten() + + self.clip = True + + self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes) + if len(self.max_sizes) > 1: + self.num_priors += len(self.max_sizes) + self.offset = 0.5 + + def init_test_input(self): + self.image = np.random.random( + (self.batch_size, self.image_channels, self.image_w, + self.image_h)).astype('float32') + + self.input = np.random.random( + (self.batch_size, self.input_channels, self.layer_w, + self.layer_h)).astype('float32') + + def init_test_output(self): + dim = self.layer_w * self.layer_h * self.num_priors * 4 + out_dim = (1, 2, dim) + output = np.zeros(out_dim).astype('float32') + + idx = 0 + for h in range(self.layer_h): + for w in range(self.layer_w): + center_x = (w + self.offset) * self.step_w + center_y = (h + self.offset) * self.step_h + for s in range(len(self.min_sizes)): + min_size = self.min_sizes[s] + # first prior: aspect_ratio = 1, size = min_size + box_width = box_height = min_size + # xmin + output[0, 0, idx] = ( + center_x - box_width / 2.) / self.image_w + idx += 1 + # ymin + output[0, 0, idx] = ( + center_y - box_height / 2.) / self.image_h + idx += 1 + # xmax + output[0, 0, idx] = ( + center_x + box_width / 2.) / self.image_w + idx += 1 + # ymax + output[0, 0, idx] = ( + center_y + box_height / 2.) / self.image_h + idx += 1 + + if len(self.max_sizes) > 0: + max_size = self.max_sizes[s] + # second prior: aspect_ratio = 1, + # size = sqrt(min_size * max_size) + box_width = box_height = math.sqrt(min_size * max_size) + # xmin + output[0, 0, idx] = ( + center_x - box_width / 2.) / self.image_w + idx += 1 + # ymin + output[0, 0, idx] = ( + center_y - box_height / 2.) / self.image_h + idx += 1 + # xmax + output[0, 0, idx] = ( + center_x + box_width / 2.) / self.image_w + idx += 1 + # ymax + output[0, 0, idx] = ( + center_y + box_height / 2.) / self.image_h + idx += 1 + + # rest of priors + for r in range(len(self.real_aspect_ratios)): + ar = self.real_aspect_ratios[r] + if math.fabs(ar - 1.) < 1e-6: + continue + box_width = min_size * math.sqrt(ar) + box_height = min_size / math.sqrt(ar) + # xmin + output[0, 0, idx] = ( + center_x - box_width / 2.) / self.image_w + idx += 1 + # ymin + output[0, 0, idx] = ( + center_y - box_height / 2.) / self.image_h + idx += 1 + # xmax + output[0, 0, idx] = ( + center_x + box_width / 2.) / self.image_w + idx += 1 + # ymax + output[0, 0, idx] = ( + center_y + box_height / 2.) / self.image_h + idx += 1 + # clip the prior's coordidate such that it is within[0, 1] + if self.clip: + for d in range(dim): + output[0, 0, d] = min(max(output[0, 0, d], 0), 1) + # set the variance. + if len(self.variances) == 1: + for i in range(dim): + output[0, 1, i] = self.variances[0] + else: + count = 0 + for h in range(self.layer_h): + for w in range(self.layer_w): + for i in range(self.num_priors): + for j in range(4): + output[0, 1, count] = self.variances[j] + count += 1 + self.output = output.astype('float32') + + +if __name__ == '__main__': + unittest.main() From 99a6c5d40edd29c4fedc8f50fe4db75177fb255d Mon Sep 17 00:00:00 2001 From: wanghaox Date: Tue, 9 Jan 2018 20:22:17 +0800 Subject: [PATCH 2/5] change output shape to [2, layer_height, layer_width, num_priors, 4] --- paddle/operators/prior_box_op.cc | 20 +++--- paddle/operators/prior_box_op.h | 71 ++++++++++--------- .../v2/fluid/tests/test_prior_box_op.py | 58 +++++++-------- 3 files changed, 72 insertions(+), 77 deletions(-) diff --git a/paddle/operators/prior_box_op.cc b/paddle/operators/prior_box_op.cc index 04182cb1b7..2ffea67bdd 100644 --- a/paddle/operators/prior_box_op.cc +++ b/paddle/operators/prior_box_op.cc @@ -93,17 +93,12 @@ class PriorBoxOp : public framework::OperatorWithKernel { const int layer_height = input_dims[2]; const int layer_width = input_dims[3]; - std::vector dim_vec(3); - // Since all images in a batch has same height and width, we only need to - // generate one set of priors which can be shared across all images. - dim_vec[0] = 1; - // 2 channels. First channel stores the mean of each prior coordinate. - // Second channel stores the variance of each prior coordinate. - dim_vec[1] = 2; - dim_vec[2] = layer_width * layer_height * num_priors * 4; - PADDLE_ENFORCE_GT(dim_vec[2], 0, - "output_dim[2] must larger than 0." - "check your data dims"); + std::vector dim_vec(5); + dim_vec[0] = 2; + dim_vec[1] = layer_height; + dim_vec[2] = layer_width; + dim_vec[3] = num_priors; + dim_vec[4] = 4; auto output_dim = framework::make_ddim(dim_vec); ctx->SetOutputDim("Out", output_dim); } @@ -130,7 +125,8 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { "the input image data of PriorBoxOp, The format is NCHW."); AddOutput("Out", "(Tensor, default Tensor), the output prior boxes of " - "PriorBoxOp."); + "PriorBoxOp. The format is [2, layer_height, layer_width, " + "num_priors, 4]"); AddAttr>("min_sizes", "(vector) ", "List of min sizes of generated prior boxes."); AddAttr>("max_sizes", "(vector) ", diff --git a/paddle/operators/prior_box_op.h b/paddle/operators/prior_box_op.h index 142e738a93..86399b53c3 100644 --- a/paddle/operators/prior_box_op.h +++ b/paddle/operators/prior_box_op.h @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include "paddle/framework/op_registry.h" #include "paddle/operators/math/math_function.h" -// #include "paddle/operators/strided_memcpy.h" namespace paddle { namespace operators { @@ -94,50 +93,52 @@ class PriorBoxOpKernel : public framework::OpKernel { num_priors += max_sizes.size(); } - int dim = layer_height * layer_width * num_priors * 4; - T* output_data = nullptr; framework::Tensor output_cpu; + framework::Tensor* output_tensor; out->mutable_data(ctx.GetPlace()); if (platform::is_gpu_place(ctx.GetPlace())) { - output_data = - output_cpu.mutable_data(out->dims(), platform::CPUPlace()); + output_cpu.mutable_data(out->dims(), platform::CPUPlace()); + output_tensor = &output_cpu; } else { - output_data = out->mutable_data(ctx.GetPlace()); + output_tensor = out; } - int idx = 0; + auto e_out = framework::EigenTensor::From(*output_tensor); for (int h = 0; h < layer_height; ++h) { for (int w = 0; w < layer_width; ++w) { float center_x = (w + offset) * step_width; float center_y = (h + offset) * step_height; float box_width, box_height; + int idx = 0; for (size_t s = 0; s < min_sizes.size(); ++s) { int min_size = min_sizes[s]; // first prior: aspect_ratio = 1, size = min_size box_width = box_height = min_size; // xmin - output_data[idx++] = (center_x - box_width / 2.) / img_width; + e_out(0, h, w, idx, 0) = (center_x - box_width / 2.) / img_width; // ymin - output_data[idx++] = (center_y - box_height / 2.) / img_height; + e_out(0, h, w, idx, 1) = (center_y - box_height / 2.) / img_height; // xmax - output_data[idx++] = (center_x + box_width / 2.) / img_width; + e_out(0, h, w, idx, 2) = (center_x + box_width / 2.) / img_width; // ymax - output_data[idx++] = (center_y + box_height / 2.) / img_height; + e_out(0, h, w, idx, 3) = (center_y + box_height / 2.) / img_height; + idx++; if (max_sizes.size() > 0) { int max_size = max_sizes[s]; // second prior: aspect_ratio = 1, // size = sqrt(min_size * max_size) box_width = box_height = sqrt(min_size * max_size); // xmin - output_data[idx++] = (center_x - box_width / 2.) / img_width; + e_out(0, h, w, idx, 0) = (center_x - box_width / 2.) / img_width; // ymin - output_data[idx++] = (center_y - box_height / 2.) / img_height; + e_out(0, h, w, idx, 1) = (center_y - box_height / 2.) / img_height; // xmax - output_data[idx++] = (center_x + box_width / 2.) / img_width; + e_out(0, h, w, idx, 2) = (center_x + box_width / 2.) / img_width; // ymax - output_data[idx++] = (center_y + box_height / 2.) / img_height; + e_out(0, h, w, idx, 3) = (center_y + box_height / 2.) / img_height; + idx++; } // rest of priors @@ -149,13 +150,14 @@ class PriorBoxOpKernel : public framework::OpKernel { box_width = min_size * sqrt(ar); box_height = min_size / sqrt(ar); // xmin - output_data[idx++] = (center_x - box_width / 2.) / img_width; + e_out(0, h, w, idx, 0) = (center_x - box_width / 2.) / img_width; // ymin - output_data[idx++] = (center_y - box_height / 2.) / img_height; + e_out(0, h, w, idx, 1) = (center_y - box_height / 2.) / img_height; // xmax - output_data[idx++] = (center_x + box_width / 2.) / img_width; + e_out(0, h, w, idx, 2) = (center_x + box_width / 2.) / img_width; // ymax - output_data[idx++] = (center_y + box_height / 2.) / img_height; + e_out(0, h, w, idx, 3) = (center_y + box_height / 2.) / img_height; + idx++; } } } @@ -163,26 +165,31 @@ class PriorBoxOpKernel : public framework::OpKernel { // clip the prior's coordidate such that it is within [0, 1] if (clip) { - for (int d = 0; d < dim; ++d) { - output_data[d] = std::min(std::max(output_data[d], 0.), 1.); + for (int h = 0; h < layer_height; ++h) { + for (int w = 0; w < layer_width; ++w) { + for (int i = 0; i < num_priors; ++i) { + for (int j = 0; j < 4; ++j) { + e_out(0, h, w, i, j) = + std::min(std::max(e_out(0, h, w, i, j), 0.), 1.); + } + } + } } - } - // set the variance. - auto output_stride = framework::stride(out->dims()); - output_data += output_stride[1]; - if (variances.size() == 1) { - for (int i = 0; i < dim; ++i) { - output_data[i] = variances[0]; + // set the variance. + auto output_stride = framework::stride(out->dims()); + output_data += output_stride[1]; + if (variances.size() == 1) { + variances.resize(4); + variances[1] = variances[0]; + variances[2] = variances[0]; + variances[3] = variances[0]; } - } else { - int count = 0; for (int h = 0; h < layer_height; ++h) { for (int w = 0; w < layer_width; ++w) { for (int i = 0; i < num_priors; ++i) { for (int j = 0; j < 4; ++j) { - output_data[count] = variances[j]; - ++count; + e_out(1, h, w, i, j) = variances[j]; } } } diff --git a/python/paddle/v2/fluid/tests/test_prior_box_op.py b/python/paddle/v2/fluid/tests/test_prior_box_op.py index 2f82188952..e00bc4bae4 100644 --- a/python/paddle/v2/fluid/tests/test_prior_box_op.py +++ b/python/paddle/v2/fluid/tests/test_prior_box_op.py @@ -81,8 +81,7 @@ class TestPriorBoxOp(OpTest): self.layer_h)).astype('float32') def init_test_output(self): - dim = self.layer_w * self.layer_h * self.num_priors * 4 - out_dim = (1, 2, dim) + out_dim = (2, self.layer_h, self.layer_w, self.num_priors, 4) output = np.zeros(out_dim).astype('float32') idx = 0 @@ -90,24 +89,22 @@ class TestPriorBoxOp(OpTest): for w in range(self.layer_w): center_x = (w + self.offset) * self.step_w center_y = (h + self.offset) * self.step_h + idx = 0 for s in range(len(self.min_sizes)): min_size = self.min_sizes[s] # first prior: aspect_ratio = 1, size = min_size box_width = box_height = min_size # xmin - output[0, 0, idx] = ( + output[0, h, w, idx, 0] = ( center_x - box_width / 2.) / self.image_w - idx += 1 # ymin - output[0, 0, idx] = ( + output[0, h, w, idx, 1] = ( center_y - box_height / 2.) / self.image_h - idx += 1 # xmax - output[0, 0, idx] = ( + output[0, h, w, idx, 2] = ( center_x + box_width / 2.) / self.image_w - idx += 1 # ymax - output[0, 0, idx] = ( + output[0, h, w, idx, 3] = ( center_y + box_height / 2.) / self.image_h idx += 1 @@ -117,19 +114,16 @@ class TestPriorBoxOp(OpTest): # size = sqrt(min_size * max_size) box_width = box_height = math.sqrt(min_size * max_size) # xmin - output[0, 0, idx] = ( + output[0, h, w, idx, 0] = ( center_x - box_width / 2.) / self.image_w - idx += 1 # ymin - output[0, 0, idx] = ( + output[0, h, w, idx, 1] = ( center_y - box_height / 2.) / self.image_h - idx += 1 # xmax - output[0, 0, idx] = ( + output[0, h, w, idx, 2] = ( center_x + box_width / 2.) / self.image_w - idx += 1 # ymax - output[0, 0, idx] = ( + output[0, h, w, idx, 3] = ( center_y + box_height / 2.) / self.image_h idx += 1 @@ -141,37 +135,35 @@ class TestPriorBoxOp(OpTest): box_width = min_size * math.sqrt(ar) box_height = min_size / math.sqrt(ar) # xmin - output[0, 0, idx] = ( + output[0, h, w, idx, 0] = ( center_x - box_width / 2.) / self.image_w - idx += 1 # ymin - output[0, 0, idx] = ( + output[0, h, w, idx, 1] = ( center_y - box_height / 2.) / self.image_h - idx += 1 # xmax - output[0, 0, idx] = ( + output[0, h, w, idx, 2] = ( center_x + box_width / 2.) / self.image_w - idx += 1 # ymax - output[0, 0, idx] = ( + output[0, h, w, idx, 3] = ( center_y + box_height / 2.) / self.image_h idx += 1 # clip the prior's coordidate such that it is within[0, 1] if self.clip: - for d in range(dim): - output[0, 0, d] = min(max(output[0, 0, d], 0), 1) - # set the variance. - if len(self.variances) == 1: - for i in range(dim): - output[0, 1, i] = self.variances[0] - else: - count = 0 for h in range(self.layer_h): for w in range(self.layer_w): for i in range(self.num_priors): for j in range(4): - output[0, 1, count] = self.variances[j] - count += 1 + output[0, h, w, i, j] = min( + max(output[0, h, w, i, j], 0), 1) + # set the variance. + for h in range(self.layer_h): + for w in range(self.layer_w): + for i in range(self.num_priors): + for j in range(4): + if len(self.variances) == 1: + output[1, h, w, i, j] = self.variances[0] + else: + output[1, h, w, i, j] = self.variances[j] self.output = output.astype('float32') From 142f6328865300b58134117db2d39190e972dac5 Mon Sep 17 00:00:00 2001 From: wanghaox Date: Mon, 22 Jan 2018 15:28:32 +0800 Subject: [PATCH 3/5] update code --- paddle/operators/prior_box_op.cc | 37 ++++----- paddle/operators/prior_box_op.h | 18 ++--- .../v2/fluid/tests/test_prior_box_op.py | 76 ++++++++----------- 3 files changed, 57 insertions(+), 74 deletions(-) diff --git a/paddle/operators/prior_box_op.cc b/paddle/operators/prior_box_op.cc index b492c05082..286afa8b4f 100644 --- a/paddle/operators/prior_box_op.cc +++ b/paddle/operators/prior_box_op.cc @@ -23,14 +23,14 @@ class PriorBoxOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(X) of PriorBoxOp should not be null."); + "Input(Input) of PriorBoxOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Image"), - "Input(Offset) of PriorBoxOp should not be null."); + "Input(Image) of PriorBoxOp should not be null."); auto image_dims = ctx->GetInputDim("Image"); auto input_dims = ctx->GetInputDim("Input"); - PADDLE_ENFORCE(image_dims.size() == 4, "The format of image is NCHW."); - PADDLE_ENFORCE(input_dims.size() == 4, "The format of input is NCHW."); + PADDLE_ENFORCE(image_dims.size() == 4, "The layout of image is NCHW."); + PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); PADDLE_ENFORCE_LT(input_dims[2], image_dims[2], "The height of input must smaller than image."); @@ -45,7 +45,7 @@ class PriorBoxOp : public framework::OperatorWithKernel { bool flip = ctx->Attrs().Get("flip"); PADDLE_ENFORCE_GT(min_sizes.size(), 0, - "Size of min_size must be at least 1."); + "Size of min_sizes must be at least 1."); for (size_t i = 0; i < min_sizes.size(); ++i) { PADDLE_ENFORCE_GT(min_sizes[i], 0, "min_sizes[%d] must be positive.", i); } @@ -56,7 +56,7 @@ class PriorBoxOp : public framework::OperatorWithKernel { int num_priors = aspect_ratios_vec.size() * min_sizes.size(); if (max_sizes.size() > 0) { PADDLE_ENFORCE_EQ(max_sizes.size(), min_sizes.size(), - "The length of min_size and max_size must be equal."); + "The number of min_size and max_size must be equal."); for (size_t i = 0; i < min_sizes.size(); ++i) { PADDLE_ENFORCE_GT(max_sizes[i], min_sizes[i], "max_size[%d] must be greater than min_size[%d].", i, @@ -65,13 +65,10 @@ class PriorBoxOp : public framework::OperatorWithKernel { } } - if (variances.size() > 1) { - PADDLE_ENFORCE_EQ(variances.size(), 4, - "Must and only provide 4 variance."); - for (size_t i = 0; i < variances.size(); ++i) { - PADDLE_ENFORCE_GT(variances[i], 0.0, - "variance[%d] must be greater than 0.", i); - } + PADDLE_ENFORCE_EQ(variances.size(), 4, "Must and only provide 4 variance."); + for (size_t i = 0; i < variances.size(); ++i) { + PADDLE_ENFORCE_GT(variances[i], 0.0, + "variance[%d] must be greater than 0.", i); } const float step_h = ctx->Attrs().Get("step_h"); @@ -95,19 +92,19 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Input", "(Tensor, default Tensor), " - "the input feature data of PriorBoxOp, The format is NCHW."); + "the input feature data of PriorBoxOp, The layout is NCHW."); AddInput("Image", "(Tensor, default Tensor), " - "the input image data of PriorBoxOp, The format is NCHW."); + "the input image data of PriorBoxOp, The layout is NCHW."); AddOutput("Boxes", "(Tensor, default Tensor), the output prior boxes of " - "PriorBoxOp. The format is [layer_height, layer_width, " + "PriorBoxOp. The layout is [layer_height, layer_width, " "num_priors, 4]. layer_height is the height of input, " "layer_width is the width of input, num_priors is the box " "count of each position."); AddOutput("Variances", "(Tensor, default Tensor), the expanded variances of " - "PriorBoxOp. The format is [layer_height, layer_width, " + "PriorBoxOp. The layout is [layer_height, layer_width, " "num_priors, 4]. layer_height is the height of input, " "layer_width is the width of input, num_priors is the box " "count of each position."); @@ -117,12 +114,10 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { "List of max sizes of generated prior boxes."); AddAttr>( "aspect_ratios", "(vector) ", - "List of aspect ratios of generated prior boxes.") - .SetDefault({}); + "List of aspect ratios of generated prior boxes."); AddAttr>( "variances", "(vector) ", - "List of variances to be encoded in prior boxes.") - .SetDefault({0.1}); + "List of variances to be encoded in prior boxes."); AddAttr("flip", "(bool) ", "Whether to flip aspect ratios.") .SetDefault(true); AddAttr("clip", "(bool) ", "Whether to clip out-of-boundary boxes.") diff --git a/paddle/operators/prior_box_op.h b/paddle/operators/prior_box_op.h index 9dcd4d8a2f..5483dd7bbe 100644 --- a/paddle/operators/prior_box_op.h +++ b/paddle/operators/prior_box_op.h @@ -70,9 +70,9 @@ class PriorBoxOpKernel : public framework::OpKernel { std::vector aspect_ratios; ExpandAspectRatios(input_aspect_ratio, flip, aspect_ratios); - auto step_w = ctx.Attr("step_w"); - auto step_h = ctx.Attr("step_h"); - auto offset = ctx.Attr("offset"); + T step_w = static_cast(ctx.Attr("step_w")); + T step_h = static_cast(ctx.Attr("step_h")); + T offset = static_cast(ctx.Attr("offset")); auto img_width = image->dims()[3]; auto img_height = image->dims()[2]; @@ -80,10 +80,10 @@ class PriorBoxOpKernel : public framework::OpKernel { auto layer_width = input->dims()[3]; auto layer_height = input->dims()[2]; - float step_width, step_height; + T step_width, step_height; if (step_w == 0 || step_h == 0) { - step_width = static_cast(img_width) / layer_width; - step_height = static_cast(img_height) / layer_height; + step_width = static_cast(img_width) / layer_width; + step_height = static_cast(img_height) / layer_height; } else { step_width = step_w; step_height = step_h; @@ -100,9 +100,9 @@ class PriorBoxOpKernel : public framework::OpKernel { auto e_boxes = framework::EigenTensor::From(*boxes); for (int h = 0; h < layer_height; ++h) { for (int w = 0; w < layer_width; ++w) { - float center_x = (w + offset) * step_width; - float center_y = (h + offset) * step_height; - float box_width, box_height; + T center_x = (w + offset) * step_width; + T center_y = (h + offset) * step_height; + T box_width, box_height; int idx = 0; for (size_t s = 0; s < min_sizes.size(); ++s) { int min_size = min_sizes[s]; diff --git a/python/paddle/v2/fluid/tests/test_prior_box_op.py b/python/paddle/v2/fluid/tests/test_prior_box_op.py index 86e4ab76b5..ca8d2bca74 100644 --- a/python/paddle/v2/fluid/tests/test_prior_box_op.py +++ b/python/paddle/v2/fluid/tests/test_prior_box_op.py @@ -1,3 +1,17 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest import numpy as np import sys @@ -86,44 +100,26 @@ class TestPriorBoxOp(OpTest): idx = 0 for h in range(self.layer_h): for w in range(self.layer_w): - center_x = (w + self.offset) * self.step_w - center_y = (h + self.offset) * self.step_h + c_x = (w + self.offset) * self.step_w + c_y = (h + self.offset) * self.step_h idx = 0 for s in range(len(self.min_sizes)): min_size = self.min_sizes[s] - # first prior: aspect_ratio = 1, size = min_size - box_width = box_height = min_size - # xmin - out_boxes[h, w, idx, 0] = ( - center_x - box_width / 2.) / self.image_w - # ymin - out_boxes[h, w, idx, 1] = ( - center_y - box_height / 2.) / self.image_h - # xmax - out_boxes[h, w, idx, 2] = ( - center_x + box_width / 2.) / self.image_w - # ymax - out_boxes[h, w, idx, 3] = ( - center_y + box_height / 2.) / self.image_h + c_w = c_h = min_size / 2. + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, (c_y - c_h) / self.image_h, + (c_x + c_w) / self.image_w, (c_y + c_h) / self.image_h + ] idx += 1 if len(self.max_sizes) > 0: max_size = self.max_sizes[s] # second prior: aspect_ratio = 1, - # size = sqrt(min_size * max_size) - box_width = box_height = math.sqrt(min_size * max_size) - # xmin - out_boxes[h, w, idx, 0] = ( - center_x - box_width / 2.) / self.image_w - # ymin - out_boxes[h, w, idx, 1] = ( - center_y - box_height / 2.) / self.image_h - # xmax - out_boxes[h, w, idx, 2] = ( - center_x + box_width / 2.) / self.image_w - # ymax - out_boxes[h, w, idx, 3] = ( - center_y + box_height / 2.) / self.image_h + c_w = c_h = math.sqrt(min_size * max_size) / 2 + out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w, + (c_y - c_h) / self.image_h, + (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h] idx += 1 # rest of priors @@ -131,20 +127,12 @@ class TestPriorBoxOp(OpTest): ar = self.real_aspect_ratios[r] if math.fabs(ar - 1.) < 1e-6: continue - box_width = min_size * math.sqrt(ar) - box_height = min_size / math.sqrt(ar) - # xmin - out_boxes[h, w, idx, 0] = ( - center_x - box_width / 2.) / self.image_w - # ymin - out_boxes[h, w, idx, 1] = ( - center_y - box_height / 2.) / self.image_h - # xmax - out_boxes[h, w, idx, 2] = ( - center_x + box_width / 2.) / self.image_w - # ymax - out_boxes[h, w, idx, 3] = ( - center_y + box_height / 2.) / self.image_h + c_w = min_size * math.sqrt(ar) / 2 + c_h = (min_size / math.sqrt(ar)) / 2 + out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w, + (c_y - c_h) / self.image_h, + (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h] idx += 1 # clip the prior's coordidate such that it is within[0, 1] if self.clip: From 0e165032a84fc06623f7f91a52eb31134526a29b Mon Sep 17 00:00:00 2001 From: wanghaox Date: Mon, 22 Jan 2018 17:33:38 +0800 Subject: [PATCH 4/5] update code --- paddle/operators/prior_box_op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/operators/prior_box_op.h b/paddle/operators/prior_box_op.h index 5483dd7bbe..1869807dc8 100644 --- a/paddle/operators/prior_box_op.h +++ b/paddle/operators/prior_box_op.h @@ -165,7 +165,7 @@ class PriorBoxOpKernel : public framework::OpKernel { } Eigen::Tensor var_et(1, variances.size()); - for (int i = 0; i < variances.size(); ++i) { + for (size_t i = 0; i < variances.size(); ++i) { var_et(0, i) = variances[i]; } From ca2e96f270e3a9b23aeb77baac169d24f68f3529 Mon Sep 17 00:00:00 2001 From: wanghaox Date: Tue, 23 Jan 2018 14:35:34 +0800 Subject: [PATCH 5/5] update code --- paddle/operators/prior_box_op.cc | 14 ++++++-------- paddle/operators/prior_box_op.h | 20 ++++++++++++-------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/paddle/operators/prior_box_op.cc b/paddle/operators/prior_box_op.cc index 286afa8b4f..105ff4ac3e 100644 --- a/paddle/operators/prior_box_op.cc +++ b/paddle/operators/prior_box_op.cc @@ -98,16 +98,14 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { "the input image data of PriorBoxOp, The layout is NCHW."); AddOutput("Boxes", "(Tensor, default Tensor), the output prior boxes of " - "PriorBoxOp. The layout is [layer_height, layer_width, " - "num_priors, 4]. layer_height is the height of input, " - "layer_width is the width of input, num_priors is the box " - "count of each position."); + "PriorBoxOp. The layout is [H, W, num_priors, 4]. " + "H is the height of input, W is the width of input, num_priors " + "is the box count of each position."); AddOutput("Variances", "(Tensor, default Tensor), the expanded variances of " - "PriorBoxOp. The layout is [layer_height, layer_width, " - "num_priors, 4]. layer_height is the height of input, " - "layer_width is the width of input, num_priors is the box " - "count of each position."); + "PriorBoxOp. The layout is [H, W, num_priors, 4]. " + "H is the height of input, W is the width of input, num_priors " + "is the box count of each position."); AddAttr>("min_sizes", "(vector) ", "List of min sizes of generated prior boxes."); AddAttr>("max_sizes", "(vector) ", diff --git a/paddle/operators/prior_box_op.h b/paddle/operators/prior_box_op.h index 1869807dc8..e0a663ace8 100644 --- a/paddle/operators/prior_box_op.h +++ b/paddle/operators/prior_box_op.h @@ -77,13 +77,13 @@ class PriorBoxOpKernel : public framework::OpKernel { auto img_width = image->dims()[3]; auto img_height = image->dims()[2]; - auto layer_width = input->dims()[3]; - auto layer_height = input->dims()[2]; + auto feature_width = input->dims()[3]; + auto feature_height = input->dims()[2]; T step_width, step_height; if (step_w == 0 || step_h == 0) { - step_width = static_cast(img_width) / layer_width; - step_height = static_cast(img_height) / layer_height; + step_width = static_cast(img_width) / feature_width; + step_height = static_cast(img_height) / feature_height; } else { step_width = step_w; step_height = step_h; @@ -98,8 +98,8 @@ class PriorBoxOpKernel : public framework::OpKernel { vars->mutable_data(ctx.GetPlace()); auto e_boxes = framework::EigenTensor::From(*boxes); - for (int h = 0; h < layer_height; ++h) { - for (int w = 0; w < layer_width; ++w) { + for (int h = 0; h < feature_height; ++h) { + for (int w = 0; w < feature_width; ++w) { T center_x = (w + offset) * step_width; T center_y = (h + offset) * step_height; T box_width, box_height; @@ -164,12 +164,16 @@ class PriorBoxOpKernel : public framework::OpKernel { boxes->data(), clip_func); } - Eigen::Tensor var_et(1, variances.size()); + framework::Tensor var_t; + var_t.mutable_data( + framework::make_ddim({1, static_cast(variances.size())}), + ctx.GetPlace()); + auto var_et = framework::EigenTensor::From(var_t); for (size_t i = 0; i < variances.size(); ++i) { var_et(0, i) = variances[i]; } - int box_num = layer_height * layer_width * num_priors; + int box_num = feature_height * feature_width * num_priors; auto var_dim = vars->dims(); vars->Resize({box_num, static_cast(variances.size())});