|
|
|
@ -109,12 +109,12 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
|
|
|
|
|
thrust::device_ptr<T>(data),
|
|
|
|
|
TruncatedNormalOffset<T>(mean, std, std::numeric_limits<T>::min(),
|
|
|
|
|
seed_offset.first, gen_offset));
|
|
|
|
|
} else {
|
|
|
|
|
thrust::transform(
|
|
|
|
|
index_sequence_begin, index_sequence_begin + size,
|
|
|
|
|
thrust::device_ptr<T>(data),
|
|
|
|
|
TruncatedNormal<T>(mean, std, std::numeric_limits<T>::min(), seed));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
thrust::transform(
|
|
|
|
|
index_sequence_begin, index_sequence_begin + size,
|
|
|
|
|
thrust::device_ptr<T>(data),
|
|
|
|
|
TruncatedNormal<T>(mean, std, std::numeric_limits<T>::min(), seed));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|