FasterRCNN Anchor Generator Op (#11218)
* Add anchor generator operator for Faster-RCNN. * Add unittest testing. * Add Python API.analysis/code-clean
parent
5f79c7fbb6
commit
5056d3ec56
@ -0,0 +1,154 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/detection/anchor_generator_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class AnchorGeneratorOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
||||
"Input(Input) of AnchorGeneratorOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Anchors"),
|
||||
"Output(Anchors) of AnchorGeneratorOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasOutput("Variances"),
|
||||
"Output(Variances) of AnchorGeneratorOp should not be null.");
|
||||
|
||||
auto input_dims = ctx->GetInputDim("Input");
|
||||
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
|
||||
|
||||
auto anchor_sizes = ctx->Attrs().Get<std::vector<float>>("anchor_sizes");
|
||||
auto aspect_ratios = ctx->Attrs().Get<std::vector<float>>("aspect_ratios");
|
||||
auto stride = ctx->Attrs().Get<std::vector<float>>("stride");
|
||||
auto variances = ctx->Attrs().Get<std::vector<float>>("variances");
|
||||
|
||||
size_t num_anchors = aspect_ratios.size() * anchor_sizes.size();
|
||||
|
||||
std::vector<int64_t> dim_vec(4);
|
||||
dim_vec[0] = input_dims[2];
|
||||
dim_vec[1] = input_dims[3];
|
||||
dim_vec[2] = num_anchors;
|
||||
dim_vec[3] = 4;
|
||||
ctx->SetOutputDim("Anchors", framework::make_ddim(dim_vec));
|
||||
ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class AnchorGeneratorOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Input",
|
||||
"(Tensor, default Tensor<float>), "
|
||||
"the input feature is a tensor with a rank of 4. "
|
||||
"The layout is NCHW.");
|
||||
AddOutput("Anchors",
|
||||
"(Tensor, default Tensor<float>), the output is a "
|
||||
"tensor with a rank of 4. The layout is [H, W, num_anchors, 4]. "
|
||||
"H is the height of input, W is the width of input, num_anchors "
|
||||
"is the box count of each position. "
|
||||
"Each anchor is in (xmin, ymin, xmax, ymax) format");
|
||||
AddOutput("Variances",
|
||||
"(Tensor, default Tensor<float>), the expanded variances for "
|
||||
"normalizing bbox regression targets. The layout is [H, W, "
|
||||
"num_anchors, 4]. "
|
||||
"H is the height of input, W is the width of input, num_anchors "
|
||||
"is the box count of each position. "
|
||||
"Each variance is in (xcenter, ycenter, w, h) format");
|
||||
|
||||
AddAttr<std::vector<float>>(
|
||||
"anchor_sizes",
|
||||
"(vector<float>) List of Region Proposal Network(RPN) anchor sizes "
|
||||
" given in absolute pixels e.g. (64, 128, 256, 512)."
|
||||
" For instance, the anchor size of 64 means the area of this anchor "
|
||||
"equals to 64**2.")
|
||||
.AddCustomChecker([](const std::vector<float>& anchor_sizes) {
|
||||
PADDLE_ENFORCE_GT(anchor_sizes.size(), 0,
|
||||
"Size of anchor_sizes must be at least 1.");
|
||||
for (size_t i = 0; i < anchor_sizes.size(); ++i) {
|
||||
PADDLE_ENFORCE_GT(anchor_sizes[i], 0.0,
|
||||
"anchor_sizes[%d] must be positive.", i);
|
||||
}
|
||||
});
|
||||
AddAttr<std::vector<float>>(
|
||||
"aspect_ratios",
|
||||
"(vector<float>) List of Region Proposal Network(RPN) anchor aspect "
|
||||
"ratios, e.g. (0.5, 1, 2)."
|
||||
"For instacne, the aspect ratio of 0.5 means the height / width of "
|
||||
"this anchor equals 0.5.");
|
||||
|
||||
AddAttr<std::vector<float>>("variances",
|
||||
"(vector<float>) List of variances to be used "
|
||||
"in box regression deltas")
|
||||
.AddCustomChecker([](const std::vector<float>& variances) {
|
||||
PADDLE_ENFORCE_EQ(variances.size(), 4,
|
||||
"Must and only provide 4 variance.");
|
||||
for (size_t i = 0; i < variances.size(); ++i) {
|
||||
PADDLE_ENFORCE_GT(variances[i], 0.0,
|
||||
"variance[%d] must be greater than 0.", i);
|
||||
}
|
||||
});
|
||||
|
||||
AddAttr<std::vector<float>>("stride",
|
||||
"Anchors stride across width and height, "
|
||||
"with a default of (16, 16)")
|
||||
.SetDefault(std::vector<float>(2, 16.0))
|
||||
.AddCustomChecker([](const std::vector<float>& stride) {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
stride.size(), 2,
|
||||
"Must and only provide 2 stride for width and height.");
|
||||
for (size_t i = 0; i < stride.size(); ++i) {
|
||||
PADDLE_ENFORCE_GT(stride[i], 0.0,
|
||||
"stride[%d] should be larger than 0.", i);
|
||||
}
|
||||
});
|
||||
|
||||
AddAttr<float>("offset",
|
||||
"(float) "
|
||||
"Anchor center offset, with a default of 0.5")
|
||||
.SetDefault(0.5);
|
||||
AddComment(R"DOC(
|
||||
AnchorGenerator operator
|
||||
Generates anchors for Faster RCNN, FPN etc. algorithm.
|
||||
Each position of the input produce N anchors, N =
|
||||
size(anchor_sizes) * size(aspect_ratios).
|
||||
|
||||
Please get more information from the following papers:
|
||||
https://arxiv.org/abs/1506.01497.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(anchor_generator, ops::AnchorGeneratorOp,
|
||||
ops::AnchorGeneratorOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(anchor_generator, ops::AnchorGeneratorOpKernel<float>,
|
||||
ops::AnchorGeneratorOpKernel<double>);
|
@ -0,0 +1,132 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/detection/anchor_generator_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
__global__ void GenAnchors(T* out, const T* aspect_ratios, const int ar_num,
|
||||
const T* anchor_sizes, const int as_num,
|
||||
const T* stride, const int sd_num, const int height,
|
||||
const int width, const T offset) {
|
||||
int num_anchors = as_num * ar_num;
|
||||
int box_num = height * width * num_anchors;
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < box_num;
|
||||
i += blockDim.x * gridDim.x) {
|
||||
int h_idx = i / (num_anchors * width);
|
||||
int w_idx = (i / num_anchors) % width;
|
||||
T stride_width = stride[0];
|
||||
T stride_height = stride[1];
|
||||
T x_ctr = (w_idx * stride_width) + offset * (stride_width - 1);
|
||||
T y_ctr = (h_idx * stride_height) + offset * (stride_height - 1);
|
||||
T area, area_ratios;
|
||||
T base_w, base_h;
|
||||
T scale_w, scale_h;
|
||||
T anchor_width, anchor_height;
|
||||
int anch_idx = i % num_anchors;
|
||||
int ar_idx = anch_idx / as_num;
|
||||
int as_idx = anch_idx % as_num;
|
||||
T aspect_ratio = aspect_ratios[ar_idx];
|
||||
T anchor_size = anchor_sizes[as_idx];
|
||||
area = stride_width * stride_height;
|
||||
area_ratios = area / aspect_ratio;
|
||||
base_w = round(sqrt(area_ratios));
|
||||
base_h = round(base_w * aspect_ratio);
|
||||
scale_w = anchor_size / stride_width;
|
||||
scale_h = anchor_size / stride_height;
|
||||
anchor_width = scale_w * base_w;
|
||||
anchor_height = scale_h * base_h;
|
||||
|
||||
T xmin = (x_ctr - 0.5 * (anchor_width - 1));
|
||||
T ymin = (y_ctr - 0.5 * (anchor_height - 1));
|
||||
T xmax = (x_ctr + 0.5 * (anchor_width - 1));
|
||||
T ymax = (y_ctr + 0.5 * (anchor_height - 1));
|
||||
out[i * 4] = xmin;
|
||||
out[i * 4 + 1] = ymin;
|
||||
out[i * 4 + 2] = xmax;
|
||||
out[i * 4 + 3] = ymax;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void SetVariance(T* out, const T* var, const int vnum,
|
||||
const int num) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
|
||||
i += blockDim.x * gridDim.x) {
|
||||
out[i] = var[i % vnum];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class AnchorGeneratorOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input = ctx.Input<paddle::framework::Tensor>("Input");
|
||||
auto* anchors = ctx.Output<paddle::framework::Tensor>("Anchors");
|
||||
auto* vars = ctx.Output<paddle::framework::Tensor>("Variances");
|
||||
|
||||
auto anchor_sizes = ctx.Attr<std::vector<float>>("anchor_sizes");
|
||||
auto aspect_ratios = ctx.Attr<std::vector<float>>("aspect_ratios");
|
||||
auto stride = ctx.Attr<std::vector<float>>("stride");
|
||||
auto variances = ctx.Attr<std::vector<float>>("variances");
|
||||
|
||||
T offset = static_cast<T>(ctx.Attr<float>("offset"));
|
||||
|
||||
auto width = input->dims()[3];
|
||||
auto height = input->dims()[2];
|
||||
|
||||
int num_anchors = aspect_ratios.size() * anchor_sizes.size();
|
||||
|
||||
int box_num = width * height * num_anchors;
|
||||
|
||||
int block = 512;
|
||||
int grid = (box_num + block - 1) / block;
|
||||
|
||||
auto stream =
|
||||
ctx.template device_context<platform::CUDADeviceContext>().stream();
|
||||
|
||||
anchors->mutable_data<T>(ctx.GetPlace());
|
||||
vars->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
framework::Tensor ar;
|
||||
framework::TensorFromVector(aspect_ratios, ctx.device_context(), &ar);
|
||||
|
||||
framework::Tensor as;
|
||||
framework::TensorFromVector(anchor_sizes, ctx.device_context(), &as);
|
||||
|
||||
framework::Tensor sd;
|
||||
framework::TensorFromVector(stride, ctx.device_context(), &sd);
|
||||
|
||||
GenAnchors<T><<<grid, block, 0, stream>>>(
|
||||
anchors->data<T>(), ar.data<T>(), aspect_ratios.size(), as.data<T>(),
|
||||
anchor_sizes.size(), sd.data<T>(), stride.size(), height, width,
|
||||
offset);
|
||||
|
||||
framework::Tensor v;
|
||||
framework::TensorFromVector(variances, ctx.device_context(), &v);
|
||||
grid = (box_num * 4 + block - 1) / block;
|
||||
SetVariance<T><<<grid, block, 0, stream>>>(vars->data<T>(), v.data<T>(),
|
||||
variances.size(), box_num * 4);
|
||||
}
|
||||
}; // namespace operators
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(anchor_generator,
|
||||
ops::AnchorGeneratorOpCUDAKernel<float>,
|
||||
ops::AnchorGeneratorOpCUDAKernel<double>);
|
@ -0,0 +1,109 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
#include "paddle/fluid/platform/transform.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
class AnchorGeneratorOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input = ctx.Input<paddle::framework::Tensor>("Input");
|
||||
auto* anchors = ctx.Output<paddle::framework::Tensor>("Anchors");
|
||||
auto* vars = ctx.Output<paddle::framework::Tensor>("Variances");
|
||||
|
||||
auto anchor_sizes = ctx.Attr<std::vector<float>>("anchor_sizes");
|
||||
auto aspect_ratios = ctx.Attr<std::vector<float>>("aspect_ratios");
|
||||
auto stride = ctx.Attr<std::vector<float>>("stride");
|
||||
auto variances = ctx.Attr<std::vector<float>>("variances");
|
||||
|
||||
T offset = static_cast<T>(ctx.Attr<float>("offset"));
|
||||
|
||||
auto feature_width = input->dims()[3];
|
||||
auto feature_height = input->dims()[2];
|
||||
|
||||
T stride_width, stride_height;
|
||||
stride_width = stride[0];
|
||||
stride_height = stride[1];
|
||||
|
||||
int num_anchors = aspect_ratios.size() * anchor_sizes.size();
|
||||
|
||||
anchors->mutable_data<T>(ctx.GetPlace());
|
||||
vars->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto e_anchors = framework::EigenTensor<T, 4>::From(*anchors);
|
||||
for (int h_idx = 0; h_idx < feature_height; ++h_idx) {
|
||||
for (int w_idx = 0; w_idx < feature_width; ++w_idx) {
|
||||
T x_ctr = (w_idx * stride_width) + offset * (stride_width - 1);
|
||||
T y_ctr = (h_idx * stride_height) + offset * (stride_height - 1);
|
||||
T area, area_ratios;
|
||||
T base_w, base_h;
|
||||
T scale_w, scale_h;
|
||||
T anchor_width, anchor_height;
|
||||
int idx = 0;
|
||||
for (size_t r = 0; r < aspect_ratios.size(); ++r) {
|
||||
auto ar = aspect_ratios[r];
|
||||
for (size_t s = 0; s < anchor_sizes.size(); ++s) {
|
||||
auto anchor_size = anchor_sizes[s];
|
||||
area = stride_width * stride_height;
|
||||
area_ratios = area / ar;
|
||||
base_w = round(sqrt(area_ratios));
|
||||
base_h = round(base_w * ar);
|
||||
scale_w = anchor_size / stride_width;
|
||||
scale_h = anchor_size / stride_height;
|
||||
anchor_width = scale_w * base_w;
|
||||
anchor_height = scale_h * base_h;
|
||||
e_anchors(h_idx, w_idx, idx, 0) =
|
||||
(x_ctr - 0.5 * (anchor_width - 1));
|
||||
e_anchors(h_idx, w_idx, idx, 1) =
|
||||
(y_ctr - 0.5 * (anchor_height - 1));
|
||||
e_anchors(h_idx, w_idx, idx, 2) =
|
||||
(x_ctr + 0.5 * (anchor_width - 1));
|
||||
e_anchors(h_idx, w_idx, idx, 3) =
|
||||
(y_ctr + 0.5 * (anchor_height - 1));
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
framework::Tensor var_t;
|
||||
var_t.mutable_data<T>(
|
||||
framework::make_ddim({1, static_cast<int>(variances.size())}),
|
||||
ctx.GetPlace());
|
||||
auto var_et = framework::EigenTensor<T, 2>::From(var_t);
|
||||
for (size_t i = 0; i < variances.size(); ++i) {
|
||||
var_et(0, i) = variances[i];
|
||||
}
|
||||
|
||||
int anchor_num = feature_height * feature_width * num_anchors;
|
||||
auto var_dim = vars->dims();
|
||||
vars->Resize({anchor_num, static_cast<int>(variances.size())});
|
||||
|
||||
auto e_vars = framework::EigenMatrix<T, Eigen::RowMajor>::From(*vars);
|
||||
e_vars = var_et.broadcast(Eigen::DSizes<int, 2>(anchor_num, 1));
|
||||
|
||||
vars->Resize(var_dim);
|
||||
}
|
||||
}; // namespace operators
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,110 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://w_idxw.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import sys
|
||||
import math
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def anchor_generator_in_python(input_feat, anchor_sizes, aspect_ratios,
|
||||
variances, stride, offset):
|
||||
num_anchors = len(aspect_ratios) * len(anchor_sizes)
|
||||
layer_h = input_feat.shape[2]
|
||||
layer_w = input_feat.shape[3]
|
||||
out_dim = (layer_h, layer_w, num_anchors, 4)
|
||||
out_anchors = np.zeros(out_dim).astype('float32')
|
||||
|
||||
for h_idx in range(layer_h):
|
||||
for w_idx in range(layer_w):
|
||||
x_ctr = (w_idx * stride[0]) + offset * (stride[0] - 1)
|
||||
y_ctr = (h_idx * stride[1]) + offset * (stride[1] - 1)
|
||||
idx = 0
|
||||
for r in range(len(aspect_ratios)):
|
||||
ar = aspect_ratios[r]
|
||||
for s in range(len(anchor_sizes)):
|
||||
anchor_size = anchor_sizes[s]
|
||||
area = stride[0] * stride[1]
|
||||
area_ratios = area / ar
|
||||
base_w = np.round(np.sqrt(area_ratios))
|
||||
base_h = np.round(base_w * ar)
|
||||
scale_w = anchor_size / stride[0]
|
||||
scale_h = anchor_size / stride[1]
|
||||
w = scale_w * base_w
|
||||
h = scale_h * base_h
|
||||
out_anchors[h_idx, w_idx, idx, :] = [
|
||||
(x_ctr - 0.5 * (w - 1)), (y_ctr - 0.5 * (h - 1)),
|
||||
(x_ctr + 0.5 * (w - 1)), (y_ctr + 0.5 * (h - 1))
|
||||
]
|
||||
idx += 1
|
||||
|
||||
# set the variance.
|
||||
out_var = np.tile(variances, (layer_h, layer_w, num_anchors, 1))
|
||||
out_anchors = out_anchors.astype('float32')
|
||||
out_var = out_var.astype('float32')
|
||||
return out_anchors, out_var
|
||||
|
||||
|
||||
class TestAnchorGeneratorOp(OpTest):
|
||||
def set_data(self):
|
||||
self.init_test_params()
|
||||
self.init_test_input()
|
||||
self.init_test_output()
|
||||
self.inputs = {'Input': self.input}
|
||||
|
||||
self.attrs = {
|
||||
'anchor_sizes': self.anchor_sizes,
|
||||
'aspect_ratios': self.aspect_ratios,
|
||||
'stride': self.stride,
|
||||
'offset': self.offset,
|
||||
'variances': self.variances,
|
||||
}
|
||||
|
||||
self.outputs = {'Anchors': self.out_anchors, 'Variances': self.out_var}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "anchor_generator"
|
||||
self.set_data()
|
||||
|
||||
def init_test_params(self):
|
||||
self.batch_size = 1
|
||||
self.input_channels = 2
|
||||
self.layer_h = 2
|
||||
self.layer_w = 2
|
||||
|
||||
self.anchor_sizes = [64., 128., 256., 512.]
|
||||
self.aspect_ratios = [0.5, 1., 2.]
|
||||
self.stride = [16., 16.]
|
||||
|
||||
self.offset = 0.5
|
||||
|
||||
self.variances = [0.1, 0.1, 0.2, 0.2]
|
||||
|
||||
def init_test_input(self):
|
||||
self.input = np.random.random(
|
||||
(self.batch_size, self.input_channels, self.layer_h,
|
||||
self.layer_w)).astype('float32')
|
||||
|
||||
def init_test_output(self):
|
||||
self.out_anchors, self.out_var = anchor_generator_in_python(
|
||||
self.input, self.anchor_sizes, self.aspect_ratios, self.variances,
|
||||
self.stride, self.offset)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue