|
|
|
@ -49,20 +49,70 @@ __global__ void Pnorm(const T* x, const int pre,
|
|
|
|
|
|
|
|
|
|
for (int i = blockIdx.x; i < num; i += gridDim.x) {
|
|
|
|
|
int base = (i / post) * post * axis_n + (i % post);
|
|
|
|
|
|
|
|
|
|
T sum = 0.0;
|
|
|
|
|
__shared__ T norm;
|
|
|
|
|
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
|
|
|
|
|
const T x_ij = x[base + j * post];
|
|
|
|
|
sum += inline_pow(inline_abs(x_ij), porder_t);
|
|
|
|
|
}
|
|
|
|
|
T reduce_result = BlockReduce(temp_storage).Sum(sum);
|
|
|
|
|
if (threadIdx.x == 0) out_norm[i] = inline_pow(reduce_result, porder_inv);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
norm = inline_pow(reduce_result, porder_inv);
|
|
|
|
|
out_norm[i] = norm;
|
|
|
|
|
template <typename T, int BlockDim>
|
|
|
|
|
__global__ void ZeorNorm(const T* x, const int pre,
|
|
|
|
|
const int axis_n, // dim in axis
|
|
|
|
|
const int post, T* out_norm) {
|
|
|
|
|
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
|
|
|
|
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
|
|
|
|
int num = pre * post;
|
|
|
|
|
for (int i = blockIdx.x; i < num; i += gridDim.x) {
|
|
|
|
|
int base = (i / post) * post * axis_n + (i % post);
|
|
|
|
|
T sum = 0.0;
|
|
|
|
|
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
|
|
|
|
|
const T x_ij = x[base + j * post];
|
|
|
|
|
sum += static_cast<T>(x_ij != 0);
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
T reduce_result = BlockReduce(temp_storage).Sum(sum);
|
|
|
|
|
if (threadIdx.x == 0) out_norm[i] = reduce_result;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, int BlockDim>
|
|
|
|
|
__global__ void InfNorm(const T* x, const int pre,
|
|
|
|
|
const int axis_n, // dim in axis
|
|
|
|
|
const int post, T* out_norm) {
|
|
|
|
|
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
|
|
|
|
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
|
|
|
|
int num = pre * post;
|
|
|
|
|
for (int i = blockIdx.x; i < num; i += gridDim.x) {
|
|
|
|
|
int base = (i / post) * post * axis_n + (i % post);
|
|
|
|
|
T cur_max = inline_abs(x[base]);
|
|
|
|
|
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
|
|
|
|
|
T x_ij_abs = inline_abs(x[base + j * post]);
|
|
|
|
|
if (cur_max < x_ij_abs) cur_max = x_ij_abs;
|
|
|
|
|
}
|
|
|
|
|
T reduce_result = BlockReduce(temp_storage).Reduce(cur_max, cub::Max());
|
|
|
|
|
if (threadIdx.x == 0) out_norm[i] = reduce_result;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, int BlockDim>
|
|
|
|
|
__global__ void NegInfNorm(const T* x, const int pre,
|
|
|
|
|
const int axis_n, // dim in axis
|
|
|
|
|
const int post, T* out_norm) {
|
|
|
|
|
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
|
|
|
|
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
|
|
|
|
int num = pre * post;
|
|
|
|
|
for (int i = blockIdx.x; i < num; i += gridDim.x) {
|
|
|
|
|
int base = (i / post) * post * axis_n + (i % post);
|
|
|
|
|
T cur_min = inline_abs(x[base]);
|
|
|
|
|
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
|
|
|
|
|
T x_ij_abs = inline_abs(x[base + j * post]);
|
|
|
|
|
if (cur_min > x_ij_abs) cur_min = x_ij_abs;
|
|
|
|
|
}
|
|
|
|
|
T reduce_result = BlockReduce(temp_storage).Reduce(cur_min, cub::Min());
|
|
|
|
|
if (threadIdx.x == 0) out_norm[i] = reduce_result;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -89,9 +139,20 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
|
|
|
|
|
const int max_blocks = std::max(max_threads / block, 1);
|
|
|
|
|
int grid = std::min(max_blocks, pre * post);
|
|
|
|
|
if (porder == 0) {
|
|
|
|
|
ZeorNorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post,
|
|
|
|
|
norm);
|
|
|
|
|
} else if (porder == INFINITY) {
|
|
|
|
|
InfNorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post,
|
|
|
|
|
norm);
|
|
|
|
|
} else if (porder == -INFINITY) {
|
|
|
|
|
NegInfNorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n,
|
|
|
|
|
post, norm);
|
|
|
|
|
} else {
|
|
|
|
|
Pnorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post,
|
|
|
|
|
porder, norm);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T, int BlockDim>
|
|
|
|
@ -112,7 +173,6 @@ __global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad,
|
|
|
|
|
pnorm_i = x_norm[i];
|
|
|
|
|
yout_i = y_grad[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
|
|
|
|
@ -125,6 +185,33 @@ __global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, int BlockDim>
|
|
|
|
|
__global__ void InfNormGradient(const T* x, const T* x_norm, const T* y_grad,
|
|
|
|
|
const int pre, const int axis_n, const int post,
|
|
|
|
|
T* x_grad) {
|
|
|
|
|
int num = pre * post;
|
|
|
|
|
for (int i = blockIdx.x; i < num; i += gridDim.x) {
|
|
|
|
|
__shared__ T pnorm_i;
|
|
|
|
|
__shared__ T yout_i;
|
|
|
|
|
auto base = (i / post) * post * axis_n + (i % post);
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
pnorm_i = x_norm[i];
|
|
|
|
|
yout_i = y_grad[i];
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
|
|
|
|
|
int index = base + j * post;
|
|
|
|
|
const T x_ij = inline_abs(x[index]);
|
|
|
|
|
if (x_ij == pnorm_i) {
|
|
|
|
|
x_grad[index] = inline_sign(x[index]) * yout_i;
|
|
|
|
|
} else {
|
|
|
|
|
x_grad[index] = static_cast<T>(0);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T, typename AttrType = T>
|
|
|
|
|
class PnormGradCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
@ -153,9 +240,18 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
|
|
|
|
|
const int max_blocks = std::max(max_threads / block, 1);
|
|
|
|
|
int grid = std::min(max_blocks, pre * post);
|
|
|
|
|
if (porder == 0) {
|
|
|
|
|
math::SetConstant<DeviceContext, T> set_zero;
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
set_zero(dev_ctx, out_dx, static_cast<T>(0));
|
|
|
|
|
} else if (porder == INFINITY || porder == -INFINITY) {
|
|
|
|
|
InfNormGradient<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
|
|
|
|
|
x, x_norm, norm_dy, pre, n, post, dx);
|
|
|
|
|
} else {
|
|
|
|
|
PnormGradient<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
|
|
|
|
|
x, x_norm, norm_dy, porder, pre, n, post, eps, dx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|