Add density_prior_box_op (#14226)
Density prior box operator for image detection model.panyx0718-patch-1
parent
9a6e239281
commit
4a55fb5f5b
@ -0,0 +1,175 @@
|
||||
/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/detection/density_prior_box_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class DensityPriorBoxOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
||||
"Input(Input) of DensityPriorBoxOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Image"),
|
||||
"Input(Image) of DensityPriorBoxOp should not be null.");
|
||||
|
||||
auto image_dims = ctx->GetInputDim("Image");
|
||||
auto input_dims = ctx->GetInputDim("Input");
|
||||
PADDLE_ENFORCE(image_dims.size() == 4, "The layout of image is NCHW.");
|
||||
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
|
||||
|
||||
PADDLE_ENFORCE_LT(input_dims[2], image_dims[2],
|
||||
"The height of input must smaller than image.");
|
||||
|
||||
PADDLE_ENFORCE_LT(input_dims[3], image_dims[3],
|
||||
"The width of input must smaller than image.");
|
||||
auto variances = ctx->Attrs().Get<std::vector<float>>("variances");
|
||||
|
||||
auto fixed_sizes = ctx->Attrs().Get<std::vector<float>>("fixed_sizes");
|
||||
auto fixed_ratios = ctx->Attrs().Get<std::vector<float>>("fixed_ratios");
|
||||
auto densities = ctx->Attrs().Get<std::vector<int>>("densities");
|
||||
|
||||
PADDLE_ENFORCE_EQ(fixed_sizes.size(), densities.size(),
|
||||
"The number of fixed_sizes and densities must be equal.");
|
||||
size_t num_priors = 0;
|
||||
if ((fixed_sizes.size() > 0) && (densities.size() > 0)) {
|
||||
for (size_t i = 0; i < densities.size(); ++i) {
|
||||
if (fixed_ratios.size() > 0) {
|
||||
num_priors += (fixed_ratios.size()) * (pow(densities[i], 2));
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<int64_t> dim_vec(4);
|
||||
dim_vec[0] = input_dims[2];
|
||||
dim_vec[1] = input_dims[3];
|
||||
dim_vec[2] = num_priors;
|
||||
dim_vec[3] = 4;
|
||||
ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec));
|
||||
ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),
|
||||
platform::CPUPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class DensityPriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput(
|
||||
"Input",
|
||||
"(Tensor, default Tensor<float>), "
|
||||
"the input feature data of DensityPriorBoxOp, the layout is NCHW.");
|
||||
AddInput("Image",
|
||||
"(Tensor, default Tensor<float>), "
|
||||
"the input image data of DensityPriorBoxOp, the layout is NCHW.");
|
||||
AddOutput("Boxes",
|
||||
"(Tensor, default Tensor<float>), the output prior boxes of "
|
||||
"DensityPriorBoxOp. The layout is [H, W, num_priors, 4]. "
|
||||
"H is the height of input, W is the width of input, num_priors "
|
||||
"is the box count of each position.");
|
||||
AddOutput("Variances",
|
||||
"(Tensor, default Tensor<float>), the expanded variances of "
|
||||
"DensityPriorBoxOp. The layout is [H, W, num_priors, 4]. "
|
||||
"H is the height of input, W is the width of input, num_priors "
|
||||
"is the box count of each position.");
|
||||
AddAttr<std::vector<float>>("variances",
|
||||
"(vector<float>) List of variances to be "
|
||||
"encoded in density prior boxes.")
|
||||
.AddCustomChecker([](const std::vector<float>& variances) {
|
||||
PADDLE_ENFORCE_EQ(variances.size(), 4,
|
||||
"Must and only provide 4 variance.");
|
||||
for (size_t i = 0; i < variances.size(); ++i) {
|
||||
PADDLE_ENFORCE_GT(variances[i], 0.0,
|
||||
"variance[%d] must be greater than 0.", i);
|
||||
}
|
||||
});
|
||||
AddAttr<bool>("clip", "(bool) Whether to clip out-of-boundary boxes.")
|
||||
.SetDefault(true);
|
||||
|
||||
AddAttr<float>(
|
||||
"step_w",
|
||||
"Density prior boxes step across width, 0.0 for auto calculation.")
|
||||
.SetDefault(0.0)
|
||||
.AddCustomChecker([](const float& step_w) {
|
||||
PADDLE_ENFORCE_GE(step_w, 0.0, "step_w should be larger than 0.");
|
||||
});
|
||||
AddAttr<float>(
|
||||
"step_h",
|
||||
"Density prior boxes step across height, 0.0 for auto calculation.")
|
||||
.SetDefault(0.0)
|
||||
.AddCustomChecker([](const float& step_h) {
|
||||
PADDLE_ENFORCE_GE(step_h, 0.0, "step_h should be larger than 0.");
|
||||
});
|
||||
|
||||
AddAttr<float>("offset",
|
||||
"(float) "
|
||||
"Density prior boxes center offset.")
|
||||
.SetDefault(0.5);
|
||||
AddAttr<std::vector<float>>("fixed_sizes",
|
||||
"(vector<float>) List of fixed sizes "
|
||||
"of generated density prior boxes.")
|
||||
.SetDefault(std::vector<float>{})
|
||||
.AddCustomChecker([](const std::vector<float>& fixed_sizes) {
|
||||
for (size_t i = 0; i < fixed_sizes.size(); ++i) {
|
||||
PADDLE_ENFORCE_GT(fixed_sizes[i], 0.0,
|
||||
"fixed_sizes[%d] should be larger than 0.", i);
|
||||
}
|
||||
});
|
||||
|
||||
AddAttr<std::vector<float>>("fixed_ratios",
|
||||
"(vector<float>) List of fixed ratios "
|
||||
"of generated density prior boxes.")
|
||||
.SetDefault(std::vector<float>{})
|
||||
.AddCustomChecker([](const std::vector<float>& fixed_ratios) {
|
||||
for (size_t i = 0; i < fixed_ratios.size(); ++i) {
|
||||
PADDLE_ENFORCE_GT(fixed_ratios[i], 0.0,
|
||||
"fixed_ratios[%d] should be larger than 0.", i);
|
||||
}
|
||||
});
|
||||
|
||||
AddAttr<std::vector<int>>("densities",
|
||||
"(vector<float>) List of densities "
|
||||
"of generated density prior boxes.")
|
||||
.SetDefault(std::vector<int>{})
|
||||
.AddCustomChecker([](const std::vector<int>& densities) {
|
||||
for (size_t i = 0; i < densities.size(); ++i) {
|
||||
PADDLE_ENFORCE_GT(densities[i], 0,
|
||||
"densities[%d] should be larger than 0.", i);
|
||||
}
|
||||
});
|
||||
AddComment(R"DOC(
|
||||
Density Prior box operator
|
||||
Each position of the input produce N density prior boxes, N is determined by
|
||||
the count of fixed_ratios, densities, the calculation of N is as follows:
|
||||
for density in densities:
|
||||
N += size(fixed_ratios)*density^2
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(density_prior_box, ops::DensityPriorBoxOp,
|
||||
ops::DensityPriorBoxOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(density_prior_box, ops::DensityPriorBoxOpKernel<float>,
|
||||
ops::DensityPriorBoxOpKernel<double>);
|
@ -0,0 +1,146 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/operators/detection/prior_box_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
class DensityPriorBoxOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input = ctx.Input<paddle::framework::Tensor>("Input");
|
||||
auto* image = ctx.Input<paddle::framework::Tensor>("Image");
|
||||
auto* boxes = ctx.Output<paddle::framework::Tensor>("Boxes");
|
||||
auto* vars = ctx.Output<paddle::framework::Tensor>("Variances");
|
||||
|
||||
auto variances = ctx.Attr<std::vector<float>>("variances");
|
||||
auto clip = ctx.Attr<bool>("clip");
|
||||
|
||||
auto fixed_sizes = ctx.Attr<std::vector<float>>("fixed_sizes");
|
||||
auto fixed_ratios = ctx.Attr<std::vector<float>>("fixed_ratios");
|
||||
auto densities = ctx.Attr<std::vector<int>>("densities");
|
||||
|
||||
T step_w = static_cast<T>(ctx.Attr<float>("step_w"));
|
||||
T step_h = static_cast<T>(ctx.Attr<float>("step_h"));
|
||||
T offset = static_cast<T>(ctx.Attr<float>("offset"));
|
||||
|
||||
auto img_width = image->dims()[3];
|
||||
auto img_height = image->dims()[2];
|
||||
|
||||
auto feature_width = input->dims()[3];
|
||||
auto feature_height = input->dims()[2];
|
||||
|
||||
T step_width, step_height;
|
||||
if (step_w == 0 || step_h == 0) {
|
||||
step_width = static_cast<T>(img_width) / feature_width;
|
||||
step_height = static_cast<T>(img_height) / feature_height;
|
||||
} else {
|
||||
step_width = step_w;
|
||||
step_height = step_h;
|
||||
}
|
||||
int num_priors = 0;
|
||||
if (fixed_sizes.size() > 0 && densities.size() > 0) {
|
||||
for (size_t i = 0; i < densities.size(); ++i) {
|
||||
if (fixed_ratios.size() > 0) {
|
||||
num_priors += (fixed_ratios.size()) * (pow(densities[i], 2));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
boxes->mutable_data<T>(ctx.GetPlace());
|
||||
vars->mutable_data<T>(ctx.GetPlace());
|
||||
auto e_boxes = framework::EigenTensor<T, 4>::From(*boxes).setConstant(0.0);
|
||||
|
||||
int step_average = static_cast<int>((step_width + step_height) * 0.5);
|
||||
|
||||
for (int h = 0; h < feature_height; ++h) {
|
||||
for (int w = 0; w < feature_width; ++w) {
|
||||
T center_x = (w + offset) * step_width;
|
||||
T center_y = (h + offset) * step_height;
|
||||
int idx = 0;
|
||||
// Generate density prior boxes with fixed sizes.
|
||||
for (size_t s = 0; s < fixed_sizes.size(); ++s) {
|
||||
auto fixed_size = fixed_sizes[s];
|
||||
int density = densities[s];
|
||||
// Generate density prior boxes with fixed ratios.
|
||||
if (fixed_ratios.size() > 0) {
|
||||
for (size_t r = 0; r < fixed_ratios.size(); ++r) {
|
||||
float ar = fixed_ratios[r];
|
||||
int shift = step_average / density;
|
||||
float box_width_ratio = fixed_size * sqrt(ar);
|
||||
float box_height_ratio = fixed_size / sqrt(ar);
|
||||
for (int di = 0; di < density; ++di) {
|
||||
for (int dj = 0; dj < density; ++dj) {
|
||||
float center_x_temp =
|
||||
center_x - step_average / 2. + shift / 2. + dj * shift;
|
||||
float center_y_temp =
|
||||
center_y - step_average / 2. + shift / 2. + di * shift;
|
||||
e_boxes(h, w, idx, 0) =
|
||||
(center_x_temp - box_width_ratio / 2.) / img_width >= 0
|
||||
? (center_x_temp - box_width_ratio / 2.) / img_width
|
||||
: 0;
|
||||
e_boxes(h, w, idx, 1) =
|
||||
(center_y_temp - box_height_ratio / 2.) / img_height >= 0
|
||||
? (center_y_temp - box_height_ratio / 2.) / img_height
|
||||
: 0;
|
||||
e_boxes(h, w, idx, 2) =
|
||||
(center_x_temp + box_width_ratio / 2.) / img_width <= 1
|
||||
? (center_x_temp + box_width_ratio / 2.) / img_width
|
||||
: 1;
|
||||
e_boxes(h, w, idx, 3) =
|
||||
(center_y_temp + box_height_ratio / 2.) / img_height <= 1
|
||||
? (center_y_temp + box_height_ratio / 2.) / img_height
|
||||
: 1;
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (clip) {
|
||||
platform::Transform<platform::CPUDeviceContext> trans;
|
||||
ClipFunctor<T> clip_func;
|
||||
trans(ctx.template device_context<platform::CPUDeviceContext>(),
|
||||
boxes->data<T>(), boxes->data<T>() + boxes->numel(),
|
||||
boxes->data<T>(), clip_func);
|
||||
}
|
||||
framework::Tensor var_t;
|
||||
var_t.mutable_data<T>(
|
||||
framework::make_ddim({1, static_cast<int>(variances.size())}),
|
||||
ctx.GetPlace());
|
||||
|
||||
auto var_et = framework::EigenTensor<T, 2>::From(var_t);
|
||||
|
||||
for (size_t i = 0; i < variances.size(); ++i) {
|
||||
var_et(0, i) = variances[i];
|
||||
}
|
||||
|
||||
int box_num = feature_height * feature_width * num_priors;
|
||||
auto var_dim = vars->dims();
|
||||
vars->Resize({box_num, static_cast<int>(variances.size())});
|
||||
|
||||
auto e_vars = framework::EigenMatrix<T, Eigen::RowMajor>::From(*vars);
|
||||
|
||||
e_vars = var_et.broadcast(Eigen::DSizes<int, 2>(box_num, 1));
|
||||
|
||||
vars->Resize(var_dim);
|
||||
}
|
||||
}; // namespace operators
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,142 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import sys
|
||||
import math
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestDensityPriorBoxOp(OpTest):
|
||||
def set_data(self):
|
||||
self.init_test_params()
|
||||
self.init_test_input()
|
||||
self.init_test_output()
|
||||
self.inputs = {'Input': self.input, 'Image': self.image}
|
||||
|
||||
self.attrs = {
|
||||
'variances': self.variances,
|
||||
'clip': self.clip,
|
||||
'step_w': self.step_w,
|
||||
'step_h': self.step_h,
|
||||
'offset': self.offset,
|
||||
'densities': self.densities,
|
||||
'fixed_sizes': self.fixed_sizes,
|
||||
'fixed_ratios': self.fixed_ratios
|
||||
}
|
||||
self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "density_prior_box"
|
||||
self.set_data()
|
||||
|
||||
def set_density(self):
|
||||
self.densities = []
|
||||
self.fixed_sizes = []
|
||||
self.fixed_ratios = []
|
||||
|
||||
def init_test_params(self):
|
||||
self.layer_w = 32
|
||||
self.layer_h = 32
|
||||
|
||||
self.image_w = 40
|
||||
self.image_h = 40
|
||||
|
||||
self.step_w = float(self.image_w) / float(self.layer_w)
|
||||
self.step_h = float(self.image_h) / float(self.layer_h)
|
||||
|
||||
self.input_channels = 2
|
||||
self.image_channels = 3
|
||||
self.batch_size = 10
|
||||
|
||||
self.variances = [0.1, 0.1, 0.2, 0.2]
|
||||
self.variances = np.array(self.variances, dtype=np.float).flatten()
|
||||
|
||||
self.set_density()
|
||||
|
||||
self.clip = True
|
||||
self.num_priors = 0
|
||||
if len(self.fixed_sizes) > 0 and len(self.densities) > 0:
|
||||
for density in self.densities:
|
||||
if len(self.fixed_ratios) > 0:
|
||||
self.num_priors += len(self.fixed_ratios) * (pow(density,
|
||||
2))
|
||||
self.offset = 0.5
|
||||
|
||||
def init_test_input(self):
|
||||
self.image = np.random.random(
|
||||
(self.batch_size, self.image_channels, self.image_w,
|
||||
self.image_h)).astype('float32')
|
||||
|
||||
self.input = np.random.random(
|
||||
(self.batch_size, self.input_channels, self.layer_w,
|
||||
self.layer_h)).astype('float32')
|
||||
|
||||
def init_test_output(self):
|
||||
out_dim = (self.layer_h, self.layer_w, self.num_priors, 4)
|
||||
out_boxes = np.zeros(out_dim).astype('float32')
|
||||
out_var = np.zeros(out_dim).astype('float32')
|
||||
|
||||
step_average = int((self.step_w + self.step_h) * 0.5)
|
||||
for h in range(self.layer_h):
|
||||
for w in range(self.layer_w):
|
||||
idx = 0
|
||||
c_x = (w + self.offset) * self.step_w
|
||||
c_y = (h + self.offset) * self.step_h
|
||||
# Generate density prior boxes with fixed size
|
||||
for density, fixed_size in zip(self.densities,
|
||||
self.fixed_sizes):
|
||||
if (len(self.fixed_ratios) > 0):
|
||||
for ar in self.fixed_ratios:
|
||||
shift = int(step_average / density)
|
||||
box_width_ratio = fixed_size * math.sqrt(ar)
|
||||
box_height_ratio = fixed_size / math.sqrt(ar)
|
||||
for di in range(density):
|
||||
for dj in range(density):
|
||||
c_x_temp = c_x - step_average / 2.0 + shift / 2.0 + dj * shift
|
||||
c_y_temp = c_y - step_average / 2.0 + shift / 2.0 + di * shift
|
||||
out_boxes[h, w, idx, :] = [
|
||||
max((c_x_temp - box_width_ratio / 2.0) /
|
||||
self.image_w, 0),
|
||||
max((c_y_temp - box_height_ratio / 2.0)
|
||||
/ self.image_h, 0),
|
||||
min((c_x_temp + box_width_ratio / 2.0) /
|
||||
self.image_w, 1),
|
||||
min((c_y_temp + box_height_ratio / 2.0)
|
||||
/ self.image_h, 1)
|
||||
]
|
||||
idx += 1
|
||||
if self.clip:
|
||||
out_boxes = np.clip(out_boxes, 0.0, 1.0)
|
||||
out_var = np.tile(self.variances,
|
||||
(self.layer_h, self.layer_w, self.num_priors, 1))
|
||||
self.out_boxes = out_boxes.astype('float32')
|
||||
self.out_var = out_var.astype('float32')
|
||||
|
||||
|
||||
class TestDensityPriorBox(TestDensityPriorBoxOp):
|
||||
def set_density(self):
|
||||
self.densities = [3, 4]
|
||||
self.fixed_sizes = [1.0, 2.0]
|
||||
self.fixed_ratios = [1.0]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue