Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fpn_ops
	
		
	
				
					
				
			
						commit
						847bb6a279
					
				| @ -0,0 +1,169 @@ | |||||||
|  | /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. */ | ||||||
|  | 
 | ||||||
|  | #include "paddle/fluid/operators/detection/box_decoder_and_assign_op.h" | ||||||
|  | 
 | ||||||
|  | namespace paddle { | ||||||
|  | namespace operators { | ||||||
|  | 
 | ||||||
|  | using LoDTensor = framework::LoDTensor; | ||||||
|  | 
 | ||||||
|  | class BoxDecoderAndAssignOp : public framework::OperatorWithKernel { | ||||||
|  |  public: | ||||||
|  |   using framework::OperatorWithKernel::OperatorWithKernel; | ||||||
|  | 
 | ||||||
|  |  protected: | ||||||
|  |   void InferShape(framework::InferShapeContext *ctx) const override { | ||||||
|  |     PADDLE_ENFORCE( | ||||||
|  |         ctx->HasInput("PriorBox"), | ||||||
|  |         "Input(PriorBox) of BoxDecoderAndAssignOp should not be null."); | ||||||
|  |     PADDLE_ENFORCE( | ||||||
|  |         ctx->HasInput("PriorBoxVar"), | ||||||
|  |         "Input(PriorBoxVar) of BoxDecoderAndAssignOp should not be null."); | ||||||
|  |     PADDLE_ENFORCE( | ||||||
|  |         ctx->HasInput("TargetBox"), | ||||||
|  |         "Input(TargetBox) of BoxDecoderAndAssignOp should not be null."); | ||||||
|  |     PADDLE_ENFORCE( | ||||||
|  |         ctx->HasInput("BoxScore"), | ||||||
|  |         "Input(BoxScore) of BoxDecoderAndAssignOp should not be null."); | ||||||
|  |     PADDLE_ENFORCE( | ||||||
|  |         ctx->HasOutput("DecodeBox"), | ||||||
|  |         "Output(DecodeBox) of BoxDecoderAndAssignOp should not be null."); | ||||||
|  |     PADDLE_ENFORCE( | ||||||
|  |         ctx->HasOutput("OutputAssignBox"), | ||||||
|  |         "Output(OutputAssignBox) of BoxDecoderAndAssignOp should not be null."); | ||||||
|  | 
 | ||||||
|  |     auto prior_box_dims = ctx->GetInputDim("PriorBox"); | ||||||
|  |     auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar"); | ||||||
|  |     auto target_box_dims = ctx->GetInputDim("TargetBox"); | ||||||
|  |     auto box_score_dims = ctx->GetInputDim("BoxScore"); | ||||||
|  | 
 | ||||||
|  |     PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2, | ||||||
|  |                       "The rank of Input of PriorBox must be 2"); | ||||||
|  |     PADDLE_ENFORCE_EQ(prior_box_dims[1], 4, "The shape of PriorBox is [N, 4]"); | ||||||
|  |     PADDLE_ENFORCE_EQ(prior_box_var_dims.size(), 1, | ||||||
|  |                       "The rank of Input of PriorBoxVar must be 1"); | ||||||
|  |     PADDLE_ENFORCE_EQ(prior_box_var_dims[0], 4, | ||||||
|  |                       "The shape of PriorBoxVar is [4]"); | ||||||
|  |     PADDLE_ENFORCE_EQ(target_box_dims.size(), 2, | ||||||
|  |                       "The rank of Input of TargetBox must be 2"); | ||||||
|  |     PADDLE_ENFORCE_EQ(box_score_dims.size(), 2, | ||||||
|  |                       "The rank of Input of BoxScore must be 2"); | ||||||
|  |     PADDLE_ENFORCE_EQ(prior_box_dims[0], target_box_dims[0], | ||||||
|  |                       "The first dim of prior_box and target_box is roi nums " | ||||||
|  |                       "and should be same!"); | ||||||
|  |     PADDLE_ENFORCE_EQ(prior_box_dims[0], box_score_dims[0], | ||||||
|  |                       "The first dim of prior_box and box_score is roi nums " | ||||||
|  |                       "and should be same!"); | ||||||
|  |     PADDLE_ENFORCE_EQ(target_box_dims[1], box_score_dims[1] * prior_box_dims[1], | ||||||
|  |                       "The shape of target_box is [N, classnum * 4], The shape " | ||||||
|  |                       "of box_score is [N, classnum], The shape of prior_box " | ||||||
|  |                       "is [N, 4]"); | ||||||
|  | 
 | ||||||
|  |     ctx->SetOutputDim("DecodeBox", framework::make_ddim({target_box_dims[0], | ||||||
|  |                                                          target_box_dims[1]})); | ||||||
|  |     ctx->ShareLoD("TargetBox", /*->*/ "DecodeBox"); | ||||||
|  |     ctx->SetOutputDim( | ||||||
|  |         "OutputAssignBox", | ||||||
|  |         framework::make_ddim({prior_box_dims[0], prior_box_dims[1]})); | ||||||
|  |     ctx->ShareLoD("PriorBox", /*->*/ "OutputAssignBox"); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | class BoxDecoderAndAssignOpMaker : public framework::OpProtoAndCheckerMaker { | ||||||
|  |  public: | ||||||
|  |   void Make() override { | ||||||
|  |     AddInput( | ||||||
|  |         "PriorBox", | ||||||
|  |         "(Tensor, default Tensor<float>) " | ||||||
|  |         "Box list PriorBox is a 2-D Tensor with shape [N, 4] which holds N " | ||||||
|  |         "boxes and each box is represented as [xmin, ymin, xmax, ymax], " | ||||||
|  |         "[xmin, ymin] is the left top coordinate of the anchor box, " | ||||||
|  |         "if the input is image feature map, they are close to the origin " | ||||||
|  |         "of the coordinate system. [xmax, ymax] is the right bottom " | ||||||
|  |         "coordinate of the anchor box."); | ||||||
|  |     AddInput("PriorBoxVar", | ||||||
|  |              "(Tensor, default Tensor<float>, optional) " | ||||||
|  |              "PriorBoxVar is a 2-D Tensor with shape [N, 4] which holds N " | ||||||
|  |              "group of variance. PriorBoxVar will set all elements to 1 by " | ||||||
|  |              "default.") | ||||||
|  |         .AsDispensable(); | ||||||
|  |     AddInput("TargetBox", | ||||||
|  |              "(LoDTensor or Tensor) " | ||||||
|  |              "This input can be a 2-D LoDTensor with shape " | ||||||
|  |              "[N, classnum*4]. It holds N targets for N boxes."); | ||||||
|  |     AddInput("BoxScore", | ||||||
|  |              "(LoDTensor or Tensor) " | ||||||
|  |              "This input can be a 2-D LoDTensor with shape " | ||||||
|  |              "[N, classnum], each box is represented as [classnum] which is " | ||||||
|  |              "the classification probabilities."); | ||||||
|  |     AddAttr<float>("box_clip", | ||||||
|  |                    "(float, default 4.135, np.log(1000. / 16.)) " | ||||||
|  |                    "clip box to prevent overflowing") | ||||||
|  |         .SetDefault(4.135f); | ||||||
|  |     AddOutput("DecodeBox", | ||||||
|  |               "(LoDTensor or Tensor) " | ||||||
|  |               "the output tensor of op with shape [N, classnum * 4] " | ||||||
|  |               "representing the result of N target boxes decoded with " | ||||||
|  |               "M Prior boxes and variances for each class."); | ||||||
|  |     AddOutput("OutputAssignBox", | ||||||
|  |               "(LoDTensor or Tensor) " | ||||||
|  |               "the output tensor of op with shape [N, 4] " | ||||||
|  |               "representing the result of N target boxes decoded with " | ||||||
|  |               "M Prior boxes and variances with the best non-background class " | ||||||
|  |               "by BoxScore."); | ||||||
|  |     AddComment(R"DOC( | ||||||
|  | 
 | ||||||
|  | Bounding Box Coder. | ||||||
|  | 
 | ||||||
|  | Decode the target bounding box with the prior_box information. | ||||||
|  | 
 | ||||||
|  | The Decoding schema is described below: | ||||||
|  | 
 | ||||||
|  |     $$ | ||||||
|  |     ox = (pw \\times pxv \\times tx + px) - \\frac{tw}{2}  | ||||||
|  |     $$ | ||||||
|  |     $$ | ||||||
|  |     oy = (ph \\times pyv \\times ty + py) - \\frac{th}{2} | ||||||
|  |     $$ | ||||||
|  |     $$ | ||||||
|  |     ow = \\exp (pwv \\times tw) \\times pw + \\frac{tw}{2} | ||||||
|  |     $$ | ||||||
|  |     $$ | ||||||
|  |     oh = \\exp (phv \\times th) \\times ph + \\frac{th}{2} | ||||||
|  |     $$ | ||||||
|  | 
 | ||||||
|  | where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates, width | ||||||
|  | and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote the | ||||||
|  | prior_box's (anchor) center coordinates, width and height. `pxv`, `pyv`, `pwv`, | ||||||
|  | `phv` denote the variance of the prior_box and `ox`, `oy`, `ow`, `oh` denote the | ||||||
|  | decoded coordinates, width and height in decode_box.  | ||||||
|  | 
 | ||||||
|  | decode_box is obtained after box decode, then assigning schema is described below: | ||||||
|  | 
 | ||||||
|  | For each prior_box, use the best non-background class's decoded values to  | ||||||
|  | update the prior_box locations and get output_assign_box. So, the shape of | ||||||
|  | output_assign_box is the same as PriorBox. | ||||||
|  | )DOC"); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | }  // namespace operators
 | ||||||
|  | }  // namespace paddle
 | ||||||
|  | 
 | ||||||
|  | namespace ops = paddle::operators; | ||||||
|  | REGISTER_OPERATOR(box_decoder_and_assign, ops::BoxDecoderAndAssignOp, | ||||||
|  |                   ops::BoxDecoderAndAssignOpMaker, | ||||||
|  |                   paddle::framework::EmptyGradOpMaker); | ||||||
|  | REGISTER_OP_CPU_KERNEL( | ||||||
|  |     box_decoder_and_assign, | ||||||
|  |     ops::BoxDecoderAndAssignKernel<paddle::platform::CPUDeviceContext, float>, | ||||||
|  |     ops::BoxDecoderAndAssignKernel<paddle::platform::CPUDeviceContext, double>); | ||||||
| @ -0,0 +1,147 @@ | |||||||
|  | /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. */ | ||||||
|  | 
 | ||||||
|  | #include "paddle/fluid/memory/memcpy.h" | ||||||
|  | #include "paddle/fluid/operators/detection/box_decoder_and_assign_op.h" | ||||||
|  | #include "paddle/fluid/platform/cuda_primitives.h" | ||||||
|  | 
 | ||||||
|  | namespace paddle { | ||||||
|  | namespace operators { | ||||||
|  | 
 | ||||||
|  | template <typename T> | ||||||
|  | __global__ void DecodeBoxKernel(const T* prior_box_data, | ||||||
|  |                                 const T* prior_box_var_data, | ||||||
|  |                                 const T* target_box_data, const int roi_num, | ||||||
|  |                                 const int class_num, const T box_clip, | ||||||
|  |                                 T* output_box_data) { | ||||||
|  |   const int idx = threadIdx.x + blockIdx.x * blockDim.x; | ||||||
|  |   if (idx < roi_num * class_num) { | ||||||
|  |     int i = idx / class_num; | ||||||
|  |     int j = idx % class_num; | ||||||
|  |     T prior_box_width = prior_box_data[i * 4 + 2] - prior_box_data[i * 4] + 1; | ||||||
|  |     T prior_box_height = | ||||||
|  |         prior_box_data[i * 4 + 3] - prior_box_data[i * 4 + 1] + 1; | ||||||
|  |     T prior_box_center_x = prior_box_data[i * 4] + prior_box_width / 2; | ||||||
|  |     T prior_box_center_y = prior_box_data[i * 4 + 1] + prior_box_height / 2; | ||||||
|  | 
 | ||||||
|  |     int offset = i * class_num * 4 + j * 4; | ||||||
|  |     T dw = prior_box_var_data[2] * target_box_data[offset + 2]; | ||||||
|  |     T dh = prior_box_var_data[3] * target_box_data[offset + 3]; | ||||||
|  |     if (dw > box_clip) { | ||||||
|  |       dw = box_clip; | ||||||
|  |     } | ||||||
|  |     if (dh > box_clip) { | ||||||
|  |       dh = box_clip; | ||||||
|  |     } | ||||||
|  |     T target_box_center_x = 0, target_box_center_y = 0; | ||||||
|  |     T target_box_width = 0, target_box_height = 0; | ||||||
|  |     target_box_center_x = | ||||||
|  |         prior_box_var_data[0] * target_box_data[offset] * prior_box_width + | ||||||
|  |         prior_box_center_x; | ||||||
|  |     target_box_center_y = | ||||||
|  |         prior_box_var_data[1] * target_box_data[offset + 1] * prior_box_height + | ||||||
|  |         prior_box_center_y; | ||||||
|  |     target_box_width = expf(dw) * prior_box_width; | ||||||
|  |     target_box_height = expf(dh) * prior_box_height; | ||||||
|  | 
 | ||||||
|  |     output_box_data[offset] = target_box_center_x - target_box_width / 2; | ||||||
|  |     output_box_data[offset + 1] = target_box_center_y - target_box_height / 2; | ||||||
|  |     output_box_data[offset + 2] = | ||||||
|  |         target_box_center_x + target_box_width / 2 - 1; | ||||||
|  |     output_box_data[offset + 3] = | ||||||
|  |         target_box_center_y + target_box_height / 2 - 1; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | template <typename T> | ||||||
|  | __global__ void AssignBoxKernel(const T* prior_box_data, | ||||||
|  |                                 const T* box_score_data, T* output_box_data, | ||||||
|  |                                 const int roi_num, const int class_num, | ||||||
|  |                                 T* output_assign_box_data) { | ||||||
|  |   const int idx = threadIdx.x + blockIdx.x * blockDim.x; | ||||||
|  |   if (idx < roi_num) { | ||||||
|  |     int i = idx; | ||||||
|  |     T max_score = -1; | ||||||
|  |     int max_j = -1; | ||||||
|  |     for (int j = 0; j < class_num; ++j) { | ||||||
|  |       T score = box_score_data[i * class_num + j]; | ||||||
|  |       if (score > max_score && j > 0) { | ||||||
|  |         max_score = score; | ||||||
|  |         max_j = j; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     if (max_j > 0) { | ||||||
|  |       for (int pno = 0; pno < 4; pno++) { | ||||||
|  |         output_assign_box_data[i * 4 + pno] = | ||||||
|  |             output_box_data[i * class_num * 4 + max_j * 4 + pno]; | ||||||
|  |       } | ||||||
|  |     } else { | ||||||
|  |       for (int pno = 0; pno < 4; pno++) { | ||||||
|  |         output_assign_box_data[i * 4 + pno] = prior_box_data[i * 4 + pno]; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | template <typename DeviceContext, typename T> | ||||||
|  | class BoxDecoderAndAssignCUDAKernel : public framework::OpKernel<T> { | ||||||
|  |  public: | ||||||
|  |   void Compute(const framework::ExecutionContext& context) const override { | ||||||
|  |     PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), | ||||||
|  |                    "This kernel only runs on GPU device."); | ||||||
|  |     auto* prior_box = context.Input<framework::LoDTensor>("PriorBox"); | ||||||
|  |     auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar"); | ||||||
|  |     auto* target_box = context.Input<framework::LoDTensor>("TargetBox"); | ||||||
|  |     auto* box_score = context.Input<framework::LoDTensor>("BoxScore"); | ||||||
|  |     auto* output_box = context.Output<framework::Tensor>("DecodeBox"); | ||||||
|  |     auto* output_assign_box = | ||||||
|  |         context.Output<framework::Tensor>("OutputAssignBox"); | ||||||
|  | 
 | ||||||
|  |     auto roi_num = target_box->dims()[0]; | ||||||
|  |     auto class_num = box_score->dims()[1]; | ||||||
|  |     auto* target_box_data = target_box->data<T>(); | ||||||
|  |     auto* prior_box_data = prior_box->data<T>(); | ||||||
|  |     auto* prior_box_var_data = prior_box_var->data<T>(); | ||||||
|  |     auto* box_score_data = box_score->data<T>(); | ||||||
|  |     output_box->mutable_data<T>({roi_num, class_num * 4}, context.GetPlace()); | ||||||
|  |     output_assign_box->mutable_data<T>({roi_num, 4}, context.GetPlace()); | ||||||
|  |     T* output_box_data = output_box->data<T>(); | ||||||
|  |     T* output_assign_box_data = output_assign_box->data<T>(); | ||||||
|  | 
 | ||||||
|  |     int block = 512; | ||||||
|  |     int grid = (roi_num * class_num + block - 1) / block; | ||||||
|  |     auto& device_ctx = context.cuda_device_context(); | ||||||
|  | 
 | ||||||
|  |     const T box_clip = context.Attr<T>("box_clip"); | ||||||
|  | 
 | ||||||
|  |     DecodeBoxKernel<T><<<grid, block, 0, device_ctx.stream()>>>( | ||||||
|  |         prior_box_data, prior_box_var_data, target_box_data, roi_num, class_num, | ||||||
|  |         box_clip, output_box_data); | ||||||
|  | 
 | ||||||
|  |     context.device_context().Wait(); | ||||||
|  |     int assign_grid = (roi_num + block - 1) / block; | ||||||
|  |     AssignBoxKernel<T><<<assign_grid, block, 0, device_ctx.stream()>>>( | ||||||
|  |         prior_box_data, box_score_data, output_box_data, roi_num, class_num, | ||||||
|  |         output_assign_box_data); | ||||||
|  |     context.device_context().Wait(); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | }  // namespace operators | ||||||
|  | }  // namespace paddle | ||||||
|  | 
 | ||||||
|  | namespace ops = paddle::operators; | ||||||
|  | REGISTER_OP_CUDA_KERNEL( | ||||||
|  |     box_decoder_and_assign, | ||||||
|  |     ops::BoxDecoderAndAssignCUDAKernel<paddle::platform::CUDADeviceContext, | ||||||
|  |                                        float>, | ||||||
|  |     ops::BoxDecoderAndAssignCUDAKernel<paddle::platform::CUDADeviceContext, | ||||||
|  |                                        double>); | ||||||
| @ -0,0 +1,103 @@ | |||||||
|  | /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. */ | ||||||
|  | 
 | ||||||
|  | #pragma once | ||||||
|  | #include <algorithm> | ||||||
|  | #include <string> | ||||||
|  | #include <vector> | ||||||
|  | #include "paddle/fluid/framework/op_registry.h" | ||||||
|  | #include "paddle/fluid/operators/math/math_function.h" | ||||||
|  | 
 | ||||||
|  | namespace paddle { | ||||||
|  | namespace operators { | ||||||
|  | 
 | ||||||
|  | template <typename DeviceContext, typename T> | ||||||
|  | class BoxDecoderAndAssignKernel : public framework::OpKernel<T> { | ||||||
|  |  public: | ||||||
|  |   void Compute(const framework::ExecutionContext& context) const override { | ||||||
|  |     auto* prior_box = context.Input<framework::LoDTensor>("PriorBox"); | ||||||
|  |     auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar"); | ||||||
|  |     auto* target_box = context.Input<framework::LoDTensor>("TargetBox"); | ||||||
|  |     auto* box_score = context.Input<framework::LoDTensor>("BoxScore"); | ||||||
|  |     auto* output_box = context.Output<framework::Tensor>("DecodeBox"); | ||||||
|  |     auto* output_assign_box = | ||||||
|  |         context.Output<framework::Tensor>("OutputAssignBox"); | ||||||
|  |     int roi_num = target_box->dims()[0]; | ||||||
|  |     int class_num = box_score->dims()[1]; | ||||||
|  |     auto* target_box_data = target_box->data<T>(); | ||||||
|  |     auto* prior_box_data = prior_box->data<T>(); | ||||||
|  |     auto* prior_box_var_data = prior_box_var->data<T>(); | ||||||
|  |     auto* box_score_data = box_score->data<T>(); | ||||||
|  |     output_box->mutable_data<T>({roi_num, class_num * 4}, context.GetPlace()); | ||||||
|  |     output_assign_box->mutable_data<T>({roi_num, 4}, context.GetPlace()); | ||||||
|  |     T* output_box_data = output_box->data<T>(); | ||||||
|  |     T* output_assign_box_data = output_assign_box->data<T>(); | ||||||
|  |     const T bbox_clip = context.Attr<T>("box_clip"); | ||||||
|  | 
 | ||||||
|  |     for (int i = 0; i < roi_num; ++i) { | ||||||
|  |       T prior_box_width = prior_box_data[i * 4 + 2] - prior_box_data[i * 4] + 1; | ||||||
|  |       T prior_box_height = | ||||||
|  |           prior_box_data[i * 4 + 3] - prior_box_data[i * 4 + 1] + 1; | ||||||
|  |       T prior_box_center_x = prior_box_data[i * 4] + prior_box_width / 2; | ||||||
|  |       T prior_box_center_y = prior_box_data[i * 4 + 1] + prior_box_height / 2; | ||||||
|  |       for (int j = 0; j < class_num; ++j) { | ||||||
|  |         int64_t offset = i * class_num * 4 + j * 4; | ||||||
|  |         T dw = std::min(prior_box_var_data[2] * target_box_data[offset + 2], | ||||||
|  |                         bbox_clip); | ||||||
|  |         T dh = std::min(prior_box_var_data[3] * target_box_data[offset + 3], | ||||||
|  |                         bbox_clip); | ||||||
|  |         T target_box_center_x = 0, target_box_center_y = 0; | ||||||
|  |         T target_box_width = 0, target_box_height = 0; | ||||||
|  |         target_box_center_x = | ||||||
|  |             prior_box_var_data[0] * target_box_data[offset] * prior_box_width + | ||||||
|  |             prior_box_center_x; | ||||||
|  |         target_box_center_y = prior_box_var_data[1] * | ||||||
|  |                                   target_box_data[offset + 1] * | ||||||
|  |                                   prior_box_height + | ||||||
|  |                               prior_box_center_y; | ||||||
|  |         target_box_width = std::exp(dw) * prior_box_width; | ||||||
|  |         target_box_height = std::exp(dh) * prior_box_height; | ||||||
|  | 
 | ||||||
|  |         output_box_data[offset] = target_box_center_x - target_box_width / 2; | ||||||
|  |         output_box_data[offset + 1] = | ||||||
|  |             target_box_center_y - target_box_height / 2; | ||||||
|  |         output_box_data[offset + 2] = | ||||||
|  |             target_box_center_x + target_box_width / 2 - 1; | ||||||
|  |         output_box_data[offset + 3] = | ||||||
|  |             target_box_center_y + target_box_height / 2 - 1; | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       T max_score = -1; | ||||||
|  |       int max_j = -1; | ||||||
|  |       for (int j = 0; j < class_num; ++j) { | ||||||
|  |         T score = box_score_data[i * class_num + j]; | ||||||
|  |         if (score > max_score && j > 0) { | ||||||
|  |           max_score = score; | ||||||
|  |           max_j = j; | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       if (max_j > 0) { | ||||||
|  |         for (int pno = 0; pno < 4; pno++) { | ||||||
|  |           output_assign_box_data[i * 4 + pno] = | ||||||
|  |               output_box_data[i * class_num * 4 + max_j * 4 + pno]; | ||||||
|  |         } | ||||||
|  |       } else { | ||||||
|  |         for (int pno = 0; pno < 4; pno++) { | ||||||
|  |           output_assign_box_data[i * 4 + pno] = prior_box_data[i * 4 + pno]; | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | }  // namespace operators
 | ||||||
|  | }  // namespace paddle
 | ||||||
| @ -0,0 +1,91 @@ | |||||||
|  | /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. */ | ||||||
|  | 
 | ||||||
|  | #include "paddle/fluid/operators/jit/gen/vbroadcast.h" | ||||||
|  | #include <memory> | ||||||
|  | #include <vector> | ||||||
|  | #include "paddle/fluid/operators/jit/registry.h" | ||||||
|  | #include "paddle/fluid/platform/cpu_info.h" | ||||||
|  | 
 | ||||||
|  | namespace paddle { | ||||||
|  | namespace operators { | ||||||
|  | namespace jit { | ||||||
|  | namespace gen { | ||||||
|  | 
 | ||||||
|  | void VBroadcastJitCode::genCode() { | ||||||
|  |   preCode(); | ||||||
|  |   constexpr int block = YMM_FLOAT_BLOCK; | ||||||
|  |   constexpr int max_num_regs = 16; | ||||||
|  |   const int num_block = w_ / block; | ||||||
|  |   const int num_groups = num_block / max_num_regs; | ||||||
|  |   const size_t block_size = sizeof(float) * block; | ||||||
|  |   std::vector<int> groups(num_groups, max_num_regs); | ||||||
|  |   int rest_num_regs = num_block % max_num_regs; | ||||||
|  |   if (rest_num_regs > 0) { | ||||||
|  |     groups.push_back(rest_num_regs); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // protect param_h
 | ||||||
|  |   mov(reg_height, param_h); | ||||||
|  |   Label l_next_h; | ||||||
|  |   xor_(reg_h_i, reg_h_i); | ||||||
|  |   mov(reg_ptr_dst_i, param_dst); | ||||||
|  |   L(l_next_h); | ||||||
|  |   { | ||||||
|  |     mov(reg_ptr_src_i, param_src); | ||||||
|  |     for (int num_regs : groups) { | ||||||
|  |       size_t w_offset = 0; | ||||||
|  |       for (int reg_i = 0; reg_i < num_regs; ++reg_i) { | ||||||
|  |         vmovups(ymm_t(reg_i), ptr[reg_ptr_src_i + w_offset]); | ||||||
|  |         w_offset += block_size; | ||||||
|  |       } | ||||||
|  |       add(reg_ptr_src_i, num_regs * block_size); | ||||||
|  | 
 | ||||||
|  |       w_offset = 0; | ||||||
|  |       for (int reg_i = 0; reg_i < num_regs; ++reg_i) { | ||||||
|  |         vmovups(ptr[reg_ptr_dst_i + w_offset], ymm_t(reg_i)); | ||||||
|  |         w_offset += block_size; | ||||||
|  |       } | ||||||
|  |       add(reg_ptr_dst_i, num_regs * block_size); | ||||||
|  |     }  // end of groups
 | ||||||
|  |     inc(reg_h_i); | ||||||
|  |     cmp(reg_h_i, reg_height); | ||||||
|  |     jl(l_next_h, T_NEAR); | ||||||
|  |   }  // end of l_next_h
 | ||||||
|  | 
 | ||||||
|  |   postCode(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | class VBroadcastCreator : public JitCodeCreator<int64_t> { | ||||||
|  |  public: | ||||||
|  |   bool UseMe(const int64_t& w) const override { | ||||||
|  |     return platform::MayIUse(platform::avx) && w % YMM_FLOAT_BLOCK == 0; | ||||||
|  |   } | ||||||
|  |   size_t CodeSize(const int64_t& w) const override { | ||||||
|  |     return 96 + (w / YMM_FLOAT_BLOCK) * 16 * 8; | ||||||
|  |   } | ||||||
|  |   std::unique_ptr<GenBase> CreateJitCode(const int64_t& w) const override { | ||||||
|  |     PADDLE_ENFORCE_GT(w, 0); | ||||||
|  |     return make_unique<VBroadcastJitCode>(w, CodeSize(w)); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | }  // namespace gen
 | ||||||
|  | }  // namespace jit
 | ||||||
|  | }  // namespace operators
 | ||||||
|  | }  // namespace paddle
 | ||||||
|  | 
 | ||||||
|  | namespace gen = paddle::operators::jit::gen; | ||||||
|  | 
 | ||||||
|  | REGISTER_JITKERNEL_GEN(kVBroadcast, gen::VBroadcastCreator); | ||||||
| @ -0,0 +1,53 @@ | |||||||
|  | /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  * http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. */ | ||||||
|  | 
 | ||||||
|  | #pragma once | ||||||
|  | 
 | ||||||
|  | #include <string> | ||||||
|  | #include "glog/logging.h" | ||||||
|  | #include "paddle/fluid/operators/jit/gen/jitcode.h" | ||||||
|  | 
 | ||||||
|  | namespace paddle { | ||||||
|  | namespace operators { | ||||||
|  | namespace jit { | ||||||
|  | namespace gen { | ||||||
|  | 
 | ||||||
|  | class VBroadcastJitCode : public JitCode { | ||||||
|  |  public: | ||||||
|  |   explicit VBroadcastJitCode(const int64_t& w, size_t code_size = 256 * 1024, | ||||||
|  |                              void* code_ptr = nullptr) | ||||||
|  |       : JitCode(code_size, code_ptr), w_(w) { | ||||||
|  |     this->genCode(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   DECLARE_JIT_CODE(VBroadcastJitCode); | ||||||
|  |   void genCode() override; | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   int w_; | ||||||
|  |   reg64_t param_src{abi_param1}; | ||||||
|  |   reg64_t param_dst{abi_param2}; | ||||||
|  |   reg64_t param_h{abi_param3}; | ||||||
|  |   reg64_t param_w{abi_param4}; | ||||||
|  | 
 | ||||||
|  |   reg64_t reg_height{r9}; | ||||||
|  |   reg64_t reg_h_i{r10}; | ||||||
|  |   reg64_t reg_ptr_src_i{r11}; | ||||||
|  |   reg64_t reg_ptr_dst_i{r12}; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | }  // namespace gen
 | ||||||
|  | }  // namespace jit
 | ||||||
|  | }  // namespace operators
 | ||||||
|  | }  // namespace paddle
 | ||||||
Some files were not shown because too many files have changed in this diff Show More
					Loading…
					
					
				
		Reference in new issue