|
|
@ -40,12 +40,12 @@ using DataLayout = framework::DataLayout;
|
|
|
|
// (np.mean(dy, axis=(n,h,w)) - dy) + inv_var.pow(3) / NxHxW *
|
|
|
|
// (np.mean(dy, axis=(n,h,w)) - dy) + inv_var.pow(3) / NxHxW *
|
|
|
|
// np.sum(dy,
|
|
|
|
// np.sum(dy,
|
|
|
|
// axis=(n,h,w)) * (x - mean) *
|
|
|
|
// axis=(n,h,w)) * (x - mean) *
|
|
|
|
// (np.mean(ddx, axis=(n,h,w)) - ddx) + ddr * (dy * inv_var -
|
|
|
|
// (np.mean(ddx, axis=(n,h,w)) - ddx)) + ddr * (dy * inv_var -
|
|
|
|
// inv_var
|
|
|
|
// inv_var
|
|
|
|
// *
|
|
|
|
// *
|
|
|
|
// np.mean(dy, axis=(n,h,w)) -
|
|
|
|
// np.mean(dy, axis=(n,h,w)) -
|
|
|
|
// inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
|
|
|
|
// inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
|
|
|
|
// axis=(n,h,w))))
|
|
|
|
// axis=(n,h,w)))
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T, int BlockDim, framework::DataLayout layout>
|
|
|
|
template <typename T, int BlockDim, framework::DataLayout layout>
|
|
|
|
__global__ void DoubleGradComputeDX(const T *x, const T *mean,
|
|
|
|
__global__ void DoubleGradComputeDX(const T *x, const T *mean,
|
|
|
@ -138,7 +138,7 @@ __global__ void DoubleGradComputeDX(const T *x, const T *mean,
|
|
|
|
? (j / sample_size * C + i) * sample_size + j % sample_size
|
|
|
|
? (j / sample_size * C + i) * sample_size + j % sample_size
|
|
|
|
: j * outer_size + i;
|
|
|
|
: j * outer_size + i;
|
|
|
|
dx[index] += (dy[index] * var_val - dy_sum_val / inner_size * var_val -
|
|
|
|
dx[index] += (dy[index] * var_val - dy_sum_val / inner_size * var_val -
|
|
|
|
(x[index] - mean_val) * var_val *
|
|
|
|
(x[index] - mean_val) * var_val * var_val *
|
|
|
|
dy_mul_x_sub_mean_sum_val * var_val / inner_size) *
|
|
|
|
dy_mul_x_sub_mean_sum_val * var_val / inner_size) *
|
|
|
|
ddscale[i];
|
|
|
|
ddscale[i];
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -326,19 +326,57 @@ __global__ void DoubleGradComputeDScaleWithGlobal(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// math: dx = ddscale * dy * inv_var
|
|
|
|
// math: dx = ddscale * dy * inv_var
|
|
|
|
// math: ddy = scale * ddx * inv_var
|
|
|
|
|
|
|
|
template <typename T, framework::DataLayout layout>
|
|
|
|
template <typename T, framework::DataLayout layout>
|
|
|
|
__global__ void DoubleGradComputeDataWithGlobal(
|
|
|
|
__global__ void DoubleGradComputeDXWithGlobal(const T *dy, const T *ddscale,
|
|
|
|
const T *dy, const T *scale, const T *variance, const double epsilon,
|
|
|
|
const T *variance,
|
|
|
|
const int C, const int sample_size, const int num, T *dx) {
|
|
|
|
const double epsilon, const int C,
|
|
|
|
|
|
|
|
const int sample_size,
|
|
|
|
|
|
|
|
const int num, T *dx) {
|
|
|
|
|
|
|
|
int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
|
|
|
|
if (ddscale != nullptr) {
|
|
|
|
|
|
|
|
for (int i = gid; i < num; i += stride) {
|
|
|
|
|
|
|
|
const int c =
|
|
|
|
|
|
|
|
layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C;
|
|
|
|
|
|
|
|
T inv_var = 1.0 / sqrt(variance[c] + epsilon);
|
|
|
|
|
|
|
|
dx[i] = dy[i] * ddscale[c] * inv_var;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// math: ddy = scale * ddx * inv_var + ddbias +
|
|
|
|
|
|
|
|
// ddscale * (x - mean) * inv_var
|
|
|
|
|
|
|
|
template <typename T, framework::DataLayout layout>
|
|
|
|
|
|
|
|
__global__ void DoubleGradComputeDDYWithGlobal(
|
|
|
|
|
|
|
|
const T *ddx, const T *scale, const T *mean, const T *variance, const T *x,
|
|
|
|
|
|
|
|
const T *ddbias, const T *ddscale, const double epsilon, const int C,
|
|
|
|
|
|
|
|
const int sample_size, const int num, T *ddy) {
|
|
|
|
int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
int stride = blockDim.x * gridDim.x;
|
|
|
|
if (scale != nullptr) {
|
|
|
|
|
|
|
|
|
|
|
|
if (ddx != nullptr) {
|
|
|
|
for (int i = gid; i < num; i += stride) {
|
|
|
|
for (int i = gid; i < num; i += stride) {
|
|
|
|
const int c =
|
|
|
|
const int c =
|
|
|
|
layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C;
|
|
|
|
layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C;
|
|
|
|
T inv_var = 1.0 / sqrt(variance[c] + epsilon);
|
|
|
|
T inv_var = 1.0 / sqrt(variance[c] + epsilon);
|
|
|
|
dx[i] = dy[i] * scale[c] * inv_var;
|
|
|
|
ddy[i] += ddx[i] * scale[c] * inv_var;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
if (ddscale != nullptr) {
|
|
|
|
|
|
|
|
for (int i = gid; i < num; i += stride) {
|
|
|
|
|
|
|
|
const int c =
|
|
|
|
|
|
|
|
layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C;
|
|
|
|
|
|
|
|
T inv_var = 1.0 / sqrt(variance[c] + epsilon);
|
|
|
|
|
|
|
|
ddy[i] += (x[i] - mean[c]) * inv_var * ddscale[c];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
if (ddbias != nullptr) {
|
|
|
|
|
|
|
|
for (int i = gid; i < num; i += stride) {
|
|
|
|
|
|
|
|
const int c =
|
|
|
|
|
|
|
|
layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C;
|
|
|
|
|
|
|
|
ddy[i] += ddbias[c];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -383,8 +421,11 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
|
|
|
|
|
|
|
|
|
|
|
|
const T *mean_data, *variance_data;
|
|
|
|
const T *mean_data, *variance_data;
|
|
|
|
if (use_global_stats) {
|
|
|
|
if (use_global_stats) {
|
|
|
|
|
|
|
|
const auto *running_mean = ctx.Input<Tensor>("Mean");
|
|
|
|
const auto *running_var = ctx.Input<Tensor>("Variance");
|
|
|
|
const auto *running_var = ctx.Input<Tensor>("Variance");
|
|
|
|
|
|
|
|
const auto *running_mean_data = running_mean->template data<T>();
|
|
|
|
const auto *running_var_data = running_var->template data<T>();
|
|
|
|
const auto *running_var_data = running_var->template data<T>();
|
|
|
|
|
|
|
|
mean_data = running_mean_data;
|
|
|
|
variance_data = running_var_data;
|
|
|
|
variance_data = running_var_data;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
const T *smean_data = Saved_mean->data<T>();
|
|
|
|
const T *smean_data = Saved_mean->data<T>();
|
|
|
@ -398,12 +439,12 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
|
|
|
|
set_constant(dev_ctx, dX, static_cast<T>(0));
|
|
|
|
set_constant(dev_ctx, dX, static_cast<T>(0));
|
|
|
|
if (use_global_stats) {
|
|
|
|
if (use_global_stats) {
|
|
|
|
if (data_layout == DataLayout::kNHWC) {
|
|
|
|
if (data_layout == DataLayout::kNHWC) {
|
|
|
|
DoubleGradComputeDataWithGlobal<
|
|
|
|
DoubleGradComputeDXWithGlobal<
|
|
|
|
T, DataLayout::kNHWC><<<grid1, block, 0, dev_ctx.stream()>>>(
|
|
|
|
T, DataLayout::kNHWC><<<grid1, block, 0, dev_ctx.stream()>>>(
|
|
|
|
dy_data, ddscale_data, variance_data, epsilon, C, sample_size, num,
|
|
|
|
dy_data, ddscale_data, variance_data, epsilon, C, sample_size, num,
|
|
|
|
dx_data);
|
|
|
|
dx_data);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
DoubleGradComputeDataWithGlobal<
|
|
|
|
DoubleGradComputeDXWithGlobal<
|
|
|
|
T, DataLayout::kNCHW><<<grid1, block, 0, dev_ctx.stream()>>>(
|
|
|
|
T, DataLayout::kNCHW><<<grid1, block, 0, dev_ctx.stream()>>>(
|
|
|
|
dy_data, ddscale_data, variance_data, epsilon, C, sample_size, num,
|
|
|
|
dy_data, ddscale_data, variance_data, epsilon, C, sample_size, num,
|
|
|
|
dx_data);
|
|
|
|
dx_data);
|
|
|
@ -456,15 +497,15 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
|
|
|
|
set_constant(dev_ctx, ddY, static_cast<T>(0));
|
|
|
|
set_constant(dev_ctx, ddY, static_cast<T>(0));
|
|
|
|
if (use_global_stats) {
|
|
|
|
if (use_global_stats) {
|
|
|
|
if (data_layout == DataLayout::kNHWC) {
|
|
|
|
if (data_layout == DataLayout::kNHWC) {
|
|
|
|
DoubleGradComputeDataWithGlobal<
|
|
|
|
DoubleGradComputeDDYWithGlobal<
|
|
|
|
T, DataLayout::kNHWC><<<grid1, block, 0, dev_ctx.stream()>>>(
|
|
|
|
T, DataLayout::kNHWC><<<grid1, block, 0, dev_ctx.stream()>>>(
|
|
|
|
ddx_data, scale_data, variance_data, epsilon, C, sample_size, num,
|
|
|
|
ddx_data, scale_data, mean_data, variance_data, x_data, ddbias_data,
|
|
|
|
ddy_data);
|
|
|
|
ddscale_data, epsilon, C, sample_size, num, ddy_data);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
DoubleGradComputeDataWithGlobal<
|
|
|
|
DoubleGradComputeDDYWithGlobal<
|
|
|
|
T, DataLayout::kNCHW><<<grid1, block, 0, dev_ctx.stream()>>>(
|
|
|
|
T, DataLayout::kNCHW><<<grid1, block, 0, dev_ctx.stream()>>>(
|
|
|
|
ddx_data, scale_data, variance_data, epsilon, C, sample_size, num,
|
|
|
|
ddx_data, scale_data, mean_data, variance_data, x_data, ddbias_data,
|
|
|
|
ddy_data);
|
|
|
|
ddscale_data, epsilon, C, sample_size, num, ddy_data);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
if (data_layout == DataLayout::kNHWC) {
|
|
|
|
if (data_layout == DataLayout::kNHWC) {
|
|
|
|