add deformable conv v1 op and cpu version of deformable conv v2 (#18500)
* add deformable conv v1 op, test=developexpand_as_op_1
parent
40c66f8df9
commit
00efd1d8a9
@ -0,0 +1,37 @@
|
|||||||
|
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
// Part of the following code in this file refs to
|
||||||
|
// https://github.com/msracver/Deformable-ConvNets/blob/master/faster_rcnn/operator_cxx/deformable_convolution.cu
|
||||||
|
//
|
||||||
|
// Copyright (c) 2017 Microsoft
|
||||||
|
// Licensed under The Apache-2.0 License [see LICENSE for details]
|
||||||
|
// \file deformable_psroi_pooling.cu
|
||||||
|
// \brief
|
||||||
|
// \author Yi Li, Guodong Zhang, Jifeng Dai
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "paddle/fluid/operators/math/blas.h"
|
||||||
|
#include "paddle/fluid/operators/math/math_function.h"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void FilterGradAddupCUDAKernel(const int nthreads, const int n,
|
||||||
|
const int height, const int width,
|
||||||
|
const T* dweight_3d, T* filter_grad) {
|
||||||
|
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
int offset = blockDim.x * gridDim.x;
|
||||||
|
for (size_t i = index; i < nthreads; i += offset) {
|
||||||
|
filter_grad[i] = filter_grad[i] + dweight_3d[i];
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,149 @@
|
|||||||
|
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
// Part of the following code in this file refs to
|
||||||
|
// https://github.com/msracver/Deformable-ConvNets/blob/master/faster_rcnn/operator_cxx/deformable_convolution.cu
|
||||||
|
//
|
||||||
|
// Copyright (c) 2017 Microsoft
|
||||||
|
// Licensed under The Apache-2.0 License [see LICENSE for details]
|
||||||
|
// \file deformable_psroi_pooling.cu
|
||||||
|
// \brief
|
||||||
|
// \author Yi Li, Guodong Zhang, Jifeng Dai
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "paddle/fluid/operators/math/blas.h"
|
||||||
|
#include "paddle/fluid/operators/math/math_function.h"
|
||||||
|
#include "paddle/fluid/platform/hostdevice.h"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
HOSTDEVICE T DmcnGetGradientWeight(T argmax_h, T argmax_w, const int h,
|
||||||
|
const int w, const int height,
|
||||||
|
const int width) {
|
||||||
|
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
|
||||||
|
argmax_w >= width) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int argmax_h_low = floor(argmax_h);
|
||||||
|
int argmax_w_low = floor(argmax_w);
|
||||||
|
int argmax_h_high = argmax_h_low + 1;
|
||||||
|
int argmax_w_high = argmax_w_low + 1;
|
||||||
|
|
||||||
|
T weight = 0;
|
||||||
|
|
||||||
|
weight = (h == argmax_h_low && w == argmax_w_low)
|
||||||
|
? (h + 1 - argmax_h) * (w + 1 - argmax_w)
|
||||||
|
: weight;
|
||||||
|
weight = (h == argmax_h_low && w == argmax_w_high)
|
||||||
|
? (h + 1 - argmax_h) * (argmax_w + 1 - w)
|
||||||
|
: weight;
|
||||||
|
weight = (h == argmax_h_high && w == argmax_w_low)
|
||||||
|
? (argmax_h + 1 - h) * (w + 1 - argmax_w)
|
||||||
|
: weight;
|
||||||
|
weight = (h == argmax_h_high && w == argmax_w_high)
|
||||||
|
? (argmax_h + 1 - h) * (argmax_w + 1 - w)
|
||||||
|
: weight;
|
||||||
|
|
||||||
|
return weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
HOSTDEVICE T DmcnGetCoordinateWeight(T argmax_h, T argmax_w, const int height,
|
||||||
|
const int width, const T* im_data,
|
||||||
|
const int data_width, const int bp_dir) {
|
||||||
|
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
|
||||||
|
argmax_w >= width) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int argmax_h_low = floor(argmax_h);
|
||||||
|
int argmax_w_low = floor(argmax_w);
|
||||||
|
int argmax_h_high = argmax_h_low + 1;
|
||||||
|
int argmax_w_high = argmax_w_low + 1;
|
||||||
|
|
||||||
|
T weight = 0;
|
||||||
|
|
||||||
|
if (bp_dir == 0) {
|
||||||
|
weight += (argmax_h_low >= 0 && argmax_w_low >= 0)
|
||||||
|
? -1 * (argmax_w_low + 1 - argmax_w) *
|
||||||
|
im_data[argmax_h_low * data_width + argmax_w_low]
|
||||||
|
: 0;
|
||||||
|
|
||||||
|
weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
||||||
|
? -1 * (argmax_w - argmax_w_low) *
|
||||||
|
im_data[argmax_h_low * data_width + argmax_w_high]
|
||||||
|
: 0;
|
||||||
|
|
||||||
|
weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
||||||
|
? (argmax_w_low + 1 - argmax_w) *
|
||||||
|
im_data[argmax_h_high * data_width + argmax_w_low]
|
||||||
|
: 0;
|
||||||
|
weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
||||||
|
? (argmax_w - argmax_w_low) *
|
||||||
|
im_data[argmax_h_high * data_width + argmax_w_high]
|
||||||
|
: 0;
|
||||||
|
} else if (bp_dir == 1) {
|
||||||
|
weight += (argmax_h_low >= 0 && argmax_w_low >= 0)
|
||||||
|
? -1 * (argmax_h_low + 1 - argmax_h) *
|
||||||
|
im_data[argmax_h_low * data_width + argmax_w_low]
|
||||||
|
: 0;
|
||||||
|
weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
||||||
|
? (argmax_h_low + 1 - argmax_h) *
|
||||||
|
im_data[argmax_h_low * data_width + argmax_w_high]
|
||||||
|
: 0;
|
||||||
|
weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
||||||
|
? -1 * (argmax_h - argmax_h_low) *
|
||||||
|
im_data[argmax_h_high * data_width + argmax_w_low]
|
||||||
|
: 0;
|
||||||
|
weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
||||||
|
? (argmax_h - argmax_h_low) *
|
||||||
|
im_data[argmax_h_high * data_width + argmax_w_high]
|
||||||
|
: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
HOSTDEVICE T DmcnIm2colBilinear(const T* bottom_data, const int data_width,
|
||||||
|
const int height, const int width, T h, T w) {
|
||||||
|
int h_low = floor(h);
|
||||||
|
int w_low = floor(w);
|
||||||
|
int h_high = h_low + 1;
|
||||||
|
int w_high = w_low + 1;
|
||||||
|
|
||||||
|
T lh = h - h_low;
|
||||||
|
T lw = w - w_low;
|
||||||
|
T hh = 1 - lh;
|
||||||
|
T hw = 1 - lw;
|
||||||
|
|
||||||
|
T v1 =
|
||||||
|
(h_low >= 0 && w_low >= 0) ? bottom_data[h_low * data_width + w_low] : 0;
|
||||||
|
T v2 = (h_low >= 0 && w_high <= width - 1)
|
||||||
|
? bottom_data[h_low * data_width + w_high]
|
||||||
|
: 0;
|
||||||
|
T v3 = (h_high <= height - 1 && w_low >= 0)
|
||||||
|
? bottom_data[h_high * data_width + w_low]
|
||||||
|
: 0;
|
||||||
|
T v4 = (h_high <= height - 1 && w_high <= width - 1)
|
||||||
|
? bottom_data[h_high * data_width + w_high]
|
||||||
|
: 0;
|
||||||
|
|
||||||
|
T w1 = hh * hw;
|
||||||
|
T w2 = hh * lw;
|
||||||
|
T w3 = lh * hw;
|
||||||
|
T w4 = lh * lw;
|
||||||
|
|
||||||
|
return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,240 @@
|
|||||||
|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import paddle.fluid.core as core
|
||||||
|
from op_test import OpTest
|
||||||
|
|
||||||
|
|
||||||
|
def dmc_bilinear(data_im, height, width, h, w):
|
||||||
|
h_low = int(np.floor(h))
|
||||||
|
w_low = int(np.floor(w))
|
||||||
|
h_high = h_low + 1
|
||||||
|
w_high = w_low + 1
|
||||||
|
|
||||||
|
lh = h - h_low
|
||||||
|
lw = w - w_low
|
||||||
|
hh = 1 - lh
|
||||||
|
hw = 1 - lw
|
||||||
|
|
||||||
|
v1 = 0
|
||||||
|
if h_low >= 0 and w_low >= 0:
|
||||||
|
v1 = data_im[h_low, w_low]
|
||||||
|
v2 = 0
|
||||||
|
if h_low >= 0 and w_high <= width - 1:
|
||||||
|
v2 = data_im[h_low, w_high]
|
||||||
|
v3 = 0
|
||||||
|
if h_high <= height - 1 and w_low >= 0:
|
||||||
|
v3 = data_im[h_high, w_low]
|
||||||
|
v4 = 0
|
||||||
|
if h_high <= height - 1 and w_high <= width - 1:
|
||||||
|
v4 = data_im[h_high, w_high]
|
||||||
|
|
||||||
|
w1, w2, w3, w4 = hh * hw, hh * lw, lh * hw, lh * lw
|
||||||
|
val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
|
||||||
|
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
def dconv_im2col_gemm(input, offset, filter, group, conv_param):
|
||||||
|
in_n, in_c, in_h, in_w = input.shape
|
||||||
|
out_c, f_c, f_h, f_w = filter.shape
|
||||||
|
|
||||||
|
assert offset.shape == (in_n, 2 * f_h * f_w, in_h, in_w)
|
||||||
|
assert f_c * group == in_c
|
||||||
|
assert np.mod(out_c, group) == 0
|
||||||
|
|
||||||
|
stride, pad, dilation = conv_param['stride'], conv_param['pad'],\
|
||||||
|
conv_param['dilation']
|
||||||
|
out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) // stride[0]
|
||||||
|
out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) // stride[1]
|
||||||
|
assert out_h == in_h
|
||||||
|
assert out_w == in_w
|
||||||
|
|
||||||
|
col_buffer = np.zeros((in_n, in_c * f_h * f_w, in_h * in_w))
|
||||||
|
for n in range(in_n):
|
||||||
|
for c in range(in_c):
|
||||||
|
for h in range(out_h):
|
||||||
|
for w in range(out_w):
|
||||||
|
for kh in range(f_h):
|
||||||
|
for kw in range(f_w):
|
||||||
|
offset_h_table = \
|
||||||
|
offset[n, ::2, h, w].reshape(f_h, f_w)
|
||||||
|
offset_w_table = \
|
||||||
|
offset[n, 1::2, h, w].reshape(f_h, f_w)
|
||||||
|
offset_h = offset_h_table[kh, kw]
|
||||||
|
offset_w = offset_w_table[kh, kw]
|
||||||
|
val = 0
|
||||||
|
im_h = h * stride[0] + kh * dilation[0] \
|
||||||
|
+ offset_h - pad[0]
|
||||||
|
im_w = w * stride[0] + kw * dilation[0] \
|
||||||
|
+ offset_w - pad[1]
|
||||||
|
if im_h > -1 and im_w > -1 and \
|
||||||
|
im_h < in_h and im_w < in_h:
|
||||||
|
val = dmc_bilinear(input[n, c], in_h, in_w,
|
||||||
|
im_h, im_w)
|
||||||
|
val_out = val
|
||||||
|
|
||||||
|
col_buffer[n, c * f_h * f_w + kh * f_w + kw, h *
|
||||||
|
in_w + w] = val_out
|
||||||
|
|
||||||
|
out = np.zeros((in_n, group, int(out_c // group), out_h * out_w))
|
||||||
|
weight = filter.reshape(group, int(out_c // group), f_c * f_h * f_w)
|
||||||
|
col_buffer = col_buffer.reshape(
|
||||||
|
(in_n, group, int(in_c // group * f_h * f_w), in_h * in_w))
|
||||||
|
for n in range(in_n):
|
||||||
|
for g in range(group):
|
||||||
|
out[n, g] = np.matmul(weight[g], col_buffer[n, g])
|
||||||
|
out = out.reshape(in_n, out_c, out_h, out_w)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class TestModulatedDeformableConvOp(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = "deformable_conv_v1"
|
||||||
|
self.dtype = np.float32
|
||||||
|
self.init_group()
|
||||||
|
self.init_dilation()
|
||||||
|
self.init_test_case()
|
||||||
|
|
||||||
|
conv_param = {
|
||||||
|
'stride': self.stride,
|
||||||
|
'pad': self.pad,
|
||||||
|
'dilation': self.dilations
|
||||||
|
}
|
||||||
|
|
||||||
|
input = np.random.random(self.input_size).astype(self.dtype)
|
||||||
|
offset = 10 * np.random.random(self.offset_size).astype(self.dtype)
|
||||||
|
filter = np.random.random(self.filter_size).astype(self.dtype)
|
||||||
|
|
||||||
|
output = dconv_im2col_gemm(input, offset, filter, self.groups,
|
||||||
|
conv_param)
|
||||||
|
output = output.astype(self.dtype)
|
||||||
|
self.inputs = {
|
||||||
|
'Input': OpTest.np_dtype_to_fluid_dtype(input),
|
||||||
|
'Offset': OpTest.np_dtype_to_fluid_dtype(offset),
|
||||||
|
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
|
||||||
|
}
|
||||||
|
self.attrs = {
|
||||||
|
'strides': self.stride,
|
||||||
|
'paddings': self.pad,
|
||||||
|
'groups': self.groups,
|
||||||
|
'deformable_groups': self.deformable_groups,
|
||||||
|
'im2col_step': self.im2col_step,
|
||||||
|
'dilations': self.dilations,
|
||||||
|
}
|
||||||
|
self.outputs = {'Output': output}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
def test_check_grad(self):
|
||||||
|
self.check_grad(
|
||||||
|
['Input', 'Offset', 'Filter'], 'Output', max_relative_error=0.05)
|
||||||
|
|
||||||
|
def test_check_grad_no_filter(self):
|
||||||
|
self.check_grad(
|
||||||
|
['Input', 'Offset'],
|
||||||
|
'Output',
|
||||||
|
max_relative_error=0.1,
|
||||||
|
no_grad_set=set(['Filter']))
|
||||||
|
|
||||||
|
def init_test_case(self):
|
||||||
|
self.pad = [1, 1]
|
||||||
|
self.stride = [1, 1]
|
||||||
|
self.dilations = [1, 1]
|
||||||
|
self.input_size = [2, 4, 4, 4] # NCHW
|
||||||
|
assert np.mod(self.input_size[1], self.groups) == 0
|
||||||
|
f_c = self.input_size[1] // self.groups
|
||||||
|
self.filter_size = [4, f_c, 3, 3]
|
||||||
|
self.im2col_step = 1
|
||||||
|
self.deformable_groups = 1
|
||||||
|
offset_c = 2 * self.deformable_groups * self.filter_size[
|
||||||
|
2] * self.filter_size[3]
|
||||||
|
self.offset_size = [
|
||||||
|
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
|
||||||
|
]
|
||||||
|
|
||||||
|
def init_dilation(self):
|
||||||
|
self.dilations = [1, 1]
|
||||||
|
|
||||||
|
def init_group(self):
|
||||||
|
self.groups = 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestWithStride(TestModulatedDeformableConvOp):
|
||||||
|
def init_test_case(self):
|
||||||
|
self.pad = [3, 3]
|
||||||
|
self.stride = [2, 2]
|
||||||
|
self.input_size = [2, 3, 5, 5] # NCHW
|
||||||
|
assert np.mod(self.input_size[1], self.groups) == 0
|
||||||
|
f_c = self.input_size[1] // self.groups
|
||||||
|
self.filter_size = [6, f_c, 3, 3]
|
||||||
|
self.im2col_step = 1
|
||||||
|
self.deformable_groups = 1
|
||||||
|
offset_c = 2 * self.deformable_groups * self.filter_size[
|
||||||
|
2] * self.filter_size[3]
|
||||||
|
self.offset_size = [
|
||||||
|
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestWithDilation(TestModulatedDeformableConvOp):
|
||||||
|
def init_test_case(self):
|
||||||
|
self.pad = [2, 2]
|
||||||
|
self.stride = [1, 1]
|
||||||
|
self.input_size = [2, 3, 4, 4] # NCHW
|
||||||
|
assert np.mod(self.input_size[1], self.groups) == 0
|
||||||
|
f_c = self.input_size[1] // self.groups
|
||||||
|
self.filter_size = [6, f_c, 3, 3]
|
||||||
|
self.im2col_step = 1
|
||||||
|
self.deformable_groups = 1
|
||||||
|
offset_c = 2 * self.deformable_groups * self.filter_size[
|
||||||
|
2] * self.filter_size[3]
|
||||||
|
self.offset_size = [
|
||||||
|
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
|
||||||
|
]
|
||||||
|
|
||||||
|
def init_dilation(self):
|
||||||
|
self.dilations = [2, 2]
|
||||||
|
|
||||||
|
|
||||||
|
class TestWith1x1(TestModulatedDeformableConvOp):
|
||||||
|
def init_test_case(self):
|
||||||
|
self.pad = [0, 0]
|
||||||
|
self.stride = [1, 1]
|
||||||
|
self.input_size = [2, 3, 5, 5] # NCHW
|
||||||
|
assert np.mod(self.input_size[1], self.groups) == 0
|
||||||
|
f_c = self.input_size[1] // self.groups
|
||||||
|
self.filter_size = [6, f_c, 1, 1]
|
||||||
|
self.im2col_step = 1
|
||||||
|
self.deformable_groups = 1
|
||||||
|
offset_c = 2 * self.deformable_groups * self.filter_size[
|
||||||
|
2] * self.filter_size[3]
|
||||||
|
self.offset_size = [
|
||||||
|
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestWithGroup(TestModulatedDeformableConvOp):
|
||||||
|
def init_group(self):
|
||||||
|
self.groups = 2
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue