add adaptive mode for pool.

for_weibo
dengkaipeng 6 years ago
parent 1213e2838f
commit eab4745965

File diff suppressed because it is too large Load Diff

@ -29,7 +29,7 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
const int ksize_width, const int stride_height,
const int stride_width, const int padding_height,
const int padding_width, PoolProcess pool_process,
bool exclusive, T* output_data) {
bool exclusive, bool adaptive, T* output_data) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int pw = index % output_width;
@ -37,13 +37,21 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
int c = (index / output_width / output_height) % channels;
int batch_idx = index / output_width / output_height / channels;
int hstart = ph * stride_height - padding_height;
int hend = min(hstart + ksize_height, input_height);
hstart = max(hstart, 0);
if (adaptive) {
int hstart = ADAPT_START_INDEX(ph, input_height, output_height);
int hend = ADAPT_END_INDEX(ph, input_height, output_height);
int wstart = pw * stride_width - padding_width;
int wend = min(wstart + ksize_width, input_width);
wstart = max(wstart, 0);
int wstart = ADAPT_START_INDEX(pw, input_width, output_width);
int wend = ADAPT_END_INDEX(pw, input_width, output_width);
} else {
int hstart = ph * stride_height - padding_height;
int hend = min(hstart + ksize_height, input_height);
hstart = max(hstart, 0);
int wstart = pw * stride_width - padding_width;
int wend = min(wstart + ksize_width, input_width);
wstart = max(wstart, 0);
}
input_data += (batch_idx * channels + c) * input_height * input_width;
T ele = pool_process.initial();
@ -52,8 +60,8 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
pool_process.compute(input_data[h * input_width + w], &ele);
}
}
int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width;
int pool_size = (exclusive || adaptive) ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width;
pool_process.finalize(static_cast<T>(pool_size), &ele);
output_data[index] = ele;
}

Loading…
Cancel
Save