|
|
@ -30,16 +30,15 @@ struct MaskGenerator {
|
|
|
|
__host__ __device__ MaskGenerator(AttrType dropout_prob, int seed)
|
|
|
|
__host__ __device__ MaskGenerator(AttrType dropout_prob, int seed)
|
|
|
|
: dropout_prob(dropout_prob), seed(seed) {}
|
|
|
|
: dropout_prob(dropout_prob), seed(seed) {}
|
|
|
|
|
|
|
|
|
|
|
|
__host__ __device__ T operator()(const unsigned int n) const {
|
|
|
|
inline __host__ __device__ T operator()(const unsigned int n) const {
|
|
|
|
thrust::minstd_rand rng;
|
|
|
|
thrust::minstd_rand rng;
|
|
|
|
rng.seed(seed);
|
|
|
|
rng.seed(seed);
|
|
|
|
thrust::uniform_real_distribution<AttrType> dist(0, 1);
|
|
|
|
thrust::uniform_real_distribution<AttrType> dist(0, 1);
|
|
|
|
rng.discard(n);
|
|
|
|
rng.discard(n);
|
|
|
|
if (dist(rng) < dropout_prob) {
|
|
|
|
if (dist(rng) < dropout_prob) {
|
|
|
|
return static_cast<T>(0);
|
|
|
|
return static_cast<T>(0);
|
|
|
|
} else {
|
|
|
|
|
|
|
|
return static_cast<T>(1);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return static_cast<T>(1);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|