commit
81be9cef47
@ -0,0 +1,154 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/prior_box_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class PriorBoxOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
||||
"Input(Input) of PriorBoxOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Image"),
|
||||
"Input(Image) of PriorBoxOp should not be null.");
|
||||
|
||||
auto image_dims = ctx->GetInputDim("Image");
|
||||
auto input_dims = ctx->GetInputDim("Input");
|
||||
PADDLE_ENFORCE(image_dims.size() == 4, "The layout of image is NCHW.");
|
||||
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
|
||||
|
||||
PADDLE_ENFORCE_LT(input_dims[2], image_dims[2],
|
||||
"The height of input must smaller than image.");
|
||||
|
||||
PADDLE_ENFORCE_LT(input_dims[3], image_dims[3],
|
||||
"The width of input must smaller than image.");
|
||||
|
||||
auto min_sizes = ctx->Attrs().Get<std::vector<int>>("min_sizes");
|
||||
auto max_sizes = ctx->Attrs().Get<std::vector<int>>("max_sizes");
|
||||
auto variances = ctx->Attrs().Get<std::vector<float>>("variances");
|
||||
auto aspect_ratios = ctx->Attrs().Get<std::vector<float>>("aspect_ratios");
|
||||
bool flip = ctx->Attrs().Get<bool>("flip");
|
||||
|
||||
PADDLE_ENFORCE_GT(min_sizes.size(), 0,
|
||||
"Size of min_sizes must be at least 1.");
|
||||
for (size_t i = 0; i < min_sizes.size(); ++i) {
|
||||
PADDLE_ENFORCE_GT(min_sizes[i], 0, "min_sizes[%d] must be positive.", i);
|
||||
}
|
||||
|
||||
std::vector<float> aspect_ratios_vec;
|
||||
ExpandAspectRatios(aspect_ratios, flip, aspect_ratios_vec);
|
||||
|
||||
int num_priors = aspect_ratios_vec.size() * min_sizes.size();
|
||||
if (max_sizes.size() > 0) {
|
||||
PADDLE_ENFORCE_EQ(max_sizes.size(), min_sizes.size(),
|
||||
"The number of min_size and max_size must be equal.");
|
||||
for (size_t i = 0; i < min_sizes.size(); ++i) {
|
||||
PADDLE_ENFORCE_GT(max_sizes[i], min_sizes[i],
|
||||
"max_size[%d] must be greater than min_size[%d].", i,
|
||||
i);
|
||||
num_priors += 1;
|
||||
}
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(variances.size(), 4, "Must and only provide 4 variance.");
|
||||
for (size_t i = 0; i < variances.size(); ++i) {
|
||||
PADDLE_ENFORCE_GT(variances[i], 0.0,
|
||||
"variance[%d] must be greater than 0.", i);
|
||||
}
|
||||
|
||||
const float step_h = ctx->Attrs().Get<float>("step_h");
|
||||
PADDLE_ENFORCE_GT(step_h, 0.0, "step_h should be larger than 0.");
|
||||
const float step_w = ctx->Attrs().Get<float>("step_w");
|
||||
PADDLE_ENFORCE_GT(step_w, 0.0, "step_w should be larger than 0.");
|
||||
|
||||
std::vector<int64_t> dim_vec(4);
|
||||
dim_vec[0] = input_dims[2];
|
||||
dim_vec[1] = input_dims[3];
|
||||
dim_vec[2] = num_priors;
|
||||
dim_vec[3] = 4;
|
||||
ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec));
|
||||
ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec));
|
||||
}
|
||||
};
|
||||
|
||||
class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
PriorBoxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("Input",
|
||||
"(Tensor, default Tensor<float>), "
|
||||
"the input feature data of PriorBoxOp, The layout is NCHW.");
|
||||
AddInput("Image",
|
||||
"(Tensor, default Tensor<float>), "
|
||||
"the input image data of PriorBoxOp, The layout is NCHW.");
|
||||
AddOutput("Boxes",
|
||||
"(Tensor, default Tensor<float>), the output prior boxes of "
|
||||
"PriorBoxOp. The layout is [H, W, num_priors, 4]. "
|
||||
"H is the height of input, W is the width of input, num_priors "
|
||||
"is the box count of each position.");
|
||||
AddOutput("Variances",
|
||||
"(Tensor, default Tensor<float>), the expanded variances of "
|
||||
"PriorBoxOp. The layout is [H, W, num_priors, 4]. "
|
||||
"H is the height of input, W is the width of input, num_priors "
|
||||
"is the box count of each position.");
|
||||
AddAttr<std::vector<int>>("min_sizes", "(vector<int>) ",
|
||||
"List of min sizes of generated prior boxes.");
|
||||
AddAttr<std::vector<int>>("max_sizes", "(vector<int>) ",
|
||||
"List of max sizes of generated prior boxes.");
|
||||
AddAttr<std::vector<float>>(
|
||||
"aspect_ratios", "(vector<float>) ",
|
||||
"List of aspect ratios of generated prior boxes.");
|
||||
AddAttr<std::vector<float>>(
|
||||
"variances", "(vector<float>) ",
|
||||
"List of variances to be encoded in prior boxes.");
|
||||
AddAttr<bool>("flip", "(bool) ", "Whether to flip aspect ratios.")
|
||||
.SetDefault(true);
|
||||
AddAttr<bool>("clip", "(bool) ", "Whether to clip out-of-boundary boxes.")
|
||||
.SetDefault(true);
|
||||
AddAttr<float>("step_w",
|
||||
"Prior boxes step across width, 0 for auto calculation.")
|
||||
.SetDefault(0.0);
|
||||
AddAttr<float>("step_h",
|
||||
"Prior boxes step across height, 0 for auto calculation.")
|
||||
.SetDefault(0.0);
|
||||
AddAttr<float>("offset",
|
||||
"(float) "
|
||||
"Prior boxes center offset.")
|
||||
.SetDefault(0.5);
|
||||
AddComment(R"DOC(
|
||||
Prior box operator
|
||||
Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm.
|
||||
Each position of the input produce N prior boxes, N is determined by
|
||||
the count of min_sizes, max_sizes and aspect_ratios, The size of the
|
||||
box is in range(min_size, max_size) interval, which is generated in
|
||||
sequence according to the aspect_ratios.
|
||||
|
||||
Please get more information from the following papers:
|
||||
https://arxiv.org/abs/1512.02325.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
prior_box, ops::PriorBoxOpKernel<paddle::platform::CPUPlace, float>,
|
||||
ops::PriorBoxOpKernel<paddle::platform::CPUPlace, double>);
|
@ -0,0 +1,188 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/math/math_function.h"
|
||||
#include "paddle/platform/transform.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
inline void ExpandAspectRatios(const std::vector<float>& input_aspect_ratior,
|
||||
bool flip,
|
||||
std::vector<float>& output_aspect_ratior) {
|
||||
constexpr float epsilon = 1e-6;
|
||||
output_aspect_ratior.clear();
|
||||
output_aspect_ratior.push_back(1.);
|
||||
for (size_t i = 0; i < input_aspect_ratior.size(); ++i) {
|
||||
float ar = input_aspect_ratior[i];
|
||||
bool already_exist = false;
|
||||
for (size_t j = 0; j < output_aspect_ratior.size(); ++j) {
|
||||
if (fabs(ar - output_aspect_ratior[j]) < epsilon) {
|
||||
already_exist = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!already_exist) {
|
||||
output_aspect_ratior.push_back(ar);
|
||||
if (flip) {
|
||||
output_aspect_ratior.push_back(1. / ar);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct ClipFunctor {
|
||||
HOSTDEVICE T operator()(T in) const {
|
||||
return std::min<T>(std::max<T>(in, 0.), 1.);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Place, typename T>
|
||||
class PriorBoxOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input = ctx.Input<paddle::framework::Tensor>("Input");
|
||||
auto* image = ctx.Input<paddle::framework::Tensor>("Image");
|
||||
auto* boxes = ctx.Output<paddle::framework::Tensor>("Boxes");
|
||||
auto* vars = ctx.Output<paddle::framework::Tensor>("Variances");
|
||||
|
||||
auto min_sizes = ctx.Attr<std::vector<int>>("min_sizes");
|
||||
auto max_sizes = ctx.Attr<std::vector<int>>("max_sizes");
|
||||
auto input_aspect_ratio = ctx.Attr<std::vector<float>>("aspect_ratios");
|
||||
auto variances = ctx.Attr<std::vector<float>>("variances");
|
||||
auto flip = ctx.Attr<bool>("flip");
|
||||
auto clip = ctx.Attr<bool>("clip");
|
||||
|
||||
std::vector<float> aspect_ratios;
|
||||
ExpandAspectRatios(input_aspect_ratio, flip, aspect_ratios);
|
||||
|
||||
T step_w = static_cast<T>(ctx.Attr<float>("step_w"));
|
||||
T step_h = static_cast<T>(ctx.Attr<float>("step_h"));
|
||||
T offset = static_cast<T>(ctx.Attr<float>("offset"));
|
||||
|
||||
auto img_width = image->dims()[3];
|
||||
auto img_height = image->dims()[2];
|
||||
|
||||
auto feature_width = input->dims()[3];
|
||||
auto feature_height = input->dims()[2];
|
||||
|
||||
T step_width, step_height;
|
||||
if (step_w == 0 || step_h == 0) {
|
||||
step_width = static_cast<T>(img_width) / feature_width;
|
||||
step_height = static_cast<T>(img_height) / feature_height;
|
||||
} else {
|
||||
step_width = step_w;
|
||||
step_height = step_h;
|
||||
}
|
||||
|
||||
int num_priors = aspect_ratios.size() * min_sizes.size();
|
||||
if (max_sizes.size() > 0) {
|
||||
num_priors += max_sizes.size();
|
||||
}
|
||||
|
||||
boxes->mutable_data<T>(ctx.GetPlace());
|
||||
vars->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto e_boxes = framework::EigenTensor<T, 4>::From(*boxes);
|
||||
for (int h = 0; h < feature_height; ++h) {
|
||||
for (int w = 0; w < feature_width; ++w) {
|
||||
T center_x = (w + offset) * step_width;
|
||||
T center_y = (h + offset) * step_height;
|
||||
T box_width, box_height;
|
||||
int idx = 0;
|
||||
for (size_t s = 0; s < min_sizes.size(); ++s) {
|
||||
int min_size = min_sizes[s];
|
||||
// first prior: aspect_ratio = 1, size = min_size
|
||||
box_width = box_height = min_size;
|
||||
// xmin
|
||||
e_boxes(h, w, idx, 0) = (center_x - box_width / 2.) / img_width;
|
||||
// ymin
|
||||
e_boxes(h, w, idx, 1) = (center_y - box_height / 2.) / img_height;
|
||||
// xmax
|
||||
e_boxes(h, w, idx, 2) = (center_x + box_width / 2.) / img_width;
|
||||
// ymax
|
||||
e_boxes(h, w, idx, 3) = (center_y + box_height / 2.) / img_height;
|
||||
|
||||
idx++;
|
||||
if (max_sizes.size() > 0) {
|
||||
int max_size = max_sizes[s];
|
||||
// second prior: aspect_ratio = 1,
|
||||
// size = sqrt(min_size * max_size)
|
||||
box_width = box_height = sqrt(min_size * max_size);
|
||||
// xmin
|
||||
e_boxes(h, w, idx, 0) = (center_x - box_width / 2.) / img_width;
|
||||
// ymin
|
||||
e_boxes(h, w, idx, 1) = (center_y - box_height / 2.) / img_height;
|
||||
// xmax
|
||||
e_boxes(h, w, idx, 2) = (center_x + box_width / 2.) / img_width;
|
||||
// ymax
|
||||
e_boxes(h, w, idx, 3) = (center_y + box_height / 2.) / img_height;
|
||||
idx++;
|
||||
}
|
||||
|
||||
// rest of priors
|
||||
for (size_t r = 0; r < aspect_ratios.size(); ++r) {
|
||||
float ar = aspect_ratios[r];
|
||||
if (fabs(ar - 1.) < 1e-6) {
|
||||
continue;
|
||||
}
|
||||
box_width = min_size * sqrt(ar);
|
||||
box_height = min_size / sqrt(ar);
|
||||
// xmin
|
||||
e_boxes(h, w, idx, 0) = (center_x - box_width / 2.) / img_width;
|
||||
// ymin
|
||||
e_boxes(h, w, idx, 1) = (center_y - box_height / 2.) / img_height;
|
||||
// xmax
|
||||
e_boxes(h, w, idx, 2) = (center_x + box_width / 2.) / img_width;
|
||||
// ymax
|
||||
e_boxes(h, w, idx, 3) = (center_y + box_height / 2.) / img_height;
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (clip) {
|
||||
platform::Transform<platform::CPUDeviceContext> trans;
|
||||
ClipFunctor<T> clip_func;
|
||||
trans(ctx.template device_context<platform::CPUDeviceContext>(),
|
||||
boxes->data<T>(), boxes->data<T>() + boxes->numel(),
|
||||
boxes->data<T>(), clip_func);
|
||||
}
|
||||
|
||||
framework::Tensor var_t;
|
||||
var_t.mutable_data<T>(
|
||||
framework::make_ddim({1, static_cast<int>(variances.size())}),
|
||||
ctx.GetPlace());
|
||||
auto var_et = framework::EigenTensor<T, 2>::From(var_t);
|
||||
for (size_t i = 0; i < variances.size(); ++i) {
|
||||
var_et(0, i) = variances[i];
|
||||
}
|
||||
|
||||
int box_num = feature_height * feature_width * num_priors;
|
||||
auto var_dim = vars->dims();
|
||||
vars->Resize({box_num, static_cast<int>(variances.size())});
|
||||
|
||||
auto e_vars = framework::EigenMatrix<T, Eigen::RowMajor>::From(*vars);
|
||||
e_vars = var_et.broadcast(Eigen::DSizes<int, 2>(box_num, 1));
|
||||
|
||||
vars->Resize(var_dim);
|
||||
}
|
||||
}; // namespace operators
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,148 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import sys
|
||||
import math
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestPriorBoxOp(OpTest):
|
||||
def set_data(self):
|
||||
self.init_test_params()
|
||||
self.init_test_input()
|
||||
self.init_test_output()
|
||||
self.inputs = {'Input': self.input, 'Image': self.image}
|
||||
|
||||
self.attrs = {
|
||||
'min_sizes': self.min_sizes,
|
||||
'max_sizes': self.max_sizes,
|
||||
'aspect_ratios': self.aspect_ratios,
|
||||
'variances': self.variances,
|
||||
'flip': self.flip,
|
||||
'clip': self.clip,
|
||||
'step_w': self.step_w,
|
||||
'step_h': self.step_h,
|
||||
'offset': self.offset
|
||||
}
|
||||
|
||||
self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
return
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "prior_box"
|
||||
self.set_data()
|
||||
|
||||
def init_test_params(self):
|
||||
self.layer_w = 4
|
||||
self.layer_h = 4
|
||||
|
||||
self.image_w = 20
|
||||
self.image_h = 20
|
||||
|
||||
self.step_w = float(self.image_w) / float(self.layer_w)
|
||||
self.step_h = float(self.image_h) / float(self.layer_h)
|
||||
|
||||
self.input_channels = 2
|
||||
self.image_channels = 3
|
||||
self.batch_size = 10
|
||||
|
||||
self.min_sizes = [2, 4]
|
||||
self.min_sizes = np.array(self.min_sizes).astype('int64')
|
||||
self.max_sizes = [5, 10]
|
||||
self.max_sizes = np.array(self.max_sizes).astype('int64')
|
||||
self.aspect_ratios = [2.0, 3.0]
|
||||
self.flip = True
|
||||
self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0]
|
||||
self.aspect_ratios = np.array(
|
||||
self.aspect_ratios, dtype=np.float).flatten()
|
||||
self.variances = [0.1, 0.1, 0.2, 0.2]
|
||||
self.variances = np.array(self.variances, dtype=np.float).flatten()
|
||||
|
||||
self.clip = True
|
||||
|
||||
self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes)
|
||||
if len(self.max_sizes) > 1:
|
||||
self.num_priors += len(self.max_sizes)
|
||||
self.offset = 0.5
|
||||
|
||||
def init_test_input(self):
|
||||
self.image = np.random.random(
|
||||
(self.batch_size, self.image_channels, self.image_w,
|
||||
self.image_h)).astype('float32')
|
||||
|
||||
self.input = np.random.random(
|
||||
(self.batch_size, self.input_channels, self.layer_w,
|
||||
self.layer_h)).astype('float32')
|
||||
|
||||
def init_test_output(self):
|
||||
out_dim = (self.layer_h, self.layer_w, self.num_priors, 4)
|
||||
out_boxes = np.zeros(out_dim).astype('float32')
|
||||
out_var = np.zeros(out_dim).astype('float32')
|
||||
|
||||
idx = 0
|
||||
for h in range(self.layer_h):
|
||||
for w in range(self.layer_w):
|
||||
c_x = (w + self.offset) * self.step_w
|
||||
c_y = (h + self.offset) * self.step_h
|
||||
idx = 0
|
||||
for s in range(len(self.min_sizes)):
|
||||
min_size = self.min_sizes[s]
|
||||
c_w = c_h = min_size / 2.
|
||||
out_boxes[h, w, idx, :] = [
|
||||
(c_x - c_w) / self.image_w, (c_y - c_h) / self.image_h,
|
||||
(c_x + c_w) / self.image_w, (c_y + c_h) / self.image_h
|
||||
]
|
||||
idx += 1
|
||||
|
||||
if len(self.max_sizes) > 0:
|
||||
max_size = self.max_sizes[s]
|
||||
# second prior: aspect_ratio = 1,
|
||||
c_w = c_h = math.sqrt(min_size * max_size) / 2
|
||||
out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w,
|
||||
(c_y - c_h) / self.image_h,
|
||||
(c_x + c_w) / self.image_w,
|
||||
(c_y + c_h) / self.image_h]
|
||||
idx += 1
|
||||
|
||||
# rest of priors
|
||||
for r in range(len(self.real_aspect_ratios)):
|
||||
ar = self.real_aspect_ratios[r]
|
||||
if math.fabs(ar - 1.) < 1e-6:
|
||||
continue
|
||||
c_w = min_size * math.sqrt(ar) / 2
|
||||
c_h = (min_size / math.sqrt(ar)) / 2
|
||||
out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w,
|
||||
(c_y - c_h) / self.image_h,
|
||||
(c_x + c_w) / self.image_w,
|
||||
(c_y + c_h) / self.image_h]
|
||||
idx += 1
|
||||
# clip the prior's coordidate such that it is within[0, 1]
|
||||
if self.clip:
|
||||
out_boxes = np.clip(out_boxes, 0.0, 1.0)
|
||||
# set the variance.
|
||||
out_var = np.tile(self.variances, (self.layer_h, self.layer_w,
|
||||
self.num_priors, 1))
|
||||
self.out_boxes = out_boxes.astype('float32')
|
||||
self.out_var = out_var.astype('float32')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue