|
|
|
@ -234,6 +234,63 @@ static __global__ void KeBNBackwardData(const T *dy,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, int BlockDim, framework::DataLayout layout>
|
|
|
|
|
static __global__ void BNBackwardData(const T *dy,
|
|
|
|
|
const BatchNormParamType<T> *scale,
|
|
|
|
|
const BatchNormParamType<T> *mean,
|
|
|
|
|
const T *x,
|
|
|
|
|
const BatchNormParamType<T> *variance,
|
|
|
|
|
const int C, const int N, const int HxW,
|
|
|
|
|
T *dx) {
|
|
|
|
|
const int outer_size = C;
|
|
|
|
|
const int inner_size = N * HxW;
|
|
|
|
|
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
|
|
|
|
|
__shared__ typename BlockReduce::TempStorage dy_storage;
|
|
|
|
|
__shared__ typename BlockReduce::TempStorage dy_x_sub_mean_storage;
|
|
|
|
|
__shared__ BatchNormParamType<T> dy_sum_val;
|
|
|
|
|
__shared__ BatchNormParamType<T> dy_x_sub_mean_sum_val;
|
|
|
|
|
|
|
|
|
|
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
|
|
|
|
|
BatchNormParamType<T> inv_var_i = variance[i];
|
|
|
|
|
BatchNormParamType<T> mean_i = mean[i];
|
|
|
|
|
BatchNormParamType<T> dy_sum = static_cast<BatchNormParamType<T>>(0);
|
|
|
|
|
BatchNormParamType<T> dy_x_sub_mean_sum =
|
|
|
|
|
static_cast<BatchNormParamType<T>>(0);
|
|
|
|
|
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
|
|
|
|
|
const int index = layout == framework::DataLayout::kNCHW
|
|
|
|
|
? (j / HxW * C + i) * HxW + j % HxW
|
|
|
|
|
: j * outer_size + i;
|
|
|
|
|
BatchNormParamType<T> dy_i =
|
|
|
|
|
static_cast<BatchNormParamType<T>>(dy[index]);
|
|
|
|
|
dy_sum += dy_i;
|
|
|
|
|
dy_x_sub_mean_sum +=
|
|
|
|
|
dy_i * (static_cast<BatchNormParamType<T>>(x[index]) - mean_i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
|
|
|
|
|
dy_x_sub_mean_sum = BlockReduce(dy_x_sub_mean_storage)
|
|
|
|
|
.Reduce(dy_x_sub_mean_sum, cub::Sum());
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
dy_sum_val = dy_sum;
|
|
|
|
|
dy_x_sub_mean_sum_val = dy_x_sub_mean_sum;
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
|
|
|
|
|
const int index = layout == framework::DataLayout::kNCHW
|
|
|
|
|
? (j / HxW * C + i) * HxW + j % HxW
|
|
|
|
|
: j * outer_size + i;
|
|
|
|
|
dx[index] =
|
|
|
|
|
(static_cast<BatchNormParamType<T>>(dy[index]) -
|
|
|
|
|
dy_sum_val / static_cast<BatchNormParamType<T>>(inner_size) -
|
|
|
|
|
(static_cast<BatchNormParamType<T>>(x[index]) - mean_i) *
|
|
|
|
|
dy_x_sub_mean_sum_val * inv_var_i * inv_var_i / inner_size) *
|
|
|
|
|
scale[i] * inv_var_i;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class BatchNormGradKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
: public framework::OpKernel<T> {
|
|
|
|
@ -282,6 +339,13 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
const int num = x->numel();
|
|
|
|
|
const int block = 512;
|
|
|
|
|
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
|
|
|
|
|
const int max_blocks = std::max(max_threads / block, 1);
|
|
|
|
|
int grid1 = (num + block - 1) / block;
|
|
|
|
|
int grid2 = std::min(C, max_blocks);
|
|
|
|
|
|
|
|
|
|
if (!use_global_stats) {
|
|
|
|
|
if ((N * H * W * D) == 1) {
|
|
|
|
|
framework::TensorCopy(*d_y, ctx.GetPlace(), d_x);
|
|
|
|
@ -325,21 +389,43 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
|
|
|
|
|
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
|
|
|
|
|
const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
|
|
|
|
|
const void *saved_mean_data =
|
|
|
|
|
const auto *saved_mean_data =
|
|
|
|
|
saved_mean->template data<BatchNormParamType<T>>();
|
|
|
|
|
const void *saved_var_data =
|
|
|
|
|
const auto *saved_var_data =
|
|
|
|
|
saved_var->template data<BatchNormParamType<T>>();
|
|
|
|
|
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
|
|
|
|
|
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
|
|
|
|
|
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
|
|
|
|
|
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
|
|
|
|
|
data_desc_, d_y->template data<T>(), data_desc_,
|
|
|
|
|
d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
|
|
|
|
|
scale->template data<BatchNormParamType<T>>(),
|
|
|
|
|
d_scale->template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
|
|
|
|
|
d_bias->template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
|
|
|
|
|
epsilon, saved_mean_data, saved_var_data));
|
|
|
|
|
if (d_scale && d_bias) {
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
|
|
|
|
|
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
|
|
|
|
|
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
|
|
|
|
|
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
|
|
|
|
|
data_desc_, d_y->template data<T>(), data_desc_,
|
|
|
|
|
d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
|
|
|
|
|
scale->template data<BatchNormParamType<T>>(),
|
|
|
|
|
d_scale->template mutable_data<BatchNormParamType<T>>(
|
|
|
|
|
ctx.GetPlace()),
|
|
|
|
|
d_bias->template mutable_data<BatchNormParamType<T>>(
|
|
|
|
|
ctx.GetPlace()),
|
|
|
|
|
epsilon, saved_mean_data, saved_var_data));
|
|
|
|
|
} else {
|
|
|
|
|
if (data_layout == framework::DataLayout::kNCHW) {
|
|
|
|
|
if (d_x) {
|
|
|
|
|
BNBackwardData<T, block, framework::DataLayout::kNCHW><<<
|
|
|
|
|
grid2, block, 0, dev_ctx.stream()>>>(
|
|
|
|
|
d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
|
|
|
|
|
saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D,
|
|
|
|
|
d_x->data<T>());
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (d_x) {
|
|
|
|
|
BNBackwardData<T, block, framework::DataLayout::kNCHW><<<
|
|
|
|
|
grid2, block, 0, dev_ctx.stream()>>>(
|
|
|
|
|
d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
|
|
|
|
|
saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D,
|
|
|
|
|
d_x->data<T>());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// clean when exit.
|
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
@ -355,13 +441,6 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
const auto *running_var_data =
|
|
|
|
|
running_var->template data<BatchNormParamType<T>>();
|
|
|
|
|
|
|
|
|
|
const int num = x->numel();
|
|
|
|
|
const int block = 512;
|
|
|
|
|
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
|
|
|
|
|
const int max_blocks = std::max(max_threads / block, 1);
|
|
|
|
|
int grid1 = (num + block - 1) / block;
|
|
|
|
|
int grid2 = std::min(C, max_blocks);
|
|
|
|
|
|
|
|
|
|
if (data_layout == framework::DataLayout::kNCHW) {
|
|
|
|
|
if (d_x) {
|
|
|
|
|
KeBNBackwardData<T, framework::DataLayout::kNCHW><<<
|
|
|
|
|