parent
dcf3ffd980
commit
ee0113af31
@ -0,0 +1,167 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/prior_box_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class PriorBoxOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
||||
"Input(X) of SequenceSliceOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Image"),
|
||||
"Input(Offset) of SequenceSliceOp should not be null.");
|
||||
|
||||
auto image_dims = ctx->GetInputDim("Image");
|
||||
auto input_dims = ctx->GetInputDim("Input");
|
||||
PADDLE_ENFORCE(image_dims.size() == 4,
|
||||
"The format of input tensor is NCHW.");
|
||||
|
||||
auto min_sizes = ctx->Attrs().Get<std::vector<int>>("min_sizes");
|
||||
auto max_sizes = ctx->Attrs().Get<std::vector<int>>("max_sizes");
|
||||
auto variances = ctx->Attrs().Get<std::vector<float>>("variances");
|
||||
auto input_aspect_ratio =
|
||||
ctx->Attrs().Get<std::vector<float>>("aspect_ratios");
|
||||
bool flip = ctx->Attrs().Get<bool>("flip");
|
||||
|
||||
PADDLE_ENFORCE_GT(min_sizes.size(), 0, "must provide min_size.");
|
||||
for (size_t i = 0; i < min_sizes.size(); ++i) {
|
||||
PADDLE_ENFORCE_GT(min_sizes[i], 0, "min_sizes[%d] must be positive.", i);
|
||||
}
|
||||
|
||||
std::vector<float> aspect_ratios;
|
||||
expand_aspect_ratios(input_aspect_ratio, flip, aspect_ratios);
|
||||
|
||||
int num_priors = aspect_ratios.size() * min_sizes.size();
|
||||
if (max_sizes.size() > 0) {
|
||||
PADDLE_ENFORCE_EQ(max_sizes.size(), min_sizes.size(),
|
||||
"The length of min_size and max_size must be equal.");
|
||||
for (size_t i = 0; i < min_sizes.size(); ++i) {
|
||||
PADDLE_ENFORCE_GT(max_sizes[i], min_sizes[i],
|
||||
"max_size[%d] must be greater than min_size[%d].", i,
|
||||
i);
|
||||
num_priors += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (variances.size() > 1) {
|
||||
PADDLE_ENFORCE_EQ(variances.size(), 4,
|
||||
"Must and only provide 4 variance.");
|
||||
for (size_t i = 0; i < variances.size(); ++i) {
|
||||
PADDLE_ENFORCE_GT(variances[i], 0.0,
|
||||
"variance[%d] must be greater than 0.", i);
|
||||
}
|
||||
} else if (variances.size() == 1) {
|
||||
PADDLE_ENFORCE_GT(variances[0], 0.0,
|
||||
"variance[0] must be greater than 0.");
|
||||
}
|
||||
|
||||
const int img_h = ctx->Attrs().Get<int>("img_h");
|
||||
PADDLE_ENFORCE_GT(img_h, 0, "img_h should be larger than 0.");
|
||||
const int img_w = ctx->Attrs().Get<int>("img_w");
|
||||
PADDLE_ENFORCE_GT(img_w, 0, "img_w should be larger than 0.");
|
||||
|
||||
const float step_h = ctx->Attrs().Get<float>("step_h");
|
||||
PADDLE_ENFORCE_GT(step_h, 0.0, "step_h should be larger than 0.");
|
||||
const float step_w = ctx->Attrs().Get<float>("step_w");
|
||||
PADDLE_ENFORCE_GT(step_w, 0.0, "step_w should be larger than 0.");
|
||||
|
||||
const int layer_height = input_dims[3];
|
||||
const int layer_width = input_dims[2];
|
||||
|
||||
std::vector<int64_t> dim_vec(3);
|
||||
// Since all images in a batch has same height and width, we only need to
|
||||
// generate one set of priors which can be shared across all images.
|
||||
dim_vec[0] = 1;
|
||||
// 2 channels. First channel stores the mean of each prior coordinate.
|
||||
// Second channel stores the variance of each prior coordinate.
|
||||
dim_vec[1] = 2;
|
||||
dim_vec[2] = layer_width * layer_height * num_priors * 4;
|
||||
PADDLE_ENFORCE_GT(dim_vec[2], 0,
|
||||
"output_dim[2] must larger than 0."
|
||||
"check your data dims");
|
||||
auto output_dim = framework::make_ddim(dim_vec);
|
||||
ctx->SetOutputDim("Out", output_dim);
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<framework::LoDTensor>("Image")->type()),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
PriorBoxOpMaker(framework::OpProto* proto,
|
||||
framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("Input",
|
||||
"(Tensor), "
|
||||
"the input feature data of PriorBoxOp.");
|
||||
AddInput("Image",
|
||||
"(Tensor), "
|
||||
"the input image data of PriorBoxOp.");
|
||||
AddOutput("Out", "(Tensor), the output prior boxes of PriorBoxOp.");
|
||||
AddAttr<std::vector<int>>("min_sizes", "(vector<int>) ",
|
||||
"List of min sizes of generated prior boxes.");
|
||||
AddAttr<std::vector<int>>("max_sizes", "(vector<int>) ",
|
||||
"List of max sizes of generated prior boxes.");
|
||||
AddAttr<std::vector<float>>(
|
||||
"aspect_ratios", "(vector<float>) ",
|
||||
"List of aspect ratios of generated prior boxes.")
|
||||
.SetDefault({});
|
||||
AddAttr<std::vector<float>>(
|
||||
"variances", "(vector<float>) ",
|
||||
"List of variances to be encoded in prior boxes.")
|
||||
.SetDefault({0.1});
|
||||
AddAttr<bool>("flip", "(bool) ", "Whether to flip aspect ratios.")
|
||||
.SetDefault(true);
|
||||
AddAttr<bool>("clip", "(bool) ", "Whether to clip out-of-boundary boxes.")
|
||||
.SetDefault(true);
|
||||
AddAttr<int>("img_w", "").SetDefault(0);
|
||||
AddAttr<int>("img_h", "").SetDefault(0);
|
||||
AddAttr<float>("step_w",
|
||||
"Prior boxes step across width, 0 for auto calculation.")
|
||||
.SetDefault(0.0);
|
||||
AddAttr<float>("step_h",
|
||||
"Prior boxes step across height, 0 for auto calculation.")
|
||||
.SetDefault(0.0);
|
||||
AddAttr<float>("offset",
|
||||
"(float) "
|
||||
"Prior boxes center offset.")
|
||||
.SetDefault(0.5);
|
||||
AddComment(R"DOC(
|
||||
Prior box operator
|
||||
Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm.
|
||||
Please get more information from the following papers:
|
||||
https://arxiv.org/abs/1512.02325.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
prior_box, ops::PriorBoxOpKernel<paddle::platform::CPUPlace, float>,
|
||||
ops::PriorBoxOpKernel<paddle::platform::CPUPlace, double>);
|
@ -0,0 +1,20 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/prior_box_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
prior_box, ops::PriorBoxOpKernel<paddle::platform::GPUPlace, float>,
|
||||
ops::PriorBoxOpKernel<paddle::platform::GPUPlace, double>);
|
@ -0,0 +1,199 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/math/math_function.h"
|
||||
// #include "paddle/operators/strided_memcpy.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
inline void expand_aspect_ratios(const std::vector<float> input_aspect_ratior,
|
||||
bool flip,
|
||||
std::vector<float>& output_aspect_ratior) {
|
||||
constexpr float eps = 1e-6;
|
||||
output_aspect_ratior.clear();
|
||||
output_aspect_ratior.push_back(1.);
|
||||
for (size_t i = 0; i < input_aspect_ratior.size(); ++i) {
|
||||
float ar = input_aspect_ratior[i];
|
||||
bool already_exist = false;
|
||||
for (size_t j = 0; j < output_aspect_ratior.size(); ++j) {
|
||||
if (fabs(ar - output_aspect_ratior[j]) < eps) {
|
||||
already_exist = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!already_exist) {
|
||||
output_aspect_ratior.push_back(ar);
|
||||
if (flip) {
|
||||
output_aspect_ratior.push_back(1. / ar);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Place, typename T>
|
||||
class PriorBoxOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input = ctx.Input<paddle::framework::Tensor>("Input");
|
||||
auto* image = ctx.Input<paddle::framework::Tensor>("Image");
|
||||
auto* out = ctx.Output<paddle::framework::Tensor>("Out");
|
||||
|
||||
auto min_sizes = ctx.Attr<std::vector<int>>("min_sizes");
|
||||
auto max_sizes = ctx.Attr<std::vector<int>>("max_sizes");
|
||||
auto input_aspect_ratio = ctx.Attr<std::vector<float>>("aspect_ratios");
|
||||
auto variances = ctx.Attr<std::vector<float>>("variances");
|
||||
auto flip = ctx.Attr<bool>("flip");
|
||||
auto clip = ctx.Attr<bool>("clip");
|
||||
|
||||
std::vector<float> aspect_ratios;
|
||||
expand_aspect_ratios(input_aspect_ratio, flip, aspect_ratios);
|
||||
|
||||
auto img_w = ctx.Attr<int>("img_w");
|
||||
auto img_h = ctx.Attr<int>("img_h");
|
||||
auto step_w = ctx.Attr<float>("step_w");
|
||||
auto step_h = ctx.Attr<float>("step_h");
|
||||
auto offset = ctx.Attr<float>("offset");
|
||||
|
||||
int img_width, img_height;
|
||||
if (img_h == 0 || img_w == 0) {
|
||||
img_width = image->dims()[2];
|
||||
img_height = image->dims()[3];
|
||||
} else {
|
||||
img_width = img_w;
|
||||
img_height = img_h;
|
||||
}
|
||||
|
||||
const int layer_width = input->dims()[2];
|
||||
const int layer_height = input->dims()[3];
|
||||
|
||||
float step_width, step_height;
|
||||
if (step_w == 0 || step_h == 0) {
|
||||
step_width = static_cast<float>(img_width) / layer_width;
|
||||
step_height = static_cast<float>(img_height) / layer_height;
|
||||
} else {
|
||||
step_width = step_w;
|
||||
step_height = step_h;
|
||||
}
|
||||
|
||||
int num_priors = aspect_ratios.size() * min_sizes.size();
|
||||
if (max_sizes.size() > 0) {
|
||||
num_priors += max_sizes.size();
|
||||
}
|
||||
|
||||
int dim = layer_height * layer_width * num_priors * 4;
|
||||
|
||||
T* output_data = nullptr;
|
||||
framework::Tensor output_cpu;
|
||||
out->mutable_data<T>(ctx.GetPlace());
|
||||
if (platform::is_gpu_place(ctx.GetPlace())) {
|
||||
output_data =
|
||||
output_cpu.mutable_data<T>(out->dims(), platform::CPUPlace());
|
||||
} else {
|
||||
output_data = out->mutable_data<T>(ctx.GetPlace());
|
||||
}
|
||||
|
||||
int idx = 0;
|
||||
for (int h = 0; h < layer_height; ++h) {
|
||||
for (int w = 0; w < layer_width; ++w) {
|
||||
float center_x = (w + offset) * step_width;
|
||||
float center_y = (h + offset) * step_height;
|
||||
float box_width, box_height;
|
||||
for (size_t s = 0; s < min_sizes.size(); ++s) {
|
||||
int min_size = min_sizes[s];
|
||||
// first prior: aspect_ratio = 1, size = min_size
|
||||
box_width = box_height = min_size;
|
||||
// xmin
|
||||
output_data[idx++] = (center_x - box_width / 2.) / img_width;
|
||||
// ymin
|
||||
output_data[idx++] = (center_y - box_height / 2.) / img_height;
|
||||
// xmax
|
||||
output_data[idx++] = (center_x + box_width / 2.) / img_width;
|
||||
// ymax
|
||||
output_data[idx++] = (center_y + box_height / 2.) / img_height;
|
||||
|
||||
if (max_sizes.size() > 0) {
|
||||
int max_size = max_sizes[s];
|
||||
// second prior: aspect_ratio = 1,
|
||||
// size = sqrt(min_size * max_size)
|
||||
box_width = box_height = sqrt(min_size * max_size);
|
||||
// xmin
|
||||
output_data[idx++] = (center_x - box_width / 2.) / img_width;
|
||||
// ymin
|
||||
output_data[idx++] = (center_y - box_height / 2.) / img_height;
|
||||
// xmax
|
||||
output_data[idx++] = (center_x + box_width / 2.) / img_width;
|
||||
// ymax
|
||||
output_data[idx++] = (center_y + box_height / 2.) / img_height;
|
||||
}
|
||||
|
||||
// rest of priors
|
||||
for (size_t r = 0; r < aspect_ratios.size(); ++r) {
|
||||
float ar = aspect_ratios[r];
|
||||
if (fabs(ar - 1.) < 1e-6) {
|
||||
continue;
|
||||
}
|
||||
box_width = min_size * sqrt(ar);
|
||||
box_height = min_size / sqrt(ar);
|
||||
// xmin
|
||||
output_data[idx++] = (center_x - box_width / 2.) / img_width;
|
||||
// ymin
|
||||
output_data[idx++] = (center_y - box_height / 2.) / img_height;
|
||||
// xmax
|
||||
output_data[idx++] = (center_x + box_width / 2.) / img_width;
|
||||
// ymax
|
||||
output_data[idx++] = (center_y + box_height / 2.) / img_height;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clip the prior's coordidate such that it is within [0, 1]
|
||||
if (clip) {
|
||||
for (int d = 0; d < dim; ++d) {
|
||||
output_data[d] = std::min<T>(std::max<T>(output_data[d], 0.), 1.);
|
||||
}
|
||||
}
|
||||
|
||||
// set the variance.
|
||||
auto output_stride = framework::stride(out->dims());
|
||||
output_data += output_stride[1];
|
||||
if (variances.size() == 1) {
|
||||
for (int i = 0; i < dim; ++i) {
|
||||
output_data[i] = variances[0];
|
||||
}
|
||||
} else {
|
||||
int count = 0;
|
||||
for (int h = 0; h < layer_height; ++h) {
|
||||
for (int w = 0; w < layer_width; ++w) {
|
||||
for (int i = 0; i < num_priors; ++i) {
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
output_data[count] = variances[j];
|
||||
++count;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (platform::is_gpu_place(ctx.GetPlace())) {
|
||||
framework::CopyFrom(output_cpu, platform::CPUPlace(),
|
||||
ctx.device_context(), out);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,179 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
import sys
|
||||
import math
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestPriorBoxOp(OpTest):
|
||||
def set_data(self):
|
||||
self.init_test_params()
|
||||
self.init_test_input()
|
||||
self.init_test_output()
|
||||
self.inputs = {'Input': self.input, 'Image': self.image}
|
||||
|
||||
self.attrs = {
|
||||
'min_sizes': self.min_sizes,
|
||||
'max_sizes': self.max_sizes,
|
||||
'aspect_ratios': self.aspect_ratios,
|
||||
'variances': self.variances,
|
||||
'flip': self.flip,
|
||||
'clip': self.clip,
|
||||
'step_w': self.step_w,
|
||||
'step_h': self.step_h,
|
||||
'img_w': self.image_w,
|
||||
'img_h': self.image_h,
|
||||
'offset': self.offset
|
||||
}
|
||||
|
||||
self.outputs = {'Out': self.output}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
return
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "prior_box"
|
||||
self.set_data()
|
||||
|
||||
def init_test_params(self):
|
||||
self.layer_w = 4
|
||||
self.layer_h = 4
|
||||
|
||||
self.image_w = 20
|
||||
self.image_h = 20
|
||||
|
||||
self.step_w = float(self.image_w) / float(self.layer_w)
|
||||
self.step_h = float(self.image_h) / float(self.layer_h)
|
||||
|
||||
self.input_channels = 2
|
||||
self.image_channels = 3
|
||||
self.batch_size = 10
|
||||
|
||||
self.min_sizes = [2, 4]
|
||||
self.min_sizes = np.array(self.min_sizes).astype('int64')
|
||||
self.max_sizes = [5, 10]
|
||||
self.max_sizes = np.array(self.max_sizes).astype('int64')
|
||||
self.aspect_ratios = [2.0, 3.0]
|
||||
self.flip = True
|
||||
self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0]
|
||||
self.aspect_ratios = np.array(
|
||||
self.aspect_ratios, dtype=np.float).flatten()
|
||||
self.variances = [0.1, 0.1, 0.2, 0.2]
|
||||
self.variances = np.array(self.variances, dtype=np.float).flatten()
|
||||
|
||||
self.clip = True
|
||||
|
||||
self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes)
|
||||
if len(self.max_sizes) > 1:
|
||||
self.num_priors += len(self.max_sizes)
|
||||
self.offset = 0.5
|
||||
|
||||
def init_test_input(self):
|
||||
self.image = np.random.random(
|
||||
(self.batch_size, self.image_channels, self.image_w,
|
||||
self.image_h)).astype('float32')
|
||||
|
||||
self.input = np.random.random(
|
||||
(self.batch_size, self.input_channels, self.layer_w,
|
||||
self.layer_h)).astype('float32')
|
||||
|
||||
def init_test_output(self):
|
||||
dim = self.layer_w * self.layer_h * self.num_priors * 4
|
||||
out_dim = (1, 2, dim)
|
||||
output = np.zeros(out_dim).astype('float32')
|
||||
|
||||
idx = 0
|
||||
for h in range(self.layer_h):
|
||||
for w in range(self.layer_w):
|
||||
center_x = (w + self.offset) * self.step_w
|
||||
center_y = (h + self.offset) * self.step_h
|
||||
for s in range(len(self.min_sizes)):
|
||||
min_size = self.min_sizes[s]
|
||||
# first prior: aspect_ratio = 1, size = min_size
|
||||
box_width = box_height = min_size
|
||||
# xmin
|
||||
output[0, 0, idx] = (
|
||||
center_x - box_width / 2.) / self.image_w
|
||||
idx += 1
|
||||
# ymin
|
||||
output[0, 0, idx] = (
|
||||
center_y - box_height / 2.) / self.image_h
|
||||
idx += 1
|
||||
# xmax
|
||||
output[0, 0, idx] = (
|
||||
center_x + box_width / 2.) / self.image_w
|
||||
idx += 1
|
||||
# ymax
|
||||
output[0, 0, idx] = (
|
||||
center_y + box_height / 2.) / self.image_h
|
||||
idx += 1
|
||||
|
||||
if len(self.max_sizes) > 0:
|
||||
max_size = self.max_sizes[s]
|
||||
# second prior: aspect_ratio = 1,
|
||||
# size = sqrt(min_size * max_size)
|
||||
box_width = box_height = math.sqrt(min_size * max_size)
|
||||
# xmin
|
||||
output[0, 0, idx] = (
|
||||
center_x - box_width / 2.) / self.image_w
|
||||
idx += 1
|
||||
# ymin
|
||||
output[0, 0, idx] = (
|
||||
center_y - box_height / 2.) / self.image_h
|
||||
idx += 1
|
||||
# xmax
|
||||
output[0, 0, idx] = (
|
||||
center_x + box_width / 2.) / self.image_w
|
||||
idx += 1
|
||||
# ymax
|
||||
output[0, 0, idx] = (
|
||||
center_y + box_height / 2.) / self.image_h
|
||||
idx += 1
|
||||
|
||||
# rest of priors
|
||||
for r in range(len(self.real_aspect_ratios)):
|
||||
ar = self.real_aspect_ratios[r]
|
||||
if math.fabs(ar - 1.) < 1e-6:
|
||||
continue
|
||||
box_width = min_size * math.sqrt(ar)
|
||||
box_height = min_size / math.sqrt(ar)
|
||||
# xmin
|
||||
output[0, 0, idx] = (
|
||||
center_x - box_width / 2.) / self.image_w
|
||||
idx += 1
|
||||
# ymin
|
||||
output[0, 0, idx] = (
|
||||
center_y - box_height / 2.) / self.image_h
|
||||
idx += 1
|
||||
# xmax
|
||||
output[0, 0, idx] = (
|
||||
center_x + box_width / 2.) / self.image_w
|
||||
idx += 1
|
||||
# ymax
|
||||
output[0, 0, idx] = (
|
||||
center_y + box_height / 2.) / self.image_h
|
||||
idx += 1
|
||||
# clip the prior's coordidate such that it is within[0, 1]
|
||||
if self.clip:
|
||||
for d in range(dim):
|
||||
output[0, 0, d] = min(max(output[0, 0, d], 0), 1)
|
||||
# set the variance.
|
||||
if len(self.variances) == 1:
|
||||
for i in range(dim):
|
||||
output[0, 1, i] = self.variances[0]
|
||||
else:
|
||||
count = 0
|
||||
for h in range(self.layer_h):
|
||||
for w in range(self.layer_w):
|
||||
for i in range(self.num_priors):
|
||||
for j in range(4):
|
||||
output[0, 1, count] = self.variances[j]
|
||||
count += 1
|
||||
self.output = output.astype('float32')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue