|
|
|
@ -74,6 +74,45 @@ struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
|
|
|
|
|
template struct FindAbsMaxFunctor<platform::CUDADeviceContext, float>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void FindChannelAbsMaxKernel(const T* in, const int n, const int c,
|
|
|
|
|
T* out) {
|
|
|
|
|
int tid = threadIdx.x;
|
|
|
|
|
int channel_size = n / c;
|
|
|
|
|
const T* in_c = in + blockIdx.x * channel_size;
|
|
|
|
|
extern __shared__ T shared_max_data[];
|
|
|
|
|
shared_max_data[tid] = T(0);
|
|
|
|
|
for (int i = tid; i < channel_size; i += blockDim.x) {
|
|
|
|
|
T tmp = fabs(in_c[i]);
|
|
|
|
|
if (tmp > shared_max_data[tid]) {
|
|
|
|
|
shared_max_data[tid] = tmp;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
for (int i = blockDim.x / 2; i > 0; i >>= 1) {
|
|
|
|
|
if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
|
|
|
|
|
shared_max_data[tid] = shared_max_data[tid + i];
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
}
|
|
|
|
|
if (tid == 0) {
|
|
|
|
|
out[blockIdx.x] = shared_max_data[0];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
void operator()(const platform::CUDADeviceContext& ctx, const T* in,
|
|
|
|
|
const int num, const int channel, T* out) {
|
|
|
|
|
int block = 1024;
|
|
|
|
|
int grid = channel;
|
|
|
|
|
FindChannelAbsMaxKernel<T><<<grid, block, 1024 * sizeof(T), ctx.stream()>>>(
|
|
|
|
|
in, num, channel, out);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, float>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void ClipAndQuantKernel(const T* in, const T* scale,
|
|
|
|
|
const int bin_cnt, const int n, T* out) {
|
|
|
|
@ -82,14 +121,76 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
|
|
|
|
|
|
|
|
|
|
T s = scale[0];
|
|
|
|
|
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
|
|
|
|
|
T x = in[bid];
|
|
|
|
|
T x = in[i];
|
|
|
|
|
T v = x > s ? s : x;
|
|
|
|
|
v = v < -s ? -s : v;
|
|
|
|
|
v = bin_cnt / s * v;
|
|
|
|
|
out[bid] = round(v);
|
|
|
|
|
out[i] = round(v);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
void operator()(const platform::CUDADeviceContext& ctx,
|
|
|
|
|
const framework::Tensor& in, const framework::Tensor& scale,
|
|
|
|
|
const int bin_cnt, framework::Tensor* out) {
|
|
|
|
|
int num = in.numel();
|
|
|
|
|
int block = 1024;
|
|
|
|
|
int grid = (block - 1 + num) / block;
|
|
|
|
|
|
|
|
|
|
const T* in_data = in.data<T>();
|
|
|
|
|
const T* scale_data = scale.data<T>();
|
|
|
|
|
T* out_data = out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
|
|
|
|
|
in_data, scale_data, bin_cnt, num, out_data);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void ChannelClipAndQuantKernel(const T* in, const T* scale,
|
|
|
|
|
const int bin_cnt, const int n,
|
|
|
|
|
const int c, T* out) {
|
|
|
|
|
int tid = threadIdx.x;
|
|
|
|
|
|
|
|
|
|
int channel_size = n / c;
|
|
|
|
|
const T* in_c = in + blockIdx.x * channel_size;
|
|
|
|
|
T* out_c = out + blockIdx.x * channel_size;
|
|
|
|
|
|
|
|
|
|
T s = scale[blockIdx.x];
|
|
|
|
|
for (int i = tid; i < channel_size; i += blockDim.x) {
|
|
|
|
|
T x = in_c[i];
|
|
|
|
|
T v = x > s ? s : x;
|
|
|
|
|
v = v < -s ? -s : v;
|
|
|
|
|
v = bin_cnt / s * v;
|
|
|
|
|
out_c[i] = round(v);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
void operator()(const platform::CUDADeviceContext& ctx,
|
|
|
|
|
const framework::Tensor& in, const framework::Tensor& scale,
|
|
|
|
|
const int bin_cnt, const int channel,
|
|
|
|
|
framework::Tensor* out) {
|
|
|
|
|
int num = in.numel();
|
|
|
|
|
int block = 1024;
|
|
|
|
|
int grid = channel;
|
|
|
|
|
|
|
|
|
|
const T* in_data = in.data<T>();
|
|
|
|
|
const T* scale_data = scale.data<T>();
|
|
|
|
|
T* out_data = out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
ChannelClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
|
|
|
|
|
in_data, scale_data, bin_cnt, num, channel, out_data);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext,
|
|
|
|
|
float>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
|
|
|
|
|
const T* last_scale,
|
|
|
|
@ -182,26 +283,6 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
template struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext,
|
|
|
|
|
float>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
void operator()(const platform::CUDADeviceContext& ctx,
|
|
|
|
|
const framework::Tensor& in, const framework::Tensor& scale,
|
|
|
|
|
const int bin_cnt, framework::Tensor* out) {
|
|
|
|
|
int num = in.numel();
|
|
|
|
|
int block = 1024;
|
|
|
|
|
int grid = (block - 1 + num) / block;
|
|
|
|
|
|
|
|
|
|
const T* in_data = in.data<T>();
|
|
|
|
|
const T* scale_data = scale.data<T>();
|
|
|
|
|
T* out_data = out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
|
|
|
|
|
in_data, scale_data, bin_cnt, num, out_data);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|