parent
430fdc52a8
commit
72eccb238e
@ -0,0 +1,106 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/box_coder_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class BoxCoderOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("PriorBox"),
|
||||
"Input(PriorBox) of BoxCoderOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("PriorBoxVar"),
|
||||
"Input(PriorBoxVar) of BoxCoderOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("PriorBox"),
|
||||
"Input(TargetBox) of BoxCoderOp should not be null.");
|
||||
|
||||
auto prior_box_dims = ctx->GetInputDim("PriorBox");
|
||||
auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar");
|
||||
auto target_box_dims = ctx->GetInputDim("TargetBox");
|
||||
|
||||
PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2UL,
|
||||
"The shape of PriorBox is [N, 4]");
|
||||
PADDLE_ENFORCE_EQ(prior_box_dims[1], 4UL,
|
||||
"The shape of PriorBox is [N, 4]");
|
||||
PADDLE_ENFORCE_EQ(prior_box_var_dims.size(), 2UL,
|
||||
"The shape of PriorBoxVar is [N, 4]");
|
||||
PADDLE_ENFORCE_EQ(prior_box_var_dims[1], 4UL,
|
||||
"The shape of PriorBoxVar is [N, 4]");
|
||||
PADDLE_ENFORCE_EQ(target_box_dims.size(), 2UL,
|
||||
"The shape of TargetBox is [M, 4]");
|
||||
PADDLE_ENFORCE_EQ(target_box_dims[1], 4UL,
|
||||
"The shape of TargetBox is [M, 4]");
|
||||
|
||||
GetBoxCodeType(ctx->Attrs().Get<std::string>("code_type"));
|
||||
|
||||
ctx->SetOutputDim("OutputBox", framework::make_ddim({target_box_dims[0],
|
||||
target_box_dims[1]}));
|
||||
}
|
||||
};
|
||||
|
||||
class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
BoxCoderOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput(
|
||||
"PriorBox",
|
||||
"(Tensor, default Tensor<float>) "
|
||||
"Box list PriorBox is a 2-D Tensor with shape [M, 4] holds N boxes, "
|
||||
"each box is represented as [xmin, ymin, xmax, ymax], "
|
||||
"[xmin, ymin] is the left top coordinate of the anchor box, "
|
||||
"if the input is image feature map, they are close to the origin "
|
||||
"of the coordinate system. [xmax, ymax] is the right bottom "
|
||||
"coordinate of the anchor box.");
|
||||
AddInput("PriorBoxVar",
|
||||
"(Tensor, default Tensor<float>) "
|
||||
"PriorBoxVar is a 2-D Tensor with shape [M, 4] holds N group "
|
||||
"of variance.");
|
||||
AddInput(
|
||||
"TargetBox",
|
||||
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
|
||||
"[N, 4], each box is represented as [xmin, ymin, xmax, ymax], "
|
||||
"[xmin, ymin] is the left top coordinate of the box if the input "
|
||||
"is image feature map, they are close to the origin of the coordinate "
|
||||
"system. [xmax, ymax] is the right bottom coordinate of the box. "
|
||||
"This tensor can contain LoD information to represent a batch "
|
||||
"of inputs. One instance of this batch can contain different "
|
||||
"numbers of entities.");
|
||||
AddAttr<std::string>("code_type",
|
||||
"(string, default encode_center_size) "
|
||||
"the code type used with the target box")
|
||||
.SetDefault("encode_center_size")
|
||||
.InEnum({"encode_center_size", "decode_center_size"});
|
||||
AddOutput(
|
||||
"OutputBox",
|
||||
"(Tensor, default Tensor<float>)"
|
||||
"(Tensor) The output of box_coder_op, a tensor with shape [N, M, 4] "
|
||||
"representing the result of N target boxes encoded/decoded with "
|
||||
"M Prior boxes and variances.");
|
||||
|
||||
AddComment(R"DOC(
|
||||
Bounding Box Coder Operator.
|
||||
Encode/Decode the priorbox information with the target bounding box.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(box_coder, ops::BoxCoderOp, ops::BoxCoderOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(box_coder, ops::BoxCoderKernel<float>,
|
||||
ops::BoxCoderKernel<double>);
|
@ -0,0 +1,145 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/box_coder_op.h"
|
||||
#include "paddle/platform/cuda_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using platform::PADDLE_CUDA_NUM_THREADS;
|
||||
|
||||
template <typename T>
|
||||
__global__ void EncodeCenterSizeKernel(const T* prior_box_data,
|
||||
const T* prior_box_var_data,
|
||||
const T* target_box_data, int row,
|
||||
int col, T* output) {
|
||||
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (idx < row * col) {
|
||||
const int row_idx = idx / col;
|
||||
const int col_idx = idx % col;
|
||||
T prior_box_width =
|
||||
prior_box_data[col_idx * 4 + 2] - prior_box_data[col_idx * 4];
|
||||
T prior_box_height =
|
||||
prior_box_data[col_idx * 4 + 3] - prior_box_data[col_idx * 4 + 1];
|
||||
T prior_box_center_x =
|
||||
(prior_box_data[col_idx * 4 + 2] + prior_box_data[col_idx * 4]) / 2;
|
||||
T prior_box_center_y =
|
||||
(prior_box_data[col_idx * 4 + 3] + prior_box_data[col_idx * 4 + 1]) / 2;
|
||||
|
||||
T target_box_center_x =
|
||||
(target_box_data[row_idx * 4 + 2] + target_box_data[row_idx * 4]) / 2;
|
||||
T target_box_center_y =
|
||||
(target_box_data[row_idx * 4 + 3] + target_box_data[row_idx * 4 + 1]) /
|
||||
2;
|
||||
T target_box_width =
|
||||
target_box_data[row_idx * 4 + 2] - target_box_data[row_idx * 4];
|
||||
T target_box_height =
|
||||
target_box_data[row_idx * 4 + 3] - target_box_data[row_idx * 4 + 1];
|
||||
|
||||
output[idx * 4] = (target_box_center_x - prior_box_center_x) /
|
||||
prior_box_width / prior_box_var_data[col_idx * 4];
|
||||
output[idx * 4 + 1] = (target_box_center_y - prior_box_center_y) /
|
||||
prior_box_height /
|
||||
prior_box_var_data[col_idx * 4 + 1];
|
||||
output[idx * 4 + 2] = log(fabs(target_box_width / prior_box_width)) /
|
||||
prior_box_var_data[col_idx * 4 + 2];
|
||||
output[idx * 4 + 3] = log(fabs(target_box_height / prior_box_height)) /
|
||||
prior_box_var_data[col_idx * 4 + 3];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void DecodeCenterSizeKernel(const T* prior_box_data,
|
||||
const T* prior_box_var_data,
|
||||
const T* target_box_data, int row,
|
||||
int col, T* output) {
|
||||
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (idx < row * col) {
|
||||
const int row_idx = idx / col;
|
||||
const int col_idx = idx % col;
|
||||
T prior_box_width =
|
||||
prior_box_data[col_idx * 4 + 2] - prior_box_data[col_idx * 4];
|
||||
T prior_box_height =
|
||||
prior_box_data[col_idx * 4 + 3] - prior_box_data[col_idx * 4 + 1];
|
||||
T prior_box_center_x =
|
||||
(prior_box_data[col_idx * 4 + 2] + prior_box_data[col_idx * 4]) / 2;
|
||||
T prior_box_center_y =
|
||||
(prior_box_data[col_idx * 4 + 3] + prior_box_data[col_idx * 4 + 1]) / 2;
|
||||
|
||||
T target_box_width = exp(prior_box_var_data[col_idx * 4 + 2] *
|
||||
target_box_data[row_idx * 4 + 2]) *
|
||||
prior_box_width;
|
||||
T target_box_height = exp(prior_box_var_data[col_idx * 4 + 3] *
|
||||
target_box_data[row_idx * 4 + 3]) *
|
||||
prior_box_height;
|
||||
T target_box_center_x = prior_box_var_data[col_idx * 4] *
|
||||
target_box_data[row_idx * 4] * prior_box_width +
|
||||
prior_box_center_x;
|
||||
T target_box_center_y = prior_box_var_data[col_idx * 4 + 1] *
|
||||
target_box_data[row_idx * 4 + 1] *
|
||||
prior_box_height +
|
||||
prior_box_center_y;
|
||||
|
||||
output[idx * 4] = target_box_center_x - target_box_width / 2;
|
||||
output[idx * 4 + 1] = target_box_center_y - target_box_height / 2;
|
||||
output[idx * 4 + 2] = target_box_center_x + target_box_width / 2;
|
||||
output[idx * 4 + 3] = target_box_center_y + target_box_height / 2;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class BoxCoderCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
|
||||
"This kernel only runs on GPU device.");
|
||||
auto* prior_box = context.Input<framework::Tensor>("PriorBox");
|
||||
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
|
||||
auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
|
||||
auto* output_box = context.Output<Tensor>("OutputBox");
|
||||
|
||||
if (target_box->lod().size()) {
|
||||
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL,
|
||||
"Only support 1 level of LoD.");
|
||||
}
|
||||
auto row = target_box->dims()[0];
|
||||
auto col = prior_box->dims()[0];
|
||||
int block = 512;
|
||||
int grid = (row * col + block - 1) / block;
|
||||
auto& device_ctx = context.cuda_device_context();
|
||||
|
||||
const T* prior_box_data = prior_box->data<T>();
|
||||
const T* prior_box_var_data = prior_box_var->data<T>();
|
||||
const T* target_box_data = target_box->data<T>();
|
||||
|
||||
output_box->mutable_data<T>({row, col, 4}, context.GetPlace());
|
||||
T* output = output_box->data<T>();
|
||||
|
||||
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
|
||||
if (code_type == BoxCodeType::kEncodeCenterSize) {
|
||||
EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
|
||||
prior_box_data, prior_box_var_data, target_box_data, row, col,
|
||||
output);
|
||||
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
|
||||
DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
|
||||
prior_box_data, prior_box_var_data, target_box_data, row, col,
|
||||
output);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(box_coder, ops::BoxCoderCUDAKernel<float>,
|
||||
ops::BoxCoderCUDAKernel<double>);
|
@ -0,0 +1,163 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
|
||||
enum class BoxCodeType { kEncodeCenterSize = 0, kDecodeCenterSize = 1 };
|
||||
|
||||
inline BoxCodeType GetBoxCodeType(const std::string& type) {
|
||||
if (type == "encode_center_size") {
|
||||
return BoxCodeType::kEncodeCenterSize;
|
||||
} else if (type == "decode_center_size") {
|
||||
return BoxCodeType::kDecodeCenterSize;
|
||||
}
|
||||
PADDLE_THROW("Not support type %s.", type);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class BoxCoderKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void EncodeCenterSize(const Tensor& target_box, const Tensor& prior_box,
|
||||
const Tensor& prior_box_var, T* output) const {
|
||||
PADDLE_ENFORCE_EQ(target_box.dims().size(), 2,
|
||||
"The rank of target_box must be 2.");
|
||||
PADDLE_ENFORCE_EQ(prior_box.dims().size(), 2,
|
||||
"The rank of prior_box must be 2.");
|
||||
PADDLE_ENFORCE_EQ(prior_box_var.dims().size(), 2,
|
||||
"The rank of prior_box_var must be 2.");
|
||||
PADDLE_ENFORCE_EQ(prior_box.dims()[0], prior_box_var.dims()[0],
|
||||
"The dims of prior_box must equal to prior_box_var.");
|
||||
|
||||
int64_t row = target_box.dims()[0];
|
||||
int64_t col = prior_box.dims()[0];
|
||||
auto* target_box_data = target_box.data<T>();
|
||||
auto* prior_box_data = prior_box.data<T>();
|
||||
auto* prior_box_var_data = prior_box_var.data<T>();
|
||||
|
||||
for (int64_t i = 0; i < row; ++i) {
|
||||
for (int64_t j = 0; j < col; ++j) {
|
||||
T prior_box_width = prior_box_data[j * 4 + 2] - prior_box_data[j * 4];
|
||||
T prior_box_height =
|
||||
prior_box_data[j * 4 + 3] - prior_box_data[j * 4 + 1];
|
||||
T prior_box_center_x =
|
||||
(prior_box_data[j * 4 + 2] + prior_box_data[j * 4]) / 2;
|
||||
T prior_box_center_y =
|
||||
(prior_box_data[j * 4 + 3] + prior_box_data[j * 4 + 1]) / 2;
|
||||
|
||||
T target_box_center_x =
|
||||
(target_box_data[i * 4 + 2] + target_box_data[i * 4]) / 2;
|
||||
T target_box_center_y =
|
||||
(target_box_data[i * 4 + 3] + target_box_data[i * 4 + 1]) / 2;
|
||||
T target_box_width =
|
||||
target_box_data[i * 4 + 2] - target_box_data[i * 4];
|
||||
T target_box_height =
|
||||
target_box_data[i * 4 + 3] - target_box_data[i * 4 + 1];
|
||||
|
||||
size_t offset = i * col * 4 + j * 4;
|
||||
output[offset] = (target_box_center_x - prior_box_center_x) /
|
||||
prior_box_width / prior_box_var_data[j * 4];
|
||||
output[offset + 1] = (target_box_center_y - prior_box_center_y) /
|
||||
prior_box_height / prior_box_var_data[j * 4 + 1];
|
||||
output[offset + 2] =
|
||||
std::log(std::fabs(target_box_width / prior_box_width)) /
|
||||
prior_box_var_data[j * 4 + 2];
|
||||
output[offset + 3] =
|
||||
std::log(std::fabs(target_box_height / prior_box_height)) /
|
||||
prior_box_var_data[j * 4 + 3];
|
||||
}
|
||||
}
|
||||
}
|
||||
void DecodeCenterSize(const Tensor& target_box, const Tensor& prior_box,
|
||||
const Tensor& prior_box_var, T* output) const {
|
||||
PADDLE_ENFORCE_EQ(target_box.dims().size(), 2,
|
||||
"The rank of target_box must be 2.");
|
||||
PADDLE_ENFORCE_EQ(prior_box.dims().size(), 2,
|
||||
"The rank of prior_box must be 2.");
|
||||
PADDLE_ENFORCE_EQ(prior_box_var.dims().size(), 2,
|
||||
"The rank of prior_box_var must be 2.");
|
||||
PADDLE_ENFORCE_EQ(prior_box.dims()[0], prior_box_var.dims()[0],
|
||||
"The dims of prior_box must equal to prior_box_var.");
|
||||
|
||||
int64_t row = target_box.dims()[0];
|
||||
int64_t col = prior_box.dims()[0];
|
||||
|
||||
auto* target_box_data = target_box.data<T>();
|
||||
auto* prior_box_data = prior_box.data<T>();
|
||||
auto* prior_box_var_data = prior_box_var.data<T>();
|
||||
|
||||
for (int64_t i = 0; i < row; ++i) {
|
||||
for (int64_t j = 0; j < col; ++j) {
|
||||
T prior_box_width = prior_box_data[j * 4 + 2] - prior_box_data[j * 4];
|
||||
T prior_box_height =
|
||||
prior_box_data[j * 4 + 3] - prior_box_data[j * 4 + 1];
|
||||
T prior_box_center_x =
|
||||
(prior_box_data[j * 4 + 2] + prior_box_data[j * 4]) / 2;
|
||||
T prior_box_center_y =
|
||||
(prior_box_data[j * 4 + 3] + prior_box_data[j * 4 + 1]) / 2;
|
||||
|
||||
T target_box_center_x = prior_box_var_data[j * 4] *
|
||||
target_box_data[i * 4] * prior_box_width +
|
||||
prior_box_center_x;
|
||||
T target_box_center_y = prior_box_var_data[j * 4 + 1] *
|
||||
target_box_data[i * 4 + 1] *
|
||||
prior_box_height +
|
||||
prior_box_center_y;
|
||||
T target_box_width = std::exp(prior_box_var_data[j * 4 + 2] *
|
||||
target_box_data[i * 4 + 2]) *
|
||||
prior_box_width;
|
||||
T target_box_height = std::exp(prior_box_var_data[j * 4 + 3] *
|
||||
target_box_data[i * 4 + 3]) *
|
||||
prior_box_height;
|
||||
|
||||
size_t offset = i * col * 4 + j * 4;
|
||||
output[offset] = target_box_center_x - target_box_width / 2;
|
||||
output[offset + 1] = target_box_center_y - target_box_height / 2;
|
||||
output[offset + 2] = target_box_center_x + target_box_width / 2;
|
||||
output[offset + 3] = target_box_center_y + target_box_height / 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* prior_box = context.Input<framework::Tensor>("PriorBox");
|
||||
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
|
||||
auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
|
||||
auto* output_box = context.Output<Tensor>("OutputBox");
|
||||
|
||||
if (target_box->lod().size()) {
|
||||
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1UL,
|
||||
"Only support 1 level of LoD.");
|
||||
}
|
||||
auto row = target_box->dims()[0];
|
||||
auto col = prior_box->dims()[0];
|
||||
|
||||
output_box->mutable_data<T>({row, col, 4}, context.GetPlace());
|
||||
|
||||
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
|
||||
T* output = output_box->data<T>();
|
||||
if (code_type == BoxCodeType::kEncodeCenterSize) {
|
||||
EncodeCenterSize(*target_box, *prior_box, *prior_box_var, output);
|
||||
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
|
||||
DecodeCenterSize(*target_box, *prior_box, *prior_box_var, output);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,117 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import sys
|
||||
import math
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def box_coder(target_box, prior_box, prior_box_var, output_box, code_type):
|
||||
prior_box_x = (prior_box[:, 2] + prior_box[:, 0]) / 2
|
||||
prior_box_y = (prior_box[:, 3] + prior_box[:, 1]) / 2
|
||||
prior_box_width = (prior_box[:, 2] - prior_box[:, 0])
|
||||
prior_box_height = (prior_box[:, 3] - prior_box[:, 1])
|
||||
|
||||
if (code_type == "EncodeCenterSize"):
|
||||
target_box_x = (target_box[:, 2] + target_box[:, 0]) / 2
|
||||
target_box_y = (target_box[:, 3] + target_box[:, 1]) / 2
|
||||
target_box_width = (target_box[:, 2] - target_box[:, 0])
|
||||
target_box_height = (target_box[:, 3] - target_box[:, 1])
|
||||
|
||||
for i in range(target_box.shape[0]):
|
||||
output_box[i,:,0] = (target_box_x[i] - prior_box_x) / prior_box_width / \
|
||||
prior_box_var[:,0]
|
||||
output_box[i,:,1] = (target_box_y[i] - prior_box_y) / prior_box_height / \
|
||||
prior_box_var[:,1]
|
||||
output_box[i,:,2] = np.log(np.fabs(target_box_width[i] / prior_box_width)) / \
|
||||
prior_box_var[:,2]
|
||||
output_box[i,:,3] = np.log(np.fabs(target_box_height[i] / prior_box_height)) / \
|
||||
prior_box_var[:,3]
|
||||
|
||||
elif (code_type == "DecodeCenterSize"):
|
||||
for i in range(target_box.shape[0]):
|
||||
target_box_x = prior_box_var[:,0] * target_box[i][0] * \
|
||||
prior_box_width[:] + prior_box_x[:]
|
||||
target_box_y = prior_box_var[:,1] * target_box[i][1] * \
|
||||
prior_box_height[:] + prior_box_y[:]
|
||||
target_box_width = np.exp(prior_box_var[:,2] * target_box[i][2]) * \
|
||||
prior_box_width[:]
|
||||
target_box_height = np.exp(prior_box_var[:,3] * target_box[i][3]) * \
|
||||
prior_box_height[:]
|
||||
output_box[i, :, 0] = target_box_x - target_box_width / 2
|
||||
output_box[i, :, 1] = target_box_y - target_box_height / 2
|
||||
output_box[i, :, 2] = target_box_x + target_box_width / 2
|
||||
output_box[i, :, 3] = target_box_y + target_box_height / 2
|
||||
|
||||
|
||||
def batch_box_coder(prior_box, prior_box_var, target_box, lod, code_type):
|
||||
n = target_box.shape[0]
|
||||
m = prior_box.shape[0]
|
||||
output_box = np.zeros((n, m, 4), dtype=np.float32)
|
||||
for i in range(len(lod) - 1):
|
||||
box_coder(target_box[lod[i]:lod[i + 1], :], prior_box, prior_box_var,
|
||||
output_box[lod[i]:lod[i + 1], :, :], code_type)
|
||||
return output_box
|
||||
|
||||
|
||||
class TestBoxCoderOp(OpTest):
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "box_coder"
|
||||
lod = [[0, 20]]
|
||||
prior_box = np.random.random((10, 4)).astype('float32')
|
||||
prior_box_var = np.random.random((10, 4)).astype('float32')
|
||||
target_box = np.random.random((20, 4)).astype('float32')
|
||||
code_type = "DecodeCenterSize"
|
||||
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
|
||||
lod[0], code_type)
|
||||
|
||||
self.inputs = {
|
||||
'PriorBox': prior_box,
|
||||
'PriorBoxVar': prior_box_var,
|
||||
'TargetBox': target_box,
|
||||
}
|
||||
self.attrs = {'code_type': 'decode_center_size'}
|
||||
self.outputs = {'OutputBox': output_box}
|
||||
|
||||
|
||||
class TestBoxCoderOpWithLoD(OpTest):
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "box_coder"
|
||||
lod = [[0, 4, 12, 20]]
|
||||
prior_box = np.random.random((10, 4)).astype('float32')
|
||||
prior_box_var = np.random.random((10, 4)).astype('float32')
|
||||
target_box = np.random.random((20, 4)).astype('float32')
|
||||
code_type = "EncodeCenterSize"
|
||||
output_box = batch_box_coder(prior_box, prior_box_var, target_box,
|
||||
lod[0], code_type)
|
||||
|
||||
self.inputs = {
|
||||
'PriorBox': prior_box,
|
||||
'PriorBoxVar': prior_box_var,
|
||||
'TargetBox': (target_box, lod),
|
||||
}
|
||||
self.attrs = {'code_type': 'encode_center_size'}
|
||||
self.outputs = {'OutputBox': output_box}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue