|
|
@ -42,7 +42,7 @@ struct TruncatedNormal {
|
|
|
|
rng.discard(n);
|
|
|
|
rng.discard(n);
|
|
|
|
T value = dist(rng);
|
|
|
|
T value = dist(rng);
|
|
|
|
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
|
|
|
|
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
|
|
|
|
return (std::sqrt(2.0) * erfinvf(2 * p - 1) + mean) * std;
|
|
|
|
return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -52,6 +52,7 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
auto* tensor = context.Output<framework::Tensor>("Out");
|
|
|
|
auto* tensor = context.Output<framework::Tensor>("Out");
|
|
|
|
T* data = tensor->mutable_data<T>(context.GetPlace());
|
|
|
|
T* data = tensor->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
|
|
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
|
|
|
|
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
|
|
|
|
if (seed == 0) {
|
|
|
|
if (seed == 0) {
|
|
|
|
std::random_device rd;
|
|
|
|
std::random_device rd;
|
|
|
|