roi_align for gpu

ce
jerrywgz 7 years ago
parent 2f5a80174e
commit 8c79071d6a

File diff suppressed because it is too large Load Diff

@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/roi_align_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {

File diff suppressed because it is too large Load Diff

@ -21,6 +21,8 @@ namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
static constexpr int kROISize = 4;
template <class T>
void pre_calc_for_bilinear_interpolate(
const platform::DeviceContext& ctx, const int height, const int width,
@ -44,9 +46,9 @@ void pre_calc_for_bilinear_interpolate(
static_cast<T>(roi_bin_grid_w);
// deal with elements out of map
if (y < -1.0 || y > height || x < -1.0 || x > width) {
for (int i = 0; i < 4; ++i) {
pre_pos_data[i + pre_calc_index * 4] = 0;
pre_w_data[i + pre_calc_index * 4] = 0;
for (int i = 0; i < kROISize; ++i) {
pre_pos_data[i + pre_calc_index * kROISize] = 0;
pre_w_data[i + pre_calc_index * kROISize] = 0;
}
pre_calc_index += 1;
continue;
@ -76,14 +78,14 @@ void pre_calc_for_bilinear_interpolate(
}
T ly = y - y_low, lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
pre_pos_data[pre_calc_index * 4] = y_low * width + x_low;
pre_pos_data[pre_calc_index * 4 + 1] = y_low * width + x_high;
pre_pos_data[pre_calc_index * 4 + 2] = y_high * width + x_low;
pre_pos_data[pre_calc_index * 4 + 3] = y_high * width + x_high;
pre_w_data[pre_calc_index * 4] = hy * hx;
pre_w_data[pre_calc_index * 4 + 1] = hy * lx;
pre_w_data[pre_calc_index * 4 + 2] = ly * hx;
pre_w_data[pre_calc_index * 4 + 3] = ly * lx;
pre_pos_data[pre_calc_index * kROISize] = y_low * width + x_low;
pre_pos_data[pre_calc_index * kROISize + 1] = y_low * width + x_high;
pre_pos_data[pre_calc_index * kROISize + 2] = y_high * width + x_low;
pre_pos_data[pre_calc_index * kROISize + 3] = y_high * width + x_high;
pre_w_data[pre_calc_index * kROISize] = hy * hx;
pre_w_data[pre_calc_index * kROISize + 1] = hy * lx;
pre_w_data[pre_calc_index * kROISize + 2] = ly * hx;
pre_w_data[pre_calc_index * kROISize + 3] = ly * lx;
pre_calc_index += 1;
}
}
@ -155,11 +157,11 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto in_dims = in->dims();
int64_t batch_size = in_dims[0];
int64_t channels = in_dims[1];
int64_t height = in_dims[2];
int64_t width = in_dims[3];
int64_t rois_num = rois->dims()[0];
int batch_size = in_dims[0];
int channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
auto in_stride = framework::stride(in_dims);
auto roi_stride = framework::stride(rois->dims());
@ -209,8 +211,8 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
Tensor pre_pos;
Tensor pre_w;
int pre_size = count * out_stride[1];
pre_pos.Resize({pre_size, 4});
pre_w.Resize({pre_size, 4});
pre_pos.Resize({pre_size, kROISize});
pre_w.Resize({pre_size, kROISize});
pre_calc_for_bilinear_interpolate(
dev_ctx, height, width, pooled_height, pooled_width, roi_bin_grid_h,
@ -226,9 +228,9 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
T output_val = 0;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
for (int i = 0; i < 4; i++) {
int pos = pre_pos_data[pre_calc_index * 4 + i];
T w = pre_w_data[pre_calc_index * 4 + i];
for (int i = 0; i < kROISize; i++) {
int pos = pre_pos_data[pre_calc_index * kROISize + i];
T w = pre_w_data[pre_calc_index * kROISize + i];
output_val += w * batch_data[pos];
}
pre_calc_index += 1;
@ -263,11 +265,11 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
auto sampling_ratio = ctx.Attr<int>("sampling_ratio");
auto in_dims = in->dims();
if (in_grad) {
int64_t channels = in_dims[1];
int64_t height = in_dims[2];
int64_t width = in_dims[3];
int channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
framework::Tensor roi_batch_id_list;
Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(ctx.GetPlace());

Loading…
Cancel
Save