|
|
|
@ -126,20 +126,48 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a,
|
|
|
|
|
matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void Set<typename GPUPlace, typename float>(const int n, const float alpha,
|
|
|
|
|
float* output,
|
|
|
|
|
platform::DeviceContext* context) {
|
|
|
|
|
auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context);
|
|
|
|
|
framework::EigenVector::Type<T> out(output, n);
|
|
|
|
|
out.device(*(cuda_context->eigen_device())) = t.constant(T(alpha));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void UniformShift(const int n, const T min, const T max, T* x) {
|
|
|
|
|
float scale = max - min;
|
|
|
|
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
|
|
|
|
|
i += blockDim.x * gridDim.x) {
|
|
|
|
|
x[i] = x[i] * scale + min;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void RandUniform<platform::GPUPlace, float>(const int n, const float min,
|
|
|
|
|
const float max, float* output,
|
|
|
|
|
platform::DeviceContext* context) {
|
|
|
|
|
auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context);
|
|
|
|
|
thrust::uniform_real_distribution<float> distribution(min, max);
|
|
|
|
|
thrust::minstd_rand engine = cuda_context->rand_enigne();
|
|
|
|
|
engine->discard(n);
|
|
|
|
|
|
|
|
|
|
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
curandGenerateUniform(cuda_context->curand_generator(), output, n));
|
|
|
|
|
int block = 512;
|
|
|
|
|
int grid = (n + block - 1) / block;
|
|
|
|
|
UniformShift<float><<<grid, block, 0, cuda_context->stream()>>>(n, min, max,
|
|
|
|
|
output);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
thrust::transform(thrust::cuda::par.on(cuda_context->stream()),
|
|
|
|
|
index_sequence_begin, index_sequence_begin + n,
|
|
|
|
|
thrust::device_ptr<float>(output), distribution(engine));
|
|
|
|
|
template <typename T>
|
|
|
|
|
int HandleOddLengthRandGaussian(const int n, const T mean, const T std,
|
|
|
|
|
T* output, CUDADeviceContext* context) {
|
|
|
|
|
if (n % 2 == 1) {
|
|
|
|
|
std::default_random_engine generator;
|
|
|
|
|
std::normal_distribution<T> distribution(mean, std);
|
|
|
|
|
const T random_value = distribution(generator);
|
|
|
|
|
Set<T, platform::GPUPlace>(1, random_value, output + (n - 1), context);
|
|
|
|
|
return n - 1;
|
|
|
|
|
}
|
|
|
|
|
return n;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
@ -147,15 +175,11 @@ void RandGaussian<platform::GPUPlace, float>(const int n, const float mean,
|
|
|
|
|
const float std, float* output,
|
|
|
|
|
platform::DeviceContext* context) {
|
|
|
|
|
auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context);
|
|
|
|
|
thrust::normal_distribution<float> distribution(mean, std);
|
|
|
|
|
thrust::minstd_rand engine = cuda_context->rand_enigne();
|
|
|
|
|
engine->discard(n);
|
|
|
|
|
|
|
|
|
|
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
|
|
|
|
|
|
|
|
|
|
thrust::transform(thrust::cuda::par.on(cuda_context->stream()),
|
|
|
|
|
index_sequence_begin, index_sequence_begin + n,
|
|
|
|
|
thrust::device_ptr<float>(output), distribution(engine));
|
|
|
|
|
const int even_n =
|
|
|
|
|
HandleOddLengthRandGaussian<float>(n, mean, std, output, cuda_context);
|
|
|
|
|
PADDLE_ENFORCE(curandGenerateNormal(cuda_context->curand_generator(), output,
|
|
|
|
|
even_n, mean, std));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace math
|
|
|
|
|