|
|
|
@ -44,6 +44,9 @@ __global__ void Pnorm(const T* x, const int pre,
|
|
|
|
|
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
|
|
|
|
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
|
|
|
|
int num = pre * post;
|
|
|
|
|
auto porder_t = static_cast<T>(porder);
|
|
|
|
|
auto porder_inv = static_cast<T>(1.0 / porder);
|
|
|
|
|
|
|
|
|
|
for (int i = blockIdx.x; i < num; i += gridDim.x) {
|
|
|
|
|
int base = (i / post) * post * axis_n + (i % post);
|
|
|
|
|
|
|
|
|
@ -51,12 +54,12 @@ __global__ void Pnorm(const T* x, const int pre,
|
|
|
|
|
__shared__ T norm;
|
|
|
|
|
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
|
|
|
|
|
const T x_ij = x[base + j * post];
|
|
|
|
|
sum += inline_pow(inline_abs(x_ij), porder);
|
|
|
|
|
sum += inline_pow(inline_abs(x_ij), porder_t);
|
|
|
|
|
}
|
|
|
|
|
T reduce_result = BlockReduce(temp_storage).Sum(sum);
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
norm = inline_pow(reduce_result, 1.0f / porder);
|
|
|
|
|
norm = inline_pow(reduce_result, porder_inv);
|
|
|
|
|
out_norm[i] = norm;
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
@ -100,6 +103,7 @@ __global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad,
|
|
|
|
|
__shared__ typename BlockReduce::TempStorage temp_storage_sum;
|
|
|
|
|
// dx = (x/pnorm_broadcast).pow(p-1) * norm_dy.broadcast * sign(x)
|
|
|
|
|
int num = pre * post;
|
|
|
|
|
auto porder_grad = static_cast<T>(porder - 1.0f);
|
|
|
|
|
for (int i = blockIdx.x; i < num; i += gridDim.x) {
|
|
|
|
|
T sum = 0.0;
|
|
|
|
|
__shared__ T row_sum;
|
|
|
|
@ -128,8 +132,8 @@ __global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad,
|
|
|
|
|
int index = base + j * post;
|
|
|
|
|
const T x_ij = inline_abs(x[index]);
|
|
|
|
|
const T dy_ij = y_grad[index];
|
|
|
|
|
x_grad[index] = inline_pow(x_ij, porder - 1.0f) /
|
|
|
|
|
(inline_pow(pnorm_i, porder - 1.0f) + eps) * yout_i *
|
|
|
|
|
x_grad[index] = inline_pow(x_ij, porder_grad) /
|
|
|
|
|
(inline_pow(pnorm_i, porder_grad) + eps) * yout_i *
|
|
|
|
|
inline_sign(x[index]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|