|
|
|
@ -9,6 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <thrust/device_vector.h>
|
|
|
|
|
#include <thrust/host_vector.h>
|
|
|
|
|
#include "paddle/fluid/operators/detection/box_coder_op.h"
|
|
|
|
|
#include "paddle/fluid/platform/cuda_primitives.h"
|
|
|
|
|
|
|
|
|
@ -16,12 +18,11 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void EncodeCenterSizeKernel(const T* prior_box_data,
|
|
|
|
|
const T* prior_box_var_data,
|
|
|
|
|
const T* target_box_data, const int row,
|
|
|
|
|
const int col, const int len,
|
|
|
|
|
const bool normalized,
|
|
|
|
|
const T prior_box_var_size, T* output) {
|
|
|
|
|
__global__ void EncodeCenterSizeKernel(
|
|
|
|
|
const T* prior_box_data, const T* prior_box_var_data,
|
|
|
|
|
const T* target_box_data, const int row, const int col, const int len,
|
|
|
|
|
const bool normalized, const T prior_box_var_size, const float* variance,
|
|
|
|
|
const int var_size, T* output) {
|
|
|
|
|
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
|
|
|
|
if (idx < row * col) {
|
|
|
|
|
const int row_idx = idx / col;
|
|
|
|
@ -62,18 +63,20 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
|
|
|
|
|
output[idx * len + 1] /= prior_box_var_data[prior_var_offset + 1];
|
|
|
|
|
output[idx * len + 2] /= prior_box_var_data[prior_var_offset + 2];
|
|
|
|
|
output[idx * len + 3] /= prior_box_var_data[prior_var_offset + 3];
|
|
|
|
|
} else if (var_size == 4) {
|
|
|
|
|
for (int k = 0; k < 4; ++k) {
|
|
|
|
|
output[idx * len + k] /= static_cast<T>(variance[k]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void DecodeCenterSizeKernel(const T* prior_box_data,
|
|
|
|
|
const T* prior_box_var_data,
|
|
|
|
|
const T* target_box_data, const int row,
|
|
|
|
|
const int col, const int len,
|
|
|
|
|
const bool normalized,
|
|
|
|
|
const T prior_box_var_size,
|
|
|
|
|
const int axis, T* output) {
|
|
|
|
|
__global__ void DecodeCenterSizeKernel(
|
|
|
|
|
const T* prior_box_data, const T* prior_box_var_data,
|
|
|
|
|
const T* target_box_data, const int row, const int col, const int len,
|
|
|
|
|
const bool normalized, const T prior_box_var_size, const float* variance,
|
|
|
|
|
const int var_size, const int axis, T* output) {
|
|
|
|
|
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
|
|
|
|
int prior_box_offset = 0;
|
|
|
|
|
if (idx < row * col) {
|
|
|
|
@ -110,6 +113,20 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
|
|
|
|
|
target_box_data[idx * len + 1] *
|
|
|
|
|
prior_box_height +
|
|
|
|
|
prior_box_center_y;
|
|
|
|
|
} else if (var_size == 4) {
|
|
|
|
|
target_box_width =
|
|
|
|
|
exp(static_cast<T>(variance[2]) * target_box_data[idx * len + 2]) *
|
|
|
|
|
prior_box_width;
|
|
|
|
|
target_box_height =
|
|
|
|
|
exp(static_cast<T>(variance[3]) * target_box_data[idx * len + 3]) *
|
|
|
|
|
prior_box_height;
|
|
|
|
|
target_box_center_x = static_cast<T>(variance[0]) *
|
|
|
|
|
target_box_data[idx * len] * prior_box_width +
|
|
|
|
|
prior_box_center_x;
|
|
|
|
|
target_box_center_y = static_cast<T>(variance[1]) *
|
|
|
|
|
target_box_data[idx * len + 1] *
|
|
|
|
|
prior_box_height +
|
|
|
|
|
prior_box_center_y;
|
|
|
|
|
} else {
|
|
|
|
|
target_box_width = exp(target_box_data[idx * len + 2]) * prior_box_width;
|
|
|
|
|
target_box_height =
|
|
|
|
@ -139,20 +156,30 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
|
|
|
|
|
auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
|
|
|
|
|
auto* output_box = context.Output<framework::Tensor>("OutputBox");
|
|
|
|
|
|
|
|
|
|
std::vector<float> variance = context.Attr<std::vector<float>>("variance");
|
|
|
|
|
const T* prior_box_data = prior_box->data<T>();
|
|
|
|
|
const T* target_box_data = target_box->data<T>();
|
|
|
|
|
const T* prior_box_var_data = nullptr;
|
|
|
|
|
auto prior_box_var_size = 0;
|
|
|
|
|
if (prior_box_var) {
|
|
|
|
|
PADDLE_ENFORCE(variance.empty(),
|
|
|
|
|
"Input 'PriorBoxVar' and attribute 'variance' should not"
|
|
|
|
|
"be used at the same time.");
|
|
|
|
|
prior_box_var_data = prior_box_var->data<T>();
|
|
|
|
|
prior_box_var_size = prior_box_var->dims().size();
|
|
|
|
|
}
|
|
|
|
|
if (!(variance.empty())) {
|
|
|
|
|
PADDLE_ENFORCE(static_cast<int>(variance.size()) == 4,
|
|
|
|
|
"Size of attribute 'variance' should be 4");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (target_box->lod().size()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(target_box->lod().size(), 1,
|
|
|
|
|
"Only support 1 level of LoD.");
|
|
|
|
|
}
|
|
|
|
|
const int var_size = static_cast<T>(variance.size());
|
|
|
|
|
thrust::device_vector<float> dev_variance(variance.begin(), variance.end());
|
|
|
|
|
const float* dev_var_data = thrust::raw_pointer_cast(dev_variance.data());
|
|
|
|
|
auto code_type = GetBoxCodeType(context.Attr<std::string>("code_type"));
|
|
|
|
|
bool normalized = context.Attr<bool>("box_normalized");
|
|
|
|
|
int axis = context.Attr<int>("axis");
|
|
|
|
@ -173,11 +200,11 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (code_type == BoxCodeType::kEncodeCenterSize) {
|
|
|
|
|
EncodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
|
|
|
|
|
prior_box_data, prior_box_var_data, target_box_data, row, col, len,
|
|
|
|
|
normalized, prior_box_var_size, output);
|
|
|
|
|
normalized, prior_box_var_size, dev_var_data, var_size, output);
|
|
|
|
|
} else if (code_type == BoxCodeType::kDecodeCenterSize) {
|
|
|
|
|
DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
|
|
|
|
|
prior_box_data, prior_box_var_data, target_box_data, row, col, len,
|
|
|
|
|
normalized, prior_box_var_size, axis, output);
|
|
|
|
|
normalized, prior_box_var_size, dev_var_data, var_size, axis, output);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|